mirror of
https://gitlab.sectorq.eu/jaydee/omv_backup.git
synced 2025-07-04 00:45:50 +02:00
added v3
This commit is contained in:
@ -0,0 +1,23 @@
|
||||
# This file is dual licensed under the terms of the Apache License, Version
|
||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
||||
# for complete details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
|
||||
class KeyDerivationFunction(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def derive(self, key_material: bytes) -> bytes:
|
||||
"""
|
||||
Deterministically generates and returns a new key based on the existing
|
||||
key material.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def verify(self, key_material: bytes, expected_key: bytes) -> None:
|
||||
"""
|
||||
Checks whether the key generated by the key material matches the
|
||||
expected derived key. Raises an exception if they do not match.
|
||||
"""
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,13 @@
|
||||
# This file is dual licensed under the terms of the Apache License, Version
|
||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
||||
# for complete details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
|
||||
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
|
||||
|
||||
Argon2id = rust_openssl.kdf.Argon2id
|
||||
KeyDerivationFunction.register(Argon2id)
|
||||
|
||||
__all__ = ["Argon2id"]
|
@ -0,0 +1,124 @@
|
||||
# This file is dual licensed under the terms of the Apache License, Version
|
||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
||||
# for complete details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from cryptography import utils
|
||||
from cryptography.exceptions import AlreadyFinalized, InvalidKey
|
||||
from cryptography.hazmat.primitives import constant_time, hashes, hmac
|
||||
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
|
||||
|
||||
|
||||
def _int_to_u32be(n: int) -> bytes:
|
||||
return n.to_bytes(length=4, byteorder="big")
|
||||
|
||||
|
||||
def _common_args_checks(
|
||||
algorithm: hashes.HashAlgorithm,
|
||||
length: int,
|
||||
otherinfo: bytes | None,
|
||||
) -> None:
|
||||
max_length = algorithm.digest_size * (2**32 - 1)
|
||||
if length > max_length:
|
||||
raise ValueError(f"Cannot derive keys larger than {max_length} bits.")
|
||||
if otherinfo is not None:
|
||||
utils._check_bytes("otherinfo", otherinfo)
|
||||
|
||||
|
||||
def _concatkdf_derive(
|
||||
key_material: bytes,
|
||||
length: int,
|
||||
auxfn: typing.Callable[[], hashes.HashContext],
|
||||
otherinfo: bytes,
|
||||
) -> bytes:
|
||||
utils._check_byteslike("key_material", key_material)
|
||||
output = [b""]
|
||||
outlen = 0
|
||||
counter = 1
|
||||
|
||||
while length > outlen:
|
||||
h = auxfn()
|
||||
h.update(_int_to_u32be(counter))
|
||||
h.update(key_material)
|
||||
h.update(otherinfo)
|
||||
output.append(h.finalize())
|
||||
outlen += len(output[-1])
|
||||
counter += 1
|
||||
|
||||
return b"".join(output)[:length]
|
||||
|
||||
|
||||
class ConcatKDFHash(KeyDerivationFunction):
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: hashes.HashAlgorithm,
|
||||
length: int,
|
||||
otherinfo: bytes | None,
|
||||
backend: typing.Any = None,
|
||||
):
|
||||
_common_args_checks(algorithm, length, otherinfo)
|
||||
self._algorithm = algorithm
|
||||
self._length = length
|
||||
self._otherinfo: bytes = otherinfo if otherinfo is not None else b""
|
||||
|
||||
self._used = False
|
||||
|
||||
def _hash(self) -> hashes.Hash:
|
||||
return hashes.Hash(self._algorithm)
|
||||
|
||||
def derive(self, key_material: bytes) -> bytes:
|
||||
if self._used:
|
||||
raise AlreadyFinalized
|
||||
self._used = True
|
||||
return _concatkdf_derive(
|
||||
key_material, self._length, self._hash, self._otherinfo
|
||||
)
|
||||
|
||||
def verify(self, key_material: bytes, expected_key: bytes) -> None:
|
||||
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
|
||||
raise InvalidKey
|
||||
|
||||
|
||||
class ConcatKDFHMAC(KeyDerivationFunction):
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: hashes.HashAlgorithm,
|
||||
length: int,
|
||||
salt: bytes | None,
|
||||
otherinfo: bytes | None,
|
||||
backend: typing.Any = None,
|
||||
):
|
||||
_common_args_checks(algorithm, length, otherinfo)
|
||||
self._algorithm = algorithm
|
||||
self._length = length
|
||||
self._otherinfo: bytes = otherinfo if otherinfo is not None else b""
|
||||
|
||||
if algorithm.block_size is None:
|
||||
raise TypeError(f"{algorithm.name} is unsupported for ConcatKDF")
|
||||
|
||||
if salt is None:
|
||||
salt = b"\x00" * algorithm.block_size
|
||||
else:
|
||||
utils._check_bytes("salt", salt)
|
||||
|
||||
self._salt = salt
|
||||
|
||||
self._used = False
|
||||
|
||||
def _hmac(self) -> hmac.HMAC:
|
||||
return hmac.HMAC(self._salt, self._algorithm)
|
||||
|
||||
def derive(self, key_material: bytes) -> bytes:
|
||||
if self._used:
|
||||
raise AlreadyFinalized
|
||||
self._used = True
|
||||
return _concatkdf_derive(
|
||||
key_material, self._length, self._hmac, self._otherinfo
|
||||
)
|
||||
|
||||
def verify(self, key_material: bytes, expected_key: bytes) -> None:
|
||||
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
|
||||
raise InvalidKey
|
@ -0,0 +1,101 @@
|
||||
# This file is dual licensed under the terms of the Apache License, Version
|
||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
||||
# for complete details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from cryptography import utils
|
||||
from cryptography.exceptions import AlreadyFinalized, InvalidKey
|
||||
from cryptography.hazmat.primitives import constant_time, hashes, hmac
|
||||
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
|
||||
|
||||
|
||||
class HKDF(KeyDerivationFunction):
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: hashes.HashAlgorithm,
|
||||
length: int,
|
||||
salt: bytes | None,
|
||||
info: bytes | None,
|
||||
backend: typing.Any = None,
|
||||
):
|
||||
self._algorithm = algorithm
|
||||
|
||||
if salt is None:
|
||||
salt = b"\x00" * self._algorithm.digest_size
|
||||
else:
|
||||
utils._check_bytes("salt", salt)
|
||||
|
||||
self._salt = salt
|
||||
|
||||
self._hkdf_expand = HKDFExpand(self._algorithm, length, info)
|
||||
|
||||
def _extract(self, key_material: bytes) -> bytes:
|
||||
h = hmac.HMAC(self._salt, self._algorithm)
|
||||
h.update(key_material)
|
||||
return h.finalize()
|
||||
|
||||
def derive(self, key_material: bytes) -> bytes:
|
||||
utils._check_byteslike("key_material", key_material)
|
||||
return self._hkdf_expand.derive(self._extract(key_material))
|
||||
|
||||
def verify(self, key_material: bytes, expected_key: bytes) -> None:
|
||||
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
|
||||
raise InvalidKey
|
||||
|
||||
|
||||
class HKDFExpand(KeyDerivationFunction):
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: hashes.HashAlgorithm,
|
||||
length: int,
|
||||
info: bytes | None,
|
||||
backend: typing.Any = None,
|
||||
):
|
||||
self._algorithm = algorithm
|
||||
|
||||
max_length = 255 * algorithm.digest_size
|
||||
|
||||
if length > max_length:
|
||||
raise ValueError(
|
||||
f"Cannot derive keys larger than {max_length} octets."
|
||||
)
|
||||
|
||||
self._length = length
|
||||
|
||||
if info is None:
|
||||
info = b""
|
||||
else:
|
||||
utils._check_bytes("info", info)
|
||||
|
||||
self._info = info
|
||||
|
||||
self._used = False
|
||||
|
||||
def _expand(self, key_material: bytes) -> bytes:
|
||||
output = [b""]
|
||||
counter = 1
|
||||
|
||||
while self._algorithm.digest_size * (len(output) - 1) < self._length:
|
||||
h = hmac.HMAC(key_material, self._algorithm)
|
||||
h.update(output[-1])
|
||||
h.update(self._info)
|
||||
h.update(bytes([counter]))
|
||||
output.append(h.finalize())
|
||||
counter += 1
|
||||
|
||||
return b"".join(output)[: self._length]
|
||||
|
||||
def derive(self, key_material: bytes) -> bytes:
|
||||
utils._check_byteslike("key_material", key_material)
|
||||
if self._used:
|
||||
raise AlreadyFinalized
|
||||
|
||||
self._used = True
|
||||
return self._expand(key_material)
|
||||
|
||||
def verify(self, key_material: bytes, expected_key: bytes) -> None:
|
||||
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
|
||||
raise InvalidKey
|
@ -0,0 +1,302 @@
|
||||
# This file is dual licensed under the terms of the Apache License, Version
|
||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
||||
# for complete details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from cryptography import utils
|
||||
from cryptography.exceptions import (
|
||||
AlreadyFinalized,
|
||||
InvalidKey,
|
||||
UnsupportedAlgorithm,
|
||||
_Reasons,
|
||||
)
|
||||
from cryptography.hazmat.primitives import (
|
||||
ciphers,
|
||||
cmac,
|
||||
constant_time,
|
||||
hashes,
|
||||
hmac,
|
||||
)
|
||||
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
|
||||
|
||||
|
||||
class Mode(utils.Enum):
|
||||
CounterMode = "ctr"
|
||||
|
||||
|
||||
class CounterLocation(utils.Enum):
|
||||
BeforeFixed = "before_fixed"
|
||||
AfterFixed = "after_fixed"
|
||||
MiddleFixed = "middle_fixed"
|
||||
|
||||
|
||||
class _KBKDFDeriver:
|
||||
def __init__(
|
||||
self,
|
||||
prf: typing.Callable,
|
||||
mode: Mode,
|
||||
length: int,
|
||||
rlen: int,
|
||||
llen: int | None,
|
||||
location: CounterLocation,
|
||||
break_location: int | None,
|
||||
label: bytes | None,
|
||||
context: bytes | None,
|
||||
fixed: bytes | None,
|
||||
):
|
||||
assert callable(prf)
|
||||
|
||||
if not isinstance(mode, Mode):
|
||||
raise TypeError("mode must be of type Mode")
|
||||
|
||||
if not isinstance(location, CounterLocation):
|
||||
raise TypeError("location must be of type CounterLocation")
|
||||
|
||||
if break_location is None and location is CounterLocation.MiddleFixed:
|
||||
raise ValueError("Please specify a break_location")
|
||||
|
||||
if (
|
||||
break_location is not None
|
||||
and location != CounterLocation.MiddleFixed
|
||||
):
|
||||
raise ValueError(
|
||||
"break_location is ignored when location is not"
|
||||
" CounterLocation.MiddleFixed"
|
||||
)
|
||||
|
||||
if break_location is not None and not isinstance(break_location, int):
|
||||
raise TypeError("break_location must be an integer")
|
||||
|
||||
if break_location is not None and break_location < 0:
|
||||
raise ValueError("break_location must be a positive integer")
|
||||
|
||||
if (label or context) and fixed:
|
||||
raise ValueError(
|
||||
"When supplying fixed data, label and context are ignored."
|
||||
)
|
||||
|
||||
if rlen is None or not self._valid_byte_length(rlen):
|
||||
raise ValueError("rlen must be between 1 and 4")
|
||||
|
||||
if llen is None and fixed is None:
|
||||
raise ValueError("Please specify an llen")
|
||||
|
||||
if llen is not None and not isinstance(llen, int):
|
||||
raise TypeError("llen must be an integer")
|
||||
|
||||
if llen == 0:
|
||||
raise ValueError("llen must be non-zero")
|
||||
|
||||
if label is None:
|
||||
label = b""
|
||||
|
||||
if context is None:
|
||||
context = b""
|
||||
|
||||
utils._check_bytes("label", label)
|
||||
utils._check_bytes("context", context)
|
||||
self._prf = prf
|
||||
self._mode = mode
|
||||
self._length = length
|
||||
self._rlen = rlen
|
||||
self._llen = llen
|
||||
self._location = location
|
||||
self._break_location = break_location
|
||||
self._label = label
|
||||
self._context = context
|
||||
self._used = False
|
||||
self._fixed_data = fixed
|
||||
|
||||
@staticmethod
|
||||
def _valid_byte_length(value: int) -> bool:
|
||||
if not isinstance(value, int):
|
||||
raise TypeError("value must be of type int")
|
||||
|
||||
value_bin = utils.int_to_bytes(1, value)
|
||||
if not 1 <= len(value_bin) <= 4:
|
||||
return False
|
||||
return True
|
||||
|
||||
def derive(self, key_material: bytes, prf_output_size: int) -> bytes:
|
||||
if self._used:
|
||||
raise AlreadyFinalized
|
||||
|
||||
utils._check_byteslike("key_material", key_material)
|
||||
self._used = True
|
||||
|
||||
# inverse floor division (equivalent to ceiling)
|
||||
rounds = -(-self._length // prf_output_size)
|
||||
|
||||
output = [b""]
|
||||
|
||||
# For counter mode, the number of iterations shall not be
|
||||
# larger than 2^r-1, where r <= 32 is the binary length of the counter
|
||||
# This ensures that the counter values used as an input to the
|
||||
# PRF will not repeat during a particular call to the KDF function.
|
||||
r_bin = utils.int_to_bytes(1, self._rlen)
|
||||
if rounds > pow(2, len(r_bin) * 8) - 1:
|
||||
raise ValueError("There are too many iterations.")
|
||||
|
||||
fixed = self._generate_fixed_input()
|
||||
|
||||
if self._location == CounterLocation.BeforeFixed:
|
||||
data_before_ctr = b""
|
||||
data_after_ctr = fixed
|
||||
elif self._location == CounterLocation.AfterFixed:
|
||||
data_before_ctr = fixed
|
||||
data_after_ctr = b""
|
||||
else:
|
||||
if isinstance(
|
||||
self._break_location, int
|
||||
) and self._break_location > len(fixed):
|
||||
raise ValueError("break_location offset > len(fixed)")
|
||||
data_before_ctr = fixed[: self._break_location]
|
||||
data_after_ctr = fixed[self._break_location :]
|
||||
|
||||
for i in range(1, rounds + 1):
|
||||
h = self._prf(key_material)
|
||||
|
||||
counter = utils.int_to_bytes(i, self._rlen)
|
||||
input_data = data_before_ctr + counter + data_after_ctr
|
||||
|
||||
h.update(input_data)
|
||||
|
||||
output.append(h.finalize())
|
||||
|
||||
return b"".join(output)[: self._length]
|
||||
|
||||
def _generate_fixed_input(self) -> bytes:
|
||||
if self._fixed_data and isinstance(self._fixed_data, bytes):
|
||||
return self._fixed_data
|
||||
|
||||
l_val = utils.int_to_bytes(self._length * 8, self._llen)
|
||||
|
||||
return b"".join([self._label, b"\x00", self._context, l_val])
|
||||
|
||||
|
||||
class KBKDFHMAC(KeyDerivationFunction):
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: hashes.HashAlgorithm,
|
||||
mode: Mode,
|
||||
length: int,
|
||||
rlen: int,
|
||||
llen: int | None,
|
||||
location: CounterLocation,
|
||||
label: bytes | None,
|
||||
context: bytes | None,
|
||||
fixed: bytes | None,
|
||||
backend: typing.Any = None,
|
||||
*,
|
||||
break_location: int | None = None,
|
||||
):
|
||||
if not isinstance(algorithm, hashes.HashAlgorithm):
|
||||
raise UnsupportedAlgorithm(
|
||||
"Algorithm supplied is not a supported hash algorithm.",
|
||||
_Reasons.UNSUPPORTED_HASH,
|
||||
)
|
||||
|
||||
from cryptography.hazmat.backends.openssl.backend import (
|
||||
backend as ossl,
|
||||
)
|
||||
|
||||
if not ossl.hmac_supported(algorithm):
|
||||
raise UnsupportedAlgorithm(
|
||||
"Algorithm supplied is not a supported hmac algorithm.",
|
||||
_Reasons.UNSUPPORTED_HASH,
|
||||
)
|
||||
|
||||
self._algorithm = algorithm
|
||||
|
||||
self._deriver = _KBKDFDeriver(
|
||||
self._prf,
|
||||
mode,
|
||||
length,
|
||||
rlen,
|
||||
llen,
|
||||
location,
|
||||
break_location,
|
||||
label,
|
||||
context,
|
||||
fixed,
|
||||
)
|
||||
|
||||
def _prf(self, key_material: bytes) -> hmac.HMAC:
|
||||
return hmac.HMAC(key_material, self._algorithm)
|
||||
|
||||
def derive(self, key_material: bytes) -> bytes:
|
||||
return self._deriver.derive(key_material, self._algorithm.digest_size)
|
||||
|
||||
def verify(self, key_material: bytes, expected_key: bytes) -> None:
|
||||
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
|
||||
raise InvalidKey
|
||||
|
||||
|
||||
class KBKDFCMAC(KeyDerivationFunction):
|
||||
def __init__(
|
||||
self,
|
||||
algorithm,
|
||||
mode: Mode,
|
||||
length: int,
|
||||
rlen: int,
|
||||
llen: int | None,
|
||||
location: CounterLocation,
|
||||
label: bytes | None,
|
||||
context: bytes | None,
|
||||
fixed: bytes | None,
|
||||
backend: typing.Any = None,
|
||||
*,
|
||||
break_location: int | None = None,
|
||||
):
|
||||
if not issubclass(
|
||||
algorithm, ciphers.BlockCipherAlgorithm
|
||||
) or not issubclass(algorithm, ciphers.CipherAlgorithm):
|
||||
raise UnsupportedAlgorithm(
|
||||
"Algorithm supplied is not a supported cipher algorithm.",
|
||||
_Reasons.UNSUPPORTED_CIPHER,
|
||||
)
|
||||
|
||||
self._algorithm = algorithm
|
||||
self._cipher: ciphers.BlockCipherAlgorithm | None = None
|
||||
|
||||
self._deriver = _KBKDFDeriver(
|
||||
self._prf,
|
||||
mode,
|
||||
length,
|
||||
rlen,
|
||||
llen,
|
||||
location,
|
||||
break_location,
|
||||
label,
|
||||
context,
|
||||
fixed,
|
||||
)
|
||||
|
||||
def _prf(self, _: bytes) -> cmac.CMAC:
|
||||
assert self._cipher is not None
|
||||
|
||||
return cmac.CMAC(self._cipher)
|
||||
|
||||
def derive(self, key_material: bytes) -> bytes:
|
||||
self._cipher = self._algorithm(key_material)
|
||||
|
||||
assert self._cipher is not None
|
||||
|
||||
from cryptography.hazmat.backends.openssl.backend import (
|
||||
backend as ossl,
|
||||
)
|
||||
|
||||
if not ossl.cmac_algorithm_supported(self._cipher):
|
||||
raise UnsupportedAlgorithm(
|
||||
"Algorithm supplied is not a supported cipher algorithm.",
|
||||
_Reasons.UNSUPPORTED_CIPHER,
|
||||
)
|
||||
|
||||
return self._deriver.derive(key_material, self._cipher.block_size // 8)
|
||||
|
||||
def verify(self, key_material: bytes, expected_key: bytes) -> None:
|
||||
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
|
||||
raise InvalidKey
|
@ -0,0 +1,62 @@
|
||||
# This file is dual licensed under the terms of the Apache License, Version
|
||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
||||
# for complete details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from cryptography import utils
|
||||
from cryptography.exceptions import (
|
||||
AlreadyFinalized,
|
||||
InvalidKey,
|
||||
UnsupportedAlgorithm,
|
||||
_Reasons,
|
||||
)
|
||||
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
|
||||
from cryptography.hazmat.primitives import constant_time, hashes
|
||||
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
|
||||
|
||||
|
||||
class PBKDF2HMAC(KeyDerivationFunction):
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: hashes.HashAlgorithm,
|
||||
length: int,
|
||||
salt: bytes,
|
||||
iterations: int,
|
||||
backend: typing.Any = None,
|
||||
):
|
||||
from cryptography.hazmat.backends.openssl.backend import (
|
||||
backend as ossl,
|
||||
)
|
||||
|
||||
if not ossl.pbkdf2_hmac_supported(algorithm):
|
||||
raise UnsupportedAlgorithm(
|
||||
f"{algorithm.name} is not supported for PBKDF2.",
|
||||
_Reasons.UNSUPPORTED_HASH,
|
||||
)
|
||||
self._used = False
|
||||
self._algorithm = algorithm
|
||||
self._length = length
|
||||
utils._check_bytes("salt", salt)
|
||||
self._salt = salt
|
||||
self._iterations = iterations
|
||||
|
||||
def derive(self, key_material: bytes) -> bytes:
|
||||
if self._used:
|
||||
raise AlreadyFinalized("PBKDF2 instances can only be used once.")
|
||||
self._used = True
|
||||
|
||||
return rust_openssl.kdf.derive_pbkdf2_hmac(
|
||||
key_material,
|
||||
self._algorithm,
|
||||
self._salt,
|
||||
self._iterations,
|
||||
self._length,
|
||||
)
|
||||
|
||||
def verify(self, key_material: bytes, expected_key: bytes) -> None:
|
||||
derived_key = self.derive(key_material)
|
||||
if not constant_time.bytes_eq(derived_key, expected_key):
|
||||
raise InvalidKey("Keys do not match.")
|
@ -0,0 +1,19 @@
|
||||
# This file is dual licensed under the terms of the Apache License, Version
|
||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
||||
# for complete details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
from cryptography.hazmat.bindings._rust import openssl as rust_openssl
|
||||
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
|
||||
|
||||
# This is used by the scrypt tests to skip tests that require more memory
|
||||
# than the MEM_LIMIT
|
||||
_MEM_LIMIT = sys.maxsize // 2
|
||||
|
||||
Scrypt = rust_openssl.kdf.Scrypt
|
||||
KeyDerivationFunction.register(Scrypt)
|
||||
|
||||
__all__ = ["Scrypt"]
|
@ -0,0 +1,61 @@
|
||||
# This file is dual licensed under the terms of the Apache License, Version
|
||||
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
|
||||
# for complete details.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from cryptography import utils
|
||||
from cryptography.exceptions import AlreadyFinalized, InvalidKey
|
||||
from cryptography.hazmat.primitives import constant_time, hashes
|
||||
from cryptography.hazmat.primitives.kdf import KeyDerivationFunction
|
||||
|
||||
|
||||
def _int_to_u32be(n: int) -> bytes:
|
||||
return n.to_bytes(length=4, byteorder="big")
|
||||
|
||||
|
||||
class X963KDF(KeyDerivationFunction):
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: hashes.HashAlgorithm,
|
||||
length: int,
|
||||
sharedinfo: bytes | None,
|
||||
backend: typing.Any = None,
|
||||
):
|
||||
max_len = algorithm.digest_size * (2**32 - 1)
|
||||
if length > max_len:
|
||||
raise ValueError(f"Cannot derive keys larger than {max_len} bits.")
|
||||
if sharedinfo is not None:
|
||||
utils._check_bytes("sharedinfo", sharedinfo)
|
||||
|
||||
self._algorithm = algorithm
|
||||
self._length = length
|
||||
self._sharedinfo = sharedinfo
|
||||
self._used = False
|
||||
|
||||
def derive(self, key_material: bytes) -> bytes:
|
||||
if self._used:
|
||||
raise AlreadyFinalized
|
||||
self._used = True
|
||||
utils._check_byteslike("key_material", key_material)
|
||||
output = [b""]
|
||||
outlen = 0
|
||||
counter = 1
|
||||
|
||||
while self._length > outlen:
|
||||
h = hashes.Hash(self._algorithm)
|
||||
h.update(key_material)
|
||||
h.update(_int_to_u32be(counter))
|
||||
if self._sharedinfo is not None:
|
||||
h.update(self._sharedinfo)
|
||||
output.append(h.finalize())
|
||||
outlen += len(output[-1])
|
||||
counter += 1
|
||||
|
||||
return b"".join(output)[: self._length]
|
||||
|
||||
def verify(self, key_material: bytes, expected_key: bytes) -> None:
|
||||
if not constant_time.bytes_eq(self.derive(key_material), expected_key):
|
||||
raise InvalidKey
|
Reference in New Issue
Block a user