Merge pull request #1417 from clinton-hall/libs/jaraco

Update jaraco-windows to 3.9.2
This commit is contained in:
Labrys of Knossos 2018-12-15 15:58:46 -05:00 committed by GitHub
commit 76763e4b76
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
92 changed files with 7515 additions and 996 deletions

BIN
libs/bin/enver.exe Normal file

Binary file not shown.

BIN
libs/bin/find-symlinks.exe Normal file

Binary file not shown.

BIN
libs/bin/gclip.exe Normal file

Binary file not shown.

BIN
libs/bin/mklink.exe Normal file

Binary file not shown.

BIN
libs/bin/pclip.exe Normal file

Binary file not shown.

BIN
libs/bin/xmouse.exe Normal file

Binary file not shown.

View file

@ -0,0 +1,17 @@
from .api import distribution, Distribution, PackageNotFoundError # noqa: F401
from .api import metadata, entry_points, resolve, version, read_text
# Import for installation side-effects.
from . import _hooks # noqa: F401
__all__ = [
'metadata',
'entry_points',
'resolve',
'version',
'read_text',
]
__version__ = version(__name__)

View file

@ -0,0 +1,148 @@
from __future__ import unicode_literals, absolute_import
import re
import sys
import itertools
from .api import Distribution
from zipfile import ZipFile
if sys.version_info >= (3,): # pragma: nocover
from contextlib import suppress
from pathlib import Path
else: # pragma: nocover
from contextlib2 import suppress # noqa
from itertools import imap as map # type: ignore
from pathlib2 import Path
FileNotFoundError = IOError, OSError
__metaclass__ = type
def install(cls):
"""Class decorator for installation on sys.meta_path."""
sys.meta_path.append(cls)
return cls
class NullFinder:
@staticmethod
def find_spec(*args, **kwargs):
return None
# In Python 2, the import system requires finders
# to have a find_module() method, but this usage
# is deprecated in Python 3 in favor of find_spec().
# For the purposes of this finder (i.e. being present
# on sys.meta_path but having no other import
# system functionality), the two methods are identical.
find_module = find_spec
@install
class MetadataPathFinder(NullFinder):
"""A degenerate finder for distribution packages on the file system.
This finder supplies only a find_distribution() method for versions
of Python that do not have a PathFinder find_distribution().
"""
search_template = r'{name}(-.*)?\.(dist|egg)-info'
@classmethod
def find_distribution(cls, name):
paths = cls._search_paths(name)
dists = map(PathDistribution, paths)
return next(dists, None)
@classmethod
def _search_paths(cls, name):
"""
Find metadata directories in sys.path heuristically.
"""
return itertools.chain.from_iterable(
cls._search_path(path, name)
for path in map(Path, sys.path)
)
@classmethod
def _search_path(cls, root, name):
if not root.is_dir():
return ()
normalized = name.replace('-', '_')
return (
item
for item in root.iterdir()
if item.is_dir()
and re.match(
cls.search_template.format(name=normalized),
str(item.name),
flags=re.IGNORECASE,
)
)
class PathDistribution(Distribution):
def __init__(self, path):
"""Construct a distribution from a path to the metadata directory."""
self._path = path
def read_text(self, filename):
with suppress(FileNotFoundError):
with self._path.joinpath(filename).open(encoding='utf-8') as fp:
return fp.read()
return None
read_text.__doc__ = Distribution.read_text.__doc__
@install
class WheelMetadataFinder(NullFinder):
"""A degenerate finder for distribution packages in wheels.
This finder supplies only a find_distribution() method for versions
of Python that do not have a PathFinder find_distribution().
"""
search_template = r'{name}(-.*)?\.whl'
@classmethod
def find_distribution(cls, name):
paths = cls._search_paths(name)
dists = map(WheelDistribution, paths)
return next(dists, None)
@classmethod
def _search_paths(cls, name):
return (
item
for item in map(Path, sys.path)
if re.match(
cls.search_template.format(name=name),
str(item.name),
flags=re.IGNORECASE,
)
)
class WheelDistribution(Distribution):
def __init__(self, archive):
self._archive = archive
name, version = archive.name.split('-')[0:2]
self._dist_info = '{}-{}.dist-info'.format(name, version)
def read_text(self, filename):
with ZipFile(_path_to_filename(self._archive)) as zf:
with suppress(KeyError):
as_bytes = zf.read('{}/{}'.format(self._dist_info, filename))
return as_bytes.decode('utf-8')
return None
read_text.__doc__ = Distribution.read_text.__doc__
def _path_to_filename(path): # pragma: nocover
"""
On non-compliant systems, ensure a path-like object is
a string.
"""
try:
return path.__fspath__()
except AttributeError:
return str(path)

View file

@ -0,0 +1,146 @@
import io
import abc
import sys
import email
from importlib import import_module
if sys.version_info > (3,): # pragma: nocover
from configparser import ConfigParser
else: # pragma: nocover
from ConfigParser import SafeConfigParser as ConfigParser
try:
BaseClass = ModuleNotFoundError
except NameError: # pragma: nocover
BaseClass = ImportError # type: ignore
__metaclass__ = type
class PackageNotFoundError(BaseClass):
"""The package was not found."""
class Distribution:
"""A Python distribution package."""
@abc.abstractmethod
def read_text(self, filename):
"""Attempt to load metadata file given by the name.
:param filename: The name of the file in the distribution info.
:return: The text if found, otherwise None.
"""
@classmethod
def from_name(cls, name):
"""Return the Distribution for the given package name.
:param name: The name of the distribution package to search for.
:return: The Distribution instance (or subclass thereof) for the named
package, if found.
:raises PackageNotFoundError: When the named package's distribution
metadata cannot be found.
"""
for resolver in cls._discover_resolvers():
resolved = resolver(name)
if resolved is not None:
return resolved
else:
raise PackageNotFoundError(name)
@staticmethod
def _discover_resolvers():
"""Search the meta_path for resolvers."""
declared = (
getattr(finder, 'find_distribution', None)
for finder in sys.meta_path
)
return filter(None, declared)
@property
def metadata(self):
"""Return the parsed metadata for this Distribution.
The returned object will have keys that name the various bits of
metadata. See PEP 566 for details.
"""
return email.message_from_string(
self.read_text('METADATA') or self.read_text('PKG-INFO')
)
@property
def version(self):
"""Return the 'Version' metadata for the distribution package."""
return self.metadata['Version']
def distribution(package):
"""Get the ``Distribution`` instance for the given package.
:param package: The name of the package as a string.
:return: A ``Distribution`` instance (or subclass thereof).
"""
return Distribution.from_name(package)
def metadata(package):
"""Get the metadata for the package.
:param package: The name of the distribution package to query.
:return: An email.Message containing the parsed metadata.
"""
return Distribution.from_name(package).metadata
def version(package):
"""Get the version string for the named package.
:param package: The name of the distribution package to query.
:return: The version string for the package as defined in the package's
"Version" metadata key.
"""
return distribution(package).version
def entry_points(name):
"""Return the entry points for the named distribution package.
:param name: The name of the distribution package to query.
:return: A ConfigParser instance where the sections and keys are taken
from the entry_points.txt ini-style contents.
"""
as_string = read_text(name, 'entry_points.txt')
# 2018-09-10(barry): Should we provide any options here, or let the caller
# send options to the underlying ConfigParser? For now, YAGNI.
config = ConfigParser()
try:
config.read_string(as_string)
except AttributeError: # pragma: nocover
# Python 2 has no read_string
config.readfp(io.StringIO(as_string))
return config
def resolve(entry_point):
"""Resolve an entry point string into the named callable.
:param entry_point: An entry point string of the form
`path.to.module:callable`.
:return: The actual callable object `path.to.module.callable`
:raises ValueError: When `entry_point` doesn't have the proper format.
"""
path, colon, name = entry_point.rpartition(':')
if colon != ':':
raise ValueError('Not an entry point: {}'.format(entry_point))
module = import_module(path)
return getattr(module, name)
def read_text(package, filename):
"""
Read the text of the file in the distribution info directory.
"""
return distribution(package).read_text(filename)

View file

View file

@ -0,0 +1,57 @@
=========================
importlib_metadata NEWS
=========================
0.7 (2018-11-27)
================
* Fixed issue where packages with dashes in their names would
not be discovered. Closes #21.
* Distribution lookup is now case-insensitive. Closes #20.
* Wheel distributions can no longer be discovered by their module
name. Like Path distributions, they must be indicated by their
distribution package name.
0.6 (2018-10-07)
================
* Removed ``importlib_metadata.distribution`` function. Now
the public interface is primarily the utility functions exposed
in ``importlib_metadata.__all__``. Closes #14.
* Added two new utility functions ``read_text`` and
``metadata``.
0.5 (2018-09-18)
================
* Updated README and removed details about Distribution
class, now considered private. Closes #15.
* Added test suite support for Python 3.4+.
* Fixed SyntaxErrors on Python 3.4 and 3.5. !12
* Fixed errors on Windows joining Path elements. !15
0.4 (2018-09-14)
================
* Housekeeping.
0.3 (2018-09-14)
================
* Added usage documentation. Closes #8
* Add support for getting metadata from wheels on ``sys.path``. Closes #9
0.2 (2018-09-11)
================
* Added ``importlib_metadata.entry_points()``. Closes #1
* Added ``importlib_metadata.resolve()``. Closes #12
* Add support for Python 2.7. Closes #4
0.1 (2018-09-10)
================
* Initial release.
..
Local Variables:
mode: change-log-mode
indent-tabs-mode: nil
sentence-end-double-space: t
fill-column: 78
coding: utf-8
End:

View file

@ -0,0 +1,180 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# flake8: noqa
#
# importlib_metadata documentation build configuration file, created by
# sphinx-quickstart on Thu Nov 30 10:21:00 2017.
#
# This file is execfile()d with the current directory set to its
# containing dir.
#
# Note that not all possible configuration values are present in this
# autogenerated file.
#
# All configuration values have a default; values that are commented out
# serve to show the default.
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['sphinx.ext.autodoc',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.coverage',
'sphinx.ext.viewcode']
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
# The master toctree document.
master_doc = 'index'
# General information about the project.
project = 'importlib_metadata'
copyright = '2017-2018, Jason Coombs, Barry Warsaw'
author = 'Jason Coombs, Barry Warsaw'
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
version = '0.1'
# The full version, including alpha/beta/rc tags.
release = '0.1'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
# -- Options for HTML output ----------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'default'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# This is required for the alabaster theme
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
html_sidebars = {
'**': [
'relations.html', # needs 'show_related': True theme option to display
'searchbox.html',
]
}
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'importlib_metadatadoc'
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'importlib_metadata.tex', 'importlib\\_metadata Documentation',
'Brett Cannon, Barry Warsaw', 'manual'),
]
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'importlib_metadata', 'importlib_metadata Documentation',
[author], 1)
]
# -- Options for Texinfo output -------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'importlib_metadata', 'importlib_metadata Documentation',
author, 'importlib_metadata', 'One line description of project.',
'Miscellaneous'),
]
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'python': ('https://docs.python.org/3', None),
}

View file

@ -0,0 +1,53 @@
===============================
Welcome to importlib_metadata
===============================
``importlib_metadata`` is a library which provides an API for accessing an
installed package's `metadata`_, such as its entry points or its top-level
name. This functionality intends to replace most uses of ``pkg_resources``
`entry point API`_ and `metadata API`_. Along with ``importlib.resources`` in
`Python 3.7 and newer`_ (backported as `importlib_resources`_ for older
versions of Python), this can eliminate the need to use the older and less
efficient ``pkg_resources`` package.
``importlib_metadata`` is a backport of Python 3.8's standard library
`importlib.metadata`_ module for Python 2.7, and 3.4 through 3.7. Users of
Python 3.8 and beyond are encouraged to use the standard library module, and
in fact for these versions, ``importlib_metadata`` just shadows that module.
Developers looking for detailed API descriptions should refer to the Python
3.8 standard library documentation.
The documentation here includes a general :ref:`usage <using>` guide.
.. toctree::
:maxdepth: 2
:caption: Contents:
using.rst
changelog.rst
Project details
===============
* Project home: https://gitlab.com/python-devs/importlib_metadata
* Report bugs at: https://gitlab.com/python-devs/importlib_metadata/issues
* Code hosting: https://gitlab.com/python-devs/importlib_metadata.git
* Documentation: http://importlib_metadata.readthedocs.io/
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
.. _`metadata`: https://www.python.org/dev/peps/pep-0566/
.. _`entry point API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#entry-points
.. _`metadata API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#metadata-api
.. _`Python 3.7 and newer`: https://docs.python.org/3/library/importlib.html#module-importlib.resources
.. _`importlib_resources`: https://importlib-resources.readthedocs.io/en/latest/index.html
.. _`importlib.metadata`: TBD

View file

@ -0,0 +1,133 @@
.. _using:
==========================
Using importlib_metadata
==========================
``importlib_metadata`` is a library that provides for access to installed
package metadata. Built in part on Python's import system, this library
intends to replace similar functionality in ``pkg_resources`` `entry point
API`_ and `metadata API`_. Along with ``importlib.resources`` in `Python 3.7
and newer`_ (backported as `importlib_resources`_ for older versions of
Python), this can eliminate the need to use the older and less efficient
``pkg_resources`` package.
By "installed package" we generally mean a third party package installed into
Python's ``site-packages`` directory via tools such as ``pip``. Specifically,
it means a package with either a discoverable ``dist-info`` or ``egg-info``
directory, and metadata defined by `PEP 566`_ or its older specifications.
By default, package metadata can live on the file system or in wheels on
``sys.path``. Through an extension mechanism, the metadata can live almost
anywhere.
Overview
========
Let's say you wanted to get the version string for a package you've installed
using ``pip``. We start by creating a virtual environment and installing
something into it::
$ python3 -m venv example
$ source example/bin/activate
(example) $ pip install importlib_metadata
(example) $ pip install wheel
You can get the version string for ``wheel`` by running the following::
(example) $ python
>>> from importlib_metadata import version
>>> version('wheel')
'0.31.1'
You can also get the set of entry points for the ``wheel`` package. Since the
``entry_points.txt`` file is an ``.ini``-style, the ``entry_points()``
function returns a `ConfigParser instance`_. To get the list of command line
entry points, extract the ``console_scripts`` section::
>>> cp = entry_points('wheel')
>>> cp.options('console_scripts')
['wheel']
You can also get the callable that the entry point is mapped to::
>>> cp.get('console_scripts', 'wheel')
'wheel.tool:main'
Even more conveniently, you can resolve this entry point to the actual
callable::
>>> from importlib_metadata import resolve
>>> ep = cp.get('console_scripts', 'wheel')
>>> resolve(ep)
<function main at 0x111b91bf8>
Distributions
=============
While the above API is the most common and convenient usage, you can get all
of that information from the ``Distribution`` class. A ``Distribution`` is an
abstract object that represents the metadata for a Python package. You can
get the ``Distribution`` instance::
>>> from importlib_metadata import distribution
>>> dist = distribution('wheel')
Thus, an alternative way to get the version number is through the
``Distribution`` instance::
>>> dist.version
'0.31.1'
There are all kinds of additional metadata available on the ``Distribution``
instance::
>>> d.metadata['Requires-Python']
'>=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*'
>>> d.metadata['License']
'MIT'
The full set of available metadata is not described here. See PEP 566 for
additional details.
Extending the search algorithm
==============================
Because package metadata is not available through ``sys.path`` searches, or
package loaders directly, the metadata for a package is found through import
system `finders`_. To find a distribution package's metadata,
``importlib_metadata`` queries the list of `meta path finders`_ on
`sys.meta_path`_.
By default ``importlib_metadata`` installs a finder for packages found on the
file system. This finder doesn't actually find any *packages*, but it cany
find the package's metadata.
The abstract class :py:class:`importlib.abc.MetaPathFinder` defines the
interface expected of finders by Python's import system.
``importlib_metadata`` extends this protocol by looking for an optional
``find_distribution()`` ``@classmethod`` on the finders from
``sys.meta_path``. If the finder has this method, it takes a single argument
which is the name of the distribution package to find. The method returns
``None`` if it cannot find the distribution package, otherwise it returns an
instance of the ``Distribution`` abstract class.
What this means in practice is that to support finding distribution package
metadata in locations other than the file system, you should derive from
``Distribution`` and implement the ``load_metadata()`` method. This takes a
single argument which is the name of the package whose metadata is being
found. This instance of the ``Distribution`` base abstract class is what your
finder's ``find_distribution()`` method should return.
.. _`entry point API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#entry-points
.. _`metadata API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#metadata-api
.. _`Python 3.7 and newer`: https://docs.python.org/3/library/importlib.html#module-importlib.resources
.. _`importlib_resources`: https://importlib-resources.readthedocs.io/en/latest/index.html
.. _`PEP 566`: https://www.python.org/dev/peps/pep-0566/
.. _`ConfigParser instance`: https://docs.python.org/3/library/configparser.html#configparser.ConfigParser
.. _`finders`: https://docs.python.org/3/reference/import.html#finders-and-loaders
.. _`meta path finders`: https://docs.python.org/3/glossary.html#term-meta-path-finder
.. _`sys.meta_path`: https://docs.python.org/3/library/sys.html#sys.meta_path

View file

@ -0,0 +1,44 @@
import re
import unittest
import importlib_metadata
class APITests(unittest.TestCase):
version_pattern = r'\d+\.\d+(\.\d)?'
def test_retrieves_version_of_self(self):
version = importlib_metadata.version('importlib_metadata')
assert isinstance(version, str)
assert re.match(self.version_pattern, version)
def test_retrieves_version_of_pip(self):
# Assume pip is installed and retrieve the version of pip.
version = importlib_metadata.version('pip')
assert isinstance(version, str)
assert re.match(self.version_pattern, version)
def test_for_name_does_not_exist(self):
with self.assertRaises(importlib_metadata.PackageNotFoundError):
importlib_metadata.distribution('does-not-exist')
def test_for_top_level(self):
distribution = importlib_metadata.distribution('importlib_metadata')
self.assertEqual(
distribution.read_text('top_level.txt').strip(),
'importlib_metadata')
def test_entry_points(self):
parser = importlib_metadata.entry_points('pip')
# We should probably not be dependent on a third party package's
# internal API staying stable.
entry_point = parser.get('console_scripts', 'pip')
self.assertEqual(entry_point, 'pip._internal:main')
def test_metadata_for_this_package(self):
md = importlib_metadata.metadata('importlib_metadata')
assert md['author'] == 'Barry Warsaw'
assert md['LICENSE'] == 'Apache Software License'
assert md['Name'] == 'importlib-metadata'
classifiers = md.get_all('Classifier')
assert 'Topic :: Software Development :: Libraries' in classifiers

View file

@ -0,0 +1,121 @@
from __future__ import unicode_literals
import re
import sys
import shutil
import tempfile
import unittest
import importlib
import contextlib
import importlib_metadata
try:
from contextlib import ExitStack
except ImportError:
from contextlib2 import ExitStack
try:
import pathlib
except ImportError:
import pathlib2 as pathlib
from importlib_metadata import _hooks
class BasicTests(unittest.TestCase):
version_pattern = r'\d+\.\d+(\.\d)?'
def test_retrieves_version_of_pip(self):
# Assume pip is installed and retrieve the version of pip.
dist = importlib_metadata.Distribution.from_name('pip')
assert isinstance(dist.version, str)
assert re.match(self.version_pattern, dist.version)
def test_for_name_does_not_exist(self):
with self.assertRaises(importlib_metadata.PackageNotFoundError):
importlib_metadata.Distribution.from_name('does-not-exist')
def test_new_style_classes(self):
self.assertIsInstance(importlib_metadata.Distribution, type)
self.assertIsInstance(_hooks.MetadataPathFinder, type)
self.assertIsInstance(_hooks.WheelMetadataFinder, type)
self.assertIsInstance(_hooks.WheelDistribution, type)
class ImportTests(unittest.TestCase):
def test_import_nonexistent_module(self):
# Ensure that the MetadataPathFinder does not crash an import of a
# non-existant module.
with self.assertRaises(ImportError):
importlib.import_module('does_not_exist')
def test_resolve(self):
entry_points = importlib_metadata.entry_points('pip')
main = importlib_metadata.resolve(
entry_points.get('console_scripts', 'pip'))
import pip._internal
self.assertEqual(main, pip._internal.main)
def test_resolve_invalid(self):
self.assertRaises(ValueError, importlib_metadata.resolve, 'bogus.ep')
class NameNormalizationTests(unittest.TestCase):
@staticmethod
def pkg_with_dashes(site_dir):
"""
Create minimal metadata for a package with dashes
in the name (and thus underscores in the filename).
"""
metadata_dir = site_dir / 'my_pkg.dist-info'
metadata_dir.mkdir()
metadata = metadata_dir / 'METADATA'
with metadata.open('w') as strm:
strm.write('Version: 1.0\n')
return 'my-pkg'
@staticmethod
@contextlib.contextmanager
def site_dir():
tmpdir = tempfile.mkdtemp()
sys.path[:0] = [tmpdir]
try:
yield pathlib.Path(tmpdir)
finally:
sys.path.remove(tmpdir)
shutil.rmtree(tmpdir)
def setUp(self):
self.fixtures = ExitStack()
self.addCleanup(self.fixtures.close)
self.site_dir = self.fixtures.enter_context(self.site_dir())
def test_dashes_in_dist_name_found_as_underscores(self):
"""
For a package with a dash in the name, the dist-info metadata
uses underscores in the name. Ensure the metadata loads.
"""
pkg_name = self.pkg_with_dashes(self.site_dir)
assert importlib_metadata.version(pkg_name) == '1.0'
@staticmethod
def pkg_with_mixed_case(site_dir):
"""
Create minimal metadata for a package with mixed case
in the name.
"""
metadata_dir = site_dir / 'CherryPy.dist-info'
metadata_dir.mkdir()
metadata = metadata_dir / 'METADATA'
with metadata.open('w') as strm:
strm.write('Version: 1.0\n')
return 'CherryPy'
def test_dist_name_found_as_any_case(self):
"""
Ensure the metadata loads when queried with any case.
"""
pkg_name = self.pkg_with_mixed_case(self.site_dir)
assert importlib_metadata.version(pkg_name) == '1.0'
assert importlib_metadata.version(pkg_name.lower()) == '1.0'
assert importlib_metadata.version(pkg_name.upper()) == '1.0'

View file

@ -0,0 +1,42 @@
import sys
import unittest
import importlib_metadata
try:
from contextlib import ExitStack
except ImportError:
from contextlib2 import ExitStack
from importlib_resources import path
class BespokeLoader:
archive = 'bespoke'
class TestZip(unittest.TestCase):
def setUp(self):
# Find the path to the example.*.whl so we can add it to the front of
# sys.path, where we'll then try to find the metadata thereof.
self.resources = ExitStack()
self.addCleanup(self.resources.close)
wheel = self.resources.enter_context(
path('importlib_metadata.tests.data',
'example-21.12-py3-none-any.whl'))
sys.path.insert(0, str(wheel))
self.resources.callback(sys.path.pop, 0)
def test_zip_version(self):
self.assertEqual(importlib_metadata.version('example'), '21.12')
def test_zip_entry_points(self):
parser = importlib_metadata.entry_points('example')
entry_point = parser.get('console_scripts', 'example')
self.assertEqual(entry_point, 'example:main')
def test_missing_metadata(self):
distribution = importlib_metadata.distribution('example')
self.assertIsNone(distribution.read_text('does not exist'))
def test_case_insensitive(self):
self.assertEqual(importlib_metadata.version('Example'), '21.12')

View file

@ -0,0 +1 @@
0.7

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -1 +0,0 @@
import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -1 +0,0 @@
import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -1 +0,0 @@
import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -1 +0,0 @@
import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -1 +0,0 @@
__import__("pkg_resources").declare_namespace(__name__)

View file

@ -5,6 +5,7 @@ of an object and its parent classes.
from __future__ import unicode_literals
def all_bases(c):
"""
return a tuple of all base classes the class c has as a parent.
@ -13,6 +14,7 @@ def all_bases(c):
"""
return c.mro()[1:]
def all_classes(c):
"""
return a tuple of all classes to which c belongs
@ -21,7 +23,10 @@ def all_classes(c):
"""
return c.mro()
# borrowed from http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/
# borrowed from
# http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/
def iter_subclasses(cls, _seen=None):
"""
Generator over all subclasses of a given class, in depth-first order.
@ -51,9 +56,12 @@ def iter_subclasses(cls, _seen=None):
"""
if not isinstance(cls, type):
raise TypeError('iter_subclasses must be called with '
'new-style classes, not %.100r' % cls)
if _seen is None: _seen = set()
raise TypeError(
'iter_subclasses must be called with '
'new-style classes, not %.100r' % cls
)
if _seen is None:
_seen = set()
try:
subs = cls.__subclasses__()
except TypeError: # fails only when cls is type

View file

@ -6,6 +6,7 @@ Some useful metaclasses.
from __future__ import unicode_literals
class LeafClassesMeta(type):
"""
A metaclass for classes that keeps track of all of them that

View file

@ -2,8 +2,10 @@ from __future__ import unicode_literals
import six
__metaclass__ = type
class NonDataProperty(object):
class NonDataProperty:
"""Much like the property builtin, but only implements __get__,
making it a non-data property, and can be subsequently reset.
@ -34,7 +36,7 @@ class NonDataProperty(object):
# from http://stackoverflow.com/a/5191224
class ClassPropertyDescriptor(object):
class ClassPropertyDescriptor:
def __init__(self, fget, fset=None):
self.fget = fget

View file

