From dd9a35df51234dbb19ff7f55d6ff788f4242171a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 19 Nov 2024 10:01:27 -0800 Subject: [PATCH] Bump pyjwt from 2.9.0 to 2.10.0 (#2441) * Bump pyjwt from 2.9.0 to 2.10.0 Bumps [pyjwt](https://github.com/jpadilla/pyjwt) from 2.9.0 to 2.10.0. - [Release notes](https://github.com/jpadilla/pyjwt/releases) - [Changelog](https://github.com/jpadilla/pyjwt/blob/master/CHANGELOG.rst) - [Commits](https://github.com/jpadilla/pyjwt/compare/2.9.0...2.10.0) --- updated-dependencies: - dependency-name: pyjwt dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * Update pyjwt==2.10.0 --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci] --- lib/jwt/__init__.py | 5 ++- lib/jwt/algorithms.py | 47 ++++++++++++++------- lib/jwt/api_jws.py | 27 ++++++++---- lib/jwt/api_jwt.py | 94 +++++++++++++++++++++++++++++++++--------- lib/jwt/exceptions.py | 8 ++++ lib/jwt/help.py | 5 ++- lib/jwt/jwks_client.py | 6 ++- lib/jwt/utils.py | 17 ++++---- requirements.txt | 2 +- 9 files changed, 152 insertions(+), 59 deletions(-) diff --git a/lib/jwt/__init__.py b/lib/jwt/__init__.py index b7a258d7..9d4b6744 100644 --- a/lib/jwt/__init__.py +++ b/lib/jwt/__init__.py @@ -6,7 +6,7 @@ from .api_jws import ( register_algorithm, unregister_algorithm, ) -from .api_jwt import PyJWT, decode, encode +from .api_jwt import PyJWT, decode, decode_complete, encode from .exceptions import ( DecodeError, ExpiredSignatureError, @@ -27,7 +27,7 @@ from .exceptions import ( ) from .jwks_client import PyJWKClient -__version__ = "2.9.0" +__version__ = "2.10.0" __title__ = "PyJWT" __description__ = "JSON Web Token implementation in Python" @@ -49,6 +49,7 @@ __all__ = [ "PyJWK", "PyJWKSet", "decode", + "decode_complete", "encode", "get_unverified_header", "register_algorithm", diff --git a/lib/jwt/algorithms.py b/lib/jwt/algorithms.py index 9be50b20..ccb1500f 100644 --- a/lib/jwt/algorithms.py +++ b/lib/jwt/algorithms.py @@ -297,7 +297,7 @@ class HMACAlgorithm(Algorithm): else: raise ValueError except ValueError: - raise InvalidKeyError("Key is not valid JSON") + raise InvalidKeyError("Key is not valid JSON") from None if obj.get("kty") != "oct": raise InvalidKeyError("Not an HMAC key") @@ -346,7 +346,9 @@ if has_crypto: try: return cast(RSAPublicKey, load_pem_public_key(key_bytes)) except (ValueError, UnsupportedAlgorithm): - raise InvalidKeyError("Could not parse the provided public key.") + raise InvalidKeyError( + "Could not parse the provided public key." + ) from None @overload @staticmethod @@ -409,10 +411,10 @@ if has_crypto: else: raise ValueError except ValueError: - raise InvalidKeyError("Key is not valid JSON") + raise InvalidKeyError("Key is not valid JSON") from None if obj.get("kty") != "RSA": - raise InvalidKeyError("Not an RSA key") + raise InvalidKeyError("Not an RSA key") from None if "d" in obj and "e" in obj and "n" in obj: # Private key @@ -428,7 +430,7 @@ if has_crypto: if any_props_found and not all(props_found): raise InvalidKeyError( "RSA key must include all parameters if any are present besides d" - ) + ) from None public_numbers = RSAPublicNumbers( from_base64url_uint(obj["e"]), @@ -520,7 +522,7 @@ if has_crypto: ): raise InvalidKeyError( "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms" - ) + ) from None return crypto_key @@ -581,13 +583,20 @@ if has_crypto: obj: dict[str, Any] = { "kty": "EC", "crv": crv, - "x": to_base64url_uint(public_numbers.x).decode(), - "y": to_base64url_uint(public_numbers.y).decode(), + "x": to_base64url_uint( + public_numbers.x, + bit_length=key_obj.curve.key_size, + ).decode(), + "y": to_base64url_uint( + public_numbers.y, + bit_length=key_obj.curve.key_size, + ).decode(), } if isinstance(key_obj, EllipticCurvePrivateKey): obj["d"] = to_base64url_uint( - key_obj.private_numbers().private_value + key_obj.private_numbers().private_value, + bit_length=key_obj.curve.key_size, ).decode() if as_dict: @@ -605,13 +614,13 @@ if has_crypto: else: raise ValueError except ValueError: - raise InvalidKeyError("Key is not valid JSON") + raise InvalidKeyError("Key is not valid JSON") from None if obj.get("kty") != "EC": - raise InvalidKeyError("Not an Elliptic curve key") + raise InvalidKeyError("Not an Elliptic curve key") from None if "x" not in obj or "y" not in obj: - raise InvalidKeyError("Not an Elliptic curve key") + raise InvalidKeyError("Not an Elliptic curve key") from None x = base64url_decode(obj.get("x")) y = base64url_decode(obj.get("y")) @@ -623,17 +632,23 @@ if has_crypto: if len(x) == len(y) == 32: curve_obj = SECP256R1() else: - raise InvalidKeyError("Coords should be 32 bytes for curve P-256") + raise InvalidKeyError( + "Coords should be 32 bytes for curve P-256" + ) from None elif curve == "P-384": if len(x) == len(y) == 48: curve_obj = SECP384R1() else: - raise InvalidKeyError("Coords should be 48 bytes for curve P-384") + raise InvalidKeyError( + "Coords should be 48 bytes for curve P-384" + ) from None elif curve == "P-521": if len(x) == len(y) == 66: curve_obj = SECP521R1() else: - raise InvalidKeyError("Coords should be 66 bytes for curve P-521") + raise InvalidKeyError( + "Coords should be 66 bytes for curve P-521" + ) from None elif curve == "secp256k1": if len(x) == len(y) == 32: curve_obj = SECP256K1() @@ -834,7 +849,7 @@ if has_crypto: else: raise ValueError except ValueError: - raise InvalidKeyError("Key is not valid JSON") + raise InvalidKeyError("Key is not valid JSON") from None if obj.get("kty") != "OKP": raise InvalidKeyError("Not an Octet Key Pair") diff --git a/lib/jwt/api_jws.py b/lib/jwt/api_jws.py index 5822ebf6..654ee0b7 100644 --- a/lib/jwt/api_jws.py +++ b/lib/jwt/api_jws.py @@ -3,6 +3,7 @@ from __future__ import annotations import binascii import json import warnings +from collections.abc import Sequence from typing import TYPE_CHECKING, Any from .algorithms import ( @@ -30,7 +31,7 @@ class PyJWS: def __init__( self, - algorithms: list[str] | None = None, + algorithms: Sequence[str] | None = None, options: dict[str, Any] | None = None, ) -> None: self._algorithms = get_default_algorithms() @@ -104,8 +105,8 @@ class PyJWS: def encode( self, payload: bytes, - key: AllowedPrivateKeys | str | bytes, - algorithm: str | None = "HS256", + key: AllowedPrivateKeys | PyJWK | str | bytes, + algorithm: str | None = None, headers: dict[str, Any] | None = None, json_encoder: type[json.JSONEncoder] | None = None, is_payload_detached: bool = False, @@ -114,7 +115,13 @@ class PyJWS: segments = [] # declare a new var to narrow the type for type checkers - algorithm_: str = algorithm if algorithm is not None else "none" + if algorithm is None: + if isinstance(key, PyJWK): + algorithm_ = key.algorithm_name + else: + algorithm_ = "HS256" + else: + algorithm_ = algorithm # Prefer headers values if present to function parameters. if headers: @@ -158,6 +165,8 @@ class PyJWS: signing_input = b".".join(segments) alg_obj = self.get_algorithm_by_name(algorithm_) + if isinstance(key, PyJWK): + key = key.key key = alg_obj.prepare_key(key) signature = alg_obj.sign(signing_input, key) @@ -174,7 +183,7 @@ class PyJWS: self, jwt: str | bytes, key: AllowedPublicKeys | PyJWK | str | bytes = "", - algorithms: list[str] | None = None, + algorithms: Sequence[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, **kwargs, @@ -185,6 +194,7 @@ class PyJWS: "and will be removed in pyjwt version 3. " f"Unsupported kwargs: {tuple(kwargs.keys())}", RemovedInPyjwt3Warning, + stacklevel=2, ) if options is None: options = {} @@ -219,7 +229,7 @@ class PyJWS: self, jwt: str | bytes, key: AllowedPublicKeys | PyJWK | str | bytes = "", - algorithms: list[str] | None = None, + algorithms: Sequence[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, **kwargs, @@ -230,6 +240,7 @@ class PyJWS: "and will be removed in pyjwt version 3. " f"Unsupported kwargs: {tuple(kwargs.keys())}", RemovedInPyjwt3Warning, + stacklevel=2, ) decoded = self.decode_complete( jwt, key, algorithms, options, detached_payload=detached_payload @@ -291,14 +302,14 @@ class PyJWS: header: dict[str, Any], signature: bytes, key: AllowedPublicKeys | PyJWK | str | bytes = "", - algorithms: list[str] | None = None, + algorithms: Sequence[str] | None = None, ) -> None: if algorithms is None and isinstance(key, PyJWK): algorithms = [key.algorithm_name] try: alg = header["alg"] except KeyError: - raise InvalidAlgorithmError("Algorithm not specified") + raise InvalidAlgorithmError("Algorithm not specified") from None if not alg or (algorithms is not None and alg not in algorithms): raise InvalidAlgorithmError("The specified alg value is not allowed") diff --git a/lib/jwt/api_jwt.py b/lib/jwt/api_jwt.py index 7a07c336..fa4d5e6f 100644 --- a/lib/jwt/api_jwt.py +++ b/lib/jwt/api_jwt.py @@ -3,9 +3,9 @@ from __future__ import annotations import json import warnings from calendar import timegm -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, Any from . import api_jws from .exceptions import ( @@ -15,6 +15,8 @@ from .exceptions import ( InvalidAudienceError, InvalidIssuedAtError, InvalidIssuerError, + InvalidJTIError, + InvalidSubjectError, MissingRequiredClaimError, ) from .warnings import RemovedInPyjwt3Warning @@ -39,14 +41,16 @@ class PyJWT: "verify_iat": True, "verify_aud": True, "verify_iss": True, + "verify_sub": True, + "verify_jti": True, "require": [], } def encode( self, payload: dict[str, Any], - key: AllowedPrivateKeys | str | bytes, - algorithm: str | None = "HS256", + key: AllowedPrivateKeys | PyJWK | str | bytes, + algorithm: str | None = None, headers: dict[str, Any] | None = None, json_encoder: type[json.JSONEncoder] | None = None, sort_headers: bool = True, @@ -102,7 +106,7 @@ class PyJWT: self, jwt: str | bytes, key: AllowedPublicKeys | PyJWK | str | bytes = "", - algorithms: list[str] | None = None, + algorithms: Sequence[str] | None = None, options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 verify: bool | None = None, @@ -111,7 +115,8 @@ class PyJWT: # passthrough arguments to _validate_claims # consider putting in options audience: str | Iterable[str] | None = None, - issuer: str | List[str] | None = None, + issuer: str | Sequence[str] | None = None, + subject: str | None = None, leeway: float | timedelta = 0, # kwargs **kwargs: Any, @@ -122,6 +127,7 @@ class PyJWT: "and will be removed in pyjwt version 3. " f"Unsupported kwargs: {tuple(kwargs.keys())}", RemovedInPyjwt3Warning, + stacklevel=2, ) options = dict(options or {}) # shallow-copy or initialize an empty dict options.setdefault("verify_signature", True) @@ -135,6 +141,7 @@ class PyJWT: "The equivalent is setting `verify_signature` to False in the `options` dictionary. " "This invocation has a mismatch between the kwarg and the option entry.", category=DeprecationWarning, + stacklevel=2, ) if not options["verify_signature"]: @@ -143,11 +150,8 @@ class PyJWT: options.setdefault("verify_iat", False) options.setdefault("verify_aud", False) options.setdefault("verify_iss", False) - - if options["verify_signature"] and not algorithms: - raise DecodeError( - 'It is required that you pass in a value for the "algorithms" argument when calling decode().' - ) + options.setdefault("verify_sub", False) + options.setdefault("verify_jti", False) decoded = api_jws.decode_complete( jwt, @@ -161,7 +165,12 @@ class PyJWT: merged_options = {**self.options, **options} self._validate_claims( - payload, merged_options, audience=audience, issuer=issuer, leeway=leeway + payload, + merged_options, + audience=audience, + issuer=issuer, + leeway=leeway, + subject=subject, ) decoded["payload"] = payload @@ -178,7 +187,7 @@ class PyJWT: try: payload = json.loads(decoded["payload"]) except ValueError as e: - raise DecodeError(f"Invalid payload string: {e}") + raise DecodeError(f"Invalid payload string: {e}") from e if not isinstance(payload, dict): raise DecodeError("Invalid payload string: must be a json object") return payload @@ -187,7 +196,7 @@ class PyJWT: self, jwt: str | bytes, key: AllowedPublicKeys | PyJWK | str | bytes = "", - algorithms: list[str] | None = None, + algorithms: Sequence[str] | None = None, options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 verify: bool | None = None, @@ -196,7 +205,8 @@ class PyJWT: # passthrough arguments to _validate_claims # consider putting in options audience: str | Iterable[str] | None = None, - issuer: str | List[str] | None = None, + subject: str | None = None, + issuer: str | Sequence[str] | None = None, leeway: float | timedelta = 0, # kwargs **kwargs: Any, @@ -207,6 +217,7 @@ class PyJWT: "and will be removed in pyjwt version 3. " f"Unsupported kwargs: {tuple(kwargs.keys())}", RemovedInPyjwt3Warning, + stacklevel=2, ) decoded = self.decode_complete( jwt, @@ -216,6 +227,7 @@ class PyJWT: verify=verify, detached_payload=detached_payload, audience=audience, + subject=subject, issuer=issuer, leeway=leeway, ) @@ -227,6 +239,7 @@ class PyJWT: options: dict[str, Any], audience=None, issuer=None, + subject: str | None = None, leeway: float | timedelta = 0, ) -> None: if isinstance(leeway, timedelta): @@ -256,6 +269,12 @@ class PyJWT: payload, audience, strict=options.get("strict_aud", False) ) + if options["verify_sub"]: + self._validate_sub(payload, subject) + + if options["verify_jti"]: + self._validate_jti(payload) + def _validate_required_claims( self, payload: dict[str, Any], @@ -265,6 +284,39 @@ class PyJWT: if payload.get(claim) is None: raise MissingRequiredClaimError(claim) + def _validate_sub(self, payload: dict[str, Any], subject=None) -> None: + """ + Checks whether "sub" if in the payload is valid ot not. + This is an Optional claim + + :param payload(dict): The payload which needs to be validated + :param subject(str): The subject of the token + """ + + if "sub" not in payload: + return + + if not isinstance(payload["sub"], str): + raise InvalidSubjectError("Subject must be a string") + + if subject is not None: + if payload.get("sub") != subject: + raise InvalidSubjectError("Invalid subject") + + def _validate_jti(self, payload: dict[str, Any]) -> None: + """ + Checks whether "jti" if in the payload is valid ot not + This is an Optional claim + + :param payload(dict): The payload which needs to be validated + """ + + if "jti" not in payload: + return + + if not isinstance(payload.get("jti"), str): + raise InvalidJTIError("JWT ID must be a string") + def _validate_iat( self, payload: dict[str, Any], @@ -274,7 +326,9 @@ class PyJWT: try: iat = int(payload["iat"]) except ValueError: - raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.") + raise InvalidIssuedAtError( + "Issued At claim (iat) must be an integer." + ) from None if iat > (now + leeway): raise ImmatureSignatureError("The token is not yet valid (iat)") @@ -287,7 +341,7 @@ class PyJWT: try: nbf = int(payload["nbf"]) except ValueError: - raise DecodeError("Not Before claim (nbf) must be an integer.") + raise DecodeError("Not Before claim (nbf) must be an integer.") from None if nbf > (now + leeway): raise ImmatureSignatureError("The token is not yet valid (nbf)") @@ -301,7 +355,9 @@ class PyJWT: try: exp = int(payload["exp"]) except ValueError: - raise DecodeError("Expiration Time claim (exp) must be an integer.") + raise DecodeError( + "Expiration Time claim (exp) must be an integer." + ) from None if exp <= (now - leeway): raise ExpiredSignatureError("Signature has expired") @@ -363,7 +419,7 @@ class PyJWT: if "iss" not in payload: raise MissingRequiredClaimError("iss") - if isinstance(issuer, list): + if isinstance(issuer, Sequence): if payload["iss"] not in issuer: raise InvalidIssuerError("Invalid issuer") else: diff --git a/lib/jwt/exceptions.py b/lib/jwt/exceptions.py index 0d985882..9b45ae48 100644 --- a/lib/jwt/exceptions.py +++ b/lib/jwt/exceptions.py @@ -72,3 +72,11 @@ class PyJWKClientError(PyJWTError): class PyJWKClientConnectionError(PyJWKClientError): pass + + +class InvalidSubjectError(InvalidTokenError): + pass + + +class InvalidJTIError(InvalidTokenError): + pass diff --git a/lib/jwt/help.py b/lib/jwt/help.py index 80b0ca56..8e1c2286 100644 --- a/lib/jwt/help.py +++ b/lib/jwt/help.py @@ -39,7 +39,10 @@ def info() -> Dict[str, Dict[str, str]]: ) if pypy_version_info.releaselevel != "final": implementation_version = "".join( - [implementation_version, pypy_version_info.releaselevel] + [ + implementation_version, + pypy_version_info.releaselevel, + ] ) else: implementation_version = "Unknown" diff --git a/lib/jwt/jwks_client.py b/lib/jwt/jwks_client.py index f19b10ac..9a8992ca 100644 --- a/lib/jwt/jwks_client.py +++ b/lib/jwt/jwks_client.py @@ -45,7 +45,9 @@ class PyJWKClient: if cache_keys: # Cache signing keys # Ignore mypy (https://github.com/python/mypy/issues/2427) - self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore + self.get_signing_key = lru_cache(maxsize=max_cached_keys)( + self.get_signing_key + ) # type: ignore def fetch_data(self) -> Any: jwk_set: Any = None @@ -58,7 +60,7 @@ class PyJWKClient: except (URLError, TimeoutError) as e: raise PyJWKClientConnectionError( f'Fail to fetch data from the url, err: "{e}"' - ) + ) from e else: return jwk_set finally: diff --git a/lib/jwt/utils.py b/lib/jwt/utils.py index d469139b..56e89bb7 100644 --- a/lib/jwt/utils.py +++ b/lib/jwt/utils.py @@ -1,7 +1,7 @@ import base64 import binascii import re -from typing import Union +from typing import Optional, Union try: from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve @@ -37,11 +37,11 @@ def base64url_encode(input: bytes) -> bytes: return base64.urlsafe_b64encode(input).replace(b"=", b"") -def to_base64url_uint(val: int) -> bytes: +def to_base64url_uint(val: int, *, bit_length: Optional[int] = None) -> bytes: if val < 0: raise ValueError("Must be a positive integer") - int_bytes = bytes_from_int(val) + int_bytes = bytes_from_int(val, bit_length=bit_length) if len(int_bytes) == 0: int_bytes = b"\x00" @@ -63,13 +63,10 @@ def bytes_to_number(string: bytes) -> int: return int(binascii.b2a_hex(string), 16) -def bytes_from_int(val: int) -> bytes: - remaining = val - byte_length = 0 - - while remaining != 0: - remaining >>= 8 - byte_length += 1 +def bytes_from_int(val: int, *, bit_length: Optional[int] = None) -> bytes: + if bit_length is None: + bit_length = val.bit_length() + byte_length = (bit_length + 7) // 8 return val.to_bytes(byte_length, "big", signed=False) diff --git a/requirements.txt b/requirements.txt index 9413a3f6..e6a85bdb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,7 +28,7 @@ platformdirs==4.3.6 plexapi==4.16.0 portend==3.2.0 profilehooks==1.13.0 -PyJWT==2.9.0 +PyJWT==2.10.0 pyparsing==3.2.0 python-dateutil==2.9.0.post0 python-twitter==3.5