Merge branch 'nightly' into dependabot/pip/nightly/tokenize-rt-6.0.0

This commit is contained in:
JonnyWong16 2024-08-10 19:19:20 -07:00 committed by GitHub
commit a1ee20a7c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 705 additions and 202 deletions

View file

@ -38,7 +38,7 @@ CL_BLANK = "data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAA
URI_SCHEME = "cloudinary" URI_SCHEME = "cloudinary"
API_VERSION = "v1_1" API_VERSION = "v1_1"
VERSION = "1.40.0" VERSION = "1.41.0"
_USER_PLATFORM_DETAILS = "; ".join((platform(), "Python {}".format(python_version()))) _USER_PLATFORM_DETAILS = "; ".join((platform(), "Python {}".format(python_version())))

View file

@ -543,10 +543,6 @@ def create_upload_preset(**options):
return call_api("post", uri, params, **options) return call_api("post", uri, params, **options)
def create_folder(path, **options):
return call_api("post", ["folders", path], {}, **options)
def root_folders(**options): def root_folders(**options):
return call_api("get", ["folders"], only(options, "next_cursor", "max_results"), **options) return call_api("get", ["folders"], only(options, "next_cursor", "max_results"), **options)
@ -555,6 +551,24 @@ def subfolders(of_folder_path, **options):
return call_api("get", ["folders", of_folder_path], only(options, "next_cursor", "max_results"), **options) return call_api("get", ["folders", of_folder_path], only(options, "next_cursor", "max_results"), **options)
def create_folder(path, **options):
return call_api("post", ["folders", path], {}, **options)
def rename_folder(from_path, to_path, **options):
"""
Renames folder
:param from_path: The full path of an existing asset folder.
:param to_path: The full path of the new asset folder.
:param options: Additional options
:rtype: Response
"""
params = {"to_folder": to_path}
return call_api("put", ["folders", from_path], params, **options)
def delete_folder(path, **options): def delete_folder(path, **options):
"""Deletes folder """Deletes folder
@ -727,7 +741,7 @@ def update_metadata_field(field_external_id, field, **options):
def __metadata_field_params(field): def __metadata_field_params(field):
return only(field, "type", "external_id", "label", "mandatory", "restrictions", return only(field, "type", "external_id", "label", "mandatory", "restrictions",
"default_value", "validation", "datasource") "default_value", "default_disabled", "validation", "datasource")
def delete_metadata_field(field_external_id, **options): def delete_metadata_field(field_external_id, **options):

View file

@ -11,7 +11,7 @@ AUTH_TOKEN_UNSAFE_RE = r'([ "#%&\'\/:;<=>?@\[\\\]^`{\|}~]+)'
def generate(url=None, acl=None, start_time=None, duration=None, def generate(url=None, acl=None, start_time=None, duration=None,
expiration=None, ip=None, key=None, token_name=AUTH_TOKEN_NAME): expiration=None, ip=None, key=None, token_name=AUTH_TOKEN_NAME, **_):
if expiration is None: if expiration is None:
if duration is not None: if duration is not None:

View file

@ -820,7 +820,7 @@ def cloudinary_url(source, **options):
transformation = re.sub(r'([^:])/+', r'\1/', transformation) transformation = re.sub(r'([^:])/+', r'\1/', transformation)
signature = None signature = None
if sign_url and not auth_token: if sign_url and (not auth_token or auth_token.pop('set_url_signature', False)):
to_sign = "/".join(__compact([transformation, source_to_sign])) to_sign = "/".join(__compact([transformation, source_to_sign]))
if long_url_signature: if long_url_signature:
# Long signature forces SHA256 # Long signature forces SHA256

View file

@ -25,7 +25,7 @@ from ._compat import (
install, install,
) )
from ._functools import method_cache, pass_none from ._functools import method_cache, pass_none
from ._itertools import always_iterable, unique_everseen from ._itertools import always_iterable, bucket, unique_everseen
from ._meta import PackageMetadata, SimplePath from ._meta import PackageMetadata, SimplePath
from contextlib import suppress from contextlib import suppress
@ -39,6 +39,7 @@ __all__ = [
'DistributionFinder', 'DistributionFinder',
'PackageMetadata', 'PackageMetadata',
'PackageNotFoundError', 'PackageNotFoundError',
'SimplePath',
'distribution', 'distribution',
'distributions', 'distributions',
'entry_points', 'entry_points',
@ -388,7 +389,7 @@ class Distribution(metaclass=abc.ABCMeta):
if not name: if not name:
raise ValueError("A distribution name is required.") raise ValueError("A distribution name is required.")
try: try:
return next(iter(cls.discover(name=name))) return next(iter(cls._prefer_valid(cls.discover(name=name))))
except StopIteration: except StopIteration:
raise PackageNotFoundError(name) raise PackageNotFoundError(name)
@ -412,6 +413,16 @@ class Distribution(metaclass=abc.ABCMeta):
resolver(context) for resolver in cls._discover_resolvers() resolver(context) for resolver in cls._discover_resolvers()
) )
@staticmethod
def _prefer_valid(dists: Iterable[Distribution]) -> Iterable[Distribution]:
"""
Prefer (move to the front) distributions that have metadata.
Ref python/importlib_resources#489.
"""
buckets = bucket(dists, lambda dist: bool(dist.metadata))
return itertools.chain(buckets[True], buckets[False])
@staticmethod @staticmethod
def at(path: str | os.PathLike[str]) -> Distribution: def at(path: str | os.PathLike[str]) -> Distribution:
"""Return a Distribution for the indicated metadata path. """Return a Distribution for the indicated metadata path.

View file

@ -1,3 +1,4 @@
from collections import defaultdict, deque
from itertools import filterfalse from itertools import filterfalse
@ -71,3 +72,100 @@ def always_iterable(obj, base_type=(str, bytes)):
return iter(obj) return iter(obj)
except TypeError: except TypeError:
return iter((obj,)) return iter((obj,))
# Copied from more_itertools 10.3
class bucket:
"""Wrap *iterable* and return an object that buckets the iterable into
child iterables based on a *key* function.
>>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
>>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
>>> sorted(list(s)) # Get the keys
['a', 'b', 'c']
>>> a_iterable = s['a']
>>> next(a_iterable)
'a1'
>>> next(a_iterable)
'a2'
>>> list(s['b'])
['b1', 'b2', 'b3']
The original iterable will be advanced and its items will be cached until
they are used by the child iterables. This may require significant storage.
By default, attempting to select a bucket to which no items belong will
exhaust the iterable and cache all values.
If you specify a *validator* function, selected buckets will instead be
checked against it.
>>> from itertools import count
>>> it = count(1, 2) # Infinite sequence of odd numbers
>>> key = lambda x: x % 10 # Bucket by last digit
>>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
>>> s = bucket(it, key=key, validator=validator)
>>> 2 in s
False
>>> list(s[2])
[]
"""
def __init__(self, iterable, key, validator=None):
self._it = iter(iterable)
self._key = key
self._cache = defaultdict(deque)
self._validator = validator or (lambda x: True)
def __contains__(self, value):
if not self._validator(value):
return False
try:
item = next(self[value])
except StopIteration:
return False
else:
self._cache[value].appendleft(item)
return True
def _get_values(self, value):
"""
Helper to yield items from the parent iterator that match *value*.
Items that don't match are stored in the local cache as they
are encountered.
"""
while True:
# If we've cached some items that match the target value, emit
# the first one and evict it from the cache.
if self._cache[value]:
yield self._cache[value].popleft()
# Otherwise we need to advance the parent iterator to search for
# a matching item, caching the rest.
else:
while True:
try:
item = next(self._it)
except StopIteration:
return
item_value = self._key(item)
if item_value == value:
yield item
break
elif self._validator(item_value):
self._cache[item_value].append(item)
def __iter__(self):
for item in self._it:
item_value = self._key(item)
if self._validator(item_value):
self._cache[item_value].append(item)
yield from self._cache.keys()
def __getitem__(self, value):
if not self._validator(value):
return iter(())
return self._get_values(value)

