Bump pyjwt from 2.4.0 to 2.6.0 (#1897)

* Bump pyjwt from 2.4.0 to 2.6.0

Bumps [pyjwt](https://github.com/jpadilla/pyjwt) from 2.4.0 to 2.6.0.
- [Release notes](https://github.com/jpadilla/pyjwt/releases)
- [Changelog](https://github.com/jpadilla/pyjwt/blob/master/CHANGELOG.rst)
- [Commits](https://github.com/jpadilla/pyjwt/compare/2.4.0...2.6.0)

---
updated-dependencies:
- dependency-name: pyjwt
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update pyjwt==2.6.0

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com>

[skip ci]
This commit is contained in:
dependabot[bot] 2022-11-14 11:27:25 -08:00 committed by GitHub
parent 79cf61c53e
commit 60da559332
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 326 additions and 103 deletions

View file

@ -1,6 +1,7 @@
from .api_jwk import PyJWK, PyJWKSet from .api_jwk import PyJWK, PyJWKSet
from .api_jws import ( from .api_jws import (
PyJWS, PyJWS,
get_algorithm_by_name,
get_unverified_header, get_unverified_header,
register_algorithm, register_algorithm,
unregister_algorithm, unregister_algorithm,
@ -25,7 +26,7 @@ from .exceptions import (
) )
from .jwks_client import PyJWKClient from .jwks_client import PyJWKClient
__version__ = "2.4.0" __version__ = "2.6.0"
__title__ = "PyJWT" __title__ = "PyJWT"
__description__ = "JSON Web Token implementation in Python" __description__ = "JSON Web Token implementation in Python"
@ -51,6 +52,7 @@ __all__ = [
"get_unverified_header", "get_unverified_header",
"register_algorithm", "register_algorithm",
"unregister_algorithm", "unregister_algorithm",
"get_algorithm_by_name",
# Exceptions # Exceptions
"DecodeError", "DecodeError",
"ExpiredSignatureError", "ExpiredSignatureError",

View file

@ -439,6 +439,41 @@ if has_crypto:
except InvalidSignature: except InvalidSignature:
return False return False
@staticmethod
def to_jwk(key_obj):
if isinstance(key_obj, EllipticCurvePrivateKey):
public_numbers = key_obj.public_key().public_numbers()
elif isinstance(key_obj, EllipticCurvePublicKey):
public_numbers = key_obj.public_numbers()
else:
raise InvalidKeyError("Not a public or private key")
if isinstance(key_obj.curve, ec.SECP256R1):
crv = "P-256"
elif isinstance(key_obj.curve, ec.SECP384R1):
crv = "P-384"
elif isinstance(key_obj.curve, ec.SECP521R1):
crv = "P-521"
elif isinstance(key_obj.curve, ec.SECP256K1):
crv = "secp256k1"
else:
raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
obj = {
"kty": "EC",
"crv": crv,
"x": to_base64url_uint(public_numbers.x).decode(),
"y": to_base64url_uint(public_numbers.y).decode(),
}
if isinstance(key_obj, EllipticCurvePrivateKey):
obj["d"] = to_base64url_uint(
key_obj.private_numbers().private_value
).decode()
return json.dumps(obj)
@staticmethod @staticmethod
def from_jwk(jwk): def from_jwk(jwk):
try: try:
@ -574,7 +609,7 @@ if has_crypto:
Sign a message ``msg`` using the EdDSA private key ``key`` Sign a message ``msg`` using the EdDSA private key ``key``
:param str|bytes msg: Message to sign :param str|bytes msg: Message to sign
:param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey` :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey`
or :class:`.Ed448PrivateKey` iinstance or :class:`.Ed448PrivateKey` isinstance
:return bytes signature: The signature, as bytes :return bytes signature: The signature, as bytes
""" """
msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg

View file

@ -1,4 +1,7 @@
from __future__ import annotations
import json import json
import time
from .algorithms import get_default_algorithms from .algorithms import get_default_algorithms
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
@ -74,17 +77,24 @@ class PyJWK:
class PyJWKSet: class PyJWKSet:
def __init__(self, keys): def __init__(self, keys: list[dict]) -> None:
self.keys = [] self.keys = []
if not keys or not isinstance(keys, list): if not keys:
raise PyJWKSetError("Invalid JWK Set value")
if len(keys) == 0:
raise PyJWKSetError("The JWK Set did not contain any keys") raise PyJWKSetError("The JWK Set did not contain any keys")
if not isinstance(keys, list):
raise PyJWKSetError("Invalid JWK Set value")
for key in keys: for key in keys:
try:
self.keys.append(PyJWK(key)) self.keys.append(PyJWK(key))
except PyJWKError:
# skip unusable keys
continue
if len(self.keys) == 0:
raise PyJWKSetError("The JWK Set did not contain any usable keys")
@staticmethod @staticmethod
def from_dict(obj): def from_dict(obj):
@ -101,3 +111,15 @@ class PyJWKSet:
if key.key_id == kid: if key.key_id == kid:
return key return key
raise KeyError(f"keyset has no key for kid: {kid}") raise KeyError(f"keyset has no key for kid: {kid}")
class PyJWTSetWithTimestamp:
def __init__(self, jwk_set: PyJWKSet):
self.jwk_set = jwk_set
self.timestamp = time.monotonic()
def get_jwk_set(self):
return self.jwk_set
def get_timestamp(self):
return self.timestamp

View file

@ -1,7 +1,9 @@
from __future__ import annotations
import binascii import binascii
import json import json
from collections.abc import Mapping import warnings
from typing import Any, Dict, List, Optional, Type from typing import Any, Type
from .algorithms import ( from .algorithms import (
Algorithm, Algorithm,
@ -16,12 +18,13 @@ from .exceptions import (
InvalidTokenError, InvalidTokenError,
) )
from .utils import base64url_decode, base64url_encode from .utils import base64url_decode, base64url_encode
from .warnings import RemovedInPyjwt3Warning
class PyJWS: class PyJWS:
header_typ = "JWT" header_typ = "JWT"
def __init__(self, algorithms=None, options=None): def __init__(self, algorithms=None, options=None) -> None:
self._algorithms = get_default_algorithms() self._algorithms = get_default_algorithms()
self._valid_algs = ( self._valid_algs = (
set(algorithms) if algorithms is not None else set(self._algorithms) set(algorithms) if algorithms is not None else set(self._algorithms)
@ -37,10 +40,10 @@ class PyJWS:
self.options = {**self._get_default_options(), **options} self.options = {**self._get_default_options(), **options}
@staticmethod @staticmethod
def _get_default_options(): def _get_default_options() -> dict[str, bool]:
return {"verify_signature": True} return {"verify_signature": True}
def register_algorithm(self, alg_id, alg_obj): def register_algorithm(self, alg_id: str, alg_obj: Algorithm) -> None:
""" """
Registers a new Algorithm for use when creating and verifying tokens. Registers a new Algorithm for use when creating and verifying tokens.
""" """
@ -53,7 +56,7 @@ class PyJWS:
self._algorithms[alg_id] = alg_obj self._algorithms[alg_id] = alg_obj
self._valid_algs.add(alg_id) self._valid_algs.add(alg_id)
def unregister_algorithm(self, alg_id): def unregister_algorithm(self, alg_id: str) -> None:
""" """
Unregisters an Algorithm for use when creating and verifying tokens Unregisters an Algorithm for use when creating and verifying tokens
Throws KeyError if algorithm is not registered. Throws KeyError if algorithm is not registered.
@ -67,38 +70,55 @@ class PyJWS:
del self._algorithms[alg_id] del self._algorithms[alg_id]
self._valid_algs.remove(alg_id) self._valid_algs.remove(alg_id)
def get_algorithms(self): def get_algorithms(self) -> list[str]:
""" """
Returns a list of supported values for the 'alg' parameter. Returns a list of supported values for the 'alg' parameter.
""" """
return list(self._valid_algs) return list(self._valid_algs)
def get_algorithm_by_name(self, alg_name: str) -> Algorithm:
"""
For a given string name, return the matching Algorithm object.
Example usage:
>>> jws_obj.get_algorithm_by_name("RS256")
"""
try:
return self._algorithms[alg_name]
except KeyError as e:
if not has_crypto and alg_name in requires_cryptography:
raise NotImplementedError(
f"Algorithm '{alg_name}' could not be found. Do you have cryptography installed?"
) from e
raise NotImplementedError("Algorithm not supported") from e
def encode( def encode(
self, self,
payload: bytes, payload: bytes,
key: str, key: str,
algorithm: Optional[str] = "HS256", algorithm: str | None = "HS256",
headers: Optional[Dict] = None, headers: dict[str, Any] | None = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None, json_encoder: Type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False, is_payload_detached: bool = False,
) -> str: ) -> str:
segments = [] segments = []
if algorithm is None: # declare a new var to narrow the type for type checkers
algorithm = "none" algorithm_: str = algorithm if algorithm is not None else "none"
# Prefer headers values if present to function parameters. # Prefer headers values if present to function parameters.
if headers: if headers:
headers_alg = headers.get("alg") headers_alg = headers.get("alg")
if headers_alg: if headers_alg:
algorithm = headers["alg"] algorithm_ = headers["alg"]
headers_b64 = headers.get("b64") headers_b64 = headers.get("b64")
if headers_b64 is False: if headers_b64 is False:
is_payload_detached = True is_payload_detached = True
# Header # Header
header = {"typ": self.header_typ, "alg": algorithm} # type: Dict[str, Any] header: dict[str, Any] = {"typ": self.header_typ, "alg": algorithm_}
if headers: if headers:
self._validate_headers(headers) self._validate_headers(headers)
@ -113,8 +133,9 @@ class PyJWS:
# True is the standard value for b64, so no need for it # True is the standard value for b64, so no need for it
del header["b64"] del header["b64"]
# Fix for headers misorder - issue #715
json_header = json.dumps( json_header = json.dumps(
header, separators=(",", ":"), cls=json_encoder header, separators=(",", ":"), cls=json_encoder, sort_keys=True
).encode() ).encode()
segments.append(base64url_encode(json_header)) segments.append(base64url_encode(json_header))
@ -128,18 +149,10 @@ class PyJWS:
# Segments # Segments
signing_input = b".".join(segments) signing_input = b".".join(segments)
try: alg_obj = self.get_algorithm_by_name(algorithm_)
alg_obj = self._algorithms[algorithm]
key = alg_obj.prepare_key(key) key = alg_obj.prepare_key(key)
signature = alg_obj.sign(signing_input, key) signature = alg_obj.sign(signing_input, key)
except KeyError as e:
if not has_crypto and algorithm in requires_cryptography:
raise NotImplementedError(
f"Algorithm '{algorithm}' could not be found. Do you have cryptography installed?"
) from e
raise NotImplementedError("Algorithm not supported") from e
segments.append(base64url_encode(signature)) segments.append(base64url_encode(signature))
# Don't put the payload content inside the encoded token when detached # Don't put the payload content inside the encoded token when detached
@ -153,11 +166,18 @@ class PyJWS:
self, self,
jwt: str, jwt: str,
key: str = "", key: str = "",
algorithms: Optional[List[str]] = None, algorithms: list[str] | None = None,
options: Optional[Dict] = None, options: dict[str, Any] | None = None,
detached_payload: Optional[bytes] = None, detached_payload: bytes | None = None,
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> dict[str, Any]:
if kwargs:
warnings.warn(
"passing additional kwargs to decode_complete() is deprecated "
"and will be removed in pyjwt version 3. "
f"Unsupported kwargs: {tuple(kwargs.keys())}",
RemovedInPyjwt3Warning,
)
if options is None: if options is None:
options = {} options = {}
merged_options = {**self.options, **options} merged_options = {**self.options, **options}
@ -191,14 +211,24 @@ class PyJWS:
self, self,
jwt: str, jwt: str,
key: str = "", key: str = "",
algorithms: Optional[List[str]] = None, algorithms: list[str] | None = None,
options: Optional[Dict] = None, options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
**kwargs, **kwargs,
) -> str: ) -> str:
decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs) if kwargs:
warnings.warn(
"passing additional kwargs to decode() is deprecated "
"and will be removed in pyjwt version 3. "
f"Unsupported kwargs: {tuple(kwargs.keys())}",
RemovedInPyjwt3Warning,
)
decoded = self.decode_complete(
jwt, key, algorithms, options, detached_payload=detached_payload
)
return decoded["payload"] return decoded["payload"]
def get_unverified_header(self, jwt): def get_unverified_header(self, jwt: str | bytes) -> dict:
"""Returns back the JWT header parameters as a dict() """Returns back the JWT header parameters as a dict()
Note: The signature is not verified so the header parameters Note: The signature is not verified so the header parameters
@ -209,7 +239,7 @@ class PyJWS:
return headers return headers
def _load(self, jwt): def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
if isinstance(jwt, str): if isinstance(jwt, str):
jwt = jwt.encode("utf-8") jwt = jwt.encode("utf-8")
@ -232,7 +262,7 @@ class PyJWS:
except ValueError as e: except ValueError as e:
raise DecodeError(f"Invalid header string: {e}") from e raise DecodeError(f"Invalid header string: {e}") from e
if not isinstance(header, Mapping): if not isinstance(header, dict):
raise DecodeError("Invalid header string: must be a json object") raise DecodeError("Invalid header string: must be a json object")
try: try:
@ -249,33 +279,32 @@ class PyJWS:
def _verify_signature( def _verify_signature(
self, self,
signing_input, signing_input: bytes,
header, header: dict,
signature, signature: bytes,
key="", key: str = "",
algorithms=None, algorithms: list[str] | None = None,
): ) -> None:
alg = header.get("alg") alg = header.get("alg")
if algorithms is not None and alg not in algorithms: if not alg or (algorithms is not None and alg not in algorithms):
raise InvalidAlgorithmError("The specified alg value is not allowed") raise InvalidAlgorithmError("The specified alg value is not allowed")
try: try:
alg_obj = self._algorithms[alg] alg_obj = self.get_algorithm_by_name(alg)
except NotImplementedError as e:
raise InvalidAlgorithmError("Algorithm not supported") from e
key = alg_obj.prepare_key(key) key = alg_obj.prepare_key(key)
if not alg_obj.verify(signing_input, key, signature): if not alg_obj.verify(signing_input, key, signature):
raise InvalidSignatureError("Signature verification failed") raise InvalidSignatureError("Signature verification failed")
except KeyError as e: def _validate_headers(self, headers: dict[str, Any]) -> None:
raise InvalidAlgorithmError("Algorithm not supported") from e
def _validate_headers(self, headers):
if "kid" in headers: if "kid" in headers:
self._validate_kid(headers["kid"]) self._validate_kid(headers["kid"])
def _validate_kid(self, kid): def _validate_kid(self, kid: str) -> None:
if not isinstance(kid, str): if not isinstance(kid, str):
raise InvalidTokenError("Key ID header parameter must be a string") raise InvalidTokenError("Key ID header parameter must be a string")
@ -286,4 +315,5 @@ decode_complete = _jws_global_obj.decode_complete
decode = _jws_global_obj.decode decode = _jws_global_obj.decode
register_algorithm = _jws_global_obj.register_algorithm register_algorithm = _jws_global_obj.register_algorithm
unregister_algorithm = _jws_global_obj.unregister_algorithm unregister_algorithm = _jws_global_obj.unregister_algorithm
get_algorithm_by_name = _jws_global_obj.get_algorithm_by_name
get_unverified_header = _jws_global_obj.get_unverified_header get_unverified_header = _jws_global_obj.get_unverified_header

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import json import json
import warnings import warnings
from calendar import timegm from calendar import timegm
@ -15,6 +17,7 @@ from .exceptions import (
InvalidIssuerError, InvalidIssuerError,
MissingRequiredClaimError, MissingRequiredClaimError,
) )
from .warnings import RemovedInPyjwt3Warning
class PyJWT: class PyJWT:
@ -40,7 +43,7 @@ class PyJWT:
payload: Dict[str, Any], payload: Dict[str, Any],
key: str, key: str,
algorithm: Optional[str] = "HS256", algorithm: Optional[str] = "HS256",
headers: Optional[Dict] = None, headers: Optional[Dict[str, Any]] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None, json_encoder: Optional[Type[json.JSONEncoder]] = None,
) -> str: ) -> str:
# Check that we get a mapping # Check that we get a mapping
@ -68,16 +71,33 @@ class PyJWT:
jwt: str, jwt: str,
key: str = "", key: str = "",
algorithms: Optional[List[str]] = None, algorithms: Optional[List[str]] = None,
options: Optional[Dict] = None, options: Optional[Dict[str, Any]] = None,
# deprecated arg, remove in pyjwt3
verify: Optional[bool] = None,
# could be used as passthrough to api_jws, consider removal in pyjwt3
detached_payload: Optional[bytes] = None,
# passthrough arguments to _validate_claims
# consider putting in options
audience: Optional[Union[str, Iterable[str]]] = None,
issuer: Optional[str] = None,
leeway: Union[int, float, timedelta] = 0,
# kwargs
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if kwargs:
warnings.warn(
"passing additional kwargs to decode_complete() is deprecated "
"and will be removed in pyjwt version 3. "
f"Unsupported kwargs: {tuple(kwargs.keys())}",
RemovedInPyjwt3Warning,
)
options = dict(options or {}) # shallow-copy or initialize an empty dict options = dict(options or {}) # shallow-copy or initialize an empty dict
options.setdefault("verify_signature", True) options.setdefault("verify_signature", True)
# If the user has set the legacy `verify` argument, and it doesn't match # If the user has set the legacy `verify` argument, and it doesn't match
# what the relevant `options` entry for the argument is, inform the user # what the relevant `options` entry for the argument is, inform the user
# that they're likely making a mistake. # that they're likely making a mistake.
if "verify" in kwargs and kwargs["verify"] != options["verify_signature"]: if verify is not None and verify != options["verify_signature"]:
warnings.warn( warnings.warn(
"The `verify` argument to `decode` does nothing in PyJWT 2.0 and newer. " "The `verify` argument to `decode` does nothing in PyJWT 2.0 and newer. "
"The equivalent is setting `verify_signature` to False in the `options` dictionary. " "The equivalent is setting `verify_signature` to False in the `options` dictionary. "
@ -102,7 +122,7 @@ class PyJWT:
key=key, key=key,
algorithms=algorithms, algorithms=algorithms,
options=options, options=options,
**kwargs, detached_payload=detached_payload,
) )
try: try:
@ -113,7 +133,9 @@ class PyJWT:
raise DecodeError("Invalid payload string: must be a json object") raise DecodeError("Invalid payload string: must be a json object")
merged_options = {**self.options, **options} merged_options = {**self.options, **options}
self._validate_claims(payload, merged_options, **kwargs) self._validate_claims(
payload, merged_options, audience=audience, issuer=issuer, leeway=leeway
)
decoded["payload"] = payload decoded["payload"] = payload
return decoded return decoded
@ -123,20 +145,45 @@ class PyJWT:
jwt: str, jwt: str,
key: str = "", key: str = "",
algorithms: Optional[List[str]] = None, algorithms: Optional[List[str]] = None,
options: Optional[Dict] = None, options: Optional[Dict[str, Any]] = None,
# deprecated arg, remove in pyjwt3
verify: Optional[bool] = None,
# could be used as passthrough to api_jws, consider removal in pyjwt3
detached_payload: Optional[bytes] = None,
# passthrough arguments to _validate_claims
# consider putting in options
audience: Optional[Union[str, Iterable[str]]] = None,
issuer: Optional[str] = None,
leeway: Union[int, float, timedelta] = 0,
# kwargs
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
decoded = self.decode_complete(jwt, key, algorithms, options, **kwargs) if kwargs:
warnings.warn(
"passing additional kwargs to decode() is deprecated "
"and will be removed in pyjwt version 3. "
f"Unsupported kwargs: {tuple(kwargs.keys())}",
RemovedInPyjwt3Warning,
)
decoded = self.decode_complete(
jwt,
key,
algorithms,
options,
verify=verify,
detached_payload=detached_payload,
audience=audience,
issuer=issuer,
leeway=leeway,
)
return decoded["payload"] return decoded["payload"]
def _validate_claims( def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0):
self, payload, options, audience=None, issuer=None, leeway=0, **kwargs
):
if isinstance(leeway, timedelta): if isinstance(leeway, timedelta):
leeway = leeway.total_seconds() leeway = leeway.total_seconds()
if not isinstance(audience, (bytes, str, type(None), Iterable)): if audience is not None and not isinstance(audience, (str, Iterable)):
raise TypeError("audience must be a string, iterable, or None") raise TypeError("audience must be a string, iterable or None")
self._validate_required_claims(payload, options) self._validate_required_claims(payload, options)
@ -163,10 +210,13 @@ class PyJWT:
raise MissingRequiredClaimError(claim) raise MissingRequiredClaimError(claim)
def _validate_iat(self, payload, now, leeway): def _validate_iat(self, payload, now, leeway):
iat = payload["iat"]
try: try:
int(payload["iat"]) int(iat)
except ValueError: except ValueError:
raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.") raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.")
if iat > (now + leeway):
raise ImmatureSignatureError("The token is not yet valid (iat)")
def _validate_nbf(self, payload, now, leeway): def _validate_nbf(self, payload, now, leeway):
try: try:
@ -183,7 +233,7 @@ class PyJWT:
except ValueError: except ValueError:
raise DecodeError("Expiration Time claim (exp) must be an" " integer.") raise DecodeError("Expiration Time claim (exp) must be an" " integer.")
if exp < (now - leeway): if exp <= (now - leeway):
raise ExpiredSignatureError("Signature has expired") raise ExpiredSignatureError("Signature has expired")
def _validate_aud(self, payload, audience): def _validate_aud(self, payload, audience):

View file

@ -1,16 +1,17 @@
import json import json
import platform import platform
import sys import sys
from typing import Dict
from . import __version__ as pyjwt_version from . import __version__ as pyjwt_version
try: try:
import cryptography import cryptography
except ModuleNotFoundError: except ModuleNotFoundError:
cryptography = None # type: ignore cryptography = None
def info(): def info() -> Dict[str, Dict[str, str]]:
""" """
Generate information for a bug report. Generate information for a bug report.
Based on the requests package help utility module. Based on the requests package help utility module.
@ -28,14 +29,15 @@ def info():
if implementation == "CPython": if implementation == "CPython":
implementation_version = platform.python_version() implementation_version = platform.python_version()
elif implementation == "PyPy": elif implementation == "PyPy":
pypy_version_info = getattr(sys, "pypy_version_info")
implementation_version = ( implementation_version = (
f"{sys.pypy_version_info.major}." f"{pypy_version_info.major}."
f"{sys.pypy_version_info.minor}." f"{pypy_version_info.minor}."
f"{sys.pypy_version_info.micro}" f"{pypy_version_info.micro}"
) )
if sys.pypy_version_info.releaselevel != "final": if pypy_version_info.releaselevel != "final":
implementation_version = "".join( implementation_version = "".join(
[implementation_version, sys.pypy_version_info.releaselevel] [implementation_version, pypy_version_info.releaselevel]
) )
else: else:
implementation_version = "Unknown" implementation_version = "Unknown"
@ -51,7 +53,7 @@ def info():
} }
def main(): def main() -> None:
"""Pretty-print the bug information as JSON.""" """Pretty-print the bug information as JSON."""
print(json.dumps(info(), sort_keys=True, indent=2)) print(json.dumps(info(), sort_keys=True, indent=2))

32
lib/jwt/jwk_set_cache.py Normal file
View file

@ -0,0 +1,32 @@
import time
from typing import Optional
from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp
class JWKSetCache:
def __init__(self, lifespan: int):
self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
self.lifespan = lifespan
def put(self, jwk_set: PyJWKSet):
if jwk_set is not None:
self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
else:
# clear cache
self.jwk_set_with_timestamp = None
def get(self) -> Optional[PyJWKSet]:
if self.jwk_set_with_timestamp is None or self.is_expired():
return None
return self.jwk_set_with_timestamp.get_jwk_set()
def is_expired(self) -> bool:
return (
self.jwk_set_with_timestamp is not None
and self.lifespan > -1
and time.monotonic()
> self.jwk_set_with_timestamp.get_timestamp() + self.lifespan
)

View file

@ -1,31 +1,68 @@
import json import json
import urllib.request import urllib.request
from functools import lru_cache from functools import lru_cache
from typing import Any, List from typing import Any, List, Optional
from urllib.error import URLError
from .api_jwk import PyJWK, PyJWKSet from .api_jwk import PyJWK, PyJWKSet
from .api_jwt import decode_complete as decode_token from .api_jwt import decode_complete as decode_token
from .exceptions import PyJWKClientError from .exceptions import PyJWKClientError
from .jwk_set_cache import JWKSetCache
class PyJWKClient: class PyJWKClient:
def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16): def __init__(
self,
uri: str,
cache_keys: bool = False,
max_cached_keys: int = 16,
cache_jwk_set: bool = True,
lifespan: int = 300,
):
self.uri = uri self.uri = uri
self.jwk_set_cache: Optional[JWKSetCache] = None
if cache_jwk_set:
# Init jwt set cache with default or given lifespan.
# Default lifespan is 300 seconds (5 minutes).
if lifespan <= 0:
raise PyJWKClientError(
f'Lifespan must be greater than 0, the input is "{lifespan}"'
)
self.jwk_set_cache = JWKSetCache(lifespan)
else:
self.jwk_set_cache = None
if cache_keys: if cache_keys:
# Cache signing keys # Cache signing keys
# Ignore mypy (https://github.com/python/mypy/issues/2427) # Ignore mypy (https://github.com/python/mypy/issues/2427)
self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore
def fetch_data(self) -> Any: def fetch_data(self) -> Any:
jwk_set: Any = None
try:
with urllib.request.urlopen(self.uri) as response: with urllib.request.urlopen(self.uri) as response:
return json.load(response) jwk_set = json.load(response)
except URLError as e:
raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"')
else:
return jwk_set
finally:
if self.jwk_set_cache is not None:
self.jwk_set_cache.put(jwk_set)
def get_jwk_set(self) -> PyJWKSet: def get_jwk_set(self, refresh: bool = False) -> PyJWKSet:
data = None
if self.jwk_set_cache is not None and not refresh:
data = self.jwk_set_cache.get()
if data is None:
data = self.fetch_data() data = self.fetch_data()
return PyJWKSet.from_dict(data) return PyJWKSet.from_dict(data)
def get_signing_keys(self) -> List[PyJWK]: def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:
jwk_set = self.get_jwk_set() jwk_set = self.get_jwk_set(refresh)
signing_keys = [ signing_keys = [
jwk_set_key jwk_set_key
for jwk_set_key in jwk_set.keys for jwk_set_key in jwk_set.keys
@ -39,12 +76,12 @@ class PyJWKClient:
def get_signing_key(self, kid: str) -> PyJWK: def get_signing_key(self, kid: str) -> PyJWK:
signing_keys = self.get_signing_keys() signing_keys = self.get_signing_keys()
signing_key = None signing_key = self.match_kid(signing_keys, kid)
for key in signing_keys: if not signing_key:
if key.key_id == kid: # If no matching signing key from the jwk set, refresh the jwk set and try again.
signing_key = key signing_keys = self.get_signing_keys(refresh=True)
break signing_key = self.match_kid(signing_keys, kid)
if not signing_key: if not signing_key:
raise PyJWKClientError( raise PyJWKClientError(
@ -57,3 +94,14 @@ class PyJWKClient:
unverified = decode_token(token, options={"verify_signature": False}) unverified = decode_token(token, options={"verify_signature": False})
header = unverified["header"] header = unverified["header"]
return self.get_signing_key(header.get("kid")) return self.get_signing_key(header.get("kid"))
@staticmethod
def match_kid(signing_keys: List[PyJWK], kid: str) -> Optional[PyJWK]:
signing_key = None
for key in signing_keys:
if key.key_id == kid:
signing_key = key
break
return signing_key

View file

@ -1,7 +1,7 @@
import base64 import base64
import binascii import binascii
import re import re
from typing import Any, Union from typing import Union
try: try:
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
@ -10,7 +10,7 @@ try:
encode_dss_signature, encode_dss_signature,
) )
except ModuleNotFoundError: except ModuleNotFoundError:
EllipticCurve = Any # type: ignore EllipticCurve = None
def force_bytes(value: Union[str, bytes]) -> bytes: def force_bytes(value: Union[str, bytes]) -> bytes:
@ -136,7 +136,7 @@ def is_pem_format(key: bytes) -> bool:
# Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46 # Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46
_CERT_SUFFIX = b"-cert-v01@openssh.com" _CERT_SUFFIX = b"-cert-v01@openssh.com"
_SSH_PUBKEY_RC = re.compile(br"\A(\S+)[ \t]+(\S+)") _SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
_SSH_KEY_FORMATS = [ _SSH_KEY_FORMATS = [
b"ssh-ed25519", b"ssh-ed25519",
b"ssh-rsa", b"ssh-rsa",

2
lib/jwt/warnings.py Normal file
View file

@ -0,0 +1,2 @@
class RemovedInPyjwt3Warning(DeprecationWarning):
pass

View file

@ -30,7 +30,7 @@ paho-mqtt==1.6.1
plexapi==4.13.1 plexapi==4.13.1
portend==3.1.0 portend==3.1.0
profilehooks==1.12.0 profilehooks==1.12.0
PyJWT==2.4.0 PyJWT==2.6.0
pyparsing==3.0.9 pyparsing==3.0.9
python-dateutil==2.8.2 python-dateutil==2.8.2
python-twitter==3.5 python-twitter==3.5