From 4ac151d7deab9289c5ce1012b060320760b0bc67 Mon Sep 17 00:00:00 2001 From: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> Date: Thu, 14 Oct 2021 21:13:45 -0700 Subject: [PATCH] Update more-itertools-8.10.0 --- lib/more_itertools/__init__.py | 6 +- lib/more_itertools/__init__.pyi | 2 + lib/more_itertools/more.py | 2259 ++++++++++++++++++--- lib/more_itertools/more.pyi | 556 ++++++ lib/more_itertools/py.typed | 0 lib/more_itertools/recipes.py | 247 ++- lib/more_itertools/recipes.pyi | 105 + lib/more_itertools/tests/test_more.py | 2313 ---------------------- lib/more_itertools/tests/test_recipes.py | 616 ------ 9 files changed, 2807 insertions(+), 3297 deletions(-) create mode 100644 lib/more_itertools/__init__.pyi create mode 100644 lib/more_itertools/more.pyi create mode 100644 lib/more_itertools/py.typed create mode 100644 lib/more_itertools/recipes.pyi delete mode 100644 lib/more_itertools/tests/test_more.py delete mode 100644 lib/more_itertools/tests/test_recipes.py diff --git a/lib/more_itertools/__init__.py b/lib/more_itertools/__init__.py index bba462c3..e2d7d91d 100644 --- a/lib/more_itertools/__init__.py +++ b/lib/more_itertools/__init__.py @@ -1,2 +1,4 @@ -from more_itertools.more import * # noqa -from more_itertools.recipes import * # noqa +from .more import * # noqa +from .recipes import * # noqa + +__version__ = '8.10.0' diff --git a/lib/more_itertools/__init__.pyi b/lib/more_itertools/__init__.pyi new file mode 100644 index 00000000..96f6e36c --- /dev/null +++ b/lib/more_itertools/__init__.pyi @@ -0,0 +1,2 @@ +from .more import * +from .recipes import * diff --git a/lib/more_itertools/more.py b/lib/more_itertools/more.py index bd32a261..edef3854 100644 --- a/lib/more_itertools/more.py +++ b/lib/more_itertools/more.py @@ -1,8 +1,10 @@ -from __future__ import print_function +import warnings -from collections import Counter, defaultdict, deque -from functools import partial, wraps -from heapq import merge +from collections import Counter, defaultdict, deque, abc +from collections.abc import Sequence +from concurrent.futures import ThreadPoolExecutor +from functools import partial, reduce, wraps +from heapq import merge, heapify, heapreplace, heappop from itertools import ( chain, compress, @@ -14,58 +16,84 @@ from itertools import ( repeat, starmap, takewhile, - tee + tee, + zip_longest, ) -from operator import itemgetter, lt, gt, sub -from sys import maxsize, version_info -try: - from collections.abc import Sequence -except ImportError: - from collections import Sequence +from math import exp, factorial, floor, log +from queue import Empty, Queue +from random import random, randrange, uniform +from operator import itemgetter, mul, sub, gt, lt +from sys import hexversion, maxsize +from time import monotonic -from six import binary_type, string_types, text_type -from six.moves import filter, map, range, zip, zip_longest - -from .recipes import consume, flatten, take +from .recipes import ( + consume, + flatten, + pairwise, + powerset, + take, + unique_everseen, +) __all__ = [ + 'AbortThread', 'adjacent', 'always_iterable', 'always_reversible', 'bucket', + 'callback_iter', 'chunked', + 'chunked_even', 'circular_shifts', 'collapse', 'collate', 'consecutive_groups', 'consumer', + 'countable', 'count_cycle', + 'mark_ends', 'difference', + 'distinct_combinations', 'distinct_permutations', 'distribute', 'divide', 'exactly_n', + 'filter_except', 'first', 'groupby_transform', 'ilen', 'interleave_longest', 'interleave', + 'interleave_evenly', 'intersperse', 'islice_extended', 'iterate', + 'ichunked', + 'is_sorted', 'last', 'locate', 'lstrip', 'make_decorator', + 'map_except', + 'map_if', 'map_reduce', + 'nth_or_last', + 'nth_permutation', + 'nth_product', 'numeric_range', 'one', + 'only', 'padded', + 'partitions', + 'set_partitions', 'peekable', + 'repeat_each', + 'repeat_last', 'replace', 'rlocate', 'rstrip', 'run_length', + 'sample', 'seekable', 'SequenceView', 'side_effect', @@ -74,43 +102,66 @@ __all__ = [ 'split_at', 'split_after', 'split_before', + 'split_when', 'split_into', 'spy', 'stagger', 'strip', 'substrings', + 'substrings_indexes', + 'time_limited', 'unique_to_each', 'unzip', 'windowed', 'with_iter', + 'UnequalIterablesError', + 'zip_equal', 'zip_offset', + 'windowed_complete', + 'all_unique', + 'value_chain', + 'product_index', + 'combination_index', + 'permutation_index', + 'zip_broadcast', ] _marker = object() -def chunked(iterable, n): +def chunked(iterable, n, strict=False): """Break *iterable* into lists of length *n*: >>> list(chunked([1, 2, 3, 4, 5, 6], 3)) [[1, 2, 3], [4, 5, 6]] - If the length of *iterable* is not evenly divisible by *n*, the last - returned list will be shorter: + By the default, the last yielded list will have fewer than *n* elements + if the length of *iterable* is not divisible by *n*: >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3)) [[1, 2, 3], [4, 5, 6], [7, 8]] To use a fill-in value instead, see the :func:`grouper` recipe. - :func:`chunked` is useful for splitting up a computation on a large number - of keys into batches, to be pickled and sent off to worker processes. One - example is operations on rows in MySQL, which does not implement - server-side cursors properly and would otherwise load the entire dataset - into RAM on the client. + If the length of *iterable* is not divisible by *n* and *strict* is + ``True``, then ``ValueError`` will be raised before the last + list is yielded. """ - return iter(partial(take, n, iter(iterable)), []) + iterator = iter(partial(take, n, iter(iterable)), []) + if strict: + if n is None: + raise ValueError('n must not be None when using strict mode.') + + def ret(): + for chunk in iterator: + if len(chunk) != n: + raise ValueError('iterable is not divisible by n.') + yield chunk + + return iter(ret()) + else: + return iterator def first(iterable, default=_marker): @@ -132,14 +183,12 @@ def first(iterable, default=_marker): """ try: return next(iter(iterable)) - except StopIteration: - # I'm on the edge about raising ValueError instead of StopIteration. At - # the moment, ValueError wins, because the caller could conceivably - # want to do something different with flow control when I raise the - # exception, and it's weird to explicitly catch StopIteration. + except StopIteration as e: if default is _marker: - raise ValueError('first() was called on an empty iterable, and no ' - 'default value was provided.') + raise ValueError( + 'first() was called on an empty iterable, and no ' + 'default value was provided.' + ) from e return default @@ -156,20 +205,40 @@ def last(iterable, default=_marker): raise ``ValueError``. """ try: - try: - # Try to access the last item directly + if isinstance(iterable, Sequence): return iterable[-1] - except (TypeError, AttributeError, KeyError): - # If not slice-able, iterate entirely using length-1 deque - return deque(iterable, maxlen=1)[0] - except IndexError: # If the iterable was empty + # Work around https://bugs.python.org/issue38525 + elif hasattr(iterable, '__reversed__') and (hexversion != 0x030800F0): + return next(reversed(iterable)) + else: + return deque(iterable, maxlen=1)[-1] + except (IndexError, TypeError, StopIteration): if default is _marker: - raise ValueError('last() was called on an empty iterable, and no ' - 'default value was provided.') + raise ValueError( + 'last() was called on an empty iterable, and no default was ' + 'provided.' + ) return default -class peekable(object): +def nth_or_last(iterable, n, default=_marker): + """Return the nth or the last item of *iterable*, + or *default* if *iterable* is empty. + + >>> nth_or_last([0, 1, 2, 3], 2) + 2 + >>> nth_or_last([0, 1], 2) + 1 + >>> nth_or_last([], 0, 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + """ + return last(islice(iterable, n + 1), default=default) + + +class peekable: """Wrap an iterator to allow lookahead and prepending elements. Call :meth:`peek` on the result to get the value that will be returned @@ -222,11 +291,12 @@ class peekable(object): >>> if p: # peekable has items ... list(p) ['a', 'b'] - >>> if not p: # peekable is exhaused + >>> if not p: # peekable is exhausted ... list(p) [] """ + def __init__(self, iterable): self._it = iter(iterable) self._cache = deque() @@ -241,10 +311,6 @@ class peekable(object): return False return True - def __nonzero__(self): - # For Python 2 compatibility - return self.__bool__() - def peek(self, default=_marker): """Return the item that will be next returned from ``next()``. @@ -298,8 +364,6 @@ class peekable(object): return next(self._it) - next = __next__ # For Python 2 compatibility - def _get_slice(self, index): # Normalize the slice's arguments step = 1 if (index.step is None) else index.step @@ -339,23 +403,6 @@ class peekable(object): return self._cache[index] -def _collate(*iterables, **kwargs): - """Helper for ``collate()``, called when the user is using the ``reverse`` - or ``key`` keyword arguments on Python versions below 3.5. - - """ - key = kwargs.pop('key', lambda a: a) - reverse = kwargs.pop('reverse', False) - - min_or_max = partial(max if reverse else min, key=itemgetter(0)) - peekables = [peekable(it) for it in iterables] - peekables = [p for p in peekables if p] # Kill empties. - while peekables: - _, p = min_or_max((key(p.peek()), p) for p in peekables) - yield next(p) - peekables = [x for x in peekables if x] - - def collate(*iterables, **kwargs): """Return a sorted merge of the items from each of several already-sorted *iterables*. @@ -384,23 +431,14 @@ def collate(*iterables, **kwargs): If the elements of the passed-in iterables are out of order, you might get unexpected results. - On Python 2.7, this function delegates to :func:`heapq.merge` if neither - of the keyword arguments are specified. On Python 3.5+, this function - is an alias for :func:`heapq.merge`. + On Python 3.5+, this function is an alias for :func:`heapq.merge`. """ - if not kwargs: - return merge(*iterables) - - return _collate(*iterables, **kwargs) - - -# If using Python version 3.5 or greater, heapq.merge() will be faster than -# collate - use that instead. -if version_info >= (3, 5, 0): - _collate_docstring = collate.__doc__ - collate = partial(merge) - collate.__doc__ = _collate_docstring + warnings.warn( + "collate is no longer part of more_itertools, use heapq.merge", + DeprecationWarning, + ) + return merge(*iterables, **kwargs) def consumer(func): @@ -425,11 +463,13 @@ def consumer(func): ``t.send()`` could be used. """ + @wraps(func) def wrapper(*args, **kwargs): gen = func(*args, **kwargs) next(gen) return gen + return wrapper @@ -453,9 +493,9 @@ def ilen(iterable): def iterate(func, start): """Return ``start``, ``func(start)``, ``func(func(start))``, ... - >>> from itertools import islice - >>> list(islice(iterate(lambda x: 2*x, 1), 10)) - [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + >>> from itertools import islice + >>> list(islice(iterate(lambda x: 2*x, 1), 10)) + [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] """ while True: @@ -475,8 +515,7 @@ def with_iter(context_manager): """ with context_manager as iterable: - for item in iterable: - yield item + yield from iterable def one(iterable, too_short=None, too_long=None): @@ -510,7 +549,8 @@ def one(iterable, too_short=None, too_long=None): >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError: too many items in iterable (expected 1)' + ValueError: Expected exactly one item in iterable, but got 'too', + 'many', and perhaps more. >>> too_long = RuntimeError >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): @@ -518,29 +558,34 @@ def one(iterable, too_short=None, too_long=None): RuntimeError Note that :func:`one` attempts to advance *iterable* twice to ensure there - is only one item. If there is more than one, both items will be discarded. - See :func:`spy` or :func:`peekable` to check iterable contents less - destructively. + is only one item. See :func:`spy` or :func:`peekable` to check iterable + contents less destructively. """ it = iter(iterable) try: - value = next(it) - except StopIteration: - raise too_short or ValueError('too few items in iterable (expected 1)') + first_value = next(it) + except StopIteration as e: + raise ( + too_short or ValueError('too few items in iterable (expected 1)') + ) from e try: - next(it) + second_value = next(it) except StopIteration: pass else: - raise too_long or ValueError('too many items in iterable (expected 1)') + msg = ( + 'Expected exactly one item in iterable, but got {!r}, {!r}, ' + 'and perhaps more.'.format(first_value, second_value) + ) + raise too_long or ValueError(msg) - return value + return first_value -def distinct_permutations(iterable): +def distinct_permutations(iterable, r=None): """Yield successive distinct permutations of the elements in *iterable*. >>> sorted(distinct_permutations([1, 0, 1])) @@ -556,34 +601,88 @@ def distinct_permutations(iterable): items input, and each `x_i` is the count of a distinct item in the input sequence. + If *r* is given, only the *r*-length permutations are yielded. + + >>> sorted(distinct_permutations([1, 0, 1], r=2)) + [(0, 1), (1, 0), (1, 1)] + >>> sorted(distinct_permutations(range(3), r=2)) + [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + """ - def perm_unique_helper(item_counts, perm, i): - """Internal helper function + # Algorithm: https://w.wiki/Qai + def _full(A): + while True: + # Yield the permutation we have + yield tuple(A) - :arg item_counts: Stores the unique items in ``iterable`` and how many - times they are repeated - :arg perm: The permutation that is being built for output - :arg i: The index of the permutation being modified + # Find the largest index i such that A[i] < A[i + 1] + for i in range(size - 2, -1, -1): + if A[i] < A[i + 1]: + break + # If no such index exists, this permutation is the last one + else: + return - The output permutations are built up recursively; the distinct items - are placed until their repetitions are exhausted. - """ - if i < 0: - yield tuple(perm) - else: - for item in item_counts: - if item_counts[item] <= 0: - continue - perm[i] = item - item_counts[item] -= 1 - for x in perm_unique_helper(item_counts, perm, i - 1): - yield x - item_counts[item] += 1 + # Find the largest index j greater than j such that A[i] < A[j] + for j in range(size - 1, i, -1): + if A[i] < A[j]: + break - item_counts = Counter(iterable) - length = sum(item_counts.values()) + # Swap the value of A[i] with that of A[j], then reverse the + # sequence from A[i + 1] to form the new permutation + A[i], A[j] = A[j], A[i] + A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1] - return perm_unique_helper(item_counts, [None] * length, length - 1) + # Algorithm: modified from the above + def _partial(A, r): + # Split A into the first r items and the last r items + head, tail = A[:r], A[r:] + right_head_indexes = range(r - 1, -1, -1) + left_tail_indexes = range(len(tail)) + + while True: + # Yield the permutation we have + yield tuple(head) + + # Starting from the right, find the first index of the head with + # value smaller than the maximum value of the tail - call it i. + pivot = tail[-1] + for i in right_head_indexes: + if head[i] < pivot: + break + pivot = head[i] + else: + return + + # Starting from the left, find the first value of the tail + # with a value greater than head[i] and swap. + for j in left_tail_indexes: + if tail[j] > head[i]: + head[i], tail[j] = tail[j], head[i] + break + # If we didn't find one, start from the right and find the first + # index of the head with a value greater than head[i] and swap. + else: + for j in right_head_indexes: + if head[j] > head[i]: + head[i], head[j] = head[j], head[i] + break + + # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)] + tail += head[: i - r : -1] # head[i + 1:][::-1] + i += 1 + head[i:], tail[:] = tail[: r - i], tail[r - i :] + + items = sorted(iterable) + + size = len(items) + if r is None: + r = size + + if 0 < r <= size: + return _full(items) if (r == size) else _partial(items, r) + + return iter(() if r else ((),)) def intersperse(e, iterable, n=1): @@ -653,7 +752,7 @@ def windowed(seq, n, fillvalue=None, step=1): [(1, 2, 3), (2, 3, 4), (3, 4, 5)] When the window is larger than the iterable, *fillvalue* is used in place - of missing values:: + of missing values: >>> list(windowed([1, 2, 3], 4)) [(1, 2, 3, None)] @@ -663,6 +762,14 @@ def windowed(seq, n, fillvalue=None, step=1): >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2)) [(1, 2, 3), (3, 4, 5), (5, 6, '!')] + To slide into the iterable's items, use :func:`chain` to add filler items + to the left: + + >>> iterable = [1, 2, 3, 4] + >>> n = 3 + >>> padding = [None] * (n - 1) + >>> list(windowed(chain(padding, iterable), 3)) + [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)] """ if n < 0: raise ValueError('n must be >= 0') @@ -672,32 +779,23 @@ def windowed(seq, n, fillvalue=None, step=1): if step < 1: raise ValueError('step must be >= 1') - it = iter(seq) - window = deque([], n) - append = window.append - - # Initial deque fill - for _ in range(n): - append(next(it, fillvalue)) - yield tuple(window) - - # Appending new items to the right causes old items to fall off the left - i = 0 - for item in it: - append(item) - i = (i + 1) % step - if i % step == 0: + window = deque(maxlen=n) + i = n + for _ in map(window.append, seq): + i -= 1 + if not i: + i = step yield tuple(window) - # If there are items from the iterable in the window, pad with the given - # value and emit them. - if (i % step) and (step - i < n): - for _ in range(step - i): - append(fillvalue) + size = len(window) + if size < n: + yield tuple(chain(window, repeat(fillvalue, n - size))) + elif 0 < i < min(step, n): + window += (fillvalue,) * i yield tuple(window) -def substrings(iterable, join_func=None): +def substrings(iterable): """Yield all of the substrings of *iterable*. >>> [''.join(s) for s in substrings('more')] @@ -720,15 +818,51 @@ def substrings(iterable, join_func=None): # And the rest for n in range(2, item_count + 1): for i in range(item_count - n + 1): - yield seq[i:i + n] + yield seq[i : i + n] -class bucket(object): +def substrings_indexes(seq, reverse=False): + """Yield all substrings and their positions in *seq* + + The items yielded will be a tuple of the form ``(substr, i, j)``, where + ``substr == seq[i:j]``. + + This function only works for iterables that support slicing, such as + ``str`` objects. + + >>> for item in substrings_indexes('more'): + ... print(item) + ('m', 0, 1) + ('o', 1, 2) + ('r', 2, 3) + ('e', 3, 4) + ('mo', 0, 2) + ('or', 1, 3) + ('re', 2, 4) + ('mor', 0, 3) + ('ore', 1, 4) + ('more', 0, 4) + + Set *reverse* to ``True`` to yield the same items in the opposite order. + + + """ + r = range(1, len(seq) + 1) + if reverse: + r = reversed(r) + return ( + (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1) + ) + + +class bucket: """Wrap *iterable* and return an object that buckets it iterable into child iterables based on a *key* function. >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] - >>> s = bucket(iterable, key=lambda x: x[0]) + >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character + >>> sorted(list(s)) # Get the keys + ['a', 'b', 'c'] >>> a_iterable = s['a'] >>> next(a_iterable) 'a1' @@ -756,6 +890,7 @@ class bucket(object): [] """ + def __init__(self, iterable, key, validator=None): self._it = iter(iterable) self._key = key @@ -801,6 +936,14 @@ class bucket(object): elif self._validator(item_value): self._cache[item_value].append(item) + def __iter__(self): + for item in self._it: + item_value = self._key(item) + if self._validator(item_value): + self._cache[item_value].append(item) + + yield from self._cache.keys() + def __getitem__(self, value): if not self._validator(value): return iter(()) @@ -848,7 +991,7 @@ def spy(iterable, n=1): it = iter(iterable) head = take(n, it) - return head, chain(head, it) + return head.copy(), chain(head, it) def interleave(*iterables): @@ -881,6 +1024,72 @@ def interleave_longest(*iterables): return (x for x in i if x is not _marker) +def interleave_evenly(iterables, lengths=None): + """ + Interleave multiple iterables so that their elements are evenly distributed + throughout the output sequence. + + >>> iterables = [1, 2, 3, 4, 5], ['a', 'b'] + >>> list(interleave_evenly(iterables)) + [1, 2, 'a', 3, 4, 'b', 5] + + >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]] + >>> list(interleave_evenly(iterables)) + [1, 6, 4, 2, 7, 3, 8, 5] + + This function requires iterables of known length. Iterables without + ``__len__()`` can be used by manually specifying lengths with *lengths*: + + >>> from itertools import combinations, repeat + >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']] + >>> lengths = [4 * (4 - 1) // 2, 3] + >>> list(interleave_evenly(iterables, lengths=lengths)) + [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c'] + + Based on Bresenham's algorithm. + """ + if lengths is None: + try: + lengths = [len(it) for it in iterables] + except TypeError: + raise ValueError( + 'Iterable lengths could not be determined automatically. ' + 'Specify them with the lengths keyword.' + ) + elif len(iterables) != len(lengths): + raise ValueError('Mismatching number of iterables and lengths.') + + dims = len(lengths) + + # sort iterables by length, descending + lengths_permute = sorted( + range(dims), key=lambda i: lengths[i], reverse=True + ) + lengths_desc = [lengths[i] for i in lengths_permute] + iters_desc = [iter(iterables[i]) for i in lengths_permute] + + # the longest iterable is the primary one (Bresenham: the longest + # distance along an axis) + delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:] + iter_primary, iters_secondary = iters_desc[0], iters_desc[1:] + errors = [delta_primary // dims] * len(deltas_secondary) + + to_yield = sum(lengths) + while to_yield: + yield next(iter_primary) + to_yield -= 1 + # update errors for each secondary iterable + errors = [e - delta for e, delta in zip(errors, deltas_secondary)] + + # those iterables for which the error is negative are yielded + # ("diagonal step" in Bresenham) + for i, e in enumerate(errors): + if e < 0: + yield next(iters_secondary[i]) + to_yield -= 1 + errors[i] += delta_primary + + def collapse(iterable, base_type=None, levels=None): """Flatten an iterable with multiple levels of nesting (e.g., a list of lists of tuples) into non-iterable types. @@ -889,7 +1098,9 @@ def collapse(iterable, base_type=None, levels=None): >>> list(collapse(iterable)) [1, 2, 3, 4, 5, 6] - String types are not considered iterable and will not be collapsed. + Binary and text strings are not considered iterable and + will not be collapsed. + To avoid collapsing other types, specify *base_type*: >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']] @@ -905,11 +1116,12 @@ def collapse(iterable, base_type=None, levels=None): ['a', ['b'], 'c', ['d']] """ + def walk(node, level): if ( - ((levels is not None) and (level > levels)) or - isinstance(node, string_types) or - ((base_type is not None) and isinstance(node, base_type)) + ((levels is not None) and (level > levels)) + or isinstance(node, (str, bytes)) + or ((base_type is not None) and isinstance(node, base_type)) ): yield node return @@ -921,11 +1133,9 @@ def collapse(iterable, base_type=None, levels=None): return else: for child in tree: - for x in walk(child, level + 1): - yield x + yield from walk(child, level + 1) - for x in walk(iterable, 0): - yield x + yield from walk(iterable, 0) def side_effect(func, iterable, chunk_size=None, before=None, after=None): @@ -983,56 +1193,93 @@ def side_effect(func, iterable, chunk_size=None, before=None, after=None): else: for chunk in chunked(iterable, chunk_size): func(chunk) - for item in chunk: - yield item + yield from chunk finally: if after is not None: after() -def sliced(seq, n): +def sliced(seq, n, strict=False): """Yield slices of length *n* from the sequence *seq*. - >>> list(sliced((1, 2, 3, 4, 5, 6), 3)) - [(1, 2, 3), (4, 5, 6)] + >>> list(sliced((1, 2, 3, 4, 5, 6), 3)) + [(1, 2, 3), (4, 5, 6)] - If the length of the sequence is not divisible by the requested slice - length, the last slice will be shorter. + By the default, the last yielded slice will have fewer than *n* elements + if the length of *seq* is not divisible by *n*: - >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3)) - [(1, 2, 3), (4, 5, 6), (7, 8)] + >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3)) + [(1, 2, 3), (4, 5, 6), (7, 8)] + + If the length of *seq* is not divisible by *n* and *strict* is + ``True``, then ``ValueError`` will be raised before the last + slice is yielded. This function will only work for iterables that support slicing. For non-sliceable iterables, see :func:`chunked`. """ - return takewhile(bool, (seq[i: i + n] for i in count(0, n))) + iterator = takewhile(len, (seq[i : i + n] for i in count(0, n))) + if strict: + + def ret(): + for _slice in iterator: + if len(_slice) != n: + raise ValueError("seq is not divisible by n.") + yield _slice + + return iter(ret()) + else: + return iterator -def split_at(iterable, pred): +def split_at(iterable, pred, maxsplit=-1, keep_separator=False): """Yield lists of items from *iterable*, where each list is delimited by - an item where callable *pred* returns ``True``. The lists do not include - the delimiting items. + an item where callable *pred* returns ``True``. >>> list(split_at('abcdcba', lambda x: x == 'b')) [['a'], ['c', 'd', 'c'], ['a']] >>> list(split_at(range(10), lambda n: n % 2 == 1)) [[0], [2], [4], [6], [8], []] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2)) + [[0], [2], [4, 5, 6, 7, 8, 9]] + + By default, the delimiting items are not included in the output. + The include them, set *keep_separator* to ``True``. + + >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True)) + [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']] + """ + if maxsplit == 0: + yield list(iterable) + return + buf = [] - for item in iterable: + it = iter(iterable) + for item in it: if pred(item): yield buf + if keep_separator: + yield [item] + if maxsplit == 1: + yield list(it) + return buf = [] + maxsplit -= 1 else: buf.append(item) yield buf -def split_before(iterable, pred): - """Yield lists of items from *iterable*, where each list starts with an - item where callable *pred* returns ``True``: +def split_before(iterable, pred, maxsplit=-1): + """Yield lists of items from *iterable*, where each list ends just before + an item for which callable *pred* returns ``True``: >>> list(split_before('OneTwo', lambda s: s.isupper())) [['O', 'n', 'e'], ['T', 'w', 'o']] @@ -1040,17 +1287,32 @@ def split_before(iterable, pred): >>> list(split_before(range(10), lambda n: n % 3 == 0)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]] """ + if maxsplit == 0: + yield list(iterable) + return + buf = [] - for item in iterable: + it = iter(iterable) + for item in it: if pred(item) and buf: yield buf + if maxsplit == 1: + yield [item] + list(it) + return buf = [] + maxsplit -= 1 buf.append(item) - yield buf + if buf: + yield buf -def split_after(iterable, pred): +def split_after(iterable, pred, maxsplit=-1): """Yield lists of items from *iterable*, where each list ends with an item where callable *pred* returns ``True``: @@ -1060,17 +1322,77 @@ def split_after(iterable, pred): >>> list(split_after(range(10), lambda n: n % 3 == 0)) [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]] + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2)) + [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]] + """ + if maxsplit == 0: + yield list(iterable) + return + buf = [] - for item in iterable: + it = iter(iterable) + for item in it: buf.append(item) if pred(item) and buf: yield buf + if maxsplit == 1: + yield list(it) + return buf = [] + maxsplit -= 1 if buf: yield buf +def split_when(iterable, pred, maxsplit=-1): + """Split *iterable* into pieces based on the output of *pred*. + *pred* should be a function that takes successive pairs of items and + returns ``True`` if the iterable should be split in between them. + + For example, to find runs of increasing numbers, split the iterable when + element ``i`` is larger than element ``i + 1``: + + >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y)) + [[1, 2, 3, 3], [2, 5], [2, 4], [2]] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], + ... lambda x, y: x > y, maxsplit=2)) + [[1, 2, 3, 3], [2, 5], [2, 4, 2]] + + """ + if maxsplit == 0: + yield list(iterable) + return + + it = iter(iterable) + try: + cur_item = next(it) + except StopIteration: + return + + buf = [cur_item] + for next_item in it: + if pred(cur_item, next_item): + yield buf + if maxsplit == 1: + yield [next_item] + list(it) + return + buf = [] + maxsplit -= 1 + + buf.append(next_item) + cur_item = next_item + + yield buf + + def split_into(iterable, sizes): """Yield a list of sequential items from *iterable* of length 'n' for each integer 'n' in *sizes*. @@ -1134,8 +1456,7 @@ def padded(iterable, fillvalue=None, n=None, next_multiple=False): """ it = iter(iterable) if n is None: - for item in chain(it, repeat(fillvalue)): - yield item + yield from chain(it, repeat(fillvalue)) elif n < 1: raise ValueError('n must be at least 1') else: @@ -1149,6 +1470,34 @@ def padded(iterable, fillvalue=None, n=None, next_multiple=False): yield fillvalue +def repeat_each(iterable, n=2): + """Repeat each element in *iterable* *n* times. + + >>> list(repeat_each('ABC', 3)) + ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C'] + """ + return chain.from_iterable(map(repeat, iterable, repeat(n))) + + +def repeat_last(iterable, default=None): + """After the *iterable* is exhausted, keep yielding its last element. + + >>> list(islice(repeat_last(range(3)), 5)) + [0, 1, 2, 2, 2] + + If the iterable is empty, yield *default* forever:: + + >>> list(islice(repeat_last(range(0), 42), 5)) + [42, 42, 42, 42, 42] + + """ + item = _marker + for item in iterable: + yield item + final = default if item is _marker else item + yield from repeat(final) + + def distribute(n, iterable): """Distribute the items from *iterable* among *n* smaller iterables. @@ -1212,7 +1561,72 @@ def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None): ) -def zip_offset(*iterables, **kwargs): +class UnequalIterablesError(ValueError): + def __init__(self, details=None): + msg = 'Iterables have different lengths' + if details is not None: + msg += (': index 0 has length {}; index {} has length {}').format( + *details + ) + + super().__init__(msg) + + +def _zip_equal_generator(iterables): + for combo in zip_longest(*iterables, fillvalue=_marker): + for val in combo: + if val is _marker: + raise UnequalIterablesError() + yield combo + + +def zip_equal(*iterables): + """``zip`` the input *iterables* together, but raise + ``UnequalIterablesError`` if they aren't all the same length. + + >>> it_1 = range(3) + >>> it_2 = iter('abc') + >>> list(zip_equal(it_1, it_2)) + [(0, 'a'), (1, 'b'), (2, 'c')] + + >>> it_1 = range(3) + >>> it_2 = iter('abcd') + >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + more_itertools.more.UnequalIterablesError: Iterables have different + lengths + + """ + if hexversion >= 0x30A00A6: + warnings.warn( + ( + 'zip_equal will be removed in a future version of ' + 'more-itertools. Use the builtin zip function with ' + 'strict=True instead.' + ), + DeprecationWarning, + ) + # Check whether the iterables are all the same size. + try: + first_size = len(iterables[0]) + for i, it in enumerate(iterables[1:], 1): + size = len(it) + if size != first_size: + break + else: + # If we didn't break out, we can use the built-in zip. + return zip(*iterables) + + # If we did break out, there was a mismatch. + raise UnequalIterablesError(details=(first_size, i, size)) + # If any one of the iterables didn't have a length, start reading + # them until one runs out. + except TypeError: + return _zip_equal_generator(iterables) + + +def zip_offset(*iterables, offsets, longest=False, fillvalue=None): """``zip`` the input *iterables* together, but offset the `i`-th iterable by the `i`-th item in *offsets*. @@ -1233,10 +1647,6 @@ def zip_offset(*iterables, **kwargs): sequence. Specify *fillvalue* to use some other value. """ - offsets = kwargs['offsets'] - longest = kwargs.get('longest', False) - fillvalue = kwargs.get('fillvalue', None) - if len(iterables) != len(offsets): raise ValueError("Number of iterables and offsets didn't match") @@ -1255,7 +1665,7 @@ def zip_offset(*iterables, **kwargs): return zip(*staggered) -def sort_together(iterables, key_list=(0,), reverse=False): +def sort_together(iterables, key_list=(0,), key=None, reverse=False): """Return the input iterables sorted together, with *key_list* as the priority for sorting. All iterables are trimmed to the length of the shortest one. @@ -1277,15 +1687,48 @@ def sort_together(iterables, key_list=(0,), reverse=False): >>> sort_together(iterables, key_list=(1, 2)) [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')] + To sort by a function of the elements of the iterable, pass a *key* + function. Its arguments are the elements of the iterables corresponding to + the key list:: + + >>> names = ('a', 'b', 'c') + >>> lengths = (1, 2, 3) + >>> widths = (5, 2, 1) + >>> def area(length, width): + ... return length * width + >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area) + [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)] + Set *reverse* to ``True`` to sort in descending order. >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True) [(3, 2, 1), ('a', 'b', 'c')] """ - return list(zip(*sorted(zip(*iterables), - key=itemgetter(*key_list), - reverse=reverse))) + if key is None: + # if there is no key function, the key argument to sorted is an + # itemgetter + key_argument = itemgetter(*key_list) + else: + # if there is a key function, call it with the items at the offsets + # specified by the key function as arguments + key_list = list(key_list) + if len(key_list) == 1: + # if key_list contains a single item, pass the item at that offset + # as the only argument to the key function + key_offset = key_list[0] + key_argument = lambda zipped_items: key(zipped_items[key_offset]) + else: + # if key_list contains multiple items, use itemgetter to return a + # tuple of items, which we pass as *args to the key function + get_key_items = itemgetter(*key_list) + key_argument = lambda zipped_items: key( + *get_key_items(zipped_items) + ) + + return list( + zip(*sorted(zip(*iterables), key=key_argument, reverse=reverse)) + ) def unzip(iterable): @@ -1331,6 +1774,7 @@ def unzip(iterable): # which just stops the unzipped iterables # at first length mismatch raise StopIteration + return getter return tuple(map(itemgetter(i), it) for i, it in enumerate(iterables)) @@ -1368,19 +1812,26 @@ def divide(n, iterable): if n < 1: raise ValueError('n must be at least 1') - seq = tuple(iterable) + try: + iterable[:0] + except TypeError: + seq = tuple(iterable) + else: + seq = iterable + q, r = divmod(len(seq), n) ret = [] - for i in range(n): - start = (i * q) + (i if i < r else r) - stop = ((i + 1) * q) + (i + 1 if i + 1 < r else r) + stop = 0 + for i in range(1, n + 1): + start = stop + stop += q + 1 if i <= r else q ret.append(iter(seq[start:stop])) return ret -def always_iterable(obj, base_type=(text_type, binary_type)): +def always_iterable(obj, base_type=(str, bytes)): """If *obj* is iterable, return an iterator over its items:: >>> obj = (1, 2, 3) @@ -1472,21 +1923,23 @@ def adjacent(predicate, iterable, distance=1): return zip(adjacent_to_selected, i2) -def groupby_transform(iterable, keyfunc=None, valuefunc=None): - """An extension of :func:`itertools.groupby` that transforms the values of - *iterable* after grouping them. - *keyfunc* is a function used to compute a grouping key for each item. - *valuefunc* is a function for transforming the items after grouping. +def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None): + """An extension of :func:`itertools.groupby` that can apply transformations + to the grouped data. - >>> iterable = 'AaaABbBCcA' - >>> keyfunc = lambda x: x.upper() - >>> valuefunc = lambda x: x.lower() - >>> grouper = groupby_transform(iterable, keyfunc, valuefunc) - >>> [(k, ''.join(g)) for k, g in grouper] - [('A', 'aaaa'), ('B', 'bbb'), ('C', 'cc'), ('A', 'a')] + * *keyfunc* is a function computing a key value for each item in *iterable* + * *valuefunc* is a function that transforms the individual items from + *iterable* after grouping + * *reducefunc* is a function that transforms each group of items - *keyfunc* and *valuefunc* default to identity functions if they are not - specified. + >>> iterable = 'aAAbBBcCC' + >>> keyfunc = lambda k: k.upper() + >>> valuefunc = lambda v: v.lower() + >>> reducefunc = lambda g: ''.join(g) + >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc)) + [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')] + + Each optional argument defaults to an identity function if not specified. :func:`groupby_transform` is useful when grouping elements of an iterable using a separate iterable as the key. To do this, :func:`zip` the iterables @@ -1506,11 +1959,16 @@ def groupby_transform(iterable, keyfunc=None, valuefunc=None): duplicate groups, you should sort the iterable by the key function. """ - valuefunc = (lambda x: x) if valuefunc is None else valuefunc - return ((k, map(valuefunc, g)) for k, g in groupby(iterable, keyfunc)) + ret = groupby(iterable, keyfunc) + if valuefunc: + ret = ((k, map(valuefunc, g)) for k, g in ret) + if reducefunc: + ret = ((k, reducefunc(g)) for k, g in ret) + + return ret -def numeric_range(*args): +class numeric_range(abc.Sequence, abc.Hashable): """An extension of the built-in ``range()`` function whose arguments can be any orderable numeric type. @@ -1547,28 +2005,184 @@ def numeric_range(*args): Be aware of the limitations of floating point numbers; the representation of the yielded numbers may be surprising. - """ - argc = len(args) - if argc == 1: - stop, = args - start = type(stop)(0) - step = 1 - elif argc == 2: - start, stop = args - step = 1 - elif argc == 3: - start, stop, step = args - else: - err_msg = 'numeric_range takes at most 3 arguments, got {}' - raise TypeError(err_msg.format(argc)) + ``datetime.datetime`` objects can be used for *start* and *stop*, if *step* + is a ``datetime.timedelta`` object: - values = (start + (step * n) for n in count()) - if step > 0: - return takewhile(partial(gt, stop), values) - elif step < 0: - return takewhile(partial(lt, stop), values) - else: - raise ValueError('numeric_range arg 3 must not be zero') + >>> import datetime + >>> start = datetime.datetime(2019, 1, 1) + >>> stop = datetime.datetime(2019, 1, 3) + >>> step = datetime.timedelta(days=1) + >>> items = iter(numeric_range(start, stop, step)) + >>> next(items) + datetime.datetime(2019, 1, 1, 0, 0) + >>> next(items) + datetime.datetime(2019, 1, 2, 0, 0) + + """ + + _EMPTY_HASH = hash(range(0, 0)) + + def __init__(self, *args): + argc = len(args) + if argc == 1: + (self._stop,) = args + self._start = type(self._stop)(0) + self._step = type(self._stop - self._start)(1) + elif argc == 2: + self._start, self._stop = args + self._step = type(self._stop - self._start)(1) + elif argc == 3: + self._start, self._stop, self._step = args + elif argc == 0: + raise TypeError( + 'numeric_range expected at least ' + '1 argument, got {}'.format(argc) + ) + else: + raise TypeError( + 'numeric_range expected at most ' + '3 arguments, got {}'.format(argc) + ) + + self._zero = type(self._step)(0) + if self._step == self._zero: + raise ValueError('numeric_range() arg 3 must not be zero') + self._growing = self._step > self._zero + self._init_len() + + def __bool__(self): + if self._growing: + return self._start < self._stop + else: + return self._start > self._stop + + def __contains__(self, elem): + if self._growing: + if self._start <= elem < self._stop: + return (elem - self._start) % self._step == self._zero + else: + if self._start >= elem > self._stop: + return (self._start - elem) % (-self._step) == self._zero + + return False + + def __eq__(self, other): + if isinstance(other, numeric_range): + empty_self = not bool(self) + empty_other = not bool(other) + if empty_self or empty_other: + return empty_self and empty_other # True if both empty + else: + return ( + self._start == other._start + and self._step == other._step + and self._get_by_index(-1) == other._get_by_index(-1) + ) + else: + return False + + def __getitem__(self, key): + if isinstance(key, int): + return self._get_by_index(key) + elif isinstance(key, slice): + step = self._step if key.step is None else key.step * self._step + + if key.start is None or key.start <= -self._len: + start = self._start + elif key.start >= self._len: + start = self._stop + else: # -self._len < key.start < self._len + start = self._get_by_index(key.start) + + if key.stop is None or key.stop >= self._len: + stop = self._stop + elif key.stop <= -self._len: + stop = self._start + else: # -self._len < key.stop < self._len + stop = self._get_by_index(key.stop) + + return numeric_range(start, stop, step) + else: + raise TypeError( + 'numeric range indices must be ' + 'integers or slices, not {}'.format(type(key).__name__) + ) + + def __hash__(self): + if self: + return hash((self._start, self._get_by_index(-1), self._step)) + else: + return self._EMPTY_HASH + + def __iter__(self): + values = (self._start + (n * self._step) for n in count()) + if self._growing: + return takewhile(partial(gt, self._stop), values) + else: + return takewhile(partial(lt, self._stop), values) + + def __len__(self): + return self._len + + def _init_len(self): + if self._growing: + start = self._start + stop = self._stop + step = self._step + else: + start = self._stop + stop = self._start + step = -self._step + distance = stop - start + if distance <= self._zero: + self._len = 0 + else: # distance > 0 and step > 0: regular euclidean division + q, r = divmod(distance, step) + self._len = int(q) + int(r != self._zero) + + def __reduce__(self): + return numeric_range, (self._start, self._stop, self._step) + + def __repr__(self): + if self._step == 1: + return "numeric_range({}, {})".format( + repr(self._start), repr(self._stop) + ) + else: + return "numeric_range({}, {}, {})".format( + repr(self._start), repr(self._stop), repr(self._step) + ) + + def __reversed__(self): + return iter( + numeric_range( + self._get_by_index(-1), self._start - self._step, -self._step + ) + ) + + def count(self, value): + return int(value in self) + + def index(self, value): + if self._growing: + if self._start <= value < self._stop: + q, r = divmod(value - self._start, self._step) + if r == self._zero: + return int(q) + else: + if self._start >= value > self._stop: + q, r = divmod(self._start - value, -self._step) + if r == self._zero: + return int(q) + + raise ValueError("{} is not in numeric range".format(value)) + + def _get_by_index(self, i): + if i < 0: + i += self._len + if i < 0 or i >= self._len: + raise IndexError("numeric range object index out of range") + return self._start + i * self._step def count_cycle(iterable, n=None): @@ -1587,6 +2201,43 @@ def count_cycle(iterable, n=None): return ((i, item) for i in counter for item in iterable) +def mark_ends(iterable): + """Yield 3-tuples of the form ``(is_first, is_last, item)``. + + >>> list(mark_ends('ABC')) + [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')] + + Use this when looping over an iterable to take special action on its first + and/or last items: + + >>> iterable = ['Header', 100, 200, 'Footer'] + >>> total = 0 + >>> for is_first, is_last, item in mark_ends(iterable): + ... if is_first: + ... continue # Skip the header + ... if is_last: + ... continue # Skip the footer + ... total += item + >>> print(total) + 300 + """ + it = iter(iterable) + + try: + b = next(it) + except StopIteration: + return + + try: + for i in count(): + a = b + b = next(it) + yield i == 0, False, a + + except StopIteration: + yield i == 0, True, a + + def locate(iterable, pred=bool, window_size=None): """Yield the index of each item in *iterable* for which *pred* returns ``True``. @@ -1669,13 +2320,13 @@ def rstrip(iterable, pred): """ cache = [] cache_append = cache.append + cache_clear = cache.clear for x in iterable: if pred(x): cache_append(x) else: - for y in cache: - yield y - del cache[:] + yield from cache + cache_clear() yield x @@ -1696,7 +2347,7 @@ def strip(iterable, pred): return rstrip(lstrip(iterable, pred), pred) -def islice_extended(iterable, *args): +class islice_extended: """An extension of :func:`itertools.islice` that supports negative values for *stop*, *start*, and *step*. @@ -1713,20 +2364,46 @@ def islice_extended(iterable, *args): >>> list(islice_extended(count(), 110, 99, -2)) [110, 108, 106, 104, 102, 100] + You can also use slice notation directly: + + >>> iterable = map(str, count()) + >>> it = islice_extended(iterable)[10:20:2] + >>> list(it) + ['10', '12', '14', '16', '18'] + """ - s = slice(*args) + + def __init__(self, iterable, *args): + it = iter(iterable) + if args: + self._iterable = _islice_helper(it, slice(*args)) + else: + self._iterable = it + + def __iter__(self): + return self + + def __next__(self): + return next(self._iterable) + + def __getitem__(self, key): + if isinstance(key, slice): + return islice_extended(_islice_helper(self._iterable, key)) + + raise TypeError('islice_extended.__getitem__ argument must be a slice') + + +def _islice_helper(it, s): start = s.start stop = s.stop if s.step == 0: raise ValueError('step argument must be a non-zero integer or None.') step = s.step or 1 - it = iter(iterable) - if step > 0: start = 0 if (start is None) else start - if (start < 0): + if start < 0: # Consume all but the last -start items cache = deque(enumerate(it, 1), maxlen=-start) len_iter = cache[-1][0] if cache else 0 @@ -1764,8 +2441,7 @@ def islice_extended(iterable, *args): cache.append(item) else: # When both start and stop are positive we have the normal case - for item in islice(it, start, stop, step): - yield item + yield from islice(it, start, stop, step) else: start = -1 if (start is None) else start @@ -1810,8 +2486,7 @@ def islice_extended(iterable, *args): cache = list(islice(it, n)) - for item in cache[i::step]: - yield item + yield from cache[i::step] def always_reversible(iterable): @@ -1862,6 +2537,17 @@ def consecutive_groups(iterable, ordering=lambda x: x): ['i'] ['l', 'm', 'n', 'o', 'p'] + Each group of consecutive items is an iterator that shares it source with + *iterable*. When an an output group is advanced, the previous group is + no longer available unless its elements are copied (e.g., into a ``list``). + + >>> iterable = [1, 2, 11, 12, 21, 22] + >>> saved_groups = [] + >>> for group in consecutive_groups(iterable): + ... saved_groups.append(list(group)) # Copy group elements + >>> saved_groups + [[1, 2], [11, 12], [21, 22]] + """ for k, g in groupby( enumerate(iterable), key=lambda x: x[0] - ordering(x[1]) @@ -1869,42 +2555,46 @@ def consecutive_groups(iterable, ordering=lambda x: x): yield map(itemgetter(1), g) -def difference(iterable, func=sub): - """By default, compute the first difference of *iterable* using - :func:`operator.sub`. +def difference(iterable, func=sub, *, initial=None): + """This function is the inverse of :func:`itertools.accumulate`. By default + it will compute the first difference of *iterable* using + :func:`operator.sub`: - >>> iterable = [0, 1, 3, 6, 10] + >>> from itertools import accumulate + >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10 >>> list(difference(iterable)) [0, 1, 2, 3, 4] - This is the opposite of :func:`accumulate`'s default behavior: - - >>> from more_itertools import accumulate - >>> iterable = [0, 1, 2, 3, 4] - >>> list(accumulate(iterable)) - [0, 1, 3, 6, 10] - >>> list(difference(accumulate(iterable))) - [0, 1, 2, 3, 4] - - By default *func* is :func:`operator.sub`, but other functions can be + *func* defaults to :func:`operator.sub`, but other functions can be specified. They will be applied as follows:: A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ... For example, to do progressive division: - >>> iterable = [1, 2, 6, 24, 120] # Factorial sequence + >>> iterable = [1, 2, 6, 24, 120] >>> func = lambda x, y: x // y >>> list(difference(iterable, func)) [1, 2, 3, 4, 5] + If the *initial* keyword is set, the first element will be skipped when + computing successive differences. + + >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10) + >>> list(difference(it, initial=10)) + [1, 2, 3] + """ a, b = tee(iterable) try: - item = next(b) + first = [next(b)] except StopIteration: return iter([]) - return chain([item], map(lambda x: func(x[1], x[0]), zip(a, b))) + + if initial is not None: + first = [] + + return chain(first, starmap(func, zip(b, a))) class SequenceView(Sequence): @@ -1936,6 +2626,7 @@ class SequenceView(Sequence): require (much) extra storage. """ + def __init__(self, target): if not isinstance(target, Sequence): raise TypeError @@ -1951,7 +2642,7 @@ class SequenceView(Sequence): return '{}({})'.format(self.__class__.__name__, repr(self._target)) -class seekable(object): +class seekable: """Wrap an iterator to allow for seeking backward and forward. This progressively caches the items in the source iterable so they can be re-visited. @@ -1984,8 +2675,26 @@ class seekable(object): >>> next(it), next(it), next(it) ('0', '1', '2') - The cache grows as the source iterable progresses, so beware of wrapping - very large or infinite iterables. + Call :meth:`peek` to look ahead one item without advancing the iterator: + + >>> it = seekable('1234') + >>> it.peek() + '1' + >>> list(it) + ['1', '2', '3', '4'] + >>> it.peek(default='empty') + 'empty' + + Before the iterator is at its end, calling :func:`bool` on it will return + ``True``. After it will return ``False``: + + >>> it = seekable('5678') + >>> bool(it) + True + >>> list(it) + ['5', '6', '7', '8'] + >>> bool(it) + False You may view the contents of the cache with the :meth:`elements` method. That returns a :class:`SequenceView`, a view that updates automatically: @@ -2001,11 +2710,30 @@ class seekable(object): >>> elements SequenceView(['0', '1', '2', '3']) + By default, the cache grows as the source iterable progresses, so beware of + wrapping very large or infinite iterables. Supply *maxlen* to limit the + size of the cache (this of course limits how far back you can seek). + + >>> from itertools import count + >>> it = seekable((str(n) for n in count()), maxlen=2) + >>> next(it), next(it), next(it), next(it) + ('0', '1', '2', '3') + >>> list(it.elements()) + ['2', '3'] + >>> it.seek(0) + >>> next(it), next(it), next(it), next(it) + ('2', '3', '4', '5') + >>> next(it) + '6' + """ - def __init__(self, iterable): + def __init__(self, iterable, maxlen=None): self._source = iter(iterable) - self._cache = [] + if maxlen is None: + self._cache = [] + else: + self._cache = deque([], maxlen) self._index = None def __iter__(self): @@ -2025,7 +2753,24 @@ class seekable(object): self._cache.append(item) return item - next = __next__ + def __bool__(self): + try: + self.peek() + except StopIteration: + return False + return True + + def peek(self, default=_marker): + try: + peeked = next(self) + except StopIteration: + if default is _marker: + raise + return default + if self._index is None: + self._index = len(self._cache) + self._index -= 1 + return peeked def elements(self): return SequenceView(self._cache) @@ -2037,7 +2782,7 @@ class seekable(object): consume(self, remainder) -class run_length(object): +class run_length: """ :func:`run_length.encode` compresses an iterable with run-length encoding. It yields groups of repeated items with the count of how many times they @@ -2087,8 +2832,8 @@ def exactly_n(iterable, n, predicate=bool): def circular_shifts(iterable): """Return a list of circular shifts of *iterable*. - >>> circular_shifts(range(4)) - [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] + >>> circular_shifts(range(4)) + [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] """ lst = list(iterable) return take(len(lst), windowed(cycle(lst), len(lst))) @@ -2262,9 +3007,7 @@ def rlocate(iterable, pred=bool, window_size=None): if window_size is None: try: len_iter = len(iterable) - return ( - len_iter - i - 1 for i in locate(reversed(iterable), pred) - ) + return (len_iter - i - 1 for i in locate(reversed(iterable), pred)) except TypeError: pass @@ -2322,8 +3065,7 @@ def replace(iterable, pred, substitutes, count=None, window_size=1): if pred(*w): if (count is None) or (n < count): n += 1 - for s in substitutes: - yield s + yield from substitutes consume(windows, window_size - 1) continue @@ -2331,3 +3073,982 @@ def replace(iterable, pred, substitutes, count=None, window_size=1): # yield the first item from the window. if w and (w[0] is not _marker): yield w[0] + + +def partitions(iterable): + """Yield all possible order-preserving partitions of *iterable*. + + >>> iterable = 'abc' + >>> for part in partitions(iterable): + ... print([''.join(p) for p in part]) + ['abc'] + ['a', 'bc'] + ['ab', 'c'] + ['a', 'b', 'c'] + + This is unrelated to :func:`partition`. + + """ + sequence = list(iterable) + n = len(sequence) + for i in powerset(range(1, n)): + yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))] + + +def set_partitions(iterable, k=None): + """ + Yield the set partitions of *iterable* into *k* parts. Set partitions are + not order-preserving. + + >>> iterable = 'abc' + >>> for part in set_partitions(iterable, 2): + ... print([''.join(p) for p in part]) + ['a', 'bc'] + ['ab', 'c'] + ['b', 'ac'] + + + If *k* is not given, every set partition is generated. + + >>> iterable = 'abc' + >>> for part in set_partitions(iterable): + ... print([''.join(p) for p in part]) + ['abc'] + ['a', 'bc'] + ['ab', 'c'] + ['b', 'ac'] + ['a', 'b', 'c'] + + """ + L = list(iterable) + n = len(L) + if k is not None: + if k < 1: + raise ValueError( + "Can't partition in a negative or zero number of groups" + ) + elif k > n: + return + + def set_partitions_helper(L, k): + n = len(L) + if k == 1: + yield [L] + elif n == k: + yield [[s] for s in L] + else: + e, *M = L + for p in set_partitions_helper(M, k - 1): + yield [[e], *p] + for p in set_partitions_helper(M, k): + for i in range(len(p)): + yield p[:i] + [[e] + p[i]] + p[i + 1 :] + + if k is None: + for k in range(1, n + 1): + yield from set_partitions_helper(L, k) + else: + yield from set_partitions_helper(L, k) + + +class time_limited: + """ + Yield items from *iterable* until *limit_seconds* have passed. + If the time limit expires before all items have been yielded, the + ``timed_out`` parameter will be set to ``True``. + + >>> from time import sleep + >>> def generator(): + ... yield 1 + ... yield 2 + ... sleep(0.2) + ... yield 3 + >>> iterable = time_limited(0.1, generator()) + >>> list(iterable) + [1, 2] + >>> iterable.timed_out + True + + Note that the time is checked before each item is yielded, and iteration + stops if the time elapsed is greater than *limit_seconds*. If your time + limit is 1 second, but it takes 2 seconds to generate the first item from + the iterable, the function will run for 2 seconds and not yield anything. + + """ + + def __init__(self, limit_seconds, iterable): + if limit_seconds < 0: + raise ValueError('limit_seconds must be positive') + self.limit_seconds = limit_seconds + self._iterable = iter(iterable) + self._start_time = monotonic() + self.timed_out = False + + def __iter__(self): + return self + + def __next__(self): + item = next(self._iterable) + if monotonic() - self._start_time > self.limit_seconds: + self.timed_out = True + raise StopIteration + + return item + + +def only(iterable, default=None, too_long=None): + """If *iterable* has only one item, return it. + If it has zero items, return *default*. + If it has more than one item, raise the exception given by *too_long*, + which is ``ValueError`` by default. + + >>> only([], default='missing') + 'missing' + >>> only([1]) + 1 + >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: Expected exactly one item in iterable, but got 1, 2, + and perhaps more.' + >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError + + Note that :func:`only` attempts to advance *iterable* twice to ensure there + is only one item. See :func:`spy` or :func:`peekable` to check + iterable contents less destructively. + """ + it = iter(iterable) + first_value = next(it, default) + + try: + second_value = next(it) + except StopIteration: + pass + else: + msg = ( + 'Expected exactly one item in iterable, but got {!r}, {!r}, ' + 'and perhaps more.'.format(first_value, second_value) + ) + raise too_long or ValueError(msg) + + return first_value + + +def ichunked(iterable, n): + """Break *iterable* into sub-iterables with *n* elements each. + :func:`ichunked` is like :func:`chunked`, but it yields iterables + instead of lists. + + If the sub-iterables are read in order, the elements of *iterable* + won't be stored in memory. + If they are read out of order, :func:`itertools.tee` is used to cache + elements as necessary. + + >>> from itertools import count + >>> all_chunks = ichunked(count(), 4) + >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks) + >>> list(c_2) # c_1's elements have been cached; c_3's haven't been + [4, 5, 6, 7] + >>> list(c_1) + [0, 1, 2, 3] + >>> list(c_3) + [8, 9, 10, 11] + + """ + source = iter(iterable) + + while True: + # Check to see whether we're at the end of the source iterable + item = next(source, _marker) + if item is _marker: + return + + # Clone the source and yield an n-length slice + source, it = tee(chain([item], source)) + yield islice(it, n) + + # Advance the source iterable + consume(source, n) + + +def distinct_combinations(iterable, r): + """Yield the distinct combinations of *r* items taken from *iterable*. + + >>> list(distinct_combinations([0, 0, 1], 2)) + [(0, 0), (0, 1)] + + Equivalent to ``set(combinations(iterable))``, except duplicates are not + generated and thrown away. For larger input sequences this is much more + efficient. + + """ + if r < 0: + raise ValueError('r must be non-negative') + elif r == 0: + yield () + return + pool = tuple(iterable) + generators = [unique_everseen(enumerate(pool), key=itemgetter(1))] + current_combo = [None] * r + level = 0 + while generators: + try: + cur_idx, p = next(generators[-1]) + except StopIteration: + generators.pop() + level -= 1 + continue + current_combo[level] = p + if level + 1 == r: + yield tuple(current_combo) + else: + generators.append( + unique_everseen( + enumerate(pool[cur_idx + 1 :], cur_idx + 1), + key=itemgetter(1), + ) + ) + level += 1 + + +def filter_except(validator, iterable, *exceptions): + """Yield the items from *iterable* for which the *validator* function does + not raise one of the specified *exceptions*. + + *validator* is called for each item in *iterable*. + It should be a function that accepts one argument and raises an exception + if that item is not valid. + + >>> iterable = ['1', '2', 'three', '4', None] + >>> list(filter_except(int, iterable, ValueError, TypeError)) + ['1', '2', '4'] + + If an exception other than one given by *exceptions* is raised by + *validator*, it is raised like normal. + """ + for item in iterable: + try: + validator(item) + except exceptions: + pass + else: + yield item + + +def map_except(function, iterable, *exceptions): + """Transform each item from *iterable* with *function* and yield the + result, unless *function* raises one of the specified *exceptions*. + + *function* is called to transform each item in *iterable*. + It should accept one argument. + + >>> iterable = ['1', '2', 'three', '4', None] + >>> list(map_except(int, iterable, ValueError, TypeError)) + [1, 2, 4] + + If an exception other than one given by *exceptions* is raised by + *function*, it is raised like normal. + """ + for item in iterable: + try: + yield function(item) + except exceptions: + pass + + +def map_if(iterable, pred, func, func_else=lambda x: x): + """Evaluate each item from *iterable* using *pred*. If the result is + equivalent to ``True``, transform the item with *func* and yield it. + Otherwise, transform the item with *func_else* and yield it. + + *pred*, *func*, and *func_else* should each be functions that accept + one argument. By default, *func_else* is the identity function. + + >>> from math import sqrt + >>> iterable = list(range(-5, 5)) + >>> iterable + [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] + >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig')) + [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig'] + >>> list(map_if(iterable, lambda x: x >= 0, + ... lambda x: f'{sqrt(x):.2f}', lambda x: None)) + [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00'] + """ + for item in iterable: + yield func(item) if pred(item) else func_else(item) + + +def _sample_unweighted(iterable, k): + # Implementation of "Algorithm L" from the 1994 paper by Kim-Hung Li: + # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))". + + # Fill up the reservoir (collection of samples) with the first `k` samples + reservoir = take(k, iterable) + + # Generate random number that's the largest in a sample of k U(0,1) numbers + # Largest order statistic: https://en.wikipedia.org/wiki/Order_statistic + W = exp(log(random()) / k) + + # The number of elements to skip before changing the reservoir is a random + # number with a geometric distribution. Sample it using random() and logs. + next_index = k + floor(log(random()) / log(1 - W)) + + for index, element in enumerate(iterable, k): + + if index == next_index: + reservoir[randrange(k)] = element + # The new W is the largest in a sample of k U(0, `old_W`) numbers + W *= exp(log(random()) / k) + next_index += floor(log(random()) / log(1 - W)) + 1 + + return reservoir + + +def _sample_weighted(iterable, k, weights): + # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. : + # "Weighted random sampling with a reservoir". + + # Log-transform for numerical stability for weights that are small/large + weight_keys = (log(random()) / weight for weight in weights) + + # Fill up the reservoir (collection of samples) with the first `k` + # weight-keys and elements, then heapify the list. + reservoir = take(k, zip(weight_keys, iterable)) + heapify(reservoir) + + # The number of jumps before changing the reservoir is a random variable + # with an exponential distribution. Sample it using random() and logs. + smallest_weight_key, _ = reservoir[0] + weights_to_skip = log(random()) / smallest_weight_key + + for weight, element in zip(weights, iterable): + if weight >= weights_to_skip: + # The notation here is consistent with the paper, but we store + # the weight-keys in log-space for better numerical stability. + smallest_weight_key, _ = reservoir[0] + t_w = exp(weight * smallest_weight_key) + r_2 = uniform(t_w, 1) # generate U(t_w, 1) + weight_key = log(r_2) / weight + heapreplace(reservoir, (weight_key, element)) + smallest_weight_key, _ = reservoir[0] + weights_to_skip = log(random()) / smallest_weight_key + else: + weights_to_skip -= weight + + # Equivalent to [element for weight_key, element in sorted(reservoir)] + return [heappop(reservoir)[1] for _ in range(k)] + + +def sample(iterable, k, weights=None): + """Return a *k*-length list of elements chosen (without replacement) + from the *iterable*. Like :func:`random.sample`, but works on iterables + of unknown length. + + >>> iterable = range(100) + >>> sample(iterable, 5) # doctest: +SKIP + [81, 60, 96, 16, 4] + + An iterable with *weights* may also be given: + + >>> iterable = range(100) + >>> weights = (i * i + 1 for i in range(100)) + >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP + [79, 67, 74, 66, 78] + + The algorithm can also be used to generate weighted random permutations. + The relative weight of each item determines the probability that it + appears late in the permutation. + + >>> data = "abcdefgh" + >>> weights = range(1, len(data) + 1) + >>> sample(data, k=len(data), weights=weights) # doctest: +SKIP + ['c', 'a', 'b', 'e', 'g', 'd', 'h', 'f'] + """ + if k == 0: + return [] + + iterable = iter(iterable) + if weights is None: + return _sample_unweighted(iterable, k) + else: + weights = iter(weights) + return _sample_weighted(iterable, k, weights) + + +def is_sorted(iterable, key=None, reverse=False): + """Returns ``True`` if the items of iterable are in sorted order, and + ``False`` otherwise. *key* and *reverse* have the same meaning that they do + in the built-in :func:`sorted` function. + + >>> is_sorted(['1', '2', '3', '4', '5'], key=int) + True + >>> is_sorted([5, 4, 3, 1, 2], reverse=True) + False + + The function returns ``False`` after encountering the first out-of-order + item. If there are no out-of-order items, the iterable is exhausted. + """ + + compare = lt if reverse else gt + it = iterable if (key is None) else map(key, iterable) + return not any(starmap(compare, pairwise(it))) + + +class AbortThread(BaseException): + pass + + +class callback_iter: + """Convert a function that uses callbacks to an iterator. + + Let *func* be a function that takes a `callback` keyword argument. + For example: + + >>> def func(callback=None): + ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]: + ... if callback: + ... callback(i, c) + ... return 4 + + + Use ``with callback_iter(func)`` to get an iterator over the parameters + that are delivered to the callback. + + >>> with callback_iter(func) as it: + ... for args, kwargs in it: + ... print(args) + (1, 'a') + (2, 'b') + (3, 'c') + + The function will be called in a background thread. The ``done`` property + indicates whether it has completed execution. + + >>> it.done + True + + If it completes successfully, its return value will be available + in the ``result`` property. + + >>> it.result + 4 + + Notes: + + * If the function uses some keyword argument besides ``callback``, supply + *callback_kwd*. + * If it finished executing, but raised an exception, accessing the + ``result`` property will raise the same exception. + * If it hasn't finished executing, accessing the ``result`` + property from within the ``with`` block will raise ``RuntimeError``. + * If it hasn't finished executing, accessing the ``result`` property from + outside the ``with`` block will raise a + ``more_itertools.AbortThread`` exception. + * Provide *wait_seconds* to adjust how frequently the it is polled for + output. + + """ + + def __init__(self, func, callback_kwd='callback', wait_seconds=0.1): + self._func = func + self._callback_kwd = callback_kwd + self._aborted = False + self._future = None + self._wait_seconds = wait_seconds + self._executor = ThreadPoolExecutor(max_workers=1) + self._iterator = self._reader() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._aborted = True + self._executor.shutdown() + + def __iter__(self): + return self + + def __next__(self): + return next(self._iterator) + + @property + def done(self): + if self._future is None: + return False + return self._future.done() + + @property + def result(self): + if not self.done: + raise RuntimeError('Function has not yet completed') + + return self._future.result() + + def _reader(self): + q = Queue() + + def callback(*args, **kwargs): + if self._aborted: + raise AbortThread('canceled by user') + + q.put((args, kwargs)) + + self._future = self._executor.submit( + self._func, **{self._callback_kwd: callback} + ) + + while True: + try: + item = q.get(timeout=self._wait_seconds) + except Empty: + pass + else: + q.task_done() + yield item + + if self._future.done(): + break + + remaining = [] + while True: + try: + item = q.get_nowait() + except Empty: + break + else: + q.task_done() + remaining.append(item) + q.join() + yield from remaining + + +def windowed_complete(iterable, n): + """ + Yield ``(beginning, middle, end)`` tuples, where: + + * Each ``middle`` has *n* items from *iterable* + * Each ``beginning`` has the items before the ones in ``middle`` + * Each ``end`` has the items after the ones in ``middle`` + + >>> iterable = range(7) + >>> n = 3 + >>> for beginning, middle, end in windowed_complete(iterable, n): + ... print(beginning, middle, end) + () (0, 1, 2) (3, 4, 5, 6) + (0,) (1, 2, 3) (4, 5, 6) + (0, 1) (2, 3, 4) (5, 6) + (0, 1, 2) (3, 4, 5) (6,) + (0, 1, 2, 3) (4, 5, 6) () + + Note that *n* must be at least 0 and most equal to the length of + *iterable*. + + This function will exhaust the iterable and may require significant + storage. + """ + if n < 0: + raise ValueError('n must be >= 0') + + seq = tuple(iterable) + size = len(seq) + + if n > size: + raise ValueError('n must be <= len(seq)') + + for i in range(size - n + 1): + beginning = seq[:i] + middle = seq[i : i + n] + end = seq[i + n :] + yield beginning, middle, end + + +def all_unique(iterable, key=None): + """ + Returns ``True`` if all the elements of *iterable* are unique (no two + elements are equal). + + >>> all_unique('ABCB') + False + + If a *key* function is specified, it will be used to make comparisons. + + >>> all_unique('ABCb') + True + >>> all_unique('ABCb', str.lower) + False + + The function returns as soon as the first non-unique element is + encountered. Iterables with a mix of hashable and unhashable items can + be used, but the function will be slower for unhashable items. + """ + seenset = set() + seenset_add = seenset.add + seenlist = [] + seenlist_add = seenlist.append + for element in map(key, iterable) if key else iterable: + try: + if element in seenset: + return False + seenset_add(element) + except TypeError: + if element in seenlist: + return False + seenlist_add(element) + return True + + +def nth_product(index, *args): + """Equivalent to ``list(product(*args))[index]``. + + The products of *args* can be ordered lexicographically. + :func:`nth_product` computes the product at sort position *index* without + computing the previous products. + + >>> nth_product(8, range(2), range(2), range(2), range(2)) + (1, 0, 0, 0) + + ``IndexError`` will be raised if the given *index* is invalid. + """ + pools = list(map(tuple, reversed(args))) + ns = list(map(len, pools)) + + c = reduce(mul, ns) + + if index < 0: + index += c + + if not 0 <= index < c: + raise IndexError + + result = [] + for pool, n in zip(pools, ns): + result.append(pool[index % n]) + index //= n + + return tuple(reversed(result)) + + +def nth_permutation(iterable, r, index): + """Equivalent to ``list(permutations(iterable, r))[index]``` + + The subsequences of *iterable* that are of length *r* where order is + important can be ordered lexicographically. :func:`nth_permutation` + computes the subsequence at sort position *index* directly, without + computing the previous subsequences. + + >>> nth_permutation('ghijk', 2, 5) + ('h', 'i') + + ``ValueError`` will be raised If *r* is negative or greater than the length + of *iterable*. + ``IndexError`` will be raised if the given *index* is invalid. + """ + pool = list(iterable) + n = len(pool) + + if r is None or r == n: + r, c = n, factorial(n) + elif not 0 <= r < n: + raise ValueError + else: + c = factorial(n) // factorial(n - r) + + if index < 0: + index += c + + if not 0 <= index < c: + raise IndexError + + if c == 0: + return tuple() + + result = [0] * r + q = index * factorial(n) // c if r < n else index + for d in range(1, n + 1): + q, i = divmod(q, d) + if 0 <= n - d < r: + result[n - d] = i + if q == 0: + break + + return tuple(map(pool.pop, result)) + + +def value_chain(*args): + """Yield all arguments passed to the function in the same order in which + they were passed. If an argument itself is iterable then iterate over its + values. + + >>> list(value_chain(1, 2, 3, [4, 5, 6])) + [1, 2, 3, 4, 5, 6] + + Binary and text strings are not considered iterable and are emitted + as-is: + + >>> list(value_chain('12', '34', ['56', '78'])) + ['12', '34', '56', '78'] + + + Multiple levels of nesting are not flattened. + + """ + for value in args: + if isinstance(value, (str, bytes)): + yield value + continue + try: + yield from value + except TypeError: + yield value + + +def product_index(element, *args): + """Equivalent to ``list(product(*args)).index(element)`` + + The products of *args* can be ordered lexicographically. + :func:`product_index` computes the first index of *element* without + computing the previous products. + + >>> product_index([8, 2], range(10), range(5)) + 42 + + ``ValueError`` will be raised if the given *element* isn't in the product + of *args*. + """ + index = 0 + + for x, pool in zip_longest(element, args, fillvalue=_marker): + if x is _marker or pool is _marker: + raise ValueError('element is not a product of args') + + pool = tuple(pool) + index = index * len(pool) + pool.index(x) + + return index + + +def combination_index(element, iterable): + """Equivalent to ``list(combinations(iterable, r)).index(element)`` + + The subsequences of *iterable* that are of length *r* can be ordered + lexicographically. :func:`combination_index` computes the index of the + first *element*, without computing the previous combinations. + + >>> combination_index('adf', 'abcdefg') + 10 + + ``ValueError`` will be raised if the given *element* isn't one of the + combinations of *iterable*. + """ + element = enumerate(element) + k, y = next(element, (None, None)) + if k is None: + return 0 + + indexes = [] + pool = enumerate(iterable) + for n, x in pool: + if x == y: + indexes.append(n) + tmp, y = next(element, (None, None)) + if tmp is None: + break + else: + k = tmp + else: + raise ValueError('element is not a combination of iterable') + + n, _ = last(pool, default=(n, None)) + + # Python versiosn below 3.8 don't have math.comb + index = 1 + for i, j in enumerate(reversed(indexes), start=1): + j = n - j + if i <= j: + index += factorial(j) // (factorial(i) * factorial(j - i)) + + return factorial(n + 1) // (factorial(k + 1) * factorial(n - k)) - index + + +def permutation_index(element, iterable): + """Equivalent to ``list(permutations(iterable, r)).index(element)``` + + The subsequences of *iterable* that are of length *r* where order is + important can be ordered lexicographically. :func:`permutation_index` + computes the index of the first *element* directly, without computing + the previous permutations. + + >>> permutation_index([1, 3, 2], range(5)) + 19 + + ``ValueError`` will be raised if the given *element* isn't one of the + permutations of *iterable*. + """ + index = 0 + pool = list(iterable) + for i, x in zip(range(len(pool), -1, -1), element): + r = pool.index(x) + index = index * i + r + del pool[r] + + return index + + +class countable: + """Wrap *iterable* and keep a count of how many items have been consumed. + + The ``items_seen`` attribute starts at ``0`` and increments as the iterable + is consumed: + + >>> iterable = map(str, range(10)) + >>> it = countable(iterable) + >>> it.items_seen + 0 + >>> next(it), next(it) + ('0', '1') + >>> list(it) + ['2', '3', '4', '5', '6', '7', '8', '9'] + >>> it.items_seen + 10 + """ + + def __init__(self, iterable): + self._it = iter(iterable) + self.items_seen = 0 + + def __iter__(self): + return self + + def __next__(self): + item = next(self._it) + self.items_seen += 1 + + return item + + +def chunked_even(iterable, n): + """Break *iterable* into lists of approximately length *n*. + Items are distributed such the lengths of the lists differ by at most + 1 item. + + >>> iterable = [1, 2, 3, 4, 5, 6, 7] + >>> n = 3 + >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2 + [[1, 2, 3], [4, 5], [6, 7]] + >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1 + [[1, 2, 3], [4, 5, 6], [7]] + + """ + + len_method = getattr(iterable, '__len__', None) + + if len_method is None: + return _chunked_even_online(iterable, n) + else: + return _chunked_even_finite(iterable, len_method(), n) + + +def _chunked_even_online(iterable, n): + buffer = [] + maxbuf = n + (n - 2) * (n - 1) + for x in iterable: + buffer.append(x) + if len(buffer) == maxbuf: + yield buffer[:n] + buffer = buffer[n:] + yield from _chunked_even_finite(buffer, len(buffer), n) + + +def _chunked_even_finite(iterable, N, n): + if N < 1: + return + + # Lists are either size `full_size <= n` or `partial_size = full_size - 1` + q, r = divmod(N, n) + num_lists = q + (1 if r > 0 else 0) + q, r = divmod(N, num_lists) + full_size = q + (1 if r > 0 else 0) + partial_size = full_size - 1 + num_full = N - partial_size * num_lists + num_partial = num_lists - num_full + + buffer = [] + iterator = iter(iterable) + + # Yield num_full lists of full_size + for x in iterator: + buffer.append(x) + if len(buffer) == full_size: + yield buffer + buffer = [] + num_full -= 1 + if num_full <= 0: + break + + # Yield num_partial lists of partial_size + for x in iterator: + buffer.append(x) + if len(buffer) == partial_size: + yield buffer + buffer = [] + num_partial -= 1 + + +def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False): + """A version of :func:`zip` that "broadcasts" any scalar + (i.e., non-iterable) items into output tuples. + + >>> iterable_1 = [1, 2, 3] + >>> iterable_2 = ['a', 'b', 'c'] + >>> scalar = '_' + >>> list(zip_broadcast(iterable_1, iterable_2, scalar)) + [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')] + + The *scalar_types* keyword argument determines what types are considered + scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to + treat strings and byte strings as iterable: + + >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None)) + [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')] + + If the *strict* keyword argument is ``True``, then + ``UnequalIterablesError`` will be raised if any of the iterables have + different lengths. + + """ + if not objects: + return + + iterables = [] + all_scalar = True + for obj in objects: + # If the object is one of our scalar types, turn it into an iterable + # by wrapping it with itertools.repeat + if scalar_types and isinstance(obj, scalar_types): + iterables.append((repeat(obj), False)) + # Otherwise, test to see whether the object is iterable. + # If it is, collect it. If it's not, treat it as a scalar. + else: + try: + iterables.append((iter(obj), True)) + except TypeError: + iterables.append((repeat(obj), False)) + else: + all_scalar = False + + # If all the objects were scalar, we just emit them as a tuple. + # Otherwise we zip the collected iterable objects. + if all_scalar: + yield tuple(objects) + else: + yield from zip(*(it for it, is_it in iterables)) + + # For strict mode, we ensure that all the iterable objects have been + # exhausted. + if strict: + for it, is_it in filter(itemgetter(1), iterables): + if next(it, _marker) is not _marker: + raise UnequalIterablesError diff --git a/lib/more_itertools/more.pyi b/lib/more_itertools/more.pyi new file mode 100644 index 00000000..6525d349 --- /dev/null +++ b/lib/more_itertools/more.pyi @@ -0,0 +1,556 @@ +"""Stubs for more_itertools.more""" + +from typing import ( + Any, + Callable, + Container, + Dict, + Generic, + Hashable, + Iterable, + Iterator, + List, + Optional, + Reversible, + Sequence, + Sized, + Tuple, + Union, + TypeVar, + type_check_only, +) +from types import TracebackType +from typing_extensions import ContextManager, Protocol, Type, overload + +# Type and type variable definitions +_T = TypeVar('_T') +_T1 = TypeVar('_T1') +_T2 = TypeVar('_T2') +_U = TypeVar('_U') +_V = TypeVar('_V') +_W = TypeVar('_W') +_T_co = TypeVar('_T_co', covariant=True) +_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[object]]) +_Raisable = Union[BaseException, 'Type[BaseException]'] + +@type_check_only +class _SizedIterable(Protocol[_T_co], Sized, Iterable[_T_co]): ... + +@type_check_only +class _SizedReversible(Protocol[_T_co], Sized, Reversible[_T_co]): ... + +def chunked( + iterable: Iterable[_T], n: Optional[int], strict: bool = ... +) -> Iterator[List[_T]]: ... +@overload +def first(iterable: Iterable[_T]) -> _T: ... +@overload +def first(iterable: Iterable[_T], default: _U) -> Union[_T, _U]: ... +@overload +def last(iterable: Iterable[_T]) -> _T: ... +@overload +def last(iterable: Iterable[_T], default: _U) -> Union[_T, _U]: ... +@overload +def nth_or_last(iterable: Iterable[_T], n: int) -> _T: ... +@overload +def nth_or_last( + iterable: Iterable[_T], n: int, default: _U +) -> Union[_T, _U]: ... + +class peekable(Generic[_T], Iterator[_T]): + def __init__(self, iterable: Iterable[_T]) -> None: ... + def __iter__(self) -> peekable[_T]: ... + def __bool__(self) -> bool: ... + @overload + def peek(self) -> _T: ... + @overload + def peek(self, default: _U) -> Union[_T, _U]: ... + def prepend(self, *items: _T) -> None: ... + def __next__(self) -> _T: ... + @overload + def __getitem__(self, index: int) -> _T: ... + @overload + def __getitem__(self, index: slice) -> List[_T]: ... + +def collate(*iterables: Iterable[_T], **kwargs: Any) -> Iterable[_T]: ... +def consumer(func: _GenFn) -> _GenFn: ... +def ilen(iterable: Iterable[object]) -> int: ... +def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ... +def with_iter( + context_manager: ContextManager[Iterable[_T]], +) -> Iterator[_T]: ... +def one( + iterable: Iterable[_T], + too_short: Optional[_Raisable] = ..., + too_long: Optional[_Raisable] = ..., +) -> _T: ... +def distinct_permutations( + iterable: Iterable[_T], r: Optional[int] = ... +) -> Iterator[Tuple[_T, ...]]: ... +def intersperse( + e: _U, iterable: Iterable[_T], n: int = ... +) -> Iterator[Union[_T, _U]]: ... +def unique_to_each(*iterables: Iterable[_T]) -> List[List[_T]]: ... +@overload +def windowed( + seq: Iterable[_T], n: int, *, step: int = ... +) -> Iterator[Tuple[Optional[_T], ...]]: ... +@overload +def windowed( + seq: Iterable[_T], n: int, fillvalue: _U, step: int = ... +) -> Iterator[Tuple[Union[_T, _U], ...]]: ... +def substrings(iterable: Iterable[_T]) -> Iterator[Tuple[_T, ...]]: ... +def substrings_indexes( + seq: Sequence[_T], reverse: bool = ... +) -> Iterator[Tuple[Sequence[_T], int, int]]: ... + +class bucket(Generic[_T, _U], Container[_U]): + def __init__( + self, + iterable: Iterable[_T], + key: Callable[[_T], _U], + validator: Optional[Callable[[object], object]] = ..., + ) -> None: ... + def __contains__(self, value: object) -> bool: ... + def __iter__(self) -> Iterator[_U]: ... + def __getitem__(self, value: object) -> Iterator[_T]: ... + +def spy( + iterable: Iterable[_T], n: int = ... +) -> Tuple[List[_T], Iterator[_T]]: ... +def interleave(*iterables: Iterable[_T]) -> Iterator[_T]: ... +def interleave_longest(*iterables: Iterable[_T]) -> Iterator[_T]: ... +def interleave_evenly( + iterables: List[Iterable[_T]], lengths: Optional[List[int]] = ... +) -> Iterator[_T]: ... +def collapse( + iterable: Iterable[Any], + base_type: Optional[type] = ..., + levels: Optional[int] = ..., +) -> Iterator[Any]: ... +@overload +def side_effect( + func: Callable[[_T], object], + iterable: Iterable[_T], + chunk_size: None = ..., + before: Optional[Callable[[], object]] = ..., + after: Optional[Callable[[], object]] = ..., +) -> Iterator[_T]: ... +@overload +def side_effect( + func: Callable[[List[_T]], object], + iterable: Iterable[_T], + chunk_size: int, + before: Optional[Callable[[], object]] = ..., + after: Optional[Callable[[], object]] = ..., +) -> Iterator[_T]: ... +def sliced( + seq: Sequence[_T], n: int, strict: bool = ... +) -> Iterator[Sequence[_T]]: ... +def split_at( + iterable: Iterable[_T], + pred: Callable[[_T], object], + maxsplit: int = ..., + keep_separator: bool = ..., +) -> Iterator[List[_T]]: ... +def split_before( + iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ... +) -> Iterator[List[_T]]: ... +def split_after( + iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ... +) -> Iterator[List[_T]]: ... +def split_when( + iterable: Iterable[_T], + pred: Callable[[_T, _T], object], + maxsplit: int = ..., +) -> Iterator[List[_T]]: ... +def split_into( + iterable: Iterable[_T], sizes: Iterable[Optional[int]] +) -> Iterator[List[_T]]: ... +@overload +def padded( + iterable: Iterable[_T], + *, + n: Optional[int] = ..., + next_multiple: bool = ... +) -> Iterator[Optional[_T]]: ... +@overload +def padded( + iterable: Iterable[_T], + fillvalue: _U, + n: Optional[int] = ..., + next_multiple: bool = ..., +) -> Iterator[Union[_T, _U]]: ... +@overload +def repeat_last(iterable: Iterable[_T]) -> Iterator[_T]: ... +@overload +def repeat_last( + iterable: Iterable[_T], default: _U +) -> Iterator[Union[_T, _U]]: ... +def distribute(n: int, iterable: Iterable[_T]) -> List[Iterator[_T]]: ... +@overload +def stagger( + iterable: Iterable[_T], + offsets: _SizedIterable[int] = ..., + longest: bool = ..., +) -> Iterator[Tuple[Optional[_T], ...]]: ... +@overload +def stagger( + iterable: Iterable[_T], + offsets: _SizedIterable[int] = ..., + longest: bool = ..., + fillvalue: _U = ..., +) -> Iterator[Tuple[Union[_T, _U], ...]]: ... + +class UnequalIterablesError(ValueError): + def __init__( + self, details: Optional[Tuple[int, int, int]] = ... + ) -> None: ... + +@overload +def zip_equal(__iter1: Iterable[_T1]) -> Iterator[Tuple[_T1]]: ... +@overload +def zip_equal( + __iter1: Iterable[_T1], __iter2: Iterable[_T2] +) -> Iterator[Tuple[_T1, _T2]]: ... +@overload +def zip_equal( + __iter1: Iterable[_T], + __iter2: Iterable[_T], + __iter3: Iterable[_T], + *iterables: Iterable[_T] +) -> Iterator[Tuple[_T, ...]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T1], + *, + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: None = None +) -> Iterator[Tuple[Optional[_T1]]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T1], + __iter2: Iterable[_T2], + *, + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: None = None +) -> Iterator[Tuple[Optional[_T1], Optional[_T2]]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T], + __iter2: Iterable[_T], + __iter3: Iterable[_T], + *iterables: Iterable[_T], + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: None = None +) -> Iterator[Tuple[Optional[_T], ...]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T1], + *, + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: _U, +) -> Iterator[Tuple[Union[_T1, _U]]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T1], + __iter2: Iterable[_T2], + *, + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: _U, +) -> Iterator[Tuple[Union[_T1, _U], Union[_T2, _U]]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T], + __iter2: Iterable[_T], + __iter3: Iterable[_T], + *iterables: Iterable[_T], + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: _U, +) -> Iterator[Tuple[Union[_T, _U], ...]]: ... +def sort_together( + iterables: Iterable[Iterable[_T]], + key_list: Iterable[int] = ..., + key: Optional[Callable[..., Any]] = ..., + reverse: bool = ..., +) -> List[Tuple[_T, ...]]: ... +def unzip(iterable: Iterable[Sequence[_T]]) -> Tuple[Iterator[_T], ...]: ... +def divide(n: int, iterable: Iterable[_T]) -> List[Iterator[_T]]: ... +def always_iterable( + obj: object, + base_type: Union[ + type, Tuple[Union[type, Tuple[Any, ...]], ...], None + ] = ..., +) -> Iterator[Any]: ... +def adjacent( + predicate: Callable[[_T], bool], + iterable: Iterable[_T], + distance: int = ..., +) -> Iterator[Tuple[bool, _T]]: ... +def groupby_transform( + iterable: Iterable[_T], + keyfunc: Optional[Callable[[_T], _U]] = ..., + valuefunc: Optional[Callable[[_T], _V]] = ..., + reducefunc: Optional[Callable[..., _W]] = ..., +) -> Iterator[Tuple[_T, _W]]: ... + +class numeric_range(Generic[_T, _U], Sequence[_T], Hashable, Reversible[_T]): + @overload + def __init__(self, __stop: _T) -> None: ... + @overload + def __init__(self, __start: _T, __stop: _T) -> None: ... + @overload + def __init__(self, __start: _T, __stop: _T, __step: _U) -> None: ... + def __bool__(self) -> bool: ... + def __contains__(self, elem: object) -> bool: ... + def __eq__(self, other: object) -> bool: ... + @overload + def __getitem__(self, key: int) -> _T: ... + @overload + def __getitem__(self, key: slice) -> numeric_range[_T, _U]: ... + def __hash__(self) -> int: ... + def __iter__(self) -> Iterator[_T]: ... + def __len__(self) -> int: ... + def __reduce__( + self, + ) -> Tuple[Type[numeric_range[_T, _U]], Tuple[_T, _T, _U]]: ... + def __repr__(self) -> str: ... + def __reversed__(self) -> Iterator[_T]: ... + def count(self, value: _T) -> int: ... + def index(self, value: _T) -> int: ... # type: ignore + +def count_cycle( + iterable: Iterable[_T], n: Optional[int] = ... +) -> Iterable[Tuple[int, _T]]: ... +def mark_ends( + iterable: Iterable[_T], +) -> Iterable[Tuple[bool, bool, _T]]: ... +def locate( + iterable: Iterable[object], + pred: Callable[..., Any] = ..., + window_size: Optional[int] = ..., +) -> Iterator[int]: ... +def lstrip( + iterable: Iterable[_T], pred: Callable[[_T], object] +) -> Iterator[_T]: ... +def rstrip( + iterable: Iterable[_T], pred: Callable[[_T], object] +) -> Iterator[_T]: ... +def strip( + iterable: Iterable[_T], pred: Callable[[_T], object] +) -> Iterator[_T]: ... + +class islice_extended(Generic[_T], Iterator[_T]): + def __init__( + self, iterable: Iterable[_T], *args: Optional[int] + ) -> None: ... + def __iter__(self) -> islice_extended[_T]: ... + def __next__(self) -> _T: ... + def __getitem__(self, index: slice) -> islice_extended[_T]: ... + +def always_reversible(iterable: Iterable[_T]) -> Iterator[_T]: ... +def consecutive_groups( + iterable: Iterable[_T], ordering: Callable[[_T], int] = ... +) -> Iterator[Iterator[_T]]: ... +@overload +def difference( + iterable: Iterable[_T], + func: Callable[[_T, _T], _U] = ..., + *, + initial: None = ... +) -> Iterator[Union[_T, _U]]: ... +@overload +def difference( + iterable: Iterable[_T], func: Callable[[_T, _T], _U] = ..., *, initial: _U +) -> Iterator[_U]: ... + +class SequenceView(Generic[_T], Sequence[_T]): + def __init__(self, target: Sequence[_T]) -> None: ... + @overload + def __getitem__(self, index: int) -> _T: ... + @overload + def __getitem__(self, index: slice) -> Sequence[_T]: ... + def __len__(self) -> int: ... + +class seekable(Generic[_T], Iterator[_T]): + def __init__( + self, iterable: Iterable[_T], maxlen: Optional[int] = ... + ) -> None: ... + def __iter__(self) -> seekable[_T]: ... + def __next__(self) -> _T: ... + def __bool__(self) -> bool: ... + @overload + def peek(self) -> _T: ... + @overload + def peek(self, default: _U) -> Union[_T, _U]: ... + def elements(self) -> SequenceView[_T]: ... + def seek(self, index: int) -> None: ... + +class run_length: + @staticmethod + def encode(iterable: Iterable[_T]) -> Iterator[Tuple[_T, int]]: ... + @staticmethod + def decode(iterable: Iterable[Tuple[_T, int]]) -> Iterator[_T]: ... + +def exactly_n( + iterable: Iterable[_T], n: int, predicate: Callable[[_T], object] = ... +) -> bool: ... +def circular_shifts(iterable: Iterable[_T]) -> List[Tuple[_T, ...]]: ... +def make_decorator( + wrapping_func: Callable[..., _U], result_index: int = ... +) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., _U]]]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: None = ..., + reducefunc: None = ..., +) -> Dict[_U, List[_T]]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: Callable[[_T], _V], + reducefunc: None = ..., +) -> Dict[_U, List[_V]]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: None = ..., + reducefunc: Callable[[List[_T]], _W] = ..., +) -> Dict[_U, _W]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: Callable[[_T], _V], + reducefunc: Callable[[List[_V]], _W], +) -> Dict[_U, _W]: ... +def rlocate( + iterable: Iterable[_T], + pred: Callable[..., object] = ..., + window_size: Optional[int] = ..., +) -> Iterator[int]: ... +def replace( + iterable: Iterable[_T], + pred: Callable[..., object], + substitutes: Iterable[_U], + count: Optional[int] = ..., + window_size: int = ..., +) -> Iterator[Union[_T, _U]]: ... +def partitions(iterable: Iterable[_T]) -> Iterator[List[List[_T]]]: ... +def set_partitions( + iterable: Iterable[_T], k: Optional[int] = ... +) -> Iterator[List[List[_T]]]: ... + +class time_limited(Generic[_T], Iterator[_T]): + def __init__( + self, limit_seconds: float, iterable: Iterable[_T] + ) -> None: ... + def __iter__(self) -> islice_extended[_T]: ... + def __next__(self) -> _T: ... + +@overload +def only( + iterable: Iterable[_T], *, too_long: Optional[_Raisable] = ... +) -> Optional[_T]: ... +@overload +def only( + iterable: Iterable[_T], default: _U, too_long: Optional[_Raisable] = ... +) -> Union[_T, _U]: ... +def ichunked(iterable: Iterable[_T], n: int) -> Iterator[Iterator[_T]]: ... +def distinct_combinations( + iterable: Iterable[_T], r: int +) -> Iterator[Tuple[_T, ...]]: ... +def filter_except( + validator: Callable[[Any], object], + iterable: Iterable[_T], + *exceptions: Type[BaseException] +) -> Iterator[_T]: ... +def map_except( + function: Callable[[Any], _U], + iterable: Iterable[_T], + *exceptions: Type[BaseException] +) -> Iterator[_U]: ... +def map_if( + iterable: Iterable[Any], + pred: Callable[[Any], bool], + func: Callable[[Any], Any], + func_else: Optional[Callable[[Any], Any]] = ..., +) -> Iterator[Any]: ... +def sample( + iterable: Iterable[_T], + k: int, + weights: Optional[Iterable[float]] = ..., +) -> List[_T]: ... +def is_sorted( + iterable: Iterable[_T], + key: Optional[Callable[[_T], _U]] = ..., + reverse: bool = False, +) -> bool: ... + +class AbortThread(BaseException): + pass + +class callback_iter(Generic[_T], Iterator[_T]): + def __init__( + self, + func: Callable[..., Any], + callback_kwd: str = ..., + wait_seconds: float = ..., + ) -> None: ... + def __enter__(self) -> callback_iter[_T]: ... + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Optional[bool]: ... + def __iter__(self) -> callback_iter[_T]: ... + def __next__(self) -> _T: ... + def _reader(self) -> Iterator[_T]: ... + @property + def done(self) -> bool: ... + @property + def result(self) -> Any: ... + +def windowed_complete( + iterable: Iterable[_T], n: int +) -> Iterator[Tuple[_T, ...]]: ... +def all_unique( + iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ... +) -> bool: ... +def nth_product(index: int, *args: Iterable[_T]) -> Tuple[_T, ...]: ... +def nth_permutation( + iterable: Iterable[_T], r: int, index: int +) -> Tuple[_T, ...]: ... +def value_chain(*args: Union[_T, Iterable[_T]]) -> Iterable[_T]: ... +def product_index(element: Iterable[_T], *args: Iterable[_T]) -> int: ... +def combination_index( + element: Iterable[_T], iterable: Iterable[_T] +) -> int: ... +def permutation_index( + element: Iterable[_T], iterable: Iterable[_T] +) -> int: ... +def repeat_each(iterable: Iterable[_T], n: int = ...) -> Iterator[_T]: ... + +class countable(Generic[_T], Iterator[_T]): + def __init__(self, iterable: Iterable[_T]) -> None: ... + def __iter__(self) -> countable[_T]: ... + def __next__(self) -> _T: ... + +def chunked_even(iterable: Iterable[_T], n: int) -> Iterator[List[_T]]: ... +def zip_broadcast( + *objects: Union[_T, Iterable[_T]], + scalar_types: Union[ + type, Tuple[Union[type, Tuple[Any, ...]], ...], None + ] = ..., + strict: bool = ... +) -> Iterable[Tuple[_T, ...]]: ... diff --git a/lib/more_itertools/py.typed b/lib/more_itertools/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/lib/more_itertools/recipes.py b/lib/more_itertools/recipes.py index 3b455d4e..e470f3bd 100644 --- a/lib/more_itertools/recipes.py +++ b/lib/more_itertools/recipes.py @@ -7,20 +7,27 @@ Some backward-compatible usability improvements have been made. .. [1] http://docs.python.org/library/itertools.html#recipes """ +import warnings from collections import deque from itertools import ( - chain, combinations, count, cycle, groupby, islice, repeat, starmap, tee + chain, + combinations, + count, + cycle, + groupby, + islice, + repeat, + starmap, + tee, + zip_longest, ) import operator from random import randrange, sample, choice -from six import PY2 -from six.moves import filter, filterfalse, map, range, zip, zip_longest - __all__ = [ - 'accumulate', 'all_equal', 'consume', + 'convolve', 'dotproduct', 'first_true', 'flatten', @@ -30,6 +37,7 @@ __all__ = [ 'nth', 'nth_combination', 'padnone', + 'pad_none', 'pairwise', 'partition', 'powerset', @@ -49,46 +57,17 @@ __all__ = [ ] -def accumulate(iterable, func=operator.add): - """ - Return an iterator whose items are the accumulated results of a function - (specified by the optional *func* argument) that takes two arguments. - By default, returns accumulated sums with :func:`operator.add`. - - >>> list(accumulate([1, 2, 3, 4, 5])) # Running sum - [1, 3, 6, 10, 15] - >>> list(accumulate([1, 2, 3], func=operator.mul)) # Running product - [1, 2, 6] - >>> list(accumulate([0, 1, -1, 2, 3, 2], func=max)) # Running maximum - [0, 1, 1, 2, 3, 3] - - This function is available in the ``itertools`` module for Python 3.2 and - greater. - - """ - it = iter(iterable) - try: - total = next(it) - except StopIteration: - return - else: - yield total - - for element in it: - total = func(total, element) - yield total - - def take(n, iterable): """Return first *n* items of the iterable as a list. >>> take(3, range(10)) [0, 1, 2] - >>> take(5, range(3)) - [0, 1, 2] - Effectively a short replacement for ``next`` based iterator consumption - when you want more than one item, but less than the whole iterator. + If there are fewer than *n* items in the iterable, all of them are + returned. + + >>> take(10, range(3)) + [0, 1, 2] """ return list(islice(iterable, n)) @@ -115,9 +94,9 @@ def tabulate(function, start=0): def tail(n, iterable): """Return an iterator over the last *n* items of *iterable*. - >>> t = tail(3, 'ABCDEFG') - >>> list(t) - ['E', 'F', 'G'] + >>> t = tail(3, 'ABCDEFG') + >>> list(t) + ['E', 'F', 'G'] """ return iter(deque(iterable, maxlen=n)) @@ -166,11 +145,11 @@ def consume(iterator, n=None): def nth(iterable, n, default=None): """Returns the nth item or a default value. - >>> l = range(10) - >>> nth(l, 3) - 3 - >>> nth(l, 20, "zebra") - 'zebra' + >>> l = range(10) + >>> nth(l, 3) + 3 + >>> nth(l, 20, "zebra") + 'zebra' """ return next(islice(iterable, n, None), default) @@ -193,17 +172,17 @@ def all_equal(iterable): def quantify(iterable, pred=bool): """Return the how many times the predicate is true. - >>> quantify([True, False, True]) - 2 + >>> quantify([True, False, True]) + 2 """ return sum(map(pred, iterable)) -def padnone(iterable): +def pad_none(iterable): """Returns the sequence of elements and then returns ``None`` indefinitely. - >>> take(5, padnone(range(3))) + >>> take(5, pad_none(range(3))) [0, 1, 2, None, None] Useful for emulating the behavior of the built-in :func:`map` function. @@ -214,11 +193,14 @@ def padnone(iterable): return chain(iterable, repeat(None)) +padnone = pad_none + + def ncycles(iterable, n): """Returns the sequence elements *n* times - >>> list(ncycles(["a", "b"], 3)) - ['a', 'b', 'a', 'b', 'a', 'b'] + >>> list(ncycles(["a", "b"], 3)) + ['a', 'b', 'a', 'b', 'a', 'b'] """ return chain.from_iterable(repeat(tuple(iterable), n)) @@ -227,8 +209,8 @@ def ncycles(iterable, n): def dotproduct(vec1, vec2): """Returns the dot product of the two iterables. - >>> dotproduct([10, 10], [20, 20]) - 400 + >>> dotproduct([10, 10], [20, 20]) + 400 """ return sum(map(operator.mul, vec1, vec2)) @@ -273,25 +255,44 @@ def repeatfunc(func, times=None, *args): return starmap(func, repeat(args, times)) -def pairwise(iterable): +def _pairwise(iterable): """Returns an iterator of paired items, overlapping, from the original - >>> take(4, pairwise(count())) - [(0, 1), (1, 2), (2, 3), (3, 4)] + >>> take(4, pairwise(count())) + [(0, 1), (1, 2), (2, 3), (3, 4)] + + On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`. """ a, b = tee(iterable) next(b, None) - return zip(a, b) + yield from zip(a, b) -def grouper(n, iterable, fillvalue=None): +try: + from itertools import pairwise as itertools_pairwise +except ImportError: + pairwise = _pairwise +else: + + def pairwise(iterable): + yield from itertools_pairwise(iterable) + + pairwise.__doc__ = _pairwise.__doc__ + + +def grouper(iterable, n, fillvalue=None): """Collect data into fixed-length chunks or blocks. - >>> list(grouper(3, 'ABCDEFG', 'x')) - [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] + >>> list(grouper('ABCDEFG', 3, 'x')) + [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] """ + if isinstance(iterable, int): + warnings.warn( + "grouper expects iterable as first parameter", DeprecationWarning + ) + n, iterable = iterable, n args = [iter(iterable)] * n return zip_longest(fillvalue=fillvalue, *args) @@ -309,10 +310,7 @@ def roundrobin(*iterables): """ # Recipe credited to George Sakkis pending = len(iterables) - if PY2: - nexts = cycle(iter(it).next for it in iterables) - else: - nexts = cycle(iter(it).__next__ for it in iterables) + nexts = cycle(iter(it).__next__ for it in iterables) while pending: try: for next in nexts: @@ -334,10 +332,23 @@ def partition(pred, iterable): >>> list(even_items), list(odd_items) ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]) + If *pred* is None, :func:`bool` is used. + + >>> iterable = [0, 1, False, True, '', ' '] + >>> false_items, true_items = partition(None, iterable) + >>> list(false_items), list(true_items) + ([0, False, ''], [1, True, ' ']) + """ - # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9 - t1, t2 = tee(iterable) - return filterfalse(pred, t1), filter(pred, t2) + if pred is None: + pred = bool + + evaluations = ((pred(x), x) for x in iterable) + t1, t2 = tee(evaluations) + return ( + (x for (cond, x) in t1 if not cond), + (x for (cond, x) in t2 if cond), + ) def powerset(iterable): @@ -375,41 +386,46 @@ def unique_everseen(iterable, key=None): Sequences with a mix of hashable and unhashable items can be used. The function will be slower (i.e., `O(n^2)`) for unhashable items. + Remember that ``list`` objects are unhashable - you can use the *key* + parameter to transform the list to a tuple (which is hashable) to + avoid a slowdown. + + >>> iterable = ([1, 2], [2, 3], [1, 2]) + >>> list(unique_everseen(iterable)) # Slow + [[1, 2], [2, 3]] + >>> list(unique_everseen(iterable, key=tuple)) # Faster + [[1, 2], [2, 3]] + + Similary, you may want to convert unhashable ``set`` objects with + ``key=frozenset``. For ``dict`` objects, + ``key=lambda x: frozenset(x.items())`` can be used. + """ seenset = set() seenset_add = seenset.add seenlist = [] seenlist_add = seenlist.append - if key is None: - for element in iterable: - try: - if element not in seenset: - seenset_add(element) - yield element - except TypeError: - if element not in seenlist: - seenlist_add(element) - yield element - else: - for element in iterable: - k = key(element) - try: - if k not in seenset: - seenset_add(k) - yield element - except TypeError: - if k not in seenlist: - seenlist_add(k) - yield element + use_key = key is not None + + for element in iterable: + k = key(element) if use_key else element + try: + if k not in seenset: + seenset_add(k) + yield element + except TypeError: + if k not in seenlist: + seenlist_add(k) + yield element def unique_justseen(iterable, key=None): """Yields elements in order, ignoring serial duplicates - >>> list(unique_justseen('AAAABBBCCDAABBB')) - ['A', 'B', 'C', 'D', 'A', 'B'] - >>> list(unique_justseen('ABBCcAD', str.lower)) - ['A', 'B', 'C', 'A', 'D'] + >>> list(unique_justseen('AAAABBBCCDAABBB')) + ['A', 'B', 'C', 'D', 'A', 'B'] + >>> list(unique_justseen('ABBCcAD', str.lower)) + ['A', 'B', 'C', 'A', 'D'] """ return map(next, map(operator.itemgetter(1), groupby(iterable, key))) @@ -426,6 +442,16 @@ def iter_except(func, exception, first=None): >>> list(iter_except(l.pop, IndexError)) [2, 1, 0] + Multiple exceptions can be specified as a stopping condition: + + >>> l = [1, 2, 3, '...', 4, 5, 6] + >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) + [7, 6, 5] + >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) + [4, 3, 2] + >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) + [] + """ try: if first is not None: @@ -456,7 +482,7 @@ def first_true(iterable, default=None, pred=None): return next(filter(pred, iterable), default) -def random_product(*args, **kwds): +def random_product(*args, repeat=1): """Draw an item at random from each of the input iterables. >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP @@ -472,7 +498,7 @@ def random_product(*args, **kwds): ``itertools.product(*args, **kwarg)``. """ - pools = [tuple(pool) for pool in args] * kwds.get('repeat', 1) + pools = [tuple(pool) for pool in args] * repeat return tuple(choice(pool) for pool in pools) @@ -535,6 +561,12 @@ def nth_combination(iterable, r, index): sort position *index* directly, without computing the previous subsequences. + >>> nth_combination(range(5), 3, 5) + (0, 3, 4) + + ``ValueError`` will be raised If *r* is negative or greater than the length + of *iterable*. + ``IndexError`` will be raised if the given *index* is invalid. """ pool = tuple(iterable) n = len(pool) @@ -571,7 +603,28 @@ def prepend(value, iterator): >>> list(prepend(value, iterator)) ['0', '1', '2', '3'] - To prepend multiple values, see :func:`itertools.chain`. + To prepend multiple values, see :func:`itertools.chain` + or :func:`value_chain`. """ return chain([value], iterator) + + +def convolve(signal, kernel): + """Convolve the iterable *signal* with the iterable *kernel*. + + >>> signal = (1, 2, 3, 4, 5) + >>> kernel = [3, 2, 1] + >>> list(convolve(signal, kernel)) + [3, 8, 14, 20, 26, 14, 5] + + Note: the input arguments are not interchangeable, as the *kernel* + is immediately consumed and stored. + + """ + kernel = tuple(kernel)[::-1] + n = len(kernel) + window = deque([0], maxlen=n) * n + for x in chain(signal, repeat(0, n - 1)): + window.append(x) + yield sum(map(operator.mul, kernel, window)) diff --git a/lib/more_itertools/recipes.pyi b/lib/more_itertools/recipes.pyi new file mode 100644 index 00000000..69ff32d7 --- /dev/null +++ b/lib/more_itertools/recipes.pyi @@ -0,0 +1,105 @@ +"""Stubs for more_itertools.recipes""" +from typing import ( + Any, + Callable, + Iterable, + Iterator, + List, + Optional, + Tuple, + TypeVar, + Union, +) +from typing_extensions import overload, Type + +# Type and type variable definitions +_T = TypeVar('_T') +_U = TypeVar('_U') + +def take(n: int, iterable: Iterable[_T]) -> List[_T]: ... +def tabulate( + function: Callable[[int], _T], start: int = ... +) -> Iterator[_T]: ... +def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: ... +def consume(iterator: Iterable[object], n: Optional[int] = ...) -> None: ... +@overload +def nth(iterable: Iterable[_T], n: int) -> Optional[_T]: ... +@overload +def nth(iterable: Iterable[_T], n: int, default: _U) -> Union[_T, _U]: ... +def all_equal(iterable: Iterable[object]) -> bool: ... +def quantify( + iterable: Iterable[_T], pred: Callable[[_T], bool] = ... +) -> int: ... +def pad_none(iterable: Iterable[_T]) -> Iterator[Optional[_T]]: ... +def padnone(iterable: Iterable[_T]) -> Iterator[Optional[_T]]: ... +def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: ... +def dotproduct(vec1: Iterable[object], vec2: Iterable[object]) -> object: ... +def flatten(listOfLists: Iterable[Iterable[_T]]) -> Iterator[_T]: ... +def repeatfunc( + func: Callable[..., _U], times: Optional[int] = ..., *args: Any +) -> Iterator[_U]: ... +def pairwise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T]]: ... +@overload +def grouper( + iterable: Iterable[_T], n: int +) -> Iterator[Tuple[Optional[_T], ...]]: ... +@overload +def grouper( + iterable: Iterable[_T], n: int, fillvalue: _U +) -> Iterator[Tuple[Union[_T, _U], ...]]: ... +@overload +def grouper( # Deprecated interface + iterable: int, n: Iterable[_T] +) -> Iterator[Tuple[Optional[_T], ...]]: ... +@overload +def grouper( # Deprecated interface + iterable: int, n: Iterable[_T], fillvalue: _U +) -> Iterator[Tuple[Union[_T, _U], ...]]: ... +def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: ... +def partition( + pred: Optional[Callable[[_T], object]], iterable: Iterable[_T] +) -> Tuple[Iterator[_T], Iterator[_T]]: ... +def powerset(iterable: Iterable[_T]) -> Iterator[Tuple[_T, ...]]: ... +def unique_everseen( + iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ... +) -> Iterator[_T]: ... +def unique_justseen( + iterable: Iterable[_T], key: Optional[Callable[[_T], object]] = ... +) -> Iterator[_T]: ... +@overload +def iter_except( + func: Callable[[], _T], + exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], + first: None = ..., +) -> Iterator[_T]: ... +@overload +def iter_except( + func: Callable[[], _T], + exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], + first: Callable[[], _U], +) -> Iterator[Union[_T, _U]]: ... +@overload +def first_true( + iterable: Iterable[_T], *, pred: Optional[Callable[[_T], object]] = ... +) -> Optional[_T]: ... +@overload +def first_true( + iterable: Iterable[_T], + default: _U, + pred: Optional[Callable[[_T], object]] = ..., +) -> Union[_T, _U]: ... +def random_product( + *args: Iterable[_T], repeat: int = ... +) -> Tuple[_T, ...]: ... +def random_permutation( + iterable: Iterable[_T], r: Optional[int] = ... +) -> Tuple[_T, ...]: ... +def random_combination(iterable: Iterable[_T], r: int) -> Tuple[_T, ...]: ... +def random_combination_with_replacement( + iterable: Iterable[_T], r: int +) -> Tuple[_T, ...]: ... +def nth_combination( + iterable: Iterable[_T], r: int, index: int +) -> Tuple[_T, ...]: ... +def prepend(value: _T, iterator: Iterable[_U]) -> Iterator[Union[_T, _U]]: ... +def convolve(signal: Iterable[_T], kernel: Iterable[_T]) -> Iterator[_T]: ... diff --git a/lib/more_itertools/tests/test_more.py b/lib/more_itertools/tests/test_more.py deleted file mode 100644 index eacf8a8a..00000000 --- a/lib/more_itertools/tests/test_more.py +++ /dev/null @@ -1,2313 +0,0 @@ -from __future__ import division, print_function, unicode_literals - -from collections import OrderedDict -from decimal import Decimal -from doctest import DocTestSuite -from fractions import Fraction -from functools import partial, reduce -from heapq import merge -from io import StringIO -from itertools import ( - chain, - count, - groupby, - islice, - permutations, - product, - repeat, -) -from operator import add, mul, itemgetter -from unittest import TestCase - -from six.moves import filter, map, range, zip - -import more_itertools as mi - - -def load_tests(loader, tests, ignore): - # Add the doctests - tests.addTests(DocTestSuite('more_itertools.more')) - return tests - - -class CollateTests(TestCase): - """Unit tests for ``collate()``""" - # Also accidentally tests peekable, though that could use its own tests - - def test_default(self): - """Test with the default `key` function.""" - iterables = [range(4), range(7), range(3, 6)] - self.assertEqual( - sorted(reduce(list.__add__, [list(it) for it in iterables])), - list(mi.collate(*iterables)) - ) - - def test_key(self): - """Test using a custom `key` function.""" - iterables = [range(5, 0, -1), range(4, 0, -1)] - actual = sorted( - reduce(list.__add__, [list(it) for it in iterables]), reverse=True - ) - expected = list(mi.collate(*iterables, key=lambda x: -x)) - self.assertEqual(actual, expected) - - def test_empty(self): - """Be nice if passed an empty list of iterables.""" - self.assertEqual([], list(mi.collate())) - - def test_one(self): - """Work when only 1 iterable is passed.""" - self.assertEqual([0, 1], list(mi.collate(range(2)))) - - def test_reverse(self): - """Test the `reverse` kwarg.""" - iterables = [range(4, 0, -1), range(7, 0, -1), range(3, 6, -1)] - - actual = sorted( - reduce(list.__add__, [list(it) for it in iterables]), reverse=True - ) - expected = list(mi.collate(*iterables, reverse=True)) - self.assertEqual(actual, expected) - - def test_alias(self): - self.assertNotEqual(merge.__doc__, mi.collate.__doc__) - self.assertNotEqual(partial.__doc__, mi.collate.__doc__) - - -class ChunkedTests(TestCase): - """Tests for ``chunked()``""" - - def test_even(self): - """Test when ``n`` divides evenly into the length of the iterable.""" - self.assertEqual( - list(mi.chunked('ABCDEF', 3)), [['A', 'B', 'C'], ['D', 'E', 'F']] - ) - - def test_odd(self): - """Test when ``n`` does not divide evenly into the length of the - iterable. - - """ - self.assertEqual( - list(mi.chunked('ABCDE', 3)), [['A', 'B', 'C'], ['D', 'E']] - ) - - -class FirstTests(TestCase): - """Tests for ``first()``""" - - def test_many(self): - """Test that it works on many-item iterables.""" - # Also try it on a generator expression to make sure it works on - # whatever those return, across Python versions. - self.assertEqual(mi.first(x for x in range(4)), 0) - - def test_one(self): - """Test that it doesn't raise StopIteration prematurely.""" - self.assertEqual(mi.first([3]), 3) - - def test_empty_stop_iteration(self): - """It should raise StopIteration for empty iterables.""" - self.assertRaises(ValueError, lambda: mi.first([])) - - def test_default(self): - """It should return the provided default arg for empty iterables.""" - self.assertEqual(mi.first([], 'boo'), 'boo') - - -class IterOnlyRange: - """User-defined iterable class which only support __iter__. - - It is not specified to inherit ``object``, so indexing on a instance will - raise an ``AttributeError`` rather than ``TypeError`` in Python 2. - - >>> r = IterOnlyRange(5) - >>> r[0] - AttributeError: IterOnlyRange instance has no attribute '__getitem__' - - Note: In Python 3, ``TypeError`` will be raised because ``object`` is - inherited implicitly by default. - - >>> r[0] - TypeError: 'IterOnlyRange' object does not support indexing - """ - def __init__(self, n): - """Set the length of the range.""" - self.n = n - - def __iter__(self): - """Works same as range().""" - return iter(range(self.n)) - - -class LastTests(TestCase): - """Tests for ``last()``""" - - def test_many_nonsliceable(self): - """Test that it works on many-item non-slice-able iterables.""" - # Also try it on a generator expression to make sure it works on - # whatever those return, across Python versions. - self.assertEqual(mi.last(x for x in range(4)), 3) - - def test_one_nonsliceable(self): - """Test that it doesn't raise StopIteration prematurely.""" - self.assertEqual(mi.last(x for x in range(1)), 0) - - def test_empty_stop_iteration_nonsliceable(self): - """It should raise ValueError for empty non-slice-able iterables.""" - self.assertRaises(ValueError, lambda: mi.last(x for x in range(0))) - - def test_default_nonsliceable(self): - """It should return the provided default arg for empty non-slice-able - iterables. - """ - self.assertEqual(mi.last((x for x in range(0)), 'boo'), 'boo') - - def test_many_sliceable(self): - """Test that it works on many-item slice-able iterables.""" - self.assertEqual(mi.last([0, 1, 2, 3]), 3) - - def test_one_sliceable(self): - """Test that it doesn't raise StopIteration prematurely.""" - self.assertEqual(mi.last([3]), 3) - - def test_empty_stop_iteration_sliceable(self): - """It should raise ValueError for empty slice-able iterables.""" - self.assertRaises(ValueError, lambda: mi.last([])) - - def test_default_sliceable(self): - """It should return the provided default arg for empty slice-able - iterables. - """ - self.assertEqual(mi.last([], 'boo'), 'boo') - - def test_dict(self): - """last(dic) and last(dic.keys()) should return same result.""" - dic = {'a': 1, 'b': 2, 'c': 3} - self.assertEqual(mi.last(dic), mi.last(dic.keys())) - - def test_ordereddict(self): - """last(dic) should return the last key.""" - od = OrderedDict() - od['a'] = 1 - od['b'] = 2 - od['c'] = 3 - self.assertEqual(mi.last(od), 'c') - - def test_customrange(self): - """It should work on custom class where [] raises AttributeError.""" - self.assertEqual(mi.last(IterOnlyRange(5)), 4) - - -class PeekableTests(TestCase): - """Tests for ``peekable()`` behavor not incidentally covered by testing - ``collate()`` - - """ - def test_peek_default(self): - """Make sure passing a default into ``peek()`` works.""" - p = mi.peekable([]) - self.assertEqual(p.peek(7), 7) - - def test_truthiness(self): - """Make sure a ``peekable`` tests true iff there are items remaining in - the iterable. - - """ - p = mi.peekable([]) - self.assertFalse(p) - - p = mi.peekable(range(3)) - self.assertTrue(p) - - def test_simple_peeking(self): - """Make sure ``next`` and ``peek`` advance and don't advance the - iterator, respectively. - - """ - p = mi.peekable(range(10)) - self.assertEqual(next(p), 0) - self.assertEqual(p.peek(), 1) - self.assertEqual(next(p), 1) - - def test_indexing(self): - """ - Indexing into the peekable shouldn't advance the iterator. - """ - p = mi.peekable('abcdefghijkl') - - # The 0th index is what ``next()`` will return - self.assertEqual(p[0], 'a') - self.assertEqual(next(p), 'a') - - # Indexing further into the peekable shouldn't advance the itertor - self.assertEqual(p[2], 'd') - self.assertEqual(next(p), 'b') - - # The 0th index moves up with the iterator; the last index follows - self.assertEqual(p[0], 'c') - self.assertEqual(p[9], 'l') - - self.assertEqual(next(p), 'c') - self.assertEqual(p[8], 'l') - - # Negative indexing should work too - self.assertEqual(p[-2], 'k') - self.assertEqual(p[-9], 'd') - self.assertRaises(IndexError, lambda: p[-10]) - - def test_slicing(self): - """Slicing the peekable shouldn't advance the iterator.""" - seq = list('abcdefghijkl') - p = mi.peekable(seq) - - # Slicing the peekable should just be like slicing a re-iterable - self.assertEqual(p[1:4], seq[1:4]) - - # Advancing the iterator moves the slices up also - self.assertEqual(next(p), 'a') - self.assertEqual(p[1:4], seq[1:][1:4]) - - # Implicit starts and stop should work - self.assertEqual(p[:5], seq[1:][:5]) - self.assertEqual(p[:], seq[1:][:]) - - # Indexing past the end should work - self.assertEqual(p[:100], seq[1:][:100]) - - # Steps should work, including negative - self.assertEqual(p[::2], seq[1:][::2]) - self.assertEqual(p[::-1], seq[1:][::-1]) - - def test_slicing_reset(self): - """Test slicing on a fresh iterable each time""" - iterable = ['0', '1', '2', '3', '4', '5'] - indexes = list(range(-4, len(iterable) + 4)) + [None] - steps = [1, 2, 3, 4, -1, -2, -3, 4] - for slice_args in product(indexes, indexes, steps): - it = iter(iterable) - p = mi.peekable(it) - next(p) - index = slice(*slice_args) - actual = p[index] - expected = iterable[1:][index] - self.assertEqual(actual, expected, slice_args) - - def test_slicing_error(self): - iterable = '01234567' - p = mi.peekable(iter(iterable)) - - # Prime the cache - p.peek() - old_cache = list(p._cache) - - # Illegal slice - with self.assertRaises(ValueError): - p[1:-1:0] - - # Neither the cache nor the iteration should be affected - self.assertEqual(old_cache, list(p._cache)) - self.assertEqual(list(p), list(iterable)) - - def test_passthrough(self): - """Iterating a peekable without using ``peek()`` or ``prepend()`` - should just give the underlying iterable's elements (a trivial test but - useful to set a baseline in case something goes wrong)""" - expected = [1, 2, 3, 4, 5] - actual = list(mi.peekable(expected)) - self.assertEqual(actual, expected) - - # prepend() behavior tests - - def test_prepend(self): - """Tests intersperesed ``prepend()`` and ``next()`` calls""" - it = mi.peekable(range(2)) - actual = [] - - # Test prepend() before next() - it.prepend(10) - actual += [next(it), next(it)] - - # Test prepend() between next()s - it.prepend(11) - actual += [next(it), next(it)] - - # Test prepend() after source iterable is consumed - it.prepend(12) - actual += [next(it)] - - expected = [10, 0, 11, 1, 12] - self.assertEqual(actual, expected) - - def test_multi_prepend(self): - """Tests prepending multiple items and getting them in proper order""" - it = mi.peekable(range(5)) - actual = [next(it), next(it)] - it.prepend(10, 11, 12) - it.prepend(20, 21) - actual += list(it) - expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4] - self.assertEqual(actual, expected) - - def test_empty(self): - """Tests prepending in front of an empty iterable""" - it = mi.peekable([]) - it.prepend(10) - actual = list(it) - expected = [10] - self.assertEqual(actual, expected) - - def test_prepend_truthiness(self): - """Tests that ``__bool__()`` or ``__nonzero__()`` works properly - with ``prepend()``""" - it = mi.peekable(range(5)) - self.assertTrue(it) - actual = list(it) - self.assertFalse(it) - it.prepend(10) - self.assertTrue(it) - actual += [next(it)] - self.assertFalse(it) - expected = [0, 1, 2, 3, 4, 10] - self.assertEqual(actual, expected) - - def test_multi_prepend_peek(self): - """Tests prepending multiple elements and getting them in reverse order - while peeking""" - it = mi.peekable(range(5)) - actual = [next(it), next(it)] - self.assertEqual(it.peek(), 2) - it.prepend(10, 11, 12) - self.assertEqual(it.peek(), 10) - it.prepend(20, 21) - self.assertEqual(it.peek(), 20) - actual += list(it) - self.assertFalse(it) - expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4] - self.assertEqual(actual, expected) - - def test_prepend_after_stop(self): - """Test resuming iteration after a previous exhaustion""" - it = mi.peekable(range(3)) - self.assertEqual(list(it), [0, 1, 2]) - self.assertRaises(StopIteration, lambda: next(it)) - it.prepend(10) - self.assertEqual(next(it), 10) - self.assertRaises(StopIteration, lambda: next(it)) - - def test_prepend_slicing(self): - """Tests interaction between prepending and slicing""" - seq = list(range(20)) - p = mi.peekable(seq) - - p.prepend(30, 40, 50) - pseq = [30, 40, 50] + seq # pseq for prepended_seq - - # adapt the specific tests from test_slicing - self.assertEqual(p[0], 30) - self.assertEqual(p[1:8], pseq[1:8]) - self.assertEqual(p[1:], pseq[1:]) - self.assertEqual(p[:5], pseq[:5]) - self.assertEqual(p[:], pseq[:]) - self.assertEqual(p[:100], pseq[:100]) - self.assertEqual(p[::2], pseq[::2]) - self.assertEqual(p[::-1], pseq[::-1]) - - def test_prepend_indexing(self): - """Tests interaction between prepending and indexing""" - seq = list(range(20)) - p = mi.peekable(seq) - - p.prepend(30, 40, 50) - - self.assertEqual(p[0], 30) - self.assertEqual(next(p), 30) - self.assertEqual(p[2], 0) - self.assertEqual(next(p), 40) - self.assertEqual(p[0], 50) - self.assertEqual(p[9], 8) - self.assertEqual(next(p), 50) - self.assertEqual(p[8], 8) - self.assertEqual(p[-2], 18) - self.assertEqual(p[-9], 11) - self.assertRaises(IndexError, lambda: p[-21]) - - def test_prepend_iterable(self): - """Tests prepending from an iterable""" - it = mi.peekable(range(5)) - # Don't directly use the range() object to avoid any range-specific - # optimizations - it.prepend(*(x for x in range(5))) - actual = list(it) - expected = list(chain(range(5), range(5))) - self.assertEqual(actual, expected) - - def test_prepend_many(self): - """Tests that prepending a huge number of elements works""" - it = mi.peekable(range(5)) - # Don't directly use the range() object to avoid any range-specific - # optimizations - it.prepend(*(x for x in range(20000))) - actual = list(it) - expected = list(chain(range(20000), range(5))) - self.assertEqual(actual, expected) - - def test_prepend_reversed(self): - """Tests prepending from a reversed iterable""" - it = mi.peekable(range(3)) - it.prepend(*reversed((10, 11, 12))) - actual = list(it) - expected = [12, 11, 10, 0, 1, 2] - self.assertEqual(actual, expected) - - -class ConsumerTests(TestCase): - """Tests for ``consumer()``""" - - def test_consumer(self): - @mi.consumer - def eater(): - while True: - x = yield # noqa - - e = eater() - e.send('hi') # without @consumer, would raise TypeError - - -class DistinctPermutationsTests(TestCase): - def test_distinct_permutations(self): - """Make sure the output for ``distinct_permutations()`` is the same as - set(permutations(it)). - - """ - iterable = ['z', 'a', 'a', 'q', 'q', 'q', 'y'] - test_output = sorted(mi.distinct_permutations(iterable)) - ref_output = sorted(set(permutations(iterable))) - self.assertEqual(test_output, ref_output) - - def test_other_iterables(self): - """Make sure ``distinct_permutations()`` accepts a different type of - iterables. - - """ - # a generator - iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y']) - test_output = sorted(mi.distinct_permutations(iterable)) - # "reload" it - iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y']) - ref_output = sorted(set(permutations(iterable))) - self.assertEqual(test_output, ref_output) - - # an iterator - iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y']) - test_output = sorted(mi.distinct_permutations(iterable)) - # "reload" it - iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y']) - ref_output = sorted(set(permutations(iterable))) - self.assertEqual(test_output, ref_output) - - -class IlenTests(TestCase): - def test_ilen(self): - """Sanity-checks for ``ilen()``.""" - # Non-empty - self.assertEqual( - mi.ilen(filter(lambda x: x % 10 == 0, range(101))), 11 - ) - - # Empty - self.assertEqual(mi.ilen((x for x in range(0))), 0) - - # Iterable with __len__ - self.assertEqual(mi.ilen(list(range(6))), 6) - - -class WithIterTests(TestCase): - def test_with_iter(self): - s = StringIO('One fish\nTwo fish') - initial_words = [line.split()[0] for line in mi.with_iter(s)] - - # Iterable's items should be faithfully represented - self.assertEqual(initial_words, ['One', 'Two']) - # The file object should be closed - self.assertTrue(s.closed) - - -class OneTests(TestCase): - def test_basic(self): - it = iter(['item']) - self.assertEqual(mi.one(it), 'item') - - def test_too_short(self): - it = iter([]) - self.assertRaises(ValueError, lambda: mi.one(it)) - self.assertRaises(IndexError, lambda: mi.one(it, too_short=IndexError)) - - def test_too_long(self): - it = count() - self.assertRaises(ValueError, lambda: mi.one(it)) # burn 0 and 1 - self.assertEqual(next(it), 2) - self.assertRaises( - OverflowError, lambda: mi.one(it, too_long=OverflowError) - ) - - -class IntersperseTest(TestCase): - """ Tests for intersperse() """ - - def test_even(self): - iterable = (x for x in '01') - self.assertEqual( - list(mi.intersperse(None, iterable)), ['0', None, '1'] - ) - - def test_odd(self): - iterable = (x for x in '012') - self.assertEqual( - list(mi.intersperse(None, iterable)), ['0', None, '1', None, '2'] - ) - - def test_nested(self): - element = ('a', 'b') - iterable = (x for x in '012') - actual = list(mi.intersperse(element, iterable)) - expected = ['0', ('a', 'b'), '1', ('a', 'b'), '2'] - self.assertEqual(actual, expected) - - def test_not_iterable(self): - self.assertRaises(TypeError, lambda: mi.intersperse('x', 1)) - - def test_n(self): - for n, element, expected in [ - (1, '_', ['0', '_', '1', '_', '2', '_', '3', '_', '4', '_', '5']), - (2, '_', ['0', '1', '_', '2', '3', '_', '4', '5']), - (3, '_', ['0', '1', '2', '_', '3', '4', '5']), - (4, '_', ['0', '1', '2', '3', '_', '4', '5']), - (5, '_', ['0', '1', '2', '3', '4', '_', '5']), - (6, '_', ['0', '1', '2', '3', '4', '5']), - (7, '_', ['0', '1', '2', '3', '4', '5']), - (3, ['a', 'b'], ['0', '1', '2', ['a', 'b'], '3', '4', '5']), - ]: - iterable = (x for x in '012345') - actual = list(mi.intersperse(element, iterable, n=n)) - self.assertEqual(actual, expected) - - def test_n_zero(self): - self.assertRaises( - ValueError, lambda: list(mi.intersperse('x', '012', n=0)) - ) - - -class UniqueToEachTests(TestCase): - """Tests for ``unique_to_each()``""" - - def test_all_unique(self): - """When all the input iterables are unique the output should match - the input.""" - iterables = [[1, 2], [3, 4, 5], [6, 7, 8]] - self.assertEqual(mi.unique_to_each(*iterables), iterables) - - def test_duplicates(self): - """When there are duplicates in any of the input iterables that aren't - in the rest, those duplicates should be emitted.""" - iterables = ["mississippi", "missouri"] - self.assertEqual( - mi.unique_to_each(*iterables), [['p', 'p'], ['o', 'u', 'r']] - ) - - def test_mixed(self): - """When the input iterables contain different types the function should - still behave properly""" - iterables = ['x', (i for i in range(3)), [1, 2, 3], tuple()] - self.assertEqual(mi.unique_to_each(*iterables), [['x'], [0], [3], []]) - - -class WindowedTests(TestCase): - """Tests for ``windowed()``""" - - def test_basic(self): - actual = list(mi.windowed([1, 2, 3, 4, 5], 3)) - expected = [(1, 2, 3), (2, 3, 4), (3, 4, 5)] - self.assertEqual(actual, expected) - - def test_large_size(self): - """ - When the window size is larger than the iterable, and no fill value is - given,``None`` should be filled in. - """ - actual = list(mi.windowed([1, 2, 3, 4, 5], 6)) - expected = [(1, 2, 3, 4, 5, None)] - self.assertEqual(actual, expected) - - def test_fillvalue(self): - """ - When sizes don't match evenly, the given fill value should be used. - """ - iterable = [1, 2, 3, 4, 5] - - for n, kwargs, expected in [ - (6, {}, [(1, 2, 3, 4, 5, '!')]), # n > len(iterable) - (3, {'step': 3}, [(1, 2, 3), (4, 5, '!')]), # using ``step`` - ]: - actual = list(mi.windowed(iterable, n, fillvalue='!', **kwargs)) - self.assertEqual(actual, expected) - - def test_zero(self): - """When the window size is zero, an empty tuple should be emitted.""" - actual = list(mi.windowed([1, 2, 3, 4, 5], 0)) - expected = [tuple()] - self.assertEqual(actual, expected) - - def test_negative(self): - """When the window size is negative, ValueError should be raised.""" - with self.assertRaises(ValueError): - list(mi.windowed([1, 2, 3, 4, 5], -1)) - - def test_step(self): - """The window should advance by the number of steps provided""" - iterable = [1, 2, 3, 4, 5, 6, 7] - for n, step, expected in [ - (3, 2, [(1, 2, 3), (3, 4, 5), (5, 6, 7)]), # n > step - (3, 3, [(1, 2, 3), (4, 5, 6), (7, None, None)]), # n == step - (3, 4, [(1, 2, 3), (5, 6, 7)]), # line up nicely - (3, 5, [(1, 2, 3), (6, 7, None)]), # off by one - (3, 6, [(1, 2, 3), (7, None, None)]), # off by two - (3, 7, [(1, 2, 3)]), # step past the end - (7, 8, [(1, 2, 3, 4, 5, 6, 7)]), # step > len(iterable) - ]: - actual = list(mi.windowed(iterable, n, step=step)) - self.assertEqual(actual, expected) - - # Step must be greater than or equal to 1 - with self.assertRaises(ValueError): - list(mi.windowed(iterable, 3, step=0)) - - -class SubstringsTests(TestCase): - def test_basic(self): - iterable = (x for x in range(4)) - actual = list(mi.substrings(iterable)) - expected = [ - (0,), - (1,), - (2,), - (3,), - (0, 1), - (1, 2), - (2, 3), - (0, 1, 2), - (1, 2, 3), - (0, 1, 2, 3), - ] - self.assertEqual(actual, expected) - - def test_strings(self): - iterable = 'abc' - actual = list(mi.substrings(iterable)) - expected = [ - ('a',), ('b',), ('c',), ('a', 'b'), ('b', 'c'), ('a', 'b', 'c') - ] - self.assertEqual(actual, expected) - - def test_empty(self): - iterable = iter([]) - actual = list(mi.substrings(iterable)) - expected = [] - self.assertEqual(actual, expected) - - def test_order(self): - iterable = [2, 0, 1] - actual = list(mi.substrings(iterable)) - expected = [(2,), (0,), (1,), (2, 0), (0, 1), (2, 0, 1)] - self.assertEqual(actual, expected) - - -class BucketTests(TestCase): - """Tests for ``bucket()``""" - - def test_basic(self): - iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33] - D = mi.bucket(iterable, key=lambda x: 10 * (x // 10)) - - # In-order access - self.assertEqual(list(D[10]), [10, 11, 12]) - - # Out of order access - self.assertEqual(list(D[30]), [30, 31, 33]) - self.assertEqual(list(D[20]), [20, 21, 22, 23]) - - self.assertEqual(list(D[40]), []) # Nothing in here! - - def test_in(self): - iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33] - D = mi.bucket(iterable, key=lambda x: 10 * (x // 10)) - - self.assertIn(10, D) - self.assertNotIn(40, D) - self.assertIn(20, D) - self.assertNotIn(21, D) - - # Checking in-ness shouldn't advance the iterator - self.assertEqual(next(D[10]), 10) - - def test_validator(self): - iterable = count(0) - key = lambda x: int(str(x)[0]) # First digit of each number - validator = lambda x: 0 < x < 10 # No leading zeros - D = mi.bucket(iterable, key, validator=validator) - self.assertEqual(mi.take(3, D[1]), [1, 10, 11]) - self.assertNotIn(0, D) # Non-valid entries don't return True - self.assertNotIn(0, D._cache) # Don't store non-valid entries - self.assertEqual(list(D[0]), []) - - -class SpyTests(TestCase): - """Tests for ``spy()``""" - - def test_basic(self): - original_iterable = iter('abcdefg') - head, new_iterable = mi.spy(original_iterable) - self.assertEqual(head, ['a']) - self.assertEqual( - list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g'] - ) - - def test_unpacking(self): - original_iterable = iter('abcdefg') - (first, second, third), new_iterable = mi.spy(original_iterable, 3) - self.assertEqual(first, 'a') - self.assertEqual(second, 'b') - self.assertEqual(third, 'c') - self.assertEqual( - list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g'] - ) - - def test_too_many(self): - original_iterable = iter('abc') - head, new_iterable = mi.spy(original_iterable, 4) - self.assertEqual(head, ['a', 'b', 'c']) - self.assertEqual(list(new_iterable), ['a', 'b', 'c']) - - def test_zero(self): - original_iterable = iter('abc') - head, new_iterable = mi.spy(original_iterable, 0) - self.assertEqual(head, []) - self.assertEqual(list(new_iterable), ['a', 'b', 'c']) - - -class InterleaveTests(TestCase): - def test_even(self): - actual = list(mi.interleave([1, 4, 7], [2, 5, 8], [3, 6, 9])) - expected = [1, 2, 3, 4, 5, 6, 7, 8, 9] - self.assertEqual(actual, expected) - - def test_short(self): - actual = list(mi.interleave([1, 4], [2, 5, 7], [3, 6, 8])) - expected = [1, 2, 3, 4, 5, 6] - self.assertEqual(actual, expected) - - def test_mixed_types(self): - it_list = ['a', 'b', 'c', 'd'] - it_str = '12345' - it_inf = count() - actual = list(mi.interleave(it_list, it_str, it_inf)) - expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', 3] - self.assertEqual(actual, expected) - - -class InterleaveLongestTests(TestCase): - def test_even(self): - actual = list(mi.interleave_longest([1, 4, 7], [2, 5, 8], [3, 6, 9])) - expected = [1, 2, 3, 4, 5, 6, 7, 8, 9] - self.assertEqual(actual, expected) - - def test_short(self): - actual = list(mi.interleave_longest([1, 4], [2, 5, 7], [3, 6, 8])) - expected = [1, 2, 3, 4, 5, 6, 7, 8] - self.assertEqual(actual, expected) - - def test_mixed_types(self): - it_list = ['a', 'b', 'c', 'd'] - it_str = '12345' - it_gen = (x for x in range(3)) - actual = list(mi.interleave_longest(it_list, it_str, it_gen)) - expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', '5'] - self.assertEqual(actual, expected) - - -class TestCollapse(TestCase): - """Tests for ``collapse()``""" - - def test_collapse(self): - l = [[1], 2, [[3], 4], [[[5]]]] - self.assertEqual(list(mi.collapse(l)), [1, 2, 3, 4, 5]) - - def test_collapse_to_string(self): - l = [["s1"], "s2", [["s3"], "s4"], [[["s5"]]]] - self.assertEqual(list(mi.collapse(l)), ["s1", "s2", "s3", "s4", "s5"]) - - def test_collapse_flatten(self): - l = [[1], [2], [[3], 4], [[[5]]]] - self.assertEqual(list(mi.collapse(l, levels=1)), list(mi.flatten(l))) - - def test_collapse_to_level(self): - l = [[1], 2, [[3], 4], [[[5]]]] - self.assertEqual(list(mi.collapse(l, levels=2)), [1, 2, 3, 4, [5]]) - self.assertEqual( - list(mi.collapse(mi.collapse(l, levels=1), levels=1)), - list(mi.collapse(l, levels=2)) - ) - - def test_collapse_to_list(self): - l = (1, [2], (3, [4, (5,)], 'ab')) - actual = list(mi.collapse(l, base_type=list)) - expected = [1, [2], 3, [4, (5,)], 'ab'] - self.assertEqual(actual, expected) - - -class SideEffectTests(TestCase): - """Tests for ``side_effect()``""" - - def test_individual(self): - # The function increments the counter for each call - counter = [0] - - def func(arg): - counter[0] += 1 - - result = list(mi.side_effect(func, range(10))) - self.assertEqual(result, list(range(10))) - self.assertEqual(counter[0], 10) - - def test_chunked(self): - # The function increments the counter for each call - counter = [0] - - def func(arg): - counter[0] += 1 - - result = list(mi.side_effect(func, range(10), 2)) - self.assertEqual(result, list(range(10))) - self.assertEqual(counter[0], 5) - - def test_before_after(self): - f = StringIO() - collector = [] - - def func(item): - print(item, file=f) - collector.append(f.getvalue()) - - def it(): - yield 'a' - yield 'b' - raise RuntimeError('kaboom') - - before = lambda: print('HEADER', file=f) - after = f.close - - try: - mi.consume(mi.side_effect(func, it(), before=before, after=after)) - except RuntimeError: - pass - - # The iterable should have been written to the file - self.assertEqual(collector, ['HEADER\na\n', 'HEADER\na\nb\n']) - - # The file should be closed even though something bad happened - self.assertTrue(f.closed) - - def test_before_fails(self): - f = StringIO() - func = lambda x: print(x, file=f) - - def before(): - raise RuntimeError('ouch') - - try: - mi.consume( - mi.side_effect(func, 'abc', before=before, after=f.close) - ) - except RuntimeError: - pass - - # The file should be closed even though something bad happened in the - # before function - self.assertTrue(f.closed) - - -class SlicedTests(TestCase): - """Tests for ``sliced()``""" - - def test_even(self): - """Test when the length of the sequence is divisible by *n*""" - seq = 'ABCDEFGHI' - self.assertEqual(list(mi.sliced(seq, 3)), ['ABC', 'DEF', 'GHI']) - - def test_odd(self): - """Test when the length of the sequence is not divisible by *n*""" - seq = 'ABCDEFGHI' - self.assertEqual(list(mi.sliced(seq, 4)), ['ABCD', 'EFGH', 'I']) - - def test_not_sliceable(self): - seq = (x for x in 'ABCDEFGHI') - - with self.assertRaises(TypeError): - list(mi.sliced(seq, 3)) - - -class SplitAtTests(TestCase): - """Tests for ``split()``""" - - def comp_with_str_split(self, str_to_split, delim): - pred = lambda c: c == delim - actual = list(map(''.join, mi.split_at(str_to_split, pred))) - expected = str_to_split.split(delim) - self.assertEqual(actual, expected) - - def test_seperators(self): - test_strs = ['', 'abcba', 'aaabbbcccddd', 'e'] - for s, delim in product(test_strs, 'abcd'): - self.comp_with_str_split(s, delim) - - -class SplitBeforeTest(TestCase): - """Tests for ``split_before()``""" - - def test_starts_with_sep(self): - actual = list(mi.split_before('xooxoo', lambda c: c == 'x')) - expected = [['x', 'o', 'o'], ['x', 'o', 'o']] - self.assertEqual(actual, expected) - - def test_ends_with_sep(self): - actual = list(mi.split_before('ooxoox', lambda c: c == 'x')) - expected = [['o', 'o'], ['x', 'o', 'o'], ['x']] - self.assertEqual(actual, expected) - - def test_no_sep(self): - actual = list(mi.split_before('ooo', lambda c: c == 'x')) - expected = [['o', 'o', 'o']] - self.assertEqual(actual, expected) - - -class SplitAfterTest(TestCase): - """Tests for ``split_after()``""" - - def test_starts_with_sep(self): - actual = list(mi.split_after('xooxoo', lambda c: c == 'x')) - expected = [['x'], ['o', 'o', 'x'], ['o', 'o']] - self.assertEqual(actual, expected) - - def test_ends_with_sep(self): - actual = list(mi.split_after('ooxoox', lambda c: c == 'x')) - expected = [['o', 'o', 'x'], ['o', 'o', 'x']] - self.assertEqual(actual, expected) - - def test_no_sep(self): - actual = list(mi.split_after('ooo', lambda c: c == 'x')) - expected = [['o', 'o', 'o']] - self.assertEqual(actual, expected) - - -class SplitIntoTests(TestCase): - """Tests for ``split_into()``""" - - def test_iterable_just_right(self): - """Size of ``iterable`` equals the sum of ``sizes``.""" - iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] - sizes = [2, 3, 4] - expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9]] - actual = list(mi.split_into(iterable, sizes)) - self.assertEqual(actual, expected) - - def test_iterable_too_small(self): - """Size of ``iterable`` is smaller than sum of ``sizes``. Last return - list is shorter as a result.""" - iterable = [1, 2, 3, 4, 5, 6, 7] - sizes = [2, 3, 4] - expected = [[1, 2], [3, 4, 5], [6, 7]] - actual = list(mi.split_into(iterable, sizes)) - self.assertEqual(actual, expected) - - def test_iterable_too_small_extra(self): - """Size of ``iterable`` is smaller than sum of ``sizes``. Second last - return list is shorter and last return list is empty as a result.""" - iterable = [1, 2, 3, 4, 5, 6, 7] - sizes = [2, 3, 4, 5] - expected = [[1, 2], [3, 4, 5], [6, 7], []] - actual = list(mi.split_into(iterable, sizes)) - self.assertEqual(actual, expected) - - def test_iterable_too_large(self): - """Size of ``iterable`` is larger than sum of ``sizes``. Not all - items of iterable are returned.""" - iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] - sizes = [2, 3, 2] - expected = [[1, 2], [3, 4, 5], [6, 7]] - actual = list(mi.split_into(iterable, sizes)) - self.assertEqual(actual, expected) - - def test_using_none_with_leftover(self): - """Last item of ``sizes`` is None when items still remain in - ``iterable``. Last list returned stretches to fit all remaining items - of ``iterable``.""" - iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] - sizes = [2, 3, None] - expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9]] - actual = list(mi.split_into(iterable, sizes)) - self.assertEqual(actual, expected) - - def test_using_none_without_leftover(self): - """Last item of ``sizes`` is None when no items remain in - ``iterable``. Last list returned is empty.""" - iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] - sizes = [2, 3, 4, None] - expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9], []] - actual = list(mi.split_into(iterable, sizes)) - self.assertEqual(actual, expected) - - def test_using_none_mid_sizes(self): - """None is present in ``sizes`` but is not the last item. Last list - returned stretches to fit all remaining items of ``iterable`` but - all items in ``sizes`` after None are ignored.""" - iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] - sizes = [2, 3, None, 4] - expected = [[1, 2], [3, 4, 5], [6, 7, 8, 9]] - actual = list(mi.split_into(iterable, sizes)) - self.assertEqual(actual, expected) - - def test_iterable_empty(self): - """``iterable`` argument is empty but ``sizes`` is not. An empty - list is returned for each item in ``sizes``.""" - iterable = [] - sizes = [2, 4, 2] - expected = [[], [], []] - actual = list(mi.split_into(iterable, sizes)) - self.assertEqual(actual, expected) - - def test_iterable_empty_using_none(self): - """``iterable`` argument is empty but ``sizes`` is not. An empty - list is returned for each item in ``sizes`` that is not after a - None item.""" - iterable = [] - sizes = [2, 4, None, 2] - expected = [[], [], []] - actual = list(mi.split_into(iterable, sizes)) - self.assertEqual(actual, expected) - - def test_sizes_empty(self): - """``sizes`` argument is empty but ``iterable`` is not. An empty - generator is returned.""" - iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] - sizes = [] - expected = [] - actual = list(mi.split_into(iterable, sizes)) - self.assertEqual(actual, expected) - - def test_both_empty(self): - """Both ``sizes`` and ``iterable`` arguments are empty. An empty - generator is returned.""" - iterable = [] - sizes = [] - expected = [] - actual = list(mi.split_into(iterable, sizes)) - self.assertEqual(actual, expected) - - def test_bool_in_sizes(self): - """A bool object is present in ``sizes`` is treated as a 1 or 0 for - ``True`` or ``False`` due to bool being an instance of int.""" - iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] - sizes = [3, True, 2, False] - expected = [[1, 2, 3], [4], [5, 6], []] - actual = list(mi.split_into(iterable, sizes)) - self.assertEqual(actual, expected) - - def test_invalid_in_sizes(self): - """A ValueError is raised if an object in ``sizes`` is neither ``None`` - or an integer.""" - iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] - sizes = [1, [], 3] - with self.assertRaises(ValueError): - list(mi.split_into(iterable, sizes)) - - def test_invalid_in_sizes_after_none(self): - """A item in ``sizes`` that is invalid will not raise a TypeError if it - comes after a ``None`` item.""" - iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] - sizes = [3, 4, None, []] - expected = [[1, 2, 3], [4, 5, 6, 7], [8, 9]] - actual = list(mi.split_into(iterable, sizes)) - self.assertEqual(actual, expected) - - def test_generator_iterable_integrity(self): - """Check that if ``iterable`` is an iterator, it is consumed only by as - many items as the sum of ``sizes``.""" - iterable = (i for i in range(10)) - sizes = [2, 3] - - expected = [[0, 1], [2, 3, 4]] - actual = list(mi.split_into(iterable, sizes)) - self.assertEqual(actual, expected) - - iterable_expected = [5, 6, 7, 8, 9] - iterable_actual = list(iterable) - self.assertEqual(iterable_actual, iterable_expected) - - def test_generator_sizes_integrity(self): - """Check that if ``sizes`` is an iterator, it is consumed only until a - ``None`` item is reached""" - iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9] - sizes = (i for i in [1, 2, None, 3, 4]) - - expected = [[1], [2, 3], [4, 5, 6, 7, 8, 9]] - actual = list(mi.split_into(iterable, sizes)) - self.assertEqual(actual, expected) - - sizes_expected = [3, 4] - sizes_actual = list(sizes) - self.assertEqual(sizes_actual, sizes_expected) - - -class PaddedTest(TestCase): - """Tests for ``padded()``""" - - def test_no_n(self): - seq = [1, 2, 3] - - # No fillvalue - self.assertEqual(mi.take(5, mi.padded(seq)), [1, 2, 3, None, None]) - - # With fillvalue - self.assertEqual( - mi.take(5, mi.padded(seq, fillvalue='')), [1, 2, 3, '', ''] - ) - - def test_invalid_n(self): - self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=-1))) - self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=0))) - - def test_valid_n(self): - seq = [1, 2, 3, 4, 5] - - # No need for padding: len(seq) <= n - self.assertEqual(list(mi.padded(seq, n=4)), [1, 2, 3, 4, 5]) - self.assertEqual(list(mi.padded(seq, n=5)), [1, 2, 3, 4, 5]) - - # No fillvalue - self.assertEqual( - list(mi.padded(seq, n=7)), [1, 2, 3, 4, 5, None, None] - ) - - # With fillvalue - self.assertEqual( - list(mi.padded(seq, fillvalue='', n=7)), [1, 2, 3, 4, 5, '', ''] - ) - - def test_next_multiple(self): - seq = [1, 2, 3, 4, 5, 6] - - # No need for padding: len(seq) % n == 0 - self.assertEqual( - list(mi.padded(seq, n=3, next_multiple=True)), [1, 2, 3, 4, 5, 6] - ) - - # Padding needed: len(seq) < n - self.assertEqual( - list(mi.padded(seq, n=8, next_multiple=True)), - [1, 2, 3, 4, 5, 6, None, None] - ) - - # No padding needed: len(seq) == n - self.assertEqual( - list(mi.padded(seq, n=6, next_multiple=True)), [1, 2, 3, 4, 5, 6] - ) - - # Padding needed: len(seq) > n - self.assertEqual( - list(mi.padded(seq, n=4, next_multiple=True)), - [1, 2, 3, 4, 5, 6, None, None] - ) - - # With fillvalue - self.assertEqual( - list(mi.padded(seq, fillvalue='', n=4, next_multiple=True)), - [1, 2, 3, 4, 5, 6, '', ''] - ) - - -class DistributeTest(TestCase): - """Tests for distribute()""" - - def test_invalid_n(self): - self.assertRaises(ValueError, lambda: mi.distribute(-1, [1, 2, 3])) - self.assertRaises(ValueError, lambda: mi.distribute(0, [1, 2, 3])) - - def test_basic(self): - iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - - for n, expected in [ - (1, [iterable]), - (2, [[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]), - (3, [[1, 4, 7, 10], [2, 5, 8], [3, 6, 9]]), - (10, [[n] for n in range(1, 10 + 1)]), - ]: - self.assertEqual( - [list(x) for x in mi.distribute(n, iterable)], expected - ) - - def test_large_n(self): - iterable = [1, 2, 3, 4] - self.assertEqual( - [list(x) for x in mi.distribute(6, iterable)], - [[1], [2], [3], [4], [], []] - ) - - -class StaggerTest(TestCase): - """Tests for ``stagger()``""" - - def test_default(self): - iterable = [0, 1, 2, 3] - actual = list(mi.stagger(iterable)) - expected = [(None, 0, 1), (0, 1, 2), (1, 2, 3)] - self.assertEqual(actual, expected) - - def test_offsets(self): - iterable = [0, 1, 2, 3] - for offsets, expected in [ - ((-2, 0, 2), [('', 0, 2), ('', 1, 3)]), - ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3)]), - ((1, 2), [(1, 2), (2, 3)]), - ]: - all_groups = mi.stagger(iterable, offsets=offsets, fillvalue='') - self.assertEqual(list(all_groups), expected) - - def test_longest(self): - iterable = [0, 1, 2, 3] - for offsets, expected in [ - ( - (-1, 0, 1), - [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, ''), (3, '', '')] - ), - ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3), (3, '')]), - ((1, 2), [(1, 2), (2, 3), (3, '')]), - ]: - all_groups = mi.stagger( - iterable, offsets=offsets, fillvalue='', longest=True - ) - self.assertEqual(list(all_groups), expected) - - -class ZipOffsetTest(TestCase): - """Tests for ``zip_offset()``""" - - def test_shortest(self): - a_1 = [0, 1, 2, 3] - a_2 = [0, 1, 2, 3, 4, 5] - a_3 = [0, 1, 2, 3, 4, 5, 6, 7] - actual = list( - mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), fillvalue='') - ) - expected = [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5)] - self.assertEqual(actual, expected) - - def test_longest(self): - a_1 = [0, 1, 2, 3] - a_2 = [0, 1, 2, 3, 4, 5] - a_3 = [0, 1, 2, 3, 4, 5, 6, 7] - actual = list( - mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), longest=True) - ) - expected = [ - (None, 0, 1), - (0, 1, 2), - (1, 2, 3), - (2, 3, 4), - (3, 4, 5), - (None, 5, 6), - (None, None, 7), - ] - self.assertEqual(actual, expected) - - def test_mismatch(self): - iterables = [0, 1, 2], [2, 3, 4] - offsets = (-1, 0, 1) - self.assertRaises( - ValueError, - lambda: list(mi.zip_offset(*iterables, offsets=offsets)) - ) - - -class UnzipTests(TestCase): - """Tests for unzip()""" - - def test_empty_iterable(self): - self.assertEqual(list(mi.unzip([])), []) - # in reality zip([], [], []) is equivalent to iter([]) - # but it doesn't hurt to test both - self.assertEqual(list(mi.unzip(zip([], [], []))), []) - - def test_length_one_iterable(self): - xs, ys, zs = mi.unzip(zip([1], [2], [3])) - self.assertEqual(list(xs), [1]) - self.assertEqual(list(ys), [2]) - self.assertEqual(list(zs), [3]) - - def test_normal_case(self): - xs, ys, zs = range(10), range(1, 11), range(2, 12) - zipped = zip(xs, ys, zs) - xs, ys, zs = mi.unzip(zipped) - self.assertEqual(list(xs), list(range(10))) - self.assertEqual(list(ys), list(range(1, 11))) - self.assertEqual(list(zs), list(range(2, 12))) - - def test_improperly_zipped(self): - zipped = iter([(1, 2, 3), (4, 5), (6,)]) - xs, ys, zs = mi.unzip(zipped) - self.assertEqual(list(xs), [1, 4, 6]) - self.assertEqual(list(ys), [2, 5]) - self.assertEqual(list(zs), [3]) - - def test_increasingly_zipped(self): - zipped = iter([(1, 2), (3, 4, 5), (6, 7, 8, 9)]) - unzipped = mi.unzip(zipped) - # from the docstring: - # len(first tuple) is the number of iterables zipped - self.assertEqual(len(unzipped), 2) - xs, ys = unzipped - self.assertEqual(list(xs), [1, 3, 6]) - self.assertEqual(list(ys), [2, 4, 7]) - - -class SortTogetherTest(TestCase): - """Tests for sort_together()""" - - def test_key_list(self): - """tests `key_list` including default, iterables include duplicates""" - iterables = [ - ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], - ['May', 'Aug.', 'May', 'June', 'July', 'July'], - [97, 20, 100, 70, 100, 20] - ] - - self.assertEqual( - mi.sort_together(iterables), - [ - ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), - ('June', 'July', 'July', 'May', 'Aug.', 'May'), - (70, 100, 20, 97, 20, 100) - ] - ) - - self.assertEqual( - mi.sort_together(iterables, key_list=(0, 1)), - [ - ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), - ('July', 'July', 'June', 'Aug.', 'May', 'May'), - (100, 20, 70, 20, 97, 100) - ] - ) - - self.assertEqual( - mi.sort_together(iterables, key_list=(0, 1, 2)), - [ - ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), - ('July', 'July', 'June', 'Aug.', 'May', 'May'), - (20, 100, 70, 20, 97, 100) - ] - ) - - self.assertEqual( - mi.sort_together(iterables, key_list=(2,)), - [ - ('GA', 'CT', 'CT', 'GA', 'GA', 'CT'), - ('Aug.', 'July', 'June', 'May', 'May', 'July'), - (20, 20, 70, 97, 100, 100) - ] - ) - - def test_invalid_key_list(self): - """tests `key_list` for indexes not available in `iterables`""" - iterables = [ - ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], - ['May', 'Aug.', 'May', 'June', 'July', 'July'], - [97, 20, 100, 70, 100, 20] - ] - - self.assertRaises( - IndexError, lambda: mi.sort_together(iterables, key_list=(5,)) - ) - - def test_reverse(self): - """tests `reverse` to ensure a reverse sort for `key_list` iterables""" - iterables = [ - ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], - ['May', 'Aug.', 'May', 'June', 'July', 'July'], - [97, 20, 100, 70, 100, 20] - ] - - self.assertEqual( - mi.sort_together(iterables, key_list=(0, 1, 2), reverse=True), - [('GA', 'GA', 'GA', 'CT', 'CT', 'CT'), - ('May', 'May', 'Aug.', 'June', 'July', 'July'), - (100, 97, 20, 70, 100, 20)] - ) - - def test_uneven_iterables(self): - """tests trimming of iterables to the shortest length before sorting""" - iterables = [['GA', 'GA', 'GA', 'CT', 'CT', 'CT', 'MA'], - ['May', 'Aug.', 'May', 'June', 'July', 'July'], - [97, 20, 100, 70, 100, 20, 0]] - - self.assertEqual( - mi.sort_together(iterables), - [ - ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), - ('June', 'July', 'July', 'May', 'Aug.', 'May'), - (70, 100, 20, 97, 20, 100) - ] - ) - - -class DivideTest(TestCase): - """Tests for divide()""" - - def test_invalid_n(self): - self.assertRaises(ValueError, lambda: mi.divide(-1, [1, 2, 3])) - self.assertRaises(ValueError, lambda: mi.divide(0, [1, 2, 3])) - - def test_basic(self): - iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - - for n, expected in [ - (1, [iterable]), - (2, [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), - (3, [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10]]), - (10, [[n] for n in range(1, 10 + 1)]), - ]: - self.assertEqual( - [list(x) for x in mi.divide(n, iterable)], expected - ) - - def test_large_n(self): - iterable = [1, 2, 3, 4] - self.assertEqual( - [list(x) for x in mi.divide(6, iterable)], - [[1], [2], [3], [4], [], []] - ) - - -class TestAlwaysIterable(TestCase): - """Tests for always_iterable()""" - def test_single(self): - self.assertEqual(list(mi.always_iterable(1)), [1]) - - def test_strings(self): - for obj in ['foo', b'bar', 'baz']: - actual = list(mi.always_iterable(obj)) - expected = [obj] - self.assertEqual(actual, expected) - - def test_base_type(self): - dict_obj = {'a': 1, 'b': 2} - str_obj = '123' - - # Default: dicts are iterable like they normally are - default_actual = list(mi.always_iterable(dict_obj)) - default_expected = list(dict_obj) - self.assertEqual(default_actual, default_expected) - - # Unitary types set: dicts are not iterable - custom_actual = list(mi.always_iterable(dict_obj, base_type=dict)) - custom_expected = [dict_obj] - self.assertEqual(custom_actual, custom_expected) - - # With unitary types set, strings are iterable - str_actual = list(mi.always_iterable(str_obj, base_type=None)) - str_expected = list(str_obj) - self.assertEqual(str_actual, str_expected) - - def test_iterables(self): - self.assertEqual(list(mi.always_iterable([0, 1])), [0, 1]) - self.assertEqual( - list(mi.always_iterable([0, 1], base_type=list)), [[0, 1]] - ) - self.assertEqual( - list(mi.always_iterable(iter('foo'))), ['f', 'o', 'o'] - ) - self.assertEqual(list(mi.always_iterable([])), []) - - def test_none(self): - self.assertEqual(list(mi.always_iterable(None)), []) - - def test_generator(self): - def _gen(): - yield 0 - yield 1 - - self.assertEqual(list(mi.always_iterable(_gen())), [0, 1]) - - -class AdjacentTests(TestCase): - def test_typical(self): - actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10))) - expected = [(True, 0), (True, 1), (False, 2), (False, 3), (True, 4), - (True, 5), (True, 6), (False, 7), (False, 8), (False, 9)] - self.assertEqual(actual, expected) - - def test_empty_iterable(self): - actual = list(mi.adjacent(lambda x: x % 5 == 0, [])) - expected = [] - self.assertEqual(actual, expected) - - def test_length_one(self): - actual = list(mi.adjacent(lambda x: x % 5 == 0, [0])) - expected = [(True, 0)] - self.assertEqual(actual, expected) - - actual = list(mi.adjacent(lambda x: x % 5 == 0, [1])) - expected = [(False, 1)] - self.assertEqual(actual, expected) - - def test_consecutive_true(self): - """Test that when the predicate matches multiple consecutive elements - it doesn't repeat elements in the output""" - actual = list(mi.adjacent(lambda x: x % 5 < 2, range(10))) - expected = [(True, 0), (True, 1), (True, 2), (False, 3), (True, 4), - (True, 5), (True, 6), (True, 7), (False, 8), (False, 9)] - self.assertEqual(actual, expected) - - def test_distance(self): - actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=2)) - expected = [(True, 0), (True, 1), (True, 2), (True, 3), (True, 4), - (True, 5), (True, 6), (True, 7), (False, 8), (False, 9)] - self.assertEqual(actual, expected) - - actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=3)) - expected = [(True, 0), (True, 1), (True, 2), (True, 3), (True, 4), - (True, 5), (True, 6), (True, 7), (True, 8), (False, 9)] - self.assertEqual(actual, expected) - - def test_large_distance(self): - """Test distance larger than the length of the iterable""" - iterable = range(10) - actual = list(mi.adjacent(lambda x: x % 5 == 4, iterable, distance=20)) - expected = list(zip(repeat(True), iterable)) - self.assertEqual(actual, expected) - - actual = list(mi.adjacent(lambda x: False, iterable, distance=20)) - expected = list(zip(repeat(False), iterable)) - self.assertEqual(actual, expected) - - def test_zero_distance(self): - """Test that adjacent() reduces to zip+map when distance is 0""" - iterable = range(1000) - predicate = lambda x: x % 4 == 2 - actual = mi.adjacent(predicate, iterable, 0) - expected = zip(map(predicate, iterable), iterable) - self.assertTrue(all(a == e for a, e in zip(actual, expected))) - - def test_negative_distance(self): - """Test that adjacent() raises an error with negative distance""" - pred = lambda x: x - self.assertRaises( - ValueError, lambda: mi.adjacent(pred, range(1000), -1) - ) - self.assertRaises( - ValueError, lambda: mi.adjacent(pred, range(10), -10) - ) - - def test_grouping(self): - """Test interaction of adjacent() with groupby_transform()""" - iterable = mi.adjacent(lambda x: x % 5 == 0, range(10)) - grouper = mi.groupby_transform(iterable, itemgetter(0), itemgetter(1)) - actual = [(k, list(g)) for k, g in grouper] - expected = [ - (True, [0, 1]), - (False, [2, 3]), - (True, [4, 5, 6]), - (False, [7, 8, 9]), - ] - self.assertEqual(actual, expected) - - def test_call_once(self): - """Test that the predicate is only called once per item.""" - already_seen = set() - iterable = range(10) - - def predicate(item): - self.assertNotIn(item, already_seen) - already_seen.add(item) - return True - - actual = list(mi.adjacent(predicate, iterable)) - expected = [(True, x) for x in iterable] - self.assertEqual(actual, expected) - - -class GroupByTransformTests(TestCase): - def assertAllGroupsEqual(self, groupby1, groupby2): - """Compare two groupby objects for equality, both keys and groups.""" - for a, b in zip(groupby1, groupby2): - key1, group1 = a - key2, group2 = b - self.assertEqual(key1, key2) - self.assertListEqual(list(group1), list(group2)) - self.assertRaises(StopIteration, lambda: next(groupby1)) - self.assertRaises(StopIteration, lambda: next(groupby2)) - - def test_default_funcs(self): - """Test that groupby_transform() with default args mimics groupby()""" - iterable = [(x // 5, x) for x in range(1000)] - actual = mi.groupby_transform(iterable) - expected = groupby(iterable) - self.assertAllGroupsEqual(actual, expected) - - def test_valuefunc(self): - iterable = [(int(x / 5), int(x / 3), x) for x in range(10)] - - # Test the standard usage of grouping one iterable using another's keys - grouper = mi.groupby_transform( - iterable, keyfunc=itemgetter(0), valuefunc=itemgetter(-1) - ) - actual = [(k, list(g)) for k, g in grouper] - expected = [(0, [0, 1, 2, 3, 4]), (1, [5, 6, 7, 8, 9])] - self.assertEqual(actual, expected) - - grouper = mi.groupby_transform( - iterable, keyfunc=itemgetter(1), valuefunc=itemgetter(-1) - ) - actual = [(k, list(g)) for k, g in grouper] - expected = [(0, [0, 1, 2]), (1, [3, 4, 5]), (2, [6, 7, 8]), (3, [9])] - self.assertEqual(actual, expected) - - # and now for something a little different - d = dict(zip(range(10), 'abcdefghij')) - grouper = mi.groupby_transform( - range(10), keyfunc=lambda x: x // 5, valuefunc=d.get - ) - actual = [(k, ''.join(g)) for k, g in grouper] - expected = [(0, 'abcde'), (1, 'fghij')] - self.assertEqual(actual, expected) - - def test_no_valuefunc(self): - iterable = range(1000) - - def key(x): - return x // 5 - - actual = mi.groupby_transform(iterable, key, valuefunc=None) - expected = groupby(iterable, key) - self.assertAllGroupsEqual(actual, expected) - - actual = mi.groupby_transform(iterable, key) # default valuefunc - expected = groupby(iterable, key) - self.assertAllGroupsEqual(actual, expected) - - -class NumericRangeTests(TestCase): - def test_basic(self): - for args, expected in [ - ((4,), [0, 1, 2, 3]), - ((4.0,), [0.0, 1.0, 2.0, 3.0]), - ((1.0, 4), [1.0, 2.0, 3.0]), - ((1, 4.0), [1, 2, 3]), - ((1.0, 5), [1.0, 2.0, 3.0, 4.0]), - ((0, 20, 5), [0, 5, 10, 15]), - ((0, 20, 5.0), [0.0, 5.0, 10.0, 15.0]), - ((0, 10, 3), [0, 3, 6, 9]), - ((0, 10, 3.0), [0.0, 3.0, 6.0, 9.0]), - ((0, -5, -1), [0, -1, -2, -3, -4]), - ((0.0, -5, -1), [0.0, -1.0, -2.0, -3.0, -4.0]), - ((1, 2, Fraction(1, 2)), [Fraction(1, 1), Fraction(3, 2)]), - ((0,), []), - ((0.0,), []), - ((1, 0), []), - ((1.0, 0.0), []), - ((Fraction(2, 1),), [Fraction(0, 1), Fraction(1, 1)]), - ((Decimal('2.0'),), [Decimal('0.0'), Decimal('1.0')]), - ]: - actual = list(mi.numeric_range(*args)) - self.assertEqual(actual, expected) - self.assertTrue( - all(type(a) == type(e) for a, e in zip(actual, expected)) - ) - - def test_arg_count(self): - self.assertRaises(TypeError, lambda: list(mi.numeric_range())) - self.assertRaises( - TypeError, lambda: list(mi.numeric_range(0, 1, 2, 3)) - ) - - def test_zero_step(self): - self.assertRaises( - ValueError, lambda: list(mi.numeric_range(1, 2, 0)) - ) - - -class CountCycleTests(TestCase): - def test_basic(self): - expected = [ - (0, 'a'), (0, 'b'), (0, 'c'), - (1, 'a'), (1, 'b'), (1, 'c'), - (2, 'a'), (2, 'b'), (2, 'c'), - ] - for actual in [ - mi.take(9, mi.count_cycle('abc')), # n=None - list(mi.count_cycle('abc', 3)), # n=3 - ]: - self.assertEqual(actual, expected) - - def test_empty(self): - self.assertEqual(list(mi.count_cycle('')), []) - self.assertEqual(list(mi.count_cycle('', 2)), []) - - def test_negative(self): - self.assertEqual(list(mi.count_cycle('abc', -3)), []) - - -class LocateTests(TestCase): - def test_default_pred(self): - iterable = [0, 1, 1, 0, 1, 0, 0] - actual = list(mi.locate(iterable)) - expected = [1, 2, 4] - self.assertEqual(actual, expected) - - def test_no_matches(self): - iterable = [0, 0, 0] - actual = list(mi.locate(iterable)) - expected = [] - self.assertEqual(actual, expected) - - def test_custom_pred(self): - iterable = ['0', 1, 1, '0', 1, '0', '0'] - pred = lambda x: x == '0' - actual = list(mi.locate(iterable, pred)) - expected = [0, 3, 5, 6] - self.assertEqual(actual, expected) - - def test_window_size(self): - iterable = ['0', 1, 1, '0', 1, '0', '0'] - pred = lambda *args: args == ('0', 1) - actual = list(mi.locate(iterable, pred, window_size=2)) - expected = [0, 3] - self.assertEqual(actual, expected) - - def test_window_size_large(self): - iterable = [1, 2, 3, 4] - pred = lambda a, b, c, d, e: True - actual = list(mi.locate(iterable, pred, window_size=5)) - expected = [0] - self.assertEqual(actual, expected) - - def test_window_size_zero(self): - iterable = [1, 2, 3, 4] - pred = lambda: True - with self.assertRaises(ValueError): - list(mi.locate(iterable, pred, window_size=0)) - - -class StripFunctionTests(TestCase): - def test_hashable(self): - iterable = list('www.example.com') - pred = lambda x: x in set('cmowz.') - - self.assertEqual(list(mi.lstrip(iterable, pred)), list('example.com')) - self.assertEqual(list(mi.rstrip(iterable, pred)), list('www.example')) - self.assertEqual(list(mi.strip(iterable, pred)), list('example')) - - def test_not_hashable(self): - iterable = [ - list('http://'), list('www'), list('.example'), list('.com') - ] - pred = lambda x: x in [list('http://'), list('www'), list('.com')] - - self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[2:]) - self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:3]) - self.assertEqual(list(mi.strip(iterable, pred)), iterable[2: 3]) - - def test_math(self): - iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] - pred = lambda x: x <= 2 - - self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[3:]) - self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:-3]) - self.assertEqual(list(mi.strip(iterable, pred)), iterable[3:-3]) - - -class IsliceExtendedTests(TestCase): - def test_all(self): - iterable = ['0', '1', '2', '3', '4', '5'] - indexes = list(range(-4, len(iterable) + 4)) + [None] - steps = [1, 2, 3, 4, -1, -2, -3, 4] - for slice_args in product(indexes, indexes, steps): - try: - actual = list(mi.islice_extended(iterable, *slice_args)) - except Exception as e: - self.fail((slice_args, e)) - - expected = iterable[slice(*slice_args)] - self.assertEqual(actual, expected, slice_args) - - def test_zero_step(self): - with self.assertRaises(ValueError): - list(mi.islice_extended([1, 2, 3], 0, 1, 0)) - - -class ConsecutiveGroupsTest(TestCase): - def test_numbers(self): - iterable = [-10, -8, -7, -6, 1, 2, 4, 5, -1, 7] - actual = [list(g) for g in mi.consecutive_groups(iterable)] - expected = [[-10], [-8, -7, -6], [1, 2], [4, 5], [-1], [7]] - self.assertEqual(actual, expected) - - def test_custom_ordering(self): - iterable = ['1', '10', '11', '20', '21', '22', '30', '31'] - ordering = lambda x: int(x) - actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)] - expected = [['1'], ['10', '11'], ['20', '21', '22'], ['30', '31']] - self.assertEqual(actual, expected) - - def test_exotic_ordering(self): - iterable = [ - ('a', 'b', 'c', 'd'), - ('a', 'c', 'b', 'd'), - ('a', 'c', 'd', 'b'), - ('a', 'd', 'b', 'c'), - ('d', 'b', 'c', 'a'), - ('d', 'c', 'a', 'b'), - ] - ordering = list(permutations('abcd')).index - actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)] - expected = [ - [('a', 'b', 'c', 'd')], - [('a', 'c', 'b', 'd'), ('a', 'c', 'd', 'b'), ('a', 'd', 'b', 'c')], - [('d', 'b', 'c', 'a'), ('d', 'c', 'a', 'b')], - ] - self.assertEqual(actual, expected) - - -class DifferenceTest(TestCase): - def test_normal(self): - iterable = [10, 20, 30, 40, 50] - actual = list(mi.difference(iterable)) - expected = [10, 10, 10, 10, 10] - self.assertEqual(actual, expected) - - def test_custom(self): - iterable = [10, 20, 30, 40, 50] - actual = list(mi.difference(iterable, add)) - expected = [10, 30, 50, 70, 90] - self.assertEqual(actual, expected) - - def test_roundtrip(self): - original = list(range(100)) - accumulated = mi.accumulate(original) - actual = list(mi.difference(accumulated)) - self.assertEqual(actual, original) - - def test_one(self): - self.assertEqual(list(mi.difference([0])), [0]) - - def test_empty(self): - self.assertEqual(list(mi.difference([])), []) - - -class SeekableTest(TestCase): - def test_exhaustion_reset(self): - iterable = [str(n) for n in range(10)] - - s = mi.seekable(iterable) - self.assertEqual(list(s), iterable) # Normal iteration - self.assertEqual(list(s), []) # Iterable is exhausted - - s.seek(0) - self.assertEqual(list(s), iterable) # Back in action - - def test_partial_reset(self): - iterable = [str(n) for n in range(10)] - - s = mi.seekable(iterable) - self.assertEqual(mi.take(5, s), iterable[:5]) # Normal iteration - - s.seek(1) - self.assertEqual(list(s), iterable[1:]) # Get the rest of the iterable - - def test_forward(self): - iterable = [str(n) for n in range(10)] - - s = mi.seekable(iterable) - self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration - - s.seek(3) # Skip over index 2 - self.assertEqual(list(s), iterable[3:]) # Result is similar to slicing - - s.seek(0) # Back to 0 - self.assertEqual(list(s), iterable) # No difference in result - - def test_past_end(self): - iterable = [str(n) for n in range(10)] - - s = mi.seekable(iterable) - self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration - - s.seek(20) - self.assertEqual(list(s), []) # Iterable is exhausted - - s.seek(0) # Back to 0 - self.assertEqual(list(s), iterable) # No difference in result - - def test_elements(self): - iterable = map(str, count()) - - s = mi.seekable(iterable) - mi.take(10, s) - - elements = s.elements() - self.assertEqual( - [elements[i] for i in range(10)], [str(n) for n in range(10)] - ) - self.assertEqual(len(elements), 10) - - mi.take(10, s) - self.assertEqual(list(elements), [str(n) for n in range(20)]) - - -class SequenceViewTests(TestCase): - def test_init(self): - view = mi.SequenceView((1, 2, 3)) - self.assertEqual(repr(view), "SequenceView((1, 2, 3))") - self.assertRaises(TypeError, lambda: mi.SequenceView({})) - - def test_update(self): - seq = [1, 2, 3] - view = mi.SequenceView(seq) - self.assertEqual(len(view), 3) - self.assertEqual(repr(view), "SequenceView([1, 2, 3])") - - seq.pop() - self.assertEqual(len(view), 2) - self.assertEqual(repr(view), "SequenceView([1, 2])") - - def test_indexing(self): - seq = ('a', 'b', 'c', 'd', 'e', 'f') - view = mi.SequenceView(seq) - for i in range(-len(seq), len(seq)): - self.assertEqual(view[i], seq[i]) - - def test_slicing(self): - seq = ('a', 'b', 'c', 'd', 'e', 'f') - view = mi.SequenceView(seq) - n = len(seq) - indexes = list(range(-n - 1, n + 1)) + [None] - steps = list(range(-n, n + 1)) - steps.remove(0) - for slice_args in product(indexes, indexes, steps): - i = slice(*slice_args) - self.assertEqual(view[i], seq[i]) - - def test_abc_methods(self): - # collections.Sequence should provide all of this functionality - seq = ('a', 'b', 'c', 'd', 'e', 'f', 'f') - view = mi.SequenceView(seq) - - # __contains__ - self.assertIn('b', view) - self.assertNotIn('g', view) - - # __iter__ - self.assertEqual(list(iter(view)), list(seq)) - - # __reversed__ - self.assertEqual(list(reversed(view)), list(reversed(seq))) - - # index - self.assertEqual(view.index('b'), 1) - - # count - self.assertEqual(seq.count('f'), 2) - - -class RunLengthTest(TestCase): - def test_encode(self): - iterable = (int(str(n)[0]) for n in count(800)) - actual = mi.take(4, mi.run_length.encode(iterable)) - expected = [(8, 100), (9, 100), (1, 1000), (2, 1000)] - self.assertEqual(actual, expected) - - def test_decode(self): - iterable = [('d', 4), ('c', 3), ('b', 2), ('a', 1)] - actual = ''.join(mi.run_length.decode(iterable)) - expected = 'ddddcccbba' - self.assertEqual(actual, expected) - - -class ExactlyNTests(TestCase): - """Tests for ``exactly_n()``""" - - def test_true(self): - """Iterable has ``n`` ``True`` elements""" - self.assertTrue(mi.exactly_n([True, False, True], 2)) - self.assertTrue(mi.exactly_n([1, 1, 1, 0], 3)) - self.assertTrue(mi.exactly_n([False, False], 0)) - self.assertTrue(mi.exactly_n(range(100), 10, lambda x: x < 10)) - - def test_false(self): - """Iterable does not have ``n`` ``True`` elements""" - self.assertFalse(mi.exactly_n([True, False, False], 2)) - self.assertFalse(mi.exactly_n([True, True, False], 1)) - self.assertFalse(mi.exactly_n([False], 1)) - self.assertFalse(mi.exactly_n([True], -1)) - self.assertFalse(mi.exactly_n(repeat(True), 100)) - - def test_empty(self): - """Return ``True`` if the iterable is empty and ``n`` is 0""" - self.assertTrue(mi.exactly_n([], 0)) - self.assertFalse(mi.exactly_n([], 1)) - - -class AlwaysReversibleTests(TestCase): - """Tests for ``always_reversible()``""" - - def test_regular_reversed(self): - self.assertEqual(list(reversed(range(10))), - list(mi.always_reversible(range(10)))) - self.assertEqual(list(reversed([1, 2, 3])), - list(mi.always_reversible([1, 2, 3]))) - self.assertEqual(reversed([1, 2, 3]).__class__, - mi.always_reversible([1, 2, 3]).__class__) - - def test_nonseq_reversed(self): - # Create a non-reversible generator from a sequence - with self.assertRaises(TypeError): - reversed(x for x in range(10)) - - self.assertEqual(list(reversed(range(10))), - list(mi.always_reversible(x for x in range(10)))) - self.assertEqual(list(reversed([1, 2, 3])), - list(mi.always_reversible(x for x in [1, 2, 3]))) - self.assertNotEqual(reversed((1, 2)).__class__, - mi.always_reversible(x for x in (1, 2)).__class__) - - -class CircularShiftsTests(TestCase): - def test_empty(self): - # empty iterable -> empty list - self.assertEqual(list(mi.circular_shifts([])), []) - - def test_simple_circular_shifts(self): - # test the a simple iterator case - self.assertEqual( - mi.circular_shifts(range(4)), - [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] - ) - - def test_duplicates(self): - # test non-distinct entries - self.assertEqual( - mi.circular_shifts([0, 1, 0, 1]), - [(0, 1, 0, 1), (1, 0, 1, 0), (0, 1, 0, 1), (1, 0, 1, 0)] - ) - - -class MakeDecoratorTests(TestCase): - def test_basic(self): - slicer = mi.make_decorator(islice) - - @slicer(1, 10, 2) - def user_function(arg_1, arg_2, kwarg_1=None): - self.assertEqual(arg_1, 'arg_1') - self.assertEqual(arg_2, 'arg_2') - self.assertEqual(kwarg_1, 'kwarg_1') - return map(str, count()) - - it = user_function('arg_1', 'arg_2', kwarg_1='kwarg_1') - actual = list(it) - expected = ['1', '3', '5', '7', '9'] - self.assertEqual(actual, expected) - - def test_result_index(self): - def stringify(*args, **kwargs): - self.assertEqual(args[0], 'arg_0') - iterable = args[1] - self.assertEqual(args[2], 'arg_2') - self.assertEqual(kwargs['kwarg_1'], 'kwarg_1') - return map(str, iterable) - - stringifier = mi.make_decorator(stringify, result_index=1) - - @stringifier('arg_0', 'arg_2', kwarg_1='kwarg_1') - def user_function(n): - return count(n) - - it = user_function(1) - actual = mi.take(5, it) - expected = ['1', '2', '3', '4', '5'] - self.assertEqual(actual, expected) - - def test_wrap_class(self): - seeker = mi.make_decorator(mi.seekable) - - @seeker() - def user_function(n): - return map(str, range(n)) - - it = user_function(5) - self.assertEqual(list(it), ['0', '1', '2', '3', '4']) - - it.seek(0) - self.assertEqual(list(it), ['0', '1', '2', '3', '4']) - - -class MapReduceTests(TestCase): - def test_default(self): - iterable = (str(x) for x in range(5)) - keyfunc = lambda x: int(x) // 2 - actual = sorted(mi.map_reduce(iterable, keyfunc).items()) - expected = [(0, ['0', '1']), (1, ['2', '3']), (2, ['4'])] - self.assertEqual(actual, expected) - - def test_valuefunc(self): - iterable = (str(x) for x in range(5)) - keyfunc = lambda x: int(x) // 2 - valuefunc = int - actual = sorted(mi.map_reduce(iterable, keyfunc, valuefunc).items()) - expected = [(0, [0, 1]), (1, [2, 3]), (2, [4])] - self.assertEqual(actual, expected) - - def test_reducefunc(self): - iterable = (str(x) for x in range(5)) - keyfunc = lambda x: int(x) // 2 - valuefunc = int - reducefunc = lambda value_list: reduce(mul, value_list, 1) - actual = sorted( - mi.map_reduce(iterable, keyfunc, valuefunc, reducefunc).items() - ) - expected = [(0, 0), (1, 6), (2, 4)] - self.assertEqual(actual, expected) - - def test_ret(self): - d = mi.map_reduce([1, 0, 2, 0, 1, 0], bool) - self.assertEqual(d, {False: [0, 0, 0], True: [1, 2, 1]}) - self.assertRaises(KeyError, lambda: d[None].append(1)) - - -class RlocateTests(TestCase): - def test_default_pred(self): - iterable = [0, 1, 1, 0, 1, 0, 0] - for it in (iterable[:], iter(iterable)): - actual = list(mi.rlocate(it)) - expected = [4, 2, 1] - self.assertEqual(actual, expected) - - def test_no_matches(self): - iterable = [0, 0, 0] - for it in (iterable[:], iter(iterable)): - actual = list(mi.rlocate(it)) - expected = [] - self.assertEqual(actual, expected) - - def test_custom_pred(self): - iterable = ['0', 1, 1, '0', 1, '0', '0'] - pred = lambda x: x == '0' - for it in (iterable[:], iter(iterable)): - actual = list(mi.rlocate(it, pred)) - expected = [6, 5, 3, 0] - self.assertEqual(actual, expected) - - def test_efficient_reversal(self): - iterable = range(9 ** 9) # Is efficiently reversible - target = 9 ** 9 - 2 - pred = lambda x: x == target # Find-able from the right - actual = next(mi.rlocate(iterable, pred)) - self.assertEqual(actual, target) - - def test_window_size(self): - iterable = ['0', 1, 1, '0', 1, '0', '0'] - pred = lambda *args: args == ('0', 1) - for it in (iterable, iter(iterable)): - actual = list(mi.rlocate(it, pred, window_size=2)) - expected = [3, 0] - self.assertEqual(actual, expected) - - def test_window_size_large(self): - iterable = [1, 2, 3, 4] - pred = lambda a, b, c, d, e: True - for it in (iterable, iter(iterable)): - actual = list(mi.rlocate(iterable, pred, window_size=5)) - expected = [0] - self.assertEqual(actual, expected) - - def test_window_size_zero(self): - iterable = [1, 2, 3, 4] - pred = lambda: True - for it in (iterable, iter(iterable)): - with self.assertRaises(ValueError): - list(mi.locate(iterable, pred, window_size=0)) - - -class ReplaceTests(TestCase): - def test_basic(self): - iterable = range(10) - pred = lambda x: x % 2 == 0 - substitutes = [] - actual = list(mi.replace(iterable, pred, substitutes)) - expected = [1, 3, 5, 7, 9] - self.assertEqual(actual, expected) - - def test_count(self): - iterable = range(10) - pred = lambda x: x % 2 == 0 - substitutes = [] - actual = list(mi.replace(iterable, pred, substitutes, count=4)) - expected = [1, 3, 5, 7, 8, 9] - self.assertEqual(actual, expected) - - def test_window_size(self): - iterable = range(10) - pred = lambda *args: args == (0, 1, 2) - substitutes = [] - actual = list(mi.replace(iterable, pred, substitutes, window_size=3)) - expected = [3, 4, 5, 6, 7, 8, 9] - self.assertEqual(actual, expected) - - def test_window_size_end(self): - iterable = range(10) - pred = lambda *args: args == (7, 8, 9) - substitutes = [] - actual = list(mi.replace(iterable, pred, substitutes, window_size=3)) - expected = [0, 1, 2, 3, 4, 5, 6] - self.assertEqual(actual, expected) - - def test_window_size_count(self): - iterable = range(10) - pred = lambda *args: (args == (0, 1, 2)) or (args == (7, 8, 9)) - substitutes = [] - actual = list( - mi.replace(iterable, pred, substitutes, count=1, window_size=3) - ) - expected = [3, 4, 5, 6, 7, 8, 9] - self.assertEqual(actual, expected) - - def test_window_size_large(self): - iterable = range(4) - pred = lambda a, b, c, d, e: True - substitutes = [5, 6, 7] - actual = list(mi.replace(iterable, pred, substitutes, window_size=5)) - expected = [5, 6, 7] - self.assertEqual(actual, expected) - - def test_window_size_zero(self): - iterable = range(10) - pred = lambda *args: True - substitutes = [] - with self.assertRaises(ValueError): - list(mi.replace(iterable, pred, substitutes, window_size=0)) - - def test_iterable_substitutes(self): - iterable = range(5) - pred = lambda x: x % 2 == 0 - substitutes = iter('__') - actual = list(mi.replace(iterable, pred, substitutes)) - expected = ['_', '_', 1, '_', '_', 3, '_', '_'] - self.assertEqual(actual, expected) diff --git a/lib/more_itertools/tests/test_recipes.py b/lib/more_itertools/tests/test_recipes.py deleted file mode 100644 index b3cfb62f..00000000 --- a/lib/more_itertools/tests/test_recipes.py +++ /dev/null @@ -1,616 +0,0 @@ -from doctest import DocTestSuite -from unittest import TestCase - -from itertools import combinations -from six.moves import range - -import more_itertools as mi - - -def load_tests(loader, tests, ignore): - # Add the doctests - tests.addTests(DocTestSuite('more_itertools.recipes')) - return tests - - -class AccumulateTests(TestCase): - """Tests for ``accumulate()``""" - - def test_empty(self): - """Test that an empty input returns an empty output""" - self.assertEqual(list(mi.accumulate([])), []) - - def test_default(self): - """Test accumulate with the default function (addition)""" - self.assertEqual(list(mi.accumulate([1, 2, 3])), [1, 3, 6]) - - def test_bogus_function(self): - """Test accumulate with an invalid function""" - with self.assertRaises(TypeError): - list(mi.accumulate([1, 2, 3], func=lambda x: x)) - - def test_custom_function(self): - """Test accumulate with a custom function""" - self.assertEqual( - list(mi.accumulate((1, 2, 3, 2, 1), func=max)), [1, 2, 3, 3, 3] - ) - - -class TakeTests(TestCase): - """Tests for ``take()``""" - - def test_simple_take(self): - """Test basic usage""" - t = mi.take(5, range(10)) - self.assertEqual(t, [0, 1, 2, 3, 4]) - - def test_null_take(self): - """Check the null case""" - t = mi.take(0, range(10)) - self.assertEqual(t, []) - - def test_negative_take(self): - """Make sure taking negative items results in a ValueError""" - self.assertRaises(ValueError, lambda: mi.take(-3, range(10))) - - def test_take_too_much(self): - """Taking more than an iterator has remaining should return what the - iterator has remaining. - - """ - t = mi.take(10, range(5)) - self.assertEqual(t, [0, 1, 2, 3, 4]) - - -class TabulateTests(TestCase): - """Tests for ``tabulate()``""" - - def test_simple_tabulate(self): - """Test the happy path""" - t = mi.tabulate(lambda x: x) - f = tuple([next(t) for _ in range(3)]) - self.assertEqual(f, (0, 1, 2)) - - def test_count(self): - """Ensure tabulate accepts specific count""" - t = mi.tabulate(lambda x: 2 * x, -1) - f = (next(t), next(t), next(t)) - self.assertEqual(f, (-2, 0, 2)) - - -class TailTests(TestCase): - """Tests for ``tail()``""" - - def test_greater(self): - """Length of iterable is greater than requested tail""" - self.assertEqual(list(mi.tail(3, 'ABCDEFG')), ['E', 'F', 'G']) - - def test_equal(self): - """Length of iterable is equal to the requested tail""" - self.assertEqual( - list(mi.tail(7, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] - ) - - def test_less(self): - """Length of iterable is less than requested tail""" - self.assertEqual( - list(mi.tail(8, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] - ) - - -class ConsumeTests(TestCase): - """Tests for ``consume()``""" - - def test_sanity(self): - """Test basic functionality""" - r = (x for x in range(10)) - mi.consume(r, 3) - self.assertEqual(3, next(r)) - - def test_null_consume(self): - """Check the null case""" - r = (x for x in range(10)) - mi.consume(r, 0) - self.assertEqual(0, next(r)) - - def test_negative_consume(self): - """Check that negative consumsion throws an error""" - r = (x for x in range(10)) - self.assertRaises(ValueError, lambda: mi.consume(r, -1)) - - def test_total_consume(self): - """Check that iterator is totally consumed by default""" - r = (x for x in range(10)) - mi.consume(r) - self.assertRaises(StopIteration, lambda: next(r)) - - -class NthTests(TestCase): - """Tests for ``nth()``""" - - def test_basic(self): - """Make sure the nth item is returned""" - l = range(10) - for i, v in enumerate(l): - self.assertEqual(mi.nth(l, i), v) - - def test_default(self): - """Ensure a default value is returned when nth item not found""" - l = range(3) - self.assertEqual(mi.nth(l, 100, "zebra"), "zebra") - - def test_negative_item_raises(self): - """Ensure asking for a negative item raises an exception""" - self.assertRaises(ValueError, lambda: mi.nth(range(10), -3)) - - -class AllEqualTests(TestCase): - """Tests for ``all_equal()``""" - - def test_true(self): - """Everything is equal""" - self.assertTrue(mi.all_equal('aaaaaa')) - self.assertTrue(mi.all_equal([0, 0, 0, 0])) - - def test_false(self): - """Not everything is equal""" - self.assertFalse(mi.all_equal('aaaaab')) - self.assertFalse(mi.all_equal([0, 0, 0, 1])) - - def test_tricky(self): - """Not everything is identical, but everything is equal""" - items = [1, complex(1, 0), 1.0] - self.assertTrue(mi.all_equal(items)) - - def test_empty(self): - """Return True if the iterable is empty""" - self.assertTrue(mi.all_equal('')) - self.assertTrue(mi.all_equal([])) - - def test_one(self): - """Return True if the iterable is singular""" - self.assertTrue(mi.all_equal('0')) - self.assertTrue(mi.all_equal([0])) - - -class QuantifyTests(TestCase): - """Tests for ``quantify()``""" - - def test_happy_path(self): - """Make sure True count is returned""" - q = [True, False, True] - self.assertEqual(mi.quantify(q), 2) - - def test_custom_predicate(self): - """Ensure non-default predicates return as expected""" - q = range(10) - self.assertEqual(mi.quantify(q, lambda x: x % 2 == 0), 5) - - -class PadnoneTests(TestCase): - """Tests for ``padnone()``""" - - def test_happy_path(self): - """wrapper iterator should return None indefinitely""" - r = range(2) - p = mi.padnone(r) - self.assertEqual([0, 1, None, None], [next(p) for _ in range(4)]) - - -class NcyclesTests(TestCase): - """Tests for ``nyclces()``""" - - def test_happy_path(self): - """cycle a sequence three times""" - r = ["a", "b", "c"] - n = mi.ncycles(r, 3) - self.assertEqual( - ["a", "b", "c", "a", "b", "c", "a", "b", "c"], - list(n) - ) - - def test_null_case(self): - """asking for 0 cycles should return an empty iterator""" - n = mi.ncycles(range(100), 0) - self.assertRaises(StopIteration, lambda: next(n)) - - def test_pathalogical_case(self): - """asking for negative cycles should return an empty iterator""" - n = mi.ncycles(range(100), -10) - self.assertRaises(StopIteration, lambda: next(n)) - - -class DotproductTests(TestCase): - """Tests for ``dotproduct()``'""" - - def test_happy_path(self): - """simple dotproduct example""" - self.assertEqual(400, mi.dotproduct([10, 10], [20, 20])) - - -class FlattenTests(TestCase): - """Tests for ``flatten()``""" - - def test_basic_usage(self): - """ensure list of lists is flattened one level""" - f = [[0, 1, 2], [3, 4, 5]] - self.assertEqual(list(range(6)), list(mi.flatten(f))) - - def test_single_level(self): - """ensure list of lists is flattened only one level""" - f = [[0, [1, 2]], [[3, 4], 5]] - self.assertEqual([0, [1, 2], [3, 4], 5], list(mi.flatten(f))) - - -class RepeatfuncTests(TestCase): - """Tests for ``repeatfunc()``""" - - def test_simple_repeat(self): - """test simple repeated functions""" - r = mi.repeatfunc(lambda: 5) - self.assertEqual([5, 5, 5, 5, 5], [next(r) for _ in range(5)]) - - def test_finite_repeat(self): - """ensure limited repeat when times is provided""" - r = mi.repeatfunc(lambda: 5, times=5) - self.assertEqual([5, 5, 5, 5, 5], list(r)) - - def test_added_arguments(self): - """ensure arguments are applied to the function""" - r = mi.repeatfunc(lambda x: x, 2, 3) - self.assertEqual([3, 3], list(r)) - - def test_null_times(self): - """repeat 0 should return an empty iterator""" - r = mi.repeatfunc(range, 0, 3) - self.assertRaises(StopIteration, lambda: next(r)) - - -class PairwiseTests(TestCase): - """Tests for ``pairwise()``""" - - def test_base_case(self): - """ensure an iterable will return pairwise""" - p = mi.pairwise([1, 2, 3]) - self.assertEqual([(1, 2), (2, 3)], list(p)) - - def test_short_case(self): - """ensure an empty iterator if there's not enough values to pair""" - p = mi.pairwise("a") - self.assertRaises(StopIteration, lambda: next(p)) - - -class GrouperTests(TestCase): - """Tests for ``grouper()``""" - - def test_even(self): - """Test when group size divides evenly into the length of - the iterable. - - """ - self.assertEqual( - list(mi.grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')] - ) - - def test_odd(self): - """Test when group size does not divide evenly into the length of the - iterable. - - """ - self.assertEqual( - list(mi.grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)] - ) - - def test_fill_value(self): - """Test that the fill value is used to pad the final group""" - self.assertEqual( - list(mi.grouper(3, 'ABCDE', 'x')), - [('A', 'B', 'C'), ('D', 'E', 'x')] - ) - - -class RoundrobinTests(TestCase): - """Tests for ``roundrobin()``""" - - def test_even_groups(self): - """Ensure ordered output from evenly populated iterables""" - self.assertEqual( - list(mi.roundrobin('ABC', [1, 2, 3], range(3))), - ['A', 1, 0, 'B', 2, 1, 'C', 3, 2] - ) - - def test_uneven_groups(self): - """Ensure ordered output from unevenly populated iterables""" - self.assertEqual( - list(mi.roundrobin('ABCD', [1, 2], range(0))), - ['A', 1, 'B', 2, 'C', 'D'] - ) - - -class PartitionTests(TestCase): - """Tests for ``partition()``""" - - def test_bool(self): - """Test when pred() returns a boolean""" - lesser, greater = mi.partition(lambda x: x > 5, range(10)) - self.assertEqual(list(lesser), [0, 1, 2, 3, 4, 5]) - self.assertEqual(list(greater), [6, 7, 8, 9]) - - def test_arbitrary(self): - """Test when pred() returns an integer""" - divisibles, remainders = mi.partition(lambda x: x % 3, range(10)) - self.assertEqual(list(divisibles), [0, 3, 6, 9]) - self.assertEqual(list(remainders), [1, 2, 4, 5, 7, 8]) - - -class PowersetTests(TestCase): - """Tests for ``powerset()``""" - - def test_combinatorics(self): - """Ensure a proper enumeration""" - p = mi.powerset([1, 2, 3]) - self.assertEqual( - list(p), - [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] - ) - - -class UniqueEverseenTests(TestCase): - """Tests for ``unique_everseen()``""" - - def test_everseen(self): - """ensure duplicate elements are ignored""" - u = mi.unique_everseen('AAAABBBBCCDAABBB') - self.assertEqual( - ['A', 'B', 'C', 'D'], - list(u) - ) - - def test_custom_key(self): - """ensure the custom key comparison works""" - u = mi.unique_everseen('aAbACCc', key=str.lower) - self.assertEqual(list('abC'), list(u)) - - def test_unhashable(self): - """ensure things work for unhashable items""" - iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] - u = mi.unique_everseen(iterable) - self.assertEqual(list(u), ['a', [1, 2, 3]]) - - def test_unhashable_key(self): - """ensure things work for unhashable items with a custom key""" - iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] - u = mi.unique_everseen(iterable, key=lambda x: x) - self.assertEqual(list(u), ['a', [1, 2, 3]]) - - -class UniqueJustseenTests(TestCase): - """Tests for ``unique_justseen()``""" - - def test_justseen(self): - """ensure only last item is remembered""" - u = mi.unique_justseen('AAAABBBCCDABB') - self.assertEqual(list('ABCDAB'), list(u)) - - def test_custom_key(self): - """ensure the custom key comparison works""" - u = mi.unique_justseen('AABCcAD', str.lower) - self.assertEqual(list('ABCAD'), list(u)) - - -class IterExceptTests(TestCase): - """Tests for ``iter_except()``""" - - def test_exact_exception(self): - """ensure the exact specified exception is caught""" - l = [1, 2, 3] - i = mi.iter_except(l.pop, IndexError) - self.assertEqual(list(i), [3, 2, 1]) - - def test_generic_exception(self): - """ensure the generic exception can be caught""" - l = [1, 2] - i = mi.iter_except(l.pop, Exception) - self.assertEqual(list(i), [2, 1]) - - def test_uncaught_exception_is_raised(self): - """ensure a non-specified exception is raised""" - l = [1, 2, 3] - i = mi.iter_except(l.pop, KeyError) - self.assertRaises(IndexError, lambda: list(i)) - - def test_first(self): - """ensure first is run before the function""" - l = [1, 2, 3] - f = lambda: 25 - i = mi.iter_except(l.pop, IndexError, f) - self.assertEqual(list(i), [25, 3, 2, 1]) - - -class FirstTrueTests(TestCase): - """Tests for ``first_true()``""" - - def test_something_true(self): - """Test with no keywords""" - self.assertEqual(mi.first_true(range(10)), 1) - - def test_nothing_true(self): - """Test default return value.""" - self.assertIsNone(mi.first_true([0, 0, 0])) - - def test_default(self): - """Test with a default keyword""" - self.assertEqual(mi.first_true([0, 0, 0], default='!'), '!') - - def test_pred(self): - """Test with a custom predicate""" - self.assertEqual( - mi.first_true([2, 4, 6], pred=lambda x: x % 3 == 0), 6 - ) - - -class RandomProductTests(TestCase): - """Tests for ``random_product()`` - - Since random.choice() has different results with the same seed across - python versions 2.x and 3.x, these tests use highly probably events to - create predictable outcomes across platforms. - """ - - def test_simple_lists(self): - """Ensure that one item is chosen from each list in each pair. - Also ensure that each item from each list eventually appears in - the chosen combinations. - - Odds are roughly 1 in 7.1 * 10e16 that one item from either list will - not be chosen after 100 samplings of one item from each list. Just to - be safe, better use a known random seed, too. - - """ - nums = [1, 2, 3] - lets = ['a', 'b', 'c'] - n, m = zip(*[mi.random_product(nums, lets) for _ in range(100)]) - n, m = set(n), set(m) - self.assertEqual(n, set(nums)) - self.assertEqual(m, set(lets)) - self.assertEqual(len(n), len(nums)) - self.assertEqual(len(m), len(lets)) - - def test_list_with_repeat(self): - """ensure multiple items are chosen, and that they appear to be chosen - from one list then the next, in proper order. - - """ - nums = [1, 2, 3] - lets = ['a', 'b', 'c'] - r = list(mi.random_product(nums, lets, repeat=100)) - self.assertEqual(2 * 100, len(r)) - n, m = set(r[::2]), set(r[1::2]) - self.assertEqual(n, set(nums)) - self.assertEqual(m, set(lets)) - self.assertEqual(len(n), len(nums)) - self.assertEqual(len(m), len(lets)) - - -class RandomPermutationTests(TestCase): - """Tests for ``random_permutation()``""" - - def test_full_permutation(self): - """ensure every item from the iterable is returned in a new ordering - - 15 elements have a 1 in 1.3 * 10e12 of appearing in sorted order, so - we fix a seed value just to be sure. - - """ - i = range(15) - r = mi.random_permutation(i) - self.assertEqual(set(i), set(r)) - if i == r: - raise AssertionError("Values were not permuted") - - def test_partial_permutation(self): - """ensure all returned items are from the iterable, that the returned - permutation is of the desired length, and that all items eventually - get returned. - - Sampling 100 permutations of length 5 from a set of 15 leaves a - (2/3)^100 chance that an item will not be chosen. Multiplied by 15 - items, there is a 1 in 2.6e16 chance that at least 1 item will not - show up in the resulting output. Using a random seed will fix that. - - """ - items = range(15) - item_set = set(items) - all_items = set() - for _ in range(100): - permutation = mi.random_permutation(items, 5) - self.assertEqual(len(permutation), 5) - permutation_set = set(permutation) - self.assertLessEqual(permutation_set, item_set) - all_items |= permutation_set - self.assertEqual(all_items, item_set) - - -class RandomCombinationTests(TestCase): - """Tests for ``random_combination()``""" - - def test_pseudorandomness(self): - """ensure different subsets of the iterable get returned over many - samplings of random combinations""" - items = range(15) - all_items = set() - for _ in range(50): - combination = mi.random_combination(items, 5) - all_items |= set(combination) - self.assertEqual(all_items, set(items)) - - def test_no_replacement(self): - """ensure that elements are sampled without replacement""" - items = range(15) - for _ in range(50): - combination = mi.random_combination(items, len(items)) - self.assertEqual(len(combination), len(set(combination))) - self.assertRaises( - ValueError, lambda: mi.random_combination(items, len(items) + 1) - ) - - -class RandomCombinationWithReplacementTests(TestCase): - """Tests for ``random_combination_with_replacement()``""" - - def test_replacement(self): - """ensure that elements are sampled with replacement""" - items = range(5) - combo = mi.random_combination_with_replacement(items, len(items) * 2) - self.assertEqual(2 * len(items), len(combo)) - if len(set(combo)) == len(combo): - raise AssertionError("Combination contained no duplicates") - - def test_pseudorandomness(self): - """ensure different subsets of the iterable get returned over many - samplings of random combinations""" - items = range(15) - all_items = set() - for _ in range(50): - combination = mi.random_combination_with_replacement(items, 5) - all_items |= set(combination) - self.assertEqual(all_items, set(items)) - - -class NthCombinationTests(TestCase): - def test_basic(self): - iterable = 'abcdefg' - r = 4 - for index, expected in enumerate(combinations(iterable, r)): - actual = mi.nth_combination(iterable, r, index) - self.assertEqual(actual, expected) - - def test_long(self): - actual = mi.nth_combination(range(180), 4, 2000000) - expected = (2, 12, 35, 126) - self.assertEqual(actual, expected) - - def test_invalid_r(self): - for r in (-1, 3): - with self.assertRaises(ValueError): - mi.nth_combination([], r, 0) - - def test_invalid_index(self): - with self.assertRaises(IndexError): - mi.nth_combination('abcdefg', 3, -36) - - -class PrependTests(TestCase): - def test_basic(self): - value = 'a' - iterator = iter('bcdefg') - actual = list(mi.prepend(value, iterator)) - expected = list('abcdefg') - self.assertEqual(actual, expected) - - def test_multiple(self): - value = 'ab' - iterator = iter('cdefg') - actual = tuple(mi.prepend(value, iterator)) - expected = ('ab',) + tuple('cdefg') - self.assertEqual(actual, expected)