mirror of
https://github.com/Tautulli/Tautulli.git
synced 2025-07-06 05:01:14 -07:00
Update PyJWT-2.2.0
This commit is contained in:
parent
b55b053b1e
commit
4eb0fea423
15 changed files with 1143 additions and 641 deletions
|
@ -1,29 +1,70 @@
|
||||||
# -*- coding: utf-8 -*-
|
from .api_jwk import PyJWK, PyJWKSet
|
||||||
# flake8: noqa
|
from .api_jws import (
|
||||||
|
PyJWS,
|
||||||
"""
|
get_unverified_header,
|
||||||
JSON Web Token implementation
|
register_algorithm,
|
||||||
|
unregister_algorithm,
|
||||||
Minimum implementation based on this spec:
|
|
||||||
http://self-issued.info/docs/draft-jones-json-web-token-01.html
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
__title__ = 'pyjwt'
|
|
||||||
__version__ = '1.4.0'
|
|
||||||
__author__ = 'José Padilla'
|
|
||||||
__license__ = 'MIT'
|
|
||||||
__copyright__ = 'Copyright 2015 José Padilla'
|
|
||||||
|
|
||||||
|
|
||||||
from .api_jwt import (
|
|
||||||
encode, decode, register_algorithm, unregister_algorithm,
|
|
||||||
get_unverified_header, PyJWT
|
|
||||||
)
|
)
|
||||||
from .api_jws import PyJWS
|
from .api_jwt import PyJWT, decode, encode
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
InvalidTokenError, DecodeError, InvalidAudienceError,
|
DecodeError,
|
||||||
ExpiredSignatureError, ImmatureSignatureError, InvalidIssuedAtError,
|
ExpiredSignatureError,
|
||||||
InvalidIssuerError, ExpiredSignature, InvalidAudience, InvalidIssuer,
|
ImmatureSignatureError,
|
||||||
MissingRequiredClaimError
|
InvalidAlgorithmError,
|
||||||
|
InvalidAudienceError,
|
||||||
|
InvalidIssuedAtError,
|
||||||
|
InvalidIssuerError,
|
||||||
|
InvalidKeyError,
|
||||||
|
InvalidSignatureError,
|
||||||
|
InvalidTokenError,
|
||||||
|
MissingRequiredClaimError,
|
||||||
|
PyJWKClientError,
|
||||||
|
PyJWKError,
|
||||||
|
PyJWKSetError,
|
||||||
|
PyJWTError,
|
||||||
)
|
)
|
||||||
|
from .jwks_client import PyJWKClient
|
||||||
|
|
||||||
|
__version__ = "2.2.0"
|
||||||
|
|
||||||
|
__title__ = "PyJWT"
|
||||||
|
__description__ = "JSON Web Token implementation in Python"
|
||||||
|
__url__ = "https://pyjwt.readthedocs.io"
|
||||||
|
__uri__ = __url__
|
||||||
|
__doc__ = __description__ + " <" + __uri__ + ">"
|
||||||
|
|
||||||
|
__author__ = "José Padilla"
|
||||||
|
__email__ = "hello@jpadilla.com"
|
||||||
|
|
||||||
|
__license__ = "MIT"
|
||||||
|
__copyright__ = "Copyright 2015-2020 José Padilla"
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PyJWS",
|
||||||
|
"PyJWT",
|
||||||
|
"PyJWKClient",
|
||||||
|
"PyJWK",
|
||||||
|
"PyJWKSet",
|
||||||
|
"decode",
|
||||||
|
"encode",
|
||||||
|
"get_unverified_header",
|
||||||
|
"register_algorithm",
|
||||||
|
"unregister_algorithm",
|
||||||
|
# Exceptions
|
||||||
|
"DecodeError",
|
||||||
|
"ExpiredSignatureError",
|
||||||
|
"ImmatureSignatureError",
|
||||||
|
"InvalidAlgorithmError",
|
||||||
|
"InvalidAudienceError",
|
||||||
|
"InvalidIssuedAtError",
|
||||||
|
"InvalidIssuerError",
|
||||||
|
"InvalidKeyError",
|
||||||
|
"InvalidSignatureError",
|
||||||
|
"InvalidTokenError",
|
||||||
|
"MissingRequiredClaimError",
|
||||||
|
"PyJWKClientError",
|
||||||
|
"PyJWKError",
|
||||||
|
"PyJWKSetError",
|
||||||
|
"PyJWTError",
|
||||||
|
]
|
||||||
|
|
|
@ -1,135 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
from __future__ import absolute_import, print_function
|
|
||||||
|
|
||||||
import json
|
|
||||||
import optparse
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
|
|
||||||
from . import DecodeError, __package__, __version__, decode, encode
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
|
|
||||||
usage = '''Encodes or decodes JSON Web Tokens based on input.
|
|
||||||
|
|
||||||
%prog [options] input
|
|
||||||
|
|
||||||
Decoding examples:
|
|
||||||
|
|
||||||
%prog --key=secret json.web.token
|
|
||||||
%prog --no-verify json.web.token
|
|
||||||
|
|
||||||
Encoding requires the key option and takes space separated key/value pairs
|
|
||||||
separated by equals (=) as input. Examples:
|
|
||||||
|
|
||||||
%prog --key=secret iss=me exp=1302049071
|
|
||||||
%prog --key=secret foo=bar exp=+10
|
|
||||||
|
|
||||||
The exp key is special and can take an offset to current Unix time.\
|
|
||||||
'''
|
|
||||||
p = optparse.OptionParser(
|
|
||||||
usage=usage,
|
|
||||||
prog=__package__,
|
|
||||||
version='%s %s' % (__package__, __version__),
|
|
||||||
)
|
|
||||||
|
|
||||||
p.add_option(
|
|
||||||
'-n', '--no-verify',
|
|
||||||
action='store_false',
|
|
||||||
dest='verify',
|
|
||||||
default=True,
|
|
||||||
help='ignore signature verification on decode'
|
|
||||||
)
|
|
||||||
|
|
||||||
p.add_option(
|
|
||||||
'--key',
|
|
||||||
dest='key',
|
|
||||||
metavar='KEY',
|
|
||||||
default=None,
|
|
||||||
help='set the secret key to sign with'
|
|
||||||
)
|
|
||||||
|
|
||||||
p.add_option(
|
|
||||||
'--alg',
|
|
||||||
dest='algorithm',
|
|
||||||
metavar='ALG',
|
|
||||||
default='HS256',
|
|
||||||
help='set crypto algorithm to sign with. default=HS256'
|
|
||||||
)
|
|
||||||
|
|
||||||
options, arguments = p.parse_args()
|
|
||||||
|
|
||||||
if len(arguments) > 0 or not sys.stdin.isatty():
|
|
||||||
if len(arguments) == 1 and (not options.verify or options.key):
|
|
||||||
# Try to decode
|
|
||||||
try:
|
|
||||||
if not sys.stdin.isatty():
|
|
||||||
token = sys.stdin.read()
|
|
||||||
else:
|
|
||||||
token = arguments[0]
|
|
||||||
|
|
||||||
token = token.encode('utf-8')
|
|
||||||
data = decode(token, key=options.key, verify=options.verify)
|
|
||||||
|
|
||||||
print(json.dumps(data))
|
|
||||||
sys.exit(0)
|
|
||||||
except DecodeError as e:
|
|
||||||
print(e)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Try to encode
|
|
||||||
if options.key is None:
|
|
||||||
print('Key is required when encoding. See --help for usage.')
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Build payload object to encode
|
|
||||||
payload = {}
|
|
||||||
|
|
||||||
for arg in arguments:
|
|
||||||
try:
|
|
||||||
k, v = arg.split('=', 1)
|
|
||||||
|
|
||||||
# exp +offset special case?
|
|
||||||
if k == 'exp' and v[0] == '+' and len(v) > 1:
|
|
||||||
v = str(int(time.time()+int(v[1:])))
|
|
||||||
|
|
||||||
# Cast to integer?
|
|
||||||
if v.isdigit():
|
|
||||||
v = int(v)
|
|
||||||
else:
|
|
||||||
# Cast to float?
|
|
||||||
try:
|
|
||||||
v = float(v)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Cast to true, false, or null?
|
|
||||||
constants = {'true': True, 'false': False, 'null': None}
|
|
||||||
|
|
||||||
if v in constants:
|
|
||||||
v = constants[v]
|
|
||||||
|
|
||||||
payload[k] = v
|
|
||||||
except ValueError:
|
|
||||||
print('Invalid encoding input at {}'.format(arg))
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
try:
|
|
||||||
token = encode(
|
|
||||||
payload,
|
|
||||||
key=options.key,
|
|
||||||
algorithm=options.algorithm
|
|
||||||
)
|
|
||||||
|
|
||||||
print(token)
|
|
||||||
sys.exit(0)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
sys.exit(1)
|
|
||||||
else:
|
|
||||||
p.print_help()
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
|
@ -1,61 +1,114 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
|
import json
|
||||||
|
|
||||||
from .compat import constant_time_compare, string_types, text_type
|
|
||||||
from .exceptions import InvalidKeyError
|
from .exceptions import InvalidKeyError
|
||||||
from .utils import der_to_raw_signature, raw_to_der_signature
|
from .utils import (
|
||||||
|
base64url_decode,
|
||||||
|
base64url_encode,
|
||||||
|
der_to_raw_signature,
|
||||||
|
force_bytes,
|
||||||
|
from_base64url_uint,
|
||||||
|
raw_to_der_signature,
|
||||||
|
to_base64url_uint,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
import cryptography.exceptions
|
||||||
|
from cryptography.exceptions import InvalidSignature
|
||||||
from cryptography.hazmat.primitives import hashes
|
from cryptography.hazmat.primitives import hashes
|
||||||
from cryptography.hazmat.primitives.serialization import (
|
from cryptography.hazmat.primitives.asymmetric import ec, padding
|
||||||
load_pem_private_key, load_pem_public_key, load_ssh_public_key
|
from cryptography.hazmat.primitives.asymmetric.ec import (
|
||||||
|
EllipticCurvePrivateKey,
|
||||||
|
EllipticCurvePublicKey,
|
||||||
|
)
|
||||||
|
from cryptography.hazmat.primitives.asymmetric.ed448 import (
|
||||||
|
Ed448PrivateKey,
|
||||||
|
Ed448PublicKey,
|
||||||
|
)
|
||||||
|
from cryptography.hazmat.primitives.asymmetric.ed25519 import (
|
||||||
|
Ed25519PrivateKey,
|
||||||
|
Ed25519PublicKey,
|
||||||
)
|
)
|
||||||
from cryptography.hazmat.primitives.asymmetric.rsa import (
|
from cryptography.hazmat.primitives.asymmetric.rsa import (
|
||||||
RSAPrivateKey, RSAPublicKey
|
RSAPrivateKey,
|
||||||
|
RSAPrivateNumbers,
|
||||||
|
RSAPublicKey,
|
||||||
|
RSAPublicNumbers,
|
||||||
|
rsa_crt_dmp1,
|
||||||
|
rsa_crt_dmq1,
|
||||||
|
rsa_crt_iqmp,
|
||||||
|
rsa_recover_prime_factors,
|
||||||
)
|
)
|
||||||
from cryptography.hazmat.primitives.asymmetric.ec import (
|
from cryptography.hazmat.primitives.serialization import (
|
||||||
EllipticCurvePrivateKey, EllipticCurvePublicKey
|
Encoding,
|
||||||
|
NoEncryption,
|
||||||
|
PrivateFormat,
|
||||||
|
PublicFormat,
|
||||||
|
load_pem_private_key,
|
||||||
|
load_pem_public_key,
|
||||||
|
load_ssh_public_key,
|
||||||
)
|
)
|
||||||
from cryptography.hazmat.primitives.asymmetric import ec, padding
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
|
||||||
from cryptography.exceptions import InvalidSignature
|
|
||||||
|
|
||||||
has_crypto = True
|
has_crypto = True
|
||||||
except ImportError:
|
except ModuleNotFoundError:
|
||||||
has_crypto = False
|
has_crypto = False
|
||||||
|
|
||||||
|
requires_cryptography = {
|
||||||
|
"RS256",
|
||||||
|
"RS384",
|
||||||
|
"RS512",
|
||||||
|
"ES256",
|
||||||
|
"ES256K",
|
||||||
|
"ES384",
|
||||||
|
"ES521",
|
||||||
|
"ES512",
|
||||||
|
"PS256",
|
||||||
|
"PS384",
|
||||||
|
"PS512",
|
||||||
|
"EdDSA",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_default_algorithms():
|
def get_default_algorithms():
|
||||||
"""
|
"""
|
||||||
Returns the algorithms that are implemented by the library.
|
Returns the algorithms that are implemented by the library.
|
||||||
"""
|
"""
|
||||||
default_algorithms = {
|
default_algorithms = {
|
||||||
'none': NoneAlgorithm(),
|
"none": NoneAlgorithm(),
|
||||||
'HS256': HMACAlgorithm(HMACAlgorithm.SHA256),
|
"HS256": HMACAlgorithm(HMACAlgorithm.SHA256),
|
||||||
'HS384': HMACAlgorithm(HMACAlgorithm.SHA384),
|
"HS384": HMACAlgorithm(HMACAlgorithm.SHA384),
|
||||||
'HS512': HMACAlgorithm(HMACAlgorithm.SHA512)
|
"HS512": HMACAlgorithm(HMACAlgorithm.SHA512),
|
||||||
}
|
}
|
||||||
|
|
||||||
if has_crypto:
|
if has_crypto:
|
||||||
default_algorithms.update({
|
default_algorithms.update(
|
||||||
'RS256': RSAAlgorithm(RSAAlgorithm.SHA256),
|
{
|
||||||
'RS384': RSAAlgorithm(RSAAlgorithm.SHA384),
|
"RS256": RSAAlgorithm(RSAAlgorithm.SHA256),
|
||||||
'RS512': RSAAlgorithm(RSAAlgorithm.SHA512),
|
"RS384": RSAAlgorithm(RSAAlgorithm.SHA384),
|
||||||
'ES256': ECAlgorithm(ECAlgorithm.SHA256),
|
"RS512": RSAAlgorithm(RSAAlgorithm.SHA512),
|
||||||
'ES384': ECAlgorithm(ECAlgorithm.SHA384),
|
"ES256": ECAlgorithm(ECAlgorithm.SHA256),
|
||||||
'ES512': ECAlgorithm(ECAlgorithm.SHA512),
|
"ES256K": ECAlgorithm(ECAlgorithm.SHA256),
|
||||||
'PS256': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
|
"ES384": ECAlgorithm(ECAlgorithm.SHA384),
|
||||||
'PS384': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
|
"ES521": ECAlgorithm(ECAlgorithm.SHA512),
|
||||||
'PS512': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512)
|
"ES512": ECAlgorithm(
|
||||||
})
|
ECAlgorithm.SHA512
|
||||||
|
), # Backward compat for #219 fix
|
||||||
|
"PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256),
|
||||||
|
"PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384),
|
||||||
|
"PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512),
|
||||||
|
"EdDSA": OKPAlgorithm(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return default_algorithms
|
return default_algorithms
|
||||||
|
|
||||||
|
|
||||||
class Algorithm(object):
|
class Algorithm:
|
||||||
"""
|
"""
|
||||||
The interface for an algorithm used to sign and verify tokens.
|
The interface for an algorithm used to sign and verify tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def prepare_key(self, key):
|
def prepare_key(self, key):
|
||||||
"""
|
"""
|
||||||
Performs necessary validation and conversions on the key and returns
|
Performs necessary validation and conversions on the key and returns
|
||||||
|
@ -77,14 +130,29 @@ class Algorithm(object):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_jwk(key_obj):
|
||||||
|
"""
|
||||||
|
Serializes a given RSA key into a JWK
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_jwk(jwk):
|
||||||
|
"""
|
||||||
|
Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class NoneAlgorithm(Algorithm):
|
class NoneAlgorithm(Algorithm):
|
||||||
"""
|
"""
|
||||||
Placeholder for use when no signing or verification
|
Placeholder for use when no signing or verification
|
||||||
operations are required.
|
operations are required.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def prepare_key(self, key):
|
def prepare_key(self, key):
|
||||||
if key == '':
|
if key == "":
|
||||||
key = None
|
key = None
|
||||||
|
|
||||||
if key is not None:
|
if key is not None:
|
||||||
|
@ -93,7 +161,7 @@ class NoneAlgorithm(Algorithm):
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def sign(self, msg, key):
|
def sign(self, msg, key):
|
||||||
return b''
|
return b""
|
||||||
|
|
||||||
def verify(self, msg, key, sig):
|
def verify(self, msg, key, sig):
|
||||||
return False
|
return False
|
||||||
|
@ -104,6 +172,7 @@ class HMACAlgorithm(Algorithm):
|
||||||
Performs signing and verification operations using HMAC
|
Performs signing and verification operations using HMAC
|
||||||
and the specified hash function.
|
and the specified hash function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SHA256 = hashlib.sha256
|
SHA256 = hashlib.sha256
|
||||||
SHA384 = hashlib.sha384
|
SHA384 = hashlib.sha384
|
||||||
SHA512 = hashlib.sha512
|
SHA512 = hashlib.sha512
|
||||||
|
@ -112,30 +181,55 @@ class HMACAlgorithm(Algorithm):
|
||||||
self.hash_alg = hash_alg
|
self.hash_alg = hash_alg
|
||||||
|
|
||||||
def prepare_key(self, key):
|
def prepare_key(self, key):
|
||||||
if not isinstance(key, string_types) and not isinstance(key, bytes):
|
key = force_bytes(key)
|
||||||
raise TypeError('Expecting a string- or bytes-formatted key.')
|
|
||||||
|
|
||||||
if isinstance(key, text_type):
|
|
||||||
key = key.encode('utf-8')
|
|
||||||
|
|
||||||
invalid_strings = [
|
invalid_strings = [
|
||||||
b'-----BEGIN PUBLIC KEY-----',
|
b"-----BEGIN PUBLIC KEY-----",
|
||||||
b'-----BEGIN CERTIFICATE-----',
|
b"-----BEGIN CERTIFICATE-----",
|
||||||
b'ssh-rsa'
|
b"-----BEGIN RSA PUBLIC KEY-----",
|
||||||
|
b"ssh-rsa",
|
||||||
]
|
]
|
||||||
|
|
||||||
if any([string_value in key for string_value in invalid_strings]):
|
if any(string_value in key for string_value in invalid_strings):
|
||||||
raise InvalidKeyError(
|
raise InvalidKeyError(
|
||||||
'The specified key is an asymmetric key or x509 certificate and'
|
"The specified key is an asymmetric key or x509 certificate and"
|
||||||
' should not be used as an HMAC secret.')
|
" should not be used as an HMAC secret."
|
||||||
|
)
|
||||||
|
|
||||||
return key
|
return key
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_jwk(key_obj):
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"k": base64url_encode(force_bytes(key_obj)).decode(),
|
||||||
|
"kty": "oct",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_jwk(jwk):
|
||||||
|
try:
|
||||||
|
if isinstance(jwk, str):
|
||||||
|
obj = json.loads(jwk)
|
||||||
|
elif isinstance(jwk, dict):
|
||||||
|
obj = jwk
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
except ValueError:
|
||||||
|
raise InvalidKeyError("Key is not valid JSON")
|
||||||
|
|
||||||
|
if obj.get("kty") != "oct":
|
||||||
|
raise InvalidKeyError("Not an HMAC key")
|
||||||
|
|
||||||
|
return base64url_decode(obj["k"])
|
||||||
|
|
||||||
def sign(self, msg, key):
|
def sign(self, msg, key):
|
||||||
return hmac.new(key, msg, self.hash_alg).digest()
|
return hmac.new(key, msg, self.hash_alg).digest()
|
||||||
|
|
||||||
def verify(self, msg, key, sig):
|
def verify(self, msg, key, sig):
|
||||||
return constant_time_compare(sig, self.sign(msg, key))
|
return hmac.compare_digest(sig, self.sign(msg, key))
|
||||||
|
|
||||||
|
|
||||||
if has_crypto:
|
if has_crypto:
|
||||||
|
|
||||||
|
@ -144,6 +238,7 @@ if has_crypto:
|
||||||
Performs signing and verification operations using
|
Performs signing and verification operations using
|
||||||
RSASSA-PKCS-v1_5 and the specified hash function.
|
RSASSA-PKCS-v1_5 and the specified hash function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SHA256 = hashes.SHA256
|
SHA256 = hashes.SHA256
|
||||||
SHA384 = hashes.SHA384
|
SHA384 = hashes.SHA384
|
||||||
SHA512 = hashes.SHA512
|
SHA512 = hashes.SHA512
|
||||||
|
@ -152,46 +247,139 @@ if has_crypto:
|
||||||
self.hash_alg = hash_alg
|
self.hash_alg = hash_alg
|
||||||
|
|
||||||
def prepare_key(self, key):
|
def prepare_key(self, key):
|
||||||
if isinstance(key, RSAPrivateKey) or \
|
if isinstance(key, (RSAPrivateKey, RSAPublicKey)):
|
||||||
isinstance(key, RSAPublicKey):
|
|
||||||
return key
|
return key
|
||||||
|
|
||||||
if isinstance(key, string_types):
|
if not isinstance(key, (bytes, str)):
|
||||||
if isinstance(key, text_type):
|
raise TypeError("Expecting a PEM-formatted key.")
|
||||||
key = key.encode('utf-8')
|
|
||||||
|
|
||||||
try:
|
key = force_bytes(key)
|
||||||
if key.startswith(b'ssh-rsa'):
|
|
||||||
key = load_ssh_public_key(key, backend=default_backend())
|
|
||||||
else:
|
|
||||||
key = load_pem_private_key(key, password=None, backend=default_backend())
|
|
||||||
except ValueError:
|
|
||||||
key = load_pem_public_key(key, backend=default_backend())
|
|
||||||
else:
|
|
||||||
raise TypeError('Expecting a PEM-formatted key.')
|
|
||||||
|
|
||||||
return key
|
|
||||||
|
|
||||||
def sign(self, msg, key):
|
|
||||||
signer = key.signer(
|
|
||||||
padding.PKCS1v15(),
|
|
||||||
self.hash_alg()
|
|
||||||
)
|
|
||||||
|
|
||||||
signer.update(msg)
|
|
||||||
return signer.finalize()
|
|
||||||
|
|
||||||
def verify(self, msg, key, sig):
|
|
||||||
verifier = key.verifier(
|
|
||||||
sig,
|
|
||||||
padding.PKCS1v15(),
|
|
||||||
self.hash_alg()
|
|
||||||
)
|
|
||||||
|
|
||||||
verifier.update(msg)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
verifier.verify()
|
if key.startswith(b"ssh-rsa"):
|
||||||
|
key = load_ssh_public_key(key)
|
||||||
|
else:
|
||||||
|
key = load_pem_private_key(key, password=None)
|
||||||
|
except ValueError:
|
||||||
|
key = load_pem_public_key(key)
|
||||||
|
return key
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_jwk(key_obj):
|
||||||
|
obj = None
|
||||||
|
|
||||||
|
if getattr(key_obj, "private_numbers", None):
|
||||||
|
# Private key
|
||||||
|
numbers = key_obj.private_numbers()
|
||||||
|
|
||||||
|
obj = {
|
||||||
|
"kty": "RSA",
|
||||||
|
"key_ops": ["sign"],
|
||||||
|
"n": to_base64url_uint(numbers.public_numbers.n).decode(),
|
||||||
|
"e": to_base64url_uint(numbers.public_numbers.e).decode(),
|
||||||
|
"d": to_base64url_uint(numbers.d).decode(),
|
||||||
|
"p": to_base64url_uint(numbers.p).decode(),
|
||||||
|
"q": to_base64url_uint(numbers.q).decode(),
|
||||||
|
"dp": to_base64url_uint(numbers.dmp1).decode(),
|
||||||
|
"dq": to_base64url_uint(numbers.dmq1).decode(),
|
||||||
|
"qi": to_base64url_uint(numbers.iqmp).decode(),
|
||||||
|
}
|
||||||
|
|
||||||
|
elif getattr(key_obj, "verify", None):
|
||||||
|
# Public key
|
||||||
|
numbers = key_obj.public_numbers()
|
||||||
|
|
||||||
|
obj = {
|
||||||
|
"kty": "RSA",
|
||||||
|
"key_ops": ["verify"],
|
||||||
|
"n": to_base64url_uint(numbers.n).decode(),
|
||||||
|
"e": to_base64url_uint(numbers.e).decode(),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise InvalidKeyError("Not a public or private key")
|
||||||
|
|
||||||
|
return json.dumps(obj)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_jwk(jwk):
|
||||||
|
try:
|
||||||
|
if isinstance(jwk, str):
|
||||||
|
obj = json.loads(jwk)
|
||||||
|
elif isinstance(jwk, dict):
|
||||||
|
obj = jwk
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
except ValueError:
|
||||||
|
raise InvalidKeyError("Key is not valid JSON")
|
||||||
|
|
||||||
|
if obj.get("kty") != "RSA":
|
||||||
|
raise InvalidKeyError("Not an RSA key")
|
||||||
|
|
||||||
|
if "d" in obj and "e" in obj and "n" in obj:
|
||||||
|
# Private key
|
||||||
|
if "oth" in obj:
|
||||||
|
raise InvalidKeyError(
|
||||||
|
"Unsupported RSA private key: > 2 primes not supported"
|
||||||
|
)
|
||||||
|
|
||||||
|
other_props = ["p", "q", "dp", "dq", "qi"]
|
||||||
|
props_found = [prop in obj for prop in other_props]
|
||||||
|
any_props_found = any(props_found)
|
||||||
|
|
||||||
|
if any_props_found and not all(props_found):
|
||||||
|
raise InvalidKeyError(
|
||||||
|
"RSA key must include all parameters if any are present besides d"
|
||||||
|
)
|
||||||
|
|
||||||
|
public_numbers = RSAPublicNumbers(
|
||||||
|
from_base64url_uint(obj["e"]),
|
||||||
|
from_base64url_uint(obj["n"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
if any_props_found:
|
||||||
|
numbers = RSAPrivateNumbers(
|
||||||
|
d=from_base64url_uint(obj["d"]),
|
||||||
|
p=from_base64url_uint(obj["p"]),
|
||||||
|
q=from_base64url_uint(obj["q"]),
|
||||||
|
dmp1=from_base64url_uint(obj["dp"]),
|
||||||
|
dmq1=from_base64url_uint(obj["dq"]),
|
||||||
|
iqmp=from_base64url_uint(obj["qi"]),
|
||||||
|
public_numbers=public_numbers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
d = from_base64url_uint(obj["d"])
|
||||||
|
p, q = rsa_recover_prime_factors(
|
||||||
|
public_numbers.n, d, public_numbers.e
|
||||||
|
)
|
||||||
|
|
||||||
|
numbers = RSAPrivateNumbers(
|
||||||
|
d=d,
|
||||||
|
p=p,
|
||||||
|
q=q,
|
||||||
|
dmp1=rsa_crt_dmp1(d, p),
|
||||||
|
dmq1=rsa_crt_dmq1(d, q),
|
||||||
|
iqmp=rsa_crt_iqmp(p, q),
|
||||||
|
public_numbers=public_numbers,
|
||||||
|
)
|
||||||
|
|
||||||
|
return numbers.private_key()
|
||||||
|
elif "n" in obj and "e" in obj:
|
||||||
|
# Public key
|
||||||
|
numbers = RSAPublicNumbers(
|
||||||
|
from_base64url_uint(obj["e"]),
|
||||||
|
from_base64url_uint(obj["n"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
return numbers.public_key()
|
||||||
|
else:
|
||||||
|
raise InvalidKeyError("Not a public or private key")
|
||||||
|
|
||||||
|
def sign(self, msg, key):
|
||||||
|
return key.sign(msg, padding.PKCS1v15(), self.hash_alg())
|
||||||
|
|
||||||
|
def verify(self, msg, key, sig):
|
||||||
|
try:
|
||||||
|
key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg())
|
||||||
return True
|
return True
|
||||||
except InvalidSignature:
|
except InvalidSignature:
|
||||||
return False
|
return False
|
||||||
|
@ -201,6 +389,7 @@ if has_crypto:
|
||||||
Performs signing and verification operations using
|
Performs signing and verification operations using
|
||||||
ECDSA and the specified hash function
|
ECDSA and the specified hash function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SHA256 = hashes.SHA256
|
SHA256 = hashes.SHA256
|
||||||
SHA384 = hashes.SHA384
|
SHA384 = hashes.SHA384
|
||||||
SHA512 = hashes.SHA512
|
SHA512 = hashes.SHA512
|
||||||
|
@ -209,32 +398,29 @@ if has_crypto:
|
||||||
self.hash_alg = hash_alg
|
self.hash_alg = hash_alg
|
||||||
|
|
||||||
def prepare_key(self, key):
|
def prepare_key(self, key):
|
||||||
if isinstance(key, EllipticCurvePrivateKey) or \
|
if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)):
|
||||||
isinstance(key, EllipticCurvePublicKey):
|
|
||||||
return key
|
return key
|
||||||
|
|
||||||
if isinstance(key, string_types):
|
if not isinstance(key, (bytes, str)):
|
||||||
if isinstance(key, text_type):
|
raise TypeError("Expecting a PEM-formatted key.")
|
||||||
key = key.encode('utf-8')
|
|
||||||
|
|
||||||
# Attempt to load key. We don't know if it's
|
key = force_bytes(key)
|
||||||
# a Signing Key or a Verifying Key, so we try
|
|
||||||
# the Verifying Key first.
|
|
||||||
try:
|
|
||||||
key = load_pem_public_key(key, backend=default_backend())
|
|
||||||
except ValueError:
|
|
||||||
key = load_pem_private_key(key, password=None, backend=default_backend())
|
|
||||||
|
|
||||||
else:
|
# Attempt to load key. We don't know if it's
|
||||||
raise TypeError('Expecting a PEM-formatted key.')
|
# a Signing Key or a Verifying Key, so we try
|
||||||
|
# the Verifying Key first.
|
||||||
|
try:
|
||||||
|
if key.startswith(b"ecdsa-sha2-"):
|
||||||
|
key = load_ssh_public_key(key)
|
||||||
|
else:
|
||||||
|
key = load_pem_public_key(key)
|
||||||
|
except ValueError:
|
||||||
|
key = load_pem_private_key(key, password=None)
|
||||||
|
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def sign(self, msg, key):
|
def sign(self, msg, key):
|
||||||
signer = key.signer(ec.ECDSA(self.hash_alg()))
|
der_sig = key.sign(msg, ec.ECDSA(self.hash_alg()))
|
||||||
|
|
||||||
signer.update(msg)
|
|
||||||
der_sig = signer.finalize()
|
|
||||||
|
|
||||||
return der_to_raw_signature(der_sig, key.curve)
|
return der_to_raw_signature(der_sig, key.curve)
|
||||||
|
|
||||||
|
@ -244,47 +430,245 @@ if has_crypto:
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
verifier = key.verifier(der_sig, ec.ECDSA(self.hash_alg()))
|
|
||||||
|
|
||||||
verifier.update(msg)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
verifier.verify()
|
if isinstance(key, EllipticCurvePrivateKey):
|
||||||
|
key = key.public_key()
|
||||||
|
key.verify(der_sig, msg, ec.ECDSA(self.hash_alg()))
|
||||||
return True
|
return True
|
||||||
except InvalidSignature:
|
except InvalidSignature:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_jwk(jwk):
|
||||||
|
try:
|
||||||
|
if isinstance(jwk, str):
|
||||||
|
obj = json.loads(jwk)
|
||||||
|
elif isinstance(jwk, dict):
|
||||||
|
obj = jwk
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
except ValueError:
|
||||||
|
raise InvalidKeyError("Key is not valid JSON")
|
||||||
|
|
||||||
|
if obj.get("kty") != "EC":
|
||||||
|
raise InvalidKeyError("Not an Elliptic curve key")
|
||||||
|
|
||||||
|
if "x" not in obj or "y" not in obj:
|
||||||
|
raise InvalidKeyError("Not an Elliptic curve key")
|
||||||
|
|
||||||
|
x = base64url_decode(obj.get("x"))
|
||||||
|
y = base64url_decode(obj.get("y"))
|
||||||
|
|
||||||
|
curve = obj.get("crv")
|
||||||
|
if curve == "P-256":
|
||||||
|
if len(x) == len(y) == 32:
|
||||||
|
curve_obj = ec.SECP256R1()
|
||||||
|
else:
|
||||||
|
raise InvalidKeyError("Coords should be 32 bytes for curve P-256")
|
||||||
|
elif curve == "P-384":
|
||||||
|
if len(x) == len(y) == 48:
|
||||||
|
curve_obj = ec.SECP384R1()
|
||||||
|
else:
|
||||||
|
raise InvalidKeyError("Coords should be 48 bytes for curve P-384")
|
||||||
|
elif curve == "P-521":
|
||||||
|
if len(x) == len(y) == 66:
|
||||||
|
curve_obj = ec.SECP521R1()
|
||||||
|
else:
|
||||||
|
raise InvalidKeyError("Coords should be 66 bytes for curve P-521")
|
||||||
|
elif curve == "secp256k1":
|
||||||
|
if len(x) == len(y) == 32:
|
||||||
|
curve_obj = ec.SECP256K1()
|
||||||
|
else:
|
||||||
|
raise InvalidKeyError(
|
||||||
|
"Coords should be 32 bytes for curve secp256k1"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise InvalidKeyError(f"Invalid curve: {curve}")
|
||||||
|
|
||||||
|
public_numbers = ec.EllipticCurvePublicNumbers(
|
||||||
|
x=int.from_bytes(x, byteorder="big"),
|
||||||
|
y=int.from_bytes(y, byteorder="big"),
|
||||||
|
curve=curve_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "d" not in obj:
|
||||||
|
return public_numbers.public_key()
|
||||||
|
|
||||||
|
d = base64url_decode(obj.get("d"))
|
||||||
|
if len(d) != len(x):
|
||||||
|
raise InvalidKeyError(
|
||||||
|
"D should be {} bytes for curve {}", len(x), curve
|
||||||
|
)
|
||||||
|
|
||||||
|
return ec.EllipticCurvePrivateNumbers(
|
||||||
|
int.from_bytes(d, byteorder="big"), public_numbers
|
||||||
|
).private_key()
|
||||||
|
|
||||||
class RSAPSSAlgorithm(RSAAlgorithm):
|
class RSAPSSAlgorithm(RSAAlgorithm):
|
||||||
"""
|
"""
|
||||||
Performs a signature using RSASSA-PSS with MGF1
|
Performs a signature using RSASSA-PSS with MGF1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def sign(self, msg, key):
|
def sign(self, msg, key):
|
||||||
signer = key.signer(
|
return key.sign(
|
||||||
|
msg,
|
||||||
padding.PSS(
|
padding.PSS(
|
||||||
mgf=padding.MGF1(self.hash_alg()),
|
mgf=padding.MGF1(self.hash_alg()),
|
||||||
salt_length=self.hash_alg.digest_size
|
salt_length=self.hash_alg.digest_size,
|
||||||
),
|
),
|
||||||
self.hash_alg()
|
self.hash_alg(),
|
||||||
)
|
)
|
||||||
|
|
||||||
signer.update(msg)
|
|
||||||
return signer.finalize()
|
|
||||||
|
|
||||||
def verify(self, msg, key, sig):
|
def verify(self, msg, key, sig):
|
||||||
verifier = key.verifier(
|
|
||||||
sig,
|
|
||||||
padding.PSS(
|
|
||||||
mgf=padding.MGF1(self.hash_alg()),
|
|
||||||
salt_length=self.hash_alg.digest_size
|
|
||||||
),
|
|
||||||
self.hash_alg()
|
|
||||||
)
|
|
||||||
|
|
||||||
verifier.update(msg)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
verifier.verify()
|
key.verify(
|
||||||
|
sig,
|
||||||
|
msg,
|
||||||
|
padding.PSS(
|
||||||
|
mgf=padding.MGF1(self.hash_alg()),
|
||||||
|
salt_length=self.hash_alg.digest_size,
|
||||||
|
),
|
||||||
|
self.hash_alg(),
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
except InvalidSignature:
|
except InvalidSignature:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
class OKPAlgorithm(Algorithm):
|
||||||
|
"""
|
||||||
|
Performs signing and verification operations using EdDSA
|
||||||
|
|
||||||
|
This class requires ``cryptography>=2.6`` to be installed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def prepare_key(self, key):
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
key,
|
||||||
|
(Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey),
|
||||||
|
):
|
||||||
|
return key
|
||||||
|
|
||||||
|
if isinstance(key, (bytes, str)):
|
||||||
|
if isinstance(key, str):
|
||||||
|
key = key.encode("utf-8")
|
||||||
|
str_key = key.decode("utf-8")
|
||||||
|
|
||||||
|
if "-----BEGIN PUBLIC" in str_key:
|
||||||
|
return load_pem_public_key(key)
|
||||||
|
if "-----BEGIN PRIVATE" in str_key:
|
||||||
|
return load_pem_private_key(key, password=None)
|
||||||
|
if str_key[0:4] == "ssh-":
|
||||||
|
return load_ssh_public_key(key)
|
||||||
|
|
||||||
|
raise TypeError("Expecting a PEM-formatted or OpenSSH key.")
|
||||||
|
|
||||||
|
def sign(self, msg, key):
|
||||||
|
"""
|
||||||
|
Sign a message ``msg`` using the EdDSA private key ``key``
|
||||||
|
:param str|bytes msg: Message to sign
|
||||||
|
:param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey`
|
||||||
|
or :class:`.Ed448PrivateKey` iinstance
|
||||||
|
:return bytes signature: The signature, as bytes
|
||||||
|
"""
|
||||||
|
msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg
|
||||||
|
return key.sign(msg)
|
||||||
|
|
||||||
|
def verify(self, msg, key, sig):
|
||||||
|
"""
|
||||||
|
Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key``
|
||||||
|
|
||||||
|
:param str|bytes sig: EdDSA signature to check ``msg`` against
|
||||||
|
:param str|bytes msg: Message to sign
|
||||||
|
:param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key:
|
||||||
|
A private or public EdDSA key instance
|
||||||
|
:return bool verified: True if signature is valid, False if not.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg
|
||||||
|
sig = bytes(sig, "utf-8") if type(sig) is not bytes else sig
|
||||||
|
|
||||||
|
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
|
||||||
|
key = key.public_key()
|
||||||
|
key.verify(sig, msg)
|
||||||
|
return True # If no exception was raised, the signature is valid.
|
||||||
|
except cryptography.exceptions.InvalidSignature:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_jwk(key):
|
||||||
|
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
|
||||||
|
x = key.public_bytes(
|
||||||
|
encoding=Encoding.Raw,
|
||||||
|
format=PublicFormat.Raw,
|
||||||
|
)
|
||||||
|
crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448"
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"x": base64url_encode(force_bytes(x)).decode(),
|
||||||
|
"kty": "OKP",
|
||||||
|
"crv": crv,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)):
|
||||||
|
d = key.private_bytes(
|
||||||
|
encoding=Encoding.Raw,
|
||||||
|
format=PrivateFormat.Raw,
|
||||||
|
encryption_algorithm=NoEncryption(),
|
||||||
|
)
|
||||||
|
|
||||||
|
x = key.public_key().public_bytes(
|
||||||
|
encoding=Encoding.Raw,
|
||||||
|
format=PublicFormat.Raw,
|
||||||
|
)
|
||||||
|
|
||||||
|
crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448"
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"x": base64url_encode(force_bytes(x)).decode(),
|
||||||
|
"d": base64url_encode(force_bytes(d)).decode(),
|
||||||
|
"kty": "OKP",
|
||||||
|
"crv": crv,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
raise InvalidKeyError("Not a public or private key")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_jwk(jwk):
|
||||||
|
try:
|
||||||
|
if isinstance(jwk, str):
|
||||||
|
obj = json.loads(jwk)
|
||||||
|
elif isinstance(jwk, dict):
|
||||||
|
obj = jwk
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
except ValueError:
|
||||||
|
raise InvalidKeyError("Key is not valid JSON")
|
||||||
|
|
||||||
|
if obj.get("kty") != "OKP":
|
||||||
|
raise InvalidKeyError("Not an Octet Key Pair")
|
||||||
|
|
||||||
|
curve = obj.get("crv")
|
||||||
|
if curve != "Ed25519" and curve != "Ed448":
|
||||||
|
raise InvalidKeyError(f"Invalid curve: {curve}")
|
||||||
|
|
||||||
|
if "x" not in obj:
|
||||||
|
raise InvalidKeyError('OKP should have "x" parameter')
|
||||||
|
x = base64url_decode(obj.get("x"))
|
||||||
|
|
||||||
|
try:
|
||||||
|
if "d" not in obj:
|
||||||
|
if curve == "Ed25519":
|
||||||
|
return Ed25519PublicKey.from_public_bytes(x)
|
||||||
|
return Ed448PublicKey.from_public_bytes(x)
|
||||||
|
d = base64url_decode(obj.get("d"))
|
||||||
|
if curve == "Ed25519":
|
||||||
|
return Ed25519PrivateKey.from_private_bytes(d)
|
||||||
|
return Ed448PrivateKey.from_private_bytes(d)
|
||||||
|
except ValueError as err:
|
||||||
|
raise InvalidKeyError("Invalid key parameter") from err
|
||||||
|
|
97
lib/jwt/api_jwk.py
Normal file
97
lib/jwt/api_jwk.py
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
from .algorithms import get_default_algorithms
|
||||||
|
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
|
||||||
|
|
||||||
|
|
||||||
|
class PyJWK:
|
||||||
|
def __init__(self, jwk_data, algorithm=None):
|
||||||
|
self._algorithms = get_default_algorithms()
|
||||||
|
self._jwk_data = jwk_data
|
||||||
|
|
||||||
|
kty = self._jwk_data.get("kty", None)
|
||||||
|
if not kty:
|
||||||
|
raise InvalidKeyError("kty is not found: %s" % self._jwk_data)
|
||||||
|
|
||||||
|
if not algorithm and isinstance(self._jwk_data, dict):
|
||||||
|
algorithm = self._jwk_data.get("alg", None)
|
||||||
|
|
||||||
|
if not algorithm:
|
||||||
|
# Determine alg with kty (and crv).
|
||||||
|
crv = self._jwk_data.get("crv", None)
|
||||||
|
if kty == "EC":
|
||||||
|
if crv == "P-256" or not crv:
|
||||||
|
algorithm = "ES256"
|
||||||
|
elif crv == "P-384":
|
||||||
|
algorithm = "ES384"
|
||||||
|
elif crv == "P-521":
|
||||||
|
algorithm = "ES512"
|
||||||
|
elif crv == "secp256k1":
|
||||||
|
algorithm = "ES256K"
|
||||||
|
else:
|
||||||
|
raise InvalidKeyError("Unsupported crv: %s" % crv)
|
||||||
|
elif kty == "RSA":
|
||||||
|
algorithm = "RS256"
|
||||||
|
elif kty == "oct":
|
||||||
|
algorithm = "HS256"
|
||||||
|
elif kty == "OKP":
|
||||||
|
if not crv:
|
||||||
|
raise InvalidKeyError("crv is not found: %s" % self._jwk_data)
|
||||||
|
if crv == "Ed25519":
|
||||||
|
algorithm = "EdDSA"
|
||||||
|
else:
|
||||||
|
raise InvalidKeyError("Unsupported crv: %s" % crv)
|
||||||
|
else:
|
||||||
|
raise InvalidKeyError("Unsupported kty: %s" % kty)
|
||||||
|
|
||||||
|
self.Algorithm = self._algorithms.get(algorithm)
|
||||||
|
|
||||||
|
if not self.Algorithm:
|
||||||
|
raise PyJWKError("Unable to find a algorithm for key: %s" % self._jwk_data)
|
||||||
|
|
||||||
|
self.key = self.Algorithm.from_jwk(self._jwk_data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dict(obj, algorithm=None):
|
||||||
|
return PyJWK(obj, algorithm)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(data, algorithm=None):
|
||||||
|
obj = json.loads(data)
|
||||||
|
return PyJWK.from_dict(obj, algorithm)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def key_type(self):
|
||||||
|
return self._jwk_data.get("kty", None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def key_id(self):
|
||||||
|
return self._jwk_data.get("kid", None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def public_key_use(self):
|
||||||
|
return self._jwk_data.get("use", None)
|
||||||
|
|
||||||
|
|
||||||
|
class PyJWKSet:
|
||||||
|
def __init__(self, keys):
|
||||||
|
self.keys = []
|
||||||
|
|
||||||
|
if not keys or not isinstance(keys, list):
|
||||||
|
raise PyJWKSetError("Invalid JWK Set value")
|
||||||
|
|
||||||
|
if len(keys) == 0:
|
||||||
|
raise PyJWKSetError("The JWK Set did not contain any keys")
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
self.keys.append(PyJWK(key))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_dict(obj):
|
||||||
|
keys = obj.get("keys", [])
|
||||||
|
return PyJWKSet(keys)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(data):
|
||||||
|
obj = json.loads(data)
|
||||||
|
return PyJWKSet.from_dict(obj)
|
|
@ -1,48 +1,54 @@
|
||||||
import binascii
|
import binascii
|
||||||
import json
|
import json
|
||||||
import warnings
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
from collections import Mapping
|
from .algorithms import (
|
||||||
|
Algorithm,
|
||||||
from .algorithms import Algorithm, get_default_algorithms # NOQA
|
get_default_algorithms,
|
||||||
from .compat import text_type
|
has_crypto,
|
||||||
from .exceptions import DecodeError, InvalidAlgorithmError
|
requires_cryptography,
|
||||||
from .utils import base64url_decode, base64url_encode, merge_dict
|
)
|
||||||
|
from .exceptions import (
|
||||||
|
DecodeError,
|
||||||
|
InvalidAlgorithmError,
|
||||||
|
InvalidSignatureError,
|
||||||
|
InvalidTokenError,
|
||||||
|
)
|
||||||
|
from .utils import base64url_decode, base64url_encode
|
||||||
|
|
||||||
|
|
||||||
class PyJWS(object):
|
class PyJWS:
|
||||||
header_typ = 'JWT'
|
header_typ = "JWT"
|
||||||
|
|
||||||
def __init__(self, algorithms=None, options=None):
|
def __init__(self, algorithms=None, options=None):
|
||||||
self._algorithms = get_default_algorithms()
|
self._algorithms = get_default_algorithms()
|
||||||
self._valid_algs = (set(algorithms) if algorithms is not None
|
self._valid_algs = (
|
||||||
else set(self._algorithms))
|
set(algorithms) if algorithms is not None else set(self._algorithms)
|
||||||
|
)
|
||||||
|
|
||||||
# Remove algorithms that aren't on the whitelist
|
# Remove algorithms that aren't on the whitelist
|
||||||
for key in list(self._algorithms.keys()):
|
for key in list(self._algorithms.keys()):
|
||||||
if key not in self._valid_algs:
|
if key not in self._valid_algs:
|
||||||
del self._algorithms[key]
|
del self._algorithms[key]
|
||||||
|
|
||||||
if not options:
|
if options is None:
|
||||||
options = {}
|
options = {}
|
||||||
|
self.options = {**self._get_default_options(), **options}
|
||||||
self.options = merge_dict(self._get_default_options(), options)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_default_options():
|
def _get_default_options():
|
||||||
return {
|
return {"verify_signature": True}
|
||||||
'verify_signature': True
|
|
||||||
}
|
|
||||||
|
|
||||||
def register_algorithm(self, alg_id, alg_obj):
|
def register_algorithm(self, alg_id, alg_obj):
|
||||||
"""
|
"""
|
||||||
Registers a new Algorithm for use when creating and verifying tokens.
|
Registers a new Algorithm for use when creating and verifying tokens.
|
||||||
"""
|
"""
|
||||||
if alg_id in self._algorithms:
|
if alg_id in self._algorithms:
|
||||||
raise ValueError('Algorithm already has a handler.')
|
raise ValueError("Algorithm already has a handler.")
|
||||||
|
|
||||||
if not isinstance(alg_obj, Algorithm):
|
if not isinstance(alg_obj, Algorithm):
|
||||||
raise TypeError('Object is not of type `Algorithm`')
|
raise TypeError("Object is not of type `Algorithm`")
|
||||||
|
|
||||||
self._algorithms[alg_id] = alg_obj
|
self._algorithms[alg_id] = alg_obj
|
||||||
self._valid_algs.add(alg_id)
|
self._valid_algs.add(alg_id)
|
||||||
|
@ -53,8 +59,10 @@ class PyJWS(object):
|
||||||
Throws KeyError if algorithm is not registered.
|
Throws KeyError if algorithm is not registered.
|
||||||
"""
|
"""
|
||||||
if alg_id not in self._algorithms:
|
if alg_id not in self._algorithms:
|
||||||
raise KeyError('The specified algorithm could not be removed'
|
raise KeyError(
|
||||||
' because it is not registered.')
|
"The specified algorithm could not be removed"
|
||||||
|
" because it is not registered."
|
||||||
|
)
|
||||||
|
|
||||||
del self._algorithms[alg_id]
|
del self._algorithms[alg_id]
|
||||||
self._valid_algs.remove(alg_id)
|
self._valid_algs.remove(alg_id)
|
||||||
|
@ -65,59 +73,98 @@ class PyJWS(object):
|
||||||
"""
|
"""
|
||||||
return list(self._valid_algs)
|
return list(self._valid_algs)
|
||||||
|
|
||||||
def encode(self, payload, key, algorithm='HS256', headers=None,
|
def encode(
|
||||||
json_encoder=None):
|
self,
|
||||||
|
payload: bytes,
|
||||||
|
key: str,
|
||||||
|
algorithm: Optional[str] = "HS256",
|
||||||
|
headers: Optional[Dict] = None,
|
||||||
|
json_encoder: Optional[Type[json.JSONEncoder]] = None,
|
||||||
|
) -> str:
|
||||||
segments = []
|
segments = []
|
||||||
|
|
||||||
if algorithm is None:
|
if algorithm is None:
|
||||||
algorithm = 'none'
|
algorithm = "none"
|
||||||
|
|
||||||
if algorithm not in self._valid_algs:
|
# Prefer headers["alg"] if present to algorithm parameter.
|
||||||
pass
|
if headers and "alg" in headers and headers["alg"]:
|
||||||
|
algorithm = headers["alg"]
|
||||||
|
|
||||||
# Header
|
# Header
|
||||||
header = {'typ': self.header_typ, 'alg': algorithm}
|
header = {"typ": self.header_typ, "alg": algorithm}
|
||||||
|
|
||||||
if headers:
|
if headers:
|
||||||
|
self._validate_headers(headers)
|
||||||
header.update(headers)
|
header.update(headers)
|
||||||
|
if not header["typ"]:
|
||||||
|
del header["typ"]
|
||||||
|
|
||||||
json_header = json.dumps(
|
json_header = json.dumps(
|
||||||
header,
|
header, separators=(",", ":"), cls=json_encoder
|
||||||
separators=(',', ':'),
|
).encode()
|
||||||
cls=json_encoder
|
|
||||||
).encode('utf-8')
|
|
||||||
|
|
||||||
segments.append(base64url_encode(json_header))
|
segments.append(base64url_encode(json_header))
|
||||||
segments.append(base64url_encode(payload))
|
segments.append(base64url_encode(payload))
|
||||||
|
|
||||||
# Segments
|
# Segments
|
||||||
signing_input = b'.'.join(segments)
|
signing_input = b".".join(segments)
|
||||||
try:
|
try:
|
||||||
alg_obj = self._algorithms[algorithm]
|
alg_obj = self._algorithms[algorithm]
|
||||||
key = alg_obj.prepare_key(key)
|
key = alg_obj.prepare_key(key)
|
||||||
signature = alg_obj.sign(signing_input, key)
|
signature = alg_obj.sign(signing_input, key)
|
||||||
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise NotImplementedError('Algorithm not supported')
|
if not has_crypto and algorithm in requires_cryptography:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Algorithm '%s' could not be found. Do you have cryptography "
|
||||||
|
"installed?" % algorithm
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Algorithm not supported")
|
||||||
|
|
||||||
segments.append(base64url_encode(signature))
|
segments.append(base64url_encode(signature))
|
||||||
|
|
||||||
return b'.'.join(segments)
|
encoded_string = b".".join(segments)
|
||||||
|
|
||||||
def decode(self, jws, key='', verify=True, algorithms=None, options=None,
|
return encoded_string.decode("utf-8")
|
||||||
**kwargs):
|
|
||||||
payload, signing_input, header, signature = self._load(jws)
|
|
||||||
|
|
||||||
if verify:
|
def decode_complete(
|
||||||
merged_options = merge_dict(self.options, options)
|
self,
|
||||||
if merged_options.get('verify_signature'):
|
jwt: str,
|
||||||
self._verify_signature(payload, signing_input, header, signature,
|
key: str = "",
|
||||||
key, algorithms)
|
algorithms: List[str] = None,
|
||||||
else:
|
options: Dict = None,
|
||||||
warnings.warn('The verify parameter is deprecated. '
|
) -> Dict[str, Any]:
|
||||||
'Please use options instead.', DeprecationWarning)
|
if options is None:
|
||||||
|
options = {}
|
||||||
|
merged_options = {**self.options, **options}
|
||||||
|
verify_signature = merged_options["verify_signature"]
|
||||||
|
|
||||||
return payload
|
if verify_signature and not algorithms:
|
||||||
|
raise DecodeError(
|
||||||
|
'It is required that you pass in a value for the "algorithms" argument when calling decode().'
|
||||||
|
)
|
||||||
|
|
||||||
|
payload, signing_input, header, signature = self._load(jwt)
|
||||||
|
|
||||||
|
if verify_signature:
|
||||||
|
self._verify_signature(signing_input, header, signature, key, algorithms)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"payload": payload,
|
||||||
|
"header": header,
|
||||||
|
"signature": signature,
|
||||||
|
}
|
||||||
|
|
||||||
|
def decode(
|
||||||
|
self,
|
||||||
|
jwt: str,
|
||||||
|
key: str = "",
|
||||||
|
algorithms: List[str] = None,
|
||||||
|
options: Dict = None,
|
||||||
|
) -> str:
|
||||||
|
decoded = self.decode_complete(jwt, key, algorithms, options)
|
||||||
|
return decoded["payload"]
|
||||||
|
|
||||||
def get_unverified_header(self, jwt):
|
def get_unverified_header(self, jwt):
|
||||||
"""Returns back the JWT header parameters as a dict()
|
"""Returns back the JWT header parameters as a dict()
|
||||||
|
@ -125,64 +172,85 @@ class PyJWS(object):
|
||||||
Note: The signature is not verified so the header parameters
|
Note: The signature is not verified so the header parameters
|
||||||
should not be fully trusted until signature verification is complete
|
should not be fully trusted until signature verification is complete
|
||||||
"""
|
"""
|
||||||
return self._load(jwt)[2]
|
headers = self._load(jwt)[2]
|
||||||
|
self._validate_headers(headers)
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
def _load(self, jwt):
|
def _load(self, jwt):
|
||||||
if isinstance(jwt, text_type):
|
if isinstance(jwt, str):
|
||||||
jwt = jwt.encode('utf-8')
|
jwt = jwt.encode("utf-8")
|
||||||
|
|
||||||
|
if not isinstance(jwt, bytes):
|
||||||
|
raise DecodeError(f"Invalid token type. Token must be a {bytes}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
signing_input, crypto_segment = jwt.rsplit(b'.', 1)
|
signing_input, crypto_segment = jwt.rsplit(b".", 1)
|
||||||
header_segment, payload_segment = signing_input.split(b'.', 1)
|
header_segment, payload_segment = signing_input.split(b".", 1)
|
||||||
except ValueError:
|
except ValueError as err:
|
||||||
raise DecodeError('Not enough segments')
|
raise DecodeError("Not enough segments") from err
|
||||||
|
|
||||||
try:
|
try:
|
||||||
header_data = base64url_decode(header_segment)
|
header_data = base64url_decode(header_segment)
|
||||||
except (TypeError, binascii.Error):
|
except (TypeError, binascii.Error) as err:
|
||||||
raise DecodeError('Invalid header padding')
|
raise DecodeError("Invalid header padding") from err
|
||||||
|
|
||||||
try:
|
try:
|
||||||
header = json.loads(header_data.decode('utf-8'))
|
header = json.loads(header_data)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise DecodeError('Invalid header string: %s' % e)
|
raise DecodeError("Invalid header string: %s" % e) from e
|
||||||
|
|
||||||
if not isinstance(header, Mapping):
|
if not isinstance(header, Mapping):
|
||||||
raise DecodeError('Invalid header string: must be a json object')
|
raise DecodeError("Invalid header string: must be a json object")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = base64url_decode(payload_segment)
|
payload = base64url_decode(payload_segment)
|
||||||
except (TypeError, binascii.Error):
|
except (TypeError, binascii.Error) as err:
|
||||||
raise DecodeError('Invalid payload padding')
|
raise DecodeError("Invalid payload padding") from err
|
||||||
|
|
||||||
try:
|
try:
|
||||||
signature = base64url_decode(crypto_segment)
|
signature = base64url_decode(crypto_segment)
|
||||||
except (TypeError, binascii.Error):
|
except (TypeError, binascii.Error) as err:
|
||||||
raise DecodeError('Invalid crypto padding')
|
raise DecodeError("Invalid crypto padding") from err
|
||||||
|
|
||||||
return (payload, signing_input, header, signature)
|
return (payload, signing_input, header, signature)
|
||||||
|
|
||||||
def _verify_signature(self, payload, signing_input, header, signature,
|
def _verify_signature(
|
||||||
key='', algorithms=None):
|
self,
|
||||||
|
signing_input,
|
||||||
|
header,
|
||||||
|
signature,
|
||||||
|
key="",
|
||||||
|
algorithms=None,
|
||||||
|
):
|
||||||
|
|
||||||
alg = header.get('alg')
|
alg = header.get("alg")
|
||||||
|
|
||||||
if algorithms is not None and alg not in algorithms:
|
if algorithms is not None and alg not in algorithms:
|
||||||
raise InvalidAlgorithmError('The specified alg value is not allowed')
|
raise InvalidAlgorithmError("The specified alg value is not allowed")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
alg_obj = self._algorithms[alg]
|
alg_obj = self._algorithms[alg]
|
||||||
key = alg_obj.prepare_key(key)
|
key = alg_obj.prepare_key(key)
|
||||||
|
|
||||||
if not alg_obj.verify(signing_input, key, signature):
|
if not alg_obj.verify(signing_input, key, signature):
|
||||||
raise DecodeError('Signature verification failed')
|
raise InvalidSignatureError("Signature verification failed")
|
||||||
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise InvalidAlgorithmError('Algorithm not supported')
|
raise InvalidAlgorithmError("Algorithm not supported")
|
||||||
|
|
||||||
|
def _validate_headers(self, headers):
|
||||||
|
if "kid" in headers:
|
||||||
|
self._validate_kid(headers["kid"])
|
||||||
|
|
||||||
|
def _validate_kid(self, kid):
|
||||||
|
if not isinstance(kid, str):
|
||||||
|
raise InvalidTokenError("Key ID header parameter must be a string")
|
||||||
|
|
||||||
|
|
||||||
_jws_global_obj = PyJWS()
|
_jws_global_obj = PyJWS()
|
||||||
encode = _jws_global_obj.encode
|
encode = _jws_global_obj.encode
|
||||||
|
decode_complete = _jws_global_obj.decode_complete
|
||||||
decode = _jws_global_obj.decode
|
decode = _jws_global_obj.decode
|
||||||
register_algorithm = _jws_global_obj.register_algorithm
|
register_algorithm = _jws_global_obj.register_algorithm
|
||||||
unregister_algorithm = _jws_global_obj.unregister_algorithm
|
unregister_algorithm = _jws_global_obj.unregister_algorithm
|
||||||
|
|
|
@ -1,187 +1,224 @@
|
||||||
import json
|
import json
|
||||||
import warnings
|
|
||||||
|
|
||||||
from calendar import timegm
|
from calendar import timegm
|
||||||
from collections import Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
from .api_jws import PyJWS
|
from . import api_jws
|
||||||
from .algorithms import Algorithm, get_default_algorithms # NOQA
|
|
||||||
from .compat import string_types, timedelta_total_seconds
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
DecodeError, ExpiredSignatureError, ImmatureSignatureError,
|
DecodeError,
|
||||||
InvalidAudienceError, InvalidIssuedAtError,
|
ExpiredSignatureError,
|
||||||
InvalidIssuerError, MissingRequiredClaimError
|
ImmatureSignatureError,
|
||||||
|
InvalidAudienceError,
|
||||||
|
InvalidIssuedAtError,
|
||||||
|
InvalidIssuerError,
|
||||||
|
MissingRequiredClaimError,
|
||||||
)
|
)
|
||||||
from .utils import merge_dict
|
|
||||||
|
|
||||||
|
|
||||||
class PyJWT(PyJWS):
|
class PyJWT:
|
||||||
header_type = 'JWT'
|
def __init__(self, options=None):
|
||||||
|
if options is None:
|
||||||
|
options = {}
|
||||||
|
self.options = {**self._get_default_options(), **options}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_default_options():
|
def _get_default_options() -> Dict[str, Union[bool, List[str]]]:
|
||||||
return {
|
return {
|
||||||
'verify_signature': True,
|
"verify_signature": True,
|
||||||
'verify_exp': True,
|
"verify_exp": True,
|
||||||
'verify_nbf': True,
|
"verify_nbf": True,
|
||||||
'verify_iat': True,
|
"verify_iat": True,
|
||||||
'verify_aud': True,
|
"verify_aud": True,
|
||||||
'verify_iss': True,
|
"verify_iss": True,
|
||||||
'require_exp': False,
|
"require": [],
|
||||||
'require_iat': False,
|
|
||||||
'require_nbf': False
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def encode(self, payload, key, algorithm='HS256', headers=None,
|
def encode(
|
||||||
json_encoder=None):
|
self,
|
||||||
|
payload: Dict[str, Any],
|
||||||
|
key: str,
|
||||||
|
algorithm: Optional[str] = "HS256",
|
||||||
|
headers: Optional[Dict] = None,
|
||||||
|
json_encoder: Optional[Type[json.JSONEncoder]] = None,
|
||||||
|
) -> str:
|
||||||
# Check that we get a mapping
|
# Check that we get a mapping
|
||||||
if not isinstance(payload, Mapping):
|
if not isinstance(payload, Mapping):
|
||||||
raise TypeError('Expecting a mapping object, as JWT only supports '
|
raise TypeError(
|
||||||
'JSON objects as payloads.')
|
"Expecting a mapping object, as JWT only supports "
|
||||||
|
"JSON objects as payloads."
|
||||||
|
)
|
||||||
|
|
||||||
# Payload
|
# Payload
|
||||||
for time_claim in ['exp', 'iat', 'nbf']:
|
payload = payload.copy()
|
||||||
|
for time_claim in ["exp", "iat", "nbf"]:
|
||||||
# Convert datetime to a intDate value in known time-format claims
|
# Convert datetime to a intDate value in known time-format claims
|
||||||
if isinstance(payload.get(time_claim), datetime):
|
if isinstance(payload.get(time_claim), datetime):
|
||||||
payload[time_claim] = timegm(payload[time_claim].utctimetuple())
|
payload[time_claim] = timegm(payload[time_claim].utctimetuple())
|
||||||
|
|
||||||
json_payload = json.dumps(
|
json_payload = json.dumps(
|
||||||
payload,
|
payload, separators=(",", ":"), cls=json_encoder
|
||||||
separators=(',', ':'),
|
).encode("utf-8")
|
||||||
cls=json_encoder
|
|
||||||
).encode('utf-8')
|
|
||||||
|
|
||||||
return super(PyJWT, self).encode(
|
return api_jws.encode(json_payload, key, algorithm, headers, json_encoder)
|
||||||
json_payload, key, algorithm, headers, json_encoder
|
|
||||||
|
def decode_complete(
|
||||||
|
self,
|
||||||
|
jwt: str,
|
||||||
|
key: str = "",
|
||||||
|
algorithms: List[str] = None,
|
||||||
|
options: Dict = None,
|
||||||
|
audience: Optional[Union[str, List[str]]] = None,
|
||||||
|
issuer: Optional[str] = None,
|
||||||
|
leeway: Union[float, timedelta] = 0,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
if options is None:
|
||||||
|
options = {"verify_signature": True}
|
||||||
|
else:
|
||||||
|
options.setdefault("verify_signature", True)
|
||||||
|
|
||||||
|
if not options["verify_signature"]:
|
||||||
|
options.setdefault("verify_exp", False)
|
||||||
|
options.setdefault("verify_nbf", False)
|
||||||
|
options.setdefault("verify_iat", False)
|
||||||
|
options.setdefault("verify_aud", False)
|
||||||
|
options.setdefault("verify_iss", False)
|
||||||
|
|
||||||
|
if options["verify_signature"] and not algorithms:
|
||||||
|
raise DecodeError(
|
||||||
|
'It is required that you pass in a value for the "algorithms" argument when calling decode().'
|
||||||
|
)
|
||||||
|
|
||||||
|
decoded = api_jws.decode_complete(
|
||||||
|
jwt,
|
||||||
|
key=key,
|
||||||
|
algorithms=algorithms,
|
||||||
|
options=options,
|
||||||
)
|
)
|
||||||
|
|
||||||
def decode(self, jwt, key='', verify=True, algorithms=None, options=None,
|
|
||||||
**kwargs):
|
|
||||||
payload, signing_input, header, signature = self._load(jwt)
|
|
||||||
|
|
||||||
decoded = super(PyJWT, self).decode(jwt, key, verify, algorithms,
|
|
||||||
options, **kwargs)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = json.loads(decoded.decode('utf-8'))
|
payload = json.loads(decoded["payload"])
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise DecodeError('Invalid payload string: %s' % e)
|
raise DecodeError("Invalid payload string: %s" % e)
|
||||||
if not isinstance(payload, Mapping):
|
if not isinstance(payload, dict):
|
||||||
raise DecodeError('Invalid payload string: must be a json object')
|
raise DecodeError("Invalid payload string: must be a json object")
|
||||||
|
|
||||||
if verify:
|
merged_options = {**self.options, **options}
|
||||||
merged_options = merge_dict(self.options, options)
|
self._validate_claims(payload, merged_options, audience, issuer, leeway)
|
||||||
self._validate_claims(payload, merged_options, **kwargs)
|
|
||||||
|
|
||||||
return payload
|
decoded["payload"] = payload
|
||||||
|
return decoded
|
||||||
|
|
||||||
def _validate_claims(self, payload, options, audience=None, issuer=None,
|
def decode(
|
||||||
leeway=0, **kwargs):
|
self,
|
||||||
|
jwt: str,
|
||||||
if 'verify_expiration' in kwargs:
|
key: str = "",
|
||||||
options['verify_exp'] = kwargs.get('verify_expiration', True)
|
algorithms: List[str] = None,
|
||||||
warnings.warn('The verify_expiration parameter is deprecated. '
|
options: Dict = None,
|
||||||
'Please use options instead.', DeprecationWarning)
|
audience: Optional[Union[str, List[str]]] = None,
|
||||||
|
issuer: Optional[str] = None,
|
||||||
|
leeway: Union[float, timedelta] = 0,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
decoded = self.decode_complete(
|
||||||
|
jwt, key, algorithms, options, audience, issuer, leeway
|
||||||
|
)
|
||||||
|
return decoded["payload"]
|
||||||
|
|
||||||
|
def _validate_claims(self, payload, options, audience, issuer, leeway):
|
||||||
if isinstance(leeway, timedelta):
|
if isinstance(leeway, timedelta):
|
||||||
leeway = timedelta_total_seconds(leeway)
|
leeway = leeway.total_seconds()
|
||||||
|
|
||||||
if not isinstance(audience, (string_types, type(None))):
|
if not isinstance(audience, (str, type(None), Iterable)):
|
||||||
raise TypeError('audience must be a string or None')
|
raise TypeError("audience must be a string, iterable, or None")
|
||||||
|
|
||||||
self._validate_required_claims(payload, options)
|
self._validate_required_claims(payload, options)
|
||||||
|
|
||||||
now = timegm(datetime.utcnow().utctimetuple())
|
now = timegm(datetime.now(tz=timezone.utc).utctimetuple())
|
||||||
|
|
||||||
if 'iat' in payload and options.get('verify_iat'):
|
if "iat" in payload and options["verify_iat"]:
|
||||||
self._validate_iat(payload, now, leeway)
|
self._validate_iat(payload, now, leeway)
|
||||||
|
|
||||||
if 'nbf' in payload and options.get('verify_nbf'):
|
if "nbf" in payload and options["verify_nbf"]:
|
||||||
self._validate_nbf(payload, now, leeway)
|
self._validate_nbf(payload, now, leeway)
|
||||||
|
|
||||||
if 'exp' in payload and options.get('verify_exp'):
|
if "exp" in payload and options["verify_exp"]:
|
||||||
self._validate_exp(payload, now, leeway)
|
self._validate_exp(payload, now, leeway)
|
||||||
|
|
||||||
if options.get('verify_iss'):
|
if options["verify_iss"]:
|
||||||
self._validate_iss(payload, issuer)
|
self._validate_iss(payload, issuer)
|
||||||
|
|
||||||
if options.get('verify_aud'):
|
if options["verify_aud"]:
|
||||||
self._validate_aud(payload, audience)
|
self._validate_aud(payload, audience)
|
||||||
|
|
||||||
def _validate_required_claims(self, payload, options):
|
def _validate_required_claims(self, payload, options):
|
||||||
if options.get('require_exp') and payload.get('exp') is None:
|
for claim in options["require"]:
|
||||||
raise MissingRequiredClaimError('exp')
|
if payload.get(claim) is None:
|
||||||
|
raise MissingRequiredClaimError(claim)
|
||||||
if options.get('require_iat') and payload.get('iat') is None:
|
|
||||||
raise MissingRequiredClaimError('iat')
|
|
||||||
|
|
||||||
if options.get('require_nbf') and payload.get('nbf') is None:
|
|
||||||
raise MissingRequiredClaimError('nbf')
|
|
||||||
|
|
||||||
def _validate_iat(self, payload, now, leeway):
|
def _validate_iat(self, payload, now, leeway):
|
||||||
try:
|
try:
|
||||||
iat = int(payload['iat'])
|
int(payload["iat"])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise DecodeError('Issued At claim (iat) must be an integer.')
|
raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.")
|
||||||
|
|
||||||
if iat > (now + leeway):
|
|
||||||
raise InvalidIssuedAtError('Issued At claim (iat) cannot be in'
|
|
||||||
' the future.')
|
|
||||||
|
|
||||||
def _validate_nbf(self, payload, now, leeway):
|
def _validate_nbf(self, payload, now, leeway):
|
||||||
try:
|
try:
|
||||||
nbf = int(payload['nbf'])
|
nbf = int(payload["nbf"])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise DecodeError('Not Before claim (nbf) must be an integer.')
|
raise DecodeError("Not Before claim (nbf) must be an integer.")
|
||||||
|
|
||||||
if nbf > (now + leeway):
|
if nbf > (now + leeway):
|
||||||
raise ImmatureSignatureError('The token is not yet valid (nbf)')
|
raise ImmatureSignatureError("The token is not yet valid (nbf)")
|
||||||
|
|
||||||
def _validate_exp(self, payload, now, leeway):
|
def _validate_exp(self, payload, now, leeway):
|
||||||
try:
|
try:
|
||||||
exp = int(payload['exp'])
|
exp = int(payload["exp"])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise DecodeError('Expiration Time claim (exp) must be an'
|
raise DecodeError("Expiration Time claim (exp) must be an" " integer.")
|
||||||
' integer.')
|
|
||||||
|
|
||||||
if exp < (now - leeway):
|
if exp < (now - leeway):
|
||||||
raise ExpiredSignatureError('Signature has expired')
|
raise ExpiredSignatureError("Signature has expired")
|
||||||
|
|
||||||
def _validate_aud(self, payload, audience):
|
def _validate_aud(self, payload, audience):
|
||||||
if audience is None and 'aud' not in payload:
|
if audience is None:
|
||||||
return
|
if "aud" not in payload or not payload["aud"]:
|
||||||
|
return
|
||||||
|
# Application did not specify an audience, but
|
||||||
|
# the token has the 'aud' claim
|
||||||
|
raise InvalidAudienceError("Invalid audience")
|
||||||
|
|
||||||
if audience is not None and 'aud' not in payload:
|
if "aud" not in payload or not payload["aud"]:
|
||||||
# Application specified an audience, but it could not be
|
# Application specified an audience, but it could not be
|
||||||
# verified since the token does not contain a claim.
|
# verified since the token does not contain a claim.
|
||||||
raise MissingRequiredClaimError('aud')
|
raise MissingRequiredClaimError("aud")
|
||||||
|
|
||||||
audience_claims = payload['aud']
|
audience_claims = payload["aud"]
|
||||||
|
|
||||||
if isinstance(audience_claims, string_types):
|
if isinstance(audience_claims, str):
|
||||||
audience_claims = [audience_claims]
|
audience_claims = [audience_claims]
|
||||||
if not isinstance(audience_claims, list):
|
if not isinstance(audience_claims, list):
|
||||||
raise InvalidAudienceError('Invalid claim format in token')
|
raise InvalidAudienceError("Invalid claim format in token")
|
||||||
if any(not isinstance(c, string_types) for c in audience_claims):
|
if any(not isinstance(c, str) for c in audience_claims):
|
||||||
raise InvalidAudienceError('Invalid claim format in token')
|
raise InvalidAudienceError("Invalid claim format in token")
|
||||||
if audience not in audience_claims:
|
|
||||||
raise InvalidAudienceError('Invalid audience')
|
if isinstance(audience, str):
|
||||||
|
audience = [audience]
|
||||||
|
|
||||||
|
if all(aud not in audience_claims for aud in audience):
|
||||||
|
raise InvalidAudienceError("Invalid audience")
|
||||||
|
|
||||||
def _validate_iss(self, payload, issuer):
|
def _validate_iss(self, payload, issuer):
|
||||||
if issuer is None:
|
if issuer is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if 'iss' not in payload:
|
if "iss" not in payload:
|
||||||
raise MissingRequiredClaimError('iss')
|
raise MissingRequiredClaimError("iss")
|
||||||
|
|
||||||
if payload['iss'] != issuer:
|
if payload["iss"] != issuer:
|
||||||
raise InvalidIssuerError('Invalid issuer')
|
raise InvalidIssuerError("Invalid issuer")
|
||||||
|
|
||||||
|
|
||||||
_jwt_global_obj = PyJWT()
|
_jwt_global_obj = PyJWT()
|
||||||
encode = _jwt_global_obj.encode
|
encode = _jwt_global_obj.encode
|
||||||
|
decode_complete = _jwt_global_obj.decode_complete
|
||||||
decode = _jwt_global_obj.decode
|
decode = _jwt_global_obj.decode
|
||||||
register_algorithm = _jwt_global_obj.register_algorithm
|
|
||||||
unregister_algorithm = _jwt_global_obj.unregister_algorithm
|
|
||||||
get_unverified_header = _jwt_global_obj.get_unverified_header
|
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
"""
|
|
||||||
The `compat` module provides support for backwards compatibility with older
|
|
||||||
versions of python, and compatibility wrappers around optional packages.
|
|
||||||
"""
|
|
||||||
# flake8: noqa
|
|
||||||
import sys
|
|
||||||
import hmac
|
|
||||||
|
|
||||||
|
|
||||||
PY3 = sys.version_info[0] == 3
|
|
||||||
|
|
||||||
|
|
||||||
if PY3:
|
|
||||||
string_types = str,
|
|
||||||
text_type = str
|
|
||||||
else:
|
|
||||||
string_types = basestring,
|
|
||||||
text_type = unicode
|
|
||||||
|
|
||||||
|
|
||||||
def timedelta_total_seconds(delta):
|
|
||||||
try:
|
|
||||||
delta.total_seconds
|
|
||||||
except AttributeError:
|
|
||||||
# On Python 2.6, timedelta instances do not have
|
|
||||||
# a .total_seconds() method.
|
|
||||||
total_seconds = delta.days * 24 * 60 * 60 + delta.seconds
|
|
||||||
else:
|
|
||||||
total_seconds = delta.total_seconds()
|
|
||||||
|
|
||||||
return total_seconds
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
constant_time_compare = hmac.compare_digest
|
|
||||||
except AttributeError:
|
|
||||||
# Fallback for Python < 2.7
|
|
||||||
def constant_time_compare(val1, val2):
|
|
||||||
"""
|
|
||||||
Returns True if the two strings are equal, False otherwise.
|
|
||||||
|
|
||||||
The time taken is independent of the number of characters that match.
|
|
||||||
"""
|
|
||||||
if len(val1) != len(val2):
|
|
||||||
return False
|
|
||||||
|
|
||||||
result = 0
|
|
||||||
|
|
||||||
for x, y in zip(val1, val2):
|
|
||||||
result |= ord(x) ^ ord(y)
|
|
||||||
|
|
||||||
return result == 0
|
|
|
@ -1,60 +0,0 @@
|
||||||
# Note: This file is named py_ecdsa.py because import behavior in Python 2
|
|
||||||
# would cause ecdsa.py to squash the ecdsa library that it depends upon.
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
|
|
||||||
import ecdsa
|
|
||||||
|
|
||||||
from jwt.algorithms import Algorithm
|
|
||||||
from jwt.compat import string_types, text_type
|
|
||||||
|
|
||||||
|
|
||||||
class ECAlgorithm(Algorithm):
|
|
||||||
"""
|
|
||||||
Performs signing and verification operations using
|
|
||||||
ECDSA and the specified hash function
|
|
||||||
|
|
||||||
This class requires the ecdsa package to be installed.
|
|
||||||
|
|
||||||
This is based off of the implementation in PyJWT 0.3.2
|
|
||||||
"""
|
|
||||||
SHA256 = hashlib.sha256
|
|
||||||
SHA384 = hashlib.sha384
|
|
||||||
SHA512 = hashlib.sha512
|
|
||||||
|
|
||||||
def __init__(self, hash_alg):
|
|
||||||
self.hash_alg = hash_alg
|
|
||||||
|
|
||||||
def prepare_key(self, key):
|
|
||||||
|
|
||||||
if isinstance(key, ecdsa.SigningKey) or \
|
|
||||||
isinstance(key, ecdsa.VerifyingKey):
|
|
||||||
return key
|
|
||||||
|
|
||||||
if isinstance(key, string_types):
|
|
||||||
if isinstance(key, text_type):
|
|
||||||
key = key.encode('utf-8')
|
|
||||||
|
|
||||||
# Attempt to load key. We don't know if it's
|
|
||||||
# a Signing Key or a Verifying Key, so we try
|
|
||||||
# the Verifying Key first.
|
|
||||||
try:
|
|
||||||
key = ecdsa.VerifyingKey.from_pem(key)
|
|
||||||
except ecdsa.der.UnexpectedDER:
|
|
||||||
key = ecdsa.SigningKey.from_pem(key)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise TypeError('Expecting a PEM-formatted key.')
|
|
||||||
|
|
||||||
return key
|
|
||||||
|
|
||||||
def sign(self, msg, key):
|
|
||||||
return key.sign(msg, hashfunc=self.hash_alg,
|
|
||||||
sigencode=ecdsa.util.sigencode_string)
|
|
||||||
|
|
||||||
def verify(self, msg, key, sig):
|
|
||||||
try:
|
|
||||||
return key.verify(sig, msg, hashfunc=self.hash_alg,
|
|
||||||
sigdecode=ecdsa.util.sigdecode_string)
|
|
||||||
except AssertionError:
|
|
||||||
return False
|
|
|
@ -1,47 +0,0 @@
|
||||||
import Crypto.Hash.SHA256
|
|
||||||
import Crypto.Hash.SHA384
|
|
||||||
import Crypto.Hash.SHA512
|
|
||||||
|
|
||||||
from Crypto.PublicKey import RSA
|
|
||||||
from Crypto.Signature import PKCS1_v1_5
|
|
||||||
|
|
||||||
from jwt.algorithms import Algorithm
|
|
||||||
from jwt.compat import string_types, text_type
|
|
||||||
|
|
||||||
|
|
||||||
class RSAAlgorithm(Algorithm):
|
|
||||||
"""
|
|
||||||
Performs signing and verification operations using
|
|
||||||
RSASSA-PKCS-v1_5 and the specified hash function.
|
|
||||||
|
|
||||||
This class requires PyCrypto package to be installed.
|
|
||||||
|
|
||||||
This is based off of the implementation in PyJWT 0.3.2
|
|
||||||
"""
|
|
||||||
SHA256 = Crypto.Hash.SHA256
|
|
||||||
SHA384 = Crypto.Hash.SHA384
|
|
||||||
SHA512 = Crypto.Hash.SHA512
|
|
||||||
|
|
||||||
def __init__(self, hash_alg):
|
|
||||||
self.hash_alg = hash_alg
|
|
||||||
|
|
||||||
def prepare_key(self, key):
|
|
||||||
|
|
||||||
if isinstance(key, RSA._RSAobj):
|
|
||||||
return key
|
|
||||||
|
|
||||||
if isinstance(key, string_types):
|
|
||||||
if isinstance(key, text_type):
|
|
||||||
key = key.encode('utf-8')
|
|
||||||
|
|
||||||
key = RSA.importKey(key)
|
|
||||||
else:
|
|
||||||
raise TypeError('Expecting a PEM- or RSA-formatted key.')
|
|
||||||
|
|
||||||
return key
|
|
||||||
|
|
||||||
def sign(self, msg, key):
|
|
||||||
return PKCS1_v1_5.new(key).sign(self.hash_alg.new(msg))
|
|
||||||
|
|
||||||
def verify(self, msg, key, sig):
|
|
||||||
return PKCS1_v1_5.new(key).verify(self.hash_alg.new(msg), sig)
|
|
|
@ -1,4 +1,12 @@
|
||||||
class InvalidTokenError(Exception):
|
class PyJWTError(Exception):
|
||||||
|
"""
|
||||||
|
Base class for all exceptions
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidTokenError(PyJWTError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,6 +14,10 @@ class DecodeError(InvalidTokenError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSignatureError(DecodeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ExpiredSignatureError(InvalidTokenError):
|
class ExpiredSignatureError(InvalidTokenError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -26,7 +38,7 @@ class ImmatureSignatureError(InvalidTokenError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class InvalidKeyError(Exception):
|
class InvalidKeyError(PyJWTError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,7 +54,13 @@ class MissingRequiredClaimError(InvalidTokenError):
|
||||||
return 'Token is missing the "%s" claim' % self.claim
|
return 'Token is missing the "%s" claim' % self.claim
|
||||||
|
|
||||||
|
|
||||||
# Compatibility aliases (deprecated)
|
class PyJWKError(PyJWTError):
|
||||||
ExpiredSignature = ExpiredSignatureError
|
pass
|
||||||
InvalidAudience = InvalidAudienceError
|
|
||||||
InvalidIssuer = InvalidIssuerError
|
|
||||||
|
class PyJWKSetError(PyJWTError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PyJWKClientError(PyJWTError):
|
||||||
|
pass
|
||||||
|
|
60
lib/jwt/help.py
Normal file
60
lib/jwt/help.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
import json
|
||||||
|
import platform
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from . import __version__ as pyjwt_version
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cryptography
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
cryptography = None # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def info():
|
||||||
|
"""
|
||||||
|
Generate information for a bug report.
|
||||||
|
Based on the requests package help utility module.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
platform_info = {
|
||||||
|
"system": platform.system(),
|
||||||
|
"release": platform.release(),
|
||||||
|
}
|
||||||
|
except OSError:
|
||||||
|
platform_info = {"system": "Unknown", "release": "Unknown"}
|
||||||
|
|
||||||
|
implementation = platform.python_implementation()
|
||||||
|
|
||||||
|
if implementation == "CPython":
|
||||||
|
implementation_version = platform.python_version()
|
||||||
|
elif implementation == "PyPy":
|
||||||
|
implementation_version = "{}.{}.{}".format(
|
||||||
|
sys.pypy_version_info.major,
|
||||||
|
sys.pypy_version_info.minor,
|
||||||
|
sys.pypy_version_info.micro,
|
||||||
|
)
|
||||||
|
if sys.pypy_version_info.releaselevel != "final":
|
||||||
|
implementation_version = "".join(
|
||||||
|
[implementation_version, sys.pypy_version_info.releaselevel]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
implementation_version = "Unknown"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"platform": platform_info,
|
||||||
|
"implementation": {
|
||||||
|
"name": implementation,
|
||||||
|
"version": implementation_version,
|
||||||
|
},
|
||||||
|
"cryptography": {"version": getattr(cryptography, "__version__", "")},
|
||||||
|
"pyjwt": {"version": pyjwt_version},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Pretty-print the bug information as JSON."""
|
||||||
|
print(json.dumps(info(), sort_keys=True, indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
59
lib/jwt/jwks_client.py
Normal file
59
lib/jwt/jwks_client.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
import json
|
||||||
|
import urllib.request
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from .api_jwk import PyJWK, PyJWKSet
|
||||||
|
from .api_jwt import decode_complete as decode_token
|
||||||
|
from .exceptions import PyJWKClientError
|
||||||
|
|
||||||
|
|
||||||
|
class PyJWKClient:
|
||||||
|
def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16):
|
||||||
|
self.uri = uri
|
||||||
|
if cache_keys:
|
||||||
|
# Cache signing keys
|
||||||
|
# Ignore mypy (https://github.com/python/mypy/issues/2427)
|
||||||
|
self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore
|
||||||
|
|
||||||
|
def fetch_data(self) -> Any:
|
||||||
|
with urllib.request.urlopen(self.uri) as response:
|
||||||
|
return json.load(response)
|
||||||
|
|
||||||
|
def get_jwk_set(self) -> PyJWKSet:
|
||||||
|
data = self.fetch_data()
|
||||||
|
return PyJWKSet.from_dict(data)
|
||||||
|
|
||||||
|
def get_signing_keys(self) -> List[PyJWK]:
|
||||||
|
jwk_set = self.get_jwk_set()
|
||||||
|
signing_keys = [
|
||||||
|
jwk_set_key
|
||||||
|
for jwk_set_key in jwk_set.keys
|
||||||
|
if jwk_set_key.public_key_use in ["sig", None] and jwk_set_key.key_id
|
||||||
|
]
|
||||||
|
|
||||||
|
if not signing_keys:
|
||||||
|
raise PyJWKClientError("The JWKS endpoint did not contain any signing keys")
|
||||||
|
|
||||||
|
return signing_keys
|
||||||
|
|
||||||
|
def get_signing_key(self, kid: str) -> PyJWK:
|
||||||
|
signing_keys = self.get_signing_keys()
|
||||||
|
signing_key = None
|
||||||
|
|
||||||
|
for key in signing_keys:
|
||||||
|
if key.key_id == kid:
|
||||||
|
signing_key = key
|
||||||
|
break
|
||||||
|
|
||||||
|
if not signing_key:
|
||||||
|
raise PyJWKClientError(
|
||||||
|
f'Unable to find a signing key that matches: "{kid}"'
|
||||||
|
)
|
||||||
|
|
||||||
|
return signing_key
|
||||||
|
|
||||||
|
def get_signing_key_from_jwt(self, token: str) -> PyJWK:
|
||||||
|
unverified = decode_token(token, options={"verify_signature": False})
|
||||||
|
header = unverified["header"]
|
||||||
|
return self.get_signing_key(header.get("kid"))
|
|
@ -1,67 +1,99 @@
|
||||||
import base64
|
import base64
|
||||||
import binascii
|
import binascii
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
|
||||||
from cryptography.hazmat.primitives.asymmetric.utils import (
|
from cryptography.hazmat.primitives.asymmetric.utils import (
|
||||||
decode_rfc6979_signature, encode_rfc6979_signature
|
decode_dss_signature,
|
||||||
|
encode_dss_signature,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ModuleNotFoundError:
|
||||||
pass
|
EllipticCurve = Any # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def base64url_decode(input):
|
def force_bytes(value: Union[str, bytes]) -> bytes:
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value.encode("utf-8")
|
||||||
|
elif isinstance(value, bytes):
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
raise TypeError("Expected a string value")
|
||||||
|
|
||||||
|
|
||||||
|
def base64url_decode(input: Union[str, bytes]) -> bytes:
|
||||||
|
if isinstance(input, str):
|
||||||
|
input = input.encode("ascii")
|
||||||
|
|
||||||
rem = len(input) % 4
|
rem = len(input) % 4
|
||||||
|
|
||||||
if rem > 0:
|
if rem > 0:
|
||||||
input += b'=' * (4 - rem)
|
input += b"=" * (4 - rem)
|
||||||
|
|
||||||
return base64.urlsafe_b64decode(input)
|
return base64.urlsafe_b64decode(input)
|
||||||
|
|
||||||
|
|
||||||
def base64url_encode(input):
|
def base64url_encode(input: bytes) -> bytes:
|
||||||
return base64.urlsafe_b64encode(input).replace(b'=', b'')
|
return base64.urlsafe_b64encode(input).replace(b"=", b"")
|
||||||
|
|
||||||
|
|
||||||
def merge_dict(original, updates):
|
def to_base64url_uint(val: int) -> bytes:
|
||||||
if not updates:
|
if val < 0:
|
||||||
return original
|
raise ValueError("Must be a positive integer")
|
||||||
|
|
||||||
try:
|
int_bytes = bytes_from_int(val)
|
||||||
merged_options = original.copy()
|
|
||||||
merged_options.update(updates)
|
|
||||||
except (AttributeError, ValueError) as e:
|
|
||||||
raise TypeError('original and updates must be a dictionary: %s' % e)
|
|
||||||
|
|
||||||
return merged_options
|
if len(int_bytes) == 0:
|
||||||
|
int_bytes = b"\x00"
|
||||||
|
|
||||||
|
return base64url_encode(int_bytes)
|
||||||
|
|
||||||
|
|
||||||
def number_to_bytes(num, num_bytes):
|
def from_base64url_uint(val: Union[str, bytes]) -> int:
|
||||||
padded_hex = '%0*x' % (2 * num_bytes, num)
|
if isinstance(val, str):
|
||||||
big_endian = binascii.a2b_hex(padded_hex.encode('ascii'))
|
val = val.encode("ascii")
|
||||||
return big_endian
|
|
||||||
|
data = base64url_decode(val)
|
||||||
|
return int.from_bytes(data, byteorder="big")
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_number(string):
|
def number_to_bytes(num: int, num_bytes: int) -> bytes:
|
||||||
|
padded_hex = "%0*x" % (2 * num_bytes, num)
|
||||||
|
return binascii.a2b_hex(padded_hex.encode("ascii"))
|
||||||
|
|
||||||
|
|
||||||
|
def bytes_to_number(string: bytes) -> int:
|
||||||
return int(binascii.b2a_hex(string), 16)
|
return int(binascii.b2a_hex(string), 16)
|
||||||
|
|
||||||
|
|
||||||
def der_to_raw_signature(der_sig, curve):
|
def bytes_from_int(val: int) -> bytes:
|
||||||
|
remaining = val
|
||||||
|
byte_length = 0
|
||||||
|
|
||||||
|
while remaining != 0:
|
||||||
|
remaining >>= 8
|
||||||
|
byte_length += 1
|
||||||
|
|
||||||
|
return val.to_bytes(byte_length, "big", signed=False)
|
||||||
|
|
||||||
|
|
||||||
|
def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes:
|
||||||
num_bits = curve.key_size
|
num_bits = curve.key_size
|
||||||
num_bytes = (num_bits + 7) // 8
|
num_bytes = (num_bits + 7) // 8
|
||||||
|
|
||||||
r, s = decode_rfc6979_signature(der_sig)
|
r, s = decode_dss_signature(der_sig)
|
||||||
|
|
||||||
return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes)
|
return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes)
|
||||||
|
|
||||||
|
|
||||||
def raw_to_der_signature(raw_sig, curve):
|
def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes:
|
||||||
num_bits = curve.key_size
|
num_bits = curve.key_size
|
||||||
num_bytes = (num_bits + 7) // 8
|
num_bytes = (num_bits + 7) // 8
|
||||||
|
|
||||||
if len(raw_sig) != 2 * num_bytes:
|
if len(raw_sig) != 2 * num_bytes:
|
||||||
raise ValueError('Invalid signature')
|
raise ValueError("Invalid signature")
|
||||||
|
|
||||||
r = bytes_to_number(raw_sig[:num_bytes])
|
r = bytes_to_number(raw_sig[:num_bytes])
|
||||||
s = bytes_to_number(raw_sig[num_bytes:])
|
s = bytes_to_number(raw_sig[num_bytes:])
|
||||||
|
|
||||||
return encode_rfc6979_signature(r, s)
|
return encode_dss_signature(r, s)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue