diff --git a/lib/jwt/__init__.py b/lib/jwt/__init__.py index 26c79b24..68d09c1c 100644 --- a/lib/jwt/__init__.py +++ b/lib/jwt/__init__.py @@ -19,6 +19,7 @@ from .exceptions import ( InvalidSignatureError, InvalidTokenError, MissingRequiredClaimError, + PyJWKClientConnectionError, PyJWKClientError, PyJWKError, PyJWKSetError, @@ -26,7 +27,7 @@ from .exceptions import ( ) from .jwks_client import PyJWKClient -__version__ = "2.6.0" +__version__ = "2.8.0" __title__ = "PyJWT" __description__ = "JSON Web Token implementation in Python" @@ -65,6 +66,7 @@ __all__ = [ "InvalidSignatureError", "InvalidTokenError", "MissingRequiredClaimError", + "PyJWKClientConnectionError", "PyJWKClientError", "PyJWKError", "PyJWKSetError", diff --git a/lib/jwt/algorithms.py b/lib/jwt/algorithms.py index 93fadf4c..ed187152 100644 --- a/lib/jwt/algorithms.py +++ b/lib/jwt/algorithms.py @@ -1,8 +1,14 @@ +from __future__ import annotations + import hashlib import hmac import json +import sys +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload from .exceptions import InvalidKeyError +from .types import HashlibHash, JWKDict from .utils import ( base64url_decode, base64url_encode, @@ -15,14 +21,28 @@ from .utils import ( to_base64url_uint, ) +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + + try: - import cryptography.exceptions from cryptography.exceptions import InvalidSignature + from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes - from cryptography.hazmat.primitives.asymmetric import ec, padding + from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric.ec import ( + ECDSA, + SECP256K1, + SECP256R1, + SECP384R1, + SECP521R1, + EllipticCurve, EllipticCurvePrivateKey, + EllipticCurvePrivateNumbers, EllipticCurvePublicKey, + EllipticCurvePublicNumbers, ) from cryptography.hazmat.primitives.asymmetric.ed448 import ( Ed448PrivateKey, @@ -56,6 +76,23 @@ try: except ModuleNotFoundError: has_crypto = False + +if TYPE_CHECKING: + # Type aliases for convenience in algorithms method signatures + AllowedRSAKeys = RSAPrivateKey | RSAPublicKey + AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey + AllowedOKPKeys = ( + Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey + ) + AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys + AllowedPrivateKeys = ( + RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey + ) + AllowedPublicKeys = ( + RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey + ) + + requires_cryptography = { "RS256", "RS384", @@ -72,7 +109,7 @@ requires_cryptography = { } -def get_default_algorithms(): +def get_default_algorithms() -> dict[str, Algorithm]: """ Returns the algorithms that are implemented by the library. """ @@ -106,45 +143,79 @@ def get_default_algorithms(): return default_algorithms -class Algorithm: +class Algorithm(ABC): """ The interface for an algorithm used to sign and verify tokens. """ - def prepare_key(self, key): + def compute_hash_digest(self, bytestr: bytes) -> bytes: + """ + Compute a hash digest using the specified algorithm's hash algorithm. + + If there is no hash algorithm, raises a NotImplementedError. + """ + # lookup self.hash_alg if defined in a way that mypy can understand + hash_alg = getattr(self, "hash_alg", None) + if hash_alg is None: + raise NotImplementedError + + if ( + has_crypto + and isinstance(hash_alg, type) + and issubclass(hash_alg, hashes.HashAlgorithm) + ): + digest = hashes.Hash(hash_alg(), backend=default_backend()) + digest.update(bytestr) + return bytes(digest.finalize()) + else: + return bytes(hash_alg(bytestr).digest()) + + @abstractmethod + def prepare_key(self, key: Any) -> Any: """ Performs necessary validation and conversions on the key and returns the key value in the proper format for sign() and verify(). """ - raise NotImplementedError - def sign(self, msg, key): + @abstractmethod + def sign(self, msg: bytes, key: Any) -> bytes: """ Returns a digital signature for the specified message using the specified key value. """ - raise NotImplementedError - def verify(self, msg, key, sig): + @abstractmethod + def verify(self, msg: bytes, key: Any, sig: bytes) -> bool: """ Verifies that the specified digital signature is valid for the specified message and key values. """ - raise NotImplementedError + + @overload + @staticmethod + @abstractmethod + def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: + ... # pragma: no cover + + @overload + @staticmethod + @abstractmethod + def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: + ... # pragma: no cover @staticmethod - def to_jwk(key_obj): + @abstractmethod + def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]: """ - Serializes a given RSA key into a JWK + Serializes a given key into a JWK """ - raise NotImplementedError @staticmethod - def from_jwk(jwk): + @abstractmethod + def from_jwk(jwk: str | JWKDict) -> Any: """ - Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object + Deserializes a given key from JWK back into a key object """ - raise NotImplementedError class NoneAlgorithm(Algorithm): @@ -153,7 +224,7 @@ class NoneAlgorithm(Algorithm): operations are required. """ - def prepare_key(self, key): + def prepare_key(self, key: str | None) -> None: if key == "": key = None @@ -162,12 +233,20 @@ class NoneAlgorithm(Algorithm): return key - def sign(self, msg, key): + def sign(self, msg: bytes, key: None) -> bytes: return b"" - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: None, sig: bytes) -> bool: return False + @staticmethod + def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn: + raise NotImplementedError() + + @staticmethod + def from_jwk(jwk: str | JWKDict) -> NoReturn: + raise NotImplementedError() + class HMACAlgorithm(Algorithm): """ @@ -175,38 +254,51 @@ class HMACAlgorithm(Algorithm): and the specified hash function. """ - SHA256 = hashlib.sha256 - SHA384 = hashlib.sha384 - SHA512 = hashlib.sha512 + SHA256: ClassVar[HashlibHash] = hashlib.sha256 + SHA384: ClassVar[HashlibHash] = hashlib.sha384 + SHA512: ClassVar[HashlibHash] = hashlib.sha512 - def __init__(self, hash_alg): + def __init__(self, hash_alg: HashlibHash) -> None: self.hash_alg = hash_alg - def prepare_key(self, key): - key = force_bytes(key) + def prepare_key(self, key: str | bytes) -> bytes: + key_bytes = force_bytes(key) - if is_pem_format(key) or is_ssh_key(key): + if is_pem_format(key_bytes) or is_ssh_key(key_bytes): raise InvalidKeyError( "The specified key is an asymmetric key or x509 certificate and" " should not be used as an HMAC secret." ) - return key + return key_bytes + + @overload + @staticmethod + def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict: + ... # pragma: no cover + + @overload + @staticmethod + def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str: + ... # pragma: no cover @staticmethod - def to_jwk(key_obj): - return json.dumps( - { - "k": base64url_encode(force_bytes(key_obj)).decode(), - "kty": "oct", - } - ) + def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]: + jwk = { + "k": base64url_encode(force_bytes(key_obj)).decode(), + "kty": "oct", + } + + if as_dict: + return jwk + else: + return json.dumps(jwk) @staticmethod - def from_jwk(jwk): + def from_jwk(jwk: str | JWKDict) -> bytes: try: if isinstance(jwk, str): - obj = json.loads(jwk) + obj: JWKDict = json.loads(jwk) elif isinstance(jwk, dict): obj = jwk else: @@ -219,10 +311,10 @@ class HMACAlgorithm(Algorithm): return base64url_decode(obj["k"]) - def sign(self, msg, key): + def sign(self, msg: bytes, key: bytes) -> bytes: return hmac.new(key, msg, self.hash_alg).digest() - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool: return hmac.compare_digest(sig, self.sign(msg, key)) @@ -234,36 +326,49 @@ if has_crypto: RSASSA-PKCS-v1_5 and the specified hash function. """ - SHA256 = hashes.SHA256 - SHA384 = hashes.SHA384 - SHA512 = hashes.SHA512 + SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 + SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 + SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 - def __init__(self, hash_alg): + def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: self.hash_alg = hash_alg - def prepare_key(self, key): + def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys: if isinstance(key, (RSAPrivateKey, RSAPublicKey)): return key if not isinstance(key, (bytes, str)): raise TypeError("Expecting a PEM-formatted key.") - key = force_bytes(key) + key_bytes = force_bytes(key) try: - if key.startswith(b"ssh-rsa"): - key = load_ssh_public_key(key) + if key_bytes.startswith(b"ssh-rsa"): + return cast(RSAPublicKey, load_ssh_public_key(key_bytes)) else: - key = load_pem_private_key(key, password=None) + return cast( + RSAPrivateKey, load_pem_private_key(key_bytes, password=None) + ) except ValueError: - key = load_pem_public_key(key) - return key + return cast(RSAPublicKey, load_pem_public_key(key_bytes)) + + @overload + @staticmethod + def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict: + ... # pragma: no cover + + @overload + @staticmethod + def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str: + ... # pragma: no cover @staticmethod - def to_jwk(key_obj): - obj = None + def to_jwk( + key_obj: AllowedRSAKeys, as_dict: bool = False + ) -> Union[JWKDict, str]: + obj: dict[str, Any] | None = None - if getattr(key_obj, "private_numbers", None): + if hasattr(key_obj, "private_numbers"): # Private key numbers = key_obj.private_numbers() @@ -280,7 +385,7 @@ if has_crypto: "qi": to_base64url_uint(numbers.iqmp).decode(), } - elif getattr(key_obj, "verify", None): + elif hasattr(key_obj, "verify"): # Public key numbers = key_obj.public_numbers() @@ -293,10 +398,13 @@ if has_crypto: else: raise InvalidKeyError("Not a public or private key") - return json.dumps(obj) + if as_dict: + return obj + else: + return json.dumps(obj) @staticmethod - def from_jwk(jwk): + def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys: try: if isinstance(jwk, str): obj = json.loads(jwk) @@ -360,19 +468,17 @@ if has_crypto: return numbers.private_key() elif "n" in obj and "e" in obj: # Public key - numbers = RSAPublicNumbers( + return RSAPublicNumbers( from_base64url_uint(obj["e"]), from_base64url_uint(obj["n"]), - ) - - return numbers.public_key() + ).public_key() else: raise InvalidKeyError("Not a public or private key") - def sign(self, msg, key): + def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: return key.sign(msg, padding.PKCS1v15(), self.hash_alg()) - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: try: key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg()) return True @@ -385,63 +491,79 @@ if has_crypto: ECDSA and the specified hash function """ - SHA256 = hashes.SHA256 - SHA384 = hashes.SHA384 - SHA512 = hashes.SHA512 + SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256 + SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384 + SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512 - def __init__(self, hash_alg): + def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None: self.hash_alg = hash_alg - def prepare_key(self, key): + def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys: if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): return key if not isinstance(key, (bytes, str)): raise TypeError("Expecting a PEM-formatted key.") - key = force_bytes(key) + key_bytes = force_bytes(key) # Attempt to load key. We don't know if it's # a Signing Key or a Verifying Key, so we try # the Verifying Key first. try: - if key.startswith(b"ecdsa-sha2-"): - key = load_ssh_public_key(key) + if key_bytes.startswith(b"ecdsa-sha2-"): + crypto_key = load_ssh_public_key(key_bytes) else: - key = load_pem_public_key(key) + crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment] except ValueError: - key = load_pem_private_key(key, password=None) + crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] # Explicit check the key to prevent confusing errors from cryptography - if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): + if not isinstance( + crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey) + ): raise InvalidKeyError( "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms" ) - return key + return crypto_key - def sign(self, msg, key): - der_sig = key.sign(msg, ec.ECDSA(self.hash_alg())) + def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes: + der_sig = key.sign(msg, ECDSA(self.hash_alg())) return der_to_raw_signature(der_sig, key.curve) - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool: try: der_sig = raw_to_der_signature(sig, key.curve) except ValueError: return False try: - if isinstance(key, EllipticCurvePrivateKey): - key = key.public_key() - key.verify(der_sig, msg, ec.ECDSA(self.hash_alg())) + public_key = ( + key.public_key() + if isinstance(key, EllipticCurvePrivateKey) + else key + ) + public_key.verify(der_sig, msg, ECDSA(self.hash_alg())) return True except InvalidSignature: return False + @overload @staticmethod - def to_jwk(key_obj): + def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict: + ... # pragma: no cover + @overload + @staticmethod + def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str: + ... # pragma: no cover + + @staticmethod + def to_jwk( + key_obj: AllowedECKeys, as_dict: bool = False + ) -> Union[JWKDict, str]: if isinstance(key_obj, EllipticCurvePrivateKey): public_numbers = key_obj.public_key().public_numbers() elif isinstance(key_obj, EllipticCurvePublicKey): @@ -449,18 +571,18 @@ if has_crypto: else: raise InvalidKeyError("Not a public or private key") - if isinstance(key_obj.curve, ec.SECP256R1): + if isinstance(key_obj.curve, SECP256R1): crv = "P-256" - elif isinstance(key_obj.curve, ec.SECP384R1): + elif isinstance(key_obj.curve, SECP384R1): crv = "P-384" - elif isinstance(key_obj.curve, ec.SECP521R1): + elif isinstance(key_obj.curve, SECP521R1): crv = "P-521" - elif isinstance(key_obj.curve, ec.SECP256K1): + elif isinstance(key_obj.curve, SECP256K1): crv = "secp256k1" else: raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") - obj = { + obj: dict[str, Any] = { "kty": "EC", "crv": crv, "x": to_base64url_uint(public_numbers.x).decode(), @@ -472,10 +594,13 @@ if has_crypto: key_obj.private_numbers().private_value ).decode() - return json.dumps(obj) + if as_dict: + return obj + else: + return json.dumps(obj) @staticmethod - def from_jwk(jwk): + def from_jwk(jwk: str | JWKDict) -> AllowedECKeys: try: if isinstance(jwk, str): obj = json.loads(jwk) @@ -496,24 +621,26 @@ if has_crypto: y = base64url_decode(obj.get("y")) curve = obj.get("crv") + curve_obj: EllipticCurve + if curve == "P-256": if len(x) == len(y) == 32: - curve_obj = ec.SECP256R1() + curve_obj = SECP256R1() else: raise InvalidKeyError("Coords should be 32 bytes for curve P-256") elif curve == "P-384": if len(x) == len(y) == 48: - curve_obj = ec.SECP384R1() + curve_obj = SECP384R1() else: raise InvalidKeyError("Coords should be 48 bytes for curve P-384") elif curve == "P-521": if len(x) == len(y) == 66: - curve_obj = ec.SECP521R1() + curve_obj = SECP521R1() else: raise InvalidKeyError("Coords should be 66 bytes for curve P-521") elif curve == "secp256k1": if len(x) == len(y) == 32: - curve_obj = ec.SECP256K1() + curve_obj = SECP256K1() else: raise InvalidKeyError( "Coords should be 32 bytes for curve secp256k1" @@ -521,7 +648,7 @@ if has_crypto: else: raise InvalidKeyError(f"Invalid curve: {curve}") - public_numbers = ec.EllipticCurvePublicNumbers( + public_numbers = EllipticCurvePublicNumbers( x=int.from_bytes(x, byteorder="big"), y=int.from_bytes(y, byteorder="big"), curve=curve_obj, @@ -536,7 +663,7 @@ if has_crypto: "D should be {} bytes for curve {}", len(x), curve ) - return ec.EllipticCurvePrivateNumbers( + return EllipticCurvePrivateNumbers( int.from_bytes(d, byteorder="big"), public_numbers ).private_key() @@ -545,24 +672,24 @@ if has_crypto: Performs a signature using RSASSA-PSS with MGF1 """ - def sign(self, msg, key): + def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes: return key.sign( msg, padding.PSS( mgf=padding.MGF1(self.hash_alg()), - salt_length=self.hash_alg.digest_size, + salt_length=self.hash_alg().digest_size, ), self.hash_alg(), ) - def verify(self, msg, key, sig): + def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool: try: key.verify( sig, msg, padding.PSS( mgf=padding.MGF1(self.hash_alg()), - salt_length=self.hash_alg.digest_size, + salt_length=self.hash_alg().digest_size, ), self.hash_alg(), ) @@ -577,21 +704,20 @@ if has_crypto: This class requires ``cryptography>=2.6`` to be installed. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: pass - def prepare_key(self, key): + def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys: if isinstance(key, (bytes, str)): - if isinstance(key, str): - key = key.encode("utf-8") - str_key = key.decode("utf-8") + key_str = key.decode("utf-8") if isinstance(key, bytes) else key + key_bytes = key.encode("utf-8") if isinstance(key, str) else key - if "-----BEGIN PUBLIC" in str_key: - key = load_pem_public_key(key) - elif "-----BEGIN PRIVATE" in str_key: - key = load_pem_private_key(key, password=None) - elif str_key[0:4] == "ssh-": - key = load_ssh_public_key(key) + if "-----BEGIN PUBLIC" in key_str: + key = load_pem_public_key(key_bytes) # type: ignore[assignment] + elif "-----BEGIN PRIVATE" in key_str: + key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment] + elif key_str[0:4] == "ssh-": + key = load_ssh_public_key(key_bytes) # type: ignore[assignment] # Explicit check the key to prevent confusing errors from cryptography if not isinstance( @@ -604,7 +730,9 @@ if has_crypto: return key - def sign(self, msg, key): + def sign( + self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey + ) -> bytes: """ Sign a message ``msg`` using the EdDSA private key ``key`` :param str|bytes msg: Message to sign @@ -612,10 +740,12 @@ if has_crypto: or :class:`.Ed448PrivateKey` isinstance :return bytes signature: The signature, as bytes """ - msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg - return key.sign(msg) + msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg + return key.sign(msg_bytes) - def verify(self, msg, key, sig): + def verify( + self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes + ) -> bool: """ Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` @@ -626,31 +756,48 @@ if has_crypto: :return bool verified: True if signature is valid, False if not. """ try: - msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg - sig = bytes(sig, "utf-8") if type(sig) is not bytes else sig + msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg + sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig - if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): - key = key.public_key() - key.verify(sig, msg) + public_key = ( + key.public_key() + if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)) + else key + ) + public_key.verify(sig_bytes, msg_bytes) return True # If no exception was raised, the signature is valid. - except cryptography.exceptions.InvalidSignature: + except InvalidSignature: return False + @overload @staticmethod - def to_jwk(key): + def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict: + ... # pragma: no cover + + @overload + @staticmethod + def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str: + ... # pragma: no cover + + @staticmethod + def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]: if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): x = key.public_bytes( encoding=Encoding.Raw, format=PublicFormat.Raw, ) crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448" - return json.dumps( - { - "x": base64url_encode(force_bytes(x)).decode(), - "kty": "OKP", - "crv": crv, - } - ) + + obj = { + "x": base64url_encode(force_bytes(x)).decode(), + "kty": "OKP", + "crv": crv, + } + + if as_dict: + return obj + else: + return json.dumps(obj) if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): d = key.private_bytes( @@ -665,19 +812,22 @@ if has_crypto: ) crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448" - return json.dumps( - { - "x": base64url_encode(force_bytes(x)).decode(), - "d": base64url_encode(force_bytes(d)).decode(), - "kty": "OKP", - "crv": crv, - } - ) + obj = { + "x": base64url_encode(force_bytes(x)).decode(), + "d": base64url_encode(force_bytes(d)).decode(), + "kty": "OKP", + "crv": crv, + } + + if as_dict: + return obj + else: + return json.dumps(obj) raise InvalidKeyError("Not a public or private key") @staticmethod - def from_jwk(jwk): + def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys: try: if isinstance(jwk, str): obj = json.loads(jwk) diff --git a/lib/jwt/api_jwk.py b/lib/jwt/api_jwk.py index aa3dd321..456c7f4d 100644 --- a/lib/jwt/api_jwk.py +++ b/lib/jwt/api_jwk.py @@ -2,13 +2,15 @@ from __future__ import annotations import json import time +from typing import Any -from .algorithms import get_default_algorithms -from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError +from .algorithms import get_default_algorithms, has_crypto, requires_cryptography +from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError, PyJWTError +from .types import JWKDict class PyJWK: - def __init__(self, jwk_data, algorithm=None): + def __init__(self, jwk_data: JWKDict, algorithm: str | None = None) -> None: self._algorithms = get_default_algorithms() self._jwk_data = jwk_data @@ -47,37 +49,40 @@ class PyJWK: else: raise InvalidKeyError(f"Unsupported kty: {kty}") + if not has_crypto and algorithm in requires_cryptography: + raise PyJWKError(f"{algorithm} requires 'cryptography' to be installed.") + self.Algorithm = self._algorithms.get(algorithm) if not self.Algorithm: - raise PyJWKError(f"Unable to find a algorithm for key: {self._jwk_data}") + raise PyJWKError(f"Unable to find an algorithm for key: {self._jwk_data}") self.key = self.Algorithm.from_jwk(self._jwk_data) @staticmethod - def from_dict(obj, algorithm=None): + def from_dict(obj: JWKDict, algorithm: str | None = None) -> "PyJWK": return PyJWK(obj, algorithm) @staticmethod - def from_json(data, algorithm=None): + def from_json(data: str, algorithm: None = None) -> "PyJWK": obj = json.loads(data) return PyJWK.from_dict(obj, algorithm) @property - def key_type(self): + def key_type(self) -> str | None: return self._jwk_data.get("kty", None) @property - def key_id(self): + def key_id(self) -> str | None: return self._jwk_data.get("kid", None) @property - def public_key_use(self): + def public_key_use(self) -> str | None: return self._jwk_data.get("use", None) class PyJWKSet: - def __init__(self, keys: list[dict]) -> None: + def __init__(self, keys: list[JWKDict]) -> None: self.keys = [] if not keys: @@ -89,24 +94,26 @@ class PyJWKSet: for key in keys: try: self.keys.append(PyJWK(key)) - except PyJWKError: + except PyJWTError: # skip unusable keys continue if len(self.keys) == 0: - raise PyJWKSetError("The JWK Set did not contain any usable keys") + raise PyJWKSetError( + "The JWK Set did not contain any usable keys. Perhaps 'cryptography' is not installed?" + ) @staticmethod - def from_dict(obj): + def from_dict(obj: dict[str, Any]) -> "PyJWKSet": keys = obj.get("keys", []) return PyJWKSet(keys) @staticmethod - def from_json(data): + def from_json(data: str) -> "PyJWKSet": obj = json.loads(data) return PyJWKSet.from_dict(obj) - def __getitem__(self, kid): + def __getitem__(self, kid: str) -> "PyJWK": for key in self.keys: if key.key_id == kid: return key @@ -118,8 +125,8 @@ class PyJWTSetWithTimestamp: self.jwk_set = jwk_set self.timestamp = time.monotonic() - def get_jwk_set(self): + def get_jwk_set(self) -> PyJWKSet: return self.jwk_set - def get_timestamp(self): + def get_timestamp(self) -> float: return self.timestamp diff --git a/lib/jwt/api_jws.py b/lib/jwt/api_jws.py index ab8490f9..fa6708cc 100644 --- a/lib/jwt/api_jws.py +++ b/lib/jwt/api_jws.py @@ -3,7 +3,7 @@ from __future__ import annotations import binascii import json import warnings -from typing import Any, Type +from typing import TYPE_CHECKING, Any from .algorithms import ( Algorithm, @@ -20,11 +20,18 @@ from .exceptions import ( from .utils import base64url_decode, base64url_encode from .warnings import RemovedInPyjwt3Warning +if TYPE_CHECKING: + from .algorithms import AllowedPrivateKeys, AllowedPublicKeys + class PyJWS: header_typ = "JWT" - def __init__(self, algorithms=None, options=None) -> None: + def __init__( + self, + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, + ) -> None: self._algorithms = get_default_algorithms() self._valid_algs = ( set(algorithms) if algorithms is not None else set(self._algorithms) @@ -96,11 +103,12 @@ class PyJWS: def encode( self, payload: bytes, - key: str, + key: AllowedPrivateKeys | str | bytes, algorithm: str | None = "HS256", headers: dict[str, Any] | None = None, - json_encoder: Type[json.JSONEncoder] | None = None, + json_encoder: type[json.JSONEncoder] | None = None, is_payload_detached: bool = False, + sort_headers: bool = True, ) -> str: segments = [] @@ -133,9 +141,8 @@ class PyJWS: # True is the standard value for b64, so no need for it del header["b64"] - # Fix for headers misorder - issue #715 json_header = json.dumps( - header, separators=(",", ":"), cls=json_encoder, sort_keys=True + header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers ).encode() segments.append(base64url_encode(json_header)) @@ -164,8 +171,8 @@ class PyJWS: def decode_complete( self, - jwt: str, - key: str = "", + jwt: str | bytes, + key: AllowedPublicKeys | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, @@ -209,13 +216,13 @@ class PyJWS: def decode( self, - jwt: str, - key: str = "", + jwt: str | bytes, + key: AllowedPublicKeys | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, **kwargs, - ) -> str: + ) -> Any: if kwargs: warnings.warn( "passing additional kwargs to decode() is deprecated " @@ -228,7 +235,7 @@ class PyJWS: ) return decoded["payload"] - def get_unverified_header(self, jwt: str | bytes) -> dict: + def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]: """Returns back the JWT header parameters as a dict() Note: The signature is not verified so the header parameters @@ -239,7 +246,7 @@ class PyJWS: return headers - def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]: + def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]: if isinstance(jwt, str): jwt = jwt.encode("utf-8") @@ -280,13 +287,15 @@ class PyJWS: def _verify_signature( self, signing_input: bytes, - header: dict, + header: dict[str, Any], signature: bytes, - key: str = "", + key: AllowedPublicKeys | str | bytes = "", algorithms: list[str] | None = None, ) -> None: - - alg = header.get("alg") + try: + alg = header["alg"] + except KeyError: + raise InvalidAlgorithmError("Algorithm not specified") if not alg or (algorithms is not None and alg not in algorithms): raise InvalidAlgorithmError("The specified alg value is not allowed") @@ -295,16 +304,16 @@ class PyJWS: alg_obj = self.get_algorithm_by_name(alg) except NotImplementedError as e: raise InvalidAlgorithmError("Algorithm not supported") from e - key = alg_obj.prepare_key(key) + prepared_key = alg_obj.prepare_key(key) - if not alg_obj.verify(signing_input, key, signature): + if not alg_obj.verify(signing_input, prepared_key, signature): raise InvalidSignatureError("Signature verification failed") def _validate_headers(self, headers: dict[str, Any]) -> None: if "kid" in headers: self._validate_kid(headers["kid"]) - def _validate_kid(self, kid: str) -> None: + def _validate_kid(self, kid: Any) -> None: if not isinstance(kid, str): raise InvalidTokenError("Key ID header parameter must be a string") diff --git a/lib/jwt/api_jwt.py b/lib/jwt/api_jwt.py index 4bb1ee1f..48d739ad 100644 --- a/lib/jwt/api_jwt.py +++ b/lib/jwt/api_jwt.py @@ -3,9 +3,9 @@ from __future__ import annotations import json import warnings from calendar import timegm -from collections.abc import Iterable, Mapping +from collections.abc import Iterable from datetime import datetime, timedelta, timezone -from typing import Any, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any from . import api_jws from .exceptions import ( @@ -19,15 +19,18 @@ from .exceptions import ( ) from .warnings import RemovedInPyjwt3Warning +if TYPE_CHECKING: + from .algorithms import AllowedPrivateKeys, AllowedPublicKeys + class PyJWT: - def __init__(self, options=None): + def __init__(self, options: dict[str, Any] | None = None) -> None: if options is None: options = {} - self.options = {**self._get_default_options(), **options} + self.options: dict[str, Any] = {**self._get_default_options(), **options} @staticmethod - def _get_default_options() -> Dict[str, Union[bool, List[str]]]: + def _get_default_options() -> dict[str, bool | list[str]]: return { "verify_signature": True, "verify_exp": True, @@ -40,16 +43,17 @@ class PyJWT: def encode( self, - payload: Dict[str, Any], - key: str, - algorithm: Optional[str] = "HS256", - headers: Optional[Dict[str, Any]] = None, - json_encoder: Optional[Type[json.JSONEncoder]] = None, + payload: dict[str, Any], + key: AllowedPrivateKeys | str | bytes, + algorithm: str | None = "HS256", + headers: dict[str, Any] | None = None, + json_encoder: type[json.JSONEncoder] | None = None, + sort_headers: bool = True, ) -> str: - # Check that we get a mapping - if not isinstance(payload, Mapping): + # Check that we get a dict + if not isinstance(payload, dict): raise TypeError( - "Expecting a mapping object, as JWT only supports " + "Expecting a dict object, as JWT only supports " "JSON objects as payloads." ) @@ -60,30 +64,57 @@ class PyJWT: if isinstance(payload.get(time_claim), datetime): payload[time_claim] = timegm(payload[time_claim].utctimetuple()) - json_payload = json.dumps( - payload, separators=(",", ":"), cls=json_encoder - ).encode("utf-8") + json_payload = self._encode_payload( + payload, + headers=headers, + json_encoder=json_encoder, + ) - return api_jws.encode(json_payload, key, algorithm, headers, json_encoder) + return api_jws.encode( + json_payload, + key, + algorithm, + headers, + json_encoder, + sort_headers=sort_headers, + ) + + def _encode_payload( + self, + payload: dict[str, Any], + headers: dict[str, Any] | None = None, + json_encoder: type[json.JSONEncoder] | None = None, + ) -> bytes: + """ + Encode a given payload to the bytes to be signed. + + This method is intended to be overridden by subclasses that need to + encode the payload in a different way, e.g. compress the payload. + """ + return json.dumps( + payload, + separators=(",", ":"), + cls=json_encoder, + ).encode("utf-8") def decode_complete( self, - jwt: str, - key: str = "", - algorithms: Optional[List[str]] = None, - options: Optional[Dict[str, Any]] = None, + jwt: str | bytes, + key: AllowedPublicKeys | str | bytes = "", + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 - verify: Optional[bool] = None, + verify: bool | None = None, # could be used as passthrough to api_jws, consider removal in pyjwt3 - detached_payload: Optional[bytes] = None, + detached_payload: bytes | None = None, # passthrough arguments to _validate_claims # consider putting in options - audience: Optional[Union[str, Iterable[str]]] = None, - issuer: Optional[str] = None, - leeway: Union[int, float, timedelta] = 0, + audience: str | Iterable[str] | None = None, + issuer: str | None = None, + leeway: float | timedelta = 0, # kwargs - **kwargs, - ) -> Dict[str, Any]: + **kwargs: Any, + ) -> dict[str, Any]: if kwargs: warnings.warn( "passing additional kwargs to decode_complete() is deprecated " @@ -125,12 +156,7 @@ class PyJWT: detached_payload=detached_payload, ) - try: - payload = json.loads(decoded["payload"]) - except ValueError as e: - raise DecodeError(f"Invalid payload string: {e}") - if not isinstance(payload, dict): - raise DecodeError("Invalid payload string: must be a json object") + payload = self._decode_payload(decoded) merged_options = {**self.options, **options} self._validate_claims( @@ -140,24 +166,40 @@ class PyJWT: decoded["payload"] = payload return decoded + def _decode_payload(self, decoded: dict[str, Any]) -> Any: + """ + Decode the payload from a JWS dictionary (payload, signature, header). + + This method is intended to be overridden by subclasses that need to + decode the payload in a different way, e.g. decompress compressed + payloads. + """ + try: + payload = json.loads(decoded["payload"]) + except ValueError as e: + raise DecodeError(f"Invalid payload string: {e}") + if not isinstance(payload, dict): + raise DecodeError("Invalid payload string: must be a json object") + return payload + def decode( self, - jwt: str, - key: str = "", - algorithms: Optional[List[str]] = None, - options: Optional[Dict[str, Any]] = None, + jwt: str | bytes, + key: AllowedPublicKeys | str | bytes = "", + algorithms: list[str] | None = None, + options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 - verify: Optional[bool] = None, + verify: bool | None = None, # could be used as passthrough to api_jws, consider removal in pyjwt3 - detached_payload: Optional[bytes] = None, + detached_payload: bytes | None = None, # passthrough arguments to _validate_claims # consider putting in options - audience: Optional[Union[str, Iterable[str]]] = None, - issuer: Optional[str] = None, - leeway: Union[int, float, timedelta] = 0, + audience: str | Iterable[str] | None = None, + issuer: str | None = None, + leeway: float | timedelta = 0, # kwargs - **kwargs, - ) -> Dict[str, Any]: + **kwargs: Any, + ) -> Any: if kwargs: warnings.warn( "passing additional kwargs to decode() is deprecated " @@ -178,7 +220,14 @@ class PyJWT: ) return decoded["payload"] - def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0): + def _validate_claims( + self, + payload: dict[str, Any], + options: dict[str, Any], + audience=None, + issuer=None, + leeway: float | timedelta = 0, + ) -> None: if isinstance(leeway, timedelta): leeway = leeway.total_seconds() @@ -187,7 +236,7 @@ class PyJWT: self._validate_required_claims(payload, options) - now = timegm(datetime.now(tz=timezone.utc).utctimetuple()) + now = datetime.now(tz=timezone.utc).timestamp() if "iat" in payload and options["verify_iat"]: self._validate_iat(payload, now, leeway) @@ -202,23 +251,38 @@ class PyJWT: self._validate_iss(payload, issuer) if options["verify_aud"]: - self._validate_aud(payload, audience) + self._validate_aud( + payload, audience, strict=options.get("strict_aud", False) + ) - def _validate_required_claims(self, payload, options): + def _validate_required_claims( + self, + payload: dict[str, Any], + options: dict[str, Any], + ) -> None: for claim in options["require"]: if payload.get(claim) is None: raise MissingRequiredClaimError(claim) - def _validate_iat(self, payload, now, leeway): - iat = payload["iat"] + def _validate_iat( + self, + payload: dict[str, Any], + now: float, + leeway: float, + ) -> None: try: - int(iat) + iat = int(payload["iat"]) except ValueError: raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.") if iat > (now + leeway): raise ImmatureSignatureError("The token is not yet valid (iat)") - def _validate_nbf(self, payload, now, leeway): + def _validate_nbf( + self, + payload: dict[str, Any], + now: float, + leeway: float, + ) -> None: try: nbf = int(payload["nbf"]) except ValueError: @@ -227,7 +291,12 @@ class PyJWT: if nbf > (now + leeway): raise ImmatureSignatureError("The token is not yet valid (nbf)") - def _validate_exp(self, payload, now, leeway): + def _validate_exp( + self, + payload: dict[str, Any], + now: float, + leeway: float, + ) -> None: try: exp = int(payload["exp"]) except ValueError: @@ -236,7 +305,13 @@ class PyJWT: if exp <= (now - leeway): raise ExpiredSignatureError("Signature has expired") - def _validate_aud(self, payload, audience): + def _validate_aud( + self, + payload: dict[str, Any], + audience: str | Iterable[str] | None, + *, + strict: bool = False, + ) -> None: if audience is None: if "aud" not in payload or not payload["aud"]: return @@ -251,6 +326,22 @@ class PyJWT: audience_claims = payload["aud"] + # In strict mode, we forbid list matching: the supplied audience + # must be a string, and it must exactly match the audience claim. + if strict: + # Only a single audience is allowed in strict mode. + if not isinstance(audience, str): + raise InvalidAudienceError("Invalid audience (strict)") + + # Only a single audience claim is allowed in strict mode. + if not isinstance(audience_claims, str): + raise InvalidAudienceError("Invalid claim format in token (strict)") + + if audience != audience_claims: + raise InvalidAudienceError("Audience doesn't match (strict)") + + return + if isinstance(audience_claims, str): audience_claims = [audience_claims] if not isinstance(audience_claims, list): @@ -262,9 +353,9 @@ class PyJWT: audience = [audience] if all(aud not in audience_claims for aud in audience): - raise InvalidAudienceError("Invalid audience") + raise InvalidAudienceError("Audience doesn't match") - def _validate_iss(self, payload, issuer): + def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None: if issuer is None: return diff --git a/lib/jwt/exceptions.py b/lib/jwt/exceptions.py index ee201add..8ac6ecf7 100644 --- a/lib/jwt/exceptions.py +++ b/lib/jwt/exceptions.py @@ -47,10 +47,10 @@ class InvalidAlgorithmError(InvalidTokenError): class MissingRequiredClaimError(InvalidTokenError): - def __init__(self, claim): + def __init__(self, claim: str) -> None: self.claim = claim - def __str__(self): + def __str__(self) -> str: return f'Token is missing the "{self.claim}" claim' @@ -64,3 +64,7 @@ class PyJWKSetError(PyJWTError): class PyJWKClientError(PyJWTError): pass + + +class PyJWKClientConnectionError(PyJWKClientError): + pass diff --git a/lib/jwt/help.py b/lib/jwt/help.py index 0c02eb92..80b0ca56 100644 --- a/lib/jwt/help.py +++ b/lib/jwt/help.py @@ -7,8 +7,10 @@ from . import __version__ as pyjwt_version try: import cryptography + + cryptography_version = cryptography.__version__ except ModuleNotFoundError: - cryptography = None + cryptography_version = "" def info() -> Dict[str, Dict[str, str]]: @@ -29,7 +31,7 @@ def info() -> Dict[str, Dict[str, str]]: if implementation == "CPython": implementation_version = platform.python_version() elif implementation == "PyPy": - pypy_version_info = getattr(sys, "pypy_version_info") + pypy_version_info = sys.pypy_version_info # type: ignore[attr-defined] implementation_version = ( f"{pypy_version_info.major}." f"{pypy_version_info.minor}." @@ -48,7 +50,7 @@ def info() -> Dict[str, Dict[str, str]]: "name": implementation, "version": implementation_version, }, - "cryptography": {"version": getattr(cryptography, "__version__", "")}, + "cryptography": {"version": cryptography_version}, "pyjwt": {"version": pyjwt_version}, } diff --git a/lib/jwt/jwk_set_cache.py b/lib/jwt/jwk_set_cache.py index e8c2a7e0..24325630 100644 --- a/lib/jwt/jwk_set_cache.py +++ b/lib/jwt/jwk_set_cache.py @@ -5,11 +5,11 @@ from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp class JWKSetCache: - def __init__(self, lifespan: int): + def __init__(self, lifespan: int) -> None: self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None self.lifespan = lifespan - def put(self, jwk_set: PyJWKSet): + def put(self, jwk_set: PyJWKSet) -> None: if jwk_set is not None: self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set) else: @@ -23,7 +23,6 @@ class JWKSetCache: return self.jwk_set_with_timestamp.get_jwk_set() def is_expired(self) -> bool: - return ( self.jwk_set_with_timestamp is not None and self.lifespan > -1 diff --git a/lib/jwt/jwks_client.py b/lib/jwt/jwks_client.py index b4e98007..f19b10ac 100644 --- a/lib/jwt/jwks_client.py +++ b/lib/jwt/jwks_client.py @@ -1,12 +1,13 @@ import json import urllib.request from functools import lru_cache -from typing import Any, List, Optional +from ssl import SSLContext +from typing import Any, Dict, List, Optional from urllib.error import URLError from .api_jwk import PyJWK, PyJWKSet from .api_jwt import decode_complete as decode_token -from .exceptions import PyJWKClientError +from .exceptions import PyJWKClientConnectionError, PyJWKClientError from .jwk_set_cache import JWKSetCache @@ -18,9 +19,17 @@ class PyJWKClient: max_cached_keys: int = 16, cache_jwk_set: bool = True, lifespan: int = 300, + headers: Optional[Dict[str, Any]] = None, + timeout: int = 30, + ssl_context: Optional[SSLContext] = None, ): + if headers is None: + headers = {} self.uri = uri self.jwk_set_cache: Optional[JWKSetCache] = None + self.headers = headers + self.timeout = timeout + self.ssl_context = ssl_context if cache_jwk_set: # Init jwt set cache with default or given lifespan. @@ -41,10 +50,15 @@ class PyJWKClient: def fetch_data(self) -> Any: jwk_set: Any = None try: - with urllib.request.urlopen(self.uri) as response: + r = urllib.request.Request(url=self.uri, headers=self.headers) + with urllib.request.urlopen( + r, timeout=self.timeout, context=self.ssl_context + ) as response: jwk_set = json.load(response) - except URLError as e: - raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"') + except (URLError, TimeoutError) as e: + raise PyJWKClientConnectionError( + f'Fail to fetch data from the url, err: "{e}"' + ) else: return jwk_set finally: @@ -59,6 +73,9 @@ class PyJWKClient: if data is None: data = self.fetch_data() + if not isinstance(data, dict): + raise PyJWKClientError("The JWKS endpoint did not return a JSON object") + return PyJWKSet.from_dict(data) def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]: diff --git a/lib/jwt/types.py b/lib/jwt/types.py new file mode 100644 index 00000000..7d993520 --- /dev/null +++ b/lib/jwt/types.py @@ -0,0 +1,5 @@ +from typing import Any, Callable, Dict + +JWKDict = Dict[str, Any] + +HashlibHash = Callable[..., Any] diff --git a/lib/jwt/utils.py b/lib/jwt/utils.py index 16cae066..81c5ee41 100644 --- a/lib/jwt/utils.py +++ b/lib/jwt/utils.py @@ -10,10 +10,10 @@ try: encode_dss_signature, ) except ModuleNotFoundError: - EllipticCurve = None + pass -def force_bytes(value: Union[str, bytes]) -> bytes: +def force_bytes(value: Union[bytes, str]) -> bytes: if isinstance(value, str): return value.encode("utf-8") elif isinstance(value, bytes): @@ -22,16 +22,15 @@ def force_bytes(value: Union[str, bytes]) -> bytes: raise TypeError("Expected a string value") -def base64url_decode(input: Union[str, bytes]) -> bytes: - if isinstance(input, str): - input = input.encode("ascii") +def base64url_decode(input: Union[bytes, str]) -> bytes: + input_bytes = force_bytes(input) - rem = len(input) % 4 + rem = len(input_bytes) % 4 if rem > 0: - input += b"=" * (4 - rem) + input_bytes += b"=" * (4 - rem) - return base64.urlsafe_b64decode(input) + return base64.urlsafe_b64decode(input_bytes) def base64url_encode(input: bytes) -> bytes: @@ -50,11 +49,8 @@ def to_base64url_uint(val: int) -> bytes: return base64url_encode(int_bytes) -def from_base64url_uint(val: Union[str, bytes]) -> int: - if isinstance(val, str): - val = val.encode("ascii") - - data = base64url_decode(val) +def from_base64url_uint(val: Union[bytes, str]) -> int: + data = base64url_decode(force_bytes(val)) return int.from_bytes(data, byteorder="big") @@ -78,7 +74,7 @@ def bytes_from_int(val: int) -> bytes: return val.to_bytes(byte_length, "big", signed=False) -def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes: +def der_to_raw_signature(der_sig: bytes, curve: "EllipticCurve") -> bytes: num_bits = curve.key_size num_bytes = (num_bits + 7) // 8 @@ -87,7 +83,7 @@ def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes: return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes) -def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes: +def raw_to_der_signature(raw_sig: bytes, curve: "EllipticCurve") -> bytes: num_bits = curve.key_size num_bytes = (num_bits + 7) // 8 @@ -97,7 +93,7 @@ def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes: r = bytes_to_number(raw_sig[:num_bytes]) s = bytes_to_number(raw_sig[num_bytes:]) - return encode_dss_signature(r, s) + return bytes(encode_dss_signature(r, s)) # Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252 diff --git a/requirements.txt b/requirements.txt index 755f9a21..0ac5222b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,7 +31,7 @@ paho-mqtt==1.6.1 plexapi==4.13.4 portend==3.1.0 profilehooks==1.12.0 -PyJWT==2.6.0 +PyJWT==2.8.0 pyparsing==3.0.9 python-dateutil==2.8.2 python-twitter==3.5