View file

@ -14,6 +14,14 @@ def compose(*funcs):
""" """
Compose any number of unary functions into a single unary function. Compose any number of unary functions into a single unary function.
Comparable to
`function composition <https://en.wikipedia.org/wiki/Function_composition>`_
in mathematics:
``h = g f`` implies ``h(x) = g(f(x))``.
In Python, ``h = compose(g, f)``.
>>> import textwrap >>> import textwrap
>>> expected = str.strip(textwrap.dedent(compose.__doc__)) >>> expected = str.strip(textwrap.dedent(compose.__doc__))
>>> strip_and_dedent = compose(str.strip, textwrap.dedent) >>> strip_and_dedent = compose(str.strip, textwrap.dedent)

View file

@ -27,7 +27,7 @@ from .exceptions import (
) )
from .jwks_client import PyJWKClient from .jwks_client import PyJWKClient
__version__ = "2.8.0" __version__ = "2.9.0"
__title__ = "PyJWT" __title__ = "PyJWT"
__description__ = "JSON Web Token implementation in Python" __description__ = "JSON Web Token implementation in Python"

View file

@ -3,9 +3,8 @@ from __future__ import annotations
import hashlib import hashlib
import hmac import hmac
import json import json
import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload
from .exceptions import InvalidKeyError from .exceptions import InvalidKeyError
from .types import HashlibHash, JWKDict from .types import HashlibHash, JWKDict
@ -21,14 +20,8 @@ from .utils import (
to_base64url_uint, to_base64url_uint,
) )
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
try: try:
from cryptography.exceptions import InvalidSignature from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric import padding
@ -194,18 +187,16 @@ class Algorithm(ABC):
@overload @overload
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover
... # pragma: no cover
@overload @overload
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover
... # pragma: no cover
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]: def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str:
""" """
Serializes a given key into a JWK Serializes a given key into a JWK
""" """
@ -274,16 +265,18 @@ class HMACAlgorithm(Algorithm):
@overload @overload
@staticmethod @staticmethod
def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict: def to_jwk(
... # pragma: no cover key_obj: str | bytes, as_dict: Literal[True]
) -> JWKDict: ... # pragma: no cover
@overload @overload
@staticmethod @staticmethod
def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str: def to_jwk(
... # pragma: no cover key_obj: str | bytes, as_dict: Literal[False] = False
) -> str: ... # pragma: no cover
@staticmethod @staticmethod
def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]: def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str:
jwk = { jwk = {
"k": base64url_encode(force_bytes(key_obj)).decode(), "k": base64url_encode(force_bytes(key_obj)).decode(),
"kty": "oct", "kty": "oct",
@ -350,22 +343,25 @@ if has_crypto:
RSAPrivateKey, load_pem_private_key(key_bytes, password=None) RSAPrivateKey, load_pem_private_key(key_bytes, password=None)
) )
except ValueError: except ValueError:
try:
return cast(RSAPublicKey, load_pem_public_key(key_bytes)) return cast(RSAPublicKey, load_pem_public_key(key_bytes))
except (ValueError, UnsupportedAlgorithm):
raise InvalidKeyError("Could not parse the provided public key.")
@overload @overload
@staticmethod
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict:
... # pragma: no cover
@overload
@staticmethod
def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str:
... # pragma: no cover
@staticmethod @staticmethod
def to_jwk( def to_jwk(
key_obj: AllowedRSAKeys, as_dict: bool = False key_obj: AllowedRSAKeys, as_dict: Literal[True]
) -> Union[JWKDict, str]: ) -> JWKDict: ... # pragma: no cover
@overload
@staticmethod
def to_jwk(
key_obj: AllowedRSAKeys, as_dict: Literal[False] = False
) -> str: ... # pragma: no cover
@staticmethod
def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str:
obj: dict[str, Any] | None = None obj: dict[str, Any] | None = None
if hasattr(key_obj, "private_numbers"): if hasattr(key_obj, "private_numbers"):
@ -533,7 +529,7 @@ if has_crypto:
return der_to_raw_signature(der_sig, key.curve) return der_to_raw_signature(der_sig, key.curve)
def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool: def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool:
try: try:
der_sig = raw_to_der_signature(sig, key.curve) der_sig = raw_to_der_signature(sig, key.curve)
except ValueError: except ValueError:
@ -552,18 +548,18 @@ if has_crypto:
@overload @overload
@staticmethod @staticmethod
def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict: def to_jwk(
... # pragma: no cover key_obj: AllowedECKeys, as_dict: Literal[True]
) -> JWKDict: ... # pragma: no cover
@overload @overload
@staticmethod @staticmethod
def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str: def to_jwk(
... # pragma: no cover key_obj: AllowedECKeys, as_dict: Literal[False] = False
) -> str: ... # pragma: no cover
@staticmethod @staticmethod
def to_jwk( def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str:
key_obj: AllowedECKeys, as_dict: bool = False
) -> Union[JWKDict, str]:
if isinstance(key_obj, EllipticCurvePrivateKey): if isinstance(key_obj, EllipticCurvePrivateKey):
public_numbers = key_obj.public_key().public_numbers() public_numbers = key_obj.public_key().public_numbers()
elif isinstance(key_obj, EllipticCurvePublicKey): elif isinstance(key_obj, EllipticCurvePublicKey):
@ -771,16 +767,18 @@ if has_crypto:
@overload @overload
@staticmethod @staticmethod
def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict: def to_jwk(
... # pragma: no cover key: AllowedOKPKeys, as_dict: Literal[True]
) -> JWKDict: ... # pragma: no cover
@overload @overload
@staticmethod @staticmethod
def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str: def to_jwk(
... # pragma: no cover key: AllowedOKPKeys, as_dict: Literal[False] = False
) -> str: ... # pragma: no cover
@staticmethod @staticmethod
def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]: def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str:
if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)):
x = key.public_bytes( x = key.public_bytes(
encoding=Encoding.Raw, encoding=Encoding.Raw,

View file

@ -5,7 +5,13 @@ import time
from typing import Any from typing import Any
from .algorithms import get_default_algorithms, has_crypto, requires_cryptography from .algorithms import get_default_algorithms, has_crypto, requires_cryptography
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError, PyJWTError from .exceptions import (
InvalidKeyError,
MissingCryptographyError,
PyJWKError,
PyJWKSetError,
PyJWTError,
)
from .types import JWKDict from .types import JWKDict
@ -50,21 +56,25 @@ class PyJWK:
raise InvalidKeyError(f"Unsupported kty: {kty}") raise InvalidKeyError(f"Unsupported kty: {kty}")
if not has_crypto and algorithm in requires_cryptography: if not has_crypto and algorithm in requires_cryptography:
raise PyJWKError(f"{algorithm} requires 'cryptography' to be installed.") raise MissingCryptographyError(
f"{algorithm} requires 'cryptography' to be installed."
)
self.Algorithm = self._algorithms.get(algorithm) self.algorithm_name = algorithm
if not self.Algorithm: if algorithm in self._algorithms:
self.Algorithm = self._algorithms[algorithm]
else:
raise PyJWKError(f"Unable to find an algorithm for key: {self._jwk_data}") raise PyJWKError(f"Unable to find an algorithm for key: {self._jwk_data}")
self.key = self.Algorithm.from_jwk(self._jwk_data) self.key = self.Algorithm.from_jwk(self._jwk_data)
@staticmethod @staticmethod
def from_dict(obj: JWKDict, algorithm: str | None = None) -> "PyJWK": def from_dict(obj: JWKDict, algorithm: str | None = None) -> PyJWK:
return PyJWK(obj, algorithm) return PyJWK(obj, algorithm)
@staticmethod @staticmethod
def from_json(data: str, algorithm: None = None) -> "PyJWK": def from_json(data: str, algorithm: None = None) -> PyJWK:
obj = json.loads(data) obj = json.loads(data)
return PyJWK.from_dict(obj, algorithm) return PyJWK.from_dict(obj, algorithm)
@ -94,7 +104,9 @@ class PyJWKSet:
for key in keys: for key in keys:
try: try:
self.keys.append(PyJWK(key)) self.keys.append(PyJWK(key))
except PyJWTError: except PyJWTError as error:
if isinstance(error, MissingCryptographyError):
raise error
# skip unusable keys # skip unusable keys
continue continue
@ -104,16 +116,16 @@ class PyJWKSet:
) )
@staticmethod @staticmethod
def from_dict(obj: dict[str, Any]) -> "PyJWKSet": def from_dict(obj: dict[str, Any]) -> PyJWKSet:
keys = obj.get("keys", []) keys = obj.get("keys", [])
return PyJWKSet(keys) return PyJWKSet(keys)
@staticmethod @staticmethod
def from_json(data: str) -> "PyJWKSet": def from_json(data: str) -> PyJWKSet:
obj = json.loads(data) obj = json.loads(data)
return PyJWKSet.from_dict(obj) return PyJWKSet.from_dict(obj)
def __getitem__(self, kid: str) -> "PyJWK": def __getitem__(self, kid: str) -> PyJWK:
for key in self.keys: for key in self.keys:
if key.key_id == kid: if key.key_id == kid:
return key return key

