Bump pyjwt from 2.6.0 to 2.8.0 (#2115)

* Bump pyjwt from 2.6.0 to 2.8.0

Bumps [pyjwt](https://github.com/jpadilla/pyjwt) from 2.6.0 to 2.8.0.
- [Release notes](https://github.com/jpadilla/pyjwt/releases)
- [Changelog](https://github.com/jpadilla/pyjwt/blob/master/CHANGELOG.rst)
- [Commits](https://github.com/jpadilla/pyjwt/compare/2.6.0...2.8.0)

---
updated-dependencies:
- dependency-name: pyjwt
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update pyjwt==2.8.0

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com>

[skip ci]
This commit is contained in:
dependabot[bot] 2023-08-23 21:45:15 -07:00 committed by GitHub
parent 77f38bbf93
commit c93f470371
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 542 additions and 260 deletions

View file

@ -19,6 +19,7 @@ from .exceptions import (
InvalidSignatureError, InvalidSignatureError,
InvalidTokenError, InvalidTokenError,
MissingRequiredClaimError, MissingRequiredClaimError,
PyJWKClientConnectionError,
PyJWKClientError, PyJWKClientError,
PyJWKError, PyJWKError,
PyJWKSetError, PyJWKSetError,
@ -26,7 +27,7 @@ from .exceptions import (
) )
from .jwks_client import PyJWKClient from .jwks_client import PyJWKClient
__version__ = "2.6.0" __version__ = "2.8.0"
__title__ = "PyJWT" __title__ = "PyJWT"
__description__ = "JSON Web Token implementation in Python" __description__ = "JSON Web Token implementation in Python"
@ -65,6 +66,7 @@ __all__ = [
"InvalidSignatureError", "InvalidSignatureError",
"InvalidTokenError", "InvalidTokenError",
"MissingRequiredClaimError", "MissingRequiredClaimError",
"PyJWKClientConnectionError",
"PyJWKClientError", "PyJWKClientError",
"PyJWKError", "PyJWKError",
"PyJWKSetError", "PyJWKSetError",

View file

@ -1,8 +1,14 @@
from __future__ import annotations
import hashlib import hashlib
import hmac import hmac
import json import json
import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload
from .exceptions import InvalidKeyError from .exceptions import InvalidKeyError
from .types import HashlibHash, JWKDict
from .utils import ( from .utils import (
base64url_decode, base64url_decode,
base64url_encode, base64url_encode,
@ -15,14 +21,28 @@ from .utils import (
to_base64url_uint, to_base64url_uint,
) )
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
try: try:
import cryptography.exceptions
from cryptography.exceptions import InvalidSignature from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec, padding from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.ec import ( from cryptography.hazmat.primitives.asymmetric.ec import (
ECDSA,
SECP256K1,
SECP256R1,
SECP384R1,
SECP521R1,
EllipticCurve,
EllipticCurvePrivateKey, EllipticCurvePrivateKey,
EllipticCurvePrivateNumbers,
EllipticCurvePublicKey, EllipticCurvePublicKey,
EllipticCurvePublicNumbers,
) )
from cryptography.hazmat.primitives.asymmetric.ed448 import ( from cryptography.hazmat.primitives.asymmetric.ed448 import (
Ed448PrivateKey, Ed448PrivateKey,
@ -56,6 +76,23 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
has_crypto = False has_crypto = False
if TYPE_CHECKING:
# Type aliases for convenience in algorithms method signatures
AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
AllowedOKPKeys = (
Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
)
AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
AllowedPrivateKeys = (
RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
)
AllowedPublicKeys = (
RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
)
requires_cryptography = { requires_cryptography = {
"RS256", "RS256",
"RS384", "RS384",
@ -72,7 +109,7 @@ requires_cryptography = {
} }
def get_default_algorithms(): def get_default_algorithms() -> dict[str, Algorithm]:
""" """
Returns the algorithms that are implemented by the library. Returns the algorithms that are implemented by the library.
""" """
@ -106,45 +143,79 @@ def get_default_algorithms():
return default_algorithms return default_algorithms
class Algorithm: class Algorithm(ABC):
""" """
The interface for an algorithm used to sign and verify tokens. The interface for an algorithm used to sign and verify tokens.
""" """
def prepare_key(self, key): def compute_hash_digest(self, bytestr: bytes) -> bytes:
"""
Compute a hash digest using the specified algorithm's hash algorithm.
If there is no hash algorithm, raises a NotImplementedError.
"""
# lookup self.hash_alg if defined in a way that mypy can understand
hash_alg = getattr(self, "hash_alg", None)
if hash_alg is None:
raise NotImplementedError
if (
has_crypto
and isinstance(hash_alg, type)
and issubclass(hash_alg, hashes.HashAlgorithm)
):
digest = hashes.Hash(hash_alg(), backend=default_backend())
digest.update(bytestr)
return bytes(digest.finalize())
else:
return bytes(hash_alg(bytestr).digest())
@abstractmethod
def prepare_key(self, key: Any) -> Any:
""" """
Performs necessary validation and conversions on the key and returns Performs necessary validation and conversions on the key and returns
the key value in the proper format for sign() and verify(). the key value in the proper format for sign() and verify().
""" """
raise NotImplementedError
def sign(self, msg, key): @abstractmethod
def sign(self, msg: bytes, key: Any) -> bytes:
""" """
Returns a digital signature for the specified message Returns a digital signature for the specified message
using the specified key value. using the specified key value.
""" """
raise NotImplementedError
def verify(self, msg, key, sig): @abstractmethod
def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
""" """
Verifies that the specified digital signature is valid Verifies that the specified digital signature is valid
for the specified message and key values. for the specified message and key values.
""" """
raise NotImplementedError
@overload
@staticmethod
@abstractmethod
def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover
@overload
@staticmethod
@abstractmethod
def to_jwk(key_obj, as_dict: Literal[False] = False) -> str:
... # pragma: no cover
@staticmethod @staticmethod
def to_jwk(key_obj): @abstractmethod
def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]:
""" """
Serializes a given RSA key into a JWK Serializes a given key into a JWK
""" """
raise NotImplementedError
@staticmethod @staticmethod
def from_jwk(jwk): @abstractmethod
def from_jwk(jwk: str | JWKDict) -> Any:
""" """
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object Deserializes a given key from JWK back into a key object
""" """
raise NotImplementedError
class NoneAlgorithm(Algorithm): class NoneAlgorithm(Algorithm):
@ -153,7 +224,7 @@ class NoneAlgorithm(Algorithm):
operations are required. operations are required.
""" """
def prepare_key(self, key): def prepare_key(self, key: str | None) -> None:
if key == "": if key == "":
key = None key = None
@ -162,12 +233,20 @@ class NoneAlgorithm(Algorithm):
return key return key
def sign(self, msg, key): def sign(self, msg: bytes, key: None) -> bytes:
return b"" return b""
def verify(self, msg, key, sig): def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
return False return False
@staticmethod
def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
raise NotImplementedError()
@staticmethod
def from_jwk(jwk: str | JWKDict) -> NoReturn:
raise NotImplementedError()
class HMACAlgorithm(Algorithm): class HMACAlgorithm(Algorithm):
""" """
@ -175,38 +254,51 @@ class HMACAlgorithm(Algorithm):
and the specified hash function. and the specified hash function.
""" """
SHA256 = hashlib.sha256 SHA256: ClassVar[HashlibHash] = hashlib.sha256
SHA384 = hashlib.sha384 SHA384: ClassVar[HashlibHash] = hashlib.sha384
SHA512 = hashlib.sha512 SHA512: ClassVar[HashlibHash] = hashlib.sha512
def __init__(self, hash_alg): def __init__(self, hash_alg: HashlibHash) -> None:
self.hash_alg = hash_alg self.hash_alg = hash_alg
def prepare_key(self, key): def prepare_key(self, key: str | bytes) -> bytes:
key = force_bytes(key) key_bytes = force_bytes(key)
if is_pem_format(key) or is_ssh_key(key): if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
raise InvalidKeyError( raise InvalidKeyError(
"The specified key is an asymmetric key or x509 certificate and" "The specified key is an asymmetric key or x509 certificate and"
" should not be used as an HMAC secret." " should not be used as an HMAC secret."
) )
return key return key_bytes
@overload
@staticmethod
def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover
@overload
@staticmethod
def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str:
... # pragma: no cover
@staticmethod @staticmethod
def to_jwk(key_obj): def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]:
return json.dumps( jwk = {
{ "k": base64url_encode(force_bytes(key_obj)).decode(),
"k": base64url_encode(force_bytes(key_obj)).decode(), "kty": "oct",
"kty": "oct", }
}
) if as_dict:
return jwk
else:
return json.dumps(jwk)
@staticmethod @staticmethod
def from_jwk(jwk): def from_jwk(jwk: str | JWKDict) -> bytes:
try: try:
if isinstance(jwk, str): if isinstance(jwk, str):
obj = json.loads(jwk) obj: JWKDict = json.loads(jwk)
elif isinstance(jwk, dict): elif isinstance(jwk, dict):
obj = jwk obj = jwk
else: else:
@ -219,10 +311,10 @@ class HMACAlgorithm(Algorithm):
return base64url_decode(obj["k"]) return base64url_decode(obj["k"])
def sign(self, msg, key): def sign(self, msg: bytes, key: bytes) -> bytes:
return hmac.new(key, msg, self.hash_alg).digest() return hmac.new(key, msg, self.hash_alg).digest()
def verify(self, msg, key, sig): def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
return hmac.compare_digest(sig, self.sign(msg, key)) return hmac.compare_digest(sig, self.sign(msg, key))
@ -234,36 +326,49 @@ if has_crypto:
RSASSA-PKCS-v1_5 and the specified hash function. RSASSA-PKCS-v1_5 and the specified hash function.
""" """
SHA256 = hashes.SHA256 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
SHA384 = hashes.SHA384 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
SHA512 = hashes.SHA512 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
def __init__(self, hash_alg): def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg self.hash_alg = hash_alg
def prepare_key(self, key): def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
if isinstance(key, (RSAPrivateKey, RSAPublicKey)): if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
return key return key
if not isinstance(key, (bytes, str)): if not isinstance(key, (bytes, str)):
raise TypeError("Expecting a PEM-formatted key.") raise TypeError("Expecting a PEM-formatted key.")
key = force_bytes(key) key_bytes = force_bytes(key)
try: try:
if key.startswith(b"ssh-rsa"): if key_bytes.startswith(b"ssh-rsa"):
key = load_ssh_public_key(key) return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
else: else:
key = load_pem_private_key(key, password=None) return cast(
RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
)
except ValueError: except ValueError:
key = load_pem_public_key(key) return cast(RSAPublicKey, load_pem_public_key(key_bytes))
return key
@overload
@staticmethod
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover
@overload
@staticmethod
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str:
... # pragma: no cover
@staticmethod @staticmethod
def to_jwk(key_obj): def to_jwk(
obj = None key_obj: AllowedRSAKeys, as_dict: bool = False
) -> Union[JWKDict, str]:
obj: dict[str, Any] | None = None
if getattr(key_obj, "private_numbers", None): if hasattr(key_obj, "private_numbers"):
# Private key # Private key
numbers = key_obj.private_numbers() numbers = key_obj.private_numbers()
@ -280,7 +385,7 @@ if has_crypto:
"qi": to_base64url_uint(numbers.iqmp).decode(), "qi": to_base64url_uint(numbers.iqmp).decode(),
} }
elif getattr(key_obj, "verify", None): elif hasattr(key_obj, "verify"):
# Public key # Public key
numbers = key_obj.public_numbers() numbers = key_obj.public_numbers()
@ -293,10 +398,13 @@ if has_crypto:
else: else:
raise InvalidKeyError("Not a public or private key") raise InvalidKeyError("Not a public or private key")
return json.dumps(obj) if as_dict:
return obj
else:
return json.dumps(obj)
@staticmethod @staticmethod
def from_jwk(jwk): def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
try: try:
if isinstance(jwk, str): if isinstance(jwk, str):
obj = json.loads(jwk) obj = json.loads(jwk)
@ -360,19 +468,17 @@ if has_crypto:
return numbers.private_key() return numbers.private_key()
elif "n" in obj and "e" in obj: elif "n" in obj and "e" in obj:
# Public key # Public key
numbers = RSAPublicNumbers( return RSAPublicNumbers(
from_base64url_uint(obj["e"]), from_base64url_uint(obj["e"]),
from_base64url_uint(obj["n"]), from_base64url_uint(obj["n"]),
) ).public_key()
return numbers.public_key()
else: else:
raise InvalidKeyError("Not a public or private key") raise InvalidKeyError("Not a public or private key")
def sign(self, msg, key): def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
return key.sign(msg, padding.PKCS1v15(), self.hash_alg()) return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
def verify(self, msg, key, sig): def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
try: try:
key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg()) key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
return True return True
@ -385,63 +491,79 @@ if has_crypto:
ECDSA and the specified hash function ECDSA and the specified hash function
""" """
SHA256 = hashes.SHA256 SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
SHA384 = hashes.SHA384 SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
SHA512 = hashes.SHA512 SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
def __init__(self, hash_alg): def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg self.hash_alg = hash_alg
def prepare_key(self, key): def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
return key return key
if not isinstance(key, (bytes, str)): if not isinstance(key, (bytes, str)):
raise TypeError("Expecting a PEM-formatted key.") raise TypeError("Expecting a PEM-formatted key.")
key = force_bytes(key) key_bytes = force_bytes(key)
# Attempt to load key. We don't know if it's # Attempt to load key. We don't know if it's
# a Signing Key or a Verifying Key, so we try # a Signing Key or a Verifying Key, so we try
# the Verifying Key first. # the Verifying Key first.
try: try:
if key.startswith(b"ecdsa-sha2-"): if key_bytes.startswith(b"ecdsa-sha2-"):
key = load_ssh_public_key(key) crypto_key = load_ssh_public_key(key_bytes)
else: else:
key = load_pem_public_key(key) crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment]
except ValueError: except ValueError:
key = load_pem_private_key(key, password=None) crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
# Explicit check the key to prevent confusing errors from cryptography # Explicit check the key to prevent confusing errors from cryptography
if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): if not isinstance(
crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
):
raise InvalidKeyError( raise InvalidKeyError(
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms" "Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
) )
return key return crypto_key
def sign(self, msg, key): def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
der_sig = key.sign(msg, ec.ECDSA(self.hash_alg())) der_sig = key.sign(msg, ECDSA(self.hash_alg()))
return der_to_raw_signature(der_sig, key.curve) return der_to_raw_signature(der_sig, key.curve)
def verify(self, msg, key, sig): def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool:
try: try:
der_sig = raw_to_der_signature(sig, key.curve) der_sig = raw_to_der_signature(sig, key.curve)
except ValueError: except ValueError:
return False return False
try: try:
if isinstance(key, EllipticCurvePrivateKey): public_key = (
key = key.public_key() key.public_key()
key.verify(der_sig, msg, ec.ECDSA(self.hash_alg())) if isinstance(key, EllipticCurvePrivateKey)
else key
)
public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
return True return True
except InvalidSignature: except InvalidSignature:
return False return False
@overload
@staticmethod @staticmethod
def to_jwk(key_obj): def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover
@overload
@staticmethod
def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str:
... # pragma: no cover
@staticmethod
def to_jwk(
key_obj: AllowedECKeys, as_dict: bool = False
) -> Union[JWKDict, str]:
if isinstance(key_obj, EllipticCurvePrivateKey): if isinstance(key_obj, EllipticCurvePrivateKey):
public_numbers = key_obj.public_key().public_numbers() public_numbers = key_obj.public_key().public_numbers()
elif isinstance(key_obj, EllipticCurvePublicKey): elif isinstance(key_obj, EllipticCurvePublicKey):
@ -449,18 +571,18 @@ if has_crypto:
else: else:
raise InvalidKeyError("Not a public or private key") raise InvalidKeyError("Not a public or private key")
if isinstance(key_obj.curve, ec.SECP256R1): if isinstance(key_obj.curve, SECP256R1):
crv = "P-256" crv = "P-256"
elif isinstance(key_obj.curve, ec.SECP384R1): elif isinstance(key_obj.curve, SECP384R1):
crv = "P-384" crv = "P-384"
elif isinstance(key_obj.curve, ec.SECP521R1): elif isinstance(key_obj.curve, SECP521R1):
crv = "P-521" crv = "P-521"
elif isinstance(key_obj.curve, ec.SECP256K1): elif isinstance(key_obj.curve, SECP256K1):
crv = "secp256k1" crv = "secp256k1"
else: else:
raise InvalidKeyError(f"Invalid curve: {key_obj.curve}") raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
obj = { obj: dict[str, Any] = {
"kty": "EC", "kty": "EC",
"crv": crv, "crv": crv,
"x": to_base64url_uint(public_numbers.x).decode(), "x": to_base64url_uint(public_numbers.x).decode(),
@ -472,10 +594,13 @@ if has_crypto:
key_obj.private_numbers().private_value key_obj.private_numbers().private_value
).decode() ).decode()
return json.dumps(obj) if as_dict:
return obj
else:
return json.dumps(obj)
@staticmethod @staticmethod
def from_jwk(jwk): def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
try: try:
if isinstance(jwk, str): if isinstance(jwk, str):
obj = json.loads(jwk) obj = json.loads(jwk)
@ -496,24 +621,26 @@ if has_crypto:
y = base64url_decode(obj.get("y")) y = base64url_decode(obj.get("y"))
curve = obj.get("crv") curve = obj.get("crv")
curve_obj: EllipticCurve
if curve == "P-256": if curve == "P-256":
if len(x) == len(y) == 32: if len(x) == len(y) == 32:
curve_obj = ec.SECP256R1() curve_obj = SECP256R1()
else: else:
raise InvalidKeyError("Coords should be 32 bytes for curve P-256") raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
elif curve == "P-384": elif curve == "P-384":
if len(x) == len(y) == 48: if len(x) == len(y) == 48:
curve_obj = ec.SECP384R1() curve_obj = SECP384R1()
else: else:
raise InvalidKeyError("Coords should be 48 bytes for curve P-384") raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
elif curve == "P-521": elif curve == "P-521":
if len(x) == len(y) == 66: if len(x) == len(y) == 66:
curve_obj = ec.SECP521R1() curve_obj = SECP521R1()
else: else:
raise InvalidKeyError("Coords should be 66 bytes for curve P-521") raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
elif curve == "secp256k1": elif curve == "secp256k1":
if len(x) == len(y) == 32: if len(x) == len(y) == 32:
curve_obj = ec.SECP256K1() curve_obj = SECP256K1()
else: else:
raise InvalidKeyError( raise InvalidKeyError(
"Coords should be 32 bytes for curve secp256k1" "Coords should be 32 bytes for curve secp256k1"
@ -521,7 +648,7 @@ if has_crypto:
else: else:
raise InvalidKeyError(f"Invalid curve: {curve}") raise InvalidKeyError(f"Invalid curve: {curve}")
public_numbers = ec.EllipticCurvePublicNumbers( public_numbers = EllipticCurvePublicNumbers(
x=int.from_bytes(x, byteorder="big"), x=int.from_bytes(x, byteorder="big"),
y=int.from_bytes(y, byteorder="big"), y=int.from_bytes(y, byteorder="big"),
curve=curve_obj, curve=curve_obj,
@ -536,7 +663,7 @@ if has_crypto:
"D should be {} bytes for curve {}", len(x), curve "D should be {} bytes for curve {}", len(x), curve
) )
return ec.EllipticCurvePrivateNumbers( return EllipticCurvePrivateNumbers(
int.from_bytes(d, byteorder="big"), public_numbers int.from_bytes(d, byteorder="big"), public_numbers
).private_key() ).private_key()
@ -545,24 +672,24 @@ if has_crypto:
Performs a signature using RSASSA-PSS with MGF1 Performs a signature using RSASSA-PSS with MGF1
""" """
def sign(self, msg, key): def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
return key.sign( return key.sign(
msg, msg,
padding.PSS( padding.PSS(
mgf=padding.MGF1(self.hash_alg()), mgf=padding.MGF1(self.hash_alg()),
salt_length=self.hash_alg.digest_size, salt_length=self.hash_alg().digest_size,
), ),
self.hash_alg(), self.hash_alg(),
) )
def verify(self, msg, key, sig): def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
try: try:
key.verify( key.verify(
sig, sig,
msg, msg,
padding.PSS( padding.PSS(
mgf=padding.MGF1(self.hash_alg()), mgf=padding.MGF1(self.hash_alg()),
salt_length=self.hash_alg.digest_size, salt_length=self.hash_alg().digest_size,
), ),
self.hash_alg(), self.hash_alg(),
) )
@ -577,21 +704,20 @@ if has_crypto:
This class requires ``cryptography>=2.6`` to be installed. This class requires ``cryptography>=2.6`` to be installed.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs: Any) -> None:
pass pass
def prepare_key(self, key): def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
if isinstance(key, (bytes, str)): if isinstance(key, (bytes, str)):
if isinstance(key, str): key_str = key.decode("utf-8") if isinstance(key, bytes) else key
key = key.encode("utf-8") key_bytes = key.encode("utf-8") if isinstance(key, str) else key
str_key = key.decode("utf-8")
if "-----BEGIN PUBLIC" in str_key: if "-----BEGIN PUBLIC" in key_str:
key = load_pem_public_key(key) key = load_pem_public_key(key_bytes) # type: ignore[assignment]
elif "-----BEGIN PRIVATE" in str_key: elif "-----BEGIN PRIVATE" in key_str:
key = load_pem_private_key(key, password=None) key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
elif str_key[0:4] == "ssh-": elif key_str[0:4] == "ssh-":
key = load_ssh_public_key(key) key = load_ssh_public_key(key_bytes) # type: ignore[assignment]
# Explicit check the key to prevent confusing errors from cryptography # Explicit check the key to prevent confusing errors from cryptography
if not isinstance( if not isinstance(
@ -604,7 +730,9 @@ if has_crypto:
return key return key
def sign(self, msg, key): def sign(
self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
) -> bytes:
""" """
Sign a message ``msg`` using the EdDSA private key ``key`` Sign a message ``msg`` using the EdDSA private key ``key``
:param str|bytes msg: Message to sign :param str|bytes msg: Message to sign
@ -612,10 +740,12 @@ if has_crypto:
or :class:`.Ed448PrivateKey` isinstance or :class:`.Ed448PrivateKey` isinstance
:return bytes signature: The signature, as bytes :return bytes signature: The signature, as bytes
""" """
msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
return key.sign(msg) return key.sign(msg_bytes)
def verify(self, msg, key, sig): def verify(
self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
) -> bool:
""" """
Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
@ -626,31 +756,48 @@ if has_crypto:
:return bool verified: True if signature is valid, False if not. :return bool verified: True if signature is valid, False if not.
""" """
try: try:
msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
sig = bytes(sig, "utf-8") if type(sig) is not bytes else sig sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): public_key = (
key = key.public_key() key.public_key()
key.verify(sig, msg) if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
else key
)
public_key.verify(sig_bytes, msg_bytes)
return True # If no exception was raised, the signature is valid. return True # If no exception was raised, the signature is valid.
except cryptography.exceptions.InvalidSignature: except InvalidSignature:
return False return False
@overload
@staticmethod @staticmethod
def to_jwk(key): def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover
@overload
@staticmethod
def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str:
... # pragma: no cover
@staticmethod
def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]:
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
x = key.public_bytes( x = key.public_bytes(
encoding=Encoding.Raw, encoding=Encoding.Raw,
format=PublicFormat.Raw, format=PublicFormat.Raw,
) )
crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448" crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
return json.dumps(
{ obj = {
"x": base64url_encode(force_bytes(x)).decode(), "x": base64url_encode(force_bytes(x)).decode(),
"kty": "OKP", "kty": "OKP",
"crv": crv, "crv": crv,
} }
)
if as_dict:
return obj
else:
return json.dumps(obj)
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
d = key.private_bytes( d = key.private_bytes(
@ -665,19 +812,22 @@ if has_crypto:
) )
crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448" crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
return json.dumps( obj = {
{ "x": base64url_encode(force_bytes(x)).decode(),
"x": base64url_encode(force_bytes(x)).decode(), "d": base64url_encode(force_bytes(d)).decode(),
"d": base64url_encode(force_bytes(d)).decode(), "kty": "OKP",
"kty": "OKP", "crv": crv,
"crv": crv, }
}
) if as_dict:
return obj
else:
return json.dumps(obj)
raise InvalidKeyError("Not a public or private key") raise InvalidKeyError("Not a public or private key")
@staticmethod @staticmethod
def from_jwk(jwk): def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
try: try:
if isinstance(jwk, str): if isinstance(jwk, str):
obj = json.loads(jwk) obj = json.loads(jwk)

View file

@ -2,13 +2,15 @@ from __future__ import annotations
import json import json
import time import time
from typing import Any
from .algorithms import get_default_algorithms from .algorithms import get_default_algorithms, has_crypto, requires_cryptography
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError, PyJWTError
from .types import JWKDict
class PyJWK: class PyJWK:
def __init__(self, jwk_data, algorithm=None): def __init__(self, jwk_data: JWKDict, algorithm: str | None = None) -> None:
self._algorithms = get_default_algorithms() self._algorithms = get_default_algorithms()
self._jwk_data = jwk_data self._jwk_data = jwk_data
@ -47,37 +49,40 @@ class PyJWK:
else: else:
raise InvalidKeyError(f"Unsupported kty: {kty}") raise InvalidKeyError(f"Unsupported kty: {kty}")
if not has_crypto and algorithm in requires_cryptography:
raise PyJWKError(f"{algorithm} requires 'cryptography' to be installed.")
self.Algorithm = self._algorithms.get(algorithm) self.Algorithm = self._algorithms.get(algorithm)
if not self.Algorithm: if not self.Algorithm:
raise PyJWKError(f"Unable to find a algorithm for key: {self._jwk_data}") raise PyJWKError(f"Unable to find an algorithm for key: {self._jwk_data}")
self.key = self.Algorithm.from_jwk(self._jwk_data) self.key = self.Algorithm.from_jwk(self._jwk_data)
@staticmethod @staticmethod
def from_dict(obj, algorithm=None): def from_dict(obj: JWKDict, algorithm: str | None = None) -> "PyJWK":
return PyJWK(obj, algorithm) return PyJWK(obj, algorithm)
@staticmethod @staticmethod
def from_json(data, algorithm=None): def from_json(data: str, algorithm: None = None) -> "PyJWK":
obj = json.loads(data) obj = json.loads(data)
return PyJWK.from_dict(obj, algorithm) return PyJWK.from_dict(obj, algorithm)
@property @property
def key_type(self): def key_type(self) -> str | None:
return self._jwk_data.get("kty", None) return self._jwk_data.get("kty", None)
@property @property
def key_id(self): def key_id(self) -> str | None:
return self._jwk_data.get("kid", None) return self._jwk_data.get("kid", None)
@property @property
def public_key_use(self): def public_key_use(self) -> str | None:
return self._jwk_data.get("use", None) return self._jwk_data.get("use", None)
class PyJWKSet: class PyJWKSet:
def __init__(self, keys: list[dict]) -> None: def __init__(self, keys: list[JWKDict]) -> None:
self.keys = [] self.keys = []
if not keys: if not keys:
@ -89,24 +94,26 @@ class PyJWKSet:
for key in keys: for key in keys:
try: try:
self.keys.append(PyJWK(key)) self.keys.append(PyJWK(key))
except PyJWKError: except PyJWTError:
# skip unusable keys # skip unusable keys
continue continue
if len(self.keys) == 0: if len(self.keys) == 0:
raise PyJWKSetError("The JWK Set did not contain any usable keys") raise PyJWKSetError(
"The JWK Set did not contain any usable keys. Perhaps 'cryptography' is not installed?"
)
@staticmethod @staticmethod
def from_dict(obj): def from_dict(obj: dict[str, Any]) -> "PyJWKSet":
keys = obj.get("keys", []) keys = obj.get("keys", [])
return PyJWKSet(keys) return PyJWKSet(keys)
@staticmethod @staticmethod
def from_json(data): def from_json(data: str) -> "PyJWKSet":
obj = json.loads(data) obj = json.loads(data)
return PyJWKSet.from_dict(obj) return PyJWKSet.from_dict(obj)
def __getitem__(self, kid): def __getitem__(self, kid: str) -> "PyJWK":
for key in self.keys: for key in self.keys:
if key.key_id == kid: if key.key_id == kid:
return key return key
@ -118,8 +125,8 @@ class PyJWTSetWithTimestamp:
self.jwk_set = jwk_set self.jwk_set = jwk_set
self.timestamp = time.monotonic() self.timestamp = time.monotonic()
def get_jwk_set(self): def get_jwk_set(self) -> PyJWKSet:
return self.jwk_set return self.jwk_set
def get_timestamp(self): def get_timestamp(self) -> float:
return self.timestamp return self.timestamp

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import binascii import binascii
import json import json
import warnings import warnings
from typing import Any, Type from typing import TYPE_CHECKING, Any
from .algorithms import ( from .algorithms import (
Algorithm, Algorithm,
@ -20,11 +20,18 @@ from .exceptions import (
from .utils import base64url_decode, base64url_encode from .utils import base64url_decode, base64url_encode
from .warnings import RemovedInPyjwt3Warning from .warnings import RemovedInPyjwt3Warning
if TYPE_CHECKING:
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
class PyJWS: class PyJWS:
header_typ = "JWT" header_typ = "JWT"
def __init__(self, algorithms=None, options=None) -> None: def __init__(
self,
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
) -> None:
self._algorithms = get_default_algorithms() self._algorithms = get_default_algorithms()
self._valid_algs = ( self._valid_algs = (
set(algorithms) if algorithms is not None else set(self._algorithms) set(algorithms) if algorithms is not None else set(self._algorithms)
@ -96,11 +103,12 @@ class PyJWS:
def encode( def encode(
self, self,
payload: bytes, payload: bytes,
key: str, key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256", algorithm: str | None = "HS256",
headers: dict[str, Any] | None = None, headers: dict[str, Any] | None = None,
json_encoder: Type[json.JSONEncoder] | None = None, json_encoder: type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False, is_payload_detached: bool = False,
sort_headers: bool = True,
) -> str: ) -> str:
segments = [] segments = []
@ -133,9 +141,8 @@ class PyJWS:
# True is the standard value for b64, so no need for it # True is the standard value for b64, so no need for it
del header["b64"] del header["b64"]
# Fix for headers misorder - issue #715
json_header = json.dumps( json_header = json.dumps(
header, separators=(",", ":"), cls=json_encoder, sort_keys=True header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers
).encode() ).encode()
segments.append(base64url_encode(json_header)) segments.append(base64url_encode(json_header))
@ -164,8 +171,8 @@ class PyJWS:
def decode_complete( def decode_complete(
self, self,
jwt: str, jwt: str | bytes,
key: str = "", key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None, algorithms: list[str] | None = None,
options: dict[str, Any] | None = None, options: dict[str, Any] | None = None,
detached_payload: bytes | None = None, detached_payload: bytes | None = None,
@ -209,13 +216,13 @@ class PyJWS:
def decode( def decode(
self, self,
jwt: str, jwt: str | bytes,
key: str = "", key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None, algorithms: list[str] | None = None,
options: dict[str, Any] | None = None, options: dict[str, Any] | None = None,
detached_payload: bytes | None = None, detached_payload: bytes | None = None,
**kwargs, **kwargs,
) -> str: ) -> Any:
if kwargs: if kwargs:
warnings.warn( warnings.warn(
"passing additional kwargs to decode() is deprecated " "passing additional kwargs to decode() is deprecated "
@ -228,7 +235,7 @@ class PyJWS:
) )
return decoded["payload"] return decoded["payload"]
def get_unverified_header(self, jwt: str | bytes) -> dict: def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]:
"""Returns back the JWT header parameters as a dict() """Returns back the JWT header parameters as a dict()
Note: The signature is not verified so the header parameters Note: The signature is not verified so the header parameters
@ -239,7 +246,7 @@ class PyJWS:
return headers return headers
def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]: def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]:
if isinstance(jwt, str): if isinstance(jwt, str):
jwt = jwt.encode("utf-8") jwt = jwt.encode("utf-8")
@ -280,13 +287,15 @@ class PyJWS:
def _verify_signature( def _verify_signature(
self, self,
signing_input: bytes, signing_input: bytes,
header: dict, header: dict[str, Any],
signature: bytes, signature: bytes,
key: str = "", key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None, algorithms: list[str] | None = None,
) -> None: ) -> None:
try:
alg = header.get("alg") alg = header["alg"]
except KeyError:
raise InvalidAlgorithmError("Algorithm not specified")
if not alg or (algorithms is not None and alg not in algorithms): if not alg or (algorithms is not None and alg not in algorithms):
raise InvalidAlgorithmError("The specified alg value is not allowed") raise InvalidAlgorithmError("The specified alg value is not allowed")
@ -295,16 +304,16 @@ class PyJWS:
alg_obj = self.get_algorithm_by_name(alg) alg_obj = self.get_algorithm_by_name(alg)
except NotImplementedError as e: except NotImplementedError as e:
raise InvalidAlgorithmError("Algorithm not supported") from e raise InvalidAlgorithmError("Algorithm not supported") from e
key = alg_obj.prepare_key(key) prepared_key = alg_obj.prepare_key(key)
if not alg_obj.verify(signing_input, key, signature): if not alg_obj.verify(signing_input, prepared_key, signature):
raise InvalidSignatureError("Signature verification failed") raise InvalidSignatureError("Signature verification failed")
def _validate_headers(self, headers: dict[str, Any]) -> None: def _validate_headers(self, headers: dict[str, Any]) -> None:
if "kid" in headers: if "kid" in headers:
self._validate_kid(headers["kid"]) self._validate_kid(headers["kid"])
def _validate_kid(self, kid: str) -> None: def _validate_kid(self, kid: Any) -> None:
if not isinstance(kid, str): if not isinstance(kid, str):
raise InvalidTokenError("Key ID header parameter must be a string") raise InvalidTokenError("Key ID header parameter must be a string")

View file

@ -3,9 +3,9 @@ from __future__ import annotations
import json import json
import warnings import warnings
from calendar import timegm from calendar import timegm
from collections.abc import Iterable, Mapping from collections.abc import Iterable
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Type, Union from typing import TYPE_CHECKING, Any
from . import api_jws from . import api_jws
from .exceptions import ( from .exceptions import (
@ -19,15 +19,18 @@ from .exceptions import (
) )
from .warnings import RemovedInPyjwt3Warning from .warnings import RemovedInPyjwt3Warning
if TYPE_CHECKING:
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
class PyJWT: class PyJWT:
def __init__(self, options=None): def __init__(self, options: dict[str, Any] | None = None) -> None:
if options is None: if options is None:
options = {} options = {}
self.options = {**self._get_default_options(), **options} self.options: dict[str, Any] = {**self._get_default_options(), **options}
@staticmethod @staticmethod
def _get_default_options() -> Dict[str, Union[bool, List[str]]]: def _get_default_options() -> dict[str, bool | list[str]]:
return { return {
"verify_signature": True, "verify_signature": True,
"verify_exp": True, "verify_exp": True,
@ -40,16 +43,17 @@ class PyJWT:
def encode( def encode(
self, self,
payload: Dict[str, Any], payload: dict[str, Any],
key: str, key: AllowedPrivateKeys | str | bytes,
algorithm: Optional[str] = "HS256", algorithm: str | None = "HS256",
headers: Optional[Dict[str, Any]] = None, headers: dict[str, Any] | None = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None, json_encoder: type[json.JSONEncoder] | None = None,
sort_headers: bool = True,
) -> str: ) -> str:
# Check that we get a mapping # Check that we get a dict
if not isinstance(payload, Mapping): if not isinstance(payload, dict):
raise TypeError( raise TypeError(
"Expecting a mapping object, as JWT only supports " "Expecting a dict object, as JWT only supports "
"JSON objects as payloads." "JSON objects as payloads."
) )
@ -60,30 +64,57 @@ class PyJWT:
if isinstance(payload.get(time_claim), datetime): if isinstance(payload.get(time_claim), datetime):
payload[time_claim] = timegm(payload[time_claim].utctimetuple()) payload[time_claim] = timegm(payload[time_claim].utctimetuple())
json_payload = json.dumps( json_payload = self._encode_payload(
payload, separators=(",", ":"), cls=json_encoder payload,
).encode("utf-8") headers=headers,
json_encoder=json_encoder,
)
return api_jws.encode(json_payload, key, algorithm, headers, json_encoder) return api_jws.encode(
json_payload,
key,
algorithm,
headers,
json_encoder,
sort_headers=sort_headers,
)
def _encode_payload(
self,
payload: dict[str, Any],
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
) -> bytes:
"""
Encode a given payload to the bytes to be signed.
This method is intended to be overridden by subclasses that need to
encode the payload in a different way, e.g. compress the payload.
"""
return json.dumps(
payload,
separators=(",", ":"),
cls=json_encoder,
).encode("utf-8")
def decode_complete( def decode_complete(
self, self,
jwt: str, jwt: str | bytes,
key: str = "", key: AllowedPublicKeys | str | bytes = "",
algorithms: Optional[List[str]] = None, algorithms: list[str] | None = None,
options: Optional[Dict[str, Any]] = None, options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3 # deprecated arg, remove in pyjwt3
verify: Optional[bool] = None, verify: bool | None = None,
# could be used as passthrough to api_jws, consider removal in pyjwt3 # could be used as passthrough to api_jws, consider removal in pyjwt3
detached_payload: Optional[bytes] = None, detached_payload: bytes | None = None,
# passthrough arguments to _validate_claims # passthrough arguments to _validate_claims
# consider putting in options # consider putting in options
audience: Optional[Union[str, Iterable[str]]] = None, audience: str | Iterable[str] | None = None,
issuer: Optional[str] = None, issuer: str | None = None,
leeway: Union[int, float, timedelta] = 0, leeway: float | timedelta = 0,
# kwargs # kwargs
**kwargs, **kwargs: Any,
) -> Dict[str, Any]: ) -> dict[str, Any]:
if kwargs: if kwargs:
warnings.warn( warnings.warn(
"passing additional kwargs to decode_complete() is deprecated " "passing additional kwargs to decode_complete() is deprecated "
@ -125,12 +156,7 @@ class PyJWT:
detached_payload=detached_payload, detached_payload=detached_payload,
) )
try: payload = self._decode_payload(decoded)
payload = json.loads(decoded["payload"])
except ValueError as e:
raise DecodeError(f"Invalid payload string: {e}")
if not isinstance(payload, dict):
raise DecodeError("Invalid payload string: must be a json object")
merged_options = {**self.options, **options} merged_options = {**self.options, **options}
self._validate_claims( self._validate_claims(
@ -140,24 +166,40 @@ class PyJWT:
decoded["payload"] = payload decoded["payload"] = payload
return decoded return decoded
def _decode_payload(self, decoded: dict[str, Any]) -> Any:
"""
Decode the payload from a JWS dictionary (payload, signature, header).
This method is intended to be overridden by subclasses that need to
decode the payload in a different way, e.g. decompress compressed
payloads.
"""
try:
payload = json.loads(decoded["payload"])
except ValueError as e:
raise DecodeError(f"Invalid payload string: {e}")
if not isinstance(payload, dict):
raise DecodeError("Invalid payload string: must be a json object")
return payload
def decode( def decode(
self, self,
jwt: str, jwt: str | bytes,
key: str = "", key: AllowedPublicKeys | str | bytes = "",
algorithms: Optional[List[str]] = None, algorithms: list[str] | None = None,
options: Optional[Dict[str, Any]] = None, options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3 # deprecated arg, remove in pyjwt3
verify: Optional[bool] = None, verify: bool | None = None,
# could be used as passthrough to api_jws, consider removal in pyjwt3 # could be used as passthrough to api_jws, consider removal in pyjwt3
detached_payload: Optional[bytes] = None, detached_payload: bytes | None = None,
# passthrough arguments to _validate_claims # passthrough arguments to _validate_claims
# consider putting in options # consider putting in options
audience: Optional[Union[str, Iterable[str]]] = None, audience: str | Iterable[str] | None = None,
issuer: Optional[str] = None, issuer: str | None = None,
leeway: Union[int, float, timedelta] = 0, leeway: float | timedelta = 0,
# kwargs # kwargs
**kwargs, **kwargs: Any,
) -> Dict[str, Any]: ) -> Any:
if kwargs: if kwargs:
warnings.warn( warnings.warn(
"passing additional kwargs to decode() is deprecated " "passing additional kwargs to decode() is deprecated "
@ -178,7 +220,14 @@ class PyJWT:
) )
return decoded["payload"] return decoded["payload"]
def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0): def _validate_claims(
self,
payload: dict[str, Any],
options: dict[str, Any],
audience=None,
issuer=None,
leeway: float | timedelta = 0,
) -> None:
if isinstance(leeway, timedelta): if isinstance(leeway, timedelta):
leeway = leeway.total_seconds() leeway = leeway.total_seconds()
@ -187,7 +236,7 @@ class PyJWT:
self._validate_required_claims(payload, options) self._validate_required_claims(payload, options)
now = timegm(datetime.now(tz=timezone.utc).utctimetuple()) now = datetime.now(tz=timezone.utc).timestamp()
if "iat" in payload and options["verify_iat"]: if "iat" in payload and options["verify_iat"]:
self._validate_iat(payload, now, leeway) self._validate_iat(payload, now, leeway)
@ -202,23 +251,38 @@ class PyJWT:
self._validate_iss(payload, issuer) self._validate_iss(payload, issuer)
if options["verify_aud"]: if options["verify_aud"]:
self._validate_aud(payload, audience) self._validate_aud(
payload, audience, strict=options.get("strict_aud", False)
)
def _validate_required_claims(self, payload, options): def _validate_required_claims(
self,
payload: dict[str, Any],
options: dict[str, Any],
) -> None:
for claim in options["require"]: for claim in options["require"]:
if payload.get(claim) is None: if payload.get(claim) is None:
raise MissingRequiredClaimError(claim) raise MissingRequiredClaimError(claim)
def _validate_iat(self, payload, now, leeway): def _validate_iat(
iat = payload["iat"] self,
payload: dict[str, Any],
now: float,
leeway: float,
) -> None:
try: try:
int(iat) iat = int(payload["iat"])
except ValueError: except ValueError:
raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.") raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.")
if iat > (now + leeway): if iat > (now + leeway):
raise ImmatureSignatureError("The token is not yet valid (iat)") raise ImmatureSignatureError("The token is not yet valid (iat)")
def _validate_nbf(self, payload, now, leeway): def _validate_nbf(
self,
payload: dict[str, Any],
now: float,
leeway: float,
) -> None:
try: try:
nbf = int(payload["nbf"]) nbf = int(payload["nbf"])
except ValueError: except ValueError:
@ -227,7 +291,12 @@ class PyJWT:
if nbf > (now + leeway): if nbf > (now + leeway):
raise ImmatureSignatureError("The token is not yet valid (nbf)") raise ImmatureSignatureError("The token is not yet valid (nbf)")
def _validate_exp(self, payload, now, leeway): def _validate_exp(
self,
payload: dict[str, Any],
now: float,
leeway: float,
) -> None:
try: try:
exp = int(payload["exp"]) exp = int(payload["exp"])
except ValueError: except ValueError:
@ -236,7 +305,13 @@ class PyJWT:
if exp <= (now - leeway): if exp <= (now - leeway):
raise ExpiredSignatureError("Signature has expired") raise ExpiredSignatureError("Signature has expired")
def _validate_aud(self, payload, audience): def _validate_aud(
self,
payload: dict[str, Any],
audience: str | Iterable[str] | None,
*,
strict: bool = False,
) -> None:
if audience is None: if audience is None:
if "aud" not in payload or not payload["aud"]: if "aud" not in payload or not payload["aud"]:
return return
@ -251,6 +326,22 @@ class PyJWT:
audience_claims = payload["aud"] audience_claims = payload["aud"]
# In strict mode, we forbid list matching: the supplied audience
# must be a string, and it must exactly match the audience claim.
if strict:
# Only a single audience is allowed in strict mode.
if not isinstance(audience, str):
raise InvalidAudienceError("Invalid audience (strict)")
# Only a single audience claim is allowed in strict mode.
if not isinstance(audience_claims, str):
raise InvalidAudienceError("Invalid claim format in token (strict)")
if audience != audience_claims:
raise InvalidAudienceError("Audience doesn't match (strict)")
return
if isinstance(audience_claims, str): if isinstance(audience_claims, str):
audience_claims = [audience_claims] audience_claims = [audience_claims]
if not isinstance(audience_claims, list): if not isinstance(audience_claims, list):
@ -262,9 +353,9 @@ class PyJWT:
audience = [audience] audience = [audience]
if all(aud not in audience_claims for aud in audience): if all(aud not in audience_claims for aud in audience):
raise InvalidAudienceError("Invalid audience") raise InvalidAudienceError("Audience doesn't match")
def _validate_iss(self, payload, issuer): def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None:
if issuer is None: if issuer is None:
return return

View file

@ -47,10 +47,10 @@ class InvalidAlgorithmError(InvalidTokenError):
class MissingRequiredClaimError(InvalidTokenError): class MissingRequiredClaimError(InvalidTokenError):
def __init__(self, claim): def __init__(self, claim: str) -> None:
self.claim = claim self.claim = claim
def __str__(self): def __str__(self) -> str:
return f'Token is missing the "{self.claim}" claim' return f'Token is missing the "{self.claim}" claim'
@ -64,3 +64,7 @@ class PyJWKSetError(PyJWTError):
class PyJWKClientError(PyJWTError): class PyJWKClientError(PyJWTError):
pass pass
class PyJWKClientConnectionError(PyJWKClientError):
pass

View file

@ -7,8 +7,10 @@ from . import __version__ as pyjwt_version
try: try:
import cryptography import cryptography
cryptography_version = cryptography.__version__
except ModuleNotFoundError: except ModuleNotFoundError:
cryptography = None cryptography_version = ""
def info() -> Dict[str, Dict[str, str]]: def info() -> Dict[str, Dict[str, str]]:
@ -29,7 +31,7 @@ def info() -> Dict[str, Dict[str, str]]:
if implementation == "CPython": if implementation == "CPython":
implementation_version = platform.python_version() implementation_version = platform.python_version()
elif implementation == "PyPy": elif implementation == "PyPy":
pypy_version_info = getattr(sys, "pypy_version_info") pypy_version_info = sys.pypy_version_info # type: ignore[attr-defined]
implementation_version = ( implementation_version = (
f"{pypy_version_info.major}." f"{pypy_version_info.major}."
f"{pypy_version_info.minor}." f"{pypy_version_info.minor}."
@ -48,7 +50,7 @@ def info() -> Dict[str, Dict[str, str]]:
"name": implementation, "name": implementation,
"version": implementation_version, "version": implementation_version,
}, },
"cryptography": {"version": getattr(cryptography, "__version__", "")}, "cryptography": {"version": cryptography_version},
"pyjwt": {"version": pyjwt_version}, "pyjwt": {"version": pyjwt_version},
} }

View file

@ -5,11 +5,11 @@ from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp
class JWKSetCache: class JWKSetCache:
def __init__(self, lifespan: int): def __init__(self, lifespan: int) -> None:
self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
self.lifespan = lifespan self.lifespan = lifespan
def put(self, jwk_set: PyJWKSet): def put(self, jwk_set: PyJWKSet) -> None:
if jwk_set is not None: if jwk_set is not None:
self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set) self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
else: else:
@ -23,7 +23,6 @@ class JWKSetCache:
return self.jwk_set_with_timestamp.get_jwk_set() return self.jwk_set_with_timestamp.get_jwk_set()
def is_expired(self) -> bool: def is_expired(self) -> bool:
return ( return (
self.jwk_set_with_timestamp is not None self.jwk_set_with_timestamp is not None
and self.lifespan > -1 and self.lifespan > -1

View file

@ -1,12 +1,13 @@
import json import json
import urllib.request import urllib.request
from functools import lru_cache from functools import lru_cache
from typing import Any, List, Optional from ssl import SSLContext
from typing import Any, Dict, List, Optional
from urllib.error import URLError from urllib.error import URLError
from .api_jwk import PyJWK, PyJWKSet from .api_jwk import PyJWK, PyJWKSet
from .api_jwt import decode_complete as decode_token from .api_jwt import decode_complete as decode_token
from .exceptions import PyJWKClientError from .exceptions import PyJWKClientConnectionError, PyJWKClientError
from .jwk_set_cache import JWKSetCache from .jwk_set_cache import JWKSetCache
@ -18,9 +19,17 @@ class PyJWKClient:
max_cached_keys: int = 16, max_cached_keys: int = 16,
cache_jwk_set: bool = True, cache_jwk_set: bool = True,
lifespan: int = 300, lifespan: int = 300,
headers: Optional[Dict[str, Any]] = None,
timeout: int = 30,
ssl_context: Optional[SSLContext] = None,
): ):
if headers is None:
headers = {}
self.uri = uri self.uri = uri
self.jwk_set_cache: Optional[JWKSetCache] = None self.jwk_set_cache: Optional[JWKSetCache] = None
self.headers = headers
self.timeout = timeout
self.ssl_context = ssl_context
if cache_jwk_set: if cache_jwk_set:
# Init jwt set cache with default or given lifespan. # Init jwt set cache with default or given lifespan.
@ -41,10 +50,15 @@ class PyJWKClient:
def fetch_data(self) -> Any: def fetch_data(self) -> Any:
jwk_set: Any = None jwk_set: Any = None
try: try:
with urllib.request.urlopen(self.uri) as response: r = urllib.request.Request(url=self.uri, headers=self.headers)
with urllib.request.urlopen(
r, timeout=self.timeout, context=self.ssl_context
) as response:
jwk_set = json.load(response) jwk_set = json.load(response)
except URLError as e: except (URLError, TimeoutError) as e:
raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"') raise PyJWKClientConnectionError(
f'Fail to fetch data from the url, err: "{e}"'
)
else: else:
return jwk_set return jwk_set
finally: finally:
@ -59,6 +73,9 @@ class PyJWKClient:
if data is None: if data is None:
data = self.fetch_data() data = self.fetch_data()
if not isinstance(data, dict):
raise PyJWKClientError("The JWKS endpoint did not return a JSON object")
return PyJWKSet.from_dict(data) return PyJWKSet.from_dict(data)
def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]: def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:

5
lib/jwt/types.py Normal file
View file

@ -0,0 +1,5 @@
from typing import Any, Callable, Dict
JWKDict = Dict[str, Any]
HashlibHash = Callable[..., Any]

View file

@ -10,10 +10,10 @@ try:
encode_dss_signature, encode_dss_signature,
) )
except ModuleNotFoundError: except ModuleNotFoundError:
EllipticCurve = None pass
def force_bytes(value: Union[str, bytes]) -> bytes: def force_bytes(value: Union[bytes, str]) -> bytes:
if isinstance(value, str): if isinstance(value, str):
return value.encode("utf-8") return value.encode("utf-8")
elif isinstance(value, bytes): elif isinstance(value, bytes):
@ -22,16 +22,15 @@ def force_bytes(value: Union[str, bytes]) -> bytes:
raise TypeError("Expected a string value") raise TypeError("Expected a string value")
def base64url_decode(input: Union[str, bytes]) -> bytes: def base64url_decode(input: Union[bytes, str]) -> bytes:
if isinstance(input, str): input_bytes = force_bytes(input)
input = input.encode("ascii")
rem = len(input) % 4 rem = len(input_bytes) % 4
if rem > 0: if rem > 0:
input += b"=" * (4 - rem) input_bytes += b"=" * (4 - rem)
return base64.urlsafe_b64decode(input) return base64.urlsafe_b64decode(input_bytes)
def base64url_encode(input: bytes) -> bytes: def base64url_encode(input: bytes) -> bytes:
@ -50,11 +49,8 @@ def to_base64url_uint(val: int) -> bytes:
return base64url_encode(int_bytes) return base64url_encode(int_bytes)
def from_base64url_uint(val: Union[str, bytes]) -> int: def from_base64url_uint(val: Union[bytes, str]) -> int:
if isinstance(val, str): data = base64url_decode(force_bytes(val))
val = val.encode("ascii")
data = base64url_decode(val)
return int.from_bytes(data, byteorder="big") return int.from_bytes(data, byteorder="big")
@ -78,7 +74,7 @@ def bytes_from_int(val: int) -> bytes:
return val.to_bytes(byte_length, "big", signed=False) return val.to_bytes(byte_length, "big", signed=False)
def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes: def der_to_raw_signature(der_sig: bytes, curve: "EllipticCurve") -> bytes:
num_bits = curve.key_size num_bits = curve.key_size
num_bytes = (num_bits + 7) // 8 num_bytes = (num_bits + 7) // 8
@ -87,7 +83,7 @@ def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes:
return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes) return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes)
def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes: def raw_to_der_signature(raw_sig: bytes, curve: "EllipticCurve") -> bytes:
num_bits = curve.key_size num_bits = curve.key_size
num_bytes = (num_bits + 7) // 8 num_bytes = (num_bits + 7) // 8
@ -97,7 +93,7 @@ def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes:
r = bytes_to_number(raw_sig[:num_bytes]) r = bytes_to_number(raw_sig[:num_bytes])
s = bytes_to_number(raw_sig[num_bytes:]) s = bytes_to_number(raw_sig[num_bytes:])
return encode_dss_signature(r, s) return bytes(encode_dss_signature(r, s))
# Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252 # Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252

View file

@ -31,7 +31,7 @@ paho-mqtt==1.6.1
plexapi==4.13.4 plexapi==4.13.4
portend==3.1.0 portend==3.1.0
profilehooks==1.12.0 profilehooks==1.12.0
PyJWT==2.6.0 PyJWT==2.8.0
pyparsing==3.0.9 pyparsing==3.0.9
python-dateutil==2.8.2 python-dateutil==2.8.2
python-twitter==3.5 python-twitter==3.5