@ -7,12 +7,63 @@ import operator
import collections
import itertools
import copy
import functools
try:
import collections.abc
except ImportError:
# Python 2.7
collections.abc = collections
import six
from jaraco.classes.properties import NonDataProperty
import jaraco.text
class Projection(collections.abc.Mapping):
"""
Project a set of keys over a mapping
>>> sample = {'a': 1, 'b': 2, 'c': 3}
>>> prj = Projection(['a', 'c', 'd'], sample)
>>> prj == {'a': 1, 'c': 3}
True
Keys should only appear if they were specified and exist in the space.
>>> sorted(list(prj.keys()))
['a', 'c']
Use the projection to update another dict.
>>> target = {'a': 2, 'b': 2}
>>> target.update(prj)
>>> target == {'a': 1, 'b': 2, 'c': 3}
True
Also note that Projection keeps a reference to the original dict, so
if you modify the original dict, that could modify the Projection.
>>> del sample['a']
>>> dict(prj)
{'c': 3}
"""
def __init__(self, keys, space):
self._keys = tuple(keys)
self._space = space
def __getitem__(self, key):
if key not in self._keys:
raise KeyError(key)
return self._space[key]
def __iter__(self):
return iter(set(self._keys).intersection(self._space))
def __len__(self):
return len(tuple(iter(self)))
class DictFilter(object):
"""
Takes a dict, and simulates a sub-dict based on the keys.
@ -52,7 +103,6 @@ class DictFilter(object):
self.pattern_keys = set()
def get_pattern_keys(self):
#key_matches = lambda k, v: self.include_pattern.match(k)
keys = filter(self.include_pattern.match, self.dict.keys())
return set(keys)
pattern_keys = NonDataProperty(get_pattern_keys)
@ -70,7 +120,7 @@ class DictFilter(object):
return values
def __getitem__(self, i):
if not i in self.include_keys:
if i not in self.include_keys:
return KeyError, i
return self.dict[i]
@ -162,7 +212,7 @@ class RangeMap(dict):
>>> r.get(7, 'not found')
'not found'
"""
def __init__(self, source, sort_params = {}, key_match_comparator = operator.le):
def __init__(self, source, sort_params={}, key_match_comparator=operator.le):
dict.__init__(self, source)
self.sort_params = sort_params
self.match = key_match_comparator
@ -190,7 +240,7 @@ class RangeMap(dict):
return default
def _find_first_match_(self, keys, item):
is_match = lambda k: self.match(item, k)
is_match = functools.partial(self.match, item)
matches = list(filter(is_match, keys))
if matches:
return matches[0]
@ -205,12 +255,15 @@ class RangeMap(dict):
# some special values for the RangeMap
undefined_value = type(str('RangeValueUndefined'), (object,), {})()
class Item(int): pass
class Item(int):
"RangeMap Item"
first_item = Item(0)
last_item = Item(-1)
__identity = lambda x: x
def __identity(x):
return x
def sorted_items(d, key=__identity, reverse=False):
@ -229,7 +282,8 @@ def sorted_items(d, key=__identity, reverse=False):
(('foo', 20), ('baz', 10), ('bar', 42))
"""
# wrap the key func so it operates on the first element of each item
pairkey_key = lambda item: key(item[0])
def pairkey_key(item):
return key(item[0])
return sorted(d.items(), key=pairkey_key, reverse=reverse)
@ -414,7 +468,11 @@ class ItemsAsAttributes(object):
It also works on dicts that customize __getitem__
>>> missing_func = lambda self, key: 'missing item'
>>> C = type(str('C'), (dict, ItemsAsAttributes), dict(__missing__ = missing_func))
>>> C = type(
... str('C'),
... (dict, ItemsAsAttributes),
... dict(__missing__ = missing_func),
... )
>>> i = C()
>>> i.missing
'missing item'
@ -428,6 +486,7 @@ class ItemsAsAttributes(object):
# attempt to get the value from the mapping (return self[key])
# but be careful not to lose the original exception context.
noval = object()
def _safe_getitem(cont, key, missing_result):
try:
return cont[key]
@ -460,7 +519,7 @@ def invert_map(map):
...
ValueError: Key conflict in inverted mapping
"""
res = dict((v,k) for k, v in map.items())
res = dict((v, k) for k, v in map.items())
if not len(res) == len(map):
raise ValueError('Key conflict in inverted mapping')
return res
@ -483,7 +542,7 @@ class IdentityOverrideMap(dict):
return key
class DictStack(list, collections.Mapping):
class DictStack(list, collections.abc.Mapping):
"""
A stack of dictionaries that behaves as a view on those dictionaries,
giving preference to the last.
@ -506,6 +565,7 @@ class DictStack(list, collections.Mapping):
>>> d = stack.pop()
>>> stack['a']
1
>>> stack.get('b', None)
"""
def keys(self):
@ -513,7 +573,8 @@ class DictStack(list, collections.Mapping):
def __getitem__(self, key):
for scope in reversed(self):
if key in scope: return scope[key]
if key in scope:
return scope[key]
raise KeyError(key)
push = list.append
@ -553,6 +614,10 @@ class BijectiveMap(dict):
Traceback (most recent call last):
ValueError: Key/Value pairs may not overlap
>>> m['e'] = 'd'
Traceback (most recent call last):
ValueError: Key/Value pairs may not overlap
>>> print(m.pop('d'))
c
@ -583,7 +648,12 @@ class BijectiveMap(dict):
def __setitem__(self, item, value):
if item == value:
raise ValueError("Key cannot map to itself")
if (value in self or item in self) and self[item] != value:
overlap = (
item in self and self[item] != value
or
value in self and self[value] != item
)
if overlap:
raise ValueError("Key/Value pairs may not overlap")
super(BijectiveMap, self).__setitem__(item, value)
super(BijectiveMap, self).__setitem__(value, item)
@ -607,7 +677,7 @@ class BijectiveMap(dict):
self.__setitem__(*item)
class FrozenDict(collections.Mapping, collections.Hashable):
class FrozenDict(collections.abc.Mapping, collections.abc.Hashable):
"""
An immutable mapping.
@ -641,8 +711,8 @@ class FrozenDict(collections.Mapping, collections.Hashable):
>>> isinstance(copy.copy(a), FrozenDict)
True
FrozenDict supplies .copy(), even though collections.Mapping doesn't
demand it.
FrozenDict supplies .copy(), even though
collections.abc.Mapping doesn't demand it.
>>> a.copy() == a
True
@ -747,6 +817,9 @@ class Everything(object):
>>> import random
>>> random.randint(1, 999) in Everything()
True
>>> random.choice([None, 'foo', 42, ('a', 'b', 'c')]) in Everything()
True
"""
def __contains__(self, other):
return True
@ -771,3 +844,63 @@ class InstrumentedDict(six.moves.UserDict):
def __init__(self, data):
six.moves.UserDict.__init__(self)
self.data = data
class Least(object):
"""
A value that is always lesser than any other
>>> least = Least()
>>> 3 < least
False
>>> 3 > least
True
>>> least < 3
True
>>> least <= 3
True
>>> least > 3
False
>>> 'x' > least
True
>>> None > least
True
"""
def __le__(self, other):
return True
__lt__ = __le__
def __ge__(self, other):
return False
__gt__ = __ge__
class Greatest(object):
"""
A value that is always greater than any other
>>> greatest = Greatest()
>>> 3 < greatest
True
>>> 3 > greatest
False
>>> greatest < 3
False
>>> greatest > 3
True
>>> greatest >= 3
True
>>> 'x' > greatest
False
>>> None > greatest
False
"""
def __ge__(self, other):
return True
__gt__ = __ge__
def __le__(self, other):
return False
__lt__ = __le__

View file

@ -1,8 +1,16 @@
from __future__ import absolute_import, unicode_literals, print_function, division
from __future__ import (
absolute_import, unicode_literals, print_function, division,
)
import functools
import time
import warnings
import inspect
import collections
from itertools import count
__metaclass__ = type
try:
from functools import lru_cache
@ -16,13 +24,17 @@ except ImportError:
warnings.warn("No lru_cache available")
import more_itertools.recipes
def compose(*funcs):
"""
Compose any number of unary functions into a single unary function.
>>> import textwrap
>>> from six import text_type
>>> text_type.strip(textwrap.dedent(compose.__doc__)) == compose(text_type.strip, textwrap.dedent)(compose.__doc__)
>>> stripped = text_type.strip(textwrap.dedent(compose.__doc__))
>>> compose(text_type.strip, textwrap.dedent)(compose.__doc__) == stripped
True
Compose also allows the innermost function to take arbitrary arguments.
@ -33,7 +45,8 @@ def compose(*funcs):
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
"""
compose_two = lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs))
def compose_two(f1, f2):
return lambda *args, **kwargs: f1(f2(*args, **kwargs))
return functools.reduce(compose_two, funcs)
@ -60,19 +73,36 @@ def once(func):
This decorator can ensure that an expensive or non-idempotent function
will not be expensive on subsequent calls and is idempotent.
>>> func = once(lambda a: a+3)
>>> func(3)
>>> add_three = once(lambda a: a+3)
>>> add_three(3)
6
>>> func(9)
>>> add_three(9)
6
>>> func('12')
>>> add_three('12')
6
To reset the stored value, simply clear the property ``saved_result``.
>>> del add_three.saved_result
>>> add_three(9)
12
>>> add_three(8)
12
Or invoke 'reset()' on it.
>>> add_three.reset()
>>> add_three(-3)
0
>>> add_three(0)
0
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not hasattr(func, 'always_returns'):
func.always_returns = func(*args, **kwargs)
return func.always_returns
if not hasattr(wrapper, 'saved_result'):
wrapper.saved_result = func(*args, **kwargs)
return wrapper.saved_result
wrapper.reset = lambda: vars(wrapper).__delitem__('saved_result')
return wrapper
@ -131,17 +161,22 @@ def method_cache(method, cache_wrapper=None):
>>> a.method2()
3
Caution - do not subsequently wrap the method with another decorator, such
as ``@property``, which changes the semantics of the function.
See also
http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/
for another implementation and additional justification.
"""
cache_wrapper = cache_wrapper or lru_cache()
def wrapper(self, *args, **kwargs):
# it's the first call, replace the method with a cached, bound method
bound_method = functools.partial(method, self)
cached_method = cache_wrapper(bound_method)
setattr(self, method.__name__, cached_method)
return cached_method(*args, **kwargs)
return _special_method_cache(method, cache_wrapper) or wrapper
@ -191,6 +226,29 @@ def apply(transform):
return wrap
def result_invoke(action):
r"""
Decorate a function with an action function that is
invoked on the results returned from the decorated
function (for its side-effect), then return the original
result.
>>> @result_invoke(print)
... def add_two(a, b):
... return a + b
>>> x = add_two(2, 3)
5
"""
def wrap(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
action(result)
return result
return wrapper
return wrap
def call_aside(f, *args, **kwargs):
"""
Call a function for its side effect after initialization.
@ -211,7 +269,7 @@ def call_aside(f, *args, **kwargs):
return f
class Throttler(object):
class Throttler:
"""
Rate-limit a function (or other callable)
"""
@ -259,10 +317,143 @@ def retry_call(func, cleanup=lambda: None, retries=0, trap=()):
exception. On the final attempt, allow any exceptions
to propagate.
"""
for attempt in range(retries):
attempts = count() if retries == float('inf') else range(retries)
for attempt in attempts:
try:
return func()
except trap:
cleanup()
return func()
def retry(*r_args, **r_kwargs):
"""
Decorator wrapper for retry_call. Accepts arguments to retry_call
except func and then returns a decorator for the decorated function.
Ex:
>>> @retry(retries=3)
... def my_func(a, b):
... "this is my funk"
... print(a, b)
>>> my_func.__doc__
'this is my funk'
"""
def decorate(func):
@functools.wraps(func)
def wrapper(*f_args, **f_kwargs):
bound = functools.partial(func, *f_args, **f_kwargs)
return retry_call(bound, *r_args, **r_kwargs)
return wrapper
return decorate
def print_yielded(func):
"""
Convert a generator into a function that prints all yielded elements
>>> @print_yielded
... def x():
... yield 3; yield None
>>> x()
3
None
"""
print_all = functools.partial(map, print)
print_results = compose(more_itertools.recipes.consume, print_all, func)
return functools.wraps(func)(print_results)
def pass_none(func):
"""
Wrap func so it's not called if its first param is None
>>> print_text = pass_none(print)
>>> print_text('text')
text
>>> print_text(None)
"""
@functools.wraps(func)
def wrapper(param, *args, **kwargs):
if param is not None:
return func(param, *args, **kwargs)
return wrapper
def assign_params(func, namespace):
"""
Assign parameters from namespace where func solicits.
>>> def func(x, y=3):
... print(x, y)
>>> assigned = assign_params(func, dict(x=2, z=4))
>>> assigned()
2 3
The usual errors are raised if a function doesn't receive
its required parameters:
>>> assigned = assign_params(func, dict(y=3, z=4))
>>> assigned()
Traceback (most recent call last):
TypeError: func() ...argument...
"""
try:
sig = inspect.signature(func)
params = sig.parameters.keys()
except AttributeError:
spec = inspect.getargspec(func)
params = spec.args
call_ns = {
k: namespace[k]
for k in params
if k in namespace
}
return functools.partial(func, **call_ns)
def save_method_args(method):
"""
Wrap a method such that when it is called, the args and kwargs are
saved on the method.
>>> class MyClass:
... @save_method_args
... def method(self, a, b):
... print(a, b)
>>> my_ob = MyClass()
>>> my_ob.method(1, 2)
1 2
>>> my_ob._saved_method.args
(1, 2)
>>> my_ob._saved_method.kwargs
{}
>>> my_ob.method(a=3, b='foo')
3 foo
>>> my_ob._saved_method.args
()
>>> my_ob._saved_method.kwargs == dict(a=3, b='foo')
True
The arguments are stored on the instance, allowing for
different instance to save different args.
>>> your_ob = MyClass()
>>> your_ob.method({str('x'): 3}, b=[4])
{'x': 3} [4]
>>> your_ob._saved_method.args
({'x': 3},)
>>> my_ob._saved_method.args
()
"""
args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs')
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
attr_name = '_saved_' + method.__name__
attr = args_and_kwargs(args, kwargs)
setattr(self, attr_name, attr)
return method(self, *args, **kwargs)
return wrapper

View file

@ -1,5 +1,6 @@
from __future__ import absolute_import, unicode_literals
import numbers
from functools import reduce
@ -25,6 +26,7 @@ def get_bit_values(number, size=32):
number += 2**size
return list(map(int, bin(number)[-size:]))
def gen_bit_values(number):
"""
Return a zero or one for each bit of a numeric value up to the most
@ -36,6 +38,7 @@ def gen_bit_values(number):
digits = bin(number)[2:]
return map(int, reversed(digits))
def coalesce(bits):
"""
Take a sequence of bits, most significant first, and
@ -47,6 +50,7 @@ def coalesce(bits):
operation = lambda a, b: (a << 1 | b)
return reduce(operation, bits)
class Flags(object):
"""
Subclasses should define _names, a list of flag names beginning
@ -96,6 +100,7 @@ class Flags(object):
index = self._names.index(key)
return self._values[index]
class BitMask(type):
"""
A metaclass to create a bitmask with attributes. Subclass an int and
@ -119,12 +124,28 @@ class BitMask(type):
>>> b2 = MyBits(8)
>>> any([b2.a, b2.b, b2.c])
False
If the instance defines methods, they won't be wrapped in
properties.
>>> ns['get_value'] = classmethod(lambda cls: 'some value')
>>> ns['prop'] = property(lambda self: 'a property')
>>> MyBits = BitMask(str('MyBits'), (int,), ns)
>>> MyBits(3).get_value()
'some value'
>>> MyBits(3).prop
'a property'
"""
def __new__(cls, name, bases, attrs):
def make_property(name, value):
if name.startswith('_') or not isinstance(value, numbers.Number):
return value
return property(lambda self, value=value: bool(self & value))
newattrs = dict(
(attr, property(lambda self, value=value: bool(self & value)))
for attr, value in attrs.items()
if not attr.startswith('_')
(name, make_property(name, value))
for name, value in attrs.items()
)
return type.__new__(cls, name, bases, newattrs)

View file

@ -39,6 +39,7 @@ class FoldedCase(six.text_type):
"""
A case insensitive string class; behaves just like str
except compares equal when the only variation is case.
>>> s = FoldedCase('hello world')
>>> s == 'Hello World'
@ -47,6 +48,9 @@ class FoldedCase(six.text_type):
>>> 'Hello World' == s
True
>>> s != 'Hello World'
False
>>> s.index('O')
4
@ -55,6 +59,38 @@ class FoldedCase(six.text_type):
>>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta']))
['alpha', 'Beta', 'GAMMA']
Sequence membership is straightforward.
>>> "Hello World" in [s]
True
>>> s in ["Hello World"]
True
You may test for set inclusion, but candidate and elements
must both be folded.
>>> FoldedCase("Hello World") in {s}
True
>>> s in {FoldedCase("Hello World")}
True
String inclusion works as long as the FoldedCase object
is on the right.
>>> "hello" in FoldedCase("Hello World")
True
But not if the FoldedCase object is on the left:
>>> FoldedCase('hello') in 'Hello World'
False
In that case, use in_:
>>> FoldedCase('hello').in_('Hello World')
True
"""
def __lt__(self, other):
return self.lower() < other.lower()
@ -65,14 +101,23 @@ class FoldedCase(six.text_type):
def __eq__(self, other):
return self.lower() == other.lower()
def __ne__(self, other):
return self.lower() != other.lower()
def __hash__(self):
return hash(self.lower())
def __contains__(self, other):
return super(FoldedCase, self).lower().__contains__(other.lower())
def in_(self, other):
"Does self appear in other?"
return self in FoldedCase(other)
# cache lower since it's likely to be called frequently.
@jaraco.functools.method_cache
def lower(self):
self._lower = super(FoldedCase, self).lower()
self.lower = lambda: self._lower
return self._lower
return super(FoldedCase, self).lower()
def index(self, sub):
return self.lower().index(sub.lower())
@ -147,6 +192,7 @@ def is_decodable(value):
return False
return True
def is_binary(value):
"""
Return True if the value appears to be binary (that is, it's a byte
@ -154,6 +200,7 @@ def is_binary(value):
"""
return isinstance(value, bytes) and not is_decodable(value)
def trim(s):
r"""
Trim something like a docstring to remove the whitespace that
@ -164,8 +211,10 @@ def trim(s):
"""
return textwrap.dedent(s).strip()
class Splitter(object):
"""object that will split a string with the given arguments for each call
>>> s = Splitter(',')
>>> s('hello, world, this is your, master calling')
['hello', ' world', ' this is your', ' master calling']
@ -176,9 +225,11 @@ class Splitter(object):
def __call__(self, s):
return s.split(*self.args)
def indent(string, prefix=' ' * 4):
return prefix + string
class WordSet(tuple):
"""
Given a Python identifier, return the words that identifier represents,
@ -269,6 +320,7 @@ class WordSet(tuple):
def from_class_name(cls, subject):
return cls.parse(subject.__class__.__name__)
# for backward compatibility
words = WordSet.parse
@ -318,6 +370,7 @@ class SeparatedValues(six.text_type):
parts = self.split(self.separator)
return six.moves.filter(None, (part.strip() for part in parts))
class Stripper:
r"""
Given a series of lines, find the common prefix and strip it from them.
@ -369,3 +422,31 @@ class Stripper:
while s1[:index] != s2[:index]:
index -= 1
return s1[:index]
def remove_prefix(text, prefix):
"""
Remove the prefix from the text if it exists.
>>> remove_prefix('underwhelming performance', 'underwhelming ')
'performance'
>>> remove_prefix('something special', 'sample')
'something special'
"""
null, prefix, rest = text.rpartition(prefix)
return rest
def remove_suffix(text, suffix):
"""
Remove the suffix from the text if it exists.
>>> remove_suffix('name.git', '.git')
'name'
>>> remove_suffix('something special', 'sample')
'something special'
"""
rest, suffix, null = text.partition(suffix)
return rest

View file

@ -60,3 +60,18 @@ class Command(object):
cls.add_subparsers(parser)
args = parser.parse_args()
args.action.run(args)
class Extend(argparse.Action):
"""
Argparse action to take an nargs=* argument
and add any values to the existing value.
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', nargs='*', default=[], action=Extend)
>>> args = parser.parse_args(['--foo', 'a=1', '--foo', 'b=2', 'c=3'])
>>> args.foo
['a=1', 'b=2', 'c=3']
"""
def __call__(self, parser, namespace, values, option_string=None):
getattr(namespace, self.dest).extend(values)

View file

@ -1,3 +1,5 @@
# deprecated -- use TQDM
from __future__ import (print_function, absolute_import, unicode_literals,
division)

View file

@ -45,3 +45,9 @@ GetClipboardData.restype = ctypes.wintypes.HANDLE
SetClipboardData = ctypes.windll.user32.SetClipboardData
SetClipboardData.argtypes = ctypes.wintypes.UINT, ctypes.wintypes.HANDLE
SetClipboardData.restype = ctypes.wintypes.HANDLE
OpenClipboard = ctypes.windll.user32.OpenClipboard
OpenClipboard.argtypes = ctypes.wintypes.HANDLE,
OpenClipboard.restype = ctypes.wintypes.BOOL
ctypes.windll.user32.CloseClipboard.restype = ctypes.wintypes.BOOL

View file

@ -10,9 +10,11 @@ try:
except ImportError:
LPBYTE = ctypes.POINTER(ctypes.wintypes.BYTE)
class CredentialAttribute(ctypes.Structure):
_fields_ = []
class Credential(ctypes.Structure):
_fields_ = [
('flags', DWORD),
@ -32,28 +34,29 @@ class Credential(ctypes.Structure):
def __del__(self):
ctypes.windll.advapi32.CredFree(ctypes.byref(self))
PCREDENTIAL = ctypes.POINTER(Credential)
CredRead = ctypes.windll.advapi32.CredReadW
CredRead.argtypes = (
LPCWSTR, # TargetName
DWORD, # Type
DWORD, # Flags
ctypes.POINTER(PCREDENTIAL), # Credential
LPCWSTR, # TargetName
DWORD, # Type
DWORD, # Flags
ctypes.POINTER(PCREDENTIAL), # Credential
)
CredRead.restype = BOOL
CredWrite = ctypes.windll.advapi32.CredWriteW
CredWrite.argtypes = (
PCREDENTIAL, # Credential
DWORD, # Flags
PCREDENTIAL, # Credential
DWORD, # Flags
)
CredWrite.restype = BOOL
CredDelete = ctypes.windll.advapi32.CredDeleteW
CredDelete.argtypes = (
LPCWSTR, # TargetName
DWORD, # Type
DWORD, # Flags
LPCWSTR, # TargetName
DWORD, # Type
DWORD, # Flags
)
CredDelete.restype = BOOL

View file

@ -2,7 +2,7 @@ import ctypes.wintypes
SetEnvironmentVariable = ctypes.windll.kernel32.SetEnvironmentVariableW
SetEnvironmentVariable.restype = ctypes.wintypes.BOOL
SetEnvironmentVariable.argtypes = [ctypes.wintypes.LPCWSTR]*2
SetEnvironmentVariable.argtypes = [ctypes.wintypes.LPCWSTR] * 2
GetEnvironmentVariable = ctypes.windll.kernel32.GetEnvironmentVariableW
GetEnvironmentVariable.restype = ctypes.wintypes.BOOL

View file

@ -1,16 +1,11 @@
from ctypes import (
Structure, windll, POINTER, byref, cast, create_unicode_buffer,
c_size_t, c_int, create_string_buffer, c_uint64, c_ushort, c_short,
c_uint,
)
from ctypes import windll, POINTER
from ctypes.wintypes import (
BOOLEAN, LPWSTR, DWORD, LPVOID, HANDLE, FILETIME,
WCHAR, BOOL, HWND, WORD, UINT,
)
LPWSTR, DWORD, LPVOID, HANDLE, BOOL,
)
CreateEvent = windll.kernel32.CreateEventW
CreateEvent.argtypes = (
LPVOID, # LPSECURITY_ATTRIBUTES
LPVOID, # LPSECURITY_ATTRIBUTES
BOOL,
BOOL,
LPWSTR,
@ -29,13 +24,15 @@ _WaitForMultipleObjects = windll.kernel32.WaitForMultipleObjects
_WaitForMultipleObjects.argtypes = DWORD, POINTER(HANDLE), BOOL, DWORD
_WaitForMultipleObjects.restype = DWORD
def WaitForMultipleObjects(handles, wait_all=False, timeout=0):
n_handles = len(handles)
handle_array = (HANDLE*n_handles)()
handle_array = (HANDLE * n_handles)()
for index, handle in enumerate(handles):
handle_array[index] = handle
return _WaitForMultipleObjects(n_handles, handle_array, wait_all, timeout)
WAIT_OBJECT_0 = 0
INFINITE = -1
WAIT_TIMEOUT = 0x102

View file

@ -5,15 +5,15 @@ CreateSymbolicLink.argtypes = (
ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.DWORD,
)
)
CreateSymbolicLink.restype = ctypes.wintypes.BOOLEAN
CreateHardLink = ctypes.windll.kernel32.CreateHardLinkW
CreateHardLink.argtypes = (
ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPVOID, # reserved for LPSECURITY_ATTRIBUTES
)
ctypes.wintypes.LPVOID, # reserved for LPSECURITY_ATTRIBUTES
)
CreateHardLink.restype = ctypes.wintypes.BOOLEAN
GetFileAttributes = ctypes.windll.kernel32.GetFileAttributesW
@ -28,16 +28,20 @@ MAX_PATH = 260
GetFinalPathNameByHandle = ctypes.windll.kernel32.GetFinalPathNameByHandleW
GetFinalPathNameByHandle.argtypes = (
ctypes.wintypes.HANDLE, ctypes.wintypes.LPWSTR, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD,
)
ctypes.wintypes.HANDLE, ctypes.wintypes.LPWSTR, ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
)
GetFinalPathNameByHandle.restype = ctypes.wintypes.DWORD
class SECURITY_ATTRIBUTES(ctypes.Structure):
_fields_ = (
('length', ctypes.wintypes.DWORD),
('p_security_descriptor', ctypes.wintypes.LPVOID),
('inherit_handle', ctypes.wintypes.BOOLEAN),
)
)
LPSECURITY_ATTRIBUTES = ctypes.POINTER(SECURITY_ATTRIBUTES)
CreateFile = ctypes.windll.kernel32.CreateFileW
@ -49,7 +53,7 @@ CreateFile.argtypes = (
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
ctypes.wintypes.HANDLE,
)
)
CreateFile.restype = ctypes.wintypes.HANDLE
FILE_SHARE_READ = 1
FILE_SHARE_WRITE = 2
@ -77,23 +81,61 @@ CloseHandle = ctypes.windll.kernel32.CloseHandle
CloseHandle.argtypes = (ctypes.wintypes.HANDLE,)
CloseHandle.restype = ctypes.wintypes.BOOLEAN
class WIN32_FIND_DATA(ctypes.Structure):
class WIN32_FIND_DATA(ctypes.wintypes.WIN32_FIND_DATAW):
"""
_fields_ = [
('file_attributes', ctypes.wintypes.DWORD),
('creation_time', ctypes.wintypes.FILETIME),
('last_access_time', ctypes.wintypes.FILETIME),
('last_write_time', ctypes.wintypes.FILETIME),
('file_size_words', ctypes.wintypes.DWORD*2),
('reserved', ctypes.wintypes.DWORD*2),
('filename', ctypes.wintypes.WCHAR*MAX_PATH),
('alternate_filename', ctypes.wintypes.WCHAR*14),
("dwFileAttributes", DWORD),
("ftCreationTime", FILETIME),
("ftLastAccessTime", FILETIME),
("ftLastWriteTime", FILETIME),
("nFileSizeHigh", DWORD),
("nFileSizeLow", DWORD),
("dwReserved0", DWORD),
("dwReserved1", DWORD),
("cFileName", WCHAR * MAX_PATH),
("cAlternateFileName", WCHAR * 14)]
]
"""
@property
def file_attributes(self):
return self.dwFileAttributes
@property
def creation_time(self):
return self.ftCreationTime
@property
def last_access_time(self):
return self.ftLastAccessTime
@property
def last_write_time(self):
return self.ftLastWriteTime
@property
def file_size_words(self):
return [self.nFileSizeHigh, self.nFileSizeLow]
@property
def reserved(self):
return [self.dwReserved0, self.dwReserved1]
@property
def filename(self):
return self.cFileName
@property
def alternate_filename(self):
return self.cAlternateFileName
@property
def file_size(self):
return ctypes.cast(self.file_size_words, ctypes.POINTER(ctypes.c_uint64)).contents
return self.nFileSizeHigh << 32 + self.nFileSizeLow
LPWIN32_FIND_DATA = ctypes.POINTER(WIN32_FIND_DATA)
LPWIN32_FIND_DATA = ctypes.POINTER(ctypes.wintypes.WIN32_FIND_DATAW)
FindFirstFile = ctypes.windll.kernel32.FindFirstFileW
FindFirstFile.argtypes = (ctypes.wintypes.LPWSTR, LPWIN32_FIND_DATA)
@ -102,19 +144,56 @@ FindNextFile = ctypes.windll.kernel32.FindNextFileW
FindNextFile.argtypes = (ctypes.wintypes.HANDLE, LPWIN32_FIND_DATA)
FindNextFile.restype = ctypes.wintypes.BOOLEAN
SCS_32BIT_BINARY = 0 # A 32-bit Windows-based application
SCS_64BIT_BINARY = 6 # A 64-bit Windows-based application
SCS_DOS_BINARY = 1 # An MS-DOS-based application
SCS_OS216_BINARY = 5 # A 16-bit OS/2-based application
SCS_PIF_BINARY = 3 # A PIF file that executes an MS-DOS-based application
SCS_POSIX_BINARY = 4 # A POSIX-based application
SCS_WOW_BINARY = 2 # A 16-bit Windows-based application
ctypes.windll.kernel32.FindClose.argtypes = ctypes.wintypes.HANDLE,
SCS_32BIT_BINARY = 0 # A 32-bit Windows-based application
SCS_64BIT_BINARY = 6 # A 64-bit Windows-based application
SCS_DOS_BINARY = 1 # An MS-DOS-based application
SCS_OS216_BINARY = 5 # A 16-bit OS/2-based application
SCS_PIF_BINARY = 3 # A PIF file that executes an MS-DOS-based application
SCS_POSIX_BINARY = 4 # A POSIX-based application
SCS_WOW_BINARY = 2 # A 16-bit Windows-based application
_GetBinaryType = ctypes.windll.kernel32.GetBinaryTypeW
_GetBinaryType.argtypes = (ctypes.wintypes.LPWSTR, ctypes.POINTER(ctypes.wintypes.DWORD))
_GetBinaryType.argtypes = (
ctypes.wintypes.LPWSTR, ctypes.POINTER(ctypes.wintypes.DWORD),
)
_GetBinaryType.restype = ctypes.wintypes.BOOL
FILEOP_FLAGS = ctypes.wintypes.WORD
class BY_HANDLE_FILE_INFORMATION(ctypes.Structure):
_fields_ = [
('file_attributes', ctypes.wintypes.DWORD),
('creation_time', ctypes.wintypes.FILETIME),
('last_access_time', ctypes.wintypes.FILETIME),
('last_write_time', ctypes.wintypes.FILETIME),
('volume_serial_number', ctypes.wintypes.DWORD),
('file_size_high', ctypes.wintypes.DWORD),
('file_size_low', ctypes.wintypes.DWORD),
('number_of_links', ctypes.wintypes.DWORD),
('file_index_high', ctypes.wintypes.DWORD),
('file_index_low', ctypes.wintypes.DWORD),
]
@property
def file_size(self):
return (self.file_size_high << 32) + self.file_size_low
@property
def file_index(self):
return (self.file_index_high << 32) + self.file_index_low
GetFileInformationByHandle = ctypes.windll.kernel32.GetFileInformationByHandle
GetFileInformationByHandle.restype = ctypes.wintypes.BOOL
GetFileInformationByHandle.argtypes = (
ctypes.wintypes.HANDLE,
ctypes.POINTER(BY_HANDLE_FILE_INFORMATION),
)
class SHFILEOPSTRUCT(ctypes.Structure):
_fields_ = [
('status_dialog', ctypes.wintypes.HWND),
@ -126,6 +205,8 @@ class SHFILEOPSTRUCT(ctypes.Structure):
('name_mapping_handles', ctypes.wintypes.LPVOID),
('progress_title', ctypes.wintypes.LPWSTR),
]
_SHFileOperation = ctypes.windll.shell32.SHFileOperationW
_SHFileOperation.argtypes = [ctypes.POINTER(SHFILEOPSTRUCT)]
_SHFileOperation.restype = ctypes.c_int
@ -143,12 +224,13 @@ ReplaceFile.argtypes = [
ctypes.wintypes.DWORD,
ctypes.wintypes.LPVOID,
ctypes.wintypes.LPVOID,
]
]
REPLACEFILE_WRITE_THROUGH = 0x1
REPLACEFILE_IGNORE_MERGE_ERRORS = 0x2
REPLACEFILE_IGNORE_ACL_ERRORS = 0x4
class STAT_STRUCT(ctypes.Structure):
_fields_ = [
('dev', ctypes.c_uint),
@ -165,17 +247,22 @@ class STAT_STRUCT(ctypes.Structure):
('ctime', ctypes.c_uint),
]
_wstat = ctypes.windll.msvcrt._wstat
_wstat.argtypes = [ctypes.wintypes.LPWSTR, ctypes.POINTER(STAT_STRUCT)]
_wstat.restype = ctypes.c_int
FILE_NOTIFY_CHANGE_LAST_WRITE = 0x10
FindFirstChangeNotification = ctypes.windll.kernel32.FindFirstChangeNotificationW
FindFirstChangeNotification.argtypes = ctypes.wintypes.LPWSTR, ctypes.wintypes.BOOL, ctypes.wintypes.DWORD
FindFirstChangeNotification = (
ctypes.windll.kernel32.FindFirstChangeNotificationW)
FindFirstChangeNotification.argtypes = (
ctypes.wintypes.LPWSTR, ctypes.wintypes.BOOL, ctypes.wintypes.DWORD,
)
FindFirstChangeNotification.restype = ctypes.wintypes.HANDLE
FindCloseChangeNotification = ctypes.windll.kernel32.FindCloseChangeNotification
FindCloseChangeNotification = (
ctypes.windll.kernel32.FindCloseChangeNotification)
FindCloseChangeNotification.argtypes = ctypes.wintypes.HANDLE,
FindCloseChangeNotification.restype = ctypes.wintypes.BOOL
@ -200,9 +287,10 @@ DeviceIoControl.argtypes = [
ctypes.wintypes.DWORD,
LPDWORD,
LPOVERLAPPED,
]
]
DeviceIoControl.restype = ctypes.wintypes.BOOL
class REPARSE_DATA_BUFFER(ctypes.Structure):
_fields_ = [
('tag', ctypes.c_ulong),
@ -213,16 +301,17 @@ class REPARSE_DATA_BUFFER(ctypes.Structure):
('print_name_offset', ctypes.c_ushort),
('print_name_length', ctypes.c_ushort),
('flags', ctypes.c_ulong),
('path_buffer', ctypes.c_byte*1),
('path_buffer', ctypes.c_byte * 1),
]
def get_print_name(self):
wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR)
arr_typ = ctypes.wintypes.WCHAR*(self.print_name_length//wchar_size)
arr_typ = ctypes.wintypes.WCHAR * (self.print_name_length // wchar_size)
data = ctypes.byref(self.path_buffer, self.print_name_offset)
return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value
def get_substitute_name(self):
wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR)
arr_typ = ctypes.wintypes.WCHAR*(self.substitute_name_length//wchar_size)
arr_typ = ctypes.wintypes.WCHAR * (self.substitute_name_length // wchar_size)
data = ctypes.byref(self.path_buffer, self.substitute_name_offset)
return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value

View file

@ -2,6 +2,7 @@ import struct
import ctypes.wintypes
from ctypes.wintypes import DWORD, WCHAR, BYTE, BOOL
# from mprapi.h
MAX_INTERFACE_NAME_LEN = 2**8
@ -13,15 +14,16 @@ MAXLEN_IFDESCR = 2**8
MAX_ADAPTER_ADDRESS_LENGTH = 8
MAX_DHCPV6_DUID_LENGTH = 130
class MIB_IFROW(ctypes.Structure):
_fields_ = (
('name', WCHAR*MAX_INTERFACE_NAME_LEN),
('name', WCHAR * MAX_INTERFACE_NAME_LEN),
('index', DWORD),
('type', DWORD),
('MTU', DWORD),
('speed', DWORD),
('physical_address_length', DWORD),
('physical_address_raw', BYTE*MAXLEN_PHYSADDR),
('physical_address_raw', BYTE * MAXLEN_PHYSADDR),
('admin_status', DWORD),
('operational_status', DWORD),
('last_change', DWORD),
@ -38,7 +40,7 @@ class MIB_IFROW(ctypes.Structure):
('outgoing_errors', DWORD),
('outgoing_queue_length', DWORD),
('description_length', DWORD),
('description_raw', ctypes.c_char*MAXLEN_IFDESCR),
('description_raw', ctypes.c_char * MAXLEN_IFDESCR),
)
def _get_binary_property(self, name):
@ -46,7 +48,7 @@ class MIB_IFROW(ctypes.Structure):
val = getattr(self, val_prop)
len_prop = '{0}_length'.format(name)
length = getattr(self, len_prop)
return str(buffer(val))[:length]
return str(memoryview(val))[:length]
@property
def physical_address(self):
@ -56,12 +58,14 @@ class MIB_IFROW(ctypes.Structure):
def description(self):
return self._get_binary_property('description')
class MIB_IFTABLE(ctypes.Structure):
_fields_ = (
('num_entries', DWORD), # dwNumEntries
('entries', MIB_IFROW*0), # table
('entries', MIB_IFROW * 0), # table
)
class MIB_IPADDRROW(ctypes.Structure):
_fields_ = (
('address_num', DWORD),
@ -79,40 +83,49 @@ class MIB_IPADDRROW(ctypes.Structure):
_ = struct.pack('L', self.address_num)
return struct.unpack('!L', _)[0]
class MIB_IPADDRTABLE(ctypes.Structure):
_fields_ = (
('num_entries', DWORD),
('entries', MIB_IPADDRROW*0),
('entries', MIB_IPADDRROW * 0),
)
class SOCKADDR(ctypes.Structure):
_fields_ = (
('family', ctypes.c_ushort),
('data', ctypes.c_byte*14),
)
('data', ctypes.c_byte * 14),
)
LPSOCKADDR = ctypes.POINTER(SOCKADDR)
class SOCKET_ADDRESS(ctypes.Structure):
_fields_ = [
('address', LPSOCKADDR),
('length', ctypes.c_int),
]
]
class _IP_ADAPTER_ADDRESSES_METRIC(ctypes.Structure):
_fields_ = [
('length', ctypes.c_ulong),
('interface_index', DWORD),
]
]
class _IP_ADAPTER_ADDRESSES_U1(ctypes.Union):
_fields_ = [
('alignment', ctypes.c_ulonglong),
('metric', _IP_ADAPTER_ADDRESSES_METRIC),
]
]
class IP_ADAPTER_ADDRESSES(ctypes.Structure):
pass
LP_IP_ADAPTER_ADDRESSES = ctypes.POINTER(IP_ADAPTER_ADDRESSES)
# for now, just use void * for pointers to unused structures
@ -125,19 +138,20 @@ PIP_ADAPTER_WINS_SERVER_ADDRESS_LH = ctypes.c_void_p
PIP_ADAPTER_GATEWAY_ADDRESS_LH = ctypes.c_void_p
PIP_ADAPTER_DNS_SUFFIX = ctypes.c_void_p
IF_OPER_STATUS = ctypes.c_uint # this is an enum, consider http://code.activestate.com/recipes/576415/
IF_OPER_STATUS = ctypes.c_uint # this is an enum, consider
# http://code.activestate.com/recipes/576415/
IF_LUID = ctypes.c_uint64
NET_IF_COMPARTMENT_ID = ctypes.c_uint32
GUID = ctypes.c_byte*16
GUID = ctypes.c_byte * 16
NET_IF_NETWORK_GUID = GUID
NET_IF_CONNECTION_TYPE = ctypes.c_uint # enum
TUNNEL_TYPE = ctypes.c_uint # enum
NET_IF_CONNECTION_TYPE = ctypes.c_uint # enum
TUNNEL_TYPE = ctypes.c_uint # enum
IP_ADAPTER_ADDRESSES._fields_ = [
#('u', _IP_ADAPTER_ADDRESSES_U1),
('length', ctypes.c_ulong),
('interface_index', DWORD),
# ('u', _IP_ADAPTER_ADDRESSES_U1),
('length', ctypes.c_ulong),
('interface_index', DWORD),
('next', LP_IP_ADAPTER_ADDRESSES),
('adapter_name', ctypes.c_char_p),
('first_unicast_address', PIP_ADAPTER_UNICAST_ADDRESS),
@ -147,7 +161,7 @@ IP_ADAPTER_ADDRESSES._fields_ = [
('dns_suffix', ctypes.c_wchar_p),
('description', ctypes.c_wchar_p),
('friendly_name', ctypes.c_wchar_p),
('byte', BYTE*MAX_ADAPTER_ADDRESS_LENGTH),
('byte', BYTE * MAX_ADAPTER_ADDRESS_LENGTH),
('physical_address_length', DWORD),
('flags', DWORD),
('mtu', DWORD),
@ -169,11 +183,11 @@ IP_ADAPTER_ADDRESSES._fields_ = [
('connection_type', NET_IF_CONNECTION_TYPE),
('tunnel_type', TUNNEL_TYPE),
('dhcpv6_server', SOCKET_ADDRESS),
('dhcpv6_client_duid', ctypes.c_byte*MAX_DHCPV6_DUID_LENGTH),
('dhcpv6_client_duid', ctypes.c_byte * MAX_DHCPV6_DUID_LENGTH),
('dhcpv6_client_duid_length', ctypes.c_ulong),
('dhcpv6_iaid', ctypes.c_ulong),
('first_dns_suffix', PIP_ADAPTER_DNS_SUFFIX),
]
]
# define some parameters to the API Functions
GetIfTable = ctypes.windll.iphlpapi.GetIfTable
@ -181,7 +195,7 @@ GetIfTable.argtypes = [
ctypes.POINTER(MIB_IFTABLE),
ctypes.POINTER(ctypes.c_ulong),
BOOL,
]
]
GetIfTable.restype = DWORD
GetIpAddrTable = ctypes.windll.iphlpapi.GetIpAddrTable
@ -189,7 +203,7 @@ GetIpAddrTable.argtypes = [
ctypes.POINTER(MIB_IPADDRTABLE),
ctypes.POINTER(ctypes.c_ulong),
BOOL,
]
]
GetIpAddrTable.restype = DWORD
GetAdaptersAddresses = ctypes.windll.iphlpapi.GetAdaptersAddresses
@ -199,5 +213,5 @@ GetAdaptersAddresses.argtypes = [
ctypes.c_void_p,
ctypes.POINTER(IP_ADAPTER_ADDRESSES),
ctypes.POINTER(ctypes.c_ulong),
]
]
GetAdaptersAddresses.restype = ctypes.c_ulong

View file

@ -3,7 +3,7 @@ import ctypes.wintypes
GMEM_MOVEABLE = 0x2
GlobalAlloc = ctypes.windll.kernel32.GlobalAlloc
GlobalAlloc.argtypes = ctypes.wintypes.UINT, ctypes.c_ssize_t
GlobalAlloc.argtypes = ctypes.wintypes.UINT, ctypes.c_size_t
GlobalAlloc.restype = ctypes.wintypes.HANDLE
GlobalLock = ctypes.windll.kernel32.GlobalLock
@ -31,3 +31,15 @@ CreateFileMapping.restype = ctypes.wintypes.HANDLE
MapViewOfFile = ctypes.windll.kernel32.MapViewOfFile
MapViewOfFile.restype = ctypes.wintypes.HANDLE
UnmapViewOfFile = ctypes.windll.kernel32.UnmapViewOfFile
UnmapViewOfFile.argtypes = ctypes.wintypes.HANDLE,
RtlMoveMemory = ctypes.windll.kernel32.RtlMoveMemory
RtlMoveMemory.argtypes = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
)
ctypes.windll.kernel32.LocalFree.argtypes = ctypes.wintypes.HLOCAL,

View file

@ -13,6 +13,7 @@ import six
LRESULT = LPARAM
class LPARAM_wstr(LPARAM):
"""
A special instance of LPARAM that can be constructed from a string
@ -25,14 +26,16 @@ class LPARAM_wstr(LPARAM):
return LPVOID.from_param(six.text_type(param))
return LPARAM.from_param(param)
SendMessage = ctypes.windll.user32.SendMessageW
SendMessage.argtypes = (HWND, UINT, WPARAM, LPARAM_wstr)
SendMessage.restype = LRESULT
HWND_BROADCAST=0xFFFF
WM_SETTINGCHANGE=0x1A
HWND_BROADCAST = 0xFFFF
WM_SETTINGCHANGE = 0x1A
# constants from http://msdn.microsoft.com/en-us/library/ms644952%28v=vs.85%29.aspx
# constants from http://msdn.microsoft.com
# /en-us/library/ms644952%28v=vs.85%29.aspx
SMTO_ABORTIFHUNG = 0x02
SMTO_BLOCK = 0x01
SMTO_NORMAL = 0x00
@ -45,6 +48,7 @@ SendMessageTimeout.argtypes = SendMessage.argtypes + (
)
SendMessageTimeout.restype = LRESULT
def unicode_as_lparam(source):
pointer = ctypes.cast(ctypes.c_wchar_p(source), ctypes.c_void_p)
return LPARAM(pointer.value)

View file

@ -5,6 +5,7 @@ mpr = ctypes.windll.mpr
RESOURCETYPE_ANY = 0
class NETRESOURCE(ctypes.Structure):
_fields_ = [
('scope', ctypes.wintypes.DWORD),
@ -16,6 +17,8 @@ class NETRESOURCE(ctypes.Structure):
('comment', ctypes.wintypes.LPWSTR),
('provider', ctypes.wintypes.LPWSTR),
]
LPNETRESOURCE = ctypes.POINTER(NETRESOURCE)
WNetAddConnection2 = mpr.WNetAddConnection2W

View file

@ -1,5 +1,6 @@
import ctypes.wintypes
class SYSTEM_POWER_STATUS(ctypes.Structure):
_fields_ = (
('ac_line_status', ctypes.wintypes.BYTE),
@ -8,11 +9,13 @@ class SYSTEM_POWER_STATUS(ctypes.Structure):
('reserved', ctypes.wintypes.BYTE),
('battery_life_time', ctypes.wintypes.DWORD),
('battery_full_life_time', ctypes.wintypes.DWORD),
)
)
@property
def ac_line_status_string(self):
return {0:'offline', 1: 'online', 255: 'unknown'}[self.ac_line_status]
return {
0: 'offline', 1: 'online', 255: 'unknown'}[self.ac_line_status]
LPSYSTEM_POWER_STATUS = ctypes.POINTER(SYSTEM_POWER_STATUS)
GetSystemPowerStatus = ctypes.windll.kernel32.GetSystemPowerStatus
@ -23,6 +26,7 @@ SetThreadExecutionState = ctypes.windll.kernel32.SetThreadExecutionState
SetThreadExecutionState.argtypes = [ctypes.c_uint]
SetThreadExecutionState.restype = ctypes.c_uint
class ES:
"""
Execution state constants

View file

@ -1,5 +1,6 @@
import ctypes.wintypes
class LUID(ctypes.Structure):
_fields_ = [
('low_part', ctypes.wintypes.DWORD),
@ -10,35 +11,39 @@ class LUID(ctypes.Structure):
return (
self.high_part == other.high_part and
self.low_part == other.low_part
)
)
def __ne__(self, other):
return not (self==other)
return not (self == other)
LookupPrivilegeValue = ctypes.windll.advapi32.LookupPrivilegeValueW
LookupPrivilegeValue.argtypes = (
ctypes.wintypes.LPWSTR, # system name
ctypes.wintypes.LPWSTR, # name
ctypes.wintypes.LPWSTR, # system name
ctypes.wintypes.LPWSTR, # name
ctypes.POINTER(LUID),
)
)
LookupPrivilegeValue.restype = ctypes.wintypes.BOOL
class TOKEN_INFORMATION_CLASS:
TokenUser = 1
TokenGroups = 2
TokenPrivileges = 3
# ... see http://msdn.microsoft.com/en-us/library/aa379626%28VS.85%29.aspx
SE_PRIVILEGE_ENABLED_BY_DEFAULT = 0x00000001
SE_PRIVILEGE_ENABLED = 0x00000002
SE_PRIVILEGE_REMOVED = 0x00000004
SE_PRIVILEGE_USED_FOR_ACCESS = 0x80000000
class LUID_AND_ATTRIBUTES(ctypes.Structure):
_fields_ = [
('LUID', LUID),
('attributes', ctypes.wintypes.DWORD),
]
]
def is_enabled(self):
return bool(self.attributes & SE_PRIVILEGE_ENABLED)
@ -50,47 +55,53 @@ class LUID_AND_ATTRIBUTES(ctypes.Structure):
size = ctypes.wintypes.DWORD(10240)
buf = ctypes.create_unicode_buffer(size.value)
res = LookupPrivilegeName(None, self.LUID, buf, size)
if res == 0: raise RuntimeError
if res == 0:
raise RuntimeError
return buf[:size.value]
def __str__(self):
res = self.get_name()
if self.is_enabled(): res += ' (enabled)'
if self.is_enabled():
res += ' (enabled)'
return res
LookupPrivilegeName = ctypes.windll.advapi32.LookupPrivilegeNameW
LookupPrivilegeName.argtypes = (
ctypes.wintypes.LPWSTR, # lpSystemName
ctypes.POINTER(LUID), # lpLuid
ctypes.wintypes.LPWSTR, # lpName
ctypes.POINTER(ctypes.wintypes.DWORD), # cchName
)
ctypes.wintypes.LPWSTR, # lpSystemName
ctypes.POINTER(LUID), # lpLuid
ctypes.wintypes.LPWSTR, # lpName
ctypes.POINTER(ctypes.wintypes.DWORD), # cchName
)
LookupPrivilegeName.restype = ctypes.wintypes.BOOL
class TOKEN_PRIVILEGES(ctypes.Structure):
_fields_ = [
('count', ctypes.wintypes.DWORD),
('privileges', LUID_AND_ATTRIBUTES*0),
]
('privileges', LUID_AND_ATTRIBUTES * 0),
]
def get_array(self):
array_type = LUID_AND_ATTRIBUTES*self.count
privileges = ctypes.cast(self.privileges, ctypes.POINTER(array_type)).contents
array_type = LUID_AND_ATTRIBUTES * self.count
privileges = ctypes.cast(
self.privileges, ctypes.POINTER(array_type)).contents
return privileges
def __iter__(self):
return iter(self.get_array())
PTOKEN_PRIVILEGES = ctypes.POINTER(TOKEN_PRIVILEGES)
GetTokenInformation = ctypes.windll.advapi32.GetTokenInformation
GetTokenInformation.argtypes = [
ctypes.wintypes.HANDLE, # TokenHandle
ctypes.c_uint, # TOKEN_INFORMATION_CLASS value
ctypes.c_void_p, # TokenInformation
ctypes.wintypes.DWORD, # TokenInformationLength
ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength
]
ctypes.wintypes.HANDLE, # TokenHandle
ctypes.c_uint, # TOKEN_INFORMATION_CLASS value
ctypes.c_void_p, # TokenInformation
ctypes.wintypes.DWORD, # TokenInformationLength
ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength
]
GetTokenInformation.restype = ctypes.wintypes.BOOL
# http://msdn.microsoft.com/en-us/library/aa375202%28VS.85%29.aspx
@ -102,5 +113,5 @@ AdjustTokenPrivileges.argtypes = [
PTOKEN_PRIVILEGES, # NewState (optional)
ctypes.wintypes.DWORD, # BufferLength of PreviousState
PTOKEN_PRIVILEGES, # PreviousState (out, optional)
ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength
]
ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength
]

View file

@ -5,5 +5,7 @@ TOKEN_ALL_ACCESS = 0xf01ff
GetCurrentProcess = ctypes.windll.kernel32.GetCurrentProcess
GetCurrentProcess.restype = ctypes.wintypes.HANDLE
OpenProcessToken = ctypes.windll.advapi32.OpenProcessToken
OpenProcessToken.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.POINTER(ctypes.wintypes.HANDLE))
OpenProcessToken.argtypes = (
ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD,
ctypes.POINTER(ctypes.wintypes.HANDLE))
OpenProcessToken.restype = ctypes.wintypes.BOOL

View file

@ -60,12 +60,15 @@ POLICY_EXECUTE = (
POLICY_VIEW_LOCAL_INFORMATION |
POLICY_LOOKUP_NAMES)
class TokenAccess:
TOKEN_QUERY = 0x8
class TokenInformationClass:
TokenUser = 1
class TOKEN_USER(ctypes.Structure):
num = 1
_fields_ = [
@ -100,6 +103,7 @@ class SECURITY_DESCRIPTOR(ctypes.Structure):
('Dacl', ctypes.c_void_p),
]
class SECURITY_ATTRIBUTES(ctypes.Structure):
"""
typedef struct _SECURITY_ATTRIBUTES {
@ -126,3 +130,10 @@ class SECURITY_ATTRIBUTES(ctypes.Structure):
def descriptor(self, value):
self._descriptor = value
self.lpSecurityDescriptor = ctypes.addressof(value)
ctypes.windll.advapi32.SetSecurityDescriptorOwner.argtypes = (
ctypes.POINTER(SECURITY_DESCRIPTOR),
ctypes.c_void_p,
ctypes.wintypes.BOOL,
)

View file

@ -1,6 +1,7 @@
import ctypes.wintypes
BOOL = ctypes.wintypes.BOOL
class SHELLSTATE(ctypes.Structure):
_fields_ = [
('show_all_objects', BOOL, 1),
@ -34,6 +35,7 @@ class SHELLSTATE(ctypes.Structure):
('spare_flags', ctypes.wintypes.UINT, 13),
]
SSF_SHOWALLOBJECTS = 0x00000001
"The fShowAllObjects member is being requested."
@ -62,7 +64,13 @@ SSF_SHOWATTRIBCOL = 0x00000100
"The fShowAttribCol member is being requested. (Windows Vista: Not used.)"
SSF_DESKTOPHTML = 0x00000200
"The fDesktopHTML member is being requested. Set is not available. Instead, for versions of Microsoft Windows prior to Windows XP, enable Desktop HTML by IActiveDesktop. The use of IActiveDesktop for this purpose, however, is not recommended for Windows XP and later versions of Windows, and is deprecated in Windows Vista."
"""
The fDesktopHTML member is being requested. Set is not available.
Instead, for versions of Microsoft Windows prior to Windows XP,
enable Desktop HTML by IActiveDesktop. The use of IActiveDesktop
for this purpose, however, is not recommended for Windows XP and
later versions of Windows, and is deprecated in Windows Vista.
"""
SSF_WIN95CLASSIC = 0x00000400
"The fWin95Classic member is being requested."
@ -117,6 +125,6 @@ SHGetSetSettings = ctypes.windll.shell32.SHGetSetSettings
SHGetSetSettings.argtypes = [
ctypes.POINTER(SHELLSTATE),
ctypes.wintypes.DWORD,
ctypes.wintypes.BOOL, # get or set (True: set)
]
ctypes.wintypes.BOOL, # get or set (True: set)
]
SHGetSetSettings.restype = None

View file

@ -6,7 +6,7 @@ SystemParametersInfo.argtypes = (
ctypes.wintypes.UINT,
ctypes.c_void_p,
ctypes.wintypes.UINT,
)
)
SPI_GETACTIVEWINDOWTRACKING = 0x1000
SPI_SETACTIVEWINDOWTRACKING = 0x1001

View file

@ -5,20 +5,22 @@ import re
import itertools
from contextlib import contextmanager
import io
import six
import ctypes
from ctypes import windll
import six
from six.moves import map
from jaraco.windows.api import clipboard, memory
from jaraco.windows.error import handle_nonzero_success, WindowsError
from jaraco.windows.memory import LockedMemory
__all__ = (
'CF_TEXT', 'GetClipboardData', 'CloseClipboard',
'GetClipboardData', 'CloseClipboard',
'SetClipboardData', 'OpenClipboard',
)
def OpenClipboard(owner=None):
"""
Open the clipboard.
@ -30,9 +32,14 @@ def OpenClipboard(owner=None):
"""
handle_nonzero_success(windll.user32.OpenClipboard(owner))
CloseClipboard = lambda: handle_nonzero_success(windll.user32.CloseClipboard())
def CloseClipboard():
handle_nonzero_success(windll.user32.CloseClipboard())
data_handlers = dict()
def handles(*formats):
def register(func):
for format in formats:
@ -40,36 +47,43 @@ def handles(*formats):
return func
return register
def nts(s):
def nts(buffer):
"""
Null Terminated String
Get the portion of s up to a null character.
Get the portion of bytestring buffer up to a null character.
"""
result, null, rest = s.partition('\x00')
result, null, rest = buffer.partition('\x00')
return result
@handles(clipboard.CF_DIBV5, clipboard.CF_DIB)
def raw_data(handle):
return LockedMemory(handle).data
@handles(clipboard.CF_TEXT)
def text_string(handle):
return nts(raw_data(handle))
@handles(clipboard.CF_UNICODETEXT)
def unicode_string(handle):
return nts(raw_data(handle).decode('utf-16'))
@handles(clipboard.CF_BITMAP)
def as_bitmap(handle):
# handle is HBITMAP
raise NotImplementedError("Can't convert to DIB")
# todo: use GetDIBits http://msdn.microsoft.com/en-us/library/dd144879%28v=VS.85%29.aspx
# todo: use GetDIBits http://msdn.microsoft.com
# /en-us/library/dd144879%28v=VS.85%29.aspx
@handles(clipboard.CF_HTML)
class HTMLSnippet(object):
def __init__(self, handle):
self.data = text_string(handle)
self.data = nts(raw_data(handle).decode('utf-8'))
self.headers = self.parse_headers(self.data)
@property
@ -79,11 +93,13 @@ class HTMLSnippet(object):
@staticmethod
def parse_headers(data):
d = io.StringIO(data)
def header_line(line):
return re.match('(\w+):(.*)', line)
headers = itertools.imap(header_line, d)
headers = map(header_line, d)
# grab headers until they no longer match
headers = itertools.takewhile(bool, headers)
def best_type(value):
try:
return int(value)
@ -101,26 +117,34 @@ class HTMLSnippet(object):
)
return dict(pairs)
def GetClipboardData(type=clipboard.CF_UNICODETEXT):
if not type in data_handlers:
if type not in data_handlers:
raise NotImplementedError("No support for data of type %d" % type)
handle = clipboard.GetClipboardData(type)
if handle is None:
raise TypeError("No clipboard data of type %d" % type)
return data_handlers[type](handle)
EmptyClipboard = lambda: handle_nonzero_success(windll.user32.EmptyClipboard())
def EmptyClipboard():
handle_nonzero_success(windll.user32.EmptyClipboard())
def SetClipboardData(type, content):
"""
Modeled after http://msdn.microsoft.com/en-us/library/ms649016%28VS.85%29.aspx#_win32_Copying_Information_to_the_Clipboard
Modeled after http://msdn.microsoft.com
/en-us/library/ms649016%28VS.85%29.aspx
#_win32_Copying_Information_to_the_Clipboard
"""
allocators = {
clipboard.CF_TEXT: ctypes.create_string_buffer,
clipboard.CF_UNICODETEXT: ctypes.create_unicode_buffer,
clipboard.CF_HTML: ctypes.create_string_buffer,
}
if not type in allocators:
raise NotImplementedError("Only text types are supported at this time")
if type not in allocators:
raise NotImplementedError(
"Only text and HTML types are supported at this time")
# allocate the memory for the data
content = allocators[type](content)
flags = memory.GMEM_MOVEABLE
@ -132,47 +156,57 @@ def SetClipboardData(type, content):
if result is None:
raise WindowsError()
def set_text(source):
with context():
EmptyClipboard()
SetClipboardData(clipboard.CF_TEXT, source)
def get_text():
with context():
result = GetClipboardData(clipboard.CF_TEXT)
return result
def set_unicode_text(source):
with context():
EmptyClipboard()
SetClipboardData(clipboard.CF_UNICODETEXT, source)
def get_unicode_text():
with context():
return GetClipboardData()
def get_html():
with context():
result = GetClipboardData(clipboard.CF_HTML)
return result
def set_html(source):
with context():
EmptyClipboard()
SetClipboardData(clipboard.CF_UNICODETEXT, source)
def get_image():
with context():
return GetClipboardData(clipboard.CF_DIB)
def paste_stdout():
getter = get_unicode_text if six.PY3 else get_text
sys.stdout.write(getter())
def stdin_copy():
setter = set_unicode_text if six.PY3 else set_text
setter(sys.stdin.read())
@contextmanager
def context():
OpenClipboard()
@ -181,10 +215,12 @@ def context():
finally:
CloseClipboard()
def get_formats():
with context():
format_index = 0
while True:
format_index = clipboard.EnumClipboardFormats(format_index)
if format_index == 0: break
if format_index == 0:
break
yield format_index

View file

@ -3,17 +3,20 @@ import ctypes
import jaraco.windows.api.credential as api
from . import error
CRED_TYPE_GENERIC=1
CRED_TYPE_GENERIC = 1
def CredDelete(TargetName, Type, Flags=0):
error.handle_nonzero_success(api.CredDelete(TargetName, Type, Flags))
def CredRead(TargetName, Type, Flags=0):
cred_pointer = api.PCREDENTIAL()
res = api.CredRead(TargetName, Type, Flags, ctypes.byref(cred_pointer))
error.handle_nonzero_success(res)
return cred_pointer.contents
def CredWrite(Credential, Flags=0):
res = api.CredWrite(Credential, Flags)
error.handle_nonzero_success(res)

View file

@ -14,6 +14,11 @@ import ctypes
from ctypes import wintypes
from jaraco.windows.error import handle_nonzero_success
# for type declarations
__import__('jaraco.windows.api.memory')
class DATA_BLOB(ctypes.Structure):
r"""
A data blob structure for use with MS DPAPI functions.
@ -29,7 +34,7 @@ class DATA_BLOB(ctypes.Structure):
_fields_ = [
('data_size', wintypes.DWORD),
('data', ctypes.c_void_p),
]
]
def __init__(self, data=None):
super(DATA_BLOB, self).__init__()
@ -48,7 +53,7 @@ class DATA_BLOB(ctypes.Structure):
def get_data(self):
"Get the data for this blob"
array = ctypes.POINTER(ctypes.c_char*len(self))
array = ctypes.POINTER(ctypes.c_char * len(self))
return ctypes.cast(self.data, array).contents.raw
def __len__(self):
@ -65,38 +70,40 @@ class DATA_BLOB(ctypes.Structure):
"""
ctypes.windll.kernel32.LocalFree(self.data)
p_DATA_BLOB = ctypes.POINTER(DATA_BLOB)
_CryptProtectData = ctypes.windll.crypt32.CryptProtectData
_CryptProtectData.argtypes = [
p_DATA_BLOB, # data in
wintypes.LPCWSTR, # data description
wintypes.LPCWSTR, # data description
p_DATA_BLOB, # optional entropy
ctypes.c_void_p, # reserved
ctypes.c_void_p, # POINTER(CRYPTPROTECT_PROMPTSTRUCT), # prompt struct
wintypes.DWORD, # flags
p_DATA_BLOB, # data out
]
]
_CryptProtectData.restype = wintypes.BOOL
_CryptUnprotectData = ctypes.windll.crypt32.CryptUnprotectData
_CryptUnprotectData.argtypes = [
p_DATA_BLOB, # data in
ctypes.POINTER(wintypes.LPWSTR), # data description
ctypes.POINTER(wintypes.LPWSTR), # data description
p_DATA_BLOB, # optional entropy
ctypes.c_void_p, # reserved
ctypes.c_void_p, # POINTER(CRYPTPROTECT_PROMPTSTRUCT), # prompt struct
wintypes.DWORD, # flags
p_DATA_BLOB, # data out
]
]
_CryptUnprotectData.restype = wintypes.BOOL
CRYPTPROTECT_UI_FORBIDDEN = 0x01
def CryptProtectData(
data, description=None, optional_entropy=None,
prompt_struct=None, flags=0,
):
):
"""
Encrypt data
"""
@ -108,17 +115,19 @@ def CryptProtectData(
data_in,
description,
entropy,
None, # reserved
None, # reserved
prompt_struct,
flags,
data_out,
)
)
handle_nonzero_success(res)
res = data_out.get_data()
data_out.free()
return res
def CryptUnprotectData(data, optional_entropy=None, prompt_struct=None, flags=0):
def CryptUnprotectData(
data, optional_entropy=None, prompt_struct=None, flags=0):
"""
Returns a tuple of (description, data) where description is the
the description that was passed to the CryptProtectData call and
@ -132,11 +141,11 @@ def CryptUnprotectData(data, optional_entropy=None, prompt_struct=None, flags=0)
data_in,
ctypes.byref(ptr_description),
entropy,
None, # reserved
None, # reserved
prompt_struct,
flags | CRYPTPROTECT_UI_FORBIDDEN,
data_out,
)
)
handle_nonzero_success(res)
description = ptr_description.value
if ptr_description.value is not None:

View file

@ -7,7 +7,7 @@ import ctypes
import ctypes.wintypes
import six
winreg = six.moves.winreg
from six.moves import winreg
from jaraco.ui.editor import EditableFile
@ -19,17 +19,21 @@ from .registry import key_values as registry_key_values
def SetEnvironmentVariable(name, value):
error.handle_nonzero_success(environ.SetEnvironmentVariable(name, value))
def ClearEnvironmentVariable(name):
error.handle_nonzero_success(environ.SetEnvironmentVariable(name, None))
def GetEnvironmentVariable(name):
max_size = 2**15-1
max_size = 2**15 - 1
buffer = ctypes.create_unicode_buffer(max_size)
error.handle_nonzero_success(environ.GetEnvironmentVariable(name, buffer, max_size))
error.handle_nonzero_success(
environ.GetEnvironmentVariable(name, buffer, max_size))
return buffer.value
###
class RegisteredEnvironment(object):
"""
Manages the environment variables as set in the Windows Registry.
@ -68,7 +72,7 @@ class RegisteredEnvironment(object):
return class_.delete(name)
do_append = options.append or (
name.upper() in ('PATH', 'PATHEXT') and not options.replace
)
)
if do_append:
sep = ';'
values = class_.get_values_list(name, sep) + [value]
@ -86,8 +90,8 @@ class RegisteredEnvironment(object):
if value in values:
return
new_value = sep.join(values + [value])
winreg.SetValueEx(class_.key, name, 0, winreg.REG_EXPAND_SZ,
new_value)
winreg.SetValueEx(
class_.key, name, 0, winreg.REG_EXPAND_SZ, new_value)
class_.notify()
@classmethod
@ -98,7 +102,7 @@ class RegisteredEnvironment(object):
value
for value in values
if value_substring.lower() not in value.lower()
]
]
values = sep.join(new_values)
winreg.SetValueEx(class_.key, name, 0, winreg.REG_EXPAND_SZ, values)
class_.notify()
@ -133,32 +137,38 @@ class RegisteredEnvironment(object):
res = message.SendMessageTimeout(
message.HWND_BROADCAST,
message.WM_SETTINGCHANGE,
0, # wparam must be null
0, # wparam must be null
'Environment',
message.SMTO_ABORTIFHUNG,
5000, # timeout in ms
5000, # timeout in ms
return_val,
)
error.handle_nonzero_success(res)
class MachineRegisteredEnvironment(RegisteredEnvironment):
path = r'SYSTEM\CurrentControlSet\Control\Session Manager\Environment'
hklm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
try:
key = winreg.OpenKey(hklm, path, 0,
key = winreg.OpenKey(
hklm, path, 0,
winreg.KEY_READ | winreg.KEY_WRITE)
except WindowsError:
key = winreg.OpenKey(hklm, path, 0, winreg.KEY_READ)
class UserRegisteredEnvironment(RegisteredEnvironment):
hkcu = winreg.ConnectRegistry(None, winreg.HKEY_CURRENT_USER)
key = winreg.OpenKey(hkcu, 'Environment', 0,
key = winreg.OpenKey(
hkcu, 'Environment', 0,
winreg.KEY_READ | winreg.KEY_WRITE)
def trim(s):
from textwrap import dedent
return dedent(s).strip()
def enver(*args):
"""
%prog [<name>=[value]]
@ -200,23 +210,27 @@ def enver(*args):
default=MachineRegisteredEnvironment,
dest='class_',
help="Use the current user's environment",
)
parser.add_option('-a', '--append',
)
parser.add_option(
'-a', '--append',
action='store_true', default=False,
help="Append the value to any existing value (default for PATH and PATHEXT)",)
help="Append the value to any existing value (default for PATH and PATHEXT)",
)
parser.add_option(
'-r', '--replace',
action='store_true', default=False,
help="Replace any existing value (used to override default append for PATH and PATHEXT)",
)
help="Replace any existing value (used to override default append "
"for PATH and PATHEXT)",
)
parser.add_option(
'--remove-value', action='store_true', default=False,
help="Remove any matching values from a semicolon-separated multi-value variable",
)
help="Remove any matching values from a semicolon-separated "
"multi-value variable",
)
parser.add_option(
'-e', '--edit', action='store_true', default=False,
help="Edit the value in a local editor",
)
)
options, args = parser.parse_args(*args)
try:
@ -224,7 +238,7 @@ def enver(*args):
if args:
parser.error("Too many parameters specified")
raise SystemExit(1)
if not '=' in param and not options.edit:
if '=' not in param and not options.edit:
parser.error("Expected <name>= or <name>=<value>")
raise SystemExit(2)
name, sep, value = param.partition('=')
@ -238,5 +252,6 @@ def enver(*args):
except IndexError:
options.class_.show()
if __name__ == '__main__':
enver()

View file

@ -6,9 +6,12 @@ import ctypes
import ctypes.wintypes
import six
builtins = six.moves.builtins
__import__('jaraco.windows.api.memory')
def format_system_message(errno):
"""
Call FormatMessage with a system error number to retrieve
@ -35,7 +38,7 @@ def format_system_message(errno):
ctypes.byref(result_buffer),
buffer_size,
arguments,
)
)
# note the following will cause an infinite loop if GetLastError
# repeatedly returns an error that cannot be formatted, although
# this should not happen.
@ -46,13 +49,16 @@ def format_system_message(errno):
class WindowsError(builtins.WindowsError):
"more info about errors at http://msdn.microsoft.com/en-us/library/ms681381(VS.85).aspx"
"""
More info about errors at
http://msdn.microsoft.com/en-us/library/ms681381(VS.85).aspx
"""
def __init__(self, value=None):
if value is None:
value = ctypes.windll.kernel32.GetLastError()
strerror = format_system_message(value)
if sys.version_info > (3,3):
if sys.version_info > (3, 3):
args = 0, strerror, None, value
else:
args = value, strerror
@ -72,6 +78,7 @@ class WindowsError(builtins.WindowsError):
def __repr__(self):
return '{self.__class__.__name__}({self.winerror})'.format(**vars())
def handle_nonzero_success(result):
if result == 0:
raise WindowsError()

View file

@ -6,7 +6,8 @@ import win32api
import win32evtlog
import win32evtlogutil
error = win32api.error # The error the evtlog module raises.
error = win32api.error # The error the evtlog module raises.
class EventLog(object):
def __init__(self, name="Application", machine_name=None):
@ -29,6 +30,7 @@ class EventLog(object):
win32evtlog.EVENTLOG_BACKWARDS_READ
| win32evtlog.EVENTLOG_SEQUENTIAL_READ
)
def get_records(self, flags=_default_flags):
with self:
while True:

View file

@ -7,19 +7,24 @@ import sys
import operator
import collections
import functools
from ctypes import (POINTER, byref, cast, create_unicode_buffer,
import stat
from ctypes import (
POINTER, byref, cast, create_unicode_buffer,
create_string_buffer, windll)
from ctypes.wintypes import LPWSTR
import nt
import posixpath
import six
from six.moves import builtins, filter, map
from jaraco.structures import binary
from jaraco.text import local_format as lf
from jaraco.windows.error import WindowsError, handle_nonzero_success
import jaraco.windows.api.filesystem as api
from jaraco.windows import reparse
def mklink():
"""
Like cmd.exe's mklink except it will infer directory status of the
@ -27,7 +32,8 @@ def mklink():
"""
from optparse import OptionParser
parser = OptionParser(usage="usage: %prog [options] link target")
parser.add_option('-d', '--directory',
parser.add_option(
'-d', '--directory',
help="Target is a directory (only necessary if not present)",
action="store_true")
options, args = parser.parse_args()
@ -38,6 +44,7 @@ def mklink():
symlink(target, link, options.directory)
sys.stdout.write("Symbolic link created: %(link)s --> %(target)s\n" % vars())
def _is_target_a_directory(link, rel_target):
"""
If creating a symlink from link to a target, determine if target
@ -46,15 +53,20 @@ def _is_target_a_directory(link, rel_target):
target = os.path.join(os.path.dirname(link), rel_target)
return os.path.isdir(target)
def symlink(target, link, target_is_directory = False):
def symlink(target, link, target_is_directory=False):
"""
An implementation of os.symlink for Windows (Vista and greater)
"""
target_is_directory = (target_is_directory or
_is_target_a_directory(link, target))
target_is_directory = (
target_is_directory or
_is_target_a_directory(link, target)
)
# normalize the target (MS symlinks don't respect forward slashes)
target = os.path.normpath(target)
handle_nonzero_success(api.CreateSymbolicLink(link, target, target_is_directory))
handle_nonzero_success(
api.CreateSymbolicLink(link, target, target_is_directory))
def link(target, link):
"""
@ -62,6 +74,7 @@ def link(target, link):
"""
handle_nonzero_success(api.CreateHardLink(link, target, None))
def is_reparse_point(path):
"""
Determine if the given path is a reparse point.
@ -74,10 +87,12 @@ def is_reparse_point(path):
and bool(res & api.FILE_ATTRIBUTE_REPARSE_POINT)
)
def islink(path):
"Determine if the given path is a symlink"
return is_reparse_point(path) and is_symlink(path)
def _patch_path(path):
"""
Paths have a max length of api.MAX_PATH characters (260). If a target path
@ -86,13 +101,15 @@ def _patch_path(path):
See http://msdn.microsoft.com/en-us/library/aa365247%28v=vs.85%29.aspx for
details.
"""
if path.startswith('\\\\?\\'): return path
if path.startswith('\\\\?\\'):
return path
abs_path = os.path.abspath(path)
if not abs_path[1] == ':':
# python doesn't include the drive letter, but \\?\ requires it
abs_path = os.getcwd()[:2] + abs_path
return '\\\\?\\' + abs_path
def is_symlink(path):
"""
Assuming path is a reparse point, determine if it's a symlink.
@ -102,11 +119,13 @@ def is_symlink(path):
return _is_symlink(next(find_files(path)))
except WindowsError as orig_error:
tmpl = "Error accessing {path}: {orig_error.message}"
raise builtins.WindowsError(lf(tmpl))
raise builtins.WindowsError(tmpl.format(**locals()))
def _is_symlink(find_data):
return find_data.reserved[0] == api.IO_REPARSE_TAG_SYMLINK
def find_files(spec):
"""
A pythonic wrapper around the FindFirstFile/FindNextFile win32 api.
@ -133,11 +152,13 @@ def find_files(spec):
error = WindowsError()
if error.code == api.ERROR_NO_MORE_FILES:
break
else: raise error
else:
raise error
# todo: how to close handle when generator is destroyed?
# hint: catch GeneratorExit
windll.kernel32.FindClose(handle)
def get_final_path(path):
"""
For a given path, determine the ultimate location of that path.
@ -150,7 +171,9 @@ def get_final_path(path):
trace_symlink_target instead.
"""
desired_access = api.NULL
share_mode = api.FILE_SHARE_READ | api.FILE_SHARE_WRITE | api.FILE_SHARE_DELETE
share_mode = (
api.FILE_SHARE_READ | api.FILE_SHARE_WRITE | api.FILE_SHARE_DELETE
)
security_attributes = api.LPSECURITY_ATTRIBUTES() # NULL pointer
hFile = api.CreateFile(
path,
@ -160,15 +183,17 @@ def get_final_path(path):
api.OPEN_EXISTING,
api.FILE_FLAG_BACKUP_SEMANTICS,
api.NULL,
)
)
if hFile == api.INVALID_HANDLE_VALUE:
raise WindowsError()
buf_size = api.GetFinalPathNameByHandle(hFile, api.LPWSTR(), 0, api.VOLUME_NAME_DOS)
buf_size = api.GetFinalPathNameByHandle(
hFile, LPWSTR(), 0, api.VOLUME_NAME_DOS)
handle_nonzero_success(buf_size)
buf = create_unicode_buffer(buf_size)
result_length = api.GetFinalPathNameByHandle(hFile, buf, len(buf), api.VOLUME_NAME_DOS)
result_length = api.GetFinalPathNameByHandle(
hFile, buf, len(buf), api.VOLUME_NAME_DOS)
assert result_length < len(buf)
handle_nonzero_success(result_length)
@ -176,22 +201,83 @@ def get_final_path(path):
return buf[:result_length]
def compat_stat(path):
"""
Generate stat as found on Python 3.2 and later.
"""
stat = os.stat(path)
info = get_file_info(path)
# rewrite st_ino, st_dev, and st_nlink based on file info
return nt.stat_result(
(stat.st_mode,) +
(info.file_index, info.volume_serial_number, info.number_of_links) +
stat[4:]
)
def samefile(f1, f2):
"""
Backport of samefile from Python 3.2 with support for Windows.
"""
return posixpath.samestat(compat_stat(f1), compat_stat(f2))
def get_file_info(path):
# open the file the same way CPython does in posixmodule.c
desired_access = api.FILE_READ_ATTRIBUTES
share_mode = 0
security_attributes = None
creation_disposition = api.OPEN_EXISTING
flags_and_attributes = (
api.FILE_ATTRIBUTE_NORMAL |
api.FILE_FLAG_BACKUP_SEMANTICS |
api.FILE_FLAG_OPEN_REPARSE_POINT
)
template_file = None
handle = api.CreateFile(
path,
desired_access,
share_mode,
security_attributes,
creation_disposition,
flags_and_attributes,
template_file,
)
if handle == api.INVALID_HANDLE_VALUE:
raise WindowsError()
info = api.BY_HANDLE_FILE_INFORMATION()
res = api.GetFileInformationByHandle(handle, info)
handle_nonzero_success(res)
handle_nonzero_success(api.CloseHandle(handle))
return info
def GetBinaryType(filepath):
res = api.DWORD()
handle_nonzero_success(api._GetBinaryType(filepath, res))
return res
def _make_null_terminated_list(obs):
obs = _makelist(obs)
if obs is None: return
if obs is None:
return
return u'\x00'.join(obs) + u'\x00\x00'
def _makelist(ob):
if ob is None: return
if ob is None:
return
if not isinstance(ob, (list, tuple, set)):
return [ob]
return ob
def SHFileOperation(operation, from_, to=None, flags=[]):
flags = functools.reduce(operator.or_, flags, 0)
from_ = _make_null_terminated_list(from_)
@ -201,6 +287,7 @@ def SHFileOperation(operation, from_, to=None, flags=[]):
if res != 0:
raise RuntimeError("SHFileOperation returned %d" % res)
def join(*paths):
r"""
Wrapper around os.path.join that works with Windows drive letters.
@ -214,17 +301,17 @@ def join(*paths):
drive = next(filter(None, reversed(drives)), '')
return os.path.join(drive, os.path.join(*paths))
def resolve_path(target, start=os.path.curdir):
r"""
Find a path from start to target where target is relative to start.
>>> orig_wd = os.getcwd()
>>> os.chdir('c:\\windows') # so we know what the working directory is
>>> tmp = str(getfixture('tmpdir_as_cwd'))
>>> findpath('d:\\')
'd:\\'
>>> findpath('d:\\', 'c:\\windows')
>>> findpath('d:\\', tmp)
'd:\\'
>>> findpath('\\bar', 'd:\\')
@ -239,11 +326,11 @@ def resolve_path(target, start=os.path.curdir):
>>> findpath('\\baz', 'd:\\foo\\bar') # fails with '\\baz'
'd:\\baz'
>>> os.path.abspath(findpath('\\bar'))
>>> os.path.abspath(findpath('\\bar')).lower()
'c:\\bar'
>>> os.path.abspath(findpath('bar'))
'c:\\windows\\bar'
'...\\bar'
>>> findpath('..', 'd:\\foo\\bar')
'd:\\foo'
@ -254,8 +341,10 @@ def resolve_path(target, start=os.path.curdir):
"""
return os.path.normpath(join(start, target))
findpath = resolve_path
def trace_symlink_target(link):
"""
Given a file that is known to be a symlink, trace it to its ultimate
@ -273,6 +362,7 @@ def trace_symlink_target(link):
link = resolve_path(link, orig)
return link
def readlink(link):
"""
readlink(link) -> target
@ -286,12 +376,13 @@ def readlink(link):
api.OPEN_EXISTING,
api.FILE_FLAG_OPEN_REPARSE_POINT | api.FILE_FLAG_BACKUP_SEMANTICS,
None,
)
)
if handle == api.INVALID_HANDLE_VALUE:
raise WindowsError()
res = reparse.DeviceIoControl(handle, api.FSCTL_GET_REPARSE_POINT, None, 10240)
res = reparse.DeviceIoControl(
handle, api.FSCTL_GET_REPARSE_POINT, None, 10240)
bytes = create_string_buffer(res)
p_rdb = cast(bytes, POINTER(api.REPARSE_DATA_BUFFER))
@ -302,6 +393,7 @@ def readlink(link):
handle_nonzero_success(api.CloseHandle(handle))
return rdb.get_substitute_name()
def patch_os_module():
"""
jaraco.windows provides the os.symlink and os.readlink functions.
@ -313,6 +405,7 @@ def patch_os_module():
if not hasattr(os, 'readlink'):
os.readlink = readlink
def find_symlinks(root):
for dirpath, dirnames, filenames in os.walk(root):
for name in dirnames + filenames:
@ -323,6 +416,7 @@ def find_symlinks(root):
if name in dirnames:
dirnames.remove(name)
def find_symlinks_cmd():
"""
%prog [start-path]
@ -333,7 +427,8 @@ def find_symlinks_cmd():
from textwrap import dedent
parser = OptionParser(usage=dedent(find_symlinks_cmd.__doc__).strip())
options, args = parser.parse_args()
if not args: args = ['.']
if not args:
args = ['.']
root = args.pop()
if args:
parser.error("unexpected argument(s)")
@ -346,8 +441,19 @@ def find_symlinks_cmd():
except KeyboardInterrupt:
pass
@six.add_metaclass(binary.BitMask)
class FileAttributes(int):
# extract the values from the stat module on Python 3.5
# and later.
locals().update(
(name.split('FILE_ATTRIBUTES_')[1].lower(), value)
for name, value in vars(stat).items()
if name.startswith('FILE_ATTRIBUTES_')
)
# For Python 3.4 and earlier, define the constants here
archive = 0x20
compressed = 0x800
hidden = 0x2
@ -364,11 +470,16 @@ class FileAttributes(int):
temporary = 0x100
virtual = 0x10000
def GetFileAttributes(filepath):
attrs = api.GetFileAttributes(filepath)
if attrs == api.INVALID_FILE_ATTRIBUTES:
raise WindowsError()
return FileAttributes(attrs)
@classmethod
def get(cls, filepath):
attrs = api.GetFileAttributes(filepath)
if attrs == api.INVALID_FILE_ATTRIBUTES:
raise WindowsError()
return cls(attrs)
GetFileAttributes = FileAttributes.get
def SetFileAttributes(filepath, *attrs):
"""
@ -382,8 +493,8 @@ def SetFileAttributes(filepath, *attrs):
"""
nice_names = collections.defaultdict(
lambda key: key,
hidden = 'FILE_ATTRIBUTE_HIDDEN',
read_only = 'FILE_ATTRIBUTE_READONLY',
hidden='FILE_ATTRIBUTE_HIDDEN',
read_only='FILE_ATTRIBUTE_READONLY',
)
flags = (getattr(api, nice_names[attr], attr) for attr in attrs)
flags = functools.reduce(operator.or_, flags)

View file

@ -0,0 +1,109 @@
from __future__ import unicode_literals
import os.path
# realpath taken from https://bugs.python.org/file38057/issue9949-v4.patch
def realpath(path):
if isinstance(path, str):
prefix = '\\\\?\\'
unc_prefix = prefix + 'UNC'
new_unc_prefix = '\\'
cwd = os.getcwd()
else:
prefix = b'\\\\?\\'
unc_prefix = prefix + b'UNC'
new_unc_prefix = b'\\'
cwd = os.getcwdb()
had_prefix = path.startswith(prefix)
path, ok = _resolve_path(cwd, path, {})
# The path returned by _getfinalpathname will always start with \\?\ -
# strip off that prefix unless it was already provided on the original
# path.
if not had_prefix:
# For UNC paths, the prefix will actually be \\?\UNC - handle that
# case as well.
if path.startswith(unc_prefix):
path = new_unc_prefix + path[len(unc_prefix):]
elif path.startswith(prefix):
path = path[len(prefix):]
return path
def _resolve_path(path, rest, seen):
# Windows normalizes the path before resolving symlinks; be sure to
# follow the same behavior.
rest = os.path.normpath(rest)
if isinstance(rest, str):
sep = '\\'
else:
sep = b'\\'
if os.path.isabs(rest):
drive, rest = os.path.splitdrive(rest)
path = drive + sep
rest = rest[1:]
while rest:
name, _, rest = rest.partition(sep)
new_path = os.path.join(path, name) if path else name
if os.path.exists(new_path):
if not rest:
# The whole path exists. Resolve it using the OS.
path = os.path._getfinalpathname(new_path)
else:
# The OS can resolve `new_path`; keep traversing the path.
path = new_path
elif not os.path.lexists(new_path):
# `new_path` does not exist on the filesystem at all. Use the
# OS to resolve `path`, if it exists, and then append the
# remainder.
if os.path.exists(path):
path = os.path._getfinalpathname(path)
rest = os.path.join(name, rest) if rest else name
return os.path.join(path, rest), True
else:
# We have a symbolic link that the OS cannot resolve. Try to
# resolve it ourselves.
# On Windows, symbolic link resolution can be partially or
# fully disabled [1]. The end result of a disabled symlink
# appears the same as a broken symlink (lexists() returns True
# but exists() returns False). And in both cases, the link can
# still be read using readlink(). Call stat() and check the
# resulting error code to ensure we don't circumvent the
# Windows symbolic link restrictions.
# [1] https://technet.microsoft.com/en-us/library/cc754077.aspx
try:
os.stat(new_path)
except OSError as e:
# WinError 1463: The symbolic link cannot be followed
# because its type is disabled.
if e.winerror == 1463:
raise
key = os.path.normcase(new_path)
if key in seen:
# This link has already been seen; try to use the
# previously resolved value.
path = seen[key]
if path is None:
# It has not yet been resolved, which means we must
# have a symbolic link loop. Return what we have
# resolved so far plus the remainder of the path (who
# cares about the Zen of Python?).
path = os.path.join(new_path, rest) if rest else new_path
return path, False
else:
# Mark this link as in the process of being resolved.
seen[key] = None
# Try to resolve it.
path, ok = _resolve_path(path, os.readlink(new_path), seen)
if ok:
# Resolution succeded; store the resolved value.
seen[key] = path
else:
# Resolution failed; punt.
return (os.path.join(path, rest) if rest else path), False
return path, True

View file

@ -17,6 +17,8 @@ from threading import Thread
import itertools
import logging
import six
from more_itertools.recipes import consume
import jaraco.text
@ -25,9 +27,11 @@ from jaraco.windows.api import event
log = logging.getLogger(__name__)
class NotifierException(Exception):
pass
class FileFilter(object):
def set_root(self, root):
self.root = root
@ -35,9 +39,11 @@ class FileFilter(object):
def _get_file_path(self, filename):
try:
filename = os.path.join(self.root, filename)
except AttributeError: pass
except AttributeError:
pass
return filename
class ModifiedTimeFilter(FileFilter):
"""
Returns true for each call where the modified time of the file is after
@ -53,6 +59,7 @@ class ModifiedTimeFilter(FileFilter):
log.debug('{filepath} last modified at {last_mod}.'.format(**vars()))
return last_mod > self.cutoff
class PatternFilter(FileFilter):
"""
Filter that returns True for files that match pattern (a regular
@ -60,13 +67,14 @@ class PatternFilter(FileFilter):
"""
def __init__(self, pattern):
self.pattern = (
re.compile(pattern) if isinstance(pattern, basestring)
re.compile(pattern) if isinstance(pattern, six.string_types)
else pattern
)
def __call__(self, file):
return bool(self.pattern.match(file, re.I))
class GlobFilter(PatternFilter):
"""
Filter that returns True for files that match the pattern (a glob
@ -102,6 +110,7 @@ class AggregateFilter(FileFilter):
def __call__(self, file):
return all(fil(file) for fil in self.filters)
class OncePerModFilter(FileFilter):
def __init__(self):
self.history = list()
@ -115,15 +124,18 @@ class OncePerModFilter(FileFilter):
del self.history[-50:]
return result
def files_with_path(files, path):
return (os.path.join(path, file) for file in files)
def get_file_paths(walk_result):
root, dirs, files = walk_result
return files_with_path(files, root)
class Notifier(object):
def __init__(self, root = '.', filters = []):
def __init__(self, root='.', filters=[]):
# assign the root, verify it exists
self.root = root
if not os.path.isdir(self.root):
@ -138,7 +150,8 @@ class Notifier(object):
def __del__(self):
try:
fs.FindCloseChangeNotification(self.hChange)
except: pass
except Exception:
pass
def _get_change_handle(self):
# set up to monitor the directory tree specified
@ -151,8 +164,8 @@ class Notifier(object):
# make sure it worked; if not, bail
INVALID_HANDLE_VALUE = fs.INVALID_HANDLE_VALUE
if self.hChange == INVALID_HANDLE_VALUE:
raise NotifierException('Could not set up directory change '
'notification')
raise NotifierException(
'Could not set up directory change notification')
@staticmethod
def _filtered_walk(path, file_filter):
@ -171,6 +184,7 @@ class Notifier(object):
def quit(self):
event.SetEvent(self.quit_event)
class BlockingNotifier(Notifier):
@staticmethod
@ -215,17 +229,18 @@ class BlockingNotifier(Notifier):
result = next(results)
return result
class ThreadedNotifier(BlockingNotifier, Thread):
r"""
ThreadedNotifier provides a simple interface that calls the handler
for each file rooted in root that passes the filters. It runs as its own
thread, so must be started as such.
thread, so must be started as such::
>>> notifier = ThreadedNotifier('c:\\', handler = StreamHandler()) # doctest: +SKIP
>>> notifier.start() # doctest: +SKIP
C:\Autoexec.bat changed.
notifier = ThreadedNotifier('c:\\', handler = StreamHandler())
notifier.start()
C:\Autoexec.bat changed.
"""
def __init__(self, root = '.', filters = [], handler = lambda file: None):
def __init__(self, root='.', filters=[], handler=lambda file: None):
# init notifier stuff
BlockingNotifier.__init__(self, root, filters)
# init thread stuff
@ -242,13 +257,14 @@ class ThreadedNotifier(BlockingNotifier, Thread):
for file in self.get_changed_files():
self.handle(file)
class StreamHandler(object):
"""
StreamHandler: a sample handler object for use with the threaded
notifier that will announce by writing to the supplied stream
(stdout by default) the name of the file.
"""
def __init__(self, output = sys.stdout):
def __init__(self, output=sys.stdout):
self.output = output
def __call__(self, filename):

View file

@ -14,20 +14,21 @@ from jaraco.windows.api import errors, inet
def GetAdaptersAddresses():
size = ctypes.c_ulong()
res = inet.GetAdaptersAddresses(0,0,None, None,size)
res = inet.GetAdaptersAddresses(0, 0, None, None, size)
if res != errors.ERROR_BUFFER_OVERFLOW:
raise RuntimeError("Error getting structure length (%d)" % res)
print(size.value)
pointer_type = ctypes.POINTER(inet.IP_ADAPTER_ADDRESSES)
buffer = ctypes.create_string_buffer(size.value)
struct_p = ctypes.cast(buffer, pointer_type)
res = inet.GetAdaptersAddresses(0,0,None, struct_p, size)
res = inet.GetAdaptersAddresses(0, 0, None, struct_p, size)
if res != errors.NO_ERROR:
raise RuntimeError("Error retrieving table (%d)" % res)
while struct_p:
yield struct_p.contents
struct_p = struct_p.contents.next
class AllocatedTable(object):
"""
Both the interface table and the ip address table use the same
@ -79,20 +80,23 @@ class AllocatedTable(object):
on the table size.
"""
table = self.get_table()
entries_array = self.row_structure*table.num_entries
entries_array = self.row_structure * table.num_entries
pointer_type = ctypes.POINTER(entries_array)
return ctypes.cast(table.entries, pointer_type).contents
class InterfaceTable(AllocatedTable):
method = inet.GetIfTable
structure = inet.MIB_IFTABLE
row_structure = inet.MIB_IFROW
class AddressTable(AllocatedTable):
method = inet.GetIpAddrTable
structure = inet.MIB_IPADDRTABLE
row_structure = inet.MIB_IPADDRROW
class AddressManager(object):
@staticmethod
def hardware_address_to_string(addr):
@ -100,7 +104,8 @@ class AddressManager(object):
return ':'.join(hex_bytes)
def get_host_mac_address_strings(self):
return (self.hardware_address_to_string(addr)
return (
self.hardware_address_to_string(addr)
for addr in self.get_host_mac_addresses())
def get_host_ip_address_strings(self):
@ -110,10 +115,10 @@ class AddressManager(object):
return (
entry.physical_address
for entry in InterfaceTable().entries
)
)
def get_host_ip_addresses(self):
return (
entry.address
for entry in AddressTable().entries
)
)

View file

@ -2,6 +2,7 @@ import ctypes
from .api import library
def find_lib(lib):
r"""
Find the DLL for a given library.