View file

@ -11,6 +11,7 @@ from .algorithms import (
has_crypto, has_crypto,
requires_cryptography, requires_cryptography,
) )
from .api_jwk import PyJWK
from .exceptions import ( from .exceptions import (
DecodeError, DecodeError,
InvalidAlgorithmError, InvalidAlgorithmError,
@ -172,7 +173,7 @@ class PyJWS:
def decode_complete( def decode_complete(
self, self,
jwt: str | bytes, jwt: str | bytes,
key: AllowedPublicKeys | str | bytes = "", key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: list[str] | None = None, algorithms: list[str] | None = None,
options: dict[str, Any] | None = None, options: dict[str, Any] | None = None,
detached_payload: bytes | None = None, detached_payload: bytes | None = None,
@ -190,7 +191,7 @@ class PyJWS:
merged_options = {**self.options, **options} merged_options = {**self.options, **options}
verify_signature = merged_options["verify_signature"] verify_signature = merged_options["verify_signature"]
if verify_signature and not algorithms: if verify_signature and not algorithms and not isinstance(key, PyJWK):
raise DecodeError( raise DecodeError(
'It is required that you pass in a value for the "algorithms" argument when calling decode().' 'It is required that you pass in a value for the "algorithms" argument when calling decode().'
) )
@ -217,7 +218,7 @@ class PyJWS:
def decode( def decode(
self, self,
jwt: str | bytes, jwt: str | bytes,
key: AllowedPublicKeys | str | bytes = "", key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: list[str] | None = None, algorithms: list[str] | None = None,
options: dict[str, Any] | None = None, options: dict[str, Any] | None = None,
detached_payload: bytes | None = None, detached_payload: bytes | None = None,
@ -289,9 +290,11 @@ class PyJWS:
signing_input: bytes, signing_input: bytes,
header: dict[str, Any], header: dict[str, Any],
signature: bytes, signature: bytes,
key: AllowedPublicKeys | str | bytes = "", key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: list[str] | None = None, algorithms: list[str] | None = None,
) -> None: ) -> None:
if algorithms is None and isinstance(key, PyJWK):
algorithms = [key.algorithm_name]
try: try:
alg = header["alg"] alg = header["alg"]
except KeyError: except KeyError:
@ -300,6 +303,10 @@ class PyJWS:
if not alg or (algorithms is not None and alg not in algorithms): if not alg or (algorithms is not None and alg not in algorithms):
raise InvalidAlgorithmError("The specified alg value is not allowed") raise InvalidAlgorithmError("The specified alg value is not allowed")
if isinstance(key, PyJWK):
alg_obj = key.Algorithm
prepared_key = key.key
else:
try: try:
alg_obj = self.get_algorithm_by_name(alg) alg_obj = self.get_algorithm_by_name(alg)
except NotImplementedError as e: except NotImplementedError as e:

View file

@ -5,7 +5,7 @@ import warnings
from calendar import timegm from calendar import timegm
from collections.abc import Iterable from collections.abc import Iterable
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, List
from . import api_jws from . import api_jws
from .exceptions import ( from .exceptions import (
@ -21,6 +21,7 @@ from .warnings import RemovedInPyjwt3Warning
if TYPE_CHECKING: if TYPE_CHECKING:
from .algorithms import AllowedPrivateKeys, AllowedPublicKeys from .algorithms import AllowedPrivateKeys, AllowedPublicKeys
from .api_jwk import PyJWK
class PyJWT: class PyJWT:
@ -100,7 +101,7 @@ class PyJWT:
def decode_complete( def decode_complete(
self, self,
jwt: str | bytes, jwt: str | bytes,
key: AllowedPublicKeys | str | bytes = "", key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: list[str] | None = None, algorithms: list[str] | None = None,
options: dict[str, Any] | None = None, options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3 # deprecated arg, remove in pyjwt3
@ -110,7 +111,7 @@ class PyJWT:
# passthrough arguments to _validate_claims # passthrough arguments to _validate_claims
# consider putting in options # consider putting in options
audience: str | Iterable[str] | None = None, audience: str | Iterable[str] | None = None,
issuer: str | None = None, issuer: str | List[str] | None = None,
leeway: float | timedelta = 0, leeway: float | timedelta = 0,
# kwargs # kwargs
**kwargs: Any, **kwargs: Any,
@ -185,7 +186,7 @@ class PyJWT:
def decode( def decode(
self, self,
jwt: str | bytes, jwt: str | bytes,
key: AllowedPublicKeys | str | bytes = "", key: AllowedPublicKeys | PyJWK | str | bytes = "",
algorithms: list[str] | None = None, algorithms: list[str] | None = None,
options: dict[str, Any] | None = None, options: dict[str, Any] | None = None,
# deprecated arg, remove in pyjwt3 # deprecated arg, remove in pyjwt3
@ -195,7 +196,7 @@ class PyJWT:
# passthrough arguments to _validate_claims # passthrough arguments to _validate_claims
# consider putting in options # consider putting in options
audience: str | Iterable[str] | None = None, audience: str | Iterable[str] | None = None,
issuer: str | None = None, issuer: str | List[str] | None = None,
leeway: float | timedelta = 0, leeway: float | timedelta = 0,
# kwargs # kwargs
**kwargs: Any, **kwargs: Any,
@ -300,7 +301,7 @@ class PyJWT:
try: try:
exp = int(payload["exp"]) exp = int(payload["exp"])
except ValueError: except ValueError:
raise DecodeError("Expiration Time claim (exp) must be an" " integer.") raise DecodeError("Expiration Time claim (exp) must be an integer.")
if exp <= (now - leeway): if exp <= (now - leeway):
raise ExpiredSignatureError("Signature has expired") raise ExpiredSignatureError("Signature has expired")
@ -362,6 +363,10 @@ class PyJWT:
if "iss" not in payload: if "iss" not in payload:
raise MissingRequiredClaimError("iss") raise MissingRequiredClaimError("iss")
if isinstance(issuer, list):
if payload["iss"] not in issuer:
raise InvalidIssuerError("Invalid issuer")
else:
if payload["iss"] != issuer: if payload["iss"] != issuer:
raise InvalidIssuerError("Invalid issuer") raise InvalidIssuerError("Invalid issuer")

View file

@ -58,6 +58,10 @@ class PyJWKError(PyJWTError):
pass pass
class MissingCryptographyError(PyJWKError):
pass
class PyJWKSetError(PyJWTError): class PyJWKSetError(PyJWTError):
pass pass

View file

@ -131,26 +131,15 @@ def is_pem_format(key: bytes) -> bool:
# Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46 # Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46
_CERT_SUFFIX = b"-cert-v01@openssh.com" _SSH_KEY_FORMATS = (
_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)")
_SSH_KEY_FORMATS = [
b"ssh-ed25519", b"ssh-ed25519",
b"ssh-rsa", b"ssh-rsa",
b"ssh-dss", b"ssh-dss",
b"ecdsa-sha2-nistp256", b"ecdsa-sha2-nistp256",
b"ecdsa-sha2-nistp384", b"ecdsa-sha2-nistp384",
b"ecdsa-sha2-nistp521", b"ecdsa-sha2-nistp521",
] )
def is_ssh_key(key: bytes) -> bool: def is_ssh_key(key: bytes) -> bool:
if any(string_value in key for string_value in _SSH_KEY_FORMATS): return key.startswith(_SSH_KEY_FORMATS)
return True
ssh_pubkey_match = _SSH_PUBKEY_RC.match(key)
if ssh_pubkey_match:
key_type = ssh_pubkey_match.group(1)
if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
return True
return False

View file

@ -3,4 +3,4 @@
from .more import * # noqa from .more import * # noqa
from .recipes import * # noqa from .recipes import * # noqa
__version__ = '10.3.0' __version__ = '10.4.0'

View file

@ -3,8 +3,9 @@ import warnings
from collections import Counter, defaultdict, deque, abc from collections import Counter, defaultdict, deque, abc
from collections.abc import Sequence from collections.abc import Sequence
from contextlib import suppress
from functools import cached_property, partial, reduce, wraps from functools import cached_property, partial, reduce, wraps
from heapq import heapify, heapreplace, heappop from heapq import heapify, heapreplace
from itertools import ( from itertools import (
chain, chain,
combinations, combinations,
@ -21,10 +22,10 @@ from itertools import (
zip_longest, zip_longest,
product, product,
) )
from math import comb, e, exp, factorial, floor, fsum, log, perm, tau from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau
from queue import Empty, Queue from queue import Empty, Queue
from random import random, randrange, uniform from random import random, randrange, shuffle, uniform
from operator import itemgetter, mul, sub, gt, lt, ge, le from operator import itemgetter, mul, sub, gt, lt, le
from sys import hexversion, maxsize from sys import hexversion, maxsize
from time import monotonic from time import monotonic
@ -34,7 +35,6 @@ from .recipes import (
UnequalIterablesError, UnequalIterablesError,
consume, consume,
flatten, flatten,
pairwise,
powerset, powerset,
take, take,
unique_everseen, unique_everseen,
@ -473,12 +473,10 @@ def ilen(iterable):
This consumes the iterable, so handle with care. This consumes the iterable, so handle with care.
""" """
# This approach was selected because benchmarks showed it's likely the # This is the "most beautiful of the fast variants" of this function.
# fastest of the known implementations at the time of writing. # If you think you can improve on it, please ensure that your version
# See GitHub tracker: #236, #230. # is both 10x faster and 10x more beautiful.
counter = count() return sum(compress(repeat(1), zip(iterable)))
deque(zip(iterable, counter), maxlen=0)
return next(counter)
def iterate(func, start): def iterate(func, start):
@ -666,9 +664,9 @@ def distinct_permutations(iterable, r=None):
>>> sorted(distinct_permutations([1, 0, 1])) >>> sorted(distinct_permutations([1, 0, 1]))
[(0, 1, 1), (1, 0, 1), (1, 1, 0)] [(0, 1, 1), (1, 0, 1), (1, 1, 0)]
Equivalent to ``set(permutations(iterable))``, except duplicates are not Equivalent to yielding from ``set(permutations(iterable))``, except
generated and thrown away. For larger input sequences this is much more duplicates are not generated and thrown away. For larger input sequences
efficient. this is much more efficient.
Duplicate permutations arise when there are duplicated elements in the Duplicate permutations arise when there are duplicated elements in the
input iterable. The number of items returned is input iterable. The number of items returned is
@ -683,6 +681,25 @@ def distinct_permutations(iterable, r=None):
>>> sorted(distinct_permutations(range(3), r=2)) >>> sorted(distinct_permutations(range(3), r=2))
[(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)]
*iterable* need not be sortable, but note that using equal (``x == y``)
but non-identical (``id(x) != id(y)``) elements may produce surprising
behavior. For example, ``1`` and ``True`` are equal but non-identical:
>>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP
[
(1, True, '3'),
(1, '3', True),
('3', 1, True)
]
>>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP
[
(1, 2, '3'),
(1, '3', 2),
(2, 1, '3'),
(2, '3', 1),
('3', 1, 2),
('3', 2, 1)
]
""" """
# Algorithm: https://w.wiki/Qai # Algorithm: https://w.wiki/Qai
@ -749,14 +766,44 @@ def distinct_permutations(iterable, r=None):
i += 1 i += 1
head[i:], tail[:] = tail[: r - i], tail[r - i :] head[i:], tail[:] = tail[: r - i], tail[r - i :]
items = sorted(iterable) items = list(iterable)
try:
items.sort()
sortable = True
except TypeError:
sortable = False
indices_dict = defaultdict(list)
for item in items:
indices_dict[items.index(item)].append(item)
indices = [items.index(item) for item in items]
indices.sort()
equivalent_items = {k: cycle(v) for k, v in indices_dict.items()}
def permuted_items(permuted_indices):
return tuple(
next(equivalent_items[index]) for index in permuted_indices
)
size = len(items) size = len(items)
if r is None: if r is None:
r = size r = size
# functools.partial(_partial, ... )
algorithm = _full if (r == size) else partial(_partial, r=r)
if 0 < r <= size: if 0 < r <= size:
return _full(items) if (r == size) else _partial(items, r) if sortable:
return algorithm(items)
else:
return (
permuted_items(permuted_indices)
for permuted_indices in algorithm(indices)
)
return iter(() if r else ((),)) return iter(() if r else ((),))
@ -1743,7 +1790,9 @@ def zip_offset(*iterables, offsets, longest=False, fillvalue=None):
return zip(*staggered) return zip(*staggered)
def sort_together(iterables, key_list=(0,), key=None, reverse=False): def sort_together(
iterables, key_list=(0,), key=None, reverse=False, strict=False
):
"""Return the input iterables sorted together, with *key_list* as the """Return the input iterables sorted together, with *key_list* as the
priority for sorting. All iterables are trimmed to the length of the priority for sorting. All iterables are trimmed to the length of the
shortest one. shortest one.
@ -1782,6 +1831,10 @@ def sort_together(iterables, key_list=(0,), key=None, reverse=False):
>>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True) >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True)
[(3, 2, 1), ('a', 'b', 'c')] [(3, 2, 1), ('a', 'b', 'c')]
If the *strict* keyword argument is ``True``, then
``UnequalIterablesError`` will be raised if any of the iterables have
different lengths.
""" """
if key is None: if key is None:
# if there is no key function, the key argument to sorted is an # if there is no key function, the key argument to sorted is an
@ -1804,8 +1857,9 @@ def sort_together(iterables, key_list=(0,), key=None, reverse=False):
*get_key_items(zipped_items) *get_key_items(zipped_items)
) )
zipper = zip_equal if strict else zip
return list( return list(
zip(*sorted(zip(*iterables), key=key_argument, reverse=reverse)) zipper(*sorted(zipper(*iterables), key=key_argument, reverse=reverse))
) )
@ -2747,8 +2801,6 @@ class seekable:
>>> it.seek(0) >>> it.seek(0)
>>> next(it), next(it), next(it) >>> next(it), next(it), next(it)
('0', '1', '2') ('0', '1', '2')
>>> next(it)
'3'
You can also seek forward: You can also seek forward:
@ -2756,15 +2808,29 @@ class seekable:
>>> it.seek(10) >>> it.seek(10)
>>> next(it) >>> next(it)
'10' '10'
>>> it.relative_seek(-2) # Seeking relative to the current position
>>> next(it)
'9'
>>> it.seek(20) # Seeking past the end of the source isn't a problem >>> it.seek(20) # Seeking past the end of the source isn't a problem
>>> list(it) >>> list(it)
[] []
>>> it.seek(0) # Resetting works even after hitting the end >>> it.seek(0) # Resetting works even after hitting the end
>>> next(it)
'0'
Call :meth:`relative_seek` to seek relative to the source iterator's
current position.
>>> it = seekable((str(n) for n in range(20)))
>>> next(it), next(it), next(it) >>> next(it), next(it), next(it)
('0', '1', '2') ('0', '1', '2')
>>> it.relative_seek(2)
>>> next(it)
'5'
>>> it.relative_seek(-3) # Source is at '6', we move back to '3'
>>> next(it)
'3'
>>> it.relative_seek(-3) # Source is at '4', we move back to '1'
>>> next(it)
'1'
Call :meth:`peek` to look ahead one item without advancing the iterator: Call :meth:`peek` to look ahead one item without advancing the iterator:
@ -2873,8 +2939,10 @@ class seekable:
consume(self, remainder) consume(self, remainder)
def relative_seek(self, count): def relative_seek(self, count):
index = len(self._cache) if self._index is None:
self.seek(max(index + count, 0)) self._index = len(self._cache)
self.seek(max(self._index + count, 0))
class run_length: class run_length:
@ -2903,7 +2971,7 @@ class run_length:
@staticmethod @staticmethod
def decode(iterable): def decode(iterable):
return chain.from_iterable(repeat(k, n) for k, n in iterable) return chain.from_iterable(starmap(repeat, iterable))
def exactly_n(iterable, n, predicate=bool): def exactly_n(iterable, n, predicate=bool):
@ -2924,14 +2992,34 @@ def exactly_n(iterable, n, predicate=bool):
return len(take(n + 1, filter(predicate, iterable))) == n return len(take(n + 1, filter(predicate, iterable))) == n
def circular_shifts(iterable): def circular_shifts(iterable, steps=1):
"""Return a list of circular shifts of *iterable*. """Yield the circular shifts of *iterable*.
>>> circular_shifts(range(4)) >>> list(circular_shifts(range(4)))
[(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)]
Set *steps* to the number of places to rotate to the left
(or to the right if negative). Defaults to 1.
>>> list(circular_shifts(range(4), 2))
[(0, 1, 2, 3), (2, 3, 0, 1)]
>>> list(circular_shifts(range(4), -1))
[(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)]
""" """
lst = list(iterable) buffer = deque(iterable)
return take(len(lst), windowed(cycle(lst), len(lst))) if steps == 0:
raise ValueError('Steps should be a non-zero integer')
buffer.rotate(steps)
steps = -steps
n = len(buffer)
n //= math.gcd(n, steps)
for __ in repeat(None, n):
buffer.rotate(steps)
yield tuple(buffer)
def make_decorator(wrapping_func, result_index=0): def make_decorator(wrapping_func, result_index=0):
@ -3191,7 +3279,7 @@ def partitions(iterable):
yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))] yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))]
def set_partitions(iterable, k=None): def set_partitions(iterable, k=None, min_size=None, max_size=None):
""" """
Yield the set partitions of *iterable* into *k* parts. Set partitions are Yield the set partitions of *iterable* into *k* parts. Set partitions are
not order-preserving. not order-preserving.
@ -3215,6 +3303,20 @@ def set_partitions(iterable, k=None):
['b', 'ac'] ['b', 'ac']
['a', 'b', 'c'] ['a', 'b', 'c']
if *min_size* and/or *max_size* are given, the minimum and/or maximum size
per block in partition is set.
>>> iterable = 'abc'
>>> for part in set_partitions(iterable, min_size=2):
... print([''.join(p) for p in part])
['abc']
>>> for part in set_partitions(iterable, max_size=2):
... print([''.join(p) for p in part])
['a', 'bc']
['ab', 'c']
['b', 'ac']
['a', 'b', 'c']
""" """
L = list(iterable) L = list(iterable)
n = len(L) n = len(L)
@ -3226,6 +3328,11 @@ def set_partitions(iterable, k=None):
elif k > n: elif k > n:
return return
min_size = min_size if min_size is not None else 0
max_size = max_size if max_size is not None else n
if min_size > max_size:
return
def set_partitions_helper(L, k): def set_partitions_helper(L, k):
n = len(L) n = len(L)
if k == 1: if k == 1:
@ -3242,9 +3349,15 @@ def set_partitions(iterable, k=None):
if k is None: if k is None:
for k in range(1, n + 1): for k in range(1, n + 1):
yield from set_partitions_helper(L, k) yield from filter(
lambda z: all(min_size <= len(bk) <= max_size for bk in z),
set_partitions_helper(L, k),
)
else: else:
yield from set_partitions_helper(L, k) yield from filter(
lambda z: all(min_size <= len(bk) <= max_size for bk in z),
set_partitions_helper(L, k),
)
class time_limited: class time_limited:
@ -3535,32 +3648,27 @@ def map_if(iterable, pred, func, func_else=lambda x: x):
yield func(item) if pred(item) else func_else(item) yield func(item) if pred(item) else func_else(item)
def _sample_unweighted(iterable, k): def _sample_unweighted(iterator, k, strict):
# Implementation of "Algorithm L" from the 1994 paper by Kim-Hung Li: # Algorithm L in the 1994 paper by Kim-Hung Li:
# "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))". # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))".
# Fill up the reservoir (collection of samples) with the first `k` samples reservoir = list(islice(iterator, k))
reservoir = take(k, iterable) if strict and len(reservoir) < k:
raise ValueError('Sample larger than population')
W = 1.0
# Generate random number that's the largest in a sample of k U(0,1) numbers with suppress(StopIteration):
# Largest order statistic: https://en.wikipedia.org/wiki/Order_statistic while True:
W = exp(log(random()) / k)
# The number of elements to skip before changing the reservoir is a random
# number with a geometric distribution. Sample it using random() and logs.
next_index = k + floor(log(random()) / log(1 - W))
for index, element in enumerate(iterable, k):
if index == next_index:
reservoir[randrange(k)] = element
# The new W is the largest in a sample of k U(0, `old_W`) numbers
W *= exp(log(random()) / k) W *= exp(log(random()) / k)
next_index += floor(log(random()) / log(1 - W)) + 1 skip = floor(log(random()) / log1p(-W))
element = next(islice(iterator, skip, None))
reservoir[randrange(k)] = element
shuffle(reservoir)
return reservoir return reservoir
def _sample_weighted(iterable, k, weights): def _sample_weighted(iterator, k, weights, strict):
# Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. : # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. :
# "Weighted random sampling with a reservoir". # "Weighted random sampling with a reservoir".
@ -3569,7 +3677,10 @@ def _sample_weighted(iterable, k, weights):
# Fill up the reservoir (collection of samples) with the first `k` # Fill up the reservoir (collection of samples) with the first `k`
# weight-keys and elements, then heapify the list. # weight-keys and elements, then heapify the list.
reservoir = take(k, zip(weight_keys, iterable)) reservoir = take(k, zip(weight_keys, iterator))
if strict and len(reservoir) < k:
raise ValueError('Sample larger than population')
heapify(reservoir) heapify(reservoir)
# The number of jumps before changing the reservoir is a random variable # The number of jumps before changing the reservoir is a random variable
@ -3577,7 +3688,7 @@ def _sample_weighted(iterable, k, weights):
smallest_weight_key, _ = reservoir[0] smallest_weight_key, _ = reservoir[0]
weights_to_skip = log(random()) / smallest_weight_key weights_to_skip = log(random()) / smallest_weight_key
for weight, element in zip(weights, iterable): for weight, element in zip(weights, iterator):
if weight >= weights_to_skip: if weight >= weights_to_skip:
# The notation here is consistent with the paper, but we store # The notation here is consistent with the paper, but we store
# the weight-keys in log-space for better numerical stability. # the weight-keys in log-space for better numerical stability.
@ -3591,44 +3702,103 @@ def _sample_weighted(iterable, k, weights):
else: else:
weights_to_skip -= weight weights_to_skip -= weight
# Equivalent to [element for weight_key, element in sorted(reservoir)] ret = [element for weight_key, element in reservoir]
return [heappop(reservoir)[1] for _ in range(k)] shuffle(ret)
return ret
def sample(iterable, k, weights=None): def _sample_counted(population, k, counts, strict):
element = None
remaining = 0
def feed(i):
# Advance *i* steps ahead and consume an element
nonlocal element, remaining
while i + 1 > remaining:
i = i - remaining
element = next(population)
remaining = next(counts)
remaining -= i + 1
return element
with suppress(StopIteration):
reservoir = []
for _ in range(k):
reservoir.append(feed(0))
if strict and len(reservoir) < k:
raise ValueError('Sample larger than population')
W = 1.0
while True:
W *= exp(log(random()) / k)
skip = floor(log(random()) / log1p(-W))
element = feed(skip)
reservoir[randrange(k)] = element
shuffle(reservoir)
return reservoir
def sample(iterable, k, weights=None, *, counts=None, strict=False):
"""Return a *k*-length list of elements chosen (without replacement) """Return a *k*-length list of elements chosen (without replacement)
from the *iterable*. Like :func:`random.sample`, but works on iterables from the *iterable*. Similar to :func:`random.sample`, but works on
of unknown length. iterables of unknown length.
>>> iterable = range(100) >>> iterable = range(100)
>>> sample(iterable, 5) # doctest: +SKIP >>> sample(iterable, 5) # doctest: +SKIP
[81, 60, 96, 16, 4] [81, 60, 96, 16, 4]
An iterable with *weights* may also be given: For iterables with repeated elements, you may supply *counts* to
indicate the repeats.
>>> iterable = ['a', 'b']
>>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b'
>>> sample(iterable, k=3, counts=counts) # doctest: +SKIP
['a', 'a', 'b']
An iterable with *weights* may be given:
>>> iterable = range(100) >>> iterable = range(100)
>>> weights = (i * i + 1 for i in range(100)) >>> weights = (i * i + 1 for i in range(100))
>>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP
[79, 67, 74, 66, 78] [79, 67, 74, 66, 78]
The algorithm can also be used to generate weighted random permutations. Weighted selections are made without replacement.
The relative weight of each item determines the probability that it After an element is selected, it is removed from the pool and the
appears late in the permutation. relative weights of the other elements increase (this
does not match the behavior of :func:`random.sample`'s *counts*
parameter). Note that *weights* may not be used with *counts*.
>>> data = "abcdefgh" If the length of *iterable* is less than *k*,
>>> weights = range(1, len(data) + 1) ``ValueError`` is raised if *strict* is ``True`` and
>>> sample(data, k=len(data), weights=weights) # doctest: +SKIP all elements are returned (in shuffled order) if *strict* is ``False``.
['c', 'a', 'b', 'e', 'g', 'd', 'h', 'f']
By default, the `Algorithm L <https://w.wiki/ANrM>`__ reservoir sampling
technique is used. When *weights* are provided,
`Algorithm A-ExpJ <https://w.wiki/ANrS>`__ is used.
""" """
iterator = iter(iterable)
if k < 0:
raise ValueError('k must be non-negative')
if k == 0: if k == 0:
return [] return []
iterable = iter(iterable) if weights is not None and counts is not None:
if weights is None: raise TypeError('weights and counts are mutally exclusive')
return _sample_unweighted(iterable, k)
else: elif weights is not None:
weights = iter(weights) weights = iter(weights)
return _sample_weighted(iterable, k, weights) return _sample_weighted(iterator, k, weights, strict)
elif counts is not None:
counts = iter(counts)
return _sample_counted(iterator, k, counts, strict)
else:
return _sample_unweighted(iterator, k, strict)
def is_sorted(iterable, key=None, reverse=False, strict=False): def is_sorted(iterable, key=None, reverse=False, strict=False):
@ -3650,12 +3820,16 @@ def is_sorted(iterable, key=None, reverse=False, strict=False):
False False
The function returns ``False`` after encountering the first out-of-order The function returns ``False`` after encountering the first out-of-order
item. If there are no out-of-order items, the iterable is exhausted. item, which means it may produce results that differ from the built-in
:func:`sorted` function for objects with unusual comparison dynamics.
If there are no out-of-order items, the iterable is exhausted.
""" """
compare = le if strict else lt
it = iterable if (key is None) else map(key, iterable)
it_1, it_2 = tee(it)
next(it_2 if reverse else it_1, None)
compare = (le if reverse else ge) if strict else (lt if reverse else gt) return not any(map(compare, it_1, it_2))
it = iterable if key is None else map(key, iterable)
return not any(starmap(compare, pairwise(it)))
class AbortThread(BaseException): class AbortThread(BaseException):

View file

@ -2,6 +2,8 @@
from __future__ import annotations from __future__ import annotations
import sys
from types import TracebackType from types import TracebackType
from typing import ( from typing import (
Any, Any,
@ -28,6 +30,9 @@ from typing_extensions import Protocol
_T = TypeVar('_T') _T = TypeVar('_T')
_T1 = TypeVar('_T1') _T1 = TypeVar('_T1')
_T2 = TypeVar('_T2') _T2 = TypeVar('_T2')
_T3 = TypeVar('_T3')
_T4 = TypeVar('_T4')
_T5 = TypeVar('_T5')
_U = TypeVar('_U') _U = TypeVar('_U')
_V = TypeVar('_V') _V = TypeVar('_V')
_W = TypeVar('_W') _W = TypeVar('_W')
@ -35,6 +40,12 @@ _T_co = TypeVar('_T_co', covariant=True)
_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[Any]]) _GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[Any]])
_Raisable = BaseException | Type[BaseException] _Raisable = BaseException | Type[BaseException]
# The type of isinstance's second argument (from typeshed builtins)
if sys.version_info >= (3, 10):
_ClassInfo = type | UnionType | tuple[_ClassInfo, ...]
else:
_ClassInfo = type | tuple[_ClassInfo, ...]
@type_check_only @type_check_only
class _SizedIterable(Protocol[_T_co], Sized, Iterable[_T_co]): ... class _SizedIterable(Protocol[_T_co], Sized, Iterable[_T_co]): ...
@ -135,7 +146,7 @@ def interleave_evenly(
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
def collapse( def collapse(
iterable: Iterable[Any], iterable: Iterable[Any],
base_type: type | None = ..., base_type: _ClassInfo | None = ...,
levels: int | None = ..., levels: int | None = ...,
) -> Iterator[Any]: ... ) -> Iterator[Any]: ...
@overload @overload
@ -213,6 +224,7 @@ def stagger(
class UnequalIterablesError(ValueError): class UnequalIterablesError(ValueError):
def __init__(self, details: tuple[int, int, int] | None = ...) -> None: ... def __init__(self, details: tuple[int, int, int] | None = ...) -> None: ...
# zip_equal
@overload @overload
def zip_equal(__iter1: Iterable[_T1]) -> Iterator[tuple[_T1]]: ... def zip_equal(__iter1: Iterable[_T1]) -> Iterator[tuple[_T1]]: ...
@overload @overload
@ -221,11 +233,35 @@ def zip_equal(
) -> Iterator[tuple[_T1, _T2]]: ... ) -> Iterator[tuple[_T1, _T2]]: ...
@overload @overload
def zip_equal( def zip_equal(
__iter1: Iterable[_T], __iter1: Iterable[_T1], __iter2: Iterable[_T2], __iter3: Iterable[_T3]
__iter2: Iterable[_T], ) -> Iterator[tuple[_T1, _T2, _T3]]: ...
__iter3: Iterable[_T], @overload
*iterables: Iterable[_T], def zip_equal(
) -> Iterator[tuple[_T, ...]]: ... __iter1: Iterable[_T1],
__iter2: Iterable[_T2],
__iter3: Iterable[_T3],
__iter4: Iterable[_T4],
) -> Iterator[tuple[_T1, _T2, _T3, _T4]]: ...
@overload
def zip_equal(
__iter1: Iterable[_T1],
__iter2: Iterable[_T2],
__iter3: Iterable[_T3],
__iter4: Iterable[_T4],
__iter5: Iterable[_T5],
) -> Iterator[tuple[_T1, _T2, _T3, _T4, _T5]]: ...
@overload
def zip_equal(
__iter1: Iterable[Any],
__iter2: Iterable[Any],
__iter3: Iterable[Any],
__iter4: Iterable[Any],
__iter5: Iterable[Any],
__iter6: Iterable[Any],
*iterables: Iterable[Any],
) -> Iterator[tuple[Any, ...]]: ...
# zip_offset
@overload @overload
def zip_offset( def zip_offset(
__iter1: Iterable[_T1], __iter1: Iterable[_T1],
@ -285,12 +321,13 @@ def sort_together(
key_list: Iterable[int] = ..., key_list: Iterable[int] = ...,
key: Callable[..., Any] | None = ..., key: Callable[..., Any] | None = ...,
reverse: bool = ..., reverse: bool = ...,
strict: bool = ...,
) -> list[tuple[_T, ...]]: ... ) -> list[tuple[_T, ...]]: ...
def unzip(iterable: Iterable[Sequence[_T]]) -> tuple[Iterator[_T], ...]: ... def unzip(iterable: Iterable[Sequence[_T]]) -> tuple[Iterator[_T], ...]: ...
def divide(n: int, iterable: Iterable[_T]) -> list[Iterator[_T]]: ... def divide(n: int, iterable: Iterable[_T]) -> list[Iterator[_T]]: ...
def always_iterable( def always_iterable(
obj: object, obj: object,
base_type: type | tuple[type | tuple[Any, ...], ...] | None = ..., base_type: _ClassInfo | None = ...,
) -> Iterator[Any]: ... ) -> Iterator[Any]: ...
def adjacent( def adjacent(
predicate: Callable[[_T], bool], predicate: Callable[[_T], bool],
@ -454,7 +491,9 @@ class run_length:
def exactly_n( def exactly_n(
iterable: Iterable[_T], n: int, predicate: Callable[[_T], object] = ... iterable: Iterable[_T], n: int, predicate: Callable[[_T], object] = ...
) -> bool: ... ) -> bool: ...
def circular_shifts(iterable: Iterable[_T]) -> list[tuple[_T, ...]]: ... def circular_shifts(
iterable: Iterable[_T], steps: int = 1
) -> list[tuple[_T, ...]]: ...
def make_decorator( def make_decorator(
wrapping_func: Callable[..., _U], result_index: int = ... wrapping_func: Callable[..., _U], result_index: int = ...
) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., _U]]]: ... ) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., _U]]]: ...
@ -500,7 +539,10 @@ def replace(
) -> Iterator[_T | _U]: ... ) -> Iterator[_T | _U]: ...
def partitions(iterable: Iterable[_T]) -> Iterator[list[list[_T]]]: ... def partitions(iterable: Iterable[_T]) -> Iterator[list[list[_T]]]: ...
def set_partitions( def set_partitions(
iterable: Iterable[_T], k: int | None = ... iterable: Iterable[_T],
k: int | None = ...,
min_size: int | None = ...,
max_size: int | None = ...,
) -> Iterator[list[list[_T]]]: ... ) -> Iterator[list[list[_T]]]: ...
class time_limited(Generic[_T], Iterator[_T]): class time_limited(Generic[_T], Iterator[_T]):
@ -538,10 +580,22 @@ def map_if(
func: Callable[[Any], Any], func: Callable[[Any], Any],
func_else: Callable[[Any], Any] | None = ..., func_else: Callable[[Any], Any] | None = ...,
) -> Iterator[Any]: ... ) -> Iterator[Any]: ...
def _sample_unweighted(
iterator: Iterator[_T], k: int, strict: bool
) -> list[_T]: ...
def _sample_counted(
population: Iterator[_T], k: int, counts: Iterable[int], strict: bool
) -> list[_T]: ...
def _sample_weighted(
iterator: Iterator[_T], k: int, weights, strict
) -> list[_T]: ...
def sample( def sample(
iterable: Iterable[_T], iterable: Iterable[_T],
k: int, k: int,
weights: Iterable[float] | None = ..., weights: Iterable[float] | None = ...,
*,
counts: Iterable[int] | None = ...,
strict: bool = False,
) -> list[_T]: ... ) -> list[_T]: ...
def is_sorted( def is_sorted(
iterable: Iterable[_T], iterable: Iterable[_T],
@ -577,7 +631,7 @@ class callback_iter(Generic[_T], Iterator[_T]):
def windowed_complete( def windowed_complete(
iterable: Iterable[_T], n: int iterable: Iterable[_T], n: int
) -> Iterator[tuple[_T, ...]]: ... ) -> Iterator[tuple[tuple[_T, ...], tuple[_T, ...], tuple[_T, ...]]]: ...
def all_unique( def all_unique(
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ... iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
) -> bool: ... ) -> bool: ...
@ -608,9 +662,61 @@ class countable(Generic[_T], Iterator[_T]):
items_seen: int items_seen: int
def chunked_even(iterable: Iterable[_T], n: int) -> Iterator[list[_T]]: ... def chunked_even(iterable: Iterable[_T], n: int) -> Iterator[list[_T]]: ...
@overload
def zip_broadcast( def zip_broadcast(
__obj1: _T | Iterable[_T],
*,
scalar_types: _ClassInfo | None = ...,
strict: bool = ...,
) -> Iterable[tuple[_T, ...]]: ...
@overload
def zip_broadcast(
__obj1: _T | Iterable[_T],
__obj2: _T | Iterable[_T],
*,
scalar_types: _ClassInfo | None = ...,
strict: bool = ...,
) -> Iterable[tuple[_T, ...]]: ...
@overload
def zip_broadcast(
__obj1: _T | Iterable[_T],
__obj2: _T | Iterable[_T],
__obj3: _T | Iterable[_T],
*,
scalar_types: _ClassInfo | None = ...,
strict: bool = ...,
) -> Iterable[tuple[_T, ...]]: ...
@overload
def zip_broadcast(
__obj1: _T | Iterable[_T],
__obj2: _T | Iterable[_T],
__obj3: _T | Iterable[_T],
__obj4: _T | Iterable[_T],
*,
scalar_types: _ClassInfo | None = ...,
strict: bool = ...,
) -> Iterable[tuple[_T, ...]]: ...
@overload
def zip_broadcast(
__obj1: _T | Iterable[_T],
__obj2: _T | Iterable[_T],
__obj3: _T | Iterable[_T],
__obj4: _T | Iterable[_T],
__obj5: _T | Iterable[_T],
*,
scalar_types: _ClassInfo | None = ...,
strict: bool = ...,
) -> Iterable[tuple[_T, ...]]: ...
@overload
def zip_broadcast(
__obj1: _T | Iterable[_T],
__obj2: _T | Iterable[_T],
__obj3: _T | Iterable[_T],
__obj4: _T | Iterable[_T],
__obj5: _T | Iterable[_T],
__obj6: _T | Iterable[_T],
*objects: _T | Iterable[_T], *objects: _T | Iterable[_T],
scalar_types: type | tuple[type | tuple[Any, ...], ...] | None = ..., scalar_types: _ClassInfo | None = ...,
strict: bool = ..., strict: bool = ...,
) -> Iterable[tuple[_T, ...]]: ... ) -> Iterable[tuple[_T, ...]]: ...
def unique_in_window( def unique_in_window(

View file

@ -795,8 +795,30 @@ def triplewise(iterable):
[('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')] [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
""" """
for (a, _), (b, c) in pairwise(pairwise(iterable)): # This deviates from the itertools documentation reciple - see
yield a, b, c # https://github.com/more-itertools/more-itertools/issues/889
t1, t2, t3 = tee(iterable, 3)
next(t3, None)
next(t3, None)
next(t2, None)
return zip(t1, t2, t3)
def _sliding_window_islice(iterable, n):
# Fast path for small, non-zero values of n.
iterators = tee(iterable, n)
for i, iterator in enumerate(iterators):
next(islice(iterator, i, i), None)
return zip(*iterators)
def _sliding_window_deque(iterable, n):
# Normal path for other values of n.
it = iter(iterable)
window = deque(islice(it, n - 1), maxlen=n)
for x in it:
window.append(x)
yield tuple(window)
def sliding_window(iterable, n): def sliding_window(iterable, n):
@ -812,11 +834,16 @@ def sliding_window(iterable, n):
For a variant with more features, see :func:`windowed`. For a variant with more features, see :func:`windowed`.
""" """
it = iter(iterable) if n > 20:
window = deque(islice(it, n - 1), maxlen=n) return _sliding_window_deque(iterable, n)
for x in it: elif n > 2:
window.append(x) return _sliding_window_islice(iterable, n)
yield tuple(window) elif n == 2:
return pairwise(iterable)
elif n == 1:
return zip(iterable)
else:
raise ValueError(f'n should be at least one, not {n}')
def subslices(iterable): def subslices(iterable):
@ -1038,9 +1065,6 @@ def totient(n):
>>> totient(12) >>> totient(12)
4 4
""" """
# The itertools docs use unique_justseen instead of set; see for prime in set(factor(n)):
# https://github.com/more-itertools/more-itertools/issues/823 n -= n // prime
for p in set(factor(n)):
n = n // p * (p - 1)
return n return n

View file

@ -10,6 +10,9 @@ from numbers import Number
from typing import Union, Tuple, Iterable from typing import Union, Tuple, Iterable
from typing import cast from typing import cast
import dateutil.parser
import dateutil.tz
# some useful constants # some useful constants
osc_per_year = 290_091_329_207_984_000 osc_per_year = 290_091_329_207_984_000
@ -611,3 +614,40 @@ def date_range(start=None, stop=None, step=None):
while start < stop: while start < stop:
yield start yield start
start += step start += step
tzinfos = dict(
AEST=dateutil.tz.gettz("Australia/Sydney"),
AEDT=dateutil.tz.gettz("Australia/Sydney"),
ACST=dateutil.tz.gettz("Australia/Darwin"),
ACDT=dateutil.tz.gettz("Australia/Adelaide"),
AWST=dateutil.tz.gettz("Australia/Perth"),
EST=dateutil.tz.gettz("America/New_York"),
EDT=dateutil.tz.gettz("America/New_York"),
CST=dateutil.tz.gettz("America/Chicago"),
CDT=dateutil.tz.gettz("America/Chicago"),
MST=dateutil.tz.gettz("America/Denver"),
MDT=dateutil.tz.gettz("America/Denver"),
PST=dateutil.tz.gettz("America/Los_Angeles"),
PDT=dateutil.tz.gettz("America/Los_Angeles"),
GMT=dateutil.tz.gettz("Etc/GMT"),
UTC=dateutil.tz.gettz("UTC"),
CET=dateutil.tz.gettz("Europe/Berlin"),
CEST=dateutil.tz.gettz("Europe/Berlin"),
IST=dateutil.tz.gettz("Asia/Kolkata"),
BST=dateutil.tz.gettz("Europe/London"),
MSK=dateutil.tz.gettz("Europe/Moscow"),
EET=dateutil.tz.gettz("Europe/Helsinki"),
EEST=dateutil.tz.gettz("Europe/Helsinki"),
# Add more mappings as needed
)
def parse(*args, **kwargs):
"""
Parse the input using dateutil.parser.parse with friendly tz support.
>>> parse('2024-07-26 12:59:00 EDT')
datetime.datetime(...America/New_York...)
"""
return dateutil.parser.parse(*args, tzinfos=tzinfos, **kwargs)

View file

@ -1,5 +1,5 @@
apscheduler==3.10.1 apscheduler==3.10.1
importlib-metadata==8.0.0 importlib-metadata==8.2.0
importlib-resources==6.4.0 importlib-resources==6.4.0
pyinstaller==6.8.0 pyinstaller==6.8.0
pyopenssl==24.1.0 pyopenssl==24.1.0

View file

@ -1504,6 +1504,18 @@ def dbcheck():
except sqlite3.OperationalError: except sqlite3.OperationalError:
logger.warn("Unable to capitalize Windows platform values in session_history table.") logger.warn("Unable to capitalize Windows platform values in session_history table.")
# Upgrade session_history table from earlier versions
try:
result = c_db.execute("SELECT platform FROM session_history "
"WHERE platform = 'macos'").fetchall()
if len(result) > 0:
logger.debug("Altering database. Capitalizing macOS platform values in session_history table.")
c_db.execute(
"UPDATE session_history SET platform = 'macOS' WHERE platform = 'macos' "
)
except sqlite3.OperationalError:
logger.warn("Unable to capitalize macOS platform values in session_history table.")
# Upgrade session_history_metadata table from earlier versions # Upgrade session_history_metadata table from earlier versions
try: try:
c_db.execute("SELECT full_title FROM session_history_metadata") c_db.execute("SELECT full_title FROM session_history_metadata")

View file

@ -102,7 +102,8 @@ PLATFORM_NAME_OVERRIDES = {
'Mystery 5': 'Xbox 360', 'Mystery 5': 'Xbox 360',
'WebMAF': 'Playstation 4', 'WebMAF': 'Playstation 4',
'windows': 'Windows', 'windows': 'Windows',
'osx': 'macOS' 'osx': 'macOS',
'macos': 'macOS',
} }
PMS_PLATFORM_NAME_OVERRIDES = { PMS_PLATFORM_NAME_OVERRIDES = {

View file

@ -6,7 +6,7 @@ bleach==6.1.0
certifi==2024.7.4 certifi==2024.7.4
cheroot==10.0.1 cheroot==10.0.1
cherrypy==18.10.0 cherrypy==18.10.0
cloudinary==1.40.0 cloudinary==1.41.0
distro==1.9.0 distro==1.9.0
dnspython==2.6.1 dnspython==2.6.1
facebook-sdk==3.1.0 facebook-sdk==3.1.0
@ -16,7 +16,7 @@ gntp==1.0.3
html5lib==1.1 html5lib==1.1
httpagentparser==1.9.5 httpagentparser==1.9.5
idna==3.7 idna==3.7
importlib-metadata==8.0.0 importlib-metadata==8.2.0
importlib-resources==6.4.0 importlib-resources==6.4.0
git+https://github.com/Tautulli/ipwhois.git@master#egg=ipwhois git+https://github.com/Tautulli/ipwhois.git@master#egg=ipwhois
IPy==1.01 IPy==1.01
@ -29,7 +29,7 @@ platformdirs==4.2.2
plexapi==4.15.15 plexapi==4.15.15
portend==3.2.0 portend==3.2.0
profilehooks==1.12.0 profilehooks==1.12.0
PyJWT==2.8.0 PyJWT==2.9.0
pyparsing==3.1.2 pyparsing==3.1.2
python-dateutil==2.9.0.post0 python-dateutil==2.9.0.post0
python-twitter==3.5 python-twitter==3.5
@ -39,7 +39,7 @@ requests-oauthlib==2.0.0
rumps==0.4.0; platform_system == "Darwin" rumps==0.4.0; platform_system == "Darwin"
simplejson==3.19.2 simplejson==3.19.2
six==1.16.0 six==1.16.0
tempora==5.6.0 tempora==5.7.0
tokenize-rt==6.0.0 tokenize-rt==6.0.0
tzdata==2024.1 tzdata==2024.1
tzlocal==5.0.1 tzlocal==5.0.1