Bump pyjwt from 2.6.0 to 2.8.0 (#2115)

* Bump pyjwt from 2.6.0 to 2.8.0

Bumps [pyjwt](https://github.com/jpadilla/pyjwt) from 2.6.0 to 2.8.0.
- [Release notes](https://github.com/jpadilla/pyjwt/releases)
- [Changelog](https://github.com/jpadilla/pyjwt/blob/master/CHANGELOG.rst)
- [Commits](https://github.com/jpadilla/pyjwt/compare/2.6.0...2.8.0)

---
updated-dependencies:
- dependency-name: pyjwt
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update pyjwt==2.8.0

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com>

[skip ci]
This commit is contained in:
dependabot[bot] 2023-08-23 21:45:15 -07:00 committed by GitHub
parent 77f38bbf93
commit c93f470371
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 542 additions and 260 deletions

View file

@ -19,6 +19,7 @@ from .exceptions import (
InvalidSignatureError,
InvalidTokenError,
MissingRequiredClaimError,
PyJWKClientConnectionError,
PyJWKClientError,
PyJWKError,
PyJWKSetError,
@ -26,7 +27,7 @@ from .exceptions import (
)
from .jwks_client import PyJWKClient
__version__ = "2.6.0"
__version__ = "2.8.0"
__title__ = "PyJWT"
__description__ = "JSON Web Token implementation in Python"
@ -65,6 +66,7 @@ __all__ = [
"InvalidSignatureError",
"InvalidTokenError",
"MissingRequiredClaimError",
"PyJWKClientConnectionError",
"PyJWKClientError",
"PyJWKError",
"PyJWKSetError",

View file

@ -1,8 +1,14 @@
from __future__ import annotations
import hashlib
import hmac
import json
import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload
from .exceptions import InvalidKeyError
from .types import HashlibHash, JWKDict
from .utils import (
base64url_decode,
base64url_encode,
@ -15,14 +21,28 @@ from .utils import (
to_base64url_uint,
)
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
try:
import cryptography.exceptions
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec, padding
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.asymmetric.ec import (
ECDSA,
SECP256K1,
SECP256R1,
SECP384R1,
SECP521R1,
EllipticCurve,
EllipticCurvePrivateKey,
EllipticCurvePrivateNumbers,
EllipticCurvePublicKey,
EllipticCurvePublicNumbers,
)
from cryptography.hazmat.primitives.asymmetric.ed448 import (
Ed448PrivateKey,
@ -56,6 +76,23 @@ try:
except ModuleNotFoundError:
has_crypto = False
if TYPE_CHECKING:
# Type aliases for convenience in algorithms method signatures
AllowedRSAKeys = RSAPrivateKey | RSAPublicKey
AllowedECKeys = EllipticCurvePrivateKey | EllipticCurvePublicKey
AllowedOKPKeys = (
Ed25519PrivateKey | Ed25519PublicKey | Ed448PrivateKey | Ed448PublicKey
)
AllowedKeys = AllowedRSAKeys | AllowedECKeys | AllowedOKPKeys
AllowedPrivateKeys = (
RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey | Ed448PrivateKey
)
AllowedPublicKeys = (
RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey | Ed448PublicKey
)
requires_cryptography = {
"RS256",
"RS384",
@ -72,7 +109,7 @@ requires_cryptography = {
}
def get_default_algorithms():
def get_default_algorithms() -> dict[str, Algorithm]:
"""
Returns the algorithms that are implemented by the library.
"""
@ -106,45 +143,79 @@ def get_default_algorithms():
return default_algorithms
class Algorithm:
class Algorithm(ABC):
"""
The interface for an algorithm used to sign and verify tokens.
"""
def prepare_key(self, key):
def compute_hash_digest(self, bytestr: bytes) -> bytes:
"""
Compute a hash digest using the specified algorithm's hash algorithm.
If there is no hash algorithm, raises a NotImplementedError.
"""
# lookup self.hash_alg if defined in a way that mypy can understand
hash_alg = getattr(self, "hash_alg", None)
if hash_alg is None:
raise NotImplementedError
if (
has_crypto
and isinstance(hash_alg, type)
and issubclass(hash_alg, hashes.HashAlgorithm)
):
digest = hashes.Hash(hash_alg(), backend=default_backend())
digest.update(bytestr)
return bytes(digest.finalize())
else:
return bytes(hash_alg(bytestr).digest())
@abstractmethod
def prepare_key(self, key: Any) -> Any:
"""
Performs necessary validation and conversions on the key and returns
the key value in the proper format for sign() and verify().
"""
raise NotImplementedError
def sign(self, msg, key):
@abstractmethod
def sign(self, msg: bytes, key: Any) -> bytes:
"""
Returns a digital signature for the specified message
using the specified key value.
"""
raise NotImplementedError
def verify(self, msg, key, sig):
@abstractmethod
def verify(self, msg: bytes, key: Any, sig: bytes) -> bool:
"""
Verifies that the specified digital signature is valid
for the specified message and key values.
"""
raise NotImplementedError
@overload
@staticmethod
@abstractmethod
def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover
@overload
@staticmethod
@abstractmethod
def to_jwk(key_obj, as_dict: Literal[False] = False) -> str:
... # pragma: no cover
@staticmethod
def to_jwk(key_obj):
@abstractmethod
def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]:
"""
Serializes a given RSA key into a JWK
Serializes a given key into a JWK
"""
raise NotImplementedError
@staticmethod
def from_jwk(jwk):
@abstractmethod
def from_jwk(jwk: str | JWKDict) -> Any:
"""
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
Deserializes a given key from JWK back into a key object
"""
raise NotImplementedError
class NoneAlgorithm(Algorithm):
@ -153,7 +224,7 @@ class NoneAlgorithm(Algorithm):
operations are required.
"""
def prepare_key(self, key):
def prepare_key(self, key: str | None) -> None:
if key == "":
key = None
@ -162,12 +233,20 @@ class NoneAlgorithm(Algorithm):
return key
def sign(self, msg, key):
def sign(self, msg: bytes, key: None) -> bytes:
return b""
def verify(self, msg, key, sig):
def verify(self, msg: bytes, key: None, sig: bytes) -> bool:
return False
@staticmethod
def to_jwk(key_obj: Any, as_dict: bool = False) -> NoReturn:
raise NotImplementedError()
@staticmethod
def from_jwk(jwk: str | JWKDict) -> NoReturn:
raise NotImplementedError()
class HMACAlgorithm(Algorithm):
"""
@ -175,38 +254,51 @@ class HMACAlgorithm(Algorithm):
and the specified hash function.
"""
SHA256 = hashlib.sha256
SHA384 = hashlib.sha384
SHA512 = hashlib.sha512
SHA256: ClassVar[HashlibHash] = hashlib.sha256
SHA384: ClassVar[HashlibHash] = hashlib.sha384
SHA512: ClassVar[HashlibHash] = hashlib.sha512
def __init__(self, hash_alg):
def __init__(self, hash_alg: HashlibHash) -> None:
self.hash_alg = hash_alg
def prepare_key(self, key):
key = force_bytes(key)
def prepare_key(self, key: str | bytes) -> bytes:
key_bytes = force_bytes(key)
if is_pem_format(key) or is_ssh_key(key):
if is_pem_format(key_bytes) or is_ssh_key(key_bytes):
raise InvalidKeyError(
"The specified key is an asymmetric key or x509 certificate and"
" should not be used as an HMAC secret."
)
return key
return key_bytes
@overload
@staticmethod
def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover
@overload
@staticmethod
def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str:
... # pragma: no cover
@staticmethod
def to_jwk(key_obj):
return json.dumps(
{
"k": base64url_encode(force_bytes(key_obj)).decode(),
"kty": "oct",
}
)
def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]:
jwk = {
"k": base64url_encode(force_bytes(key_obj)).decode(),
"kty": "oct",
}
if as_dict:
return jwk
else:
return json.dumps(jwk)
@staticmethod
def from_jwk(jwk):
def from_jwk(jwk: str | JWKDict) -> bytes:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)
obj: JWKDict = json.loads(jwk)
elif isinstance(jwk, dict):
obj = jwk
else:
@ -219,10 +311,10 @@ class HMACAlgorithm(Algorithm):
return base64url_decode(obj["k"])
def sign(self, msg, key):
def sign(self, msg: bytes, key: bytes) -> bytes:
return hmac.new(key, msg, self.hash_alg).digest()
def verify(self, msg, key, sig):
def verify(self, msg: bytes, key: bytes, sig: bytes) -> bool:
return hmac.compare_digest(sig, self.sign(msg, key))
@ -234,36 +326,49 @@ if has_crypto:
RSASSA-PKCS-v1_5 and the specified hash function.
"""
SHA256 = hashes.SHA256
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
def __init__(self, hash_alg):
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg
def prepare_key(self, key):
def prepare_key(self, key: AllowedRSAKeys | str | bytes) -> AllowedRSAKeys:
if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
return key
if not isinstance(key, (bytes, str)):
raise TypeError("Expecting a PEM-formatted key.")
key = force_bytes(key)
key_bytes = force_bytes(key)
try:
if key.startswith(b"ssh-rsa"):
key = load_ssh_public_key(key)
if key_bytes.startswith(b"ssh-rsa"):
return cast(RSAPublicKey, load_ssh_public_key(key_bytes))
else:
key = load_pem_private_key(key, password=None)
return cast(
RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
)
except ValueError:
key = load_pem_public_key(key)
return key
return cast(RSAPublicKey, load_pem_public_key(key_bytes))
@overload
@staticmethod
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover
@overload
@staticmethod
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str:
... # pragma: no cover
@staticmethod
def to_jwk(key_obj):
obj = None
def to_jwk(
key_obj: AllowedRSAKeys, as_dict: bool = False
) -> Union[JWKDict, str]:
obj: dict[str, Any] | None = None
if getattr(key_obj, "private_numbers", None):
if hasattr(key_obj, "private_numbers"):
# Private key
numbers = key_obj.private_numbers()
@ -280,7 +385,7 @@ if has_crypto:
"qi": to_base64url_uint(numbers.iqmp).decode(),
}
elif getattr(key_obj, "verify", None):
elif hasattr(key_obj, "verify"):
# Public key
numbers = key_obj.public_numbers()
@ -293,10 +398,13 @@ if has_crypto:
else:
raise InvalidKeyError("Not a public or private key")
return json.dumps(obj)
if as_dict:
return obj
else:
return json.dumps(obj)
@staticmethod
def from_jwk(jwk):
def from_jwk(jwk: str | JWKDict) -> AllowedRSAKeys:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)
@ -360,19 +468,17 @@ if has_crypto:
return numbers.private_key()
elif "n" in obj and "e" in obj:
# Public key
numbers = RSAPublicNumbers(
return RSAPublicNumbers(
from_base64url_uint(obj["e"]),
from_base64url_uint(obj["n"]),
)
return numbers.public_key()
).public_key()
else:
raise InvalidKeyError("Not a public or private key")
def sign(self, msg, key):
def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
def verify(self, msg, key, sig):
def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
try:
key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
return True
@ -385,63 +491,79 @@ if has_crypto:
ECDSA and the specified hash function
"""
SHA256 = hashes.SHA256
SHA384 = hashes.SHA384
SHA512 = hashes.SHA512
SHA256: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA256
SHA384: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA384
SHA512: ClassVar[type[hashes.HashAlgorithm]] = hashes.SHA512
def __init__(self, hash_alg):
def __init__(self, hash_alg: type[hashes.HashAlgorithm]) -> None:
self.hash_alg = hash_alg
def prepare_key(self, key):
def prepare_key(self, key: AllowedECKeys | str | bytes) -> AllowedECKeys:
if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
return key
if not isinstance(key, (bytes, str)):
raise TypeError("Expecting a PEM-formatted key.")
key = force_bytes(key)
key_bytes = force_bytes(key)
# Attempt to load key. We don't know if it's
# a Signing Key or a Verifying Key, so we try
# the Verifying Key first.
try:
if key.startswith(b"ecdsa-sha2-"):
key = load_ssh_public_key(key)
if key_bytes.startswith(b"ecdsa-sha2-"):
crypto_key = load_ssh_public_key(key_bytes)
else:
key = load_pem_public_key(key)
crypto_key = load_pem_public_key(key_bytes) # type: ignore[assignment]
except ValueError:
key = load_pem_private_key(key, password=None)
crypto_key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
# Explicit check the key to prevent confusing errors from cryptography
if not isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
if not isinstance(
crypto_key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)
):
raise InvalidKeyError(
"Expecting a EllipticCurvePrivateKey/EllipticCurvePublicKey. Wrong key provided for ECDSA algorithms"
)
return key
return crypto_key
def sign(self, msg, key):
der_sig = key.sign(msg, ec.ECDSA(self.hash_alg()))
def sign(self, msg: bytes, key: EllipticCurvePrivateKey) -> bytes:
der_sig = key.sign(msg, ECDSA(self.hash_alg()))
return der_to_raw_signature(der_sig, key.curve)
def verify(self, msg, key, sig):
def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool:
try:
der_sig = raw_to_der_signature(sig, key.curve)
except ValueError:
return False
try:
if isinstance(key, EllipticCurvePrivateKey):
key = key.public_key()
key.verify(der_sig, msg, ec.ECDSA(self.hash_alg()))
public_key = (
key.public_key()
if isinstance(key, EllipticCurvePrivateKey)
else key
)
public_key.verify(der_sig, msg, ECDSA(self.hash_alg()))
return True
except InvalidSignature:
return False
@overload
@staticmethod
def to_jwk(key_obj):
def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover
@overload
@staticmethod
def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str:
... # pragma: no cover
@staticmethod
def to_jwk(
key_obj: AllowedECKeys, as_dict: bool = False
) -> Union[JWKDict, str]:
if isinstance(key_obj, EllipticCurvePrivateKey):
public_numbers = key_obj.public_key().public_numbers()
elif isinstance(key_obj, EllipticCurvePublicKey):
@ -449,18 +571,18 @@ if has_crypto:
else:
raise InvalidKeyError("Not a public or private key")
if isinstance(key_obj.curve, ec.SECP256R1):
if isinstance(key_obj.curve, SECP256R1):
crv = "P-256"
elif isinstance(key_obj.curve, ec.SECP384R1):
elif isinstance(key_obj.curve, SECP384R1):
crv = "P-384"
elif isinstance(key_obj.curve, ec.SECP521R1):
elif isinstance(key_obj.curve, SECP521R1):
crv = "P-521"
elif isinstance(key_obj.curve, ec.SECP256K1):
elif isinstance(key_obj.curve, SECP256K1):
crv = "secp256k1"
else:
raise InvalidKeyError(f"Invalid curve: {key_obj.curve}")
obj = {
obj: dict[str, Any] = {
"kty": "EC",
"crv": crv,
"x": to_base64url_uint(public_numbers.x).decode(),
@ -472,10 +594,13 @@ if has_crypto:
key_obj.private_numbers().private_value
).decode()
return json.dumps(obj)
if as_dict:
return obj
else:
return json.dumps(obj)
@staticmethod
def from_jwk(jwk):
def from_jwk(jwk: str | JWKDict) -> AllowedECKeys:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)
@ -496,24 +621,26 @@ if has_crypto:
y = base64url_decode(obj.get("y"))
curve = obj.get("crv")
curve_obj: EllipticCurve
if curve == "P-256":
if len(x) == len(y) == 32:
curve_obj = ec.SECP256R1()
curve_obj = SECP256R1()
else:
raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
elif curve == "P-384":
if len(x) == len(y) == 48:
curve_obj = ec.SECP384R1()
curve_obj = SECP384R1()
else:
raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
elif curve == "P-521":
if len(x) == len(y) == 66:
curve_obj = ec.SECP521R1()
curve_obj = SECP521R1()
else:
raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
elif curve == "secp256k1":
if len(x) == len(y) == 32:
curve_obj = ec.SECP256K1()
curve_obj = SECP256K1()
else:
raise InvalidKeyError(
"Coords should be 32 bytes for curve secp256k1"
@ -521,7 +648,7 @@ if has_crypto:
else:
raise InvalidKeyError(f"Invalid curve: {curve}")
public_numbers = ec.EllipticCurvePublicNumbers(
public_numbers = EllipticCurvePublicNumbers(
x=int.from_bytes(x, byteorder="big"),
y=int.from_bytes(y, byteorder="big"),
curve=curve_obj,
@ -536,7 +663,7 @@ if has_crypto:
"D should be {} bytes for curve {}", len(x), curve
)
return ec.EllipticCurvePrivateNumbers(
return EllipticCurvePrivateNumbers(
int.from_bytes(d, byteorder="big"), public_numbers
).private_key()
@ -545,24 +672,24 @@ if has_crypto:
Performs a signature using RSASSA-PSS with MGF1
"""
def sign(self, msg, key):
def sign(self, msg: bytes, key: RSAPrivateKey) -> bytes:
return key.sign(
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
salt_length=self.hash_alg.digest_size,
salt_length=self.hash_alg().digest_size,
),
self.hash_alg(),
)
def verify(self, msg, key, sig):
def verify(self, msg: bytes, key: RSAPublicKey, sig: bytes) -> bool:
try:
key.verify(
sig,
msg,
padding.PSS(
mgf=padding.MGF1(self.hash_alg()),
salt_length=self.hash_alg.digest_size,
salt_length=self.hash_alg().digest_size,
),
self.hash_alg(),
)
@ -577,21 +704,20 @@ if has_crypto:
This class requires ``cryptography>=2.6`` to be installed.
"""
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
pass
def prepare_key(self, key):
def prepare_key(self, key: AllowedOKPKeys | str | bytes) -> AllowedOKPKeys:
if isinstance(key, (bytes, str)):
if isinstance(key, str):
key = key.encode("utf-8")
str_key = key.decode("utf-8")
key_str = key.decode("utf-8") if isinstance(key, bytes) else key
key_bytes = key.encode("utf-8") if isinstance(key, str) else key
if "-----BEGIN PUBLIC" in str_key:
key = load_pem_public_key(key)
elif "-----BEGIN PRIVATE" in str_key:
key = load_pem_private_key(key, password=None)
elif str_key[0:4] == "ssh-":
key = load_ssh_public_key(key)
if "-----BEGIN PUBLIC" in key_str:
key = load_pem_public_key(key_bytes) # type: ignore[assignment]
elif "-----BEGIN PRIVATE" in key_str:
key = load_pem_private_key(key_bytes, password=None) # type: ignore[assignment]
elif key_str[0:4] == "ssh-":
key = load_ssh_public_key(key_bytes) # type: ignore[assignment]
# Explicit check the key to prevent confusing errors from cryptography
if not isinstance(
@ -604,7 +730,9 @@ if has_crypto:
return key
def sign(self, msg, key):
def sign(
self, msg: str | bytes, key: Ed25519PrivateKey | Ed448PrivateKey
) -> bytes:
"""
Sign a message ``msg`` using the EdDSA private key ``key``
:param str|bytes msg: Message to sign
@ -612,10 +740,12 @@ if has_crypto:
or :class:`.Ed448PrivateKey` isinstance
:return bytes signature: The signature, as bytes
"""
msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg
return key.sign(msg)
msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
return key.sign(msg_bytes)
def verify(self, msg, key, sig):
def verify(
self, msg: str | bytes, key: AllowedOKPKeys, sig: str | bytes
) -> bool:
"""
Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
@ -626,31 +756,48 @@ if has_crypto:
:return bool verified: True if signature is valid, False if not.
"""
try:
msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg
sig = bytes(sig, "utf-8") if type(sig) is not bytes else sig
msg_bytes = msg.encode("utf-8") if isinstance(msg, str) else msg
sig_bytes = sig.encode("utf-8") if isinstance(sig, str) else sig
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
key = key.public_key()
key.verify(sig, msg)
public_key = (
key.public_key()
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey))
else key
)
public_key.verify(sig_bytes, msg_bytes)
return True # If no exception was raised, the signature is valid.
except cryptography.exceptions.InvalidSignature:
except InvalidSignature:
return False
@overload
@staticmethod
def to_jwk(key):
def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover
@overload
@staticmethod
def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str:
... # pragma: no cover
@staticmethod
def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]:
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
x = key.public_bytes(
encoding=Encoding.Raw,
format=PublicFormat.Raw,
)
crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
return json.dumps(
{
"x": base64url_encode(force_bytes(x)).decode(),
"kty": "OKP",
"crv": crv,
}
)
obj = {
"x": base64url_encode(force_bytes(x)).decode(),
"kty": "OKP",
"crv": crv,
}
if as_dict:
return obj
else:
return json.dumps(obj)
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
d = key.private_bytes(
@ -665,19 +812,22 @@ if has_crypto:
)
crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
return json.dumps(
{
"x": base64url_encode(force_bytes(x)).decode(),
"d": base64url_encode(force_bytes(d)).decode(),
"kty": "OKP",
"crv": crv,
}
)
obj = {
"x": base64url_encode(force_bytes(x)).decode(),
"d": base64url_encode(force_bytes(d)).decode(),
"kty": "OKP",
"crv": crv,
}
if as_dict:
return obj
else:
return json.dumps(obj)
raise InvalidKeyError("Not a public or private key")
@staticmethod
def from_jwk(jwk):
def from_jwk(jwk: str | JWKDict) -> AllowedOKPKeys:
try:
if isinstance(jwk, str):
obj = json.loads(jwk)

View file

@ -2,13 +2,15 @@ from __future__ import annotations
import json
import time
from typing import Any
from .algorithms import get_default_algorithms
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
from .algorithms import get_default_algorithms, has_crypto, requires_cryptography
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError, PyJWTError
from .types import JWKDict
class PyJWK:
def __init__(self, jwk_data, algorithm=None):
def __init__(self, jwk_data: JWKDict, algorithm: str | None = None) -> None:
self._algorithms = get_default_algorithms()
self._jwk_data = jwk_data
@ -47,37 +49,40 @@ class PyJWK:
else:
raise InvalidKeyError(f"Unsupported kty: {kty}")
if not has_crypto and algorithm in requires_cryptography:
raise PyJWKError(f"{algorithm} requires 'cryptography' to be installed.")
self.Algorithm = self._algorithms.get(algorithm)
if not self.Algorithm:
raise PyJWKError(f"Unable to find a algorithm for key: {self._jwk_data}")
raise PyJWKError(f"Unable to find an algorithm for key: {self._jwk_data}")
self.key = self.Algorithm.from_jwk(self._jwk_data)
@staticmethod
def from_dict(obj, algorithm=None):
def from_dict(obj: JWKDict, algorithm: str | None = None) -> "PyJWK":
return PyJWK(obj, algorithm)
@staticmethod
def from_json(data, algorithm=None):
def from_json(data: str, algorithm: None = None) -> "PyJWK":
obj = json.loads(data)
return PyJWK.from_dict(obj, algorithm)
@property
def key_type(self):
def key_type(self) -> str | None:
return self._jwk_data.get("kty", None)
@property
def key_id(self):
def key_id(self) -> str | None:
return self._jwk_data.get("kid", None)
@property
def public_key_use(self):
def public_key_use(self) -> str | None:
return self._jwk_data.get("use", None)
class PyJWKSet:
def __init__(self, keys: list[dict]) -> None:
def __init__(self, keys: list[JWKDict]) -> None:
self.keys = []
if not keys:
@ -89,24 +94,26 @@ class PyJWKSet:
for key in keys:
try:
self.keys.append(PyJWK(key))
except PyJWKError:
except PyJWTError:
# skip unusable keys
continue
if len(self.keys) == 0:
raise PyJWKSetError("The JWK Set did not contain any usable keys")
raise PyJWKSetError(
"The JWK Set did not contain any usable keys. Perhaps 'cryptography' is not installed?"
)
@staticmethod
def from_dict(obj):
def from_dict(obj: dict[str, Any]) -> "PyJWKSet":
keys = obj.get("keys", [])
return PyJWKSet(keys)
@staticmethod
def from_json(data):
def from_json(data: str) -> "PyJWKSet":
obj = json.loads(data)
return PyJWKSet.from_dict(obj)
def __getitem__(self, kid):
def __getitem__(self, kid: str) -> "PyJWK":
for key in self.keys:
if key.key_id == kid:
return key
@ -118,8 +125,8 @@ class PyJWTSetWithTimestamp:
self.jwk_set = jwk_set
self.timestamp = time.monotonic()
def get_jwk_set(self):
def get_jwk_set(self) -> PyJWKSet:
return self.jwk_set
def get_timestamp(self):
def get_timestamp(self) -> float:
return self.timestamp

View file

@ -3,7 +3,7 @@ from __future__ import annotations
import binascii
import json
import warnings
from typing import Any, Type
from typing import TYPE_CHECKING, Any
from .algorithms import (
Algorithm,
@ -20,11 +20,18 @@ from .exceptions import (
from .utils import base64url_decode, base64url_encode
from .warnings import RemovedInPyjwt3Warning
if TYPE_CHECKING:
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
class PyJWS:
header_typ = "JWT"
def __init__(self, algorithms=None, options=None) -> None:
def __init__(
self,
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
) -> None:
self._algorithms = get_default_algorithms()
self._valid_algs = (
set(algorithms) if algorithms is not None else set(self._algorithms)
@ -96,11 +103,12 @@ class PyJWS:
def encode(
self,
payload: bytes,
key: str,
key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256",
headers: dict[str, Any] | None = None,
json_encoder: Type[json.JSONEncoder] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
is_payload_detached: bool = False,
sort_headers: bool = True,
) -> str:
segments = []
@ -133,9 +141,8 @@ class PyJWS:
# True is the standard value for b64, so no need for it
del header["b64"]
# Fix for headers misorder - issue #715
json_header = json.dumps(
header, separators=(",", ":"), cls=json_encoder, sort_keys=True
header, separators=(",", ":"), cls=json_encoder, sort_keys=sort_headers
).encode()
segments.append(base64url_encode(json_header))
@ -164,8 +171,8 @@ class PyJWS:
def decode_complete(
self,
jwt: str,
key: str = "",
jwt: str | bytes,
key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
@ -209,13 +216,13 @@ class PyJWS:
def decode(
self,
jwt: str,
key: str = "",
jwt: str | bytes,
key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
detached_payload: bytes | None = None,
**kwargs,
) -> str:
) -> Any:
if kwargs:
warnings.warn(
"passing additional kwargs to decode() is deprecated "
@ -228,7 +235,7 @@ class PyJWS:
)
return decoded["payload"]
def get_unverified_header(self, jwt: str | bytes) -> dict:
def get_unverified_header(self, jwt: str | bytes) -> dict[str, Any]:
"""Returns back the JWT header parameters as a dict()
Note: The signature is not verified so the header parameters
@ -239,7 +246,7 @@ class PyJWS:
return headers
def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict, bytes]:
def _load(self, jwt: str | bytes) -> tuple[bytes, bytes, dict[str, Any], bytes]:
if isinstance(jwt, str):
jwt = jwt.encode("utf-8")
@ -280,13 +287,15 @@ class PyJWS:
def _verify_signature(
self,
signing_input: bytes,
header: dict,
header: dict[str, Any],
signature: bytes,
key: str = "",
key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
) -> None:
alg = header.get("alg")
try:
alg = header["alg"]
except KeyError:
raise InvalidAlgorithmError("Algorithm not specified")
if not alg or (algorithms is not None and alg not in algorithms):
raise InvalidAlgorithmError("The specified alg value is not allowed")
@ -295,16 +304,16 @@ class PyJWS:
alg_obj = self.get_algorithm_by_name(alg)
except NotImplementedError as e:
raise InvalidAlgorithmError("Algorithm not supported") from e
key = alg_obj.prepare_key(key)
prepared_key = alg_obj.prepare_key(key)
if not alg_obj.verify(signing_input, key, signature):
if not alg_obj.verify(signing_input, prepared_key, signature):
raise InvalidSignatureError("Signature verification failed")
def _validate_headers(self, headers: dict[str, Any]) -> None:
if "kid" in headers:
self._validate_kid(headers["kid"])
def _validate_kid(self, kid: str) -> None:
def _validate_kid(self, kid: Any) -> None:
if not isinstance(kid, str):
raise InvalidTokenError("Key ID header parameter must be a string")

View file

@ -3,9 +3,9 @@ from __future__ import annotations
import json
import warnings
from calendar import timegm
from collections.abc import Iterable, Mapping
from collections.abc import Iterable
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Type, Union
from typing import TYPE_CHECKING, Any
from . import api_jws
from .exceptions import (
@ -19,15 +19,18 @@ from .exceptions import (
)
from .warnings import RemovedInPyjwt3Warning
if TYPE_CHECKING:
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
class PyJWT:
def __init__(self, options=None):
def __init__(self, options: dict[str, Any] | None = None) -> None:
if options is None:
options = {}
self.options = {**self._get_default_options(), **options}
self.options: dict[str, Any] = {**self._get_default_options(), **options}
@staticmethod
def _get_default_options() -> Dict[str, Union[bool, List[str]]]:
def _get_default_options() -> dict[str, bool | list[str]]:
return {
"verify_signature": True,
"verify_exp": True,
@ -40,16 +43,17 @@ class PyJWT:
def encode(
self,
payload: Dict[str, Any],
key: str,
algorithm: Optional[str] = "HS256",
headers: Optional[Dict[str, Any]] = None,
json_encoder: Optional[Type[json.JSONEncoder]] = None,
payload: dict[str, Any],
key: AllowedPrivateKeys | str | bytes,
algorithm: str | None = "HS256",
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
sort_headers: bool = True,
) -> str:
# Check that we get a mapping
if not isinstance(payload, Mapping):
# Check that we get a dict
if not isinstance(payload, dict):
raise TypeError(
"Expecting a mapping object, as JWT only supports "
"Expecting a dict object, as JWT only supports "
"JSON objects as payloads."
)
@ -60,30 +64,57 @@ class PyJWT:
if isinstance(payload.get(time_claim), datetime):
payload[time_claim] = timegm(payload[time_claim].utctimetuple())
json_payload = json.dumps(
payload, separators=(",", ":"), cls=json_encoder
).encode("utf-8")
json_payload = self._encode_payload(
payload,
headers=headers,
json_encoder=json_encoder,
)
return api_jws.encode(json_payload, key, algorithm, headers, json_encoder)
return api_jws.encode(
json_payload,
key,
algorithm,
headers,
json_encoder,
sort_headers=sort_headers,
)
def _encode_payload(
self,
payload: dict[str, Any],
headers: dict[str, Any] | None = None,
json_encoder: type[json.JSONEncoder] | None = None,
) -> bytes:
"""
Encode a given payload to the bytes to be signed.
This method is intended to be overridden by subclasses that need to
encode the payload in a different way, e.g. compress the payload.
"""
return json.dumps(
payload,
separators=(",", ":"),
cls=json_encoder,
).encode("utf-8")
def decode_complete(
self,
jwt: str,
key: str = "",
algorithms: Optional[List[str]] = None,
options: Optional[Dict[str, Any]] = None,
jwt: str | bytes,
key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3
verify: Optional[bool] = None,
verify: bool | None = None,
# could be used as passthrough to api_jws, consider removal in pyjwt3
detached_payload: Optional[bytes] = None,
detached_payload: bytes | None = None,
# passthrough arguments to _validate_claims
# consider putting in options
audience: Optional[Union[str, Iterable[str]]] = None,
issuer: Optional[str] = None,
leeway: Union[int, float, timedelta] = 0,
audience: str | Iterable[str] | None = None,
issuer: str | None = None,
leeway: float | timedelta = 0,
# kwargs
**kwargs,
) -> Dict[str, Any]:
**kwargs: Any,
) -> dict[str, Any]:
if kwargs:
warnings.warn(
"passing additional kwargs to decode_complete() is deprecated "
@ -125,12 +156,7 @@ class PyJWT:
detached_payload=detached_payload,
)
try:
payload = json.loads(decoded["payload"])
except ValueError as e:
raise DecodeError(f"Invalid payload string: {e}")
if not isinstance(payload, dict):
raise DecodeError("Invalid payload string: must be a json object")
payload = self._decode_payload(decoded)
merged_options = {**self.options, **options}
self._validate_claims(
@ -140,24 +166,40 @@ class PyJWT:
decoded["payload"] = payload
return decoded
def _decode_payload(self, decoded: dict[str, Any]) -> Any:
"""
Decode the payload from a JWS dictionary (payload, signature, header).
This method is intended to be overridden by subclasses that need to
decode the payload in a different way, e.g. decompress compressed
payloads.
"""
try:
payload = json.loads(decoded["payload"])
except ValueError as e:
raise DecodeError(f"Invalid payload string: {e}")
if not isinstance(payload, dict):
raise DecodeError("Invalid payload string: must be a json object")
return payload
def decode(
self,
jwt: str,
key: str = "",
algorithms: Optional[List[str]] = None,
options: Optional[Dict[str, Any]] = None,
jwt: str | bytes,
key: AllowedPublicKeys | str | bytes = "",
algorithms: list[str] | None = None,
options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3
verify: Optional[bool] = None,
verify: bool | None = None,
# could be used as passthrough to api_jws, consider removal in pyjwt3
detached_payload: Optional[bytes] = None,
detached_payload: bytes | None = None,
# passthrough arguments to _validate_claims
# consider putting in options
audience: Optional[Union[str, Iterable[str]]] = None,
issuer: Optional[str] = None,
leeway: Union[int, float, timedelta] = 0,
audience: str | Iterable[str] | None = None,
issuer: str | None = None,
leeway: float | timedelta = 0,
# kwargs
**kwargs,
) -> Dict[str, Any]:
**kwargs: Any,
) -> Any:
if kwargs:
warnings.warn(
"passing additional kwargs to decode() is deprecated "
@ -178,7 +220,14 @@ class PyJWT:
)
return decoded["payload"]
def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0):
def _validate_claims(
self,
payload: dict[str, Any],
options: dict[str, Any],
audience=None,
issuer=None,
leeway: float | timedelta = 0,
) -> None:
if isinstance(leeway, timedelta):
leeway = leeway.total_seconds()
@ -187,7 +236,7 @@ class PyJWT:
self._validate_required_claims(payload, options)
now = timegm(datetime.now(tz=timezone.utc).utctimetuple())
now = datetime.now(tz=timezone.utc).timestamp()
if "iat" in payload and options["verify_iat"]:
self._validate_iat(payload, now, leeway)
@ -202,23 +251,38 @@ class PyJWT:
self._validate_iss(payload, issuer)
if options["verify_aud"]:
self._validate_aud(payload, audience)
self._validate_aud(
payload, audience, strict=options.get("strict_aud", False)
)
def _validate_required_claims(self, payload, options):
def _validate_required_claims(
self,
payload: dict[str, Any],
options: dict[str, Any],
) -> None:
for claim in options["require"]:
if payload.get(claim) is None:
raise MissingRequiredClaimError(claim)
def _validate_iat(self, payload, now, leeway):
iat = payload["iat"]
def _validate_iat(
self,
payload: dict[str, Any],
now: float,
leeway: float,
) -> None:
try:
int(iat)
iat = int(payload["iat"])
except ValueError:
raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.")
if iat > (now + leeway):
raise ImmatureSignatureError("The token is not yet valid (iat)")
def _validate_nbf(self, payload, now, leeway):
def _validate_nbf(
self,
payload: dict[str, Any],
now: float,
leeway: float,
) -> None:
try:
nbf = int(payload["nbf"])
except ValueError:
@ -227,7 +291,12 @@ class PyJWT:
if nbf > (now + leeway):
raise ImmatureSignatureError("The token is not yet valid (nbf)")
def _validate_exp(self, payload, now, leeway):
def _validate_exp(
self,
payload: dict[str, Any],
now: float,
leeway: float,
) -> None:
try:
exp = int(payload["exp"])
except ValueError:
@ -236,7 +305,13 @@ class PyJWT:
if exp <= (now - leeway):
raise ExpiredSignatureError("Signature has expired")
def _validate_aud(self, payload, audience):
def _validate_aud(
self,
payload: dict[str, Any],
audience: str | Iterable[str] | None,
*,
strict: bool = False,
) -> None:
if audience is None:
if "aud" not in payload or not payload["aud"]:
return
@ -251,6 +326,22 @@ class PyJWT:
audience_claims = payload["aud"]
# In strict mode, we forbid list matching: the supplied audience
# must be a string, and it must exactly match the audience claim.
if strict:
# Only a single audience is allowed in strict mode.
if not isinstance(audience, str):
raise InvalidAudienceError("Invalid audience (strict)")
# Only a single audience claim is allowed in strict mode.
if not isinstance(audience_claims, str):
raise InvalidAudienceError("Invalid claim format in token (strict)")
if audience != audience_claims:
raise InvalidAudienceError("Audience doesn't match (strict)")
return
if isinstance(audience_claims, str):
audience_claims = [audience_claims]
if not isinstance(audience_claims, list):
@ -262,9 +353,9 @@ class PyJWT:
audience = [audience]
if all(aud not in audience_claims for aud in audience):
raise InvalidAudienceError("Invalid audience")
raise InvalidAudienceError("Audience doesn't match")
def _validate_iss(self, payload, issuer):
def _validate_iss(self, payload: dict[str, Any], issuer: Any) -> None:
if issuer is None:
return

View file

@ -47,10 +47,10 @@ class InvalidAlgorithmError(InvalidTokenError):
class MissingRequiredClaimError(InvalidTokenError):
def __init__(self, claim):
def __init__(self, claim: str) -> None:
self.claim = claim
def __str__(self):
def __str__(self) -> str:
return f'Token is missing the "{self.claim}" claim'
@ -64,3 +64,7 @@ class PyJWKSetError(PyJWTError):
class PyJWKClientError(PyJWTError):
pass
class PyJWKClientConnectionError(PyJWKClientError):
pass

View file

@ -7,8 +7,10 @@ from . import __version__ as pyjwt_version
try:
import cryptography
cryptography_version = cryptography.__version__
except ModuleNotFoundError:
cryptography = None
cryptography_version = ""
def info() -> Dict[str, Dict[str, str]]:
@ -29,7 +31,7 @@ def info() -> Dict[str, Dict[str, str]]:
if implementation == "CPython":
implementation_version = platform.python_version()
elif implementation == "PyPy":
pypy_version_info = getattr(sys, "pypy_version_info")
pypy_version_info = sys.pypy_version_info # type: ignore[attr-defined]
implementation_version = (
f"{pypy_version_info.major}."
f"{pypy_version_info.minor}."
@ -48,7 +50,7 @@ def info() -> Dict[str, Dict[str, str]]:
"name": implementation,
"version": implementation_version,
},
"cryptography": {"version": getattr(cryptography, "__version__", "")},
"cryptography": {"version": cryptography_version},
"pyjwt": {"version": pyjwt_version},
}

View file

@ -5,11 +5,11 @@ from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp
class JWKSetCache:
def __init__(self, lifespan: int):
def __init__(self, lifespan: int) -> None:
self.jwk_set_with_timestamp: Optional[PyJWTSetWithTimestamp] = None
self.lifespan = lifespan
def put(self, jwk_set: PyJWKSet):
def put(self, jwk_set: PyJWKSet) -> None:
if jwk_set is not None:
self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
else:
@ -23,7 +23,6 @@ class JWKSetCache:
return self.jwk_set_with_timestamp.get_jwk_set()
def is_expired(self) -> bool:
return (
self.jwk_set_with_timestamp is not None
and self.lifespan > -1

View file

@ -1,12 +1,13 @@
import json
import urllib.request
from functools import lru_cache
from typing import Any, List, Optional
from ssl import SSLContext
from typing import Any, Dict, List, Optional
from urllib.error import URLError
from .api_jwk import PyJWK, PyJWKSet
from .api_jwt import decode_complete as decode_token
from .exceptions import PyJWKClientError
from .exceptions import PyJWKClientConnectionError, PyJWKClientError
from .jwk_set_cache import JWKSetCache
@ -18,9 +19,17 @@ class PyJWKClient:
max_cached_keys: int = 16,
cache_jwk_set: bool = True,
lifespan: int = 300,
headers: Optional[Dict[str, Any]] = None,
timeout: int = 30,
ssl_context: Optional[SSLContext] = None,
):
if headers is None:
headers = {}
self.uri = uri
self.jwk_set_cache: Optional[JWKSetCache] = None
self.headers = headers
self.timeout = timeout
self.ssl_context = ssl_context
if cache_jwk_set:
# Init jwt set cache with default or given lifespan.
@ -41,10 +50,15 @@ class PyJWKClient:
def fetch_data(self) -> Any:
jwk_set: Any = None
try:
with urllib.request.urlopen(self.uri) as response:
r = urllib.request.Request(url=self.uri, headers=self.headers)
with urllib.request.urlopen(
r, timeout=self.timeout, context=self.ssl_context
) as response:
jwk_set = json.load(response)
except URLError as e:
raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"')
except (URLError, TimeoutError) as e:
raise PyJWKClientConnectionError(
f'Fail to fetch data from the url, err: "{e}"'
)
else:
return jwk_set
finally:
@ -59,6 +73,9 @@ class PyJWKClient:
if data is None:
data = self.fetch_data()
if not isinstance(data, dict):
raise PyJWKClientError("The JWKS endpoint did not return a JSON object")
return PyJWKSet.from_dict(data)
def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:

5
lib/jwt/types.py Normal file
View file

@ -0,0 +1,5 @@
from typing import Any, Callable, Dict
JWKDict = Dict[str, Any]
HashlibHash = Callable[..., Any]

View file

@ -10,10 +10,10 @@ try:
encode_dss_signature,
)
except ModuleNotFoundError:
EllipticCurve = None
pass
def force_bytes(value: Union[str, bytes]) -> bytes:
def force_bytes(value: Union[bytes, str]) -> bytes:
if isinstance(value, str):
return value.encode("utf-8")
elif isinstance(value, bytes):
@ -22,16 +22,15 @@ def force_bytes(value: Union[str, bytes]) -> bytes:
raise TypeError("Expected a string value")
def base64url_decode(input: Union[str, bytes]) -> bytes:
if isinstance(input, str):
input = input.encode("ascii")
def base64url_decode(input: Union[bytes, str]) -> bytes:
input_bytes = force_bytes(input)
rem = len(input) % 4
rem = len(input_bytes) % 4
if rem > 0:
input += b"=" * (4 - rem)
input_bytes += b"=" * (4 - rem)
return base64.urlsafe_b64decode(input)
return base64.urlsafe_b64decode(input_bytes)
def base64url_encode(input: bytes) -> bytes:
@ -50,11 +49,8 @@ def to_base64url_uint(val: int) -> bytes:
return base64url_encode(int_bytes)
def from_base64url_uint(val: Union[str, bytes]) -> int:
if isinstance(val, str):
val = val.encode("ascii")
data = base64url_decode(val)
def from_base64url_uint(val: Union[bytes, str]) -> int:
data = base64url_decode(force_bytes(val))
return int.from_bytes(data, byteorder="big")
@ -78,7 +74,7 @@ def bytes_from_int(val: int) -> bytes:
return val.to_bytes(byte_length, "big", signed=False)
def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes:
def der_to_raw_signature(der_sig: bytes, curve: "EllipticCurve") -> bytes:
num_bits = curve.key_size
num_bytes = (num_bits + 7) // 8
@ -87,7 +83,7 @@ def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes:
return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes)
def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes:
def raw_to_der_signature(raw_sig: bytes, curve: "EllipticCurve") -> bytes:
num_bits = curve.key_size
num_bytes = (num_bits + 7) // 8
@ -97,7 +93,7 @@ def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes:
r = bytes_to_number(raw_sig[:num_bytes])
s = bytes_to_number(raw_sig[num_bytes:])
return encode_dss_signature(r, s)
return bytes(encode_dss_signature(r, s))
# Based on https://github.com/hynek/pem/blob/7ad94db26b0bc21d10953f5dbad3acfdfacf57aa/src/pem/_core.py#L224-L252

View file

@ -31,7 +31,7 @@ paho-mqtt==1.6.1
plexapi==4.13.4
portend==3.1.0
profilehooks==1.12.0
PyJWT==2.6.0
PyJWT==2.8.0
pyparsing==3.0.9
python-dateutil==2.8.2
python-twitter==3.5