mirror of
https://github.com/Tautulli/Tautulli.git
synced 2025-07-06 21:21:15 -07:00
Bump pyjwt from 2.6.0 to 2.8.0 (#2115)
* Bump pyjwt from 2.6.0 to 2.8.0 Bumps [pyjwt](https://github.com/jpadilla/pyjwt) from 2.6.0 to 2.8.0. - [Release notes](https://github.com/jpadilla/pyjwt/releases) - [Changelog](https://github.com/jpadilla/pyjwt/blob/master/CHANGELOG.rst) - [Commits](https://github.com/jpadilla/pyjwt/compare/2.6.0...2.8.0) --- updated-dependencies: - dependency-name: pyjwt dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Update pyjwt==2.8.0 --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> [skip ci]
This commit is contained in:
parent
77f38bbf93
commit
c93f470371
12 changed files with 542 additions and 260 deletions
|
@ -19,6 +19,7 @@ from .exceptions import (
|
||||||
InvalidSignatureError,
|
InvalidSignatureError,
|
||||||
InvalidTokenError,
|
InvalidTokenError,
|
||||||
MissingRequiredClaimError,
|
MissingRequiredClaimError,
|
||||||
|
PyJWKClientConnectionError,
|
||||||
PyJWKClientError,
|
PyJWKClientError,
|
||||||
PyJWKError,
|
PyJWKError,
|
||||||
PyJWKSetError,
|
PyJWKSetError,
|
||||||
|
@ -26,7 +27,7 @@ from .exceptions import (
|
||||||
)
|
)
|
||||||
from .jwks_client import PyJWKClient
|
from .jwks_client import PyJWKClient
|
||||||
|
|
||||||
__version__ = "2.6.0"
|
__version__ = "2.8.0"
|
||||||
|
|
||||||
__title__ = "PyJWT"
|
__title__ = "PyJWT"
|
||||||
__description__ = "JSON Web Token implementation in Python"
|
__description__ = "JSON Web Token implementation in Python"
|
||||||
|
@ -65,6 +66,7 @@ __all__ = [
|
||||||
"InvalidSignatureError",
|
"InvalidSignatureError",
|
||||||
"InvalidTokenError",
|
"InvalidTokenError",
|
||||||
"MissingRequiredClaimError",
|
"MissingRequiredClaimError",
|
||||||
|
"PyJWKClientConnectionError",
|
||||||
"PyJWKClientError",
|
"PyJWKClientError",
|
||||||
"PyJWKError",
|
"PyJWKError",
|
||||||
"PyJWKSetError",
|
"PyJWKSetError",
|
||||||
|
|
|
@ -1,8 +1,14 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
import json
|
import json
|
||||||
|
import sys
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload
|
||||||
|
|
||||||
from .exceptions import InvalidKeyError
|
from .exceptions import InvalidKeyError
|
||||||
|
from .types import HashlibHash, JWKDict
|
||||||
from .utils import (
|
from .utils import (
|
||||||
base64url_decode,
|
base64url_decode,
|
||||||
base64url_encode,
|
base64url_encode,
|
||||||
|
@ -15,14 +21,28 @@ from .utils import (
|
||||||
to_base64url_uint,
|
to_base64url_uint,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 8):
|
||||||
|
from typing import Literal
|
||||||
|
else:
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cryptography.exceptions
|
|
||||||
from cryptography.exceptions import InvalidSignature
|
from cryptography.exceptions import InvalidSignature
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
from cryptography.hazmat.primitives import hashes
|
from cryptography.hazmat.primitives import hashes
|
||||||
from cryptography.hazmat.primitives.asymmetric import ec, padding
|
from cryptography.hazmat.primitives.asymmetric import padding
|
||||||
from cryptography.hazmat.primitives.asymmetric.ec import (
|
from cryptography.hazmat.primitives.asymmetric.ec import (
|
||||||
|
ECDSA,
|
||||||
|
SECP256K1,
|
||||||
|
SECP256R1,
|
||||||
|
SECP384R1,
|
||||||
|
SECP521R1,
|
||||||
|
EllipticCurve,
|
||||||
EllipticCurvePrivateKey,
|
EllipticCurvePrivateKey,
|
||||||
|
EllipticCurvePrivateNumbers,
|
||||||
EllipticCurvePublicKey,
|
EllipticCurvePublicKey,
|
||||||
|
EllipticCurvePublicNumbers,
|
||||||
)
|
)
|
||||||
from cryptography.hazmat.primitives.asymmetric.ed448 import (
|
from cryptography.hazmat.primitives.asymmetric.ed448 import (
|
||||||
Ed448PrivateKey,
|
Ed448PrivateKey,
|
||||||
|
@ -56,6 +76,23 @@ try:
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
has_crypto = False
|
has_crypto = False
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
# Type aliases for convenience in algorithms method signatures
|
||||||
|
AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
|
||||||
|
AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
|
||||||
|
AllowedOKPKeys = (
|
||||||
|
Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
|
||||||
|
)
|
||||||
|
AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
|
||||||
|
AllowedPrivateKeys = (
|
||||||
|
RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
|
||||||
|
)
|
||||||
|
AllowedPublicKeys = (
|
||||||
|
RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
requires_cryptography = {
|
requires_cryptography = {
|
||||||
"RS256",
|
"RS256",
|
||||||
"RS384",
|
"RS384",
|
||||||
|
@ -72,7 +109,7 @@ requires_cryptography = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_default_algorithms():
|
def get_default_algorithms() -> dict[str, Algorithm]:
|
||||||
"""
|
"""
|
||||||
Returns the algorithms that are implemented by the library.
|
Returns the algorithms that are implemented by the library.
|
||||||
"""
|
"""
|
||||||
|
@ -106,45 +143,79 @@ def get_default_algorithms():
|
||||||
return default_algorithms
|
return default_algorithms
|
||||||
|
|
||||||
|
|
||||||
class Algorithm:
|
class Algorithm(ABC):
|
||||||
"""
|
"""
|
||||||
The interface for an algorithm used to sign and verify tokens.
|
The interface for an algorithm used to sign and verify tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def prepare_key(self, key):
|
def compute_hash_digest(self, bytestr: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
Compute a hash digest using the specified algorithm's hash algorithm.
|
||||||
|
|
||||||
|
If there is no hash algorithm, raises a NotImplementedError.
|
||||||
|
"""
|
||||||
|
# lookup self.hash_alg if defined in a way that mypy can understand
|
||||||
|
hash_alg = getattr(self, "hash_alg", None)
|
||||||
|
if hash_alg is None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
if (
|
||||||
|
has_crypto
|
||||||
|
and isinstance(hash_alg, type)
|
||||||
|
and issubclass(hash_alg, hashes.HashAlgorithm)
|
||||||
|
):
|
||||||
|
digest = hashes.Hash(hash_alg(), backend=default_backend())
|
||||||
|
digest.update(bytestr)
|
||||||
|
return bytes(digest.finalize())
|
||||||
|
else:
|
||||||
|
return bytes(hash_alg(bytestr).digest())
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def prepare_key(self, key: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
Performs necessary validation and conversions on the key and returns
|
Performs necessary validation and conversions on the key and returns
|
||||||
the key value in the proper format for sign() and verify().
|
the key value in the proper format for sign() and verify().
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def sign(self, msg, key):
|
@abstractmethod
|
||||||
|
def sign(self, msg: bytes, key: Any) -> bytes:
|
||||||
"""
|
"""
|
||||||
Returns a digital signature for the specified message
|
Returns a digital signature for the specified message
|
||||||
using the specified key value.
|
using the specified key value.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def verify(self, msg, key, sig):
|
@abstractmethod
|
||||||
|
def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
|
||||||
"""
|
"""
|
||||||
Verifies that the specified digital signature is valid
|
Verifies that the specified digital signature is valid
|
||||||
for the specified message and key values.
|
for the specified message and key values.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
|
||||||
|
@overload
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict:
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def to_jwk(key_obj, as_dict: Literal[False] = False) -> str:
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_jwk(key_obj):
|
@abstractmethod
|
||||||
|
def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]:
|
||||||
"""
|
"""
|
||||||
Serializes a given RSA key into a JWK
|
Serializes a given key into a JWK
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_jwk(jwk):
|
@abstractmethod
|
||||||
|
def from_jwk(jwk: str | JWKDict) -> Any:
|
||||||
"""
|
"""
|
||||||
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
|
Deserializes a given key from JWK back into a key object
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class NoneAlgorithm(Algorithm):
|
class NoneAlgorithm(Algorithm):
|
||||||
|
@ -153,7 +224,7 @@ class NoneAlgorithm(Algorithm):
|
||||||
operations are required.
|
operations are required.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def prepare_key(self, key):
|
def prepare_key(self, key: str | None) -> None:
|
||||||
if key == "":
|
if key == "":
|
||||||
key = None
|
key = None
|
||||||
|
|
||||||
|
@ -162,12 +233,20 @@ class NoneAlgorithm(Algorithm):
|
||||||
|
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def sign(self, msg, key):
|
def sign(self, msg: bytes, key: None) -> bytes:
|
||||||
return b""
|
return b""
|
||||||
|
|
||||||
def verify(self, msg, key, sig):
|
def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_jwk(jwk: str | JWKDict) -> NoReturn:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class HMACAlgorithm(Algorithm):
|
class HMACAlgorithm(Algorithm):
|
||||||
"""
|
"""
|
||||||
|
@ -175,38 +254,51 @@ class HMACAlgorithm(Algorithm):
|
||||||
and the specified hash function.
|
and the specified hash function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SHA256 = hashlib.sha256
|
SHA256: ClassVar[HashlibHash] = hashlib.sha256
|
||||||
SHA384 = hashlib.sha384
|
SHA384: ClassVar[HashlibHash] = hashlib.sha384
|
||||||
SHA512 = hashlib.sha512
|
SHA512: ClassVar[HashlibHash] = hashlib.sha512
|
||||||
|
|
||||||
def __init__(self, hash_alg):
|
def __init__(self, hash_alg: HashlibHash) -> None:
|
||||||
self.hash_alg = hash_alg
|
self.hash_alg = hash_alg
|
||||||
|
|
||||||
def prepare_key(self, key):
|
def prepare_key(self, key: str | bytes) -> bytes:
|
||||||
key = force_bytes(key)
|
key_bytes = force_bytes(key)
|
||||||
|
|
||||||
if is_pem_format(key) or is_ssh_key(key):
|
if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
|
||||||
raise InvalidKeyError(
|
raise InvalidKeyError(
|
||||||
"The specified key is an asymmetric key or x509 certificate and"
|
"The specified key is an asymmetric key or x509 certificate and"
|
||||||
" should not be used as an HMAC secret."
|
" should not be used as an HMAC secret."
|
||||||
)
|
)
|
||||||
|
|
||||||
return key
|
return key_bytes
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@staticmethod
|
||||||
|
def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict:
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@staticmethod
|
||||||
|
def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str:
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_jwk(key_obj):
|
def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]:
|
||||||
return json.dumps(
|
jwk = {
|
||||||
{
|
"k": base64url_encode(force_bytes(key_obj)).decode(),
|
||||||
"k": base64url_encode(force_bytes(key_obj)).decode(),
|
"kty": "oct",
|
||||||
"kty": "oct",
|
}
|
||||||
}
|
|
||||||
)
|
if as_dict:
|
||||||
|
return jwk
|
||||||
|
else:
|
||||||
|
return json.dumps(jwk)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_jwk(jwk):
|
def from_jwk(jwk: str | JWKDict) -> bytes:
|
||||||
try:
|
try:
|
||||||
if isinstance(jwk, str):
|
if isinstance(jwk, str):
|
||||||
obj = json.loads(jwk)
|
obj: JWKDict = json.loads(jwk)
|
||||||
elif isinstance(jwk, dict):
|
elif isinstance(jwk, dict):
|
||||||
obj = jwk
|
obj = jwk
|
||||||
else:
|
else:
|
||||||
|
@ -219,10 +311,10 @@ class HMACAlgorithm(Algorithm):
|
||||||
|
|
||||||
return base64url_decode(obj["k"])
|
return base64url_decode(obj["k"])
|
||||||
|
|
||||||
def sign(self, msg, key):
|
def sign(self, msg: bytes, key: bytes) -> bytes:
|
||||||
return hmac.new(key, msg, self.hash_alg).digest()
|
return hmac.new(key, msg, self.hash_alg).digest()
|
||||||
|
|
||||||
def verify(self, msg, key, sig):
|
def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
|
||||||
return hmac.compare_digest(sig, self.sign(msg, key))
|
return hmac.compare_digest(sig, self.sign(msg, key))
|
||||||
|
|
||||||
|
|
||||||
|
@ -234,36 +326,49 @@ if has_crypto:
|
||||||
RSASSA-PKCS-v1_5 and the specified hash function.
|
RSASSA-PKCS-v1_5 and the specified hash function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SHA256 = hashes.SHA256
|
SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
|
||||||
SHA384 = hashes.SHA384
|
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
|
||||||
SHA512 = hashes.SHA512
|
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
|
||||||
|
|
||||||
def __init__(self, hash_alg):
|
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
|
||||||
self.hash_alg = hash_alg
|
self.hash_alg = hash_alg
|
||||||
|
|
||||||
def prepare_key(self, key):
|
def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
|
||||||
if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
|
if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
|
||||||
return key
|
return key
|
||||||
|
|
||||||
if not isinstance(key, (bytes, str)):
|
if not isinstance(key, (bytes, str)):
|
||||||
raise TypeError("Expecting a PEM-formatted key.")
|
raise TypeError("Expecting a PEM-formatted key.")
|
||||||
|
|
||||||
key = force_bytes(key)
|
key_bytes = force_bytes(key)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if key.startswith(b"ssh-rsa"):
|
if key_bytes.startswith(b"ssh-rsa"):
|
||||||
key = load_ssh_public_key(key)
|
return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
|
||||||
else:
|
else:
|
||||||
key = load_pem_private_key(key, password=None)
|
return cast(
|
||||||
|
RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
|
||||||
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
key = load_pem_public_key(key)
|
return cast(RSAPublicKey, load_pem_public_key(key_bytes))
|
||||||
return key
|
|
||||||
|
@overload
|
||||||
|
@staticmethod
|
||||||
|
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict:
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@staticmethod
|
||||||
|
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str:
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_jwk(key_obj):
|
def to_jwk(
|
||||||
obj = None
|
key_obj: AllowedRSAKeys, as_dict: bool = False
|
||||||
|
) -> Union[JWKDict, str]:
|
||||||
|
obj: dict[str, Any] | None = None
|
||||||
|
|
||||||
if getattr(key_obj, "private_numbers", None):
|
if hasattr(key_obj, "private_numbers"):
|
||||||
# Private key
|
# Private key
|
||||||
numbers = key_obj.private_numbers()
|
numbers = key_obj.private_numbers()
|
||||||
|
|
||||||
|
@ -280,7 +385,7 @@ if has_crypto:
|
||||||
"qi": to_base64url_uint(numbers.iqmp).decode(),
|
"qi": to_base64url_uint(numbers.iqmp).decode(),
|
||||||
}
|
}
|
||||||
|
|
||||||
elif getattr(key_obj, "verify", None):
|
elif hasattr(key_obj, "verify"):
|
||||||
# Public key
|
# Public key
|
||||||
numbers = key_obj.public_numbers()
|
numbers = key_obj.public_numbers()
|
||||||
|
|
||||||
|
@ -293,10 +398,13 @@ if has_crypto:
|
||||||
else:
|
else:
|
||||||
raise InvalidKeyError("Not a public or private key")
|
raise InvalidKeyError("Not a public or private key")
|
||||||
|
|
||||||
return json.dumps(obj)
|
if as_dict:
|
||||||
|
return obj
|
||||||
|
else:
|
||||||
|
return json.dumps(obj)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_jwk(jwk):
|
def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
|
||||||
try:
|
try:
|
||||||
if isinstance(jwk, str):
|
if isinstance(jwk, str):
|
||||||
obj = json.loads(jwk)
|
obj = json.loads(jwk)
|
||||||
|
@ -360,19 +468,17 @@ if has_crypto:
|
||||||
return numbers.private_key()
|
return numbers.private_key()
|
||||||
elif "n" in obj and "e" in obj:
|
elif "n" in obj and "e" in obj:
|
||||||
# Public key
|
# Public key
|
||||||
numbers = RSAPublicNumbers(
|
return RSAPublicNumbers(
|
||||||
from_base64url_uint(obj["e"]),
|
from_base64url_uint(obj["e"]),
|
||||||
from_base64url_uint(obj["n"]),
|
from_base64url_uint(obj["n"]),
|
||||||
)
|
).public_key()
|
||||||
|
|
||||||
return numbers.public_key()
|
|
||||||
else:
|
else:
|
||||||
raise InvalidKeyError("Not a public or private key")
|
raise InvalidKeyError("Not a public or private key")
|
||||||
|
|
||||||
def sign(self, msg, key):
|
def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
|
||||||
return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
|
return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
|
||||||
|
|
||||||
def verify(self, msg, key, sig):
|
def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
|
||||||
try:
|
try:
|
||||||
key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
|
key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
|
||||||
return True
|
return True
|
||||||
|
@ -385,63 +491,79 @@ if has_crypto:
|
||||||
ECDSA and the specified hash function
|
ECDSA and the specified hash function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SHA256 = hashes.SHA256
|
SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
|
||||||
SHA384 = hashes.SHA384
|
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
|
||||||
SHA512 = hashes.SHA512
|
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
|
||||||
|
|
||||||
def __init__(self, hash_alg):
|
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
|
||||||
self.hash_alg = hash_alg
|
self.hash_alg = hash_alg
|
||||||
|
|
||||||
def prepare_key(self, key):
|
def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
|
||||||
if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
|
if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
|
||||||
return key
|
return key
|
||||||
|
|
||||||
if not isinstance(key, (bytes, str)):
|
if not isinstance(key, (bytes, str)):
|
||||||
raise TypeError("Expecting a PEM-formatted key.")
|
raise TypeError("Expecting a PEM-formatted key.")
|
||||||
|
|
||||||
key = force_bytes(key)
|
key_bytes = force_bytes(key)
|
||||||
|
|
||||||
# Attempt to load key. We don't know if it's
|
# Attempt to load key. We don't know if it's
|
||||||
# a Signing Key or a Verifying Key, so we try
|
# a Signing Key or a Verifying Key, so we try
|
||||||
# the Verifying Key first.
|
# the Verifying Key first.
|
||||||
try:
|
try:
|
||||||
if key.startswith(b"ecdsa-sha2-"):
|
if key_bytes.startswith(b"ecdsa-sha2-"):
|
||||||
key = load_ssh_public_key(key)
|
crypto_key = load_ssh_public_key(key_bytes)
|
||||||
else:
|
else:
|
||||||
key = load_pem_public_key(key)
|
crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment]
|
||||||
except ValueError:
|
except ValueError:
|
||||||
key = load_pem_private_key(key, password=None)
|
crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
|
||||||
|
|
||||||
# Explicit check the key to prevent confusing errors from cryptography
|
# Explicit check the key to prevent confusing errors from cryptography
|
||||||
if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
|
if not isinstance(
|
||||||
|
crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
|
||||||
|
):
|
||||||
raise InvalidKeyError(
|
raise InvalidKeyError(
|
||||||
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
|
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
|
||||||
)
|
)
|
||||||
|
|
||||||
return key
|
return crypto_key
|
||||||
|
|
||||||
def sign(self, msg, key):
|
def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
|
||||||
der_sig = key.sign(msg, ec.ECDSA(self.hash_alg()))
|
der_sig = key.sign(msg, ECDSA(self.hash_alg()))
|
||||||
|
|
||||||
return der_to_raw_signature(der_sig, key.curve)
|
return der_to_raw_signature(der_sig, key.curve)
|
||||||
|
|
||||||
def verify(self, msg, key, sig):
|
def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool:
|
||||||
try:
|
try:
|
||||||
der_sig = raw_to_der_signature(sig, key.curve)
|
der_sig = raw_to_der_signature(sig, key.curve)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if isinstance(key, EllipticCurvePrivateKey):
|
public_key = (
|
||||||
key = key.public_key()
|
key.public_key()
|
||||||
key.verify(der_sig, msg, ec.ECDSA(self.hash_alg()))
|
if isinstance(key, EllipticCurvePrivateKey)
|
||||||
|
else key
|
||||||
|
)
|
||||||
|
public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
|
||||||
return True
|
return True
|
||||||
except InvalidSignature:
|
except InvalidSignature:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@overload
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_jwk(key_obj):
|
def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict:
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@staticmethod
|
||||||
|
def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str:
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_jwk(
|
||||||
|
key_obj: AllowedECKeys, as_dict: bool = False
|
||||||
|
) -> Union[JWKDict, str]:
|
||||||
if isinstance(key_obj, EllipticCurvePrivateKey):
|
if isinstance(key_obj, EllipticCurvePrivateKey):
|
||||||
public_numbers = key_obj.public_key().public_numbers()
|
public_numbers = key_obj.public_key().public_numbers()
|
||||||
elif isinstance(key_obj, EllipticCurvePublicKey):
|
elif isinstance(key_obj, EllipticCurvePublicKey):
|
||||||
|
@ -449,18 +571,18 @@ if has_crypto:
|
||||||
else:
|
else:
|
||||||
raise InvalidKeyError("Not a public or private key")
|
raise InvalidKeyError("Not a public or private key")
|
||||||
|
|
||||||
if isinstance(key_obj.curve, ec.SECP256R1):
|
if isinstance(key_obj.curve, SECP256R1):
|
||||||
crv = "P-256"
|
crv = "P-256"
|
||||||
elif isinstance(key_obj.curve, ec.SECP384R1):
|
elif isinstance(key_obj.curve, SECP384R1):
|
||||||
crv = "P-384"
|
crv = "P-384"
|
||||||
elif isinstance(key_obj.curve, ec.SECP521R1):
|
elif isinstance(key_obj.curve, SECP521R1):
|
||||||
crv = "P-521"
|
crv = "P-521"
|
||||||
elif isinstance(key_obj.curve, ec.SECP256K1):
|
elif isinstance(key_obj.curve, SECP256K1):
|
||||||
crv = "secp256k1"
|
crv = "secp256k1"
|
||||||
else:
|
else:
|
||||||
raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
|
raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
|
||||||
|
|
||||||
obj = {
|
obj: dict[str, Any] = {
|
||||||
"kty": "EC",
|
"kty": "EC",
|
||||||
"crv": crv,
|
"crv": crv,
|
||||||
"x": to_base64url_uint(public_numbers.x).decode(),
|
"x": to_base64url_uint(public_numbers.x).decode(),
|
||||||
|
@ -472,10 +594,13 @@ if has_crypto:
|
||||||
key_obj.private_numbers().private_value
|
key_obj.private_numbers().private_value
|
||||||
).decode()
|
).decode()
|
||||||
|
|
||||||
return json.dumps(obj)
|
if as_dict:
|
||||||
|
return obj
|
||||||
|
else:
|
||||||
|
return json.dumps(obj)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_jwk(jwk):
|
def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
|
||||||
try:
|
try:
|
||||||
if isinstance(jwk, str):
|
if isinstance(jwk, str):
|
||||||
obj = json.loads(jwk)
|
obj = json.loads(jwk)
|
||||||
|
@ -496,24 +621,26 @@ if has_crypto:
|
||||||
y = base64url_decode(obj.get("y"))
|
y = base64url_decode(obj.get("y"))
|
||||||
|
|
||||||
curve = obj.get("crv")
|
curve = obj.get("crv")
|
||||||
|
curve_obj: EllipticCurve
|
||||||
|
|
||||||
if curve == "P-256":
|
if curve == "P-256":
|
||||||
if len(x) == len(y) == 32:
|
if len(x) == len(y) == 32:
|
||||||
curve_obj = ec.SECP256R1()
|
curve_obj = SECP256R1()
|
||||||
else:
|
else:
|
||||||
raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
|
raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
|
||||||
elif curve == "P-384":
|
elif curve == "P-384":
|
||||||
if len(x) == len(y) == 48:
|
if len(x) == len(y) == 48:
|
||||||
curve_obj = ec.SECP384R1()
|
curve_obj = SECP384R1()
|
||||||
else:
|
else:
|
||||||
raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
|
raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
|
||||||
elif curve == "P-521":
|
elif curve == "P-521":
|
||||||
if len(x) == len(y) == 66:
|
if len(x) == len(y) == 66:
|
||||||
curve_obj = ec.SECP521R1()
|
curve_obj = SECP521R1()
|
||||||
else:
|
else:
|
||||||
raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
|
raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
|
||||||
elif curve == "secp256k1":
|
elif curve == "secp256k1":
|
||||||
if len(x) == len(y) == 32:
|
if len(x) == len(y) == 32:
|
||||||
curve_obj = ec.SECP256K1()
|
curve_obj = SECP256K1()
|
||||||
else:
|
else:
|
||||||
raise InvalidKeyError(
|
raise InvalidKeyError(
|
||||||
"Coords should be 32 bytes for curve secp256k1"
|
"Coords should be 32 bytes for curve secp256k1"
|
||||||
|
@ -521,7 +648,7 @@ if has_crypto:
|
||||||
else:
|
else:
|
||||||
raise InvalidKeyError(f"Invalid curve: {curve}")
|
raise InvalidKeyError(f"Invalid curve: {curve}")
|
||||||
|
|
||||||
public_numbers = ec.EllipticCurvePublicNumbers(
|
public_numbers = EllipticCurvePublicNumbers(
|
||||||
x=int.from_bytes(x, byteorder="big"),
|
x=int.from_bytes(x, byteorder="big"),
|
||||||
y=int.from_bytes(y, byteorder="big"),
|
y=int.from_bytes(y, byteorder="big"),
|
||||||
curve=curve_obj,
|
curve=curve_obj,
|
||||||
|
@ -536,7 +663,7 @@ if has_crypto:
|
||||||
"D should be {} bytes for curve {}", len(x), curve
|
"D should be {} bytes for curve {}", len(x), curve
|
||||||
)
|
)
|
||||||
|
|
||||||
return ec.EllipticCurvePrivateNumbers(
|
return EllipticCurvePrivateNumbers(
|
||||||
int.from_bytes(d, byteorder="big"), public_numbers
|
int.from_bytes(d, byteorder="big"), public_numbers
|
||||||
).private_key()
|
).private_key()
|
||||||
|
|
||||||
|
@ -545,24 +672,24 @@ if has_crypto:
|
||||||
Performs a signature using RSASSA-PSS with MGF1
|
Performs a signature using RSASSA-PSS with MGF1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def sign(self, msg, key):
|
def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
|
||||||
return key.sign(
|
return key.sign(
|
||||||
msg,
|
msg,
|
||||||
padding.PSS(
|
padding.PSS(
|
||||||
mgf=padding.MGF1(self.hash_alg()),
|
mgf=padding.MGF1(self.hash_alg()),
|
||||||
salt_length=self.hash_alg.digest_size,
|
salt_length=self.hash_alg().digest_size,
|
||||||
),
|
),
|
||||||
self.hash_alg(),
|
self.hash_alg(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def verify(self, msg, key, sig):
|
def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
|
||||||
try:
|
try:
|
||||||
key.verify(
|
key.verify(
|
||||||
sig,
|
sig,
|
||||||
msg,
|
msg,
|
||||||
padding.PSS(
|
padding.PSS(
|
||||||
mgf=padding.MGF1(self.hash_alg()),
|
mgf=padding.MGF1(self.hash_alg()),
|
||||||
salt_length=self.hash_alg.digest_size,
|
salt_length=self.hash_alg().digest_size,
|
||||||
),
|
),
|
||||||
self.hash_alg(),
|
self.hash_alg(),
|
||||||
)
|
)
|
||||||
|
@ -577,21 +704,20 @@ if has_crypto:
|
||||||
This class requires ``cryptography>=2.6`` to be installed.
|
This class requires ``cryptography>=2.6`` to be installed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def prepare_key(self, key):
|
def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
|
||||||
if isinstance(key, (bytes, str)):
|
if isinstance(key, (bytes, str)):
|
||||||
if isinstance(key, str):
|
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
|
||||||
key = key.encode("utf-8")
|
key_bytes = key.encode("utf-8") if isinstance(key, str) else key
|
||||||
str_key = key.decode("utf-8")
|
|
||||||
|
|
||||||
if "-----BEGIN PUBLIC" in str_key:
|
if "-----BEGIN PUBLIC" in key_str:
|
||||||
key = load_pem_public_key(key)
|
key = load_pem_public_key(key_bytes) # type: ignore[assignment]
|
||||||
elif "-----BEGIN PRIVATE" in str_key:
|
elif "-----BEGIN PRIVATE" in key_str:
|
||||||
key = load_pem_private_key(key, password=None)
|
key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
|
||||||
elif str_key[0:4] == "ssh-":
|
elif key_str[0:4] == "ssh-":
|
||||||
key = load_ssh_public_key(key)
|
key = load_ssh_public_key(key_bytes) # type: ignore[assignment]
|
||||||
|
|
||||||
# Explicit check the key to prevent confusing errors from cryptography
|
# Explicit check the key to prevent confusing errors from cryptography
|
||||||
if not isinstance(
|
if not isinstance(
|
||||||
|
@ -604,7 +730,9 @@ if has_crypto:
|
||||||
|
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def sign(self, msg, key):
|
def sign(
|
||||||
|
self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
|
||||||
|
) -> bytes:
|
||||||
"""
|
"""
|
||||||
Sign a message ``msg`` using the EdDSA private key ``key``
|
Sign a message ``msg`` using the EdDSA private key ``key``
|
||||||
:param str|bytes msg: Message to sign
|
:param str|bytes msg: Message to sign
|
||||||
|
@ -612,10 +740,12 @@ if has_crypto:
|
||||||
or :class:`.Ed448PrivateKey` isinstance
|
or :class:`.Ed448PrivateKey` isinstance
|
||||||
:return bytes signature: The signature, as bytes
|
:return bytes signature: The signature, as bytes
|
||||||
"""
|
"""
|
||||||
msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg
|
msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
|
||||||
return key.sign(msg)
|
return key.sign(msg_bytes)
|
||||||
|
|
||||||
def verify(self, msg, key, sig):
|
def verify(
|
||||||
|
self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
|
Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
|
||||||
|
|
||||||
|
@ -626,31 +756,48 @@ if has_crypto:
|
||||||
:return bool verified: True if signature is valid, False if not.
|
:return bool verified: True if signature is valid, False if not.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg
|
msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
|
||||||
sig = bytes(sig, "utf-8") if type(sig) is not bytes else sig
|
sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
|
||||||
|
|
||||||
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
|
public_key = (
|
||||||
key = key.public_key()
|
key.public_key()
|
||||||
key.verify(sig, msg)
|
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
|
||||||
|
else key
|
||||||
|
)
|
||||||
|
public_key.verify(sig_bytes, msg_bytes)
|
||||||
return True # If no exception was raised, the signature is valid.
|
return True # If no exception was raised, the signature is valid.
|
||||||
except cryptography.exceptions.InvalidSignature:
|
except InvalidSignature:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@overload
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_jwk(key):
|
def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict:
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@staticmethod
|
||||||
|
def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str:
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]:
|
||||||
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
|
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
|
||||||
x = key.public_bytes(
|
x = key.public_bytes(
|
||||||
encoding=Encoding.Raw,
|
encoding=Encoding.Raw,
|
||||||
format=PublicFormat.Raw,
|
format=PublicFormat.Raw,
|
||||||
)
|
)
|
||||||
crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
|
crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
|
||||||
return json.dumps(
|
|
||||||
{
|
obj = {
|
||||||
"x": base64url_encode(force_bytes(x)).decode(),
|
"x": base64url_encode(force_bytes(x)).decode(),
|
||||||
"kty": "OKP",
|
"kty": "OKP",
|
||||||
"crv": crv,
|
"crv": crv,
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
if as_dict:
|
||||||
|
return obj
|
||||||
|
else:
|
||||||
|
return json.dumps(obj)
|
||||||
|
|
||||||
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
|
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
|
||||||
d = key.private_bytes(
|
d = key.private_bytes(
|
||||||
|
@ -665,19 +812,22 @@ if has_crypto:
|
||||||
)
|
)
|
||||||
|
|
||||||
crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
|
crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
|
||||||
return json.dumps(
|
obj = {
|
||||||
{
|
"x": base64url_encode(force_bytes(x)).decode(),
|
||||||
"x": base64url_encode(force_bytes(x)).decode(),
|
"d": base64url_encode(force_bytes(d)).decode(),
|
||||||
"d": base64url_encode(force_bytes(d)).decode(),
|
"kty": "OKP",
|
||||||
"kty": "OKP",
|
"crv": crv,
|
||||||
"crv": crv,
|
}
|
||||||
}
|
|
||||||
)
|
if as_dict:
|
||||||
|
return obj
|
||||||
|
else:
|
||||||
|
return json.dumps(obj)
|
||||||
|
|
||||||
raise InvalidKeyError("Not a public or private key")
|
raise InvalidKeyError("Not a public or private key")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_jwk(jwk):
|
def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
|
||||||
try:
|
try:
|
||||||
if isinstance(jwk, str):
|
if isinstance(jwk, str):
|
||||||
obj = json.loads(jwk)
|
obj = json.loads(jwk)
|
||||||
|
|
|
@ -2,13 +2,15 @@ from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from .algorithms import get_default_algorithms
|
from .algorithms import get_default_algorithms, has_crypto, requires_cryptography
|
||||||
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
|
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError, PyJWTError
|
||||||
|
from .types import JWKDict
|
||||||
|
|
||||||
|
|
||||||
class PyJWK:
|
class PyJWK:
|
||||||
def __init__(self, jwk_data, algorithm=None):
|
def __init__(self, jwk_data: JWKDict, algorithm: str | None = None) -> None:
|
||||||
self._algorithms = get_default_algorithms()
|
self._algorithms = get_default_algorithms()
|
||||||
self._jwk_data = jwk_data
|
self._jwk_data = jwk_data
|
||||||
|
|
||||||
|
@ -47,37 +49,40 @@ class PyJWK:
|
||||||
else:
|
else:
|
||||||
raise InvalidKeyError(f"Unsupported kty: {kty}")
|
raise InvalidKeyError(f"Unsupported kty: {kty}")
|
||||||
|
|
||||||
|
if not has_crypto and algorithm in requires_cryptography:
|
||||||
|
raise PyJWKError(f"{algorithm} requires 'cryptography' to be installed.")
|
||||||
|
|
||||||
self.Algorithm = self._algorithms.get(algorithm)
|
self.Algorithm = self._algorithms.get(algorithm)
|
||||||
|
|
||||||
if not self.Algorithm:
|
if not self.Algorithm:
|
||||||
raise PyJWKError(f"Unable to find a algorithm for key: {self._jwk_data}")
|
raise PyJWKError(f"Unable to find an algorithm for key: {self._jwk_data}")
|
||||||
|
|
||||||
self.key = self.Algorithm.from_jwk(self._jwk_data)
|
self.key = self.Algorithm.from_jwk(self._jwk_data)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_dict(obj, algorithm=None):
|
def from_dict(obj: JWKDict, algorithm: str | None = None) -> "PyJWK":
|
||||||
return PyJWK(obj, algorithm)
|
return PyJWK(obj, algorithm)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_json(data, algorithm=None):
|
def from_json(data: str, algorithm: None = None) -> "PyJWK":
|
||||||
obj = json.loads(data)
|
obj = json.loads(data)
|
||||||
return PyJWK.from_dict(obj, algorithm)
|
return PyJWK.from_dict(obj, algorithm)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key_type(self):
|
def key_type(self) -> str | None:
|
||||||
return self._jwk_data.get("kty", None)
|
return self._jwk_data.get("kty", None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key_id(self):
|
def key_id(self) -> str | None:
|
||||||
return self._jwk_data.get("kid", None)
|
return self._jwk_data.get("kid", None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def public_key_use(self):
|
def public_key_use(self) -> str | None:
|
||||||
return self._jwk_data.get("use", None)
|
return self._jwk_data.get("use", None)
|
||||||
|
|
||||||
|
|
||||||
class PyJWKSet:
|
class PyJWKSet:
|
||||||
def __init__(self, keys: list[dict]) -> None:
|
def __init__(self, keys: list[JWKDict]) -> None:
|
||||||
self.keys = []
|
self.keys = []
|
||||||
|
|
||||||
if not keys:
|
if not keys:
|
||||||
|
@ -89,24 +94,26 @@ class PyJWKSet:
|
||||||
for key in keys:
|
for key in keys:
|
||||||
try:
|
try:
|
||||||
self.keys.append(PyJWK(key))
|
self.keys.append(PyJWK(key))
|
||||||
except PyJWKError:
|
except PyJWTError:
|
||||||
# skip unusable keys
|
# skip unusable keys
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(self.keys) == 0:
|
if len(self.keys) == 0:
|
||||||
raise PyJWKSetError("The JWK Set did not contain any usable keys")
|
raise PyJWKSetError(
|
||||||
|
"The JWK Set did not contain any usable keys. Perhaps 'cryptography' is not installed?"
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_dict(obj):
|
def from_dict(obj: dict[str, Any]) -> "PyJWKSet":
|
||||||
keys = obj.get("keys", [])
|
keys = obj.get("keys", [])
|
||||||
return PyJWKSet(keys)
|
return PyJWKSet(keys)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_json(data):
|
def from_json(data: str) -> "PyJWKSet":
|
||||||
obj = json.loads(data)
|
obj = json.loads(data)
|
||||||
return PyJWKSet.from_dict(obj)
|
return PyJWKSet.from_dict(obj)
|
||||||
|
|
||||||
def __getitem__(self, kid):
|
def __getitem__(self, kid: str) -> "PyJWK":
|
||||||
for key in self.keys:
|
for key in self.keys:
|
||||||
if key.key_id == kid:
|
if key.key_id == kid:
|
||||||
return key
|
return key
|
||||||
|
@ -118,8 +125,8 @@ class PyJWTSetWithTimestamp:
|
||||||
self.jwk_set = jwk_set
|
self.jwk_set = jwk_set
|
||||||
self.timestamp = time.monotonic()
|
self.timestamp = time.monotonic()
|
||||||
|
|
||||||
def get_jwk_set(self):
|
def get_jwk_set(self) -> PyJWKSet:
|
||||||
return self.jwk_set
|
return self.jwk_set
|
||||||
|
|
||||||
def get_timestamp(self):
|
def get_timestamp(self) -> float:
|
||||||
return self.timestamp
|
return self.timestamp
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||||
import binascii
|
import binascii
|
||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Type
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from .algorithms import (
|
from .algorithms import (
|
||||||
Algorithm,
|
Algorithm,
|
||||||
|
@ -20,11 +20,18 @@ from .exceptions import (
|
||||||
from .utils import base64url_decode, base64url_encode
|
from .utils import base64url_decode, base64url_encode
|
||||||
from .warnings import RemovedInPyjwt3Warning
|
from .warnings import RemovedInPyjwt3Warning
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
|
||||||
|
|
||||||
|
|
||||||
class PyJWS:
|
class PyJWS:
|
||||||
header_typ = "JWT"
|
header_typ = "JWT"
|
||||||
|
|
||||||
def __init__(self, algorithms=None, options=None) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
algorithms: list[str] | None = None,
|
||||||
|
options: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
self._algorithms = get_default_algorithms()
|
self._algorithms = get_default_algorithms()
|
||||||
self._valid_algs = (
|
self._valid_algs = (
|
||||||
set(algorithms) if algorithms is not None else set(self._algorithms)
|
set(algorithms) if algorithms is not None else set(self._algorithms)
|
||||||
|
@ -96,11 +103,12 @@ class PyJWS:
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
payload: bytes,
|
payload: bytes,
|
||||||
key: str,
|
key: AllowedPrivateKeys | str | bytes,
|
||||||
algorithm: str | None = "HS256",
|
algorithm: str | None = "HS256",
|
||||||
headers: dict[str, Any] | None = None,
|
headers: dict[str, Any] | None = None,
|
||||||
json_encoder: Type[json.JSONEncoder] | None = None,
|
json_encoder: type[json.JSONEncoder] | None = None,
|
||||||
is_payload_detached: bool = False,
|
is_payload_detached: bool = False,
|
||||||
|
sort_headers: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
segments = []
|
segments = []
|
||||||
|
|
||||||
|
@ -133,9 +141,8 @@ class PyJWS:
|
||||||
# True is the standard value for b64, so no need for it
|
# True is the standard value for b64, so no need for it
|
||||||
del header["b64"]
|
del header["b64"]
|
||||||
|
|
||||||
# Fix for headers misorder - issue #715
|
|
||||||
json_header = json.dumps(
|
json_header = json.dumps(
|
||||||
header, separators=(",", ":"), cls=json_encoder, sort_keys=True
|
header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers
|
||||||
).encode()
|
).encode()
|
||||||
|
|
||||||
segments.append(base64url_encode(json_header))
|
segments.append(base64url_encode(json_header))
|
||||||
|
@ -164,8 +171,8 @@ class PyJWS:
|
||||||
|
|
||||||
def decode_complete(
|
def decode_complete(
|
||||||
self,
|
self,
|
||||||
jwt: str,
|
jwt: str | bytes,
|
||||||
key: str = "",
|
key: AllowedPublicKeys | str | bytes = "",
|
||||||
algorithms: list[str] | None = None,
|
algorithms: list[str] | None = None,
|
||||||
options: dict[str, Any] | None = None,
|
options: dict[str, Any] | None = None,
|
||||||
detached_payload: bytes | None = None,
|
detached_payload: bytes | None = None,
|
||||||
|
@ -209,13 +216,13 @@ class PyJWS:
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
jwt: str,
|
jwt: str | bytes,
|
||||||
key: str = "",
|
key: AllowedPublicKeys | str | bytes = "",
|
||||||
algorithms: list[str] | None = None,
|
algorithms: list[str] | None = None,
|
||||||
options: dict[str, Any] | None = None,
|
options: dict[str, Any] | None = None,
|
||||||
detached_payload: bytes | None = None,
|
detached_payload: bytes | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> str:
|
) -> Any:
|
||||||
if kwargs:
|
if kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"passing additional kwargs to decode() is deprecated "
|
"passing additional kwargs to decode() is deprecated "
|
||||||
|
@ -228,7 +235,7 @@ class PyJWS:
|
||||||
)
|
)
|
||||||
return decoded["payload"]
|
return decoded["payload"]
|
||||||
|
|
||||||
def get_unverified_header(self, jwt: str | bytes) -> dict:
|
def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]:
|
||||||
"""Returns back the JWT header parameters as a dict()
|
"""Returns back the JWT header parameters as a dict()
|
||||||
|
|
||||||
Note: The signature is not verified so the header parameters
|
Note: The signature is not verified so the header parameters
|
||||||
|
@ -239,7 +246,7 @@ class PyJWS:
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
|
def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]:
|
||||||
if isinstance(jwt, str):
|
if isinstance(jwt, str):
|
||||||
jwt = jwt.encode("utf-8")
|
jwt = jwt.encode("utf-8")
|
||||||
|
|
||||||
|
@ -280,13 +287,15 @@ class PyJWS:
|
||||||
def _verify_signature(
|
def _verify_signature(
|
||||||
self,
|
self,
|
||||||
signing_input: bytes,
|
signing_input: bytes,
|
||||||
header: dict,
|
header: dict[str, Any],
|
||||||
signature: bytes,
|
signature: bytes,
|
||||||
key: str = "",
|
key: AllowedPublicKeys | str | bytes = "",
|
||||||
algorithms: list[str] | None = None,
|
algorithms: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
try:
|
||||||
alg = header.get("alg")
|
alg = header["alg"]
|
||||||
|
except KeyError:
|
||||||
|
raise InvalidAlgorithmError("Algorithm not specified")
|
||||||
|
|
||||||
if not alg or (algorithms is not None and alg not in algorithms):
|
if not alg or (algorithms is not None and alg not in algorithms):
|
||||||
raise InvalidAlgorithmError("The specified alg value is not allowed")
|
raise InvalidAlgorithmError("The specified alg value is not allowed")
|
||||||
|
@ -295,16 +304,16 @@ class PyJWS:
|
||||||
alg_obj = self.get_algorithm_by_name(alg)
|
alg_obj = self.get_algorithm_by_name(alg)
|
||||||
except NotImplementedError as e:
|
except NotImplementedError as e:
|
||||||
raise InvalidAlgorithmError("Algorithm not supported") from e
|
raise InvalidAlgorithmError("Algorithm not supported") from e
|
||||||
key = alg_obj.prepare_key(key)
|
prepared_key = alg_obj.prepare_key(key)
|
||||||
|
|
||||||
if not alg_obj.verify(signing_input, key, signature):
|
if not alg_obj.verify(signing_input, prepared_key, signature):
|
||||||
raise InvalidSignatureError("Signature verification failed")
|
raise InvalidSignatureError("Signature verification failed")
|
||||||
|
|
||||||
def _validate_headers(self, headers: dict[str, Any]) -> None:
|
def _validate_headers(self, headers: dict[str, Any]) -> None:
|
||||||
if "kid" in headers:
|
if "kid" in headers:
|
||||||
self._validate_kid(headers["kid"])
|
self._validate_kid(headers["kid"])
|
||||||
|
|
||||||
def _validate_kid(self, kid: str) -> None:
|
def _validate_kid(self, kid: Any) -> None:
|
||||||
if not isinstance(kid, str):
|
if not isinstance(kid, str):
|
||||||
raise InvalidTokenError("Key ID header parameter must be a string")
|
raise InvalidTokenError("Key ID header parameter must be a string")
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,9 @@ from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from calendar import timegm
|
from calendar import timegm
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any, Dict, List, Optional, Type, Union
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from . import api_jws
|
from . import api_jws
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
|
@ -19,15 +19,18 @@ from .exceptions import (
|
||||||
)
|
)
|
||||||
from .warnings import RemovedInPyjwt3Warning
|
from .warnings import RemovedInPyjwt3Warning
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
|
||||||
|
|
||||||
|
|
||||||
class PyJWT:
|
class PyJWT:
|
||||||
def __init__(self, options=None):
|
def __init__(self, options: dict[str, Any] | None = None) -> None:
|
||||||
if options is None:
|
if options is None:
|
||||||
options = {}
|
options = {}
|
||||||
self.options = {**self._get_default_options(), **options}
|
self.options: dict[str, Any] = {**self._get_default_options(), **options}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_default_options() -> Dict[str, Union[bool, List[str]]]:
|
def _get_default_options() -> dict[str, bool | list[str]]:
|
||||||
return {
|
return {
|
||||||
"verify_signature": True,
|
"verify_signature": True,
|
||||||
"verify_exp": True,
|
"verify_exp": True,
|
||||||
|
@ -40,16 +43,17 @@ class PyJWT:
|
||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
payload: Dict[str, Any],
|
payload: dict[str, Any],
|
||||||
key: str,
|
key: AllowedPrivateKeys | str | bytes,
|
||||||
algorithm: Optional[str] = "HS256",
|
algorithm: str | None = "HS256",
|
||||||
headers: Optional[Dict[str, Any]] = None,
|
headers: dict[str, Any] | None = None,
|
||||||
json_encoder: Optional[Type[json.JSONEncoder]] = None,
|
json_encoder: type[json.JSONEncoder] | None = None,
|
||||||
|
sort_headers: bool = True,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Check that we get a mapping
|
# Check that we get a dict
|
||||||
if not isinstance(payload, Mapping):
|
if not isinstance(payload, dict):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Expecting a mapping object, as JWT only supports "
|
"Expecting a dict object, as JWT only supports "
|
||||||
"JSON objects as payloads."
|
"JSON objects as payloads."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -60,30 +64,57 @@ class PyJWT:
|
||||||
if isinstance(payload.get(time_claim), datetime):
|
if isinstance(payload.get(time_claim), datetime):
|
||||||
payload[time_claim] = timegm(payload[time_claim].utctimetuple())
|
payload[time_claim] = timegm(payload[time_claim].utctimetuple())
|
||||||
|
|
||||||
json_payload = json.dumps(
|
json_payload = self._encode_payload(
|
||||||
payload, separators=(",", ":"), cls=json_encoder
|
payload,
|
||||||
).encode("utf-8")
|
headers=headers,
|
||||||
|
json_encoder=json_encoder,
|
||||||
|
)
|
||||||
|
|
||||||
return api_jws.encode(json_payload, key, algorithm, headers, json_encoder)
|
return api_jws.encode(
|
||||||
|
json_payload,
|
||||||
|
key,
|
||||||
|
algorithm,
|
||||||
|
headers,
|
||||||
|
json_encoder,
|
||||||
|
sort_headers=sort_headers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _encode_payload(
|
||||||
|
self,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
headers: dict[str, Any] | None = None,
|
||||||
|
json_encoder: type[json.JSONEncoder] | None = None,
|
||||||
|
) -> bytes:
|
||||||
|
"""
|
||||||
|
Encode a given payload to the bytes to be signed.
|
||||||
|
|
||||||
|
This method is intended to be overridden by subclasses that need to
|
||||||
|
encode the payload in a different way, e.g. compress the payload.
|
||||||
|
"""
|
||||||
|
return json.dumps(
|
||||||
|
payload,
|
||||||
|
separators=(",", ":"),
|
||||||
|
cls=json_encoder,
|
||||||
|
).encode("utf-8")
|
||||||
|
|
||||||
def decode_complete(
|
def decode_complete(
|
||||||
self,
|
self,
|
||||||
jwt: str,
|
jwt: str | bytes,
|
||||||
key: str = "",
|
key: AllowedPublicKeys | str | bytes = "",
|
||||||
algorithms: Optional[List[str]] = None,
|
algorithms: list[str] | None = None,
|
||||||
options: Optional[Dict[str, Any]] = None,
|
options: dict[str, Any] | None = None,
|
||||||
# deprecated arg, remove in pyjwt3
|
# deprecated arg, remove in pyjwt3
|
||||||
verify: Optional[bool] = None,
|
verify: bool | None = None,
|
||||||
# could be used as passthrough to api_jws, consider removal in pyjwt3
|
# could be used as passthrough to api_jws, consider removal in pyjwt3
|
||||||
detached_payload: Optional[bytes] = None,
|
detached_payload: bytes | None = None,
|
||||||
# passthrough arguments to _validate_claims
|
# passthrough arguments to _validate_claims
|
||||||
# consider putting in options
|
# consider putting in options
|
||||||
audience: Optional[Union[str, Iterable[str]]] = None,
|
audience: str | Iterable[str] | None = None,
|
||||||
issuer: Optional[str] = None,
|
issuer: str | None = None,
|
||||||
leeway: Union[int, float, timedelta] = 0,
|
leeway: float | timedelta = 0,
|
||||||
# kwargs
|
# kwargs
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
if kwargs:
|
if kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"passing additional kwargs to decode_complete() is deprecated "
|
"passing additional kwargs to decode_complete() is deprecated "
|
||||||
|
@ -125,12 +156,7 @@ class PyJWT:
|
||||||
detached_payload=detached_payload,
|
detached_payload=detached_payload,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
payload = self._decode_payload(decoded)
|
||||||
payload = json.loads(decoded["payload"])
|
|
||||||
except ValueError as e:
|
|
||||||
raise DecodeError(f"Invalid payload string: {e}")
|
|
||||||
if not isinstance(payload, dict):
|
|
||||||
raise DecodeError("Invalid payload string: must be a json object")
|
|
||||||
|
|
||||||
merged_options = {**self.options, **options}
|
merged_options = {**self.options, **options}
|
||||||
self._validate_claims(
|
self._validate_claims(
|
||||||
|
@ -140,24 +166,40 @@ class PyJWT:
|
||||||
decoded["payload"] = payload
|
decoded["payload"] = payload
|
||||||
return decoded
|
return decoded
|
||||||
|
|
||||||
|
def _decode_payload(self, decoded: dict[str, Any]) -> Any:
|
||||||
|
"""
|
||||||
|
Decode the payload from a JWS dictionary (payload, signature, header).
|
||||||
|
|
||||||
|
This method is intended to be overridden by subclasses that need to
|
||||||
|
decode the payload in a different way, e.g. decompress compressed
|
||||||
|
payloads.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload = json.loads(decoded["payload"])
|
||||||
|
except ValueError as e:
|
||||||
|
raise DecodeError(f"Invalid payload string: {e}")
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
raise DecodeError("Invalid payload string: must be a json object")
|
||||||
|
return payload
|
||||||
|
|
||||||
def decode(
|
def decode(
|
||||||
self,
|
self,
|
||||||
jwt: str,
|
jwt: str | bytes,
|
||||||
key: str = "",
|
key: AllowedPublicKeys | str | bytes = "",
|
||||||
algorithms: Optional[List[str]] = None,
|
algorithms: list[str] | None = None,
|
||||||
options: Optional[Dict[str, Any]] = None,
|
options: dict[str, Any] | None = None,
|
||||||
# deprecated arg, remove in pyjwt3
|
# deprecated arg, remove in pyjwt3
|
||||||
verify: Optional[bool] = None,
|
verify: bool | None = None,
|
||||||
# could be used as passthrough to api_jws, consider removal in pyjwt3
|
# could be used as passthrough to api_jws, consider removal in pyjwt3
|
||||||
detached_payload: Optional[bytes] = None,
|
detached_payload: bytes | None = None,
|
||||||
# passthrough arguments to _validate_claims
|
# passthrough arguments to _validate_claims
|
||||||
# consider putting in options
|
# consider putting in options
|
||||||
audience: Optional[Union[str, Iterable[str]]] = None,
|
audience: str | Iterable[str] | None = None,
|
||||||
issuer: Optional[str] = None,
|
issuer: str | None = None,
|
||||||
leeway: Union[int, float, timedelta] = 0,
|
leeway: float | timedelta = 0,
|
||||||
# kwargs
|
# kwargs
|
||||||
**kwargs,
|
**kwargs: Any,
|
||||||
) -> Dict[str, Any]:
|
) -> Any:
|
||||||
if kwargs:
|
if kwargs:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"passing additional kwargs to decode() is deprecated "
|
"passing additional kwargs to decode() is deprecated "
|
||||||
|
@ -178,7 +220,14 @@ class PyJWT:
|
||||||
)
|
)
|
||||||
return decoded["payload"]
|
return decoded["payload"]
|
||||||
|
|
||||||
def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0):
|
def _validate_claims(
|
||||||
|
self,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
options: dict[str, Any],
|
||||||
|
audience=None,
|
||||||
|
issuer=None,
|
||||||
|
leeway: float | timedelta = 0,
|
||||||
|
) -> None:
|
||||||
if isinstance(leeway, timedelta):
|
if isinstance(leeway, timedelta):
|
||||||
leeway = leeway.total_seconds()
|
leeway = leeway.total_seconds()
|
||||||
|
|
||||||
|
@ -187,7 +236,7 @@ class PyJWT:
|
||||||
|
|
||||||
self._validate_required_claims(payload, options)
|
self._validate_required_claims(payload, options)
|
||||||
|
|
||||||
now = timegm(datetime.now(tz=timezone.utc).utctimetuple())
|
now = datetime.now(tz=timezone.utc).timestamp()
|
||||||
|
|
||||||
if "iat" in payload and options["verify_iat"]:
|
if "iat" in payload and options["verify_iat"]:
|
||||||
self._validate_iat(payload, now, leeway)
|
self._validate_iat(payload, now, leeway)
|
||||||
|
@ -202,23 +251,38 @@ class PyJWT:
|
||||||
self._validate_iss(payload, issuer)
|
self._validate_iss(payload, issuer)
|
||||||
|
|
||||||
if options["verify_aud"]:
|
if options["verify_aud"]:
|
||||||
self._validate_aud(payload, audience)
|
self._validate_aud(
|
||||||
|
payload, audience, strict=options.get("strict_aud", False)
|
||||||
|
)
|
||||||
|
|
||||||
def _validate_required_claims(self, payload, options):
|
def _validate_required_claims(
|
||||||
|
self,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
options: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
for claim in options["require"]:
|
for claim in options["require"]:
|
||||||
if payload.get(claim) is None:
|
if payload.get(claim) is None:
|
||||||
raise MissingRequiredClaimError(claim)
|
raise MissingRequiredClaimError(claim)
|
||||||
|
|
||||||
def _validate_iat(self, payload, now, leeway):
|
def _validate_iat(
|
||||||
iat = payload["iat"]
|
self,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
now: float,
|
||||||
|
leeway: float,
|
||||||
|
) -> None:
|
||||||
try:
|
try:
|
||||||
int(iat)
|
iat = int(payload["iat"])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.")
|
raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.")
|
||||||
if iat > (now + leeway):
|
if iat > (now + leeway):
|
||||||
raise ImmatureSignatureError("The token is not yet valid (iat)")
|
raise ImmatureSignatureError("The token is not yet valid (iat)")
|
||||||
|
|
||||||
def _validate_nbf(self, payload, now, leeway):
|
def _validate_nbf(
|
||||||
|
self,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
now: float,
|
||||||
|
leeway: float,
|
||||||
|
) -> None:
|
||||||
try:
|
try:
|
||||||
nbf = int(payload["nbf"])
|
nbf = int(payload["nbf"])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -227,7 +291,12 @@ class PyJWT:
|
||||||
if nbf > (now + leeway):
|
if nbf > (now + leeway):
|
||||||
raise ImmatureSignatureError("The token is not yet valid (nbf)")
|
raise ImmatureSignatureError("The token is not yet valid (nbf)")
|
||||||
|
|
||||||
def _validate_exp(self, payload, now, leeway):
|
def _validate_exp(
|
||||||
|
self,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
now: float,
|
||||||
|
leeway: float,
|
||||||
|
) -> None:
|
||||||
try:
|
try:
|
||||||
exp = int(payload["exp"])
|
exp = int(payload["exp"])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -236,7 +305,13 @@ class PyJWT:
|
||||||
if exp <= (now - leeway):
|
if exp <= (now - leeway):
|
||||||
raise ExpiredSignatureError("Signature has expired")
|
raise ExpiredSignatureError("Signature has expired")
|
||||||
|
|
||||||
def _validate_aud(self, payload, audience):
|
def _validate_aud(
|
||||||
|
self,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
audience: str | Iterable[str] | None,
|
||||||
|
*,
|
||||||
|
strict: bool = False,
|
||||||
|
) -> None:
|
||||||
if audience is None:
|
if audience is None:
|
||||||
if "aud" not in payload or not payload["aud"]:
|
if "aud" not in payload or not payload["aud"]:
|
||||||
return
|
return
|
||||||
|
@ -251,6 +326,22 @@ class PyJWT:
|
||||||
|
|
||||||
audience_claims = payload["aud"]
|
audience_claims = payload["aud"]
|
||||||
|
|
||||||
|
# In strict mode, we forbid list matching: the supplied audience
|
||||||
|
# must be a string, and it must exactly match the audience claim.
|
||||||
|
if strict:
|
||||||
|
# Only a single audience is allowed in strict mode.
|
||||||
|
if not isinstance(audience, str):
|
||||||
|
raise InvalidAudienceError("Invalid audience (strict)")
|
||||||
|
|
||||||
|
# Only a single audience claim is allowed in strict mode.
|
||||||
|
if not isinstance(audience_claims, str):
|
||||||
|
raise InvalidAudienceError("Invalid claim format in token (strict)")
|
||||||
|
|
||||||
|
if audience != audience_claims:
|
||||||
|
raise InvalidAudienceError("Audience doesn't match (strict)")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
if isinstance(audience_claims, str):
|
if isinstance(audience_claims, str):
|
||||||
audience_claims = [audience_claims]
|
audience_claims = [audience_claims]
|
||||||
if not isinstance(audience_claims, list):
|
if not isinstance(audience_claims, list):
|
||||||
|
@ -262,9 +353,9 @@ class PyJWT:
|
||||||
audience = [audience]
|
audience = [audience]
|
||||||
|
|
||||||
if all(aud not in audience_claims for aud in audience):
|
if all(aud not in audience_claims for aud in audience):
|
||||||
raise InvalidAudienceError("Invalid audience")
|
raise InvalidAudienceError("Audience doesn't match")
|
||||||
|
|
||||||
def _validate_iss(self, payload, issuer):
|
def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None:
|
||||||
if issuer is None:
|
if issuer is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -47,10 +47,10 @@ class InvalidAlgorithmError(InvalidTokenError):
|
||||||
|
|
||||||
|
|
||||||
class MissingRequiredClaimError(InvalidTokenError):
|
class MissingRequiredClaimError(InvalidTokenError):
|
||||||
def __init__(self, claim):
|
def __init__(self, claim: str) -> None:
|
||||||
self.claim = claim
|
self.claim = claim
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return f'Token is missing the "{self.claim}" claim'
|
return f'Token is missing the "{self.claim}" claim'
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,3 +64,7 @@ class PyJWKSetError(PyJWTError):
|
||||||
|
|
||||||
class PyJWKClientError(PyJWTError):
|
class PyJWKClientError(PyJWTError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PyJWKClientConnectionError(PyJWKClientError):
|
||||||
|
pass
|
||||||
|
|
|
@ -7,8 +7,10 @@ from . import __version__ as pyjwt_version
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cryptography
|
import cryptography
|
||||||
|
|
||||||
|
cryptography_version = cryptography.__version__
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
cryptography = None
|
cryptography_version = ""
|
||||||
|
|
||||||
|
|
||||||
def info() -> Dict[str, Dict[str, str]]:
|
def info() -> Dict[str, Dict[str, str]]:
|
||||||
|
@ -29,7 +31,7 @@ def info() -> Dict[str, Dict[str, str]]:
|
||||||
if implementation == "CPython":
|
if implementation == "CPython":
|
||||||
implementation_version = platform.python_version()
|
implementation_version = platform.python_version()
|
||||||
elif implementation == "PyPy":
|
elif implementation == "PyPy":
|
||||||
pypy_version_info = getattr(sys, "pypy_version_info")
|
pypy_version_info = sys.pypy_version_info # type: ignore[attr-defined]
|
||||||
implementation_version = (
|
implementation_version = (
|
||||||
f"{pypy_version_info.major}."
|
f"{pypy_version_info.major}."
|
||||||
f"{pypy_version_info.minor}."
|
f"{pypy_version_info.minor}."
|
||||||
|
@ -48,7 +50,7 @@ def info() -> Dict[str, Dict[str, str]]:
|
||||||
"name": implementation,
|
"name": implementation,
|
||||||
"version": implementation_version,
|
"version": implementation_version,
|
||||||
},
|
},
|
||||||
"cryptography": {"version": getattr(cryptography, "__version__", "")},
|
"cryptography": {"version": cryptography_version},
|
||||||
"pyjwt": {"version": pyjwt_version},
|
"pyjwt": {"version": pyjwt_version},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,11 +5,11 @@ from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp
|
||||||
|
|
||||||
|
|
||||||
class JWKSetCache:
|
class JWKSetCache:
|
||||||
def __init__(self, lifespan: int):
|
def __init__(self, lifespan: int) -> None:
|
||||||
self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
|
self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
|
||||||
self.lifespan = lifespan
|
self.lifespan = lifespan
|
||||||
|
|
||||||
def put(self, jwk_set: PyJWKSet):
|
def put(self, jwk_set: PyJWKSet) -> None:
|
||||||
if jwk_set is not None:
|
if jwk_set is not None:
|
||||||
self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
|
self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
|
||||||
else:
|
else:
|
||||||
|
@ -23,7 +23,6 @@ class JWKSetCache:
|
||||||
return self.jwk_set_with_timestamp.get_jwk_set()
|
return self.jwk_set_with_timestamp.get_jwk_set()
|
||||||
|
|
||||||
def is_expired(self) -> bool:
|
def is_expired(self) -> bool:
|
||||||
|
|
||||||
return (
|
return (
|
||||||
self.jwk_set_with_timestamp is not None
|
self.jwk_set_with_timestamp is not None
|
||||||
and self.lifespan > -1
|
and self.lifespan > -1
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
import json
|
import json
|
||||||
import urllib.request
|
import urllib.request
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, List, Optional
|
from ssl import SSLContext
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
from urllib.error import URLError
|
from urllib.error import URLError
|
||||||
|
|
||||||
from .api_jwk import PyJWK, PyJWKSet
|
from .api_jwk import PyJWK, PyJWKSet
|
||||||
from .api_jwt import decode_complete as decode_token
|
from .api_jwt import decode_complete as decode_token
|
||||||
from .exceptions import PyJWKClientError
|
from .exceptions import PyJWKClientConnectionError, PyJWKClientError
|
||||||
from .jwk_set_cache import JWKSetCache
|
from .jwk_set_cache import JWKSetCache
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,9 +19,17 @@ class PyJWKClient:
|
||||||
max_cached_keys: int = 16,
|
max_cached_keys: int = 16,
|
||||||
cache_jwk_set: bool = True,
|
cache_jwk_set: bool = True,
|
||||||
lifespan: int = 300,
|
lifespan: int = 300,
|
||||||
|
headers: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout: int = 30,
|
||||||
|
ssl_context: Optional[SSLContext] = None,
|
||||||
):
|
):
|
||||||
|
if headers is None:
|
||||||
|
headers = {}
|
||||||
self.uri = uri
|
self.uri = uri
|
||||||
self.jwk_set_cache: Optional[JWKSetCache] = None
|
self.jwk_set_cache: Optional[JWKSetCache] = None
|
||||||
|
self.headers = headers
|
||||||
|
self.timeout = timeout
|
||||||
|
self.ssl_context = ssl_context
|
||||||
|
|
||||||
if cache_jwk_set:
|
if cache_jwk_set:
|
||||||
# Init jwt set cache with default or given lifespan.
|
# Init jwt set cache with default or given lifespan.
|
||||||
|
@ -41,10 +50,15 @@ class PyJWKClient:
|
||||||
def fetch_data(self) -> Any:
|
def fetch_data(self) -> Any:
|
||||||
jwk_set: Any = None
|
jwk_set: Any = None
|
||||||
try:
|
try:
|
||||||
with urllib.request.urlopen(self.uri) as response:
|
r = urllib.request.Request(url=self.uri, headers=self.headers)
|
||||||
|
with urllib.request.urlopen(
|
||||||
|
r, timeout=self.timeout, context=self.ssl_context
|
||||||
|
) as response:
|
||||||
jwk_set = json.load(response)
|
jwk_set = json.load(response)
|
||||||
except URLError as e:
|
except (URLError, TimeoutError) as e:
|
||||||
raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"')
|
raise PyJWKClientConnectionError(
|
||||||
|
f'Fail to fetch data from the url, err: "{e}"'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return jwk_set
|
return jwk_set
|
||||||
finally:
|
finally:
|
||||||
|
@ -59,6 +73,9 @@ class PyJWKClient:
|
||||||
if data is None:
|
if data is None:
|
||||||
data = self.fetch_data()
|
data = self.fetch_data()
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise PyJWKClientError("The JWKS endpoint did not return a JSON object")
|
||||||
|
|
||||||
return PyJWKSet.from_dict(data)
|
return PyJWKSet.from_dict(data)
|
||||||
|
|
||||||
def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:
|
def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:
|
||||||
|
|
5
lib/jwt/types.py
Normal file
5
lib/jwt/types.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
from typing import Any, Callable, Dict
|
||||||
|
|
||||||
|
JWKDict = Dict[str, Any]
|
||||||
|
|
||||||
|
HashlibHash = Callable[..., Any]
|
|
@ -10,10 +10,10 @@ try:
|
||||||
encode_dss_signature,
|
encode_dss_signature,
|
||||||
)
|
)
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
EllipticCurve = None
|
pass
|
||||||
|
|
||||||
|
|
||||||
def force_bytes(value: Union[str, bytes]) -> bytes:
|
def force_bytes(value: Union[bytes, str]) -> bytes:
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
return value.encode("utf-8")
|
return value.encode("utf-8")
|
||||||
elif isinstance(value, bytes):
|
elif isinstance(value, bytes):
|
||||||
|
@ -22,16 +22,15 @@ def force_bytes(value: Union[str, bytes]) -> bytes:
|
||||||
raise TypeError("Expected a string value")
|
raise TypeError("Expected a string value")
|
||||||
|
|
||||||
|
|
||||||
def base64url_decode(input: Union[str, bytes]) -> bytes:
|
def base64url_decode(input: Union[bytes, str]) -> bytes:
|
||||||
if isinstance(input, str):
|
input_bytes = force_bytes(input)
|
||||||
input = input.encode("ascii")
|
|
||||||
|
|
||||||
rem = len(input) % 4
|
rem = len(input_bytes) % 4
|
||||||
|
|
||||||
if rem > 0:
|
if rem > 0:
|
||||||
input += b"=" * (4 - rem)
|
input_bytes += b"=" * (4 - rem)
|
||||||
|
|
||||||
return base64.urlsafe_b64decode(input)
|
return base64.urlsafe_b64decode(input_bytes)
|
||||||
|
|
||||||
|
|
||||||
def base64url_encode(input: bytes) -> bytes:
|
def base64url_encode(input: bytes) -> bytes:
|
||||||
|
@ -50,11 +49,8 @@ def to_base64url_uint(val: int) -> bytes:
|
||||||
return base64url_encode(int_bytes)
|
return base64url_encode(int_bytes)
|
||||||
|
|
||||||
|
|
||||||
def from_base64url_uint(val: Union[str, bytes]) -> int:
|
def from_base64url_uint(val: Union[bytes, str]) -> int:
|
||||||
if isinstance(val, str):
|
data = base64url_decode(force_bytes(val))
|
||||||
val = val.encode("ascii")
|
|
||||||
|
|
||||||
data = base64url_decode(val)
|
|
||||||
return int.from_bytes(data, byteorder="big")
|
return int.from_bytes(data, byteorder="big")
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,7 +74,7 @@ def bytes_from_int(val: int) -> bytes:
|
||||||
return val.to_bytes(byte_length, "big", signed=False)
|
return val.to_bytes(byte_length, "big", signed=False)
|
||||||
|
|
||||||
|
|
||||||
def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes:
|
def der_to_raw_signature(der_sig: bytes, curve: "EllipticCurve") -> bytes:
|
||||||
num_bits = curve.key_size
|
num_bits = curve.key_size
|
||||||
num_bytes = (num_bits + 7) // 8
|
num_bytes = (num_bits + 7) // 8
|
||||||
|
|
||||||
|
@ -87,7 +83,7 @@ def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes:
|
||||||
return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes)
|
return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes)
|
||||||
|
|
||||||
|
|
||||||
def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes:
|
def raw_to_der_signature(raw_sig: bytes, curve: "EllipticCurve") -> bytes:
|
||||||
num_bits = curve.key_size
|
num_bits = curve.key_size
|
||||||
num_bytes = (num_bits + 7) // 8
|
num_bytes = (num_bits + 7) // 8
|
||||||
|
|
||||||
|
@ -97,7 +93,7 @@ def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes:
|
||||||
r = bytes_to_number(raw_sig[:num_bytes])
|
r = bytes_to_number(raw_sig[:num_bytes])
|
||||||
s = bytes_to_number(raw_sig[num_bytes:])
|
s = bytes_to_number(raw_sig[num_bytes:])
|
||||||
|
|
||||||
return encode_dss_signature(r, s)
|
return bytes(encode_dss_signature(r, s))
|
||||||
|
|
||||||
|
|
||||||
# Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252
|
# Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252
|
||||||
|
|
|
@ -31,7 +31,7 @@ paho-mqtt==1.6.1
|
||||||
plexapi==4.13.4
|
plexapi==4.13.4
|
||||||
portend==3.1.0
|
portend==3.1.0
|
||||||
profilehooks==1.12.0
|
profilehooks==1.12.0
|
||||||
PyJWT==2.6.0
|
PyJWT==2.8.0
|
||||||
pyparsing==3.0.9
|
pyparsing==3.0.9
|
||||||
python-dateutil==2.8.2
|
python-dateutil==2.8.2
|
||||||
python-twitter==3.5
|
python-twitter==3.5
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue