diff --git a/lib/jwt/__init__.py b/lib/jwt/__init__.py index 2840c040..0631c99c 100644 --- a/lib/jwt/__init__.py +++ b/lib/jwt/__init__.py @@ -1,29 +1,70 @@ -# -*- coding: utf-8 -*- -# flake8: noqa - -""" -JSON Web Token implementation - -Minimum implementation based on this spec: -http://self-issued.info/docs/draft-jones-json-web-token-01.html -""" - - -__title__ = 'pyjwt' -__version__ = '1.4.0' -__author__ = 'José Padilla' -__license__ = 'MIT' -__copyright__ = 'Copyright 2015 José Padilla' - - -from .api_jwt import ( - encode, decode, register_algorithm, unregister_algorithm, - get_unverified_header, PyJWT +from .api_jwk import PyJWK, PyJWKSet +from .api_jws import ( + PyJWS, + get_unverified_header, + register_algorithm, + unregister_algorithm, ) -from .api_jws import PyJWS +from .api_jwt import PyJWT, decode, encode from .exceptions import ( - InvalidTokenError, DecodeError, InvalidAudienceError, - ExpiredSignatureError, ImmatureSignatureError, InvalidIssuedAtError, - InvalidIssuerError, ExpiredSignature, InvalidAudience, InvalidIssuer, - MissingRequiredClaimError + DecodeError, + ExpiredSignatureError, + ImmatureSignatureError, + InvalidAlgorithmError, + InvalidAudienceError, + InvalidIssuedAtError, + InvalidIssuerError, + InvalidKeyError, + InvalidSignatureError, + InvalidTokenError, + MissingRequiredClaimError, + PyJWKClientError, + PyJWKError, + PyJWKSetError, + PyJWTError, ) +from .jwks_client import PyJWKClient + +__version__ = "2.2.0" + +__title__ = "PyJWT" +__description__ = "JSON Web Token implementation in Python" +__url__ = "https://pyjwt.readthedocs.io" +__uri__ = __url__ +__doc__ = __description__ + " <" + __uri__ + ">" + +__author__ = "José Padilla" +__email__ = "hello@jpadilla.com" + +__license__ = "MIT" +__copyright__ = "Copyright 2015-2020 José Padilla" + + +__all__ = [ + "PyJWS", + "PyJWT", + "PyJWKClient", + "PyJWK", + "PyJWKSet", + "decode", + "encode", + "get_unverified_header", + "register_algorithm", + "unregister_algorithm", + # Exceptions + "DecodeError", + "ExpiredSignatureError", + "ImmatureSignatureError", + "InvalidAlgorithmError", + "InvalidAudienceError", + "InvalidIssuedAtError", + "InvalidIssuerError", + "InvalidKeyError", + "InvalidSignatureError", + "InvalidTokenError", + "MissingRequiredClaimError", + "PyJWKClientError", + "PyJWKError", + "PyJWKSetError", + "PyJWTError", +] diff --git a/lib/jwt/__main__.py b/lib/jwt/__main__.py deleted file mode 100644 index 2aa70b3d..00000000 --- a/lib/jwt/__main__.py +++ /dev/null @@ -1,135 +0,0 @@ -#!/usr/bin/env python - -from __future__ import absolute_import, print_function - -import json -import optparse -import sys -import time - -from . import DecodeError, __package__, __version__, decode, encode - - -def main(): - - usage = '''Encodes or decodes JSON Web Tokens based on input. - - %prog [options] input - -Decoding examples: - - %prog --key=secret json.web.token - %prog --no-verify json.web.token - -Encoding requires the key option and takes space separated key/value pairs -separated by equals (=) as input. Examples: - - %prog --key=secret iss=me exp=1302049071 - %prog --key=secret foo=bar exp=+10 - -The exp key is special and can take an offset to current Unix time.\ -''' - p = optparse.OptionParser( - usage=usage, - prog=__package__, - version='%s %s' % (__package__, __version__), - ) - - p.add_option( - '-n', '--no-verify', - action='store_false', - dest='verify', - default=True, - help='ignore signature verification on decode' - ) - - p.add_option( - '--key', - dest='key', - metavar='KEY', - default=None, - help='set the secret key to sign with' - ) - - p.add_option( - '--alg', - dest='algorithm', - metavar='ALG', - default='HS256', - help='set crypto algorithm to sign with. default=HS256' - ) - - options, arguments = p.parse_args() - - if len(arguments) > 0 or not sys.stdin.isatty(): - if len(arguments) == 1 and (not options.verify or options.key): - # Try to decode - try: - if not sys.stdin.isatty(): - token = sys.stdin.read() - else: - token = arguments[0] - - token = token.encode('utf-8') - data = decode(token, key=options.key, verify=options.verify) - - print(json.dumps(data)) - sys.exit(0) - except DecodeError as e: - print(e) - sys.exit(1) - - # Try to encode - if options.key is None: - print('Key is required when encoding. See --help for usage.') - sys.exit(1) - - # Build payload object to encode - payload = {} - - for arg in arguments: - try: - k, v = arg.split('=', 1) - - # exp +offset special case? - if k == 'exp' and v[0] == '+' and len(v) > 1: - v = str(int(time.time()+int(v[1:]))) - - # Cast to integer? - if v.isdigit(): - v = int(v) - else: - # Cast to float? - try: - v = float(v) - except ValueError: - pass - - # Cast to true, false, or null? - constants = {'true': True, 'false': False, 'null': None} - - if v in constants: - v = constants[v] - - payload[k] = v - except ValueError: - print('Invalid encoding input at {}'.format(arg)) - sys.exit(1) - - try: - token = encode( - payload, - key=options.key, - algorithm=options.algorithm - ) - - print(token) - sys.exit(0) - except Exception as e: - print(e) - sys.exit(1) - else: - p.print_help() - -if __name__ == '__main__': - main() diff --git a/lib/jwt/algorithms.py b/lib/jwt/algorithms.py index 9c1a7e80..1f8865af 100644 --- a/lib/jwt/algorithms.py +++ b/lib/jwt/algorithms.py @@ -1,61 +1,114 @@ import hashlib import hmac +import json -from .compat import constant_time_compare, string_types, text_type from .exceptions import InvalidKeyError -from .utils import der_to_raw_signature, raw_to_der_signature +from .utils import ( + base64url_decode, + base64url_encode, + der_to_raw_signature, + force_bytes, + from_base64url_uint, + raw_to_der_signature, + to_base64url_uint, +) try: + import cryptography.exceptions + from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import hashes - from cryptography.hazmat.primitives.serialization import ( - load_pem_private_key, load_pem_public_key, load_ssh_public_key + from cryptography.hazmat.primitives.asymmetric import ec, padding + from cryptography.hazmat.primitives.asymmetric.ec import ( + EllipticCurvePrivateKey, + EllipticCurvePublicKey, + ) + from cryptography.hazmat.primitives.asymmetric.ed448 import ( + Ed448PrivateKey, + Ed448PublicKey, + ) + from cryptography.hazmat.primitives.asymmetric.ed25519 import ( + Ed25519PrivateKey, + Ed25519PublicKey, ) from cryptography.hazmat.primitives.asymmetric.rsa import ( - RSAPrivateKey, RSAPublicKey + RSAPrivateKey, + RSAPrivateNumbers, + RSAPublicKey, + RSAPublicNumbers, + rsa_crt_dmp1, + rsa_crt_dmq1, + rsa_crt_iqmp, + rsa_recover_prime_factors, ) - from cryptography.hazmat.primitives.asymmetric.ec import ( - EllipticCurvePrivateKey, EllipticCurvePublicKey + from cryptography.hazmat.primitives.serialization import ( + Encoding, + NoEncryption, + PrivateFormat, + PublicFormat, + load_pem_private_key, + load_pem_public_key, + load_ssh_public_key, ) - from cryptography.hazmat.primitives.asymmetric import ec, padding - from cryptography.hazmat.backends import default_backend - from cryptography.exceptions import InvalidSignature has_crypto = True -except ImportError: +except ModuleNotFoundError: has_crypto = False +requires_cryptography = { + "RS256", + "RS384", + "RS512", + "ES256", + "ES256K", + "ES384", + "ES521", + "ES512", + "PS256", + "PS384", + "PS512", + "EdDSA", +} + def get_default_algorithms(): """ Returns the algorithms that are implemented by the library. """ default_algorithms = { - 'none': NoneAlgorithm(), - 'HS256': HMACAlgorithm(HMACAlgorithm.SHA256), - 'HS384': HMACAlgorithm(HMACAlgorithm.SHA384), - 'HS512': HMACAlgorithm(HMACAlgorithm.SHA512) + "none": NoneAlgorithm(), + "HS256": HMACAlgorithm(HMACAlgorithm.SHA256), + "HS384": HMACAlgorithm(HMACAlgorithm.SHA384), + "HS512": HMACAlgorithm(HMACAlgorithm.SHA512), } if has_crypto: - default_algorithms.update({ - 'RS256': RSAAlgorithm(RSAAlgorithm.SHA256), - 'RS384': RSAAlgorithm(RSAAlgorithm.SHA384), - 'RS512': RSAAlgorithm(RSAAlgorithm.SHA512), - 'ES256': ECAlgorithm(ECAlgorithm.SHA256), - 'ES384': ECAlgorithm(ECAlgorithm.SHA384), - 'ES512': ECAlgorithm(ECAlgorithm.SHA512), - 'PS256': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), - 'PS384': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), - 'PS512': RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512) - }) + default_algorithms.update( + { + "RS256": RSAAlgorithm(RSAAlgorithm.SHA256), + "RS384": RSAAlgorithm(RSAAlgorithm.SHA384), + "RS512": RSAAlgorithm(RSAAlgorithm.SHA512), + "ES256": ECAlgorithm(ECAlgorithm.SHA256), + "ES256K": ECAlgorithm(ECAlgorithm.SHA256), + "ES384": ECAlgorithm(ECAlgorithm.SHA384), + "ES521": ECAlgorithm(ECAlgorithm.SHA512), + "ES512": ECAlgorithm( + ECAlgorithm.SHA512 + ), # Backward compat for #219 fix + "PS256": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA256), + "PS384": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA384), + "PS512": RSAPSSAlgorithm(RSAPSSAlgorithm.SHA512), + "EdDSA": OKPAlgorithm(), + } + ) return default_algorithms -class Algorithm(object): +class Algorithm: """ The interface for an algorithm used to sign and verify tokens. """ + def prepare_key(self, key): """ Performs necessary validation and conversions on the key and returns @@ -77,14 +130,29 @@ class Algorithm(object): """ raise NotImplementedError + @staticmethod + def to_jwk(key_obj): + """ + Serializes a given RSA key into a JWK + """ + raise NotImplementedError + + @staticmethod + def from_jwk(jwk): + """ + Deserializes a given RSA key from JWK back into a PublicKey or PrivateKey object + """ + raise NotImplementedError + class NoneAlgorithm(Algorithm): """ Placeholder for use when no signing or verification operations are required. """ + def prepare_key(self, key): - if key == '': + if key == "": key = None if key is not None: @@ -93,7 +161,7 @@ class NoneAlgorithm(Algorithm): return key def sign(self, msg, key): - return b'' + return b"" def verify(self, msg, key, sig): return False @@ -104,6 +172,7 @@ class HMACAlgorithm(Algorithm): Performs signing and verification operations using HMAC and the specified hash function. """ + SHA256 = hashlib.sha256 SHA384 = hashlib.sha384 SHA512 = hashlib.sha512 @@ -112,30 +181,55 @@ class HMACAlgorithm(Algorithm): self.hash_alg = hash_alg def prepare_key(self, key): - if not isinstance(key, string_types) and not isinstance(key, bytes): - raise TypeError('Expecting a string- or bytes-formatted key.') - - if isinstance(key, text_type): - key = key.encode('utf-8') + key = force_bytes(key) invalid_strings = [ - b'-----BEGIN PUBLIC KEY-----', - b'-----BEGIN CERTIFICATE-----', - b'ssh-rsa' + b"-----BEGIN PUBLIC KEY-----", + b"-----BEGIN CERTIFICATE-----", + b"-----BEGIN RSA PUBLIC KEY-----", + b"ssh-rsa", ] - if any([string_value in key for string_value in invalid_strings]): + if any(string_value in key for string_value in invalid_strings): raise InvalidKeyError( - 'The specified key is an asymmetric key or x509 certificate and' - ' should not be used as an HMAC secret.') + "The specified key is an asymmetric key or x509 certificate and" + " should not be used as an HMAC secret." + ) return key + @staticmethod + def to_jwk(key_obj): + return json.dumps( + { + "k": base64url_encode(force_bytes(key_obj)).decode(), + "kty": "oct", + } + ) + + @staticmethod + def from_jwk(jwk): + try: + if isinstance(jwk, str): + obj = json.loads(jwk) + elif isinstance(jwk, dict): + obj = jwk + else: + raise ValueError + except ValueError: + raise InvalidKeyError("Key is not valid JSON") + + if obj.get("kty") != "oct": + raise InvalidKeyError("Not an HMAC key") + + return base64url_decode(obj["k"]) + def sign(self, msg, key): return hmac.new(key, msg, self.hash_alg).digest() def verify(self, msg, key, sig): - return constant_time_compare(sig, self.sign(msg, key)) + return hmac.compare_digest(sig, self.sign(msg, key)) + if has_crypto: @@ -144,6 +238,7 @@ if has_crypto: Performs signing and verification operations using RSASSA-PKCS-v1_5 and the specified hash function. """ + SHA256 = hashes.SHA256 SHA384 = hashes.SHA384 SHA512 = hashes.SHA512 @@ -152,46 +247,139 @@ if has_crypto: self.hash_alg = hash_alg def prepare_key(self, key): - if isinstance(key, RSAPrivateKey) or \ - isinstance(key, RSAPublicKey): + if isinstance(key, (RSAPrivateKey, RSAPublicKey)): return key - if isinstance(key, string_types): - if isinstance(key, text_type): - key = key.encode('utf-8') + if not isinstance(key, (bytes, str)): + raise TypeError("Expecting a PEM-formatted key.") - try: - if key.startswith(b'ssh-rsa'): - key = load_ssh_public_key(key, backend=default_backend()) - else: - key = load_pem_private_key(key, password=None, backend=default_backend()) - except ValueError: - key = load_pem_public_key(key, backend=default_backend()) - else: - raise TypeError('Expecting a PEM-formatted key.') - - return key - - def sign(self, msg, key): - signer = key.signer( - padding.PKCS1v15(), - self.hash_alg() - ) - - signer.update(msg) - return signer.finalize() - - def verify(self, msg, key, sig): - verifier = key.verifier( - sig, - padding.PKCS1v15(), - self.hash_alg() - ) - - verifier.update(msg) + key = force_bytes(key) try: - verifier.verify() + if key.startswith(b"ssh-rsa"): + key = load_ssh_public_key(key) + else: + key = load_pem_private_key(key, password=None) + except ValueError: + key = load_pem_public_key(key) + return key + + @staticmethod + def to_jwk(key_obj): + obj = None + + if getattr(key_obj, "private_numbers", None): + # Private key + numbers = key_obj.private_numbers() + + obj = { + "kty": "RSA", + "key_ops": ["sign"], + "n": to_base64url_uint(numbers.public_numbers.n).decode(), + "e": to_base64url_uint(numbers.public_numbers.e).decode(), + "d": to_base64url_uint(numbers.d).decode(), + "p": to_base64url_uint(numbers.p).decode(), + "q": to_base64url_uint(numbers.q).decode(), + "dp": to_base64url_uint(numbers.dmp1).decode(), + "dq": to_base64url_uint(numbers.dmq1).decode(), + "qi": to_base64url_uint(numbers.iqmp).decode(), + } + + elif getattr(key_obj, "verify", None): + # Public key + numbers = key_obj.public_numbers() + + obj = { + "kty": "RSA", + "key_ops": ["verify"], + "n": to_base64url_uint(numbers.n).decode(), + "e": to_base64url_uint(numbers.e).decode(), + } + else: + raise InvalidKeyError("Not a public or private key") + + return json.dumps(obj) + + @staticmethod + def from_jwk(jwk): + try: + if isinstance(jwk, str): + obj = json.loads(jwk) + elif isinstance(jwk, dict): + obj = jwk + else: + raise ValueError + except ValueError: + raise InvalidKeyError("Key is not valid JSON") + + if obj.get("kty") != "RSA": + raise InvalidKeyError("Not an RSA key") + + if "d" in obj and "e" in obj and "n" in obj: + # Private key + if "oth" in obj: + raise InvalidKeyError( + "Unsupported RSA private key: > 2 primes not supported" + ) + + other_props = ["p", "q", "dp", "dq", "qi"] + props_found = [prop in obj for prop in other_props] + any_props_found = any(props_found) + + if any_props_found and not all(props_found): + raise InvalidKeyError( + "RSA key must include all parameters if any are present besides d" + ) + + public_numbers = RSAPublicNumbers( + from_base64url_uint(obj["e"]), + from_base64url_uint(obj["n"]), + ) + + if any_props_found: + numbers = RSAPrivateNumbers( + d=from_base64url_uint(obj["d"]), + p=from_base64url_uint(obj["p"]), + q=from_base64url_uint(obj["q"]), + dmp1=from_base64url_uint(obj["dp"]), + dmq1=from_base64url_uint(obj["dq"]), + iqmp=from_base64url_uint(obj["qi"]), + public_numbers=public_numbers, + ) + else: + d = from_base64url_uint(obj["d"]) + p, q = rsa_recover_prime_factors( + public_numbers.n, d, public_numbers.e + ) + + numbers = RSAPrivateNumbers( + d=d, + p=p, + q=q, + dmp1=rsa_crt_dmp1(d, p), + dmq1=rsa_crt_dmq1(d, q), + iqmp=rsa_crt_iqmp(p, q), + public_numbers=public_numbers, + ) + + return numbers.private_key() + elif "n" in obj and "e" in obj: + # Public key + numbers = RSAPublicNumbers( + from_base64url_uint(obj["e"]), + from_base64url_uint(obj["n"]), + ) + + return numbers.public_key() + else: + raise InvalidKeyError("Not a public or private key") + + def sign(self, msg, key): + return key.sign(msg, padding.PKCS1v15(), self.hash_alg()) + + def verify(self, msg, key, sig): + try: + key.verify(sig, msg, padding.PKCS1v15(), self.hash_alg()) return True except InvalidSignature: return False @@ -201,6 +389,7 @@ if has_crypto: Performs signing and verification operations using ECDSA and the specified hash function """ + SHA256 = hashes.SHA256 SHA384 = hashes.SHA384 SHA512 = hashes.SHA512 @@ -209,32 +398,29 @@ if has_crypto: self.hash_alg = hash_alg def prepare_key(self, key): - if isinstance(key, EllipticCurvePrivateKey) or \ - isinstance(key, EllipticCurvePublicKey): + if isinstance(key, (EllipticCurvePrivateKey, EllipticCurvePublicKey)): return key - if isinstance(key, string_types): - if isinstance(key, text_type): - key = key.encode('utf-8') + if not isinstance(key, (bytes, str)): + raise TypeError("Expecting a PEM-formatted key.") - # Attempt to load key. We don't know if it's - # a Signing Key or a Verifying Key, so we try - # the Verifying Key first. - try: - key = load_pem_public_key(key, backend=default_backend()) - except ValueError: - key = load_pem_private_key(key, password=None, backend=default_backend()) + key = force_bytes(key) - else: - raise TypeError('Expecting a PEM-formatted key.') + # Attempt to load key. We don't know if it's + # a Signing Key or a Verifying Key, so we try + # the Verifying Key first. + try: + if key.startswith(b"ecdsa-sha2-"): + key = load_ssh_public_key(key) + else: + key = load_pem_public_key(key) + except ValueError: + key = load_pem_private_key(key, password=None) return key def sign(self, msg, key): - signer = key.signer(ec.ECDSA(self.hash_alg())) - - signer.update(msg) - der_sig = signer.finalize() + der_sig = key.sign(msg, ec.ECDSA(self.hash_alg())) return der_to_raw_signature(der_sig, key.curve) @@ -244,47 +430,245 @@ if has_crypto: except ValueError: return False - verifier = key.verifier(der_sig, ec.ECDSA(self.hash_alg())) - - verifier.update(msg) - try: - verifier.verify() + if isinstance(key, EllipticCurvePrivateKey): + key = key.public_key() + key.verify(der_sig, msg, ec.ECDSA(self.hash_alg())) return True except InvalidSignature: return False + @staticmethod + def from_jwk(jwk): + try: + if isinstance(jwk, str): + obj = json.loads(jwk) + elif isinstance(jwk, dict): + obj = jwk + else: + raise ValueError + except ValueError: + raise InvalidKeyError("Key is not valid JSON") + + if obj.get("kty") != "EC": + raise InvalidKeyError("Not an Elliptic curve key") + + if "x" not in obj or "y" not in obj: + raise InvalidKeyError("Not an Elliptic curve key") + + x = base64url_decode(obj.get("x")) + y = base64url_decode(obj.get("y")) + + curve = obj.get("crv") + if curve == "P-256": + if len(x) == len(y) == 32: + curve_obj = ec.SECP256R1() + else: + raise InvalidKeyError("Coords should be 32 bytes for curve P-256") + elif curve == "P-384": + if len(x) == len(y) == 48: + curve_obj = ec.SECP384R1() + else: + raise InvalidKeyError("Coords should be 48 bytes for curve P-384") + elif curve == "P-521": + if len(x) == len(y) == 66: + curve_obj = ec.SECP521R1() + else: + raise InvalidKeyError("Coords should be 66 bytes for curve P-521") + elif curve == "secp256k1": + if len(x) == len(y) == 32: + curve_obj = ec.SECP256K1() + else: + raise InvalidKeyError( + "Coords should be 32 bytes for curve secp256k1" + ) + else: + raise InvalidKeyError(f"Invalid curve: {curve}") + + public_numbers = ec.EllipticCurvePublicNumbers( + x=int.from_bytes(x, byteorder="big"), + y=int.from_bytes(y, byteorder="big"), + curve=curve_obj, + ) + + if "d" not in obj: + return public_numbers.public_key() + + d = base64url_decode(obj.get("d")) + if len(d) != len(x): + raise InvalidKeyError( + "D should be {} bytes for curve {}", len(x), curve + ) + + return ec.EllipticCurvePrivateNumbers( + int.from_bytes(d, byteorder="big"), public_numbers + ).private_key() + class RSAPSSAlgorithm(RSAAlgorithm): """ Performs a signature using RSASSA-PSS with MGF1 """ def sign(self, msg, key): - signer = key.signer( + return key.sign( + msg, padding.PSS( mgf=padding.MGF1(self.hash_alg()), - salt_length=self.hash_alg.digest_size + salt_length=self.hash_alg.digest_size, ), - self.hash_alg() + self.hash_alg(), ) - signer.update(msg) - return signer.finalize() - def verify(self, msg, key, sig): - verifier = key.verifier( - sig, - padding.PSS( - mgf=padding.MGF1(self.hash_alg()), - salt_length=self.hash_alg.digest_size - ), - self.hash_alg() - ) - - verifier.update(msg) - try: - verifier.verify() + key.verify( + sig, + msg, + padding.PSS( + mgf=padding.MGF1(self.hash_alg()), + salt_length=self.hash_alg.digest_size, + ), + self.hash_alg(), + ) return True except InvalidSignature: return False + + class OKPAlgorithm(Algorithm): + """ + Performs signing and verification operations using EdDSA + + This class requires ``cryptography>=2.6`` to be installed. + """ + + def __init__(self, **kwargs): + pass + + def prepare_key(self, key): + + if isinstance( + key, + (Ed25519PrivateKey, Ed25519PublicKey, Ed448PrivateKey, Ed448PublicKey), + ): + return key + + if isinstance(key, (bytes, str)): + if isinstance(key, str): + key = key.encode("utf-8") + str_key = key.decode("utf-8") + + if "-----BEGIN PUBLIC" in str_key: + return load_pem_public_key(key) + if "-----BEGIN PRIVATE" in str_key: + return load_pem_private_key(key, password=None) + if str_key[0:4] == "ssh-": + return load_ssh_public_key(key) + + raise TypeError("Expecting a PEM-formatted or OpenSSH key.") + + def sign(self, msg, key): + """ + Sign a message ``msg`` using the EdDSA private key ``key`` + :param str|bytes msg: Message to sign + :param Ed25519PrivateKey}Ed448PrivateKey key: A :class:`.Ed25519PrivateKey` + or :class:`.Ed448PrivateKey` iinstance + :return bytes signature: The signature, as bytes + """ + msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg + return key.sign(msg) + + def verify(self, msg, key, sig): + """ + Verify a given ``msg`` against a signature ``sig`` using the EdDSA key ``key`` + + :param str|bytes sig: EdDSA signature to check ``msg`` against + :param str|bytes msg: Message to sign + :param Ed25519PrivateKey|Ed25519PublicKey|Ed448PrivateKey|Ed448PublicKey key: + A private or public EdDSA key instance + :return bool verified: True if signature is valid, False if not. + """ + try: + msg = bytes(msg, "utf-8") if type(msg) is not bytes else msg + sig = bytes(sig, "utf-8") if type(sig) is not bytes else sig + + if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): + key = key.public_key() + key.verify(sig, msg) + return True # If no exception was raised, the signature is valid. + except cryptography.exceptions.InvalidSignature: + return False + + @staticmethod + def to_jwk(key): + if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): + x = key.public_bytes( + encoding=Encoding.Raw, + format=PublicFormat.Raw, + ) + crv = "Ed25519" if isinstance(key, Ed25519PublicKey) else "Ed448" + return json.dumps( + { + "x": base64url_encode(force_bytes(x)).decode(), + "kty": "OKP", + "crv": crv, + } + ) + + if isinstance(key, (Ed25519PrivateKey, Ed448PrivateKey)): + d = key.private_bytes( + encoding=Encoding.Raw, + format=PrivateFormat.Raw, + encryption_algorithm=NoEncryption(), + ) + + x = key.public_key().public_bytes( + encoding=Encoding.Raw, + format=PublicFormat.Raw, + ) + + crv = "Ed25519" if isinstance(key, Ed25519PrivateKey) else "Ed448" + return json.dumps( + { + "x": base64url_encode(force_bytes(x)).decode(), + "d": base64url_encode(force_bytes(d)).decode(), + "kty": "OKP", + "crv": crv, + } + ) + + raise InvalidKeyError("Not a public or private key") + + @staticmethod + def from_jwk(jwk): + try: + if isinstance(jwk, str): + obj = json.loads(jwk) + elif isinstance(jwk, dict): + obj = jwk + else: + raise ValueError + except ValueError: + raise InvalidKeyError("Key is not valid JSON") + + if obj.get("kty") != "OKP": + raise InvalidKeyError("Not an Octet Key Pair") + + curve = obj.get("crv") + if curve != "Ed25519" and curve != "Ed448": + raise InvalidKeyError(f"Invalid curve: {curve}") + + if "x" not in obj: + raise InvalidKeyError('OKP should have "x" parameter') + x = base64url_decode(obj.get("x")) + + try: + if "d" not in obj: + if curve == "Ed25519": + return Ed25519PublicKey.from_public_bytes(x) + return Ed448PublicKey.from_public_bytes(x) + d = base64url_decode(obj.get("d")) + if curve == "Ed25519": + return Ed25519PrivateKey.from_private_bytes(d) + return Ed448PrivateKey.from_private_bytes(d) + except ValueError as err: + raise InvalidKeyError("Invalid key parameter") from err diff --git a/lib/jwt/api_jwk.py b/lib/jwt/api_jwk.py new file mode 100644 index 00000000..a0f6364d --- /dev/null +++ b/lib/jwt/api_jwk.py @@ -0,0 +1,97 @@ +import json + +from .algorithms import get_default_algorithms +from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError + + +class PyJWK: + def __init__(self, jwk_data, algorithm=None): + self._algorithms = get_default_algorithms() + self._jwk_data = jwk_data + + kty = self._jwk_data.get("kty", None) + if not kty: + raise InvalidKeyError("kty is not found: %s" % self._jwk_data) + + if not algorithm and isinstance(self._jwk_data, dict): + algorithm = self._jwk_data.get("alg", None) + + if not algorithm: + # Determine alg with kty (and crv). + crv = self._jwk_data.get("crv", None) + if kty == "EC": + if crv == "P-256" or not crv: + algorithm = "ES256" + elif crv == "P-384": + algorithm = "ES384" + elif crv == "P-521": + algorithm = "ES512" + elif crv == "secp256k1": + algorithm = "ES256K" + else: + raise InvalidKeyError("Unsupported crv: %s" % crv) + elif kty == "RSA": + algorithm = "RS256" + elif kty == "oct": + algorithm = "HS256" + elif kty == "OKP": + if not crv: + raise InvalidKeyError("crv is not found: %s" % self._jwk_data) + if crv == "Ed25519": + algorithm = "EdDSA" + else: + raise InvalidKeyError("Unsupported crv: %s" % crv) + else: + raise InvalidKeyError("Unsupported kty: %s" % kty) + + self.Algorithm = self._algorithms.get(algorithm) + + if not self.Algorithm: + raise PyJWKError("Unable to find a algorithm for key: %s" % self._jwk_data) + + self.key = self.Algorithm.from_jwk(self._jwk_data) + + @staticmethod + def from_dict(obj, algorithm=None): + return PyJWK(obj, algorithm) + + @staticmethod + def from_json(data, algorithm=None): + obj = json.loads(data) + return PyJWK.from_dict(obj, algorithm) + + @property + def key_type(self): + return self._jwk_data.get("kty", None) + + @property + def key_id(self): + return self._jwk_data.get("kid", None) + + @property + def public_key_use(self): + return self._jwk_data.get("use", None) + + +class PyJWKSet: + def __init__(self, keys): + self.keys = [] + + if not keys or not isinstance(keys, list): + raise PyJWKSetError("Invalid JWK Set value") + + if len(keys) == 0: + raise PyJWKSetError("The JWK Set did not contain any keys") + + for key in keys: + self.keys.append(PyJWK(key)) + + @staticmethod + def from_dict(obj): + keys = obj.get("keys", []) + return PyJWKSet(keys) + + @staticmethod + def from_json(data): + obj = json.loads(data) + return PyJWKSet.from_dict(obj) diff --git a/lib/jwt/api_jws.py b/lib/jwt/api_jws.py index 0c61c7df..a61d2277 100644 --- a/lib/jwt/api_jws.py +++ b/lib/jwt/api_jws.py @@ -1,48 +1,54 @@ import binascii import json -import warnings +from collections.abc import Mapping +from typing import Any, Dict, List, Optional, Type -from collections import Mapping - -from .algorithms import Algorithm, get_default_algorithms # NOQA -from .compat import text_type -from .exceptions import DecodeError, InvalidAlgorithmError -from .utils import base64url_decode, base64url_encode, merge_dict +from .algorithms import ( + Algorithm, + get_default_algorithms, + has_crypto, + requires_cryptography, +) +from .exceptions import ( + DecodeError, + InvalidAlgorithmError, + InvalidSignatureError, + InvalidTokenError, +) +from .utils import base64url_decode, base64url_encode -class PyJWS(object): - header_typ = 'JWT' +class PyJWS: + header_typ = "JWT" def __init__(self, algorithms=None, options=None): self._algorithms = get_default_algorithms() - self._valid_algs = (set(algorithms) if algorithms is not None - else set(self._algorithms)) + self._valid_algs = ( + set(algorithms) if algorithms is not None else set(self._algorithms) + ) # Remove algorithms that aren't on the whitelist for key in list(self._algorithms.keys()): if key not in self._valid_algs: del self._algorithms[key] - if not options: + if options is None: options = {} - - self.options = merge_dict(self._get_default_options(), options) + self.options = {**self._get_default_options(), **options} @staticmethod def _get_default_options(): - return { - 'verify_signature': True - } + return {"verify_signature": True} def register_algorithm(self, alg_id, alg_obj): """ Registers a new Algorithm for use when creating and verifying tokens. """ if alg_id in self._algorithms: - raise ValueError('Algorithm already has a handler.') + raise ValueError("Algorithm already has a handler.") if not isinstance(alg_obj, Algorithm): - raise TypeError('Object is not of type `Algorithm`') + raise TypeError("Object is not of type `Algorithm`") self._algorithms[alg_id] = alg_obj self._valid_algs.add(alg_id) @@ -53,8 +59,10 @@ class PyJWS(object): Throws KeyError if algorithm is not registered. """ if alg_id not in self._algorithms: - raise KeyError('The specified algorithm could not be removed' - ' because it is not registered.') + raise KeyError( + "The specified algorithm could not be removed" + " because it is not registered." + ) del self._algorithms[alg_id] self._valid_algs.remove(alg_id) @@ -65,59 +73,98 @@ class PyJWS(object): """ return list(self._valid_algs) - def encode(self, payload, key, algorithm='HS256', headers=None, - json_encoder=None): + def encode( + self, + payload: bytes, + key: str, + algorithm: Optional[str] = "HS256", + headers: Optional[Dict] = None, + json_encoder: Optional[Type[json.JSONEncoder]] = None, + ) -> str: segments = [] if algorithm is None: - algorithm = 'none' + algorithm = "none" - if algorithm not in self._valid_algs: - pass + # Prefer headers["alg"] if present to algorithm parameter. + if headers and "alg" in headers and headers["alg"]: + algorithm = headers["alg"] # Header - header = {'typ': self.header_typ, 'alg': algorithm} + header = {"typ": self.header_typ, "alg": algorithm} if headers: + self._validate_headers(headers) header.update(headers) + if not header["typ"]: + del header["typ"] json_header = json.dumps( - header, - separators=(',', ':'), - cls=json_encoder - ).encode('utf-8') + header, separators=(",", ":"), cls=json_encoder + ).encode() segments.append(base64url_encode(json_header)) segments.append(base64url_encode(payload)) # Segments - signing_input = b'.'.join(segments) + signing_input = b".".join(segments) try: alg_obj = self._algorithms[algorithm] key = alg_obj.prepare_key(key) signature = alg_obj.sign(signing_input, key) except KeyError: - raise NotImplementedError('Algorithm not supported') + if not has_crypto and algorithm in requires_cryptography: + raise NotImplementedError( + "Algorithm '%s' could not be found. Do you have cryptography " + "installed?" % algorithm + ) + else: + raise NotImplementedError("Algorithm not supported") segments.append(base64url_encode(signature)) - return b'.'.join(segments) + encoded_string = b".".join(segments) - def decode(self, jws, key='', verify=True, algorithms=None, options=None, - **kwargs): - payload, signing_input, header, signature = self._load(jws) + return encoded_string.decode("utf-8") - if verify: - merged_options = merge_dict(self.options, options) - if merged_options.get('verify_signature'): - self._verify_signature(payload, signing_input, header, signature, - key, algorithms) - else: - warnings.warn('The verify parameter is deprecated. ' - 'Please use options instead.', DeprecationWarning) + def decode_complete( + self, + jwt: str, + key: str = "", + algorithms: List[str] = None, + options: Dict = None, + ) -> Dict[str, Any]: + if options is None: + options = {} + merged_options = {**self.options, **options} + verify_signature = merged_options["verify_signature"] - return payload + if verify_signature and not algorithms: + raise DecodeError( + 'It is required that you pass in a value for the "algorithms" argument when calling decode().' + ) + + payload, signing_input, header, signature = self._load(jwt) + + if verify_signature: + self._verify_signature(signing_input, header, signature, key, algorithms) + + return { + "payload": payload, + "header": header, + "signature": signature, + } + + def decode( + self, + jwt: str, + key: str = "", + algorithms: List[str] = None, + options: Dict = None, + ) -> str: + decoded = self.decode_complete(jwt, key, algorithms, options) + return decoded["payload"] def get_unverified_header(self, jwt): """Returns back the JWT header parameters as a dict() @@ -125,64 +172,85 @@ class PyJWS(object): Note: The signature is not verified so the header parameters should not be fully trusted until signature verification is complete """ - return self._load(jwt)[2] + headers = self._load(jwt)[2] + self._validate_headers(headers) + + return headers def _load(self, jwt): - if isinstance(jwt, text_type): - jwt = jwt.encode('utf-8') + if isinstance(jwt, str): + jwt = jwt.encode("utf-8") + + if not isinstance(jwt, bytes): + raise DecodeError(f"Invalid token type. Token must be a {bytes}") try: - signing_input, crypto_segment = jwt.rsplit(b'.', 1) - header_segment, payload_segment = signing_input.split(b'.', 1) - except ValueError: - raise DecodeError('Not enough segments') + signing_input, crypto_segment = jwt.rsplit(b".", 1) + header_segment, payload_segment = signing_input.split(b".", 1) + except ValueError as err: + raise DecodeError("Not enough segments") from err try: header_data = base64url_decode(header_segment) - except (TypeError, binascii.Error): - raise DecodeError('Invalid header padding') + except (TypeError, binascii.Error) as err: + raise DecodeError("Invalid header padding") from err try: - header = json.loads(header_data.decode('utf-8')) + header = json.loads(header_data) except ValueError as e: - raise DecodeError('Invalid header string: %s' % e) + raise DecodeError("Invalid header string: %s" % e) from e if not isinstance(header, Mapping): - raise DecodeError('Invalid header string: must be a json object') + raise DecodeError("Invalid header string: must be a json object") try: payload = base64url_decode(payload_segment) - except (TypeError, binascii.Error): - raise DecodeError('Invalid payload padding') + except (TypeError, binascii.Error) as err: + raise DecodeError("Invalid payload padding") from err try: signature = base64url_decode(crypto_segment) - except (TypeError, binascii.Error): - raise DecodeError('Invalid crypto padding') + except (TypeError, binascii.Error) as err: + raise DecodeError("Invalid crypto padding") from err return (payload, signing_input, header, signature) - def _verify_signature(self, payload, signing_input, header, signature, - key='', algorithms=None): + def _verify_signature( + self, + signing_input, + header, + signature, + key="", + algorithms=None, + ): - alg = header.get('alg') + alg = header.get("alg") if algorithms is not None and alg not in algorithms: - raise InvalidAlgorithmError('The specified alg value is not allowed') + raise InvalidAlgorithmError("The specified alg value is not allowed") try: alg_obj = self._algorithms[alg] key = alg_obj.prepare_key(key) if not alg_obj.verify(signing_input, key, signature): - raise DecodeError('Signature verification failed') + raise InvalidSignatureError("Signature verification failed") except KeyError: - raise InvalidAlgorithmError('Algorithm not supported') + raise InvalidAlgorithmError("Algorithm not supported") + + def _validate_headers(self, headers): + if "kid" in headers: + self._validate_kid(headers["kid"]) + + def _validate_kid(self, kid): + if not isinstance(kid, str): + raise InvalidTokenError("Key ID header parameter must be a string") _jws_global_obj = PyJWS() encode = _jws_global_obj.encode +decode_complete = _jws_global_obj.decode_complete decode = _jws_global_obj.decode register_algorithm = _jws_global_obj.register_algorithm unregister_algorithm = _jws_global_obj.unregister_algorithm diff --git a/lib/jwt/api_jwt.py b/lib/jwt/api_jwt.py index 9703b8d6..7e21b754 100644 --- a/lib/jwt/api_jwt.py +++ b/lib/jwt/api_jwt.py @@ -1,187 +1,224 @@ import json -import warnings - from calendar import timegm -from collections import Mapping -from datetime import datetime, timedelta +from collections.abc import Iterable, Mapping +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, List, Optional, Type, Union -from .api_jws import PyJWS -from .algorithms import Algorithm, get_default_algorithms # NOQA -from .compat import string_types, timedelta_total_seconds +from . import api_jws from .exceptions import ( - DecodeError, ExpiredSignatureError, ImmatureSignatureError, - InvalidAudienceError, InvalidIssuedAtError, - InvalidIssuerError, MissingRequiredClaimError + DecodeError, + ExpiredSignatureError, + ImmatureSignatureError, + InvalidAudienceError, + InvalidIssuedAtError, + InvalidIssuerError, + MissingRequiredClaimError, ) -from .utils import merge_dict -class PyJWT(PyJWS): - header_type = 'JWT' +class PyJWT: + def __init__(self, options=None): + if options is None: + options = {} + self.options = {**self._get_default_options(), **options} @staticmethod - def _get_default_options(): + def _get_default_options() -> Dict[str, Union[bool, List[str]]]: return { - 'verify_signature': True, - 'verify_exp': True, - 'verify_nbf': True, - 'verify_iat': True, - 'verify_aud': True, - 'verify_iss': True, - 'require_exp': False, - 'require_iat': False, - 'require_nbf': False + "verify_signature": True, + "verify_exp": True, + "verify_nbf": True, + "verify_iat": True, + "verify_aud": True, + "verify_iss": True, + "require": [], } - def encode(self, payload, key, algorithm='HS256', headers=None, - json_encoder=None): + def encode( + self, + payload: Dict[str, Any], + key: str, + algorithm: Optional[str] = "HS256", + headers: Optional[Dict] = None, + json_encoder: Optional[Type[json.JSONEncoder]] = None, + ) -> str: # Check that we get a mapping if not isinstance(payload, Mapping): - raise TypeError('Expecting a mapping object, as JWT only supports ' - 'JSON objects as payloads.') + raise TypeError( + "Expecting a mapping object, as JWT only supports " + "JSON objects as payloads." + ) # Payload - for time_claim in ['exp', 'iat', 'nbf']: + payload = payload.copy() + for time_claim in ["exp", "iat", "nbf"]: # Convert datetime to a intDate value in known time-format claims if isinstance(payload.get(time_claim), datetime): payload[time_claim] = timegm(payload[time_claim].utctimetuple()) json_payload = json.dumps( - payload, - separators=(',', ':'), - cls=json_encoder - ).encode('utf-8') + payload, separators=(",", ":"), cls=json_encoder + ).encode("utf-8") - return super(PyJWT, self).encode( - json_payload, key, algorithm, headers, json_encoder + return api_jws.encode(json_payload, key, algorithm, headers, json_encoder) + + def decode_complete( + self, + jwt: str, + key: str = "", + algorithms: List[str] = None, + options: Dict = None, + audience: Optional[Union[str, List[str]]] = None, + issuer: Optional[str] = None, + leeway: Union[float, timedelta] = 0, + ) -> Dict[str, Any]: + if options is None: + options = {"verify_signature": True} + else: + options.setdefault("verify_signature", True) + + if not options["verify_signature"]: + options.setdefault("verify_exp", False) + options.setdefault("verify_nbf", False) + options.setdefault("verify_iat", False) + options.setdefault("verify_aud", False) + options.setdefault("verify_iss", False) + + if options["verify_signature"] and not algorithms: + raise DecodeError( + 'It is required that you pass in a value for the "algorithms" argument when calling decode().' + ) + + decoded = api_jws.decode_complete( + jwt, + key=key, + algorithms=algorithms, + options=options, ) - def decode(self, jwt, key='', verify=True, algorithms=None, options=None, - **kwargs): - payload, signing_input, header, signature = self._load(jwt) - - decoded = super(PyJWT, self).decode(jwt, key, verify, algorithms, - options, **kwargs) - try: - payload = json.loads(decoded.decode('utf-8')) + payload = json.loads(decoded["payload"]) except ValueError as e: - raise DecodeError('Invalid payload string: %s' % e) - if not isinstance(payload, Mapping): - raise DecodeError('Invalid payload string: must be a json object') + raise DecodeError("Invalid payload string: %s" % e) + if not isinstance(payload, dict): + raise DecodeError("Invalid payload string: must be a json object") - if verify: - merged_options = merge_dict(self.options, options) - self._validate_claims(payload, merged_options, **kwargs) + merged_options = {**self.options, **options} + self._validate_claims(payload, merged_options, audience, issuer, leeway) - return payload + decoded["payload"] = payload + return decoded - def _validate_claims(self, payload, options, audience=None, issuer=None, - leeway=0, **kwargs): - - if 'verify_expiration' in kwargs: - options['verify_exp'] = kwargs.get('verify_expiration', True) - warnings.warn('The verify_expiration parameter is deprecated. ' - 'Please use options instead.', DeprecationWarning) + def decode( + self, + jwt: str, + key: str = "", + algorithms: List[str] = None, + options: Dict = None, + audience: Optional[Union[str, List[str]]] = None, + issuer: Optional[str] = None, + leeway: Union[float, timedelta] = 0, + ) -> Dict[str, Any]: + decoded = self.decode_complete( + jwt, key, algorithms, options, audience, issuer, leeway + ) + return decoded["payload"] + def _validate_claims(self, payload, options, audience, issuer, leeway): if isinstance(leeway, timedelta): - leeway = timedelta_total_seconds(leeway) + leeway = leeway.total_seconds() - if not isinstance(audience, (string_types, type(None))): - raise TypeError('audience must be a string or None') + if not isinstance(audience, (str, type(None), Iterable)): + raise TypeError("audience must be a string, iterable, or None") self._validate_required_claims(payload, options) - now = timegm(datetime.utcnow().utctimetuple()) + now = timegm(datetime.now(tz=timezone.utc).utctimetuple()) - if 'iat' in payload and options.get('verify_iat'): + if "iat" in payload and options["verify_iat"]: self._validate_iat(payload, now, leeway) - if 'nbf' in payload and options.get('verify_nbf'): + if "nbf" in payload and options["verify_nbf"]: self._validate_nbf(payload, now, leeway) - if 'exp' in payload and options.get('verify_exp'): + if "exp" in payload and options["verify_exp"]: self._validate_exp(payload, now, leeway) - if options.get('verify_iss'): + if options["verify_iss"]: self._validate_iss(payload, issuer) - if options.get('verify_aud'): + if options["verify_aud"]: self._validate_aud(payload, audience) def _validate_required_claims(self, payload, options): - if options.get('require_exp') and payload.get('exp') is None: - raise MissingRequiredClaimError('exp') - - if options.get('require_iat') and payload.get('iat') is None: - raise MissingRequiredClaimError('iat') - - if options.get('require_nbf') and payload.get('nbf') is None: - raise MissingRequiredClaimError('nbf') + for claim in options["require"]: + if payload.get(claim) is None: + raise MissingRequiredClaimError(claim) def _validate_iat(self, payload, now, leeway): try: - iat = int(payload['iat']) + int(payload["iat"]) except ValueError: - raise DecodeError('Issued At claim (iat) must be an integer.') - - if iat > (now + leeway): - raise InvalidIssuedAtError('Issued At claim (iat) cannot be in' - ' the future.') + raise InvalidIssuedAtError("Issued At claim (iat) must be an integer.") def _validate_nbf(self, payload, now, leeway): try: - nbf = int(payload['nbf']) + nbf = int(payload["nbf"]) except ValueError: - raise DecodeError('Not Before claim (nbf) must be an integer.') + raise DecodeError("Not Before claim (nbf) must be an integer.") if nbf > (now + leeway): - raise ImmatureSignatureError('The token is not yet valid (nbf)') + raise ImmatureSignatureError("The token is not yet valid (nbf)") def _validate_exp(self, payload, now, leeway): try: - exp = int(payload['exp']) + exp = int(payload["exp"]) except ValueError: - raise DecodeError('Expiration Time claim (exp) must be an' - ' integer.') + raise DecodeError("Expiration Time claim (exp) must be an" " integer.") if exp < (now - leeway): - raise ExpiredSignatureError('Signature has expired') + raise ExpiredSignatureError("Signature has expired") def _validate_aud(self, payload, audience): - if audience is None and 'aud' not in payload: - return + if audience is None: + if "aud" not in payload or not payload["aud"]: + return + # Application did not specify an audience, but + # the token has the 'aud' claim + raise InvalidAudienceError("Invalid audience") - if audience is not None and 'aud' not in payload: + if "aud" not in payload or not payload["aud"]: # Application specified an audience, but it could not be # verified since the token does not contain a claim. - raise MissingRequiredClaimError('aud') + raise MissingRequiredClaimError("aud") - audience_claims = payload['aud'] + audience_claims = payload["aud"] - if isinstance(audience_claims, string_types): + if isinstance(audience_claims, str): audience_claims = [audience_claims] if not isinstance(audience_claims, list): - raise InvalidAudienceError('Invalid claim format in token') - if any(not isinstance(c, string_types) for c in audience_claims): - raise InvalidAudienceError('Invalid claim format in token') - if audience not in audience_claims: - raise InvalidAudienceError('Invalid audience') + raise InvalidAudienceError("Invalid claim format in token") + if any(not isinstance(c, str) for c in audience_claims): + raise InvalidAudienceError("Invalid claim format in token") + + if isinstance(audience, str): + audience = [audience] + + if all(aud not in audience_claims for aud in audience): + raise InvalidAudienceError("Invalid audience") def _validate_iss(self, payload, issuer): if issuer is None: return - if 'iss' not in payload: - raise MissingRequiredClaimError('iss') + if "iss" not in payload: + raise MissingRequiredClaimError("iss") - if payload['iss'] != issuer: - raise InvalidIssuerError('Invalid issuer') + if payload["iss"] != issuer: + raise InvalidIssuerError("Invalid issuer") _jwt_global_obj = PyJWT() encode = _jwt_global_obj.encode +decode_complete = _jwt_global_obj.decode_complete decode = _jwt_global_obj.decode -register_algorithm = _jwt_global_obj.register_algorithm -unregister_algorithm = _jwt_global_obj.unregister_algorithm -get_unverified_header = _jwt_global_obj.get_unverified_header diff --git a/lib/jwt/compat.py b/lib/jwt/compat.py deleted file mode 100644 index 11d423bd..00000000 --- a/lib/jwt/compat.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -The `compat` module provides support for backwards compatibility with older -versions of python, and compatibility wrappers around optional packages. -""" -# flake8: noqa -import sys -import hmac - - -PY3 = sys.version_info[0] == 3 - - -if PY3: - string_types = str, - text_type = str -else: - string_types = basestring, - text_type = unicode - - -def timedelta_total_seconds(delta): - try: - delta.total_seconds - except AttributeError: - # On Python 2.6, timedelta instances do not have - # a .total_seconds() method. - total_seconds = delta.days * 24 * 60 * 60 + delta.seconds - else: - total_seconds = delta.total_seconds() - - return total_seconds - - -try: - constant_time_compare = hmac.compare_digest -except AttributeError: - # Fallback for Python < 2.7 - def constant_time_compare(val1, val2): - """ - Returns True if the two strings are equal, False otherwise. - - The time taken is independent of the number of characters that match. - """ - if len(val1) != len(val2): - return False - - result = 0 - - for x, y in zip(val1, val2): - result |= ord(x) ^ ord(y) - - return result == 0 diff --git a/lib/jwt/contrib/algorithms/__init__.py b/lib/jwt/contrib/algorithms/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/lib/jwt/contrib/algorithms/py_ecdsa.py b/lib/jwt/contrib/algorithms/py_ecdsa.py deleted file mode 100644 index bf0dea5a..00000000 --- a/lib/jwt/contrib/algorithms/py_ecdsa.py +++ /dev/null @@ -1,60 +0,0 @@ -# Note: This file is named py_ecdsa.py because import behavior in Python 2 -# would cause ecdsa.py to squash the ecdsa library that it depends upon. - -import hashlib - -import ecdsa - -from jwt.algorithms import Algorithm -from jwt.compat import string_types, text_type - - -class ECAlgorithm(Algorithm): - """ - Performs signing and verification operations using - ECDSA and the specified hash function - - This class requires the ecdsa package to be installed. - - This is based off of the implementation in PyJWT 0.3.2 - """ - SHA256 = hashlib.sha256 - SHA384 = hashlib.sha384 - SHA512 = hashlib.sha512 - - def __init__(self, hash_alg): - self.hash_alg = hash_alg - - def prepare_key(self, key): - - if isinstance(key, ecdsa.SigningKey) or \ - isinstance(key, ecdsa.VerifyingKey): - return key - - if isinstance(key, string_types): - if isinstance(key, text_type): - key = key.encode('utf-8') - - # Attempt to load key. We don't know if it's - # a Signing Key or a Verifying Key, so we try - # the Verifying Key first. - try: - key = ecdsa.VerifyingKey.from_pem(key) - except ecdsa.der.UnexpectedDER: - key = ecdsa.SigningKey.from_pem(key) - - else: - raise TypeError('Expecting a PEM-formatted key.') - - return key - - def sign(self, msg, key): - return key.sign(msg, hashfunc=self.hash_alg, - sigencode=ecdsa.util.sigencode_string) - - def verify(self, msg, key, sig): - try: - return key.verify(sig, msg, hashfunc=self.hash_alg, - sigdecode=ecdsa.util.sigdecode_string) - except AssertionError: - return False diff --git a/lib/jwt/contrib/algorithms/pycrypto.py b/lib/jwt/contrib/algorithms/pycrypto.py deleted file mode 100644 index e6afaa59..00000000 --- a/lib/jwt/contrib/algorithms/pycrypto.py +++ /dev/null @@ -1,47 +0,0 @@ -import Crypto.Hash.SHA256 -import Crypto.Hash.SHA384 -import Crypto.Hash.SHA512 - -from Crypto.PublicKey import RSA -from Crypto.Signature import PKCS1_v1_5 - -from jwt.algorithms import Algorithm -from jwt.compat import string_types, text_type - - -class RSAAlgorithm(Algorithm): - """ - Performs signing and verification operations using - RSASSA-PKCS-v1_5 and the specified hash function. - - This class requires PyCrypto package to be installed. - - This is based off of the implementation in PyJWT 0.3.2 - """ - SHA256 = Crypto.Hash.SHA256 - SHA384 = Crypto.Hash.SHA384 - SHA512 = Crypto.Hash.SHA512 - - def __init__(self, hash_alg): - self.hash_alg = hash_alg - - def prepare_key(self, key): - - if isinstance(key, RSA._RSAobj): - return key - - if isinstance(key, string_types): - if isinstance(key, text_type): - key = key.encode('utf-8') - - key = RSA.importKey(key) - else: - raise TypeError('Expecting a PEM- or RSA-formatted key.') - - return key - - def sign(self, msg, key): - return PKCS1_v1_5.new(key).sign(self.hash_alg.new(msg)) - - def verify(self, msg, key, sig): - return PKCS1_v1_5.new(key).verify(self.hash_alg.new(msg), sig) diff --git a/lib/jwt/exceptions.py b/lib/jwt/exceptions.py index 31177a0a..308899aa 100644 --- a/lib/jwt/exceptions.py +++ b/lib/jwt/exceptions.py @@ -1,4 +1,12 @@ -class InvalidTokenError(Exception): +class PyJWTError(Exception): + """ + Base class for all exceptions + """ + + pass + + +class InvalidTokenError(PyJWTError): pass @@ -6,6 +14,10 @@ class DecodeError(InvalidTokenError): pass +class InvalidSignatureError(DecodeError): + pass + + class ExpiredSignatureError(InvalidTokenError): pass @@ -26,7 +38,7 @@ class ImmatureSignatureError(InvalidTokenError): pass -class InvalidKeyError(Exception): +class InvalidKeyError(PyJWTError): pass @@ -42,7 +54,13 @@ class MissingRequiredClaimError(InvalidTokenError): return 'Token is missing the "%s" claim' % self.claim -# Compatibility aliases (deprecated) -ExpiredSignature = ExpiredSignatureError -InvalidAudience = InvalidAudienceError -InvalidIssuer = InvalidIssuerError +class PyJWKError(PyJWTError): + pass + + +class PyJWKSetError(PyJWTError): + pass + + +class PyJWKClientError(PyJWTError): + pass diff --git a/lib/jwt/help.py b/lib/jwt/help.py new file mode 100644 index 00000000..d8f23024 --- /dev/null +++ b/lib/jwt/help.py @@ -0,0 +1,60 @@ +import json +import platform +import sys + +from . import __version__ as pyjwt_version + +try: + import cryptography +except ModuleNotFoundError: + cryptography = None # type: ignore + + +def info(): + """ + Generate information for a bug report. + Based on the requests package help utility module. + """ + try: + platform_info = { + "system": platform.system(), + "release": platform.release(), + } + except OSError: + platform_info = {"system": "Unknown", "release": "Unknown"} + + implementation = platform.python_implementation() + + if implementation == "CPython": + implementation_version = platform.python_version() + elif implementation == "PyPy": + implementation_version = "{}.{}.{}".format( + sys.pypy_version_info.major, + sys.pypy_version_info.minor, + sys.pypy_version_info.micro, + ) + if sys.pypy_version_info.releaselevel != "final": + implementation_version = "".join( + [implementation_version, sys.pypy_version_info.releaselevel] + ) + else: + implementation_version = "Unknown" + + return { + "platform": platform_info, + "implementation": { + "name": implementation, + "version": implementation_version, + }, + "cryptography": {"version": getattr(cryptography, "__version__", "")}, + "pyjwt": {"version": pyjwt_version}, + } + + +def main(): + """Pretty-print the bug information as JSON.""" + print(json.dumps(info(), sort_keys=True, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/lib/jwt/jwks_client.py b/lib/jwt/jwks_client.py new file mode 100644 index 00000000..767b7179 --- /dev/null +++ b/lib/jwt/jwks_client.py @@ -0,0 +1,59 @@ +import json +import urllib.request +from functools import lru_cache +from typing import Any, List + +from .api_jwk import PyJWK, PyJWKSet +from .api_jwt import decode_complete as decode_token +from .exceptions import PyJWKClientError + + +class PyJWKClient: + def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16): + self.uri = uri + if cache_keys: + # Cache signing keys + # Ignore mypy (https://github.com/python/mypy/issues/2427) + self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore + + def fetch_data(self) -> Any: + with urllib.request.urlopen(self.uri) as response: + return json.load(response) + + def get_jwk_set(self) -> PyJWKSet: + data = self.fetch_data() + return PyJWKSet.from_dict(data) + + def get_signing_keys(self) -> List[PyJWK]: + jwk_set = self.get_jwk_set() + signing_keys = [ + jwk_set_key + for jwk_set_key in jwk_set.keys + if jwk_set_key.public_key_use in ["sig", None] and jwk_set_key.key_id + ] + + if not signing_keys: + raise PyJWKClientError("The JWKS endpoint did not contain any signing keys") + + return signing_keys + + def get_signing_key(self, kid: str) -> PyJWK: + signing_keys = self.get_signing_keys() + signing_key = None + + for key in signing_keys: + if key.key_id == kid: + signing_key = key + break + + if not signing_key: + raise PyJWKClientError( + f'Unable to find a signing key that matches: "{kid}"' + ) + + return signing_key + + def get_signing_key_from_jwt(self, token: str) -> PyJWK: + unverified = decode_token(token, options={"verify_signature": False}) + header = unverified["header"] + return self.get_signing_key(header.get("kid")) diff --git a/lib/jwt/contrib/__init__.py b/lib/jwt/py.typed similarity index 100% rename from lib/jwt/contrib/__init__.py rename to lib/jwt/py.typed diff --git a/lib/jwt/utils.py b/lib/jwt/utils.py index 637b8929..9dde10cf 100644 --- a/lib/jwt/utils.py +++ b/lib/jwt/utils.py @@ -1,67 +1,99 @@ import base64 import binascii +from typing import Any, Union try: + from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve from cryptography.hazmat.primitives.asymmetric.utils import ( - decode_rfc6979_signature, encode_rfc6979_signature + decode_dss_signature, + encode_dss_signature, ) -except ImportError: - pass +except ModuleNotFoundError: + EllipticCurve = Any # type: ignore -def base64url_decode(input): +def force_bytes(value: Union[str, bytes]) -> bytes: + if isinstance(value, str): + return value.encode("utf-8") + elif isinstance(value, bytes): + return value + else: + raise TypeError("Expected a string value") + + +def base64url_decode(input: Union[str, bytes]) -> bytes: + if isinstance(input, str): + input = input.encode("ascii") + rem = len(input) % 4 if rem > 0: - input += b'=' * (4 - rem) + input += b"=" * (4 - rem) return base64.urlsafe_b64decode(input) -def base64url_encode(input): - return base64.urlsafe_b64encode(input).replace(b'=', b'') +def base64url_encode(input: bytes) -> bytes: + return base64.urlsafe_b64encode(input).replace(b"=", b"") -def merge_dict(original, updates): - if not updates: - return original +def to_base64url_uint(val: int) -> bytes: + if val < 0: + raise ValueError("Must be a positive integer") - try: - merged_options = original.copy() - merged_options.update(updates) - except (AttributeError, ValueError) as e: - raise TypeError('original and updates must be a dictionary: %s' % e) + int_bytes = bytes_from_int(val) - return merged_options + if len(int_bytes) == 0: + int_bytes = b"\x00" + + return base64url_encode(int_bytes) -def number_to_bytes(num, num_bytes): - padded_hex = '%0*x' % (2 * num_bytes, num) - big_endian = binascii.a2b_hex(padded_hex.encode('ascii')) - return big_endian +def from_base64url_uint(val: Union[str, bytes]) -> int: + if isinstance(val, str): + val = val.encode("ascii") + + data = base64url_decode(val) + return int.from_bytes(data, byteorder="big") -def bytes_to_number(string): +def number_to_bytes(num: int, num_bytes: int) -> bytes: + padded_hex = "%0*x" % (2 * num_bytes, num) + return binascii.a2b_hex(padded_hex.encode("ascii")) + + +def bytes_to_number(string: bytes) -> int: return int(binascii.b2a_hex(string), 16) -def der_to_raw_signature(der_sig, curve): +def bytes_from_int(val: int) -> bytes: + remaining = val + byte_length = 0 + + while remaining != 0: + remaining >>= 8 + byte_length += 1 + + return val.to_bytes(byte_length, "big", signed=False) + + +def der_to_raw_signature(der_sig: bytes, curve: EllipticCurve) -> bytes: num_bits = curve.key_size num_bytes = (num_bits + 7) // 8 - r, s = decode_rfc6979_signature(der_sig) + r, s = decode_dss_signature(der_sig) return number_to_bytes(r, num_bytes) + number_to_bytes(s, num_bytes) -def raw_to_der_signature(raw_sig, curve): +def raw_to_der_signature(raw_sig: bytes, curve: EllipticCurve) -> bytes: num_bits = curve.key_size num_bytes = (num_bits + 7) // 8 if len(raw_sig) != 2 * num_bytes: - raise ValueError('Invalid signature') + raise ValueError("Invalid signature") r = bytes_to_number(raw_sig[:num_bytes]) s = bytes_to_number(raw_sig[num_bytes:]) - return encode_rfc6979_signature(r, s) + return encode_dss_signature(r, s)