diff --git a/libs/bin/enver.exe b/libs/bin/enver.exe new file mode 100644 index 00000000..19390fed Binary files /dev/null and b/libs/bin/enver.exe differ diff --git a/libs/bin/find-symlinks.exe b/libs/bin/find-symlinks.exe new file mode 100644 index 00000000..cf6d37a8 Binary files /dev/null and b/libs/bin/find-symlinks.exe differ diff --git a/libs/bin/gclip.exe b/libs/bin/gclip.exe new file mode 100644 index 00000000..e3b9a0a6 Binary files /dev/null and b/libs/bin/gclip.exe differ diff --git a/libs/bin/mklink.exe b/libs/bin/mklink.exe new file mode 100644 index 00000000..f467f718 Binary files /dev/null and b/libs/bin/mklink.exe differ diff --git a/libs/bin/pclip.exe b/libs/bin/pclip.exe new file mode 100644 index 00000000..fc8c364b Binary files /dev/null and b/libs/bin/pclip.exe differ diff --git a/libs/bin/xmouse.exe b/libs/bin/xmouse.exe new file mode 100644 index 00000000..b9d66505 Binary files /dev/null and b/libs/bin/xmouse.exe differ diff --git a/libs/importlib_metadata/__init__.py b/libs/importlib_metadata/__init__.py new file mode 100644 index 00000000..f594c6f7 --- /dev/null +++ b/libs/importlib_metadata/__init__.py @@ -0,0 +1,17 @@ +from .api import distribution, Distribution, PackageNotFoundError # noqa: F401 +from .api import metadata, entry_points, resolve, version, read_text + +# Import for installation side-effects. +from . import _hooks # noqa: F401 + + +__all__ = [ + 'metadata', + 'entry_points', + 'resolve', + 'version', + 'read_text', + ] + + +__version__ = version(__name__) diff --git a/libs/importlib_metadata/_hooks.py b/libs/importlib_metadata/_hooks.py new file mode 100644 index 00000000..1fd62698 --- /dev/null +++ b/libs/importlib_metadata/_hooks.py @@ -0,0 +1,148 @@ +from __future__ import unicode_literals, absolute_import + +import re +import sys +import itertools + +from .api import Distribution +from zipfile import ZipFile + +if sys.version_info >= (3,): # pragma: nocover + from contextlib import suppress + from pathlib import Path +else: # pragma: nocover + from contextlib2 import suppress # noqa + from itertools import imap as map # type: ignore + from pathlib2 import Path + + FileNotFoundError = IOError, OSError + __metaclass__ = type + + +def install(cls): + """Class decorator for installation on sys.meta_path.""" + sys.meta_path.append(cls) + return cls + + +class NullFinder: + @staticmethod + def find_spec(*args, **kwargs): + return None + + # In Python 2, the import system requires finders + # to have a find_module() method, but this usage + # is deprecated in Python 3 in favor of find_spec(). + # For the purposes of this finder (i.e. being present + # on sys.meta_path but having no other import + # system functionality), the two methods are identical. + find_module = find_spec + + +@install +class MetadataPathFinder(NullFinder): + """A degenerate finder for distribution packages on the file system. + + This finder supplies only a find_distribution() method for versions + of Python that do not have a PathFinder find_distribution(). + """ + search_template = r'{name}(-.*)?\.(dist|egg)-info' + + @classmethod + def find_distribution(cls, name): + paths = cls._search_paths(name) + dists = map(PathDistribution, paths) + return next(dists, None) + + @classmethod + def _search_paths(cls, name): + """ + Find metadata directories in sys.path heuristically. + """ + return itertools.chain.from_iterable( + cls._search_path(path, name) + for path in map(Path, sys.path) + ) + + @classmethod + def _search_path(cls, root, name): + if not root.is_dir(): + return () + normalized = name.replace('-', '_') + return ( + item + for item in root.iterdir() + if item.is_dir() + and re.match( + cls.search_template.format(name=normalized), + str(item.name), + flags=re.IGNORECASE, + ) + ) + + +class PathDistribution(Distribution): + def __init__(self, path): + """Construct a distribution from a path to the metadata directory.""" + self._path = path + + def read_text(self, filename): + with suppress(FileNotFoundError): + with self._path.joinpath(filename).open(encoding='utf-8') as fp: + return fp.read() + return None + read_text.__doc__ = Distribution.read_text.__doc__ + + +@install +class WheelMetadataFinder(NullFinder): + """A degenerate finder for distribution packages in wheels. + + This finder supplies only a find_distribution() method for versions + of Python that do not have a PathFinder find_distribution(). + """ + search_template = r'{name}(-.*)?\.whl' + + @classmethod + def find_distribution(cls, name): + paths = cls._search_paths(name) + dists = map(WheelDistribution, paths) + return next(dists, None) + + @classmethod + def _search_paths(cls, name): + return ( + item + for item in map(Path, sys.path) + if re.match( + cls.search_template.format(name=name), + str(item.name), + flags=re.IGNORECASE, + ) + ) + + +class WheelDistribution(Distribution): + def __init__(self, archive): + self._archive = archive + name, version = archive.name.split('-')[0:2] + self._dist_info = '{}-{}.dist-info'.format(name, version) + + def read_text(self, filename): + with ZipFile(_path_to_filename(self._archive)) as zf: + with suppress(KeyError): + as_bytes = zf.read('{}/{}'.format(self._dist_info, filename)) + return as_bytes.decode('utf-8') + return None + read_text.__doc__ = Distribution.read_text.__doc__ + + +def _path_to_filename(path): # pragma: nocover + """ + On non-compliant systems, ensure a path-like object is + a string. + """ + try: + return path.__fspath__() + except AttributeError: + return str(path) diff --git a/libs/importlib_metadata/api.py b/libs/importlib_metadata/api.py new file mode 100644 index 00000000..41942a39 --- /dev/null +++ b/libs/importlib_metadata/api.py @@ -0,0 +1,146 @@ +import io +import abc +import sys +import email + +from importlib import import_module + +if sys.version_info > (3,): # pragma: nocover + from configparser import ConfigParser +else: # pragma: nocover + from ConfigParser import SafeConfigParser as ConfigParser + +try: + BaseClass = ModuleNotFoundError +except NameError: # pragma: nocover + BaseClass = ImportError # type: ignore + + +__metaclass__ = type + + +class PackageNotFoundError(BaseClass): + """The package was not found.""" + + +class Distribution: + """A Python distribution package.""" + + @abc.abstractmethod + def read_text(self, filename): + """Attempt to load metadata file given by the name. + + :param filename: The name of the file in the distribution info. + :return: The text if found, otherwise None. + """ + + @classmethod + def from_name(cls, name): + """Return the Distribution for the given package name. + + :param name: The name of the distribution package to search for. + :return: The Distribution instance (or subclass thereof) for the named + package, if found. + :raises PackageNotFoundError: When the named package's distribution + metadata cannot be found. + """ + for resolver in cls._discover_resolvers(): + resolved = resolver(name) + if resolved is not None: + return resolved + else: + raise PackageNotFoundError(name) + + @staticmethod + def _discover_resolvers(): + """Search the meta_path for resolvers.""" + declared = ( + getattr(finder, 'find_distribution', None) + for finder in sys.meta_path + ) + return filter(None, declared) + + @property + def metadata(self): + """Return the parsed metadata for this Distribution. + + The returned object will have keys that name the various bits of + metadata. See PEP 566 for details. + """ + return email.message_from_string( + self.read_text('METADATA') or self.read_text('PKG-INFO') + ) + + @property + def version(self): + """Return the 'Version' metadata for the distribution package.""" + return self.metadata['Version'] + + +def distribution(package): + """Get the ``Distribution`` instance for the given package. + + :param package: The name of the package as a string. + :return: A ``Distribution`` instance (or subclass thereof). + """ + return Distribution.from_name(package) + + +def metadata(package): + """Get the metadata for the package. + + :param package: The name of the distribution package to query. + :return: An email.Message containing the parsed metadata. + """ + return Distribution.from_name(package).metadata + + +def version(package): + """Get the version string for the named package. + + :param package: The name of the distribution package to query. + :return: The version string for the package as defined in the package's + "Version" metadata key. + """ + return distribution(package).version + + +def entry_points(name): + """Return the entry points for the named distribution package. + + :param name: The name of the distribution package to query. + :return: A ConfigParser instance where the sections and keys are taken + from the entry_points.txt ini-style contents. + """ + as_string = read_text(name, 'entry_points.txt') + # 2018-09-10(barry): Should we provide any options here, or let the caller + # send options to the underlying ConfigParser? For now, YAGNI. + config = ConfigParser() + try: + config.read_string(as_string) + except AttributeError: # pragma: nocover + # Python 2 has no read_string + config.readfp(io.StringIO(as_string)) + return config + + +def resolve(entry_point): + """Resolve an entry point string into the named callable. + + :param entry_point: An entry point string of the form + `path.to.module:callable`. + :return: The actual callable object `path.to.module.callable` + :raises ValueError: When `entry_point` doesn't have the proper format. + """ + path, colon, name = entry_point.rpartition(':') + if colon != ':': + raise ValueError('Not an entry point: {}'.format(entry_point)) + module = import_module(path) + return getattr(module, name) + + +def read_text(package, filename): + """ + Read the text of the file in the distribution info directory. + """ + return distribution(package).read_text(filename) diff --git a/libs/importlib_metadata/docs/__init__.py b/libs/importlib_metadata/docs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/importlib_metadata/docs/changelog.rst b/libs/importlib_metadata/docs/changelog.rst new file mode 100644 index 00000000..f8f1fedc --- /dev/null +++ b/libs/importlib_metadata/docs/changelog.rst @@ -0,0 +1,57 @@ +========================= + importlib_metadata NEWS +========================= + +0.7 (2018-11-27) +================ +* Fixed issue where packages with dashes in their names would + not be discovered. Closes #21. +* Distribution lookup is now case-insensitive. Closes #20. +* Wheel distributions can no longer be discovered by their module + name. Like Path distributions, they must be indicated by their + distribution package name. + +0.6 (2018-10-07) +================ +* Removed ``importlib_metadata.distribution`` function. Now + the public interface is primarily the utility functions exposed + in ``importlib_metadata.__all__``. Closes #14. +* Added two new utility functions ``read_text`` and + ``metadata``. + +0.5 (2018-09-18) +================ +* Updated README and removed details about Distribution + class, now considered private. Closes #15. +* Added test suite support for Python 3.4+. +* Fixed SyntaxErrors on Python 3.4 and 3.5. !12 +* Fixed errors on Windows joining Path elements. !15 + +0.4 (2018-09-14) +================ +* Housekeeping. + +0.3 (2018-09-14) +================ +* Added usage documentation. Closes #8 +* Add support for getting metadata from wheels on ``sys.path``. Closes #9 + +0.2 (2018-09-11) +================ +* Added ``importlib_metadata.entry_points()``. Closes #1 +* Added ``importlib_metadata.resolve()``. Closes #12 +* Add support for Python 2.7. Closes #4 + +0.1 (2018-09-10) +================ +* Initial release. + + +.. + Local Variables: + mode: change-log-mode + indent-tabs-mode: nil + sentence-end-double-space: t + fill-column: 78 + coding: utf-8 + End: diff --git a/libs/importlib_metadata/docs/conf.py b/libs/importlib_metadata/docs/conf.py new file mode 100644 index 00000000..c87fc4f2 --- /dev/null +++ b/libs/importlib_metadata/docs/conf.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# flake8: noqa +# +# importlib_metadata documentation build configuration file, created by +# sphinx-quickstart on Thu Nov 30 10:21:00 2017. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) + + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = ['sphinx.ext.autodoc', + 'sphinx.ext.doctest', + 'sphinx.ext.intersphinx', + 'sphinx.ext.coverage', + 'sphinx.ext.viewcode'] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = 'importlib_metadata' +copyright = '2017-2018, Jason Coombs, Barry Warsaw' +author = 'Jason Coombs, Barry Warsaw' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +version = '0.1' +# The full version, including alpha/beta/rc tags. +release = '0.1' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'default' + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +# html_theme_options = {} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. +# +# This is required for the alabaster theme +# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars +html_sidebars = { + '**': [ + 'relations.html', # needs 'show_related': True theme option to display + 'searchbox.html', + ] +} + + +# -- Options for HTMLHelp output ------------------------------------------ + +# Output file base name for HTML help builder. +htmlhelp_basename = 'importlib_metadatadoc' + + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'importlib_metadata.tex', 'importlib\\_metadata Documentation', + 'Brett Cannon, Barry Warsaw', 'manual'), +] + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'importlib_metadata', 'importlib_metadata Documentation', + [author], 1) +] + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'importlib_metadata', 'importlib_metadata Documentation', + author, 'importlib_metadata', 'One line description of project.', + 'Miscellaneous'), +] + + + + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + 'python': ('https://docs.python.org/3', None), + } diff --git a/libs/importlib_metadata/docs/index.rst b/libs/importlib_metadata/docs/index.rst new file mode 100644 index 00000000..21da1ed6 --- /dev/null +++ b/libs/importlib_metadata/docs/index.rst @@ -0,0 +1,53 @@ +=============================== + Welcome to importlib_metadata +=============================== + +``importlib_metadata`` is a library which provides an API for accessing an +installed package's `metadata`_, such as its entry points or its top-level +name. This functionality intends to replace most uses of ``pkg_resources`` +`entry point API`_ and `metadata API`_. Along with ``importlib.resources`` in +`Python 3.7 and newer`_ (backported as `importlib_resources`_ for older +versions of Python), this can eliminate the need to use the older and less +efficient ``pkg_resources`` package. + +``importlib_metadata`` is a backport of Python 3.8's standard library +`importlib.metadata`_ module for Python 2.7, and 3.4 through 3.7. Users of +Python 3.8 and beyond are encouraged to use the standard library module, and +in fact for these versions, ``importlib_metadata`` just shadows that module. +Developers looking for detailed API descriptions should refer to the Python +3.8 standard library documentation. + +The documentation here includes a general :ref:`usage ` guide. + + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + using.rst + changelog.rst + + +Project details +=============== + + * Project home: https://gitlab.com/python-devs/importlib_metadata + * Report bugs at: https://gitlab.com/python-devs/importlib_metadata/issues + * Code hosting: https://gitlab.com/python-devs/importlib_metadata.git + * Documentation: http://importlib_metadata.readthedocs.io/ + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` + + +.. _`metadata`: https://www.python.org/dev/peps/pep-0566/ +.. _`entry point API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#entry-points +.. _`metadata API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#metadata-api +.. _`Python 3.7 and newer`: https://docs.python.org/3/library/importlib.html#module-importlib.resources +.. _`importlib_resources`: https://importlib-resources.readthedocs.io/en/latest/index.html +.. _`importlib.metadata`: TBD diff --git a/libs/importlib_metadata/docs/using.rst b/libs/importlib_metadata/docs/using.rst new file mode 100644 index 00000000..2af6c822 --- /dev/null +++ b/libs/importlib_metadata/docs/using.rst @@ -0,0 +1,133 @@ +.. _using: + +========================== + Using importlib_metadata +========================== + +``importlib_metadata`` is a library that provides for access to installed +package metadata. Built in part on Python's import system, this library +intends to replace similar functionality in ``pkg_resources`` `entry point +API`_ and `metadata API`_. Along with ``importlib.resources`` in `Python 3.7 +and newer`_ (backported as `importlib_resources`_ for older versions of +Python), this can eliminate the need to use the older and less efficient +``pkg_resources`` package. + +By "installed package" we generally mean a third party package installed into +Python's ``site-packages`` directory via tools such as ``pip``. Specifically, +it means a package with either a discoverable ``dist-info`` or ``egg-info`` +directory, and metadata defined by `PEP 566`_ or its older specifications. +By default, package metadata can live on the file system or in wheels on +``sys.path``. Through an extension mechanism, the metadata can live almost +anywhere. + + +Overview +======== + +Let's say you wanted to get the version string for a package you've installed +using ``pip``. We start by creating a virtual environment and installing +something into it:: + + $ python3 -m venv example + $ source example/bin/activate + (example) $ pip install importlib_metadata + (example) $ pip install wheel + +You can get the version string for ``wheel`` by running the following:: + + (example) $ python + >>> from importlib_metadata import version + >>> version('wheel') + '0.31.1' + +You can also get the set of entry points for the ``wheel`` package. Since the +``entry_points.txt`` file is an ``.ini``-style, the ``entry_points()`` +function returns a `ConfigParser instance`_. To get the list of command line +entry points, extract the ``console_scripts`` section:: + + >>> cp = entry_points('wheel') + >>> cp.options('console_scripts') + ['wheel'] + +You can also get the callable that the entry point is mapped to:: + + >>> cp.get('console_scripts', 'wheel') + 'wheel.tool:main' + +Even more conveniently, you can resolve this entry point to the actual +callable:: + + >>> from importlib_metadata import resolve + >>> ep = cp.get('console_scripts', 'wheel') + >>> resolve(ep) + + + +Distributions +============= + +While the above API is the most common and convenient usage, you can get all +of that information from the ``Distribution`` class. A ``Distribution`` is an +abstract object that represents the metadata for a Python package. You can +get the ``Distribution`` instance:: + + >>> from importlib_metadata import distribution + >>> dist = distribution('wheel') + +Thus, an alternative way to get the version number is through the +``Distribution`` instance:: + + >>> dist.version + '0.31.1' + +There are all kinds of additional metadata available on the ``Distribution`` +instance:: + + >>> d.metadata['Requires-Python'] + '>=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*' + >>> d.metadata['License'] + 'MIT' + +The full set of available metadata is not described here. See PEP 566 for +additional details. + + +Extending the search algorithm +============================== + +Because package metadata is not available through ``sys.path`` searches, or +package loaders directly, the metadata for a package is found through import +system `finders`_. To find a distribution package's metadata, +``importlib_metadata`` queries the list of `meta path finders`_ on +`sys.meta_path`_. + +By default ``importlib_metadata`` installs a finder for packages found on the +file system. This finder doesn't actually find any *packages*, but it cany +find the package's metadata. + +The abstract class :py:class:`importlib.abc.MetaPathFinder` defines the +interface expected of finders by Python's import system. +``importlib_metadata`` extends this protocol by looking for an optional +``find_distribution()`` ``@classmethod`` on the finders from +``sys.meta_path``. If the finder has this method, it takes a single argument +which is the name of the distribution package to find. The method returns +``None`` if it cannot find the distribution package, otherwise it returns an +instance of the ``Distribution`` abstract class. + +What this means in practice is that to support finding distribution package +metadata in locations other than the file system, you should derive from +``Distribution`` and implement the ``load_metadata()`` method. This takes a +single argument which is the name of the package whose metadata is being +found. This instance of the ``Distribution`` base abstract class is what your +finder's ``find_distribution()`` method should return. + + +.. _`entry point API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#entry-points +.. _`metadata API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#metadata-api +.. _`Python 3.7 and newer`: https://docs.python.org/3/library/importlib.html#module-importlib.resources +.. _`importlib_resources`: https://importlib-resources.readthedocs.io/en/latest/index.html +.. _`PEP 566`: https://www.python.org/dev/peps/pep-0566/ +.. _`ConfigParser instance`: https://docs.python.org/3/library/configparser.html#configparser.ConfigParser +.. _`finders`: https://docs.python.org/3/reference/import.html#finders-and-loaders +.. _`meta path finders`: https://docs.python.org/3/glossary.html#term-meta-path-finder +.. _`sys.meta_path`: https://docs.python.org/3/library/sys.html#sys.meta_path diff --git a/libs/importlib_metadata/tests/__init__.py b/libs/importlib_metadata/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/importlib_metadata/tests/data/__init__.py b/libs/importlib_metadata/tests/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/importlib_metadata/tests/test_api.py b/libs/importlib_metadata/tests/test_api.py new file mode 100644 index 00000000..82c61f51 --- /dev/null +++ b/libs/importlib_metadata/tests/test_api.py @@ -0,0 +1,44 @@ +import re +import unittest + +import importlib_metadata + + +class APITests(unittest.TestCase): + version_pattern = r'\d+\.\d+(\.\d)?' + + def test_retrieves_version_of_self(self): + version = importlib_metadata.version('importlib_metadata') + assert isinstance(version, str) + assert re.match(self.version_pattern, version) + + def test_retrieves_version_of_pip(self): + # Assume pip is installed and retrieve the version of pip. + version = importlib_metadata.version('pip') + assert isinstance(version, str) + assert re.match(self.version_pattern, version) + + def test_for_name_does_not_exist(self): + with self.assertRaises(importlib_metadata.PackageNotFoundError): + importlib_metadata.distribution('does-not-exist') + + def test_for_top_level(self): + distribution = importlib_metadata.distribution('importlib_metadata') + self.assertEqual( + distribution.read_text('top_level.txt').strip(), + 'importlib_metadata') + + def test_entry_points(self): + parser = importlib_metadata.entry_points('pip') + # We should probably not be dependent on a third party package's + # internal API staying stable. + entry_point = parser.get('console_scripts', 'pip') + self.assertEqual(entry_point, 'pip._internal:main') + + def test_metadata_for_this_package(self): + md = importlib_metadata.metadata('importlib_metadata') + assert md['author'] == 'Barry Warsaw' + assert md['LICENSE'] == 'Apache Software License' + assert md['Name'] == 'importlib-metadata' + classifiers = md.get_all('Classifier') + assert 'Topic :: Software Development :: Libraries' in classifiers diff --git a/libs/importlib_metadata/tests/test_main.py b/libs/importlib_metadata/tests/test_main.py new file mode 100644 index 00000000..381e4dae --- /dev/null +++ b/libs/importlib_metadata/tests/test_main.py @@ -0,0 +1,121 @@ +from __future__ import unicode_literals + +import re +import sys +import shutil +import tempfile +import unittest +import importlib +import contextlib +import importlib_metadata + +try: + from contextlib import ExitStack +except ImportError: + from contextlib2 import ExitStack + +try: + import pathlib +except ImportError: + import pathlib2 as pathlib + +from importlib_metadata import _hooks + + +class BasicTests(unittest.TestCase): + version_pattern = r'\d+\.\d+(\.\d)?' + + def test_retrieves_version_of_pip(self): + # Assume pip is installed and retrieve the version of pip. + dist = importlib_metadata.Distribution.from_name('pip') + assert isinstance(dist.version, str) + assert re.match(self.version_pattern, dist.version) + + def test_for_name_does_not_exist(self): + with self.assertRaises(importlib_metadata.PackageNotFoundError): + importlib_metadata.Distribution.from_name('does-not-exist') + + def test_new_style_classes(self): + self.assertIsInstance(importlib_metadata.Distribution, type) + self.assertIsInstance(_hooks.MetadataPathFinder, type) + self.assertIsInstance(_hooks.WheelMetadataFinder, type) + self.assertIsInstance(_hooks.WheelDistribution, type) + + +class ImportTests(unittest.TestCase): + def test_import_nonexistent_module(self): + # Ensure that the MetadataPathFinder does not crash an import of a + # non-existant module. + with self.assertRaises(ImportError): + importlib.import_module('does_not_exist') + + def test_resolve(self): + entry_points = importlib_metadata.entry_points('pip') + main = importlib_metadata.resolve( + entry_points.get('console_scripts', 'pip')) + import pip._internal + self.assertEqual(main, pip._internal.main) + + def test_resolve_invalid(self): + self.assertRaises(ValueError, importlib_metadata.resolve, 'bogus.ep') + + +class NameNormalizationTests(unittest.TestCase): + @staticmethod + def pkg_with_dashes(site_dir): + """ + Create minimal metadata for a package with dashes + in the name (and thus underscores in the filename). + """ + metadata_dir = site_dir / 'my_pkg.dist-info' + metadata_dir.mkdir() + metadata = metadata_dir / 'METADATA' + with metadata.open('w') as strm: + strm.write('Version: 1.0\n') + return 'my-pkg' + + @staticmethod + @contextlib.contextmanager + def site_dir(): + tmpdir = tempfile.mkdtemp() + sys.path[:0] = [tmpdir] + try: + yield pathlib.Path(tmpdir) + finally: + sys.path.remove(tmpdir) + shutil.rmtree(tmpdir) + + def setUp(self): + self.fixtures = ExitStack() + self.addCleanup(self.fixtures.close) + self.site_dir = self.fixtures.enter_context(self.site_dir()) + + def test_dashes_in_dist_name_found_as_underscores(self): + """ + For a package with a dash in the name, the dist-info metadata + uses underscores in the name. Ensure the metadata loads. + """ + pkg_name = self.pkg_with_dashes(self.site_dir) + assert importlib_metadata.version(pkg_name) == '1.0' + + @staticmethod + def pkg_with_mixed_case(site_dir): + """ + Create minimal metadata for a package with mixed case + in the name. + """ + metadata_dir = site_dir / 'CherryPy.dist-info' + metadata_dir.mkdir() + metadata = metadata_dir / 'METADATA' + with metadata.open('w') as strm: + strm.write('Version: 1.0\n') + return 'CherryPy' + + def test_dist_name_found_as_any_case(self): + """ + Ensure the metadata loads when queried with any case. + """ + pkg_name = self.pkg_with_mixed_case(self.site_dir) + assert importlib_metadata.version(pkg_name) == '1.0' + assert importlib_metadata.version(pkg_name.lower()) == '1.0' + assert importlib_metadata.version(pkg_name.upper()) == '1.0' diff --git a/libs/importlib_metadata/tests/test_zip.py b/libs/importlib_metadata/tests/test_zip.py new file mode 100644 index 00000000..7bdf55a9 --- /dev/null +++ b/libs/importlib_metadata/tests/test_zip.py @@ -0,0 +1,42 @@ +import sys +import unittest +import importlib_metadata + +try: + from contextlib import ExitStack +except ImportError: + from contextlib2 import ExitStack + +from importlib_resources import path + + +class BespokeLoader: + archive = 'bespoke' + + +class TestZip(unittest.TestCase): + def setUp(self): + # Find the path to the example.*.whl so we can add it to the front of + # sys.path, where we'll then try to find the metadata thereof. + self.resources = ExitStack() + self.addCleanup(self.resources.close) + wheel = self.resources.enter_context( + path('importlib_metadata.tests.data', + 'example-21.12-py3-none-any.whl')) + sys.path.insert(0, str(wheel)) + self.resources.callback(sys.path.pop, 0) + + def test_zip_version(self): + self.assertEqual(importlib_metadata.version('example'), '21.12') + + def test_zip_entry_points(self): + parser = importlib_metadata.entry_points('example') + entry_point = parser.get('console_scripts', 'example') + self.assertEqual(entry_point, 'example:main') + + def test_missing_metadata(self): + distribution = importlib_metadata.distribution('example') + self.assertIsNone(distribution.read_text('does not exist')) + + def test_case_insensitive(self): + self.assertEqual(importlib_metadata.version('Example'), '21.12') diff --git a/libs/importlib_metadata/version.txt b/libs/importlib_metadata/version.txt new file mode 100644 index 00000000..eb49d7c7 --- /dev/null +++ b/libs/importlib_metadata/version.txt @@ -0,0 +1 @@ +0.7 diff --git a/libs/jaraco.classes-1.5-py3.6-nspkg.pth b/libs/jaraco.classes-1.5-py3.6-nspkg.pth new file mode 100644 index 00000000..61cb14f9 --- /dev/null +++ b/libs/jaraco.classes-1.5-py3.6-nspkg.pth @@ -0,0 +1 @@ +import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/jaraco.collections-1.3.2-py3.5-nspkg.pth b/libs/jaraco.collections-1.3.2-py3.5-nspkg.pth deleted file mode 100644 index c8127a57..00000000 --- a/libs/jaraco.collections-1.3.2-py3.5-nspkg.pth +++ /dev/null @@ -1 +0,0 @@ -import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/jaraco.collections-1.6.0-py3.7-nspkg.pth b/libs/jaraco.collections-1.6.0-py3.7-nspkg.pth new file mode 100644 index 00000000..61cb14f9 --- /dev/null +++ b/libs/jaraco.collections-1.6.0-py3.7-nspkg.pth @@ -0,0 +1 @@ +import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/jaraco.functools-1.11-py2.7-nspkg.pth b/libs/jaraco.functools-1.11-py2.7-nspkg.pth deleted file mode 100644 index c8127a57..00000000 --- a/libs/jaraco.functools-1.11-py2.7-nspkg.pth +++ /dev/null @@ -1 +0,0 @@ -import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/jaraco.functools-1.20-py3.6-nspkg.pth b/libs/jaraco.functools-1.20-py3.6-nspkg.pth new file mode 100644 index 00000000..61cb14f9 --- /dev/null +++ b/libs/jaraco.functools-1.20-py3.6-nspkg.pth @@ -0,0 +1 @@ +import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/jaraco.structures-1.1.2-py3.6-nspkg.pth b/libs/jaraco.structures-1.1.2-py3.6-nspkg.pth new file mode 100644 index 00000000..61cb14f9 --- /dev/null +++ b/libs/jaraco.structures-1.1.2-py3.6-nspkg.pth @@ -0,0 +1 @@ +import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/jaraco.text-1.10.1-py3.6-nspkg.pth b/libs/jaraco.text-1.10.1-py3.6-nspkg.pth new file mode 100644 index 00000000..61cb14f9 --- /dev/null +++ b/libs/jaraco.text-1.10.1-py3.6-nspkg.pth @@ -0,0 +1 @@ +import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/jaraco.text-1.7-py3.5-nspkg.pth b/libs/jaraco.text-1.7-py3.5-nspkg.pth deleted file mode 100644 index c8127a57..00000000 --- a/libs/jaraco.text-1.7-py3.5-nspkg.pth +++ /dev/null @@ -1 +0,0 @@ -import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/jaraco.ui-1.6-py3.6-nspkg.pth b/libs/jaraco.ui-1.6-py3.6-nspkg.pth new file mode 100644 index 00000000..61cb14f9 --- /dev/null +++ b/libs/jaraco.ui-1.6-py3.6-nspkg.pth @@ -0,0 +1 @@ +import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/jaraco.windows-3.6-py3.5-nspkg.pth b/libs/jaraco.windows-3.6-py3.5-nspkg.pth deleted file mode 100644 index c8127a57..00000000 --- a/libs/jaraco.windows-3.6-py3.5-nspkg.pth +++ /dev/null @@ -1 +0,0 @@ -import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/jaraco.windows-3.9.2-py3.7-nspkg.pth b/libs/jaraco.windows-3.9.2-py3.7-nspkg.pth new file mode 100644 index 00000000..61cb14f9 --- /dev/null +++ b/libs/jaraco.windows-3.9.2-py3.7-nspkg.pth @@ -0,0 +1 @@ +import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/jaraco/__init__.py b/libs/jaraco/__init__.py deleted file mode 100644 index 5284146e..00000000 --- a/libs/jaraco/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__import__("pkg_resources").declare_namespace(__name__) diff --git a/libs/jaraco/classes/ancestry.py b/libs/jaraco/classes/ancestry.py index 905c18fd..040ce612 100644 --- a/libs/jaraco/classes/ancestry.py +++ b/libs/jaraco/classes/ancestry.py @@ -5,6 +5,7 @@ of an object and its parent classes. from __future__ import unicode_literals + def all_bases(c): """ return a tuple of all base classes the class c has as a parent. @@ -13,6 +14,7 @@ def all_bases(c): """ return c.mro()[1:] + def all_classes(c): """ return a tuple of all classes to which c belongs @@ -21,7 +23,10 @@ def all_classes(c): """ return c.mro() -# borrowed from http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/ +# borrowed from +# http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/ + + def iter_subclasses(cls, _seen=None): """ Generator over all subclasses of a given class, in depth-first order. @@ -51,9 +56,12 @@ def iter_subclasses(cls, _seen=None): """ if not isinstance(cls, type): - raise TypeError('iter_subclasses must be called with ' - 'new-style classes, not %.100r' % cls) - if _seen is None: _seen = set() + raise TypeError( + 'iter_subclasses must be called with ' + 'new-style classes, not %.100r' % cls + ) + if _seen is None: + _seen = set() try: subs = cls.__subclasses__() except TypeError: # fails only when cls is type diff --git a/libs/jaraco/classes/meta.py b/libs/jaraco/classes/meta.py index cdb744d7..c26f7dc2 100644 --- a/libs/jaraco/classes/meta.py +++ b/libs/jaraco/classes/meta.py @@ -6,6 +6,7 @@ Some useful metaclasses. from __future__ import unicode_literals + class LeafClassesMeta(type): """ A metaclass for classes that keeps track of all of them that diff --git a/libs/jaraco/classes/properties.py b/libs/jaraco/classes/properties.py index d64262a3..57f9054f 100644 --- a/libs/jaraco/classes/properties.py +++ b/libs/jaraco/classes/properties.py @@ -2,8 +2,10 @@ from __future__ import unicode_literals import six +__metaclass__ = type -class NonDataProperty(object): + +class NonDataProperty: """Much like the property builtin, but only implements __get__, making it a non-data property, and can be subsequently reset. @@ -34,7 +36,7 @@ class NonDataProperty(object): # from http://stackoverflow.com/a/5191224 -class ClassPropertyDescriptor(object): +class ClassPropertyDescriptor: def __init__(self, fget, fset=None): self.fget = fget diff --git a/libs/jaraco/collections.py b/libs/jaraco/collections.py index 6af6ad45..bb463deb 100644 --- a/libs/jaraco/collections.py +++ b/libs/jaraco/collections.py @@ -7,12 +7,63 @@ import operator import collections import itertools import copy +import functools + +try: + import collections.abc +except ImportError: + # Python 2.7 + collections.abc = collections import six from jaraco.classes.properties import NonDataProperty import jaraco.text +class Projection(collections.abc.Mapping): + """ + Project a set of keys over a mapping + + >>> sample = {'a': 1, 'b': 2, 'c': 3} + >>> prj = Projection(['a', 'c', 'd'], sample) + >>> prj == {'a': 1, 'c': 3} + True + + Keys should only appear if they were specified and exist in the space. + + >>> sorted(list(prj.keys())) + ['a', 'c'] + + Use the projection to update another dict. + + >>> target = {'a': 2, 'b': 2} + >>> target.update(prj) + >>> target == {'a': 1, 'b': 2, 'c': 3} + True + + Also note that Projection keeps a reference to the original dict, so + if you modify the original dict, that could modify the Projection. + + >>> del sample['a'] + >>> dict(prj) + {'c': 3} + """ + def __init__(self, keys, space): + self._keys = tuple(keys) + self._space = space + + def __getitem__(self, key): + if key not in self._keys: + raise KeyError(key) + return self._space[key] + + def __iter__(self): + return iter(set(self._keys).intersection(self._space)) + + def __len__(self): + return len(tuple(iter(self))) + + class DictFilter(object): """ Takes a dict, and simulates a sub-dict based on the keys. @@ -52,7 +103,6 @@ class DictFilter(object): self.pattern_keys = set() def get_pattern_keys(self): - #key_matches = lambda k, v: self.include_pattern.match(k) keys = filter(self.include_pattern.match, self.dict.keys()) return set(keys) pattern_keys = NonDataProperty(get_pattern_keys) @@ -70,7 +120,7 @@ class DictFilter(object): return values def __getitem__(self, i): - if not i in self.include_keys: + if i not in self.include_keys: return KeyError, i return self.dict[i] @@ -162,7 +212,7 @@ class RangeMap(dict): >>> r.get(7, 'not found') 'not found' """ - def __init__(self, source, sort_params = {}, key_match_comparator = operator.le): + def __init__(self, source, sort_params={}, key_match_comparator=operator.le): dict.__init__(self, source) self.sort_params = sort_params self.match = key_match_comparator @@ -190,7 +240,7 @@ class RangeMap(dict): return default def _find_first_match_(self, keys, item): - is_match = lambda k: self.match(item, k) + is_match = functools.partial(self.match, item) matches = list(filter(is_match, keys)) if matches: return matches[0] @@ -205,12 +255,15 @@ class RangeMap(dict): # some special values for the RangeMap undefined_value = type(str('RangeValueUndefined'), (object,), {})() - class Item(int): pass + + class Item(int): + "RangeMap Item" first_item = Item(0) last_item = Item(-1) -__identity = lambda x: x +def __identity(x): + return x def sorted_items(d, key=__identity, reverse=False): @@ -229,7 +282,8 @@ def sorted_items(d, key=__identity, reverse=False): (('foo', 20), ('baz', 10), ('bar', 42)) """ # wrap the key func so it operates on the first element of each item - pairkey_key = lambda item: key(item[0]) + def pairkey_key(item): + return key(item[0]) return sorted(d.items(), key=pairkey_key, reverse=reverse) @@ -414,7 +468,11 @@ class ItemsAsAttributes(object): It also works on dicts that customize __getitem__ >>> missing_func = lambda self, key: 'missing item' - >>> C = type(str('C'), (dict, ItemsAsAttributes), dict(__missing__ = missing_func)) + >>> C = type( + ... str('C'), + ... (dict, ItemsAsAttributes), + ... dict(__missing__ = missing_func), + ... ) >>> i = C() >>> i.missing 'missing item' @@ -428,6 +486,7 @@ class ItemsAsAttributes(object): # attempt to get the value from the mapping (return self[key]) # but be careful not to lose the original exception context. noval = object() + def _safe_getitem(cont, key, missing_result): try: return cont[key] @@ -460,7 +519,7 @@ def invert_map(map): ... ValueError: Key conflict in inverted mapping """ - res = dict((v,k) for k, v in map.items()) + res = dict((v, k) for k, v in map.items()) if not len(res) == len(map): raise ValueError('Key conflict in inverted mapping') return res @@ -483,7 +542,7 @@ class IdentityOverrideMap(dict): return key -class DictStack(list, collections.Mapping): +class DictStack(list, collections.abc.Mapping): """ A stack of dictionaries that behaves as a view on those dictionaries, giving preference to the last. @@ -506,6 +565,7 @@ class DictStack(list, collections.Mapping): >>> d = stack.pop() >>> stack['a'] 1 + >>> stack.get('b', None) """ def keys(self): @@ -513,7 +573,8 @@ class DictStack(list, collections.Mapping): def __getitem__(self, key): for scope in reversed(self): - if key in scope: return scope[key] + if key in scope: + return scope[key] raise KeyError(key) push = list.append @@ -553,6 +614,10 @@ class BijectiveMap(dict): Traceback (most recent call last): ValueError: Key/Value pairs may not overlap + >>> m['e'] = 'd' + Traceback (most recent call last): + ValueError: Key/Value pairs may not overlap + >>> print(m.pop('d')) c @@ -583,7 +648,12 @@ class BijectiveMap(dict): def __setitem__(self, item, value): if item == value: raise ValueError("Key cannot map to itself") - if (value in self or item in self) and self[item] != value: + overlap = ( + item in self and self[item] != value + or + value in self and self[value] != item + ) + if overlap: raise ValueError("Key/Value pairs may not overlap") super(BijectiveMap, self).__setitem__(item, value) super(BijectiveMap, self).__setitem__(value, item) @@ -607,7 +677,7 @@ class BijectiveMap(dict): self.__setitem__(*item) -class FrozenDict(collections.Mapping, collections.Hashable): +class FrozenDict(collections.abc.Mapping, collections.abc.Hashable): """ An immutable mapping. @@ -641,8 +711,8 @@ class FrozenDict(collections.Mapping, collections.Hashable): >>> isinstance(copy.copy(a), FrozenDict) True - FrozenDict supplies .copy(), even though collections.Mapping doesn't - demand it. + FrozenDict supplies .copy(), even though + collections.abc.Mapping doesn't demand it. >>> a.copy() == a True @@ -747,6 +817,9 @@ class Everything(object): >>> import random >>> random.randint(1, 999) in Everything() True + + >>> random.choice([None, 'foo', 42, ('a', 'b', 'c')]) in Everything() + True """ def __contains__(self, other): return True @@ -771,3 +844,63 @@ class InstrumentedDict(six.moves.UserDict): def __init__(self, data): six.moves.UserDict.__init__(self) self.data = data + + +class Least(object): + """ + A value that is always lesser than any other + + >>> least = Least() + >>> 3 < least + False + >>> 3 > least + True + >>> least < 3 + True + >>> least <= 3 + True + >>> least > 3 + False + >>> 'x' > least + True + >>> None > least + True + """ + + def __le__(self, other): + return True + __lt__ = __le__ + + def __ge__(self, other): + return False + __gt__ = __ge__ + + +class Greatest(object): + """ + A value that is always greater than any other + + >>> greatest = Greatest() + >>> 3 < greatest + True + >>> 3 > greatest + False + >>> greatest < 3 + False + >>> greatest > 3 + True + >>> greatest >= 3 + True + >>> 'x' > greatest + False + >>> None > greatest + False + """ + + def __ge__(self, other): + return True + __gt__ = __ge__ + + def __le__(self, other): + return False + __lt__ = __le__ diff --git a/libs/jaraco/functools.py b/libs/jaraco/functools.py index d9ccf3a6..134102a7 100644 --- a/libs/jaraco/functools.py +++ b/libs/jaraco/functools.py @@ -1,8 +1,16 @@ -from __future__ import absolute_import, unicode_literals, print_function, division +from __future__ import ( + absolute_import, unicode_literals, print_function, division, +) import functools import time import warnings +import inspect +import collections +from itertools import count + +__metaclass__ = type + try: from functools import lru_cache @@ -16,13 +24,17 @@ except ImportError: warnings.warn("No lru_cache available") +import more_itertools.recipes + + def compose(*funcs): """ Compose any number of unary functions into a single unary function. >>> import textwrap >>> from six import text_type - >>> text_type.strip(textwrap.dedent(compose.__doc__)) == compose(text_type.strip, textwrap.dedent)(compose.__doc__) + >>> stripped = text_type.strip(textwrap.dedent(compose.__doc__)) + >>> compose(text_type.strip, textwrap.dedent)(compose.__doc__) == stripped True Compose also allows the innermost function to take arbitrary arguments. @@ -33,7 +45,8 @@ def compose(*funcs): [1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7] """ - compose_two = lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)) + def compose_two(f1, f2): + return lambda *args, **kwargs: f1(f2(*args, **kwargs)) return functools.reduce(compose_two, funcs) @@ -60,19 +73,36 @@ def once(func): This decorator can ensure that an expensive or non-idempotent function will not be expensive on subsequent calls and is idempotent. - >>> func = once(lambda a: a+3) - >>> func(3) + >>> add_three = once(lambda a: a+3) + >>> add_three(3) 6 - >>> func(9) + >>> add_three(9) 6 - >>> func('12') + >>> add_three('12') 6 + + To reset the stored value, simply clear the property ``saved_result``. + + >>> del add_three.saved_result + >>> add_three(9) + 12 + >>> add_three(8) + 12 + + Or invoke 'reset()' on it. + + >>> add_three.reset() + >>> add_three(-3) + 0 + >>> add_three(0) + 0 """ @functools.wraps(func) def wrapper(*args, **kwargs): - if not hasattr(func, 'always_returns'): - func.always_returns = func(*args, **kwargs) - return func.always_returns + if not hasattr(wrapper, 'saved_result'): + wrapper.saved_result = func(*args, **kwargs) + return wrapper.saved_result + wrapper.reset = lambda: vars(wrapper).__delitem__('saved_result') return wrapper @@ -131,17 +161,22 @@ def method_cache(method, cache_wrapper=None): >>> a.method2() 3 + Caution - do not subsequently wrap the method with another decorator, such + as ``@property``, which changes the semantics of the function. + See also http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/ for another implementation and additional justification. """ cache_wrapper = cache_wrapper or lru_cache() + def wrapper(self, *args, **kwargs): # it's the first call, replace the method with a cached, bound method bound_method = functools.partial(method, self) cached_method = cache_wrapper(bound_method) setattr(self, method.__name__, cached_method) return cached_method(*args, **kwargs) + return _special_method_cache(method, cache_wrapper) or wrapper @@ -191,6 +226,29 @@ def apply(transform): return wrap +def result_invoke(action): + r""" + Decorate a function with an action function that is + invoked on the results returned from the decorated + function (for its side-effect), then return the original + result. + + >>> @result_invoke(print) + ... def add_two(a, b): + ... return a + b + >>> x = add_two(2, 3) + 5 + """ + def wrap(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + action(result) + return result + return wrapper + return wrap + + def call_aside(f, *args, **kwargs): """ Call a function for its side effect after initialization. @@ -211,7 +269,7 @@ def call_aside(f, *args, **kwargs): return f -class Throttler(object): +class Throttler: """ Rate-limit a function (or other callable) """ @@ -259,10 +317,143 @@ def retry_call(func, cleanup=lambda: None, retries=0, trap=()): exception. On the final attempt, allow any exceptions to propagate. """ - for attempt in range(retries): + attempts = count() if retries == float('inf') else range(retries) + for attempt in attempts: try: return func() except trap: cleanup() return func() + + +def retry(*r_args, **r_kwargs): + """ + Decorator wrapper for retry_call. Accepts arguments to retry_call + except func and then returns a decorator for the decorated function. + + Ex: + + >>> @retry(retries=3) + ... def my_func(a, b): + ... "this is my funk" + ... print(a, b) + >>> my_func.__doc__ + 'this is my funk' + """ + def decorate(func): + @functools.wraps(func) + def wrapper(*f_args, **f_kwargs): + bound = functools.partial(func, *f_args, **f_kwargs) + return retry_call(bound, *r_args, **r_kwargs) + return wrapper + return decorate + + +def print_yielded(func): + """ + Convert a generator into a function that prints all yielded elements + + >>> @print_yielded + ... def x(): + ... yield 3; yield None + >>> x() + 3 + None + """ + print_all = functools.partial(map, print) + print_results = compose(more_itertools.recipes.consume, print_all, func) + return functools.wraps(func)(print_results) + + +def pass_none(func): + """ + Wrap func so it's not called if its first param is None + + >>> print_text = pass_none(print) + >>> print_text('text') + text + >>> print_text(None) + """ + @functools.wraps(func) + def wrapper(param, *args, **kwargs): + if param is not None: + return func(param, *args, **kwargs) + return wrapper + + +def assign_params(func, namespace): + """ + Assign parameters from namespace where func solicits. + + >>> def func(x, y=3): + ... print(x, y) + >>> assigned = assign_params(func, dict(x=2, z=4)) + >>> assigned() + 2 3 + + The usual errors are raised if a function doesn't receive + its required parameters: + + >>> assigned = assign_params(func, dict(y=3, z=4)) + >>> assigned() + Traceback (most recent call last): + TypeError: func() ...argument... + """ + try: + sig = inspect.signature(func) + params = sig.parameters.keys() + except AttributeError: + spec = inspect.getargspec(func) + params = spec.args + call_ns = { + k: namespace[k] + for k in params + if k in namespace + } + return functools.partial(func, **call_ns) + + +def save_method_args(method): + """ + Wrap a method such that when it is called, the args and kwargs are + saved on the method. + + >>> class MyClass: + ... @save_method_args + ... def method(self, a, b): + ... print(a, b) + >>> my_ob = MyClass() + >>> my_ob.method(1, 2) + 1 2 + >>> my_ob._saved_method.args + (1, 2) + >>> my_ob._saved_method.kwargs + {} + >>> my_ob.method(a=3, b='foo') + 3 foo + >>> my_ob._saved_method.args + () + >>> my_ob._saved_method.kwargs == dict(a=3, b='foo') + True + + The arguments are stored on the instance, allowing for + different instance to save different args. + + >>> your_ob = MyClass() + >>> your_ob.method({str('x'): 3}, b=[4]) + {'x': 3} [4] + >>> your_ob._saved_method.args + ({'x': 3},) + >>> my_ob._saved_method.args + () + """ + args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs') + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + attr_name = '_saved_' + method.__name__ + attr = args_and_kwargs(args, kwargs) + setattr(self, attr_name, attr) + return method(self, *args, **kwargs) + return wrapper diff --git a/libs/jaraco/structures/binary.py b/libs/jaraco/structures/binary.py index e4db2c65..be57cc76 100644 --- a/libs/jaraco/structures/binary.py +++ b/libs/jaraco/structures/binary.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, unicode_literals +import numbers from functools import reduce @@ -25,6 +26,7 @@ def get_bit_values(number, size=32): number += 2**size return list(map(int, bin(number)[-size:])) + def gen_bit_values(number): """ Return a zero or one for each bit of a numeric value up to the most @@ -36,6 +38,7 @@ def gen_bit_values(number): digits = bin(number)[2:] return map(int, reversed(digits)) + def coalesce(bits): """ Take a sequence of bits, most significant first, and @@ -47,6 +50,7 @@ def coalesce(bits): operation = lambda a, b: (a << 1 | b) return reduce(operation, bits) + class Flags(object): """ Subclasses should define _names, a list of flag names beginning @@ -96,6 +100,7 @@ class Flags(object): index = self._names.index(key) return self._values[index] + class BitMask(type): """ A metaclass to create a bitmask with attributes. Subclass an int and @@ -119,12 +124,28 @@ class BitMask(type): >>> b2 = MyBits(8) >>> any([b2.a, b2.b, b2.c]) False + + If the instance defines methods, they won't be wrapped in + properties. + + >>> ns['get_value'] = classmethod(lambda cls: 'some value') + >>> ns['prop'] = property(lambda self: 'a property') + >>> MyBits = BitMask(str('MyBits'), (int,), ns) + + >>> MyBits(3).get_value() + 'some value' + >>> MyBits(3).prop + 'a property' """ def __new__(cls, name, bases, attrs): + def make_property(name, value): + if name.startswith('_') or not isinstance(value, numbers.Number): + return value + return property(lambda self, value=value: bool(self & value)) + newattrs = dict( - (attr, property(lambda self, value=value: bool(self & value))) - for attr, value in attrs.items() - if not attr.startswith('_') + (name, make_property(name, value)) + for name, value in attrs.items() ) return type.__new__(cls, name, bases, newattrs) diff --git a/libs/jaraco/text.py b/libs/jaraco/text.py index c459e6e0..71b4b0bc 100644 --- a/libs/jaraco/text.py +++ b/libs/jaraco/text.py @@ -39,6 +39,7 @@ class FoldedCase(six.text_type): """ A case insensitive string class; behaves just like str except compares equal when the only variation is case. + >>> s = FoldedCase('hello world') >>> s == 'Hello World' @@ -47,6 +48,9 @@ class FoldedCase(six.text_type): >>> 'Hello World' == s True + >>> s != 'Hello World' + False + >>> s.index('O') 4 @@ -55,6 +59,38 @@ class FoldedCase(six.text_type): >>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta'])) ['alpha', 'Beta', 'GAMMA'] + + Sequence membership is straightforward. + + >>> "Hello World" in [s] + True + >>> s in ["Hello World"] + True + + You may test for set inclusion, but candidate and elements + must both be folded. + + >>> FoldedCase("Hello World") in {s} + True + >>> s in {FoldedCase("Hello World")} + True + + String inclusion works as long as the FoldedCase object + is on the right. + + >>> "hello" in FoldedCase("Hello World") + True + + But not if the FoldedCase object is on the left: + + >>> FoldedCase('hello') in 'Hello World' + False + + In that case, use in_: + + >>> FoldedCase('hello').in_('Hello World') + True + """ def __lt__(self, other): return self.lower() < other.lower() @@ -65,14 +101,23 @@ class FoldedCase(six.text_type): def __eq__(self, other): return self.lower() == other.lower() + def __ne__(self, other): + return self.lower() != other.lower() + def __hash__(self): return hash(self.lower()) + def __contains__(self, other): + return super(FoldedCase, self).lower().__contains__(other.lower()) + + def in_(self, other): + "Does self appear in other?" + return self in FoldedCase(other) + # cache lower since it's likely to be called frequently. + @jaraco.functools.method_cache def lower(self): - self._lower = super(FoldedCase, self).lower() - self.lower = lambda: self._lower - return self._lower + return super(FoldedCase, self).lower() def index(self, sub): return self.lower().index(sub.lower()) @@ -147,6 +192,7 @@ def is_decodable(value): return False return True + def is_binary(value): """ Return True if the value appears to be binary (that is, it's a byte @@ -154,6 +200,7 @@ def is_binary(value): """ return isinstance(value, bytes) and not is_decodable(value) + def trim(s): r""" Trim something like a docstring to remove the whitespace that @@ -164,8 +211,10 @@ def trim(s): """ return textwrap.dedent(s).strip() + class Splitter(object): """object that will split a string with the given arguments for each call + >>> s = Splitter(',') >>> s('hello, world, this is your, master calling') ['hello', ' world', ' this is your', ' master calling'] @@ -176,9 +225,11 @@ class Splitter(object): def __call__(self, s): return s.split(*self.args) + def indent(string, prefix=' ' * 4): return prefix + string + class WordSet(tuple): """ Given a Python identifier, return the words that identifier represents, @@ -269,6 +320,7 @@ class WordSet(tuple): def from_class_name(cls, subject): return cls.parse(subject.__class__.__name__) + # for backward compatibility words = WordSet.parse @@ -318,6 +370,7 @@ class SeparatedValues(six.text_type): parts = self.split(self.separator) return six.moves.filter(None, (part.strip() for part in parts)) + class Stripper: r""" Given a series of lines, find the common prefix and strip it from them. @@ -369,3 +422,31 @@ class Stripper: while s1[:index] != s2[:index]: index -= 1 return s1[:index] + + +def remove_prefix(text, prefix): + """ + Remove the prefix from the text if it exists. + + >>> remove_prefix('underwhelming performance', 'underwhelming ') + 'performance' + + >>> remove_prefix('something special', 'sample') + 'something special' + """ + null, prefix, rest = text.rpartition(prefix) + return rest + + +def remove_suffix(text, suffix): + """ + Remove the suffix from the text if it exists. + + >>> remove_suffix('name.git', '.git') + 'name' + + >>> remove_suffix('something special', 'sample') + 'something special' + """ + rest, suffix, null = text.partition(suffix) + return rest diff --git a/libs/jaraco/ui/cmdline.py b/libs/jaraco/ui/cmdline.py index 0634f21d..a7982ddb 100644 --- a/libs/jaraco/ui/cmdline.py +++ b/libs/jaraco/ui/cmdline.py @@ -60,3 +60,18 @@ class Command(object): cls.add_subparsers(parser) args = parser.parse_args() args.action.run(args) + + +class Extend(argparse.Action): + """ + Argparse action to take an nargs=* argument + and add any values to the existing value. + + >>> parser = argparse.ArgumentParser() + >>> _ = parser.add_argument('--foo', nargs='*', default=[], action=Extend) + >>> args = parser.parse_args(['--foo', 'a=1', '--foo', 'b=2', 'c=3']) + >>> args.foo + ['a=1', 'b=2', 'c=3'] + """ + def __call__(self, parser, namespace, values, option_string=None): + getattr(namespace, self.dest).extend(values) diff --git a/libs/jaraco/ui/progress.py b/libs/jaraco/ui/progress.py index a00adf47..d083310b 100644 --- a/libs/jaraco/ui/progress.py +++ b/libs/jaraco/ui/progress.py @@ -1,3 +1,5 @@ +# deprecated -- use TQDM + from __future__ import (print_function, absolute_import, unicode_literals, division) diff --git a/libs/jaraco/windows/api/clipboard.py b/libs/jaraco/windows/api/clipboard.py index ac6e4dd8..d871aaa9 100644 --- a/libs/jaraco/windows/api/clipboard.py +++ b/libs/jaraco/windows/api/clipboard.py @@ -45,3 +45,9 @@ GetClipboardData.restype = ctypes.wintypes.HANDLE SetClipboardData = ctypes.windll.user32.SetClipboardData SetClipboardData.argtypes = ctypes.wintypes.UINT, ctypes.wintypes.HANDLE SetClipboardData.restype = ctypes.wintypes.HANDLE + +OpenClipboard = ctypes.windll.user32.OpenClipboard +OpenClipboard.argtypes = ctypes.wintypes.HANDLE, +OpenClipboard.restype = ctypes.wintypes.BOOL + +ctypes.windll.user32.CloseClipboard.restype = ctypes.wintypes.BOOL diff --git a/libs/jaraco/windows/api/credential.py b/libs/jaraco/windows/api/credential.py index c8bf1399..003c3cb3 100644 --- a/libs/jaraco/windows/api/credential.py +++ b/libs/jaraco/windows/api/credential.py @@ -10,9 +10,11 @@ try: except ImportError: LPBYTE = ctypes.POINTER(ctypes.wintypes.BYTE) + class CredentialAttribute(ctypes.Structure): _fields_ = [] + class Credential(ctypes.Structure): _fields_ = [ ('flags', DWORD), @@ -32,28 +34,29 @@ class Credential(ctypes.Structure): def __del__(self): ctypes.windll.advapi32.CredFree(ctypes.byref(self)) + PCREDENTIAL = ctypes.POINTER(Credential) CredRead = ctypes.windll.advapi32.CredReadW CredRead.argtypes = ( - LPCWSTR, # TargetName - DWORD, # Type - DWORD, # Flags - ctypes.POINTER(PCREDENTIAL), # Credential + LPCWSTR, # TargetName + DWORD, # Type + DWORD, # Flags + ctypes.POINTER(PCREDENTIAL), # Credential ) CredRead.restype = BOOL CredWrite = ctypes.windll.advapi32.CredWriteW CredWrite.argtypes = ( - PCREDENTIAL, # Credential - DWORD, # Flags + PCREDENTIAL, # Credential + DWORD, # Flags ) CredWrite.restype = BOOL CredDelete = ctypes.windll.advapi32.CredDeleteW CredDelete.argtypes = ( - LPCWSTR, # TargetName - DWORD, # Type - DWORD, # Flags + LPCWSTR, # TargetName + DWORD, # Type + DWORD, # Flags ) CredDelete.restype = BOOL diff --git a/libs/jaraco/windows/api/environ.py b/libs/jaraco/windows/api/environ.py index b4fb3e41..f394da02 100644 --- a/libs/jaraco/windows/api/environ.py +++ b/libs/jaraco/windows/api/environ.py @@ -2,7 +2,7 @@ import ctypes.wintypes SetEnvironmentVariable = ctypes.windll.kernel32.SetEnvironmentVariableW SetEnvironmentVariable.restype = ctypes.wintypes.BOOL -SetEnvironmentVariable.argtypes = [ctypes.wintypes.LPCWSTR]*2 +SetEnvironmentVariable.argtypes = [ctypes.wintypes.LPCWSTR] * 2 GetEnvironmentVariable = ctypes.windll.kernel32.GetEnvironmentVariableW GetEnvironmentVariable.restype = ctypes.wintypes.BOOL diff --git a/libs/jaraco/windows/api/event.py b/libs/jaraco/windows/api/event.py index 4b141a31..5d2818c6 100644 --- a/libs/jaraco/windows/api/event.py +++ b/libs/jaraco/windows/api/event.py @@ -1,16 +1,11 @@ -from ctypes import ( - Structure, windll, POINTER, byref, cast, create_unicode_buffer, - c_size_t, c_int, create_string_buffer, c_uint64, c_ushort, c_short, - c_uint, - ) +from ctypes import windll, POINTER from ctypes.wintypes import ( - BOOLEAN, LPWSTR, DWORD, LPVOID, HANDLE, FILETIME, - WCHAR, BOOL, HWND, WORD, UINT, - ) + LPWSTR, DWORD, LPVOID, HANDLE, BOOL, +) CreateEvent = windll.kernel32.CreateEventW CreateEvent.argtypes = ( - LPVOID, # LPSECURITY_ATTRIBUTES + LPVOID, # LPSECURITY_ATTRIBUTES BOOL, BOOL, LPWSTR, @@ -29,13 +24,15 @@ _WaitForMultipleObjects = windll.kernel32.WaitForMultipleObjects _WaitForMultipleObjects.argtypes = DWORD, POINTER(HANDLE), BOOL, DWORD _WaitForMultipleObjects.restype = DWORD + def WaitForMultipleObjects(handles, wait_all=False, timeout=0): n_handles = len(handles) - handle_array = (HANDLE*n_handles)() + handle_array = (HANDLE * n_handles)() for index, handle in enumerate(handles): handle_array[index] = handle return _WaitForMultipleObjects(n_handles, handle_array, wait_all, timeout) + WAIT_OBJECT_0 = 0 INFINITE = -1 WAIT_TIMEOUT = 0x102 diff --git a/libs/jaraco/windows/api/filesystem.py b/libs/jaraco/windows/api/filesystem.py index 5b3cdbb8..fbd999de 100644 --- a/libs/jaraco/windows/api/filesystem.py +++ b/libs/jaraco/windows/api/filesystem.py @@ -5,15 +5,15 @@ CreateSymbolicLink.argtypes = ( ctypes.wintypes.LPWSTR, ctypes.wintypes.LPWSTR, ctypes.wintypes.DWORD, - ) +) CreateSymbolicLink.restype = ctypes.wintypes.BOOLEAN CreateHardLink = ctypes.windll.kernel32.CreateHardLinkW CreateHardLink.argtypes = ( ctypes.wintypes.LPWSTR, ctypes.wintypes.LPWSTR, - ctypes.wintypes.LPVOID, # reserved for LPSECURITY_ATTRIBUTES - ) + ctypes.wintypes.LPVOID, # reserved for LPSECURITY_ATTRIBUTES +) CreateHardLink.restype = ctypes.wintypes.BOOLEAN GetFileAttributes = ctypes.windll.kernel32.GetFileAttributesW @@ -28,16 +28,20 @@ MAX_PATH = 260 GetFinalPathNameByHandle = ctypes.windll.kernel32.GetFinalPathNameByHandleW GetFinalPathNameByHandle.argtypes = ( - ctypes.wintypes.HANDLE, ctypes.wintypes.LPWSTR, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD, - ) + ctypes.wintypes.HANDLE, ctypes.wintypes.LPWSTR, ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, +) GetFinalPathNameByHandle.restype = ctypes.wintypes.DWORD + class SECURITY_ATTRIBUTES(ctypes.Structure): _fields_ = ( ('length', ctypes.wintypes.DWORD), ('p_security_descriptor', ctypes.wintypes.LPVOID), ('inherit_handle', ctypes.wintypes.BOOLEAN), - ) + ) + + LPSECURITY_ATTRIBUTES = ctypes.POINTER(SECURITY_ATTRIBUTES) CreateFile = ctypes.windll.kernel32.CreateFileW @@ -49,7 +53,7 @@ CreateFile.argtypes = ( ctypes.wintypes.DWORD, ctypes.wintypes.DWORD, ctypes.wintypes.HANDLE, - ) +) CreateFile.restype = ctypes.wintypes.HANDLE FILE_SHARE_READ = 1 FILE_SHARE_WRITE = 2 @@ -77,23 +81,61 @@ CloseHandle = ctypes.windll.kernel32.CloseHandle CloseHandle.argtypes = (ctypes.wintypes.HANDLE,) CloseHandle.restype = ctypes.wintypes.BOOLEAN -class WIN32_FIND_DATA(ctypes.Structure): + +class WIN32_FIND_DATA(ctypes.wintypes.WIN32_FIND_DATAW): + """ _fields_ = [ - ('file_attributes', ctypes.wintypes.DWORD), - ('creation_time', ctypes.wintypes.FILETIME), - ('last_access_time', ctypes.wintypes.FILETIME), - ('last_write_time', ctypes.wintypes.FILETIME), - ('file_size_words', ctypes.wintypes.DWORD*2), - ('reserved', ctypes.wintypes.DWORD*2), - ('filename', ctypes.wintypes.WCHAR*MAX_PATH), - ('alternate_filename', ctypes.wintypes.WCHAR*14), + ("dwFileAttributes", DWORD), + ("ftCreationTime", FILETIME), + ("ftLastAccessTime", FILETIME), + ("ftLastWriteTime", FILETIME), + ("nFileSizeHigh", DWORD), + ("nFileSizeLow", DWORD), + ("dwReserved0", DWORD), + ("dwReserved1", DWORD), + ("cFileName", WCHAR * MAX_PATH), + ("cAlternateFileName", WCHAR * 14)] ] + """ + + @property + def file_attributes(self): + return self.dwFileAttributes + + @property + def creation_time(self): + return self.ftCreationTime + + @property + def last_access_time(self): + return self.ftLastAccessTime + + @property + def last_write_time(self): + return self.ftLastWriteTime + + @property + def file_size_words(self): + return [self.nFileSizeHigh, self.nFileSizeLow] + + @property + def reserved(self): + return [self.dwReserved0, self.dwReserved1] + + @property + def filename(self): + return self.cFileName + + @property + def alternate_filename(self): + return self.cAlternateFileName @property def file_size(self): - return ctypes.cast(self.file_size_words, ctypes.POINTER(ctypes.c_uint64)).contents + return self.nFileSizeHigh << 32 + self.nFileSizeLow -LPWIN32_FIND_DATA = ctypes.POINTER(WIN32_FIND_DATA) + +LPWIN32_FIND_DATA = ctypes.POINTER(ctypes.wintypes.WIN32_FIND_DATAW) FindFirstFile = ctypes.windll.kernel32.FindFirstFileW FindFirstFile.argtypes = (ctypes.wintypes.LPWSTR, LPWIN32_FIND_DATA) @@ -102,19 +144,56 @@ FindNextFile = ctypes.windll.kernel32.FindNextFileW FindNextFile.argtypes = (ctypes.wintypes.HANDLE, LPWIN32_FIND_DATA) FindNextFile.restype = ctypes.wintypes.BOOLEAN -SCS_32BIT_BINARY = 0 # A 32-bit Windows-based application -SCS_64BIT_BINARY = 6 # A 64-bit Windows-based application -SCS_DOS_BINARY = 1 # An MS-DOS-based application -SCS_OS216_BINARY = 5 # A 16-bit OS/2-based application -SCS_PIF_BINARY = 3 # A PIF file that executes an MS-DOS-based application -SCS_POSIX_BINARY = 4 # A POSIX-based application -SCS_WOW_BINARY = 2 # A 16-bit Windows-based application +ctypes.windll.kernel32.FindClose.argtypes = ctypes.wintypes.HANDLE, + +SCS_32BIT_BINARY = 0 # A 32-bit Windows-based application +SCS_64BIT_BINARY = 6 # A 64-bit Windows-based application +SCS_DOS_BINARY = 1 # An MS-DOS-based application +SCS_OS216_BINARY = 5 # A 16-bit OS/2-based application +SCS_PIF_BINARY = 3 # A PIF file that executes an MS-DOS-based application +SCS_POSIX_BINARY = 4 # A POSIX-based application +SCS_WOW_BINARY = 2 # A 16-bit Windows-based application _GetBinaryType = ctypes.windll.kernel32.GetBinaryTypeW -_GetBinaryType.argtypes = (ctypes.wintypes.LPWSTR, ctypes.POINTER(ctypes.wintypes.DWORD)) +_GetBinaryType.argtypes = ( + ctypes.wintypes.LPWSTR, ctypes.POINTER(ctypes.wintypes.DWORD), +) _GetBinaryType.restype = ctypes.wintypes.BOOL FILEOP_FLAGS = ctypes.wintypes.WORD + + +class BY_HANDLE_FILE_INFORMATION(ctypes.Structure): + _fields_ = [ + ('file_attributes', ctypes.wintypes.DWORD), + ('creation_time', ctypes.wintypes.FILETIME), + ('last_access_time', ctypes.wintypes.FILETIME), + ('last_write_time', ctypes.wintypes.FILETIME), + ('volume_serial_number', ctypes.wintypes.DWORD), + ('file_size_high', ctypes.wintypes.DWORD), + ('file_size_low', ctypes.wintypes.DWORD), + ('number_of_links', ctypes.wintypes.DWORD), + ('file_index_high', ctypes.wintypes.DWORD), + ('file_index_low', ctypes.wintypes.DWORD), + ] + + @property + def file_size(self): + return (self.file_size_high << 32) + self.file_size_low + + @property + def file_index(self): + return (self.file_index_high << 32) + self.file_index_low + + +GetFileInformationByHandle = ctypes.windll.kernel32.GetFileInformationByHandle +GetFileInformationByHandle.restype = ctypes.wintypes.BOOL +GetFileInformationByHandle.argtypes = ( + ctypes.wintypes.HANDLE, + ctypes.POINTER(BY_HANDLE_FILE_INFORMATION), +) + + class SHFILEOPSTRUCT(ctypes.Structure): _fields_ = [ ('status_dialog', ctypes.wintypes.HWND), @@ -126,6 +205,8 @@ class SHFILEOPSTRUCT(ctypes.Structure): ('name_mapping_handles', ctypes.wintypes.LPVOID), ('progress_title', ctypes.wintypes.LPWSTR), ] + + _SHFileOperation = ctypes.windll.shell32.SHFileOperationW _SHFileOperation.argtypes = [ctypes.POINTER(SHFILEOPSTRUCT)] _SHFileOperation.restype = ctypes.c_int @@ -143,12 +224,13 @@ ReplaceFile.argtypes = [ ctypes.wintypes.DWORD, ctypes.wintypes.LPVOID, ctypes.wintypes.LPVOID, - ] +] REPLACEFILE_WRITE_THROUGH = 0x1 REPLACEFILE_IGNORE_MERGE_ERRORS = 0x2 REPLACEFILE_IGNORE_ACL_ERRORS = 0x4 + class STAT_STRUCT(ctypes.Structure): _fields_ = [ ('dev', ctypes.c_uint), @@ -165,17 +247,22 @@ class STAT_STRUCT(ctypes.Structure): ('ctime', ctypes.c_uint), ] + _wstat = ctypes.windll.msvcrt._wstat _wstat.argtypes = [ctypes.wintypes.LPWSTR, ctypes.POINTER(STAT_STRUCT)] _wstat.restype = ctypes.c_int FILE_NOTIFY_CHANGE_LAST_WRITE = 0x10 -FindFirstChangeNotification = ctypes.windll.kernel32.FindFirstChangeNotificationW -FindFirstChangeNotification.argtypes = ctypes.wintypes.LPWSTR, ctypes.wintypes.BOOL, ctypes.wintypes.DWORD +FindFirstChangeNotification = ( + ctypes.windll.kernel32.FindFirstChangeNotificationW) +FindFirstChangeNotification.argtypes = ( + ctypes.wintypes.LPWSTR, ctypes.wintypes.BOOL, ctypes.wintypes.DWORD, +) FindFirstChangeNotification.restype = ctypes.wintypes.HANDLE -FindCloseChangeNotification = ctypes.windll.kernel32.FindCloseChangeNotification +FindCloseChangeNotification = ( + ctypes.windll.kernel32.FindCloseChangeNotification) FindCloseChangeNotification.argtypes = ctypes.wintypes.HANDLE, FindCloseChangeNotification.restype = ctypes.wintypes.BOOL @@ -200,9 +287,10 @@ DeviceIoControl.argtypes = [ ctypes.wintypes.DWORD, LPDWORD, LPOVERLAPPED, - ] +] DeviceIoControl.restype = ctypes.wintypes.BOOL + class REPARSE_DATA_BUFFER(ctypes.Structure): _fields_ = [ ('tag', ctypes.c_ulong), @@ -213,16 +301,17 @@ class REPARSE_DATA_BUFFER(ctypes.Structure): ('print_name_offset', ctypes.c_ushort), ('print_name_length', ctypes.c_ushort), ('flags', ctypes.c_ulong), - ('path_buffer', ctypes.c_byte*1), + ('path_buffer', ctypes.c_byte * 1), ] + def get_print_name(self): wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR) - arr_typ = ctypes.wintypes.WCHAR*(self.print_name_length//wchar_size) + arr_typ = ctypes.wintypes.WCHAR * (self.print_name_length // wchar_size) data = ctypes.byref(self.path_buffer, self.print_name_offset) return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value def get_substitute_name(self): wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR) - arr_typ = ctypes.wintypes.WCHAR*(self.substitute_name_length//wchar_size) + arr_typ = ctypes.wintypes.WCHAR * (self.substitute_name_length // wchar_size) data = ctypes.byref(self.path_buffer, self.substitute_name_offset) return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value diff --git a/libs/jaraco/windows/api/inet.py b/libs/jaraco/windows/api/inet.py index 9d821e8d..36c8e37c 100644 --- a/libs/jaraco/windows/api/inet.py +++ b/libs/jaraco/windows/api/inet.py @@ -2,6 +2,7 @@ import struct import ctypes.wintypes from ctypes.wintypes import DWORD, WCHAR, BYTE, BOOL + # from mprapi.h MAX_INTERFACE_NAME_LEN = 2**8 @@ -13,15 +14,16 @@ MAXLEN_IFDESCR = 2**8 MAX_ADAPTER_ADDRESS_LENGTH = 8 MAX_DHCPV6_DUID_LENGTH = 130 + class MIB_IFROW(ctypes.Structure): _fields_ = ( - ('name', WCHAR*MAX_INTERFACE_NAME_LEN), + ('name', WCHAR * MAX_INTERFACE_NAME_LEN), ('index', DWORD), ('type', DWORD), ('MTU', DWORD), ('speed', DWORD), ('physical_address_length', DWORD), - ('physical_address_raw', BYTE*MAXLEN_PHYSADDR), + ('physical_address_raw', BYTE * MAXLEN_PHYSADDR), ('admin_status', DWORD), ('operational_status', DWORD), ('last_change', DWORD), @@ -38,7 +40,7 @@ class MIB_IFROW(ctypes.Structure): ('outgoing_errors', DWORD), ('outgoing_queue_length', DWORD), ('description_length', DWORD), - ('description_raw', ctypes.c_char*MAXLEN_IFDESCR), + ('description_raw', ctypes.c_char * MAXLEN_IFDESCR), ) def _get_binary_property(self, name): @@ -46,7 +48,7 @@ class MIB_IFROW(ctypes.Structure): val = getattr(self, val_prop) len_prop = '{0}_length'.format(name) length = getattr(self, len_prop) - return str(buffer(val))[:length] + return str(memoryview(val))[:length] @property def physical_address(self): @@ -56,12 +58,14 @@ class MIB_IFROW(ctypes.Structure): def description(self): return self._get_binary_property('description') + class MIB_IFTABLE(ctypes.Structure): _fields_ = ( ('num_entries', DWORD), # dwNumEntries - ('entries', MIB_IFROW*0), # table + ('entries', MIB_IFROW * 0), # table ) + class MIB_IPADDRROW(ctypes.Structure): _fields_ = ( ('address_num', DWORD), @@ -79,40 +83,49 @@ class MIB_IPADDRROW(ctypes.Structure): _ = struct.pack('L', self.address_num) return struct.unpack('!L', _)[0] + class MIB_IPADDRTABLE(ctypes.Structure): _fields_ = ( ('num_entries', DWORD), - ('entries', MIB_IPADDRROW*0), + ('entries', MIB_IPADDRROW * 0), ) + class SOCKADDR(ctypes.Structure): _fields_ = ( ('family', ctypes.c_ushort), - ('data', ctypes.c_byte*14), - ) + ('data', ctypes.c_byte * 14), + ) + + LPSOCKADDR = ctypes.POINTER(SOCKADDR) + class SOCKET_ADDRESS(ctypes.Structure): _fields_ = [ ('address', LPSOCKADDR), ('length', ctypes.c_int), - ] + ] + class _IP_ADAPTER_ADDRESSES_METRIC(ctypes.Structure): _fields_ = [ ('length', ctypes.c_ulong), ('interface_index', DWORD), - ] + ] + class _IP_ADAPTER_ADDRESSES_U1(ctypes.Union): _fields_ = [ ('alignment', ctypes.c_ulonglong), ('metric', _IP_ADAPTER_ADDRESSES_METRIC), - ] + ] + class IP_ADAPTER_ADDRESSES(ctypes.Structure): pass + LP_IP_ADAPTER_ADDRESSES = ctypes.POINTER(IP_ADAPTER_ADDRESSES) # for now, just use void * for pointers to unused structures @@ -125,19 +138,20 @@ PIP_ADAPTER_WINS_SERVER_ADDRESS_LH = ctypes.c_void_p PIP_ADAPTER_GATEWAY_ADDRESS_LH = ctypes.c_void_p PIP_ADAPTER_DNS_SUFFIX = ctypes.c_void_p -IF_OPER_STATUS = ctypes.c_uint # this is an enum, consider http://code.activestate.com/recipes/576415/ +IF_OPER_STATUS = ctypes.c_uint # this is an enum, consider +# http://code.activestate.com/recipes/576415/ IF_LUID = ctypes.c_uint64 NET_IF_COMPARTMENT_ID = ctypes.c_uint32 -GUID = ctypes.c_byte*16 +GUID = ctypes.c_byte * 16 NET_IF_NETWORK_GUID = GUID -NET_IF_CONNECTION_TYPE = ctypes.c_uint # enum -TUNNEL_TYPE = ctypes.c_uint # enum +NET_IF_CONNECTION_TYPE = ctypes.c_uint # enum +TUNNEL_TYPE = ctypes.c_uint # enum IP_ADAPTER_ADDRESSES._fields_ = [ - #('u', _IP_ADAPTER_ADDRESSES_U1), - ('length', ctypes.c_ulong), - ('interface_index', DWORD), + # ('u', _IP_ADAPTER_ADDRESSES_U1), + ('length', ctypes.c_ulong), + ('interface_index', DWORD), ('next', LP_IP_ADAPTER_ADDRESSES), ('adapter_name', ctypes.c_char_p), ('first_unicast_address', PIP_ADAPTER_UNICAST_ADDRESS), @@ -147,7 +161,7 @@ IP_ADAPTER_ADDRESSES._fields_ = [ ('dns_suffix', ctypes.c_wchar_p), ('description', ctypes.c_wchar_p), ('friendly_name', ctypes.c_wchar_p), - ('byte', BYTE*MAX_ADAPTER_ADDRESS_LENGTH), + ('byte', BYTE * MAX_ADAPTER_ADDRESS_LENGTH), ('physical_address_length', DWORD), ('flags', DWORD), ('mtu', DWORD), @@ -169,11 +183,11 @@ IP_ADAPTER_ADDRESSES._fields_ = [ ('connection_type', NET_IF_CONNECTION_TYPE), ('tunnel_type', TUNNEL_TYPE), ('dhcpv6_server', SOCKET_ADDRESS), - ('dhcpv6_client_duid', ctypes.c_byte*MAX_DHCPV6_DUID_LENGTH), + ('dhcpv6_client_duid', ctypes.c_byte * MAX_DHCPV6_DUID_LENGTH), ('dhcpv6_client_duid_length', ctypes.c_ulong), ('dhcpv6_iaid', ctypes.c_ulong), ('first_dns_suffix', PIP_ADAPTER_DNS_SUFFIX), - ] +] # define some parameters to the API Functions GetIfTable = ctypes.windll.iphlpapi.GetIfTable @@ -181,7 +195,7 @@ GetIfTable.argtypes = [ ctypes.POINTER(MIB_IFTABLE), ctypes.POINTER(ctypes.c_ulong), BOOL, - ] +] GetIfTable.restype = DWORD GetIpAddrTable = ctypes.windll.iphlpapi.GetIpAddrTable @@ -189,7 +203,7 @@ GetIpAddrTable.argtypes = [ ctypes.POINTER(MIB_IPADDRTABLE), ctypes.POINTER(ctypes.c_ulong), BOOL, - ] +] GetIpAddrTable.restype = DWORD GetAdaptersAddresses = ctypes.windll.iphlpapi.GetAdaptersAddresses @@ -199,5 +213,5 @@ GetAdaptersAddresses.argtypes = [ ctypes.c_void_p, ctypes.POINTER(IP_ADAPTER_ADDRESSES), ctypes.POINTER(ctypes.c_ulong), - ] +] GetAdaptersAddresses.restype = ctypes.c_ulong diff --git a/libs/jaraco/windows/api/memory.py b/libs/jaraco/windows/api/memory.py index 0e25dc4a..0371099c 100644 --- a/libs/jaraco/windows/api/memory.py +++ b/libs/jaraco/windows/api/memory.py @@ -3,7 +3,7 @@ import ctypes.wintypes GMEM_MOVEABLE = 0x2 GlobalAlloc = ctypes.windll.kernel32.GlobalAlloc -GlobalAlloc.argtypes = ctypes.wintypes.UINT, ctypes.c_ssize_t +GlobalAlloc.argtypes = ctypes.wintypes.UINT, ctypes.c_size_t GlobalAlloc.restype = ctypes.wintypes.HANDLE GlobalLock = ctypes.windll.kernel32.GlobalLock @@ -31,3 +31,15 @@ CreateFileMapping.restype = ctypes.wintypes.HANDLE MapViewOfFile = ctypes.windll.kernel32.MapViewOfFile MapViewOfFile.restype = ctypes.wintypes.HANDLE + +UnmapViewOfFile = ctypes.windll.kernel32.UnmapViewOfFile +UnmapViewOfFile.argtypes = ctypes.wintypes.HANDLE, + +RtlMoveMemory = ctypes.windll.kernel32.RtlMoveMemory +RtlMoveMemory.argtypes = ( + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_size_t, +) + +ctypes.windll.kernel32.LocalFree.argtypes = ctypes.wintypes.HLOCAL, diff --git a/libs/jaraco/windows/api/message.py b/libs/jaraco/windows/api/message.py index 8fddcba2..5ce2d808 100644 --- a/libs/jaraco/windows/api/message.py +++ b/libs/jaraco/windows/api/message.py @@ -13,6 +13,7 @@ import six LRESULT = LPARAM + class LPARAM_wstr(LPARAM): """ A special instance of LPARAM that can be constructed from a string @@ -25,14 +26,16 @@ class LPARAM_wstr(LPARAM): return LPVOID.from_param(six.text_type(param)) return LPARAM.from_param(param) + SendMessage = ctypes.windll.user32.SendMessageW SendMessage.argtypes = (HWND, UINT, WPARAM, LPARAM_wstr) SendMessage.restype = LRESULT -HWND_BROADCAST=0xFFFF -WM_SETTINGCHANGE=0x1A +HWND_BROADCAST = 0xFFFF +WM_SETTINGCHANGE = 0x1A -# constants from http://msdn.microsoft.com/en-us/library/ms644952%28v=vs.85%29.aspx +# constants from http://msdn.microsoft.com +# /en-us/library/ms644952%28v=vs.85%29.aspx SMTO_ABORTIFHUNG = 0x02 SMTO_BLOCK = 0x01 SMTO_NORMAL = 0x00 @@ -45,6 +48,7 @@ SendMessageTimeout.argtypes = SendMessage.argtypes + ( ) SendMessageTimeout.restype = LRESULT + def unicode_as_lparam(source): pointer = ctypes.cast(ctypes.c_wchar_p(source), ctypes.c_void_p) return LPARAM(pointer.value) diff --git a/libs/jaraco/windows/api/net.py b/libs/jaraco/windows/api/net.py index 38918c2d..ce693319 100644 --- a/libs/jaraco/windows/api/net.py +++ b/libs/jaraco/windows/api/net.py @@ -5,6 +5,7 @@ mpr = ctypes.windll.mpr RESOURCETYPE_ANY = 0 + class NETRESOURCE(ctypes.Structure): _fields_ = [ ('scope', ctypes.wintypes.DWORD), @@ -16,6 +17,8 @@ class NETRESOURCE(ctypes.Structure): ('comment', ctypes.wintypes.LPWSTR), ('provider', ctypes.wintypes.LPWSTR), ] + + LPNETRESOURCE = ctypes.POINTER(NETRESOURCE) WNetAddConnection2 = mpr.WNetAddConnection2W diff --git a/libs/jaraco/windows/api/power.py b/libs/jaraco/windows/api/power.py index 279c1133..77253a8a 100644 --- a/libs/jaraco/windows/api/power.py +++ b/libs/jaraco/windows/api/power.py @@ -1,5 +1,6 @@ import ctypes.wintypes + class SYSTEM_POWER_STATUS(ctypes.Structure): _fields_ = ( ('ac_line_status', ctypes.wintypes.BYTE), @@ -8,11 +9,13 @@ class SYSTEM_POWER_STATUS(ctypes.Structure): ('reserved', ctypes.wintypes.BYTE), ('battery_life_time', ctypes.wintypes.DWORD), ('battery_full_life_time', ctypes.wintypes.DWORD), - ) + ) @property def ac_line_status_string(self): - return {0:'offline', 1: 'online', 255: 'unknown'}[self.ac_line_status] + return { + 0: 'offline', 1: 'online', 255: 'unknown'}[self.ac_line_status] + LPSYSTEM_POWER_STATUS = ctypes.POINTER(SYSTEM_POWER_STATUS) GetSystemPowerStatus = ctypes.windll.kernel32.GetSystemPowerStatus @@ -23,6 +26,7 @@ SetThreadExecutionState = ctypes.windll.kernel32.SetThreadExecutionState SetThreadExecutionState.argtypes = [ctypes.c_uint] SetThreadExecutionState.restype = ctypes.c_uint + class ES: """ Execution state constants diff --git a/libs/jaraco/windows/api/privilege.py b/libs/jaraco/windows/api/privilege.py index 0016a7dd..b841311e 100644 --- a/libs/jaraco/windows/api/privilege.py +++ b/libs/jaraco/windows/api/privilege.py @@ -1,5 +1,6 @@ import ctypes.wintypes + class LUID(ctypes.Structure): _fields_ = [ ('low_part', ctypes.wintypes.DWORD), @@ -10,35 +11,39 @@ class LUID(ctypes.Structure): return ( self.high_part == other.high_part and self.low_part == other.low_part - ) + ) def __ne__(self, other): - return not (self==other) + return not (self == other) + LookupPrivilegeValue = ctypes.windll.advapi32.LookupPrivilegeValueW LookupPrivilegeValue.argtypes = ( - ctypes.wintypes.LPWSTR, # system name - ctypes.wintypes.LPWSTR, # name + ctypes.wintypes.LPWSTR, # system name + ctypes.wintypes.LPWSTR, # name ctypes.POINTER(LUID), - ) +) LookupPrivilegeValue.restype = ctypes.wintypes.BOOL + class TOKEN_INFORMATION_CLASS: TokenUser = 1 TokenGroups = 2 TokenPrivileges = 3 # ... see http://msdn.microsoft.com/en-us/library/aa379626%28VS.85%29.aspx + SE_PRIVILEGE_ENABLED_BY_DEFAULT = 0x00000001 SE_PRIVILEGE_ENABLED = 0x00000002 SE_PRIVILEGE_REMOVED = 0x00000004 SE_PRIVILEGE_USED_FOR_ACCESS = 0x80000000 + class LUID_AND_ATTRIBUTES(ctypes.Structure): _fields_ = [ ('LUID', LUID), ('attributes', ctypes.wintypes.DWORD), - ] + ] def is_enabled(self): return bool(self.attributes & SE_PRIVILEGE_ENABLED) @@ -50,47 +55,53 @@ class LUID_AND_ATTRIBUTES(ctypes.Structure): size = ctypes.wintypes.DWORD(10240) buf = ctypes.create_unicode_buffer(size.value) res = LookupPrivilegeName(None, self.LUID, buf, size) - if res == 0: raise RuntimeError + if res == 0: + raise RuntimeError return buf[:size.value] def __str__(self): res = self.get_name() - if self.is_enabled(): res += ' (enabled)' + if self.is_enabled(): + res += ' (enabled)' return res + LookupPrivilegeName = ctypes.windll.advapi32.LookupPrivilegeNameW LookupPrivilegeName.argtypes = ( - ctypes.wintypes.LPWSTR, # lpSystemName - ctypes.POINTER(LUID), # lpLuid - ctypes.wintypes.LPWSTR, # lpName - ctypes.POINTER(ctypes.wintypes.DWORD), # cchName - ) + ctypes.wintypes.LPWSTR, # lpSystemName + ctypes.POINTER(LUID), # lpLuid + ctypes.wintypes.LPWSTR, # lpName + ctypes.POINTER(ctypes.wintypes.DWORD), # cchName +) LookupPrivilegeName.restype = ctypes.wintypes.BOOL + class TOKEN_PRIVILEGES(ctypes.Structure): _fields_ = [ ('count', ctypes.wintypes.DWORD), - ('privileges', LUID_AND_ATTRIBUTES*0), - ] + ('privileges', LUID_AND_ATTRIBUTES * 0), + ] def get_array(self): - array_type = LUID_AND_ATTRIBUTES*self.count - privileges = ctypes.cast(self.privileges, ctypes.POINTER(array_type)).contents + array_type = LUID_AND_ATTRIBUTES * self.count + privileges = ctypes.cast( + self.privileges, ctypes.POINTER(array_type)).contents return privileges def __iter__(self): return iter(self.get_array()) + PTOKEN_PRIVILEGES = ctypes.POINTER(TOKEN_PRIVILEGES) GetTokenInformation = ctypes.windll.advapi32.GetTokenInformation GetTokenInformation.argtypes = [ - ctypes.wintypes.HANDLE, # TokenHandle - ctypes.c_uint, # TOKEN_INFORMATION_CLASS value - ctypes.c_void_p, # TokenInformation - ctypes.wintypes.DWORD, # TokenInformationLength - ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength - ] + ctypes.wintypes.HANDLE, # TokenHandle + ctypes.c_uint, # TOKEN_INFORMATION_CLASS value + ctypes.c_void_p, # TokenInformation + ctypes.wintypes.DWORD, # TokenInformationLength + ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength +] GetTokenInformation.restype = ctypes.wintypes.BOOL # http://msdn.microsoft.com/en-us/library/aa375202%28VS.85%29.aspx @@ -102,5 +113,5 @@ AdjustTokenPrivileges.argtypes = [ PTOKEN_PRIVILEGES, # NewState (optional) ctypes.wintypes.DWORD, # BufferLength of PreviousState PTOKEN_PRIVILEGES, # PreviousState (out, optional) - ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength - ] + ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength +] diff --git a/libs/jaraco/windows/api/process.py b/libs/jaraco/windows/api/process.py index e3a6ae4b..56ce7852 100644 --- a/libs/jaraco/windows/api/process.py +++ b/libs/jaraco/windows/api/process.py @@ -5,5 +5,7 @@ TOKEN_ALL_ACCESS = 0xf01ff GetCurrentProcess = ctypes.windll.kernel32.GetCurrentProcess GetCurrentProcess.restype = ctypes.wintypes.HANDLE OpenProcessToken = ctypes.windll.advapi32.OpenProcessToken -OpenProcessToken.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.POINTER(ctypes.wintypes.HANDLE)) +OpenProcessToken.argtypes = ( + ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, + ctypes.POINTER(ctypes.wintypes.HANDLE)) OpenProcessToken.restype = ctypes.wintypes.BOOL diff --git a/libs/jaraco/windows/api/security.py b/libs/jaraco/windows/api/security.py index 28a50972..c9e7b58e 100644 --- a/libs/jaraco/windows/api/security.py +++ b/libs/jaraco/windows/api/security.py @@ -60,12 +60,15 @@ POLICY_EXECUTE = ( POLICY_VIEW_LOCAL_INFORMATION | POLICY_LOOKUP_NAMES) + class TokenAccess: TOKEN_QUERY = 0x8 + class TokenInformationClass: TokenUser = 1 + class TOKEN_USER(ctypes.Structure): num = 1 _fields_ = [ @@ -100,6 +103,7 @@ class SECURITY_DESCRIPTOR(ctypes.Structure): ('Dacl', ctypes.c_void_p), ] + class SECURITY_ATTRIBUTES(ctypes.Structure): """ typedef struct _SECURITY_ATTRIBUTES { @@ -126,3 +130,10 @@ class SECURITY_ATTRIBUTES(ctypes.Structure): def descriptor(self, value): self._descriptor = value self.lpSecurityDescriptor = ctypes.addressof(value) + + +ctypes.windll.advapi32.SetSecurityDescriptorOwner.argtypes = ( + ctypes.POINTER(SECURITY_DESCRIPTOR), + ctypes.c_void_p, + ctypes.wintypes.BOOL, +) diff --git a/libs/jaraco/windows/api/shell.py b/libs/jaraco/windows/api/shell.py index 7f89fa72..1d428c87 100644 --- a/libs/jaraco/windows/api/shell.py +++ b/libs/jaraco/windows/api/shell.py @@ -1,6 +1,7 @@ import ctypes.wintypes BOOL = ctypes.wintypes.BOOL + class SHELLSTATE(ctypes.Structure): _fields_ = [ ('show_all_objects', BOOL, 1), @@ -34,6 +35,7 @@ class SHELLSTATE(ctypes.Structure): ('spare_flags', ctypes.wintypes.UINT, 13), ] + SSF_SHOWALLOBJECTS = 0x00000001 "The fShowAllObjects member is being requested." @@ -62,7 +64,13 @@ SSF_SHOWATTRIBCOL = 0x00000100 "The fShowAttribCol member is being requested. (Windows Vista: Not used.)" SSF_DESKTOPHTML = 0x00000200 -"The fDesktopHTML member is being requested. Set is not available. Instead, for versions of Microsoft Windows prior to Windows XP, enable Desktop HTML by IActiveDesktop. The use of IActiveDesktop for this purpose, however, is not recommended for Windows XP and later versions of Windows, and is deprecated in Windows Vista." +""" +The fDesktopHTML member is being requested. Set is not available. +Instead, for versions of Microsoft Windows prior to Windows XP, +enable Desktop HTML by IActiveDesktop. The use of IActiveDesktop +for this purpose, however, is not recommended for Windows XP and +later versions of Windows, and is deprecated in Windows Vista. +""" SSF_WIN95CLASSIC = 0x00000400 "The fWin95Classic member is being requested." @@ -117,6 +125,6 @@ SHGetSetSettings = ctypes.windll.shell32.SHGetSetSettings SHGetSetSettings.argtypes = [ ctypes.POINTER(SHELLSTATE), ctypes.wintypes.DWORD, - ctypes.wintypes.BOOL, # get or set (True: set) - ] + ctypes.wintypes.BOOL, # get or set (True: set) +] SHGetSetSettings.restype = None diff --git a/libs/jaraco/windows/api/system.py b/libs/jaraco/windows/api/system.py index f9439fef..6a09f5ad 100644 --- a/libs/jaraco/windows/api/system.py +++ b/libs/jaraco/windows/api/system.py @@ -6,7 +6,7 @@ SystemParametersInfo.argtypes = ( ctypes.wintypes.UINT, ctypes.c_void_p, ctypes.wintypes.UINT, - ) +) SPI_GETACTIVEWINDOWTRACKING = 0x1000 SPI_SETACTIVEWINDOWTRACKING = 0x1001 diff --git a/libs/jaraco/windows/clipboard.py b/libs/jaraco/windows/clipboard.py index ec677ff9..2f4bbc3a 100644 --- a/libs/jaraco/windows/clipboard.py +++ b/libs/jaraco/windows/clipboard.py @@ -5,20 +5,22 @@ import re import itertools from contextlib import contextmanager import io - -import six import ctypes from ctypes import windll +import six +from six.moves import map + from jaraco.windows.api import clipboard, memory from jaraco.windows.error import handle_nonzero_success, WindowsError from jaraco.windows.memory import LockedMemory __all__ = ( - 'CF_TEXT', 'GetClipboardData', 'CloseClipboard', + 'GetClipboardData', 'CloseClipboard', 'SetClipboardData', 'OpenClipboard', ) + def OpenClipboard(owner=None): """ Open the clipboard. @@ -30,9 +32,14 @@ def OpenClipboard(owner=None): """ handle_nonzero_success(windll.user32.OpenClipboard(owner)) -CloseClipboard = lambda: handle_nonzero_success(windll.user32.CloseClipboard()) + +def CloseClipboard(): + handle_nonzero_success(windll.user32.CloseClipboard()) + data_handlers = dict() + + def handles(*formats): def register(func): for format in formats: @@ -40,36 +47,43 @@ def handles(*formats): return func return register -def nts(s): + +def nts(buffer): """ Null Terminated String - Get the portion of s up to a null character. + Get the portion of bytestring buffer up to a null character. """ - result, null, rest = s.partition('\x00') + result, null, rest = buffer.partition('\x00') return result + @handles(clipboard.CF_DIBV5, clipboard.CF_DIB) def raw_data(handle): return LockedMemory(handle).data + @handles(clipboard.CF_TEXT) def text_string(handle): return nts(raw_data(handle)) + @handles(clipboard.CF_UNICODETEXT) def unicode_string(handle): return nts(raw_data(handle).decode('utf-16')) + @handles(clipboard.CF_BITMAP) def as_bitmap(handle): # handle is HBITMAP raise NotImplementedError("Can't convert to DIB") - # todo: use GetDIBits http://msdn.microsoft.com/en-us/library/dd144879%28v=VS.85%29.aspx + # todo: use GetDIBits http://msdn.microsoft.com + # /en-us/library/dd144879%28v=VS.85%29.aspx + @handles(clipboard.CF_HTML) class HTMLSnippet(object): def __init__(self, handle): - self.data = text_string(handle) + self.data = nts(raw_data(handle).decode('utf-8')) self.headers = self.parse_headers(self.data) @property @@ -79,11 +93,13 @@ class HTMLSnippet(object): @staticmethod def parse_headers(data): d = io.StringIO(data) + def header_line(line): return re.match('(\w+):(.*)', line) - headers = itertools.imap(header_line, d) + headers = map(header_line, d) # grab headers until they no longer match headers = itertools.takewhile(bool, headers) + def best_type(value): try: return int(value) @@ -101,26 +117,34 @@ class HTMLSnippet(object): ) return dict(pairs) + def GetClipboardData(type=clipboard.CF_UNICODETEXT): - if not type in data_handlers: + if type not in data_handlers: raise NotImplementedError("No support for data of type %d" % type) handle = clipboard.GetClipboardData(type) if handle is None: raise TypeError("No clipboard data of type %d" % type) return data_handlers[type](handle) -EmptyClipboard = lambda: handle_nonzero_success(windll.user32.EmptyClipboard()) + +def EmptyClipboard(): + handle_nonzero_success(windll.user32.EmptyClipboard()) + def SetClipboardData(type, content): """ - Modeled after http://msdn.microsoft.com/en-us/library/ms649016%28VS.85%29.aspx#_win32_Copying_Information_to_the_Clipboard + Modeled after http://msdn.microsoft.com + /en-us/library/ms649016%28VS.85%29.aspx + #_win32_Copying_Information_to_the_Clipboard """ allocators = { clipboard.CF_TEXT: ctypes.create_string_buffer, clipboard.CF_UNICODETEXT: ctypes.create_unicode_buffer, + clipboard.CF_HTML: ctypes.create_string_buffer, } - if not type in allocators: - raise NotImplementedError("Only text types are supported at this time") + if type not in allocators: + raise NotImplementedError( + "Only text and HTML types are supported at this time") # allocate the memory for the data content = allocators[type](content) flags = memory.GMEM_MOVEABLE @@ -132,47 +156,57 @@ def SetClipboardData(type, content): if result is None: raise WindowsError() + def set_text(source): with context(): EmptyClipboard() SetClipboardData(clipboard.CF_TEXT, source) + def get_text(): with context(): result = GetClipboardData(clipboard.CF_TEXT) return result + def set_unicode_text(source): with context(): EmptyClipboard() SetClipboardData(clipboard.CF_UNICODETEXT, source) + def get_unicode_text(): with context(): return GetClipboardData() + def get_html(): with context(): result = GetClipboardData(clipboard.CF_HTML) return result + def set_html(source): with context(): EmptyClipboard() SetClipboardData(clipboard.CF_UNICODETEXT, source) + def get_image(): with context(): return GetClipboardData(clipboard.CF_DIB) + def paste_stdout(): getter = get_unicode_text if six.PY3 else get_text sys.stdout.write(getter()) + def stdin_copy(): setter = set_unicode_text if six.PY3 else set_text setter(sys.stdin.read()) + @contextmanager def context(): OpenClipboard() @@ -181,10 +215,12 @@ def context(): finally: CloseClipboard() + def get_formats(): with context(): format_index = 0 while True: format_index = clipboard.EnumClipboardFormats(format_index) - if format_index == 0: break + if format_index == 0: + break yield format_index diff --git a/libs/jaraco/windows/cred.py b/libs/jaraco/windows/cred.py index 61096309..14c3255a 100644 --- a/libs/jaraco/windows/cred.py +++ b/libs/jaraco/windows/cred.py @@ -3,17 +3,20 @@ import ctypes import jaraco.windows.api.credential as api from . import error -CRED_TYPE_GENERIC=1 +CRED_TYPE_GENERIC = 1 + def CredDelete(TargetName, Type, Flags=0): error.handle_nonzero_success(api.CredDelete(TargetName, Type, Flags)) + def CredRead(TargetName, Type, Flags=0): cred_pointer = api.PCREDENTIAL() res = api.CredRead(TargetName, Type, Flags, ctypes.byref(cred_pointer)) error.handle_nonzero_success(res) return cred_pointer.contents + def CredWrite(Credential, Flags=0): res = api.CredWrite(Credential, Flags) error.handle_nonzero_success(res) diff --git a/libs/jaraco/windows/dpapi.py b/libs/jaraco/windows/dpapi.py index 6e0771c4..3b7348aa 100644 --- a/libs/jaraco/windows/dpapi.py +++ b/libs/jaraco/windows/dpapi.py @@ -14,6 +14,11 @@ import ctypes from ctypes import wintypes from jaraco.windows.error import handle_nonzero_success + +# for type declarations +__import__('jaraco.windows.api.memory') + + class DATA_BLOB(ctypes.Structure): r""" A data blob structure for use with MS DPAPI functions. @@ -29,7 +34,7 @@ class DATA_BLOB(ctypes.Structure): _fields_ = [ ('data_size', wintypes.DWORD), ('data', ctypes.c_void_p), - ] + ] def __init__(self, data=None): super(DATA_BLOB, self).__init__() @@ -48,7 +53,7 @@ class DATA_BLOB(ctypes.Structure): def get_data(self): "Get the data for this blob" - array = ctypes.POINTER(ctypes.c_char*len(self)) + array = ctypes.POINTER(ctypes.c_char * len(self)) return ctypes.cast(self.data, array).contents.raw def __len__(self): @@ -65,38 +70,40 @@ class DATA_BLOB(ctypes.Structure): """ ctypes.windll.kernel32.LocalFree(self.data) + p_DATA_BLOB = ctypes.POINTER(DATA_BLOB) _CryptProtectData = ctypes.windll.crypt32.CryptProtectData _CryptProtectData.argtypes = [ p_DATA_BLOB, # data in - wintypes.LPCWSTR, # data description + wintypes.LPCWSTR, # data description p_DATA_BLOB, # optional entropy ctypes.c_void_p, # reserved ctypes.c_void_p, # POINTER(CRYPTPROTECT_PROMPTSTRUCT), # prompt struct wintypes.DWORD, # flags p_DATA_BLOB, # data out - ] +] _CryptProtectData.restype = wintypes.BOOL _CryptUnprotectData = ctypes.windll.crypt32.CryptUnprotectData _CryptUnprotectData.argtypes = [ p_DATA_BLOB, # data in - ctypes.POINTER(wintypes.LPWSTR), # data description + ctypes.POINTER(wintypes.LPWSTR), # data description p_DATA_BLOB, # optional entropy ctypes.c_void_p, # reserved ctypes.c_void_p, # POINTER(CRYPTPROTECT_PROMPTSTRUCT), # prompt struct wintypes.DWORD, # flags p_DATA_BLOB, # data out - ] +] _CryptUnprotectData.restype = wintypes.BOOL CRYPTPROTECT_UI_FORBIDDEN = 0x01 + def CryptProtectData( data, description=None, optional_entropy=None, prompt_struct=None, flags=0, - ): +): """ Encrypt data """ @@ -108,17 +115,19 @@ def CryptProtectData( data_in, description, entropy, - None, # reserved + None, # reserved prompt_struct, flags, data_out, - ) + ) handle_nonzero_success(res) res = data_out.get_data() data_out.free() return res -def CryptUnprotectData(data, optional_entropy=None, prompt_struct=None, flags=0): + +def CryptUnprotectData( + data, optional_entropy=None, prompt_struct=None, flags=0): """ Returns a tuple of (description, data) where description is the the description that was passed to the CryptProtectData call and @@ -132,11 +141,11 @@ def CryptUnprotectData(data, optional_entropy=None, prompt_struct=None, flags=0) data_in, ctypes.byref(ptr_description), entropy, - None, # reserved + None, # reserved prompt_struct, flags | CRYPTPROTECT_UI_FORBIDDEN, data_out, - ) + ) handle_nonzero_success(res) description = ptr_description.value if ptr_description.value is not None: diff --git a/libs/jaraco/windows/environ.py b/libs/jaraco/windows/environ.py index 54c925f9..a014ae35 100644 --- a/libs/jaraco/windows/environ.py +++ b/libs/jaraco/windows/environ.py @@ -7,7 +7,7 @@ import ctypes import ctypes.wintypes import six -winreg = six.moves.winreg +from six.moves import winreg from jaraco.ui.editor import EditableFile @@ -19,17 +19,21 @@ from .registry import key_values as registry_key_values def SetEnvironmentVariable(name, value): error.handle_nonzero_success(environ.SetEnvironmentVariable(name, value)) + def ClearEnvironmentVariable(name): error.handle_nonzero_success(environ.SetEnvironmentVariable(name, None)) + def GetEnvironmentVariable(name): - max_size = 2**15-1 + max_size = 2**15 - 1 buffer = ctypes.create_unicode_buffer(max_size) - error.handle_nonzero_success(environ.GetEnvironmentVariable(name, buffer, max_size)) + error.handle_nonzero_success( + environ.GetEnvironmentVariable(name, buffer, max_size)) return buffer.value ### + class RegisteredEnvironment(object): """ Manages the environment variables as set in the Windows Registry. @@ -68,7 +72,7 @@ class RegisteredEnvironment(object): return class_.delete(name) do_append = options.append or ( name.upper() in ('PATH', 'PATHEXT') and not options.replace - ) + ) if do_append: sep = ';' values = class_.get_values_list(name, sep) + [value] @@ -86,8 +90,8 @@ class RegisteredEnvironment(object): if value in values: return new_value = sep.join(values + [value]) - winreg.SetValueEx(class_.key, name, 0, winreg.REG_EXPAND_SZ, - new_value) + winreg.SetValueEx( + class_.key, name, 0, winreg.REG_EXPAND_SZ, new_value) class_.notify() @classmethod @@ -98,7 +102,7 @@ class RegisteredEnvironment(object): value for value in values if value_substring.lower() not in value.lower() - ] + ] values = sep.join(new_values) winreg.SetValueEx(class_.key, name, 0, winreg.REG_EXPAND_SZ, values) class_.notify() @@ -133,32 +137,38 @@ class RegisteredEnvironment(object): res = message.SendMessageTimeout( message.HWND_BROADCAST, message.WM_SETTINGCHANGE, - 0, # wparam must be null + 0, # wparam must be null 'Environment', message.SMTO_ABORTIFHUNG, - 5000, # timeout in ms + 5000, # timeout in ms return_val, ) error.handle_nonzero_success(res) + class MachineRegisteredEnvironment(RegisteredEnvironment): path = r'SYSTEM\CurrentControlSet\Control\Session Manager\Environment' hklm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) try: - key = winreg.OpenKey(hklm, path, 0, + key = winreg.OpenKey( + hklm, path, 0, winreg.KEY_READ | winreg.KEY_WRITE) except WindowsError: key = winreg.OpenKey(hklm, path, 0, winreg.KEY_READ) + class UserRegisteredEnvironment(RegisteredEnvironment): hkcu = winreg.ConnectRegistry(None, winreg.HKEY_CURRENT_USER) - key = winreg.OpenKey(hkcu, 'Environment', 0, + key = winreg.OpenKey( + hkcu, 'Environment', 0, winreg.KEY_READ | winreg.KEY_WRITE) + def trim(s): from textwrap import dedent return dedent(s).strip() + def enver(*args): """ %prog [=[value]] @@ -200,23 +210,27 @@ def enver(*args): default=MachineRegisteredEnvironment, dest='class_', help="Use the current user's environment", - ) - parser.add_option('-a', '--append', + ) + parser.add_option( + '-a', '--append', action='store_true', default=False, - help="Append the value to any existing value (default for PATH and PATHEXT)",) + help="Append the value to any existing value (default for PATH and PATHEXT)", + ) parser.add_option( '-r', '--replace', action='store_true', default=False, - help="Replace any existing value (used to override default append for PATH and PATHEXT)", - ) + help="Replace any existing value (used to override default append " + "for PATH and PATHEXT)", + ) parser.add_option( '--remove-value', action='store_true', default=False, - help="Remove any matching values from a semicolon-separated multi-value variable", - ) + help="Remove any matching values from a semicolon-separated " + "multi-value variable", + ) parser.add_option( '-e', '--edit', action='store_true', default=False, help="Edit the value in a local editor", - ) + ) options, args = parser.parse_args(*args) try: @@ -224,7 +238,7 @@ def enver(*args): if args: parser.error("Too many parameters specified") raise SystemExit(1) - if not '=' in param and not options.edit: + if '=' not in param and not options.edit: parser.error("Expected = or =") raise SystemExit(2) name, sep, value = param.partition('=') @@ -238,5 +252,6 @@ def enver(*args): except IndexError: options.class_.show() + if __name__ == '__main__': enver() diff --git a/libs/jaraco/windows/error.py b/libs/jaraco/windows/error.py index f78b9017..48e0b6cc 100644 --- a/libs/jaraco/windows/error.py +++ b/libs/jaraco/windows/error.py @@ -6,9 +6,12 @@ import ctypes import ctypes.wintypes import six - builtins = six.moves.builtins + +__import__('jaraco.windows.api.memory') + + def format_system_message(errno): """ Call FormatMessage with a system error number to retrieve @@ -35,7 +38,7 @@ def format_system_message(errno): ctypes.byref(result_buffer), buffer_size, arguments, - ) + ) # note the following will cause an infinite loop if GetLastError # repeatedly returns an error that cannot be formatted, although # this should not happen. @@ -46,13 +49,16 @@ def format_system_message(errno): class WindowsError(builtins.WindowsError): - "more info about errors at http://msdn.microsoft.com/en-us/library/ms681381(VS.85).aspx" + """ + More info about errors at + http://msdn.microsoft.com/en-us/library/ms681381(VS.85).aspx + """ def __init__(self, value=None): if value is None: value = ctypes.windll.kernel32.GetLastError() strerror = format_system_message(value) - if sys.version_info > (3,3): + if sys.version_info > (3, 3): args = 0, strerror, None, value else: args = value, strerror @@ -72,6 +78,7 @@ class WindowsError(builtins.WindowsError): def __repr__(self): return '{self.__class__.__name__}({self.winerror})'.format(**vars()) + def handle_nonzero_success(result): if result == 0: raise WindowsError() diff --git a/libs/jaraco/windows/eventlog.py b/libs/jaraco/windows/eventlog.py index ea9eed7e..7fed221a 100644 --- a/libs/jaraco/windows/eventlog.py +++ b/libs/jaraco/windows/eventlog.py @@ -6,7 +6,8 @@ import win32api import win32evtlog import win32evtlogutil -error = win32api.error # The error the evtlog module raises. +error = win32api.error # The error the evtlog module raises. + class EventLog(object): def __init__(self, name="Application", machine_name=None): @@ -29,6 +30,7 @@ class EventLog(object): win32evtlog.EVENTLOG_BACKWARDS_READ | win32evtlog.EVENTLOG_SEQUENTIAL_READ ) + def get_records(self, flags=_default_flags): with self: while True: diff --git a/libs/jaraco/windows/filesystem/__init__.py b/libs/jaraco/windows/filesystem/__init__.py index 35f7a4c4..14e8c76d 100644 --- a/libs/jaraco/windows/filesystem/__init__.py +++ b/libs/jaraco/windows/filesystem/__init__.py @@ -7,19 +7,24 @@ import sys import operator import collections import functools -from ctypes import (POINTER, byref, cast, create_unicode_buffer, +import stat +from ctypes import ( + POINTER, byref, cast, create_unicode_buffer, create_string_buffer, windll) +from ctypes.wintypes import LPWSTR +import nt +import posixpath import six from six.moves import builtins, filter, map from jaraco.structures import binary -from jaraco.text import local_format as lf from jaraco.windows.error import WindowsError, handle_nonzero_success import jaraco.windows.api.filesystem as api from jaraco.windows import reparse + def mklink(): """ Like cmd.exe's mklink except it will infer directory status of the @@ -27,7 +32,8 @@ def mklink(): """ from optparse import OptionParser parser = OptionParser(usage="usage: %prog [options] link target") - parser.add_option('-d', '--directory', + parser.add_option( + '-d', '--directory', help="Target is a directory (only necessary if not present)", action="store_true") options, args = parser.parse_args() @@ -38,6 +44,7 @@ def mklink(): symlink(target, link, options.directory) sys.stdout.write("Symbolic link created: %(link)s --> %(target)s\n" % vars()) + def _is_target_a_directory(link, rel_target): """ If creating a symlink from link to a target, determine if target @@ -46,15 +53,20 @@ def _is_target_a_directory(link, rel_target): target = os.path.join(os.path.dirname(link), rel_target) return os.path.isdir(target) -def symlink(target, link, target_is_directory = False): + +def symlink(target, link, target_is_directory=False): """ An implementation of os.symlink for Windows (Vista and greater) """ - target_is_directory = (target_is_directory or - _is_target_a_directory(link, target)) + target_is_directory = ( + target_is_directory or + _is_target_a_directory(link, target) + ) # normalize the target (MS symlinks don't respect forward slashes) target = os.path.normpath(target) - handle_nonzero_success(api.CreateSymbolicLink(link, target, target_is_directory)) + handle_nonzero_success( + api.CreateSymbolicLink(link, target, target_is_directory)) + def link(target, link): """ @@ -62,6 +74,7 @@ def link(target, link): """ handle_nonzero_success(api.CreateHardLink(link, target, None)) + def is_reparse_point(path): """ Determine if the given path is a reparse point. @@ -74,10 +87,12 @@ def is_reparse_point(path): and bool(res & api.FILE_ATTRIBUTE_REPARSE_POINT) ) + def islink(path): "Determine if the given path is a symlink" return is_reparse_point(path) and is_symlink(path) + def _patch_path(path): """ Paths have a max length of api.MAX_PATH characters (260). If a target path @@ -86,13 +101,15 @@ def _patch_path(path): See http://msdn.microsoft.com/en-us/library/aa365247%28v=vs.85%29.aspx for details. """ - if path.startswith('\\\\?\\'): return path + if path.startswith('\\\\?\\'): + return path abs_path = os.path.abspath(path) if not abs_path[1] == ':': # python doesn't include the drive letter, but \\?\ requires it abs_path = os.getcwd()[:2] + abs_path return '\\\\?\\' + abs_path + def is_symlink(path): """ Assuming path is a reparse point, determine if it's a symlink. @@ -102,11 +119,13 @@ def is_symlink(path): return _is_symlink(next(find_files(path))) except WindowsError as orig_error: tmpl = "Error accessing {path}: {orig_error.message}" - raise builtins.WindowsError(lf(tmpl)) + raise builtins.WindowsError(tmpl.format(**locals())) + def _is_symlink(find_data): return find_data.reserved[0] == api.IO_REPARSE_TAG_SYMLINK + def find_files(spec): """ A pythonic wrapper around the FindFirstFile/FindNextFile win32 api. @@ -133,11 +152,13 @@ def find_files(spec): error = WindowsError() if error.code == api.ERROR_NO_MORE_FILES: break - else: raise error + else: + raise error # todo: how to close handle when generator is destroyed? # hint: catch GeneratorExit windll.kernel32.FindClose(handle) + def get_final_path(path): """ For a given path, determine the ultimate location of that path. @@ -150,7 +171,9 @@ def get_final_path(path): trace_symlink_target instead. """ desired_access = api.NULL - share_mode = api.FILE_SHARE_READ | api.FILE_SHARE_WRITE | api.FILE_SHARE_DELETE + share_mode = ( + api.FILE_SHARE_READ | api.FILE_SHARE_WRITE | api.FILE_SHARE_DELETE + ) security_attributes = api.LPSECURITY_ATTRIBUTES() # NULL pointer hFile = api.CreateFile( path, @@ -160,15 +183,17 @@ def get_final_path(path): api.OPEN_EXISTING, api.FILE_FLAG_BACKUP_SEMANTICS, api.NULL, - ) + ) if hFile == api.INVALID_HANDLE_VALUE: raise WindowsError() - buf_size = api.GetFinalPathNameByHandle(hFile, api.LPWSTR(), 0, api.VOLUME_NAME_DOS) + buf_size = api.GetFinalPathNameByHandle( + hFile, LPWSTR(), 0, api.VOLUME_NAME_DOS) handle_nonzero_success(buf_size) buf = create_unicode_buffer(buf_size) - result_length = api.GetFinalPathNameByHandle(hFile, buf, len(buf), api.VOLUME_NAME_DOS) + result_length = api.GetFinalPathNameByHandle( + hFile, buf, len(buf), api.VOLUME_NAME_DOS) assert result_length < len(buf) handle_nonzero_success(result_length) @@ -176,22 +201,83 @@ def get_final_path(path): return buf[:result_length] + +def compat_stat(path): + """ + Generate stat as found on Python 3.2 and later. + """ + stat = os.stat(path) + info = get_file_info(path) + # rewrite st_ino, st_dev, and st_nlink based on file info + return nt.stat_result( + (stat.st_mode,) + + (info.file_index, info.volume_serial_number, info.number_of_links) + + stat[4:] + ) + + +def samefile(f1, f2): + """ + Backport of samefile from Python 3.2 with support for Windows. + """ + return posixpath.samestat(compat_stat(f1), compat_stat(f2)) + + +def get_file_info(path): + # open the file the same way CPython does in posixmodule.c + desired_access = api.FILE_READ_ATTRIBUTES + share_mode = 0 + security_attributes = None + creation_disposition = api.OPEN_EXISTING + flags_and_attributes = ( + api.FILE_ATTRIBUTE_NORMAL | + api.FILE_FLAG_BACKUP_SEMANTICS | + api.FILE_FLAG_OPEN_REPARSE_POINT + ) + template_file = None + + handle = api.CreateFile( + path, + desired_access, + share_mode, + security_attributes, + creation_disposition, + flags_and_attributes, + template_file, + ) + + if handle == api.INVALID_HANDLE_VALUE: + raise WindowsError() + + info = api.BY_HANDLE_FILE_INFORMATION() + res = api.GetFileInformationByHandle(handle, info) + handle_nonzero_success(res) + handle_nonzero_success(api.CloseHandle(handle)) + + return info + + def GetBinaryType(filepath): res = api.DWORD() handle_nonzero_success(api._GetBinaryType(filepath, res)) return res + def _make_null_terminated_list(obs): obs = _makelist(obs) - if obs is None: return + if obs is None: + return return u'\x00'.join(obs) + u'\x00\x00' + def _makelist(ob): - if ob is None: return + if ob is None: + return if not isinstance(ob, (list, tuple, set)): return [ob] return ob + def SHFileOperation(operation, from_, to=None, flags=[]): flags = functools.reduce(operator.or_, flags, 0) from_ = _make_null_terminated_list(from_) @@ -201,6 +287,7 @@ def SHFileOperation(operation, from_, to=None, flags=[]): if res != 0: raise RuntimeError("SHFileOperation returned %d" % res) + def join(*paths): r""" Wrapper around os.path.join that works with Windows drive letters. @@ -214,17 +301,17 @@ def join(*paths): drive = next(filter(None, reversed(drives)), '') return os.path.join(drive, os.path.join(*paths)) + def resolve_path(target, start=os.path.curdir): r""" Find a path from start to target where target is relative to start. - >>> orig_wd = os.getcwd() - >>> os.chdir('c:\\windows') # so we know what the working directory is + >>> tmp = str(getfixture('tmpdir_as_cwd')) >>> findpath('d:\\') 'd:\\' - >>> findpath('d:\\', 'c:\\windows') + >>> findpath('d:\\', tmp) 'd:\\' >>> findpath('\\bar', 'd:\\') @@ -239,11 +326,11 @@ def resolve_path(target, start=os.path.curdir): >>> findpath('\\baz', 'd:\\foo\\bar') # fails with '\\baz' 'd:\\baz' - >>> os.path.abspath(findpath('\\bar')) + >>> os.path.abspath(findpath('\\bar')).lower() 'c:\\bar' >>> os.path.abspath(findpath('bar')) - 'c:\\windows\\bar' + '...\\bar' >>> findpath('..', 'd:\\foo\\bar') 'd:\\foo' @@ -254,8 +341,10 @@ def resolve_path(target, start=os.path.curdir): """ return os.path.normpath(join(start, target)) + findpath = resolve_path + def trace_symlink_target(link): """ Given a file that is known to be a symlink, trace it to its ultimate @@ -273,6 +362,7 @@ def trace_symlink_target(link): link = resolve_path(link, orig) return link + def readlink(link): """ readlink(link) -> target @@ -286,12 +376,13 @@ def readlink(link): api.OPEN_EXISTING, api.FILE_FLAG_OPEN_REPARSE_POINT | api.FILE_FLAG_BACKUP_SEMANTICS, None, - ) + ) if handle == api.INVALID_HANDLE_VALUE: raise WindowsError() - res = reparse.DeviceIoControl(handle, api.FSCTL_GET_REPARSE_POINT, None, 10240) + res = reparse.DeviceIoControl( + handle, api.FSCTL_GET_REPARSE_POINT, None, 10240) bytes = create_string_buffer(res) p_rdb = cast(bytes, POINTER(api.REPARSE_DATA_BUFFER)) @@ -302,6 +393,7 @@ def readlink(link): handle_nonzero_success(api.CloseHandle(handle)) return rdb.get_substitute_name() + def patch_os_module(): """ jaraco.windows provides the os.symlink and os.readlink functions. @@ -313,6 +405,7 @@ def patch_os_module(): if not hasattr(os, 'readlink'): os.readlink = readlink + def find_symlinks(root): for dirpath, dirnames, filenames in os.walk(root): for name in dirnames + filenames: @@ -323,6 +416,7 @@ def find_symlinks(root): if name in dirnames: dirnames.remove(name) + def find_symlinks_cmd(): """ %prog [start-path] @@ -333,7 +427,8 @@ def find_symlinks_cmd(): from textwrap import dedent parser = OptionParser(usage=dedent(find_symlinks_cmd.__doc__).strip()) options, args = parser.parse_args() - if not args: args = ['.'] + if not args: + args = ['.'] root = args.pop() if args: parser.error("unexpected argument(s)") @@ -346,8 +441,19 @@ def find_symlinks_cmd(): except KeyboardInterrupt: pass + @six.add_metaclass(binary.BitMask) class FileAttributes(int): + + # extract the values from the stat module on Python 3.5 + # and later. + locals().update( + (name.split('FILE_ATTRIBUTES_')[1].lower(), value) + for name, value in vars(stat).items() + if name.startswith('FILE_ATTRIBUTES_') + ) + + # For Python 3.4 and earlier, define the constants here archive = 0x20 compressed = 0x800 hidden = 0x2 @@ -364,11 +470,16 @@ class FileAttributes(int): temporary = 0x100 virtual = 0x10000 -def GetFileAttributes(filepath): - attrs = api.GetFileAttributes(filepath) - if attrs == api.INVALID_FILE_ATTRIBUTES: - raise WindowsError() - return FileAttributes(attrs) + @classmethod + def get(cls, filepath): + attrs = api.GetFileAttributes(filepath) + if attrs == api.INVALID_FILE_ATTRIBUTES: + raise WindowsError() + return cls(attrs) + + +GetFileAttributes = FileAttributes.get + def SetFileAttributes(filepath, *attrs): """ @@ -382,8 +493,8 @@ def SetFileAttributes(filepath, *attrs): """ nice_names = collections.defaultdict( lambda key: key, - hidden = 'FILE_ATTRIBUTE_HIDDEN', - read_only = 'FILE_ATTRIBUTE_READONLY', + hidden='FILE_ATTRIBUTE_HIDDEN', + read_only='FILE_ATTRIBUTE_READONLY', ) flags = (getattr(api, nice_names[attr], attr) for attr in attrs) flags = functools.reduce(operator.or_, flags) diff --git a/libs/jaraco/windows/filesystem/backports.py b/libs/jaraco/windows/filesystem/backports.py new file mode 100644 index 00000000..abb45d07 --- /dev/null +++ b/libs/jaraco/windows/filesystem/backports.py @@ -0,0 +1,109 @@ +from __future__ import unicode_literals + +import os.path + + +# realpath taken from https://bugs.python.org/file38057/issue9949-v4.patch +def realpath(path): + if isinstance(path, str): + prefix = '\\\\?\\' + unc_prefix = prefix + 'UNC' + new_unc_prefix = '\\' + cwd = os.getcwd() + else: + prefix = b'\\\\?\\' + unc_prefix = prefix + b'UNC' + new_unc_prefix = b'\\' + cwd = os.getcwdb() + had_prefix = path.startswith(prefix) + path, ok = _resolve_path(cwd, path, {}) + # The path returned by _getfinalpathname will always start with \\?\ - + # strip off that prefix unless it was already provided on the original + # path. + if not had_prefix: + # For UNC paths, the prefix will actually be \\?\UNC - handle that + # case as well. + if path.startswith(unc_prefix): + path = new_unc_prefix + path[len(unc_prefix):] + elif path.startswith(prefix): + path = path[len(prefix):] + return path + + +def _resolve_path(path, rest, seen): + # Windows normalizes the path before resolving symlinks; be sure to + # follow the same behavior. + rest = os.path.normpath(rest) + + if isinstance(rest, str): + sep = '\\' + else: + sep = b'\\' + + if os.path.isabs(rest): + drive, rest = os.path.splitdrive(rest) + path = drive + sep + rest = rest[1:] + + while rest: + name, _, rest = rest.partition(sep) + new_path = os.path.join(path, name) if path else name + if os.path.exists(new_path): + if not rest: + # The whole path exists. Resolve it using the OS. + path = os.path._getfinalpathname(new_path) + else: + # The OS can resolve `new_path`; keep traversing the path. + path = new_path + elif not os.path.lexists(new_path): + # `new_path` does not exist on the filesystem at all. Use the + # OS to resolve `path`, if it exists, and then append the + # remainder. + if os.path.exists(path): + path = os.path._getfinalpathname(path) + rest = os.path.join(name, rest) if rest else name + return os.path.join(path, rest), True + else: + # We have a symbolic link that the OS cannot resolve. Try to + # resolve it ourselves. + + # On Windows, symbolic link resolution can be partially or + # fully disabled [1]. The end result of a disabled symlink + # appears the same as a broken symlink (lexists() returns True + # but exists() returns False). And in both cases, the link can + # still be read using readlink(). Call stat() and check the + # resulting error code to ensure we don't circumvent the + # Windows symbolic link restrictions. + # [1] https://technet.microsoft.com/en-us/library/cc754077.aspx + try: + os.stat(new_path) + except OSError as e: + # WinError 1463: The symbolic link cannot be followed + # because its type is disabled. + if e.winerror == 1463: + raise + + key = os.path.normcase(new_path) + if key in seen: + # This link has already been seen; try to use the + # previously resolved value. + path = seen[key] + if path is None: + # It has not yet been resolved, which means we must + # have a symbolic link loop. Return what we have + # resolved so far plus the remainder of the path (who + # cares about the Zen of Python?). + path = os.path.join(new_path, rest) if rest else new_path + return path, False + else: + # Mark this link as in the process of being resolved. + seen[key] = None + # Try to resolve it. + path, ok = _resolve_path(path, os.readlink(new_path), seen) + if ok: + # Resolution succeded; store the resolved value. + seen[key] = path + else: + # Resolution failed; punt. + return (os.path.join(path, rest) if rest else path), False + return path, True diff --git a/libs/jaraco/windows/filesystem/change.py b/libs/jaraco/windows/filesystem/change.py index 50074f3e..620d9272 100644 --- a/libs/jaraco/windows/filesystem/change.py +++ b/libs/jaraco/windows/filesystem/change.py @@ -17,6 +17,8 @@ from threading import Thread import itertools import logging +import six + from more_itertools.recipes import consume import jaraco.text @@ -25,9 +27,11 @@ from jaraco.windows.api import event log = logging.getLogger(__name__) + class NotifierException(Exception): pass + class FileFilter(object): def set_root(self, root): self.root = root @@ -35,9 +39,11 @@ class FileFilter(object): def _get_file_path(self, filename): try: filename = os.path.join(self.root, filename) - except AttributeError: pass + except AttributeError: + pass return filename + class ModifiedTimeFilter(FileFilter): """ Returns true for each call where the modified time of the file is after @@ -53,6 +59,7 @@ class ModifiedTimeFilter(FileFilter): log.debug('{filepath} last modified at {last_mod}.'.format(**vars())) return last_mod > self.cutoff + class PatternFilter(FileFilter): """ Filter that returns True for files that match pattern (a regular @@ -60,13 +67,14 @@ class PatternFilter(FileFilter): """ def __init__(self, pattern): self.pattern = ( - re.compile(pattern) if isinstance(pattern, basestring) + re.compile(pattern) if isinstance(pattern, six.string_types) else pattern ) def __call__(self, file): return bool(self.pattern.match(file, re.I)) + class GlobFilter(PatternFilter): """ Filter that returns True for files that match the pattern (a glob @@ -102,6 +110,7 @@ class AggregateFilter(FileFilter): def __call__(self, file): return all(fil(file) for fil in self.filters) + class OncePerModFilter(FileFilter): def __init__(self): self.history = list() @@ -115,15 +124,18 @@ class OncePerModFilter(FileFilter): del self.history[-50:] return result + def files_with_path(files, path): return (os.path.join(path, file) for file in files) + def get_file_paths(walk_result): root, dirs, files = walk_result return files_with_path(files, root) + class Notifier(object): - def __init__(self, root = '.', filters = []): + def __init__(self, root='.', filters=[]): # assign the root, verify it exists self.root = root if not os.path.isdir(self.root): @@ -138,7 +150,8 @@ class Notifier(object): def __del__(self): try: fs.FindCloseChangeNotification(self.hChange) - except: pass + except Exception: + pass def _get_change_handle(self): # set up to monitor the directory tree specified @@ -151,8 +164,8 @@ class Notifier(object): # make sure it worked; if not, bail INVALID_HANDLE_VALUE = fs.INVALID_HANDLE_VALUE if self.hChange == INVALID_HANDLE_VALUE: - raise NotifierException('Could not set up directory change ' - 'notification') + raise NotifierException( + 'Could not set up directory change notification') @staticmethod def _filtered_walk(path, file_filter): @@ -171,6 +184,7 @@ class Notifier(object): def quit(self): event.SetEvent(self.quit_event) + class BlockingNotifier(Notifier): @staticmethod @@ -215,17 +229,18 @@ class BlockingNotifier(Notifier): result = next(results) return result + class ThreadedNotifier(BlockingNotifier, Thread): r""" ThreadedNotifier provides a simple interface that calls the handler for each file rooted in root that passes the filters. It runs as its own - thread, so must be started as such. + thread, so must be started as such:: - >>> notifier = ThreadedNotifier('c:\\', handler = StreamHandler()) # doctest: +SKIP - >>> notifier.start() # doctest: +SKIP - C:\Autoexec.bat changed. + notifier = ThreadedNotifier('c:\\', handler = StreamHandler()) + notifier.start() + C:\Autoexec.bat changed. """ - def __init__(self, root = '.', filters = [], handler = lambda file: None): + def __init__(self, root='.', filters=[], handler=lambda file: None): # init notifier stuff BlockingNotifier.__init__(self, root, filters) # init thread stuff @@ -242,13 +257,14 @@ class ThreadedNotifier(BlockingNotifier, Thread): for file in self.get_changed_files(): self.handle(file) + class StreamHandler(object): """ StreamHandler: a sample handler object for use with the threaded notifier that will announce by writing to the supplied stream (stdout by default) the name of the file. """ - def __init__(self, output = sys.stdout): + def __init__(self, output=sys.stdout): self.output = output def __call__(self, filename): diff --git a/libs/jaraco/windows/inet.py b/libs/jaraco/windows/inet.py index 5acd65ee..37c40cda 100644 --- a/libs/jaraco/windows/inet.py +++ b/libs/jaraco/windows/inet.py @@ -14,20 +14,21 @@ from jaraco.windows.api import errors, inet def GetAdaptersAddresses(): size = ctypes.c_ulong() - res = inet.GetAdaptersAddresses(0,0,None, None,size) + res = inet.GetAdaptersAddresses(0, 0, None, None, size) if res != errors.ERROR_BUFFER_OVERFLOW: raise RuntimeError("Error getting structure length (%d)" % res) print(size.value) pointer_type = ctypes.POINTER(inet.IP_ADAPTER_ADDRESSES) buffer = ctypes.create_string_buffer(size.value) struct_p = ctypes.cast(buffer, pointer_type) - res = inet.GetAdaptersAddresses(0,0,None, struct_p, size) + res = inet.GetAdaptersAddresses(0, 0, None, struct_p, size) if res != errors.NO_ERROR: raise RuntimeError("Error retrieving table (%d)" % res) while struct_p: yield struct_p.contents struct_p = struct_p.contents.next + class AllocatedTable(object): """ Both the interface table and the ip address table use the same @@ -79,20 +80,23 @@ class AllocatedTable(object): on the table size. """ table = self.get_table() - entries_array = self.row_structure*table.num_entries + entries_array = self.row_structure * table.num_entries pointer_type = ctypes.POINTER(entries_array) return ctypes.cast(table.entries, pointer_type).contents + class InterfaceTable(AllocatedTable): method = inet.GetIfTable structure = inet.MIB_IFTABLE row_structure = inet.MIB_IFROW + class AddressTable(AllocatedTable): method = inet.GetIpAddrTable structure = inet.MIB_IPADDRTABLE row_structure = inet.MIB_IPADDRROW + class AddressManager(object): @staticmethod def hardware_address_to_string(addr): @@ -100,7 +104,8 @@ class AddressManager(object): return ':'.join(hex_bytes) def get_host_mac_address_strings(self): - return (self.hardware_address_to_string(addr) + return ( + self.hardware_address_to_string(addr) for addr in self.get_host_mac_addresses()) def get_host_ip_address_strings(self): @@ -110,10 +115,10 @@ class AddressManager(object): return ( entry.physical_address for entry in InterfaceTable().entries - ) + ) def get_host_ip_addresses(self): return ( entry.address for entry in AddressTable().entries - ) + ) diff --git a/libs/jaraco/windows/lib.py b/libs/jaraco/windows/lib.py index 7cca12a0..0602c8e0 100644 --- a/libs/jaraco/windows/lib.py +++ b/libs/jaraco/windows/lib.py @@ -2,6 +2,7 @@ import ctypes from .api import library + def find_lib(lib): r""" Find the DLL for a given library. diff --git a/libs/jaraco/windows/memory.py b/libs/jaraco/windows/memory.py index 3455f17d..d4bcb83c 100644 --- a/libs/jaraco/windows/memory.py +++ b/libs/jaraco/windows/memory.py @@ -3,6 +3,7 @@ from ctypes import WinError from .api import memory + class LockedMemory(object): def __init__(self, handle): self.handle = handle diff --git a/libs/jaraco/windows/mmap.py b/libs/jaraco/windows/mmap.py index 8ae7d7ca..11460894 100644 --- a/libs/jaraco/windows/mmap.py +++ b/libs/jaraco/windows/mmap.py @@ -5,6 +5,7 @@ import six from .error import handle_nonzero_success from .api import memory + class MemoryMap(object): """ A memory map object which can have security attributes overridden. diff --git a/libs/jaraco/windows/msie.py b/libs/jaraco/windows/msie.py index 2e2223ce..c4b5793c 100644 --- a/libs/jaraco/windows/msie.py +++ b/libs/jaraco/windows/msie.py @@ -10,13 +10,15 @@ import itertools import six + class CookieMonster(object): "Read cookies out of a user's IE cookies file" @property def cookie_dir(self): import _winreg as winreg - key = winreg.OpenKeyEx(winreg.HKEY_CURRENT_USER, 'Software' + key = winreg.OpenKeyEx( + winreg.HKEY_CURRENT_USER, 'Software' '\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders') cookie_dir, type = winreg.QueryValueEx(key, 'Cookies') return cookie_dir @@ -24,10 +26,12 @@ class CookieMonster(object): def entries(self, filename): with open(os.path.join(self.cookie_dir, filename)) as cookie_file: while True: - entry = itertools.takewhile(self.is_not_cookie_delimiter, + entry = itertools.takewhile( + self.is_not_cookie_delimiter, cookie_file) entry = list(map(six.text_type.rstrip, entry)) - if not entry: break + if not entry: + break cookie = self.make_cookie(*entry) yield cookie @@ -36,8 +40,9 @@ class CookieMonster(object): return s != '*\n' @staticmethod - def make_cookie(key, value, domain, flags, ExpireLow, ExpireHigh, - CreateLow, CreateHigh): + def make_cookie( + key, value, domain, flags, ExpireLow, ExpireHigh, + CreateLow, CreateHigh): expires = (int(ExpireHigh) << 32) | int(ExpireLow) created = (int(CreateHigh) << 32) | int(CreateLow) flags = int(flags) @@ -51,4 +56,4 @@ class CookieMonster(object): expires=expires, created=created, path=path, - ) + ) diff --git a/libs/jaraco/windows/net.py b/libs/jaraco/windows/net.py index e2057aba..709f0dbf 100644 --- a/libs/jaraco/windows/net.py +++ b/libs/jaraco/windows/net.py @@ -7,8 +7,10 @@ __all__ = ('AddConnection') from jaraco.windows.error import WindowsError from .api import net -def AddConnection(remote_name, type=net.RESOURCETYPE_ANY, local_name=None, - provider_name=None, user=None, password=None, flags=0): + +def AddConnection( + remote_name, type=net.RESOURCETYPE_ANY, local_name=None, + provider_name=None, user=None, password=None, flags=0): resource = net.NETRESOURCE( type=type, remote_name=remote_name, diff --git a/libs/jaraco/windows/power.py b/libs/jaraco/windows/power.py index 07f5c779..8d8276fa 100644 --- a/libs/jaraco/windows/power.py +++ b/libs/jaraco/windows/power.py @@ -1,10 +1,9 @@ -#-*- coding: utf-8 -*- +# -*- coding: utf-8 -*- from __future__ import print_function import itertools import contextlib -import ctypes from more_itertools.recipes import consume, unique_justseen try: @@ -15,11 +14,13 @@ except ImportError: from jaraco.windows.error import handle_nonzero_success from .api import power + def GetSystemPowerStatus(): stat = power.SYSTEM_POWER_STATUS() handle_nonzero_success(GetSystemPowerStatus(stat)) return stat + def _init_power_watcher(): global power_watcher if 'power_watcher' not in globals(): @@ -27,18 +28,22 @@ def _init_power_watcher(): query = 'SELECT * from Win32_PowerManagementEvent' power_watcher = wmi.ExecNotificationQuery(query) + def get_power_management_events(): _init_power_watcher() while True: yield power_watcher.NextEvent() + def wait_for_power_status_change(): EVT_POWER_STATUS_CHANGE = 10 + def not_power_status_change(evt): return evt.EventType != EVT_POWER_STATUS_CHANGE events = get_power_management_events() consume(itertools.takewhile(not_power_status_change, events)) + def get_unique_power_states(): """ Just like get_power_states, but ensures values are returned only @@ -46,6 +51,7 @@ def get_unique_power_states(): """ return unique_justseen(get_power_states()) + def get_power_states(): """ Continuously return the power state of the system when it changes. @@ -57,6 +63,7 @@ def get_power_states(): yield state.ac_line_status_string wait_for_power_status_change() + @contextlib.contextmanager def no_sleep(): """ diff --git a/libs/jaraco/windows/privilege.py b/libs/jaraco/windows/privilege.py index 5a1e2bbb..848a526d 100644 --- a/libs/jaraco/windows/privilege.py +++ b/libs/jaraco/windows/privilege.py @@ -7,24 +7,31 @@ from .api import security from .api import privilege from .api import process + def get_process_token(): """ Get the current process token """ token = wintypes.HANDLE() - res = process.OpenProcessToken(process.GetCurrentProcess(), process.TOKEN_ALL_ACCESS, token) - if not res > 0: raise RuntimeError("Couldn't get process token") + res = process.OpenProcessToken( + process.GetCurrentProcess(), process.TOKEN_ALL_ACCESS, token) + if not res > 0: + raise RuntimeError("Couldn't get process token") return token + def get_symlink_luid(): """ Get the LUID for the SeCreateSymbolicLinkPrivilege """ symlink_luid = privilege.LUID() - res = privilege.LookupPrivilegeValue(None, "SeCreateSymbolicLinkPrivilege", symlink_luid) - if not res > 0: raise RuntimeError("Couldn't lookup privilege value") + res = privilege.LookupPrivilegeValue( + None, "SeCreateSymbolicLinkPrivilege", symlink_luid) + if not res > 0: + raise RuntimeError("Couldn't lookup privilege value") return symlink_luid + def get_privilege_information(): """ Get all privileges associated with the current process. @@ -38,7 +45,7 @@ def get_privilege_information(): None, 0, return_length, - ] + ] res = privilege.GetTokenInformation(*params) @@ -51,9 +58,11 @@ def get_privilege_information(): res = privilege.GetTokenInformation(*params) assert res > 0, "Error in second GetTokenInformation (%d)" % res - privileges = ctypes.cast(buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES)).contents + privileges = ctypes.cast( + buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES)).contents return privileges + def report_privilege_information(): """ Report all privilege information assigned to the current process. @@ -62,6 +71,7 @@ def report_privilege_information(): print("found {0} privileges".format(privileges.count)) tuple(map(print, privileges)) + def enable_symlink_privilege(): """ Try to assign the symlink privilege to the current process token. @@ -84,9 +94,11 @@ def enable_symlink_privilege(): ERROR_NOT_ALL_ASSIGNED = 1300 return ctypes.windll.kernel32.GetLastError() != ERROR_NOT_ALL_ASSIGNED + class PolicyHandle(wintypes.HANDLE): pass + class LSA_UNICODE_STRING(ctypes.Structure): _fields_ = [ ('length', ctypes.c_ushort), @@ -94,15 +106,20 @@ class LSA_UNICODE_STRING(ctypes.Structure): ('buffer', ctypes.wintypes.LPWSTR), ] + def OpenPolicy(system_name, object_attributes, access_mask): policy = PolicyHandle() - raise NotImplementedError("Need to construct structures for parameters " - "(see http://msdn.microsoft.com/en-us/library/windows/desktop/aa378299%28v=vs.85%29.aspx)") - res = ctypes.windll.advapi32.LsaOpenPolicy(system_name, object_attributes, + raise NotImplementedError( + "Need to construct structures for parameters " + "(see http://msdn.microsoft.com/en-us/library/windows" + "/desktop/aa378299%28v=vs.85%29.aspx)") + res = ctypes.windll.advapi32.LsaOpenPolicy( + system_name, object_attributes, access_mask, ctypes.byref(policy)) assert res == 0, "Error status {res}".format(**vars()) return policy + def grant_symlink_privilege(who, machine=''): """ Grant the 'create symlink' privilege to who. @@ -113,10 +130,13 @@ def grant_symlink_privilege(who, machine=''): policy = OpenPolicy(machine, flags) return policy + def main(): assigned = enable_symlink_privilege() msg = ['failure', 'success'][assigned] print("Symlink privilege assignment completed with {0}".format(msg)) -if __name__ == '__main__': main() + +if __name__ == '__main__': + main() diff --git a/libs/jaraco/windows/registry.py b/libs/jaraco/windows/registry.py index 96f16c7b..b6f3b239 100644 --- a/libs/jaraco/windows/registry.py +++ b/libs/jaraco/windows/registry.py @@ -3,6 +3,7 @@ from itertools import count import six winreg = six.moves.winreg + def key_values(key): for index in count(): try: @@ -10,6 +11,7 @@ def key_values(key): except WindowsError: break + def key_subkeys(key): for index in count(): try: diff --git a/libs/jaraco/windows/reparse.py b/libs/jaraco/windows/reparse.py index 67b5e2d7..2751e967 100644 --- a/libs/jaraco/windows/reparse.py +++ b/libs/jaraco/windows/reparse.py @@ -5,7 +5,9 @@ import ctypes.wintypes from .error import handle_nonzero_success from .api import filesystem -def DeviceIoControl(device, io_control_code, in_buffer, out_buffer, overlapped=None): + +def DeviceIoControl( + device, io_control_code, in_buffer, out_buffer, overlapped=None): if overlapped is not None: raise NotImplementedError("overlapped handles not yet supported") @@ -25,7 +27,7 @@ def DeviceIoControl(device, io_control_code, in_buffer, out_buffer, overlapped=N out_buffer, out_buffer_size, returned_bytes, overlapped, - ) + ) handle_nonzero_success(res) handle_nonzero_success(returned_bytes) diff --git a/libs/jaraco/windows/security.py b/libs/jaraco/windows/security.py index f5859462..7c481ed6 100644 --- a/libs/jaraco/windows/security.py +++ b/libs/jaraco/windows/security.py @@ -3,20 +3,24 @@ import ctypes.wintypes from jaraco.windows.error import handle_nonzero_success from .api import security + def GetTokenInformation(token, information_class): """ Given a token, get the token information for it. """ data_size = ctypes.wintypes.DWORD() - ctypes.windll.advapi32.GetTokenInformation(token, information_class.num, + ctypes.windll.advapi32.GetTokenInformation( + token, information_class.num, 0, 0, ctypes.byref(data_size)) data = ctypes.create_string_buffer(data_size.value) - handle_nonzero_success(ctypes.windll.advapi32.GetTokenInformation(token, + handle_nonzero_success(ctypes.windll.advapi32.GetTokenInformation( + token, information_class.num, ctypes.byref(data), ctypes.sizeof(data), ctypes.byref(data_size))) return ctypes.cast(data, ctypes.POINTER(security.TOKEN_USER)).contents + def OpenProcessToken(proc_handle, access): result = ctypes.wintypes.HANDLE() proc_handle = ctypes.wintypes.HANDLE(proc_handle) @@ -24,6 +28,7 @@ def OpenProcessToken(proc_handle, access): proc_handle, access, ctypes.byref(result))) return result + def get_current_user(): """ Return a TOKEN_USER for the owner of this process. @@ -34,6 +39,7 @@ def get_current_user(): ) return GetTokenInformation(process, security.TOKEN_USER) + def get_security_attributes_for_user(user=None): """ Return a SECURITY_ATTRIBUTES structure with the SID set to the @@ -42,7 +48,8 @@ def get_security_attributes_for_user(user=None): if user is None: user = get_current_user() - assert isinstance(user, security.TOKEN_USER), "user must be TOKEN_USER instance" + assert isinstance(user, security.TOKEN_USER), ( + "user must be TOKEN_USER instance") SD = security.SECURITY_DESCRIPTOR() SA = security.SECURITY_ATTRIBUTES() @@ -51,8 +58,10 @@ def get_security_attributes_for_user(user=None): SA.descriptor = SD SA.bInheritHandle = 1 - ctypes.windll.advapi32.InitializeSecurityDescriptor(ctypes.byref(SD), + ctypes.windll.advapi32.InitializeSecurityDescriptor( + ctypes.byref(SD), security.SECURITY_DESCRIPTOR.REVISION) - ctypes.windll.advapi32.SetSecurityDescriptorOwner(ctypes.byref(SD), + ctypes.windll.advapi32.SetSecurityDescriptorOwner( + ctypes.byref(SD), user.SID, 0) return SA diff --git a/libs/jaraco/windows/services.py b/libs/jaraco/windows/services.py index 5d8a41fa..97cea7ab 100644 --- a/libs/jaraco/windows/services.py +++ b/libs/jaraco/windows/services.py @@ -1,7 +1,8 @@ """ Windows Services support for controlling Windows Services. -Based on http://code.activestate.com/recipes/115875-controlling-windows-services/ +Based on http://code.activestate.com +/recipes/115875-controlling-windows-services/ """ from __future__ import print_function @@ -13,6 +14,7 @@ import win32api import win32con import win32service + class Service(object): """ The Service Class is used for controlling Windows @@ -47,7 +49,8 @@ class Service(object): pause: Pauses service (Only if service supports feature). resume: Resumes service that has been paused. status: Queries current status of service. - fetchstatus: Continually queries service until requested status(STARTING, RUNNING, + fetchstatus: Continually queries service until requested + status(STARTING, RUNNING, STOPPING & STOPPED) is met or timeout value(in seconds) reached. Default timeout value is infinite. infotype: Queries service for process type. (Single, shared and/or @@ -64,18 +67,21 @@ class Service(object): def __init__(self, service, machinename=None, dbname=None): self.userv = service - self.scmhandle = win32service.OpenSCManager(machinename, dbname, win32service.SC_MANAGER_ALL_ACCESS) + self.scmhandle = win32service.OpenSCManager( + machinename, dbname, win32service.SC_MANAGER_ALL_ACCESS) self.sserv, self.lserv = self.getname() if (self.sserv or self.lserv) is None: sys.exit() - self.handle = win32service.OpenService(self.scmhandle, self.sserv, win32service.SERVICE_ALL_ACCESS) + self.handle = win32service.OpenService( + self.scmhandle, self.sserv, win32service.SERVICE_ALL_ACCESS) self.sccss = "SYSTEM\\CurrentControlSet\\Services\\" def start(self): win32service.StartService(self.handle, None) def stop(self): - self.stat = win32service.ControlService(self.handle, win32service.SERVICE_CONTROL_STOP) + self.stat = win32service.ControlService( + self.handle, win32service.SERVICE_CONTROL_STOP) def restart(self): self.stop() @@ -83,29 +89,31 @@ class Service(object): self.start() def pause(self): - self.stat = win32service.ControlService(self.handle, win32service.SERVICE_CONTROL_PAUSE) + self.stat = win32service.ControlService( + self.handle, win32service.SERVICE_CONTROL_PAUSE) def resume(self): - self.stat = win32service.ControlService(self.handle, win32service.SERVICE_CONTROL_CONTINUE) + self.stat = win32service.ControlService( + self.handle, win32service.SERVICE_CONTROL_CONTINUE) - def status(self, prn = 0): + def status(self, prn=0): self.stat = win32service.QueryServiceStatus(self.handle) - if self.stat[1]==win32service.SERVICE_STOPPED: + if self.stat[1] == win32service.SERVICE_STOPPED: if prn == 1: print("The %s service is stopped." % self.lserv) else: return "STOPPED" - elif self.stat[1]==win32service.SERVICE_START_PENDING: + elif self.stat[1] == win32service.SERVICE_START_PENDING: if prn == 1: print("The %s service is starting." % self.lserv) else: return "STARTING" - elif self.stat[1]==win32service.SERVICE_STOP_PENDING: + elif self.stat[1] == win32service.SERVICE_STOP_PENDING: if prn == 1: print("The %s service is stopping." % self.lserv) else: return "STOPPING" - elif self.stat[1]==win32service.SERVICE_RUNNING: + elif self.stat[1] == win32service.SERVICE_RUNNING: if prn == 1: print("The %s service is running." % self.lserv) else: @@ -116,6 +124,7 @@ class Service(object): if timeout is not None: timeout = int(timeout) timeout *= 2 + def to(timeout): time.sleep(.5) if timeout is not None: @@ -127,11 +136,11 @@ class Service(object): if self.fstatus == "STOPPED": while 1: self.stat = win32service.QueryServiceStatus(self.handle) - if self.stat[1]==win32service.SERVICE_STOPPED: + if self.stat[1] == win32service.SERVICE_STOPPED: self.fstate = "STOPPED" break else: - timeout=to(timeout) + timeout = to(timeout) if timeout == "TO": return "TIMEDOUT" break diff --git a/libs/jaraco/windows/shell.py b/libs/jaraco/windows/shell.py index 133bab44..58333359 100644 --- a/libs/jaraco/windows/shell.py +++ b/libs/jaraco/windows/shell.py @@ -1,10 +1,12 @@ from .api import shell + def get_recycle_bin_confirm(): settings = shell.SHELLSTATE() shell.SHGetSetSettings(settings, shell.SSF_NOCONFIRMRECYCLE, False) return not settings.no_confirm_recycle + def set_recycle_bin_confirm(confirm=False): settings = shell.SHELLSTATE() settings.no_confirm_recycle = not confirm diff --git a/libs/jaraco/windows/timers.py b/libs/jaraco/windows/timers.py index 14f85427..626118a9 100644 --- a/libs/jaraco/windows/timers.py +++ b/libs/jaraco/windows/timers.py @@ -14,6 +14,7 @@ from jaraco.windows.api import event as win32event __author__ = 'Jason R. Coombs ' + class WaitableTimer: """ t = WaitableTimer() @@ -32,12 +33,12 @@ class WaitableTimer: def stop(self): win32event.SetEvent(self.stop_event) - def wait_for_signal(self, timeout = None): + def wait_for_signal(self, timeout=None): """ wait for the signal; return after the signal has occurred or the timeout in seconds elapses. """ - timeout_ms = int(timeout*1000) if timeout else win32event.INFINITE + timeout_ms = int(timeout * 1000) if timeout else win32event.INFINITE win32event.WaitForSingleObject(self.signal_event, timeout_ms) def _signal_loop(self, due_time, period): @@ -54,14 +55,14 @@ class WaitableTimer: except Exception: pass - #we're done here, just quit - def _wait(self, seconds): - milliseconds = int(seconds*1000) + milliseconds = int(seconds * 1000) if milliseconds > 0: res = win32event.WaitForSingleObject(self.stop_event, milliseconds) - if res == win32event.WAIT_OBJECT_0: raise Exception - if res == win32event.WAIT_TIMEOUT: pass + if res == win32event.WAIT_OBJECT_0: + raise Exception + if res == win32event.WAIT_TIMEOUT: + pass win32event.SetEvent(self.signal_event) @staticmethod diff --git a/libs/jaraco/windows/timezone.py b/libs/jaraco/windows/timezone.py index 9492143b..7eedcf0b 100644 --- a/libs/jaraco/windows/timezone.py +++ b/libs/jaraco/windows/timezone.py @@ -8,6 +8,7 @@ from ctypes.wintypes import WORD, WCHAR, BOOL, LONG from jaraco.windows.util import Extended from jaraco.collections import RangeMap + class AnyDict(object): "A dictionary that returns the same value regardless of key" @@ -17,6 +18,7 @@ class AnyDict(object): def __getitem__(self, key): return self.value + class SYSTEMTIME(Extended, ctypes.Structure): _fields_ = [ ('year', WORD), @@ -29,6 +31,7 @@ class SYSTEMTIME(Extended, ctypes.Structure): ('millisecond', WORD), ] + class REG_TZI_FORMAT(Extended, ctypes.Structure): _fields_ = [ ('bias', LONG), @@ -38,17 +41,19 @@ class REG_TZI_FORMAT(Extended, ctypes.Structure): ('daylight_start', SYSTEMTIME), ] + class TIME_ZONE_INFORMATION(Extended, ctypes.Structure): _fields_ = [ ('bias', LONG), - ('standard_name', WCHAR*32), + ('standard_name', WCHAR * 32), ('standard_start', SYSTEMTIME), ('standard_bias', LONG), - ('daylight_name', WCHAR*32), + ('daylight_name', WCHAR * 32), ('daylight_start', SYSTEMTIME), ('daylight_bias', LONG), ] + class DYNAMIC_TIME_ZONE_INFORMATION(TIME_ZONE_INFORMATION): """ Because the structure of the DYNAMIC_TIME_ZONE_INFORMATION extends @@ -70,7 +75,7 @@ class DYNAMIC_TIME_ZONE_INFORMATION(TIME_ZONE_INFORMATION): """ _fields_ = [ # ctypes automatically includes the fields from the parent - ('key_name', WCHAR*128), + ('key_name', WCHAR * 128), ('dynamic_daylight_time_disabled', BOOL), ] @@ -89,6 +94,7 @@ class DYNAMIC_TIME_ZONE_INFORMATION(TIME_ZONE_INFORMATION): kwargs[field_name] = arg super(DYNAMIC_TIME_ZONE_INFORMATION, self).__init__(*self_args, **kwargs) + class Info(DYNAMIC_TIME_ZONE_INFORMATION): """ A time zone definition class based on the win32 @@ -114,7 +120,7 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION): self.__init_from_other, self.__init_from_reg_tzi, self.__init_from_bytes, - ) + ) for func in funcs: try: func(*args, **kwargs) @@ -126,7 +132,7 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION): def __init_from_bytes(self, bytes, **kwargs): reg_tzi = REG_TZI_FORMAT() # todo: use buffer API in Python 3 - buffer = buffer(bytes) + buffer = memoryview(bytes) ctypes.memmove(ctypes.addressof(reg_tzi), buffer, len(buffer)) self.__init_from_reg_tzi(self, reg_tzi, **kwargs) @@ -146,12 +152,14 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION): value = super(Info, other).__getattribute__(other, name) setattr(self, name, value) # consider instead of the loop above just copying the memory directly - #size = max(ctypes.sizeof(DYNAMIC_TIME_ZONE_INFO), ctypes.sizeof(other)) - #ctypes.memmove(ctypes.addressof(self), other, size) + # size = max(ctypes.sizeof(DYNAMIC_TIME_ZONE_INFO), ctypes.sizeof(other)) + # ctypes.memmove(ctypes.addressof(self), other, size) def __getattribute__(self, attr): value = super(Info, self).__getattribute__(attr) - make_minute_timedelta = lambda m: datetime.timedelta(minutes = m) + + def make_minute_timedelta(m): + datetime.timedelta(minutes=m) if 'bias' in attr: value = make_minute_timedelta(value) return value @@ -205,10 +213,12 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION): def _locate_day(year, cutoff): """ Takes a SYSTEMTIME object, such as retrieved from a TIME_ZONE_INFORMATION - structure or call to GetTimeZoneInformation and interprets it based on the given + structure or call to GetTimeZoneInformation and interprets + it based on the given year to identify the actual day. - This method is necessary because the SYSTEMTIME structure refers to a day by its + This method is necessary because the SYSTEMTIME structure + refers to a day by its day of the week and week of the month (e.g. 4th saturday in March). >>> SATURDAY = 6 @@ -227,9 +237,11 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION): week_of_month = cutoff.day # so the following is the first day of that week day = (week_of_month - 1) * 7 + 1 - result = datetime.datetime(year, cutoff.month, day, + result = datetime.datetime( + year, cutoff.month, day, cutoff.hour, cutoff.minute, cutoff.second, cutoff.millisecond) - # now the result is the correct week, but not necessarily the correct day of the week + # now the result is the correct week, but not necessarily + # the correct day of the week days_to_go = (target_weekday - result.weekday()) % 7 result += datetime.timedelta(days_to_go) # if we selected a day in the month following the target month, @@ -238,5 +250,5 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION): # to be the last week in a month and adding the time delta might have # pushed the result into the next month. while result.month == cutoff.month + 1: - result -= datetime.timedelta(weeks = 1) + result -= datetime.timedelta(weeks=1) return result diff --git a/libs/jaraco/windows/ui.py b/libs/jaraco/windows/ui.py index 4be063db..20f948f3 100644 --- a/libs/jaraco/windows/ui.py +++ b/libs/jaraco/windows/ui.py @@ -3,6 +3,7 @@ import ctypes from jaraco.windows.util import ensure_unicode + def MessageBox(text, caption=None, handle=None, type=None): text, caption = map(ensure_unicode, (text, caption)) ctypes.windll.user32.MessageBoxW(handle, text, caption, type) diff --git a/libs/jaraco/windows/user.py b/libs/jaraco/windows/user.py index 0cb6e903..9b574777 100644 --- a/libs/jaraco/windows/user.py +++ b/libs/jaraco/windows/user.py @@ -3,6 +3,7 @@ from .api import errors from .api.user import GetUserName from .error import WindowsError, handle_nonzero_success + def get_user_name(): size = ctypes.wintypes.DWORD() try: diff --git a/libs/jaraco/windows/util.py b/libs/jaraco/windows/util.py index 9c6ae1aa..5524df85 100644 --- a/libs/jaraco/windows/util.py +++ b/libs/jaraco/windows/util.py @@ -2,17 +2,19 @@ import ctypes + def ensure_unicode(param): try: param = ctypes.create_unicode_buffer(param) except TypeError: - pass # just return the param as is + pass # just return the param as is return param + class Extended(object): "Used to add extended capability to structures" def __eq__(self, other): - return buffer(self) == buffer(other) + return memoryview(self) == memoryview(other) def __ne__(self, other): - return buffer(self) != buffer(other) + return memoryview(self) != memoryview(other) diff --git a/libs/jaraco/windows/vpn.py b/libs/jaraco/windows/vpn.py index a76f7363..9cf31dc1 100644 --- a/libs/jaraco/windows/vpn.py +++ b/libs/jaraco/windows/vpn.py @@ -1,11 +1,14 @@ import os -from path import path +from path import Path + def install_pptp(name, param_lines): """ """ - # or consider using the API: http://msdn.microsoft.com/en-us/library/aa446739%28v=VS.85%29.aspx - pbk_path = (path(os.environ['PROGRAMDATA']) + # or consider using the API: + # http://msdn.microsoft.com/en-us/library/aa446739%28v=VS.85%29.aspx + pbk_path = ( + Path(os.environ['PROGRAMDATA']) / 'Microsoft' / 'Network' / 'Connections' / 'pbk' / 'rasphone.pbk') pbk_path.dirname().makedirs_p() with open(pbk_path, 'a') as pbk: diff --git a/libs/jaraco/windows/xmouse.py b/libs/jaraco/windows/xmouse.py index 54b71e00..20b19435 100644 --- a/libs/jaraco/windows/xmouse.py +++ b/libs/jaraco/windows/xmouse.py @@ -17,6 +17,7 @@ def set(value): ) handle_nonzero_success(result) + def get(): value = ctypes.wintypes.BOOL() result = system.SystemParametersInfo( @@ -28,6 +29,7 @@ def get(): handle_nonzero_success(result) return bool(value) + def set_delay(milliseconds): result = system.SystemParametersInfo( system.SPI_SETACTIVEWNDTRKTIMEOUT, @@ -37,6 +39,7 @@ def set_delay(milliseconds): ) handle_nonzero_success(result) + def get_delay(): value = ctypes.wintypes.DWORD() result = system.SystemParametersInfo( diff --git a/libs/more_itertools/__init__.py b/libs/more_itertools/__init__.py index 5a3467fe..bba462c3 100644 --- a/libs/more_itertools/__init__.py +++ b/libs/more_itertools/__init__.py @@ -1,2 +1,2 @@ -from more_itertools.more import * -from more_itertools.recipes import * +from more_itertools.more import * # noqa +from more_itertools.recipes import * # noqa diff --git a/libs/more_itertools/more.py b/libs/more_itertools/more.py index 56512ce4..05e851ee 100644 --- a/libs/more_itertools/more.py +++ b/libs/more_itertools/more.py @@ -1,54 +1,130 @@ +from __future__ import print_function + +from collections import Counter, defaultdict, deque from functools import partial, wraps -from itertools import izip_longest -from recipes import * +from heapq import merge +from itertools import ( + chain, + compress, + count, + cycle, + dropwhile, + groupby, + islice, + repeat, + starmap, + takewhile, + tee +) +from operator import itemgetter, lt, gt, sub +from sys import maxsize, version_info +try: + from collections.abc import Sequence +except ImportError: + from collections import Sequence -__all__ = ['chunked', 'first', 'peekable', 'collate', 'consumer', 'ilen', - 'iterate', 'with_iter'] +from six import binary_type, string_types, text_type +from six.moves import filter, map, range, zip, zip_longest +from .recipes import consume, flatten, take + +__all__ = [ + 'adjacent', + 'always_iterable', + 'always_reversible', + 'bucket', + 'chunked', + 'circular_shifts', + 'collapse', + 'collate', + 'consecutive_groups', + 'consumer', + 'count_cycle', + 'difference', + 'distinct_permutations', + 'distribute', + 'divide', + 'exactly_n', + 'first', + 'groupby_transform', + 'ilen', + 'interleave_longest', + 'interleave', + 'intersperse', + 'islice_extended', + 'iterate', + 'last', + 'locate', + 'lstrip', + 'make_decorator', + 'map_reduce', + 'numeric_range', + 'one', + 'padded', + 'peekable', + 'replace', + 'rlocate', + 'rstrip', + 'run_length', + 'seekable', + 'SequenceView', + 'side_effect', + 'sliced', + 'sort_together', + 'split_at', + 'split_after', + 'split_before', + 'spy', + 'stagger', + 'strip', + 'unique_to_each', + 'windowed', + 'with_iter', + 'zip_offset', +] _marker = object() def chunked(iterable, n): - """Break an iterable into lists of a given length:: + """Break *iterable* into lists of length *n*: - >>> list(chunked([1, 2, 3, 4, 5, 6, 7], 3)) - [[1, 2, 3], [4, 5, 6], [7]] + >>> list(chunked([1, 2, 3, 4, 5, 6], 3)) + [[1, 2, 3], [4, 5, 6]] - If the length of ``iterable`` is not evenly divisible by ``n``, the last - returned list will be shorter. + If the length of *iterable* is not evenly divisible by *n*, the last + returned list will be shorter: - This is useful for splitting up a computation on a large number of keys - into batches, to be pickled and sent off to worker processes. One example - is operations on rows in MySQL, which does not implement server-side - cursors properly and would otherwise load the entire dataset into RAM on - the client. + >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3)) + [[1, 2, 3], [4, 5, 6], [7, 8]] + + To use a fill-in value instead, see the :func:`grouper` recipe. + + :func:`chunked` is useful for splitting up a computation on a large number + of keys into batches, to be pickled and sent off to worker processes. One + example is operations on rows in MySQL, which does not implement + server-side cursors properly and would otherwise load the entire dataset + into RAM on the client. """ - # Doesn't seem to run into any number-of-args limits. - for group in (list(g) for g in izip_longest(*[iter(iterable)] * n, - fillvalue=_marker)): - if group[-1] is _marker: - # If this is the last group, shuck off the padding: - del group[group.index(_marker):] - yield group + return iter(partial(take, n, iter(iterable)), []) def first(iterable, default=_marker): - """Return the first item of an iterable, ``default`` if there is none. + """Return the first item of *iterable*, or *default* if *iterable* is + empty. - >>> first(xrange(4)) + >>> first([0, 1, 2, 3]) 0 - >>> first(xrange(0), 'some default') + >>> first([], 'some default') 'some default' - If ``default`` is not provided and there are no items in the iterable, + If *default* is not provided and there are no items in the iterable, raise ``ValueError``. - ``first()`` is useful when you have a generator of expensive-to-retrieve + :func:`first` is useful when you have a generator of expensive-to-retrieve values and want any arbitrary one. It is marginally shorter than - ``next(iter(...))`` but saves you an entire ``try``/``except`` when you - want to provide a fallback value. + ``next(iter(iterable), default)``. """ try: @@ -64,55 +140,108 @@ def first(iterable, default=_marker): return default +def last(iterable, default=_marker): + """Return the last item of *iterable*, or *default* if *iterable* is + empty. + + >>> last([0, 1, 2, 3]) + 3 + >>> last([], 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + """ + try: + try: + # Try to access the last item directly + return iterable[-1] + except (TypeError, AttributeError, KeyError): + # If not slice-able, iterate entirely using length-1 deque + return deque(iterable, maxlen=1)[0] + except IndexError: # If the iterable was empty + if default is _marker: + raise ValueError('last() was called on an empty iterable, and no ' + 'default value was provided.') + return default + + class peekable(object): - """Wrapper for an iterator to allow 1-item lookahead + """Wrap an iterator to allow lookahead and prepending elements. - Call ``peek()`` on the result to get the value that will next pop out of - ``next()``, without advancing the iterator: + Call :meth:`peek` on the result to get the value that will be returned + by :func:`next`. This won't advance the iterator: - >>> p = peekable(xrange(2)) + >>> p = peekable(['a', 'b']) >>> p.peek() - 0 - >>> p.next() - 0 - >>> p.peek() - 1 - >>> p.next() - 1 + 'a' + >>> next(p) + 'a' - Pass ``peek()`` a default value, and it will be returned in the case where - the iterator is exhausted: + Pass :meth:`peek` a default value to return that instead of raising + ``StopIteration`` when the iterator is exhausted. >>> p = peekable([]) >>> p.peek('hi') 'hi' - If no default is provided, ``peek()`` raises ``StopIteration`` when there - are no items left. + peekables also offer a :meth:`prepend` method, which "inserts" items + at the head of the iterable: - To test whether there are more items in the iterator, examine the - peekable's truth value. If it is truthy, there are more items. + >>> p = peekable([1, 2, 3]) + >>> p.prepend(10, 11, 12) + >>> next(p) + 10 + >>> p.peek() + 11 + >>> list(p) + [11, 12, 1, 2, 3] - >>> assert peekable(xrange(1)) - >>> assert not peekable([]) + peekables can be indexed. Index 0 is the item that will be returned by + :func:`next`, index 1 is the item after that, and so on: + The values up to the given index will be cached. + + >>> p = peekable(['a', 'b', 'c', 'd']) + >>> p[0] + 'a' + >>> p[1] + 'b' + >>> next(p) + 'a' + + Negative indexes are supported, but be aware that they will cache the + remaining items in the source iterator, which may require significant + storage. + + To check whether a peekable is exhausted, check its truth value: + + >>> p = peekable(['a', 'b']) + >>> if p: # peekable has items + ... list(p) + ['a', 'b'] + >>> if not p: # peekable is exhaused + ... list(p) + [] """ - # Lowercase to blend in with itertools. The fact that it's a class is an - # implementation detail. - def __init__(self, iterable): self._it = iter(iterable) + self._cache = deque() def __iter__(self): return self - def __nonzero__(self): + def __bool__(self): try: self.peek() except StopIteration: return False return True + def __nonzero__(self): + # For Python 2 compatibility + return self.__bool__() + def peek(self, default=_marker): """Return the item that will be next returned from ``next()``. @@ -120,52 +249,155 @@ class peekable(object): provided, raise ``StopIteration``. """ - if not hasattr(self, '_peek'): + if not self._cache: try: - self._peek = self._it.next() + self._cache.append(next(self._it)) except StopIteration: if default is _marker: raise return default - return self._peek + return self._cache[0] - def next(self): - ret = self.peek() - del self._peek - return ret + def prepend(self, *items): + """Stack up items to be the next ones returned from ``next()`` or + ``self.peek()``. The items will be returned in + first in, first out order:: + + >>> p = peekable([1, 2, 3]) + >>> p.prepend(10, 11, 12) + >>> next(p) + 10 + >>> list(p) + [11, 12, 1, 2, 3] + + It is possible, by prepending items, to "resurrect" a peekable that + previously raised ``StopIteration``. + + >>> p = peekable([]) + >>> next(p) + Traceback (most recent call last): + ... + StopIteration + >>> p.prepend(1) + >>> next(p) + 1 + >>> next(p) + Traceback (most recent call last): + ... + StopIteration + + """ + self._cache.extendleft(reversed(items)) + + def __next__(self): + if self._cache: + return self._cache.popleft() + + return next(self._it) + + next = __next__ # For Python 2 compatibility + + def _get_slice(self, index): + # Normalize the slice's arguments + step = 1 if (index.step is None) else index.step + if step > 0: + start = 0 if (index.start is None) else index.start + stop = maxsize if (index.stop is None) else index.stop + elif step < 0: + start = -1 if (index.start is None) else index.start + stop = (-maxsize - 1) if (index.stop is None) else index.stop + else: + raise ValueError('slice step cannot be zero') + + # If either the start or stop index is negative, we'll need to cache + # the rest of the iterable in order to slice from the right side. + if (start < 0) or (stop < 0): + self._cache.extend(self._it) + # Otherwise we'll need to find the rightmost index and cache to that + # point. + else: + n = min(max(start, stop) + 1, maxsize) + cache_len = len(self._cache) + if n >= cache_len: + self._cache.extend(islice(self._it, n - cache_len)) + + return list(self._cache)[index] + + def __getitem__(self, index): + if isinstance(index, slice): + return self._get_slice(index) + + cache_len = len(self._cache) + if index < 0: + self._cache.extend(self._it) + elif index >= cache_len: + self._cache.extend(islice(self._it, index + 1 - cache_len)) + + return self._cache[index] -def collate(*iterables, **kwargs): - """Return a sorted merge of the items from each of several already-sorted - ``iterables``. - - >>> list(collate('ACDZ', 'AZ', 'JKL')) - ['A', 'A', 'C', 'D', 'J', 'K', 'L', 'Z', 'Z'] - - Works lazily, keeping only the next value from each iterable in memory. Use - ``collate()`` to, for example, perform a n-way mergesort of items that - don't fit in memory. - - :arg key: A function that returns a comparison value for an item. Defaults - to the identity function. - :arg reverse: If ``reverse=True``, yield results in descending order - rather than ascending. ``iterables`` must also yield their elements in - descending order. - - If the elements of the passed-in iterables are out of order, you might get - unexpected results. +def _collate(*iterables, **kwargs): + """Helper for ``collate()``, called when the user is using the ``reverse`` + or ``key`` keyword arguments on Python versions below 3.5. """ key = kwargs.pop('key', lambda a: a) reverse = kwargs.pop('reverse', False) - min_or_max = partial(max if reverse else min, key=lambda (a, b): a) + min_or_max = partial(max if reverse else min, key=itemgetter(0)) peekables = [peekable(it) for it in iterables] peekables = [p for p in peekables if p] # Kill empties. while peekables: _, p = min_or_max((key(p.peek()), p) for p in peekables) - yield p.next() - peekables = [p for p in peekables if p] + yield next(p) + peekables = [x for x in peekables if x] + + +def collate(*iterables, **kwargs): + """Return a sorted merge of the items from each of several already-sorted + *iterables*. + + >>> list(collate('ACDZ', 'AZ', 'JKL')) + ['A', 'A', 'C', 'D', 'J', 'K', 'L', 'Z', 'Z'] + + Works lazily, keeping only the next value from each iterable in memory. Use + :func:`collate` to, for example, perform a n-way mergesort of items that + don't fit in memory. + + If a *key* function is specified, the iterables will be sorted according + to its result: + + >>> key = lambda s: int(s) # Sort by numeric value, not by string + >>> list(collate(['1', '10'], ['2', '11'], key=key)) + ['1', '2', '10', '11'] + + + If the *iterables* are sorted in descending order, set *reverse* to + ``True``: + + >>> list(collate([5, 3, 1], [4, 2, 0], reverse=True)) + [5, 4, 3, 2, 1, 0] + + If the elements of the passed-in iterables are out of order, you might get + unexpected results. + + On Python 2.7, this function delegates to :func:`heapq.merge` if neither + of the keyword arguments are specified. On Python 3.5+, this function + is an alias for :func:`heapq.merge`. + + """ + if not kwargs: + return merge(*iterables) + + return _collate(*iterables, **kwargs) + + +# If using Python version 3.5 or greater, heapq.merge() will be faster than +# collate - use that instead. +if version_info >= (3, 5, 0): + _collate_docstring = collate.__doc__ + collate = partial(merge) + collate.__doc__ = _collate_docstring def consumer(func): @@ -173,50 +405,54 @@ def consumer(func): to its first yield point so you don't have to call ``next()`` on it manually. - >>> @consumer - ... def tally(): - ... i = 0 - ... while True: - ... print 'Thing number %s is %s.' % (i, (yield)) - ... i += 1 - ... - >>> t = tally() - >>> t.send('red') - Thing number 0 is red. - >>> t.send('fish') - Thing number 1 is fish. + >>> @consumer + ... def tally(): + ... i = 0 + ... while True: + ... print('Thing number %s is %s.' % (i, (yield))) + ... i += 1 + ... + >>> t = tally() + >>> t.send('red') + Thing number 0 is red. + >>> t.send('fish') + Thing number 1 is fish. - Without the decorator, you would have to call ``t.next()`` before + Without the decorator, you would have to call ``next(t)`` before ``t.send()`` could be used. """ @wraps(func) def wrapper(*args, **kwargs): gen = func(*args, **kwargs) - gen.next() + next(gen) return gen return wrapper def ilen(iterable): - """Return the number of items in ``iterable``. + """Return the number of items in *iterable*. - >>> from itertools import ifilter - >>> ilen(ifilter(lambda x: x % 3 == 0, xrange(1000000))) - 333334 + >>> ilen(x for x in range(1000000) if x % 3 == 0) + 333334 - This does, of course, consume the iterable, so handle it with care. + This consumes the iterable, so handle with care. """ - return sum(1 for _ in iterable) + # maxlen=1 only stores the last item in the deque + d = deque(enumerate(iterable, 1), maxlen=1) + # since we started enumerate at 1, + # the first item of the last pair will be the length of the iterable + # (assuming there were items) + return d[0][0] if d else 0 def iterate(func, start): """Return ``start``, ``func(start)``, ``func(func(start))``, ... - >>> from itertools import islice - >>> list(islice(iterate(lambda x: 2*x, 1), 10)) - [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + >>> from itertools import islice + >>> list(islice(iterate(lambda x: 2*x, 1), 10)) + [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] """ while True: @@ -227,11 +463,1749 @@ def iterate(func, start): def with_iter(context_manager): """Wrap an iterable in a ``with`` statement, so it closes once exhausted. - Example:: + For example, this will close the file when the iterator is exhausted:: upper_lines = (line.upper() for line in with_iter(open('foo'))) + Any context manager which returns an iterable is a candidate for + ``with_iter``. + """ with context_manager as iterable: for item in iterable: yield item + + +def one(iterable, too_short=None, too_long=None): + """Return the first item from *iterable*, which is expected to contain only + that item. Raise an exception if *iterable* is empty or has more than one + item. + + :func:`one` is useful for ensuring that an iterable contains only one item. + For example, it can be used to retrieve the result of a database query + that is expected to return a single row. + + If *iterable* is empty, ``ValueError`` will be raised. You may specify a + different exception with the *too_short* keyword: + + >>> it = [] + >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too many items in iterable (expected 1)' + >>> too_short = IndexError('too few items') + >>> one(it, too_short=too_short) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + IndexError: too few items + + Similarly, if *iterable* contains more than one item, ``ValueError`` will + be raised. You may specify a different exception with the *too_long* + keyword: + + >>> it = ['too', 'many'] + >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too many items in iterable (expected 1)' + >>> too_long = RuntimeError + >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + RuntimeError + + Note that :func:`one` attempts to advance *iterable* twice to ensure there + is only one item. If there is more than one, both items will be discarded. + See :func:`spy` or :func:`peekable` to check iterable contents less + destructively. + + """ + it = iter(iterable) + + try: + value = next(it) + except StopIteration: + raise too_short or ValueError('too few items in iterable (expected 1)') + + try: + next(it) + except StopIteration: + pass + else: + raise too_long or ValueError('too many items in iterable (expected 1)') + + return value + + +def distinct_permutations(iterable): + """Yield successive distinct permutations of the elements in *iterable*. + + >>> sorted(distinct_permutations([1, 0, 1])) + [(0, 1, 1), (1, 0, 1), (1, 1, 0)] + + Equivalent to ``set(permutations(iterable))``, except duplicates are not + generated and thrown away. For larger input sequences this is much more + efficient. + + Duplicate permutations arise when there are duplicated elements in the + input iterable. The number of items returned is + `n! / (x_1! * x_2! * ... * x_n!)`, where `n` is the total number of + items input, and each `x_i` is the count of a distinct item in the input + sequence. + + """ + def perm_unique_helper(item_counts, perm, i): + """Internal helper function + + :arg item_counts: Stores the unique items in ``iterable`` and how many + times they are repeated + :arg perm: The permutation that is being built for output + :arg i: The index of the permutation being modified + + The output permutations are built up recursively; the distinct items + are placed until their repetitions are exhausted. + """ + if i < 0: + yield tuple(perm) + else: + for item in item_counts: + if item_counts[item] <= 0: + continue + perm[i] = item + item_counts[item] -= 1 + for x in perm_unique_helper(item_counts, perm, i - 1): + yield x + item_counts[item] += 1 + + item_counts = Counter(iterable) + length = sum(item_counts.values()) + + return perm_unique_helper(item_counts, [None] * length, length - 1) + + +def intersperse(e, iterable, n=1): + """Intersperse filler element *e* among the items in *iterable*, leaving + *n* items between each filler element. + + >>> list(intersperse('!', [1, 2, 3, 4, 5])) + [1, '!', 2, '!', 3, '!', 4, '!', 5] + + >>> list(intersperse(None, [1, 2, 3, 4, 5], n=2)) + [1, 2, None, 3, 4, None, 5] + + """ + if n == 0: + raise ValueError('n must be > 0') + elif n == 1: + # interleave(repeat(e), iterable) -> e, x_0, e, e, x_1, e, x_2... + # islice(..., 1, None) -> x_0, e, e, x_1, e, x_2... + return islice(interleave(repeat(e), iterable), 1, None) + else: + # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]... + # islice(..., 1, None) -> [x_0, x_1], [e], [x_2, x_3]... + # flatten(...) -> x_0, x_1, e, x_2, x_3... + filler = repeat([e]) + chunks = chunked(iterable, n) + return flatten(islice(interleave(filler, chunks), 1, None)) + + +def unique_to_each(*iterables): + """Return the elements from each of the input iterables that aren't in the + other input iterables. + + For example, suppose you have a set of packages, each with a set of + dependencies:: + + {'pkg_1': {'A', 'B'}, 'pkg_2': {'B', 'C'}, 'pkg_3': {'B', 'D'}} + + If you remove one package, which dependencies can also be removed? + + If ``pkg_1`` is removed, then ``A`` is no longer necessary - it is not + associated with ``pkg_2`` or ``pkg_3``. Similarly, ``C`` is only needed for + ``pkg_2``, and ``D`` is only needed for ``pkg_3``:: + + >>> unique_to_each({'A', 'B'}, {'B', 'C'}, {'B', 'D'}) + [['A'], ['C'], ['D']] + + If there are duplicates in one input iterable that aren't in the others + they will be duplicated in the output. Input order is preserved:: + + >>> unique_to_each("mississippi", "missouri") + [['p', 'p'], ['o', 'u', 'r']] + + It is assumed that the elements of each iterable are hashable. + + """ + pool = [list(it) for it in iterables] + counts = Counter(chain.from_iterable(map(set, pool))) + uniques = {element for element in counts if counts[element] == 1} + return [list(filter(uniques.__contains__, it)) for it in pool] + + +def windowed(seq, n, fillvalue=None, step=1): + """Return a sliding window of width *n* over the given iterable. + + >>> all_windows = windowed([1, 2, 3, 4, 5], 3) + >>> list(all_windows) + [(1, 2, 3), (2, 3, 4), (3, 4, 5)] + + When the window is larger than the iterable, *fillvalue* is used in place + of missing values:: + + >>> list(windowed([1, 2, 3], 4)) + [(1, 2, 3, None)] + + Each window will advance in increments of *step*: + + >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2)) + [(1, 2, 3), (3, 4, 5), (5, 6, '!')] + + """ + if n < 0: + raise ValueError('n must be >= 0') + if n == 0: + yield tuple() + return + if step < 1: + raise ValueError('step must be >= 1') + + it = iter(seq) + window = deque([], n) + append = window.append + + # Initial deque fill + for _ in range(n): + append(next(it, fillvalue)) + yield tuple(window) + + # Appending new items to the right causes old items to fall off the left + i = 0 + for item in it: + append(item) + i = (i + 1) % step + if i % step == 0: + yield tuple(window) + + # If there are items from the iterable in the window, pad with the given + # value and emit them. + if (i % step) and (step - i < n): + for _ in range(step - i): + append(fillvalue) + yield tuple(window) + + +class bucket(object): + """Wrap *iterable* and return an object that buckets it iterable into + child iterables based on a *key* function. + + >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] + >>> s = bucket(iterable, key=lambda x: x[0]) + >>> a_iterable = s['a'] + >>> next(a_iterable) + 'a1' + >>> next(a_iterable) + 'a2' + >>> list(s['b']) + ['b1', 'b2', 'b3'] + + The original iterable will be advanced and its items will be cached until + they are used by the child iterables. This may require significant storage. + + By default, attempting to select a bucket to which no items belong will + exhaust the iterable and cache all values. + If you specify a *validator* function, selected buckets will instead be + checked against it. + + >>> from itertools import count + >>> it = count(1, 2) # Infinite sequence of odd numbers + >>> key = lambda x: x % 10 # Bucket by last digit + >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only + >>> s = bucket(it, key=key, validator=validator) + >>> 2 in s + False + >>> list(s[2]) + [] + + """ + def __init__(self, iterable, key, validator=None): + self._it = iter(iterable) + self._key = key + self._cache = defaultdict(deque) + self._validator = validator or (lambda x: True) + + def __contains__(self, value): + if not self._validator(value): + return False + + try: + item = next(self[value]) + except StopIteration: + return False + else: + self._cache[value].appendleft(item) + + return True + + def _get_values(self, value): + """ + Helper to yield items from the parent iterator that match *value*. + Items that don't match are stored in the local cache as they + are encountered. + """ + while True: + # If we've cached some items that match the target value, emit + # the first one and evict it from the cache. + if self._cache[value]: + yield self._cache[value].popleft() + # Otherwise we need to advance the parent iterator to search for + # a matching item, caching the rest. + else: + while True: + try: + item = next(self._it) + except StopIteration: + return + item_value = self._key(item) + if item_value == value: + yield item + break + elif self._validator(item_value): + self._cache[item_value].append(item) + + def __getitem__(self, value): + if not self._validator(value): + return iter(()) + + return self._get_values(value) + + +def spy(iterable, n=1): + """Return a 2-tuple with a list containing the first *n* elements of + *iterable*, and an iterator with the same items as *iterable*. + This allows you to "look ahead" at the items in the iterable without + advancing it. + + There is one item in the list by default: + + >>> iterable = 'abcdefg' + >>> head, iterable = spy(iterable) + >>> head + ['a'] + >>> list(iterable) + ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + + You may use unpacking to retrieve items instead of lists: + + >>> (head,), iterable = spy('abcdefg') + >>> head + 'a' + >>> (first, second), iterable = spy('abcdefg', 2) + >>> first + 'a' + >>> second + 'b' + + The number of items requested can be larger than the number of items in + the iterable: + + >>> iterable = [1, 2, 3, 4, 5] + >>> head, iterable = spy(iterable, 10) + >>> head + [1, 2, 3, 4, 5] + >>> list(iterable) + [1, 2, 3, 4, 5] + + """ + it = iter(iterable) + head = take(n, it) + + return head, chain(head, it) + + +def interleave(*iterables): + """Return a new iterable yielding from each iterable in turn, + until the shortest is exhausted. + + >>> list(interleave([1, 2, 3], [4, 5], [6, 7, 8])) + [1, 4, 6, 2, 5, 7] + + For a version that doesn't terminate after the shortest iterable is + exhausted, see :func:`interleave_longest`. + + """ + return chain.from_iterable(zip(*iterables)) + + +def interleave_longest(*iterables): + """Return a new iterable yielding from each iterable in turn, + skipping any that are exhausted. + + >>> list(interleave_longest([1, 2, 3], [4, 5], [6, 7, 8])) + [1, 4, 6, 2, 5, 7, 3, 8] + + This function produces the same output as :func:`roundrobin`, but may + perform better for some inputs (in particular when the number of iterables + is large). + + """ + i = chain.from_iterable(zip_longest(*iterables, fillvalue=_marker)) + return (x for x in i if x is not _marker) + + +def collapse(iterable, base_type=None, levels=None): + """Flatten an iterable with multiple levels of nesting (e.g., a list of + lists of tuples) into non-iterable types. + + >>> iterable = [(1, 2), ([3, 4], [[5], [6]])] + >>> list(collapse(iterable)) + [1, 2, 3, 4, 5, 6] + + String types are not considered iterable and will not be collapsed. + To avoid collapsing other types, specify *base_type*: + + >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']] + >>> list(collapse(iterable, base_type=tuple)) + ['ab', ('cd', 'ef'), 'gh', 'ij'] + + Specify *levels* to stop flattening after a certain level: + + >>> iterable = [('a', ['b']), ('c', ['d'])] + >>> list(collapse(iterable)) # Fully flattened + ['a', 'b', 'c', 'd'] + >>> list(collapse(iterable, levels=1)) # Only one level flattened + ['a', ['b'], 'c', ['d']] + + """ + def walk(node, level): + if ( + ((levels is not None) and (level > levels)) or + isinstance(node, string_types) or + ((base_type is not None) and isinstance(node, base_type)) + ): + yield node + return + + try: + tree = iter(node) + except TypeError: + yield node + return + else: + for child in tree: + for x in walk(child, level + 1): + yield x + + for x in walk(iterable, 0): + yield x + + +def side_effect(func, iterable, chunk_size=None, before=None, after=None): + """Invoke *func* on each item in *iterable* (or on each *chunk_size* group + of items) before yielding the item. + + `func` must be a function that takes a single argument. Its return value + will be discarded. + + *before* and *after* are optional functions that take no arguments. They + will be executed before iteration starts and after it ends, respectively. + + `side_effect` can be used for logging, updating progress bars, or anything + that is not functionally "pure." + + Emitting a status message: + + >>> from more_itertools import consume + >>> func = lambda item: print('Received {}'.format(item)) + >>> consume(side_effect(func, range(2))) + Received 0 + Received 1 + + Operating on chunks of items: + + >>> pair_sums = [] + >>> func = lambda chunk: pair_sums.append(sum(chunk)) + >>> list(side_effect(func, [0, 1, 2, 3, 4, 5], 2)) + [0, 1, 2, 3, 4, 5] + >>> list(pair_sums) + [1, 5, 9] + + Writing to a file-like object: + + >>> from io import StringIO + >>> from more_itertools import consume + >>> f = StringIO() + >>> func = lambda x: print(x, file=f) + >>> before = lambda: print(u'HEADER', file=f) + >>> after = f.close + >>> it = [u'a', u'b', u'c'] + >>> consume(side_effect(func, it, before=before, after=after)) + >>> f.closed + True + + """ + try: + if before is not None: + before() + + if chunk_size is None: + for item in iterable: + func(item) + yield item + else: + for chunk in chunked(iterable, chunk_size): + func(chunk) + for item in chunk: + yield item + finally: + if after is not None: + after() + + +def sliced(seq, n): + """Yield slices of length *n* from the sequence *seq*. + + >>> list(sliced((1, 2, 3, 4, 5, 6), 3)) + [(1, 2, 3), (4, 5, 6)] + + If the length of the sequence is not divisible by the requested slice + length, the last slice will be shorter. + + >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3)) + [(1, 2, 3), (4, 5, 6), (7, 8)] + + This function will only work for iterables that support slicing. + For non-sliceable iterables, see :func:`chunked`. + + """ + return takewhile(bool, (seq[i: i + n] for i in count(0, n))) + + +def split_at(iterable, pred): + """Yield lists of items from *iterable*, where each list is delimited by + an item where callable *pred* returns ``True``. The lists do not include + the delimiting items. + + >>> list(split_at('abcdcba', lambda x: x == 'b')) + [['a'], ['c', 'd', 'c'], ['a']] + + >>> list(split_at(range(10), lambda n: n % 2 == 1)) + [[0], [2], [4], [6], [8], []] + """ + buf = [] + for item in iterable: + if pred(item): + yield buf + buf = [] + else: + buf.append(item) + yield buf + + +def split_before(iterable, pred): + """Yield lists of items from *iterable*, where each list starts with an + item where callable *pred* returns ``True``: + + >>> list(split_before('OneTwo', lambda s: s.isupper())) + [['O', 'n', 'e'], ['T', 'w', 'o']] + + >>> list(split_before(range(10), lambda n: n % 3 == 0)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + + """ + buf = [] + for item in iterable: + if pred(item) and buf: + yield buf + buf = [] + buf.append(item) + yield buf + + +def split_after(iterable, pred): + """Yield lists of items from *iterable*, where each list ends with an + item where callable *pred* returns ``True``: + + >>> list(split_after('one1two2', lambda s: s.isdigit())) + [['o', 'n', 'e', '1'], ['t', 'w', 'o', '2']] + + >>> list(split_after(range(10), lambda n: n % 3 == 0)) + [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]] + + """ + buf = [] + for item in iterable: + buf.append(item) + if pred(item) and buf: + yield buf + buf = [] + if buf: + yield buf + + +def padded(iterable, fillvalue=None, n=None, next_multiple=False): + """Yield the elements from *iterable*, followed by *fillvalue*, such that + at least *n* items are emitted. + + >>> list(padded([1, 2, 3], '?', 5)) + [1, 2, 3, '?', '?'] + + If *next_multiple* is ``True``, *fillvalue* will be emitted until the + number of items emitted is a multiple of *n*:: + + >>> list(padded([1, 2, 3, 4], n=3, next_multiple=True)) + [1, 2, 3, 4, None, None] + + If *n* is ``None``, *fillvalue* will be emitted indefinitely. + + """ + it = iter(iterable) + if n is None: + for item in chain(it, repeat(fillvalue)): + yield item + elif n < 1: + raise ValueError('n must be at least 1') + else: + item_count = 0 + for item in it: + yield item + item_count += 1 + + remaining = (n - item_count) % n if next_multiple else n - item_count + for _ in range(remaining): + yield fillvalue + + +def distribute(n, iterable): + """Distribute the items from *iterable* among *n* smaller iterables. + + >>> group_1, group_2 = distribute(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 3, 5] + >>> list(group_2) + [2, 4, 6] + + If the length of *iterable* is not evenly divisible by *n*, then the + length of the returned iterables will not be identical: + + >>> children = distribute(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 4, 7], [2, 5], [3, 6]] + + If the length of *iterable* is smaller than *n*, then the last returned + iterables will be empty: + + >>> children = distribute(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + + This function uses :func:`itertools.tee` and may require significant + storage. If you need the order items in the smaller iterables to match the + original iterable, see :func:`divide`. + + """ + if n < 1: + raise ValueError('n must be at least 1') + + children = tee(iterable, n) + return [islice(it, index, None, n) for index, it in enumerate(children)] + + +def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None): + """Yield tuples whose elements are offset from *iterable*. + The amount by which the `i`-th item in each tuple is offset is given by + the `i`-th item in *offsets*. + + >>> list(stagger([0, 1, 2, 3])) + [(None, 0, 1), (0, 1, 2), (1, 2, 3)] + >>> list(stagger(range(8), offsets=(0, 2, 4))) + [(0, 2, 4), (1, 3, 5), (2, 4, 6), (3, 5, 7)] + + By default, the sequence will end when the final element of a tuple is the + last item in the iterable. To continue until the first element of a tuple + is the last item in the iterable, set *longest* to ``True``:: + + >>> list(stagger([0, 1, 2, 3], longest=True)) + [(None, 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, None), (3, None, None)] + + By default, ``None`` will be used to replace offsets beyond the end of the + sequence. Specify *fillvalue* to use some other value. + + """ + children = tee(iterable, len(offsets)) + + return zip_offset( + *children, offsets=offsets, longest=longest, fillvalue=fillvalue + ) + + +def zip_offset(*iterables, **kwargs): + """``zip`` the input *iterables* together, but offset the `i`-th iterable + by the `i`-th item in *offsets*. + + >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1))) + [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')] + + This can be used as a lightweight alternative to SciPy or pandas to analyze + data sets in which somes series have a lead or lag relationship. + + By default, the sequence will end when the shortest iterable is exhausted. + To continue until the longest iterable is exhausted, set *longest* to + ``True``. + + >>> list(zip_offset('0123', 'abcdef', offsets=(0, 1), longest=True)) + [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e'), (None, 'f')] + + By default, ``None`` will be used to replace offsets beyond the end of the + sequence. Specify *fillvalue* to use some other value. + + """ + offsets = kwargs['offsets'] + longest = kwargs.get('longest', False) + fillvalue = kwargs.get('fillvalue', None) + + if len(iterables) != len(offsets): + raise ValueError("Number of iterables and offsets didn't match") + + staggered = [] + for it, n in zip(iterables, offsets): + if n < 0: + staggered.append(chain(repeat(fillvalue, -n), it)) + elif n > 0: + staggered.append(islice(it, n, None)) + else: + staggered.append(it) + + if longest: + return zip_longest(*staggered, fillvalue=fillvalue) + + return zip(*staggered) + + +def sort_together(iterables, key_list=(0,), reverse=False): + """Return the input iterables sorted together, with *key_list* as the + priority for sorting. All iterables are trimmed to the length of the + shortest one. + + This can be used like the sorting function in a spreadsheet. If each + iterable represents a column of data, the key list determines which + columns are used for sorting. + + By default, all iterables are sorted using the ``0``-th iterable:: + + >>> iterables = [(4, 3, 2, 1), ('a', 'b', 'c', 'd')] + >>> sort_together(iterables) + [(1, 2, 3, 4), ('d', 'c', 'b', 'a')] + + Set a different key list to sort according to another iterable. + Specifying mutliple keys dictates how ties are broken:: + + >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')] + >>> sort_together(iterables, key_list=(1, 2)) + [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')] + + Set *reverse* to ``True`` to sort in descending order. + + >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True) + [(3, 2, 1), ('a', 'b', 'c')] + + """ + return list(zip(*sorted(zip(*iterables), + key=itemgetter(*key_list), + reverse=reverse))) + + +def divide(n, iterable): + """Divide the elements from *iterable* into *n* parts, maintaining + order. + + >>> group_1, group_2 = divide(2, [1, 2, 3, 4, 5, 6]) + >>> list(group_1) + [1, 2, 3] + >>> list(group_2) + [4, 5, 6] + + If the length of *iterable* is not evenly divisible by *n*, then the + length of the returned iterables will not be identical: + + >>> children = divide(3, [1, 2, 3, 4, 5, 6, 7]) + >>> [list(c) for c in children] + [[1, 2, 3], [4, 5], [6, 7]] + + If the length of the iterable is smaller than n, then the last returned + iterables will be empty: + + >>> children = divide(5, [1, 2, 3]) + >>> [list(c) for c in children] + [[1], [2], [3], [], []] + + This function will exhaust the iterable before returning and may require + significant storage. If order is not important, see :func:`distribute`, + which does not first pull the iterable into memory. + + """ + if n < 1: + raise ValueError('n must be at least 1') + + seq = tuple(iterable) + q, r = divmod(len(seq), n) + + ret = [] + for i in range(n): + start = (i * q) + (i if i < r else r) + stop = ((i + 1) * q) + (i + 1 if i + 1 < r else r) + ret.append(iter(seq[start:stop])) + + return ret + + +def always_iterable(obj, base_type=(text_type, binary_type)): + """If *obj* is iterable, return an iterator over its items:: + + >>> obj = (1, 2, 3) + >>> list(always_iterable(obj)) + [1, 2, 3] + + If *obj* is not iterable, return a one-item iterable containing *obj*:: + + >>> obj = 1 + >>> list(always_iterable(obj)) + [1] + + If *obj* is ``None``, return an empty iterable: + + >>> obj = None + >>> list(always_iterable(None)) + [] + + By default, binary and text strings are not considered iterable:: + + >>> obj = 'foo' + >>> list(always_iterable(obj)) + ['foo'] + + If *base_type* is set, objects for which ``isinstance(obj, base_type)`` + returns ``True`` won't be considered iterable. + + >>> obj = {'a': 1} + >>> list(always_iterable(obj)) # Iterate over the dict's keys + ['a'] + >>> list(always_iterable(obj, base_type=dict)) # Treat dicts as a unit + [{'a': 1}] + + Set *base_type* to ``None`` to avoid any special handling and treat objects + Python considers iterable as iterable: + + >>> obj = 'foo' + >>> list(always_iterable(obj, base_type=None)) + ['f', 'o', 'o'] + """ + if obj is None: + return iter(()) + + if (base_type is not None) and isinstance(obj, base_type): + return iter((obj,)) + + try: + return iter(obj) + except TypeError: + return iter((obj,)) + + +def adjacent(predicate, iterable, distance=1): + """Return an iterable over `(bool, item)` tuples where the `item` is + drawn from *iterable* and the `bool` indicates whether + that item satisfies the *predicate* or is adjacent to an item that does. + + For example, to find whether items are adjacent to a ``3``:: + + >>> list(adjacent(lambda x: x == 3, range(6))) + [(False, 0), (False, 1), (True, 2), (True, 3), (True, 4), (False, 5)] + + Set *distance* to change what counts as adjacent. For example, to find + whether items are two places away from a ``3``: + + >>> list(adjacent(lambda x: x == 3, range(6), distance=2)) + [(False, 0), (True, 1), (True, 2), (True, 3), (True, 4), (True, 5)] + + This is useful for contextualizing the results of a search function. + For example, a code comparison tool might want to identify lines that + have changed, but also surrounding lines to give the viewer of the diff + context. + + The predicate function will only be called once for each item in the + iterable. + + See also :func:`groupby_transform`, which can be used with this function + to group ranges of items with the same `bool` value. + + """ + # Allow distance=0 mainly for testing that it reproduces results with map() + if distance < 0: + raise ValueError('distance must be at least 0') + + i1, i2 = tee(iterable) + padding = [False] * distance + selected = chain(padding, map(predicate, i1), padding) + adjacent_to_selected = map(any, windowed(selected, 2 * distance + 1)) + return zip(adjacent_to_selected, i2) + + +def groupby_transform(iterable, keyfunc=None, valuefunc=None): + """An extension of :func:`itertools.groupby` that transforms the values of + *iterable* after grouping them. + *keyfunc* is a function used to compute a grouping key for each item. + *valuefunc* is a function for transforming the items after grouping. + + >>> iterable = 'AaaABbBCcA' + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: x.lower() + >>> grouper = groupby_transform(iterable, keyfunc, valuefunc) + >>> [(k, ''.join(g)) for k, g in grouper] + [('A', 'aaaa'), ('B', 'bbb'), ('C', 'cc'), ('A', 'a')] + + *keyfunc* and *valuefunc* default to identity functions if they are not + specified. + + :func:`groupby_transform` is useful when grouping elements of an iterable + using a separate iterable as the key. To do this, :func:`zip` the iterables + and pass a *keyfunc* that extracts the first element and a *valuefunc* + that extracts the second element:: + + >>> from operator import itemgetter + >>> keys = [0, 0, 1, 1, 1, 2, 2, 2, 3] + >>> values = 'abcdefghi' + >>> iterable = zip(keys, values) + >>> grouper = groupby_transform(iterable, itemgetter(0), itemgetter(1)) + >>> [(k, ''.join(g)) for k, g in grouper] + [(0, 'ab'), (1, 'cde'), (2, 'fgh'), (3, 'i')] + + Note that the order of items in the iterable is significant. + Only adjacent items are grouped together, so if you don't want any + duplicate groups, you should sort the iterable by the key function. + + """ + valuefunc = (lambda x: x) if valuefunc is None else valuefunc + return ((k, map(valuefunc, g)) for k, g in groupby(iterable, keyfunc)) + + +def numeric_range(*args): + """An extension of the built-in ``range()`` function whose arguments can + be any orderable numeric type. + + With only *stop* specified, *start* defaults to ``0`` and *step* + defaults to ``1``. The output items will match the type of *stop*: + + >>> list(numeric_range(3.5)) + [0.0, 1.0, 2.0, 3.0] + + With only *start* and *stop* specified, *step* defaults to ``1``. The + output items will match the type of *start*: + + >>> from decimal import Decimal + >>> start = Decimal('2.1') + >>> stop = Decimal('5.1') + >>> list(numeric_range(start, stop)) + [Decimal('2.1'), Decimal('3.1'), Decimal('4.1')] + + With *start*, *stop*, and *step* specified the output items will match + the type of ``start + step``: + + >>> from fractions import Fraction + >>> start = Fraction(1, 2) # Start at 1/2 + >>> stop = Fraction(5, 2) # End at 5/2 + >>> step = Fraction(1, 2) # Count by 1/2 + >>> list(numeric_range(start, stop, step)) + [Fraction(1, 2), Fraction(1, 1), Fraction(3, 2), Fraction(2, 1)] + + If *step* is zero, ``ValueError`` is raised. Negative steps are supported: + + >>> list(numeric_range(3, -1, -1.0)) + [3.0, 2.0, 1.0, 0.0] + + Be aware of the limitations of floating point numbers; the representation + of the yielded numbers may be surprising. + + """ + argc = len(args) + if argc == 1: + stop, = args + start = type(stop)(0) + step = 1 + elif argc == 2: + start, stop = args + step = 1 + elif argc == 3: + start, stop, step = args + else: + err_msg = 'numeric_range takes at most 3 arguments, got {}' + raise TypeError(err_msg.format(argc)) + + values = (start + (step * n) for n in count()) + if step > 0: + return takewhile(partial(gt, stop), values) + elif step < 0: + return takewhile(partial(lt, stop), values) + else: + raise ValueError('numeric_range arg 3 must not be zero') + + +def count_cycle(iterable, n=None): + """Cycle through the items from *iterable* up to *n* times, yielding + the number of completed cycles along with each item. If *n* is omitted the + process repeats indefinitely. + + >>> list(count_cycle('AB', 3)) + [(0, 'A'), (0, 'B'), (1, 'A'), (1, 'B'), (2, 'A'), (2, 'B')] + + """ + iterable = tuple(iterable) + if not iterable: + return iter(()) + counter = count() if n is None else range(n) + return ((i, item) for i in counter for item in iterable) + + +def locate(iterable, pred=bool, window_size=None): + """Yield the index of each item in *iterable* for which *pred* returns + ``True``. + + *pred* defaults to :func:`bool`, which will select truthy items: + + >>> list(locate([0, 1, 1, 0, 1, 0, 0])) + [1, 2, 4] + + Set *pred* to a custom function to, e.g., find the indexes for a particular + item. + + >>> list(locate(['a', 'b', 'c', 'b'], lambda x: x == 'b')) + [1, 3] + + If *window_size* is given, then the *pred* function will be called with + that many items. This enables searching for sub-sequences: + + >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + >>> pred = lambda *args: args == (1, 2, 3) + >>> list(locate(iterable, pred=pred, window_size=3)) + [1, 5, 9] + + Use with :func:`seekable` to find indexes and then retrieve the associated + items: + + >>> from itertools import count + >>> from more_itertools import seekable + >>> source = (3 * n + 1 if (n % 2) else n // 2 for n in count()) + >>> it = seekable(source) + >>> pred = lambda x: x > 100 + >>> indexes = locate(it, pred=pred) + >>> i = next(indexes) + >>> it.seek(i) + >>> next(it) + 106 + + """ + if window_size is None: + return compress(count(), map(pred, iterable)) + + if window_size < 1: + raise ValueError('window size must be at least 1') + + it = windowed(iterable, window_size, fillvalue=_marker) + return compress(count(), starmap(pred, it)) + + +def lstrip(iterable, pred): + """Yield the items from *iterable*, but strip any from the beginning + for which *pred* returns ``True``. + + For example, to remove a set of items from the start of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(lstrip(iterable, pred)) + [1, 2, None, 3, False, None] + + This function is analogous to to :func:`str.lstrip`, and is essentially + an wrapper for :func:`itertools.dropwhile`. + + """ + return dropwhile(pred, iterable) + + +def rstrip(iterable, pred): + """Yield the items from *iterable*, but strip any from the end + for which *pred* returns ``True``. + + For example, to remove a set of items from the end of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(rstrip(iterable, pred)) + [None, False, None, 1, 2, None, 3] + + This function is analogous to :func:`str.rstrip`. + + """ + cache = [] + cache_append = cache.append + for x in iterable: + if pred(x): + cache_append(x) + else: + for y in cache: + yield y + del cache[:] + yield x + + +def strip(iterable, pred): + """Yield the items from *iterable*, but strip any from the + beginning and end for which *pred* returns ``True``. + + For example, to remove a set of items from both ends of an iterable: + + >>> iterable = (None, False, None, 1, 2, None, 3, False, None) + >>> pred = lambda x: x in {None, False, ''} + >>> list(strip(iterable, pred)) + [1, 2, None, 3] + + This function is analogous to :func:`str.strip`. + + """ + return rstrip(lstrip(iterable, pred), pred) + + +def islice_extended(iterable, *args): + """An extension of :func:`itertools.islice` that supports negative values + for *stop*, *start*, and *step*. + + >>> iterable = iter('abcdefgh') + >>> list(islice_extended(iterable, -4, -1)) + ['e', 'f', 'g'] + + Slices with negative values require some caching of *iterable*, but this + function takes care to minimize the amount of memory required. + + For example, you can use a negative step with an infinite iterator: + + >>> from itertools import count + >>> list(islice_extended(count(), 110, 99, -2)) + [110, 108, 106, 104, 102, 100] + + """ + s = slice(*args) + start = s.start + stop = s.stop + if s.step == 0: + raise ValueError('step argument must be a non-zero integer or None.') + step = s.step or 1 + + it = iter(iterable) + + if step > 0: + start = 0 if (start is None) else start + + if (start < 0): + # Consume all but the last -start items + cache = deque(enumerate(it, 1), maxlen=-start) + len_iter = cache[-1][0] if cache else 0 + + # Adjust start to be positive + i = max(len_iter + start, 0) + + # Adjust stop to be positive + if stop is None: + j = len_iter + elif stop >= 0: + j = min(stop, len_iter) + else: + j = max(len_iter + stop, 0) + + # Slice the cache + n = j - i + if n <= 0: + return + + for index, item in islice(cache, 0, n, step): + yield item + elif (stop is not None) and (stop < 0): + # Advance to the start position + next(islice(it, start, start), None) + + # When stop is negative, we have to carry -stop items while + # iterating + cache = deque(islice(it, -stop), maxlen=-stop) + + for index, item in enumerate(it): + cached_item = cache.popleft() + if index % step == 0: + yield cached_item + cache.append(item) + else: + # When both start and stop are positive we have the normal case + for item in islice(it, start, stop, step): + yield item + else: + start = -1 if (start is None) else start + + if (stop is not None) and (stop < 0): + # Consume all but the last items + n = -stop - 1 + cache = deque(enumerate(it, 1), maxlen=n) + len_iter = cache[-1][0] if cache else 0 + + # If start and stop are both negative they are comparable and + # we can just slice. Otherwise we can adjust start to be negative + # and then slice. + if start < 0: + i, j = start, stop + else: + i, j = min(start - len_iter, -1), None + + for index, item in list(cache)[i:j:step]: + yield item + else: + # Advance to the stop position + if stop is not None: + m = stop + 1 + next(islice(it, m, m), None) + + # stop is positive, so if start is negative they are not comparable + # and we need the rest of the items. + if start < 0: + i = start + n = None + # stop is None and start is positive, so we just need items up to + # the start index. + elif stop is None: + i = None + n = start + 1 + # Both stop and start are positive, so they are comparable. + else: + i = None + n = start - stop + if n <= 0: + return + + cache = list(islice(it, n)) + + for item in cache[i::step]: + yield item + + +def always_reversible(iterable): + """An extension of :func:`reversed` that supports all iterables, not + just those which implement the ``Reversible`` or ``Sequence`` protocols. + + >>> print(*always_reversible(x for x in range(3))) + 2 1 0 + + If the iterable is already reversible, this function returns the + result of :func:`reversed()`. If the iterable is not reversible, + this function will cache the remaining items in the iterable and + yield them in reverse order, which may require significant storage. + """ + try: + return reversed(iterable) + except TypeError: + return reversed(list(iterable)) + + +def consecutive_groups(iterable, ordering=lambda x: x): + """Yield groups of consecutive items using :func:`itertools.groupby`. + The *ordering* function determines whether two items are adjacent by + returning their position. + + By default, the ordering function is the identity function. This is + suitable for finding runs of numbers: + + >>> iterable = [1, 10, 11, 12, 20, 30, 31, 32, 33, 40] + >>> for group in consecutive_groups(iterable): + ... print(list(group)) + [1] + [10, 11, 12] + [20] + [30, 31, 32, 33] + [40] + + For finding runs of adjacent letters, try using the :meth:`index` method + of a string of letters: + + >>> from string import ascii_lowercase + >>> iterable = 'abcdfgilmnop' + >>> ordering = ascii_lowercase.index + >>> for group in consecutive_groups(iterable, ordering): + ... print(list(group)) + ['a', 'b', 'c', 'd'] + ['f', 'g'] + ['i'] + ['l', 'm', 'n', 'o', 'p'] + + """ + for k, g in groupby( + enumerate(iterable), key=lambda x: x[0] - ordering(x[1]) + ): + yield map(itemgetter(1), g) + + +def difference(iterable, func=sub): + """By default, compute the first difference of *iterable* using + :func:`operator.sub`. + + >>> iterable = [0, 1, 3, 6, 10] + >>> list(difference(iterable)) + [0, 1, 2, 3, 4] + + This is the opposite of :func:`accumulate`'s default behavior: + + >>> from more_itertools import accumulate + >>> iterable = [0, 1, 2, 3, 4] + >>> list(accumulate(iterable)) + [0, 1, 3, 6, 10] + >>> list(difference(accumulate(iterable))) + [0, 1, 2, 3, 4] + + By default *func* is :func:`operator.sub`, but other functions can be + specified. They will be applied as follows:: + + A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ... + + For example, to do progressive division: + + >>> iterable = [1, 2, 6, 24, 120] # Factorial sequence + >>> func = lambda x, y: x // y + >>> list(difference(iterable, func)) + [1, 2, 3, 4, 5] + + """ + a, b = tee(iterable) + try: + item = next(b) + except StopIteration: + return iter([]) + return chain([item], map(lambda x: func(x[1], x[0]), zip(a, b))) + + +class SequenceView(Sequence): + """Return a read-only view of the sequence object *target*. + + :class:`SequenceView` objects are analagous to Python's built-in + "dictionary view" types. They provide a dynamic view of a sequence's items, + meaning that when the sequence updates, so does the view. + + >>> seq = ['0', '1', '2'] + >>> view = SequenceView(seq) + >>> view + SequenceView(['0', '1', '2']) + >>> seq.append('3') + >>> view + SequenceView(['0', '1', '2', '3']) + + Sequence views support indexing, slicing, and length queries. They act + like the underlying sequence, except they don't allow assignment: + + >>> view[1] + '1' + >>> view[1:-1] + ['1', '2'] + >>> len(view) + 4 + + Sequence views are useful as an alternative to copying, as they don't + require (much) extra storage. + + """ + def __init__(self, target): + if not isinstance(target, Sequence): + raise TypeError + self._target = target + + def __getitem__(self, index): + return self._target[index] + + def __len__(self): + return len(self._target) + + def __repr__(self): + return '{}({})'.format(self.__class__.__name__, repr(self._target)) + + +class seekable(object): + """Wrap an iterator to allow for seeking backward and forward. This + progressively caches the items in the source iterable so they can be + re-visited. + + Call :meth:`seek` with an index to seek to that position in the source + iterable. + + To "reset" an iterator, seek to ``0``: + + >>> from itertools import count + >>> it = seekable((str(n) for n in count())) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> it.seek(0) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> next(it) + '3' + + You can also seek forward: + + >>> it = seekable((str(n) for n in range(20))) + >>> it.seek(10) + >>> next(it) + '10' + >>> it.seek(20) # Seeking past the end of the source isn't a problem + >>> list(it) + [] + >>> it.seek(0) # Resetting works even after hitting the end + >>> next(it), next(it), next(it) + ('0', '1', '2') + + The cache grows as the source iterable progresses, so beware of wrapping + very large or infinite iterables. + + You may view the contents of the cache with the :meth:`elements` method. + That returns a :class:`SequenceView`, a view that updates automatically: + + >>> it = seekable((str(n) for n in range(10))) + >>> next(it), next(it), next(it) + ('0', '1', '2') + >>> elements = it.elements() + >>> elements + SequenceView(['0', '1', '2']) + >>> next(it) + '3' + >>> elements + SequenceView(['0', '1', '2', '3']) + + """ + + def __init__(self, iterable): + self._source = iter(iterable) + self._cache = [] + self._index = None + + def __iter__(self): + return self + + def __next__(self): + if self._index is not None: + try: + item = self._cache[self._index] + except IndexError: + self._index = None + else: + self._index += 1 + return item + + item = next(self._source) + self._cache.append(item) + return item + + next = __next__ + + def elements(self): + return SequenceView(self._cache) + + def seek(self, index): + self._index = index + remainder = index - len(self._cache) + if remainder > 0: + consume(self, remainder) + + +class run_length(object): + """ + :func:`run_length.encode` compresses an iterable with run-length encoding. + It yields groups of repeated items with the count of how many times they + were repeated: + + >>> uncompressed = 'abbcccdddd' + >>> list(run_length.encode(uncompressed)) + [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + + :func:`run_length.decode` decompresses an iterable that was previously + compressed with run-length encoding. It yields the items of the + decompressed iterable: + + >>> compressed = [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + >>> list(run_length.decode(compressed)) + ['a', 'b', 'b', 'c', 'c', 'c', 'd', 'd', 'd', 'd'] + + """ + + @staticmethod + def encode(iterable): + return ((k, ilen(g)) for k, g in groupby(iterable)) + + @staticmethod + def decode(iterable): + return chain.from_iterable(repeat(k, n) for k, n in iterable) + + +def exactly_n(iterable, n, predicate=bool): + """Return ``True`` if exactly ``n`` items in the iterable are ``True`` + according to the *predicate* function. + + >>> exactly_n([True, True, False], 2) + True + >>> exactly_n([True, True, False], 1) + False + >>> exactly_n([0, 1, 2, 3, 4, 5], 3, lambda x: x < 3) + True + + The iterable will be advanced until ``n + 1`` truthy items are encountered, + so avoid calling it on infinite iterables. + + """ + return len(take(n + 1, filter(predicate, iterable))) == n + + +def circular_shifts(iterable): + """Return a list of circular shifts of *iterable*. + + >>> circular_shifts(range(4)) + [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] + """ + lst = list(iterable) + return take(len(lst), windowed(cycle(lst), len(lst))) + + +def make_decorator(wrapping_func, result_index=0): + """Return a decorator version of *wrapping_func*, which is a function that + modifies an iterable. *result_index* is the position in that function's + signature where the iterable goes. + + This lets you use itertools on the "production end," i.e. at function + definition. This can augment what the function returns without changing the + function's code. + + For example, to produce a decorator version of :func:`chunked`: + + >>> from more_itertools import chunked + >>> chunker = make_decorator(chunked, result_index=0) + >>> @chunker(3) + ... def iter_range(n): + ... return iter(range(n)) + ... + >>> list(iter_range(9)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + + To only allow truthy items to be returned: + + >>> truth_serum = make_decorator(filter, result_index=1) + >>> @truth_serum(bool) + ... def boolean_test(): + ... return [0, 1, '', ' ', False, True] + ... + >>> list(boolean_test()) + [1, ' ', True] + + The :func:`peekable` and :func:`seekable` wrappers make for practical + decorators: + + >>> from more_itertools import peekable + >>> peekable_function = make_decorator(peekable) + >>> @peekable_function() + ... def str_range(*args): + ... return (str(x) for x in range(*args)) + ... + >>> it = str_range(1, 20, 2) + >>> next(it), next(it), next(it) + ('1', '3', '5') + >>> it.peek() + '7' + >>> next(it) + '7' + + """ + # See https://sites.google.com/site/bbayles/index/decorator_factory for + # notes on how this works. + def decorator(*wrapping_args, **wrapping_kwargs): + def outer_wrapper(f): + def inner_wrapper(*args, **kwargs): + result = f(*args, **kwargs) + wrapping_args_ = list(wrapping_args) + wrapping_args_.insert(result_index, result) + return wrapping_func(*wrapping_args_, **wrapping_kwargs) + + return inner_wrapper + + return outer_wrapper + + return decorator + + +def map_reduce(iterable, keyfunc, valuefunc=None, reducefunc=None): + """Return a dictionary that maps the items in *iterable* to categories + defined by *keyfunc*, transforms them with *valuefunc*, and + then summarizes them by category with *reducefunc*. + + *valuefunc* defaults to the identity function if it is unspecified. + If *reducefunc* is unspecified, no summarization takes place: + + >>> keyfunc = lambda x: x.upper() + >>> result = map_reduce('abbccc', keyfunc) + >>> sorted(result.items()) + [('A', ['a']), ('B', ['b', 'b']), ('C', ['c', 'c', 'c'])] + + Specifying *valuefunc* transforms the categorized items: + + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: 1 + >>> result = map_reduce('abbccc', keyfunc, valuefunc) + >>> sorted(result.items()) + [('A', [1]), ('B', [1, 1]), ('C', [1, 1, 1])] + + Specifying *reducefunc* summarizes the categorized items: + + >>> keyfunc = lambda x: x.upper() + >>> valuefunc = lambda x: 1 + >>> reducefunc = sum + >>> result = map_reduce('abbccc', keyfunc, valuefunc, reducefunc) + >>> sorted(result.items()) + [('A', 1), ('B', 2), ('C', 3)] + + You may want to filter the input iterable before applying the map/reduce + procedure: + + >>> all_items = range(30) + >>> items = [x for x in all_items if 10 <= x <= 20] # Filter + >>> keyfunc = lambda x: x % 2 # Evens map to 0; odds to 1 + >>> categories = map_reduce(items, keyfunc=keyfunc) + >>> sorted(categories.items()) + [(0, [10, 12, 14, 16, 18, 20]), (1, [11, 13, 15, 17, 19])] + >>> summaries = map_reduce(items, keyfunc=keyfunc, reducefunc=sum) + >>> sorted(summaries.items()) + [(0, 90), (1, 75)] + + Note that all items in the iterable are gathered into a list before the + summarization step, which may require significant storage. + + The returned object is a :obj:`collections.defaultdict` with the + ``default_factory`` set to ``None``, such that it behaves like a normal + dictionary. + + """ + valuefunc = (lambda x: x) if (valuefunc is None) else valuefunc + + ret = defaultdict(list) + for item in iterable: + key = keyfunc(item) + value = valuefunc(item) + ret[key].append(value) + + if reducefunc is not None: + for key, value_list in ret.items(): + ret[key] = reducefunc(value_list) + + ret.default_factory = None + return ret + + +def rlocate(iterable, pred=bool, window_size=None): + """Yield the index of each item in *iterable* for which *pred* returns + ``True``, starting from the right and moving left. + + *pred* defaults to :func:`bool`, which will select truthy items: + + >>> list(rlocate([0, 1, 1, 0, 1, 0, 0])) # Truthy at 1, 2, and 4 + [4, 2, 1] + + Set *pred* to a custom function to, e.g., find the indexes for a particular + item: + + >>> iterable = iter('abcb') + >>> pred = lambda x: x == 'b' + >>> list(rlocate(iterable, pred)) + [3, 1] + + If *window_size* is given, then the *pred* function will be called with + that many items. This enables searching for sub-sequences: + + >>> iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + >>> pred = lambda *args: args == (1, 2, 3) + >>> list(rlocate(iterable, pred=pred, window_size=3)) + [9, 5, 1] + + Beware, this function won't return anything for infinite iterables. + If *iterable* is reversible, ``rlocate`` will reverse it and search from + the right. Otherwise, it will search from the left and return the results + in reverse order. + + See :func:`locate` to for other example applications. + + """ + if window_size is None: + try: + len_iter = len(iterable) + return ( + len_iter - i - 1 for i in locate(reversed(iterable), pred) + ) + except TypeError: + pass + + return reversed(list(locate(iterable, pred, window_size))) + + +def replace(iterable, pred, substitutes, count=None, window_size=1): + """Yield the items from *iterable*, replacing the items for which *pred* + returns ``True`` with the items from the iterable *substitutes*. + + >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1] + >>> pred = lambda x: x == 0 + >>> substitutes = (2, 3) + >>> list(replace(iterable, pred, substitutes)) + [1, 1, 2, 3, 1, 1, 2, 3, 1, 1] + + If *count* is given, the number of replacements will be limited: + + >>> iterable = [1, 1, 0, 1, 1, 0, 1, 1, 0] + >>> pred = lambda x: x == 0 + >>> substitutes = [None] + >>> list(replace(iterable, pred, substitutes, count=2)) + [1, 1, None, 1, 1, None, 1, 1, 0] + + Use *window_size* to control the number of items passed as arguments to + *pred*. This allows for locating and replacing subsequences. + + >>> iterable = [0, 1, 2, 5, 0, 1, 2, 5] + >>> window_size = 3 + >>> pred = lambda *args: args == (0, 1, 2) # 3 items passed to pred + >>> substitutes = [3, 4] # Splice in these items + >>> list(replace(iterable, pred, substitutes, window_size=window_size)) + [3, 4, 5, 3, 4, 5] + + """ + if window_size < 1: + raise ValueError('window_size must be at least 1') + + # Save the substitutes iterable, since it's used more than once + substitutes = tuple(substitutes) + + # Add padding such that the number of windows matches the length of the + # iterable + it = chain(iterable, [_marker] * (window_size - 1)) + windows = windowed(it, window_size) + + n = 0 + for w in windows: + # If the current window matches our predicate (and we haven't hit + # our maximum number of replacements), splice in the substitutes + # and then consume the following windows that overlap with this one. + # For example, if the iterable is (0, 1, 2, 3, 4...) + # and the window size is 2, we have (0, 1), (1, 2), (2, 3)... + # If the predicate matches on (0, 1), we need to zap (0, 1) and (1, 2) + if pred(*w): + if (count is None) or (n < count): + n += 1 + for s in substitutes: + yield s + consume(windows, window_size - 1) + continue + + # If there was no match (or we've reached the replacement limit), + # yield the first item from the window. + if w and (w[0] is not _marker): + yield w[0] diff --git a/libs/more_itertools/recipes.py b/libs/more_itertools/recipes.py index c92373c6..3a7706cb 100644 --- a/libs/more_itertools/recipes.py +++ b/libs/more_itertools/recipes.py @@ -8,21 +8,79 @@ Some backward-compatible usability improvements have been made. """ from collections import deque -from itertools import chain, combinations, count, cycle, groupby, ifilterfalse, imap, islice, izip, izip_longest, repeat, starmap, tee # Wrapping breaks 2to3. +from itertools import ( + chain, combinations, count, cycle, groupby, islice, repeat, starmap, tee +) import operator from random import randrange, sample, choice +from six import PY2 +from six.moves import filter, filterfalse, map, range, zip, zip_longest -__all__ = ['take', 'tabulate', 'consume', 'nth', 'quantify', 'padnone', - 'ncycles', 'dotproduct', 'flatten', 'repeatfunc', 'pairwise', - 'grouper', 'roundrobin', 'powerset', 'unique_everseen', - 'unique_justseen', 'iter_except', 'random_product', - 'random_permutation', 'random_combination', - 'random_combination_with_replacement'] +__all__ = [ + 'accumulate', + 'all_equal', + 'consume', + 'dotproduct', + 'first_true', + 'flatten', + 'grouper', + 'iter_except', + 'ncycles', + 'nth', + 'nth_combination', + 'padnone', + 'pairwise', + 'partition', + 'powerset', + 'prepend', + 'quantify', + 'random_combination_with_replacement', + 'random_combination', + 'random_permutation', + 'random_product', + 'repeatfunc', + 'roundrobin', + 'tabulate', + 'tail', + 'take', + 'unique_everseen', + 'unique_justseen', +] + + +def accumulate(iterable, func=operator.add): + """ + Return an iterator whose items are the accumulated results of a function + (specified by the optional *func* argument) that takes two arguments. + By default, returns accumulated sums with :func:`operator.add`. + + >>> list(accumulate([1, 2, 3, 4, 5])) # Running sum + [1, 3, 6, 10, 15] + >>> list(accumulate([1, 2, 3], func=operator.mul)) # Running product + [1, 2, 6] + >>> list(accumulate([0, 1, -1, 2, 3, 2], func=max)) # Running maximum + [0, 1, 1, 2, 3, 3] + + This function is available in the ``itertools`` module for Python 3.2 and + greater. + + """ + it = iter(iterable) + try: + total = next(it) + except StopIteration: + return + else: + yield total + + for element in it: + total = func(total, element) + yield total def take(n, iterable): - """Return first n items of the iterable as a list + """Return first *n* items of the iterable as a list. >>> take(3, range(10)) [0, 1, 2] @@ -37,21 +95,37 @@ def take(n, iterable): def tabulate(function, start=0): - """Return an iterator mapping the function over linear input. + """Return an iterator over the results of ``func(start)``, + ``func(start + 1)``, ``func(start + 2)``... - The start argument will be increased by 1 each time the iterator is called - and fed into the function. + *func* should be a function that accepts one integer argument. - >>> t = tabulate(lambda x: x**2, -3) - >>> take(3, t) - [9, 4, 1] + If *start* is not specified it defaults to 0. It will be incremented each + time the iterator is advanced. + + >>> square = lambda x: x ** 2 + >>> iterator = tabulate(square, -3) + >>> take(4, iterator) + [9, 4, 1, 0] """ - return imap(function, count(start)) + return map(function, count(start)) + + +def tail(n, iterable): + """Return an iterator over the last *n* items of *iterable*. + + >>> t = tail(3, 'ABCDEFG') + >>> list(t) + ['E', 'F', 'G'] + + """ + return iter(deque(iterable, maxlen=n)) def consume(iterator, n=None): - """Advance the iterator n-steps ahead. If n is none, consume entirely. + """Advance *iterable* by *n* steps. If *n* is ``None``, consume it + entirely. Efficiently exhausts an iterator without returning values. Defaults to consuming the whole iterator, but an optional second argument may be @@ -90,7 +164,7 @@ def consume(iterator, n=None): def nth(iterable, n, default=None): - """Returns the nth item or a default value + """Returns the nth item or a default value. >>> l = range(10) >>> nth(l, 3) @@ -102,30 +176,46 @@ def nth(iterable, n, default=None): return next(islice(iterable, n, None), default) +def all_equal(iterable): + """ + Returns ``True`` if all the elements are equal to each other. + + >>> all_equal('aaaa') + True + >>> all_equal('aaab') + False + + """ + g = groupby(iterable) + return next(g, True) and not next(g, False) + + def quantify(iterable, pred=bool): - """Return the how many times the predicate is true + """Return the how many times the predicate is true. >>> quantify([True, False, True]) 2 """ - return sum(imap(pred, iterable)) + return sum(map(pred, iterable)) def padnone(iterable): - """Returns the sequence of elements and then returns None indefinitely. + """Returns the sequence of elements and then returns ``None`` indefinitely. >>> take(5, padnone(range(3))) [0, 1, 2, None, None] - Useful for emulating the behavior of the built-in map() function. + Useful for emulating the behavior of the built-in :func:`map` function. + + See also :func:`padded`. """ return chain(iterable, repeat(None)) def ncycles(iterable, n): - """Returns the sequence elements n times + """Returns the sequence elements *n* times >>> list(ncycles(["a", "b"], 3)) ['a', 'b', 'a', 'b', 'a', 'b'] @@ -135,32 +225,47 @@ def ncycles(iterable, n): def dotproduct(vec1, vec2): - """Returns the dot product of the two iterables + """Returns the dot product of the two iterables. >>> dotproduct([10, 10], [20, 20]) 400 """ - return sum(imap(operator.mul, vec1, vec2)) + return sum(map(operator.mul, vec1, vec2)) def flatten(listOfLists): - """Return an iterator flattening one level of nesting in a list of lists + """Return an iterator flattening one level of nesting in a list of lists. >>> list(flatten([[0, 1], [2, 3]])) [0, 1, 2, 3] + See also :func:`collapse`, which can flatten multiple levels of nesting. + """ return chain.from_iterable(listOfLists) def repeatfunc(func, times=None, *args): - """Repeat calls to func with specified arguments. + """Call *func* with *args* repeatedly, returning an iterable over the + results. - >>> list(repeatfunc(lambda: 5, 3)) - [5, 5, 5] - >>> list(repeatfunc(lambda x: x ** 2, 3, 3)) - [9, 9, 9] + If *times* is specified, the iterable will terminate after that many + repetitions: + + >>> from operator import add + >>> times = 4 + >>> args = 3, 5 + >>> list(repeatfunc(add, times, *args)) + [8, 8, 8, 8] + + If *times* is ``None`` the iterable will not terminate: + + >>> from random import randrange + >>> times = None + >>> args = 1, 11 + >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP + [2, 4, 8, 1, 8, 4] """ if times is None: @@ -177,30 +282,37 @@ def pairwise(iterable): """ a, b = tee(iterable) next(b, None) - return izip(a, b) + return zip(a, b) def grouper(n, iterable, fillvalue=None): - """Collect data into fixed-length chunks or blocks + """Collect data into fixed-length chunks or blocks. >>> list(grouper(3, 'ABCDEFG', 'x')) [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] """ args = [iter(iterable)] * n - return izip_longest(fillvalue=fillvalue, *args) + return zip_longest(fillvalue=fillvalue, *args) def roundrobin(*iterables): - """Yields an item from each iterable, alternating between them + """Yields an item from each iterable, alternating between them. >>> list(roundrobin('ABC', 'D', 'EF')) ['A', 'D', 'E', 'B', 'F', 'C'] + This function produces the same output as :func:`interleave_longest`, but + may perform better for some inputs (in particular when the number of + iterables is small). + """ # Recipe credited to George Sakkis pending = len(iterables) - nexts = cycle(iter(it).next for it in iterables) + if PY2: + nexts = cycle(iter(it).next for it in iterables) + else: + nexts = cycle(iter(it).__next__ for it in iterables) while pending: try: for next in nexts: @@ -210,38 +322,73 @@ def roundrobin(*iterables): nexts = cycle(islice(nexts, pending)) +def partition(pred, iterable): + """ + Returns a 2-tuple of iterables derived from the input iterable. + The first yields the items that have ``pred(item) == False``. + The second yields the items that have ``pred(item) == True``. + + >>> is_odd = lambda x: x % 2 != 0 + >>> iterable = range(10) + >>> even_items, odd_items = partition(is_odd, iterable) + >>> list(even_items), list(odd_items) + ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]) + + """ + # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9 + t1, t2 = tee(iterable) + return filterfalse(pred, t1), filter(pred, t2) + + def powerset(iterable): - """Yields all possible subsets of the iterable + """Yields all possible subsets of the iterable. >>> list(powerset([1,2,3])) [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] """ s = list(iterable) - return chain.from_iterable(combinations(s, r) for r in range(len(s)+1)) + return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) def unique_everseen(iterable, key=None): - """Yield unique elements, preserving order. + """ + Yield unique elements, preserving order. >>> list(unique_everseen('AAAABBBCCDAABBB')) ['A', 'B', 'C', 'D'] >>> list(unique_everseen('ABBCcAD', str.lower)) ['A', 'B', 'C', 'D'] + Sequences with a mix of hashable and unhashable items can be used. + The function will be slower (i.e., `O(n^2)`) for unhashable items. + """ - seen = set() - seen_add = seen.add + seenset = set() + seenset_add = seenset.add + seenlist = [] + seenlist_add = seenlist.append if key is None: - for element in ifilterfalse(seen.__contains__, iterable): - seen_add(element) - yield element + for element in iterable: + try: + if element not in seenset: + seenset_add(element) + yield element + except TypeError: + if element not in seenlist: + seenlist_add(element) + yield element else: for element in iterable: k = key(element) - if k not in seen: - seen_add(k) - yield element + try: + if k not in seenset: + seenset_add(k) + yield element + except TypeError: + if k not in seenlist: + seenlist_add(k) + yield element def unique_justseen(iterable, key=None): @@ -253,17 +400,17 @@ def unique_justseen(iterable, key=None): ['A', 'B', 'C', 'A', 'D'] """ - return imap(next, imap(operator.itemgetter(1), groupby(iterable, key))) + return map(next, map(operator.itemgetter(1), groupby(iterable, key))) def iter_except(func, exception, first=None): """Yields results from a function repeatedly until an exception is raised. Converts a call-until-exception interface to an iterator interface. - Like __builtin__.iter(func, sentinel) but uses an exception instead - of a sentinel to end the loop. + Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel + to end the loop. - >>> l = range(3) + >>> l = [0, 1, 2] >>> list(iter_except(l.pop, IndexError)) [2, 1, 0] @@ -277,28 +424,58 @@ def iter_except(func, exception, first=None): pass -def random_product(*args, **kwds): - """Returns a random pairing of items from each iterable argument +def first_true(iterable, default=False, pred=None): + """ + Returns the first true value in the iterable. - If `repeat` is provided as a kwarg, it's value will be used to indicate - how many pairings should be chosen. + If no true value is found, returns *default* - >>> random_product(['a', 'b', 'c'], [1, 2], repeat=2) # doctest:+SKIP - ('b', '2', 'c', '2') + If *pred* is not None, returns the first item for which + ``pred(item) == True`` . + + >>> first_true(range(10)) + 1 + >>> first_true(range(10), pred=lambda x: x > 5) + 6 + >>> first_true(range(10), default='missing', pred=lambda x: x > 9) + 'missing' """ - pools = map(tuple, args) * kwds.get('repeat', 1) + return next(filter(pred, iterable), default) + + +def random_product(*args, **kwds): + """Draw an item at random from each of the input iterables. + + >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP + ('c', 3, 'Z') + + If *repeat* is provided as a keyword argument, that many items will be + drawn from each iterable. + + >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP + ('a', 2, 'd', 3) + + This equivalent to taking a random selection from + ``itertools.product(*args, **kwarg)``. + + """ + pools = [tuple(pool) for pool in args] * kwds.get('repeat', 1) return tuple(choice(pool) for pool in pools) def random_permutation(iterable, r=None): - """Returns a random permutation. + """Return a random *r* length permutation of the elements in *iterable*. - If r is provided, the permutation is truncated to length r. + If *r* is not specified or is ``None``, then *r* defaults to the length of + *iterable*. - >>> random_permutation(range(5)) # doctest:+SKIP + >>> random_permutation(range(5)) # doctest:+SKIP (3, 4, 0, 1, 2) + This equivalent to taking a random selection from + ``itertools.permutations(iterable, r)``. + """ pool = tuple(iterable) r = len(pool) if r is None else r @@ -306,26 +483,83 @@ def random_permutation(iterable, r=None): def random_combination(iterable, r): - """Returns a random combination of length r, chosen without replacement. + """Return a random *r* length subsequence of the elements in *iterable*. - >>> random_combination(range(5), 3) # doctest:+SKIP + >>> random_combination(range(5), 3) # doctest:+SKIP (2, 3, 4) + This equivalent to taking a random selection from + ``itertools.combinations(iterable, r)``. + """ pool = tuple(iterable) n = len(pool) - indices = sorted(sample(xrange(n), r)) + indices = sorted(sample(range(n), r)) return tuple(pool[i] for i in indices) def random_combination_with_replacement(iterable, r): - """Returns a random combination of length r, chosen with replacement. + """Return a random *r* length subsequence of elements in *iterable*, + allowing individual elements to be repeated. - >>> random_combination_with_replacement(range(3), 5) # # doctest:+SKIP + >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP (0, 0, 1, 2, 2) + This equivalent to taking a random selection from + ``itertools.combinations_with_replacement(iterable, r)``. + """ pool = tuple(iterable) n = len(pool) - indices = sorted(randrange(n) for i in xrange(r)) + indices = sorted(randrange(n) for i in range(r)) return tuple(pool[i] for i in indices) + + +def nth_combination(iterable, r, index): + """Equivalent to ``list(combinations(iterable, r))[index]``. + + The subsequences of *iterable* that are of length *r* can be ordered + lexicographically. :func:`nth_combination` computes the subsequence at + sort position *index* directly, without computing the previous + subsequences. + + """ + pool = tuple(iterable) + n = len(pool) + if (r < 0) or (r > n): + raise ValueError + + c = 1 + k = min(r, n - r) + for i in range(1, k + 1): + c = c * (n - k + i) // i + + if index < 0: + index += c + + if (index < 0) or (index >= c): + raise IndexError + + result = [] + while r: + c, n, r = c * r // n, n - 1, r - 1 + while index >= c: + index -= c + c, n = c * (n - r) // n, n - 1 + result.append(pool[-1 - n]) + + return tuple(result) + + +def prepend(value, iterator): + """Yield *value*, followed by the elements in *iterator*. + + >>> value = '0' + >>> iterator = ['1', '2', '3'] + >>> list(prepend(value, iterator)) + ['0', '1', '2', '3'] + + To prepend multiple values, see :func:`itertools.chain`. + + """ + return chain([value], iterator) diff --git a/libs/more_itertools/tests/test_more.py b/libs/more_itertools/tests/test_more.py index 53b10618..a1b1e431 100644 --- a/libs/more_itertools/tests/test_more.py +++ b/libs/more_itertools/tests/test_more.py @@ -1,11 +1,33 @@ -from contextlib import closing -from itertools import islice, ifilter -from StringIO import StringIO +from __future__ import division, print_function, unicode_literals + +from collections import OrderedDict +from decimal import Decimal +from doctest import DocTestSuite +from fractions import Fraction +from functools import partial, reduce +from heapq import merge +from io import StringIO +from itertools import ( + chain, + count, + groupby, + islice, + permutations, + product, + repeat, +) +from operator import add, mul, itemgetter from unittest import TestCase -from nose.tools import eq_, assert_raises +from six.moves import filter, map, range, zip -from more_itertools import * # Test all the symbols are in __all__. +import more_itertools as mi + + +def load_tests(loader, tests, ignore): + # Add the doctests + tests.addTests(DocTestSuite('more_itertools.more')) + return tests class CollateTests(TestCase): @@ -14,32 +36,42 @@ class CollateTests(TestCase): def test_default(self): """Test with the default `key` function.""" - iterables = [xrange(4), xrange(7), xrange(3, 6)] - eq_(sorted(reduce(list.__add__, [list(it) for it in iterables])), - list(collate(*iterables))) + iterables = [range(4), range(7), range(3, 6)] + self.assertEqual( + sorted(reduce(list.__add__, [list(it) for it in iterables])), + list(mi.collate(*iterables)) + ) def test_key(self): """Test using a custom `key` function.""" - iterables = [xrange(5, 0, -1), xrange(4, 0, -1)] - eq_(list(sorted(reduce(list.__add__, - [list(it) for it in iterables]), - reverse=True)), - list(collate(*iterables, key=lambda x: -x))) + iterables = [range(5, 0, -1), range(4, 0, -1)] + actual = sorted( + reduce(list.__add__, [list(it) for it in iterables]), reverse=True + ) + expected = list(mi.collate(*iterables, key=lambda x: -x)) + self.assertEqual(actual, expected) def test_empty(self): """Be nice if passed an empty list of iterables.""" - eq_([], list(collate())) + self.assertEqual([], list(mi.collate())) def test_one(self): """Work when only 1 iterable is passed.""" - eq_([0, 1], list(collate(xrange(2)))) + self.assertEqual([0, 1], list(mi.collate(range(2)))) def test_reverse(self): """Test the `reverse` kwarg.""" - iterables = [xrange(4, 0, -1), xrange(7, 0, -1), xrange(3, 6, -1)] - eq_(sorted(reduce(list.__add__, [list(it) for it in iterables]), - reverse=True), - list(collate(*iterables, reverse=True))) + iterables = [range(4, 0, -1), range(7, 0, -1), range(3, 6, -1)] + + actual = sorted( + reduce(list.__add__, [list(it) for it in iterables]), reverse=True + ) + expected = list(mi.collate(*iterables, reverse=True)) + self.assertEqual(actual, expected) + + def test_alias(self): + self.assertNotEqual(merge.__doc__, mi.collate.__doc__) + self.assertNotEqual(partial.__doc__, mi.collate.__doc__) class ChunkedTests(TestCase): @@ -47,14 +79,18 @@ class ChunkedTests(TestCase): def test_even(self): """Test when ``n`` divides evenly into the length of the iterable.""" - eq_(list(chunked('ABCDEF', 3)), [['A', 'B', 'C'], ['D', 'E', 'F']]) + self.assertEqual( + list(mi.chunked('ABCDEF', 3)), [['A', 'B', 'C'], ['D', 'E', 'F']] + ) def test_odd(self): """Test when ``n`` does not divide evenly into the length of the iterable. """ - eq_(list(chunked('ABCDE', 3)), [['A', 'B', 'C'], ['D', 'E']]) + self.assertEqual( + list(mi.chunked('ABCDE', 3)), [['A', 'B', 'C'], ['D', 'E']] + ) class FirstTests(TestCase): @@ -64,19 +100,103 @@ class FirstTests(TestCase): """Test that it works on many-item iterables.""" # Also try it on a generator expression to make sure it works on # whatever those return, across Python versions. - eq_(first(x for x in xrange(4)), 0) + self.assertEqual(mi.first(x for x in range(4)), 0) def test_one(self): """Test that it doesn't raise StopIteration prematurely.""" - eq_(first([3]), 3) + self.assertEqual(mi.first([3]), 3) def test_empty_stop_iteration(self): """It should raise StopIteration for empty iterables.""" - assert_raises(ValueError, first, []) + self.assertRaises(ValueError, lambda: mi.first([])) def test_default(self): """It should return the provided default arg for empty iterables.""" - eq_(first([], 'boo'), 'boo') + self.assertEqual(mi.first([], 'boo'), 'boo') + + +class IterOnlyRange: + """User-defined iterable class which only support __iter__. + + It is not specified to inherit ``object``, so indexing on a instance will + raise an ``AttributeError`` rather than ``TypeError`` in Python 2. + + >>> r = IterOnlyRange(5) + >>> r[0] + AttributeError: IterOnlyRange instance has no attribute '__getitem__' + + Note: In Python 3, ``TypeError`` will be raised because ``object`` is + inherited implicitly by default. + + >>> r[0] + TypeError: 'IterOnlyRange' object does not support indexing + """ + def __init__(self, n): + """Set the length of the range.""" + self.n = n + + def __iter__(self): + """Works same as range().""" + return iter(range(self.n)) + + +class LastTests(TestCase): + """Tests for ``last()``""" + + def test_many_nonsliceable(self): + """Test that it works on many-item non-slice-able iterables.""" + # Also try it on a generator expression to make sure it works on + # whatever those return, across Python versions. + self.assertEqual(mi.last(x for x in range(4)), 3) + + def test_one_nonsliceable(self): + """Test that it doesn't raise StopIteration prematurely.""" + self.assertEqual(mi.last(x for x in range(1)), 0) + + def test_empty_stop_iteration_nonsliceable(self): + """It should raise ValueError for empty non-slice-able iterables.""" + self.assertRaises(ValueError, lambda: mi.last(x for x in range(0))) + + def test_default_nonsliceable(self): + """It should return the provided default arg for empty non-slice-able + iterables. + """ + self.assertEqual(mi.last((x for x in range(0)), 'boo'), 'boo') + + def test_many_sliceable(self): + """Test that it works on many-item slice-able iterables.""" + self.assertEqual(mi.last([0, 1, 2, 3]), 3) + + def test_one_sliceable(self): + """Test that it doesn't raise StopIteration prematurely.""" + self.assertEqual(mi.last([3]), 3) + + def test_empty_stop_iteration_sliceable(self): + """It should raise ValueError for empty slice-able iterables.""" + self.assertRaises(ValueError, lambda: mi.last([])) + + def test_default_sliceable(self): + """It should return the provided default arg for empty slice-able + iterables. + """ + self.assertEqual(mi.last([], 'boo'), 'boo') + + def test_dict(self): + """last(dic) and last(dic.keys()) should return same result.""" + dic = {'a': 1, 'b': 2, 'c': 3} + self.assertEqual(mi.last(dic), mi.last(dic.keys())) + + def test_ordereddict(self): + """last(dic) should return the last key.""" + od = OrderedDict() + od['a'] = 1 + od['b'] = 2 + od['c'] = 3 + self.assertEqual(mi.last(od), 'c') + + def test_customrange(self): + """It should work on custom class where [] raises AttributeError.""" + self.assertEqual(mi.last(IterOnlyRange(5)), 4) class PeekableTests(TestCase): @@ -86,58 +206,1869 @@ class PeekableTests(TestCase): """ def test_peek_default(self): """Make sure passing a default into ``peek()`` works.""" - p = peekable([]) - eq_(p.peek(7), 7) + p = mi.peekable([]) + self.assertEqual(p.peek(7), 7) def test_truthiness(self): """Make sure a ``peekable`` tests true iff there are items remaining in the iterable. """ - p = peekable([]) - self.failIf(p) - p = peekable(xrange(3)) - self.failUnless(p) + p = mi.peekable([]) + self.assertFalse(p) + + p = mi.peekable(range(3)) + self.assertTrue(p) def test_simple_peeking(self): """Make sure ``next`` and ``peek`` advance and don't advance the iterator, respectively. """ - p = peekable(xrange(10)) - eq_(p.next(), 0) - eq_(p.peek(), 1) - eq_(p.next(), 1) + p = mi.peekable(range(10)) + self.assertEqual(next(p), 0) + self.assertEqual(p.peek(), 1) + self.assertEqual(next(p), 1) + + def test_indexing(self): + """ + Indexing into the peekable shouldn't advance the iterator. + """ + p = mi.peekable('abcdefghijkl') + + # The 0th index is what ``next()`` will return + self.assertEqual(p[0], 'a') + self.assertEqual(next(p), 'a') + + # Indexing further into the peekable shouldn't advance the itertor + self.assertEqual(p[2], 'd') + self.assertEqual(next(p), 'b') + + # The 0th index moves up with the iterator; the last index follows + self.assertEqual(p[0], 'c') + self.assertEqual(p[9], 'l') + + self.assertEqual(next(p), 'c') + self.assertEqual(p[8], 'l') + + # Negative indexing should work too + self.assertEqual(p[-2], 'k') + self.assertEqual(p[-9], 'd') + self.assertRaises(IndexError, lambda: p[-10]) + + def test_slicing(self): + """Slicing the peekable shouldn't advance the iterator.""" + seq = list('abcdefghijkl') + p = mi.peekable(seq) + + # Slicing the peekable should just be like slicing a re-iterable + self.assertEqual(p[1:4], seq[1:4]) + + # Advancing the iterator moves the slices up also + self.assertEqual(next(p), 'a') + self.assertEqual(p[1:4], seq[1:][1:4]) + + # Implicit starts and stop should work + self.assertEqual(p[:5], seq[1:][:5]) + self.assertEqual(p[:], seq[1:][:]) + + # Indexing past the end should work + self.assertEqual(p[:100], seq[1:][:100]) + + # Steps should work, including negative + self.assertEqual(p[::2], seq[1:][::2]) + self.assertEqual(p[::-1], seq[1:][::-1]) + + def test_slicing_reset(self): + """Test slicing on a fresh iterable each time""" + iterable = ['0', '1', '2', '3', '4', '5'] + indexes = list(range(-4, len(iterable) + 4)) + [None] + steps = [1, 2, 3, 4, -1, -2, -3, 4] + for slice_args in product(indexes, indexes, steps): + it = iter(iterable) + p = mi.peekable(it) + next(p) + index = slice(*slice_args) + actual = p[index] + expected = iterable[1:][index] + self.assertEqual(actual, expected, slice_args) + + def test_slicing_error(self): + iterable = '01234567' + p = mi.peekable(iter(iterable)) + + # Prime the cache + p.peek() + old_cache = list(p._cache) + + # Illegal slice + with self.assertRaises(ValueError): + p[1:-1:0] + + # Neither the cache nor the iteration should be affected + self.assertEqual(old_cache, list(p._cache)) + self.assertEqual(list(p), list(iterable)) + + def test_passthrough(self): + """Iterating a peekable without using ``peek()`` or ``prepend()`` + should just give the underlying iterable's elements (a trivial test but + useful to set a baseline in case something goes wrong)""" + expected = [1, 2, 3, 4, 5] + actual = list(mi.peekable(expected)) + self.assertEqual(actual, expected) + + # prepend() behavior tests + + def test_prepend(self): + """Tests intersperesed ``prepend()`` and ``next()`` calls""" + it = mi.peekable(range(2)) + actual = [] + + # Test prepend() before next() + it.prepend(10) + actual += [next(it), next(it)] + + # Test prepend() between next()s + it.prepend(11) + actual += [next(it), next(it)] + + # Test prepend() after source iterable is consumed + it.prepend(12) + actual += [next(it)] + + expected = [10, 0, 11, 1, 12] + self.assertEqual(actual, expected) + + def test_multi_prepend(self): + """Tests prepending multiple items and getting them in proper order""" + it = mi.peekable(range(5)) + actual = [next(it), next(it)] + it.prepend(10, 11, 12) + it.prepend(20, 21) + actual += list(it) + expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4] + self.assertEqual(actual, expected) + + def test_empty(self): + """Tests prepending in front of an empty iterable""" + it = mi.peekable([]) + it.prepend(10) + actual = list(it) + expected = [10] + self.assertEqual(actual, expected) + + def test_prepend_truthiness(self): + """Tests that ``__bool__()`` or ``__nonzero__()`` works properly + with ``prepend()``""" + it = mi.peekable(range(5)) + self.assertTrue(it) + actual = list(it) + self.assertFalse(it) + it.prepend(10) + self.assertTrue(it) + actual += [next(it)] + self.assertFalse(it) + expected = [0, 1, 2, 3, 4, 10] + self.assertEqual(actual, expected) + + def test_multi_prepend_peek(self): + """Tests prepending multiple elements and getting them in reverse order + while peeking""" + it = mi.peekable(range(5)) + actual = [next(it), next(it)] + self.assertEqual(it.peek(), 2) + it.prepend(10, 11, 12) + self.assertEqual(it.peek(), 10) + it.prepend(20, 21) + self.assertEqual(it.peek(), 20) + actual += list(it) + self.assertFalse(it) + expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4] + self.assertEqual(actual, expected) + + def test_prepend_after_stop(self): + """Test resuming iteration after a previous exhaustion""" + it = mi.peekable(range(3)) + self.assertEqual(list(it), [0, 1, 2]) + self.assertRaises(StopIteration, lambda: next(it)) + it.prepend(10) + self.assertEqual(next(it), 10) + self.assertRaises(StopIteration, lambda: next(it)) + + def test_prepend_slicing(self): + """Tests interaction between prepending and slicing""" + seq = list(range(20)) + p = mi.peekable(seq) + + p.prepend(30, 40, 50) + pseq = [30, 40, 50] + seq # pseq for prepended_seq + + # adapt the specific tests from test_slicing + self.assertEqual(p[0], 30) + self.assertEqual(p[1:8], pseq[1:8]) + self.assertEqual(p[1:], pseq[1:]) + self.assertEqual(p[:5], pseq[:5]) + self.assertEqual(p[:], pseq[:]) + self.assertEqual(p[:100], pseq[:100]) + self.assertEqual(p[::2], pseq[::2]) + self.assertEqual(p[::-1], pseq[::-1]) + + def test_prepend_indexing(self): + """Tests interaction between prepending and indexing""" + seq = list(range(20)) + p = mi.peekable(seq) + + p.prepend(30, 40, 50) + + self.assertEqual(p[0], 30) + self.assertEqual(next(p), 30) + self.assertEqual(p[2], 0) + self.assertEqual(next(p), 40) + self.assertEqual(p[0], 50) + self.assertEqual(p[9], 8) + self.assertEqual(next(p), 50) + self.assertEqual(p[8], 8) + self.assertEqual(p[-2], 18) + self.assertEqual(p[-9], 11) + self.assertRaises(IndexError, lambda: p[-21]) + + def test_prepend_iterable(self): + """Tests prepending from an iterable""" + it = mi.peekable(range(5)) + # Don't directly use the range() object to avoid any range-specific + # optimizations + it.prepend(*(x for x in range(5))) + actual = list(it) + expected = list(chain(range(5), range(5))) + self.assertEqual(actual, expected) + + def test_prepend_many(self): + """Tests that prepending a huge number of elements works""" + it = mi.peekable(range(5)) + # Don't directly use the range() object to avoid any range-specific + # optimizations + it.prepend(*(x for x in range(20000))) + actual = list(it) + expected = list(chain(range(20000), range(5))) + self.assertEqual(actual, expected) + + def test_prepend_reversed(self): + """Tests prepending from a reversed iterable""" + it = mi.peekable(range(3)) + it.prepend(*reversed((10, 11, 12))) + actual = list(it) + expected = [12, 11, 10, 0, 1, 2] + self.assertEqual(actual, expected) class ConsumerTests(TestCase): """Tests for ``consumer()``""" def test_consumer(self): - @consumer + @mi.consumer def eater(): while True: - x = yield + x = yield # noqa e = eater() e.send('hi') # without @consumer, would raise TypeError -def test_ilen(): - """Sanity-check ``ilen()``.""" - eq_(ilen(ifilter(lambda x: x % 10 == 0, range(101))), 11) +class DistinctPermutationsTests(TestCase): + def test_distinct_permutations(self): + """Make sure the output for ``distinct_permutations()`` is the same as + set(permutations(it)). + + """ + iterable = ['z', 'a', 'a', 'q', 'q', 'q', 'y'] + test_output = sorted(mi.distinct_permutations(iterable)) + ref_output = sorted(set(permutations(iterable))) + self.assertEqual(test_output, ref_output) + + def test_other_iterables(self): + """Make sure ``distinct_permutations()`` accepts a different type of + iterables. + + """ + # a generator + iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y']) + test_output = sorted(mi.distinct_permutations(iterable)) + # "reload" it + iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y']) + ref_output = sorted(set(permutations(iterable))) + self.assertEqual(test_output, ref_output) + + # an iterator + iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y']) + test_output = sorted(mi.distinct_permutations(iterable)) + # "reload" it + iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y']) + ref_output = sorted(set(permutations(iterable))) + self.assertEqual(test_output, ref_output) -def test_with_iter(): - """Make sure ``with_iter`` iterates over and closes things correctly.""" - s = StringIO('One fish\nTwo fish') - initial_words = [line.split()[0] for line in with_iter(closing(s))] - eq_(initial_words, ['One', 'Two']) +class IlenTests(TestCase): + def test_ilen(self): + """Sanity-checks for ``ilen()``.""" + # Non-empty + self.assertEqual( + mi.ilen(filter(lambda x: x % 10 == 0, range(101))), 11 + ) - # Make sure closing happened: - try: - list(s) - except ValueError: # "I/O operation on closed file" - pass - else: - raise AssertionError('StringIO object was not closed.') + # Empty + self.assertEqual(mi.ilen((x for x in range(0))), 0) + + # Iterable with __len__ + self.assertEqual(mi.ilen(list(range(6))), 6) + + +class WithIterTests(TestCase): + def test_with_iter(self): + s = StringIO('One fish\nTwo fish') + initial_words = [line.split()[0] for line in mi.with_iter(s)] + + # Iterable's items should be faithfully represented + self.assertEqual(initial_words, ['One', 'Two']) + # The file object should be closed + self.assertEqual(s.closed, True) + + +class OneTests(TestCase): + def test_basic(self): + it = iter(['item']) + self.assertEqual(mi.one(it), 'item') + + def test_too_short(self): + it = iter([]) + self.assertRaises(ValueError, lambda: mi.one(it)) + self.assertRaises(IndexError, lambda: mi.one(it, too_short=IndexError)) + + def test_too_long(self): + it = count() + self.assertRaises(ValueError, lambda: mi.one(it)) # burn 0 and 1 + self.assertEqual(next(it), 2) + self.assertRaises( + OverflowError, lambda: mi.one(it, too_long=OverflowError) + ) + + +class IntersperseTest(TestCase): + """ Tests for intersperse() """ + + def test_even(self): + iterable = (x for x in '01') + self.assertEqual( + list(mi.intersperse(None, iterable)), ['0', None, '1'] + ) + + def test_odd(self): + iterable = (x for x in '012') + self.assertEqual( + list(mi.intersperse(None, iterable)), ['0', None, '1', None, '2'] + ) + + def test_nested(self): + element = ('a', 'b') + iterable = (x for x in '012') + actual = list(mi.intersperse(element, iterable)) + expected = ['0', ('a', 'b'), '1', ('a', 'b'), '2'] + self.assertEqual(actual, expected) + + def test_not_iterable(self): + self.assertRaises(TypeError, lambda: mi.intersperse('x', 1)) + + def test_n(self): + for n, element, expected in [ + (1, '_', ['0', '_', '1', '_', '2', '_', '3', '_', '4', '_', '5']), + (2, '_', ['0', '1', '_', '2', '3', '_', '4', '5']), + (3, '_', ['0', '1', '2', '_', '3', '4', '5']), + (4, '_', ['0', '1', '2', '3', '_', '4', '5']), + (5, '_', ['0', '1', '2', '3', '4', '_', '5']), + (6, '_', ['0', '1', '2', '3', '4', '5']), + (7, '_', ['0', '1', '2', '3', '4', '5']), + (3, ['a', 'b'], ['0', '1', '2', ['a', 'b'], '3', '4', '5']), + ]: + iterable = (x for x in '012345') + actual = list(mi.intersperse(element, iterable, n=n)) + self.assertEqual(actual, expected) + + def test_n_zero(self): + self.assertRaises( + ValueError, lambda: list(mi.intersperse('x', '012', n=0)) + ) + + +class UniqueToEachTests(TestCase): + """Tests for ``unique_to_each()``""" + + def test_all_unique(self): + """When all the input iterables are unique the output should match + the input.""" + iterables = [[1, 2], [3, 4, 5], [6, 7, 8]] + self.assertEqual(mi.unique_to_each(*iterables), iterables) + + def test_duplicates(self): + """When there are duplicates in any of the input iterables that aren't + in the rest, those duplicates should be emitted.""" + iterables = ["mississippi", "missouri"] + self.assertEqual( + mi.unique_to_each(*iterables), [['p', 'p'], ['o', 'u', 'r']] + ) + + def test_mixed(self): + """When the input iterables contain different types the function should + still behave properly""" + iterables = ['x', (i for i in range(3)), [1, 2, 3], tuple()] + self.assertEqual(mi.unique_to_each(*iterables), [['x'], [0], [3], []]) + + +class WindowedTests(TestCase): + """Tests for ``windowed()``""" + + def test_basic(self): + actual = list(mi.windowed([1, 2, 3, 4, 5], 3)) + expected = [(1, 2, 3), (2, 3, 4), (3, 4, 5)] + self.assertEqual(actual, expected) + + def test_large_size(self): + """ + When the window size is larger than the iterable, and no fill value is + given,``None`` should be filled in. + """ + actual = list(mi.windowed([1, 2, 3, 4, 5], 6)) + expected = [(1, 2, 3, 4, 5, None)] + self.assertEqual(actual, expected) + + def test_fillvalue(self): + """ + When sizes don't match evenly, the given fill value should be used. + """ + iterable = [1, 2, 3, 4, 5] + + for n, kwargs, expected in [ + (6, {}, [(1, 2, 3, 4, 5, '!')]), # n > len(iterable) + (3, {'step': 3}, [(1, 2, 3), (4, 5, '!')]), # using ``step`` + ]: + actual = list(mi.windowed(iterable, n, fillvalue='!', **kwargs)) + self.assertEqual(actual, expected) + + def test_zero(self): + """When the window size is zero, an empty tuple should be emitted.""" + actual = list(mi.windowed([1, 2, 3, 4, 5], 0)) + expected = [tuple()] + self.assertEqual(actual, expected) + + def test_negative(self): + """When the window size is negative, ValueError should be raised.""" + with self.assertRaises(ValueError): + list(mi.windowed([1, 2, 3, 4, 5], -1)) + + def test_step(self): + """The window should advance by the number of steps provided""" + iterable = [1, 2, 3, 4, 5, 6, 7] + for n, step, expected in [ + (3, 2, [(1, 2, 3), (3, 4, 5), (5, 6, 7)]), # n > step + (3, 3, [(1, 2, 3), (4, 5, 6), (7, None, None)]), # n == step + (3, 4, [(1, 2, 3), (5, 6, 7)]), # line up nicely + (3, 5, [(1, 2, 3), (6, 7, None)]), # off by one + (3, 6, [(1, 2, 3), (7, None, None)]), # off by two + (3, 7, [(1, 2, 3)]), # step past the end + (7, 8, [(1, 2, 3, 4, 5, 6, 7)]), # step > len(iterable) + ]: + actual = list(mi.windowed(iterable, n, step=step)) + self.assertEqual(actual, expected) + + # Step must be greater than or equal to 1 + with self.assertRaises(ValueError): + list(mi.windowed(iterable, 3, step=0)) + + +class BucketTests(TestCase): + """Tests for ``bucket()``""" + + def test_basic(self): + iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33] + D = mi.bucket(iterable, key=lambda x: 10 * (x // 10)) + + # In-order access + self.assertEqual(list(D[10]), [10, 11, 12]) + + # Out of order access + self.assertEqual(list(D[30]), [30, 31, 33]) + self.assertEqual(list(D[20]), [20, 21, 22, 23]) + + self.assertEqual(list(D[40]), []) # Nothing in here! + + def test_in(self): + iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33] + D = mi.bucket(iterable, key=lambda x: 10 * (x // 10)) + + self.assertTrue(10 in D) + self.assertFalse(40 in D) + self.assertTrue(20 in D) + self.assertFalse(21 in D) + + # Checking in-ness shouldn't advance the iterator + self.assertEqual(next(D[10]), 10) + + def test_validator(self): + iterable = count(0) + key = lambda x: int(str(x)[0]) # First digit of each number + validator = lambda x: 0 < x < 10 # No leading zeros + D = mi.bucket(iterable, key, validator=validator) + self.assertEqual(mi.take(3, D[1]), [1, 10, 11]) + self.assertNotIn(0, D) # Non-valid entries don't return True + self.assertNotIn(0, D._cache) # Don't store non-valid entries + self.assertEqual(list(D[0]), []) + + +class SpyTests(TestCase): + """Tests for ``spy()``""" + + def test_basic(self): + original_iterable = iter('abcdefg') + head, new_iterable = mi.spy(original_iterable) + self.assertEqual(head, ['a']) + self.assertEqual( + list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + ) + + def test_unpacking(self): + original_iterable = iter('abcdefg') + (first, second, third), new_iterable = mi.spy(original_iterable, 3) + self.assertEqual(first, 'a') + self.assertEqual(second, 'b') + self.assertEqual(third, 'c') + self.assertEqual( + list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + ) + + def test_too_many(self): + original_iterable = iter('abc') + head, new_iterable = mi.spy(original_iterable, 4) + self.assertEqual(head, ['a', 'b', 'c']) + self.assertEqual(list(new_iterable), ['a', 'b', 'c']) + + def test_zero(self): + original_iterable = iter('abc') + head, new_iterable = mi.spy(original_iterable, 0) + self.assertEqual(head, []) + self.assertEqual(list(new_iterable), ['a', 'b', 'c']) + + +class InterleaveTests(TestCase): + def test_even(self): + actual = list(mi.interleave([1, 4, 7], [2, 5, 8], [3, 6, 9])) + expected = [1, 2, 3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_short(self): + actual = list(mi.interleave([1, 4], [2, 5, 7], [3, 6, 8])) + expected = [1, 2, 3, 4, 5, 6] + self.assertEqual(actual, expected) + + def test_mixed_types(self): + it_list = ['a', 'b', 'c', 'd'] + it_str = '12345' + it_inf = count() + actual = list(mi.interleave(it_list, it_str, it_inf)) + expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', 3] + self.assertEqual(actual, expected) + + +class InterleaveLongestTests(TestCase): + def test_even(self): + actual = list(mi.interleave_longest([1, 4, 7], [2, 5, 8], [3, 6, 9])) + expected = [1, 2, 3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_short(self): + actual = list(mi.interleave_longest([1, 4], [2, 5, 7], [3, 6, 8])) + expected = [1, 2, 3, 4, 5, 6, 7, 8] + self.assertEqual(actual, expected) + + def test_mixed_types(self): + it_list = ['a', 'b', 'c', 'd'] + it_str = '12345' + it_gen = (x for x in range(3)) + actual = list(mi.interleave_longest(it_list, it_str, it_gen)) + expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', '5'] + self.assertEqual(actual, expected) + + +class TestCollapse(TestCase): + """Tests for ``collapse()``""" + + def test_collapse(self): + l = [[1], 2, [[3], 4], [[[5]]]] + self.assertEqual(list(mi.collapse(l)), [1, 2, 3, 4, 5]) + + def test_collapse_to_string(self): + l = [["s1"], "s2", [["s3"], "s4"], [[["s5"]]]] + self.assertEqual(list(mi.collapse(l)), ["s1", "s2", "s3", "s4", "s5"]) + + def test_collapse_flatten(self): + l = [[1], [2], [[3], 4], [[[5]]]] + self.assertEqual(list(mi.collapse(l, levels=1)), list(mi.flatten(l))) + + def test_collapse_to_level(self): + l = [[1], 2, [[3], 4], [[[5]]]] + self.assertEqual(list(mi.collapse(l, levels=2)), [1, 2, 3, 4, [5]]) + self.assertEqual( + list(mi.collapse(mi.collapse(l, levels=1), levels=1)), + list(mi.collapse(l, levels=2)) + ) + + def test_collapse_to_list(self): + l = (1, [2], (3, [4, (5,)], 'ab')) + actual = list(mi.collapse(l, base_type=list)) + expected = [1, [2], 3, [4, (5,)], 'ab'] + self.assertEqual(actual, expected) + + +class SideEffectTests(TestCase): + """Tests for ``side_effect()``""" + + def test_individual(self): + # The function increments the counter for each call + counter = [0] + + def func(arg): + counter[0] += 1 + + result = list(mi.side_effect(func, range(10))) + self.assertEqual(result, list(range(10))) + self.assertEqual(counter[0], 10) + + def test_chunked(self): + # The function increments the counter for each call + counter = [0] + + def func(arg): + counter[0] += 1 + + result = list(mi.side_effect(func, range(10), 2)) + self.assertEqual(result, list(range(10))) + self.assertEqual(counter[0], 5) + + def test_before_after(self): + f = StringIO() + collector = [] + + def func(item): + print(item, file=f) + collector.append(f.getvalue()) + + def it(): + yield u'a' + yield u'b' + raise RuntimeError('kaboom') + + before = lambda: print('HEADER', file=f) + after = f.close + + try: + mi.consume(mi.side_effect(func, it(), before=before, after=after)) + except RuntimeError: + pass + + # The iterable should have been written to the file + self.assertEqual(collector, [u'HEADER\na\n', u'HEADER\na\nb\n']) + + # The file should be closed even though something bad happened + self.assertTrue(f.closed) + + def test_before_fails(self): + f = StringIO() + func = lambda x: print(x, file=f) + + def before(): + raise RuntimeError('ouch') + + try: + mi.consume( + mi.side_effect(func, u'abc', before=before, after=f.close) + ) + except RuntimeError: + pass + + # The file should be closed even though something bad happened in the + # before function + self.assertTrue(f.closed) + + +class SlicedTests(TestCase): + """Tests for ``sliced()``""" + + def test_even(self): + """Test when the length of the sequence is divisible by *n*""" + seq = 'ABCDEFGHI' + self.assertEqual(list(mi.sliced(seq, 3)), ['ABC', 'DEF', 'GHI']) + + def test_odd(self): + """Test when the length of the sequence is not divisible by *n*""" + seq = 'ABCDEFGHI' + self.assertEqual(list(mi.sliced(seq, 4)), ['ABCD', 'EFGH', 'I']) + + def test_not_sliceable(self): + seq = (x for x in 'ABCDEFGHI') + + with self.assertRaises(TypeError): + list(mi.sliced(seq, 3)) + + +class SplitAtTests(TestCase): + """Tests for ``split()``""" + + def comp_with_str_split(self, str_to_split, delim): + pred = lambda c: c == delim + actual = list(map(''.join, mi.split_at(str_to_split, pred))) + expected = str_to_split.split(delim) + self.assertEqual(actual, expected) + + def test_seperators(self): + test_strs = ['', 'abcba', 'aaabbbcccddd', 'e'] + for s, delim in product(test_strs, 'abcd'): + self.comp_with_str_split(s, delim) + + +class SplitBeforeTest(TestCase): + """Tests for ``split_before()``""" + + def test_starts_with_sep(self): + actual = list(mi.split_before('xooxoo', lambda c: c == 'x')) + expected = [['x', 'o', 'o'], ['x', 'o', 'o']] + self.assertEqual(actual, expected) + + def test_ends_with_sep(self): + actual = list(mi.split_before('ooxoox', lambda c: c == 'x')) + expected = [['o', 'o'], ['x', 'o', 'o'], ['x']] + self.assertEqual(actual, expected) + + def test_no_sep(self): + actual = list(mi.split_before('ooo', lambda c: c == 'x')) + expected = [['o', 'o', 'o']] + self.assertEqual(actual, expected) + + +class SplitAfterTest(TestCase): + """Tests for ``split_after()``""" + + def test_starts_with_sep(self): + actual = list(mi.split_after('xooxoo', lambda c: c == 'x')) + expected = [['x'], ['o', 'o', 'x'], ['o', 'o']] + self.assertEqual(actual, expected) + + def test_ends_with_sep(self): + actual = list(mi.split_after('ooxoox', lambda c: c == 'x')) + expected = [['o', 'o', 'x'], ['o', 'o', 'x']] + self.assertEqual(actual, expected) + + def test_no_sep(self): + actual = list(mi.split_after('ooo', lambda c: c == 'x')) + expected = [['o', 'o', 'o']] + self.assertEqual(actual, expected) + + +class PaddedTest(TestCase): + """Tests for ``padded()``""" + + def test_no_n(self): + seq = [1, 2, 3] + + # No fillvalue + self.assertEqual(mi.take(5, mi.padded(seq)), [1, 2, 3, None, None]) + + # With fillvalue + self.assertEqual( + mi.take(5, mi.padded(seq, fillvalue='')), [1, 2, 3, '', ''] + ) + + def test_invalid_n(self): + self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=-1))) + self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=0))) + + def test_valid_n(self): + seq = [1, 2, 3, 4, 5] + + # No need for padding: len(seq) <= n + self.assertEqual(list(mi.padded(seq, n=4)), [1, 2, 3, 4, 5]) + self.assertEqual(list(mi.padded(seq, n=5)), [1, 2, 3, 4, 5]) + + # No fillvalue + self.assertEqual( + list(mi.padded(seq, n=7)), [1, 2, 3, 4, 5, None, None] + ) + + # With fillvalue + self.assertEqual( + list(mi.padded(seq, fillvalue='', n=7)), [1, 2, 3, 4, 5, '', ''] + ) + + def test_next_multiple(self): + seq = [1, 2, 3, 4, 5, 6] + + # No need for padding: len(seq) % n == 0 + self.assertEqual( + list(mi.padded(seq, n=3, next_multiple=True)), [1, 2, 3, 4, 5, 6] + ) + + # Padding needed: len(seq) < n + self.assertEqual( + list(mi.padded(seq, n=8, next_multiple=True)), + [1, 2, 3, 4, 5, 6, None, None] + ) + + # No padding needed: len(seq) == n + self.assertEqual( + list(mi.padded(seq, n=6, next_multiple=True)), [1, 2, 3, 4, 5, 6] + ) + + # Padding needed: len(seq) > n + self.assertEqual( + list(mi.padded(seq, n=4, next_multiple=True)), + [1, 2, 3, 4, 5, 6, None, None] + ) + + # With fillvalue + self.assertEqual( + list(mi.padded(seq, fillvalue='', n=4, next_multiple=True)), + [1, 2, 3, 4, 5, 6, '', ''] + ) + + +class DistributeTest(TestCase): + """Tests for distribute()""" + + def test_invalid_n(self): + self.assertRaises(ValueError, lambda: mi.distribute(-1, [1, 2, 3])) + self.assertRaises(ValueError, lambda: mi.distribute(0, [1, 2, 3])) + + def test_basic(self): + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + for n, expected in [ + (1, [iterable]), + (2, [[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]), + (3, [[1, 4, 7, 10], [2, 5, 8], [3, 6, 9]]), + (10, [[n] for n in range(1, 10 + 1)]), + ]: + self.assertEqual( + [list(x) for x in mi.distribute(n, iterable)], expected + ) + + def test_large_n(self): + iterable = [1, 2, 3, 4] + self.assertEqual( + [list(x) for x in mi.distribute(6, iterable)], + [[1], [2], [3], [4], [], []] + ) + + +class StaggerTest(TestCase): + """Tests for ``stagger()``""" + + def test_default(self): + iterable = [0, 1, 2, 3] + actual = list(mi.stagger(iterable)) + expected = [(None, 0, 1), (0, 1, 2), (1, 2, 3)] + self.assertEqual(actual, expected) + + def test_offsets(self): + iterable = [0, 1, 2, 3] + for offsets, expected in [ + ((-2, 0, 2), [('', 0, 2), ('', 1, 3)]), + ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3)]), + ((1, 2), [(1, 2), (2, 3)]), + ]: + all_groups = mi.stagger(iterable, offsets=offsets, fillvalue='') + self.assertEqual(list(all_groups), expected) + + def test_longest(self): + iterable = [0, 1, 2, 3] + for offsets, expected in [ + ( + (-1, 0, 1), + [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, ''), (3, '', '')] + ), + ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3), (3, '')]), + ((1, 2), [(1, 2), (2, 3), (3, '')]), + ]: + all_groups = mi.stagger( + iterable, offsets=offsets, fillvalue='', longest=True + ) + self.assertEqual(list(all_groups), expected) + + +class ZipOffsetTest(TestCase): + """Tests for ``zip_offset()``""" + + def test_shortest(self): + a_1 = [0, 1, 2, 3] + a_2 = [0, 1, 2, 3, 4, 5] + a_3 = [0, 1, 2, 3, 4, 5, 6, 7] + actual = list( + mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), fillvalue='') + ) + expected = [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5)] + self.assertEqual(actual, expected) + + def test_longest(self): + a_1 = [0, 1, 2, 3] + a_2 = [0, 1, 2, 3, 4, 5] + a_3 = [0, 1, 2, 3, 4, 5, 6, 7] + actual = list( + mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), longest=True) + ) + expected = [ + (None, 0, 1), + (0, 1, 2), + (1, 2, 3), + (2, 3, 4), + (3, 4, 5), + (None, 5, 6), + (None, None, 7), + ] + self.assertEqual(actual, expected) + + def test_mismatch(self): + iterables = [0, 1, 2], [2, 3, 4] + offsets = (-1, 0, 1) + self.assertRaises( + ValueError, + lambda: list(mi.zip_offset(*iterables, offsets=offsets)) + ) + + +class SortTogetherTest(TestCase): + """Tests for sort_together()""" + + def test_key_list(self): + """tests `key_list` including default, iterables include duplicates""" + iterables = [ + ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20] + ] + + self.assertEqual( + mi.sort_together(iterables), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('June', 'July', 'July', 'May', 'Aug.', 'May'), + (70, 100, 20, 97, 20, 100) + ] + ) + + self.assertEqual( + mi.sort_together(iterables, key_list=(0, 1)), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('July', 'July', 'June', 'Aug.', 'May', 'May'), + (100, 20, 70, 20, 97, 100) + ] + ) + + self.assertEqual( + mi.sort_together(iterables, key_list=(0, 1, 2)), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('July', 'July', 'June', 'Aug.', 'May', 'May'), + (20, 100, 70, 20, 97, 100) + ] + ) + + self.assertEqual( + mi.sort_together(iterables, key_list=(2,)), + [ + ('GA', 'CT', 'CT', 'GA', 'GA', 'CT'), + ('Aug.', 'July', 'June', 'May', 'May', 'July'), + (20, 20, 70, 97, 100, 100) + ] + ) + + def test_invalid_key_list(self): + """tests `key_list` for indexes not available in `iterables`""" + iterables = [ + ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20] + ] + + self.assertRaises( + IndexError, lambda: mi.sort_together(iterables, key_list=(5,)) + ) + + def test_reverse(self): + """tests `reverse` to ensure a reverse sort for `key_list` iterables""" + iterables = [ + ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20] + ] + + self.assertEqual( + mi.sort_together(iterables, key_list=(0, 1, 2), reverse=True), + [('GA', 'GA', 'GA', 'CT', 'CT', 'CT'), + ('May', 'May', 'Aug.', 'June', 'July', 'July'), + (100, 97, 20, 70, 100, 20)] + ) + + def test_uneven_iterables(self): + """tests trimming of iterables to the shortest length before sorting""" + iterables = [['GA', 'GA', 'GA', 'CT', 'CT', 'CT', 'MA'], + ['May', 'Aug.', 'May', 'June', 'July', 'July'], + [97, 20, 100, 70, 100, 20, 0]] + + self.assertEqual( + mi.sort_together(iterables), + [ + ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), + ('June', 'July', 'July', 'May', 'Aug.', 'May'), + (70, 100, 20, 97, 20, 100) + ] + ) + + +class DivideTest(TestCase): + """Tests for divide()""" + + def test_invalid_n(self): + self.assertRaises(ValueError, lambda: mi.divide(-1, [1, 2, 3])) + self.assertRaises(ValueError, lambda: mi.divide(0, [1, 2, 3])) + + def test_basic(self): + iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + for n, expected in [ + (1, [iterable]), + (2, [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), + (3, [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10]]), + (10, [[n] for n in range(1, 10 + 1)]), + ]: + self.assertEqual( + [list(x) for x in mi.divide(n, iterable)], expected + ) + + def test_large_n(self): + iterable = [1, 2, 3, 4] + self.assertEqual( + [list(x) for x in mi.divide(6, iterable)], + [[1], [2], [3], [4], [], []] + ) + + +class TestAlwaysIterable(TestCase): + """Tests for always_iterable()""" + def test_single(self): + self.assertEqual(list(mi.always_iterable(1)), [1]) + + def test_strings(self): + for obj in ['foo', b'bar', u'baz']: + actual = list(mi.always_iterable(obj)) + expected = [obj] + self.assertEqual(actual, expected) + + def test_base_type(self): + dict_obj = {'a': 1, 'b': 2} + str_obj = '123' + + # Default: dicts are iterable like they normally are + default_actual = list(mi.always_iterable(dict_obj)) + default_expected = list(dict_obj) + self.assertEqual(default_actual, default_expected) + + # Unitary types set: dicts are not iterable + custom_actual = list(mi.always_iterable(dict_obj, base_type=dict)) + custom_expected = [dict_obj] + self.assertEqual(custom_actual, custom_expected) + + # With unitary types set, strings are iterable + str_actual = list(mi.always_iterable(str_obj, base_type=None)) + str_expected = list(str_obj) + self.assertEqual(str_actual, str_expected) + + def test_iterables(self): + self.assertEqual(list(mi.always_iterable([0, 1])), [0, 1]) + self.assertEqual( + list(mi.always_iterable([0, 1], base_type=list)), [[0, 1]] + ) + self.assertEqual( + list(mi.always_iterable(iter('foo'))), ['f', 'o', 'o'] + ) + self.assertEqual(list(mi.always_iterable([])), []) + + def test_none(self): + self.assertEqual(list(mi.always_iterable(None)), []) + + def test_generator(self): + def _gen(): + yield 0 + yield 1 + + self.assertEqual(list(mi.always_iterable(_gen())), [0, 1]) + + +class AdjacentTests(TestCase): + def test_typical(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10))) + expected = [(True, 0), (True, 1), (False, 2), (False, 3), (True, 4), + (True, 5), (True, 6), (False, 7), (False, 8), (False, 9)] + self.assertEqual(actual, expected) + + def test_empty_iterable(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, [])) + expected = [] + self.assertEqual(actual, expected) + + def test_length_one(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, [0])) + expected = [(True, 0)] + self.assertEqual(actual, expected) + + actual = list(mi.adjacent(lambda x: x % 5 == 0, [1])) + expected = [(False, 1)] + self.assertEqual(actual, expected) + + def test_consecutive_true(self): + """Test that when the predicate matches multiple consecutive elements + it doesn't repeat elements in the output""" + actual = list(mi.adjacent(lambda x: x % 5 < 2, range(10))) + expected = [(True, 0), (True, 1), (True, 2), (False, 3), (True, 4), + (True, 5), (True, 6), (True, 7), (False, 8), (False, 9)] + self.assertEqual(actual, expected) + + def test_distance(self): + actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=2)) + expected = [(True, 0), (True, 1), (True, 2), (True, 3), (True, 4), + (True, 5), (True, 6), (True, 7), (False, 8), (False, 9)] + self.assertEqual(actual, expected) + + actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=3)) + expected = [(True, 0), (True, 1), (True, 2), (True, 3), (True, 4), + (True, 5), (True, 6), (True, 7), (True, 8), (False, 9)] + self.assertEqual(actual, expected) + + def test_large_distance(self): + """Test distance larger than the length of the iterable""" + iterable = range(10) + actual = list(mi.adjacent(lambda x: x % 5 == 4, iterable, distance=20)) + expected = list(zip(repeat(True), iterable)) + self.assertEqual(actual, expected) + + actual = list(mi.adjacent(lambda x: False, iterable, distance=20)) + expected = list(zip(repeat(False), iterable)) + self.assertEqual(actual, expected) + + def test_zero_distance(self): + """Test that adjacent() reduces to zip+map when distance is 0""" + iterable = range(1000) + predicate = lambda x: x % 4 == 2 + actual = mi.adjacent(predicate, iterable, 0) + expected = zip(map(predicate, iterable), iterable) + self.assertTrue(all(a == e for a, e in zip(actual, expected))) + + def test_negative_distance(self): + """Test that adjacent() raises an error with negative distance""" + pred = lambda x: x + self.assertRaises( + ValueError, lambda: mi.adjacent(pred, range(1000), -1) + ) + self.assertRaises( + ValueError, lambda: mi.adjacent(pred, range(10), -10) + ) + + def test_grouping(self): + """Test interaction of adjacent() with groupby_transform()""" + iterable = mi.adjacent(lambda x: x % 5 == 0, range(10)) + grouper = mi.groupby_transform(iterable, itemgetter(0), itemgetter(1)) + actual = [(k, list(g)) for k, g in grouper] + expected = [ + (True, [0, 1]), + (False, [2, 3]), + (True, [4, 5, 6]), + (False, [7, 8, 9]), + ] + self.assertEqual(actual, expected) + + def test_call_once(self): + """Test that the predicate is only called once per item.""" + already_seen = set() + iterable = range(10) + + def predicate(item): + self.assertNotIn(item, already_seen) + already_seen.add(item) + return True + + actual = list(mi.adjacent(predicate, iterable)) + expected = [(True, x) for x in iterable] + self.assertEqual(actual, expected) + + +class GroupByTransformTests(TestCase): + def assertAllGroupsEqual(self, groupby1, groupby2): + """Compare two groupby objects for equality, both keys and groups.""" + for a, b in zip(groupby1, groupby2): + key1, group1 = a + key2, group2 = b + self.assertEqual(key1, key2) + self.assertListEqual(list(group1), list(group2)) + self.assertRaises(StopIteration, lambda: next(groupby1)) + self.assertRaises(StopIteration, lambda: next(groupby2)) + + def test_default_funcs(self): + """Test that groupby_transform() with default args mimics groupby()""" + iterable = [(x // 5, x) for x in range(1000)] + actual = mi.groupby_transform(iterable) + expected = groupby(iterable) + self.assertAllGroupsEqual(actual, expected) + + def test_valuefunc(self): + iterable = [(int(x / 5), int(x / 3), x) for x in range(10)] + + # Test the standard usage of grouping one iterable using another's keys + grouper = mi.groupby_transform( + iterable, keyfunc=itemgetter(0), valuefunc=itemgetter(-1) + ) + actual = [(k, list(g)) for k, g in grouper] + expected = [(0, [0, 1, 2, 3, 4]), (1, [5, 6, 7, 8, 9])] + self.assertEqual(actual, expected) + + grouper = mi.groupby_transform( + iterable, keyfunc=itemgetter(1), valuefunc=itemgetter(-1) + ) + actual = [(k, list(g)) for k, g in grouper] + expected = [(0, [0, 1, 2]), (1, [3, 4, 5]), (2, [6, 7, 8]), (3, [9])] + self.assertEqual(actual, expected) + + # and now for something a little different + d = dict(zip(range(10), 'abcdefghij')) + grouper = mi.groupby_transform( + range(10), keyfunc=lambda x: x // 5, valuefunc=d.get + ) + actual = [(k, ''.join(g)) for k, g in grouper] + expected = [(0, 'abcde'), (1, 'fghij')] + self.assertEqual(actual, expected) + + def test_no_valuefunc(self): + iterable = range(1000) + + def key(x): + return x // 5 + + actual = mi.groupby_transform(iterable, key, valuefunc=None) + expected = groupby(iterable, key) + self.assertAllGroupsEqual(actual, expected) + + actual = mi.groupby_transform(iterable, key) # default valuefunc + expected = groupby(iterable, key) + self.assertAllGroupsEqual(actual, expected) + + +class NumericRangeTests(TestCase): + def test_basic(self): + for args, expected in [ + ((4,), [0, 1, 2, 3]), + ((4.0,), [0.0, 1.0, 2.0, 3.0]), + ((1.0, 4), [1.0, 2.0, 3.0]), + ((1, 4.0), [1, 2, 3]), + ((1.0, 5), [1.0, 2.0, 3.0, 4.0]), + ((0, 20, 5), [0, 5, 10, 15]), + ((0, 20, 5.0), [0.0, 5.0, 10.0, 15.0]), + ((0, 10, 3), [0, 3, 6, 9]), + ((0, 10, 3.0), [0.0, 3.0, 6.0, 9.0]), + ((0, -5, -1), [0, -1, -2, -3, -4]), + ((0.0, -5, -1), [0.0, -1.0, -2.0, -3.0, -4.0]), + ((1, 2, Fraction(1, 2)), [Fraction(1, 1), Fraction(3, 2)]), + ((0,), []), + ((0.0,), []), + ((1, 0), []), + ((1.0, 0.0), []), + ((Fraction(2, 1),), [Fraction(0, 1), Fraction(1, 1)]), + ((Decimal('2.0'),), [Decimal('0.0'), Decimal('1.0')]), + ]: + actual = list(mi.numeric_range(*args)) + self.assertEqual(actual, expected) + self.assertTrue( + all(type(a) == type(e) for a, e in zip(actual, expected)) + ) + + def test_arg_count(self): + self.assertRaises(TypeError, lambda: list(mi.numeric_range())) + self.assertRaises( + TypeError, lambda: list(mi.numeric_range(0, 1, 2, 3)) + ) + + def test_zero_step(self): + self.assertRaises( + ValueError, lambda: list(mi.numeric_range(1, 2, 0)) + ) + + +class CountCycleTests(TestCase): + def test_basic(self): + expected = [ + (0, 'a'), (0, 'b'), (0, 'c'), + (1, 'a'), (1, 'b'), (1, 'c'), + (2, 'a'), (2, 'b'), (2, 'c'), + ] + for actual in [ + mi.take(9, mi.count_cycle('abc')), # n=None + list(mi.count_cycle('abc', 3)), # n=3 + ]: + self.assertEqual(actual, expected) + + def test_empty(self): + self.assertEqual(list(mi.count_cycle('')), []) + self.assertEqual(list(mi.count_cycle('', 2)), []) + + def test_negative(self): + self.assertEqual(list(mi.count_cycle('abc', -3)), []) + + +class LocateTests(TestCase): + def test_default_pred(self): + iterable = [0, 1, 1, 0, 1, 0, 0] + actual = list(mi.locate(iterable)) + expected = [1, 2, 4] + self.assertEqual(actual, expected) + + def test_no_matches(self): + iterable = [0, 0, 0] + actual = list(mi.locate(iterable)) + expected = [] + self.assertEqual(actual, expected) + + def test_custom_pred(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda x: x == '0' + actual = list(mi.locate(iterable, pred)) + expected = [0, 3, 5, 6] + self.assertEqual(actual, expected) + + def test_window_size(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda *args: args == ('0', 1) + actual = list(mi.locate(iterable, pred, window_size=2)) + expected = [0, 3] + self.assertEqual(actual, expected) + + def test_window_size_large(self): + iterable = [1, 2, 3, 4] + pred = lambda a, b, c, d, e: True + actual = list(mi.locate(iterable, pred, window_size=5)) + expected = [0] + self.assertEqual(actual, expected) + + def test_window_size_zero(self): + iterable = [1, 2, 3, 4] + pred = lambda: True + with self.assertRaises(ValueError): + list(mi.locate(iterable, pred, window_size=0)) + + +class StripFunctionTests(TestCase): + def test_hashable(self): + iterable = list('www.example.com') + pred = lambda x: x in set('cmowz.') + + self.assertEqual(list(mi.lstrip(iterable, pred)), list('example.com')) + self.assertEqual(list(mi.rstrip(iterable, pred)), list('www.example')) + self.assertEqual(list(mi.strip(iterable, pred)), list('example')) + + def test_not_hashable(self): + iterable = [ + list('http://'), list('www'), list('.example'), list('.com') + ] + pred = lambda x: x in [list('http://'), list('www'), list('.com')] + + self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[2:]) + self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:3]) + self.assertEqual(list(mi.strip(iterable, pred)), iterable[2: 3]) + + def test_math(self): + iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] + pred = lambda x: x <= 2 + + self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[3:]) + self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:-3]) + self.assertEqual(list(mi.strip(iterable, pred)), iterable[3:-3]) + + +class IsliceExtendedTests(TestCase): + def test_all(self): + iterable = ['0', '1', '2', '3', '4', '5'] + indexes = list(range(-4, len(iterable) + 4)) + [None] + steps = [1, 2, 3, 4, -1, -2, -3, 4] + for slice_args in product(indexes, indexes, steps): + try: + actual = list(mi.islice_extended(iterable, *slice_args)) + except Exception as e: + self.fail((slice_args, e)) + + expected = iterable[slice(*slice_args)] + self.assertEqual(actual, expected, slice_args) + + def test_zero_step(self): + with self.assertRaises(ValueError): + list(mi.islice_extended([1, 2, 3], 0, 1, 0)) + + +class ConsecutiveGroupsTest(TestCase): + def test_numbers(self): + iterable = [-10, -8, -7, -6, 1, 2, 4, 5, -1, 7] + actual = [list(g) for g in mi.consecutive_groups(iterable)] + expected = [[-10], [-8, -7, -6], [1, 2], [4, 5], [-1], [7]] + self.assertEqual(actual, expected) + + def test_custom_ordering(self): + iterable = ['1', '10', '11', '20', '21', '22', '30', '31'] + ordering = lambda x: int(x) + actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)] + expected = [['1'], ['10', '11'], ['20', '21', '22'], ['30', '31']] + self.assertEqual(actual, expected) + + def test_exotic_ordering(self): + iterable = [ + ('a', 'b', 'c', 'd'), + ('a', 'c', 'b', 'd'), + ('a', 'c', 'd', 'b'), + ('a', 'd', 'b', 'c'), + ('d', 'b', 'c', 'a'), + ('d', 'c', 'a', 'b'), + ] + ordering = list(permutations('abcd')).index + actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)] + expected = [ + [('a', 'b', 'c', 'd')], + [('a', 'c', 'b', 'd'), ('a', 'c', 'd', 'b'), ('a', 'd', 'b', 'c')], + [('d', 'b', 'c', 'a'), ('d', 'c', 'a', 'b')], + ] + self.assertEqual(actual, expected) + + +class DifferenceTest(TestCase): + def test_normal(self): + iterable = [10, 20, 30, 40, 50] + actual = list(mi.difference(iterable)) + expected = [10, 10, 10, 10, 10] + self.assertEqual(actual, expected) + + def test_custom(self): + iterable = [10, 20, 30, 40, 50] + actual = list(mi.difference(iterable, add)) + expected = [10, 30, 50, 70, 90] + self.assertEqual(actual, expected) + + def test_roundtrip(self): + original = list(range(100)) + accumulated = mi.accumulate(original) + actual = list(mi.difference(accumulated)) + self.assertEqual(actual, original) + + def test_one(self): + self.assertEqual(list(mi.difference([0])), [0]) + + def test_empty(self): + self.assertEqual(list(mi.difference([])), []) + + +class SeekableTest(TestCase): + def test_exhaustion_reset(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(list(s), iterable) # Normal iteration + self.assertEqual(list(s), []) # Iterable is exhausted + + s.seek(0) + self.assertEqual(list(s), iterable) # Back in action + + def test_partial_reset(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(mi.take(5, s), iterable[:5]) # Normal iteration + + s.seek(1) + self.assertEqual(list(s), iterable[1:]) # Get the rest of the iterable + + def test_forward(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration + + s.seek(3) # Skip over index 2 + self.assertEqual(list(s), iterable[3:]) # Result is similar to slicing + + s.seek(0) # Back to 0 + self.assertEqual(list(s), iterable) # No difference in result + + def test_past_end(self): + iterable = [str(n) for n in range(10)] + + s = mi.seekable(iterable) + self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration + + s.seek(20) + self.assertEqual(list(s), []) # Iterable is exhausted + + s.seek(0) # Back to 0 + self.assertEqual(list(s), iterable) # No difference in result + + def test_elements(self): + iterable = map(str, count()) + + s = mi.seekable(iterable) + mi.take(10, s) + + elements = s.elements() + self.assertEqual( + [elements[i] for i in range(10)], [str(n) for n in range(10)] + ) + self.assertEqual(len(elements), 10) + + mi.take(10, s) + self.assertEqual(list(elements), [str(n) for n in range(20)]) + + +class SequenceViewTests(TestCase): + def test_init(self): + view = mi.SequenceView((1, 2, 3)) + self.assertEqual(repr(view), "SequenceView((1, 2, 3))") + self.assertRaises(TypeError, lambda: mi.SequenceView({})) + + def test_update(self): + seq = [1, 2, 3] + view = mi.SequenceView(seq) + self.assertEqual(len(view), 3) + self.assertEqual(repr(view), "SequenceView([1, 2, 3])") + + seq.pop() + self.assertEqual(len(view), 2) + self.assertEqual(repr(view), "SequenceView([1, 2])") + + def test_indexing(self): + seq = ('a', 'b', 'c', 'd', 'e', 'f') + view = mi.SequenceView(seq) + for i in range(-len(seq), len(seq)): + self.assertEqual(view[i], seq[i]) + + def test_slicing(self): + seq = ('a', 'b', 'c', 'd', 'e', 'f') + view = mi.SequenceView(seq) + n = len(seq) + indexes = list(range(-n - 1, n + 1)) + [None] + steps = list(range(-n, n + 1)) + steps.remove(0) + for slice_args in product(indexes, indexes, steps): + i = slice(*slice_args) + self.assertEqual(view[i], seq[i]) + + def test_abc_methods(self): + # collections.Sequence should provide all of this functionality + seq = ('a', 'b', 'c', 'd', 'e', 'f', 'f') + view = mi.SequenceView(seq) + + # __contains__ + self.assertIn('b', view) + self.assertNotIn('g', view) + + # __iter__ + self.assertEqual(list(iter(view)), list(seq)) + + # __reversed__ + self.assertEqual(list(reversed(view)), list(reversed(seq))) + + # index + self.assertEqual(view.index('b'), 1) + + # count + self.assertEqual(seq.count('f'), 2) + + +class RunLengthTest(TestCase): + def test_encode(self): + iterable = (int(str(n)[0]) for n in count(800)) + actual = mi.take(4, mi.run_length.encode(iterable)) + expected = [(8, 100), (9, 100), (1, 1000), (2, 1000)] + self.assertEqual(actual, expected) + + def test_decode(self): + iterable = [('d', 4), ('c', 3), ('b', 2), ('a', 1)] + actual = ''.join(mi.run_length.decode(iterable)) + expected = 'ddddcccbba' + self.assertEqual(actual, expected) + + +class ExactlyNTests(TestCase): + """Tests for ``exactly_n()``""" + + def test_true(self): + """Iterable has ``n`` ``True`` elements""" + self.assertTrue(mi.exactly_n([True, False, True], 2)) + self.assertTrue(mi.exactly_n([1, 1, 1, 0], 3)) + self.assertTrue(mi.exactly_n([False, False], 0)) + self.assertTrue(mi.exactly_n(range(100), 10, lambda x: x < 10)) + + def test_false(self): + """Iterable does not have ``n`` ``True`` elements""" + self.assertFalse(mi.exactly_n([True, False, False], 2)) + self.assertFalse(mi.exactly_n([True, True, False], 1)) + self.assertFalse(mi.exactly_n([False], 1)) + self.assertFalse(mi.exactly_n([True], -1)) + self.assertFalse(mi.exactly_n(repeat(True), 100)) + + def test_empty(self): + """Return ``True`` if the iterable is empty and ``n`` is 0""" + self.assertTrue(mi.exactly_n([], 0)) + self.assertFalse(mi.exactly_n([], 1)) + + +class AlwaysReversibleTests(TestCase): + """Tests for ``always_reversible()``""" + + def test_regular_reversed(self): + self.assertEqual(list(reversed(range(10))), + list(mi.always_reversible(range(10)))) + self.assertEqual(list(reversed([1, 2, 3])), + list(mi.always_reversible([1, 2, 3]))) + self.assertEqual(reversed([1, 2, 3]).__class__, + mi.always_reversible([1, 2, 3]).__class__) + + def test_nonseq_reversed(self): + # Create a non-reversible generator from a sequence + with self.assertRaises(TypeError): + reversed(x for x in range(10)) + + self.assertEqual(list(reversed(range(10))), + list(mi.always_reversible(x for x in range(10)))) + self.assertEqual(list(reversed([1, 2, 3])), + list(mi.always_reversible(x for x in [1, 2, 3]))) + self.assertNotEqual(reversed((1, 2)).__class__, + mi.always_reversible(x for x in (1, 2)).__class__) + + +class CircularShiftsTests(TestCase): + def test_empty(self): + # empty iterable -> empty list + self.assertEqual(list(mi.circular_shifts([])), []) + + def test_simple_circular_shifts(self): + # test the a simple iterator case + self.assertEqual( + mi.circular_shifts(range(4)), + [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] + ) + + def test_duplicates(self): + # test non-distinct entries + self.assertEqual( + mi.circular_shifts([0, 1, 0, 1]), + [(0, 1, 0, 1), (1, 0, 1, 0), (0, 1, 0, 1), (1, 0, 1, 0)] + ) + + +class MakeDecoratorTests(TestCase): + def test_basic(self): + slicer = mi.make_decorator(islice) + + @slicer(1, 10, 2) + def user_function(arg_1, arg_2, kwarg_1=None): + self.assertEqual(arg_1, 'arg_1') + self.assertEqual(arg_2, 'arg_2') + self.assertEqual(kwarg_1, 'kwarg_1') + return map(str, count()) + + it = user_function('arg_1', 'arg_2', kwarg_1='kwarg_1') + actual = list(it) + expected = ['1', '3', '5', '7', '9'] + self.assertEqual(actual, expected) + + def test_result_index(self): + def stringify(*args, **kwargs): + self.assertEqual(args[0], 'arg_0') + iterable = args[1] + self.assertEqual(args[2], 'arg_2') + self.assertEqual(kwargs['kwarg_1'], 'kwarg_1') + return map(str, iterable) + + stringifier = mi.make_decorator(stringify, result_index=1) + + @stringifier('arg_0', 'arg_2', kwarg_1='kwarg_1') + def user_function(n): + return count(n) + + it = user_function(1) + actual = mi.take(5, it) + expected = ['1', '2', '3', '4', '5'] + self.assertEqual(actual, expected) + + def test_wrap_class(self): + seeker = mi.make_decorator(mi.seekable) + + @seeker() + def user_function(n): + return map(str, range(n)) + + it = user_function(5) + self.assertEqual(list(it), ['0', '1', '2', '3', '4']) + + it.seek(0) + self.assertEqual(list(it), ['0', '1', '2', '3', '4']) + + +class MapReduceTests(TestCase): + def test_default(self): + iterable = (str(x) for x in range(5)) + keyfunc = lambda x: int(x) // 2 + actual = sorted(mi.map_reduce(iterable, keyfunc).items()) + expected = [(0, ['0', '1']), (1, ['2', '3']), (2, ['4'])] + self.assertEqual(actual, expected) + + def test_valuefunc(self): + iterable = (str(x) for x in range(5)) + keyfunc = lambda x: int(x) // 2 + valuefunc = int + actual = sorted(mi.map_reduce(iterable, keyfunc, valuefunc).items()) + expected = [(0, [0, 1]), (1, [2, 3]), (2, [4])] + self.assertEqual(actual, expected) + + def test_reducefunc(self): + iterable = (str(x) for x in range(5)) + keyfunc = lambda x: int(x) // 2 + valuefunc = int + reducefunc = lambda value_list: reduce(mul, value_list, 1) + actual = sorted( + mi.map_reduce(iterable, keyfunc, valuefunc, reducefunc).items() + ) + expected = [(0, 0), (1, 6), (2, 4)] + self.assertEqual(actual, expected) + + def test_ret(self): + d = mi.map_reduce([1, 0, 2, 0, 1, 0], bool) + self.assertEqual(d, {False: [0, 0, 0], True: [1, 2, 1]}) + self.assertRaises(KeyError, lambda: d[None].append(1)) + + +class RlocateTests(TestCase): + def test_default_pred(self): + iterable = [0, 1, 1, 0, 1, 0, 0] + for it in (iterable[:], iter(iterable)): + actual = list(mi.rlocate(it)) + expected = [4, 2, 1] + self.assertEqual(actual, expected) + + def test_no_matches(self): + iterable = [0, 0, 0] + for it in (iterable[:], iter(iterable)): + actual = list(mi.rlocate(it)) + expected = [] + self.assertEqual(actual, expected) + + def test_custom_pred(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda x: x == '0' + for it in (iterable[:], iter(iterable)): + actual = list(mi.rlocate(it, pred)) + expected = [6, 5, 3, 0] + self.assertEqual(actual, expected) + + def test_efficient_reversal(self): + iterable = range(10 ** 10) # Is efficiently reversible + target = 10 ** 10 - 2 + pred = lambda x: x == target # Find-able from the right + actual = next(mi.rlocate(iterable, pred)) + self.assertEqual(actual, target) + + def test_window_size(self): + iterable = ['0', 1, 1, '0', 1, '0', '0'] + pred = lambda *args: args == ('0', 1) + for it in (iterable, iter(iterable)): + actual = list(mi.rlocate(it, pred, window_size=2)) + expected = [3, 0] + self.assertEqual(actual, expected) + + def test_window_size_large(self): + iterable = [1, 2, 3, 4] + pred = lambda a, b, c, d, e: True + for it in (iterable, iter(iterable)): + actual = list(mi.rlocate(iterable, pred, window_size=5)) + expected = [0] + self.assertEqual(actual, expected) + + def test_window_size_zero(self): + iterable = [1, 2, 3, 4] + pred = lambda: True + for it in (iterable, iter(iterable)): + with self.assertRaises(ValueError): + list(mi.locate(iterable, pred, window_size=0)) + + +class ReplaceTests(TestCase): + def test_basic(self): + iterable = range(10) + pred = lambda x: x % 2 == 0 + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes)) + expected = [1, 3, 5, 7, 9] + self.assertEqual(actual, expected) + + def test_count(self): + iterable = range(10) + pred = lambda x: x % 2 == 0 + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes, count=4)) + expected = [1, 3, 5, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_window_size(self): + iterable = range(10) + pred = lambda *args: args == (0, 1, 2) + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes, window_size=3)) + expected = [3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_window_size_end(self): + iterable = range(10) + pred = lambda *args: args == (7, 8, 9) + substitutes = [] + actual = list(mi.replace(iterable, pred, substitutes, window_size=3)) + expected = [0, 1, 2, 3, 4, 5, 6] + self.assertEqual(actual, expected) + + def test_window_size_count(self): + iterable = range(10) + pred = lambda *args: (args == (0, 1, 2)) or (args == (7, 8, 9)) + substitutes = [] + actual = list( + mi.replace(iterable, pred, substitutes, count=1, window_size=3) + ) + expected = [3, 4, 5, 6, 7, 8, 9] + self.assertEqual(actual, expected) + + def test_window_size_large(self): + iterable = range(4) + pred = lambda a, b, c, d, e: True + substitutes = [5, 6, 7] + actual = list(mi.replace(iterable, pred, substitutes, window_size=5)) + expected = [5, 6, 7] + self.assertEqual(actual, expected) + + def test_window_size_zero(self): + iterable = range(10) + pred = lambda *args: True + substitutes = [] + with self.assertRaises(ValueError): + list(mi.replace(iterable, pred, substitutes, window_size=0)) + + def test_iterable_substitutes(self): + iterable = range(5) + pred = lambda x: x % 2 == 0 + substitutes = iter('__') + actual = list(mi.replace(iterable, pred, substitutes)) + expected = ['_', '_', 1, '_', '_', 3, '_', '_'] + self.assertEqual(actual, expected) diff --git a/libs/more_itertools/tests/test_recipes.py b/libs/more_itertools/tests/test_recipes.py index 485d9d30..98981fe8 100644 --- a/libs/more_itertools/tests/test_recipes.py +++ b/libs/more_itertools/tests/test_recipes.py @@ -1,13 +1,39 @@ -from random import seed +from doctest import DocTestSuite from unittest import TestCase -from nose.tools import eq_, assert_raises, ok_ +from itertools import combinations +from six.moves import range -from more_itertools import * +import more_itertools as mi -def setup_module(): - seed(1337) +def load_tests(loader, tests, ignore): + # Add the doctests + tests.addTests(DocTestSuite('more_itertools.recipes')) + return tests + + +class AccumulateTests(TestCase): + """Tests for ``accumulate()``""" + + def test_empty(self): + """Test that an empty input returns an empty output""" + self.assertEqual(list(mi.accumulate([])), []) + + def test_default(self): + """Test accumulate with the default function (addition)""" + self.assertEqual(list(mi.accumulate([1, 2, 3])), [1, 3, 6]) + + def test_bogus_function(self): + """Test accumulate with an invalid function""" + with self.assertRaises(TypeError): + list(mi.accumulate([1, 2, 3], func=lambda x: x)) + + def test_custom_function(self): + """Test accumulate with a custom function""" + self.assertEqual( + list(mi.accumulate((1, 2, 3, 2, 1), func=max)), [1, 2, 3, 3, 3] + ) class TakeTests(TestCase): @@ -15,25 +41,25 @@ class TakeTests(TestCase): def test_simple_take(self): """Test basic usage""" - t = take(5, xrange(10)) - eq_(t, [0, 1, 2, 3, 4]) + t = mi.take(5, range(10)) + self.assertEqual(t, [0, 1, 2, 3, 4]) def test_null_take(self): """Check the null case""" - t = take(0, xrange(10)) - eq_(t, []) + t = mi.take(0, range(10)) + self.assertEqual(t, []) def test_negative_take(self): """Make sure taking negative items results in a ValueError""" - assert_raises(ValueError, take, -3, xrange(10)) + self.assertRaises(ValueError, lambda: mi.take(-3, range(10))) def test_take_too_much(self): """Taking more than an iterator has remaining should return what the iterator has remaining. """ - t = take(10, xrange(5)) - eq_(t, [0, 1, 2, 3, 4]) + t = mi.take(10, range(5)) + self.assertEqual(t, [0, 1, 2, 3, 4]) class TabulateTests(TestCase): @@ -41,15 +67,35 @@ class TabulateTests(TestCase): def test_simple_tabulate(self): """Test the happy path""" - t = tabulate(lambda x: x) + t = mi.tabulate(lambda x: x) f = tuple([next(t) for _ in range(3)]) - eq_(f, (0, 1, 2)) + self.assertEqual(f, (0, 1, 2)) def test_count(self): """Ensure tabulate accepts specific count""" - t = tabulate(lambda x: 2 * x, -1) + t = mi.tabulate(lambda x: 2 * x, -1) f = (next(t), next(t), next(t)) - eq_(f, (-2, 0, 2)) + self.assertEqual(f, (-2, 0, 2)) + + +class TailTests(TestCase): + """Tests for ``tail()``""" + + def test_greater(self): + """Length of iterable is greather than requested tail""" + self.assertEqual(list(mi.tail(3, 'ABCDEFG')), ['E', 'F', 'G']) + + def test_equal(self): + """Length of iterable is equal to the requested tail""" + self.assertEqual( + list(mi.tail(7, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] + ) + + def test_less(self): + """Length of iterable is less than requested tail""" + self.assertEqual( + list(mi.tail(8, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] + ) class ConsumeTests(TestCase): @@ -58,25 +104,25 @@ class ConsumeTests(TestCase): def test_sanity(self): """Test basic functionality""" r = (x for x in range(10)) - consume(r, 3) - eq_(3, next(r)) + mi.consume(r, 3) + self.assertEqual(3, next(r)) def test_null_consume(self): """Check the null case""" r = (x for x in range(10)) - consume(r, 0) - eq_(0, next(r)) + mi.consume(r, 0) + self.assertEqual(0, next(r)) def test_negative_consume(self): """Check that negative consumsion throws an error""" r = (x for x in range(10)) - assert_raises(ValueError, consume, r, -1) + self.assertRaises(ValueError, lambda: mi.consume(r, -1)) def test_total_consume(self): """Check that iterator is totally consumed by default""" r = (x for x in range(10)) - consume(r) - assert_raises(StopIteration, next, r) + mi.consume(r) + self.assertRaises(StopIteration, lambda: next(r)) class NthTests(TestCase): @@ -86,16 +132,45 @@ class NthTests(TestCase): """Make sure the nth item is returned""" l = range(10) for i, v in enumerate(l): - eq_(nth(l, i), v) + self.assertEqual(mi.nth(l, i), v) def test_default(self): """Ensure a default value is returned when nth item not found""" l = range(3) - eq_(nth(l, 100, "zebra"), "zebra") + self.assertEqual(mi.nth(l, 100, "zebra"), "zebra") def test_negative_item_raises(self): """Ensure asking for a negative item raises an exception""" - assert_raises(ValueError, nth, range(10), -3) + self.assertRaises(ValueError, lambda: mi.nth(range(10), -3)) + + +class AllEqualTests(TestCase): + """Tests for ``all_equal()``""" + + def test_true(self): + """Everything is equal""" + self.assertTrue(mi.all_equal('aaaaaa')) + self.assertTrue(mi.all_equal([0, 0, 0, 0])) + + def test_false(self): + """Not everything is equal""" + self.assertFalse(mi.all_equal('aaaaab')) + self.assertFalse(mi.all_equal([0, 0, 0, 1])) + + def test_tricky(self): + """Not everything is identical, but everything is equal""" + items = [1, complex(1, 0), 1.0] + self.assertTrue(mi.all_equal(items)) + + def test_empty(self): + """Return True if the iterable is empty""" + self.assertTrue(mi.all_equal('')) + self.assertTrue(mi.all_equal([])) + + def test_one(self): + """Return True if the iterable is singular""" + self.assertTrue(mi.all_equal('0')) + self.assertTrue(mi.all_equal([0])) class QuantifyTests(TestCase): @@ -104,12 +179,12 @@ class QuantifyTests(TestCase): def test_happy_path(self): """Make sure True count is returned""" q = [True, False, True] - eq_(quantify(q), 2) + self.assertEqual(mi.quantify(q), 2) def test_custom_predicate(self): """Ensure non-default predicates return as expected""" q = range(10) - eq_(quantify(q, lambda x: x % 2 == 0), 5) + self.assertEqual(mi.quantify(q, lambda x: x % 2 == 0), 5) class PadnoneTests(TestCase): @@ -118,8 +193,8 @@ class PadnoneTests(TestCase): def test_happy_path(self): """wrapper iterator should return None indefinitely""" r = range(2) - p = padnone(r) - eq_([0, 1, None, None], [next(p) for _ in range(4)]) + p = mi.padnone(r) + self.assertEqual([0, 1, None, None], [next(p) for _ in range(4)]) class NcyclesTests(TestCase): @@ -128,19 +203,21 @@ class NcyclesTests(TestCase): def test_happy_path(self): """cycle a sequence three times""" r = ["a", "b", "c"] - n = ncycles(r, 3) - eq_(["a", "b", "c", "a", "b", "c", "a", "b", "c"], - list(n)) + n = mi.ncycles(r, 3) + self.assertEqual( + ["a", "b", "c", "a", "b", "c", "a", "b", "c"], + list(n) + ) def test_null_case(self): """asking for 0 cycles should return an empty iterator""" - n = ncycles(range(100), 0) - assert_raises(StopIteration, next, n) + n = mi.ncycles(range(100), 0) + self.assertRaises(StopIteration, lambda: next(n)) def test_pathalogical_case(self): """asking for negative cycles should return an empty iterator""" - n = ncycles(range(100), -10) - assert_raises(StopIteration, next, n) + n = mi.ncycles(range(100), -10) + self.assertRaises(StopIteration, lambda: next(n)) class DotproductTests(TestCase): @@ -148,7 +225,7 @@ class DotproductTests(TestCase): def test_happy_path(self): """simple dotproduct example""" - eq_(400, dotproduct([10, 10], [20, 20])) + self.assertEqual(400, mi.dotproduct([10, 10], [20, 20])) class FlattenTests(TestCase): @@ -157,12 +234,12 @@ class FlattenTests(TestCase): def test_basic_usage(self): """ensure list of lists is flattened one level""" f = [[0, 1, 2], [3, 4, 5]] - eq_(range(6), list(flatten(f))) + self.assertEqual(list(range(6)), list(mi.flatten(f))) def test_single_level(self): """ensure list of lists is flattened only one level""" f = [[0, [1, 2]], [[3, 4], 5]] - eq_([0, [1, 2], [3, 4], 5], list(flatten(f))) + self.assertEqual([0, [1, 2], [3, 4], 5], list(mi.flatten(f))) class RepeatfuncTests(TestCase): @@ -170,23 +247,23 @@ class RepeatfuncTests(TestCase): def test_simple_repeat(self): """test simple repeated functions""" - r = repeatfunc(lambda: 5) - eq_([5, 5, 5, 5, 5], [next(r) for _ in range(5)]) + r = mi.repeatfunc(lambda: 5) + self.assertEqual([5, 5, 5, 5, 5], [next(r) for _ in range(5)]) def test_finite_repeat(self): """ensure limited repeat when times is provided""" - r = repeatfunc(lambda: 5, times=5) - eq_([5, 5, 5, 5, 5], list(r)) + r = mi.repeatfunc(lambda: 5, times=5) + self.assertEqual([5, 5, 5, 5, 5], list(r)) def test_added_arguments(self): """ensure arguments are applied to the function""" - r = repeatfunc(lambda x: x, 2, 3) - eq_([3, 3], list(r)) + r = mi.repeatfunc(lambda x: x, 2, 3) + self.assertEqual([3, 3], list(r)) def test_null_times(self): """repeat 0 should return an empty iterator""" - r = repeatfunc(range, 0, 3) - assert_raises(StopIteration, next, r) + r = mi.repeatfunc(range, 0, 3) + self.assertRaises(StopIteration, lambda: next(r)) class PairwiseTests(TestCase): @@ -194,13 +271,13 @@ class PairwiseTests(TestCase): def test_base_case(self): """ensure an iterable will return pairwise""" - p = pairwise([1, 2, 3]) - eq_([(1, 2), (2, 3)], list(p)) + p = mi.pairwise([1, 2, 3]) + self.assertEqual([(1, 2), (2, 3)], list(p)) def test_short_case(self): """ensure an empty iterator if there's not enough values to pair""" - p = pairwise("a") - assert_raises(StopIteration, next, p) + p = mi.pairwise("a") + self.assertRaises(StopIteration, lambda: next(p)) class GrouperTests(TestCase): @@ -211,18 +288,25 @@ class GrouperTests(TestCase): the iterable. """ - eq_(list(grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')]) + self.assertEqual( + list(mi.grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')] + ) def test_odd(self): """Test when group size does not divide evenly into the length of the iterable. """ - eq_(list(grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)]) + self.assertEqual( + list(mi.grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)] + ) def test_fill_value(self): """Test that the fill value is used to pad the final group""" - eq_(list(grouper(3, 'ABCDE', 'x')), [('A', 'B', 'C'), ('D', 'E', 'x')]) + self.assertEqual( + list(mi.grouper(3, 'ABCDE', 'x')), + [('A', 'B', 'C'), ('D', 'E', 'x')] + ) class RoundrobinTests(TestCase): @@ -230,13 +314,33 @@ class RoundrobinTests(TestCase): def test_even_groups(self): """Ensure ordered output from evenly populated iterables""" - eq_(list(roundrobin('ABC', [1, 2, 3], range(3))), - ['A', 1, 0, 'B', 2, 1, 'C', 3, 2]) + self.assertEqual( + list(mi.roundrobin('ABC', [1, 2, 3], range(3))), + ['A', 1, 0, 'B', 2, 1, 'C', 3, 2] + ) def test_uneven_groups(self): """Ensure ordered output from unevenly populated iterables""" - eq_(list(roundrobin('ABCD', [1, 2], range(0))), - ['A', 1, 'B', 2, 'C', 'D']) + self.assertEqual( + list(mi.roundrobin('ABCD', [1, 2], range(0))), + ['A', 1, 'B', 2, 'C', 'D'] + ) + + +class PartitionTests(TestCase): + """Tests for ``partition()``""" + + def test_bool(self): + """Test when pred() returns a boolean""" + lesser, greater = mi.partition(lambda x: x > 5, range(10)) + self.assertEqual(list(lesser), [0, 1, 2, 3, 4, 5]) + self.assertEqual(list(greater), [6, 7, 8, 9]) + + def test_arbitrary(self): + """Test when pred() returns an integer""" + divisibles, remainders = mi.partition(lambda x: x % 3, range(10)) + self.assertEqual(list(divisibles), [0, 3, 6, 9]) + self.assertEqual(list(remainders), [1, 2, 4, 5, 7, 8]) class PowersetTests(TestCase): @@ -244,9 +348,11 @@ class PowersetTests(TestCase): def test_combinatorics(self): """Ensure a proper enumeration""" - p = powerset([1, 2, 3]) - eq_(list(p), - [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]) + p = mi.powerset([1, 2, 3]) + self.assertEqual( + list(p), + [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] + ) class UniqueEverseenTests(TestCase): @@ -254,14 +360,28 @@ class UniqueEverseenTests(TestCase): def test_everseen(self): """ensure duplicate elements are ignored""" - u = unique_everseen('AAAABBBBCCDAABBB') - eq_(['A', 'B', 'C', 'D'], - list(u)) + u = mi.unique_everseen('AAAABBBBCCDAABBB') + self.assertEqual( + ['A', 'B', 'C', 'D'], + list(u) + ) def test_custom_key(self): """ensure the custom key comparison works""" - u = unique_everseen('aAbACCc', key=str.lower) - eq_(list('abC'), list(u)) + u = mi.unique_everseen('aAbACCc', key=str.lower) + self.assertEqual(list('abC'), list(u)) + + def test_unhashable(self): + """ensure things work for unhashable items""" + iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] + u = mi.unique_everseen(iterable) + self.assertEqual(list(u), ['a', [1, 2, 3]]) + + def test_unhashable_key(self): + """ensure things work for unhashable items with a custom key""" + iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] + u = mi.unique_everseen(iterable, key=lambda x: x) + self.assertEqual(list(u), ['a', [1, 2, 3]]) class UniqueJustseenTests(TestCase): @@ -269,13 +389,13 @@ class UniqueJustseenTests(TestCase): def test_justseen(self): """ensure only last item is remembered""" - u = unique_justseen('AAAABBBCCDABB') - eq_(list('ABCDAB'), list(u)) + u = mi.unique_justseen('AAAABBBCCDABB') + self.assertEqual(list('ABCDAB'), list(u)) def test_custom_key(self): """ensure the custom key comparison works""" - u = unique_justseen('AABCcAD', str.lower) - eq_(list('ABCAD'), list(u)) + u = mi.unique_justseen('AABCcAD', str.lower) + self.assertEqual(list('ABCAD'), list(u)) class IterExceptTests(TestCase): @@ -284,27 +404,49 @@ class IterExceptTests(TestCase): def test_exact_exception(self): """ensure the exact specified exception is caught""" l = [1, 2, 3] - i = iter_except(l.pop, IndexError) - eq_(list(i), [3, 2, 1]) + i = mi.iter_except(l.pop, IndexError) + self.assertEqual(list(i), [3, 2, 1]) def test_generic_exception(self): """ensure the generic exception can be caught""" l = [1, 2] - i = iter_except(l.pop, Exception) - eq_(list(i), [2, 1]) + i = mi.iter_except(l.pop, Exception) + self.assertEqual(list(i), [2, 1]) def test_uncaught_exception_is_raised(self): """ensure a non-specified exception is raised""" l = [1, 2, 3] - i = iter_except(l.pop, KeyError) - assert_raises(IndexError, list, i) + i = mi.iter_except(l.pop, KeyError) + self.assertRaises(IndexError, lambda: list(i)) def test_first(self): """ensure first is run before the function""" l = [1, 2, 3] f = lambda: 25 - i = iter_except(l.pop, IndexError, f) - eq_(list(i), [25, 3, 2, 1]) + i = mi.iter_except(l.pop, IndexError, f) + self.assertEqual(list(i), [25, 3, 2, 1]) + + +class FirstTrueTests(TestCase): + """Tests for ``first_true()``""" + + def test_something_true(self): + """Test with no keywords""" + self.assertEqual(mi.first_true(range(10)), 1) + + def test_nothing_true(self): + """Test default return value.""" + self.assertEqual(mi.first_true([0, 0, 0]), False) + + def test_default(self): + """Test with a default keyword""" + self.assertEqual(mi.first_true([0, 0, 0], default='!'), '!') + + def test_pred(self): + """Test with a custom predicate""" + self.assertEqual( + mi.first_true([2, 4, 6], pred=lambda x: x % 3 == 0), 6 + ) class RandomProductTests(TestCase): @@ -327,12 +469,12 @@ class RandomProductTests(TestCase): """ nums = [1, 2, 3] lets = ['a', 'b', 'c'] - n, m = zip(*[random_product(nums, lets) for _ in range(100)]) + n, m = zip(*[mi.random_product(nums, lets) for _ in range(100)]) n, m = set(n), set(m) - eq_(n, set(nums)) - eq_(m, set(lets)) - eq_(len(n), len(nums)) - eq_(len(m), len(lets)) + self.assertEqual(n, set(nums)) + self.assertEqual(m, set(lets)) + self.assertEqual(len(n), len(nums)) + self.assertEqual(len(m), len(lets)) def test_list_with_repeat(self): """ensure multiple items are chosen, and that they appear to be chosen @@ -341,13 +483,13 @@ class RandomProductTests(TestCase): """ nums = [1, 2, 3] lets = ['a', 'b', 'c'] - r = list(random_product(nums, lets, repeat=100)) - eq_(2 * 100, len(r)) + r = list(mi.random_product(nums, lets, repeat=100)) + self.assertEqual(2 * 100, len(r)) n, m = set(r[::2]), set(r[1::2]) - eq_(n, set(nums)) - eq_(m, set(lets)) - eq_(len(n), len(nums)) - eq_(len(m), len(lets)) + self.assertEqual(n, set(nums)) + self.assertEqual(m, set(lets)) + self.assertEqual(len(n), len(nums)) + self.assertEqual(len(m), len(lets)) class RandomPermutationTests(TestCase): @@ -361,8 +503,8 @@ class RandomPermutationTests(TestCase): """ i = range(15) - r = random_permutation(i) - eq_(set(i), set(r)) + r = mi.random_permutation(i) + self.assertEqual(set(i), set(r)) if i == r: raise AssertionError("Values were not permuted") @@ -380,13 +522,13 @@ class RandomPermutationTests(TestCase): items = range(15) item_set = set(items) all_items = set() - for _ in xrange(100): - permutation = random_permutation(items, 5) - eq_(len(permutation), 5) + for _ in range(100): + permutation = mi.random_permutation(items, 5) + self.assertEqual(len(permutation), 5) permutation_set = set(permutation) - ok_(permutation_set <= item_set) + self.assertLessEqual(permutation_set, item_set) all_items |= permutation_set - eq_(all_items, item_set) + self.assertEqual(all_items, item_set) class RandomCombinationTests(TestCase): @@ -397,18 +539,20 @@ class RandomCombinationTests(TestCase): samplings of random combinations""" items = range(15) all_items = set() - for _ in xrange(50): - combination = random_combination(items, 5) + for _ in range(50): + combination = mi.random_combination(items, 5) all_items |= set(combination) - eq_(all_items, set(items)) + self.assertEqual(all_items, set(items)) def test_no_replacement(self): """ensure that elements are sampled without replacement""" items = range(15) - for _ in xrange(50): - combination = random_combination(items, len(items)) - eq_(len(combination), len(set(combination))) - assert_raises(ValueError, random_combination, items, len(items) + 1) + for _ in range(50): + combination = mi.random_combination(items, len(items)) + self.assertEqual(len(combination), len(set(combination))) + self.assertRaises( + ValueError, lambda: mi.random_combination(items, len(items) + 1) + ) class RandomCombinationWithReplacementTests(TestCase): @@ -417,17 +561,56 @@ class RandomCombinationWithReplacementTests(TestCase): def test_replacement(self): """ensure that elements are sampled with replacement""" items = range(5) - combo = random_combination_with_replacement(items, len(items) * 2) - eq_(2 * len(items), len(combo)) + combo = mi.random_combination_with_replacement(items, len(items) * 2) + self.assertEqual(2 * len(items), len(combo)) if len(set(combo)) == len(combo): raise AssertionError("Combination contained no duplicates") - def test_psuedorandomness(self): + def test_pseudorandomness(self): """ensure different subsets of the iterable get returned over many samplings of random combinations""" items = range(15) all_items = set() - for _ in xrange(50): - combination = random_combination_with_replacement(items, 5) + for _ in range(50): + combination = mi.random_combination_with_replacement(items, 5) all_items |= set(combination) - eq_(all_items, set(items)) + self.assertEqual(all_items, set(items)) + + +class NthCombinationTests(TestCase): + def test_basic(self): + iterable = 'abcdefg' + r = 4 + for index, expected in enumerate(combinations(iterable, r)): + actual = mi.nth_combination(iterable, r, index) + self.assertEqual(actual, expected) + + def test_long(self): + actual = mi.nth_combination(range(180), 4, 2000000) + expected = (2, 12, 35, 126) + self.assertEqual(actual, expected) + + def test_invalid_r(self): + for r in (-1, 3): + with self.assertRaises(ValueError): + mi.nth_combination([], r, 0) + + def test_invalid_index(self): + with self.assertRaises(IndexError): + mi.nth_combination('abcdefg', 3, -36) + + +class PrependTests(TestCase): + def test_basic(self): + value = 'a' + iterator = iter('bcdefg') + actual = list(mi.prepend(value, iterator)) + expected = list('abcdefg') + self.assertEqual(actual, expected) + + def test_multiple(self): + value = 'ab' + iterator = iter('cdefg') + actual = tuple(mi.prepend(value, iterator)) + expected = ('ab',) + tuple('cdefg') + self.assertEqual(actual, expected) diff --git a/libs/path.py b/libs/path.py index 1e92a490..69ac5c13 100644 --- a/libs/path.py +++ b/libs/path.py @@ -1,25 +1,3 @@ -# -# Copyright (c) 2010 Mikhail Gusarov -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. -# - """ path.py - An object representing a path to a file or directory. @@ -29,8 +7,18 @@ Example:: from path import Path d = Path('/home/guido/bin') + + # Globbing for f in d.files('*.py'): f.chmod(0o755) + + # Changing the working directory: + with Path("somewhere"): + # cwd in now `somewhere` + ... + + # Concatenate paths with / + foo_txt = Path("bar") / "foo.txt" """ from __future__ import unicode_literals @@ -41,7 +29,6 @@ import os import fnmatch import glob import shutil -import codecs import hashlib import errno import tempfile @@ -50,8 +37,10 @@ import operator import re import contextlib import io -from distutils import dir_util import importlib +import itertools +import platform +import ntpath try: import win32security @@ -77,22 +66,17 @@ string_types = str, text_type = str getcwdu = os.getcwd -def surrogate_escape(error): - """ - Simulate the Python 3 ``surrogateescape`` handler, but for Python 2 only. - """ - chars = error.object[error.start:error.end] - assert len(chars) == 1 - val = ord(chars) - val += 0xdc00 - return __builtin__.unichr(val), error.end if PY2: import __builtin__ string_types = __builtin__.basestring, text_type = __builtin__.unicode getcwdu = os.getcwdu - codecs.register_error('surrogateescape', surrogate_escape) + map = itertools.imap + filter = itertools.ifilter + FileNotFoundError = OSError + itertools.filterfalse = itertools.ifilterfalse + @contextlib.contextmanager def io_error_compat(): @@ -107,7 +91,8 @@ def io_error_compat(): ############################################################################## -__all__ = ['Path', 'CaseInsensitivePattern'] + +__all__ = ['Path', 'TempDir', 'CaseInsensitivePattern'] LINESEPS = ['\r\n', '\r', '\n'] @@ -119,8 +104,8 @@ U_NL_END = re.compile(r'(?:{0})$'.format(U_NEWLINE.pattern)) try: - import pkg_resources - __version__ = pkg_resources.require('path.py')[0].version + import importlib_metadata + __version__ = importlib_metadata.version('path.py') except Exception: __version__ = 'unknown' @@ -131,7 +116,7 @@ class TreeWalkWarning(Warning): # from jaraco.functools def compose(*funcs): - compose_two = lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)) + compose_two = lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)) # noqa return functools.reduce(compose_two, funcs) @@ -170,6 +155,60 @@ class multimethod(object): ) +class matchers(object): + # TODO: make this class a module + + @staticmethod + def load(param): + """ + If the supplied parameter is a string, assum it's a simple + pattern. + """ + return ( + matchers.Pattern(param) if isinstance(param, string_types) + else param if param is not None + else matchers.Null() + ) + + class Base(object): + pass + + class Null(Base): + def __call__(self, path): + return True + + class Pattern(Base): + def __init__(self, pattern): + self.pattern = pattern + + def get_pattern(self, normcase): + try: + return self._pattern + except AttributeError: + pass + self._pattern = normcase(self.pattern) + return self._pattern + + def __call__(self, path): + normcase = getattr(self, 'normcase', path.module.normcase) + pattern = self.get_pattern(normcase) + return fnmatch.fnmatchcase(normcase(path.name), pattern) + + class CaseInsensitive(Pattern): + """ + A Pattern with a ``'normcase'`` property, suitable for passing to + :meth:`listdir`, :meth:`dirs`, :meth:`files`, :meth:`walk`, + :meth:`walkdirs`, or :meth:`walkfiles` to match case-insensitive. + + For example, to get all files ending in .py, .Py, .pY, or .PY in the + current directory:: + + from path import Path, matchers + Path('.').files(matchers.CaseInsensitive('*.py')) + """ + normcase = staticmethod(ntpath.normcase) + + class Path(text_type): """ Represents a filesystem path. @@ -214,16 +253,6 @@ class Path(text_type): """ return cls - @classmethod - def _always_unicode(cls, path): - """ - Ensure the path as retrieved from a Python API, such as :func:`os.listdir`, - is a proper Unicode string. - """ - if PY3 or isinstance(path, text_type): - return path - return path.decode(sys.getfilesystemencoding(), 'surrogateescape') - # --- Special Python methods. def __repr__(self): @@ -277,6 +306,9 @@ class Path(text_type): def __exit__(self, *_): os.chdir(self._old_dir) + def __fspath__(self): + return self + @classmethod def getcwd(cls): """ Return the current working directory as a path object. @@ -330,23 +362,45 @@ class Path(text_type): return self.expandvars().expanduser().normpath() @property - def namebase(self): + def stem(self): """ The same as :meth:`name`, but with one file extension stripped off. - For example, - ``Path('/home/guido/python.tar.gz').name == 'python.tar.gz'``, - but - ``Path('/home/guido/python.tar.gz').namebase == 'python.tar'``. + >>> Path('/home/guido/python.tar.gz').stem + 'python.tar' """ base, ext = self.module.splitext(self.name) return base + @property + def namebase(self): + warnings.warn("Use .stem instead of .namebase", DeprecationWarning) + return self.stem + @property def ext(self): """ The file extension, for example ``'.py'``. """ f, ext = self.module.splitext(self) return ext + def with_suffix(self, suffix): + """ Return a new path with the file suffix changed (or added, if none) + + >>> Path('/home/guido/python.tar.gz').with_suffix(".foo") + Path('/home/guido/python.tar.foo') + + >>> Path('python').with_suffix('.zip') + Path('python.zip') + + >>> Path('filename.ext').with_suffix('zip') + Traceback (most recent call last): + ... + ValueError: Invalid suffix 'zip' + """ + if not suffix.startswith('.'): + raise ValueError("Invalid suffix {suffix!r}".format(**locals())) + + return self.stripext() + suffix + @property def drive(self): """ The drive specifier, for example ``'C:'``. @@ -437,8 +491,9 @@ class Path(text_type): @multimethod def joinpath(cls, first, *others): """ - Join first to zero or more :class:`Path` components, adding a separator - character (:samp:`{first}.module.sep`) if needed. Returns a new instance of + Join first to zero or more :class:`Path` components, + adding a separator character (:samp:`{first}.module.sep`) + if needed. Returns a new instance of :samp:`{first}._next_class`. .. seealso:: :func:`os.path.join` @@ -516,7 +571,7 @@ class Path(text_type): # --- Listing, searching, walking, and matching - def listdir(self, pattern=None): + def listdir(self, match=None): """ D.listdir() -> List of items in this directory. Use :meth:`files` or :meth:`dirs` instead if you want a listing @@ -524,46 +579,39 @@ class Path(text_type): The elements of the list are Path objects. - With the optional `pattern` argument, this only lists - items whose names match the given pattern. + With the optional `match` argument, a callable, + only return items whose names match the given pattern. .. seealso:: :meth:`files`, :meth:`dirs` """ - if pattern is None: - pattern = '*' - return [ - self / child - for child in map(self._always_unicode, os.listdir(self)) - if self._next_class(child).fnmatch(pattern) - ] + match = matchers.load(match) + return list(filter(match, ( + self / child for child in os.listdir(self) + ))) - def dirs(self, pattern=None): + def dirs(self, *args, **kwargs): """ D.dirs() -> List of this directory's subdirectories. The elements of the list are Path objects. This does not walk recursively into subdirectories (but see :meth:`walkdirs`). - With the optional `pattern` argument, this only lists - directories whose names match the given pattern. For - example, ``d.dirs('build-*')``. + Accepts parameters to :meth:`listdir`. """ - return [p for p in self.listdir(pattern) if p.isdir()] + return [p for p in self.listdir(*args, **kwargs) if p.isdir()] - def files(self, pattern=None): + def files(self, *args, **kwargs): """ D.files() -> List of the files in this directory. The elements of the list are Path objects. This does not walk into subdirectories (see :meth:`walkfiles`). - With the optional `pattern` argument, this only lists files - whose names match the given pattern. For example, - ``d.files('*.pyc')``. + Accepts parameters to :meth:`listdir`. """ - return [p for p in self.listdir(pattern) if p.isfile()] + return [p for p in self.listdir(*args, **kwargs) if p.isfile()] - def walk(self, pattern=None, errors='strict'): + def walk(self, match=None, errors='strict'): """ D.walk() -> iterator over files and subdirs, recursively. The iterator yields Path objects naming each child item of @@ -593,6 +641,8 @@ class Path(text_type): raise ValueError("invalid errors parameter") errors = vars(Handlers).get(errors, errors) + match = matchers.load(match) + try: childList = self.listdir() except Exception: @@ -603,7 +653,7 @@ class Path(text_type): return for child in childList: - if pattern is None or child.fnmatch(pattern): + if match(child): yield child try: isdir = child.isdir() @@ -615,92 +665,26 @@ class Path(text_type): isdir = False if isdir: - for item in child.walk(pattern, errors): + for item in child.walk(errors=errors, match=match): yield item - def walkdirs(self, pattern=None, errors='strict'): + def walkdirs(self, *args, **kwargs): """ D.walkdirs() -> iterator over subdirs, recursively. - - With the optional `pattern` argument, this yields only - directories whose names match the given pattern. For - example, ``mydir.walkdirs('*test')`` yields only directories - with names ending in ``'test'``. - - The `errors=` keyword argument controls behavior when an - error occurs. The default is ``'strict'``, which causes an - exception. The other allowed values are ``'warn'`` (which - reports the error via :func:`warnings.warn()`), and ``'ignore'``. """ - if errors not in ('strict', 'warn', 'ignore'): - raise ValueError("invalid errors parameter") + return ( + item + for item in self.walk(*args, **kwargs) + if item.isdir() + ) - try: - dirs = self.dirs() - except Exception: - if errors == 'ignore': - return - elif errors == 'warn': - warnings.warn( - "Unable to list directory '%s': %s" - % (self, sys.exc_info()[1]), - TreeWalkWarning) - return - else: - raise - - for child in dirs: - if pattern is None or child.fnmatch(pattern): - yield child - for subsubdir in child.walkdirs(pattern, errors): - yield subsubdir - - def walkfiles(self, pattern=None, errors='strict'): + def walkfiles(self, *args, **kwargs): """ D.walkfiles() -> iterator over files in D, recursively. - - The optional argument `pattern` limits the results to files - with names that match the pattern. For example, - ``mydir.walkfiles('*.tmp')`` yields only files with the ``.tmp`` - extension. """ - if errors not in ('strict', 'warn', 'ignore'): - raise ValueError("invalid errors parameter") - - try: - childList = self.listdir() - except Exception: - if errors == 'ignore': - return - elif errors == 'warn': - warnings.warn( - "Unable to list directory '%s': %s" - % (self, sys.exc_info()[1]), - TreeWalkWarning) - return - else: - raise - - for child in childList: - try: - isfile = child.isfile() - isdir = not isfile and child.isdir() - except: - if errors == 'ignore': - continue - elif errors == 'warn': - warnings.warn( - "Unable to access '%s': %s" - % (self, sys.exc_info()[1]), - TreeWalkWarning) - continue - else: - raise - - if isfile: - if pattern is None or child.fnmatch(pattern): - yield child - elif isdir: - for f in child.walkfiles(pattern, errors): - yield f + return ( + item + for item in self.walk(*args, **kwargs) + if item.isfile() + ) def fnmatch(self, pattern, normcase=None): """ Return ``True`` if `self.name` matches the given `pattern`. @@ -710,8 +694,8 @@ class Path(text_type): attribute, it is applied to the name and path prior to comparison. `normcase` - (optional) A function used to normalize the pattern and - filename before matching. Defaults to :meth:`self.module`, which defaults - to :meth:`os.path.normcase`. + filename before matching. Defaults to :meth:`self.module`, which + defaults to :meth:`os.path.normcase`. .. seealso:: :func:`fnmatch.fnmatch` """ @@ -730,10 +714,32 @@ class Path(text_type): of all the files users have in their :file:`bin` directories. .. seealso:: :func:`glob.glob` + + .. note:: Glob is **not** recursive, even when using ``**``. + To do recursive globbing see :func:`walk`, + :func:`walkdirs` or :func:`walkfiles`. """ cls = self._next_class return [cls(s) for s in glob.glob(self / pattern)] + def iglob(self, pattern): + """ Return an iterator of Path objects that match the pattern. + + `pattern` - a path relative to this directory, with wildcards. + + For example, ``Path('/users').iglob('*/bin/*')`` returns an + iterator of all the files users have in their :file:`bin` + directories. + + .. seealso:: :func:`glob.iglob` + + .. note:: Glob is **not** recursive, even when using ``**``. + To do recursive globbing see :func:`walk`, + :func:`walkdirs` or :func:`walkfiles`. + """ + cls = self._next_class + return (cls(s) for s in glob.iglob(self / pattern)) + # # --- Reading or writing an entire file at once. @@ -882,15 +888,9 @@ class Path(text_type): translated to ``'\n'``. If ``False``, newline characters are stripped off. Default is ``True``. - This uses ``'U'`` mode. - .. seealso:: :meth:`text` """ - if encoding is None and retain: - with self.open('U') as f: - return f.readlines() - else: - return self.text(encoding, errors).splitlines(retain) + return self.text(encoding, errors).splitlines(retain) def write_lines(self, lines, encoding=None, errors='strict', linesep=os.linesep, append=False): @@ -931,14 +931,15 @@ class Path(text_type): to read the file later. """ with self.open('ab' if append else 'wb') as f: - for l in lines: - isUnicode = isinstance(l, text_type) + for line in lines: + isUnicode = isinstance(line, text_type) if linesep is not None: pattern = U_NL_END if isUnicode else NL_END - l = pattern.sub('', l) + linesep + line = pattern.sub('', line) + linesep if isUnicode: - l = l.encode(encoding or sys.getdefaultencoding(), errors) - f.write(l) + line = line.encode( + encoding or sys.getdefaultencoding(), errors) + f.write(line) def read_md5(self): """ Calculate the md5 hash for this file. @@ -952,8 +953,8 @@ class Path(text_type): def _hash(self, hash_name): """ Returns a hash object for the file at the current path. - `hash_name` should be a hash algo name (such as ``'md5'`` or ``'sha1'``) - that's available in the :mod:`hashlib` module. + `hash_name` should be a hash algo name (such as ``'md5'`` + or ``'sha1'``) that's available in the :mod:`hashlib` module. """ m = hashlib.new(hash_name) for chunk in self.chunks(8192, mode="rb"): @@ -1176,7 +1177,8 @@ class Path(text_type): gid = grp.getgrnam(gid).gr_gid os.chown(self, uid, gid) else: - raise NotImplementedError("Ownership not available on this platform.") + msg = "Ownership not available on this platform." + raise NotImplementedError(msg) return self def rename(self, new): @@ -1236,7 +1238,8 @@ class Path(text_type): self.rmdir() except OSError: _, e, _ = sys.exc_info() - if e.errno != errno.ENOTEMPTY and e.errno != errno.EEXIST: + bypass_codes = errno.ENOTEMPTY, errno.EEXIST, errno.ENOENT + if e.errno not in bypass_codes: raise return self @@ -1277,9 +1280,8 @@ class Path(text_type): file does not exist. """ try: self.unlink() - except OSError: - _, e, _ = sys.exc_info() - if e.errno != errno.ENOENT: + except FileNotFoundError as exc: + if PY2 and exc.errno != errno.ENOENT: raise return self @@ -1306,11 +1308,16 @@ class Path(text_type): return self._next_class(newpath) if hasattr(os, 'symlink'): - def symlink(self, newlink): + def symlink(self, newlink=None): """ Create a symbolic link at `newlink`, pointing here. + If newlink is not supplied, the symbolic link will assume + the name self.basename(), creating the link in the cwd. + .. seealso:: :func:`os.symlink` """ + if newlink is None: + newlink = self.basename() os.symlink(self, newlink) return self._next_class(newlink) @@ -1368,30 +1375,60 @@ class Path(text_type): cd = chdir - def merge_tree(self, dst, symlinks=False, *args, **kwargs): + def merge_tree( + self, dst, symlinks=False, + # * + update=False, + copy_function=shutil.copy2, + ignore=lambda dir, contents: []): """ Copy entire contents of self to dst, overwriting existing contents in dst with those in self. - If the additional keyword `update` is True, each - `src` will only be copied if `dst` does not exist, - or `src` is newer than `dst`. + Pass ``symlinks=True`` to copy symbolic links as links. - Note that the technique employed stages the files in a temporary - directory first, so this function is not suitable for merging - trees with large files, especially if the temporary directory - is not capable of storing a copy of the entire source tree. + Accepts a ``copy_function``, similar to copytree. + + To avoid overwriting newer files, supply a copy function + wrapped in ``only_newer``. For example:: + + src.merge_tree(dst, copy_function=only_newer(shutil.copy2)) """ - update = kwargs.pop('update', False) - with tempdir() as _temp_dir: - # first copy the tree to a stage directory to support - # the parameters and behavior of copytree. - stage = _temp_dir / str(hash(self)) - self.copytree(stage, symlinks, *args, **kwargs) - # now copy everything from the stage directory using - # the semantics of dir_util.copy_tree - dir_util.copy_tree(stage, dst, preserve_symlinks=symlinks, - update=update) + dst = self._next_class(dst) + dst.makedirs_p() + + if update: + warnings.warn( + "Update is deprecated; " + "use copy_function=only_newer(shutil.copy2)", + DeprecationWarning, + stacklevel=2, + ) + copy_function = only_newer(copy_function) + + sources = self.listdir() + _ignored = ignore(self, [item.name for item in sources]) + + def ignored(item): + return item.name in _ignored + + for source in itertools.filterfalse(ignored, sources): + dest = dst / source.name + if symlinks and source.islink(): + target = source.readlink() + target.symlink(dest) + elif source.isdir(): + source.merge_tree( + dest, + symlinks=symlinks, + update=update, + copy_function=copy_function, + ignore=ignore, + ) + else: + copy_function(source, dest) + + self.copystat(dst) # # --- Special stuff from os @@ -1410,19 +1447,23 @@ class Path(text_type): # in-place re-writing, courtesy of Martijn Pieters # http://www.zopatista.com/python/2013/11/26/inplace-file-rewriting/ @contextlib.contextmanager - def in_place(self, mode='r', buffering=-1, encoding=None, errors=None, - newline=None, backup_extension=None): + def in_place( + self, mode='r', buffering=-1, encoding=None, errors=None, + newline=None, backup_extension=None, + ): """ - A context in which a file may be re-written in-place with new content. + A context in which a file may be re-written in-place with + new content. - Yields a tuple of :samp:`({readable}, {writable})` file objects, where `writable` - replaces `readable`. + Yields a tuple of :samp:`({readable}, {writable})` file + objects, where `writable` replaces `readable`. If an exception occurs, the old file is restored, removing the written data. - Mode *must not* use ``'w'``, ``'a'``, or ``'+'``; only read-only-modes are - allowed. A :exc:`ValueError` is raised on invalid modes. + Mode *must not* use ``'w'``, ``'a'``, or ``'+'``; only + read-only-modes are allowed. A :exc:`ValueError` is raised + on invalid modes. For example, to add line numbers to a file:: @@ -1448,22 +1489,28 @@ class Path(text_type): except os.error: pass os.rename(self, backup_fn) - readable = io.open(backup_fn, mode, buffering=buffering, - encoding=encoding, errors=errors, newline=newline) + readable = io.open( + backup_fn, mode, buffering=buffering, + encoding=encoding, errors=errors, newline=newline, + ) try: perm = os.fstat(readable.fileno()).st_mode except OSError: - writable = open(self, 'w' + mode.replace('r', ''), + writable = open( + self, 'w' + mode.replace('r', ''), buffering=buffering, encoding=encoding, errors=errors, - newline=newline) + newline=newline, + ) else: os_mode = os.O_CREAT | os.O_WRONLY | os.O_TRUNC if hasattr(os, 'O_BINARY'): os_mode |= os.O_BINARY fd = os.open(self, os_mode, perm) - writable = io.open(fd, "w" + mode.replace('r', ''), + writable = io.open( + fd, "w" + mode.replace('r', ''), buffering=buffering, encoding=encoding, errors=errors, - newline=newline) + newline=newline, + ) try: if hasattr(os, 'chmod'): os.chmod(self, perm) @@ -1516,6 +1563,23 @@ class Path(text_type): return functools.partial(SpecialResolver, cls) +def only_newer(copy_func): + """ + Wrap a copy function (like shutil.copy2) to return + the dst if it's newer than the source. + """ + @functools.wraps(copy_func) + def wrapper(src, dst, *args, **kwargs): + is_newer_dst = ( + dst.exists() + and dst.getmtime() >= src.getmtime() + ) + if is_newer_dst: + return dst + return copy_func(src, dst, *args, **kwargs) + return wrapper + + class SpecialResolver(object): class ResolverScope: def __init__(self, paths, scope): @@ -1584,14 +1648,15 @@ class Multi: ) -class tempdir(Path): +class TempDir(Path): """ - A temporary directory via :func:`tempfile.mkdtemp`, and constructed with the - same parameters that you can use as a context manager. + A temporary directory via :func:`tempfile.mkdtemp`, and + constructed with the same parameters that you can use + as a context manager. - Example: + Example:: - with tempdir() as d: + with TempDir() as d: # do stuff with the Path object "d" # here the directory is deleted automatically @@ -1606,19 +1671,27 @@ class tempdir(Path): def __new__(cls, *args, **kwargs): dirname = tempfile.mkdtemp(*args, **kwargs) - return super(tempdir, cls).__new__(cls, dirname) + return super(TempDir, cls).__new__(cls, dirname) def __init__(self, *args, **kwargs): pass def __enter__(self): - return self + # TempDir should return a Path version of itself and not itself + # so that a second context manager does not create a second + # temporary directory, but rather changes CWD to the location + # of the temporary directory. + return self._next_class(self) def __exit__(self, exc_type, exc_value, traceback): if not exc_value: self.rmtree() +# For backwards compatibility. +tempdir = TempDir + + def _multi_permission_mask(mode): """ Support multiple, comma-separated Unix chmod symbolic modes. @@ -1626,7 +1699,8 @@ def _multi_permission_mask(mode): >>> _multi_permission_mask('a=r,u+w')(0) == 0o644 True """ - compose = lambda f, g: lambda *args, **kwargs: g(f(*args, **kwargs)) + def compose(f, g): + return lambda *args, **kwargs: g(f(*args, **kwargs)) return functools.reduce(compose, map(_permission_mask, mode.split(','))) @@ -1692,31 +1766,56 @@ def _permission_mask(mode): return functools.partial(op_map[op], mask) -class CaseInsensitivePattern(text_type): +class CaseInsensitivePattern(matchers.CaseInsensitive): + def __init__(self, value): + warnings.warn( + "Use matchers.CaseInsensitive instead", + DeprecationWarning, + stacklevel=2, + ) + super(CaseInsensitivePattern, self).__init__(value) + + +class FastPath(Path): + def __init__(self, *args, **kwargs): + warnings.warn( + "Use Path, as FastPath no longer holds any advantage", + DeprecationWarning, + stacklevel=2, + ) + super(FastPath, self).__init__(*args, **kwargs) + + +def patch_for_linux_python2(): """ - A string with a ``'normcase'`` property, suitable for passing to - :meth:`listdir`, :meth:`dirs`, :meth:`files`, :meth:`walk`, - :meth:`walkdirs`, or :meth:`walkfiles` to match case-insensitive. - - For example, to get all files ending in .py, .Py, .pY, or .PY in the - current directory:: - - from path import Path, CaseInsensitivePattern as ci - Path('.').files(ci('*.py')) + As reported in #130, when Linux users create filenames + not in the file system encoding, it creates problems on + Python 2. This function attempts to patch the os module + to make it behave more like that on Python 3. """ + if not PY2 or platform.system() != 'Linux': + return - @property - def normcase(self): - return __import__('ntpath').normcase + try: + import backports.os + except ImportError: + return -######################## -# Backward-compatibility -class path(Path): - def __new__(cls, *args, **kwargs): - msg = "path is deprecated. Use Path instead." - warnings.warn(msg, DeprecationWarning) - return Path.__new__(cls, *args, **kwargs) + class OS: + """ + The proxy to the os module + """ + def __init__(self, wrapped): + self._orig = wrapped + + def __getattr__(self, name): + return getattr(self._orig, name) + + def listdir(self, *args, **kwargs): + items = self._orig.listdir(*args, **kwargs) + return list(map(backports.os.fsdecode, items)) + + globals().update(os=OS(os)) -__all__ += ['path'] -######################## +patch_for_linux_python2() diff --git a/libs/test_path.py b/libs/test_path.py index f6aa1b67..2a7ddb8f 100644 --- a/libs/test_path.py +++ b/libs/test_path.py @@ -22,25 +22,54 @@ import os import sys import shutil import time +import types import ntpath import posixpath import textwrap import platform import importlib +import operator +import datetime +import subprocess +import re import pytest +import packaging.version -from path import Path, tempdir -from path import CaseInsensitivePattern as ci +import path +from path import TempDir +from path import matchers from path import SpecialResolver from path import Multi +Path = None + def p(**choices): """ Choose a value from several possible values, based on os.name """ return choices[os.name] +@pytest.fixture(autouse=True, params=[path.Path]) +def path_class(request, monkeypatch): + """ + Invoke tests on any number of Path classes. + """ + monkeypatch.setitem(globals(), 'Path', request.param) + + +def mac_version(target, comparator=operator.ge): + """ + Return True if on a Mac whose version passes the comparator. + """ + current_ver = packaging.version.parse(platform.mac_ver()[0]) + target_ver = packaging.version.parse(target) + return ( + platform.system() == 'Darwin' + and comparator(current_ver, target_ver) + ) + + class TestBasics: def test_relpath(self): root = Path(p(nt='C:\\', posix='/')) @@ -51,14 +80,14 @@ class TestBasics: up = Path(os.pardir) # basics - assert root.relpathto(boz) == Path('foo')/'bar'/'Baz'/'Boz' - assert bar.relpathto(boz) == Path('Baz')/'Boz' - assert quux.relpathto(boz) == up/'bar'/'Baz'/'Boz' - assert boz.relpathto(quux) == up/up/up/'quux' - assert boz.relpathto(bar) == up/up + assert root.relpathto(boz) == Path('foo') / 'bar' / 'Baz' / 'Boz' + assert bar.relpathto(boz) == Path('Baz') / 'Boz' + assert quux.relpathto(boz) == up / 'bar' / 'Baz' / 'Boz' + assert boz.relpathto(quux) == up / up / up / 'quux' + assert boz.relpathto(bar) == up / up # Path is not the first element in concatenation - assert root.relpathto(boz) == 'foo'/Path('bar')/'Baz'/'Boz' + assert root.relpathto(boz) == 'foo' / Path('bar') / 'Baz' / 'Boz' # x.relpathto(x) == curdir assert root.relpathto(root) == os.curdir @@ -112,7 +141,7 @@ class TestBasics: # Test p1/p1. p1 = Path("foo") p2 = Path("bar") - assert p1/p2 == p(nt='foo\\bar', posix='foo/bar') + assert p1 / p2 == p(nt='foo\\bar', posix='foo/bar') def test_properties(self): # Create sample path object. @@ -207,6 +236,30 @@ class TestBasics: assert res2 == 'foo/bar' +class TestPerformance: + @pytest.mark.skipif( + path.PY2, + reason="Tests fail frequently on Python 2; see #153") + def test_import_time(self, monkeypatch): + """ + Import of path.py should take less than 100ms. + + Run tests in a subprocess to isolate from test suite overhead. + """ + cmd = [ + sys.executable, + '-m', 'timeit', + '-n', '1', + '-r', '1', + 'import path', + ] + res = subprocess.check_output(cmd, universal_newlines=True) + dur = re.search(r'(\d+) msec per loop', res).group(1) + limit = datetime.timedelta(milliseconds=100) + duration = datetime.timedelta(milliseconds=int(dur)) + assert duration < limit + + class TestSelfReturn: """ Some methods don't necessarily return any value (e.g. makedirs, @@ -246,7 +299,7 @@ class TestSelfReturn: class TestScratchDir: """ - Tests that run in a temporary directory (does not test tempdir class) + Tests that run in a temporary directory (does not test TempDir class) """ def test_context_manager(self, tmpdir): """Can be used as context manager for chdir.""" @@ -282,12 +335,12 @@ class TestScratchDir: ct = f.ctime assert t0 <= ct <= t1 - time.sleep(threshold*2) + time.sleep(threshold * 2) fobj = open(f, 'ab') fobj.write('some bytes'.encode('utf-8')) fobj.close() - time.sleep(threshold*2) + time.sleep(threshold * 2) t2 = time.time() - threshold f.touch() t3 = time.time() + threshold @@ -305,9 +358,12 @@ class TestScratchDir: assert ct == ct2 assert ct2 < t2 else: - # On other systems, it might be the CHANGE time - # (especially on Unix, time of inode changes) - assert ct == ct2 or ct2 == f.mtime + assert ( + # ctime is unchanged + ct == ct2 or + # ctime is approximately the mtime + ct2 == pytest.approx(f.mtime, 0.001) + ) def test_listing(self, tmpdir): d = Path(tmpdir) @@ -330,6 +386,11 @@ class TestScratchDir: assert d.glob('*') == [af] assert d.glob('*.html') == [] assert d.glob('testfile') == [] + + # .iglob matches .glob but as an iterator. + assert list(d.iglob('*')) == d.glob('*') + assert isinstance(d.iglob('*'), types.GeneratorType) + finally: af.remove() @@ -348,9 +409,17 @@ class TestScratchDir: for f in files: try: f.remove() - except: + except Exception: pass + @pytest.mark.xfail( + mac_version('10.13'), + reason="macOS disallows invalid encodings", + ) + @pytest.mark.xfail( + platform.system() == 'Windows' and path.PY3, + reason="Can't write latin characters. See #133", + ) def test_listdir_other_encoding(self, tmpdir): """ Some filesystems allow non-character sequences in path names. @@ -498,28 +567,28 @@ class TestScratchDir: def test_patterns(self, tmpdir): d = Path(tmpdir) names = ['x.tmp', 'x.xtmp', 'x2g', 'x22', 'x.txt'] - dirs = [d, d/'xdir', d/'xdir.tmp', d/'xdir.tmp'/'xsubdir'] + dirs = [d, d / 'xdir', d / 'xdir.tmp', d / 'xdir.tmp' / 'xsubdir'] for e in dirs: if not e.isdir(): e.makedirs() for name in names: - (e/name).touch() - self.assertList(d.listdir('*.tmp'), [d/'x.tmp', d/'xdir.tmp']) - self.assertList(d.files('*.tmp'), [d/'x.tmp']) - self.assertList(d.dirs('*.tmp'), [d/'xdir.tmp']) + (e / name).touch() + self.assertList(d.listdir('*.tmp'), [d / 'x.tmp', d / 'xdir.tmp']) + self.assertList(d.files('*.tmp'), [d / 'x.tmp']) + self.assertList(d.dirs('*.tmp'), [d / 'xdir.tmp']) self.assertList(d.walk(), [e for e in dirs - if e != d] + [e/n for e in dirs + if e != d] + [e / n for e in dirs for n in names]) self.assertList(d.walk('*.tmp'), - [e/'x.tmp' for e in dirs] + [d/'xdir.tmp']) - self.assertList(d.walkfiles('*.tmp'), [e/'x.tmp' for e in dirs]) - self.assertList(d.walkdirs('*.tmp'), [d/'xdir.tmp']) + [e / 'x.tmp' for e in dirs] + [d / 'xdir.tmp']) + self.assertList(d.walkfiles('*.tmp'), [e / 'x.tmp' for e in dirs]) + self.assertList(d.walkdirs('*.tmp'), [d / 'xdir.tmp']) def test_unicode(self, tmpdir): d = Path(tmpdir) - p = d/'unicode.txt' + p = d / 'unicode.txt' def test(enc): """ Test that path works with the specified encoding, @@ -527,18 +596,22 @@ class TestScratchDir: Unicode codepoints. """ - given = ('Hello world\n' - '\u0d0a\u0a0d\u0d15\u0a15\r\n' - '\u0d0a\u0a0d\u0d15\u0a15\x85' - '\u0d0a\u0a0d\u0d15\u0a15\u2028' - '\r' - 'hanging') - clean = ('Hello world\n' - '\u0d0a\u0a0d\u0d15\u0a15\n' - '\u0d0a\u0a0d\u0d15\u0a15\n' - '\u0d0a\u0a0d\u0d15\u0a15\n' - '\n' - 'hanging') + given = ( + 'Hello world\n' + '\u0d0a\u0a0d\u0d15\u0a15\r\n' + '\u0d0a\u0a0d\u0d15\u0a15\x85' + '\u0d0a\u0a0d\u0d15\u0a15\u2028' + '\r' + 'hanging' + ) + clean = ( + 'Hello world\n' + '\u0d0a\u0a0d\u0d15\u0a15\n' + '\u0d0a\u0a0d\u0d15\u0a15\n' + '\u0d0a\u0a0d\u0d15\u0a15\n' + '\n' + 'hanging' + ) givenLines = [ ('Hello world\n'), ('\u0d0a\u0a0d\u0d15\u0a15\r\n'), @@ -581,8 +654,9 @@ class TestScratchDir: return # Write Unicode to file using path.write_text(). - cleanNoHanging = clean + '\n' # This test doesn't work with a - # hanging line. + # This test doesn't work with a hanging line. + cleanNoHanging = clean + '\n' + p.write_text(cleanNoHanging, enc) p.write_text(cleanNoHanging, enc, append=True) # Check the result. @@ -641,7 +715,7 @@ class TestScratchDir: test('UTF-16') def test_chunks(self, tmpdir): - p = (tempdir() / 'test.txt').touch() + p = (TempDir() / 'test.txt').touch() txt = "0123456789" size = 5 p.write_text(txt) @@ -650,16 +724,18 @@ class TestScratchDir: assert i == len(txt) / size - 1 - @pytest.mark.skipif(not hasattr(os.path, 'samefile'), - reason="samefile not present") + @pytest.mark.skipif( + not hasattr(os.path, 'samefile'), + reason="samefile not present", + ) def test_samefile(self, tmpdir): - f1 = (tempdir() / '1.txt').touch() + f1 = (TempDir() / '1.txt').touch() f1.write_text('foo') - f2 = (tempdir() / '2.txt').touch() + f2 = (TempDir() / '2.txt').touch() f1.write_text('foo') - f3 = (tempdir() / '3.txt').touch() + f3 = (TempDir() / '3.txt').touch() f1.write_text('bar') - f4 = (tempdir() / '4.txt') + f4 = (TempDir() / '4.txt') f1.copyfile(f4) assert os.path.samefile(f1, f2) == f1.samefile(f2) @@ -680,6 +756,26 @@ class TestScratchDir: self.fail("Calling `rmtree_p` on non-existent directory " "should not raise an exception.") + def test_rmdir_p_exists(self, tmpdir): + """ + Invocation of rmdir_p on an existant directory should + remove the directory. + """ + d = Path(tmpdir) + sub = d / 'subfolder' + sub.mkdir() + sub.rmdir_p() + assert not sub.exists() + + def test_rmdir_p_nonexistent(self, tmpdir): + """ + A non-existent file should not raise an exception. + """ + d = Path(tmpdir) + sub = d / 'subfolder' + assert not sub.exists() + sub.rmdir_p() + class TestMergeTree: @pytest.fixture(autouse=True) @@ -701,6 +797,11 @@ class TestMergeTree: else: self.test_file.copy(self.test_link) + def check_link(self): + target = Path(self.subdir_b / self.test_link.name) + check = target.islink if hasattr(os, 'symlink') else target.isfile + assert check() + def test_with_nonexisting_dst_kwargs(self): self.subdir_a.merge_tree(self.subdir_b, symlinks=True) assert self.subdir_b.isdir() @@ -709,7 +810,7 @@ class TestMergeTree: self.subdir_b / self.test_link.name, )) assert set(self.subdir_b.listdir()) == expected - assert Path(self.subdir_b / self.test_link.name).islink() + self.check_link() def test_with_nonexisting_dst_args(self): self.subdir_a.merge_tree(self.subdir_b, True) @@ -719,7 +820,7 @@ class TestMergeTree: self.subdir_b / self.test_link.name, )) assert set(self.subdir_b.listdir()) == expected - assert Path(self.subdir_b / self.test_link.name).islink() + self.check_link() def test_with_existing_dst(self): self.subdir_b.rmtree() @@ -740,7 +841,7 @@ class TestMergeTree: self.subdir_b / test_new.name, )) assert set(self.subdir_b.listdir()) == expected - assert Path(self.subdir_b / self.test_link.name).islink() + self.check_link() assert len(Path(self.subdir_b / self.test_file.name).bytes()) == 5000 def test_copytree_parameters(self): @@ -753,6 +854,20 @@ class TestMergeTree: assert self.subdir_b.isdir() assert self.subdir_b.listdir() == [self.subdir_b / self.test_file.name] + def test_only_newer(self): + """ + merge_tree should accept a copy_function in which only + newer files are copied and older files do not overwrite + newer copies in the dest. + """ + target = self.subdir_b / 'testfile.txt' + target.write_text('this is newer') + self.subdir_a.merge_tree( + self.subdir_b, + copy_function=path.only_newer(shutil.copy2), + ) + assert target.text() == 'this is newer' + class TestChdir: def test_chdir_or_cd(self, tmpdir): @@ -781,17 +896,17 @@ class TestChdir: class TestSubclass: - class PathSubclass(Path): - pass def test_subclass_produces_same_class(self): """ When operations are invoked on a subclass, they should produce another instance of that subclass. """ - p = self.PathSubclass('/foo') + class PathSubclass(Path): + pass + p = PathSubclass('/foo') subdir = p / 'bar' - assert isinstance(subdir, self.PathSubclass) + assert isinstance(subdir, PathSubclass) class TestTempDir: @@ -800,8 +915,8 @@ class TestTempDir: """ One should be able to readily construct a temporary directory """ - d = tempdir() - assert isinstance(d, Path) + d = TempDir() + assert isinstance(d, path.Path) assert d.exists() assert d.isdir() d.rmdir() @@ -809,24 +924,24 @@ class TestTempDir: def test_next_class(self): """ - It should be possible to invoke operations on a tempdir and get + It should be possible to invoke operations on a TempDir and get Path classes. """ - d = tempdir() + d = TempDir() sub = d / 'subdir' - assert isinstance(sub, Path) + assert isinstance(sub, path.Path) d.rmdir() def test_context_manager(self): """ - One should be able to use a tempdir object as a context, which will + One should be able to use a TempDir object as a context, which will clean up the contents after. """ - d = tempdir() + d = TempDir() res = d.__enter__() - assert res is d + assert res == path.Path(d) (d / 'somefile.txt').touch() - assert not isinstance(d / 'somefile.txt', tempdir) + assert not isinstance(d / 'somefile.txt', TempDir) d.__exit__(None, None, None) assert not d.exists() @@ -834,10 +949,10 @@ class TestTempDir: """ The context manager will not clean up if an exception occurs. """ - d = tempdir() + d = TempDir() d.__enter__() (d / 'somefile.txt').touch() - assert not isinstance(d / 'somefile.txt', tempdir) + assert not isinstance(d / 'somefile.txt', TempDir) d.__exit__(TypeError, TypeError('foo'), None) assert d.exists() @@ -847,7 +962,7 @@ class TestTempDir: provide a temporry directory that will be deleted after that. """ - with tempdir() as d: + with TempDir() as d: assert d.isdir() assert not d.isdir() @@ -876,7 +991,8 @@ class TestPatternMatching: assert p.fnmatch('FOO[ABC]AR') def test_fnmatch_custom_normcase(self): - normcase = lambda path: path.upper() + def normcase(path): + return path.upper() p = Path('FooBar') assert p.fnmatch('foobar', normcase=normcase) assert p.fnmatch('FOO[ABC]AR', normcase=normcase) @@ -891,8 +1007,8 @@ class TestPatternMatching: def test_listdir_patterns(self, tmpdir): p = Path(tmpdir) - (p/'sub').mkdir() - (p/'File').touch() + (p / 'sub').mkdir() + (p / 'File').touch() assert p.listdir('s*') == [p / 'sub'] assert len(p.listdir('*')) == 2 @@ -903,14 +1019,14 @@ class TestPatternMatching: """ always_unix = Path.using_module(posixpath) p = always_unix(tmpdir) - (p/'sub').mkdir() - (p/'File').touch() + (p / 'sub').mkdir() + (p / 'File').touch() assert p.listdir('S*') == [] always_win = Path.using_module(ntpath) p = always_win(tmpdir) - assert p.listdir('S*') == [p/'sub'] - assert p.listdir('f*') == [p/'File'] + assert p.listdir('S*') == [p / 'sub'] + assert p.listdir('f*') == [p / 'File'] def test_listdir_case_insensitive(self, tmpdir): """ @@ -918,27 +1034,30 @@ class TestPatternMatching: used by that Path class. """ p = Path(tmpdir) - (p/'sub').mkdir() - (p/'File').touch() - assert p.listdir(ci('S*')) == [p/'sub'] - assert p.listdir(ci('f*')) == [p/'File'] - assert p.files(ci('S*')) == [] - assert p.dirs(ci('f*')) == [] + (p / 'sub').mkdir() + (p / 'File').touch() + assert p.listdir(matchers.CaseInsensitive('S*')) == [p / 'sub'] + assert p.listdir(matchers.CaseInsensitive('f*')) == [p / 'File'] + assert p.files(matchers.CaseInsensitive('S*')) == [] + assert p.dirs(matchers.CaseInsensitive('f*')) == [] def test_walk_case_insensitive(self, tmpdir): p = Path(tmpdir) - (p/'sub1'/'foo').makedirs_p() - (p/'sub2'/'foo').makedirs_p() - (p/'sub1'/'foo'/'bar.Txt').touch() - (p/'sub2'/'foo'/'bar.TXT').touch() - (p/'sub2'/'foo'/'bar.txt.bz2').touch() - files = list(p.walkfiles(ci('*.txt'))) + (p / 'sub1' / 'foo').makedirs_p() + (p / 'sub2' / 'foo').makedirs_p() + (p / 'sub1' / 'foo' / 'bar.Txt').touch() + (p / 'sub2' / 'foo' / 'bar.TXT').touch() + (p / 'sub2' / 'foo' / 'bar.txt.bz2').touch() + files = list(p.walkfiles(matchers.CaseInsensitive('*.txt'))) assert len(files) == 2 - assert p/'sub2'/'foo'/'bar.TXT' in files - assert p/'sub1'/'foo'/'bar.Txt' in files + assert p / 'sub2' / 'foo' / 'bar.TXT' in files + assert p / 'sub1' / 'foo' / 'bar.Txt' in files -@pytest.mark.skipif(sys.version_info < (2, 6), - reason="in_place requires io module in Python 2.6") + +@pytest.mark.skipif( + sys.version_info < (2, 6), + reason="in_place requires io module in Python 2.6", +) class TestInPlace: reference_content = textwrap.dedent(""" The quick brown fox jumped over the lazy dog. @@ -959,7 +1078,7 @@ class TestInPlace: @classmethod def create_reference(cls, tmpdir): - p = Path(tmpdir)/'document' + p = Path(tmpdir) / 'document' with p.open('w') as stream: stream.write(cls.reference_content) return p @@ -984,7 +1103,7 @@ class TestInPlace: assert "some error" in str(exc) with doc.open() as stream: data = stream.read() - assert not 'Lorem' in data + assert 'Lorem' not in data assert 'lazy dog' in data @@ -1023,8 +1142,9 @@ class TestSpecialPaths: def test_unix_paths_fallback(self, tmpdir, monkeypatch, feign_linux): "Without XDG_CONFIG_HOME set, ~/.config should be used." fake_home = tmpdir / '_home' + monkeypatch.delitem(os.environ, 'XDG_CONFIG_HOME', raising=False) monkeypatch.setitem(os.environ, 'HOME', str(fake_home)) - expected = str(tmpdir / '_home' / '.config') + expected = Path('~/.config').expanduser() assert SpecialResolver(Path).user.config == expected def test_property(self): @@ -1075,7 +1195,8 @@ class TestMultiPath: cls = Multi.for_class(Path) assert issubclass(cls, Path) assert issubclass(cls, Multi) - assert cls.__name__ == 'MultiPath' + expected_name = 'Multi' + Path.__name__ + assert cls.__name__ == expected_name def test_detect_no_pathsep(self): """ @@ -1115,5 +1236,23 @@ class TestMultiPath: assert path == input -if __name__ == '__main__': - pytest.main() +@pytest.mark.xfail('path.PY2', reason="Python 2 has no __future__") +def test_no_dependencies(): + """ + Path.py guarantees that the path module can be + transplanted into an environment without any dependencies. + """ + cmd = [ + sys.executable, + '-S', + '-c', 'import path', + ] + subprocess.check_call(cmd) + + +def test_version(): + """ + Under normal circumstances, path should present a + __version__. + """ + assert re.match(r'\d+\.\d+.*', path.__version__)