diff --git a/lib/cloudinary/__init__.py b/lib/cloudinary/__init__.py index f9b79955..877340c9 100644 --- a/lib/cloudinary/__init__.py +++ b/lib/cloudinary/__init__.py @@ -38,7 +38,7 @@ CL_BLANK = "data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAA URI_SCHEME = "cloudinary" API_VERSION = "v1_1" -VERSION = "1.40.0" +VERSION = "1.41.0" _USER_PLATFORM_DETAILS = "; ".join((platform(), "Python {}".format(python_version()))) diff --git a/lib/cloudinary/api.py b/lib/cloudinary/api.py index 0ae0fb33..95543d32 100644 --- a/lib/cloudinary/api.py +++ b/lib/cloudinary/api.py @@ -543,10 +543,6 @@ def create_upload_preset(**options): return call_api("post", uri, params, **options) -def create_folder(path, **options): - return call_api("post", ["folders", path], {}, **options) - - def root_folders(**options): return call_api("get", ["folders"], only(options, "next_cursor", "max_results"), **options) @@ -555,6 +551,24 @@ def subfolders(of_folder_path, **options): return call_api("get", ["folders", of_folder_path], only(options, "next_cursor", "max_results"), **options) +def create_folder(path, **options): + return call_api("post", ["folders", path], {}, **options) + + +def rename_folder(from_path, to_path, **options): + """ + Renames folder + + :param from_path: The full path of an existing asset folder. + :param to_path: The full path of the new asset folder. + :param options: Additional options + + :rtype: Response + """ + params = {"to_folder": to_path} + return call_api("put", ["folders", from_path], params, **options) + + def delete_folder(path, **options): """Deletes folder @@ -727,7 +741,7 @@ def update_metadata_field(field_external_id, field, **options): def __metadata_field_params(field): return only(field, "type", "external_id", "label", "mandatory", "restrictions", - "default_value", "validation", "datasource") + "default_value", "default_disabled", "validation", "datasource") def delete_metadata_field(field_external_id, **options): diff --git a/lib/cloudinary/auth_token.py b/lib/cloudinary/auth_token.py index 4f6c1fe1..8ddeacdf 100644 --- a/lib/cloudinary/auth_token.py +++ b/lib/cloudinary/auth_token.py @@ -11,7 +11,7 @@ AUTH_TOKEN_UNSAFE_RE = r'([ "#%&\'\/:;<=>?@\[\\\]^`{\|}~]+)' def generate(url=None, acl=None, start_time=None, duration=None, - expiration=None, ip=None, key=None, token_name=AUTH_TOKEN_NAME): + expiration=None, ip=None, key=None, token_name=AUTH_TOKEN_NAME, **_): if expiration is None: if duration is not None: diff --git a/lib/cloudinary/utils.py b/lib/cloudinary/utils.py index 1b7b7215..297bdb7a 100644 --- a/lib/cloudinary/utils.py +++ b/lib/cloudinary/utils.py @@ -820,7 +820,7 @@ def cloudinary_url(source, **options): transformation = re.sub(r'([^:])/+', r'\1/', transformation) signature = None - if sign_url and not auth_token: + if sign_url and (not auth_token or auth_token.pop('set_url_signature', False)): to_sign = "/".join(__compact([transformation, source_to_sign])) if long_url_signature: # Long signature forces SHA256 diff --git a/lib/importlib_metadata/__init__.py b/lib/importlib_metadata/__init__.py index ed481355..2c71d33c 100644 --- a/lib/importlib_metadata/__init__.py +++ b/lib/importlib_metadata/__init__.py @@ -25,7 +25,7 @@ from ._compat import ( install, ) from ._functools import method_cache, pass_none -from ._itertools import always_iterable, unique_everseen +from ._itertools import always_iterable, bucket, unique_everseen from ._meta import PackageMetadata, SimplePath from contextlib import suppress @@ -39,6 +39,7 @@ __all__ = [ 'DistributionFinder', 'PackageMetadata', 'PackageNotFoundError', + 'SimplePath', 'distribution', 'distributions', 'entry_points', @@ -388,7 +389,7 @@ class Distribution(metaclass=abc.ABCMeta): if not name: raise ValueError("A distribution name is required.") try: - return next(iter(cls.discover(name=name))) + return next(iter(cls._prefer_valid(cls.discover(name=name)))) except StopIteration: raise PackageNotFoundError(name) @@ -412,6 +413,16 @@ class Distribution(metaclass=abc.ABCMeta): resolver(context) for resolver in cls._discover_resolvers() ) + @staticmethod + def _prefer_valid(dists: Iterable[Distribution]) -> Iterable[Distribution]: + """ + Prefer (move to the front) distributions that have metadata. + + Ref python/importlib_resources#489. + """ + buckets = bucket(dists, lambda dist: bool(dist.metadata)) + return itertools.chain(buckets[True], buckets[False]) + @staticmethod def at(path: str | os.PathLike[str]) -> Distribution: """Return a Distribution for the indicated metadata path. diff --git a/lib/importlib_metadata/_itertools.py b/lib/importlib_metadata/_itertools.py index d4ca9b91..79d37198 100644 --- a/lib/importlib_metadata/_itertools.py +++ b/lib/importlib_metadata/_itertools.py @@ -1,3 +1,4 @@ +from collections import defaultdict, deque from itertools import filterfalse @@ -71,3 +72,100 @@ def always_iterable(obj, base_type=(str, bytes)): return iter(obj) except TypeError: return iter((obj,)) + + +# Copied from more_itertools 10.3 +class bucket: + """Wrap *iterable* and return an object that buckets the iterable into + child iterables based on a *key* function. + + >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] + >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character + >>> sorted(list(s)) # Get the keys + ['a', 'b', 'c'] + >>> a_iterable = s['a'] + >>> next(a_iterable) + 'a1' + >>> next(a_iterable) + 'a2' + >>> list(s['b']) + ['b1', 'b2', 'b3'] + + The original iterable will be advanced and its items will be cached until + they are used by the child iterables. This may require significant storage. + + By default, attempting to select a bucket to which no items belong will + exhaust the iterable and cache all values. + If you specify a *validator* function, selected buckets will instead be + checked against it. + + >>> from itertools import count + >>> it = count(1, 2) # Infinite sequence of odd numbers + >>> key = lambda x: x % 10 # Bucket by last digit + >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only + >>> s = bucket(it, key=key, validator=validator) + >>> 2 in s + False + >>> list(s[2]) + [] + + """ + + def __init__(self, iterable, key, validator=None): + self._it = iter(iterable) + self._key = key + self._cache = defaultdict(deque) + self._validator = validator or (lambda x: True) + + def __contains__(self, value): + if not self._validator(value): + return False + + try: + item = next(self[value]) + except StopIteration: + return False + else: + self._cache[value].appendleft(item) + + return True + + def _get_values(self, value): + """ + Helper to yield items from the parent iterator that match *value*. + Items that don't match are stored in the local cache as they + are encountered. + """ + while True: + # If we've cached some items that match the target value, emit + # the first one and evict it from the cache. + if self._cache[value]: + yield self._cache[value].popleft() + # Otherwise we need to advance the parent iterator to search for + # a matching item, caching the rest. + else: + while True: + try: + item = next(self._it) + except StopIteration: + return + item_value = self._key(item) + if item_value == value: + yield item + break + elif self._validator(item_value): + self._cache[item_value].append(item) + + def __iter__(self): + for item in self._it: + item_value = self._key(item) + if self._validator(item_value): + self._cache[item_value].append(item) + + yield from self._cache.keys() + + def __getitem__(self, value): + if not self._validator(value): + return iter(()) + + return self._get_values(value) diff --git a/lib/jaraco/functools/__init__.py b/lib/jaraco/functools/__init__.py index ca6c22fa..d66f1af1 100644 --- a/lib/jaraco/functools/__init__.py +++ b/lib/jaraco/functools/__init__.py @@ -14,6 +14,14 @@ def compose(*funcs): """ Compose any number of unary functions into a single unary function. + Comparable to + `function composition `_ + in mathematics: + + ``h = g ∘ f`` implies ``h(x) = g(f(x))``. + + In Python, ``h = compose(g, f)``. + >>> import textwrap >>> expected = str.strip(textwrap.dedent(compose.__doc__)) >>> strip_and_dedent = compose(str.strip, textwrap.dedent) diff --git a/lib/jwt/__init__.py b/lib/jwt/__init__.py index 68d09c1c..b7a258d7 100644 --- a/lib/jwt/__init__.py +++ b/lib/jwt/__init__.py @@ -27,7 +27,7 @@ from .exceptions import ( ) from .jwks_client import PyJWKClient -__version__ = "2.8.0" +__version__ = "2.9.0" __title__ = "PyJWT" __description__ = "JSON Web Token implementation in Python" diff --git a/lib/jwt/algorithms.py b/lib/jwt/algorithms.py index ed187152..9be50b20 100644 --- a/lib/jwt/algorithms.py +++ b/lib/jwt/algorithms.py @@ -3,9 +3,8 @@ from __future__ import annotations import hashlib import hmac import json -import sys from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, ClassVar, NoReturn, Union, cast, overload +from typing import TYPE_CHECKING, Any, ClassVar, Literal, NoReturn, cast, overload from .exceptions import InvalidKeyError from .types import HashlibHash, JWKDict @@ -21,14 +20,8 @@ from .utils import ( to_base64url_uint, ) -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal - - try: - from cryptography.exceptions import InvalidSignature + from cryptography.exceptions import InvalidSignature, UnsupportedAlgorithm from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import padding @@ -194,18 +187,16 @@ class Algorithm(ABC): @overload @staticmethod @abstractmethod - def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: - ... # pragma: no cover + def to_jwk(key_obj, as_dict: Literal[True]) -> JWKDict: ... # pragma: no cover @overload @staticmethod @abstractmethod - def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: - ... # pragma: no cover + def to_jwk(key_obj, as_dict: Literal[False] = False) -> str: ... # pragma: no cover @staticmethod @abstractmethod - def to_jwk(key_obj, as_dict: bool = False) -> Union[JWKDict, str]: + def to_jwk(key_obj, as_dict: bool = False) -> JWKDict | str: """ Serializes a given key into a JWK """ @@ -274,16 +265,18 @@ class HMACAlgorithm(Algorithm): @overload @staticmethod - def to_jwk(key_obj: str | bytes, as_dict: Literal[True]) -> JWKDict: - ... # pragma: no cover + def to_jwk( + key_obj: str | bytes, as_dict: Literal[True] + ) -> JWKDict: ... # pragma: no cover @overload @staticmethod - def to_jwk(key_obj: str | bytes, as_dict: Literal[False] = False) -> str: - ... # pragma: no cover + def to_jwk( + key_obj: str | bytes, as_dict: Literal[False] = False + ) -> str: ... # pragma: no cover @staticmethod - def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> Union[JWKDict, str]: + def to_jwk(key_obj: str | bytes, as_dict: bool = False) -> JWKDict | str: jwk = { "k": base64url_encode(force_bytes(key_obj)).decode(), "kty": "oct", @@ -350,22 +343,25 @@ if has_crypto: RSAPrivateKey, load_pem_private_key(key_bytes, password=None) ) except ValueError: - return cast(RSAPublicKey, load_pem_public_key(key_bytes)) + try: + return cast(RSAPublicKey, load_pem_public_key(key_bytes)) + except (ValueError, UnsupportedAlgorithm): + raise InvalidKeyError("Could not parse the provided public key.") @overload - @staticmethod - def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[True]) -> JWKDict: - ... # pragma: no cover - - @overload - @staticmethod - def to_jwk(key_obj: AllowedRSAKeys, as_dict: Literal[False] = False) -> str: - ... # pragma: no cover - @staticmethod def to_jwk( - key_obj: AllowedRSAKeys, as_dict: bool = False - ) -> Union[JWKDict, str]: + key_obj: AllowedRSAKeys, as_dict: Literal[True] + ) -> JWKDict: ... # pragma: no cover + + @overload + @staticmethod + def to_jwk( + key_obj: AllowedRSAKeys, as_dict: Literal[False] = False + ) -> str: ... # pragma: no cover + + @staticmethod + def to_jwk(key_obj: AllowedRSAKeys, as_dict: bool = False) -> JWKDict | str: obj: dict[str, Any] | None = None if hasattr(key_obj, "private_numbers"): @@ -533,7 +529,7 @@ if has_crypto: return der_to_raw_signature(der_sig, key.curve) - def verify(self, msg: bytes, key: "AllowedECKeys", sig: bytes) -> bool: + def verify(self, msg: bytes, key: AllowedECKeys, sig: bytes) -> bool: try: der_sig = raw_to_der_signature(sig, key.curve) except ValueError: @@ -552,18 +548,18 @@ if has_crypto: @overload @staticmethod - def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[True]) -> JWKDict: - ... # pragma: no cover + def to_jwk( + key_obj: AllowedECKeys, as_dict: Literal[True] + ) -> JWKDict: ... # pragma: no cover @overload @staticmethod - def to_jwk(key_obj: AllowedECKeys, as_dict: Literal[False] = False) -> str: - ... # pragma: no cover + def to_jwk( + key_obj: AllowedECKeys, as_dict: Literal[False] = False + ) -> str: ... # pragma: no cover @staticmethod - def to_jwk( - key_obj: AllowedECKeys, as_dict: bool = False - ) -> Union[JWKDict, str]: + def to_jwk(key_obj: AllowedECKeys, as_dict: bool = False) -> JWKDict | str: if isinstance(key_obj, EllipticCurvePrivateKey): public_numbers = key_obj.public_key().public_numbers() elif isinstance(key_obj, EllipticCurvePublicKey): @@ -771,16 +767,18 @@ if has_crypto: @overload @staticmethod - def to_jwk(key: AllowedOKPKeys, as_dict: Literal[True]) -> JWKDict: - ... # pragma: no cover + def to_jwk( + key: AllowedOKPKeys, as_dict: Literal[True] + ) -> JWKDict: ... # pragma: no cover @overload @staticmethod - def to_jwk(key: AllowedOKPKeys, as_dict: Literal[False] = False) -> str: - ... # pragma: no cover + def to_jwk( + key: AllowedOKPKeys, as_dict: Literal[False] = False + ) -> str: ... # pragma: no cover @staticmethod - def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> Union[JWKDict, str]: + def to_jwk(key: AllowedOKPKeys, as_dict: bool = False) -> JWKDict | str: if isinstance(key, (Ed25519PublicKey, Ed448PublicKey)): x = key.public_bytes( encoding=Encoding.Raw, diff --git a/lib/jwt/api_jwk.py b/lib/jwt/api_jwk.py index 456c7f4d..02f4679c 100644 --- a/lib/jwt/api_jwk.py +++ b/lib/jwt/api_jwk.py @@ -5,7 +5,13 @@ import time from typing import Any from .algorithms import get_default_algorithms, has_crypto, requires_cryptography -from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError, PyJWTError +from .exceptions import ( + InvalidKeyError, + MissingCryptographyError, + PyJWKError, + PyJWKSetError, + PyJWTError, +) from .types import JWKDict @@ -50,21 +56,25 @@ class PyJWK: raise InvalidKeyError(f"Unsupported kty: {kty}") if not has_crypto and algorithm in requires_cryptography: - raise PyJWKError(f"{algorithm} requires 'cryptography' to be installed.") + raise MissingCryptographyError( + f"{algorithm} requires 'cryptography' to be installed." + ) - self.Algorithm = self._algorithms.get(algorithm) + self.algorithm_name = algorithm - if not self.Algorithm: + if algorithm in self._algorithms: + self.Algorithm = self._algorithms[algorithm] + else: raise PyJWKError(f"Unable to find an algorithm for key: {self._jwk_data}") self.key = self.Algorithm.from_jwk(self._jwk_data) @staticmethod - def from_dict(obj: JWKDict, algorithm: str | None = None) -> "PyJWK": + def from_dict(obj: JWKDict, algorithm: str | None = None) -> PyJWK: return PyJWK(obj, algorithm) @staticmethod - def from_json(data: str, algorithm: None = None) -> "PyJWK": + def from_json(data: str, algorithm: None = None) -> PyJWK: obj = json.loads(data) return PyJWK.from_dict(obj, algorithm) @@ -94,7 +104,9 @@ class PyJWKSet: for key in keys: try: self.keys.append(PyJWK(key)) - except PyJWTError: + except PyJWTError as error: + if isinstance(error, MissingCryptographyError): + raise error # skip unusable keys continue @@ -104,16 +116,16 @@ class PyJWKSet: ) @staticmethod - def from_dict(obj: dict[str, Any]) -> "PyJWKSet": + def from_dict(obj: dict[str, Any]) -> PyJWKSet: keys = obj.get("keys", []) return PyJWKSet(keys) @staticmethod - def from_json(data: str) -> "PyJWKSet": + def from_json(data: str) -> PyJWKSet: obj = json.loads(data) return PyJWKSet.from_dict(obj) - def __getitem__(self, kid: str) -> "PyJWK": + def __getitem__(self, kid: str) -> PyJWK: for key in self.keys: if key.key_id == kid: return key diff --git a/lib/jwt/api_jws.py b/lib/jwt/api_jws.py index fa6708cc..5822ebf6 100644 --- a/lib/jwt/api_jws.py +++ b/lib/jwt/api_jws.py @@ -11,6 +11,7 @@ from .algorithms import ( has_crypto, requires_cryptography, ) +from .api_jwk import PyJWK from .exceptions import ( DecodeError, InvalidAlgorithmError, @@ -172,7 +173,7 @@ class PyJWS: def decode_complete( self, jwt: str | bytes, - key: AllowedPublicKeys | str | bytes = "", + key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, @@ -190,7 +191,7 @@ class PyJWS: merged_options = {**self.options, **options} verify_signature = merged_options["verify_signature"] - if verify_signature and not algorithms: + if verify_signature and not algorithms and not isinstance(key, PyJWK): raise DecodeError( 'It is required that you pass in a value for the "algorithms" argument when calling decode().' ) @@ -217,7 +218,7 @@ class PyJWS: def decode( self, jwt: str | bytes, - key: AllowedPublicKeys | str | bytes = "", + key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, detached_payload: bytes | None = None, @@ -289,9 +290,11 @@ class PyJWS: signing_input: bytes, header: dict[str, Any], signature: bytes, - key: AllowedPublicKeys | str | bytes = "", + key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, ) -> None: + if algorithms is None and isinstance(key, PyJWK): + algorithms = [key.algorithm_name] try: alg = header["alg"] except KeyError: @@ -300,11 +303,15 @@ class PyJWS: if not alg or (algorithms is not None and alg not in algorithms): raise InvalidAlgorithmError("The specified alg value is not allowed") - try: - alg_obj = self.get_algorithm_by_name(alg) - except NotImplementedError as e: - raise InvalidAlgorithmError("Algorithm not supported") from e - prepared_key = alg_obj.prepare_key(key) + if isinstance(key, PyJWK): + alg_obj = key.Algorithm + prepared_key = key.key + else: + try: + alg_obj = self.get_algorithm_by_name(alg) + except NotImplementedError as e: + raise InvalidAlgorithmError("Algorithm not supported") from e + prepared_key = alg_obj.prepare_key(key) if not alg_obj.verify(signing_input, prepared_key, signature): raise InvalidSignatureError("Signature verification failed") diff --git a/lib/jwt/api_jwt.py b/lib/jwt/api_jwt.py index 48d739ad..7a07c336 100644 --- a/lib/jwt/api_jwt.py +++ b/lib/jwt/api_jwt.py @@ -5,7 +5,7 @@ import warnings from calendar import timegm from collections.abc import Iterable from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, List from . import api_jws from .exceptions import ( @@ -21,6 +21,7 @@ from .warnings import RemovedInPyjwt3Warning if TYPE_CHECKING: from .algorithms import AllowedPrivateKeys, AllowedPublicKeys + from .api_jwk import PyJWK class PyJWT: @@ -100,7 +101,7 @@ class PyJWT: def decode_complete( self, jwt: str | bytes, - key: AllowedPublicKeys | str | bytes = "", + key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 @@ -110,7 +111,7 @@ class PyJWT: # passthrough arguments to _validate_claims # consider putting in options audience: str | Iterable[str] | None = None, - issuer: str | None = None, + issuer: str | List[str] | None = None, leeway: float | timedelta = 0, # kwargs **kwargs: Any, @@ -185,7 +186,7 @@ class PyJWT: def decode( self, jwt: str | bytes, - key: AllowedPublicKeys | str | bytes = "", + key: AllowedPublicKeys | PyJWK | str | bytes = "", algorithms: list[str] | None = None, options: dict[str, Any] | None = None, # deprecated arg, remove in pyjwt3 @@ -195,7 +196,7 @@ class PyJWT: # passthrough arguments to _validate_claims # consider putting in options audience: str | Iterable[str] | None = None, - issuer: str | None = None, + issuer: str | List[str] | None = None, leeway: float | timedelta = 0, # kwargs **kwargs: Any, @@ -300,7 +301,7 @@ class PyJWT: try: exp = int(payload["exp"]) except ValueError: - raise DecodeError("Expiration Time claim (exp) must be an" " integer.") + raise DecodeError("Expiration Time claim (exp) must be an integer.") if exp <= (now - leeway): raise ExpiredSignatureError("Signature has expired") @@ -362,8 +363,12 @@ class PyJWT: if "iss" not in payload: raise MissingRequiredClaimError("iss") - if payload["iss"] != issuer: - raise InvalidIssuerError("Invalid issuer") + if isinstance(issuer, list): + if payload["iss"] not in issuer: + raise InvalidIssuerError("Invalid issuer") + else: + if payload["iss"] != issuer: + raise InvalidIssuerError("Invalid issuer") _jwt_global_obj = PyJWT() diff --git a/lib/jwt/exceptions.py b/lib/jwt/exceptions.py index 8ac6ecf7..0d985882 100644 --- a/lib/jwt/exceptions.py +++ b/lib/jwt/exceptions.py @@ -58,6 +58,10 @@ class PyJWKError(PyJWTError): pass +class MissingCryptographyError(PyJWKError): + pass + + class PyJWKSetError(PyJWTError): pass diff --git a/lib/jwt/utils.py b/lib/jwt/utils.py index 81c5ee41..d469139b 100644 --- a/lib/jwt/utils.py +++ b/lib/jwt/utils.py @@ -131,26 +131,15 @@ def is_pem_format(key: bytes) -> bool: # Based on https://github.com/pyca/cryptography/blob/bcb70852d577b3f490f015378c75cba74986297b/src/cryptography/hazmat/primitives/serialization/ssh.py#L40-L46 -_CERT_SUFFIX = b"-cert-v01@openssh.com" -_SSH_PUBKEY_RC = re.compile(rb"\A(\S+)[ \t]+(\S+)") -_SSH_KEY_FORMATS = [ +_SSH_KEY_FORMATS = ( b"ssh-ed25519", b"ssh-rsa", b"ssh-dss", b"ecdsa-sha2-nistp256", b"ecdsa-sha2-nistp384", b"ecdsa-sha2-nistp521", -] +) def is_ssh_key(key: bytes) -> bool: - if any(string_value in key for string_value in _SSH_KEY_FORMATS): - return True - - ssh_pubkey_match = _SSH_PUBKEY_RC.match(key) - if ssh_pubkey_match: - key_type = ssh_pubkey_match.group(1) - if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]: - return True - - return False + return key.startswith(_SSH_KEY_FORMATS) diff --git a/lib/more_itertools/__init__.py b/lib/more_itertools/__init__.py index 9c4662fc..2e2fcbbe 100644 --- a/lib/more_itertools/__init__.py +++ b/lib/more_itertools/__init__.py @@ -3,4 +3,4 @@ from .more import * # noqa from .recipes import * # noqa -__version__ = '10.3.0' +__version__ = '10.4.0' diff --git a/lib/more_itertools/more.py b/lib/more_itertools/more.py index 7b481907..3bf2c76b 100755 --- a/lib/more_itertools/more.py +++ b/lib/more_itertools/more.py @@ -3,8 +3,9 @@ import warnings from collections import Counter, defaultdict, deque, abc from collections.abc import Sequence +from contextlib import suppress from functools import cached_property, partial, reduce, wraps -from heapq import heapify, heapreplace, heappop +from heapq import heapify, heapreplace from itertools import ( chain, combinations, @@ -21,10 +22,10 @@ from itertools import ( zip_longest, product, ) -from math import comb, e, exp, factorial, floor, fsum, log, perm, tau +from math import comb, e, exp, factorial, floor, fsum, log, log1p, perm, tau from queue import Empty, Queue -from random import random, randrange, uniform -from operator import itemgetter, mul, sub, gt, lt, ge, le +from random import random, randrange, shuffle, uniform +from operator import itemgetter, mul, sub, gt, lt, le from sys import hexversion, maxsize from time import monotonic @@ -34,7 +35,6 @@ from .recipes import ( UnequalIterablesError, consume, flatten, - pairwise, powerset, take, unique_everseen, @@ -473,12 +473,10 @@ def ilen(iterable): This consumes the iterable, so handle with care. """ - # This approach was selected because benchmarks showed it's likely the - # fastest of the known implementations at the time of writing. - # See GitHub tracker: #236, #230. - counter = count() - deque(zip(iterable, counter), maxlen=0) - return next(counter) + # This is the "most beautiful of the fast variants" of this function. + # If you think you can improve on it, please ensure that your version + # is both 10x faster and 10x more beautiful. + return sum(compress(repeat(1), zip(iterable))) def iterate(func, start): @@ -666,9 +664,9 @@ def distinct_permutations(iterable, r=None): >>> sorted(distinct_permutations([1, 0, 1])) [(0, 1, 1), (1, 0, 1), (1, 1, 0)] - Equivalent to ``set(permutations(iterable))``, except duplicates are not - generated and thrown away. For larger input sequences this is much more - efficient. + Equivalent to yielding from ``set(permutations(iterable))``, except + duplicates are not generated and thrown away. For larger input sequences + this is much more efficient. Duplicate permutations arise when there are duplicated elements in the input iterable. The number of items returned is @@ -683,6 +681,25 @@ def distinct_permutations(iterable, r=None): >>> sorted(distinct_permutations(range(3), r=2)) [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + *iterable* need not be sortable, but note that using equal (``x == y``) + but non-identical (``id(x) != id(y)``) elements may produce surprising + behavior. For example, ``1`` and ``True`` are equal but non-identical: + + >>> list(distinct_permutations([1, True, '3'])) # doctest: +SKIP + [ + (1, True, '3'), + (1, '3', True), + ('3', 1, True) + ] + >>> list(distinct_permutations([1, 2, '3'])) # doctest: +SKIP + [ + (1, 2, '3'), + (1, '3', 2), + (2, 1, '3'), + (2, '3', 1), + ('3', 1, 2), + ('3', 2, 1) + ] """ # Algorithm: https://w.wiki/Qai @@ -749,14 +766,44 @@ def distinct_permutations(iterable, r=None): i += 1 head[i:], tail[:] = tail[: r - i], tail[r - i :] - items = sorted(iterable) + items = list(iterable) + + try: + items.sort() + sortable = True + except TypeError: + sortable = False + + indices_dict = defaultdict(list) + + for item in items: + indices_dict[items.index(item)].append(item) + + indices = [items.index(item) for item in items] + indices.sort() + + equivalent_items = {k: cycle(v) for k, v in indices_dict.items()} + + def permuted_items(permuted_indices): + return tuple( + next(equivalent_items[index]) for index in permuted_indices + ) size = len(items) if r is None: r = size + # functools.partial(_partial, ... ) + algorithm = _full if (r == size) else partial(_partial, r=r) + if 0 < r <= size: - return _full(items) if (r == size) else _partial(items, r) + if sortable: + return algorithm(items) + else: + return ( + permuted_items(permuted_indices) + for permuted_indices in algorithm(indices) + ) return iter(() if r else ((),)) @@ -1743,7 +1790,9 @@ def zip_offset(*iterables, offsets, longest=False, fillvalue=None): return zip(*staggered) -def sort_together(iterables, key_list=(0,), key=None, reverse=False): +def sort_together( + iterables, key_list=(0,), key=None, reverse=False, strict=False +): """Return the input iterables sorted together, with *key_list* as the priority for sorting. All iterables are trimmed to the length of the shortest one. @@ -1782,6 +1831,10 @@ def sort_together(iterables, key_list=(0,), key=None, reverse=False): >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True) [(3, 2, 1), ('a', 'b', 'c')] + If the *strict* keyword argument is ``True``, then + ``UnequalIterablesError`` will be raised if any of the iterables have + different lengths. + """ if key is None: # if there is no key function, the key argument to sorted is an @@ -1804,8 +1857,9 @@ def sort_together(iterables, key_list=(0,), key=None, reverse=False): *get_key_items(zipped_items) ) + zipper = zip_equal if strict else zip return list( - zip(*sorted(zip(*iterables), key=key_argument, reverse=reverse)) + zipper(*sorted(zipper(*iterables), key=key_argument, reverse=reverse)) ) @@ -2747,8 +2801,6 @@ class seekable: >>> it.seek(0) >>> next(it), next(it), next(it) ('0', '1', '2') - >>> next(it) - '3' You can also seek forward: @@ -2756,15 +2808,29 @@ class seekable: >>> it.seek(10) >>> next(it) '10' - >>> it.relative_seek(-2) # Seeking relative to the current position - >>> next(it) - '9' >>> it.seek(20) # Seeking past the end of the source isn't a problem >>> list(it) [] >>> it.seek(0) # Resetting works even after hitting the end + >>> next(it) + '0' + + Call :meth:`relative_seek` to seek relative to the source iterator's + current position. + + >>> it = seekable((str(n) for n in range(20))) >>> next(it), next(it), next(it) ('0', '1', '2') + >>> it.relative_seek(2) + >>> next(it) + '5' + >>> it.relative_seek(-3) # Source is at '6', we move back to '3' + >>> next(it) + '3' + >>> it.relative_seek(-3) # Source is at '4', we move back to '1' + >>> next(it) + '1' + Call :meth:`peek` to look ahead one item without advancing the iterator: @@ -2873,8 +2939,10 @@ class seekable: consume(self, remainder) def relative_seek(self, count): - index = len(self._cache) - self.seek(max(index + count, 0)) + if self._index is None: + self._index = len(self._cache) + + self.seek(max(self._index + count, 0)) class run_length: @@ -2903,7 +2971,7 @@ class run_length: @staticmethod def decode(iterable): - return chain.from_iterable(repeat(k, n) for k, n in iterable) + return chain.from_iterable(starmap(repeat, iterable)) def exactly_n(iterable, n, predicate=bool): @@ -2924,14 +2992,34 @@ def exactly_n(iterable, n, predicate=bool): return len(take(n + 1, filter(predicate, iterable))) == n -def circular_shifts(iterable): - """Return a list of circular shifts of *iterable*. +def circular_shifts(iterable, steps=1): + """Yield the circular shifts of *iterable*. - >>> circular_shifts(range(4)) + >>> list(circular_shifts(range(4))) [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] + + Set *steps* to the number of places to rotate to the left + (or to the right if negative). Defaults to 1. + + >>> list(circular_shifts(range(4), 2)) + [(0, 1, 2, 3), (2, 3, 0, 1)] + + >>> list(circular_shifts(range(4), -1)) + [(0, 1, 2, 3), (3, 0, 1, 2), (2, 3, 0, 1), (1, 2, 3, 0)] + """ - lst = list(iterable) - return take(len(lst), windowed(cycle(lst), len(lst))) + buffer = deque(iterable) + if steps == 0: + raise ValueError('Steps should be a non-zero integer') + + buffer.rotate(steps) + steps = -steps + n = len(buffer) + n //= math.gcd(n, steps) + + for __ in repeat(None, n): + buffer.rotate(steps) + yield tuple(buffer) def make_decorator(wrapping_func, result_index=0): @@ -3191,7 +3279,7 @@ def partitions(iterable): yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))] -def set_partitions(iterable, k=None): +def set_partitions(iterable, k=None, min_size=None, max_size=None): """ Yield the set partitions of *iterable* into *k* parts. Set partitions are not order-preserving. @@ -3215,6 +3303,20 @@ def set_partitions(iterable, k=None): ['b', 'ac'] ['a', 'b', 'c'] + if *min_size* and/or *max_size* are given, the minimum and/or maximum size + per block in partition is set. + + >>> iterable = 'abc' + >>> for part in set_partitions(iterable, min_size=2): + ... print([''.join(p) for p in part]) + ['abc'] + >>> for part in set_partitions(iterable, max_size=2): + ... print([''.join(p) for p in part]) + ['a', 'bc'] + ['ab', 'c'] + ['b', 'ac'] + ['a', 'b', 'c'] + """ L = list(iterable) n = len(L) @@ -3226,6 +3328,11 @@ def set_partitions(iterable, k=None): elif k > n: return + min_size = min_size if min_size is not None else 0 + max_size = max_size if max_size is not None else n + if min_size > max_size: + return + def set_partitions_helper(L, k): n = len(L) if k == 1: @@ -3242,9 +3349,15 @@ def set_partitions(iterable, k=None): if k is None: for k in range(1, n + 1): - yield from set_partitions_helper(L, k) + yield from filter( + lambda z: all(min_size <= len(bk) <= max_size for bk in z), + set_partitions_helper(L, k), + ) else: - yield from set_partitions_helper(L, k) + yield from filter( + lambda z: all(min_size <= len(bk) <= max_size for bk in z), + set_partitions_helper(L, k), + ) class time_limited: @@ -3535,32 +3648,27 @@ def map_if(iterable, pred, func, func_else=lambda x: x): yield func(item) if pred(item) else func_else(item) -def _sample_unweighted(iterable, k): - # Implementation of "Algorithm L" from the 1994 paper by Kim-Hung Li: +def _sample_unweighted(iterator, k, strict): + # Algorithm L in the 1994 paper by Kim-Hung Li: # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))". - # Fill up the reservoir (collection of samples) with the first `k` samples - reservoir = take(k, iterable) + reservoir = list(islice(iterator, k)) + if strict and len(reservoir) < k: + raise ValueError('Sample larger than population') + W = 1.0 - # Generate random number that's the largest in a sample of k U(0,1) numbers - # Largest order statistic: https://en.wikipedia.org/wiki/Order_statistic - W = exp(log(random()) / k) - - # The number of elements to skip before changing the reservoir is a random - # number with a geometric distribution. Sample it using random() and logs. - next_index = k + floor(log(random()) / log(1 - W)) - - for index, element in enumerate(iterable, k): - if index == next_index: - reservoir[randrange(k)] = element - # The new W is the largest in a sample of k U(0, `old_W`) numbers + with suppress(StopIteration): + while True: W *= exp(log(random()) / k) - next_index += floor(log(random()) / log(1 - W)) + 1 + skip = floor(log(random()) / log1p(-W)) + element = next(islice(iterator, skip, None)) + reservoir[randrange(k)] = element + shuffle(reservoir) return reservoir -def _sample_weighted(iterable, k, weights): +def _sample_weighted(iterator, k, weights, strict): # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. : # "Weighted random sampling with a reservoir". @@ -3569,7 +3677,10 @@ def _sample_weighted(iterable, k, weights): # Fill up the reservoir (collection of samples) with the first `k` # weight-keys and elements, then heapify the list. - reservoir = take(k, zip(weight_keys, iterable)) + reservoir = take(k, zip(weight_keys, iterator)) + if strict and len(reservoir) < k: + raise ValueError('Sample larger than population') + heapify(reservoir) # The number of jumps before changing the reservoir is a random variable @@ -3577,7 +3688,7 @@ def _sample_weighted(iterable, k, weights): smallest_weight_key, _ = reservoir[0] weights_to_skip = log(random()) / smallest_weight_key - for weight, element in zip(weights, iterable): + for weight, element in zip(weights, iterator): if weight >= weights_to_skip: # The notation here is consistent with the paper, but we store # the weight-keys in log-space for better numerical stability. @@ -3591,44 +3702,103 @@ def _sample_weighted(iterable, k, weights): else: weights_to_skip -= weight - # Equivalent to [element for weight_key, element in sorted(reservoir)] - return [heappop(reservoir)[1] for _ in range(k)] + ret = [element for weight_key, element in reservoir] + shuffle(ret) + return ret -def sample(iterable, k, weights=None): +def _sample_counted(population, k, counts, strict): + element = None + remaining = 0 + + def feed(i): + # Advance *i* steps ahead and consume an element + nonlocal element, remaining + + while i + 1 > remaining: + i = i - remaining + element = next(population) + remaining = next(counts) + remaining -= i + 1 + return element + + with suppress(StopIteration): + reservoir = [] + for _ in range(k): + reservoir.append(feed(0)) + if strict and len(reservoir) < k: + raise ValueError('Sample larger than population') + + W = 1.0 + while True: + W *= exp(log(random()) / k) + skip = floor(log(random()) / log1p(-W)) + element = feed(skip) + reservoir[randrange(k)] = element + + shuffle(reservoir) + return reservoir + + +def sample(iterable, k, weights=None, *, counts=None, strict=False): """Return a *k*-length list of elements chosen (without replacement) - from the *iterable*. Like :func:`random.sample`, but works on iterables - of unknown length. + from the *iterable*. Similar to :func:`random.sample`, but works on + iterables of unknown length. >>> iterable = range(100) >>> sample(iterable, 5) # doctest: +SKIP [81, 60, 96, 16, 4] - An iterable with *weights* may also be given: + For iterables with repeated elements, you may supply *counts* to + indicate the repeats. + + >>> iterable = ['a', 'b'] + >>> counts = [3, 4] # Equivalent to 'a', 'a', 'a', 'b', 'b', 'b', 'b' + >>> sample(iterable, k=3, counts=counts) # doctest: +SKIP + ['a', 'a', 'b'] + + An iterable with *weights* may be given: >>> iterable = range(100) >>> weights = (i * i + 1 for i in range(100)) >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP [79, 67, 74, 66, 78] - The algorithm can also be used to generate weighted random permutations. - The relative weight of each item determines the probability that it - appears late in the permutation. + Weighted selections are made without replacement. + After an element is selected, it is removed from the pool and the + relative weights of the other elements increase (this + does not match the behavior of :func:`random.sample`'s *counts* + parameter). Note that *weights* may not be used with *counts*. - >>> data = "abcdefgh" - >>> weights = range(1, len(data) + 1) - >>> sample(data, k=len(data), weights=weights) # doctest: +SKIP - ['c', 'a', 'b', 'e', 'g', 'd', 'h', 'f'] + If the length of *iterable* is less than *k*, + ``ValueError`` is raised if *strict* is ``True`` and + all elements are returned (in shuffled order) if *strict* is ``False``. + + By default, the `Algorithm L `__ reservoir sampling + technique is used. When *weights* are provided, + `Algorithm A-ExpJ `__ is used. """ + iterator = iter(iterable) + + if k < 0: + raise ValueError('k must be non-negative') + if k == 0: return [] - iterable = iter(iterable) - if weights is None: - return _sample_unweighted(iterable, k) - else: + if weights is not None and counts is not None: + raise TypeError('weights and counts are mutally exclusive') + + elif weights is not None: weights = iter(weights) - return _sample_weighted(iterable, k, weights) + return _sample_weighted(iterator, k, weights, strict) + + elif counts is not None: + counts = iter(counts) + return _sample_counted(iterator, k, counts, strict) + + else: + return _sample_unweighted(iterator, k, strict) def is_sorted(iterable, key=None, reverse=False, strict=False): @@ -3650,12 +3820,16 @@ def is_sorted(iterable, key=None, reverse=False, strict=False): False The function returns ``False`` after encountering the first out-of-order - item. If there are no out-of-order items, the iterable is exhausted. + item, which means it may produce results that differ from the built-in + :func:`sorted` function for objects with unusual comparison dynamics. + If there are no out-of-order items, the iterable is exhausted. """ + compare = le if strict else lt + it = iterable if (key is None) else map(key, iterable) + it_1, it_2 = tee(it) + next(it_2 if reverse else it_1, None) - compare = (le if reverse else ge) if strict else (lt if reverse else gt) - it = iterable if key is None else map(key, iterable) - return not any(starmap(compare, pairwise(it))) + return not any(map(compare, it_1, it_2)) class AbortThread(BaseException): diff --git a/lib/more_itertools/more.pyi b/lib/more_itertools/more.pyi index e9460232..f1a155dc 100644 --- a/lib/more_itertools/more.pyi +++ b/lib/more_itertools/more.pyi @@ -2,6 +2,8 @@ from __future__ import annotations +import sys + from types import TracebackType from typing import ( Any, @@ -28,6 +30,9 @@ from typing_extensions import Protocol _T = TypeVar('_T') _T1 = TypeVar('_T1') _T2 = TypeVar('_T2') +_T3 = TypeVar('_T3') +_T4 = TypeVar('_T4') +_T5 = TypeVar('_T5') _U = TypeVar('_U') _V = TypeVar('_V') _W = TypeVar('_W') @@ -35,6 +40,12 @@ _T_co = TypeVar('_T_co', covariant=True) _GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[Any]]) _Raisable = BaseException | Type[BaseException] +# The type of isinstance's second argument (from typeshed builtins) +if sys.version_info >= (3, 10): + _ClassInfo = type | UnionType | tuple[_ClassInfo, ...] +else: + _ClassInfo = type | tuple[_ClassInfo, ...] + @type_check_only class _SizedIterable(Protocol[_T_co], Sized, Iterable[_T_co]): ... @@ -135,7 +146,7 @@ def interleave_evenly( ) -> Iterator[_T]: ... def collapse( iterable: Iterable[Any], - base_type: type | None = ..., + base_type: _ClassInfo | None = ..., levels: int | None = ..., ) -> Iterator[Any]: ... @overload @@ -213,6 +224,7 @@ def stagger( class UnequalIterablesError(ValueError): def __init__(self, details: tuple[int, int, int] | None = ...) -> None: ... +# zip_equal @overload def zip_equal(__iter1: Iterable[_T1]) -> Iterator[tuple[_T1]]: ... @overload @@ -221,11 +233,35 @@ def zip_equal( ) -> Iterator[tuple[_T1, _T2]]: ... @overload def zip_equal( - __iter1: Iterable[_T], - __iter2: Iterable[_T], - __iter3: Iterable[_T], - *iterables: Iterable[_T], -) -> Iterator[tuple[_T, ...]]: ... + __iter1: Iterable[_T1], __iter2: Iterable[_T2], __iter3: Iterable[_T3] +) -> Iterator[tuple[_T1, _T2, _T3]]: ... +@overload +def zip_equal( + __iter1: Iterable[_T1], + __iter2: Iterable[_T2], + __iter3: Iterable[_T3], + __iter4: Iterable[_T4], +) -> Iterator[tuple[_T1, _T2, _T3, _T4]]: ... +@overload +def zip_equal( + __iter1: Iterable[_T1], + __iter2: Iterable[_T2], + __iter3: Iterable[_T3], + __iter4: Iterable[_T4], + __iter5: Iterable[_T5], +) -> Iterator[tuple[_T1, _T2, _T3, _T4, _T5]]: ... +@overload +def zip_equal( + __iter1: Iterable[Any], + __iter2: Iterable[Any], + __iter3: Iterable[Any], + __iter4: Iterable[Any], + __iter5: Iterable[Any], + __iter6: Iterable[Any], + *iterables: Iterable[Any], +) -> Iterator[tuple[Any, ...]]: ... + +# zip_offset @overload def zip_offset( __iter1: Iterable[_T1], @@ -285,12 +321,13 @@ def sort_together( key_list: Iterable[int] = ..., key: Callable[..., Any] | None = ..., reverse: bool = ..., + strict: bool = ..., ) -> list[tuple[_T, ...]]: ... def unzip(iterable: Iterable[Sequence[_T]]) -> tuple[Iterator[_T], ...]: ... def divide(n: int, iterable: Iterable[_T]) -> list[Iterator[_T]]: ... def always_iterable( obj: object, - base_type: type | tuple[type | tuple[Any, ...], ...] | None = ..., + base_type: _ClassInfo | None = ..., ) -> Iterator[Any]: ... def adjacent( predicate: Callable[[_T], bool], @@ -454,7 +491,9 @@ class run_length: def exactly_n( iterable: Iterable[_T], n: int, predicate: Callable[[_T], object] = ... ) -> bool: ... -def circular_shifts(iterable: Iterable[_T]) -> list[tuple[_T, ...]]: ... +def circular_shifts( + iterable: Iterable[_T], steps: int = 1 +) -> list[tuple[_T, ...]]: ... def make_decorator( wrapping_func: Callable[..., _U], result_index: int = ... ) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., _U]]]: ... @@ -500,7 +539,10 @@ def replace( ) -> Iterator[_T | _U]: ... def partitions(iterable: Iterable[_T]) -> Iterator[list[list[_T]]]: ... def set_partitions( - iterable: Iterable[_T], k: int | None = ... + iterable: Iterable[_T], + k: int | None = ..., + min_size: int | None = ..., + max_size: int | None = ..., ) -> Iterator[list[list[_T]]]: ... class time_limited(Generic[_T], Iterator[_T]): @@ -538,10 +580,22 @@ def map_if( func: Callable[[Any], Any], func_else: Callable[[Any], Any] | None = ..., ) -> Iterator[Any]: ... +def _sample_unweighted( + iterator: Iterator[_T], k: int, strict: bool +) -> list[_T]: ... +def _sample_counted( + population: Iterator[_T], k: int, counts: Iterable[int], strict: bool +) -> list[_T]: ... +def _sample_weighted( + iterator: Iterator[_T], k: int, weights, strict +) -> list[_T]: ... def sample( iterable: Iterable[_T], k: int, weights: Iterable[float] | None = ..., + *, + counts: Iterable[int] | None = ..., + strict: bool = False, ) -> list[_T]: ... def is_sorted( iterable: Iterable[_T], @@ -577,7 +631,7 @@ class callback_iter(Generic[_T], Iterator[_T]): def windowed_complete( iterable: Iterable[_T], n: int -) -> Iterator[tuple[_T, ...]]: ... +) -> Iterator[tuple[tuple[_T, ...], tuple[_T, ...], tuple[_T, ...]]]: ... def all_unique( iterable: Iterable[_T], key: Callable[[_T], _U] | None = ... ) -> bool: ... @@ -608,9 +662,61 @@ class countable(Generic[_T], Iterator[_T]): items_seen: int def chunked_even(iterable: Iterable[_T], n: int) -> Iterator[list[_T]]: ... +@overload def zip_broadcast( + __obj1: _T | Iterable[_T], + *, + scalar_types: _ClassInfo | None = ..., + strict: bool = ..., +) -> Iterable[tuple[_T, ...]]: ... +@overload +def zip_broadcast( + __obj1: _T | Iterable[_T], + __obj2: _T | Iterable[_T], + *, + scalar_types: _ClassInfo | None = ..., + strict: bool = ..., +) -> Iterable[tuple[_T, ...]]: ... +@overload +def zip_broadcast( + __obj1: _T | Iterable[_T], + __obj2: _T | Iterable[_T], + __obj3: _T | Iterable[_T], + *, + scalar_types: _ClassInfo | None = ..., + strict: bool = ..., +) -> Iterable[tuple[_T, ...]]: ... +@overload +def zip_broadcast( + __obj1: _T | Iterable[_T], + __obj2: _T | Iterable[_T], + __obj3: _T | Iterable[_T], + __obj4: _T | Iterable[_T], + *, + scalar_types: _ClassInfo | None = ..., + strict: bool = ..., +) -> Iterable[tuple[_T, ...]]: ... +@overload +def zip_broadcast( + __obj1: _T | Iterable[_T], + __obj2: _T | Iterable[_T], + __obj3: _T | Iterable[_T], + __obj4: _T | Iterable[_T], + __obj5: _T | Iterable[_T], + *, + scalar_types: _ClassInfo | None = ..., + strict: bool = ..., +) -> Iterable[tuple[_T, ...]]: ... +@overload +def zip_broadcast( + __obj1: _T | Iterable[_T], + __obj2: _T | Iterable[_T], + __obj3: _T | Iterable[_T], + __obj4: _T | Iterable[_T], + __obj5: _T | Iterable[_T], + __obj6: _T | Iterable[_T], *objects: _T | Iterable[_T], - scalar_types: type | tuple[type | tuple[Any, ...], ...] | None = ..., + scalar_types: _ClassInfo | None = ..., strict: bool = ..., ) -> Iterable[tuple[_T, ...]]: ... def unique_in_window( diff --git a/lib/more_itertools/recipes.py b/lib/more_itertools/recipes.py index b32fa955..a21a1f5d 100644 --- a/lib/more_itertools/recipes.py +++ b/lib/more_itertools/recipes.py @@ -795,8 +795,30 @@ def triplewise(iterable): [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')] """ - for (a, _), (b, c) in pairwise(pairwise(iterable)): - yield a, b, c + # This deviates from the itertools documentation reciple - see + # https://github.com/more-itertools/more-itertools/issues/889 + t1, t2, t3 = tee(iterable, 3) + next(t3, None) + next(t3, None) + next(t2, None) + return zip(t1, t2, t3) + + +def _sliding_window_islice(iterable, n): + # Fast path for small, non-zero values of n. + iterators = tee(iterable, n) + for i, iterator in enumerate(iterators): + next(islice(iterator, i, i), None) + return zip(*iterators) + + +def _sliding_window_deque(iterable, n): + # Normal path for other values of n. + it = iter(iterable) + window = deque(islice(it, n - 1), maxlen=n) + for x in it: + window.append(x) + yield tuple(window) def sliding_window(iterable, n): @@ -812,11 +834,16 @@ def sliding_window(iterable, n): For a variant with more features, see :func:`windowed`. """ - it = iter(iterable) - window = deque(islice(it, n - 1), maxlen=n) - for x in it: - window.append(x) - yield tuple(window) + if n > 20: + return _sliding_window_deque(iterable, n) + elif n > 2: + return _sliding_window_islice(iterable, n) + elif n == 2: + return pairwise(iterable) + elif n == 1: + return zip(iterable) + else: + raise ValueError(f'n should be at least one, not {n}') def subslices(iterable): @@ -1038,9 +1065,6 @@ def totient(n): >>> totient(12) 4 """ - # The itertools docs use unique_justseen instead of set; see - # https://github.com/more-itertools/more-itertools/issues/823 - for p in set(factor(n)): - n = n // p * (p - 1) - + for prime in set(factor(n)): + n -= n // prime return n diff --git a/lib/tempora/__init__.py b/lib/tempora/__init__.py index b2690a74..a41d02c2 100644 --- a/lib/tempora/__init__.py +++ b/lib/tempora/__init__.py @@ -10,6 +10,9 @@ from numbers import Number from typing import Union, Tuple, Iterable from typing import cast +import dateutil.parser +import dateutil.tz + # some useful constants osc_per_year = 290_091_329_207_984_000 @@ -611,3 +614,40 @@ def date_range(start=None, stop=None, step=None): while start < stop: yield start start += step + + +tzinfos = dict( + AEST=dateutil.tz.gettz("Australia/Sydney"), + AEDT=dateutil.tz.gettz("Australia/Sydney"), + ACST=dateutil.tz.gettz("Australia/Darwin"), + ACDT=dateutil.tz.gettz("Australia/Adelaide"), + AWST=dateutil.tz.gettz("Australia/Perth"), + EST=dateutil.tz.gettz("America/New_York"), + EDT=dateutil.tz.gettz("America/New_York"), + CST=dateutil.tz.gettz("America/Chicago"), + CDT=dateutil.tz.gettz("America/Chicago"), + MST=dateutil.tz.gettz("America/Denver"), + MDT=dateutil.tz.gettz("America/Denver"), + PST=dateutil.tz.gettz("America/Los_Angeles"), + PDT=dateutil.tz.gettz("America/Los_Angeles"), + GMT=dateutil.tz.gettz("Etc/GMT"), + UTC=dateutil.tz.gettz("UTC"), + CET=dateutil.tz.gettz("Europe/Berlin"), + CEST=dateutil.tz.gettz("Europe/Berlin"), + IST=dateutil.tz.gettz("Asia/Kolkata"), + BST=dateutil.tz.gettz("Europe/London"), + MSK=dateutil.tz.gettz("Europe/Moscow"), + EET=dateutil.tz.gettz("Europe/Helsinki"), + EEST=dateutil.tz.gettz("Europe/Helsinki"), + # Add more mappings as needed +) + + +def parse(*args, **kwargs): + """ + Parse the input using dateutil.parser.parse with friendly tz support. + + >>> parse('2024-07-26 12:59:00 EDT') + datetime.datetime(...America/New_York...) + """ + return dateutil.parser.parse(*args, tzinfos=tzinfos, **kwargs) diff --git a/package/requirements-package.txt b/package/requirements-package.txt index 2f21fea4..0eed49c9 100644 --- a/package/requirements-package.txt +++ b/package/requirements-package.txt @@ -1,5 +1,5 @@ apscheduler==3.10.1 -importlib-metadata==8.0.0 +importlib-metadata==8.2.0 importlib-resources==6.4.0 pyinstaller==6.8.0 pyopenssl==24.1.0 diff --git a/plexpy/__init__.py b/plexpy/__init__.py index 0366fcde..21b5ef6d 100644 --- a/plexpy/__init__.py +++ b/plexpy/__init__.py @@ -1504,6 +1504,18 @@ def dbcheck(): except sqlite3.OperationalError: logger.warn("Unable to capitalize Windows platform values in session_history table.") + # Upgrade session_history table from earlier versions + try: + result = c_db.execute("SELECT platform FROM session_history " + "WHERE platform = 'macos'").fetchall() + if len(result) > 0: + logger.debug("Altering database. Capitalizing macOS platform values in session_history table.") + c_db.execute( + "UPDATE session_history SET platform = 'macOS' WHERE platform = 'macos' " + ) + except sqlite3.OperationalError: + logger.warn("Unable to capitalize macOS platform values in session_history table.") + # Upgrade session_history_metadata table from earlier versions try: c_db.execute("SELECT full_title FROM session_history_metadata") diff --git a/plexpy/common.py b/plexpy/common.py index fb35beb3..8d68e2ad 100644 --- a/plexpy/common.py +++ b/plexpy/common.py @@ -102,7 +102,8 @@ PLATFORM_NAME_OVERRIDES = { 'Mystery 5': 'Xbox 360', 'WebMAF': 'Playstation 4', 'windows': 'Windows', - 'osx': 'macOS' + 'osx': 'macOS', + 'macos': 'macOS', } PMS_PLATFORM_NAME_OVERRIDES = { diff --git a/requirements.txt b/requirements.txt index f7fb1581..d0ee1941 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ bleach==6.1.0 certifi==2024.7.4 cheroot==10.0.1 cherrypy==18.10.0 -cloudinary==1.40.0 +cloudinary==1.41.0 distro==1.9.0 dnspython==2.6.1 facebook-sdk==3.1.0 @@ -16,7 +16,7 @@ gntp==1.0.3 html5lib==1.1 httpagentparser==1.9.5 idna==3.7 -importlib-metadata==8.0.0 +importlib-metadata==8.2.0 importlib-resources==6.4.0 git+https://github.com/Tautulli/ipwhois.git@master#egg=ipwhois IPy==1.01 @@ -29,7 +29,7 @@ platformdirs==4.2.2 plexapi==4.15.15 portend==3.2.0 profilehooks==1.12.0 -PyJWT==2.8.0 +PyJWT==2.9.0 pyparsing==3.1.2 python-dateutil==2.9.0.post0 python-twitter==3.5 @@ -39,7 +39,7 @@ requests-oauthlib==2.0.0 rumps==0.4.0; platform_system == "Darwin" simplejson==3.19.2 six==1.16.0 -tempora==5.6.0 +tempora==5.7.0 tokenize-rt==6.0.0 tzdata==2024.1 tzlocal==5.0.1