diff --git a/lib/jwt/__init__.py b/lib/jwt/__init__.py index 68d09c1c..b7a258d7 100644 --- a/lib/jwt/__init__.py +++ b/lib/jwt/__init__.py @@ -27,7 +27,7 @@ from .exceptions import ( ) from .jwks_client import PyJWKClient -__version__ = "2.8.0" +__version__ = "2.9.0" __title__ = "PyJWT" __description__ = "JSON Web Token implementation in Python" diff --git a/lib/jwt/algorithms.py b/lib/jwt/algorithms.py index ed187152..9be50b20 100644 --- a/lib/jwt/algorithms.py +++ b/lib/jwt/algorithms.py @@ -3,9 +3,8 @@ from __future__ import annotations import hashlib import hmac import json -import sys from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload +from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload from .exceptions import InvalidKeyError from .types import HashlibHash, JWKDict @@ -21,14 +20,8 @@ from .utils import ( to_base64url_uint, ) -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal - - try: - from cryptography.exceptions import InvalidSignature + from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding @@ -194,18 +187,16 @@ class Algorithm(ABC): @overload @staticmethod @abstractmethod - def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: - ... # pragma: no cover + def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover @overload @staticmethod @abstractmethod - def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: - ... # pragma: no cover + def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover @staticmethod @abstractmethod - def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]: + def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str: """ Serializes a given key into a JWK """ @@ -274,16 +265,18 @@ class HMACAlgorithm(Algorithm): @overload @staticmethod - def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict: - ... # pragma: no cover + def to_jwk( + key_obj: str | bytes, as_dict: Literal[True] + ) -> JWKDict: ... # pragma: no cover @overload @staticmethod - def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str: - ... # pragma: no cover + def to_jwk( + key_obj: str | bytes, as_dict: Literal[False] = False + ) -> str: ... # pragma: no cover @staticmethod - def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]: + def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str: jwk = { "k": base64url_encode(force_bytes(key_obj)).decode(), "kty": "oct", @@ -350,22 +343,25 @@ if has_crypto: RSAPrivateKey, load_pem_private_key(key_bytes, password=None) ) except ValueError: - return cast(RSAPublicKey, load_pem_public_key(key_bytes)) + try: + return cast(RSAPublicKey, load_pem_public_key(key_bytes)) + except (ValueError, UnsupportedAlgorithm): + raise InvalidKeyError("Could not parse the provided public key.") @overload - @staticmethod - def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict: - ... # pragma: no cover - - @overload - @staticmethod - def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str: - ... # pragma: no cover - @staticmethod def to_jwk( - key_obj: AllowedRSAKeys, as_dict: bool = False - ) -> Union[JWKDict, str]: + key_obj: AllowedRSAKeys, as_dict: Literal[True] + ) -> JWKDict: ... # pragma: no cover + + @overload + @staticmethod + def to_jwk( + key_obj: AllowedRSAKeys, as_dict: Literal[False] = False + ) -> str: ... # pragma: no cover + + @staticmethod + def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str: obj: dict[str, Any] | None = None if hasattr(key_obj, "private_numbers"): @@ -533,7 +529,7 @@ if has_crypto: return der_to_raw_signature(der_sig, key.curve) - def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool: + def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool: try: der_sig = raw_to_der_signature(sig, key.curve) except ValueError: @@ -552,18 +548,18 @@ if has_crypto: @overload @staticmethod - def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict: - ... # pragma: no cover + def to_jwk( + key_obj: AllowedECKeys, as_dict: Literal[True] + ) -> JWKDict: ... # pragma: no cover @overload @staticmethod - def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str: - ... # pragma: no cover + def to_jwk( + key_obj: AllowedECKeys, as_dict: Literal[False] = False + ) -> str: ... # pragma: no cover @staticmethod - def to_jwk( - key_obj: AllowedECKeys, as_dict: bool = False - ) -> Union[JWKDict, str]: + def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str: if isinstance(key_obj, EllipticCurvePrivateKey): public_numbers = key_obj.public_key().public_numbers() elif isinstance(key_obj, EllipticCurvePublicKey): @@ -771,16 +767,18 @@ if has_crypto: @overload @staticmethod - def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict: - ... # pragma: no cover + def to_jwk( + key: AllowedOKPKeys, as_dict: Literal[True] + ) -> JWKDict: ... # pragma: no cover @overload @staticmethod - def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str: - ... # pragma: no cover + def to_jwk( + key: AllowedOKPKeys, as_dict: Literal[False] = False + ) -> str: ... # pragma: no cover @staticmethod - def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]: + def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str: if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): x = key.public_bytes( encoding=Encoding.Raw, diff --git a/lib/jwt/api_jwk.py b/lib/jwt/api_jwk.py index 456c7f4d..02f4679c 100644 --- a/lib/jwt/api_jwk.py +++ b/lib/jwt/api_jwk.py @@ -5,7 +5,13 @@ import time from typing import Any from .algorithms import get_default_algorithms, has_crypto, requires_cryptography -from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError, PyJWTError +from .exceptions import ( + InvalidKeyError, + MissingCryptographyError, + PyJWKError, + PyJWKSetError, + PyJWTError, +) from .types import JWKDict @@ -50,21 +56,25 @@ class PyJWK: raise InvalidKeyError(f"Unsupported kty: {kty}") if not has_crypto and algorithm in requires_cryptography: - raise PyJWKError(f"{algorithm} requires 'cryptography' to be installed.") + raise MissingCryptographyError( + f"{algorithm} requires 'cryptography' to be installed." + ) - self.Algorithm = self._algorithms.get(algorithm) + self.algorithm_name = algorithm - if not self.Algorithm: + if algorithm in self._algorithms: + self.Algorithm = self._algorithms[algorithm] + else: raise PyJWKError(f"Unable to find an algorithm for key: {self._jwk_data}") self.key = self.Algorithm.from_jwk(self._jwk_data) @staticmethod - def from_dict(obj: JWKDict, algorithm: str | None = None) -> "PyJWK": + def from_dict(obj: JWKDict, algorithm: str | None = None) -> PyJWK: return PyJWK(obj, algorithm) @staticmethod - def from_json(data: str, algorithm: None = None) -> "PyJWK": + def from_json(data: str, algorithm: None = None) -> PyJWK: obj = json.loads(data) return PyJWK.from_dict(obj, algorithm) @@ -94,7 +104,9 @@ class PyJWKSet: for key in keys: try: self.keys.append(PyJWK(key)) - except PyJWTError: + except PyJWTError as error: + if isinstance(error, MissingCryptographyError): + raise error # skip unusable keys continue @@ -104,16 +116,16 @@ class PyJWKSet: ) @staticmethod - def from_dict(obj: dict[str, Any]) -> "PyJWKSet": + def from_dict(obj: dict[str, Any]) -> PyJWKSet: keys = obj.get("keys", []) return PyJWKSet(keys) @staticmethod - def from_json(data: str) -> "PyJWKSet": + def from_json(data: str) -> PyJWKSet: obj = json.loads(data) return PyJWKSet.from_dict(obj) - def __getitem__(self, kid: str) -> "PyJWK": + def __getitem__(self, kid: str) -> PyJWK: for key in self.keys: if key.key_id == kid: return key diff --git a/lib/jwt/api_jws.py b/lib/jwt/api_jws.py index fa6708cc..5822ebf6 100644 --- a/lib/jwt/api_jws.py +++ b/lib/jwt/api_jws.py @@ -11,6 +11,7 @@ from .algorithms import ( has_crypto, requires_cryptography, ) +from .api_jwk import PyJWK from .exceptions import ( DecodeError, InvalidAlgorithmError, @@ -172,7 +173,7 @@ class PyJWS: def decode_complete( self, jwt: str | bytes, - key: AllowedPublicKeys | str | bytes = "", + key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, @@ -190,7 +191,7 @@ class PyJWS: merged_options = {**self.options, **options} verify_signature = merged_options["verify_signature"] - if verify_signature and not algorithms: + if verify_signature and not algorithms and not isinstance(key, PyJWK): raise DecodeError( 'It is required that you pass in a value for the "algorithms" argument when calling decode().' ) @@ -217,7 +218,7 @@ class PyJWS: def decode( self, jwt: str | bytes, - key: AllowedPublicKeys | str | bytes = "", + key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, @@ -289,9 +290,11 @@ class PyJWS: signing_input: bytes, header: dict[str, Any], signature: bytes, - key: AllowedPublicKeys | str | bytes = "", + key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, ) -> None: + if algorithms is None and isinstance(key, PyJWK): + algorithms = [key.algorithm_name] try: alg = header["alg"] except KeyError: @@ -300,11 +303,15 @@ class PyJWS: if not alg or (algorithms is not None and alg not in algorithms): raise InvalidAlgorithmError("The specified alg value is not allowed") - try: - alg_obj = self.get_algorithm_by_name(alg) - except NotImplementedError as e: - raise InvalidAlgorithmError("Algorithm not supported") from e - prepared_key = alg_obj.prepare_key(key) + if isinstance(key, PyJWK): + alg_obj = key.Algorithm + prepared_key = key.key + else: + try: + alg_obj = self.get_algorithm_by_name(alg) + except NotImplementedError as e: + raise InvalidAlgorithmError("Algorithm not supported") from e + prepared_key = alg_obj.prepare_key(key) if not alg_obj.verify(signing_input, prepared_key, signature): raise InvalidSignatureError("Signature verification failed") diff --git a/lib/jwt/api_jwt.py b/lib/jwt/api_jwt.py index 48d739ad..7a07c336 100644 --- a/lib/jwt/api_jwt.py +++ b/lib/jwt/api_jwt.py @@ -5,7 +5,7 @@ import warnings from calendar import timegm from collections.abc import Iterable from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, List from . import api_jws from .exceptions import ( @@ -21,6 +21,7 @@ from .warnings import RemovedInPyjwt3Warning if TYPE_CHECKING: from .algorithms import AllowedPrivateKeys, AllowedPublicKeys + from .api_jwk import PyJWK class PyJWT: @@ -100,7 +101,7 @@ class PyJWT: def decode_complete( self, jwt: str | bytes, - key: AllowedPublicKeys | str | bytes = "", + key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 @@ -110,7 +111,7 @@ class PyJWT: # passthrough arguments to _validate_claims # consider putting in options audience: str | Iterable[str] | None = None, - issuer: str | None = None, + issuer: str | List[str] | None = None, leeway: float | timedelta = 0, # kwargs **kwargs: Any, @@ -185,7 +186,7 @@ class PyJWT: def decode( self, jwt: str | bytes, - key: AllowedPublicKeys | str | bytes = "", + key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 @@ -195,7 +196,7 @@ class PyJWT: # passthrough arguments to _validate_claims # consider putting in options audience: str | Iterable[str] | None = None, - issuer: str | None = None, + issuer: str | List[str] | None = None, leeway: float | timedelta = 0, # kwargs **kwargs: Any, @@ -300,7 +301,7 @@ class PyJWT: try: exp = int(payload["exp"]) except ValueError: - raise DecodeError("Expiration Time claim (exp) must be an" " integer.") + raise DecodeError("Expiration Time claim (exp) must be an integer.") if exp <= (now - leeway): raise ExpiredSignatureError("Signature has expired") @@ -362,8 +363,12 @@ class PyJWT: if "iss" not in payload: raise MissingRequiredClaimError("iss") - if payload["iss"] != issuer: - raise InvalidIssuerError("Invalid issuer") + if isinstance(issuer, list): + if payload["iss"] not in issuer: + raise InvalidIssuerError("Invalid issuer") + else: + if payload["iss"] != issuer: + raise InvalidIssuerError("Invalid issuer") _jwt_global_obj = PyJWT() diff --git a/lib/jwt/exceptions.py b/lib/jwt/exceptions.py index 8ac6ecf7..0d985882 100644 --- a/lib/jwt/exceptions.py +++ b/lib/jwt/exceptions.py @@ -58,6 +58,10 @@ class PyJWKError(PyJWTError): pass +class MissingCryptographyError(PyJWKError): + pass + + class PyJWKSetError(PyJWTError): pass diff --git a/lib/jwt/utils.py b/lib/jwt/utils.py index 81c5ee41..d469139b 100644 --- a/lib/jwt/utils.py +++ b/lib/jwt/utils.py @@ -131,26 +131,15 @@ def is_pem_format(key: bytes) -> bool: # Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46 -_CERT_SUFFIX = b"-cert-v01@openssh.com" -_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") -_SSH_KEY_FORMATS = [ +_SSH_KEY_FORMATS = ( b"ssh-ed25519", b"ssh-rsa", b"ssh-dss", b"ecdsa-sha2-nistp256", b"ecdsa-sha2-nistp384", b"ecdsa-sha2-nistp521", -] +) def is_ssh_key(key: bytes) -> bool: - if any(string_value in key for string_value in _SSH_KEY_FORMATS): - return True - - ssh_pubkey_match = _SSH_PUBKEY_RC.match(key) - if ssh_pubkey_match: - key_type = ssh_pubkey_match.group(1) - if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]: - return True - - return False + return key.startswith(_SSH_KEY_FORMATS) diff --git a/requirements.txt b/requirements.txt index bbbfe81e..0e7f8cf0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,7 +29,7 @@ platformdirs==4.2.2 plexapi==4.15.15 portend==3.2.0 profilehooks==1.12.0 -PyJWT==2.8.0 +PyJWT==2.9.0 pyparsing==3.1.2 python-dateutil==2.9.0.post0 python-twitter==3.5