View file

@ -3,6 +3,7 @@ from ctypes import WinError
from .api import memory
class LockedMemory(object):
def __init__(self, handle):
self.handle = handle

View file

@ -5,6 +5,7 @@ import six
from .error import handle_nonzero_success
from .api import memory
class MemoryMap(object):
"""
A memory map object which can have security attributes overridden.

View file

@ -10,13 +10,15 @@ import itertools
import six
class CookieMonster(object):
"Read cookies out of a user's IE cookies file"
@property
def cookie_dir(self):
import _winreg as winreg
key = winreg.OpenKeyEx(winreg.HKEY_CURRENT_USER, 'Software'
key = winreg.OpenKeyEx(
winreg.HKEY_CURRENT_USER, 'Software'
'\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders')
cookie_dir, type = winreg.QueryValueEx(key, 'Cookies')
return cookie_dir
@ -24,10 +26,12 @@ class CookieMonster(object):
def entries(self, filename):
with open(os.path.join(self.cookie_dir, filename)) as cookie_file:
while True:
entry = itertools.takewhile(self.is_not_cookie_delimiter,
entry = itertools.takewhile(
self.is_not_cookie_delimiter,
cookie_file)
entry = list(map(six.text_type.rstrip, entry))
if not entry: break
if not entry:
break
cookie = self.make_cookie(*entry)
yield cookie
@ -36,8 +40,9 @@ class CookieMonster(object):
return s != '*\n'
@staticmethod
def make_cookie(key, value, domain, flags, ExpireLow, ExpireHigh,
CreateLow, CreateHigh):
def make_cookie(
key, value, domain, flags, ExpireLow, ExpireHigh,
CreateLow, CreateHigh):
expires = (int(ExpireHigh) << 32) | int(ExpireLow)
created = (int(CreateHigh) << 32) | int(CreateLow)
flags = int(flags)
@ -51,4 +56,4 @@ class CookieMonster(object):
expires=expires,
created=created,
path=path,
)
)

View file

@ -7,8 +7,10 @@ __all__ = ('AddConnection')
from jaraco.windows.error import WindowsError
from .api import net
def AddConnection(remote_name, type=net.RESOURCETYPE_ANY, local_name=None,
provider_name=None, user=None, password=None, flags=0):
def AddConnection(
remote_name, type=net.RESOURCETYPE_ANY, local_name=None,
provider_name=None, user=None, password=None, flags=0):
resource = net.NETRESOURCE(
type=type,
remote_name=remote_name,

View file

@ -1,10 +1,9 @@
#-*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from __future__ import print_function
import itertools
import contextlib
import ctypes
from more_itertools.recipes import consume, unique_justseen
try:
@ -15,11 +14,13 @@ except ImportError:
from jaraco.windows.error import handle_nonzero_success
from .api import power
def GetSystemPowerStatus():
stat = power.SYSTEM_POWER_STATUS()
handle_nonzero_success(GetSystemPowerStatus(stat))
return stat
def _init_power_watcher():
global power_watcher
if 'power_watcher' not in globals():
@ -27,18 +28,22 @@ def _init_power_watcher():
query = 'SELECT * from Win32_PowerManagementEvent'
power_watcher = wmi.ExecNotificationQuery(query)
def get_power_management_events():
_init_power_watcher()
while True:
yield power_watcher.NextEvent()
def wait_for_power_status_change():
EVT_POWER_STATUS_CHANGE = 10
def not_power_status_change(evt):
return evt.EventType != EVT_POWER_STATUS_CHANGE
events = get_power_management_events()
consume(itertools.takewhile(not_power_status_change, events))
def get_unique_power_states():
"""
Just like get_power_states, but ensures values are returned only
@ -46,6 +51,7 @@ def get_unique_power_states():
"""
return unique_justseen(get_power_states())
def get_power_states():
"""
Continuously return the power state of the system when it changes.
@ -57,6 +63,7 @@ def get_power_states():
yield state.ac_line_status_string
wait_for_power_status_change()
@contextlib.contextmanager
def no_sleep():
"""

View file

@ -7,24 +7,31 @@ from .api import security
from .api import privilege
from .api import process
def get_process_token():
"""
Get the current process token
"""
token = wintypes.HANDLE()
res = process.OpenProcessToken(process.GetCurrentProcess(), process.TOKEN_ALL_ACCESS, token)
if not res > 0: raise RuntimeError("Couldn't get process token")
res = process.OpenProcessToken(
process.GetCurrentProcess(), process.TOKEN_ALL_ACCESS, token)
if not res > 0:
raise RuntimeError("Couldn't get process token")
return token
def get_symlink_luid():
"""
Get the LUID for the SeCreateSymbolicLinkPrivilege
"""
symlink_luid = privilege.LUID()
res = privilege.LookupPrivilegeValue(None, "SeCreateSymbolicLinkPrivilege", symlink_luid)
if not res > 0: raise RuntimeError("Couldn't lookup privilege value")
res = privilege.LookupPrivilegeValue(
None, "SeCreateSymbolicLinkPrivilege", symlink_luid)
if not res > 0:
raise RuntimeError("Couldn't lookup privilege value")
return symlink_luid
def get_privilege_information():
"""
Get all privileges associated with the current process.
@ -38,7 +45,7 @@ def get_privilege_information():
None,
0,
return_length,
]
]
res = privilege.GetTokenInformation(*params)
@ -51,9 +58,11 @@ def get_privilege_information():
res = privilege.GetTokenInformation(*params)
assert res > 0, "Error in second GetTokenInformation (%d)" % res
privileges = ctypes.cast(buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES)).contents
privileges = ctypes.cast(
buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES)).contents
return privileges
def report_privilege_information():
"""
Report all privilege information assigned to the current process.
@ -62,6 +71,7 @@ def report_privilege_information():
print("found {0} privileges".format(privileges.count))
tuple(map(print, privileges))
def enable_symlink_privilege():
"""
Try to assign the symlink privilege to the current process token.
@ -84,9 +94,11 @@ def enable_symlink_privilege():
ERROR_NOT_ALL_ASSIGNED = 1300
return ctypes.windll.kernel32.GetLastError() != ERROR_NOT_ALL_ASSIGNED
class PolicyHandle(wintypes.HANDLE):
pass
class LSA_UNICODE_STRING(ctypes.Structure):
_fields_ = [
('length', ctypes.c_ushort),
@ -94,15 +106,20 @@ class LSA_UNICODE_STRING(ctypes.Structure):
('buffer', ctypes.wintypes.LPWSTR),
]
def OpenPolicy(system_name, object_attributes, access_mask):
policy = PolicyHandle()
raise NotImplementedError("Need to construct structures for parameters "
"(see http://msdn.microsoft.com/en-us/library/windows/desktop/aa378299%28v=vs.85%29.aspx)")
res = ctypes.windll.advapi32.LsaOpenPolicy(system_name, object_attributes,
raise NotImplementedError(
"Need to construct structures for parameters "
"(see http://msdn.microsoft.com/en-us/library/windows"
"/desktop/aa378299%28v=vs.85%29.aspx)")
res = ctypes.windll.advapi32.LsaOpenPolicy(
system_name, object_attributes,
access_mask, ctypes.byref(policy))
assert res == 0, "Error status {res}".format(**vars())
return policy
def grant_symlink_privilege(who, machine=''):
"""
Grant the 'create symlink' privilege to who.
@ -113,10 +130,13 @@ def grant_symlink_privilege(who, machine=''):
policy = OpenPolicy(machine, flags)
return policy
def main():
assigned = enable_symlink_privilege()
msg = ['failure', 'success'][assigned]
print("Symlink privilege assignment completed with {0}".format(msg))
if __name__ == '__main__': main()
if __name__ == '__main__':
main()

View file

@ -3,6 +3,7 @@ from itertools import count
import six
winreg = six.moves.winreg
def key_values(key):
for index in count():
try:
@ -10,6 +11,7 @@ def key_values(key):
except WindowsError:
break
def key_subkeys(key):
for index in count():
try:

View file

@ -5,7 +5,9 @@ import ctypes.wintypes
from .error import handle_nonzero_success
from .api import filesystem
def DeviceIoControl(device, io_control_code, in_buffer, out_buffer, overlapped=None):
def DeviceIoControl(
device, io_control_code, in_buffer, out_buffer, overlapped=None):
if overlapped is not None:
raise NotImplementedError("overlapped handles not yet supported")
@ -25,7 +27,7 @@ def DeviceIoControl(device, io_control_code, in_buffer, out_buffer, overlapped=N
out_buffer, out_buffer_size,
returned_bytes,
overlapped,
)
)
handle_nonzero_success(res)
handle_nonzero_success(returned_bytes)

View file

@ -3,20 +3,24 @@ import ctypes.wintypes
from jaraco.windows.error import handle_nonzero_success
from .api import security
def GetTokenInformation(token, information_class):
"""
Given a token, get the token information for it.
"""
data_size = ctypes.wintypes.DWORD()
ctypes.windll.advapi32.GetTokenInformation(token, information_class.num,
ctypes.windll.advapi32.GetTokenInformation(
token, information_class.num,
0, 0, ctypes.byref(data_size))
data = ctypes.create_string_buffer(data_size.value)
handle_nonzero_success(ctypes.windll.advapi32.GetTokenInformation(token,
handle_nonzero_success(ctypes.windll.advapi32.GetTokenInformation(
token,
information_class.num,
ctypes.byref(data), ctypes.sizeof(data),
ctypes.byref(data_size)))
return ctypes.cast(data, ctypes.POINTER(security.TOKEN_USER)).contents
def OpenProcessToken(proc_handle, access):
result = ctypes.wintypes.HANDLE()
proc_handle = ctypes.wintypes.HANDLE(proc_handle)
@ -24,6 +28,7 @@ def OpenProcessToken(proc_handle, access):
proc_handle, access, ctypes.byref(result)))
return result
def get_current_user():
"""
Return a TOKEN_USER for the owner of this process.
@ -34,6 +39,7 @@ def get_current_user():
)
return GetTokenInformation(process, security.TOKEN_USER)
def get_security_attributes_for_user(user=None):
"""
Return a SECURITY_ATTRIBUTES structure with the SID set to the
@ -42,7 +48,8 @@ def get_security_attributes_for_user(user=None):
if user is None:
user = get_current_user()
assert isinstance(user, security.TOKEN_USER), "user must be TOKEN_USER instance"
assert isinstance(user, security.TOKEN_USER), (
"user must be TOKEN_USER instance")
SD = security.SECURITY_DESCRIPTOR()
SA = security.SECURITY_ATTRIBUTES()
@ -51,8 +58,10 @@ def get_security_attributes_for_user(user=None):
SA.descriptor = SD
SA.bInheritHandle = 1
ctypes.windll.advapi32.InitializeSecurityDescriptor(ctypes.byref(SD),
ctypes.windll.advapi32.InitializeSecurityDescriptor(
ctypes.byref(SD),
security.SECURITY_DESCRIPTOR.REVISION)
ctypes.windll.advapi32.SetSecurityDescriptorOwner(ctypes.byref(SD),
ctypes.windll.advapi32.SetSecurityDescriptorOwner(
ctypes.byref(SD),
user.SID, 0)
return SA

View file

@ -1,7 +1,8 @@
"""
Windows Services support for controlling Windows Services.
Based on http://code.activestate.com/recipes/115875-controlling-windows-services/
Based on http://code.activestate.com
/recipes/115875-controlling-windows-services/
"""
from __future__ import print_function
@ -13,6 +14,7 @@ import win32api
import win32con
import win32service
class Service(object):
"""
The Service Class is used for controlling Windows
@ -47,7 +49,8 @@ class Service(object):
pause: Pauses service (Only if service supports feature).
resume: Resumes service that has been paused.
status: Queries current status of service.
fetchstatus: Continually queries service until requested status(STARTING, RUNNING,
fetchstatus: Continually queries service until requested
status(STARTING, RUNNING,
STOPPING & STOPPED) is met or timeout value(in seconds) reached.
Default timeout value is infinite.
infotype: Queries service for process type. (Single, shared and/or
@ -64,18 +67,21 @@ class Service(object):
def __init__(self, service, machinename=None, dbname=None):
self.userv = service
self.scmhandle = win32service.OpenSCManager(machinename, dbname, win32service.SC_MANAGER_ALL_ACCESS)
self.scmhandle = win32service.OpenSCManager(
machinename, dbname, win32service.SC_MANAGER_ALL_ACCESS)
self.sserv, self.lserv = self.getname()
if (self.sserv or self.lserv) is None:
sys.exit()
self.handle = win32service.OpenService(self.scmhandle, self.sserv, win32service.SERVICE_ALL_ACCESS)
self.handle = win32service.OpenService(
self.scmhandle, self.sserv, win32service.SERVICE_ALL_ACCESS)
self.sccss = "SYSTEM\\CurrentControlSet\\Services\\"
def start(self):
win32service.StartService(self.handle, None)
def stop(self):
self.stat = win32service.ControlService(self.handle, win32service.SERVICE_CONTROL_STOP)
self.stat = win32service.ControlService(
self.handle, win32service.SERVICE_CONTROL_STOP)
def restart(self):
self.stop()
@ -83,29 +89,31 @@ class Service(object):
self.start()
def pause(self):
self.stat = win32service.ControlService(self.handle, win32service.SERVICE_CONTROL_PAUSE)
self.stat = win32service.ControlService(
self.handle, win32service.SERVICE_CONTROL_PAUSE)
def resume(self):
self.stat = win32service.ControlService(self.handle, win32service.SERVICE_CONTROL_CONTINUE)
self.stat = win32service.ControlService(
self.handle, win32service.SERVICE_CONTROL_CONTINUE)
def status(self, prn = 0):
def status(self, prn=0):
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[1]==win32service.SERVICE_STOPPED:
if self.stat[1] == win32service.SERVICE_STOPPED:
if prn == 1:
print("The %s service is stopped." % self.lserv)
else:
return "STOPPED"
elif self.stat[1]==win32service.SERVICE_START_PENDING:
elif self.stat[1] == win32service.SERVICE_START_PENDING:
if prn == 1:
print("The %s service is starting." % self.lserv)
else:
return "STARTING"
elif self.stat[1]==win32service.SERVICE_STOP_PENDING:
elif self.stat[1] == win32service.SERVICE_STOP_PENDING:
if prn == 1:
print("The %s service is stopping." % self.lserv)
else:
return "STOPPING"
elif self.stat[1]==win32service.SERVICE_RUNNING:
elif self.stat[1] == win32service.SERVICE_RUNNING:
if prn == 1:
print("The %s service is running." % self.lserv)
else:
@ -116,6 +124,7 @@ class Service(object):
if timeout is not None:
timeout = int(timeout)
timeout *= 2
def to(timeout):
time.sleep(.5)
if timeout is not None:
@ -127,11 +136,11 @@ class Service(object):
if self.fstatus == "STOPPED":
while 1:
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[1]==win32service.SERVICE_STOPPED:
if self.stat[1] == win32service.SERVICE_STOPPED:
self.fstate = "STOPPED"
break
else:
timeout=to(timeout)
timeout = to(timeout)
if timeout == "TO":
return "TIMEDOUT"
break

View file

@ -1,10 +1,12 @@
from .api import shell
def get_recycle_bin_confirm():
settings = shell.SHELLSTATE()
shell.SHGetSetSettings(settings, shell.SSF_NOCONFIRMRECYCLE, False)
return not settings.no_confirm_recycle
def set_recycle_bin_confirm(confirm=False):
settings = shell.SHELLSTATE()
settings.no_confirm_recycle = not confirm

View file

@ -14,6 +14,7 @@ from jaraco.windows.api import event as win32event
__author__ = 'Jason R. Coombs <jaraco@jaraco.com>'
class WaitableTimer:
"""
t = WaitableTimer()
@ -32,12 +33,12 @@ class WaitableTimer:
def stop(self):
win32event.SetEvent(self.stop_event)
def wait_for_signal(self, timeout = None):
def wait_for_signal(self, timeout=None):
"""
wait for the signal; return after the signal has occurred or the
timeout in seconds elapses.
"""
timeout_ms = int(timeout*1000) if timeout else win32event.INFINITE
timeout_ms = int(timeout * 1000) if timeout else win32event.INFINITE
win32event.WaitForSingleObject(self.signal_event, timeout_ms)
def _signal_loop(self, due_time, period):
@ -54,14 +55,14 @@ class WaitableTimer:
except Exception:
pass
#we're done here, just quit
def _wait(self, seconds):
milliseconds = int(seconds*1000)
milliseconds = int(seconds * 1000)
if milliseconds > 0:
res = win32event.WaitForSingleObject(self.stop_event, milliseconds)
if res == win32event.WAIT_OBJECT_0: raise Exception
if res == win32event.WAIT_TIMEOUT: pass
if res == win32event.WAIT_OBJECT_0:
raise Exception
if res == win32event.WAIT_TIMEOUT:
pass
win32event.SetEvent(self.signal_event)
@staticmethod

View file

@ -8,6 +8,7 @@ from ctypes.wintypes import WORD, WCHAR, BOOL, LONG
from jaraco.windows.util import Extended
from jaraco.collections import RangeMap
class AnyDict(object):
"A dictionary that returns the same value regardless of key"
@ -17,6 +18,7 @@ class AnyDict(object):
def __getitem__(self, key):
return self.value
class SYSTEMTIME(Extended, ctypes.Structure):
_fields_ = [
('year', WORD),
@ -29,6 +31,7 @@ class SYSTEMTIME(Extended, ctypes.Structure):
('millisecond', WORD),
]
class REG_TZI_FORMAT(Extended, ctypes.Structure):
_fields_ = [
('bias', LONG),
@ -38,17 +41,19 @@ class REG_TZI_FORMAT(Extended, ctypes.Structure):
('daylight_start', SYSTEMTIME),
]
class TIME_ZONE_INFORMATION(Extended, ctypes.Structure):
_fields_ = [
('bias', LONG),
('standard_name', WCHAR*32),
('standard_name', WCHAR * 32),
('standard_start', SYSTEMTIME),
('standard_bias', LONG),
('daylight_name', WCHAR*32),
('daylight_name', WCHAR * 32),
('daylight_start', SYSTEMTIME),
('daylight_bias', LONG),
]
class DYNAMIC_TIME_ZONE_INFORMATION(TIME_ZONE_INFORMATION):
"""
Because the structure of the DYNAMIC_TIME_ZONE_INFORMATION extends
@ -70,7 +75,7 @@ class DYNAMIC_TIME_ZONE_INFORMATION(TIME_ZONE_INFORMATION):
"""
_fields_ = [
# ctypes automatically includes the fields from the parent
('key_name', WCHAR*128),
('key_name', WCHAR * 128),
('dynamic_daylight_time_disabled', BOOL),
]
@ -89,6 +94,7 @@ class DYNAMIC_TIME_ZONE_INFORMATION(TIME_ZONE_INFORMATION):
kwargs[field_name] = arg
super(DYNAMIC_TIME_ZONE_INFORMATION, self).__init__(*self_args, **kwargs)
class Info(DYNAMIC_TIME_ZONE_INFORMATION):
"""
A time zone definition class based on the win32
@ -114,7 +120,7 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION):
self.__init_from_other,
self.__init_from_reg_tzi,
self.__init_from_bytes,
)
)
for func in funcs:
try:
func(*args, **kwargs)
@ -126,7 +132,7 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION):
def __init_from_bytes(self, bytes, **kwargs):
reg_tzi = REG_TZI_FORMAT()
# todo: use buffer API in Python 3
buffer = buffer(bytes)
buffer = memoryview(bytes)
ctypes.memmove(ctypes.addressof(reg_tzi), buffer, len(buffer))
self.__init_from_reg_tzi(self, reg_tzi, **kwargs)
@ -146,12 +152,14 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION):
value = super(Info, other).__getattribute__(other, name)
setattr(self, name, value)
# consider instead of the loop above just copying the memory directly
#size = max(ctypes.sizeof(DYNAMIC_TIME_ZONE_INFO), ctypes.sizeof(other))
#ctypes.memmove(ctypes.addressof(self), other, size)
# size = max(ctypes.sizeof(DYNAMIC_TIME_ZONE_INFO), ctypes.sizeof(other))
# ctypes.memmove(ctypes.addressof(self), other, size)
def __getattribute__(self, attr):
value = super(Info, self).__getattribute__(attr)
make_minute_timedelta = lambda m: datetime.timedelta(minutes = m)
def make_minute_timedelta(m):
datetime.timedelta(minutes=m)
if 'bias' in attr:
value = make_minute_timedelta(value)
return value
@ -205,10 +213,12 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION):
def _locate_day(year, cutoff):
"""
Takes a SYSTEMTIME object, such as retrieved from a TIME_ZONE_INFORMATION
structure or call to GetTimeZoneInformation and interprets it based on the given
structure or call to GetTimeZoneInformation and interprets
it based on the given
year to identify the actual day.
This method is necessary because the SYSTEMTIME structure refers to a day by its
This method is necessary because the SYSTEMTIME structure
refers to a day by its
day of the week and week of the month (e.g. 4th saturday in March).
>>> SATURDAY = 6
@ -227,9 +237,11 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION):
week_of_month = cutoff.day
# so the following is the first day of that week
day = (week_of_month - 1) * 7 + 1
result = datetime.datetime(year, cutoff.month, day,
result = datetime.datetime(
year, cutoff.month, day,
cutoff.hour, cutoff.minute, cutoff.second, cutoff.millisecond)
# now the result is the correct week, but not necessarily the correct day of the week
# now the result is the correct week, but not necessarily
# the correct day of the week
days_to_go = (target_weekday - result.weekday()) % 7
result += datetime.timedelta(days_to_go)
# if we selected a day in the month following the target month,
@ -238,5 +250,5 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION):
# to be the last week in a month and adding the time delta might have
# pushed the result into the next month.
while result.month == cutoff.month + 1:
result -= datetime.timedelta(weeks = 1)
result -= datetime.timedelta(weeks=1)
return result

View file

@ -3,6 +3,7 @@
import ctypes
from jaraco.windows.util import ensure_unicode
def MessageBox(text, caption=None, handle=None, type=None):
text, caption = map(ensure_unicode, (text, caption))
ctypes.windll.user32.MessageBoxW(handle, text, caption, type)

View file

@ -3,6 +3,7 @@ from .api import errors
from .api.user import GetUserName
from .error import WindowsError, handle_nonzero_success
def get_user_name():
size = ctypes.wintypes.DWORD()
try:

View file

@ -2,17 +2,19 @@
import ctypes
def ensure_unicode(param):
try:
param = ctypes.create_unicode_buffer(param)
except TypeError:
pass # just return the param as is
pass # just return the param as is
return param
class Extended(object):
"Used to add extended capability to structures"
def __eq__(self, other):
return buffer(self) == buffer(other)
return memoryview(self) == memoryview(other)
def __ne__(self, other):
return buffer(self) != buffer(other)
return memoryview(self) != memoryview(other)

View file

@ -1,11 +1,14 @@
import os
from path import path
from path import Path
def install_pptp(name, param_lines):
"""
"""
# or consider using the API: http://msdn.microsoft.com/en-us/library/aa446739%28v=VS.85%29.aspx
pbk_path = (path(os.environ['PROGRAMDATA'])
# or consider using the API:
# http://msdn.microsoft.com/en-us/library/aa446739%28v=VS.85%29.aspx
pbk_path = (
Path(os.environ['PROGRAMDATA'])
/ 'Microsoft' / 'Network' / 'Connections' / 'pbk' / 'rasphone.pbk')
pbk_path.dirname().makedirs_p()
with open(pbk_path, 'a') as pbk:

View file

@ -17,6 +17,7 @@ def set(value):
)
handle_nonzero_success(result)
def get():
value = ctypes.wintypes.BOOL()
result = system.SystemParametersInfo(
@ -28,6 +29,7 @@ def get():
handle_nonzero_success(result)
return bool(value)
def set_delay(milliseconds):
result = system.SystemParametersInfo(
system.SPI_SETACTIVEWNDTRKTIMEOUT,
@ -37,6 +39,7 @@ def set_delay(milliseconds):
)
handle_nonzero_success(result)
def get_delay():
value = ctypes.wintypes.DWORD()
result = system.SystemParametersInfo(

View file

@ -1,2 +1,2 @@
from more_itertools.more import *
from more_itertools.recipes import *
from more_itertools.more import * # noqa
from more_itertools.recipes import * # noqa

File diff suppressed because it is too large Load diff

View file

@ -8,21 +8,79 @@ Some backward-compatible usability improvements have been made.
"""
from collections import deque
from itertools import chain, combinations, count, cycle, groupby, ifilterfalse, imap, islice, izip, izip_longest, repeat, starmap, tee # Wrapping breaks 2to3.
from itertools import (
chain, combinations, count, cycle, groupby, islice, repeat, starmap, tee
)
import operator
from random import randrange, sample, choice
from six import PY2
from six.moves import filter, filterfalse, map, range, zip, zip_longest
__all__ = ['take', 'tabulate', 'consume', 'nth', 'quantify', 'padnone',
'ncycles', 'dotproduct', 'flatten', 'repeatfunc', 'pairwise',
'grouper', 'roundrobin', 'powerset', 'unique_everseen',
'unique_justseen', 'iter_except', 'random_product',
'random_permutation', 'random_combination',
'random_combination_with_replacement']
__all__ = [
'accumulate',
'all_equal',
'consume',
'dotproduct',
'first_true',
'flatten',
'grouper',
'iter_except',
'ncycles',
'nth',
'nth_combination',
'padnone',
'pairwise',
'partition',
'powerset',
'prepend',
'quantify',
'random_combination_with_replacement',
'random_combination',
'random_permutation',
'random_product',
'repeatfunc',
'roundrobin',
'tabulate',
'tail',
'take',
'unique_everseen',
'unique_justseen',
]
def accumulate(iterable, func=operator.add):
"""
Return an iterator whose items are the accumulated results of a function
(specified by the optional *func* argument) that takes two arguments.
By default, returns accumulated sums with :func:`operator.add`.
>>> list(accumulate([1, 2, 3, 4, 5])) # Running sum
[1, 3, 6, 10, 15]
>>> list(accumulate([1, 2, 3], func=operator.mul)) # Running product
[1, 2, 6]
>>> list(accumulate([0, 1, -1, 2, 3, 2], func=max)) # Running maximum
[0, 1, 1, 2, 3, 3]
This function is available in the ``itertools`` module for Python 3.2 and
greater.
"""
it = iter(iterable)
try:
total = next(it)
except StopIteration:
return
else:
yield total
for element in it:
total = func(total, element)
yield total
def take(n, iterable):
"""Return first n items of the iterable as a list
"""Return first *n* items of the iterable as a list.
>>> take(3, range(10))
[0, 1, 2]
@ -37,21 +95,37 @@ def take(n, iterable):
def tabulate(function, start=0):
"""Return an iterator mapping the function over linear input.
"""Return an iterator over the results of ``func(start)``,
``func(start + 1)``, ``func(start + 2)``...
The start argument will be increased by 1 each time the iterator is called
and fed into the function.
*func* should be a function that accepts one integer argument.
>>> t = tabulate(lambda x: x**2, -3)
>>> take(3, t)
[9, 4, 1]
If *start* is not specified it defaults to 0. It will be incremented each
time the iterator is advanced.
>>> square = lambda x: x ** 2
>>> iterator = tabulate(square, -3)
>>> take(4, iterator)
[9, 4, 1, 0]
"""
return imap(function, count(start))
return map(function, count(start))
def tail(n, iterable):
"""Return an iterator over the last *n* items of *iterable*.
>>> t = tail(3, 'ABCDEFG')
>>> list(t)
['E', 'F', 'G']
"""
return iter(deque(iterable, maxlen=n))
def consume(iterator, n=None):
"""Advance the iterator n-steps ahead. If n is none, consume entirely.
"""Advance *iterable* by *n* steps. If *n* is ``None``, consume it
entirely.
Efficiently exhausts an iterator without returning values. Defaults to
consuming the whole iterator, but an optional second argument may be
@ -90,7 +164,7 @@ def consume(iterator, n=None):
def nth(iterable, n, default=None):
"""Returns the nth item or a default value
"""Returns the nth item or a default value.
>>> l = range(10)
>>> nth(l, 3)
@ -102,30 +176,46 @@ def nth(iterable, n, default=None):
return next(islice(iterable, n, None), default)
def all_equal(iterable):
"""
Returns ``True`` if all the elements are equal to each other.
>>> all_equal('aaaa')
True
>>> all_equal('aaab')
False
"""
g = groupby(iterable)
return next(g, True) and not next(g, False)
def quantify(iterable, pred=bool):
"""Return the how many times the predicate is true
"""Return the how many times the predicate is true.
>>> quantify([True, False, True])
2
"""
return sum(imap(pred, iterable))
return sum(map(pred, iterable))
def padnone(iterable):
"""Returns the sequence of elements and then returns None indefinitely.
"""Returns the sequence of elements and then returns ``None`` indefinitely.
>>> take(5, padnone(range(3)))
[0, 1, 2, None, None]
Useful for emulating the behavior of the built-in map() function.
Useful for emulating the behavior of the built-in :func:`map` function.
See also :func:`padded`.
"""
return chain(iterable, repeat(None))
def ncycles(iterable, n):
"""Returns the sequence elements n times
"""Returns the sequence elements *n* times
>>> list(ncycles(["a", "b"], 3))
['a', 'b', 'a', 'b', 'a', 'b']
@ -135,32 +225,47 @@ def ncycles(iterable, n):
def dotproduct(vec1, vec2):
"""Returns the dot product of the two iterables
"""Returns the dot product of the two iterables.
>>> dotproduct([10, 10], [20, 20])
400
"""
return sum(imap(operator.mul, vec1, vec2))
return sum(map(operator.mul, vec1, vec2))
def flatten(listOfLists):
"""Return an iterator flattening one level of nesting in a list of lists
"""Return an iterator flattening one level of nesting in a list of lists.
>>> list(flatten([[0, 1], [2, 3]]))
[0, 1, 2, 3]
See also :func:`collapse`, which can flatten multiple levels of nesting.
"""
return chain.from_iterable(listOfLists)
def repeatfunc(func, times=None, *args):
"""Repeat calls to func with specified arguments.
"""Call *func* with *args* repeatedly, returning an iterable over the
results.
>>> list(repeatfunc(lambda: 5, 3))
[5, 5, 5]
>>> list(repeatfunc(lambda x: x ** 2, 3, 3))
[9, 9, 9]
If *times* is specified, the iterable will terminate after that many
repetitions:
>>> from operator import add
>>> times = 4
>>> args = 3, 5
>>> list(repeatfunc(add, times, *args))
[8, 8, 8, 8]
If *times* is ``None`` the iterable will not terminate:
>>> from random import randrange
>>> times = None
>>> args = 1, 11
>>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
[2, 4, 8, 1, 8, 4]
"""
if times is None:
@ -177,30 +282,37 @@ def pairwise(iterable):
"""
a, b = tee(iterable)
next(b, None)
return izip(a, b)
return zip(a, b)
def grouper(n, iterable, fillvalue=None):
"""Collect data into fixed-length chunks or blocks
"""Collect data into fixed-length chunks or blocks.
>>> list(grouper(3, 'ABCDEFG', 'x'))
[('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
"""
args = [iter(iterable)] * n
return izip_longest(fillvalue=fillvalue, *args)
return zip_longest(fillvalue=fillvalue, *args)
def roundrobin(*iterables):
"""Yields an item from each iterable, alternating between them
"""Yields an item from each iterable, alternating between them.
>>> list(roundrobin('ABC', 'D', 'EF'))
['A', 'D', 'E', 'B', 'F', 'C']
This function produces the same output as :func:`interleave_longest`, but
may perform better for some inputs (in particular when the number of
iterables is small).
"""
# Recipe credited to George Sakkis
pending = len(iterables)
nexts = cycle(iter(it).next for it in iterables)
if PY2:
nexts = cycle(iter(it).next for it in iterables)
else:
nexts = cycle(iter(it).__next__ for it in iterables)
while pending:
try:
for next in nexts:
@ -210,38 +322,73 @@ def roundrobin(*iterables):
nexts = cycle(islice(nexts, pending))
def partition(pred, iterable):
"""
Returns a 2-tuple of iterables derived from the input iterable.
The first yields the items that have ``pred(item) == False``.
The second yields the items that have ``pred(item) == True``.
>>> is_odd = lambda x: x % 2 != 0
>>> iterable = range(10)
>>> even_items, odd_items = partition(is_odd, iterable)
>>> list(even_items), list(odd_items)
([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
"""
# partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
t1, t2 = tee(iterable)
return filterfalse(pred, t1), filter(pred, t2)
def powerset(iterable):
"""Yields all possible subsets of the iterable
"""Yields all possible subsets of the iterable.
>>> list(powerset([1,2,3]))
[(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
"""
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
def unique_everseen(iterable, key=None):
"""Yield unique elements, preserving order.
"""
Yield unique elements, preserving order.
>>> list(unique_everseen('AAAABBBCCDAABBB'))
['A', 'B', 'C', 'D']
>>> list(unique_everseen('ABBCcAD', str.lower))
['A', 'B', 'C', 'D']
Sequences with a mix of hashable and unhashable items can be used.
The function will be slower (i.e., `O(n^2)`) for unhashable items.
"""
seen = set()
seen_add = seen.add
seenset = set()
seenset_add = seenset.add
seenlist = []
seenlist_add = seenlist.append
if key is None:
for element in ifilterfalse(seen.__contains__, iterable):
seen_add(element)
yield element
for element in iterable:
try:
if element not in seenset:
seenset_add(element)
yield element
except TypeError:
if element not in seenlist:
seenlist_add(element)
yield element
else:
for element in iterable:
k = key(element)
if k not in seen:
seen_add(k)
yield element
try:
if k not in seenset:
seenset_add(k)
yield element
except TypeError:
if k not in seenlist:
seenlist_add(k)
yield element
def unique_justseen(iterable, key=None):
@ -253,17 +400,17 @@ def unique_justseen(iterable, key=None):
['A', 'B', 'C', 'A', 'D']
"""
return imap(next, imap(operator.itemgetter(1), groupby(iterable, key)))
return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
def iter_except(func, exception, first=None):
"""Yields results from a function repeatedly until an exception is raised.
Converts a call-until-exception interface to an iterator interface.
Like __builtin__.iter(func, sentinel) but uses an exception instead
of a sentinel to end the loop.
Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
to end the loop.
>>> l = range(3)
>>> l = [0, 1, 2]
>>> list(iter_except(l.pop, IndexError))
[2, 1, 0]
@ -277,28 +424,58 @@ def iter_except(func, exception, first=None):
pass
def random_product(*args, **kwds):
"""Returns a random pairing of items from each iterable argument
def first_true(iterable, default=False, pred=None):
"""
Returns the first true value in the iterable.
If `repeat` is provided as a kwarg, it's value will be used to indicate
how many pairings should be chosen.
If no true value is found, returns *default*
>>> random_product(['a', 'b', 'c'], [1, 2], repeat=2) # doctest:+SKIP
('b', '2', 'c', '2')
If *pred* is not None, returns the first item for which
``pred(item) == True`` .
>>> first_true(range(10))
1
>>> first_true(range(10), pred=lambda x: x > 5)
6
>>> first_true(range(10), default='missing', pred=lambda x: x > 9)
'missing'
"""
pools = map(tuple, args) * kwds.get('repeat', 1)
return next(filter(pred, iterable), default)
def random_product(*args, **kwds):
"""Draw an item at random from each of the input iterables.
>>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
('c', 3, 'Z')
If *repeat* is provided as a keyword argument, that many items will be
drawn from each iterable.
>>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
('a', 2, 'd', 3)
This equivalent to taking a random selection from
``itertools.product(*args, **kwarg)``.
"""
pools = [tuple(pool) for pool in args] * kwds.get('repeat', 1)
return tuple(choice(pool) for pool in pools)
def random_permutation(iterable, r=None):
"""Returns a random permutation.
"""Return a random *r* length permutation of the elements in *iterable*.
If r is provided, the permutation is truncated to length r.
If *r* is not specified or is ``None``, then *r* defaults to the length of
*iterable*.
>>> random_permutation(range(5)) # doctest:+SKIP
>>> random_permutation(range(5)) # doctest:+SKIP
(3, 4, 0, 1, 2)
This equivalent to taking a random selection from
``itertools.permutations(iterable, r)``.
"""
pool = tuple(iterable)
r = len(pool) if r is None else r
@ -306,26 +483,83 @@ def random_permutation(iterable, r=None):
def random_combination(iterable, r):
"""Returns a random combination of length r, chosen without replacement.
"""Return a random *r* length subsequence of the elements in *iterable*.
>>> random_combination(range(5), 3) # doctest:+SKIP
>>> random_combination(range(5), 3) # doctest:+SKIP
(2, 3, 4)
This equivalent to taking a random selection from
``itertools.combinations(iterable, r)``.
"""
pool = tuple(iterable)
n = len(pool)
indices = sorted(sample(xrange(n), r))
indices = sorted(sample(range(n), r))
return tuple(pool[i] for i in indices)
def random_combination_with_replacement(iterable, r):
"""Returns a random combination of length r, chosen with replacement.
"""Return a random *r* length subsequence of elements in *iterable*,
allowing individual elements to be repeated.
>>> random_combination_with_replacement(range(3), 5) # # doctest:+SKIP
>>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
(0, 0, 1, 2, 2)
This equivalent to taking a random selection from
``itertools.combinations_with_replacement(iterable, r)``.
"""
pool = tuple(iterable)
n = len(pool)
indices = sorted(randrange(n) for i in xrange(r))
indices = sorted(randrange(n) for i in range(r))
return tuple(pool[i] for i in indices)
def nth_combination(iterable, r, index):
"""Equivalent to ``list(combinations(iterable, r))[index]``.
The subsequences of *iterable* that are of length *r* can be ordered
lexicographically. :func:`nth_combination` computes the subsequence at
sort position *index* directly, without computing the previous
subsequences.
"""
pool = tuple(iterable)
n = len(pool)
if (r < 0) or (r > n):
raise ValueError
c = 1
k = min(r, n - r)
for i in range(1, k + 1):
c = c * (n - k + i) // i
if index < 0:
index += c
if (index < 0) or (index >= c):
raise IndexError
result = []
while r:
c, n, r = c * r // n, n - 1, r - 1
while index >= c:
index -= c
c, n = c * (n - r) // n, n - 1
result.append(pool[-1 - n])
return tuple(result)
def prepend(value, iterator):
"""Yield *value*, followed by the elements in *iterator*.
>>> value = '0'
>>> iterator = ['1', '2', '3']
>>> list(prepend(value, iterator))
['0', '1', '2', '3']
To prepend multiple values, see :func:`itertools.chain`.
"""
return chain([value], iterator)

File diff suppressed because it is too large Load diff

View file

@ -1,13 +1,39 @@
from random import seed
from doctest import DocTestSuite
from unittest import TestCase
from nose.tools import eq_, assert_raises, ok_
from itertools import combinations
from six.moves import range
from more_itertools import *
import more_itertools as mi
def setup_module():
seed(1337)
def load_tests(loader, tests, ignore):
# Add the doctests
tests.addTests(DocTestSuite('more_itertools.recipes'))
return tests
class AccumulateTests(TestCase):
"""Tests for ``accumulate()``"""
def test_empty(self):
"""Test that an empty input returns an empty output"""
self.assertEqual(list(mi.accumulate([])), [])
def test_default(self):
"""Test accumulate with the default function (addition)"""
self.assertEqual(list(mi.accumulate([1, 2, 3])), [1, 3, 6])
def test_bogus_function(self):
"""Test accumulate with an invalid function"""
with self.assertRaises(TypeError):
list(mi.accumulate([1, 2, 3], func=lambda x: x))
def test_custom_function(self):
"""Test accumulate with a custom function"""
self.assertEqual(
list(mi.accumulate((1, 2, 3, 2, 1), func=max)), [1, 2, 3, 3, 3]
)
class TakeTests(TestCase):
@ -15,25 +41,25 @@ class TakeTests(TestCase):
def test_simple_take(self):
"""Test basic usage"""
t = take(5, xrange(10))
eq_(t, [0, 1, 2, 3, 4])
t = mi.take(5, range(10))
self.assertEqual(t, [0, 1, 2, 3, 4])
def test_null_take(self):
"""Check the null case"""
t = take(0, xrange(10))
eq_(t, [])
t = mi.take(0, range(10))
self.assertEqual(t, [])
def test_negative_take(self):
"""Make sure taking negative items results in a ValueError"""
assert_raises(ValueError, take, -3, xrange(10))
self.assertRaises(ValueError, lambda: mi.take(-3, range(10)))
def test_take_too_much(self):
"""Taking more than an iterator has remaining should return what the
iterator has remaining.
"""
t = take(10, xrange(5))
eq_(t, [0, 1, 2, 3, 4])
t = mi.take(10, range(5))
self.assertEqual(t, [0, 1, 2, 3, 4])
class TabulateTests(TestCase):
@ -41,15 +67,35 @@ class TabulateTests(TestCase):
def test_simple_tabulate(self):
"""Test the happy path"""
t = tabulate(lambda x: x)
t = mi.tabulate(lambda x: x)
f = tuple([next(t) for _ in range(3)])
eq_(f, (0, 1, 2))
self.assertEqual(f, (0, 1, 2))
def test_count(self):
"""Ensure tabulate accepts specific count"""
t = tabulate(lambda x: 2 * x, -1)
t = mi.tabulate(lambda x: 2 * x, -1)
f = (next(t), next(t), next(t))
eq_(f, (-2, 0, 2))
self.assertEqual(f, (-2, 0, 2))
class TailTests(TestCase):
"""Tests for ``tail()``"""
def test_greater(self):
"""Length of iterable is greather than requested tail"""
self.assertEqual(list(mi.tail(3, 'ABCDEFG')), ['E', 'F', 'G'])
def test_equal(self):
"""Length of iterable is equal to the requested tail"""
self.assertEqual(
list(mi.tail(7, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G']
)
def test_less(self):
"""Length of iterable is less than requested tail"""
self.assertEqual(
list(mi.tail(8, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G']
)
class ConsumeTests(TestCase):
@ -58,25 +104,25 @@ class ConsumeTests(TestCase):
def test_sanity(self):
"""Test basic functionality"""
r = (x for x in range(10))
consume(r, 3)
eq_(3, next(r))
mi.consume(r, 3)
self.assertEqual(3, next(r))
def test_null_consume(self):
"""Check the null case"""
r = (x for x in range(10))
consume(r, 0)
eq_(0, next(r))
mi.consume(r, 0)
self.assertEqual(0, next(r))
def test_negative_consume(self):
"""Check that negative consumsion throws an error"""
r = (x for x in range(10))
assert_raises(ValueError, consume, r, -1)
self.assertRaises(ValueError, lambda: mi.consume(r, -1))
def test_total_consume(self):
"""Check that iterator is totally consumed by default"""
r = (x for x in range(10))
consume(r)
assert_raises(StopIteration, next, r)
mi.consume(r)
self.assertRaises(StopIteration, lambda: next(r))
class NthTests(TestCase):
@ -86,16 +132,45 @@ class NthTests(TestCase):
"""Make sure the nth item is returned"""
l = range(10)
for i, v in enumerate(l):
eq_(nth(l, i), v)
self.assertEqual(mi.nth(l, i), v)
def test_default(self):
"""Ensure a default value is returned when nth item not found"""
l = range(3)
eq_(nth(l, 100, "zebra"), "zebra")
self.assertEqual(mi.nth(l, 100, "zebra"), "zebra")
def test_negative_item_raises(self):
"""Ensure asking for a negative item raises an exception"""
assert_raises(ValueError, nth, range(10), -3)
self.assertRaises(ValueError, lambda: mi.nth(range(10), -3))
class AllEqualTests(TestCase):
"""Tests for ``all_equal()``"""
def test_true(self):
"""Everything is equal"""
self.assertTrue(mi.all_equal('aaaaaa'))
self.assertTrue(mi.all_equal([0, 0, 0, 0]))
def test_false(self):
"""Not everything is equal"""
self.assertFalse(mi.all_equal('aaaaab'))
self.assertFalse(mi.all_equal([0, 0, 0, 1]))
def test_tricky(self):
"""Not everything is identical, but everything is equal"""
items = [1, complex(1, 0), 1.0]
self.assertTrue(mi.all_equal(items))
def test_empty(self):
"""Return True if the iterable is empty"""
self.assertTrue(mi.all_equal(''))
self.assertTrue(mi.all_equal([]))
def test_one(self):
"""Return True if the iterable is singular"""
self.assertTrue(mi.all_equal('0'))
self.assertTrue(mi.all_equal([0]))
class QuantifyTests(TestCase):
@ -104,12 +179,12 @@ class QuantifyTests(TestCase):
def test_happy_path(self):
"""Make sure True count is returned"""
q = [True, False, True]
eq_(quantify(q), 2)
self.assertEqual(mi.quantify(q), 2)
def test_custom_predicate(self):
"""Ensure non-default predicates return as expected"""
q = range(10)
eq_(quantify(q, lambda x: x % 2 == 0), 5)
self.assertEqual(mi.quantify(q, lambda x: x % 2 == 0), 5)
class PadnoneTests(TestCase):
@ -118,8 +193,8 @@ class PadnoneTests(TestCase):
def test_happy_path(self):
"""wrapper iterator should return None indefinitely"""
r = range(2)
p = padnone(r)
eq_([0, 1, None, None], [next(p) for _ in range(4)])
p = mi.padnone(r)
self.assertEqual([0, 1, None, None], [next(p) for _ in range(4)])
class NcyclesTests(TestCase):
@ -128,19 +203,21 @@ class NcyclesTests(TestCase):
def test_happy_path(self):
"""cycle a sequence three times"""
r = ["a", "b", "c"]
n = ncycles(r, 3)
eq_(["a", "b", "c", "a", "b", "c", "a", "b", "c"],
list(n))
n = mi.ncycles(r, 3)
self.assertEqual(
["a", "b", "c", "a", "b", "c", "a", "b", "c"],
list(n)
)
def test_null_case(self):
"""asking for 0 cycles should return an empty iterator"""
n = ncycles(range(100), 0)
assert_raises(StopIteration, next, n)
n = mi.ncycles(range(100), 0)
self.assertRaises(StopIteration, lambda: next(n))
def test_pathalogical_case(self):
"""asking for negative cycles should return an empty iterator"""
n = ncycles(range(100), -10)
assert_raises(StopIteration, next, n)
n = mi.ncycles(range(100), -10)
self.assertRaises(StopIteration, lambda: next(n))
class DotproductTests(TestCase):
@ -148,7 +225,7 @@ class DotproductTests(TestCase):
def test_happy_path(self):
"""simple dotproduct example"""
eq_(400, dotproduct([10, 10], [20, 20]))
self.assertEqual(400, mi.dotproduct([10, 10], [20, 20]))
class FlattenTests(TestCase):
@ -157,12 +234,12 @@ class FlattenTests(TestCase):
def test_basic_usage(self):
"""ensure list of lists is flattened one level"""
f = [[0, 1, 2], [3, 4, 5]]
eq_(range(6), list(flatten(f)))
self.assertEqual(list(range(6)), list(mi.flatten(f)))
def test_single_level(self):
"""ensure list of lists is flattened only one level"""
f = [[0, [1, 2]], [[3, 4], 5]]
eq_([0, [1, 2], [3, 4], 5], list(flatten(f)))
self.assertEqual([0, [1, 2], [3, 4], 5], list(mi.flatten(f)))
class RepeatfuncTests(TestCase):
@ -170,23 +247,23 @@ class RepeatfuncTests(TestCase):
def test_simple_repeat(self):
"""test simple repeated functions"""
r = repeatfunc(lambda: 5)
eq_([5, 5, 5, 5, 5], [next(r) for _ in range(5)])
r = mi.repeatfunc(lambda: 5)
self.assertEqual([5, 5, 5, 5, 5], [next(r) for _ in range(5)])
def test_finite_repeat(self):
"""ensure limited repeat when times is provided"""
r = repeatfunc(lambda: 5, times=5)
eq_([5, 5, 5, 5, 5], list(r))
r = mi.repeatfunc(lambda: 5, times=5)
self.assertEqual([5, 5, 5, 5, 5], list(r))
def test_added_arguments(self):
"""ensure arguments are applied to the function"""
r = repeatfunc(lambda x: x, 2, 3)
eq_([3, 3], list(r))
r = mi.repeatfunc(lambda x: x, 2, 3)
self.assertEqual([3, 3], list(r))
def test_null_times(self):
"""repeat 0 should return an empty iterator"""
r = repeatfunc(range, 0, 3)
assert_raises(StopIteration, next, r)
r = mi.repeatfunc(range, 0, 3)
self.assertRaises(StopIteration, lambda: next(r))
class PairwiseTests(TestCase):
@ -194,13 +271,13 @@ class PairwiseTests(TestCase):
def test_base_case(self):
"""ensure an iterable will return pairwise"""
p = pairwise([1, 2, 3])
eq_([(1, 2), (2, 3)], list(p))
p = mi.pairwise([1, 2, 3])
self.assertEqual([(1, 2), (2, 3)], list(p))
def test_short_case(self):
"""ensure an empty iterator if there's not enough values to pair"""
p = pairwise("a")
assert_raises(StopIteration, next, p)
p = mi.pairwise("a")
self.assertRaises(StopIteration, lambda: next(p))
class GrouperTests(TestCase):
@ -211,18 +288,25 @@ class GrouperTests(TestCase):
the iterable.
"""
eq_(list(grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')])
self.assertEqual(
list(mi.grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')]
)
def test_odd(self):
"""Test when group size does not divide evenly into the length of the
iterable.
"""
eq_(list(grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)])
self.assertEqual(
list(mi.grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)]
)
def test_fill_value(self):
"""Test that the fill value is used to pad the final group"""
eq_(list(grouper(3, 'ABCDE', 'x')), [('A', 'B', 'C'), ('D', 'E', 'x')])
self.assertEqual(
list(mi.grouper(3, 'ABCDE', 'x')),
[('A', 'B', 'C'), ('D', 'E', 'x')]
)
class RoundrobinTests(TestCase):
@ -230,13 +314,33 @@ class RoundrobinTests(TestCase):
def test_even_groups(self):
"""Ensure ordered output from evenly populated iterables"""
eq_(list(roundrobin('ABC', [1, 2, 3], range(3))),
['A', 1, 0, 'B', 2, 1, 'C', 3, 2])
self.assertEqual(
list(mi.roundrobin('ABC', [1, 2, 3], range(3))),
['A', 1, 0, 'B', 2, 1, 'C', 3, 2]
)
def test_uneven_groups(self):
"""Ensure ordered output from unevenly populated iterables"""
eq_(list(roundrobin('ABCD', [1, 2], range(0))),
['A', 1, 'B', 2, 'C', 'D'])
self.assertEqual(
list(mi.roundrobin('ABCD', [1, 2], range(0))),
['A', 1, 'B', 2, 'C', 'D']
)
class PartitionTests(TestCase):
"""Tests for ``partition()``"""
def test_bool(self):
"""Test when pred() returns a boolean"""
lesser, greater = mi.partition(lambda x: x > 5, range(10))
self.assertEqual(list(lesser), [0, 1, 2, 3, 4, 5])
self.assertEqual(list(greater), [6, 7, 8, 9])
def test_arbitrary(self):
"""Test when pred() returns an integer"""
divisibles, remainders = mi.partition(lambda x: x % 3, range(10))
self.assertEqual(list(divisibles), [0, 3, 6, 9])
self.assertEqual(list(remainders), [1, 2, 4, 5, 7, 8])
class PowersetTests(TestCase):
@ -244,9 +348,11 @@ class PowersetTests(TestCase):
def test_combinatorics(self):
"""Ensure a proper enumeration"""
p = powerset([1, 2, 3])
eq_(list(p),
[(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)])
p = mi.powerset([1, 2, 3])
self.assertEqual(
list(p),
[(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
)
class UniqueEverseenTests(TestCase):
@ -254,14 +360,28 @@ class UniqueEverseenTests(TestCase):
def test_everseen(self):
"""ensure duplicate elements are ignored"""
u = unique_everseen('AAAABBBBCCDAABBB')
eq_(['A', 'B', 'C', 'D'],
list(u))
u = mi.unique_everseen('AAAABBBBCCDAABBB')
self.assertEqual(
['A', 'B', 'C', 'D'],
list(u)
)
def test_custom_key(self):
"""ensure the custom key comparison works"""
u = unique_everseen('aAbACCc', key=str.lower)
eq_(list('abC'), list(u))
u = mi.unique_everseen('aAbACCc', key=str.lower)
self.assertEqual(list('abC'), list(u))
def test_unhashable(self):
"""ensure things work for unhashable items"""
iterable = ['a', [1, 2, 3], [1, 2, 3], 'a']
u = mi.unique_everseen(iterable)
self.assertEqual(list(u), ['a', [1, 2, 3]])
def test_unhashable_key(self):
"""ensure things work for unhashable items with a custom key"""
iterable = ['a', [1, 2, 3], [1, 2, 3], 'a']
u = mi.unique_everseen(iterable, key=lambda x: x)
self.assertEqual(list(u), ['a', [1, 2, 3]])
class UniqueJustseenTests(TestCase):
@ -269,13 +389,13 @@ class UniqueJustseenTests(TestCase):
def test_justseen(self):
"""ensure only last item is remembered"""
u = unique_justseen('AAAABBBCCDABB')
eq_(list('ABCDAB'), list(u))
u = mi.unique_justseen('AAAABBBCCDABB')
self.assertEqual(list('ABCDAB'), list(u))
def test_custom_key(self):
"""ensure the custom key comparison works"""
u = unique_justseen('AABCcAD', str.lower)
eq_(list('ABCAD'), list(u))
u = mi.unique_justseen('AABCcAD', str.lower)
self.assertEqual(list('ABCAD'), list(u))
class IterExceptTests(TestCase):
@ -284,27 +404,49 @@ class IterExceptTests(TestCase):
def test_exact_exception(self):
"""ensure the exact specified exception is caught"""
l = [1, 2, 3]
i = iter_except(l.pop, IndexError)
eq_(list(i), [3, 2, 1])
i = mi.iter_except(l.pop, IndexError)
self.assertEqual(list(i), [3, 2, 1])
def test_generic_exception(self):
"""ensure the generic exception can be caught"""
l = [1, 2]
i = iter_except(l.pop, Exception)
eq_(list(i), [2, 1])
i = mi.iter_except(l.pop, Exception)
self.assertEqual(list(i), [2, 1])
def test_uncaught_exception_is_raised(self):
"""ensure a non-specified exception is raised"""
l = [1, 2, 3]
i = iter_except(l.pop, KeyError)
assert_raises(IndexError, list, i)
i = mi.iter_except(l.pop, KeyError)
self.assertRaises(IndexError, lambda: list(i))
def test_first(self):
"""ensure first is run before the function"""
l = [1, 2, 3]
f = lambda: 25
i = iter_except(l.pop, IndexError, f)
eq_(list(i), [25, 3, 2, 1])
i = mi.iter_except(l.pop, IndexError, f)
self.assertEqual(list(i), [25, 3, 2, 1])
class FirstTrueTests(TestCase):
"""Tests for ``first_true()``"""
def test_something_true(self):
"""Test with no keywords"""
self.assertEqual(mi.first_true(range(10)), 1)
def test_nothing_true(self):
"""Test default return value."""
self.assertEqual(mi.first_true([0, 0, 0]), False)
def test_default(self):
"""Test with a default keyword"""
self.assertEqual(mi.first_true([0, 0, 0], default='!'), '!')
def test_pred(self):
"""Test with a custom predicate"""
self.assertEqual(
mi.first_true([2, 4, 6], pred=lambda x: x % 3 == 0), 6
)
class RandomProductTests(TestCase):
@ -327,12 +469,12 @@ class RandomProductTests(TestCase):
"""
nums = [1, 2, 3]
lets = ['a', 'b', 'c']
n, m = zip(*[random_product(nums, lets) for _ in range(100)])
n, m = zip(*[mi.random_product(nums, lets) for _ in range(100)])
n, m = set(n), set(m)
eq_(n, set(nums))
eq_(m, set(lets))
eq_(len(n), len(nums))
eq_(len(m), len(lets))
self.assertEqual(n, set(nums))
self.assertEqual(m, set(lets))
self.assertEqual(len(n), len(nums))
self.assertEqual(len(m), len(lets))
def test_list_with_repeat(self):
"""ensure multiple items are chosen, and that they appear to be chosen
@ -341,13 +483,13 @@ class RandomProductTests(TestCase):
"""
nums = [1, 2, 3]
lets = ['a', 'b', 'c']
r = list(random_product(nums, lets, repeat=100))
eq_(2 * 100, len(r))
r = list(mi.random_product(nums, lets, repeat=100))
self.assertEqual(2 * 100, len(r))
n, m = set(r[::2]), set(r[1::2])
eq_(n, set(nums))
eq_(m, set(lets))
eq_(len(n), len(nums))
eq_(len(m), len(lets))
self.assertEqual(n, set(nums))
self.assertEqual(m, set(lets))
self.assertEqual(len(n), len(nums))
self.assertEqual(len(m), len(lets))
class RandomPermutationTests(TestCase):
@ -361,8 +503,8 @@ class RandomPermutationTests(TestCase):
"""
i = range(15)
r = random_permutation(i)
eq_(set(i), set(r))
r = mi.random_permutation(i)
self.assertEqual(set(i), set(r))
if i == r:
raise AssertionError("Values were not permuted")
@ -380,13 +522,13 @@ class RandomPermutationTests(TestCase):
items = range(15)
item_set = set(items)
all_items = set()
for _ in xrange(100):
permutation = random_permutation(items, 5)
eq_(len(permutation), 5)
for _ in range(100):
permutation = mi.random_permutation(items, 5)
self.assertEqual(len(permutation), 5)
permutation_set = set(permutation)
ok_(permutation_set <= item_set)
self.assertLessEqual(permutation_set, item_set)
all_items |= permutation_set
eq_(all_items, item_set)
self.assertEqual(all_items, item_set)
class RandomCombinationTests(TestCase):
@ -397,18 +539,20 @@ class RandomCombinationTests(TestCase):
samplings of random combinations"""
items = range(15)
all_items = set()
for _ in xrange(50):
combination = random_combination(items, 5)
for _ in range(50):
combination = mi.random_combination(items, 5)
all_items |= set(combination)
eq_(all_items, set(items))
self.assertEqual(all_items, set(items))
def test_no_replacement(self):
"""ensure that elements are sampled without replacement"""
items = range(15)
for _ in xrange(50):
combination = random_combination(items, len(items))
eq_(len(combination), len(set(combination)))
assert_raises(ValueError, random_combination, items, len(items) + 1)
for _ in range(50):
combination = mi.random_combination(items, len(items))
self.assertEqual(len(combination), len(set(combination)))
self.assertRaises(
ValueError, lambda: mi.random_combination(items, len(items) + 1)
)
class RandomCombinationWithReplacementTests(TestCase):
@ -417,17 +561,56 @@ class RandomCombinationWithReplacementTests(TestCase):
def test_replacement(self):
"""ensure that elements are sampled with replacement"""
items = range(5)
combo = random_combination_with_replacement(items, len(items) * 2)
eq_(2 * len(items), len(combo))
combo = mi.random_combination_with_replacement(items, len(items) * 2)
self.assertEqual(2 * len(items), len(combo))
if len(set(combo)) == len(combo):
raise AssertionError("Combination contained no duplicates")
def test_psuedorandomness(self):
def test_pseudorandomness(self):
"""ensure different subsets of the iterable get returned over many
samplings of random combinations"""
items = range(15)
all_items = set()
for _ in xrange(50):
combination = random_combination_with_replacement(items, 5)
for _ in range(50):
combination = mi.random_combination_with_replacement(items, 5)
all_items |= set(combination)
eq_(all_items, set(items))
self.assertEqual(all_items, set(items))
class NthCombinationTests(TestCase):
def test_basic(self):
iterable = 'abcdefg'
r = 4
for index, expected in enumerate(combinations(iterable, r)):
actual = mi.nth_combination(iterable, r, index)
self.assertEqual(actual, expected)
def test_long(self):
actual = mi.nth_combination(range(180), 4, 2000000)
expected = (2, 12, 35, 126)
self.assertEqual(actual, expected)
def test_invalid_r(self):
for r in (-1, 3):
with self.assertRaises(ValueError):
mi.nth_combination([], r, 0)
def test_invalid_index(self):
with self.assertRaises(IndexError):
mi.nth_combination('abcdefg', 3, -36)
class PrependTests(TestCase):
def test_basic(self):
value = 'a'
iterator = iter('bcdefg')
actual = list(mi.prepend(value, iterator))
expected = list('abcdefg')
self.assertEqual(actual, expected)
def test_multiple(self):
value = 'ab'
iterator = iter('cdefg')
actual = tuple(mi.prepend(value, iterator))
expected = ('ab',) + tuple('cdefg')
self.assertEqual(actual, expected)

View file

@ -1,25 +1,3 @@
#
# Copyright (c) 2010 Mikhail Gusarov
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""
path.py - An object representing a path to a file or directory.
@ -29,8 +7,18 @@ Example::
from path import Path
d = Path('/home/guido/bin')
# Globbing
for f in d.files('*.py'):
f.chmod(0o755)
# Changing the working directory:
with Path("somewhere"):
# cwd in now `somewhere`
...
# Concatenate paths with /
foo_txt = Path("bar") / "foo.txt"
"""
from __future__ import unicode_literals
@ -41,7 +29,6 @@ import os
import fnmatch
import glob
import shutil
import codecs
import hashlib
import errno
import tempfile
@ -50,8 +37,10 @@ import operator
import re
import contextlib
import io
from distutils import dir_util
import importlib
import itertools
import platform
import ntpath
try:
import win32security
@ -77,22 +66,17 @@ string_types = str,
text_type = str
getcwdu = os.getcwd
def surrogate_escape(error):
"""
Simulate the Python 3 ``surrogateescape`` handler, but for Python 2 only.
"""
chars = error.object[error.start:error.end]
assert len(chars) == 1
val = ord(chars)
val += 0xdc00
return __builtin__.unichr(val), error.end
if PY2:
import __builtin__
string_types = __builtin__.basestring,
text_type = __builtin__.unicode
getcwdu = os.getcwdu
codecs.register_error('surrogateescape', surrogate_escape)
map = itertools.imap
filter = itertools.ifilter
FileNotFoundError = OSError
itertools.filterfalse = itertools.ifilterfalse
@contextlib.contextmanager
def io_error_compat():
@ -107,7 +91,8 @@ def io_error_compat():
##############################################################################
__all__ = ['Path', 'CaseInsensitivePattern']
__all__ = ['Path', 'TempDir', 'CaseInsensitivePattern']
LINESEPS = ['\r\n', '\r', '\n']
@ -119,8 +104,8 @@ U_NL_END = re.compile(r'(?:{0})$'.format(U_NEWLINE.pattern))
try:
import pkg_resources
__version__ = pkg_resources.require('path.py')[0].version
import importlib_metadata
__version__ = importlib_metadata.version('path.py')
except Exception:
__version__ = 'unknown'
@ -131,7 +116,7 @@ class TreeWalkWarning(Warning):
# from jaraco.functools
def compose(*funcs):
compose_two = lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs))
compose_two = lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)) # noqa
return functools.reduce(compose_two, funcs)
@ -170,6 +155,60 @@ class multimethod(object):
)
class matchers(object):
# TODO: make this class a module
@staticmethod
def load(param):
"""
If the supplied parameter is a string, assum it's a simple
pattern.
"""
return (
matchers.Pattern(param) if isinstance(param, string_types)
else param if param is not None
else matchers.Null()
)
class Base(object):
pass
class Null(Base):
def __call__(self, path):
return True
class Pattern(Base):
def __init__(self, pattern):
self.pattern = pattern
def get_pattern(self, normcase):
try:
return self._pattern
except AttributeError:
pass
self._pattern = normcase(self.pattern)
return self._pattern
def __call__(self, path):
normcase = getattr(self, 'normcase', path.module.normcase)
pattern = self.get_pattern(normcase)
return fnmatch.fnmatchcase(normcase(path.name), pattern)
class CaseInsensitive(Pattern):
"""
A Pattern with a ``'normcase'`` property, suitable for passing to
:meth:`listdir`, :meth:`dirs`, :meth:`files`, :meth:`walk`,
:meth:`walkdirs`, or :meth:`walkfiles` to match case-insensitive.
For example, to get all files ending in .py, .Py, .pY, or .PY in the
current directory::
from path import Path, matchers
Path('.').files(matchers.CaseInsensitive('*.py'))
"""
normcase = staticmethod(ntpath.normcase)
class Path(text_type):
"""
Represents a filesystem path.
@ -214,16 +253,6 @@ class Path(text_type):
"""
return cls
@classmethod
def _always_unicode(cls, path):
"""
Ensure the path as retrieved from a Python API, such as :func:`os.listdir`,
is a proper Unicode string.
"""
if PY3 or isinstance(path, text_type):
return path
return path.decode(sys.getfilesystemencoding(), 'surrogateescape')
# --- Special Python methods.
def __repr__(self):
@ -277,6 +306,9 @@ class Path(text_type):
def __exit__(self, *_):
os.chdir(self._old_dir)
def __fspath__(self):
return self
@classmethod
def getcwd(cls):
""" Return the current working directory as a path object.
@ -330,23 +362,45 @@ class Path(text_type):
return self.expandvars().expanduser().normpath()
@property
def namebase(self):
def stem(self):
""" The same as :meth:`name`, but with one file extension stripped off.
For example,
``Path('/home/guido/python.tar.gz').name == 'python.tar.gz'``,
but
``Path('/home/guido/python.tar.gz').namebase == 'python.tar'``.
>>> Path('/home/guido/python.tar.gz').stem
'python.tar'
"""
base, ext = self.module.splitext(self.name)
return base
@property
def namebase(self):
warnings.warn("Use .stem instead of .namebase", DeprecationWarning)
return self.stem
@property
def ext(self):
""" The file extension, for example ``'.py'``. """
f, ext = self.module.splitext(self)
return ext
def with_suffix(self, suffix):
""" Return a new path with the file suffix changed (or added, if none)
>>> Path('/home/guido/python.tar.gz').with_suffix(".foo")
Path('/home/guido/python.tar.foo')
>>> Path('python').with_suffix('.zip')
Path('python.zip')
>>> Path('filename.ext').with_suffix('zip')
Traceback (most recent call last):
...
ValueError: Invalid suffix 'zip'
"""
if not suffix.startswith('.'):
raise ValueError("Invalid suffix {suffix!r}".format(**locals()))
return self.stripext() + suffix
@property
def drive(self):
""" The drive specifier, for example ``'C:'``.
@ -437,8 +491,9 @@ class Path(text_type):
@multimethod
def joinpath(cls, first, *others):
"""
Join first to zero or more :class:`Path` components, adding a separator
character (:samp:`{first}.module.sep`) if needed. Returns a new instance of
Join first to zero or more :class:`Path` components,
adding a separator character (:samp:`{first}.module.sep`)
if needed. Returns a new instance of
:samp:`{first}._next_class`.
.. seealso:: :func:`os.path.join`
@ -516,7 +571,7 @@ class Path(text_type):
# --- Listing, searching, walking, and matching
def listdir(self, pattern=None):
def listdir(self, match=None):
""" D.listdir() -> List of items in this directory.
Use :meth:`files` or :meth:`dirs` instead if you want a listing
@ -524,46 +579,39 @@ class Path(text_type):
The elements of the list are Path objects.
With the optional `pattern` argument, this only lists
items whose names match the given pattern.
With the optional `match` argument, a callable,
only return items whose names match the given pattern.
.. seealso:: :meth:`files`, :meth:`dirs`
"""
if pattern is None:
pattern = '*'
return [
self / child
for child in map(self._always_unicode, os.listdir(self))
if self._next_class(child).fnmatch(pattern)
]
match = matchers.load(match)
return list(filter(match, (
self / child for child in os.listdir(self)
)))
def dirs(self, pattern=None):
def dirs(self, *args, **kwargs):
""" D.dirs() -> List of this directory's subdirectories.
The elements of the list are Path objects.
This does not walk recursively into subdirectories
(but see :meth:`walkdirs`).
With the optional `pattern` argument, this only lists
directories whose names match the given pattern. For
example, ``d.dirs('build-*')``.
Accepts parameters to :meth:`listdir`.
"""
return [p for p in self.listdir(pattern) if p.isdir()]
return [p for p in self.listdir(*args, **kwargs) if p.isdir()]
def files(self, pattern=None):
def files(self, *args, **kwargs):
""" D.files() -> List of the files in this directory.
The elements of the list are Path objects.
This does not walk into subdirectories (see :meth:`walkfiles`).
With the optional `pattern` argument, this only lists files
whose names match the given pattern. For example,
``d.files('*.pyc')``.
Accepts parameters to :meth:`listdir`.
"""
return [p for p in self.listdir(pattern) if p.isfile()]
return [p for p in self.listdir(*args, **kwargs) if p.isfile()]
def walk(self, pattern=None, errors='strict'):
def walk(self, match=None, errors='strict'):
""" D.walk() -> iterator over files and subdirs, recursively.
The iterator yields Path objects naming each child item of
@ -593,6 +641,8 @@ class Path(text_type):
raise ValueError("invalid errors parameter")
errors = vars(Handlers).get(errors, errors)
match = matchers.load(match)
try:
childList = self.listdir()
except Exception:
@ -603,7 +653,7 @@ class Path(text_type):
return
for child in childList:
if pattern is None or child.fnmatch(pattern):
if match(child):
yield child
try:
isdir = child.isdir()
@ -615,92 +665,26 @@ class Path(text_type):
isdir = False
if isdir:
for item in child.walk(pattern, errors):
for item in child.walk(errors=errors, match=match):
yield item
def walkdirs(self, pattern=None, errors='strict'):
def walkdirs(self, *args, **kwargs):
""" D.walkdirs() -> iterator over subdirs, recursively.
With the optional `pattern` argument, this yields only
directories whose names match the given pattern. For
example, ``mydir.walkdirs('*test')`` yields only directories
with names ending in ``'test'``.
The `errors=` keyword argument controls behavior when an
error occurs. The default is ``'strict'``, which causes an
exception. The other allowed values are ``'warn'`` (which
reports the error via :func:`warnings.warn()`), and ``'ignore'``.
"""
if errors not in ('strict', 'warn', 'ignore'):
raise ValueError("invalid errors parameter")
return (
item
for item in self.walk(*args, **kwargs)
if item.isdir()
)
try:
dirs = self.dirs()
except Exception:
if errors == 'ignore':
return
elif errors == 'warn':
warnings.warn(
"Unable to list directory '%s': %s"
% (self, sys.exc_info()[1]),
TreeWalkWarning)
return
else:
raise
for child in dirs:
if pattern is None or child.fnmatch(pattern):
yield child
for subsubdir in child.walkdirs(pattern, errors):
yield subsubdir
def walkfiles(self, pattern=None, errors='strict'):
def walkfiles(self, *args, **kwargs):
""" D.walkfiles() -> iterator over files in D, recursively.
The optional argument `pattern` limits the results to files
with names that match the pattern. For example,
``mydir.walkfiles('*.tmp')`` yields only files with the ``.tmp``
extension.
"""
if errors not in ('strict', 'warn', 'ignore'):
raise ValueError("invalid errors parameter")
try:
childList = self.listdir()
except Exception:
if errors == 'ignore':
return
elif errors == 'warn':
warnings.warn(
"Unable to list directory '%s': %s"
% (self, sys.exc_info()[1]),
TreeWalkWarning)
return
else:
raise
for child in childList:
try:
isfile = child.isfile()
isdir = not isfile and child.isdir()
except:
if errors == 'ignore':
continue
elif errors == 'warn':
warnings.warn(
"Unable to access '%s': %s"
% (self, sys.exc_info()[1]),
TreeWalkWarning)
continue
else:
raise
if isfile:
if pattern is None or child.fnmatch(pattern):
yield child
elif isdir:
for f in child.walkfiles(pattern, errors):
yield f
return (
item
for item in self.walk(*args, **kwargs)
if item.isfile()
)
def fnmatch(self, pattern, normcase=None):
""" Return ``True`` if `self.name` matches the given `pattern`.
@ -710,8 +694,8 @@ class Path(text_type):
attribute, it is applied to the name and path prior to comparison.
`normcase` - (optional) A function used to normalize the pattern and
filename before matching. Defaults to :meth:`self.module`, which defaults
to :meth:`os.path.normcase`.
filename before matching. Defaults to :meth:`self.module`, which
defaults to :meth:`os.path.normcase`.
.. seealso:: :func:`fnmatch.fnmatch`
"""
@ -730,10 +714,32 @@ class Path(text_type):
of all the files users have in their :file:`bin` directories.
.. seealso:: :func:`glob.glob`
.. note:: Glob is **not** recursive, even when using ``**``.
To do recursive globbing see :func:`walk`,
:func:`walkdirs` or :func:`walkfiles`.
"""
cls = self._next_class
return [cls(s) for s in glob.glob(self / pattern)]
def iglob(self, pattern):
""" Return an iterator of Path objects that match the pattern.
`pattern` - a path relative to this directory, with wildcards.
For example, ``Path('/users').iglob('*/bin/*')`` returns an
iterator of all the files users have in their :file:`bin`
directories.
.. seealso:: :func:`glob.iglob`
.. note:: Glob is **not** recursive, even when using ``**``.
To do recursive globbing see :func:`walk`,
:func:`walkdirs` or :func:`walkfiles`.
"""
cls = self._next_class
return (cls(s) for s in glob.iglob(self / pattern))
#
# --- Reading or writing an entire file at once.
@ -882,15 +888,9 @@ class Path(text_type):
translated to ``'\n'``. If ``False``, newline characters are
stripped off. Default is ``True``.
This uses ``'U'`` mode.
.. seealso:: :meth:`text`
"""
if encoding is None and retain:
with self.open('U') as f:
return f.readlines()
else:
return self.text(encoding, errors).splitlines(retain)
return self.text(encoding, errors).splitlines(retain)
def write_lines(self, lines, encoding=None, errors='strict',
linesep=os.linesep, append=False):
@ -931,14 +931,15 @@ class Path(text_type):
to read the file later.
"""
with self.open('ab' if append else 'wb') as f:
for l in lines:
isUnicode = isinstance(l, text_type)
for line in lines:
isUnicode = isinstance(line, text_type)
if linesep is not None:
pattern = U_NL_END if isUnicode else NL_END
l = pattern.sub('', l) + linesep
line = pattern.sub('', line) + linesep
if isUnicode:
l = l.encode(encoding or sys.getdefaultencoding(), errors)
f.write(l)
line = line.encode(
encoding or sys.getdefaultencoding(), errors)
f.write(line)
def read_md5(self):
""" Calculate the md5 hash for this file.
@ -952,8 +953,8 @@ class Path(text_type):
def _hash(self, hash_name):
""" Returns a hash object for the file at the current path.
`hash_name` should be a hash algo name (such as ``'md5'`` or ``'sha1'``)
that's available in the :mod:`hashlib` module.
`hash_name` should be a hash algo name (such as ``'md5'``
or ``'sha1'``) that's available in the :mod:`hashlib` module.
"""
m = hashlib.new(hash_name)
for chunk in self.chunks(8192, mode="rb"):
@ -1176,7 +1177,8 @@ class Path(text_type):
gid = grp.getgrnam(gid).gr_gid
os.chown(self, uid, gid)
else:
raise NotImplementedError("Ownership not available on this platform.")
msg = "Ownership not available on this platform."
raise NotImplementedError(msg)
return self
def rename(self, new):
@ -1236,7 +1238,8 @@ class Path(text_type):
self.rmdir()
except OSError:
_, e, _ = sys.exc_info()
if e.errno != errno.ENOTEMPTY and e.errno != errno.EEXIST:
bypass_codes = errno.ENOTEMPTY, errno.EEXIST, errno.ENOENT
if e.errno not in bypass_codes:
raise
return self
@ -1277,9 +1280,8 @@ class Path(text_type):
file does not exist. """
try:
self.unlink()
except OSError:
_, e, _ = sys.exc_info()
if e.errno != errno.ENOENT:
except FileNotFoundError as exc:
if PY2 and exc.errno != errno.ENOENT:
raise
return self
@ -1306,11 +1308,16 @@ class Path(text_type):
return self._next_class(newpath)
if hasattr(os, 'symlink'):
def symlink(self, newlink):
def symlink(self, newlink=None):
""" Create a symbolic link at `newlink`, pointing here.
If newlink is not supplied, the symbolic link will assume
the name self.basename(), creating the link in the cwd.
.. seealso:: :func:`os.symlink`
"""
if newlink is None:
newlink = self.basename()
os.symlink(self, newlink)
return self._next_class(newlink)
@ -1368,30 +1375,60 @@ class Path(text_type):
cd = chdir
def merge_tree(self, dst, symlinks=False, *args, **kwargs):
def merge_tree(
self, dst, symlinks=False,
# *
update=False,
copy_function=shutil.copy2,
ignore=lambda dir, contents: []):
"""
Copy entire contents of self to dst, overwriting existing
contents in dst with those in self.
If the additional keyword `update` is True, each
`src` will only be copied if `dst` does not exist,
or `src` is newer than `dst`.
Pass ``symlinks=True`` to copy symbolic links as links.
Note that the technique employed stages the files in a temporary
directory first, so this function is not suitable for merging
trees with large files, especially if the temporary directory
is not capable of storing a copy of the entire source tree.
Accepts a ``copy_function``, similar to copytree.
To avoid overwriting newer files, supply a copy function
wrapped in ``only_newer``. For example::
src.merge_tree(dst, copy_function=only_newer(shutil.copy2))
"""
update = kwargs.pop('update', False)
with tempdir() as _temp_dir:
# first copy the tree to a stage directory to support
# the parameters and behavior of copytree.
stage = _temp_dir / str(hash(self))
self.copytree(stage, symlinks, *args, **kwargs)
# now copy everything from the stage directory using
# the semantics of dir_util.copy_tree
dir_util.copy_tree(stage, dst, preserve_symlinks=symlinks,
update=update)
dst = self._next_class(dst)
dst.makedirs_p()
if update:
warnings.warn(
"Update is deprecated; "
"use copy_function=only_newer(shutil.copy2)",
DeprecationWarning,
stacklevel=2,
)
copy_function = only_newer(copy_function)
sources = self.listdir()
_ignored = ignore(self, [item.name for item in sources])
def ignored(item):
return item.name in _ignored
for source in itertools.filterfalse(ignored, sources):
dest = dst / source.name
if symlinks and source.islink():
target = source.readlink()
target.symlink(dest)
elif source.isdir():
source.merge_tree(
dest,
symlinks=symlinks,
update=update,
copy_function=copy_function,
ignore=ignore,
)
else:
copy_function(source, dest)
self.copystat(dst)
#
# --- Special stuff from os
@ -1410,19 +1447,23 @@ class Path(text_type):
# in-place re-writing, courtesy of Martijn Pieters
# http://www.zopatista.com/python/2013/11/26/inplace-file-rewriting/
@contextlib.contextmanager
def in_place(self, mode='r', buffering=-1, encoding=None, errors=None,
newline=None, backup_extension=None):
def in_place(
self, mode='r', buffering=-1, encoding=None, errors=None,
newline=None, backup_extension=None,
):
"""
A context in which a file may be re-written in-place with new content.
A context in which a file may be re-written in-place with
new content.
Yields a tuple of :samp:`({readable}, {writable})` file objects, where `writable`
replaces `readable`.
Yields a tuple of :samp:`({readable}, {writable})` file
objects, where `writable` replaces `readable`.
If an exception occurs, the old file is restored, removing the
written data.
Mode *must not* use ``'w'``, ``'a'``, or ``'+'``; only read-only-modes are
allowed. A :exc:`ValueError` is raised on invalid modes.
Mode *must not* use ``'w'``, ``'a'``, or ``'+'``; only
read-only-modes are allowed. A :exc:`ValueError` is raised
on invalid modes.
For example, to add line numbers to a file::
@ -1448,22 +1489,28 @@ class Path(text_type):
except os.error:
pass
os.rename(self, backup_fn)
readable = io.open(backup_fn, mode, buffering=buffering,
encoding=encoding, errors=errors, newline=newline)
readable = io.open(
backup_fn, mode, buffering=buffering,
encoding=encoding, errors=errors, newline=newline,
)
try:
perm = os.fstat(readable.fileno()).st_mode
except OSError:
writable = open(self, 'w' + mode.replace('r', ''),
writable = open(
self, 'w' + mode.replace('r', ''),
buffering=buffering, encoding=encoding, errors=errors,
newline=newline)
newline=newline,
)
else:
os_mode = os.O_CREAT | os.O_WRONLY | os.O_TRUNC
if hasattr(os, 'O_BINARY'):
os_mode |= os.O_BINARY
fd = os.open(self, os_mode, perm)
writable = io.open(fd, "w" + mode.replace('r', ''),
writable = io.open(
fd, "w" + mode.replace('r', ''),
buffering=buffering, encoding=encoding, errors=errors,
newline=newline)
newline=newline,
)
try:
if hasattr(os, 'chmod'):
os.chmod(self, perm)
@ -1516,6 +1563,23 @@ class Path(text_type):
return functools.partial(SpecialResolver, cls)
def only_newer(copy_func):
"""
Wrap a copy function (like shutil.copy2) to return
the dst if it's newer than the source.
"""
@functools.wraps(copy_func)
def wrapper(src, dst, *args, **kwargs):
is_newer_dst = (
dst.exists()
and dst.getmtime() >= src.getmtime()
)
if is_newer_dst:
return dst
return copy_func(src, dst, *args, **kwargs)
return wrapper
class SpecialResolver(object):
class ResolverScope:
def __init__(self, paths, scope):
@ -1584,14 +1648,15 @@ class Multi:
)
class tempdir(Path):
class TempDir(Path):
"""
A temporary directory via :func:`tempfile.mkdtemp`, and constructed with the
same parameters that you can use as a context manager.
A temporary directory via :func:`tempfile.mkdtemp`, and
constructed with the same parameters that you can use
as a context manager.
Example:
Example::
with tempdir() as d:
with TempDir() as d:
# do stuff with the Path object "d"
# here the directory is deleted automatically
@ -1606,19 +1671,27 @@ class tempdir(Path):
def __new__(cls, *args, **kwargs):
dirname = tempfile.mkdtemp(*args, **kwargs)
return super(tempdir, cls).__new__(cls, dirname)
return super(TempDir, cls).__new__(cls, dirname)
def __init__(self, *args, **kwargs):
pass
def __enter__(self):
return self
# TempDir should return a Path version of itself and not itself
# so that a second context manager does not create a second
# temporary directory, but rather changes CWD to the location
# of the temporary directory.
return self._next_class(self)
def __exit__(self, exc_type, exc_value, traceback):
if not exc_value:
self.rmtree()
# For backwards compatibility.
tempdir = TempDir
def _multi_permission_mask(mode):
"""
Support multiple, comma-separated Unix chmod symbolic modes.
@ -1626,7 +1699,8 @@ def _multi_permission_mask(mode):
>>> _multi_permission_mask('a=r,u+w')(0) == 0o644
True
"""
compose = lambda f, g: lambda *args, **kwargs: g(f(*args, **kwargs))
def compose(f, g):
return lambda *args, **kwargs: g(f(*args, **kwargs))
return functools.reduce(compose, map(_permission_mask, mode.split(',')))
@ -1692,31 +1766,56 @@ def _permission_mask(mode):
return functools.partial(op_map[op], mask)
class CaseInsensitivePattern(text_type):
class CaseInsensitivePattern(matchers.CaseInsensitive):
def __init__(self, value):
warnings.warn(
"Use matchers.CaseInsensitive instead",
DeprecationWarning,
stacklevel=2,
)
super(CaseInsensitivePattern, self).__init__(value)
class FastPath(Path):
def __init__(self, *args, **kwargs):
warnings.warn(
"Use Path, as FastPath no longer holds any advantage",
DeprecationWarning,
stacklevel=2,
)
super(FastPath, self).__init__(*args, **kwargs)
def patch_for_linux_python2():
"""
A string with a ``'normcase'`` property, suitable for passing to
:meth:`listdir`, :meth:`dirs`, :meth:`files`, :meth:`walk`,
:meth:`walkdirs`, or :meth:`walkfiles` to match case-insensitive.
For example, to get all files ending in .py, .Py, .pY, or .PY in the
current directory::
from path import Path, CaseInsensitivePattern as ci
Path('.').files(ci('*.py'))
As reported in #130, when Linux users create filenames
not in the file system encoding, it creates problems on
Python 2. This function attempts to patch the os module
to make it behave more like that on Python 3.
"""
if not PY2 or platform.system() != 'Linux':
return
@property
def normcase(self):
return __import__('ntpath').normcase
try:
import backports.os
except ImportError:
return
########################
# Backward-compatibility
class path(Path):
def __new__(cls, *args, **kwargs):
msg = "path is deprecated. Use Path instead."
warnings.warn(msg, DeprecationWarning)
return Path.__new__(cls, *args, **kwargs)
class OS:
"""
The proxy to the os module
"""
def __init__(self, wrapped):
self._orig = wrapped
def __getattr__(self, name):
return getattr(self._orig, name)
def listdir(self, *args, **kwargs):
items = self._orig.listdir(*args, **kwargs)
return list(map(backports.os.fsdecode, items))
globals().update(os=OS(os))
__all__ += ['path']
########################
patch_for_linux_python2()

View file

@ -22,25 +22,54 @@ import os
import sys
import shutil
import time
import types
import ntpath
import posixpath
import textwrap
import platform
import importlib
import operator
import datetime
import subprocess
import re
import pytest
import packaging.version
from path import Path, tempdir
from path import CaseInsensitivePattern as ci
import path
from path import TempDir
from path import matchers
from path import SpecialResolver
from path import Multi
Path = None
def p(**choices):
""" Choose a value from several possible values, based on os.name """
return choices[os.name]
@pytest.fixture(autouse=True, params=[path.Path])
def path_class(request, monkeypatch):
"""
Invoke tests on any number of Path classes.
"""
monkeypatch.setitem(globals(), 'Path', request.param)
def mac_version(target, comparator=operator.ge):
"""
Return True if on a Mac whose version passes the comparator.
"""
current_ver = packaging.version.parse(platform.mac_ver()[0])
target_ver = packaging.version.parse(target)
return (
platform.system() == 'Darwin'
and comparator(current_ver, target_ver)
)
class TestBasics:
def test_relpath(self):
root = Path(p(nt='C:\\', posix='/'))
@ -51,14 +80,14 @@ class TestBasics:
up = Path(os.pardir)
# basics
assert root.relpathto(boz) == Path('foo')/'bar'/'Baz'/'Boz'
assert bar.relpathto(boz) == Path('Baz')/'Boz'
assert quux.relpathto(boz) == up/'bar'/'Baz'/'Boz'
assert boz.relpathto(quux) == up/up/up/'quux'
assert boz.relpathto(bar) == up/up
assert root.relpathto(boz) == Path('foo') / 'bar' / 'Baz' / 'Boz'
assert bar.relpathto(boz) == Path('Baz') / 'Boz'
assert quux.relpathto(boz) == up / 'bar' / 'Baz' / 'Boz'
assert boz.relpathto(quux) == up / up / up / 'quux'
assert boz.relpathto(bar) == up / up
# Path is not the first element in concatenation
assert root.relpathto(boz) == 'foo'/Path('bar')/'Baz'/'Boz'
assert root.relpathto(boz) == 'foo' / Path('bar') / 'Baz' / 'Boz'
# x.relpathto(x) == curdir
assert root.relpathto(root) == os.curdir
@ -112,7 +141,7 @@ class TestBasics:
# Test p1/p1.
p1 = Path("foo")
p2 = Path("bar")
assert p1/p2 == p(nt='foo\\bar', posix='foo/bar')
assert p1 / p2 == p(nt='foo\\bar', posix='foo/bar')
def test_properties(self):
# Create sample path object.
@ -207,6 +236,30 @@ class TestBasics:
assert res2 == 'foo/bar'
class TestPerformance:
@pytest.mark.skipif(
path.PY2,
reason="Tests fail frequently on Python 2; see #153")
def test_import_time(self, monkeypatch):
"""
Import of path.py should take less than 100ms.
Run tests in a subprocess to isolate from test suite overhead.
"""
cmd = [
sys.executable,
'-m', 'timeit',
'-n', '1',
'-r', '1',
'import path',
]
res = subprocess.check_output(cmd, universal_newlines=True)
dur = re.search(r'(\d+) msec per loop', res).group(1)
limit = datetime.timedelta(milliseconds=100)
duration = datetime.timedelta(milliseconds=int(dur))
assert duration < limit
class TestSelfReturn:
"""
Some methods don't necessarily return any value (e.g. makedirs,
@ -246,7 +299,7 @@ class TestSelfReturn:
class TestScratchDir:
"""
Tests that run in a temporary directory (does not test tempdir class)
Tests that run in a temporary directory (does not test TempDir class)
"""
def test_context_manager(self, tmpdir):
"""Can be used as context manager for chdir."""
@ -282,12 +335,12 @@ class TestScratchDir:
ct = f.ctime
assert t0 <= ct <= t1
time.sleep(threshold*2)
time.sleep(threshold * 2)
fobj = open(f, 'ab')
fobj.write('some bytes'.encode('utf-8'))
fobj.close()
time.sleep(threshold*2)
time.sleep(threshold * 2)
t2 = time.time() - threshold
f.touch()
t3 = time.time() + threshold
@ -305,9 +358,12 @@ class TestScratchDir:
assert ct == ct2
assert ct2 < t2
else:
# On other systems, it might be the CHANGE time
# (especially on Unix, time of inode changes)
assert ct == ct2 or ct2 == f.mtime
assert (
# ctime is unchanged
ct == ct2 or
# ctime is approximately the mtime
ct2 == pytest.approx(f.mtime, 0.001)
)
def test_listing(self, tmpdir):
d = Path(tmpdir)
@ -330,6 +386,11 @@ class TestScratchDir:
assert d.glob('*') == [af]
assert d.glob('*.html') == []
assert d.glob('testfile') == []
# .iglob matches .glob but as an iterator.
assert list(d.iglob('*')) == d.glob('*')
assert isinstance(d.iglob('*'), types.GeneratorType)
finally:
af.remove()
@ -348,9 +409,17 @@ class TestScratchDir:
for f in files:
try:
f.remove()
except:
except Exception:
pass
@pytest.mark.xfail(
mac_version('10.13'),
reason="macOS disallows invalid encodings",
)
@pytest.mark.xfail(
platform.system() == 'Windows' and path.PY3,
reason="Can't write latin characters. See #133",
)
def test_listdir_other_encoding(self, tmpdir):
"""
Some filesystems allow non-character sequences in path names.
@ -498,28 +567,28 @@ class TestScratchDir:
def test_patterns(self, tmpdir):
d = Path(tmpdir)
names = ['x.tmp', 'x.xtmp', 'x2g', 'x22', 'x.txt']
dirs = [d, d/'xdir', d/'xdir.tmp', d/'xdir.tmp'/'xsubdir']
dirs = [d, d / 'xdir', d / 'xdir.tmp', d / 'xdir.tmp' / 'xsubdir']
for e in dirs:
if not e.isdir():
e.makedirs()
for name in names:
(e/name).touch()
self.assertList(d.listdir('*.tmp'), [d/'x.tmp', d/'xdir.tmp'])
self.assertList(d.files('*.tmp'), [d/'x.tmp'])
self.assertList(d.dirs('*.tmp'), [d/'xdir.tmp'])
(e / name).touch()
self.assertList(d.listdir('*.tmp'), [d / 'x.tmp', d / 'xdir.tmp'])
self.assertList(d.files('*.tmp'), [d / 'x.tmp'])
self.assertList(d.dirs('*.tmp'), [d / 'xdir.tmp'])
self.assertList(d.walk(), [e for e in dirs
if e != d] + [e/n for e in dirs
if e != d] + [e / n for e in dirs
for n in names])
self.assertList(d.walk('*.tmp'),
[e/'x.tmp' for e in dirs] + [d/'xdir.tmp'])
self.assertList(d.walkfiles('*.tmp'), [e/'x.tmp' for e in dirs])
self.assertList(d.walkdirs('*.tmp'), [d/'xdir.tmp'])
[e / 'x.tmp' for e in dirs] + [d / 'xdir.tmp'])
self.assertList(d.walkfiles('*.tmp'), [e / 'x.tmp' for e in dirs])
self.assertList(d.walkdirs('*.tmp'), [d / 'xdir.tmp'])
def test_unicode(self, tmpdir):
d = Path(tmpdir)
p = d/'unicode.txt'
p = d / 'unicode.txt'
def test(enc):
""" Test that path works with the specified encoding,
@ -527,18 +596,22 @@ class TestScratchDir:
Unicode codepoints.
"""
given = ('Hello world\n'
'\u0d0a\u0a0d\u0d15\u0a15\r\n'
'\u0d0a\u0a0d\u0d15\u0a15\x85'
'\u0d0a\u0a0d\u0d15\u0a15\u2028'
'\r'
'hanging')
clean = ('Hello world\n'
'\u0d0a\u0a0d\u0d15\u0a15\n'
'\u0d0a\u0a0d\u0d15\u0a15\n'
'\u0d0a\u0a0d\u0d15\u0a15\n'
'\n'
'hanging')
given = (
'Hello world\n'
'\u0d0a\u0a0d\u0d15\u0a15\r\n'
'\u0d0a\u0a0d\u0d15\u0a15\x85'
'\u0d0a\u0a0d\u0d15\u0a15\u2028'
'\r'
'hanging'
)
clean = (
'Hello world\n'
'\u0d0a\u0a0d\u0d15\u0a15\n'
'\u0d0a\u0a0d\u0d15\u0a15\n'
'\u0d0a\u0a0d\u0d15\u0a15\n'
'\n'
'hanging'
)
givenLines = [
('Hello world\n'),
('\u0d0a\u0a0d\u0d15\u0a15\r\n'),
@ -581,8 +654,9 @@ class TestScratchDir:
return
# Write Unicode to file using path.write_text().
cleanNoHanging = clean + '\n' # This test doesn't work with a
# hanging line.
# This test doesn't work with a hanging line.
cleanNoHanging = clean + '\n'
p.write_text(cleanNoHanging, enc)
p.write_text(cleanNoHanging, enc, append=True)
# Check the result.
@ -641,7 +715,7 @@ class TestScratchDir:
test('UTF-16')
def test_chunks(self, tmpdir):
p = (tempdir() / 'test.txt').touch()
p = (TempDir() / 'test.txt').touch()
txt = "0123456789"
size = 5
p.write_text(txt)
@ -650,16 +724,18 @@ class TestScratchDir:
assert i == len(txt) / size - 1
@pytest.mark.skipif(not hasattr(os.path, 'samefile'),
reason="samefile not present")
@pytest.mark.skipif(
not hasattr(os.path, 'samefile'),
reason="samefile not present",
)
def test_samefile(self, tmpdir):
f1 = (tempdir() / '1.txt').touch()
f1 = (TempDir() / '1.txt').touch()
f1.write_text('foo')
f2 = (tempdir() / '2.txt').touch()
f2 = (TempDir() / '2.txt').touch()
f1.write_text('foo')
f3 = (tempdir() / '3.txt').touch()
f3 = (TempDir() / '3.txt').touch()
f1.write_text('bar')
f4 = (tempdir() / '4.txt')
f4 = (TempDir() / '4.txt')
f1.copyfile(f4)
assert os.path.samefile(f1, f2) == f1.samefile(f2)
@ -680,6 +756,26 @@ class TestScratchDir:
self.fail("Calling `rmtree_p` on non-existent directory "
"should not raise an exception.")
def test_rmdir_p_exists(self, tmpdir):
"""
Invocation of rmdir_p on an existant directory should
remove the directory.
"""
d = Path(tmpdir)
sub = d / 'subfolder'
sub.mkdir()
sub.rmdir_p()
assert not sub.exists()
def test_rmdir_p_nonexistent(self, tmpdir):
"""
A non-existent file should not raise an exception.
"""
d = Path(tmpdir)
sub = d / 'subfolder'
assert not sub.exists()
sub.rmdir_p()
class TestMergeTree:
@pytest.fixture(autouse=True)
@ -701,6 +797,11 @@ class TestMergeTree:
else:
self.test_file.copy(self.test_link)
def check_link(self):
target = Path(self.subdir_b / self.test_link.name)
check = target.islink if hasattr(os, 'symlink') else target.isfile
assert check()
def test_with_nonexisting_dst_kwargs(self):
self.subdir_a.merge_tree(self.subdir_b, symlinks=True)
assert self.subdir_b.isdir()
@ -709,7 +810,7 @@ class TestMergeTree:
self.subdir_b / self.test_link.name,
))
assert set(self.subdir_b.listdir()) == expected
assert Path(self.subdir_b / self.test_link.name).islink()
self.check_link()
def test_with_nonexisting_dst_args(self):
self.subdir_a.merge_tree(self.subdir_b, True)
@ -719,7 +820,7 @@ class TestMergeTree:
self.subdir_b / self.test_link.name,
))
assert set(self.subdir_b.listdir()) == expected
assert Path(self.subdir_b / self.test_link.name).islink()
self.check_link()
def test_with_existing_dst(self):
self.subdir_b.rmtree()
@ -740,7 +841,7 @@ class TestMergeTree:
self.subdir_b / test_new.name,
))
assert set(self.subdir_b.listdir()) == expected
assert Path(self.subdir_b / self.test_link.name).islink()
self.check_link()
assert len(Path(self.subdir_b / self.test_file.name).bytes()) == 5000
def test_copytree_parameters(self):
@ -753,6 +854,20 @@ class TestMergeTree:
assert self.subdir_b.isdir()
assert self.subdir_b.listdir() == [self.subdir_b / self.test_file.name]
def test_only_newer(self):
"""
merge_tree should accept a copy_function in which only
newer files are copied and older files do not overwrite
newer copies in the dest.
"""
target = self.subdir_b / 'testfile.txt'
target.write_text('this is newer')
self.subdir_a.merge_tree(
self.subdir_b,
copy_function=path.only_newer(shutil.copy2),
)
assert target.text() == 'this is newer'
class TestChdir:
def test_chdir_or_cd(self, tmpdir):
@ -781,17 +896,17 @@ class TestChdir:
class TestSubclass:
class PathSubclass(Path):
pass
def test_subclass_produces_same_class(self):
"""
When operations are invoked on a subclass, they should produce another
instance of that subclass.
"""
p = self.PathSubclass('/foo')
class PathSubclass(Path):
pass
p = PathSubclass('/foo')
subdir = p / 'bar'
assert isinstance(subdir, self.PathSubclass)
assert isinstance(subdir, PathSubclass)
class TestTempDir:
@ -800,8 +915,8 @@ class TestTempDir:
"""
One should be able to readily construct a temporary directory
"""
d = tempdir()
assert isinstance(d, Path)
d = TempDir()
assert isinstance(d, path.Path)
assert d.exists()
assert d.isdir()
d.rmdir()
@ -809,24 +924,24 @@ class TestTempDir:
def test_next_class(self):
"""
It should be possible to invoke operations on a tempdir and get
It should be possible to invoke operations on a TempDir and get
Path classes.
"""
d = tempdir()
d = TempDir()
sub = d / 'subdir'
assert isinstance(sub, Path)
assert isinstance(sub, path.Path)
d.rmdir()
def test_context_manager(self):
"""
One should be able to use a tempdir object as a context, which will
One should be able to use a TempDir object as a context, which will
clean up the contents after.
"""
d = tempdir()
d = TempDir()
res = d.__enter__()
assert res is d
assert res == path.Path(d)
(d / 'somefile.txt').touch()
assert not isinstance(d / 'somefile.txt', tempdir)
assert not isinstance(d / 'somefile.txt', TempDir)
d.__exit__(None, None, None)
assert not d.exists()
@ -834,10 +949,10 @@ class TestTempDir:
"""
The context manager will not clean up if an exception occurs.
"""
d = tempdir()
d = TempDir()
d.__enter__()
(d / 'somefile.txt').touch()
assert not isinstance(d / 'somefile.txt', tempdir)
assert not isinstance(d / 'somefile.txt', TempDir)
d.__exit__(TypeError, TypeError('foo'), None)
assert d.exists()
@ -847,7 +962,7 @@ class TestTempDir:
provide a temporry directory that will be deleted after that.
"""
with tempdir() as d:
with TempDir() as d:
assert d.isdir()
assert not d.isdir()
@ -876,7 +991,8 @@ class TestPatternMatching:
assert p.fnmatch('FOO[ABC]AR')
def test_fnmatch_custom_normcase(self):
normcase = lambda path: path.upper()
def normcase(path):
return path.upper()
p = Path('FooBar')
assert p.fnmatch('foobar', normcase=normcase)
assert p.fnmatch('FOO[ABC]AR', normcase=normcase)
@ -891,8 +1007,8 @@ class TestPatternMatching:
def test_listdir_patterns(self, tmpdir):
p = Path(tmpdir)
(p/'sub').mkdir()
(p/'File').touch()
(p / 'sub').mkdir()
(p / 'File').touch()
assert p.listdir('s*') == [p / 'sub']
assert len(p.listdir('*')) == 2
@ -903,14 +1019,14 @@ class TestPatternMatching:
"""
always_unix = Path.using_module(posixpath)
p = always_unix(tmpdir)
(p/'sub').mkdir()
(p/'File').touch()
(p / 'sub').mkdir()
(p / 'File').touch()
assert p.listdir('S*') == []
always_win = Path.using_module(ntpath)
p = always_win(tmpdir)
assert p.listdir('S*') == [p/'sub']
assert p.listdir('f*') == [p/'File']
assert p.listdir('S*') == [p / 'sub']
assert p.listdir('f*') == [p / 'File']
def test_listdir_case_insensitive(self, tmpdir):
"""
@ -918,27 +1034,30 @@ class TestPatternMatching:
used by that Path class.
"""
p = Path(tmpdir)
(p/'sub').mkdir()
(p/'File').touch()
assert p.listdir(ci('S*')) == [p/'sub']
assert p.listdir(ci('f*')) == [p/'File']
assert p.files(ci('S*')) == []
assert p.dirs(ci('f*')) == []
(p / 'sub').mkdir()
(p / 'File').touch()
assert p.listdir(matchers.CaseInsensitive('S*')) == [p / 'sub']
assert p.listdir(matchers.CaseInsensitive('f*')) == [p / 'File']
assert p.files(matchers.CaseInsensitive('S*')) == []
assert p.dirs(matchers.CaseInsensitive('f*')) == []
def test_walk_case_insensitive(self, tmpdir):
p = Path(tmpdir)
(p/'sub1'/'foo').makedirs_p()
(p/'sub2'/'foo').makedirs_p()
(p/'sub1'/'foo'/'bar.Txt').touch()
(p/'sub2'/'foo'/'bar.TXT').touch()
(p/'sub2'/'foo'/'bar.txt.bz2').touch()
files = list(p.walkfiles(ci('*.txt')))
(p / 'sub1' / 'foo').makedirs_p()
(p / 'sub2' / 'foo').makedirs_p()
(p / 'sub1' / 'foo' / 'bar.Txt').touch()
(p / 'sub2' / 'foo' / 'bar.TXT').touch()
(p / 'sub2' / 'foo' / 'bar.txt.bz2').touch()
files = list(p.walkfiles(matchers.CaseInsensitive('*.txt')))
assert len(files) == 2
assert p/'sub2'/'foo'/'bar.TXT' in files
assert p/'sub1'/'foo'/'bar.Txt' in files
assert p / 'sub2' / 'foo' / 'bar.TXT' in files
assert p / 'sub1' / 'foo' / 'bar.Txt' in files
@pytest.mark.skipif(sys.version_info < (2, 6),
reason="in_place requires io module in Python 2.6")
@pytest.mark.skipif(
sys.version_info < (2, 6),
reason="in_place requires io module in Python 2.6",
)
class TestInPlace:
reference_content = textwrap.dedent("""
The quick brown fox jumped over the lazy dog.
@ -959,7 +1078,7 @@ class TestInPlace:
@classmethod
def create_reference(cls, tmpdir):
p = Path(tmpdir)/'document'
p = Path(tmpdir) / 'document'
with p.open('w') as stream:
stream.write(cls.reference_content)
return p
@ -984,7 +1103,7 @@ class TestInPlace:
assert "some error" in str(exc)
with doc.open() as stream:
data = stream.read()
assert not 'Lorem' in data
assert 'Lorem' not in data
assert 'lazy dog' in data
@ -1023,8 +1142,9 @@ class TestSpecialPaths:
def test_unix_paths_fallback(self, tmpdir, monkeypatch, feign_linux):
"Without XDG_CONFIG_HOME set, ~/.config should be used."
fake_home = tmpdir / '_home'
monkeypatch.delitem(os.environ, 'XDG_CONFIG_HOME', raising=False)
monkeypatch.setitem(os.environ, 'HOME', str(fake_home))
expected = str(tmpdir / '_home' / '.config')
expected = Path('~/.config').expanduser()
assert SpecialResolver(Path).user.config == expected
def test_property(self):
@ -1075,7 +1195,8 @@ class TestMultiPath:
cls = Multi.for_class(Path)
assert issubclass(cls, Path)
assert issubclass(cls, Multi)
assert cls.__name__ == 'MultiPath'
expected_name = 'Multi' + Path.__name__
assert cls.__name__ == expected_name
def test_detect_no_pathsep(self):
"""
@ -1115,5 +1236,23 @@ class TestMultiPath:
assert path == input
if __name__ == '__main__':
pytest.main()
@pytest.mark.xfail('path.PY2', reason="Python 2 has no __future__")
def test_no_dependencies():
"""
Path.py guarantees that the path module can be
transplanted into an environment without any dependencies.
"""
cmd = [
sys.executable,
'-S',
'-c', 'import path',
]
subprocess.check_call(cmd)
def test_version():
"""
Under normal circumstances, path should present a
__version__.
"""
assert re.match(r'\d+\.\d+.*', path.__version__)