From 684cca8c9b72d88e6ee03c6b3b5c149758e18fc2 Mon Sep 17 00:00:00 2001 From: Labrys of Knossos Date: Tue, 29 Nov 2022 01:26:33 -0500 Subject: [PATCH] Updated more-itertools to 5.0.0 --- libs/common/bin/beet.exe | Bin 108369 -> 108369 bytes libs/common/bin/chardetect.exe | Bin 108383 -> 108383 bytes libs/common/bin/easy_install-3.7.exe | Bin 108392 -> 108392 bytes libs/common/bin/easy_install.exe | Bin 108392 -> 108392 bytes libs/common/bin/guessit.exe | Bin 108377 -> 108377 bytes libs/common/bin/mid3cp.exe | Bin 108396 -> 108396 bytes libs/common/bin/mid3iconv.exe | Bin 108399 -> 108399 bytes libs/common/bin/mid3v2.exe | Bin 108396 -> 108396 bytes libs/common/bin/moggsplit.exe | Bin 108399 -> 108399 bytes libs/common/bin/mutagen-inspect.exe | Bin 108405 -> 108405 bytes libs/common/bin/mutagen-pony.exe | Bin 108402 -> 108402 bytes libs/common/bin/pbr.exe | Bin 108373 -> 108373 bytes libs/common/bin/srt.exe | Bin 108375 -> 108375 bytes libs/common/bin/subliminal.exe | Bin 108387 -> 108387 bytes libs/common/bin/unidecode.exe | Bin 108375 -> 108375 bytes libs/common/more_itertools/__init__.py | 8 +- libs/common/more_itertools/__init__.pyi | 2 - libs/common/more_itertools/more.py | 2630 ++--------------- libs/common/more_itertools/more.pyi | 674 ----- libs/common/more_itertools/recipes.py | 464 +-- libs/common/more_itertools/recipes.pyi | 110 - .../{py.typed => tests/__init__.py} | 0 libs/common/more_itertools/tests/test_more.py | 2313 +++++++++++++++ .../more_itertools/tests/test_recipes.py | 616 ++++ 24 files changed, 3339 insertions(+), 3478 deletions(-) delete mode 100644 libs/common/more_itertools/__init__.pyi delete mode 100644 libs/common/more_itertools/more.pyi delete mode 100644 libs/common/more_itertools/recipes.pyi rename libs/common/more_itertools/{py.typed => tests/__init__.py} (100%) create mode 100644 libs/common/more_itertools/tests/test_more.py create mode 100644 libs/common/more_itertools/tests/test_recipes.py diff --git a/libs/common/bin/beet.exe b/libs/common/bin/beet.exe index b749e3ea4c425fbcdcd83832694fb0b4b19b868b..56dacbbb8a01e7ec29d0d7abe225dd5d42310f23 100644 GIT binary patch delta 27 hcmcb3j_u+(wuUW?>dTl+xudTnAxu$C_V|)st1(q{90sy1Z3ReIC diff --git a/libs/common/bin/chardetect.exe b/libs/common/bin/chardetect.exe index 01d5a18f7f534aa4319da9dd996705d554681347..c2e6db4de8f68a98c5d8274c46b31b8a65086bbf 100644 GIT binary patch delta 27 hcmcbAj_v+AwuUW?>dTl+xTk9_WBd%F<(4x#0sx=v3F80& delta 27 hcmcbAj_v+AwuUW?>dTnAxTb3^WBd%F<(4x#0syBx3XK2& diff --git a/libs/common/bin/easy_install-3.7.exe b/libs/common/bin/easy_install-3.7.exe index 04194cd016fc9200017a2512f82baddd8478e5a0..86b781b0042091569b3c229be1d29950150904ee 100644 GIT binary patch delta 27 hcmaEHj_t)cwuUW?>dTmnxudTlqxu$C_WBdc6)t56m0syII3adTmnxudTlqxu$C_WBdc6)t56m0syII3adTl+xudTnAxu$C_V|)vuC6+Tf0sy7b3U>ei diff --git a/libs/common/bin/mid3cp.exe b/libs/common/bin/mid3cp.exe index 9a70d062dfbffba14ef55f050778784888d6a1a0..e681f9dd09890fa25f657eb9667658e285a17cc6 100644 GIT binary patch delta 27 icmaEJj_u7kwuUW?>dTl+xR+}#V`K!<)3ui~IsyQj1qmns delta 27 icmaEJj_u7kwuUW?>dTlqxt41#V`K!<)3ui~IsyQqKMBwP diff --git a/libs/common/bin/mid3iconv.exe b/libs/common/bin/mid3iconv.exe index 44ea559e00c2b89c92c2bc7d97be237964ef0fec..7811e559b6e655cae6236d63aef8657bc9145945 100644 GIT binary patch delta 28 jcmaEVj_v(9wuUW?>dTl+xVLLAV`OFoGN$V-XLJMrqv;7m delta 28 jcmaEVj_v(9wuUW?>dTlqxwdOAV`OFoGN$V-XLJMrt9J?P diff --git a/libs/common/bin/mid3v2.exe b/libs/common/bin/mid3v2.exe index 38e087cf870527693520df767bc27454f14e63e7..ffbf60c8a6d1189fc63e165bf4d3e04155da274b 100644 GIT binary patch delta 27 icmaEJj_u7kwuUW?>dTl+xR+}#V`K!<)3ui~IsyQj1qmns delta 27 icmaEJj_u7kwuUW?>dTlqxt41#V`K!<)3ui~IsyQqKMBwP diff --git a/libs/common/bin/moggsplit.exe b/libs/common/bin/moggsplit.exe index 060ba5d6815b30b6bafbb148527e9a4fb25a5b6c..cba468f2e6fcea277cb6d936ef4b5a5880d48bdf 100644 GIT binary patch delta 28 jcmaEVj_v(9wuUW?>dTl+xVLLAV`OFoGN$V-XLJMrqv;7m delta 28 jcmaEVj_v(9wuUW?>dTlqxwdOAV`OFoGN$V-XLJMrt9J?P diff --git a/libs/common/bin/mutagen-inspect.exe b/libs/common/bin/mutagen-inspect.exe index 639e3339a506e4852fc491aff869718e9bcc6c48..3e9d5294a00580c02a03bef7d785507c20705fca 100644 GIT binary patch delta 28 jcmex*j_vC?wuUW?>dTl+xVLLAW8`E6GNzj>XLJMrrF#in delta 28 jcmex*j_vC?wuUW?>dTlqxwdOAW8`E6GNzj>XLJMrtqBSQ diff --git a/libs/common/bin/mutagen-pony.exe b/libs/common/bin/mutagen-pony.exe index 9961a3583fb8356b2f836be2655dcf1e9c99e8c3..1544315ba95bb25e206826f5345763026a84d09d 100644 GIT binary patch delta 28 jcmex#j_uPqwuUW?>dTl+xVLLAV`O6lGNv0YXLJMrq_PQ6 delta 28 jcmex#j_uPqwuUW?>dTlqxwdOAV`O6lGNv0YXLJMrtUw9) diff --git a/libs/common/bin/pbr.exe b/libs/common/bin/pbr.exe index 28e739c3e226fa1c8ebcbdd7b9db0c9f7b5dbe0e..805e3677e4f8c740ed161fc4d6e7ba03950fbc4f 100644 GIT binary patch delta 27 hcmcb5j_v9>wuUW?>dTmnxu3Aq3O delta 27 hcmcb5j_v9>wuUW?>dTlqxu$C_V|)psMV2!<0sy3@3S$5O diff --git a/libs/common/bin/srt.exe b/libs/common/bin/srt.exe index f3a3e58618d90890d6ac531abb178bc752821fc1..6a910ec5766773a425f85ff28e678a86a2e1eec5 100644 GIT binary patch delta 27 hcmcb9j_vw6wuUW?>dTl+xTk9_V|)#w#g;QV0sx)t3Bv#Y delta 27 hcmcb9j_vw6wuUW?>dTnAxTb3^V|)#w#g;QV0sy5v3T*%Y diff --git a/libs/common/bin/subliminal.exe b/libs/common/bin/subliminal.exe index e4144d13367cd0a6020513979dcacc0f27af2273..c638b62b26ec2bb73bbe5642ae49b6d0e726abe5 100644 GIT binary patch delta 27 hcmaESj_vU|wuUW?>dTnSxTk9_WBd-Hm6kI)0sx^H3HJa1 delta 27 hcmaESj_vU|wuUW?>dTnAxu$C_WBd-Hm6kI)0syE|3ZDP~ diff --git a/libs/common/bin/unidecode.exe b/libs/common/bin/unidecode.exe index 8fe25e42d7ba78c322ad95e55a256a25194dfecd..9d73f7d0967faa9931addc4fb5e64cf1b1a91628 100644 GIT binary patch delta 27 hcmcb9j_vw6wuUW?>dTmnxudTlqxu$C_V|)#w#g;QV0sy5Z3TprW diff --git a/libs/common/more_itertools/__init__.py b/libs/common/more_itertools/__init__.py index 557bfc20..bba462c3 100644 --- a/libs/common/more_itertools/__init__.py +++ b/libs/common/more_itertools/__init__.py @@ -1,6 +1,2 @@ -"""More routines for operating on iterables, beyond itertools""" - -from .more import * # noqa -from .recipes import * # noqa - -__version__ = '9.0.0' +from more_itertools.more import * # noqa +from more_itertools.recipes import * # noqa diff --git a/libs/common/more_itertools/__init__.pyi b/libs/common/more_itertools/__init__.pyi deleted file mode 100644 index 96f6e36c..00000000 --- a/libs/common/more_itertools/__init__.pyi +++ /dev/null @@ -1,2 +0,0 @@ -from .more import * -from .recipes import * diff --git a/libs/common/more_itertools/more.py b/libs/common/more_itertools/more.py index 7f73d6ed..bd32a261 100644 --- a/libs/common/more_itertools/more.py +++ b/libs/common/more_itertools/more.py @@ -1,9 +1,8 @@ -import warnings +from __future__ import print_function -from collections import Counter, defaultdict, deque, abc -from collections.abc import Sequence -from functools import partial, reduce, wraps -from heapq import heapify, heapreplace, heappop +from collections import Counter, defaultdict, deque +from functools import partial, wraps +from heapq import merge from itertools import ( chain, compress, @@ -15,162 +14,103 @@ from itertools import ( repeat, starmap, takewhile, - tee, - zip_longest, + tee ) -from math import exp, factorial, floor, log -from queue import Empty, Queue -from random import random, randrange, uniform -from operator import itemgetter, mul, sub, gt, lt, ge, le -from sys import hexversion, maxsize -from time import monotonic +from operator import itemgetter, lt, gt, sub +from sys import maxsize, version_info +try: + from collections.abc import Sequence +except ImportError: + from collections import Sequence -from .recipes import ( - _marker, - _zip_equal, - UnequalIterablesError, - consume, - flatten, - pairwise, - powerset, - take, - unique_everseen, - all_equal, -) +from six import binary_type, string_types, text_type +from six.moves import filter, map, range, zip, zip_longest + +from .recipes import consume, flatten, take __all__ = [ - 'AbortThread', - 'SequenceView', - 'UnequalIterablesError', 'adjacent', - 'all_unique', 'always_iterable', 'always_reversible', 'bucket', - 'callback_iter', 'chunked', - 'chunked_even', 'circular_shifts', 'collapse', - 'combination_index', + 'collate', 'consecutive_groups', - 'constrained_batches', 'consumer', 'count_cycle', - 'countable', 'difference', - 'distinct_combinations', 'distinct_permutations', 'distribute', 'divide', - 'duplicates_everseen', - 'duplicates_justseen', 'exactly_n', - 'filter_except', 'first', 'groupby_transform', - 'ichunked', - 'iequals', 'ilen', - 'interleave', - 'interleave_evenly', 'interleave_longest', + 'interleave', 'intersperse', - 'is_sorted', 'islice_extended', 'iterate', 'last', 'locate', - 'longest_common_prefix', 'lstrip', 'make_decorator', - 'map_except', - 'map_if', 'map_reduce', - 'mark_ends', - 'minmax', - 'nth_or_last', - 'nth_permutation', - 'nth_product', 'numeric_range', 'one', - 'only', 'padded', - 'partitions', 'peekable', - 'permutation_index', - 'product_index', - 'raise_', - 'repeat_each', - 'repeat_last', 'replace', 'rlocate', 'rstrip', 'run_length', - 'sample', 'seekable', - 'set_partitions', + 'SequenceView', 'side_effect', 'sliced', 'sort_together', - 'split_after', 'split_at', + 'split_after', 'split_before', 'split_into', - 'split_when', 'spy', 'stagger', 'strip', - 'strictly_n', 'substrings', - 'substrings_indexes', - 'time_limited', - 'unique_in_window', 'unique_to_each', 'unzip', - 'value_chain', 'windowed', - 'windowed_complete', 'with_iter', - 'zip_broadcast', - 'zip_equal', 'zip_offset', ] +_marker = object() -def chunked(iterable, n, strict=False): + +def chunked(iterable, n): """Break *iterable* into lists of length *n*: >>> list(chunked([1, 2, 3, 4, 5, 6], 3)) [[1, 2, 3], [4, 5, 6]] - By the default, the last yielded list will have fewer than *n* elements - if the length of *iterable* is not divisible by *n*: + If the length of *iterable* is not evenly divisible by *n*, the last + returned list will be shorter: >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3)) [[1, 2, 3], [4, 5, 6], [7, 8]] To use a fill-in value instead, see the :func:`grouper` recipe. - If the length of *iterable* is not divisible by *n* and *strict* is - ``True``, then ``ValueError`` will be raised before the last - list is yielded. + :func:`chunked` is useful for splitting up a computation on a large number + of keys into batches, to be pickled and sent off to worker processes. One + example is operations on rows in MySQL, which does not implement + server-side cursors properly and would otherwise load the entire dataset + into RAM on the client. """ - iterator = iter(partial(take, n, iter(iterable)), []) - if strict: - if n is None: - raise ValueError('n must not be None when using strict mode.') - - def ret(): - for chunk in iterator: - if len(chunk) != n: - raise ValueError('iterable is not divisible by n.') - yield chunk - - return iter(ret()) - else: - return iterator + return iter(partial(take, n, iter(iterable)), []) def first(iterable, default=_marker): @@ -192,12 +132,14 @@ def first(iterable, default=_marker): """ try: return next(iter(iterable)) - except StopIteration as e: + except StopIteration: + # I'm on the edge about raising ValueError instead of StopIteration. At + # the moment, ValueError wins, because the caller could conceivably + # want to do something different with flow control when I raise the + # exception, and it's weird to explicitly catch StopIteration. if default is _marker: - raise ValueError( - 'first() was called on an empty iterable, and no ' - 'default value was provided.' - ) from e + raise ValueError('first() was called on an empty iterable, and no ' + 'default value was provided.') return default @@ -214,40 +156,20 @@ def last(iterable, default=_marker): raise ``ValueError``. """ try: - if isinstance(iterable, Sequence): + try: + # Try to access the last item directly return iterable[-1] - # Work around https://bugs.python.org/issue38525 - elif hasattr(iterable, '__reversed__') and (hexversion != 0x030800F0): - return next(reversed(iterable)) - else: - return deque(iterable, maxlen=1)[-1] - except (IndexError, TypeError, StopIteration): + except (TypeError, AttributeError, KeyError): + # If not slice-able, iterate entirely using length-1 deque + return deque(iterable, maxlen=1)[0] + except IndexError: # If the iterable was empty if default is _marker: - raise ValueError( - 'last() was called on an empty iterable, and no default was ' - 'provided.' - ) + raise ValueError('last() was called on an empty iterable, and no ' + 'default value was provided.') return default -def nth_or_last(iterable, n, default=_marker): - """Return the nth or the last item of *iterable*, - or *default* if *iterable* is empty. - - >>> nth_or_last([0, 1, 2, 3], 2) - 2 - >>> nth_or_last([0, 1], 2) - 1 - >>> nth_or_last([], 0, 'some default') - 'some default' - - If *default* is not provided and there are no items in the iterable, - raise ``ValueError``. - """ - return last(islice(iterable, n + 1), default=default) - - -class peekable: +class peekable(object): """Wrap an iterator to allow lookahead and prepending elements. Call :meth:`peek` on the result to get the value that will be returned @@ -300,12 +222,11 @@ class peekable: >>> if p: # peekable has items ... list(p) ['a', 'b'] - >>> if not p: # peekable is exhausted + >>> if not p: # peekable is exhaused ... list(p) [] """ - def __init__(self, iterable): self._it = iter(iterable) self._cache = deque() @@ -320,6 +241,10 @@ class peekable: return False return True + def __nonzero__(self): + # For Python 2 compatibility + return self.__bool__() + def peek(self, default=_marker): """Return the item that will be next returned from ``next()``. @@ -373,6 +298,8 @@ class peekable: return next(self._it) + next = __next__ # For Python 2 compatibility + def _get_slice(self, index): # Normalize the slice's arguments step = 1 if (index.step is None) else index.step @@ -412,6 +339,70 @@ class peekable: return self._cache[index] +def _collate(*iterables, **kwargs): + """Helper for ``collate()``, called when the user is using the ``reverse`` + or ``key`` keyword arguments on Python versions below 3.5. + + """ + key = kwargs.pop('key', lambda a: a) + reverse = kwargs.pop('reverse', False) + + min_or_max = partial(max if reverse else min, key=itemgetter(0)) + peekables = [peekable(it) for it in iterables] + peekables = [p for p in peekables if p] # Kill empties. + while peekables: + _, p = min_or_max((key(p.peek()), p) for p in peekables) + yield next(p) + peekables = [x for x in peekables if x] + + +def collate(*iterables, **kwargs): + """Return a sorted merge of the items from each of several already-sorted + *iterables*. + + >>> list(collate('ACDZ', 'AZ', 'JKL')) + ['A', 'A', 'C', 'D', 'J', 'K', 'L', 'Z', 'Z'] + + Works lazily, keeping only the next value from each iterable in memory. Use + :func:`collate` to, for example, perform a n-way mergesort of items that + don't fit in memory. + + If a *key* function is specified, the iterables will be sorted according + to its result: + + >>> key = lambda s: int(s) # Sort by numeric value, not by string + >>> list(collate(['1', '10'], ['2', '11'], key=key)) + ['1', '2', '10', '11'] + + + If the *iterables* are sorted in descending order, set *reverse* to + ``True``: + + >>> list(collate([5, 3, 1], [4, 2, 0], reverse=True)) + [5, 4, 3, 2, 1, 0] + + If the elements of the passed-in iterables are out of order, you might get + unexpected results. + + On Python 2.7, this function delegates to :func:`heapq.merge` if neither + of the keyword arguments are specified. On Python 3.5+, this function + is an alias for :func:`heapq.merge`. + + """ + if not kwargs: + return merge(*iterables) + + return _collate(*iterables, **kwargs) + + +# If using Python version 3.5 or greater, heapq.merge() will be faster than +# collate - use that instead. +if version_info >= (3, 5, 0): + _collate_docstring = collate.__doc__ + collate = partial(merge) + collate.__doc__ = _collate_docstring + + def consumer(func): """Decorator that automatically advances a PEP-342-style "reverse iterator" to its first yield point so you don't have to call ``next()`` on it @@ -434,13 +425,11 @@ def consumer(func): ``t.send()`` could be used. """ - @wraps(func) def wrapper(*args, **kwargs): gen = func(*args, **kwargs) next(gen) return gen - return wrapper @@ -464,9 +453,9 @@ def ilen(iterable): def iterate(func, start): """Return ``start``, ``func(start)``, ``func(func(start))``, ... - >>> from itertools import islice - >>> list(islice(iterate(lambda x: 2*x, 1), 10)) - [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + >>> from itertools import islice + >>> list(islice(iterate(lambda x: 2*x, 1), 10)) + [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] """ while True: @@ -486,7 +475,8 @@ def with_iter(context_manager): """ with context_manager as iterable: - yield from iterable + for item in iterable: + yield item def one(iterable, too_short=None, too_long=None): @@ -520,8 +510,7 @@ def one(iterable, too_short=None, too_long=None): >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError: Expected exactly one item in iterable, but got 'too', - 'many', and perhaps more. + ValueError: too many items in iterable (expected 1)' >>> too_long = RuntimeError >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): @@ -529,112 +518,29 @@ def one(iterable, too_short=None, too_long=None): RuntimeError Note that :func:`one` attempts to advance *iterable* twice to ensure there - is only one item. See :func:`spy` or :func:`peekable` to check iterable - contents less destructively. + is only one item. If there is more than one, both items will be discarded. + See :func:`spy` or :func:`peekable` to check iterable contents less + destructively. """ it = iter(iterable) try: - first_value = next(it) - except StopIteration as e: - raise ( - too_short or ValueError('too few items in iterable (expected 1)') - ) from e - - try: - second_value = next(it) + value = next(it) except StopIteration: - pass - else: - msg = ( - 'Expected exactly one item in iterable, but got {!r}, {!r}, ' - 'and perhaps more.'.format(first_value, second_value) - ) - raise too_long or ValueError(msg) - - return first_value - - -def raise_(exception, *args): - raise exception(*args) - - -def strictly_n(iterable, n, too_short=None, too_long=None): - """Validate that *iterable* has exactly *n* items and return them if - it does. If it has fewer than *n* items, call function *too_short* - with those items. If it has more than *n* items, call function - *too_long* with the first ``n + 1`` items. - - >>> iterable = ['a', 'b', 'c', 'd'] - >>> n = 4 - >>> list(strictly_n(iterable, n)) - ['a', 'b', 'c', 'd'] - - By default, *too_short* and *too_long* are functions that raise - ``ValueError``. - - >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError: too few items in iterable (got 2) - - >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError: too many items in iterable (got at least 3) - - You can instead supply functions that do something else. - *too_short* will be called with the number of items in *iterable*. - *too_long* will be called with `n + 1`. - - >>> def too_short(item_count): - ... raise RuntimeError - >>> it = strictly_n('abcd', 6, too_short=too_short) - >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - RuntimeError - - >>> def too_long(item_count): - ... print('The boss is going to hear about this') - >>> it = strictly_n('abcdef', 4, too_long=too_long) - >>> list(it) - The boss is going to hear about this - ['a', 'b', 'c', 'd'] - - """ - if too_short is None: - too_short = lambda item_count: raise_( - ValueError, - 'Too few items in iterable (got {})'.format(item_count), - ) - - if too_long is None: - too_long = lambda item_count: raise_( - ValueError, - 'Too many items in iterable (got at least {})'.format(item_count), - ) - - it = iter(iterable) - for i in range(n): - try: - item = next(it) - except StopIteration: - too_short(i) - return - else: - yield item + raise too_short or ValueError('too few items in iterable (expected 1)') try: next(it) except StopIteration: pass else: - too_long(n + 1) + raise too_long or ValueError('too many items in iterable (expected 1)') + + return value -def distinct_permutations(iterable, r=None): +def distinct_permutations(iterable): """Yield successive distinct permutations of the elements in *iterable*. >>> sorted(distinct_permutations([1, 0, 1])) @@ -650,88 +556,34 @@ def distinct_permutations(iterable, r=None): items input, and each `x_i` is the count of a distinct item in the input sequence. - If *r* is given, only the *r*-length permutations are yielded. - - >>> sorted(distinct_permutations([1, 0, 1], r=2)) - [(0, 1), (1, 0), (1, 1)] - >>> sorted(distinct_permutations(range(3), r=2)) - [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] - """ - # Algorithm: https://w.wiki/Qai - def _full(A): - while True: - # Yield the permutation we have - yield tuple(A) + def perm_unique_helper(item_counts, perm, i): + """Internal helper function - # Find the largest index i such that A[i] < A[i + 1] - for i in range(size - 2, -1, -1): - if A[i] < A[i + 1]: - break - # If no such index exists, this permutation is the last one - else: - return + :arg item_counts: Stores the unique items in ``iterable`` and how many + times they are repeated + :arg perm: The permutation that is being built for output + :arg i: The index of the permutation being modified - # Find the largest index j greater than j such that A[i] < A[j] - for j in range(size - 1, i, -1): - if A[i] < A[j]: - break + The output permutations are built up recursively; the distinct items + are placed until their repetitions are exhausted. + """ + if i < 0: + yield tuple(perm) + else: + for item in item_counts: + if item_counts[item] <= 0: + continue + perm[i] = item + item_counts[item] -= 1 + for x in perm_unique_helper(item_counts, perm, i - 1): + yield x + item_counts[item] += 1 - # Swap the value of A[i] with that of A[j], then reverse the - # sequence from A[i + 1] to form the new permutation - A[i], A[j] = A[j], A[i] - A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1] + item_counts = Counter(iterable) + length = sum(item_counts.values()) - # Algorithm: modified from the above - def _partial(A, r): - # Split A into the first r items and the last r items - head, tail = A[:r], A[r:] - right_head_indexes = range(r - 1, -1, -1) - left_tail_indexes = range(len(tail)) - - while True: - # Yield the permutation we have - yield tuple(head) - - # Starting from the right, find the first index of the head with - # value smaller than the maximum value of the tail - call it i. - pivot = tail[-1] - for i in right_head_indexes: - if head[i] < pivot: - break - pivot = head[i] - else: - return - - # Starting from the left, find the first value of the tail - # with a value greater than head[i] and swap. - for j in left_tail_indexes: - if tail[j] > head[i]: - head[i], tail[j] = tail[j], head[i] - break - # If we didn't find one, start from the right and find the first - # index of the head with a value greater than head[i] and swap. - else: - for j in right_head_indexes: - if head[j] > head[i]: - head[i], head[j] = head[j], head[i] - break - - # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)] - tail += head[: i - r : -1] # head[i + 1:][::-1] - i += 1 - head[i:], tail[:] = tail[: r - i], tail[r - i :] - - items = sorted(iterable) - - size = len(items) - if r is None: - r = size - - if 0 < r <= size: - return _full(items) if (r == size) else _partial(items, r) - - return iter(() if r else ((),)) + return perm_unique_helper(item_counts, [None] * length, length - 1) def intersperse(e, iterable, n=1): @@ -748,8 +600,8 @@ def intersperse(e, iterable, n=1): if n == 0: raise ValueError('n must be > 0') elif n == 1: - # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2... - # islice(..., 1, None) -> x_0, e, x_1, e, x_2... + # interleave(repeat(e), iterable) -> e, x_0, e, e, x_1, e, x_2... + # islice(..., 1, None) -> x_0, e, e, x_1, e, x_2... return islice(interleave(repeat(e), iterable), 1, None) else: # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]... @@ -801,7 +653,7 @@ def windowed(seq, n, fillvalue=None, step=1): [(1, 2, 3), (2, 3, 4), (3, 4, 5)] When the window is larger than the iterable, *fillvalue* is used in place - of missing values: + of missing values:: >>> list(windowed([1, 2, 3], 4)) [(1, 2, 3, None)] @@ -811,14 +663,6 @@ def windowed(seq, n, fillvalue=None, step=1): >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2)) [(1, 2, 3), (3, 4, 5), (5, 6, '!')] - To slide into the iterable's items, use :func:`chain` to add filler items - to the left: - - >>> iterable = [1, 2, 3, 4] - >>> n = 3 - >>> padding = [None] * (n - 1) - >>> list(windowed(chain(padding, iterable), 3)) - [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)] """ if n < 0: raise ValueError('n must be >= 0') @@ -828,25 +672,32 @@ def windowed(seq, n, fillvalue=None, step=1): if step < 1: raise ValueError('step must be >= 1') - window = deque(maxlen=n) - i = n - for _ in map(window.append, seq): - i -= 1 - if not i: - i = step + it = iter(seq) + window = deque([], n) + append = window.append + + # Initial deque fill + for _ in range(n): + append(next(it, fillvalue)) + yield tuple(window) + + # Appending new items to the right causes old items to fall off the left + i = 0 + for item in it: + append(item) + i = (i + 1) % step + if i % step == 0: yield tuple(window) - size = len(window) - if size == 0: - return - elif size < n: - yield tuple(chain(window, repeat(fillvalue, n - size))) - elif 0 < i < min(step, n): - window += (fillvalue,) * i + # If there are items from the iterable in the window, pad with the given + # value and emit them. + if (i % step) and (step - i < n): + for _ in range(step - i): + append(fillvalue) yield tuple(window) -def substrings(iterable): +def substrings(iterable, join_func=None): """Yield all of the substrings of *iterable*. >>> [''.join(s) for s in substrings('more')] @@ -869,51 +720,15 @@ def substrings(iterable): # And the rest for n in range(2, item_count + 1): for i in range(item_count - n + 1): - yield seq[i : i + n] + yield seq[i:i + n] -def substrings_indexes(seq, reverse=False): - """Yield all substrings and their positions in *seq* - - The items yielded will be a tuple of the form ``(substr, i, j)``, where - ``substr == seq[i:j]``. - - This function only works for iterables that support slicing, such as - ``str`` objects. - - >>> for item in substrings_indexes('more'): - ... print(item) - ('m', 0, 1) - ('o', 1, 2) - ('r', 2, 3) - ('e', 3, 4) - ('mo', 0, 2) - ('or', 1, 3) - ('re', 2, 4) - ('mor', 0, 3) - ('ore', 1, 4) - ('more', 0, 4) - - Set *reverse* to ``True`` to yield the same items in the opposite order. - - - """ - r = range(1, len(seq) + 1) - if reverse: - r = reversed(r) - return ( - (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1) - ) - - -class bucket: +class bucket(object): """Wrap *iterable* and return an object that buckets it iterable into child iterables based on a *key* function. >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] - >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character - >>> sorted(list(s)) # Get the keys - ['a', 'b', 'c'] + >>> s = bucket(iterable, key=lambda x: x[0]) >>> a_iterable = s['a'] >>> next(a_iterable) 'a1' @@ -941,7 +756,6 @@ class bucket: [] """ - def __init__(self, iterable, key, validator=None): self._it = iter(iterable) self._key = key @@ -987,14 +801,6 @@ class bucket: elif self._validator(item_value): self._cache[item_value].append(item) - def __iter__(self): - for item in self._it: - item_value = self._key(item) - if self._validator(item_value): - self._cache[item_value].append(item) - - yield from self._cache.keys() - def __getitem__(self, value): if not self._validator(value): return iter(()) @@ -1042,7 +848,7 @@ def spy(iterable, n=1): it = iter(iterable) head = take(n, it) - return head.copy(), chain(head, it) + return head, chain(head, it) def interleave(*iterables): @@ -1075,72 +881,6 @@ def interleave_longest(*iterables): return (x for x in i if x is not _marker) -def interleave_evenly(iterables, lengths=None): - """ - Interleave multiple iterables so that their elements are evenly distributed - throughout the output sequence. - - >>> iterables = [1, 2, 3, 4, 5], ['a', 'b'] - >>> list(interleave_evenly(iterables)) - [1, 2, 'a', 3, 4, 'b', 5] - - >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]] - >>> list(interleave_evenly(iterables)) - [1, 6, 4, 2, 7, 3, 8, 5] - - This function requires iterables of known length. Iterables without - ``__len__()`` can be used by manually specifying lengths with *lengths*: - - >>> from itertools import combinations, repeat - >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']] - >>> lengths = [4 * (4 - 1) // 2, 3] - >>> list(interleave_evenly(iterables, lengths=lengths)) - [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c'] - - Based on Bresenham's algorithm. - """ - if lengths is None: - try: - lengths = [len(it) for it in iterables] - except TypeError: - raise ValueError( - 'Iterable lengths could not be determined automatically. ' - 'Specify them with the lengths keyword.' - ) - elif len(iterables) != len(lengths): - raise ValueError('Mismatching number of iterables and lengths.') - - dims = len(lengths) - - # sort iterables by length, descending - lengths_permute = sorted( - range(dims), key=lambda i: lengths[i], reverse=True - ) - lengths_desc = [lengths[i] for i in lengths_permute] - iters_desc = [iter(iterables[i]) for i in lengths_permute] - - # the longest iterable is the primary one (Bresenham: the longest - # distance along an axis) - delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:] - iter_primary, iters_secondary = iters_desc[0], iters_desc[1:] - errors = [delta_primary // dims] * len(deltas_secondary) - - to_yield = sum(lengths) - while to_yield: - yield next(iter_primary) - to_yield -= 1 - # update errors for each secondary iterable - errors = [e - delta for e, delta in zip(errors, deltas_secondary)] - - # those iterables for which the error is negative are yielded - # ("diagonal step" in Bresenham) - for i, e in enumerate(errors): - if e < 0: - yield next(iters_secondary[i]) - to_yield -= 1 - errors[i] += delta_primary - - def collapse(iterable, base_type=None, levels=None): """Flatten an iterable with multiple levels of nesting (e.g., a list of lists of tuples) into non-iterable types. @@ -1149,9 +889,7 @@ def collapse(iterable, base_type=None, levels=None): >>> list(collapse(iterable)) [1, 2, 3, 4, 5, 6] - Binary and text strings are not considered iterable and - will not be collapsed. - + String types are not considered iterable and will not be collapsed. To avoid collapsing other types, specify *base_type*: >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']] @@ -1167,12 +905,11 @@ def collapse(iterable, base_type=None, levels=None): ['a', ['b'], 'c', ['d']] """ - def walk(node, level): if ( - ((levels is not None) and (level > levels)) - or isinstance(node, (str, bytes)) - or ((base_type is not None) and isinstance(node, base_type)) + ((levels is not None) and (level > levels)) or + isinstance(node, string_types) or + ((base_type is not None) and isinstance(node, base_type)) ): yield node return @@ -1184,9 +921,11 @@ def collapse(iterable, base_type=None, levels=None): return else: for child in tree: - yield from walk(child, level + 1) + for x in walk(child, level + 1): + yield x - yield from walk(iterable, 0) + for x in walk(iterable, 0): + yield x def side_effect(func, iterable, chunk_size=None, before=None, after=None): @@ -1244,93 +983,56 @@ def side_effect(func, iterable, chunk_size=None, before=None, after=None): else: for chunk in chunked(iterable, chunk_size): func(chunk) - yield from chunk + for item in chunk: + yield item finally: if after is not None: after() -def sliced(seq, n, strict=False): +def sliced(seq, n): """Yield slices of length *n* from the sequence *seq*. - >>> list(sliced((1, 2, 3, 4, 5, 6), 3)) - [(1, 2, 3), (4, 5, 6)] + >>> list(sliced((1, 2, 3, 4, 5, 6), 3)) + [(1, 2, 3), (4, 5, 6)] - By the default, the last yielded slice will have fewer than *n* elements - if the length of *seq* is not divisible by *n*: + If the length of the sequence is not divisible by the requested slice + length, the last slice will be shorter. - >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3)) - [(1, 2, 3), (4, 5, 6), (7, 8)] - - If the length of *seq* is not divisible by *n* and *strict* is - ``True``, then ``ValueError`` will be raised before the last - slice is yielded. + >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3)) + [(1, 2, 3), (4, 5, 6), (7, 8)] This function will only work for iterables that support slicing. For non-sliceable iterables, see :func:`chunked`. """ - iterator = takewhile(len, (seq[i : i + n] for i in count(0, n))) - if strict: - - def ret(): - for _slice in iterator: - if len(_slice) != n: - raise ValueError("seq is not divisible by n.") - yield _slice - - return iter(ret()) - else: - return iterator + return takewhile(bool, (seq[i: i + n] for i in count(0, n))) -def split_at(iterable, pred, maxsplit=-1, keep_separator=False): +def split_at(iterable, pred): """Yield lists of items from *iterable*, where each list is delimited by - an item where callable *pred* returns ``True``. + an item where callable *pred* returns ``True``. The lists do not include + the delimiting items. >>> list(split_at('abcdcba', lambda x: x == 'b')) [['a'], ['c', 'd', 'c'], ['a']] >>> list(split_at(range(10), lambda n: n % 2 == 1)) [[0], [2], [4], [6], [8], []] - - At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, - then there is no limit on the number of splits: - - >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2)) - [[0], [2], [4, 5, 6, 7, 8, 9]] - - By default, the delimiting items are not included in the output. - The include them, set *keep_separator* to ``True``. - - >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True)) - [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']] - """ - if maxsplit == 0: - yield list(iterable) - return - buf = [] - it = iter(iterable) - for item in it: + for item in iterable: if pred(item): yield buf - if keep_separator: - yield [item] - if maxsplit == 1: - yield list(it) - return buf = [] - maxsplit -= 1 else: buf.append(item) yield buf -def split_before(iterable, pred, maxsplit=-1): - """Yield lists of items from *iterable*, where each list ends just before - an item for which callable *pred* returns ``True``: +def split_before(iterable, pred): + """Yield lists of items from *iterable*, where each list starts with an + item where callable *pred* returns ``True``: >>> list(split_before('OneTwo', lambda s: s.isupper())) [['O', 'n', 'e'], ['T', 'w', 'o']] @@ -1338,32 +1040,17 @@ def split_before(iterable, pred, maxsplit=-1): >>> list(split_before(range(10), lambda n: n % 3 == 0)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] - At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, - then there is no limit on the number of splits: - - >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2)) - [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]] """ - if maxsplit == 0: - yield list(iterable) - return - buf = [] - it = iter(iterable) - for item in it: + for item in iterable: if pred(item) and buf: yield buf - if maxsplit == 1: - yield [item] + list(it) - return buf = [] - maxsplit -= 1 buf.append(item) - if buf: - yield buf + yield buf -def split_after(iterable, pred, maxsplit=-1): +def split_after(iterable, pred): """Yield lists of items from *iterable*, where each list ends with an item where callable *pred* returns ``True``: @@ -1373,77 +1060,17 @@ def split_after(iterable, pred, maxsplit=-1): >>> list(split_after(range(10), lambda n: n % 3 == 0)) [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]] - At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, - then there is no limit on the number of splits: - - >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2)) - [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]] - """ - if maxsplit == 0: - yield list(iterable) - return - buf = [] - it = iter(iterable) - for item in it: + for item in iterable: buf.append(item) if pred(item) and buf: yield buf - if maxsplit == 1: - yield list(it) - return buf = [] - maxsplit -= 1 if buf: yield buf -def split_when(iterable, pred, maxsplit=-1): - """Split *iterable* into pieces based on the output of *pred*. - *pred* should be a function that takes successive pairs of items and - returns ``True`` if the iterable should be split in between them. - - For example, to find runs of increasing numbers, split the iterable when - element ``i`` is larger than element ``i + 1``: - - >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y)) - [[1, 2, 3, 3], [2, 5], [2, 4], [2]] - - At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, - then there is no limit on the number of splits: - - >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], - ... lambda x, y: x > y, maxsplit=2)) - [[1, 2, 3, 3], [2, 5], [2, 4, 2]] - - """ - if maxsplit == 0: - yield list(iterable) - return - - it = iter(iterable) - try: - cur_item = next(it) - except StopIteration: - return - - buf = [cur_item] - for next_item in it: - if pred(cur_item, next_item): - yield buf - if maxsplit == 1: - yield [next_item] + list(it) - return - buf = [] - maxsplit -= 1 - - buf.append(next_item) - cur_item = next_item - - yield buf - - def split_into(iterable, sizes): """Yield a list of sequential items from *iterable* of length 'n' for each integer 'n' in *sizes*. @@ -1507,7 +1134,8 @@ def padded(iterable, fillvalue=None, n=None, next_multiple=False): """ it = iter(iterable) if n is None: - yield from chain(it, repeat(fillvalue)) + for item in chain(it, repeat(fillvalue)): + yield item elif n < 1: raise ValueError('n must be at least 1') else: @@ -1521,34 +1149,6 @@ def padded(iterable, fillvalue=None, n=None, next_multiple=False): yield fillvalue -def repeat_each(iterable, n=2): - """Repeat each element in *iterable* *n* times. - - >>> list(repeat_each('ABC', 3)) - ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C'] - """ - return chain.from_iterable(map(repeat, iterable, repeat(n))) - - -def repeat_last(iterable, default=None): - """After the *iterable* is exhausted, keep yielding its last element. - - >>> list(islice(repeat_last(range(3)), 5)) - [0, 1, 2, 2, 2] - - If the iterable is empty, yield *default* forever:: - - >>> list(islice(repeat_last(range(0), 42), 5)) - [42, 42, 42, 42, 42] - - """ - item = _marker - for item in iterable: - yield item - final = default if item is _marker else item - yield from repeat(final) - - def distribute(n, iterable): """Distribute the items from *iterable* among *n* smaller iterables. @@ -1612,38 +1212,7 @@ def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None): ) -def zip_equal(*iterables): - """``zip`` the input *iterables* together, but raise - ``UnequalIterablesError`` if they aren't all the same length. - - >>> it_1 = range(3) - >>> it_2 = iter('abc') - >>> list(zip_equal(it_1, it_2)) - [(0, 'a'), (1, 'b'), (2, 'c')] - - >>> it_1 = range(3) - >>> it_2 = iter('abcd') - >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - more_itertools.more.UnequalIterablesError: Iterables have different - lengths - - """ - if hexversion >= 0x30A00A6: - warnings.warn( - ( - 'zip_equal will be removed in a future version of ' - 'more-itertools. Use the builtin zip function with ' - 'strict=True instead.' - ), - DeprecationWarning, - ) - - return _zip_equal(*iterables) - - -def zip_offset(*iterables, offsets, longest=False, fillvalue=None): +def zip_offset(*iterables, **kwargs): """``zip`` the input *iterables* together, but offset the `i`-th iterable by the `i`-th item in *offsets*. @@ -1664,6 +1233,10 @@ def zip_offset(*iterables, offsets, longest=False, fillvalue=None): sequence. Specify *fillvalue* to use some other value. """ + offsets = kwargs['offsets'] + longest = kwargs.get('longest', False) + fillvalue = kwargs.get('fillvalue', None) + if len(iterables) != len(offsets): raise ValueError("Number of iterables and offsets didn't match") @@ -1682,7 +1255,7 @@ def zip_offset(*iterables, offsets, longest=False, fillvalue=None): return zip(*staggered) -def sort_together(iterables, key_list=(0,), key=None, reverse=False): +def sort_together(iterables, key_list=(0,), reverse=False): """Return the input iterables sorted together, with *key_list* as the priority for sorting. All iterables are trimmed to the length of the shortest one. @@ -1704,48 +1277,15 @@ def sort_together(iterables, key_list=(0,), key=None, reverse=False): >>> sort_together(iterables, key_list=(1, 2)) [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')] - To sort by a function of the elements of the iterable, pass a *key* - function. Its arguments are the elements of the iterables corresponding to - the key list:: - - >>> names = ('a', 'b', 'c') - >>> lengths = (1, 2, 3) - >>> widths = (5, 2, 1) - >>> def area(length, width): - ... return length * width - >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area) - [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)] - Set *reverse* to ``True`` to sort in descending order. >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True) [(3, 2, 1), ('a', 'b', 'c')] """ - if key is None: - # if there is no key function, the key argument to sorted is an - # itemgetter - key_argument = itemgetter(*key_list) - else: - # if there is a key function, call it with the items at the offsets - # specified by the key function as arguments - key_list = list(key_list) - if len(key_list) == 1: - # if key_list contains a single item, pass the item at that offset - # as the only argument to the key function - key_offset = key_list[0] - key_argument = lambda zipped_items: key(zipped_items[key_offset]) - else: - # if key_list contains multiple items, use itemgetter to return a - # tuple of items, which we pass as *args to the key function - get_key_items = itemgetter(*key_list) - key_argument = lambda zipped_items: key( - *get_key_items(zipped_items) - ) - - return list( - zip(*sorted(zip(*iterables), key=key_argument, reverse=reverse)) - ) + return list(zip(*sorted(zip(*iterables), + key=itemgetter(*key_list), + reverse=reverse))) def unzip(iterable): @@ -1753,7 +1293,7 @@ def unzip(iterable): of the zipped *iterable*. The ``i``-th iterable contains the ``i``-th element from each element - of the zipped iterable. The first element is used to determine the + of the zipped iterable. The first element is used to to determine the length of the remaining elements. >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)] @@ -1791,7 +1331,6 @@ def unzip(iterable): # which just stops the unzipped iterables # at first length mismatch raise StopIteration - return getter return tuple(map(itemgetter(i), it) for i, it in enumerate(iterables)) @@ -1829,26 +1368,19 @@ def divide(n, iterable): if n < 1: raise ValueError('n must be at least 1') - try: - iterable[:0] - except TypeError: - seq = tuple(iterable) - else: - seq = iterable - + seq = tuple(iterable) q, r = divmod(len(seq), n) ret = [] - stop = 0 - for i in range(1, n + 1): - start = stop - stop += q + 1 if i <= r else q + for i in range(n): + start = (i * q) + (i if i < r else r) + stop = ((i + 1) * q) + (i + 1 if i + 1 < r else r) ret.append(iter(seq[start:stop])) return ret -def always_iterable(obj, base_type=(str, bytes)): +def always_iterable(obj, base_type=(text_type, binary_type)): """If *obj* is iterable, return an iterator over its items:: >>> obj = (1, 2, 3) @@ -1940,23 +1472,21 @@ def adjacent(predicate, iterable, distance=1): return zip(adjacent_to_selected, i2) -def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None): - """An extension of :func:`itertools.groupby` that can apply transformations - to the grouped data. +def groupby_transform(iterable, keyfunc=None, valuefunc=None): + """An extension of :func:`itertools.groupby` that transforms the values of + *iterable* after grouping them. + *keyfunc* is a function used to compute a grouping key for each item. + *valuefunc* is a function for transforming the items after grouping. - * *keyfunc* is a function computing a key value for each item in *iterable* - * *valuefunc* is a function that transforms the individual items from - *iterable* after grouping - * *reducefunc* is a function that transforms each group of items + >>> iterable = 'AaaABbBCcA' + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: x.lower() + >>> grouper = groupby_transform(iterable, keyfunc, valuefunc) + >>> [(k, ''.join(g)) for k, g in grouper] + [('A', 'aaaa'), ('B', 'bbb'), ('C', 'cc'), ('A', 'a')] - >>> iterable = 'aAAbBBcCC' - >>> keyfunc = lambda k: k.upper() - >>> valuefunc = lambda v: v.lower() - >>> reducefunc = lambda g: ''.join(g) - >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc)) - [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')] - - Each optional argument defaults to an identity function if not specified. + *keyfunc* and *valuefunc* default to identity functions if they are not + specified. :func:`groupby_transform` is useful when grouping elements of an iterable using a separate iterable as the key. To do this, :func:`zip` the iterables @@ -1976,16 +1506,11 @@ def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None): duplicate groups, you should sort the iterable by the key function. """ - ret = groupby(iterable, keyfunc) - if valuefunc: - ret = ((k, map(valuefunc, g)) for k, g in ret) - if reducefunc: - ret = ((k, reducefunc(g)) for k, g in ret) - - return ret + valuefunc = (lambda x: x) if valuefunc is None else valuefunc + return ((k, map(valuefunc, g)) for k, g in groupby(iterable, keyfunc)) -class numeric_range(abc.Sequence, abc.Hashable): +def numeric_range(*args): """An extension of the built-in ``range()`` function whose arguments can be any orderable numeric type. @@ -2022,184 +1547,28 @@ class numeric_range(abc.Sequence, abc.Hashable): Be aware of the limitations of floating point numbers; the representation of the yielded numbers may be surprising. - ``datetime.datetime`` objects can be used for *start* and *stop*, if *step* - is a ``datetime.timedelta`` object: - - >>> import datetime - >>> start = datetime.datetime(2019, 1, 1) - >>> stop = datetime.datetime(2019, 1, 3) - >>> step = datetime.timedelta(days=1) - >>> items = iter(numeric_range(start, stop, step)) - >>> next(items) - datetime.datetime(2019, 1, 1, 0, 0) - >>> next(items) - datetime.datetime(2019, 1, 2, 0, 0) - """ + argc = len(args) + if argc == 1: + stop, = args + start = type(stop)(0) + step = 1 + elif argc == 2: + start, stop = args + step = 1 + elif argc == 3: + start, stop, step = args + else: + err_msg = 'numeric_range takes at most 3 arguments, got {}' + raise TypeError(err_msg.format(argc)) - _EMPTY_HASH = hash(range(0, 0)) - - def __init__(self, *args): - argc = len(args) - if argc == 1: - (self._stop,) = args - self._start = type(self._stop)(0) - self._step = type(self._stop - self._start)(1) - elif argc == 2: - self._start, self._stop = args - self._step = type(self._stop - self._start)(1) - elif argc == 3: - self._start, self._stop, self._step = args - elif argc == 0: - raise TypeError( - 'numeric_range expected at least ' - '1 argument, got {}'.format(argc) - ) - else: - raise TypeError( - 'numeric_range expected at most ' - '3 arguments, got {}'.format(argc) - ) - - self._zero = type(self._step)(0) - if self._step == self._zero: - raise ValueError('numeric_range() arg 3 must not be zero') - self._growing = self._step > self._zero - self._init_len() - - def __bool__(self): - if self._growing: - return self._start < self._stop - else: - return self._start > self._stop - - def __contains__(self, elem): - if self._growing: - if self._start <= elem < self._stop: - return (elem - self._start) % self._step == self._zero - else: - if self._start >= elem > self._stop: - return (self._start - elem) % (-self._step) == self._zero - - return False - - def __eq__(self, other): - if isinstance(other, numeric_range): - empty_self = not bool(self) - empty_other = not bool(other) - if empty_self or empty_other: - return empty_self and empty_other # True if both empty - else: - return ( - self._start == other._start - and self._step == other._step - and self._get_by_index(-1) == other._get_by_index(-1) - ) - else: - return False - - def __getitem__(self, key): - if isinstance(key, int): - return self._get_by_index(key) - elif isinstance(key, slice): - step = self._step if key.step is None else key.step * self._step - - if key.start is None or key.start <= -self._len: - start = self._start - elif key.start >= self._len: - start = self._stop - else: # -self._len < key.start < self._len - start = self._get_by_index(key.start) - - if key.stop is None or key.stop >= self._len: - stop = self._stop - elif key.stop <= -self._len: - stop = self._start - else: # -self._len < key.stop < self._len - stop = self._get_by_index(key.stop) - - return numeric_range(start, stop, step) - else: - raise TypeError( - 'numeric range indices must be ' - 'integers or slices, not {}'.format(type(key).__name__) - ) - - def __hash__(self): - if self: - return hash((self._start, self._get_by_index(-1), self._step)) - else: - return self._EMPTY_HASH - - def __iter__(self): - values = (self._start + (n * self._step) for n in count()) - if self._growing: - return takewhile(partial(gt, self._stop), values) - else: - return takewhile(partial(lt, self._stop), values) - - def __len__(self): - return self._len - - def _init_len(self): - if self._growing: - start = self._start - stop = self._stop - step = self._step - else: - start = self._stop - stop = self._start - step = -self._step - distance = stop - start - if distance <= self._zero: - self._len = 0 - else: # distance > 0 and step > 0: regular euclidean division - q, r = divmod(distance, step) - self._len = int(q) + int(r != self._zero) - - def __reduce__(self): - return numeric_range, (self._start, self._stop, self._step) - - def __repr__(self): - if self._step == 1: - return "numeric_range({}, {})".format( - repr(self._start), repr(self._stop) - ) - else: - return "numeric_range({}, {}, {})".format( - repr(self._start), repr(self._stop), repr(self._step) - ) - - def __reversed__(self): - return iter( - numeric_range( - self._get_by_index(-1), self._start - self._step, -self._step - ) - ) - - def count(self, value): - return int(value in self) - - def index(self, value): - if self._growing: - if self._start <= value < self._stop: - q, r = divmod(value - self._start, self._step) - if r == self._zero: - return int(q) - else: - if self._start >= value > self._stop: - q, r = divmod(self._start - value, -self._step) - if r == self._zero: - return int(q) - - raise ValueError("{} is not in numeric range".format(value)) - - def _get_by_index(self, i): - if i < 0: - i += self._len - if i < 0 or i >= self._len: - raise IndexError("numeric range object index out of range") - return self._start + i * self._step + values = (start + (step * n) for n in count()) + if step > 0: + return takewhile(partial(gt, stop), values) + elif step < 0: + return takewhile(partial(lt, stop), values) + else: + raise ValueError('numeric_range arg 3 must not be zero') def count_cycle(iterable, n=None): @@ -2218,43 +1587,6 @@ def count_cycle(iterable, n=None): return ((i, item) for i in counter for item in iterable) -def mark_ends(iterable): - """Yield 3-tuples of the form ``(is_first, is_last, item)``. - - >>> list(mark_ends('ABC')) - [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')] - - Use this when looping over an iterable to take special action on its first - and/or last items: - - >>> iterable = ['Header', 100, 200, 'Footer'] - >>> total = 0 - >>> for is_first, is_last, item in mark_ends(iterable): - ... if is_first: - ... continue # Skip the header - ... if is_last: - ... continue # Skip the footer - ... total += item - >>> print(total) - 300 - """ - it = iter(iterable) - - try: - b = next(it) - except StopIteration: - return - - try: - for i in count(): - a = b - b = next(it) - yield i == 0, False, a - - except StopIteration: - yield i == 0, True, a - - def locate(iterable, pred=bool, window_size=None): """Yield the index of each item in *iterable* for which *pred* returns ``True``. @@ -2303,16 +1635,6 @@ def locate(iterable, pred=bool, window_size=None): return compress(count(), starmap(pred, it)) -def longest_common_prefix(iterables): - """Yield elements of the longest common prefix amongst given *iterables*. - - >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf'])) - 'ab' - - """ - return (c[0] for c in takewhile(all_equal, zip(*iterables))) - - def lstrip(iterable, pred): """Yield the items from *iterable*, but strip any from the beginning for which *pred* returns ``True``. @@ -2347,13 +1669,13 @@ def rstrip(iterable, pred): """ cache = [] cache_append = cache.append - cache_clear = cache.clear for x in iterable: if pred(x): cache_append(x) else: - yield from cache - cache_clear() + for y in cache: + yield y + del cache[:] yield x @@ -2374,7 +1696,7 @@ def strip(iterable, pred): return rstrip(lstrip(iterable, pred), pred) -class islice_extended: +def islice_extended(iterable, *args): """An extension of :func:`itertools.islice` that supports negative values for *stop*, *start*, and *step*. @@ -2391,46 +1713,20 @@ class islice_extended: >>> list(islice_extended(count(), 110, 99, -2)) [110, 108, 106, 104, 102, 100] - You can also use slice notation directly: - - >>> iterable = map(str, count()) - >>> it = islice_extended(iterable)[10:20:2] - >>> list(it) - ['10', '12', '14', '16', '18'] - """ - - def __init__(self, iterable, *args): - it = iter(iterable) - if args: - self._iterable = _islice_helper(it, slice(*args)) - else: - self._iterable = it - - def __iter__(self): - return self - - def __next__(self): - return next(self._iterable) - - def __getitem__(self, key): - if isinstance(key, slice): - return islice_extended(_islice_helper(self._iterable, key)) - - raise TypeError('islice_extended.__getitem__ argument must be a slice') - - -def _islice_helper(it, s): + s = slice(*args) start = s.start stop = s.stop if s.step == 0: raise ValueError('step argument must be a non-zero integer or None.') step = s.step or 1 + it = iter(iterable) + if step > 0: start = 0 if (start is None) else start - if start < 0: + if (start < 0): # Consume all but the last -start items cache = deque(enumerate(it, 1), maxlen=-start) len_iter = cache[-1][0] if cache else 0 @@ -2468,7 +1764,8 @@ def _islice_helper(it, s): cache.append(item) else: # When both start and stop are positive we have the normal case - yield from islice(it, start, stop, step) + for item in islice(it, start, stop, step): + yield item else: start = -1 if (start is None) else start @@ -2513,7 +1810,8 @@ def _islice_helper(it, s): cache = list(islice(it, n)) - yield from cache[i::step] + for item in cache[i::step]: + yield item def always_reversible(iterable): @@ -2564,17 +1862,6 @@ def consecutive_groups(iterable, ordering=lambda x: x): ['i'] ['l', 'm', 'n', 'o', 'p'] - Each group of consecutive items is an iterator that shares it source with - *iterable*. When an an output group is advanced, the previous group is - no longer available unless its elements are copied (e.g., into a ``list``). - - >>> iterable = [1, 2, 11, 12, 21, 22] - >>> saved_groups = [] - >>> for group in consecutive_groups(iterable): - ... saved_groups.append(list(group)) # Copy group elements - >>> saved_groups - [[1, 2], [11, 12], [21, 22]] - """ for k, g in groupby( enumerate(iterable), key=lambda x: x[0] - ordering(x[1]) @@ -2582,46 +1869,42 @@ def consecutive_groups(iterable, ordering=lambda x: x): yield map(itemgetter(1), g) -def difference(iterable, func=sub, *, initial=None): - """This function is the inverse of :func:`itertools.accumulate`. By default - it will compute the first difference of *iterable* using - :func:`operator.sub`: +def difference(iterable, func=sub): + """By default, compute the first difference of *iterable* using + :func:`operator.sub`. - >>> from itertools import accumulate - >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10 + >>> iterable = [0, 1, 3, 6, 10] >>> list(difference(iterable)) [0, 1, 2, 3, 4] - *func* defaults to :func:`operator.sub`, but other functions can be + This is the opposite of :func:`accumulate`'s default behavior: + + >>> from more_itertools import accumulate + >>> iterable = [0, 1, 2, 3, 4] + >>> list(accumulate(iterable)) + [0, 1, 3, 6, 10] + >>> list(difference(accumulate(iterable))) + [0, 1, 2, 3, 4] + + By default *func* is :func:`operator.sub`, but other functions can be specified. They will be applied as follows:: A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ... For example, to do progressive division: - >>> iterable = [1, 2, 6, 24, 120] + >>> iterable = [1, 2, 6, 24, 120] # Factorial sequence >>> func = lambda x, y: x // y >>> list(difference(iterable, func)) [1, 2, 3, 4, 5] - If the *initial* keyword is set, the first element will be skipped when - computing successive differences. - - >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10) - >>> list(difference(it, initial=10)) - [1, 2, 3] - """ a, b = tee(iterable) try: - first = [next(b)] + item = next(b) except StopIteration: return iter([]) - - if initial is not None: - first = [] - - return chain(first, map(func, b, a)) + return chain([item], map(lambda x: func(x[1], x[0]), zip(a, b))) class SequenceView(Sequence): @@ -2653,7 +1936,6 @@ class SequenceView(Sequence): require (much) extra storage. """ - def __init__(self, target): if not isinstance(target, Sequence): raise TypeError @@ -2669,7 +1951,7 @@ class SequenceView(Sequence): return '{}({})'.format(self.__class__.__name__, repr(self._target)) -class seekable: +class seekable(object): """Wrap an iterator to allow for seeking backward and forward. This progressively caches the items in the source iterable so they can be re-visited. @@ -2702,26 +1984,8 @@ class seekable: >>> next(it), next(it), next(it) ('0', '1', '2') - Call :meth:`peek` to look ahead one item without advancing the iterator: - - >>> it = seekable('1234') - >>> it.peek() - '1' - >>> list(it) - ['1', '2', '3', '4'] - >>> it.peek(default='empty') - 'empty' - - Before the iterator is at its end, calling :func:`bool` on it will return - ``True``. After it will return ``False``: - - >>> it = seekable('5678') - >>> bool(it) - True - >>> list(it) - ['5', '6', '7', '8'] - >>> bool(it) - False + The cache grows as the source iterable progresses, so beware of wrapping + very large or infinite iterables. You may view the contents of the cache with the :meth:`elements` method. That returns a :class:`SequenceView`, a view that updates automatically: @@ -2737,30 +2001,11 @@ class seekable: >>> elements SequenceView(['0', '1', '2', '3']) - By default, the cache grows as the source iterable progresses, so beware of - wrapping very large or infinite iterables. Supply *maxlen* to limit the - size of the cache (this of course limits how far back you can seek). - - >>> from itertools import count - >>> it = seekable((str(n) for n in count()), maxlen=2) - >>> next(it), next(it), next(it), next(it) - ('0', '1', '2', '3') - >>> list(it.elements()) - ['2', '3'] - >>> it.seek(0) - >>> next(it), next(it), next(it), next(it) - ('2', '3', '4', '5') - >>> next(it) - '6' - """ - def __init__(self, iterable, maxlen=None): + def __init__(self, iterable): self._source = iter(iterable) - if maxlen is None: - self._cache = [] - else: - self._cache = deque([], maxlen) + self._cache = [] self._index = None def __iter__(self): @@ -2780,24 +2025,7 @@ class seekable: self._cache.append(item) return item - def __bool__(self): - try: - self.peek() - except StopIteration: - return False - return True - - def peek(self, default=_marker): - try: - peeked = next(self) - except StopIteration: - if default is _marker: - raise - return default - if self._index is None: - self._index = len(self._cache) - self._index -= 1 - return peeked + next = __next__ def elements(self): return SequenceView(self._cache) @@ -2809,7 +2037,7 @@ class seekable: consume(self, remainder) -class run_length: +class run_length(object): """ :func:`run_length.encode` compresses an iterable with run-length encoding. It yields groups of repeated items with the count of how many times they @@ -2859,8 +2087,8 @@ def exactly_n(iterable, n, predicate=bool): def circular_shifts(iterable): """Return a list of circular shifts of *iterable*. - >>> circular_shifts(range(4)) - [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] + >>> circular_shifts(range(4)) + [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] """ lst = list(iterable) return take(len(lst), windowed(cycle(lst), len(lst))) @@ -3034,7 +2262,9 @@ def rlocate(iterable, pred=bool, window_size=None): if window_size is None: try: len_iter = len(iterable) - return (len_iter - i - 1 for i in locate(reversed(iterable), pred)) + return ( + len_iter - i - 1 for i in locate(reversed(iterable), pred) + ) except TypeError: pass @@ -3092,7 +2322,8 @@ def replace(iterable, pred, substitutes, count=None, window_size=1): if pred(*w): if (count is None) or (n < count): n += 1 - yield from substitutes + for s in substitutes: + yield s consume(windows, window_size - 1) continue @@ -3100,1248 +2331,3 @@ def replace(iterable, pred, substitutes, count=None, window_size=1): # yield the first item from the window. if w and (w[0] is not _marker): yield w[0] - - -def partitions(iterable): - """Yield all possible order-preserving partitions of *iterable*. - - >>> iterable = 'abc' - >>> for part in partitions(iterable): - ... print([''.join(p) for p in part]) - ['abc'] - ['a', 'bc'] - ['ab', 'c'] - ['a', 'b', 'c'] - - This is unrelated to :func:`partition`. - - """ - sequence = list(iterable) - n = len(sequence) - for i in powerset(range(1, n)): - yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))] - - -def set_partitions(iterable, k=None): - """ - Yield the set partitions of *iterable* into *k* parts. Set partitions are - not order-preserving. - - >>> iterable = 'abc' - >>> for part in set_partitions(iterable, 2): - ... print([''.join(p) for p in part]) - ['a', 'bc'] - ['ab', 'c'] - ['b', 'ac'] - - - If *k* is not given, every set partition is generated. - - >>> iterable = 'abc' - >>> for part in set_partitions(iterable): - ... print([''.join(p) for p in part]) - ['abc'] - ['a', 'bc'] - ['ab', 'c'] - ['b', 'ac'] - ['a', 'b', 'c'] - - """ - L = list(iterable) - n = len(L) - if k is not None: - if k < 1: - raise ValueError( - "Can't partition in a negative or zero number of groups" - ) - elif k > n: - return - - def set_partitions_helper(L, k): - n = len(L) - if k == 1: - yield [L] - elif n == k: - yield [[s] for s in L] - else: - e, *M = L - for p in set_partitions_helper(M, k - 1): - yield [[e], *p] - for p in set_partitions_helper(M, k): - for i in range(len(p)): - yield p[:i] + [[e] + p[i]] + p[i + 1 :] - - if k is None: - for k in range(1, n + 1): - yield from set_partitions_helper(L, k) - else: - yield from set_partitions_helper(L, k) - - -class time_limited: - """ - Yield items from *iterable* until *limit_seconds* have passed. - If the time limit expires before all items have been yielded, the - ``timed_out`` parameter will be set to ``True``. - - >>> from time import sleep - >>> def generator(): - ... yield 1 - ... yield 2 - ... sleep(0.2) - ... yield 3 - >>> iterable = time_limited(0.1, generator()) - >>> list(iterable) - [1, 2] - >>> iterable.timed_out - True - - Note that the time is checked before each item is yielded, and iteration - stops if the time elapsed is greater than *limit_seconds*. If your time - limit is 1 second, but it takes 2 seconds to generate the first item from - the iterable, the function will run for 2 seconds and not yield anything. - - """ - - def __init__(self, limit_seconds, iterable): - if limit_seconds < 0: - raise ValueError('limit_seconds must be positive') - self.limit_seconds = limit_seconds - self._iterable = iter(iterable) - self._start_time = monotonic() - self.timed_out = False - - def __iter__(self): - return self - - def __next__(self): - item = next(self._iterable) - if monotonic() - self._start_time > self.limit_seconds: - self.timed_out = True - raise StopIteration - - return item - - -def only(iterable, default=None, too_long=None): - """If *iterable* has only one item, return it. - If it has zero items, return *default*. - If it has more than one item, raise the exception given by *too_long*, - which is ``ValueError`` by default. - - >>> only([], default='missing') - 'missing' - >>> only([1]) - 1 - >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - ValueError: Expected exactly one item in iterable, but got 1, 2, - and perhaps more.' - >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - TypeError - - Note that :func:`only` attempts to advance *iterable* twice to ensure there - is only one item. See :func:`spy` or :func:`peekable` to check - iterable contents less destructively. - """ - it = iter(iterable) - first_value = next(it, default) - - try: - second_value = next(it) - except StopIteration: - pass - else: - msg = ( - 'Expected exactly one item in iterable, but got {!r}, {!r}, ' - 'and perhaps more.'.format(first_value, second_value) - ) - raise too_long or ValueError(msg) - - return first_value - - -class _IChunk: - def __init__(self, iterable, n): - self._it = islice(iterable, n) - self._cache = deque() - - def fill_cache(self): - self._cache.extend(self._it) - - def __iter__(self): - return self - - def __next__(self): - try: - return next(self._it) - except StopIteration: - if self._cache: - return self._cache.popleft() - else: - raise - - -def ichunked(iterable, n): - """Break *iterable* into sub-iterables with *n* elements each. - :func:`ichunked` is like :func:`chunked`, but it yields iterables - instead of lists. - - If the sub-iterables are read in order, the elements of *iterable* - won't be stored in memory. - If they are read out of order, :func:`itertools.tee` is used to cache - elements as necessary. - - >>> from itertools import count - >>> all_chunks = ichunked(count(), 4) - >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks) - >>> list(c_2) # c_1's elements have been cached; c_3's haven't been - [4, 5, 6, 7] - >>> list(c_1) - [0, 1, 2, 3] - >>> list(c_3) - [8, 9, 10, 11] - - """ - source = peekable(iter(iterable)) - ichunk_marker = object() - while True: - # Check to see whether we're at the end of the source iterable - item = source.peek(ichunk_marker) - if item is ichunk_marker: - return - - chunk = _IChunk(source, n) - yield chunk - - # Advance the source iterable and fill previous chunk's cache - chunk.fill_cache() - - -def iequals(*iterables): - """Return ``True`` if all given *iterables* are equal to each other, - which means that they contain the same elements in the same order. - - The function is useful for comparing iterables of different data types - or iterables that do not support equality checks. - - >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc")) - True - - >>> iequals("abc", "acb") - False - - Not to be confused with :func:`all_equals`, which checks whether all - elements of iterable are equal to each other. - - """ - return all(map(all_equal, zip_longest(*iterables, fillvalue=object()))) - - -def distinct_combinations(iterable, r): - """Yield the distinct combinations of *r* items taken from *iterable*. - - >>> list(distinct_combinations([0, 0, 1], 2)) - [(0, 0), (0, 1)] - - Equivalent to ``set(combinations(iterable))``, except duplicates are not - generated and thrown away. For larger input sequences this is much more - efficient. - - """ - if r < 0: - raise ValueError('r must be non-negative') - elif r == 0: - yield () - return - pool = tuple(iterable) - generators = [unique_everseen(enumerate(pool), key=itemgetter(1))] - current_combo = [None] * r - level = 0 - while generators: - try: - cur_idx, p = next(generators[-1]) - except StopIteration: - generators.pop() - level -= 1 - continue - current_combo[level] = p - if level + 1 == r: - yield tuple(current_combo) - else: - generators.append( - unique_everseen( - enumerate(pool[cur_idx + 1 :], cur_idx + 1), - key=itemgetter(1), - ) - ) - level += 1 - - -def filter_except(validator, iterable, *exceptions): - """Yield the items from *iterable* for which the *validator* function does - not raise one of the specified *exceptions*. - - *validator* is called for each item in *iterable*. - It should be a function that accepts one argument and raises an exception - if that item is not valid. - - >>> iterable = ['1', '2', 'three', '4', None] - >>> list(filter_except(int, iterable, ValueError, TypeError)) - ['1', '2', '4'] - - If an exception other than one given by *exceptions* is raised by - *validator*, it is raised like normal. - """ - for item in iterable: - try: - validator(item) - except exceptions: - pass - else: - yield item - - -def map_except(function, iterable, *exceptions): - """Transform each item from *iterable* with *function* and yield the - result, unless *function* raises one of the specified *exceptions*. - - *function* is called to transform each item in *iterable*. - It should accept one argument. - - >>> iterable = ['1', '2', 'three', '4', None] - >>> list(map_except(int, iterable, ValueError, TypeError)) - [1, 2, 4] - - If an exception other than one given by *exceptions* is raised by - *function*, it is raised like normal. - """ - for item in iterable: - try: - yield function(item) - except exceptions: - pass - - -def map_if(iterable, pred, func, func_else=lambda x: x): - """Evaluate each item from *iterable* using *pred*. If the result is - equivalent to ``True``, transform the item with *func* and yield it. - Otherwise, transform the item with *func_else* and yield it. - - *pred*, *func*, and *func_else* should each be functions that accept - one argument. By default, *func_else* is the identity function. - - >>> from math import sqrt - >>> iterable = list(range(-5, 5)) - >>> iterable - [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] - >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig')) - [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig'] - >>> list(map_if(iterable, lambda x: x >= 0, - ... lambda x: f'{sqrt(x):.2f}', lambda x: None)) - [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00'] - """ - for item in iterable: - yield func(item) if pred(item) else func_else(item) - - -def _sample_unweighted(iterable, k): - # Implementation of "Algorithm L" from the 1994 paper by Kim-Hung Li: - # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))". - - # Fill up the reservoir (collection of samples) with the first `k` samples - reservoir = take(k, iterable) - - # Generate random number that's the largest in a sample of k U(0,1) numbers - # Largest order statistic: https://en.wikipedia.org/wiki/Order_statistic - W = exp(log(random()) / k) - - # The number of elements to skip before changing the reservoir is a random - # number with a geometric distribution. Sample it using random() and logs. - next_index = k + floor(log(random()) / log(1 - W)) - - for index, element in enumerate(iterable, k): - - if index == next_index: - reservoir[randrange(k)] = element - # The new W is the largest in a sample of k U(0, `old_W`) numbers - W *= exp(log(random()) / k) - next_index += floor(log(random()) / log(1 - W)) + 1 - - return reservoir - - -def _sample_weighted(iterable, k, weights): - # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. : - # "Weighted random sampling with a reservoir". - - # Log-transform for numerical stability for weights that are small/large - weight_keys = (log(random()) / weight for weight in weights) - - # Fill up the reservoir (collection of samples) with the first `k` - # weight-keys and elements, then heapify the list. - reservoir = take(k, zip(weight_keys, iterable)) - heapify(reservoir) - - # The number of jumps before changing the reservoir is a random variable - # with an exponential distribution. Sample it using random() and logs. - smallest_weight_key, _ = reservoir[0] - weights_to_skip = log(random()) / smallest_weight_key - - for weight, element in zip(weights, iterable): - if weight >= weights_to_skip: - # The notation here is consistent with the paper, but we store - # the weight-keys in log-space for better numerical stability. - smallest_weight_key, _ = reservoir[0] - t_w = exp(weight * smallest_weight_key) - r_2 = uniform(t_w, 1) # generate U(t_w, 1) - weight_key = log(r_2) / weight - heapreplace(reservoir, (weight_key, element)) - smallest_weight_key, _ = reservoir[0] - weights_to_skip = log(random()) / smallest_weight_key - else: - weights_to_skip -= weight - - # Equivalent to [element for weight_key, element in sorted(reservoir)] - return [heappop(reservoir)[1] for _ in range(k)] - - -def sample(iterable, k, weights=None): - """Return a *k*-length list of elements chosen (without replacement) - from the *iterable*. Like :func:`random.sample`, but works on iterables - of unknown length. - - >>> iterable = range(100) - >>> sample(iterable, 5) # doctest: +SKIP - [81, 60, 96, 16, 4] - - An iterable with *weights* may also be given: - - >>> iterable = range(100) - >>> weights = (i * i + 1 for i in range(100)) - >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP - [79, 67, 74, 66, 78] - - The algorithm can also be used to generate weighted random permutations. - The relative weight of each item determines the probability that it - appears late in the permutation. - - >>> data = "abcdefgh" - >>> weights = range(1, len(data) + 1) - >>> sample(data, k=len(data), weights=weights) # doctest: +SKIP - ['c', 'a', 'b', 'e', 'g', 'd', 'h', 'f'] - """ - if k == 0: - return [] - - iterable = iter(iterable) - if weights is None: - return _sample_unweighted(iterable, k) - else: - weights = iter(weights) - return _sample_weighted(iterable, k, weights) - - -def is_sorted(iterable, key=None, reverse=False, strict=False): - """Returns ``True`` if the items of iterable are in sorted order, and - ``False`` otherwise. *key* and *reverse* have the same meaning that they do - in the built-in :func:`sorted` function. - - >>> is_sorted(['1', '2', '3', '4', '5'], key=int) - True - >>> is_sorted([5, 4, 3, 1, 2], reverse=True) - False - - If *strict*, tests for strict sorting, that is, returns ``False`` if equal - elements are found: - - >>> is_sorted([1, 2, 2]) - True - >>> is_sorted([1, 2, 2], strict=True) - False - - The function returns ``False`` after encountering the first out-of-order - item. If there are no out-of-order items, the iterable is exhausted. - """ - - compare = (le if reverse else ge) if strict else (lt if reverse else gt) - it = iterable if key is None else map(key, iterable) - return not any(starmap(compare, pairwise(it))) - - -class AbortThread(BaseException): - pass - - -class callback_iter: - """Convert a function that uses callbacks to an iterator. - - Let *func* be a function that takes a `callback` keyword argument. - For example: - - >>> def func(callback=None): - ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]: - ... if callback: - ... callback(i, c) - ... return 4 - - - Use ``with callback_iter(func)`` to get an iterator over the parameters - that are delivered to the callback. - - >>> with callback_iter(func) as it: - ... for args, kwargs in it: - ... print(args) - (1, 'a') - (2, 'b') - (3, 'c') - - The function will be called in a background thread. The ``done`` property - indicates whether it has completed execution. - - >>> it.done - True - - If it completes successfully, its return value will be available - in the ``result`` property. - - >>> it.result - 4 - - Notes: - - * If the function uses some keyword argument besides ``callback``, supply - *callback_kwd*. - * If it finished executing, but raised an exception, accessing the - ``result`` property will raise the same exception. - * If it hasn't finished executing, accessing the ``result`` - property from within the ``with`` block will raise ``RuntimeError``. - * If it hasn't finished executing, accessing the ``result`` property from - outside the ``with`` block will raise a - ``more_itertools.AbortThread`` exception. - * Provide *wait_seconds* to adjust how frequently the it is polled for - output. - - """ - - def __init__(self, func, callback_kwd='callback', wait_seconds=0.1): - self._func = func - self._callback_kwd = callback_kwd - self._aborted = False - self._future = None - self._wait_seconds = wait_seconds - # Lazily import concurrent.future - self._executor = __import__( - 'concurrent.futures' - ).futures.ThreadPoolExecutor(max_workers=1) - self._iterator = self._reader() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self._aborted = True - self._executor.shutdown() - - def __iter__(self): - return self - - def __next__(self): - return next(self._iterator) - - @property - def done(self): - if self._future is None: - return False - return self._future.done() - - @property - def result(self): - if not self.done: - raise RuntimeError('Function has not yet completed') - - return self._future.result() - - def _reader(self): - q = Queue() - - def callback(*args, **kwargs): - if self._aborted: - raise AbortThread('canceled by user') - - q.put((args, kwargs)) - - self._future = self._executor.submit( - self._func, **{self._callback_kwd: callback} - ) - - while True: - try: - item = q.get(timeout=self._wait_seconds) - except Empty: - pass - else: - q.task_done() - yield item - - if self._future.done(): - break - - remaining = [] - while True: - try: - item = q.get_nowait() - except Empty: - break - else: - q.task_done() - remaining.append(item) - q.join() - yield from remaining - - -def windowed_complete(iterable, n): - """ - Yield ``(beginning, middle, end)`` tuples, where: - - * Each ``middle`` has *n* items from *iterable* - * Each ``beginning`` has the items before the ones in ``middle`` - * Each ``end`` has the items after the ones in ``middle`` - - >>> iterable = range(7) - >>> n = 3 - >>> for beginning, middle, end in windowed_complete(iterable, n): - ... print(beginning, middle, end) - () (0, 1, 2) (3, 4, 5, 6) - (0,) (1, 2, 3) (4, 5, 6) - (0, 1) (2, 3, 4) (5, 6) - (0, 1, 2) (3, 4, 5) (6,) - (0, 1, 2, 3) (4, 5, 6) () - - Note that *n* must be at least 0 and most equal to the length of - *iterable*. - - This function will exhaust the iterable and may require significant - storage. - """ - if n < 0: - raise ValueError('n must be >= 0') - - seq = tuple(iterable) - size = len(seq) - - if n > size: - raise ValueError('n must be <= len(seq)') - - for i in range(size - n + 1): - beginning = seq[:i] - middle = seq[i : i + n] - end = seq[i + n :] - yield beginning, middle, end - - -def all_unique(iterable, key=None): - """ - Returns ``True`` if all the elements of *iterable* are unique (no two - elements are equal). - - >>> all_unique('ABCB') - False - - If a *key* function is specified, it will be used to make comparisons. - - >>> all_unique('ABCb') - True - >>> all_unique('ABCb', str.lower) - False - - The function returns as soon as the first non-unique element is - encountered. Iterables with a mix of hashable and unhashable items can - be used, but the function will be slower for unhashable items. - """ - seenset = set() - seenset_add = seenset.add - seenlist = [] - seenlist_add = seenlist.append - for element in map(key, iterable) if key else iterable: - try: - if element in seenset: - return False - seenset_add(element) - except TypeError: - if element in seenlist: - return False - seenlist_add(element) - return True - - -def nth_product(index, *args): - """Equivalent to ``list(product(*args))[index]``. - - The products of *args* can be ordered lexicographically. - :func:`nth_product` computes the product at sort position *index* without - computing the previous products. - - >>> nth_product(8, range(2), range(2), range(2), range(2)) - (1, 0, 0, 0) - - ``IndexError`` will be raised if the given *index* is invalid. - """ - pools = list(map(tuple, reversed(args))) - ns = list(map(len, pools)) - - c = reduce(mul, ns) - - if index < 0: - index += c - - if not 0 <= index < c: - raise IndexError - - result = [] - for pool, n in zip(pools, ns): - result.append(pool[index % n]) - index //= n - - return tuple(reversed(result)) - - -def nth_permutation(iterable, r, index): - """Equivalent to ``list(permutations(iterable, r))[index]``` - - The subsequences of *iterable* that are of length *r* where order is - important can be ordered lexicographically. :func:`nth_permutation` - computes the subsequence at sort position *index* directly, without - computing the previous subsequences. - - >>> nth_permutation('ghijk', 2, 5) - ('h', 'i') - - ``ValueError`` will be raised If *r* is negative or greater than the length - of *iterable*. - ``IndexError`` will be raised if the given *index* is invalid. - """ - pool = list(iterable) - n = len(pool) - - if r is None or r == n: - r, c = n, factorial(n) - elif not 0 <= r < n: - raise ValueError - else: - c = factorial(n) // factorial(n - r) - - if index < 0: - index += c - - if not 0 <= index < c: - raise IndexError - - if c == 0: - return tuple() - - result = [0] * r - q = index * factorial(n) // c if r < n else index - for d in range(1, n + 1): - q, i = divmod(q, d) - if 0 <= n - d < r: - result[n - d] = i - if q == 0: - break - - return tuple(map(pool.pop, result)) - - -def value_chain(*args): - """Yield all arguments passed to the function in the same order in which - they were passed. If an argument itself is iterable then iterate over its - values. - - >>> list(value_chain(1, 2, 3, [4, 5, 6])) - [1, 2, 3, 4, 5, 6] - - Binary and text strings are not considered iterable and are emitted - as-is: - - >>> list(value_chain('12', '34', ['56', '78'])) - ['12', '34', '56', '78'] - - - Multiple levels of nesting are not flattened. - - """ - for value in args: - if isinstance(value, (str, bytes)): - yield value - continue - try: - yield from value - except TypeError: - yield value - - -def product_index(element, *args): - """Equivalent to ``list(product(*args)).index(element)`` - - The products of *args* can be ordered lexicographically. - :func:`product_index` computes the first index of *element* without - computing the previous products. - - >>> product_index([8, 2], range(10), range(5)) - 42 - - ``ValueError`` will be raised if the given *element* isn't in the product - of *args*. - """ - index = 0 - - for x, pool in zip_longest(element, args, fillvalue=_marker): - if x is _marker or pool is _marker: - raise ValueError('element is not a product of args') - - pool = tuple(pool) - index = index * len(pool) + pool.index(x) - - return index - - -def combination_index(element, iterable): - """Equivalent to ``list(combinations(iterable, r)).index(element)`` - - The subsequences of *iterable* that are of length *r* can be ordered - lexicographically. :func:`combination_index` computes the index of the - first *element*, without computing the previous combinations. - - >>> combination_index('adf', 'abcdefg') - 10 - - ``ValueError`` will be raised if the given *element* isn't one of the - combinations of *iterable*. - """ - element = enumerate(element) - k, y = next(element, (None, None)) - if k is None: - return 0 - - indexes = [] - pool = enumerate(iterable) - for n, x in pool: - if x == y: - indexes.append(n) - tmp, y = next(element, (None, None)) - if tmp is None: - break - else: - k = tmp - else: - raise ValueError('element is not a combination of iterable') - - n, _ = last(pool, default=(n, None)) - - # Python versions below 3.8 don't have math.comb - index = 1 - for i, j in enumerate(reversed(indexes), start=1): - j = n - j - if i <= j: - index += factorial(j) // (factorial(i) * factorial(j - i)) - - return factorial(n + 1) // (factorial(k + 1) * factorial(n - k)) - index - - -def permutation_index(element, iterable): - """Equivalent to ``list(permutations(iterable, r)).index(element)``` - - The subsequences of *iterable* that are of length *r* where order is - important can be ordered lexicographically. :func:`permutation_index` - computes the index of the first *element* directly, without computing - the previous permutations. - - >>> permutation_index([1, 3, 2], range(5)) - 19 - - ``ValueError`` will be raised if the given *element* isn't one of the - permutations of *iterable*. - """ - index = 0 - pool = list(iterable) - for i, x in zip(range(len(pool), -1, -1), element): - r = pool.index(x) - index = index * i + r - del pool[r] - - return index - - -class countable: - """Wrap *iterable* and keep a count of how many items have been consumed. - - The ``items_seen`` attribute starts at ``0`` and increments as the iterable - is consumed: - - >>> iterable = map(str, range(10)) - >>> it = countable(iterable) - >>> it.items_seen - 0 - >>> next(it), next(it) - ('0', '1') - >>> list(it) - ['2', '3', '4', '5', '6', '7', '8', '9'] - >>> it.items_seen - 10 - """ - - def __init__(self, iterable): - self._it = iter(iterable) - self.items_seen = 0 - - def __iter__(self): - return self - - def __next__(self): - item = next(self._it) - self.items_seen += 1 - - return item - - -def chunked_even(iterable, n): - """Break *iterable* into lists of approximately length *n*. - Items are distributed such the lengths of the lists differ by at most - 1 item. - - >>> iterable = [1, 2, 3, 4, 5, 6, 7] - >>> n = 3 - >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2 - [[1, 2, 3], [4, 5], [6, 7]] - >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1 - [[1, 2, 3], [4, 5, 6], [7]] - - """ - - len_method = getattr(iterable, '__len__', None) - - if len_method is None: - return _chunked_even_online(iterable, n) - else: - return _chunked_even_finite(iterable, len_method(), n) - - -def _chunked_even_online(iterable, n): - buffer = [] - maxbuf = n + (n - 2) * (n - 1) - for x in iterable: - buffer.append(x) - if len(buffer) == maxbuf: - yield buffer[:n] - buffer = buffer[n:] - yield from _chunked_even_finite(buffer, len(buffer), n) - - -def _chunked_even_finite(iterable, N, n): - if N < 1: - return - - # Lists are either size `full_size <= n` or `partial_size = full_size - 1` - q, r = divmod(N, n) - num_lists = q + (1 if r > 0 else 0) - q, r = divmod(N, num_lists) - full_size = q + (1 if r > 0 else 0) - partial_size = full_size - 1 - num_full = N - partial_size * num_lists - num_partial = num_lists - num_full - - buffer = [] - iterator = iter(iterable) - - # Yield num_full lists of full_size - for x in iterator: - buffer.append(x) - if len(buffer) == full_size: - yield buffer - buffer = [] - num_full -= 1 - if num_full <= 0: - break - - # Yield num_partial lists of partial_size - for x in iterator: - buffer.append(x) - if len(buffer) == partial_size: - yield buffer - buffer = [] - num_partial -= 1 - - -def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False): - """A version of :func:`zip` that "broadcasts" any scalar - (i.e., non-iterable) items into output tuples. - - >>> iterable_1 = [1, 2, 3] - >>> iterable_2 = ['a', 'b', 'c'] - >>> scalar = '_' - >>> list(zip_broadcast(iterable_1, iterable_2, scalar)) - [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')] - - The *scalar_types* keyword argument determines what types are considered - scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to - treat strings and byte strings as iterable: - - >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None)) - [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')] - - If the *strict* keyword argument is ``True``, then - ``UnequalIterablesError`` will be raised if any of the iterables have - different lengths. - """ - - def is_scalar(obj): - if scalar_types and isinstance(obj, scalar_types): - return True - try: - iter(obj) - except TypeError: - return True - else: - return False - - size = len(objects) - if not size: - return - - iterables, iterable_positions = [], [] - scalars, scalar_positions = [], [] - for i, obj in enumerate(objects): - if is_scalar(obj): - scalars.append(obj) - scalar_positions.append(i) - else: - iterables.append(iter(obj)) - iterable_positions.append(i) - - if len(scalars) == size: - yield tuple(objects) - return - - zipper = _zip_equal if strict else zip - for item in zipper(*iterables): - new_item = [None] * size - - for i, elem in zip(iterable_positions, item): - new_item[i] = elem - - for i, elem in zip(scalar_positions, scalars): - new_item[i] = elem - - yield tuple(new_item) - - -def unique_in_window(iterable, n, key=None): - """Yield the items from *iterable* that haven't been seen recently. - *n* is the size of the lookback window. - - >>> iterable = [0, 1, 0, 2, 3, 0] - >>> n = 3 - >>> list(unique_in_window(iterable, n)) - [0, 1, 2, 3, 0] - - The *key* function, if provided, will be used to determine uniqueness: - - >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower())) - ['a', 'b', 'c', 'd', 'a'] - - The items in *iterable* must be hashable. - - """ - if n <= 0: - raise ValueError('n must be greater than 0') - - window = deque(maxlen=n) - uniques = set() - use_key = key is not None - - for item in iterable: - k = key(item) if use_key else item - if k in uniques: - continue - - if len(uniques) == n: - uniques.discard(window[0]) - - uniques.add(k) - window.append(k) - - yield item - - -def duplicates_everseen(iterable, key=None): - """Yield duplicate elements after their first appearance. - - >>> list(duplicates_everseen('mississippi')) - ['s', 'i', 's', 's', 'i', 'p', 'i'] - >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower)) - ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a'] - - This function is analagous to :func:`unique_everseen` and is subject to - the same performance considerations. - - """ - seen_set = set() - seen_list = [] - use_key = key is not None - - for element in iterable: - k = key(element) if use_key else element - try: - if k not in seen_set: - seen_set.add(k) - else: - yield element - except TypeError: - if k not in seen_list: - seen_list.append(k) - else: - yield element - - -def duplicates_justseen(iterable, key=None): - """Yields serially-duplicate elements after their first appearance. - - >>> list(duplicates_justseen('mississippi')) - ['s', 's', 'p'] - >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower)) - ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a'] - - This function is analagous to :func:`unique_justseen`. - - """ - return flatten( - map( - lambda group_tuple: islice_extended(group_tuple[1])[1:], - groupby(iterable, key), - ) - ) - - -def minmax(iterable_or_value, *others, key=None, default=_marker): - """Returns both the smallest and largest items in an iterable - or the largest of two or more arguments. - - >>> minmax([3, 1, 5]) - (1, 5) - - >>> minmax(4, 2, 6) - (2, 6) - - If a *key* function is provided, it will be used to transform the input - items for comparison. - - >>> minmax([5, 30], key=str) # '30' sorts before '5' - (30, 5) - - If a *default* value is provided, it will be returned if there are no - input items. - - >>> minmax([], default=(0, 0)) - (0, 0) - - Otherwise ``ValueError`` is raised. - - This function is based on the - `recipe `__ by - Raymond Hettinger and takes care to minimize the number of comparisons - performed. - """ - iterable = (iterable_or_value, *others) if others else iterable_or_value - - it = iter(iterable) - - try: - lo = hi = next(it) - except StopIteration as e: - if default is _marker: - raise ValueError( - '`minmax()` argument is an empty iterable. ' - 'Provide a `default` value to suppress this error.' - ) from e - return default - - # Different branches depending on the presence of key. This saves a lot - # of unimportant copies which would slow the "key=None" branch - # significantly down. - if key is None: - for x, y in zip_longest(it, it, fillvalue=lo): - if y < x: - x, y = y, x - if x < lo: - lo = x - if hi < y: - hi = y - - else: - lo_key = hi_key = key(lo) - - for x, y in zip_longest(it, it, fillvalue=lo): - - x_key, y_key = key(x), key(y) - - if y_key < x_key: - x, y, x_key, y_key = y, x, y_key, x_key - if x_key < lo_key: - lo, lo_key = x, x_key - if hi_key < y_key: - hi, hi_key = y, y_key - - return lo, hi - - -def constrained_batches( - iterable, max_size, max_count=None, get_len=len, strict=True -): - """Yield batches of items from *iterable* with a combined size limited by - *max_size*. - - >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1'] - >>> list(constrained_batches(iterable, 10)) - [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')] - - If a *max_count* is supplied, the number of items per batch is also - limited: - - >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1'] - >>> list(constrained_batches(iterable, 10, max_count = 2)) - [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)] - - If a *get_len* function is supplied, use that instead of :func:`len` to - determine item size. - - If *strict* is ``True``, raise ``ValueError`` if any single item is bigger - than *max_size*. Otherwise, allow single items to exceed *max_size*. - """ - if max_size <= 0: - raise ValueError('maximum size must be greater than zero') - - batch = [] - batch_size = 0 - batch_count = 0 - for item in iterable: - item_len = get_len(item) - if strict and item_len > max_size: - raise ValueError('item size exceeds maximum size') - - reached_count = batch_count == max_count - reached_size = item_len + batch_size > max_size - if batch_count and (reached_size or reached_count): - yield tuple(batch) - batch.clear() - batch_size = 0 - batch_count = 0 - - batch.append(item) - batch_size += item_len - batch_count += 1 - - if batch: - yield tuple(batch) diff --git a/libs/common/more_itertools/more.pyi b/libs/common/more_itertools/more.pyi deleted file mode 100644 index 1413fae7..00000000 --- a/libs/common/more_itertools/more.pyi +++ /dev/null @@ -1,674 +0,0 @@ -"""Stubs for more_itertools.more""" - -from typing import ( - Any, - Callable, - Container, - Dict, - Generic, - Hashable, - Iterable, - Iterator, - List, - Optional, - Reversible, - Sequence, - Sized, - Tuple, - Union, - TypeVar, - type_check_only, -) -from types import TracebackType -from typing_extensions import ContextManager, Protocol, Type, overload - -# Type and type variable definitions -_T = TypeVar('_T') -_T1 = TypeVar('_T1') -_T2 = TypeVar('_T2') -_U = TypeVar('_U') -_V = TypeVar('_V') -_W = TypeVar('_W') -_T_co = TypeVar('_T_co', covariant=True) -_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[object]]) -_Raisable = Union[BaseException, 'Type[BaseException]'] - -@type_check_only -class _SizedIterable(Protocol[_T_co], Sized, Iterable[_T_co]): ... - -@type_check_only -class _SizedReversible(Protocol[_T_co], Sized, Reversible[_T_co]): ... - -def chunked( - iterable: Iterable[_T], n: Optional[int], strict: bool = ... -) -> Iterator[List[_T]]: ... -@overload -def first(iterable: Iterable[_T]) -> _T: ... -@overload -def first(iterable: Iterable[_T], default: _U) -> Union[_T, _U]: ... -@overload -def last(iterable: Iterable[_T]) -> _T: ... -@overload -def last(iterable: Iterable[_T], default: _U) -> Union[_T, _U]: ... -@overload -def nth_or_last(iterable: Iterable[_T], n: int) -> _T: ... -@overload -def nth_or_last( - iterable: Iterable[_T], n: int, default: _U -) -> Union[_T, _U]: ... - -class peekable(Generic[_T], Iterator[_T]): - def __init__(self, iterable: Iterable[_T]) -> None: ... - def __iter__(self) -> peekable[_T]: ... - def __bool__(self) -> bool: ... - @overload - def peek(self) -> _T: ... - @overload - def peek(self, default: _U) -> Union[_T, _U]: ... - def prepend(self, *items: _T) -> None: ... - def __next__(self) -> _T: ... - @overload - def __getitem__(self, index: int) -> _T: ... - @overload - def __getitem__(self, index: slice) -> List[_T]: ... - -def consumer(func: _GenFn) -> _GenFn: ... -def ilen(iterable: Iterable[object]) -> int: ... -def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ... -def with_iter( - context_manager: ContextManager[Iterable[_T]], -) -> Iterator[_T]: ... -def one( - iterable: Iterable[_T], - too_short: Optional[_Raisable] = ..., - too_long: Optional[_Raisable] = ..., -) -> _T: ... -def raise_(exception: _Raisable, *args: Any) -> None: ... -def strictly_n( - iterable: Iterable[_T], - n: int, - too_short: Optional[_GenFn] = ..., - too_long: Optional[_GenFn] = ..., -) -> List[_T]: ... -def distinct_permutations( - iterable: Iterable[_T], r: Optional[int] = ... -) -> Iterator[Tuple[_T, ...]]: ... -def intersperse( - e: _U, iterable: Iterable[_T], n: int = ... -) -> Iterator[Union[_T, _U]]: ... -def unique_to_each(*iterables: Iterable[_T]) -> List[List[_T]]: ... -@overload -def windowed( - seq: Iterable[_T], n: int, *, step: int = ... -) -> Iterator[Tuple[Optional[_T], ...]]: ... -@overload -def windowed( - seq: Iterable[_T], n: int, fillvalue: _U, step: int = ... -) -> Iterator[Tuple[Union[_T, _U], ...]]: ... -def substrings(iterable: Iterable[_T]) -> Iterator[Tuple[_T, ...]]: ... -def substrings_indexes( - seq: Sequence[_T], reverse: bool = ... -) -> Iterator[Tuple[Sequence[_T], int, int]]: ... - -class bucket(Generic[_T, _U], Container[_U]): - def __init__( - self, - iterable: Iterable[_T], - key: Callable[[_T], _U], - validator: Optional[Callable[[object], object]] = ..., - ) -> None: ... - def __contains__(self, value: object) -> bool: ... - def __iter__(self) -> Iterator[_U]: ... - def __getitem__(self, value: object) -> Iterator[_T]: ... - -def spy( - iterable: Iterable[_T], n: int = ... -) -> Tuple[List[_T], Iterator[_T]]: ... -def interleave(*iterables: Iterable[_T]) -> Iterator[_T]: ... -def interleave_longest(*iterables: Iterable[_T]) -> Iterator[_T]: ... -def interleave_evenly( - iterables: List[Iterable[_T]], lengths: Optional[List[int]] = ... -) -> Iterator[_T]: ... -def collapse( - iterable: Iterable[Any], - base_type: Optional[type] = ..., - levels: Optional[int] = ..., -) -> Iterator[Any]: ... -@overload -def side_effect( - func: Callable[[_T], object], - iterable: Iterable[_T], - chunk_size: None = ..., - before: Optional[Callable[[], object]] = ..., - after: Optional[Callable[[], object]] = ..., -) -> Iterator[_T]: ... -@overload -def side_effect( - func: Callable[[List[_T]], object], - iterable: Iterable[_T], - chunk_size: int, - before: Optional[Callable[[], object]] = ..., - after: Optional[Callable[[], object]] = ..., -) -> Iterator[_T]: ... -def sliced( - seq: Sequence[_T], n: int, strict: bool = ... -) -> Iterator[Sequence[_T]]: ... -def split_at( - iterable: Iterable[_T], - pred: Callable[[_T], object], - maxsplit: int = ..., - keep_separator: bool = ..., -) -> Iterator[List[_T]]: ... -def split_before( - iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ... -) -> Iterator[List[_T]]: ... -def split_after( - iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ... -) -> Iterator[List[_T]]: ... -def split_when( - iterable: Iterable[_T], - pred: Callable[[_T, _T], object], - maxsplit: int = ..., -) -> Iterator[List[_T]]: ... -def split_into( - iterable: Iterable[_T], sizes: Iterable[Optional[int]] -) -> Iterator[List[_T]]: ... -@overload -def padded( - iterable: Iterable[_T], - *, - n: Optional[int] = ..., - next_multiple: bool = ..., -) -> Iterator[Optional[_T]]: ... -@overload -def padded( - iterable: Iterable[_T], - fillvalue: _U, - n: Optional[int] = ..., - next_multiple: bool = ..., -) -> Iterator[Union[_T, _U]]: ... -@overload -def repeat_last(iterable: Iterable[_T]) -> Iterator[_T]: ... -@overload -def repeat_last( - iterable: Iterable[_T], default: _U -) -> Iterator[Union[_T, _U]]: ... -def distribute(n: int, iterable: Iterable[_T]) -> List[Iterator[_T]]: ... -@overload -def stagger( - iterable: Iterable[_T], - offsets: _SizedIterable[int] = ..., - longest: bool = ..., -) -> Iterator[Tuple[Optional[_T], ...]]: ... -@overload -def stagger( - iterable: Iterable[_T], - offsets: _SizedIterable[int] = ..., - longest: bool = ..., - fillvalue: _U = ..., -) -> Iterator[Tuple[Union[_T, _U], ...]]: ... - -class UnequalIterablesError(ValueError): - def __init__( - self, details: Optional[Tuple[int, int, int]] = ... - ) -> None: ... - -@overload -def zip_equal(__iter1: Iterable[_T1]) -> Iterator[Tuple[_T1]]: ... -@overload -def zip_equal( - __iter1: Iterable[_T1], __iter2: Iterable[_T2] -) -> Iterator[Tuple[_T1, _T2]]: ... -@overload -def zip_equal( - __iter1: Iterable[_T], - __iter2: Iterable[_T], - __iter3: Iterable[_T], - *iterables: Iterable[_T], -) -> Iterator[Tuple[_T, ...]]: ... -@overload -def zip_offset( - __iter1: Iterable[_T1], - *, - offsets: _SizedIterable[int], - longest: bool = ..., - fillvalue: None = None, -) -> Iterator[Tuple[Optional[_T1]]]: ... -@overload -def zip_offset( - __iter1: Iterable[_T1], - __iter2: Iterable[_T2], - *, - offsets: _SizedIterable[int], - longest: bool = ..., - fillvalue: None = None, -) -> Iterator[Tuple[Optional[_T1], Optional[_T2]]]: ... -@overload -def zip_offset( - __iter1: Iterable[_T], - __iter2: Iterable[_T], - __iter3: Iterable[_T], - *iterables: Iterable[_T], - offsets: _SizedIterable[int], - longest: bool = ..., - fillvalue: None = None, -) -> Iterator[Tuple[Optional[_T], ...]]: ... -@overload -def zip_offset( - __iter1: Iterable[_T1], - *, - offsets: _SizedIterable[int], - longest: bool = ..., - fillvalue: _U, -) -> Iterator[Tuple[Union[_T1, _U]]]: ... -@overload -def zip_offset( - __iter1: Iterable[_T1], - __iter2: Iterable[_T2], - *, - offsets: _SizedIterable[int], - longest: bool = ..., - fillvalue: _U, -) -> Iterator[Tuple[Union[_T1, _U], Union[_T2, _U]]]: ... -@overload -def zip_offset( - __iter1: Iterable[_T], - __iter2: Iterable[_T], - __iter3: Iterable[_T], - *iterables: Iterable[_T], - offsets: _SizedIterable[int], - longest: bool = ..., - fillvalue: _U, -) -> Iterator[Tuple[Union[_T, _U], ...]]: ... -def sort_together( - iterables: Iterable[Iterable[_T]], - key_list: Iterable[int] = ..., - key: Optional[Callable[..., Any]] = ..., - reverse: bool = ..., -) -> List[Tuple[_T, ...]]: ... -def unzip(iterable: Iterable[Sequence[_T]]) -> Tuple[Iterator[_T], ...]: ... -def divide(n: int, iterable: Iterable[_T]) -> List[Iterator[_T]]: ... -def always_iterable( - obj: object, - base_type: Union[ - type, Tuple[Union[type, Tuple[Any, ...]], ...], None - ] = ..., -) -> Iterator[Any]: ... -def adjacent( - predicate: Callable[[_T], bool], - iterable: Iterable[_T], - distance: int = ..., -) -> Iterator[Tuple[bool, _T]]: ... -@overload -def groupby_transform( - iterable: Iterable[_T], - keyfunc: None = None, - valuefunc: None = None, - reducefunc: None = None, -) -> Iterator[Tuple[_T, Iterator[_T]]]: ... -@overload -def groupby_transform( - iterable: Iterable[_T], - keyfunc: Callable[[_T], _U], - valuefunc: None, - reducefunc: None, -) -> Iterator[Tuple[_U, Iterator[_T]]]: ... -@overload -def groupby_transform( - iterable: Iterable[_T], - keyfunc: None, - valuefunc: Callable[[_T], _V], - reducefunc: None, -) -> Iterable[Tuple[_T, Iterable[_V]]]: ... -@overload -def groupby_transform( - iterable: Iterable[_T], - keyfunc: Callable[[_T], _U], - valuefunc: Callable[[_T], _V], - reducefunc: None, -) -> Iterable[Tuple[_U, Iterator[_V]]]: ... -@overload -def groupby_transform( - iterable: Iterable[_T], - keyfunc: None, - valuefunc: None, - reducefunc: Callable[[Iterator[_T]], _W], -) -> Iterable[Tuple[_T, _W]]: ... -@overload -def groupby_transform( - iterable: Iterable[_T], - keyfunc: Callable[[_T], _U], - valuefunc: None, - reducefunc: Callable[[Iterator[_T]], _W], -) -> Iterable[Tuple[_U, _W]]: ... -@overload -def groupby_transform( - iterable: Iterable[_T], - keyfunc: None, - valuefunc: Callable[[_T], _V], - reducefunc: Callable[[Iterable[_V]], _W], -) -> Iterable[Tuple[_T, _W]]: ... -@overload -def groupby_transform( - iterable: Iterable[_T], - keyfunc: Callable[[_T], _U], - valuefunc: Callable[[_T], _V], - reducefunc: Callable[[Iterable[_V]], _W], -) -> Iterable[Tuple[_U, _W]]: ... - -class numeric_range(Generic[_T, _U], Sequence[_T], Hashable, Reversible[_T]): - @overload - def __init__(self, __stop: _T) -> None: ... - @overload - def __init__(self, __start: _T, __stop: _T) -> None: ... - @overload - def __init__(self, __start: _T, __stop: _T, __step: _U) -> None: ... - def __bool__(self) -> bool: ... - def __contains__(self, elem: object) -> bool: ... - def __eq__(self, other: object) -> bool: ... - @overload - def __getitem__(self, key: int) -> _T: ... - @overload - def __getitem__(self, key: slice) -> numeric_range[_T, _U]: ... - def __hash__(self) -> int: ... - def __iter__(self) -> Iterator[_T]: ... - def __len__(self) -> int: ... - def __reduce__( - self, - ) -> Tuple[Type[numeric_range[_T, _U]], Tuple[_T, _T, _U]]: ... - def __repr__(self) -> str: ... - def __reversed__(self) -> Iterator[_T]: ... - def count(self, value: _T) -> int: ... - def index(self, value: _T) -> int: ... # type: ignore - -def count_cycle( - iterable: Iterable[_T], n: Optional[int] = ... -) -> Iterable[Tuple[int, _T]]: ... -def mark_ends( - iterable: Iterable[_T], -) -> Iterable[Tuple[bool, bool, _T]]: ... -def locate( - iterable: Iterable[object], - pred: Callable[..., Any] = ..., - window_size: Optional[int] = ..., -) -> Iterator[int]: ... -def lstrip( - iterable: Iterable[_T], pred: Callable[[_T], object] -) -> Iterator[_T]: ... -def rstrip( - iterable: Iterable[_T], pred: Callable[[_T], object] -) -> Iterator[_T]: ... -def strip( - iterable: Iterable[_T], pred: Callable[[_T], object] -) -> Iterator[_T]: ... - -class islice_extended(Generic[_T], Iterator[_T]): - def __init__( - self, iterable: Iterable[_T], *args: Optional[int] - ) -> None: ... - def __iter__(self) -> islice_extended[_T]: ... - def __next__(self) -> _T: ... - def __getitem__(self, index: slice) -> islice_extended[_T]: ... - -def always_reversible(iterable: Iterable[_T]) -> Iterator[_T]: ... -def consecutive_groups( - iterable: Iterable[_T], ordering: Callable[[_T], int] = ... -) -> Iterator[Iterator[_T]]: ... -@overload -def difference( - iterable: Iterable[_T], - func: Callable[[_T, _T], _U] = ..., - *, - initial: None = ..., -) -> Iterator[Union[_T, _U]]: ... -@overload -def difference( - iterable: Iterable[_T], func: Callable[[_T, _T], _U] = ..., *, initial: _U -) -> Iterator[_U]: ... - -class SequenceView(Generic[_T], Sequence[_T]): - def __init__(self, target: Sequence[_T]) -> None: ... - @overload - def __getitem__(self, index: int) -> _T: ... - @overload - def __getitem__(self, index: slice) -> Sequence[_T]: ... - def __len__(self) -> int: ... - -class seekable(Generic[_T], Iterator[_T]): - def __init__( - self, iterable: Iterable[_T], maxlen: Optional[int] = ... - ) -> None: ... - def __iter__(self) -> seekable[_T]: ... - def __next__(self) -> _T: ... - def __bool__(self) -> bool: ... - @overload - def peek(self) -> _T: ... - @overload - def peek(self, default: _U) -> Union[_T, _U]: ... - def elements(self) -> SequenceView[_T]: ... - def seek(self, index: int) -> None: ... - -class run_length: - @staticmethod - def encode(iterable: Iterable[_T]) -> Iterator[Tuple[_T, int]]: ... - @staticmethod - def decode(iterable: Iterable[Tuple[_T, int]]) -> Iterator[_T]: ... - -def exactly_n( - iterable: Iterable[_T], n: int, predicate: Callable[[_T], object] = ... -) -> bool: ... -def circular_shifts(iterable: Iterable[_T]) -> List[Tuple[_T, ...]]: ... -def make_decorator( - wrapping_func: Callable[..., _U], result_index: int = ... -) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., _U]]]: ... -@overload -def map_reduce( - iterable: Iterable[_T], - keyfunc: Callable[[_T], _U], - valuefunc: None = ..., - reducefunc: None = ..., -) -> Dict[_U, List[_T]]: ... -@overload -def map_reduce( - iterable: Iterable[_T], - keyfunc: Callable[[_T], _U], - valuefunc: Callable[[_T], _V], - reducefunc: None = ..., -) -> Dict[_U, List[_V]]: ... -@overload -def map_reduce( - iterable: Iterable[_T], - keyfunc: Callable[[_T], _U], - valuefunc: None = ..., - reducefunc: Callable[[List[_T]], _W] = ..., -) -> Dict[_U, _W]: ... -@overload -def map_reduce( - iterable: Iterable[_T], - keyfunc: Callable[[_T], _U], - valuefunc: Callable[[_T], _V], - reducefunc: Callable[[List[_V]], _W], -) -> Dict[_U, _W]: ... -def rlocate( - iterable: Iterable[_T], - pred: Callable[..., object] = ..., - window_size: Optional[int] = ..., -) -> Iterator[int]: ... -def replace( - iterable: Iterable[_T], - pred: Callable[..., object], - substitutes: Iterable[_U], - count: Optional[int] = ..., - window_size: int = ..., -) -> Iterator[Union[_T, _U]]: ... -def partitions(iterable: Iterable[_T]) -> Iterator[List[List[_T]]]: ... -def set_partitions( - iterable: Iterable[_T], k: Optional[int] = ... -) -> Iterator[List[List[_T]]]: ... - -class time_limited(Generic[_T], Iterator[_T]): - def __init__( - self, limit_seconds: float, iterable: Iterable[_T] - ) -> None: ... - def __iter__(self) -> islice_extended[_T]: ... - def __next__(self) -> _T: ... - -@overload -def only( - iterable: Iterable[_T], *, too_long: Optional[_Raisable] = ... -) -> Optional[_T]: ... -@overload -def only( - iterable: Iterable[_T], default: _U, too_long: Optional[_Raisable] = ... -) -> Union[_T, _U]: ... -def ichunked(iterable: Iterable[_T], n: int) -> Iterator[Iterator[_T]]: ... -def distinct_combinations( - iterable: Iterable[_T], r: int -) -> Iterator[Tuple[_T, ...]]: ... -def filter_except( - validator: Callable[[Any], object], - iterable: Iterable[_T], - *exceptions: Type[BaseException], -) -> Iterator[_T]: ... -def map_except( - function: Callable[[Any], _U], - iterable: Iterable[_T], - *exceptions: Type[BaseException], -) -> Iterator[_U]: ... -def map_if( - iterable: Iterable[Any], - pred: Callable[[Any], bool], - func: Callable[[Any], Any], - func_else: Optional[Callable[[Any], Any]] = ..., -) -> Iterator[Any]: ... -def sample( - iterable: Iterable[_T], - k: int, - weights: Optional[Iterable[float]] = ..., -) -> List[_T]: ... -def is_sorted( - iterable: Iterable[_T], - key: Optional[Callable[[_T], _U]] = ..., - reverse: bool = False, - strict: bool = False, -) -> bool: ... - -class AbortThread(BaseException): - pass - -class callback_iter(Generic[_T], Iterator[_T]): - def __init__( - self, - func: Callable[..., Any], - callback_kwd: str = ..., - wait_seconds: float = ..., - ) -> None: ... - def __enter__(self) -> callback_iter[_T]: ... - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], - ) -> Optional[bool]: ... - def __iter__(self) -> callback_iter[_T]: ... - def __next__(self) -> _T: ... - def _reader(self) -> Iterator[_T]: ... - @property - def done(self) -> bool: ... - @property - def result(self) -> Any: ... - -def windowed_complete( - iterable: Iterable[_T], n: int -) -> Iterator[Tuple[_T, ...]]: ... -def all_unique( - iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ... -) -> bool: ... -def nth_product(index: int, *args: Iterable[_T]) -> Tuple[_T, ...]: ... -def nth_permutation( - iterable: Iterable[_T], r: int, index: int -) -> Tuple[_T, ...]: ... -def value_chain(*args: Union[_T, Iterable[_T]]) -> Iterable[_T]: ... -def product_index(element: Iterable[_T], *args: Iterable[_T]) -> int: ... -def combination_index( - element: Iterable[_T], iterable: Iterable[_T] -) -> int: ... -def permutation_index( - element: Iterable[_T], iterable: Iterable[_T] -) -> int: ... -def repeat_each(iterable: Iterable[_T], n: int = ...) -> Iterator[_T]: ... - -class countable(Generic[_T], Iterator[_T]): - def __init__(self, iterable: Iterable[_T]) -> None: ... - def __iter__(self) -> countable[_T]: ... - def __next__(self) -> _T: ... - -def chunked_even(iterable: Iterable[_T], n: int) -> Iterator[List[_T]]: ... -def zip_broadcast( - *objects: Union[_T, Iterable[_T]], - scalar_types: Union[ - type, Tuple[Union[type, Tuple[Any, ...]], ...], None - ] = ..., - strict: bool = ..., -) -> Iterable[Tuple[_T, ...]]: ... -def unique_in_window( - iterable: Iterable[_T], n: int, key: Optional[Callable[[_T], _U]] = ... -) -> Iterator[_T]: ... -def duplicates_everseen( - iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ... -) -> Iterator[_T]: ... -def duplicates_justseen( - iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ... -) -> Iterator[_T]: ... - -class _SupportsLessThan(Protocol): - def __lt__(self, __other: Any) -> bool: ... - -_SupportsLessThanT = TypeVar("_SupportsLessThanT", bound=_SupportsLessThan) - -@overload -def minmax( - iterable_or_value: Iterable[_SupportsLessThanT], *, key: None = None -) -> Tuple[_SupportsLessThanT, _SupportsLessThanT]: ... -@overload -def minmax( - iterable_or_value: Iterable[_T], *, key: Callable[[_T], _SupportsLessThan] -) -> Tuple[_T, _T]: ... -@overload -def minmax( - iterable_or_value: Iterable[_SupportsLessThanT], - *, - key: None = None, - default: _U, -) -> Union[_U, Tuple[_SupportsLessThanT, _SupportsLessThanT]]: ... -@overload -def minmax( - iterable_or_value: Iterable[_T], - *, - key: Callable[[_T], _SupportsLessThan], - default: _U, -) -> Union[_U, Tuple[_T, _T]]: ... -@overload -def minmax( - iterable_or_value: _SupportsLessThanT, - __other: _SupportsLessThanT, - *others: _SupportsLessThanT, -) -> Tuple[_SupportsLessThanT, _SupportsLessThanT]: ... -@overload -def minmax( - iterable_or_value: _T, - __other: _T, - *others: _T, - key: Callable[[_T], _SupportsLessThan], -) -> Tuple[_T, _T]: ... -def longest_common_prefix( - iterables: Iterable[Iterable[_T]], -) -> Iterator[_T]: ... -def iequals(*iterables: Iterable[object]) -> bool: ... -def constrained_batches( - iterable: Iterable[object], - max_size: int, - max_count: Optional[int] = ..., - get_len: Callable[[_T], object] = ..., - strict: bool = ..., -) -> Iterator[Tuple[_T]]: ... diff --git a/libs/common/more_itertools/recipes.py b/libs/common/more_itertools/recipes.py index 85796207..3b455d4e 100644 --- a/libs/common/more_itertools/recipes.py +++ b/libs/common/more_itertools/recipes.py @@ -7,33 +7,20 @@ Some backward-compatible usability improvements have been made. .. [1] http://docs.python.org/library/itertools.html#recipes """ -import math -import operator - from collections import deque -from collections.abc import Sized -from functools import reduce from itertools import ( - chain, - combinations, - compress, - count, - cycle, - groupby, - islice, - repeat, - starmap, - tee, - zip_longest, + chain, combinations, count, cycle, groupby, islice, repeat, starmap, tee ) +import operator from random import randrange, sample, choice +from six import PY2 +from six.moves import filter, filterfalse, map, range, zip, zip_longest + __all__ = [ + 'accumulate', 'all_equal', - 'batched', - 'before_and_after', 'consume', - 'convolve', 'dotproduct', 'first_true', 'flatten', @@ -43,10 +30,8 @@ __all__ = [ 'nth', 'nth_combination', 'padnone', - 'pad_none', 'pairwise', 'partition', - 'polynomial_from_roots', 'powerset', 'prepend', 'quantify', @@ -56,18 +41,42 @@ __all__ = [ 'random_product', 'repeatfunc', 'roundrobin', - 'sieve', - 'sliding_window', - 'subslices', 'tabulate', 'tail', 'take', - 'triplewise', 'unique_everseen', 'unique_justseen', ] -_marker = object() + +def accumulate(iterable, func=operator.add): + """ + Return an iterator whose items are the accumulated results of a function + (specified by the optional *func* argument) that takes two arguments. + By default, returns accumulated sums with :func:`operator.add`. + + >>> list(accumulate([1, 2, 3, 4, 5])) # Running sum + [1, 3, 6, 10, 15] + >>> list(accumulate([1, 2, 3], func=operator.mul)) # Running product + [1, 2, 6] + >>> list(accumulate([0, 1, -1, 2, 3, 2], func=max)) # Running maximum + [0, 1, 1, 2, 3, 3] + + This function is available in the ``itertools`` module for Python 3.2 and + greater. + + """ + it = iter(iterable) + try: + total = next(it) + except StopIteration: + return + else: + yield total + + for element in it: + total = func(total, element) + yield total def take(n, iterable): @@ -75,13 +84,12 @@ def take(n, iterable): >>> take(3, range(10)) [0, 1, 2] - - If there are fewer than *n* items in the iterable, all of them are - returned. - - >>> take(10, range(3)) + >>> take(5, range(3)) [0, 1, 2] + Effectively a short replacement for ``next`` based iterator consumption + when you want more than one item, but less than the whole iterator. + """ return list(islice(iterable, n)) @@ -107,19 +115,12 @@ def tabulate(function, start=0): def tail(n, iterable): """Return an iterator over the last *n* items of *iterable*. - >>> t = tail(3, 'ABCDEFG') - >>> list(t) - ['E', 'F', 'G'] + >>> t = tail(3, 'ABCDEFG') + >>> list(t) + ['E', 'F', 'G'] """ - # If the given iterable has a length, then we can use islice to get its - # final elements. Note that if the iterable is not actually Iterable, - # either islice or deque will throw a TypeError. This is why we don't - # check if it is Iterable. - if isinstance(iterable, Sized): - yield from islice(iterable, max(0, len(iterable) - n), None) - else: - yield from iter(deque(iterable, maxlen=n)) + return iter(deque(iterable, maxlen=n)) def consume(iterator, n=None): @@ -165,11 +166,11 @@ def consume(iterator, n=None): def nth(iterable, n, default=None): """Returns the nth item or a default value. - >>> l = range(10) - >>> nth(l, 3) - 3 - >>> nth(l, 20, "zebra") - 'zebra' + >>> l = range(10) + >>> nth(l, 3) + 3 + >>> nth(l, 20, "zebra") + 'zebra' """ return next(islice(iterable, n, None), default) @@ -192,17 +193,17 @@ def all_equal(iterable): def quantify(iterable, pred=bool): """Return the how many times the predicate is true. - >>> quantify([True, False, True]) - 2 + >>> quantify([True, False, True]) + 2 """ return sum(map(pred, iterable)) -def pad_none(iterable): +def padnone(iterable): """Returns the sequence of elements and then returns ``None`` indefinitely. - >>> take(5, pad_none(range(3))) + >>> take(5, padnone(range(3))) [0, 1, 2, None, None] Useful for emulating the behavior of the built-in :func:`map` function. @@ -213,14 +214,11 @@ def pad_none(iterable): return chain(iterable, repeat(None)) -padnone = pad_none - - def ncycles(iterable, n): """Returns the sequence elements *n* times - >>> list(ncycles(["a", "b"], 3)) - ['a', 'b', 'a', 'b', 'a', 'b'] + >>> list(ncycles(["a", "b"], 3)) + ['a', 'b', 'a', 'b', 'a', 'b'] """ return chain.from_iterable(repeat(tuple(iterable), n)) @@ -229,8 +227,8 @@ def ncycles(iterable, n): def dotproduct(vec1, vec2): """Returns the dot product of the two iterables. - >>> dotproduct([10, 10], [20, 20]) - 400 + >>> dotproduct([10, 10], [20, 20]) + 400 """ return sum(map(operator.mul, vec1, vec2)) @@ -275,109 +273,27 @@ def repeatfunc(func, times=None, *args): return starmap(func, repeat(args, times)) -def _pairwise(iterable): +def pairwise(iterable): """Returns an iterator of paired items, overlapping, from the original - >>> take(4, pairwise(count())) - [(0, 1), (1, 2), (2, 3), (3, 4)] - - On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`. + >>> take(4, pairwise(count())) + [(0, 1), (1, 2), (2, 3), (3, 4)] """ a, b = tee(iterable) next(b, None) - yield from zip(a, b) + return zip(a, b) -try: - from itertools import pairwise as itertools_pairwise -except ImportError: - pairwise = _pairwise -else: +def grouper(n, iterable, fillvalue=None): + """Collect data into fixed-length chunks or blocks. - def pairwise(iterable): - yield from itertools_pairwise(iterable) - - pairwise.__doc__ = _pairwise.__doc__ - - -class UnequalIterablesError(ValueError): - def __init__(self, details=None): - msg = 'Iterables have different lengths' - if details is not None: - msg += (': index 0 has length {}; index {} has length {}').format( - *details - ) - - super().__init__(msg) - - -def _zip_equal_generator(iterables): - for combo in zip_longest(*iterables, fillvalue=_marker): - for val in combo: - if val is _marker: - raise UnequalIterablesError() - yield combo - - -def _zip_equal(*iterables): - # Check whether the iterables are all the same size. - try: - first_size = len(iterables[0]) - for i, it in enumerate(iterables[1:], 1): - size = len(it) - if size != first_size: - break - else: - # If we didn't break out, we can use the built-in zip. - return zip(*iterables) - - # If we did break out, there was a mismatch. - raise UnequalIterablesError(details=(first_size, i, size)) - # If any one of the iterables didn't have a length, start reading - # them until one runs out. - except TypeError: - return _zip_equal_generator(iterables) - - -def grouper(iterable, n, incomplete='fill', fillvalue=None): - """Group elements from *iterable* into fixed-length groups of length *n*. - - >>> list(grouper('ABCDEF', 3)) - [('A', 'B', 'C'), ('D', 'E', 'F')] - - The keyword arguments *incomplete* and *fillvalue* control what happens for - iterables whose length is not a multiple of *n*. - - When *incomplete* is `'fill'`, the last group will contain instances of - *fillvalue*. - - >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x')) - [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] - - When *incomplete* is `'ignore'`, the last group will not be emitted. - - >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x')) - [('A', 'B', 'C'), ('D', 'E', 'F')] - - When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised. - - >>> it = grouper('ABCDEFG', 3, incomplete='strict') - >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ... - UnequalIterablesError + >>> list(grouper(3, 'ABCDEFG', 'x')) + [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] """ args = [iter(iterable)] * n - if incomplete == 'fill': - return zip_longest(*args, fillvalue=fillvalue) - if incomplete == 'strict': - return _zip_equal(*args) - if incomplete == 'ignore': - return zip(*args) - else: - raise ValueError('Expected fill, strict, or ignore') + return zip_longest(fillvalue=fillvalue, *args) def roundrobin(*iterables): @@ -393,7 +309,10 @@ def roundrobin(*iterables): """ # Recipe credited to George Sakkis pending = len(iterables) - nexts = cycle(iter(it).__next__ for it in iterables) + if PY2: + nexts = cycle(iter(it).next for it in iterables) + else: + nexts = cycle(iter(it).__next__ for it in iterables) while pending: try: for next in nexts: @@ -415,23 +334,10 @@ def partition(pred, iterable): >>> list(even_items), list(odd_items) ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]) - If *pred* is None, :func:`bool` is used. - - >>> iterable = [0, 1, False, True, '', ' '] - >>> false_items, true_items = partition(None, iterable) - >>> list(false_items), list(true_items) - ([0, False, ''], [1, True, ' ']) - """ - if pred is None: - pred = bool - - evaluations = ((pred(x), x) for x in iterable) - t1, t2 = tee(evaluations) - return ( - (x for (cond, x) in t1 if not cond), - (x for (cond, x) in t2 if cond), - ) + # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9 + t1, t2 = tee(iterable) + return filterfalse(pred, t1), filter(pred, t2) def powerset(iterable): @@ -469,46 +375,41 @@ def unique_everseen(iterable, key=None): Sequences with a mix of hashable and unhashable items can be used. The function will be slower (i.e., `O(n^2)`) for unhashable items. - Remember that ``list`` objects are unhashable - you can use the *key* - parameter to transform the list to a tuple (which is hashable) to - avoid a slowdown. - - >>> iterable = ([1, 2], [2, 3], [1, 2]) - >>> list(unique_everseen(iterable)) # Slow - [[1, 2], [2, 3]] - >>> list(unique_everseen(iterable, key=tuple)) # Faster - [[1, 2], [2, 3]] - - Similary, you may want to convert unhashable ``set`` objects with - ``key=frozenset``. For ``dict`` objects, - ``key=lambda x: frozenset(x.items())`` can be used. - """ seenset = set() seenset_add = seenset.add seenlist = [] seenlist_add = seenlist.append - use_key = key is not None - - for element in iterable: - k = key(element) if use_key else element - try: - if k not in seenset: - seenset_add(k) - yield element - except TypeError: - if k not in seenlist: - seenlist_add(k) - yield element + if key is None: + for element in iterable: + try: + if element not in seenset: + seenset_add(element) + yield element + except TypeError: + if element not in seenlist: + seenlist_add(element) + yield element + else: + for element in iterable: + k = key(element) + try: + if k not in seenset: + seenset_add(k) + yield element + except TypeError: + if k not in seenlist: + seenlist_add(k) + yield element def unique_justseen(iterable, key=None): """Yields elements in order, ignoring serial duplicates - >>> list(unique_justseen('AAAABBBCCDAABBB')) - ['A', 'B', 'C', 'D', 'A', 'B'] - >>> list(unique_justseen('ABBCcAD', str.lower)) - ['A', 'B', 'C', 'A', 'D'] + >>> list(unique_justseen('AAAABBBCCDAABBB')) + ['A', 'B', 'C', 'D', 'A', 'B'] + >>> list(unique_justseen('ABBCcAD', str.lower)) + ['A', 'B', 'C', 'A', 'D'] """ return map(next, map(operator.itemgetter(1), groupby(iterable, key))) @@ -525,16 +426,6 @@ def iter_except(func, exception, first=None): >>> list(iter_except(l.pop, IndexError)) [2, 1, 0] - Multiple exceptions can be specified as a stopping condition: - - >>> l = [1, 2, 3, '...', 4, 5, 6] - >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) - [7, 6, 5] - >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) - [4, 3, 2] - >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) - [] - """ try: if first is not None: @@ -565,7 +456,7 @@ def first_true(iterable, default=None, pred=None): return next(filter(pred, iterable), default) -def random_product(*args, repeat=1): +def random_product(*args, **kwds): """Draw an item at random from each of the input iterables. >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP @@ -581,7 +472,7 @@ def random_product(*args, repeat=1): ``itertools.product(*args, **kwarg)``. """ - pools = [tuple(pool) for pool in args] * repeat + pools = [tuple(pool) for pool in args] * kwds.get('repeat', 1) return tuple(choice(pool) for pool in pools) @@ -644,12 +535,6 @@ def nth_combination(iterable, r, index): sort position *index* directly, without computing the previous subsequences. - >>> nth_combination(range(5), 3, 5) - (0, 3, 4) - - ``ValueError`` will be raised If *r* is negative or greater than the length - of *iterable*. - ``IndexError`` will be raised if the given *index* is invalid. """ pool = tuple(iterable) n = len(pool) @@ -686,156 +571,7 @@ def prepend(value, iterator): >>> list(prepend(value, iterator)) ['0', '1', '2', '3'] - To prepend multiple values, see :func:`itertools.chain` - or :func:`value_chain`. + To prepend multiple values, see :func:`itertools.chain`. """ return chain([value], iterator) - - -def convolve(signal, kernel): - """Convolve the iterable *signal* with the iterable *kernel*. - - >>> signal = (1, 2, 3, 4, 5) - >>> kernel = [3, 2, 1] - >>> list(convolve(signal, kernel)) - [3, 8, 14, 20, 26, 14, 5] - - Note: the input arguments are not interchangeable, as the *kernel* - is immediately consumed and stored. - - """ - kernel = tuple(kernel)[::-1] - n = len(kernel) - window = deque([0], maxlen=n) * n - for x in chain(signal, repeat(0, n - 1)): - window.append(x) - yield sum(map(operator.mul, kernel, window)) - - -def before_and_after(predicate, it): - """A variant of :func:`takewhile` that allows complete access to the - remainder of the iterator. - - >>> it = iter('ABCdEfGhI') - >>> all_upper, remainder = before_and_after(str.isupper, it) - >>> ''.join(all_upper) - 'ABC' - >>> ''.join(remainder) # takewhile() would lose the 'd' - 'dEfGhI' - - Note that the first iterator must be fully consumed before the second - iterator can generate valid results. - """ - it = iter(it) - transition = [] - - def true_iterator(): - for elem in it: - if predicate(elem): - yield elem - else: - transition.append(elem) - return - - # Note: this is different from itertools recipes to allow nesting - # before_and_after remainders into before_and_after again. See tests - # for an example. - remainder_iterator = chain(transition, it) - - return true_iterator(), remainder_iterator - - -def triplewise(iterable): - """Return overlapping triplets from *iterable*. - - >>> list(triplewise('ABCDE')) - [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')] - - """ - for (a, _), (b, c) in pairwise(pairwise(iterable)): - yield a, b, c - - -def sliding_window(iterable, n): - """Return a sliding window of width *n* over *iterable*. - - >>> list(sliding_window(range(6), 4)) - [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)] - - If *iterable* has fewer than *n* items, then nothing is yielded: - - >>> list(sliding_window(range(3), 4)) - [] - - For a variant with more features, see :func:`windowed`. - """ - it = iter(iterable) - window = deque(islice(it, n), maxlen=n) - if len(window) == n: - yield tuple(window) - for x in it: - window.append(x) - yield tuple(window) - - -def subslices(iterable): - """Return all contiguous non-empty subslices of *iterable*. - - >>> list(subslices('ABC')) - [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']] - - This is similar to :func:`substrings`, but emits items in a different - order. - """ - seq = list(iterable) - slices = starmap(slice, combinations(range(len(seq) + 1), 2)) - return map(operator.getitem, repeat(seq), slices) - - -def polynomial_from_roots(roots): - """Compute a polynomial's coefficients from its roots. - - >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3) - >>> polynomial_from_roots(roots) # x^3 - 4 * x^2 - 17 * x + 60 - [1, -4, -17, 60] - """ - # Use math.prod for Python 3.8+, - prod = getattr(math, 'prod', lambda x: reduce(operator.mul, x, 1)) - roots = list(map(operator.neg, roots)) - return [ - sum(map(prod, combinations(roots, k))) for k in range(len(roots) + 1) - ] - - -def sieve(n): - """Yield the primes less than n. - - >>> list(sieve(30)) - [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] - """ - isqrt = getattr(math, 'isqrt', lambda x: int(math.sqrt(x))) - limit = isqrt(n) + 1 - data = bytearray([1]) * n - data[:2] = 0, 0 - for p in compress(range(limit), data): - data[p + p : n : p] = bytearray(len(range(p + p, n, p))) - - return compress(count(), data) - - -def batched(iterable, n): - """Batch data into lists of length *n*. The last batch may be shorter. - - >>> list(batched('ABCDEFG', 3)) - [['A', 'B', 'C'], ['D', 'E', 'F'], ['G']] - - This recipe is from the ``itertools`` docs. This library also provides - :func:`chunked`, which has a different implementation. - """ - it = iter(iterable) - while True: - batch = list(islice(it, n)) - if not batch: - break - yield batch diff --git a/libs/common/more_itertools/recipes.pyi b/libs/common/more_itertools/recipes.pyi deleted file mode 100644 index 29415c5a..00000000 --- a/libs/common/more_itertools/recipes.pyi +++ /dev/null @@ -1,110 +0,0 @@ -"""Stubs for more_itertools.recipes""" -from typing import ( - Any, - Callable, - Iterable, - Iterator, - List, - Optional, - Sequence, - Tuple, - TypeVar, - Union, -) -from typing_extensions import overload, Type - -# Type and type variable definitions -_T = TypeVar('_T') -_U = TypeVar('_U') - -def take(n: int, iterable: Iterable[_T]) -> List[_T]: ... -def tabulate( - function: Callable[[int], _T], start: int = ... -) -> Iterator[_T]: ... -def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: ... -def consume(iterator: Iterable[object], n: Optional[int] = ...) -> None: ... -@overload -def nth(iterable: Iterable[_T], n: int) -> Optional[_T]: ... -@overload -def nth(iterable: Iterable[_T], n: int, default: _U) -> Union[_T, _U]: ... -def all_equal(iterable: Iterable[object]) -> bool: ... -def quantify( - iterable: Iterable[_T], pred: Callable[[_T], bool] = ... -) -> int: ... -def pad_none(iterable: Iterable[_T]) -> Iterator[Optional[_T]]: ... -def padnone(iterable: Iterable[_T]) -> Iterator[Optional[_T]]: ... -def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: ... -def dotproduct(vec1: Iterable[object], vec2: Iterable[object]) -> object: ... -def flatten(listOfLists: Iterable[Iterable[_T]]) -> Iterator[_T]: ... -def repeatfunc( - func: Callable[..., _U], times: Optional[int] = ..., *args: Any -) -> Iterator[_U]: ... -def pairwise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T]]: ... -def grouper( - iterable: Iterable[_T], - n: int, - incomplete: str = ..., - fillvalue: _U = ..., -) -> Iterator[Tuple[Union[_T, _U], ...]]: ... -def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: ... -def partition( - pred: Optional[Callable[[_T], object]], iterable: Iterable[_T] -) -> Tuple[Iterator[_T], Iterator[_T]]: ... -def powerset(iterable: Iterable[_T]) -> Iterator[Tuple[_T, ...]]: ... -def unique_everseen( - iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ... -) -> Iterator[_T]: ... -def unique_justseen( - iterable: Iterable[_T], key: Optional[Callable[[_T], object]] = ... -) -> Iterator[_T]: ... -@overload -def iter_except( - func: Callable[[], _T], - exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], - first: None = ..., -) -> Iterator[_T]: ... -@overload -def iter_except( - func: Callable[[], _T], - exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], - first: Callable[[], _U], -) -> Iterator[Union[_T, _U]]: ... -@overload -def first_true( - iterable: Iterable[_T], *, pred: Optional[Callable[[_T], object]] = ... -) -> Optional[_T]: ... -@overload -def first_true( - iterable: Iterable[_T], - default: _U, - pred: Optional[Callable[[_T], object]] = ..., -) -> Union[_T, _U]: ... -def random_product( - *args: Iterable[_T], repeat: int = ... -) -> Tuple[_T, ...]: ... -def random_permutation( - iterable: Iterable[_T], r: Optional[int] = ... -) -> Tuple[_T, ...]: ... -def random_combination(iterable: Iterable[_T], r: int) -> Tuple[_T, ...]: ... -def random_combination_with_replacement( - iterable: Iterable[_T], r: int -) -> Tuple[_T, ...]: ... -def nth_combination( - iterable: Iterable[_T], r: int, index: int -) -> Tuple[_T, ...]: ... -def prepend(value: _T, iterator: Iterable[_U]) -> Iterator[Union[_T, _U]]: ... -def convolve(signal: Iterable[_T], kernel: Iterable[_T]) -> Iterator[_T]: ... -def before_and_after( - predicate: Callable[[_T], bool], it: Iterable[_T] -) -> Tuple[Iterator[_T], Iterator[_T]]: ... -def triplewise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T, _T]]: ... -def sliding_window( - iterable: Iterable[_T], n: int -) -> Iterator[Tuple[_T, ...]]: ... -def subslices(iterable: Iterable[_T]) -> Iterator[List[_T]]: ... -def polynomial_from_roots(roots: Sequence[int]) -> List[int]: ... -def sieve(n: int) -> Iterator[int]: ... -def batched( - iterable: Iterable[_T], - n: int, -) -> Iterator[List[_T]]: ... diff --git a/libs/common/more_itertools/py.typed b/libs/common/more_itertools/tests/__init__.py similarity index 100% rename from libs/common/more_itertools/py.typed rename to libs/common/more_itertools/tests/__init__.py diff --git a/libs/common/more_itertools/tests/test_more.py b/libs/common/more_itertools/tests/test_more.py new file mode 100644 index 00000000..eacf8a8a --- /dev/null +++ b/libs/common/more_itertools/tests/test_more.py @@ -0,0 +1,2313 @@ +from __future__ import division, print_function, unicode_literals + +from collections import OrderedDict +from decimal import Decimal +from doctest import DocTestSuite +from fractions import Fraction +from functools import partial, reduce +from heapq import merge +from io import StringIO +from itertools import ( + chain, + count, + groupby, + islice, + permutations, + product, + repeat, +) +from operator import add, mul, itemgetter +from unittest import TestCase + +from six.moves import filter, map, range, zip + +import more_itertools as mi + + +def load_tests(loader, tests, ignore): + # Add the doctests + tests.addTests(DocTestSuite('more_itertools.more')) + return tests + + +class CollateTests(TestCase): + """Unit tests for ``collate()``""" + # Also accidentally tests peekable, though that could use its own tests + + def test_default(self): + """Test with the default `key` function.""" + iterables = [range(4), range(7), range(3, 6)] + self.assertEqual( + sorted(reduce(list.__add__, [list(it) for it in iterables])), + list(mi.collate(*iterables)) + ) + + def test_key(self): + """Test using a custom `key` function.""" + iterables = [range(5, 0, -1), range(4, 0, -1)] + actual = sorted( + reduce(list.__add__, [list(it) for it in iterables]), reverse=True + ) + expected = list(mi.collate(*iterables, key=lambda x: -x)) + self.assertEqual(actual, expected) + + def test_empty(self): + """Be nice if passed an empty list of iterables.""" + self.assertEqual([], list(mi.collate())) + + def test_one(self): + """Work when only 1 iterable is passed.""" + self.assertEqual([0, 1], list(mi.collate(range(2)))) + + def test_reverse(self): + """Test the `reverse` kwarg.""" + iterables = [range(4, 0, -1), range(7, 0, -1), range(3, 6, -1)] + + actual = sorted( + reduce(list.__add__, [list(it) for it in iterables]), reverse=True + ) + expected = list(mi.collate(*iterables, reverse=True)) + self.assertEqual(actual, expected) + + def test_alias(self): + self.assertNotEqual(merge.__doc__, mi.collate.__doc__) + self.assertNotEqual(partial.__doc__, mi.collate.__doc__) + + +class ChunkedTests(TestCase): + """Tests for ``chunked()``""" + + def test_even(self): + """Test when ``n`` divides evenly into the length of the iterable.""" + self.assertEqual( + list(mi.chunked('ABCDEF', 3)), [['A', 'B', 'C'], ['D', 'E', 'F']] + ) + + def test_odd(self): + """Test when ``n`` does not divide evenly into the length of the + iterable. + + """ + self.assertEqual( + list(mi.chunked('ABCDE', 3)), [['A', 'B', 'C'], ['D', 'E']] + ) + + +class FirstTests(TestCase): + """Tests for ``first()``""" + + def test_many(self): + """Test that it works on many-item iterables.""" + # Also try it on a generator expression to make sure it works on + # whatever those return, across Python versions. + self.assertEqual(mi.first(x for x in range(4)), 0) + + def test_one(self): + """Test that it doesn't raise StopIteration prematurely.""" + self.assertEqual(mi.first([3]), 3) + + def test_empty_stop_iteration(self): + """It should raise StopIteration for empty iterables.""" + self.assertRaises(ValueError, lambda: mi.first([])) + + def test_default(self): + """It should return the provided default arg for empty iterables.""" + self.assertEqual(mi.first([], 'boo'), 'boo') + + +class IterOnlyRange: + """User-defined iterable class which only support __iter__. + + It is not specified to inherit ``object``, so indexing on a instance will + raise an ``AttributeError`` rather than ``TypeError`` in Python 2. + + >>> r = IterOnlyRange(5) + >>> r[0] + AttributeError: IterOnlyRange instance has no attribute '__getitem__' + + Note: In Python 3, ``TypeError`` will be raised because ``object`` is + inherited implicitly by default. + + >>> r[0] + TypeError: 'IterOnlyRange' object does not support indexing + """ + def __init__(self, n): + """Set the length of the range.""" + self.n = n + + def __iter__(self): + """Works same as range().""" + return iter(range(self.n)) + + +class LastTests(TestCase): + """Tests for ``last()``""" + + def test_many_nonsliceable(self): + """Test that it works on many-item non-slice-able iterables.""" + # Also try it on a generator expression to make sure it works on + # whatever those return, across Python versions. + self.assertEqual(mi.last(x for x in range(4)), 3) + + def test_one_nonsliceable(self): + """Test that it doesn't raise StopIteration prematurely.""" + self.assertEqual(mi.last(x for x in range(1)), 0) + + def test_empty_stop_iteration_nonsliceable(self): + """It should raise ValueError for empty non-slice-able iterables.""" + self.assertRaises(ValueError, lambda: mi.last(x for x in range(0))) + + def test_default_nonsliceable(self): + """It should return the provided default arg for empty non-slice-able + iterables. + """ + self.assertEqual(mi.last((x for x in range(0)), 'boo'), 'boo') + + def test_many_sliceable(self): + """Test that it works on many-item slice-able iterables.""" + self.assertEqual(mi.last([0, 1, 2, 3]), 3) + + def test_one_sliceable(self): + """Test that it doesn't raise StopIteration prematurely.""" + self.assertEqual(mi.last([3]), 3) + + def test_empty_stop_iteration_sliceable(self): + """It should raise ValueError for empty slice-able iterables.""" + self.assertRaises(ValueError, lambda: mi.last([])) + + def test_default_sliceable(self): + """It should return the provided default arg for empty slice-able + iterables. + """ + self.assertEqual(mi.last([], 'boo'), 'boo') + + def test_dict(self): + """last(dic) and last(dic.keys()) should return same result.""" + dic = {'a': 1, 'b': 2, 'c': 3} + self.assertEqual(mi.last(dic), mi.last(dic.keys())) + + def test_ordereddict(self): + """last(dic) should return the last key.""" + od = OrderedDict() + od['a'] = 1 + od['b'] = 2 + od['c'] = 3 + self.assertEqual(mi.last(od), 'c') + + def test_customrange(self): + """It should work on custom class where [] raises AttributeError.""" + self.assertEqual(mi.last(IterOnlyRange(5)), 4) + + +class PeekableTests(TestCase): + """Tests for ``peekable()`` behavor not incidentally covered by testing + ``collate()`` + + """ + def test_peek_default(self): + """Make sure passing a default into ``peek()`` works.""" + p = mi.peekable([]) + self.assertEqual(p.peek(7), 7) + + def test_truthiness(self): + """Make sure a ``peekable`` tests true iff there are items remaining in + the iterable. + + """ + p = mi.peekable([]) + self.assertFalse(p) + + p = mi.peekable(range(3)) + self.assertTrue(p) + + def test_simple_peeking(self): + """Make sure ``next`` and ``peek`` advance and don't advance the + iterator, respectively. + + """ + p = mi.peekable(range(10)) + self.assertEqual(next(p), 0) + self.assertEqual(p.peek(), 1) + self.assertEqual(next(p), 1) + + def test_indexing(self): + """ + Indexing into the peekable shouldn't advance the iterator. + """ + p = mi.peekable('abcdefghijkl') + + # The 0th index is what ``next()`` will return + self.assertEqual(p[0], 'a') + self.assertEqual(next(p), 'a') + + # Indexing further into the peekable shouldn't advance the itertor + self.assertEqual(p[2], 'd') + self.assertEqual(next(p), 'b') + + # The 0th index moves up with the iterator; the last index follows + self.assertEqual(p[0], 'c') + self.assertEqual(p[9], 'l') + + self.assertEqual(next(p), 'c') + self.assertEqual(p[8], 'l') + + # Negative indexing should work too + self.assertEqual(p[-2], 'k') + self.assertEqual(p[-9], 'd') + self.assertRaises(IndexError, lambda: p[-10]) + + def test_slicing(self): + """Slicing the peekable shouldn't advance the iterator.""" + seq = list('abcdefghijkl') + p = mi.peekable(seq) + + # Slicing the peekable should just be like slicing a re-iterable + self.assertEqual(p[1:4], seq[1:4]) + + # Advancing the iterator moves the slices up also + self.assertEqual(next(p), 'a') + self.assertEqual(p[1:4], seq[1:][1:4]) + + # Implicit starts and stop should work + self.assertEqual(p[:5], seq[1:][:5]) + self.assertEqual(p[:], seq[1:][:]) + + # Indexing past the end should work + self.assertEqual(p[:100], seq[1:][:100]) + + # Steps should work, including negative + self.assertEqual(p[::2], seq[1:][::2]) + self.assertEqual(p[::-1], seq[1:][::-1]) + + def test_slicing_reset(self): + """Test slicing on a fresh iterable each time""" + iterable = ['0', '1', '2', '3', '4', '5'] + indexes = list(range(-4, len(iterable) + 4)) + [None] + steps = [1, 2, 3, 4, -1, -2, -3, 4] + for slice_args in product(indexes, indexes, steps): + it = iter(iterable) + p = mi.peekable(it) + next(p) + index = slice(*slice_args) + actual = p[index] + expected = iterable[1:][index] + self.assertEqual(actual, expected, slice_args) + + def test_slicing_error(self): + iterable = '01234567' + p = mi.peekable(iter(iterable)) + + # Prime the cache + p.peek() + old_cache = list(p._cache) + + # Illegal slice + with self.assertRaises(ValueError): + p[1:-1:0] + + # Neither the cache nor the iteration should be affected + self.assertEqual(old_cache, list(p._cache)) + self.assertEqual(list(p), list(iterable)) + + def test_passthrough(self): + """Iterating a peekable without using ``peek()`` or ``prepend()`` + should just give the underlying iterable's elements (a trivial test but + useful to set a baseline in case something goes wrong)""" + expected = [1, 2, 3, 4, 5] + actual = list(mi.peekable(expected)) + self.assertEqual(actual, expected) + + # prepend() behavior tests + + def test_prepend(self): + """Tests intersperesed ``prepend()`` and ``next()`` calls""" + it = mi.peekable(range(2)) + actual = [] + + # Test prepend() before next() + it.prepend(10) + actual += [next(it), next(it)] + + # Test prepend() between next()s + it.prepend(11) + actual += [next(it), next(it)] + + # Test prepend() after source iterable is consumed + it.prepend(12) + actual += [next(it)] + + expected = [10, 0, 11, 1, 12] + self.assertEqual(actual, expected) + + def test_multi_prepend(self): + """Tests prepending multiple items and getting them in proper order""" + it = mi.peekable(range(5)) + actual = [next(it), next(it)] + it.prepend(10, 11, 12) + it.prepend(20, 21) + actual += list(it) + expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4] + self.assertEqual(actual, expected) + + def test_empty(self): + """Tests prepending in front of an empty iterable""" + it = mi.peekable([]) + it.prepend(10) + actual = list(it) + expected = [10] + self.assertEqual(actual, expected) + + def test_prepend_truthiness(self): + """Tests that ``__bool__()`` or ``__nonzero__()`` works properly + with ``prepend()``""" + it = mi.peekable(range(5)) + self.assertTrue(it) + actual = list(it) + self.assertFalse(it) + it.prepend(10) + self.assertTrue(it) + actual += [next(it)] + self.assertFalse(it) + expected = [0, 1, 2, 3, 4, 10] + self.assertEqual(actual, expected) + + def test_multi_prepend_peek(self): + """Tests prepending multiple elements and getting them in reverse order + while peeking""" + it = mi.peekable(range(5)) + actual = [next(it), next(it)] + self.assertEqual(it.peek(), 2) + it.prepend(10, 11, 12) + self.assertEqual(it.peek(), 10) + it.prepend(20, 21) + self.assertEqual(it.peek(), 20) + actual += list(it) + self.assertFalse(it) + expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4] + self.assertEqual(actual, expected) + + def test_prepend_after_stop(self): + """Test resuming iteration after a previous exhaustion""" + it = mi.peekable(range(3)) + self.assertEqual(list(it), [0, 1, 2]) + self.assertRaises(StopIteration, lambda: next(it)) + it.prepend(10) + self.assertEqual(next(it), 10) + self.assertRaises(StopIteration, lambda: next(it)) + + def test_prepend_slicing(self): + """Tests interaction between prepending and slicing""" + seq = list(range(20)) + p = mi.peekable(seq) + + p.prepend(30, 40, 50) + pseq = [30, 40, 50] + seq # pseq for prepended_seq + + # adapt the specific tests from test_slicing + self.assertEqual(p[0], 30) + self.assertEqual(p[1:8], pseq[1:8]) + self.assertEqual(p[1:], pseq[1:]) + self.assertEqual(p[:5], pseq[:5]) + self.assertEqual(p[:], pseq[:]) + self.assertEqual(p[:100], pseq[:100]) + self.assertEqual(p[::2], pseq[::2]) + self.assertEqual(p[::-1], pseq[::-1]) + + def test_prepend_indexing(self): + """Tests interaction between prepending and indexing""" + seq = list(range(20)) + p = mi.peekable(seq) + + p.prepend(30, 40, 50) + + self.assertEqual(p[0], 30) + self.assertEqual(next(p), 30) + self.assertEqual(p[2], 0) + self.assertEqual(next(p), 40) + self.assertEqual(p[0], 50) + self.assertEqual(p[9], 8) + self.assertEqual(next(p), 50) + self.assertEqual(p[8], 8) + self.assertEqual(p[-2], 18) + self.assertEqual(p[-9], 11) + self.assertRaises(IndexError, lambda: p[-21]) + + def test_prepend_iterable(self): + """Tests prepending from an iterable""" + it = mi.peekable(range(5)) + # Don't directly use the range() object to avoid any range-specific + # optimizations + it.prepend(*(x for x in range(5))) + actual = list(it) + expected = list(chain(range(5), range(5))) + self.assertEqual(actual, expected) + + def test_prepend_many(self): + """Tests that prepending a huge number of elements works""" + it = mi.peekable(range(5)) + # Don't directly use the range() object to avoid any range-specific + # optimizations + it.prepend(*(x for x in range(20000))) + actual = list(it) + expected = list(chain(range(20000), range(5))) + self.assertEqual(actual, expected) + + def test_prepend_reversed(self): + """Tests prepending from a reversed iterable""" + it = mi.peekable(range(3)) + it.prepend(*reversed((10, 11, 12))) + actual = list(it) + expected = [12, 11, 10, 0, 1, 2] + self.assertEqual(actual, expected) + + +class ConsumerTests(TestCase): + """Tests for ``consumer()``""" + + def test_consumer(self): + @mi.consumer + def eater(): + while True: + x = yield # noqa + + e = eater() + e.send('hi') # without @consumer, would raise TypeError + + +class DistinctPermutationsTests(TestCase): + def test_distinct_permutations(self): + """Make sure the output for ``distinct_permutations()`` is the same as + set(permutations(it)). + + """ + iterable = ['z', 'a', 'a', 'q', 'q', 'q', 'y'] + test_output = sorted(mi.distinct_permutations(iterable)) + ref_output = sorted(set(permutations(iterable))) + self.assertEqual(test_output, ref_output) + + def test_other_iterables(self): + """Make sure ``distinct_permutations()`` accepts a different type of + iterables. + + """ + # a generator + iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y']) + test_output = sorted(mi.distinct_permutations(iterable)) + # "reload" it + iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y']) + ref_output = sorted(set(permutations(iterable))) + self.assertEqual(test_output, ref_output) + + # an iterator + iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y']) + test_output = sorted(mi.distinct_permutations(iterable)) + # "reload" it + iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y']) + ref_output = sorted(set(permutations(iterable))) + self.assertEqual(test_output, ref_output) + + +class IlenTests(TestCase): + def test_ilen(self): + """Sanity-checks for ``ilen()``.""" + # Non-empty + self.assertEqual( + mi.ilen(filter(lambda x: x % 10 == 0, range(101))), 11 + ) + + # Empty + self.assertEqual(mi.ilen((x for x in range(0))), 0) + + # Iterable with __len__ + self.assertEqual(mi.ilen(list(range(6))), 6) + + +class WithIterTests(TestCase): + def test_with_iter(self): + s = StringIO('One fish\nTwo fish') + initial_words = [line.split()[0] for line in mi.with_iter(s)] + + # Iterable's items should be faithfully represented + self.assertEqual(initial_words, ['One', 'Two']) + # The file object should be closed + self.assertTrue(s.closed) + + +class OneTests(TestCase): + def test_basic(self): + it = iter(['item']) + self.assertEqual(mi.one(it), 'item') + + def test_too_short(self): + it = iter([]) + self.assertRaises(ValueError, lambda: mi.one(it)) + self.assertRaises(IndexError, lambda: mi.one(it, too_short=IndexError)) + + def test_too_long(self): + it = count() + self.assertRaises(ValueError, lambda: mi.one(it)) # burn 0 and 1 + self.assertEqual(next(it), 2) + self.assertRaises( + OverflowError, lambda: mi.one(it, too_long=OverflowError) + ) + + +class IntersperseTest(TestCase): + """ Tests for intersperse() """ + + def test_even(self): + iterable = (x for x in '01') + self.assertEqual( + list(mi.intersperse(None, iterable)), ['0', None, '1'] + ) + + def test_odd(self): + iterable = (x for x in '012') + self.assertEqual( + list(mi.intersperse(None, iterable)), ['0', None, '1', None, '2'] + ) + + def test_nested(self): + element = ('a', 'b') + iterable = (x for x in '012') + actual = list(mi.intersperse(element, iterable)) + expected = ['0', ('a', 'b'), '1', ('a', 'b'), '2'] + self.assertEqual(actual, expected) + + def test_not_iterable(self): + self.assertRaises(TypeError, lambda: mi.intersperse('x', 1)) + + def test_n(self): + for n, element, expected in [ + (1, '_', ['0', '_', '1', '_', '2', '_', '3', '_', '4', '_', '5']), + (2, '_', ['0', '1', '_', '2', '3', '_', '4', '5']), + (3, '_', ['0', '1', '2', '_', '3', '4', '5']), + (4, '_', ['0', '1', '2', '3', '_', '4', '5']), + (5, '_', ['0', '1', '2', '3', '4', '_', '5']), + (6, '_', ['0', '1', '2', '3', '4', '5']), + (7, '_', ['0', '1', '2', '3', '4', '5']), + (3, ['a', 'b'], ['0', '1', '2', ['a', 'b'], '3', '4', '5']), + ]: + iterable = (x for x in '012345') + actual = list(mi.intersperse(element, iterable, n=n)) + self.assertEqual(actual, expected) + + def test_n_zero(self): + self.assertRaises( + ValueError, lambda: list(mi.intersperse('x', '012', n=0)) + ) + + +class UniqueToEachTests(TestCase): + """Tests for ``unique_to_each()``""" + + def test_all_unique(self): + """When all the input iterables are unique the output should match + the input.""" + iterables = [[1, 2], [3, 4, 5], [6, 7, 8]] + self.assertEqual(mi.unique_to_each(*iterables), iterables) + + def test_duplicates(self): + """When there are duplicates in any of the input iterables that aren't + in the rest, those duplicates should be emitted.""" + iterables = ["mississippi", "missouri"] + self.assertEqual( + mi.unique_to_each(*iterables), [['p', 'p'], ['o', 'u', 'r']] + ) + + def test_mixed(self): + """When the input iterables contain different types the function should + still behave properly""" + iterables = ['x', (i for i in range(3)), [1, 2, 3], tuple()] + self.assertEqual(mi.unique_to_each(*iterables), [['x'], [0], [3], []]) + + +class WindowedTests(TestCase): + """Tests for ``windowed()``""" + + def test_basic(self): + actual = list(mi.windowed([1, 2, 3, 4, 5], 3)) + expected = [(1, 2, 3), (2, 3, 4), (3, 4, 5)] + self.assertEqual(actual, expected) + + def test_large_size(self): + """ + When the window size is larger than the iterable, and no fill value is + given,``None`` should be filled in. + """ + actual = list(mi.windowed([1, 2, 3, 4, 5], 6)) + expected = [(1, 2, 3, 4, 5, None)] + self.assertEqual(actual, expected) + + def test_fillvalue(self): + """ + When sizes don't match evenly, the given fill value should be used. + """ + iterable = [1, 2, 3, 4, 5] + + for n, kwargs, expected in [ + (6, {}, [(1, 2, 3, 4, 5, '!')]), # n > len(iterable) + (3, {'step': 3}, [(1, 2, 3), (4, 5, '!')]), # using ``step`` + ]: + actual = list(mi.windowed(iterable, n, fillvalue='!', **kwargs)) + self.assertEqual(actual, expected) + + def test_zero(self): + """When the window size is zero, an empty tuple should be emitted.""" + actual = list(mi.windowed([1, 2, 3, 4, 5], 0)) + expected = [tuple()] + self.assertEqual(actual, expected) + + def test_negative(self): + """When the window size is negative, ValueError should be raised.""" + with self.assertRaises(ValueError): + list(mi.windowed([1, 2, 3, 4, 5], -1)) + + def test_step(self): + """The window should advance by the number of steps provided""" + iterable = [1, 2, 3, 4, 5, 6, 7] + for n, step, expected in [ + (3, 2, [(1, 2, 3), (3, 4, 5), (5, 6, 7)]), # n > step + (3, 3, [(1, 2, 3), (4, 5, 6), (7, None, None)]), # n == step + (3, 4, [(1, 2, 3), (5, 6, 7)]), # line up nicely + (3, 5, [(1, 2, 3), (6, 7, None)]), # off by one + (3, 6, [(1, 2, 3), (7, None, None)]), # off by two + (3, 7, [(1, 2, 3)]), # step past the end + (7, 8, [(1, 2, 3, 4, 5, 6, 7)]), # step > len(iterable) + ]: + actual = list(mi.windowed(iterable, n, step=step)) + self.assertEqual(actual, expected) + + # Step must be greater than or equal to 1 + with self.assertRaises(ValueError): + list(mi.windowed(iterable, 3, step=0)) + + +class SubstringsTests(TestCase): + def test_basic(self): + iterable = (x for x in range(4)) + actual = list(mi.substrings(iterable)) + expected = [ + (0,), + (1,), + (2,), + (3,), + (0, 1), + (1, 2), + (2, 3), + (0, 1, 2), + (1, 2, 3), + (0, 1, 2, 3), + ] + self.assertEqual(actual, expected) + + def test_strings(self): + iterable = 'abc' + actual = list(mi.substrings(iterable)) + expected = [ + ('a',), ('b',), ('c',), ('a', 'b'), ('b', 'c'), ('a', 'b', 'c') + ] + self.assertEqual(actual, expected) + + def test_empty(self): + iterable = iter([]) + actual = list(mi.substrings(iterable)) + expected = [] + self.assertEqual(actual, expected) + + def test_order(self): + iterable = [2, 0, 1] + actual = list(mi.substrings(iterable)) + expected = [(2,), (0,), (1,), (2, 0), (0, 1), (2, 0, 1)] + self.assertEqual(actual, expected) + + +class BucketTests(TestCase): + """Tests for ``bucket()``""" + + def test_basic(self): + iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33] + D = mi.bucket(iterable, key=lambda x: 10 * (x // 10)) + + # In-order access + self.assertEqual(list(D[10]), [10, 11, 12]) + + # Out of order access + self.assertEqual(list(D[30]), [30, 31, 33]) + self.assertEqual(list(D[20]), [20, 21, 22, 23]) + + self.assertEqual(list(D[40]), []) # Nothing in here! + + def test_in(self): + iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33] + D = mi.bucket(iterable, key=lambda x: 10 * (x // 10)) + + self.assertIn(10, D) + self.assertNotIn(40, D) + self.assertIn(20, D) + self.assertNotIn(21, D) + + # Checking in-ness shouldn't advance the iterator + self.assertEqual(next(D[10]), 10) + + def test_validator(self): + iterable = count(0) + key = lambda x: int(str(x)[0]) # First digit of each number + validator = lambda x: 0 < x < 10 # No leading zeros + D = mi.bucket(iterable, key, validator=validator) + self.assertEqual(mi.take(3, D[1]), [1, 10, 11]) + self.assertNotIn(0, D) # Non-valid entries don't return True + self.assertNotIn(0, D._cache) # Don't store non-valid entries + self.assertEqual(list(D[0]), []) + + +class SpyTests(TestCase): + """Tests for ``spy()``""" + + def test_basic(self): + original_iterable = iter('abcdefg') + head, new_iterable = mi.spy(original_iterable) + self.assertEqual(head, ['a']) + self.assertEqual( + list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + ) + + def test_unpacking(self): + original_iterable = iter('abcdefg') + (first, second, third), new_iterable = mi.spy(original_iterable, 3) + self.assertEqual(first, 'a') + self.assertEqual(second, 'b') + self.assertEqual(third, 'c') + self.assertEqual( + list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + ) + + def test_too_many(self): + original_iterable = iter('abc') + head, new_iterable = mi.spy(original_iterable, 4) + self.assertEqual(head, ['a', 'b', 'c']) + self.assertEqual(list(new_iterable), ['a', 'b', 'c']) + + def test_zero(self): + original_iterable = iter('abc') + head, new_iterable = mi.spy(original_iterable, 0) + self.assertEqual(head, []) + self.assertEqual(list(new_iterable), ['a', 'b', 'c']) + + +class InterleaveTests(TestCase): + def test_even(self): + actual = list(mi.interleave([1, 4, 7], [2, 5, 8], [3, 6, 9])) + expected = [1, 2, 3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_short(self): + actual = list(mi.interleave([1, 4], [2, 5, 7], [3, 6, 8])) + expected = [1, 2, 3, 4, 5, 6] + self.assertEqual(actual, expected) + + def test_mixed_types(self): + it_list = ['a', 'b', 'c', 'd'] + it_str = '12345' + it_inf = count() + actual = list(mi.interleave(it_list, it_str, it_inf)) + expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', 3] + self.assertEqual(actual, expected) + + +class InterleaveLongestTests(TestCase): + def test_even(self): + actual = list(mi.interleave_longest([1, 4, 7], [2, 5, 8], [3, 6, 9])) + expected = [1, 2, 3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_short(self): + actual = list(mi.interleave_longest([1, 4], [2, 5, 7], [3, 6, 8])) + expected = [1, 2, 3, 4, 5, 6, 7, 8] + self.assertEqual(actual, expected) + + def test_mixed_types(self): + it_list = ['a', 'b', 'c', 'd'] + it_str = '12345' + it_gen = (x for x in range(3)) + actual = list(mi.interleave_longest(it_list, it_str, it_gen)) + expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', '5'] + self.assertEqual(actual, expected) + + +class TestCollapse(TestCase): + """Tests for ``collapse()``""" + + def test_collapse(self): + l = [[1], 2, [[3], 4], [[[5]]]] + self.assertEqual(list(mi.collapse(l)), [1, 2, 3, 4, 5]) + + def test_collapse_to_string(self): + l = [["s1"], "s2", [["s3"], "s4"], [[["s5"]]]] + self.assertEqual(list(mi.collapse(l)), ["s1", "s2", "s3", "s4", "s5"]) + + def test_collapse_flatten(self): + l = [[1], [2], [[3], 4], [[[5]]]] + self.assertEqual(list(mi.collapse(l, levels=1)), list(mi.flatten(l))) + + def test_collapse_to_level(self): + l = [[1], 2, [[3], 4], [[[5]]]] + self.assertEqual(list(mi.collapse(l, levels=2)), [1, 2, 3, 4, [5]]) + self.assertEqual( + list(mi.collapse(mi.collapse(l, levels=1), levels=1)), + list(mi.collapse(l, levels=2)) + ) + + def test_collapse_to_list(self): + l = (1, [2], (3, [4, (5,)], 'ab')) + actual = list(mi.collapse(l, base_type=list)) + expected = [1, [2], 3, [4, (5,)], 'ab'] + self.assertEqual(actual, expected) + + +class SideEffectTests(TestCase): + """Tests for ``side_effect()``""" + + def test_individual(self): + # The function increments the counter for each call + counter = [0] + + def func(arg): + counter[0] += 1 + + result = list(mi.side_effect(func, range(10))) + self.assertEqual(result, list(range(10))) + self.assertEqual(counter[0], 10) + + def test_chunked(self): + # The function increments the counter for each call + counter = [0] + + def func(arg): + counter[0] += 1 + + result = list(mi.side_effect(func, range(10), 2)) + self.assertEqual(result, list(range(10))) + self.assertEqual(counter[0], 5) + + def test_before_after(self): + f = StringIO() + collector = [] + + def func(item): + print(item, file=f) + collector.append(f.getvalue()) + + def it(): + yield 'a' + yield 'b' + raise RuntimeError('kaboom') + + before = lambda: print('HEADER', file=f) + after = f.close + + try: + mi.consume(mi.side_effect(func, it(), before=before, after=after)) + except RuntimeError: + pass + + # The iterable should have been written to the file + self.assertEqual(collector, ['HEADER\na\n', 'HEADER\na\nb\n']) + + # The file should be closed even though something bad happened + self.assertTrue(f.closed) + + def test_before_fails(self): + f = StringIO() + func = lambda x: print(x, file=f) + + def before(): + raise RuntimeError('ouch') + + try: + mi.consume( + mi.side_effect(func, 'abc', before=before, after=f.close) + ) + except RuntimeError: + pass + + # The file should be closed even though something bad happened in the + # before function + self.assertTrue(f.closed) + + +class SlicedTests(TestCase): + """Tests for ``sliced()``""" + + def test_even(self): + """Test when the length of the sequence is divisible by *n*""" + seq = 'ABCDEFGHI' + self.assertEqual(list(mi.sliced(seq, 3)), ['ABC', 'DEF', 'GHI']) + + def test_odd(self): + """Test when the length of the sequence is not divisible by *n*""" + seq = 'ABCDEFGHI' + self.assertEqual(list(mi.sliced(seq, 4)), ['ABCD', 'EFGH', 'I']) + + def test_not_sliceable(self): + seq = (x for x in 'ABCDEFGHI') + + with self.assertRaises(TypeError): + list(mi.sliced(seq, 3)) + + +class SplitAtTests(TestCase): + """Tests for ``split()``""" + + def comp_with_str_split(self, str_to_split, delim): + pred = lambda c: c == delim + actual = list(map(''.join, mi.split_at(str_to_split, pred))) + expected = str_to_split.split(delim) + self.assertEqual(actual, expected) + + def test_seperators(self): + test_strs = ['', 'abcba', 'aaabbbcccddd', 'e'] + for s, delim in product(test_strs, 'abcd'): + self.comp_with_str_split(s, delim) + + +class SplitBeforeTest(TestCase): + """Tests for ``split_before()``""" + + def test_starts_with_sep(self): + actual = list(mi.split_before('xooxoo', lambda c: c == 'x')) + expected = [['x', 'o', 'o'], ['x', 'o', 'o']] + self.assertEqual(actual, expected) + + def test_ends_with_sep(self): + actual = list(mi.split_before('ooxoox', lambda c: c == 'x')) + expected = [['o', 'o'], ['x', 'o', 'o'], ['x']] + self.assertEqual(actual, expected) + + def test_no_sep(self): + actual = list(mi.split_before('ooo', lambda c: c == 'x')) + expected = [['o', 'o', 'o']] + self.assertEqual(actual, expected) + + +class SplitAfterTest(TestCase): + """Tests for ``split_after()``""" + + def test_starts_with_sep(self): + actual = list(mi.split_after('xooxoo', lambda c: c == 'x')) + expected = [['x'], ['o', 'o', 'x'], ['o', 'o']] + self.assertEqual(actual, expected) + + def test_ends_with_sep(self): + actual = list(mi.split_after('ooxoox', lambda c: c == 'x')) + expected = [['o', 'o', 'x'], ['o', 'o', 'x']] + self.assertEqual(actual, expected) + + def test_no_sep(self): + actual = list(mi.split_after('ooo', lambda c: c == 'x')) + expected = [['o', 'o', 'o']] + self.assertEqual(actual, expected) + + +class SplitIntoTests(TestCase): + """Tests for ``split_into()``""" + + def test_iterable_just_right(self): + """Size of ``iterable`` equals the sum of ``sizes``.""" + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] + sizes = [2, 3, 4] + expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9]] + actual = list(mi.split_into(iterable, sizes)) + self.assertEqual(actual, expected) + + def test_iterable_too_small(self): + """Size of ``iterable`` is smaller than sum of ``sizes``. Last return + list is shorter as a result.""" + iterable = [1, 2, 3, 4, 5, 6, 7] + sizes = [2, 3, 4] + expected = [[1, 2], [3, 4, 5], [6, 7]] + actual = list(mi.split_into(iterable, sizes)) + self.assertEqual(actual, expected) + + def test_iterable_too_small_extra(self): + """Size of ``iterable`` is smaller than sum of ``sizes``. Second last + return list is shorter and last return list is empty as a result.""" + iterable = [1, 2, 3, 4, 5, 6, 7] + sizes = [2, 3, 4, 5] + expected = [[1, 2], [3, 4, 5], [6, 7], []] + actual = list(mi.split_into(iterable, sizes)) + self.assertEqual(actual, expected) + + def test_iterable_too_large(self): + """Size of ``iterable`` is larger than sum of ``sizes``. Not all + items of iterable are returned.""" + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] + sizes = [2, 3, 2] + expected = [[1, 2], [3, 4, 5], [6, 7]] + actual = list(mi.split_into(iterable, sizes)) + self.assertEqual(actual, expected) + + def test_using_none_with_leftover(self): + """Last item of ``sizes`` is None when items still remain in + ``iterable``. Last list returned stretches to fit all remaining items + of ``iterable``.""" + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] + sizes = [2, 3, None] + expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9]] + actual = list(mi.split_into(iterable, sizes)) + self.assertEqual(actual, expected) + + def test_using_none_without_leftover(self): + """Last item of ``sizes`` is None when no items remain in + ``iterable``. Last list returned is empty.""" + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] + sizes = [2, 3, 4, None] + expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9], []] + actual = list(mi.split_into(iterable, sizes)) + self.assertEqual(actual, expected) + + def test_using_none_mid_sizes(self): + """None is present in ``sizes`` but is not the last item. Last list + returned stretches to fit all remaining items of ``iterable`` but + all items in ``sizes`` after None are ignored.""" + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] + sizes = [2, 3, None, 4] + expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9]] + actual = list(mi.split_into(iterable, sizes)) + self.assertEqual(actual, expected) + + def test_iterable_empty(self): + """``iterable`` argument is empty but ``sizes`` is not. An empty + list is returned for each item in ``sizes``.""" + iterable = [] + sizes = [2, 4, 2] + expected = [[], [], []] + actual = list(mi.split_into(iterable, sizes)) + self.assertEqual(actual, expected) + + def test_iterable_empty_using_none(self): + """``iterable`` argument is empty but ``sizes`` is not. An empty + list is returned for each item in ``sizes`` that is not after a + None item.""" + iterable = [] + sizes = [2, 4, None, 2] + expected = [[], [], []] + actual = list(mi.split_into(iterable, sizes)) + self.assertEqual(actual, expected) + + def test_sizes_empty(self): + """``sizes`` argument is empty but ``iterable`` is not. An empty + generator is returned.""" + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] + sizes = [] + expected = [] + actual = list(mi.split_into(iterable, sizes)) + self.assertEqual(actual, expected) + + def test_both_empty(self): + """Both ``sizes`` and ``iterable`` arguments are empty. An empty + generator is returned.""" + iterable = [] + sizes = [] + expected = [] + actual = list(mi.split_into(iterable, sizes)) + self.assertEqual(actual, expected) + + def test_bool_in_sizes(self): + """A bool object is present in ``sizes`` is treated as a 1 or 0 for + ``True`` or ``False`` due to bool being an instance of int.""" + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] + sizes = [3, True, 2, False] + expected = [[1, 2, 3], [4], [5, 6], []] + actual = list(mi.split_into(iterable, sizes)) + self.assertEqual(actual, expected) + + def test_invalid_in_sizes(self): + """A ValueError is raised if an object in ``sizes`` is neither ``None`` + or an integer.""" + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] + sizes = [1, [], 3] + with self.assertRaises(ValueError): + list(mi.split_into(iterable, sizes)) + + def test_invalid_in_sizes_after_none(self): + """A item in ``sizes`` that is invalid will not raise a TypeError if it + comes after a ``None`` item.""" + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] + sizes = [3, 4, None, []] + expected = [[1, 2, 3], [4, 5, 6, 7], [8, 9]] + actual = list(mi.split_into(iterable, sizes)) + self.assertEqual(actual, expected) + + def test_generator_iterable_integrity(self): + """Check that if ``iterable`` is an iterator, it is consumed only by as + many items as the sum of ``sizes``.""" + iterable = (i for i in range(10)) + sizes = [2, 3] + + expected = [[0, 1], [2, 3, 4]] + actual = list(mi.split_into(iterable, sizes)) + self.assertEqual(actual, expected) + + iterable_expected = [5, 6, 7, 8, 9] + iterable_actual = list(iterable) + self.assertEqual(iterable_actual, iterable_expected) + + def test_generator_sizes_integrity(self): + """Check that if ``sizes`` is an iterator, it is consumed only until a + ``None`` item is reached""" + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] + sizes = (i for i in [1, 2, None, 3, 4]) + + expected = [[1], [2, 3], [4, 5, 6, 7, 8, 9]] + actual = list(mi.split_into(iterable, sizes)) + self.assertEqual(actual, expected) + + sizes_expected = [3, 4] + sizes_actual = list(sizes) + self.assertEqual(sizes_actual, sizes_expected) + + +class PaddedTest(TestCase): + """Tests for ``padded()``""" + + def test_no_n(self): + seq = [1, 2, 3] + + # No fillvalue + self.assertEqual(mi.take(5, mi.padded(seq)), [1, 2, 3, None, None]) + + # With fillvalue + self.assertEqual( + mi.take(5, mi.padded(seq, fillvalue='')), [1, 2, 3, '', ''] + ) + + def test_invalid_n(self): + self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=-1))) + self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=0))) + + def test_valid_n(self): + seq = [1, 2, 3, 4, 5] + + # No need for padding: len(seq) <= n + self.assertEqual(list(mi.padded(seq, n=4)), [1, 2, 3, 4, 5]) + self.assertEqual(list(mi.padded(seq, n=5)), [1, 2, 3, 4, 5]) + + # No fillvalue + self.assertEqual( + list(mi.padded(seq, n=7)), [1, 2, 3, 4, 5, None, None] + ) + + # With fillvalue + self.assertEqual( + list(mi.padded(seq, fillvalue='', n=7)), [1, 2, 3, 4, 5, '', ''] + ) + + def test_next_multiple(self): + seq = [1, 2, 3, 4, 5, 6] + + # No need for padding: len(seq) % n == 0 + self.assertEqual( + list(mi.padded(seq, n=3, next_multiple=True)), [1, 2, 3, 4, 5, 6] + ) + + # Padding needed: len(seq) < n + self.assertEqual( + list(mi.padded(seq, n=8, next_multiple=True)), + [1, 2, 3, 4, 5, 6, None, None] + ) + + # No padding needed: len(seq) == n + self.assertEqual( + list(mi.padded(seq, n=6, next_multiple=True)), [1, 2, 3, 4, 5, 6] + ) + + # Padding needed: len(seq) > n + self.assertEqual( + list(mi.padded(seq, n=4, next_multiple=True)), + [1, 2, 3, 4, 5, 6, None, None] + ) + + # With fillvalue + self.assertEqual( + list(mi.padded(seq, fillvalue='', n=4, next_multiple=True)), + [1, 2, 3, 4, 5, 6, '', ''] + ) + + +class DistributeTest(TestCase): + """Tests for distribute()""" + + def test_invalid_n(self): + self.assertRaises(ValueError, lambda: mi.distribute(-1, [1, 2, 3])) + self.assertRaises(ValueError, lambda: mi.distribute(0, [1, 2, 3])) + + def test_basic(self): + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + for n, expected in [ + (1, [iterable]), + (2, [[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]), + (3, [[1, 4, 7, 10], [2, 5, 8], [3, 6, 9]]), + (10, [[n] for n in range(1, 10 + 1)]), + ]: + self.assertEqual( + [list(x) for x in mi.distribute(n, iterable)], expected + ) + + def test_large_n(self): + iterable = [1, 2, 3, 4] + self.assertEqual( + [list(x) for x in mi.distribute(6, iterable)], + [[1], [2], [3], [4], [], []] + ) + + +class StaggerTest(TestCase): + """Tests for ``stagger()``""" + + def test_default(self): + iterable = [0, 1, 2, 3] + actual = list(mi.stagger(iterable)) + expected = [(None, 0, 1), (0, 1, 2), (1, 2, 3)] + self.assertEqual(actual, expected) + + def test_offsets(self): + iterable = [0, 1, 2, 3] + for offsets, expected in [ + ((-2, 0, 2), [('', 0, 2), ('', 1, 3)]), + ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3)]), + ((1, 2), [(1, 2), (2, 3)]), + ]: + all_groups = mi.stagger(iterable, offsets=offsets, fillvalue='') + self.assertEqual(list(all_groups), expected) + + def test_longest(self): + iterable = [0, 1, 2, 3] + for offsets, expected in [ + ( + (-1, 0, 1), + [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, ''), (3, '', '')] + ), + ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3), (3, '')]), + ((1, 2), [(1, 2), (2, 3), (3, '')]), + ]: + all_groups = mi.stagger( + iterable, offsets=offsets, fillvalue='', longest=True + ) + self.assertEqual(list(all_groups), expected) + + +class ZipOffsetTest(TestCase): + """Tests for ``zip_offset()``""" + + def test_shortest(self): + a_1 = [0, 1, 2, 3] + a_2 = [0, 1, 2, 3, 4, 5] + a_3 = [0, 1, 2, 3, 4, 5, 6, 7] + actual = list( + mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), fillvalue='') + ) + expected = [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5)] + self.assertEqual(actual, expected) + + def test_longest(self): + a_1 = [0, 1, 2, 3] + a_2 = [0, 1, 2, 3, 4, 5] + a_3 = [0, 1, 2, 3, 4, 5, 6, 7] + actual = list( + mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), longest=True) + ) + expected = [ + (None, 0, 1), + (0, 1, 2), + (1, 2, 3), + (2, 3, 4), + (3, 4, 5), + (None, 5, 6), + (None, None, 7), + ] + self.assertEqual(actual, expected) + + def test_mismatch(self): + iterables = [0, 1, 2], [2, 3, 4] + offsets = (-1, 0, 1) + self.assertRaises( + ValueError, + lambda: list(mi.zip_offset(*iterables, offsets=offsets)) + ) + + +class UnzipTests(TestCase): + """Tests for unzip()""" + + def test_empty_iterable(self): + self.assertEqual(list(mi.unzip([])), []) + # in reality zip([], [], []) is equivalent to iter([]) + # but it doesn't hurt to test both + self.assertEqual(list(mi.unzip(zip([], [], []))), []) + + def test_length_one_iterable(self): + xs, ys, zs = mi.unzip(zip([1], [2], [3])) + self.assertEqual(list(xs), [1]) + self.assertEqual(list(ys), [2]) + self.assertEqual(list(zs), [3]) + + def test_normal_case(self): + xs, ys, zs = range(10), range(1, 11), range(2, 12) + zipped = zip(xs, ys, zs) + xs, ys, zs = mi.unzip(zipped) + self.assertEqual(list(xs), list(range(10))) + self.assertEqual(list(ys), list(range(1, 11))) + self.assertEqual(list(zs), list(range(2, 12))) + + def test_improperly_zipped(self): + zipped = iter([(1, 2, 3), (4, 5), (6,)]) + xs, ys, zs = mi.unzip(zipped) + self.assertEqual(list(xs), [1, 4, 6]) + self.assertEqual(list(ys), [2, 5]) + self.assertEqual(list(zs), [3]) + + def test_increasingly_zipped(self): + zipped = iter([(1, 2), (3, 4, 5), (6, 7, 8, 9)]) + unzipped = mi.unzip(zipped) + # from the docstring: + # len(first tuple) is the number of iterables zipped + self.assertEqual(len(unzipped), 2) + xs, ys = unzipped + self.assertEqual(list(xs), [1, 3, 6]) + self.assertEqual(list(ys), [2, 4, 7]) + + +class SortTogetherTest(TestCase): + """Tests for sort_together()""" + + def test_key_list(self): + """tests `key_list` including default, iterables include duplicates""" + iterables = [ + ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20] + ] + + self.assertEqual( + mi.sort_together(iterables), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('June', 'July', 'July', 'May', 'Aug.', 'May'), + (70, 100, 20, 97, 20, 100) + ] + ) + + self.assertEqual( + mi.sort_together(iterables, key_list=(0, 1)), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('July', 'July', 'June', 'Aug.', 'May', 'May'), + (100, 20, 70, 20, 97, 100) + ] + ) + + self.assertEqual( + mi.sort_together(iterables, key_list=(0, 1, 2)), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('July', 'July', 'June', 'Aug.', 'May', 'May'), + (20, 100, 70, 20, 97, 100) + ] + ) + + self.assertEqual( + mi.sort_together(iterables, key_list=(2,)), + [ + ('GA', 'CT', 'CT', 'GA', 'GA', 'CT'), + ('Aug.', 'July', 'June', 'May', 'May', 'July'), + (20, 20, 70, 97, 100, 100) + ] + ) + + def test_invalid_key_list(self): + """tests `key_list` for indexes not available in `iterables`""" + iterables = [ + ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20] + ] + + self.assertRaises( + IndexError, lambda: mi.sort_together(iterables, key_list=(5,)) + ) + + def test_reverse(self): + """tests `reverse` to ensure a reverse sort for `key_list` iterables""" + iterables = [ + ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20] + ] + + self.assertEqual( + mi.sort_together(iterables, key_list=(0, 1, 2), reverse=True), + [('GA', 'GA', 'GA', 'CT', 'CT', 'CT'), + ('May', 'May', 'Aug.', 'June', 'July', 'July'), + (100, 97, 20, 70, 100, 20)] + ) + + def test_uneven_iterables(self): + """tests trimming of iterables to the shortest length before sorting""" + iterables = [['GA', 'GA', 'GA', 'CT', 'CT', 'CT', 'MA'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20, 0]] + + self.assertEqual( + mi.sort_together(iterables), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('June', 'July', 'July', 'May', 'Aug.', 'May'), + (70, 100, 20, 97, 20, 100) + ] + ) + + +class DivideTest(TestCase): + """Tests for divide()""" + + def test_invalid_n(self): + self.assertRaises(ValueError, lambda: mi.divide(-1, [1, 2, 3])) + self.assertRaises(ValueError, lambda: mi.divide(0, [1, 2, 3])) + + def test_basic(self): + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + for n, expected in [ + (1, [iterable]), + (2, [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), + (3, [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10]]), + (10, [[n] for n in range(1, 10 + 1)]), + ]: + self.assertEqual( + [list(x) for x in mi.divide(n, iterable)], expected + ) + + def test_large_n(self): + iterable = [1, 2, 3, 4] + self.assertEqual( + [list(x) for x in mi.divide(6, iterable)], + [[1], [2], [3], [4], [], []] + ) + + +class TestAlwaysIterable(TestCase): + """Tests for always_iterable()""" + def test_single(self): + self.assertEqual(list(mi.always_iterable(1)), [1]) + + def test_strings(self): + for obj in ['foo', b'bar', 'baz']: + actual = list(mi.always_iterable(obj)) + expected = [obj] + self.assertEqual(actual, expected) + + def test_base_type(self): + dict_obj = {'a': 1, 'b': 2} + str_obj = '123' + + # Default: dicts are iterable like they normally are + default_actual = list(mi.always_iterable(dict_obj)) + default_expected = list(dict_obj) + self.assertEqual(default_actual, default_expected) + + # Unitary types set: dicts are not iterable + custom_actual = list(mi.always_iterable(dict_obj, base_type=dict)) + custom_expected = [dict_obj] + self.assertEqual(custom_actual, custom_expected) + + # With unitary types set, strings are iterable + str_actual = list(mi.always_iterable(str_obj, base_type=None)) + str_expected = list(str_obj) + self.assertEqual(str_actual, str_expected) + + def test_iterables(self): + self.assertEqual(list(mi.always_iterable([0, 1])), [0, 1]) + self.assertEqual( + list(mi.always_iterable([0, 1], base_type=list)), [[0, 1]] + ) + self.assertEqual( + list(mi.always_iterable(iter('foo'))), ['f', 'o', 'o'] + ) + self.assertEqual(list(mi.always_iterable([])), []) + + def test_none(self): + self.assertEqual(list(mi.always_iterable(None)), []) + + def test_generator(self): + def _gen(): + yield 0 + yield 1 + + self.assertEqual(list(mi.always_iterable(_gen())), [0, 1]) + + +class AdjacentTests(TestCase): + def test_typical(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10))) + expected = [(True, 0), (True, 1), (False, 2), (False, 3), (True, 4), + (True, 5), (True, 6), (False, 7), (False, 8), (False, 9)] + self.assertEqual(actual, expected) + + def test_empty_iterable(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, [])) + expected = [] + self.assertEqual(actual, expected) + + def test_length_one(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, [0])) + expected = [(True, 0)] + self.assertEqual(actual, expected) + + actual = list(mi.adjacent(lambda x: x % 5 == 0, [1])) + expected = [(False, 1)] + self.assertEqual(actual, expected) + + def test_consecutive_true(self): + """Test that when the predicate matches multiple consecutive elements + it doesn't repeat elements in the output""" + actual = list(mi.adjacent(lambda x: x % 5 < 2, range(10))) + expected = [(True, 0), (True, 1), (True, 2), (False, 3), (True, 4), + (True, 5), (True, 6), (True, 7), (False, 8), (False, 9)] + self.assertEqual(actual, expected) + + def test_distance(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=2)) + expected = [(True, 0), (True, 1), (True, 2), (True, 3), (True, 4), + (True, 5), (True, 6), (True, 7), (False, 8), (False, 9)] + self.assertEqual(actual, expected) + + actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=3)) + expected = [(True, 0), (True, 1), (True, 2), (True, 3), (True, 4), + (True, 5), (True, 6), (True, 7), (True, 8), (False, 9)] + self.assertEqual(actual, expected) + + def test_large_distance(self): + """Test distance larger than the length of the iterable""" + iterable = range(10) + actual = list(mi.adjacent(lambda x: x % 5 == 4, iterable, distance=20)) + expected = list(zip(repeat(True), iterable)) + self.assertEqual(actual, expected) + + actual = list(mi.adjacent(lambda x: False, iterable, distance=20)) + expected = list(zip(repeat(False), iterable)) + self.assertEqual(actual, expected) + + def test_zero_distance(self): + """Test that adjacent() reduces to zip+map when distance is 0""" + iterable = range(1000) + predicate = lambda x: x % 4 == 2 + actual = mi.adjacent(predicate, iterable, 0) + expected = zip(map(predicate, iterable), iterable) + self.assertTrue(all(a == e for a, e in zip(actual, expected))) + + def test_negative_distance(self): + """Test that adjacent() raises an error with negative distance""" + pred = lambda x: x + self.assertRaises( + ValueError, lambda: mi.adjacent(pred, range(1000), -1) + ) + self.assertRaises( + ValueError, lambda: mi.adjacent(pred, range(10), -10) + ) + + def test_grouping(self): + """Test interaction of adjacent() with groupby_transform()""" + iterable = mi.adjacent(lambda x: x % 5 == 0, range(10)) + grouper = mi.groupby_transform(iterable, itemgetter(0), itemgetter(1)) + actual = [(k, list(g)) for k, g in grouper] + expected = [ + (True, [0, 1]), + (False, [2, 3]), + (True, [4, 5, 6]), + (False, [7, 8, 9]), + ] + self.assertEqual(actual, expected) + + def test_call_once(self): + """Test that the predicate is only called once per item.""" + already_seen = set() + iterable = range(10) + + def predicate(item): + self.assertNotIn(item, already_seen) + already_seen.add(item) + return True + + actual = list(mi.adjacent(predicate, iterable)) + expected = [(True, x) for x in iterable] + self.assertEqual(actual, expected) + + +class GroupByTransformTests(TestCase): + def assertAllGroupsEqual(self, groupby1, groupby2): + """Compare two groupby objects for equality, both keys and groups.""" + for a, b in zip(groupby1, groupby2): + key1, group1 = a + key2, group2 = b + self.assertEqual(key1, key2) + self.assertListEqual(list(group1), list(group2)) + self.assertRaises(StopIteration, lambda: next(groupby1)) + self.assertRaises(StopIteration, lambda: next(groupby2)) + + def test_default_funcs(self): + """Test that groupby_transform() with default args mimics groupby()""" + iterable = [(x // 5, x) for x in range(1000)] + actual = mi.groupby_transform(iterable) + expected = groupby(iterable) + self.assertAllGroupsEqual(actual, expected) + + def test_valuefunc(self): + iterable = [(int(x / 5), int(x / 3), x) for x in range(10)] + + # Test the standard usage of grouping one iterable using another's keys + grouper = mi.groupby_transform( + iterable, keyfunc=itemgetter(0), valuefunc=itemgetter(-1) + ) + actual = [(k, list(g)) for k, g in grouper] + expected = [(0, [0, 1, 2, 3, 4]), (1, [5, 6, 7, 8, 9])] + self.assertEqual(actual, expected) + + grouper = mi.groupby_transform( + iterable, keyfunc=itemgetter(1), valuefunc=itemgetter(-1) + ) + actual = [(k, list(g)) for k, g in grouper] + expected = [(0, [0, 1, 2]), (1, [3, 4, 5]), (2, [6, 7, 8]), (3, [9])] + self.assertEqual(actual, expected) + + # and now for something a little different + d = dict(zip(range(10), 'abcdefghij')) + grouper = mi.groupby_transform( + range(10), keyfunc=lambda x: x // 5, valuefunc=d.get + ) + actual = [(k, ''.join(g)) for k, g in grouper] + expected = [(0, 'abcde'), (1, 'fghij')] + self.assertEqual(actual, expected) + + def test_no_valuefunc(self): + iterable = range(1000) + + def key(x): + return x // 5 + + actual = mi.groupby_transform(iterable, key, valuefunc=None) + expected = groupby(iterable, key) + self.assertAllGroupsEqual(actual, expected) + + actual = mi.groupby_transform(iterable, key) # default valuefunc + expected = groupby(iterable, key) + self.assertAllGroupsEqual(actual, expected) + + +class NumericRangeTests(TestCase): + def test_basic(self): + for args, expected in [ + ((4,), [0, 1, 2, 3]), + ((4.0,), [0.0, 1.0, 2.0, 3.0]), + ((1.0, 4), [1.0, 2.0, 3.0]), + ((1, 4.0), [1, 2, 3]), + ((1.0, 5), [1.0, 2.0, 3.0, 4.0]), + ((0, 20, 5), [0, 5, 10, 15]), + ((0, 20, 5.0), [0.0, 5.0, 10.0, 15.0]), + ((0, 10, 3), [0, 3, 6, 9]), + ((0, 10, 3.0), [0.0, 3.0, 6.0, 9.0]), + ((0, -5, -1), [0, -1, -2, -3, -4]), + ((0.0, -5, -1), [0.0, -1.0, -2.0, -3.0, -4.0]), + ((1, 2, Fraction(1, 2)), [Fraction(1, 1), Fraction(3, 2)]), + ((0,), []), + ((0.0,), []), + ((1, 0), []), + ((1.0, 0.0), []), + ((Fraction(2, 1),), [Fraction(0, 1), Fraction(1, 1)]), + ((Decimal('2.0'),), [Decimal('0.0'), Decimal('1.0')]), + ]: + actual = list(mi.numeric_range(*args)) + self.assertEqual(actual, expected) + self.assertTrue( + all(type(a) == type(e) for a, e in zip(actual, expected)) + ) + + def test_arg_count(self): + self.assertRaises(TypeError, lambda: list(mi.numeric_range())) + self.assertRaises( + TypeError, lambda: list(mi.numeric_range(0, 1, 2, 3)) + ) + + def test_zero_step(self): + self.assertRaises( + ValueError, lambda: list(mi.numeric_range(1, 2, 0)) + ) + + +class CountCycleTests(TestCase): + def test_basic(self): + expected = [ + (0, 'a'), (0, 'b'), (0, 'c'), + (1, 'a'), (1, 'b'), (1, 'c'), + (2, 'a'), (2, 'b'), (2, 'c'), + ] + for actual in [ + mi.take(9, mi.count_cycle('abc')), # n=None + list(mi.count_cycle('abc', 3)), # n=3 + ]: + self.assertEqual(actual, expected) + + def test_empty(self): + self.assertEqual(list(mi.count_cycle('')), []) + self.assertEqual(list(mi.count_cycle('', 2)), []) + + def test_negative(self): + self.assertEqual(list(mi.count_cycle('abc', -3)), []) + + +class LocateTests(TestCase): + def test_default_pred(self): + iterable = [0, 1, 1, 0, 1, 0, 0] + actual = list(mi.locate(iterable)) + expected = [1, 2, 4] + self.assertEqual(actual, expected) + + def test_no_matches(self): + iterable = [0, 0, 0] + actual = list(mi.locate(iterable)) + expected = [] + self.assertEqual(actual, expected) + + def test_custom_pred(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda x: x == '0' + actual = list(mi.locate(iterable, pred)) + expected = [0, 3, 5, 6] + self.assertEqual(actual, expected) + + def test_window_size(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda *args: args == ('0', 1) + actual = list(mi.locate(iterable, pred, window_size=2)) + expected = [0, 3] + self.assertEqual(actual, expected) + + def test_window_size_large(self): + iterable = [1, 2, 3, 4] + pred = lambda a, b, c, d, e: True + actual = list(mi.locate(iterable, pred, window_size=5)) + expected = [0] + self.assertEqual(actual, expected) + + def test_window_size_zero(self): + iterable = [1, 2, 3, 4] + pred = lambda: True + with self.assertRaises(ValueError): + list(mi.locate(iterable, pred, window_size=0)) + + +class StripFunctionTests(TestCase): + def test_hashable(self): + iterable = list('www.example.com') + pred = lambda x: x in set('cmowz.') + + self.assertEqual(list(mi.lstrip(iterable, pred)), list('example.com')) + self.assertEqual(list(mi.rstrip(iterable, pred)), list('www.example')) + self.assertEqual(list(mi.strip(iterable, pred)), list('example')) + + def test_not_hashable(self): + iterable = [ + list('http://'), list('www'), list('.example'), list('.com') + ] + pred = lambda x: x in [list('http://'), list('www'), list('.com')] + + self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[2:]) + self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:3]) + self.assertEqual(list(mi.strip(iterable, pred)), iterable[2: 3]) + + def test_math(self): + iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] + pred = lambda x: x <= 2 + + self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[3:]) + self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:-3]) + self.assertEqual(list(mi.strip(iterable, pred)), iterable[3:-3]) + + +class IsliceExtendedTests(TestCase): + def test_all(self): + iterable = ['0', '1', '2', '3', '4', '5'] + indexes = list(range(-4, len(iterable) + 4)) + [None] + steps = [1, 2, 3, 4, -1, -2, -3, 4] + for slice_args in product(indexes, indexes, steps): + try: + actual = list(mi.islice_extended(iterable, *slice_args)) + except Exception as e: + self.fail((slice_args, e)) + + expected = iterable[slice(*slice_args)] + self.assertEqual(actual, expected, slice_args) + + def test_zero_step(self): + with self.assertRaises(ValueError): + list(mi.islice_extended([1, 2, 3], 0, 1, 0)) + + +class ConsecutiveGroupsTest(TestCase): + def test_numbers(self): + iterable = [-10, -8, -7, -6, 1, 2, 4, 5, -1, 7] + actual = [list(g) for g in mi.consecutive_groups(iterable)] + expected = [[-10], [-8, -7, -6], [1, 2], [4, 5], [-1], [7]] + self.assertEqual(actual, expected) + + def test_custom_ordering(self): + iterable = ['1', '10', '11', '20', '21', '22', '30', '31'] + ordering = lambda x: int(x) + actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)] + expected = [['1'], ['10', '11'], ['20', '21', '22'], ['30', '31']] + self.assertEqual(actual, expected) + + def test_exotic_ordering(self): + iterable = [ + ('a', 'b', 'c', 'd'), + ('a', 'c', 'b', 'd'), + ('a', 'c', 'd', 'b'), + ('a', 'd', 'b', 'c'), + ('d', 'b', 'c', 'a'), + ('d', 'c', 'a', 'b'), + ] + ordering = list(permutations('abcd')).index + actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)] + expected = [ + [('a', 'b', 'c', 'd')], + [('a', 'c', 'b', 'd'), ('a', 'c', 'd', 'b'), ('a', 'd', 'b', 'c')], + [('d', 'b', 'c', 'a'), ('d', 'c', 'a', 'b')], + ] + self.assertEqual(actual, expected) + + +class DifferenceTest(TestCase): + def test_normal(self): + iterable = [10, 20, 30, 40, 50] + actual = list(mi.difference(iterable)) + expected = [10, 10, 10, 10, 10] + self.assertEqual(actual, expected) + + def test_custom(self): + iterable = [10, 20, 30, 40, 50] + actual = list(mi.difference(iterable, add)) + expected = [10, 30, 50, 70, 90] + self.assertEqual(actual, expected) + + def test_roundtrip(self): + original = list(range(100)) + accumulated = mi.accumulate(original) + actual = list(mi.difference(accumulated)) + self.assertEqual(actual, original) + + def test_one(self): + self.assertEqual(list(mi.difference([0])), [0]) + + def test_empty(self): + self.assertEqual(list(mi.difference([])), []) + + +class SeekableTest(TestCase): + def test_exhaustion_reset(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(list(s), iterable) # Normal iteration + self.assertEqual(list(s), []) # Iterable is exhausted + + s.seek(0) + self.assertEqual(list(s), iterable) # Back in action + + def test_partial_reset(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(mi.take(5, s), iterable[:5]) # Normal iteration + + s.seek(1) + self.assertEqual(list(s), iterable[1:]) # Get the rest of the iterable + + def test_forward(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration + + s.seek(3) # Skip over index 2 + self.assertEqual(list(s), iterable[3:]) # Result is similar to slicing + + s.seek(0) # Back to 0 + self.assertEqual(list(s), iterable) # No difference in result + + def test_past_end(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration + + s.seek(20) + self.assertEqual(list(s), []) # Iterable is exhausted + + s.seek(0) # Back to 0 + self.assertEqual(list(s), iterable) # No difference in result + + def test_elements(self): + iterable = map(str, count()) + + s = mi.seekable(iterable) + mi.take(10, s) + + elements = s.elements() + self.assertEqual( + [elements[i] for i in range(10)], [str(n) for n in range(10)] + ) + self.assertEqual(len(elements), 10) + + mi.take(10, s) + self.assertEqual(list(elements), [str(n) for n in range(20)]) + + +class SequenceViewTests(TestCase): + def test_init(self): + view = mi.SequenceView((1, 2, 3)) + self.assertEqual(repr(view), "SequenceView((1, 2, 3))") + self.assertRaises(TypeError, lambda: mi.SequenceView({})) + + def test_update(self): + seq = [1, 2, 3] + view = mi.SequenceView(seq) + self.assertEqual(len(view), 3) + self.assertEqual(repr(view), "SequenceView([1, 2, 3])") + + seq.pop() + self.assertEqual(len(view), 2) + self.assertEqual(repr(view), "SequenceView([1, 2])") + + def test_indexing(self): + seq = ('a', 'b', 'c', 'd', 'e', 'f') + view = mi.SequenceView(seq) + for i in range(-len(seq), len(seq)): + self.assertEqual(view[i], seq[i]) + + def test_slicing(self): + seq = ('a', 'b', 'c', 'd', 'e', 'f') + view = mi.SequenceView(seq) + n = len(seq) + indexes = list(range(-n - 1, n + 1)) + [None] + steps = list(range(-n, n + 1)) + steps.remove(0) + for slice_args in product(indexes, indexes, steps): + i = slice(*slice_args) + self.assertEqual(view[i], seq[i]) + + def test_abc_methods(self): + # collections.Sequence should provide all of this functionality + seq = ('a', 'b', 'c', 'd', 'e', 'f', 'f') + view = mi.SequenceView(seq) + + # __contains__ + self.assertIn('b', view) + self.assertNotIn('g', view) + + # __iter__ + self.assertEqual(list(iter(view)), list(seq)) + + # __reversed__ + self.assertEqual(list(reversed(view)), list(reversed(seq))) + + # index + self.assertEqual(view.index('b'), 1) + + # count + self.assertEqual(seq.count('f'), 2) + + +class RunLengthTest(TestCase): + def test_encode(self): + iterable = (int(str(n)[0]) for n in count(800)) + actual = mi.take(4, mi.run_length.encode(iterable)) + expected = [(8, 100), (9, 100), (1, 1000), (2, 1000)] + self.assertEqual(actual, expected) + + def test_decode(self): + iterable = [('d', 4), ('c', 3), ('b', 2), ('a', 1)] + actual = ''.join(mi.run_length.decode(iterable)) + expected = 'ddddcccbba' + self.assertEqual(actual, expected) + + +class ExactlyNTests(TestCase): + """Tests for ``exactly_n()``""" + + def test_true(self): + """Iterable has ``n`` ``True`` elements""" + self.assertTrue(mi.exactly_n([True, False, True], 2)) + self.assertTrue(mi.exactly_n([1, 1, 1, 0], 3)) + self.assertTrue(mi.exactly_n([False, False], 0)) + self.assertTrue(mi.exactly_n(range(100), 10, lambda x: x < 10)) + + def test_false(self): + """Iterable does not have ``n`` ``True`` elements""" + self.assertFalse(mi.exactly_n([True, False, False], 2)) + self.assertFalse(mi.exactly_n([True, True, False], 1)) + self.assertFalse(mi.exactly_n([False], 1)) + self.assertFalse(mi.exactly_n([True], -1)) + self.assertFalse(mi.exactly_n(repeat(True), 100)) + + def test_empty(self): + """Return ``True`` if the iterable is empty and ``n`` is 0""" + self.assertTrue(mi.exactly_n([], 0)) + self.assertFalse(mi.exactly_n([], 1)) + + +class AlwaysReversibleTests(TestCase): + """Tests for ``always_reversible()``""" + + def test_regular_reversed(self): + self.assertEqual(list(reversed(range(10))), + list(mi.always_reversible(range(10)))) + self.assertEqual(list(reversed([1, 2, 3])), + list(mi.always_reversible([1, 2, 3]))) + self.assertEqual(reversed([1, 2, 3]).__class__, + mi.always_reversible([1, 2, 3]).__class__) + + def test_nonseq_reversed(self): + # Create a non-reversible generator from a sequence + with self.assertRaises(TypeError): + reversed(x for x in range(10)) + + self.assertEqual(list(reversed(range(10))), + list(mi.always_reversible(x for x in range(10)))) + self.assertEqual(list(reversed([1, 2, 3])), + list(mi.always_reversible(x for x in [1, 2, 3]))) + self.assertNotEqual(reversed((1, 2)).__class__, + mi.always_reversible(x for x in (1, 2)).__class__) + + +class CircularShiftsTests(TestCase): + def test_empty(self): + # empty iterable -> empty list + self.assertEqual(list(mi.circular_shifts([])), []) + + def test_simple_circular_shifts(self): + # test the a simple iterator case + self.assertEqual( + mi.circular_shifts(range(4)), + [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] + ) + + def test_duplicates(self): + # test non-distinct entries + self.assertEqual( + mi.circular_shifts([0, 1, 0, 1]), + [(0, 1, 0, 1), (1, 0, 1, 0), (0, 1, 0, 1), (1, 0, 1, 0)] + ) + + +class MakeDecoratorTests(TestCase): + def test_basic(self): + slicer = mi.make_decorator(islice) + + @slicer(1, 10, 2) + def user_function(arg_1, arg_2, kwarg_1=None): + self.assertEqual(arg_1, 'arg_1') + self.assertEqual(arg_2, 'arg_2') + self.assertEqual(kwarg_1, 'kwarg_1') + return map(str, count()) + + it = user_function('arg_1', 'arg_2', kwarg_1='kwarg_1') + actual = list(it) + expected = ['1', '3', '5', '7', '9'] + self.assertEqual(actual, expected) + + def test_result_index(self): + def stringify(*args, **kwargs): + self.assertEqual(args[0], 'arg_0') + iterable = args[1] + self.assertEqual(args[2], 'arg_2') + self.assertEqual(kwargs['kwarg_1'], 'kwarg_1') + return map(str, iterable) + + stringifier = mi.make_decorator(stringify, result_index=1) + + @stringifier('arg_0', 'arg_2', kwarg_1='kwarg_1') + def user_function(n): + return count(n) + + it = user_function(1) + actual = mi.take(5, it) + expected = ['1', '2', '3', '4', '5'] + self.assertEqual(actual, expected) + + def test_wrap_class(self): + seeker = mi.make_decorator(mi.seekable) + + @seeker() + def user_function(n): + return map(str, range(n)) + + it = user_function(5) + self.assertEqual(list(it), ['0', '1', '2', '3', '4']) + + it.seek(0) + self.assertEqual(list(it), ['0', '1', '2', '3', '4']) + + +class MapReduceTests(TestCase): + def test_default(self): + iterable = (str(x) for x in range(5)) + keyfunc = lambda x: int(x) // 2 + actual = sorted(mi.map_reduce(iterable, keyfunc).items()) + expected = [(0, ['0', '1']), (1, ['2', '3']), (2, ['4'])] + self.assertEqual(actual, expected) + + def test_valuefunc(self): + iterable = (str(x) for x in range(5)) + keyfunc = lambda x: int(x) // 2 + valuefunc = int + actual = sorted(mi.map_reduce(iterable, keyfunc, valuefunc).items()) + expected = [(0, [0, 1]), (1, [2, 3]), (2, [4])] + self.assertEqual(actual, expected) + + def test_reducefunc(self): + iterable = (str(x) for x in range(5)) + keyfunc = lambda x: int(x) // 2 + valuefunc = int + reducefunc = lambda value_list: reduce(mul, value_list, 1) + actual = sorted( + mi.map_reduce(iterable, keyfunc, valuefunc, reducefunc).items() + ) + expected = [(0, 0), (1, 6), (2, 4)] + self.assertEqual(actual, expected) + + def test_ret(self): + d = mi.map_reduce([1, 0, 2, 0, 1, 0], bool) + self.assertEqual(d, {False: [0, 0, 0], True: [1, 2, 1]}) + self.assertRaises(KeyError, lambda: d[None].append(1)) + + +class RlocateTests(TestCase): + def test_default_pred(self): + iterable = [0, 1, 1, 0, 1, 0, 0] + for it in (iterable[:], iter(iterable)): + actual = list(mi.rlocate(it)) + expected = [4, 2, 1] + self.assertEqual(actual, expected) + + def test_no_matches(self): + iterable = [0, 0, 0] + for it in (iterable[:], iter(iterable)): + actual = list(mi.rlocate(it)) + expected = [] + self.assertEqual(actual, expected) + + def test_custom_pred(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda x: x == '0' + for it in (iterable[:], iter(iterable)): + actual = list(mi.rlocate(it, pred)) + expected = [6, 5, 3, 0] + self.assertEqual(actual, expected) + + def test_efficient_reversal(self): + iterable = range(9 ** 9) # Is efficiently reversible + target = 9 ** 9 - 2 + pred = lambda x: x == target # Find-able from the right + actual = next(mi.rlocate(iterable, pred)) + self.assertEqual(actual, target) + + def test_window_size(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda *args: args == ('0', 1) + for it in (iterable, iter(iterable)): + actual = list(mi.rlocate(it, pred, window_size=2)) + expected = [3, 0] + self.assertEqual(actual, expected) + + def test_window_size_large(self): + iterable = [1, 2, 3, 4] + pred = lambda a, b, c, d, e: True + for it in (iterable, iter(iterable)): + actual = list(mi.rlocate(iterable, pred, window_size=5)) + expected = [0] + self.assertEqual(actual, expected) + + def test_window_size_zero(self): + iterable = [1, 2, 3, 4] + pred = lambda: True + for it in (iterable, iter(iterable)): + with self.assertRaises(ValueError): + list(mi.locate(iterable, pred, window_size=0)) + + +class ReplaceTests(TestCase): + def test_basic(self): + iterable = range(10) + pred = lambda x: x % 2 == 0 + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes)) + expected = [1, 3, 5, 7, 9] + self.assertEqual(actual, expected) + + def test_count(self): + iterable = range(10) + pred = lambda x: x % 2 == 0 + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes, count=4)) + expected = [1, 3, 5, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_window_size(self): + iterable = range(10) + pred = lambda *args: args == (0, 1, 2) + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes, window_size=3)) + expected = [3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_window_size_end(self): + iterable = range(10) + pred = lambda *args: args == (7, 8, 9) + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes, window_size=3)) + expected = [0, 1, 2, 3, 4, 5, 6] + self.assertEqual(actual, expected) + + def test_window_size_count(self): + iterable = range(10) + pred = lambda *args: (args == (0, 1, 2)) or (args == (7, 8, 9)) + substitutes = [] + actual = list( + mi.replace(iterable, pred, substitutes, count=1, window_size=3) + ) + expected = [3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_window_size_large(self): + iterable = range(4) + pred = lambda a, b, c, d, e: True + substitutes = [5, 6, 7] + actual = list(mi.replace(iterable, pred, substitutes, window_size=5)) + expected = [5, 6, 7] + self.assertEqual(actual, expected) + + def test_window_size_zero(self): + iterable = range(10) + pred = lambda *args: True + substitutes = [] + with self.assertRaises(ValueError): + list(mi.replace(iterable, pred, substitutes, window_size=0)) + + def test_iterable_substitutes(self): + iterable = range(5) + pred = lambda x: x % 2 == 0 + substitutes = iter('__') + actual = list(mi.replace(iterable, pred, substitutes)) + expected = ['_', '_', 1, '_', '_', 3, '_', '_'] + self.assertEqual(actual, expected) diff --git a/libs/common/more_itertools/tests/test_recipes.py b/libs/common/more_itertools/tests/test_recipes.py new file mode 100644 index 00000000..b3cfb62f --- /dev/null +++ b/libs/common/more_itertools/tests/test_recipes.py @@ -0,0 +1,616 @@ +from doctest import DocTestSuite +from unittest import TestCase + +from itertools import combinations +from six.moves import range + +import more_itertools as mi + + +def load_tests(loader, tests, ignore): + # Add the doctests + tests.addTests(DocTestSuite('more_itertools.recipes')) + return tests + + +class AccumulateTests(TestCase): + """Tests for ``accumulate()``""" + + def test_empty(self): + """Test that an empty input returns an empty output""" + self.assertEqual(list(mi.accumulate([])), []) + + def test_default(self): + """Test accumulate with the default function (addition)""" + self.assertEqual(list(mi.accumulate([1, 2, 3])), [1, 3, 6]) + + def test_bogus_function(self): + """Test accumulate with an invalid function""" + with self.assertRaises(TypeError): + list(mi.accumulate([1, 2, 3], func=lambda x: x)) + + def test_custom_function(self): + """Test accumulate with a custom function""" + self.assertEqual( + list(mi.accumulate((1, 2, 3, 2, 1), func=max)), [1, 2, 3, 3, 3] + ) + + +class TakeTests(TestCase): + """Tests for ``take()``""" + + def test_simple_take(self): + """Test basic usage""" + t = mi.take(5, range(10)) + self.assertEqual(t, [0, 1, 2, 3, 4]) + + def test_null_take(self): + """Check the null case""" + t = mi.take(0, range(10)) + self.assertEqual(t, []) + + def test_negative_take(self): + """Make sure taking negative items results in a ValueError""" + self.assertRaises(ValueError, lambda: mi.take(-3, range(10))) + + def test_take_too_much(self): + """Taking more than an iterator has remaining should return what the + iterator has remaining. + + """ + t = mi.take(10, range(5)) + self.assertEqual(t, [0, 1, 2, 3, 4]) + + +class TabulateTests(TestCase): + """Tests for ``tabulate()``""" + + def test_simple_tabulate(self): + """Test the happy path""" + t = mi.tabulate(lambda x: x) + f = tuple([next(t) for _ in range(3)]) + self.assertEqual(f, (0, 1, 2)) + + def test_count(self): + """Ensure tabulate accepts specific count""" + t = mi.tabulate(lambda x: 2 * x, -1) + f = (next(t), next(t), next(t)) + self.assertEqual(f, (-2, 0, 2)) + + +class TailTests(TestCase): + """Tests for ``tail()``""" + + def test_greater(self): + """Length of iterable is greater than requested tail""" + self.assertEqual(list(mi.tail(3, 'ABCDEFG')), ['E', 'F', 'G']) + + def test_equal(self): + """Length of iterable is equal to the requested tail""" + self.assertEqual( + list(mi.tail(7, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] + ) + + def test_less(self): + """Length of iterable is less than requested tail""" + self.assertEqual( + list(mi.tail(8, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] + ) + + +class ConsumeTests(TestCase): + """Tests for ``consume()``""" + + def test_sanity(self): + """Test basic functionality""" + r = (x for x in range(10)) + mi.consume(r, 3) + self.assertEqual(3, next(r)) + + def test_null_consume(self): + """Check the null case""" + r = (x for x in range(10)) + mi.consume(r, 0) + self.assertEqual(0, next(r)) + + def test_negative_consume(self): + """Check that negative consumsion throws an error""" + r = (x for x in range(10)) + self.assertRaises(ValueError, lambda: mi.consume(r, -1)) + + def test_total_consume(self): + """Check that iterator is totally consumed by default""" + r = (x for x in range(10)) + mi.consume(r) + self.assertRaises(StopIteration, lambda: next(r)) + + +class NthTests(TestCase): + """Tests for ``nth()``""" + + def test_basic(self): + """Make sure the nth item is returned""" + l = range(10) + for i, v in enumerate(l): + self.assertEqual(mi.nth(l, i), v) + + def test_default(self): + """Ensure a default value is returned when nth item not found""" + l = range(3) + self.assertEqual(mi.nth(l, 100, "zebra"), "zebra") + + def test_negative_item_raises(self): + """Ensure asking for a negative item raises an exception""" + self.assertRaises(ValueError, lambda: mi.nth(range(10), -3)) + + +class AllEqualTests(TestCase): + """Tests for ``all_equal()``""" + + def test_true(self): + """Everything is equal""" + self.assertTrue(mi.all_equal('aaaaaa')) + self.assertTrue(mi.all_equal([0, 0, 0, 0])) + + def test_false(self): + """Not everything is equal""" + self.assertFalse(mi.all_equal('aaaaab')) + self.assertFalse(mi.all_equal([0, 0, 0, 1])) + + def test_tricky(self): + """Not everything is identical, but everything is equal""" + items = [1, complex(1, 0), 1.0] + self.assertTrue(mi.all_equal(items)) + + def test_empty(self): + """Return True if the iterable is empty""" + self.assertTrue(mi.all_equal('')) + self.assertTrue(mi.all_equal([])) + + def test_one(self): + """Return True if the iterable is singular""" + self.assertTrue(mi.all_equal('0')) + self.assertTrue(mi.all_equal([0])) + + +class QuantifyTests(TestCase): + """Tests for ``quantify()``""" + + def test_happy_path(self): + """Make sure True count is returned""" + q = [True, False, True] + self.assertEqual(mi.quantify(q), 2) + + def test_custom_predicate(self): + """Ensure non-default predicates return as expected""" + q = range(10) + self.assertEqual(mi.quantify(q, lambda x: x % 2 == 0), 5) + + +class PadnoneTests(TestCase): + """Tests for ``padnone()``""" + + def test_happy_path(self): + """wrapper iterator should return None indefinitely""" + r = range(2) + p = mi.padnone(r) + self.assertEqual([0, 1, None, None], [next(p) for _ in range(4)]) + + +class NcyclesTests(TestCase): + """Tests for ``nyclces()``""" + + def test_happy_path(self): + """cycle a sequence three times""" + r = ["a", "b", "c"] + n = mi.ncycles(r, 3) + self.assertEqual( + ["a", "b", "c", "a", "b", "c", "a", "b", "c"], + list(n) + ) + + def test_null_case(self): + """asking for 0 cycles should return an empty iterator""" + n = mi.ncycles(range(100), 0) + self.assertRaises(StopIteration, lambda: next(n)) + + def test_pathalogical_case(self): + """asking for negative cycles should return an empty iterator""" + n = mi.ncycles(range(100), -10) + self.assertRaises(StopIteration, lambda: next(n)) + + +class DotproductTests(TestCase): + """Tests for ``dotproduct()``'""" + + def test_happy_path(self): + """simple dotproduct example""" + self.assertEqual(400, mi.dotproduct([10, 10], [20, 20])) + + +class FlattenTests(TestCase): + """Tests for ``flatten()``""" + + def test_basic_usage(self): + """ensure list of lists is flattened one level""" + f = [[0, 1, 2], [3, 4, 5]] + self.assertEqual(list(range(6)), list(mi.flatten(f))) + + def test_single_level(self): + """ensure list of lists is flattened only one level""" + f = [[0, [1, 2]], [[3, 4], 5]] + self.assertEqual([0, [1, 2], [3, 4], 5], list(mi.flatten(f))) + + +class RepeatfuncTests(TestCase): + """Tests for ``repeatfunc()``""" + + def test_simple_repeat(self): + """test simple repeated functions""" + r = mi.repeatfunc(lambda: 5) + self.assertEqual([5, 5, 5, 5, 5], [next(r) for _ in range(5)]) + + def test_finite_repeat(self): + """ensure limited repeat when times is provided""" + r = mi.repeatfunc(lambda: 5, times=5) + self.assertEqual([5, 5, 5, 5, 5], list(r)) + + def test_added_arguments(self): + """ensure arguments are applied to the function""" + r = mi.repeatfunc(lambda x: x, 2, 3) + self.assertEqual([3, 3], list(r)) + + def test_null_times(self): + """repeat 0 should return an empty iterator""" + r = mi.repeatfunc(range, 0, 3) + self.assertRaises(StopIteration, lambda: next(r)) + + +class PairwiseTests(TestCase): + """Tests for ``pairwise()``""" + + def test_base_case(self): + """ensure an iterable will return pairwise""" + p = mi.pairwise([1, 2, 3]) + self.assertEqual([(1, 2), (2, 3)], list(p)) + + def test_short_case(self): + """ensure an empty iterator if there's not enough values to pair""" + p = mi.pairwise("a") + self.assertRaises(StopIteration, lambda: next(p)) + + +class GrouperTests(TestCase): + """Tests for ``grouper()``""" + + def test_even(self): + """Test when group size divides evenly into the length of + the iterable. + + """ + self.assertEqual( + list(mi.grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')] + ) + + def test_odd(self): + """Test when group size does not divide evenly into the length of the + iterable. + + """ + self.assertEqual( + list(mi.grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)] + ) + + def test_fill_value(self): + """Test that the fill value is used to pad the final group""" + self.assertEqual( + list(mi.grouper(3, 'ABCDE', 'x')), + [('A', 'B', 'C'), ('D', 'E', 'x')] + ) + + +class RoundrobinTests(TestCase): + """Tests for ``roundrobin()``""" + + def test_even_groups(self): + """Ensure ordered output from evenly populated iterables""" + self.assertEqual( + list(mi.roundrobin('ABC', [1, 2, 3], range(3))), + ['A', 1, 0, 'B', 2, 1, 'C', 3, 2] + ) + + def test_uneven_groups(self): + """Ensure ordered output from unevenly populated iterables""" + self.assertEqual( + list(mi.roundrobin('ABCD', [1, 2], range(0))), + ['A', 1, 'B', 2, 'C', 'D'] + ) + + +class PartitionTests(TestCase): + """Tests for ``partition()``""" + + def test_bool(self): + """Test when pred() returns a boolean""" + lesser, greater = mi.partition(lambda x: x > 5, range(10)) + self.assertEqual(list(lesser), [0, 1, 2, 3, 4, 5]) + self.assertEqual(list(greater), [6, 7, 8, 9]) + + def test_arbitrary(self): + """Test when pred() returns an integer""" + divisibles, remainders = mi.partition(lambda x: x % 3, range(10)) + self.assertEqual(list(divisibles), [0, 3, 6, 9]) + self.assertEqual(list(remainders), [1, 2, 4, 5, 7, 8]) + + +class PowersetTests(TestCase): + """Tests for ``powerset()``""" + + def test_combinatorics(self): + """Ensure a proper enumeration""" + p = mi.powerset([1, 2, 3]) + self.assertEqual( + list(p), + [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] + ) + + +class UniqueEverseenTests(TestCase): + """Tests for ``unique_everseen()``""" + + def test_everseen(self): + """ensure duplicate elements are ignored""" + u = mi.unique_everseen('AAAABBBBCCDAABBB') + self.assertEqual( + ['A', 'B', 'C', 'D'], + list(u) + ) + + def test_custom_key(self): + """ensure the custom key comparison works""" + u = mi.unique_everseen('aAbACCc', key=str.lower) + self.assertEqual(list('abC'), list(u)) + + def test_unhashable(self): + """ensure things work for unhashable items""" + iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] + u = mi.unique_everseen(iterable) + self.assertEqual(list(u), ['a', [1, 2, 3]]) + + def test_unhashable_key(self): + """ensure things work for unhashable items with a custom key""" + iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] + u = mi.unique_everseen(iterable, key=lambda x: x) + self.assertEqual(list(u), ['a', [1, 2, 3]]) + + +class UniqueJustseenTests(TestCase): + """Tests for ``unique_justseen()``""" + + def test_justseen(self): + """ensure only last item is remembered""" + u = mi.unique_justseen('AAAABBBCCDABB') + self.assertEqual(list('ABCDAB'), list(u)) + + def test_custom_key(self): + """ensure the custom key comparison works""" + u = mi.unique_justseen('AABCcAD', str.lower) + self.assertEqual(list('ABCAD'), list(u)) + + +class IterExceptTests(TestCase): + """Tests for ``iter_except()``""" + + def test_exact_exception(self): + """ensure the exact specified exception is caught""" + l = [1, 2, 3] + i = mi.iter_except(l.pop, IndexError) + self.assertEqual(list(i), [3, 2, 1]) + + def test_generic_exception(self): + """ensure the generic exception can be caught""" + l = [1, 2] + i = mi.iter_except(l.pop, Exception) + self.assertEqual(list(i), [2, 1]) + + def test_uncaught_exception_is_raised(self): + """ensure a non-specified exception is raised""" + l = [1, 2, 3] + i = mi.iter_except(l.pop, KeyError) + self.assertRaises(IndexError, lambda: list(i)) + + def test_first(self): + """ensure first is run before the function""" + l = [1, 2, 3] + f = lambda: 25 + i = mi.iter_except(l.pop, IndexError, f) + self.assertEqual(list(i), [25, 3, 2, 1]) + + +class FirstTrueTests(TestCase): + """Tests for ``first_true()``""" + + def test_something_true(self): + """Test with no keywords""" + self.assertEqual(mi.first_true(range(10)), 1) + + def test_nothing_true(self): + """Test default return value.""" + self.assertIsNone(mi.first_true([0, 0, 0])) + + def test_default(self): + """Test with a default keyword""" + self.assertEqual(mi.first_true([0, 0, 0], default='!'), '!') + + def test_pred(self): + """Test with a custom predicate""" + self.assertEqual( + mi.first_true([2, 4, 6], pred=lambda x: x % 3 == 0), 6 + ) + + +class RandomProductTests(TestCase): + """Tests for ``random_product()`` + + Since random.choice() has different results with the same seed across + python versions 2.x and 3.x, these tests use highly probably events to + create predictable outcomes across platforms. + """ + + def test_simple_lists(self): + """Ensure that one item is chosen from each list in each pair. + Also ensure that each item from each list eventually appears in + the chosen combinations. + + Odds are roughly 1 in 7.1 * 10e16 that one item from either list will + not be chosen after 100 samplings of one item from each list. Just to + be safe, better use a known random seed, too. + + """ + nums = [1, 2, 3] + lets = ['a', 'b', 'c'] + n, m = zip(*[mi.random_product(nums, lets) for _ in range(100)]) + n, m = set(n), set(m) + self.assertEqual(n, set(nums)) + self.assertEqual(m, set(lets)) + self.assertEqual(len(n), len(nums)) + self.assertEqual(len(m), len(lets)) + + def test_list_with_repeat(self): + """ensure multiple items are chosen, and that they appear to be chosen + from one list then the next, in proper order. + + """ + nums = [1, 2, 3] + lets = ['a', 'b', 'c'] + r = list(mi.random_product(nums, lets, repeat=100)) + self.assertEqual(2 * 100, len(r)) + n, m = set(r[::2]), set(r[1::2]) + self.assertEqual(n, set(nums)) + self.assertEqual(m, set(lets)) + self.assertEqual(len(n), len(nums)) + self.assertEqual(len(m), len(lets)) + + +class RandomPermutationTests(TestCase): + """Tests for ``random_permutation()``""" + + def test_full_permutation(self): + """ensure every item from the iterable is returned in a new ordering + + 15 elements have a 1 in 1.3 * 10e12 of appearing in sorted order, so + we fix a seed value just to be sure. + + """ + i = range(15) + r = mi.random_permutation(i) + self.assertEqual(set(i), set(r)) + if i == r: + raise AssertionError("Values were not permuted") + + def test_partial_permutation(self): + """ensure all returned items are from the iterable, that the returned + permutation is of the desired length, and that all items eventually + get returned. + + Sampling 100 permutations of length 5 from a set of 15 leaves a + (2/3)^100 chance that an item will not be chosen. Multiplied by 15 + items, there is a 1 in 2.6e16 chance that at least 1 item will not + show up in the resulting output. Using a random seed will fix that. + + """ + items = range(15) + item_set = set(items) + all_items = set() + for _ in range(100): + permutation = mi.random_permutation(items, 5) + self.assertEqual(len(permutation), 5) + permutation_set = set(permutation) + self.assertLessEqual(permutation_set, item_set) + all_items |= permutation_set + self.assertEqual(all_items, item_set) + + +class RandomCombinationTests(TestCase): + """Tests for ``random_combination()``""" + + def test_pseudorandomness(self): + """ensure different subsets of the iterable get returned over many + samplings of random combinations""" + items = range(15) + all_items = set() + for _ in range(50): + combination = mi.random_combination(items, 5) + all_items |= set(combination) + self.assertEqual(all_items, set(items)) + + def test_no_replacement(self): + """ensure that elements are sampled without replacement""" + items = range(15) + for _ in range(50): + combination = mi.random_combination(items, len(items)) + self.assertEqual(len(combination), len(set(combination))) + self.assertRaises( + ValueError, lambda: mi.random_combination(items, len(items) + 1) + ) + + +class RandomCombinationWithReplacementTests(TestCase): + """Tests for ``random_combination_with_replacement()``""" + + def test_replacement(self): + """ensure that elements are sampled with replacement""" + items = range(5) + combo = mi.random_combination_with_replacement(items, len(items) * 2) + self.assertEqual(2 * len(items), len(combo)) + if len(set(combo)) == len(combo): + raise AssertionError("Combination contained no duplicates") + + def test_pseudorandomness(self): + """ensure different subsets of the iterable get returned over many + samplings of random combinations""" + items = range(15) + all_items = set() + for _ in range(50): + combination = mi.random_combination_with_replacement(items, 5) + all_items |= set(combination) + self.assertEqual(all_items, set(items)) + + +class NthCombinationTests(TestCase): + def test_basic(self): + iterable = 'abcdefg' + r = 4 + for index, expected in enumerate(combinations(iterable, r)): + actual = mi.nth_combination(iterable, r, index) + self.assertEqual(actual, expected) + + def test_long(self): + actual = mi.nth_combination(range(180), 4, 2000000) + expected = (2, 12, 35, 126) + self.assertEqual(actual, expected) + + def test_invalid_r(self): + for r in (-1, 3): + with self.assertRaises(ValueError): + mi.nth_combination([], r, 0) + + def test_invalid_index(self): + with self.assertRaises(IndexError): + mi.nth_combination('abcdefg', 3, -36) + + +class PrependTests(TestCase): + def test_basic(self): + value = 'a' + iterator = iter('bcdefg') + actual = list(mi.prepend(value, iterator)) + expected = list('abcdefg') + self.assertEqual(actual, expected) + + def test_multiple(self): + value = 'ab' + iterator = iter('cdefg') + actual = tuple(mi.prepend(value, iterator)) + expected = ('ab',) + tuple('cdefg') + self.assertEqual(actual, expected)