Update jaraco-windows to 3.9.2

Also updates:
- importlib-metadata-0.7
- jaraco-windows
- jaraco.classes-1.5
- jaraco.collections-1.6.0
- jaraco.functools-1.20
- jaraco.structures-1.1.2
- jaraco.text-1.10.1
- jaraco.ui-1.6
- more-itertools-4.3.0
- path.py-11.5.0
- six-1.12.0
This commit is contained in:
Labrys of Knossos 2018-12-15 01:17:06 -05:00
commit 8d43b8ea39
92 changed files with 7515 additions and 996 deletions

BIN
libs/bin/enver.exe Normal file

Binary file not shown.

BIN
libs/bin/find-symlinks.exe Normal file

Binary file not shown.

BIN
libs/bin/gclip.exe Normal file

Binary file not shown.

BIN
libs/bin/mklink.exe Normal file

Binary file not shown.

BIN
libs/bin/pclip.exe Normal file

Binary file not shown.

BIN
libs/bin/xmouse.exe Normal file

Binary file not shown.

View file

@ -0,0 +1,17 @@
from .api import distribution, Distribution, PackageNotFoundError # noqa: F401
from .api import metadata, entry_points, resolve, version, read_text
# Import for installation side-effects.
from . import _hooks # noqa: F401
__all__ = [
'metadata',
'entry_points',
'resolve',
'version',
'read_text',
]
__version__ = version(__name__)

View file

@ -0,0 +1,148 @@
from __future__ import unicode_literals, absolute_import
import re
import sys
import itertools
from .api import Distribution
from zipfile import ZipFile
if sys.version_info >= (3,): # pragma: nocover
from contextlib import suppress
from pathlib import Path
else: # pragma: nocover
from contextlib2 import suppress # noqa
from itertools import imap as map # type: ignore
from pathlib2 import Path
FileNotFoundError = IOError, OSError
__metaclass__ = type
def install(cls):
"""Class decorator for installation on sys.meta_path."""
sys.meta_path.append(cls)
return cls
class NullFinder:
@staticmethod
def find_spec(*args, **kwargs):
return None
# In Python 2, the import system requires finders
# to have a find_module() method, but this usage
# is deprecated in Python 3 in favor of find_spec().
# For the purposes of this finder (i.e. being present
# on sys.meta_path but having no other import
# system functionality), the two methods are identical.
find_module = find_spec
@install
class MetadataPathFinder(NullFinder):
"""A degenerate finder for distribution packages on the file system.
This finder supplies only a find_distribution() method for versions
of Python that do not have a PathFinder find_distribution().
"""
search_template = r'{name}(-.*)?\.(dist|egg)-info'
@classmethod
def find_distribution(cls, name):
paths = cls._search_paths(name)
dists = map(PathDistribution, paths)
return next(dists, None)
@classmethod
def _search_paths(cls, name):
"""
Find metadata directories in sys.path heuristically.
"""
return itertools.chain.from_iterable(
cls._search_path(path, name)
for path in map(Path, sys.path)
)
@classmethod
def _search_path(cls, root, name):
if not root.is_dir():
return ()
normalized = name.replace('-', '_')
return (
item
for item in root.iterdir()
if item.is_dir()
and re.match(
cls.search_template.format(name=normalized),
str(item.name),
flags=re.IGNORECASE,
)
)
class PathDistribution(Distribution):
def __init__(self, path):
"""Construct a distribution from a path to the metadata directory."""
self._path = path
def read_text(self, filename):
with suppress(FileNotFoundError):
with self._path.joinpath(filename).open(encoding='utf-8') as fp:
return fp.read()
return None
read_text.__doc__ = Distribution.read_text.__doc__
@install
class WheelMetadataFinder(NullFinder):
"""A degenerate finder for distribution packages in wheels.
This finder supplies only a find_distribution() method for versions
of Python that do not have a PathFinder find_distribution().
"""
search_template = r'{name}(-.*)?\.whl'
@classmethod
def find_distribution(cls, name):
paths = cls._search_paths(name)
dists = map(WheelDistribution, paths)
return next(dists, None)
@classmethod
def _search_paths(cls, name):
return (
item
for item in map(Path, sys.path)
if re.match(
cls.search_template.format(name=name),
str(item.name),
flags=re.IGNORECASE,
)
)
class WheelDistribution(Distribution):
def __init__(self, archive):
self._archive = archive
name, version = archive.name.split('-')[0:2]
self._dist_info = '{}-{}.dist-info'.format(name, version)
def read_text(self, filename):
with ZipFile(_path_to_filename(self._archive)) as zf:
with suppress(KeyError):
as_bytes = zf.read('{}/{}'.format(self._dist_info, filename))
return as_bytes.decode('utf-8')
return None
read_text.__doc__ = Distribution.read_text.__doc__
def _path_to_filename(path): # pragma: nocover
"""
On non-compliant systems, ensure a path-like object is
a string.
"""
try:
return path.__fspath__()
except AttributeError:
return str(path)

View file

@ -0,0 +1,146 @@
import io
import abc
import sys
import email
from importlib import import_module
if sys.version_info > (3,): # pragma: nocover
from configparser import ConfigParser
else: # pragma: nocover
from ConfigParser import SafeConfigParser as ConfigParser
try:
BaseClass = ModuleNotFoundError
except NameError: # pragma: nocover
BaseClass = ImportError # type: ignore
__metaclass__ = type
class PackageNotFoundError(BaseClass):
"""The package was not found."""
class Distribution:
"""A Python distribution package."""
@abc.abstractmethod
def read_text(self, filename):
"""Attempt to load metadata file given by the name.
:param filename: The name of the file in the distribution info.
:return: The text if found, otherwise None.
"""
@classmethod
def from_name(cls, name):
"""Return the Distribution for the given package name.
:param name: The name of the distribution package to search for.
:return: The Distribution instance (or subclass thereof) for the named
package, if found.
:raises PackageNotFoundError: When the named package's distribution
metadata cannot be found.
"""
for resolver in cls._discover_resolvers():
resolved = resolver(name)
if resolved is not None:
return resolved
else:
raise PackageNotFoundError(name)
@staticmethod
def _discover_resolvers():
"""Search the meta_path for resolvers."""
declared = (
getattr(finder, 'find_distribution', None)
for finder in sys.meta_path
)
return filter(None, declared)
@property
def metadata(self):
"""Return the parsed metadata for this Distribution.
The returned object will have keys that name the various bits of
metadata. See PEP 566 for details.
"""
return email.message_from_string(
self.read_text('METADATA') or self.read_text('PKG-INFO')
)
@property
def version(self):
"""Return the 'Version' metadata for the distribution package."""
return self.metadata['Version']
def distribution(package):
"""Get the ``Distribution`` instance for the given package.
:param package: The name of the package as a string.
:return: A ``Distribution`` instance (or subclass thereof).
"""
return Distribution.from_name(package)
def metadata(package):
"""Get the metadata for the package.
:param package: The name of the distribution package to query.
:return: An email.Message containing the parsed metadata.
"""
return Distribution.from_name(package).metadata
def version(package):
"""Get the version string for the named package.
:param package: The name of the distribution package to query.
:return: The version string for the package as defined in the package's
"Version" metadata key.
"""
return distribution(package).version
def entry_points(name):
"""Return the entry points for the named distribution package.
:param name: The name of the distribution package to query.
:return: A ConfigParser instance where the sections and keys are taken
from the entry_points.txt ini-style contents.
"""
as_string = read_text(name, 'entry_points.txt')
# 2018-09-10(barry): Should we provide any options here, or let the caller
# send options to the underlying ConfigParser? For now, YAGNI.
config = ConfigParser()
try:
config.read_string(as_string)
except AttributeError: # pragma: nocover
# Python 2 has no read_string
config.readfp(io.StringIO(as_string))
return config
def resolve(entry_point):
"""Resolve an entry point string into the named callable.
:param entry_point: An entry point string of the form
`path.to.module:callable`.
:return: The actual callable object `path.to.module.callable`
:raises ValueError: When `entry_point` doesn't have the proper format.
"""
path, colon, name = entry_point.rpartition(':')
if colon != ':':
raise ValueError('Not an entry point: {}'.format(entry_point))
module = import_module(path)
return getattr(module, name)
def read_text(package, filename):
"""
Read the text of the file in the distribution info directory.
"""
return distribution(package).read_text(filename)

View file

View file

@ -0,0 +1,57 @@
=========================
importlib_metadata NEWS
=========================
0.7 (2018-11-27)
================
* Fixed issue where packages with dashes in their names would
not be discovered. Closes #21.
* Distribution lookup is now case-insensitive. Closes #20.
* Wheel distributions can no longer be discovered by their module
name. Like Path distributions, they must be indicated by their
distribution package name.
0.6 (2018-10-07)
================
* Removed ``importlib_metadata.distribution`` function. Now
the public interface is primarily the utility functions exposed
in ``importlib_metadata.__all__``. Closes #14.
* Added two new utility functions ``read_text`` and
``metadata``.
0.5 (2018-09-18)
================
* Updated README and removed details about Distribution
class, now considered private. Closes #15.
* Added test suite support for Python 3.4+.
* Fixed SyntaxErrors on Python 3.4 and 3.5. !12
* Fixed errors on Windows joining Path elements. !15
0.4 (2018-09-14)
================
* Housekeeping.
0.3 (2018-09-14)
================
* Added usage documentation. Closes #8
* Add support for getting metadata from wheels on ``sys.path``. Closes #9
0.2 (2018-09-11)
================
* Added ``importlib_metadata.entry_points()``. Closes #1
* Added ``importlib_metadata.resolve()``. Closes #12
* Add support for Python 2.7. Closes #4
0.1 (2018-09-10)
================
* Initial release.
..
Local Variables:
mode: change-log-mode
indent-tabs-mode: nil
sentence-end-double-space: t
fill-column: 78
coding: utf-8
End:

View file

@ -0,0 +1,180 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# flake8: noqa
#
# importlib_metadata documentation build configuration file, created by
# sphinx-quickstart on Thu Nov 30 10:21:00 2017.
#
# This file is execfile()d with the current directory set to its
# containing dir.
#
# Note that not all possible configuration values are present in this
# autogenerated file.
#
# All configuration values have a default; values that are commented out
# serve to show the default.
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['sphinx.ext.autodoc',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.coverage',
'sphinx.ext.viewcode']
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
# The master toctree document.
master_doc = 'index'
# General information about the project.
project = 'importlib_metadata'
copyright = '2017-2018, Jason Coombs, Barry Warsaw'
author = 'Jason Coombs, Barry Warsaw'
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
version = '0.1'
# The full version, including alpha/beta/rc tags.
release = '0.1'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
# -- Options for HTML output ----------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'default'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# This is required for the alabaster theme
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
html_sidebars = {
'**': [
'relations.html', # needs 'show_related': True theme option to display
'searchbox.html',
]
}
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'importlib_metadatadoc'
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'importlib_metadata.tex', 'importlib\\_metadata Documentation',
'Brett Cannon, Barry Warsaw', 'manual'),
]
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'importlib_metadata', 'importlib_metadata Documentation',
[author], 1)
]
# -- Options for Texinfo output -------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'importlib_metadata', 'importlib_metadata Documentation',
author, 'importlib_metadata', 'One line description of project.',
'Miscellaneous'),
]
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'python': ('https://docs.python.org/3', None),
}

View file

@ -0,0 +1,53 @@
===============================
Welcome to importlib_metadata
===============================
``importlib_metadata`` is a library which provides an API for accessing an
installed package's `metadata`_, such as its entry points or its top-level
name. This functionality intends to replace most uses of ``pkg_resources``
`entry point API`_ and `metadata API`_. Along with ``importlib.resources`` in
`Python 3.7 and newer`_ (backported as `importlib_resources`_ for older
versions of Python), this can eliminate the need to use the older and less
efficient ``pkg_resources`` package.
``importlib_metadata`` is a backport of Python 3.8's standard library
`importlib.metadata`_ module for Python 2.7, and 3.4 through 3.7. Users of
Python 3.8 and beyond are encouraged to use the standard library module, and
in fact for these versions, ``importlib_metadata`` just shadows that module.
Developers looking for detailed API descriptions should refer to the Python
3.8 standard library documentation.
The documentation here includes a general :ref:`usage <using>` guide.
.. toctree::
:maxdepth: 2
:caption: Contents:
using.rst
changelog.rst
Project details
===============
* Project home: https://gitlab.com/python-devs/importlib_metadata
* Report bugs at: https://gitlab.com/python-devs/importlib_metadata/issues
* Code hosting: https://gitlab.com/python-devs/importlib_metadata.git
* Documentation: http://importlib_metadata.readthedocs.io/
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
.. _`metadata`: https://www.python.org/dev/peps/pep-0566/
.. _`entry point API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#entry-points
.. _`metadata API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#metadata-api
.. _`Python 3.7 and newer`: https://docs.python.org/3/library/importlib.html#module-importlib.resources
.. _`importlib_resources`: https://importlib-resources.readthedocs.io/en/latest/index.html
.. _`importlib.metadata`: TBD

View file

@ -0,0 +1,133 @@
.. _using:
==========================
Using importlib_metadata
==========================
``importlib_metadata`` is a library that provides for access to installed
package metadata. Built in part on Python's import system, this library
intends to replace similar functionality in ``pkg_resources`` `entry point
API`_ and `metadata API`_. Along with ``importlib.resources`` in `Python 3.7
and newer`_ (backported as `importlib_resources`_ for older versions of
Python), this can eliminate the need to use the older and less efficient
``pkg_resources`` package.
By "installed package" we generally mean a third party package installed into
Python's ``site-packages`` directory via tools such as ``pip``. Specifically,
it means a package with either a discoverable ``dist-info`` or ``egg-info``
directory, and metadata defined by `PEP 566`_ or its older specifications.
By default, package metadata can live on the file system or in wheels on
``sys.path``. Through an extension mechanism, the metadata can live almost
anywhere.
Overview
========
Let's say you wanted to get the version string for a package you've installed
using ``pip``. We start by creating a virtual environment and installing
something into it::
$ python3 -m venv example
$ source example/bin/activate
(example) $ pip install importlib_metadata
(example) $ pip install wheel
You can get the version string for ``wheel`` by running the following::
(example) $ python
>>> from importlib_metadata import version
>>> version('wheel')
'0.31.1'
You can also get the set of entry points for the ``wheel`` package. Since the
``entry_points.txt`` file is an ``.ini``-style, the ``entry_points()``
function returns a `ConfigParser instance`_. To get the list of command line
entry points, extract the ``console_scripts`` section::
>>> cp = entry_points('wheel')
>>> cp.options('console_scripts')
['wheel']
You can also get the callable that the entry point is mapped to::
>>> cp.get('console_scripts', 'wheel')
'wheel.tool:main'
Even more conveniently, you can resolve this entry point to the actual
callable::
>>> from importlib_metadata import resolve
>>> ep = cp.get('console_scripts', 'wheel')
>>> resolve(ep)
<function main at 0x111b91bf8>
Distributions
=============
While the above API is the most common and convenient usage, you can get all
of that information from the ``Distribution`` class. A ``Distribution`` is an
abstract object that represents the metadata for a Python package. You can
get the ``Distribution`` instance::
>>> from importlib_metadata import distribution
>>> dist = distribution('wheel')
Thus, an alternative way to get the version number is through the
``Distribution`` instance::
>>> dist.version
'0.31.1'
There are all kinds of additional metadata available on the ``Distribution``
instance::
>>> d.metadata['Requires-Python']
'>=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*'
>>> d.metadata['License']
'MIT'
The full set of available metadata is not described here. See PEP 566 for
additional details.
Extending the search algorithm
==============================
Because package metadata is not available through ``sys.path`` searches, or
package loaders directly, the metadata for a package is found through import
system `finders`_. To find a distribution package's metadata,
``importlib_metadata`` queries the list of `meta path finders`_ on
`sys.meta_path`_.
By default ``importlib_metadata`` installs a finder for packages found on the
file system. This finder doesn't actually find any *packages*, but it cany
find the package's metadata.
The abstract class :py:class:`importlib.abc.MetaPathFinder` defines the
interface expected of finders by Python's import system.
``importlib_metadata`` extends this protocol by looking for an optional
``find_distribution()`` ``@classmethod`` on the finders from
``sys.meta_path``. If the finder has this method, it takes a single argument
which is the name of the distribution package to find. The method returns
``None`` if it cannot find the distribution package, otherwise it returns an
instance of the ``Distribution`` abstract class.
What this means in practice is that to support finding distribution package
metadata in locations other than the file system, you should derive from
``Distribution`` and implement the ``load_metadata()`` method. This takes a
single argument which is the name of the package whose metadata is being
found. This instance of the ``Distribution`` base abstract class is what your
finder's ``find_distribution()`` method should return.
.. _`entry point API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#entry-points
.. _`metadata API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#metadata-api
.. _`Python 3.7 and newer`: https://docs.python.org/3/library/importlib.html#module-importlib.resources
.. _`importlib_resources`: https://importlib-resources.readthedocs.io/en/latest/index.html
.. _`PEP 566`: https://www.python.org/dev/peps/pep-0566/
.. _`ConfigParser instance`: https://docs.python.org/3/library/configparser.html#configparser.ConfigParser
.. _`finders`: https://docs.python.org/3/reference/import.html#finders-and-loaders
.. _`meta path finders`: https://docs.python.org/3/glossary.html#term-meta-path-finder
.. _`sys.meta_path`: https://docs.python.org/3/library/sys.html#sys.meta_path

View file

@ -0,0 +1,44 @@
import re
import unittest
import importlib_metadata
class APITests(unittest.TestCase):
version_pattern = r'\d+\.\d+(\.\d)?'
def test_retrieves_version_of_self(self):
version = importlib_metadata.version('importlib_metadata')
assert isinstance(version, str)
assert re.match(self.version_pattern, version)
def test_retrieves_version_of_pip(self):
# Assume pip is installed and retrieve the version of pip.
version = importlib_metadata.version('pip')
assert isinstance(version, str)
assert re.match(self.version_pattern, version)
def test_for_name_does_not_exist(self):
with self.assertRaises(importlib_metadata.PackageNotFoundError):
importlib_metadata.distribution('does-not-exist')
def test_for_top_level(self):
distribution = importlib_metadata.distribution('importlib_metadata')
self.assertEqual(
distribution.read_text('top_level.txt').strip(),
'importlib_metadata')
def test_entry_points(self):
parser = importlib_metadata.entry_points('pip')
# We should probably not be dependent on a third party package's
# internal API staying stable.
entry_point = parser.get('console_scripts', 'pip')
self.assertEqual(entry_point, 'pip._internal:main')
def test_metadata_for_this_package(self):
md = importlib_metadata.metadata('importlib_metadata')
assert md['author'] == 'Barry Warsaw'
assert md['LICENSE'] == 'Apache Software License'
assert md['Name'] == 'importlib-metadata'
classifiers = md.get_all('Classifier')
assert 'Topic :: Software Development :: Libraries' in classifiers

View file

@ -0,0 +1,121 @@
from __future__ import unicode_literals
import re
import sys
import shutil
import tempfile
import unittest
import importlib
import contextlib
import importlib_metadata
try:
from contextlib import ExitStack
except ImportError:
from contextlib2 import ExitStack
try:
import pathlib
except ImportError:
import pathlib2 as pathlib
from importlib_metadata import _hooks
class BasicTests(unittest.TestCase):
version_pattern = r'\d+\.\d+(\.\d)?'
def test_retrieves_version_of_pip(self):
# Assume pip is installed and retrieve the version of pip.
dist = importlib_metadata.Distribution.from_name('pip')
assert isinstance(dist.version, str)
assert re.match(self.version_pattern, dist.version)
def test_for_name_does_not_exist(self):
with self.assertRaises(importlib_metadata.PackageNotFoundError):
importlib_metadata.Distribution.from_name('does-not-exist')
def test_new_style_classes(self):
self.assertIsInstance(importlib_metadata.Distribution, type)
self.assertIsInstance(_hooks.MetadataPathFinder, type)
self.assertIsInstance(_hooks.WheelMetadataFinder, type)
self.assertIsInstance(_hooks.WheelDistribution, type)
class ImportTests(unittest.TestCase):
def test_import_nonexistent_module(self):
# Ensure that the MetadataPathFinder does not crash an import of a
# non-existant module.
with self.assertRaises(ImportError):
importlib.import_module('does_not_exist')
def test_resolve(self):
entry_points = importlib_metadata.entry_points('pip')
main = importlib_metadata.resolve(
entry_points.get('console_scripts', 'pip'))
import pip._internal
self.assertEqual(main, pip._internal.main)
def test_resolve_invalid(self):
self.assertRaises(ValueError, importlib_metadata.resolve, 'bogus.ep')
class NameNormalizationTests(unittest.TestCase):
@staticmethod
def pkg_with_dashes(site_dir):
"""
Create minimal metadata for a package with dashes
in the name (and thus underscores in the filename).
"""
metadata_dir = site_dir / 'my_pkg.dist-info'
metadata_dir.mkdir()
metadata = metadata_dir / 'METADATA'
with metadata.open('w') as strm:
strm.write('Version: 1.0\n')
return 'my-pkg'
@staticmethod
@contextlib.contextmanager
def site_dir():
tmpdir = tempfile.mkdtemp()
sys.path[:0] = [tmpdir]
try:
yield pathlib.Path(tmpdir)
finally:
sys.path.remove(tmpdir)
shutil.rmtree(tmpdir)
def setUp(self):
self.fixtures = ExitStack()
self.addCleanup(self.fixtures.close)
self.site_dir = self.fixtures.enter_context(self.site_dir())
def test_dashes_in_dist_name_found_as_underscores(self):
"""
For a package with a dash in the name, the dist-info metadata
uses underscores in the name. Ensure the metadata loads.
"""
pkg_name = self.pkg_with_dashes(self.site_dir)
assert importlib_metadata.version(pkg_name) == '1.0'
@staticmethod
def pkg_with_mixed_case(site_dir):
"""
Create minimal metadata for a package with mixed case
in the name.
"""
metadata_dir = site_dir / 'CherryPy.dist-info'
metadata_dir.mkdir()
metadata = metadata_dir / 'METADATA'
with metadata.open('w') as strm:
strm.write('Version: 1.0\n')
return 'CherryPy'
def test_dist_name_found_as_any_case(self):
"""
Ensure the metadata loads when queried with any case.
"""
pkg_name = self.pkg_with_mixed_case(self.site_dir)
assert importlib_metadata.version(pkg_name) == '1.0'
assert importlib_metadata.version(pkg_name.lower()) == '1.0'
assert importlib_metadata.version(pkg_name.upper()) == '1.0'

View file

@ -0,0 +1,42 @@
import sys
import unittest
import importlib_metadata
try:
from contextlib import ExitStack
except ImportError:
from contextlib2 import ExitStack
from importlib_resources import path
class BespokeLoader:
archive = 'bespoke'
class TestZip(unittest.TestCase):
def setUp(self):
# Find the path to the example.*.whl so we can add it to the front of
# sys.path, where we'll then try to find the metadata thereof.
self.resources = ExitStack()
self.addCleanup(self.resources.close)
wheel = self.resources.enter_context(
path('importlib_metadata.tests.data',
'example-21.12-py3-none-any.whl'))
sys.path.insert(0, str(wheel))
self.resources.callback(sys.path.pop, 0)
def test_zip_version(self):
self.assertEqual(importlib_metadata.version('example'), '21.12')
def test_zip_entry_points(self):
parser = importlib_metadata.entry_points('example')
entry_point = parser.get('console_scripts', 'example')
self.assertEqual(entry_point, 'example:main')
def test_missing_metadata(self):
distribution = importlib_metadata.distribution('example')
self.assertIsNone(distribution.read_text('does not exist'))
def test_case_insensitive(self):
self.assertEqual(importlib_metadata.version('Example'), '21.12')

View file

@ -0,0 +1 @@
0.7

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -1 +0,0 @@
import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -1 +0,0 @@
import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -1 +0,0 @@
import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -1 +0,0 @@
import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -1 +0,0 @@
__import__("pkg_resources").declare_namespace(__name__)

View file

@ -5,6 +5,7 @@ of an object and its parent classes.
from __future__ import unicode_literals from __future__ import unicode_literals
def all_bases(c): def all_bases(c):
""" """
return a tuple of all base classes the class c has as a parent. return a tuple of all base classes the class c has as a parent.
@ -13,6 +14,7 @@ def all_bases(c):
""" """
return c.mro()[1:] return c.mro()[1:]
def all_classes(c): def all_classes(c):
""" """
return a tuple of all classes to which c belongs return a tuple of all classes to which c belongs
@ -21,7 +23,10 @@ def all_classes(c):
""" """
return c.mro() return c.mro()
# borrowed from http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/ # borrowed from
# http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/
def iter_subclasses(cls, _seen=None): def iter_subclasses(cls, _seen=None):
""" """
Generator over all subclasses of a given class, in depth-first order. Generator over all subclasses of a given class, in depth-first order.
@ -51,9 +56,12 @@ def iter_subclasses(cls, _seen=None):
""" """
if not isinstance(cls, type): if not isinstance(cls, type):
raise TypeError('iter_subclasses must be called with ' raise TypeError(
'new-style classes, not %.100r' % cls) 'iter_subclasses must be called with '
if _seen is None: _seen = set() 'new-style classes, not %.100r' % cls
)
if _seen is None:
_seen = set()
try: try:
subs = cls.__subclasses__() subs = cls.__subclasses__()
except TypeError: # fails only when cls is type except TypeError: # fails only when cls is type

View file

@ -6,6 +6,7 @@ Some useful metaclasses.
from __future__ import unicode_literals from __future__ import unicode_literals
class LeafClassesMeta(type): class LeafClassesMeta(type):
""" """
A metaclass for classes that keeps track of all of them that A metaclass for classes that keeps track of all of them that

View file

@ -2,8 +2,10 @@ from __future__ import unicode_literals
import six import six
__metaclass__ = type
class NonDataProperty(object):
class NonDataProperty:
"""Much like the property builtin, but only implements __get__, """Much like the property builtin, but only implements __get__,
making it a non-data property, and can be subsequently reset. making it a non-data property, and can be subsequently reset.
@ -34,7 +36,7 @@ class NonDataProperty(object):
# from http://stackoverflow.com/a/5191224 # from http://stackoverflow.com/a/5191224
class ClassPropertyDescriptor(object): class ClassPropertyDescriptor:
def __init__(self, fget, fset=None): def __init__(self, fget, fset=None):
self.fget = fget self.fget = fget

View file

@ -7,12 +7,63 @@ import operator
import collections import collections
import itertools import itertools
import copy import copy
import functools
try:
import collections.abc
except ImportError:
# Python 2.7
collections.abc = collections
import six import six
from jaraco.classes.properties import NonDataProperty from jaraco.classes.properties import NonDataProperty
import jaraco.text import jaraco.text
class Projection(collections.abc.Mapping):
"""
Project a set of keys over a mapping
>>> sample = {'a': 1, 'b': 2, 'c': 3}
>>> prj = Projection(['a', 'c', 'd'], sample)
>>> prj == {'a': 1, 'c': 3}
True
Keys should only appear if they were specified and exist in the space.
>>> sorted(list(prj.keys()))
['a', 'c']
Use the projection to update another dict.
>>> target = {'a': 2, 'b': 2}
>>> target.update(prj)
>>> target == {'a': 1, 'b': 2, 'c': 3}
True
Also note that Projection keeps a reference to the original dict, so
if you modify the original dict, that could modify the Projection.
>>> del sample['a']
>>> dict(prj)
{'c': 3}
"""
def __init__(self, keys, space):
self._keys = tuple(keys)
self._space = space
def __getitem__(self, key):
if key not in self._keys:
raise KeyError(key)
return self._space[key]
def __iter__(self):
return iter(set(self._keys).intersection(self._space))
def __len__(self):
return len(tuple(iter(self)))
class DictFilter(object): class DictFilter(object):
""" """
Takes a dict, and simulates a sub-dict based on the keys. Takes a dict, and simulates a sub-dict based on the keys.
@ -52,7 +103,6 @@ class DictFilter(object):
self.pattern_keys = set() self.pattern_keys = set()
def get_pattern_keys(self): def get_pattern_keys(self):
#key_matches = lambda k, v: self.include_pattern.match(k)
keys = filter(self.include_pattern.match, self.dict.keys()) keys = filter(self.include_pattern.match, self.dict.keys())
return set(keys) return set(keys)
pattern_keys = NonDataProperty(get_pattern_keys) pattern_keys = NonDataProperty(get_pattern_keys)
@ -70,7 +120,7 @@ class DictFilter(object):
return values return values
def __getitem__(self, i): def __getitem__(self, i):
if not i in self.include_keys: if i not in self.include_keys:
return KeyError, i return KeyError, i
return self.dict[i] return self.dict[i]
@ -162,7 +212,7 @@ class RangeMap(dict):
>>> r.get(7, 'not found') >>> r.get(7, 'not found')
'not found' 'not found'
""" """
def __init__(self, source, sort_params = {}, key_match_comparator = operator.le): def __init__(self, source, sort_params={}, key_match_comparator=operator.le):
dict.__init__(self, source) dict.__init__(self, source)
self.sort_params = sort_params self.sort_params = sort_params
self.match = key_match_comparator self.match = key_match_comparator
@ -190,7 +240,7 @@ class RangeMap(dict):
return default return default
def _find_first_match_(self, keys, item): def _find_first_match_(self, keys, item):
is_match = lambda k: self.match(item, k) is_match = functools.partial(self.match, item)
matches = list(filter(is_match, keys)) matches = list(filter(is_match, keys))
if matches: if matches:
return matches[0] return matches[0]
@ -205,12 +255,15 @@ class RangeMap(dict):
# some special values for the RangeMap # some special values for the RangeMap
undefined_value = type(str('RangeValueUndefined'), (object,), {})() undefined_value = type(str('RangeValueUndefined'), (object,), {})()
class Item(int): pass
class Item(int):
"RangeMap Item"
first_item = Item(0) first_item = Item(0)
last_item = Item(-1) last_item = Item(-1)
__identity = lambda x: x def __identity(x):
return x
def sorted_items(d, key=__identity, reverse=False): def sorted_items(d, key=__identity, reverse=False):
@ -229,7 +282,8 @@ def sorted_items(d, key=__identity, reverse=False):
(('foo', 20), ('baz', 10), ('bar', 42)) (('foo', 20), ('baz', 10), ('bar', 42))
""" """
# wrap the key func so it operates on the first element of each item # wrap the key func so it operates on the first element of each item
pairkey_key = lambda item: key(item[0]) def pairkey_key(item):
return key(item[0])
return sorted(d.items(), key=pairkey_key, reverse=reverse) return sorted(d.items(), key=pairkey_key, reverse=reverse)
@ -414,7 +468,11 @@ class ItemsAsAttributes(object):
It also works on dicts that customize __getitem__ It also works on dicts that customize __getitem__
>>> missing_func = lambda self, key: 'missing item' >>> missing_func = lambda self, key: 'missing item'
>>> C = type(str('C'), (dict, ItemsAsAttributes), dict(__missing__ = missing_func)) >>> C = type(
... str('C'),
... (dict, ItemsAsAttributes),
... dict(__missing__ = missing_func),
... )
>>> i = C() >>> i = C()
>>> i.missing >>> i.missing
'missing item' 'missing item'
@ -428,6 +486,7 @@ class ItemsAsAttributes(object):
# attempt to get the value from the mapping (return self[key]) # attempt to get the value from the mapping (return self[key])
# but be careful not to lose the original exception context. # but be careful not to lose the original exception context.
noval = object() noval = object()
def _safe_getitem(cont, key, missing_result): def _safe_getitem(cont, key, missing_result):
try: try:
return cont[key] return cont[key]
@ -460,7 +519,7 @@ def invert_map(map):
... ...
ValueError: Key conflict in inverted mapping ValueError: Key conflict in inverted mapping
""" """
res = dict((v,k) for k, v in map.items()) res = dict((v, k) for k, v in map.items())
if not len(res) == len(map): if not len(res) == len(map):
raise ValueError('Key conflict in inverted mapping') raise ValueError('Key conflict in inverted mapping')
return res return res
@ -483,7 +542,7 @@ class IdentityOverrideMap(dict):
return key return key
class DictStack(list, collections.Mapping): class DictStack(list, collections.abc.Mapping):
""" """
A stack of dictionaries that behaves as a view on those dictionaries, A stack of dictionaries that behaves as a view on those dictionaries,
giving preference to the last. giving preference to the last.
@ -506,6 +565,7 @@ class DictStack(list, collections.Mapping):
>>> d = stack.pop() >>> d = stack.pop()
>>> stack['a'] >>> stack['a']
1 1
>>> stack.get('b', None)
""" """
def keys(self): def keys(self):
@ -513,7 +573,8 @@ class DictStack(list, collections.Mapping):
def __getitem__(self, key): def __getitem__(self, key):
for scope in reversed(self): for scope in reversed(self):
if key in scope: return scope[key] if key in scope:
return scope[key]
raise KeyError(key) raise KeyError(key)
push = list.append push = list.append
@ -553,6 +614,10 @@ class BijectiveMap(dict):
Traceback (most recent call last): Traceback (most recent call last):
ValueError: Key/Value pairs may not overlap ValueError: Key/Value pairs may not overlap
>>> m['e'] = 'd'
Traceback (most recent call last):
ValueError: Key/Value pairs may not overlap
>>> print(m.pop('d')) >>> print(m.pop('d'))
c c
@ -583,7 +648,12 @@ class BijectiveMap(dict):
def __setitem__(self, item, value): def __setitem__(self, item, value):
if item == value: if item == value:
raise ValueError("Key cannot map to itself") raise ValueError("Key cannot map to itself")
if (value in self or item in self) and self[item] != value: overlap = (
item in self and self[item] != value
or
value in self and self[value] != item
)
if overlap:
raise ValueError("Key/Value pairs may not overlap") raise ValueError("Key/Value pairs may not overlap")
super(BijectiveMap, self).__setitem__(item, value) super(BijectiveMap, self).__setitem__(item, value)
super(BijectiveMap, self).__setitem__(value, item) super(BijectiveMap, self).__setitem__(value, item)
@ -607,7 +677,7 @@ class BijectiveMap(dict):
self.__setitem__(*item) self.__setitem__(*item)
class FrozenDict(collections.Mapping, collections.Hashable): class FrozenDict(collections.abc.Mapping, collections.abc.Hashable):
""" """
An immutable mapping. An immutable mapping.
@ -641,8 +711,8 @@ class FrozenDict(collections.Mapping, collections.Hashable):
>>> isinstance(copy.copy(a), FrozenDict) >>> isinstance(copy.copy(a), FrozenDict)
True True
FrozenDict supplies .copy(), even though collections.Mapping doesn't FrozenDict supplies .copy(), even though
demand it. collections.abc.Mapping doesn't demand it.
>>> a.copy() == a >>> a.copy() == a
True True
@ -747,6 +817,9 @@ class Everything(object):
>>> import random >>> import random
>>> random.randint(1, 999) in Everything() >>> random.randint(1, 999) in Everything()
True True
>>> random.choice([None, 'foo', 42, ('a', 'b', 'c')]) in Everything()
True
""" """
def __contains__(self, other): def __contains__(self, other):
return True return True
@ -771,3 +844,63 @@ class InstrumentedDict(six.moves.UserDict):
def __init__(self, data): def __init__(self, data):
six.moves.UserDict.__init__(self) six.moves.UserDict.__init__(self)
self.data = data self.data = data
class Least(object):
"""
A value that is always lesser than any other
>>> least = Least()
>>> 3 < least
False
>>> 3 > least
True
>>> least < 3
True
>>> least <= 3
True
>>> least > 3
False
>>> 'x' > least
True
>>> None > least
True
"""
def __le__(self, other):
return True
__lt__ = __le__
def __ge__(self, other):
return False
__gt__ = __ge__
class Greatest(object):
"""
A value that is always greater than any other
>>> greatest = Greatest()
>>> 3 < greatest
True
>>> 3 > greatest
False
>>> greatest < 3
False
>>> greatest > 3
True
>>> greatest >= 3
True
>>> 'x' > greatest
False
>>> None > greatest
False
"""
def __ge__(self, other):
return True
__gt__ = __ge__
def __le__(self, other):
return False
__lt__ = __le__

View file

@ -1,8 +1,16 @@
from __future__ import absolute_import, unicode_literals, print_function, division from __future__ import (
absolute_import, unicode_literals, print_function, division,
)
import functools import functools
import time import time
import warnings import warnings
import inspect
import collections
from itertools import count
__metaclass__ = type
try: try:
from functools import lru_cache from functools import lru_cache
@ -16,13 +24,17 @@ except ImportError:
warnings.warn("No lru_cache available") warnings.warn("No lru_cache available")
import more_itertools.recipes
def compose(*funcs): def compose(*funcs):
""" """
Compose any number of unary functions into a single unary function. Compose any number of unary functions into a single unary function.
>>> import textwrap >>> import textwrap
>>> from six import text_type >>> from six import text_type
>>> text_type.strip(textwrap.dedent(compose.__doc__)) == compose(text_type.strip, textwrap.dedent)(compose.__doc__) >>> stripped = text_type.strip(textwrap.dedent(compose.__doc__))
>>> compose(text_type.strip, textwrap.dedent)(compose.__doc__) == stripped
True True
Compose also allows the innermost function to take arbitrary arguments. Compose also allows the innermost function to take arbitrary arguments.
@ -33,7 +45,8 @@ def compose(*funcs):
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7] [1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
""" """
compose_two = lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)) def compose_two(f1, f2):
return lambda *args, **kwargs: f1(f2(*args, **kwargs))
return functools.reduce(compose_two, funcs) return functools.reduce(compose_two, funcs)
@ -60,19 +73,36 @@ def once(func):
This decorator can ensure that an expensive or non-idempotent function This decorator can ensure that an expensive or non-idempotent function
will not be expensive on subsequent calls and is idempotent. will not be expensive on subsequent calls and is idempotent.
>>> func = once(lambda a: a+3) >>> add_three = once(lambda a: a+3)
>>> func(3) >>> add_three(3)
6 6
>>> func(9) >>> add_three(9)
6 6
>>> func('12') >>> add_three('12')
6 6
To reset the stored value, simply clear the property ``saved_result``.
>>> del add_three.saved_result
>>> add_three(9)
12
>>> add_three(8)
12
Or invoke 'reset()' on it.
>>> add_three.reset()
>>> add_three(-3)
0
>>> add_three(0)
0
""" """
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not hasattr(func, 'always_returns'): if not hasattr(wrapper, 'saved_result'):
func.always_returns = func(*args, **kwargs) wrapper.saved_result = func(*args, **kwargs)
return func.always_returns return wrapper.saved_result
wrapper.reset = lambda: vars(wrapper).__delitem__('saved_result')
return wrapper return wrapper
@ -131,17 +161,22 @@ def method_cache(method, cache_wrapper=None):
>>> a.method2() >>> a.method2()
3 3
Caution - do not subsequently wrap the method with another decorator, such
as ``@property``, which changes the semantics of the function.
See also See also
http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/ http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/
for another implementation and additional justification. for another implementation and additional justification.
""" """
cache_wrapper = cache_wrapper or lru_cache() cache_wrapper = cache_wrapper or lru_cache()
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
# it's the first call, replace the method with a cached, bound method # it's the first call, replace the method with a cached, bound method
bound_method = functools.partial(method, self) bound_method = functools.partial(method, self)
cached_method = cache_wrapper(bound_method) cached_method = cache_wrapper(bound_method)
setattr(self, method.__name__, cached_method) setattr(self, method.__name__, cached_method)
return cached_method(*args, **kwargs) return cached_method(*args, **kwargs)
return _special_method_cache(method, cache_wrapper) or wrapper return _special_method_cache(method, cache_wrapper) or wrapper
@ -191,6 +226,29 @@ def apply(transform):
return wrap return wrap
def result_invoke(action):
r"""
Decorate a function with an action function that is
invoked on the results returned from the decorated
function (for its side-effect), then return the original
result.
>>> @result_invoke(print)
... def add_two(a, b):
... return a + b
>>> x = add_two(2, 3)
5
"""
def wrap(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
action(result)
return result
return wrapper
return wrap
def call_aside(f, *args, **kwargs): def call_aside(f, *args, **kwargs):
""" """
Call a function for its side effect after initialization. Call a function for its side effect after initialization.
@ -211,7 +269,7 @@ def call_aside(f, *args, **kwargs):
return f return f
class Throttler(object): class Throttler:
""" """
Rate-limit a function (or other callable) Rate-limit a function (or other callable)
""" """
@ -259,10 +317,143 @@ def retry_call(func, cleanup=lambda: None, retries=0, trap=()):
exception. On the final attempt, allow any exceptions exception. On the final attempt, allow any exceptions
to propagate. to propagate.
""" """
for attempt in range(retries): attempts = count() if retries == float('inf') else range(retries)
for attempt in attempts:
try: try:
return func() return func()
except trap: except trap:
cleanup() cleanup()
return func() return func()
def retry(*r_args, **r_kwargs):
"""
Decorator wrapper for retry_call. Accepts arguments to retry_call
except func and then returns a decorator for the decorated function.
Ex:
>>> @retry(retries=3)
... def my_func(a, b):
... "this is my funk"
... print(a, b)
>>> my_func.__doc__
'this is my funk'
"""
def decorate(func):
@functools.wraps(func)
def wrapper(*f_args, **f_kwargs):
bound = functools.partial(func, *f_args, **f_kwargs)
return retry_call(bound, *r_args, **r_kwargs)
return wrapper
return decorate
def print_yielded(func):
"""
Convert a generator into a function that prints all yielded elements
>>> @print_yielded
... def x():
... yield 3; yield None
>>> x()
3
None
"""
print_all = functools.partial(map, print)
print_results = compose(more_itertools.recipes.consume, print_all, func)
return functools.wraps(func)(print_results)
def pass_none(func):
"""
Wrap func so it's not called if its first param is None
>>> print_text = pass_none(print)
>>> print_text('text')
text
>>> print_text(None)
"""
@functools.wraps(func)
def wrapper(param, *args, **kwargs):
if param is not None:
return func(param, *args, **kwargs)
return wrapper
def assign_params(func, namespace):
"""
Assign parameters from namespace where func solicits.
>>> def func(x, y=3):
... print(x, y)
>>> assigned = assign_params(func, dict(x=2, z=4))
>>> assigned()
2 3
The usual errors are raised if a function doesn't receive
its required parameters:
>>> assigned = assign_params(func, dict(y=3, z=4))
>>> assigned()
Traceback (most recent call last):
TypeError: func() ...argument...
"""
try:
sig = inspect.signature(func)
params = sig.parameters.keys()
except AttributeError:
spec = inspect.getargspec(func)
params = spec.args
call_ns = {
k: namespace[k]
for k in params
if k in namespace
}
return functools.partial(func, **call_ns)
def save_method_args(method):
"""
Wrap a method such that when it is called, the args and kwargs are
saved on the method.
>>> class MyClass:
... @save_method_args
... def method(self, a, b):
... print(a, b)
>>> my_ob = MyClass()
>>> my_ob.method(1, 2)
1 2
>>> my_ob._saved_method.args
(1, 2)
>>> my_ob._saved_method.kwargs
{}
>>> my_ob.method(a=3, b='foo')
3 foo
>>> my_ob._saved_method.args
()
>>> my_ob._saved_method.kwargs == dict(a=3, b='foo')
True
The arguments are stored on the instance, allowing for
different instance to save different args.
>>> your_ob = MyClass()
>>> your_ob.method({str('x'): 3}, b=[4])
{'x': 3} [4]
>>> your_ob._saved_method.args
({'x': 3},)
>>> my_ob._saved_method.args
()
"""
args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs')
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
attr_name = '_saved_' + method.__name__
attr = args_and_kwargs(args, kwargs)
setattr(self, attr_name, attr)
return method(self, *args, **kwargs)
return wrapper

View file

@ -1,5 +1,6 @@
from __future__ import absolute_import, unicode_literals from __future__ import absolute_import, unicode_literals
import numbers
from functools import reduce from functools import reduce
@ -25,6 +26,7 @@ def get_bit_values(number, size=32):
number += 2**size number += 2**size
return list(map(int, bin(number)[-size:])) return list(map(int, bin(number)[-size:]))
def gen_bit_values(number): def gen_bit_values(number):
""" """
Return a zero or one for each bit of a numeric value up to the most Return a zero or one for each bit of a numeric value up to the most
@ -36,6 +38,7 @@ def gen_bit_values(number):
digits = bin(number)[2:] digits = bin(number)[2:]
return map(int, reversed(digits)) return map(int, reversed(digits))
def coalesce(bits): def coalesce(bits):
""" """
Take a sequence of bits, most significant first, and Take a sequence of bits, most significant first, and
@ -47,6 +50,7 @@ def coalesce(bits):
operation = lambda a, b: (a << 1 | b) operation = lambda a, b: (a << 1 | b)
return reduce(operation, bits) return reduce(operation, bits)
class Flags(object): class Flags(object):
""" """
Subclasses should define _names, a list of flag names beginning Subclasses should define _names, a list of flag names beginning
@ -96,6 +100,7 @@ class Flags(object):
index = self._names.index(key) index = self._names.index(key)
return self._values[index] return self._values[index]
class BitMask(type): class BitMask(type):
""" """
A metaclass to create a bitmask with attributes. Subclass an int and A metaclass to create a bitmask with attributes. Subclass an int and
@ -119,12 +124,28 @@ class BitMask(type):
>>> b2 = MyBits(8) >>> b2 = MyBits(8)
>>> any([b2.a, b2.b, b2.c]) >>> any([b2.a, b2.b, b2.c])
False False
If the instance defines methods, they won't be wrapped in
properties.
>>> ns['get_value'] = classmethod(lambda cls: 'some value')
>>> ns['prop'] = property(lambda self: 'a property')
>>> MyBits = BitMask(str('MyBits'), (int,), ns)
>>> MyBits(3).get_value()
'some value'
>>> MyBits(3).prop
'a property'
""" """
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
def make_property(name, value):
if name.startswith('_') or not isinstance(value, numbers.Number):
return value
return property(lambda self, value=value: bool(self & value))
newattrs = dict( newattrs = dict(
(attr, property(lambda self, value=value: bool(self & value))) (name, make_property(name, value))
for attr, value in attrs.items() for name, value in attrs.items()
if not attr.startswith('_')
) )
return type.__new__(cls, name, bases, newattrs) return type.__new__(cls, name, bases, newattrs)

View file

@ -39,6 +39,7 @@ class FoldedCase(six.text_type):
""" """
A case insensitive string class; behaves just like str A case insensitive string class; behaves just like str
except compares equal when the only variation is case. except compares equal when the only variation is case.
>>> s = FoldedCase('hello world') >>> s = FoldedCase('hello world')
>>> s == 'Hello World' >>> s == 'Hello World'
@ -47,6 +48,9 @@ class FoldedCase(six.text_type):
>>> 'Hello World' == s >>> 'Hello World' == s
True True
>>> s != 'Hello World'
False
>>> s.index('O') >>> s.index('O')
4 4
@ -55,6 +59,38 @@ class FoldedCase(six.text_type):
>>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta'])) >>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta']))
['alpha', 'Beta', 'GAMMA'] ['alpha', 'Beta', 'GAMMA']
Sequence membership is straightforward.
>>> "Hello World" in [s]
True
>>> s in ["Hello World"]
True
You may test for set inclusion, but candidate and elements
must both be folded.
>>> FoldedCase("Hello World") in {s}
True
>>> s in {FoldedCase("Hello World")}
True
String inclusion works as long as the FoldedCase object
is on the right.
>>> "hello" in FoldedCase("Hello World")
True
But not if the FoldedCase object is on the left:
>>> FoldedCase('hello') in 'Hello World'
False
In that case, use in_:
>>> FoldedCase('hello').in_('Hello World')
True
""" """
def __lt__(self, other): def __lt__(self, other):
return self.lower() < other.lower() return self.lower() < other.lower()
@ -65,14 +101,23 @@ class FoldedCase(six.text_type):
def __eq__(self, other): def __eq__(self, other):
return self.lower() == other.lower() return self.lower() == other.lower()
def __ne__(self, other):
return self.lower() != other.lower()
def __hash__(self): def __hash__(self):
return hash(self.lower()) return hash(self.lower())
def __contains__(self, other):
return super(FoldedCase, self).lower().__contains__(other.lower())
def in_(self, other):
"Does self appear in other?"
return self in FoldedCase(other)
# cache lower since it's likely to be called frequently. # cache lower since it's likely to be called frequently.
@jaraco.functools.method_cache
def lower(self): def lower(self):
self._lower = super(FoldedCase, self).lower() return super(FoldedCase, self).lower()
self.lower = lambda: self._lower
return self._lower
def index(self, sub): def index(self, sub):
return self.lower().index(sub.lower()) return self.lower().index(sub.lower())
@ -147,6 +192,7 @@ def is_decodable(value):
return False return False
return True return True
def is_binary(value): def is_binary(value):
""" """
Return True if the value appears to be binary (that is, it's a byte Return True if the value appears to be binary (that is, it's a byte
@ -154,6 +200,7 @@ def is_binary(value):
""" """
return isinstance(value, bytes) and not is_decodable(value) return isinstance(value, bytes) and not is_decodable(value)
def trim(s): def trim(s):
r""" r"""
Trim something like a docstring to remove the whitespace that Trim something like a docstring to remove the whitespace that
@ -164,8 +211,10 @@ def trim(s):
""" """
return textwrap.dedent(s).strip() return textwrap.dedent(s).strip()
class Splitter(object): class Splitter(object):
"""object that will split a string with the given arguments for each call """object that will split a string with the given arguments for each call
>>> s = Splitter(',') >>> s = Splitter(',')
>>> s('hello, world, this is your, master calling') >>> s('hello, world, this is your, master calling')
['hello', ' world', ' this is your', ' master calling'] ['hello', ' world', ' this is your', ' master calling']
@ -176,9 +225,11 @@ class Splitter(object):
def __call__(self, s): def __call__(self, s):
return s.split(*self.args) return s.split(*self.args)
def indent(string, prefix=' ' * 4): def indent(string, prefix=' ' * 4):
return prefix + string return prefix + string
class WordSet(tuple): class WordSet(tuple):
""" """
Given a Python identifier, return the words that identifier represents, Given a Python identifier, return the words that identifier represents,
@ -269,6 +320,7 @@ class WordSet(tuple):
def from_class_name(cls, subject): def from_class_name(cls, subject):
return cls.parse(subject.__class__.__name__) return cls.parse(subject.__class__.__name__)
# for backward compatibility # for backward compatibility
words = WordSet.parse words = WordSet.parse
@ -318,6 +370,7 @@ class SeparatedValues(six.text_type):
parts = self.split(self.separator) parts = self.split(self.separator)
return six.moves.filter(None, (part.strip() for part in parts)) return six.moves.filter(None, (part.strip() for part in parts))
class Stripper: class Stripper:
r""" r"""
Given a series of lines, find the common prefix and strip it from them. Given a series of lines, find the common prefix and strip it from them.
@ -369,3 +422,31 @@ class Stripper:
while s1[:index] != s2[:index]: while s1[:index] != s2[:index]:
index -= 1 index -= 1
return s1[:index] return s1[:index]
def remove_prefix(text, prefix):
"""
Remove the prefix from the text if it exists.
>>> remove_prefix('underwhelming performance', 'underwhelming ')
'performance'
>>> remove_prefix('something special', 'sample')
'something special'
"""
null, prefix, rest = text.rpartition(prefix)
return rest
def remove_suffix(text, suffix):
"""
Remove the suffix from the text if it exists.
>>> remove_suffix('name.git', '.git')
'name'
>>> remove_suffix('something special', 'sample')
'something special'
"""
rest, suffix, null = text.partition(suffix)
return rest

View file

@ -60,3 +60,18 @@ class Command(object):
cls.add_subparsers(parser) cls.add_subparsers(parser)
args = parser.parse_args() args = parser.parse_args()
args.action.run(args) args.action.run(args)
class Extend(argparse.Action):
"""
Argparse action to take an nargs=* argument
and add any values to the existing value.
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', nargs='*', default=[], action=Extend)
>>> args = parser.parse_args(['--foo', 'a=1', '--foo', 'b=2', 'c=3'])
>>> args.foo
['a=1', 'b=2', 'c=3']
"""
def __call__(self, parser, namespace, values, option_string=None):
getattr(namespace, self.dest).extend(values)

View file

@ -1,3 +1,5 @@
# deprecated -- use TQDM
from __future__ import (print_function, absolute_import, unicode_literals, from __future__ import (print_function, absolute_import, unicode_literals,
division) division)

View file

@ -45,3 +45,9 @@ GetClipboardData.restype = ctypes.wintypes.HANDLE
SetClipboardData = ctypes.windll.user32.SetClipboardData SetClipboardData = ctypes.windll.user32.SetClipboardData
SetClipboardData.argtypes = ctypes.wintypes.UINT, ctypes.wintypes.HANDLE SetClipboardData.argtypes = ctypes.wintypes.UINT, ctypes.wintypes.HANDLE
SetClipboardData.restype = ctypes.wintypes.HANDLE SetClipboardData.restype = ctypes.wintypes.HANDLE
OpenClipboard = ctypes.windll.user32.OpenClipboard
OpenClipboard.argtypes = ctypes.wintypes.HANDLE,
OpenClipboard.restype = ctypes.wintypes.BOOL
ctypes.windll.user32.CloseClipboard.restype = ctypes.wintypes.BOOL

View file

@ -10,9 +10,11 @@ try:
except ImportError: except ImportError:
LPBYTE = ctypes.POINTER(ctypes.wintypes.BYTE) LPBYTE = ctypes.POINTER(ctypes.wintypes.BYTE)
class CredentialAttribute(ctypes.Structure): class CredentialAttribute(ctypes.Structure):
_fields_ = [] _fields_ = []
class Credential(ctypes.Structure): class Credential(ctypes.Structure):
_fields_ = [ _fields_ = [
('flags', DWORD), ('flags', DWORD),
@ -32,6 +34,7 @@ class Credential(ctypes.Structure):
def __del__(self): def __del__(self):
ctypes.windll.advapi32.CredFree(ctypes.byref(self)) ctypes.windll.advapi32.CredFree(ctypes.byref(self))
PCREDENTIAL = ctypes.POINTER(Credential) PCREDENTIAL = ctypes.POINTER(Credential)
CredRead = ctypes.windll.advapi32.CredReadW CredRead = ctypes.windll.advapi32.CredReadW

View file

@ -2,7 +2,7 @@ import ctypes.wintypes
SetEnvironmentVariable = ctypes.windll.kernel32.SetEnvironmentVariableW SetEnvironmentVariable = ctypes.windll.kernel32.SetEnvironmentVariableW
SetEnvironmentVariable.restype = ctypes.wintypes.BOOL SetEnvironmentVariable.restype = ctypes.wintypes.BOOL
SetEnvironmentVariable.argtypes = [ctypes.wintypes.LPCWSTR]*2 SetEnvironmentVariable.argtypes = [ctypes.wintypes.LPCWSTR] * 2
GetEnvironmentVariable = ctypes.windll.kernel32.GetEnvironmentVariableW GetEnvironmentVariable = ctypes.windll.kernel32.GetEnvironmentVariableW
GetEnvironmentVariable.restype = ctypes.wintypes.BOOL GetEnvironmentVariable.restype = ctypes.wintypes.BOOL

View file

@ -1,12 +1,7 @@
from ctypes import ( from ctypes import windll, POINTER
Structure, windll, POINTER, byref, cast, create_unicode_buffer,
c_size_t, c_int, create_string_buffer, c_uint64, c_ushort, c_short,
c_uint,
)
from ctypes.wintypes import ( from ctypes.wintypes import (
BOOLEAN, LPWSTR, DWORD, LPVOID, HANDLE, FILETIME, LPWSTR, DWORD, LPVOID, HANDLE, BOOL,
WCHAR, BOOL, HWND, WORD, UINT, )
)
CreateEvent = windll.kernel32.CreateEventW CreateEvent = windll.kernel32.CreateEventW
CreateEvent.argtypes = ( CreateEvent.argtypes = (
@ -29,13 +24,15 @@ _WaitForMultipleObjects = windll.kernel32.WaitForMultipleObjects
_WaitForMultipleObjects.argtypes = DWORD, POINTER(HANDLE), BOOL, DWORD _WaitForMultipleObjects.argtypes = DWORD, POINTER(HANDLE), BOOL, DWORD
_WaitForMultipleObjects.restype = DWORD _WaitForMultipleObjects.restype = DWORD
def WaitForMultipleObjects(handles, wait_all=False, timeout=0): def WaitForMultipleObjects(handles, wait_all=False, timeout=0):
n_handles = len(handles) n_handles = len(handles)
handle_array = (HANDLE*n_handles)() handle_array = (HANDLE * n_handles)()
for index, handle in enumerate(handles): for index, handle in enumerate(handles):
handle_array[index] = handle handle_array[index] = handle
return _WaitForMultipleObjects(n_handles, handle_array, wait_all, timeout) return _WaitForMultipleObjects(n_handles, handle_array, wait_all, timeout)
WAIT_OBJECT_0 = 0 WAIT_OBJECT_0 = 0
INFINITE = -1 INFINITE = -1
WAIT_TIMEOUT = 0x102 WAIT_TIMEOUT = 0x102

View file

@ -5,7 +5,7 @@ CreateSymbolicLink.argtypes = (
ctypes.wintypes.LPWSTR, ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPWSTR, ctypes.wintypes.LPWSTR,
ctypes.wintypes.DWORD, ctypes.wintypes.DWORD,
) )
CreateSymbolicLink.restype = ctypes.wintypes.BOOLEAN CreateSymbolicLink.restype = ctypes.wintypes.BOOLEAN
CreateHardLink = ctypes.windll.kernel32.CreateHardLinkW CreateHardLink = ctypes.windll.kernel32.CreateHardLinkW
@ -13,7 +13,7 @@ CreateHardLink.argtypes = (
ctypes.wintypes.LPWSTR, ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPWSTR, ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPVOID, # reserved for LPSECURITY_ATTRIBUTES ctypes.wintypes.LPVOID, # reserved for LPSECURITY_ATTRIBUTES
) )
CreateHardLink.restype = ctypes.wintypes.BOOLEAN CreateHardLink.restype = ctypes.wintypes.BOOLEAN
GetFileAttributes = ctypes.windll.kernel32.GetFileAttributesW GetFileAttributes = ctypes.windll.kernel32.GetFileAttributesW
@ -28,16 +28,20 @@ MAX_PATH = 260
GetFinalPathNameByHandle = ctypes.windll.kernel32.GetFinalPathNameByHandleW GetFinalPathNameByHandle = ctypes.windll.kernel32.GetFinalPathNameByHandleW
GetFinalPathNameByHandle.argtypes = ( GetFinalPathNameByHandle.argtypes = (
ctypes.wintypes.HANDLE, ctypes.wintypes.LPWSTR, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD, ctypes.wintypes.HANDLE, ctypes.wintypes.LPWSTR, ctypes.wintypes.DWORD,
) ctypes.wintypes.DWORD,
)
GetFinalPathNameByHandle.restype = ctypes.wintypes.DWORD GetFinalPathNameByHandle.restype = ctypes.wintypes.DWORD
class SECURITY_ATTRIBUTES(ctypes.Structure): class SECURITY_ATTRIBUTES(ctypes.Structure):
_fields_ = ( _fields_ = (
('length', ctypes.wintypes.DWORD), ('length', ctypes.wintypes.DWORD),
('p_security_descriptor', ctypes.wintypes.LPVOID), ('p_security_descriptor', ctypes.wintypes.LPVOID),
('inherit_handle', ctypes.wintypes.BOOLEAN), ('inherit_handle', ctypes.wintypes.BOOLEAN),
) )
LPSECURITY_ATTRIBUTES = ctypes.POINTER(SECURITY_ATTRIBUTES) LPSECURITY_ATTRIBUTES = ctypes.POINTER(SECURITY_ATTRIBUTES)
CreateFile = ctypes.windll.kernel32.CreateFileW CreateFile = ctypes.windll.kernel32.CreateFileW
@ -49,7 +53,7 @@ CreateFile.argtypes = (
ctypes.wintypes.DWORD, ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD, ctypes.wintypes.DWORD,
ctypes.wintypes.HANDLE, ctypes.wintypes.HANDLE,
) )
CreateFile.restype = ctypes.wintypes.HANDLE CreateFile.restype = ctypes.wintypes.HANDLE
FILE_SHARE_READ = 1 FILE_SHARE_READ = 1
FILE_SHARE_WRITE = 2 FILE_SHARE_WRITE = 2
@ -77,23 +81,61 @@ CloseHandle = ctypes.windll.kernel32.CloseHandle
CloseHandle.argtypes = (ctypes.wintypes.HANDLE,) CloseHandle.argtypes = (ctypes.wintypes.HANDLE,)
CloseHandle.restype = ctypes.wintypes.BOOLEAN CloseHandle.restype = ctypes.wintypes.BOOLEAN
class WIN32_FIND_DATA(ctypes.Structure):
class WIN32_FIND_DATA(ctypes.wintypes.WIN32_FIND_DATAW):
"""
_fields_ = [ _fields_ = [
('file_attributes', ctypes.wintypes.DWORD), ("dwFileAttributes", DWORD),
('creation_time', ctypes.wintypes.FILETIME), ("ftCreationTime", FILETIME),
('last_access_time', ctypes.wintypes.FILETIME), ("ftLastAccessTime", FILETIME),
('last_write_time', ctypes.wintypes.FILETIME), ("ftLastWriteTime", FILETIME),
('file_size_words', ctypes.wintypes.DWORD*2), ("nFileSizeHigh", DWORD),
('reserved', ctypes.wintypes.DWORD*2), ("nFileSizeLow", DWORD),
('filename', ctypes.wintypes.WCHAR*MAX_PATH), ("dwReserved0", DWORD),
('alternate_filename', ctypes.wintypes.WCHAR*14), ("dwReserved1", DWORD),
("cFileName", WCHAR * MAX_PATH),
("cAlternateFileName", WCHAR * 14)]
] ]
"""
@property
def file_attributes(self):
return self.dwFileAttributes
@property
def creation_time(self):
return self.ftCreationTime
@property
def last_access_time(self):
return self.ftLastAccessTime
@property
def last_write_time(self):
return self.ftLastWriteTime
@property
def file_size_words(self):
return [self.nFileSizeHigh, self.nFileSizeLow]
@property
def reserved(self):
return [self.dwReserved0, self.dwReserved1]
@property
def filename(self):
return self.cFileName
@property
def alternate_filename(self):
return self.cAlternateFileName
@property @property
def file_size(self): def file_size(self):
return ctypes.cast(self.file_size_words, ctypes.POINTER(ctypes.c_uint64)).contents return self.nFileSizeHigh << 32 + self.nFileSizeLow
LPWIN32_FIND_DATA = ctypes.POINTER(WIN32_FIND_DATA)
LPWIN32_FIND_DATA = ctypes.POINTER(ctypes.wintypes.WIN32_FIND_DATAW)
FindFirstFile = ctypes.windll.kernel32.FindFirstFileW FindFirstFile = ctypes.windll.kernel32.FindFirstFileW
FindFirstFile.argtypes = (ctypes.wintypes.LPWSTR, LPWIN32_FIND_DATA) FindFirstFile.argtypes = (ctypes.wintypes.LPWSTR, LPWIN32_FIND_DATA)
@ -102,6 +144,8 @@ FindNextFile = ctypes.windll.kernel32.FindNextFileW
FindNextFile.argtypes = (ctypes.wintypes.HANDLE, LPWIN32_FIND_DATA) FindNextFile.argtypes = (ctypes.wintypes.HANDLE, LPWIN32_FIND_DATA)
FindNextFile.restype = ctypes.wintypes.BOOLEAN FindNextFile.restype = ctypes.wintypes.BOOLEAN
ctypes.windll.kernel32.FindClose.argtypes = ctypes.wintypes.HANDLE,
SCS_32BIT_BINARY = 0 # A 32-bit Windows-based application SCS_32BIT_BINARY = 0 # A 32-bit Windows-based application
SCS_64BIT_BINARY = 6 # A 64-bit Windows-based application SCS_64BIT_BINARY = 6 # A 64-bit Windows-based application
SCS_DOS_BINARY = 1 # An MS-DOS-based application SCS_DOS_BINARY = 1 # An MS-DOS-based application
@ -111,10 +155,45 @@ SCS_POSIX_BINARY = 4 # A POSIX-based application
SCS_WOW_BINARY = 2 # A 16-bit Windows-based application SCS_WOW_BINARY = 2 # A 16-bit Windows-based application
_GetBinaryType = ctypes.windll.kernel32.GetBinaryTypeW _GetBinaryType = ctypes.windll.kernel32.GetBinaryTypeW
_GetBinaryType.argtypes = (ctypes.wintypes.LPWSTR, ctypes.POINTER(ctypes.wintypes.DWORD)) _GetBinaryType.argtypes = (
ctypes.wintypes.LPWSTR, ctypes.POINTER(ctypes.wintypes.DWORD),
)
_GetBinaryType.restype = ctypes.wintypes.BOOL _GetBinaryType.restype = ctypes.wintypes.BOOL
FILEOP_FLAGS = ctypes.wintypes.WORD FILEOP_FLAGS = ctypes.wintypes.WORD
class BY_HANDLE_FILE_INFORMATION(ctypes.Structure):
_fields_ = [
('file_attributes', ctypes.wintypes.DWORD),
('creation_time', ctypes.wintypes.FILETIME),
('last_access_time', ctypes.wintypes.FILETIME),
('last_write_time', ctypes.wintypes.FILETIME),
('volume_serial_number', ctypes.wintypes.DWORD),
('file_size_high', ctypes.wintypes.DWORD),
('file_size_low', ctypes.wintypes.DWORD),
('number_of_links', ctypes.wintypes.DWORD),
('file_index_high', ctypes.wintypes.DWORD),
('file_index_low', ctypes.wintypes.DWORD),
]
@property
def file_size(self):
return (self.file_size_high << 32) + self.file_size_low
@property
def file_index(self):
return (self.file_index_high << 32) + self.file_index_low
GetFileInformationByHandle = ctypes.windll.kernel32.GetFileInformationByHandle
GetFileInformationByHandle.restype = ctypes.wintypes.BOOL
GetFileInformationByHandle.argtypes = (
ctypes.wintypes.HANDLE,
ctypes.POINTER(BY_HANDLE_FILE_INFORMATION),
)
class SHFILEOPSTRUCT(ctypes.Structure): class SHFILEOPSTRUCT(ctypes.Structure):
_fields_ = [ _fields_ = [
('status_dialog', ctypes.wintypes.HWND), ('status_dialog', ctypes.wintypes.HWND),
@ -126,6 +205,8 @@ class SHFILEOPSTRUCT(ctypes.Structure):
('name_mapping_handles', ctypes.wintypes.LPVOID), ('name_mapping_handles', ctypes.wintypes.LPVOID),
('progress_title', ctypes.wintypes.LPWSTR), ('progress_title', ctypes.wintypes.LPWSTR),
] ]
_SHFileOperation = ctypes.windll.shell32.SHFileOperationW _SHFileOperation = ctypes.windll.shell32.SHFileOperationW
_SHFileOperation.argtypes = [ctypes.POINTER(SHFILEOPSTRUCT)] _SHFileOperation.argtypes = [ctypes.POINTER(SHFILEOPSTRUCT)]
_SHFileOperation.restype = ctypes.c_int _SHFileOperation.restype = ctypes.c_int
@ -143,12 +224,13 @@ ReplaceFile.argtypes = [
ctypes.wintypes.DWORD, ctypes.wintypes.DWORD,
ctypes.wintypes.LPVOID, ctypes.wintypes.LPVOID,
ctypes.wintypes.LPVOID, ctypes.wintypes.LPVOID,
] ]
REPLACEFILE_WRITE_THROUGH = 0x1 REPLACEFILE_WRITE_THROUGH = 0x1
REPLACEFILE_IGNORE_MERGE_ERRORS = 0x2 REPLACEFILE_IGNORE_MERGE_ERRORS = 0x2
REPLACEFILE_IGNORE_ACL_ERRORS = 0x4 REPLACEFILE_IGNORE_ACL_ERRORS = 0x4
class STAT_STRUCT(ctypes.Structure): class STAT_STRUCT(ctypes.Structure):
_fields_ = [ _fields_ = [
('dev', ctypes.c_uint), ('dev', ctypes.c_uint),
@ -165,17 +247,22 @@ class STAT_STRUCT(ctypes.Structure):
('ctime', ctypes.c_uint), ('ctime', ctypes.c_uint),
] ]
_wstat = ctypes.windll.msvcrt._wstat _wstat = ctypes.windll.msvcrt._wstat
_wstat.argtypes = [ctypes.wintypes.LPWSTR, ctypes.POINTER(STAT_STRUCT)] _wstat.argtypes = [ctypes.wintypes.LPWSTR, ctypes.POINTER(STAT_STRUCT)]
_wstat.restype = ctypes.c_int _wstat.restype = ctypes.c_int
FILE_NOTIFY_CHANGE_LAST_WRITE = 0x10 FILE_NOTIFY_CHANGE_LAST_WRITE = 0x10
FindFirstChangeNotification = ctypes.windll.kernel32.FindFirstChangeNotificationW FindFirstChangeNotification = (
FindFirstChangeNotification.argtypes = ctypes.wintypes.LPWSTR, ctypes.wintypes.BOOL, ctypes.wintypes.DWORD ctypes.windll.kernel32.FindFirstChangeNotificationW)
FindFirstChangeNotification.argtypes = (
ctypes.wintypes.LPWSTR, ctypes.wintypes.BOOL, ctypes.wintypes.DWORD,
)
FindFirstChangeNotification.restype = ctypes.wintypes.HANDLE FindFirstChangeNotification.restype = ctypes.wintypes.HANDLE
FindCloseChangeNotification = ctypes.windll.kernel32.FindCloseChangeNotification FindCloseChangeNotification = (
ctypes.windll.kernel32.FindCloseChangeNotification)
FindCloseChangeNotification.argtypes = ctypes.wintypes.HANDLE, FindCloseChangeNotification.argtypes = ctypes.wintypes.HANDLE,
FindCloseChangeNotification.restype = ctypes.wintypes.BOOL FindCloseChangeNotification.restype = ctypes.wintypes.BOOL
@ -200,9 +287,10 @@ DeviceIoControl.argtypes = [
ctypes.wintypes.DWORD, ctypes.wintypes.DWORD,
LPDWORD, LPDWORD,
LPOVERLAPPED, LPOVERLAPPED,
] ]
DeviceIoControl.restype = ctypes.wintypes.BOOL DeviceIoControl.restype = ctypes.wintypes.BOOL
class REPARSE_DATA_BUFFER(ctypes.Structure): class REPARSE_DATA_BUFFER(ctypes.Structure):
_fields_ = [ _fields_ = [
('tag', ctypes.c_ulong), ('tag', ctypes.c_ulong),
@ -213,16 +301,17 @@ class REPARSE_DATA_BUFFER(ctypes.Structure):
('print_name_offset', ctypes.c_ushort), ('print_name_offset', ctypes.c_ushort),
('print_name_length', ctypes.c_ushort), ('print_name_length', ctypes.c_ushort),
('flags', ctypes.c_ulong), ('flags', ctypes.c_ulong),
('path_buffer', ctypes.c_byte*1), ('path_buffer', ctypes.c_byte * 1),
] ]
def get_print_name(self): def get_print_name(self):
wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR) wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR)
arr_typ = ctypes.wintypes.WCHAR*(self.print_name_length//wchar_size) arr_typ = ctypes.wintypes.WCHAR * (self.print_name_length // wchar_size)
data = ctypes.byref(self.path_buffer, self.print_name_offset) data = ctypes.byref(self.path_buffer, self.print_name_offset)
return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value
def get_substitute_name(self): def get_substitute_name(self):
wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR) wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR)
arr_typ = ctypes.wintypes.WCHAR*(self.substitute_name_length//wchar_size) arr_typ = ctypes.wintypes.WCHAR * (self.substitute_name_length // wchar_size)
data = ctypes.byref(self.path_buffer, self.substitute_name_offset) data = ctypes.byref(self.path_buffer, self.substitute_name_offset)
return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value

View file

@ -2,6 +2,7 @@ import struct
import ctypes.wintypes import ctypes.wintypes
from ctypes.wintypes import DWORD, WCHAR, BYTE, BOOL from ctypes.wintypes import DWORD, WCHAR, BYTE, BOOL
# from mprapi.h # from mprapi.h
MAX_INTERFACE_NAME_LEN = 2**8 MAX_INTERFACE_NAME_LEN = 2**8
@ -13,15 +14,16 @@ MAXLEN_IFDESCR = 2**8
MAX_ADAPTER_ADDRESS_LENGTH = 8 MAX_ADAPTER_ADDRESS_LENGTH = 8
MAX_DHCPV6_DUID_LENGTH = 130 MAX_DHCPV6_DUID_LENGTH = 130
class MIB_IFROW(ctypes.Structure): class MIB_IFROW(ctypes.Structure):
_fields_ = ( _fields_ = (
('name', WCHAR*MAX_INTERFACE_NAME_LEN), ('name', WCHAR * MAX_INTERFACE_NAME_LEN),
('index', DWORD), ('index', DWORD),
('type', DWORD), ('type', DWORD),
('MTU', DWORD), ('MTU', DWORD),
('speed', DWORD), ('speed', DWORD),
('physical_address_length', DWORD), ('physical_address_length', DWORD),
('physical_address_raw', BYTE*MAXLEN_PHYSADDR), ('physical_address_raw', BYTE * MAXLEN_PHYSADDR),
('admin_status', DWORD), ('admin_status', DWORD),
('operational_status', DWORD), ('operational_status', DWORD),
('last_change', DWORD), ('last_change', DWORD),
@ -38,7 +40,7 @@ class MIB_IFROW(ctypes.Structure):
('outgoing_errors', DWORD), ('outgoing_errors', DWORD),
('outgoing_queue_length', DWORD), ('outgoing_queue_length', DWORD),
('description_length', DWORD), ('description_length', DWORD),
('description_raw', ctypes.c_char*MAXLEN_IFDESCR), ('description_raw', ctypes.c_char * MAXLEN_IFDESCR),
) )
def _get_binary_property(self, name): def _get_binary_property(self, name):
@ -46,7 +48,7 @@ class MIB_IFROW(ctypes.Structure):
val = getattr(self, val_prop) val = getattr(self, val_prop)
len_prop = '{0}_length'.format(name) len_prop = '{0}_length'.format(name)
length = getattr(self, len_prop) length = getattr(self, len_prop)
return str(buffer(val))[:length] return str(memoryview(val))[:length]
@property @property
def physical_address(self): def physical_address(self):
@ -56,12 +58,14 @@ class MIB_IFROW(ctypes.Structure):
def description(self): def description(self):
return self._get_binary_property('description') return self._get_binary_property('description')
class MIB_IFTABLE(ctypes.Structure): class MIB_IFTABLE(ctypes.Structure):
_fields_ = ( _fields_ = (
('num_entries', DWORD), # dwNumEntries ('num_entries', DWORD), # dwNumEntries
('entries', MIB_IFROW*0), # table ('entries', MIB_IFROW * 0), # table
) )
class MIB_IPADDRROW(ctypes.Structure): class MIB_IPADDRROW(ctypes.Structure):
_fields_ = ( _fields_ = (
('address_num', DWORD), ('address_num', DWORD),
@ -79,40 +83,49 @@ class MIB_IPADDRROW(ctypes.Structure):
_ = struct.pack('L', self.address_num) _ = struct.pack('L', self.address_num)
return struct.unpack('!L', _)[0] return struct.unpack('!L', _)[0]
class MIB_IPADDRTABLE(ctypes.Structure): class MIB_IPADDRTABLE(ctypes.Structure):
_fields_ = ( _fields_ = (
('num_entries', DWORD), ('num_entries', DWORD),
('entries', MIB_IPADDRROW*0), ('entries', MIB_IPADDRROW * 0),
) )
class SOCKADDR(ctypes.Structure): class SOCKADDR(ctypes.Structure):
_fields_ = ( _fields_ = (
('family', ctypes.c_ushort), ('family', ctypes.c_ushort),
('data', ctypes.c_byte*14), ('data', ctypes.c_byte * 14),
) )
LPSOCKADDR = ctypes.POINTER(SOCKADDR) LPSOCKADDR = ctypes.POINTER(SOCKADDR)
class SOCKET_ADDRESS(ctypes.Structure): class SOCKET_ADDRESS(ctypes.Structure):
_fields_ = [ _fields_ = [
('address', LPSOCKADDR), ('address', LPSOCKADDR),
('length', ctypes.c_int), ('length', ctypes.c_int),
] ]
class _IP_ADAPTER_ADDRESSES_METRIC(ctypes.Structure): class _IP_ADAPTER_ADDRESSES_METRIC(ctypes.Structure):
_fields_ = [ _fields_ = [
('length', ctypes.c_ulong), ('length', ctypes.c_ulong),
('interface_index', DWORD), ('interface_index', DWORD),
] ]
class _IP_ADAPTER_ADDRESSES_U1(ctypes.Union): class _IP_ADAPTER_ADDRESSES_U1(ctypes.Union):
_fields_ = [ _fields_ = [
('alignment', ctypes.c_ulonglong), ('alignment', ctypes.c_ulonglong),
('metric', _IP_ADAPTER_ADDRESSES_METRIC), ('metric', _IP_ADAPTER_ADDRESSES_METRIC),
] ]
class IP_ADAPTER_ADDRESSES(ctypes.Structure): class IP_ADAPTER_ADDRESSES(ctypes.Structure):
pass pass
LP_IP_ADAPTER_ADDRESSES = ctypes.POINTER(IP_ADAPTER_ADDRESSES) LP_IP_ADAPTER_ADDRESSES = ctypes.POINTER(IP_ADAPTER_ADDRESSES)
# for now, just use void * for pointers to unused structures # for now, just use void * for pointers to unused structures
@ -125,17 +138,18 @@ PIP_ADAPTER_WINS_SERVER_ADDRESS_LH = ctypes.c_void_p
PIP_ADAPTER_GATEWAY_ADDRESS_LH = ctypes.c_void_p PIP_ADAPTER_GATEWAY_ADDRESS_LH = ctypes.c_void_p
PIP_ADAPTER_DNS_SUFFIX = ctypes.c_void_p PIP_ADAPTER_DNS_SUFFIX = ctypes.c_void_p
IF_OPER_STATUS = ctypes.c_uint # this is an enum, consider http://code.activestate.com/recipes/576415/ IF_OPER_STATUS = ctypes.c_uint # this is an enum, consider
# http://code.activestate.com/recipes/576415/
IF_LUID = ctypes.c_uint64 IF_LUID = ctypes.c_uint64
NET_IF_COMPARTMENT_ID = ctypes.c_uint32 NET_IF_COMPARTMENT_ID = ctypes.c_uint32
GUID = ctypes.c_byte*16 GUID = ctypes.c_byte * 16
NET_IF_NETWORK_GUID = GUID NET_IF_NETWORK_GUID = GUID
NET_IF_CONNECTION_TYPE = ctypes.c_uint # enum NET_IF_CONNECTION_TYPE = ctypes.c_uint # enum
TUNNEL_TYPE = ctypes.c_uint # enum TUNNEL_TYPE = ctypes.c_uint # enum
IP_ADAPTER_ADDRESSES._fields_ = [ IP_ADAPTER_ADDRESSES._fields_ = [
#('u', _IP_ADAPTER_ADDRESSES_U1), # ('u', _IP_ADAPTER_ADDRESSES_U1),
('length', ctypes.c_ulong), ('length', ctypes.c_ulong),
('interface_index', DWORD), ('interface_index', DWORD),
('next', LP_IP_ADAPTER_ADDRESSES), ('next', LP_IP_ADAPTER_ADDRESSES),
@ -147,7 +161,7 @@ IP_ADAPTER_ADDRESSES._fields_ = [
('dns_suffix', ctypes.c_wchar_p), ('dns_suffix', ctypes.c_wchar_p),
('description', ctypes.c_wchar_p), ('description', ctypes.c_wchar_p),
('friendly_name', ctypes.c_wchar_p), ('friendly_name', ctypes.c_wchar_p),
('byte', BYTE*MAX_ADAPTER_ADDRESS_LENGTH), ('byte', BYTE * MAX_ADAPTER_ADDRESS_LENGTH),
('physical_address_length', DWORD), ('physical_address_length', DWORD),
('flags', DWORD), ('flags', DWORD),
('mtu', DWORD), ('mtu', DWORD),
@ -169,11 +183,11 @@ IP_ADAPTER_ADDRESSES._fields_ = [
('connection_type', NET_IF_CONNECTION_TYPE), ('connection_type', NET_IF_CONNECTION_TYPE),
('tunnel_type', TUNNEL_TYPE), ('tunnel_type', TUNNEL_TYPE),
('dhcpv6_server', SOCKET_ADDRESS), ('dhcpv6_server', SOCKET_ADDRESS),
('dhcpv6_client_duid', ctypes.c_byte*MAX_DHCPV6_DUID_LENGTH), ('dhcpv6_client_duid', ctypes.c_byte * MAX_DHCPV6_DUID_LENGTH),
('dhcpv6_client_duid_length', ctypes.c_ulong), ('dhcpv6_client_duid_length', ctypes.c_ulong),
('dhcpv6_iaid', ctypes.c_ulong), ('dhcpv6_iaid', ctypes.c_ulong),
('first_dns_suffix', PIP_ADAPTER_DNS_SUFFIX), ('first_dns_suffix', PIP_ADAPTER_DNS_SUFFIX),
] ]
# define some parameters to the API Functions # define some parameters to the API Functions
GetIfTable = ctypes.windll.iphlpapi.GetIfTable GetIfTable = ctypes.windll.iphlpapi.GetIfTable
@ -181,7 +195,7 @@ GetIfTable.argtypes = [
ctypes.POINTER(MIB_IFTABLE), ctypes.POINTER(MIB_IFTABLE),
ctypes.POINTER(ctypes.c_ulong), ctypes.POINTER(ctypes.c_ulong),
BOOL, BOOL,
] ]
GetIfTable.restype = DWORD GetIfTable.restype = DWORD
GetIpAddrTable = ctypes.windll.iphlpapi.GetIpAddrTable GetIpAddrTable = ctypes.windll.iphlpapi.GetIpAddrTable
@ -189,7 +203,7 @@ GetIpAddrTable.argtypes = [
ctypes.POINTER(MIB_IPADDRTABLE), ctypes.POINTER(MIB_IPADDRTABLE),
ctypes.POINTER(ctypes.c_ulong), ctypes.POINTER(ctypes.c_ulong),
BOOL, BOOL,
] ]
GetIpAddrTable.restype = DWORD GetIpAddrTable.restype = DWORD
GetAdaptersAddresses = ctypes.windll.iphlpapi.GetAdaptersAddresses GetAdaptersAddresses = ctypes.windll.iphlpapi.GetAdaptersAddresses
@ -199,5 +213,5 @@ GetAdaptersAddresses.argtypes = [
ctypes.c_void_p, ctypes.c_void_p,
ctypes.POINTER(IP_ADAPTER_ADDRESSES), ctypes.POINTER(IP_ADAPTER_ADDRESSES),
ctypes.POINTER(ctypes.c_ulong), ctypes.POINTER(ctypes.c_ulong),
] ]
GetAdaptersAddresses.restype = ctypes.c_ulong GetAdaptersAddresses.restype = ctypes.c_ulong

View file

@ -3,7 +3,7 @@ import ctypes.wintypes
GMEM_MOVEABLE = 0x2 GMEM_MOVEABLE = 0x2
GlobalAlloc = ctypes.windll.kernel32.GlobalAlloc GlobalAlloc = ctypes.windll.kernel32.GlobalAlloc
GlobalAlloc.argtypes = ctypes.wintypes.UINT, ctypes.c_ssize_t GlobalAlloc.argtypes = ctypes.wintypes.UINT, ctypes.c_size_t
GlobalAlloc.restype = ctypes.wintypes.HANDLE GlobalAlloc.restype = ctypes.wintypes.HANDLE
GlobalLock = ctypes.windll.kernel32.GlobalLock GlobalLock = ctypes.windll.kernel32.GlobalLock
@ -31,3 +31,15 @@ CreateFileMapping.restype = ctypes.wintypes.HANDLE
MapViewOfFile = ctypes.windll.kernel32.MapViewOfFile MapViewOfFile = ctypes.windll.kernel32.MapViewOfFile
MapViewOfFile.restype = ctypes.wintypes.HANDLE MapViewOfFile.restype = ctypes.wintypes.HANDLE
UnmapViewOfFile = ctypes.windll.kernel32.UnmapViewOfFile
UnmapViewOfFile.argtypes = ctypes.wintypes.HANDLE,
RtlMoveMemory = ctypes.windll.kernel32.RtlMoveMemory
RtlMoveMemory.argtypes = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
)
ctypes.windll.kernel32.LocalFree.argtypes = ctypes.wintypes.HLOCAL,

View file

@ -13,6 +13,7 @@ import six
LRESULT = LPARAM LRESULT = LPARAM
class LPARAM_wstr(LPARAM): class LPARAM_wstr(LPARAM):
""" """
A special instance of LPARAM that can be constructed from a string A special instance of LPARAM that can be constructed from a string
@ -25,14 +26,16 @@ class LPARAM_wstr(LPARAM):
return LPVOID.from_param(six.text_type(param)) return LPVOID.from_param(six.text_type(param))
return LPARAM.from_param(param) return LPARAM.from_param(param)
SendMessage = ctypes.windll.user32.SendMessageW SendMessage = ctypes.windll.user32.SendMessageW
SendMessage.argtypes = (HWND, UINT, WPARAM, LPARAM_wstr) SendMessage.argtypes = (HWND, UINT, WPARAM, LPARAM_wstr)
SendMessage.restype = LRESULT SendMessage.restype = LRESULT
HWND_BROADCAST=0xFFFF HWND_BROADCAST = 0xFFFF
WM_SETTINGCHANGE=0x1A WM_SETTINGCHANGE = 0x1A
# constants from http://msdn.microsoft.com/en-us/library/ms644952%28v=vs.85%29.aspx # constants from http://msdn.microsoft.com
# /en-us/library/ms644952%28v=vs.85%29.aspx
SMTO_ABORTIFHUNG = 0x02 SMTO_ABORTIFHUNG = 0x02
SMTO_BLOCK = 0x01 SMTO_BLOCK = 0x01
SMTO_NORMAL = 0x00 SMTO_NORMAL = 0x00
@ -45,6 +48,7 @@ SendMessageTimeout.argtypes = SendMessage.argtypes + (
) )
SendMessageTimeout.restype = LRESULT SendMessageTimeout.restype = LRESULT
def unicode_as_lparam(source): def unicode_as_lparam(source):
pointer = ctypes.cast(ctypes.c_wchar_p(source), ctypes.c_void_p) pointer = ctypes.cast(ctypes.c_wchar_p(source), ctypes.c_void_p)
return LPARAM(pointer.value) return LPARAM(pointer.value)

View file

@ -5,6 +5,7 @@ mpr = ctypes.windll.mpr
RESOURCETYPE_ANY = 0 RESOURCETYPE_ANY = 0
class NETRESOURCE(ctypes.Structure): class NETRESOURCE(ctypes.Structure):
_fields_ = [ _fields_ = [
('scope', ctypes.wintypes.DWORD), ('scope', ctypes.wintypes.DWORD),
@ -16,6 +17,8 @@ class NETRESOURCE(ctypes.Structure):
('comment', ctypes.wintypes.LPWSTR), ('comment', ctypes.wintypes.LPWSTR),
('provider', ctypes.wintypes.LPWSTR), ('provider', ctypes.wintypes.LPWSTR),
] ]
LPNETRESOURCE = ctypes.POINTER(NETRESOURCE) LPNETRESOURCE = ctypes.POINTER(NETRESOURCE)
WNetAddConnection2 = mpr.WNetAddConnection2W WNetAddConnection2 = mpr.WNetAddConnection2W

View file

@ -1,5 +1,6 @@
import ctypes.wintypes import ctypes.wintypes
class SYSTEM_POWER_STATUS(ctypes.Structure): class SYSTEM_POWER_STATUS(ctypes.Structure):
_fields_ = ( _fields_ = (
('ac_line_status', ctypes.wintypes.BYTE), ('ac_line_status', ctypes.wintypes.BYTE),
@ -12,7 +13,9 @@ class SYSTEM_POWER_STATUS(ctypes.Structure):
@property @property
def ac_line_status_string(self): def ac_line_status_string(self):
return {0:'offline', 1: 'online', 255: 'unknown'}[self.ac_line_status] return {
0: 'offline', 1: 'online', 255: 'unknown'}[self.ac_line_status]
LPSYSTEM_POWER_STATUS = ctypes.POINTER(SYSTEM_POWER_STATUS) LPSYSTEM_POWER_STATUS = ctypes.POINTER(SYSTEM_POWER_STATUS)
GetSystemPowerStatus = ctypes.windll.kernel32.GetSystemPowerStatus GetSystemPowerStatus = ctypes.windll.kernel32.GetSystemPowerStatus
@ -23,6 +26,7 @@ SetThreadExecutionState = ctypes.windll.kernel32.SetThreadExecutionState
SetThreadExecutionState.argtypes = [ctypes.c_uint] SetThreadExecutionState.argtypes = [ctypes.c_uint]
SetThreadExecutionState.restype = ctypes.c_uint SetThreadExecutionState.restype = ctypes.c_uint
class ES: class ES:
""" """
Execution state constants Execution state constants

View file

@ -1,5 +1,6 @@
import ctypes.wintypes import ctypes.wintypes
class LUID(ctypes.Structure): class LUID(ctypes.Structure):
_fields_ = [ _fields_ = [
('low_part', ctypes.wintypes.DWORD), ('low_part', ctypes.wintypes.DWORD),
@ -13,27 +14,31 @@ class LUID(ctypes.Structure):
) )
def __ne__(self, other): def __ne__(self, other):
return not (self==other) return not (self == other)
LookupPrivilegeValue = ctypes.windll.advapi32.LookupPrivilegeValueW LookupPrivilegeValue = ctypes.windll.advapi32.LookupPrivilegeValueW
LookupPrivilegeValue.argtypes = ( LookupPrivilegeValue.argtypes = (
ctypes.wintypes.LPWSTR, # system name ctypes.wintypes.LPWSTR, # system name
ctypes.wintypes.LPWSTR, # name ctypes.wintypes.LPWSTR, # name
ctypes.POINTER(LUID), ctypes.POINTER(LUID),
) )
LookupPrivilegeValue.restype = ctypes.wintypes.BOOL LookupPrivilegeValue.restype = ctypes.wintypes.BOOL
class TOKEN_INFORMATION_CLASS: class TOKEN_INFORMATION_CLASS:
TokenUser = 1 TokenUser = 1
TokenGroups = 2 TokenGroups = 2
TokenPrivileges = 3 TokenPrivileges = 3
# ... see http://msdn.microsoft.com/en-us/library/aa379626%28VS.85%29.aspx # ... see http://msdn.microsoft.com/en-us/library/aa379626%28VS.85%29.aspx
SE_PRIVILEGE_ENABLED_BY_DEFAULT = 0x00000001 SE_PRIVILEGE_ENABLED_BY_DEFAULT = 0x00000001
SE_PRIVILEGE_ENABLED = 0x00000002 SE_PRIVILEGE_ENABLED = 0x00000002
SE_PRIVILEGE_REMOVED = 0x00000004 SE_PRIVILEGE_REMOVED = 0x00000004
SE_PRIVILEGE_USED_FOR_ACCESS = 0x80000000 SE_PRIVILEGE_USED_FOR_ACCESS = 0x80000000
class LUID_AND_ATTRIBUTES(ctypes.Structure): class LUID_AND_ATTRIBUTES(ctypes.Structure):
_fields_ = [ _fields_ = [
('LUID', LUID), ('LUID', LUID),
@ -50,37 +55,43 @@ class LUID_AND_ATTRIBUTES(ctypes.Structure):
size = ctypes.wintypes.DWORD(10240) size = ctypes.wintypes.DWORD(10240)
buf = ctypes.create_unicode_buffer(size.value) buf = ctypes.create_unicode_buffer(size.value)
res = LookupPrivilegeName(None, self.LUID, buf, size) res = LookupPrivilegeName(None, self.LUID, buf, size)
if res == 0: raise RuntimeError if res == 0:
raise RuntimeError
return buf[:size.value] return buf[:size.value]
def __str__(self): def __str__(self):
res = self.get_name() res = self.get_name()
if self.is_enabled(): res += ' (enabled)' if self.is_enabled():
res += ' (enabled)'
return res return res
LookupPrivilegeName = ctypes.windll.advapi32.LookupPrivilegeNameW LookupPrivilegeName = ctypes.windll.advapi32.LookupPrivilegeNameW
LookupPrivilegeName.argtypes = ( LookupPrivilegeName.argtypes = (
ctypes.wintypes.LPWSTR, # lpSystemName ctypes.wintypes.LPWSTR, # lpSystemName
ctypes.POINTER(LUID), # lpLuid ctypes.POINTER(LUID), # lpLuid
ctypes.wintypes.LPWSTR, # lpName ctypes.wintypes.LPWSTR, # lpName
ctypes.POINTER(ctypes.wintypes.DWORD), # cchName ctypes.POINTER(ctypes.wintypes.DWORD), # cchName
) )
LookupPrivilegeName.restype = ctypes.wintypes.BOOL LookupPrivilegeName.restype = ctypes.wintypes.BOOL
class TOKEN_PRIVILEGES(ctypes.Structure): class TOKEN_PRIVILEGES(ctypes.Structure):
_fields_ = [ _fields_ = [
('count', ctypes.wintypes.DWORD), ('count', ctypes.wintypes.DWORD),
('privileges', LUID_AND_ATTRIBUTES*0), ('privileges', LUID_AND_ATTRIBUTES * 0),
] ]
def get_array(self): def get_array(self):
array_type = LUID_AND_ATTRIBUTES*self.count array_type = LUID_AND_ATTRIBUTES * self.count
privileges = ctypes.cast(self.privileges, ctypes.POINTER(array_type)).contents privileges = ctypes.cast(
self.privileges, ctypes.POINTER(array_type)).contents
return privileges return privileges
def __iter__(self): def __iter__(self):
return iter(self.get_array()) return iter(self.get_array())
PTOKEN_PRIVILEGES = ctypes.POINTER(TOKEN_PRIVILEGES) PTOKEN_PRIVILEGES = ctypes.POINTER(TOKEN_PRIVILEGES)
GetTokenInformation = ctypes.windll.advapi32.GetTokenInformation GetTokenInformation = ctypes.windll.advapi32.GetTokenInformation
@ -90,7 +101,7 @@ GetTokenInformation.argtypes = [
ctypes.c_void_p, # TokenInformation ctypes.c_void_p, # TokenInformation
ctypes.wintypes.DWORD, # TokenInformationLength ctypes.wintypes.DWORD, # TokenInformationLength
ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength
] ]
GetTokenInformation.restype = ctypes.wintypes.BOOL GetTokenInformation.restype = ctypes.wintypes.BOOL
# http://msdn.microsoft.com/en-us/library/aa375202%28VS.85%29.aspx # http://msdn.microsoft.com/en-us/library/aa375202%28VS.85%29.aspx
@ -103,4 +114,4 @@ AdjustTokenPrivileges.argtypes = [
ctypes.wintypes.DWORD, # BufferLength of PreviousState ctypes.wintypes.DWORD, # BufferLength of PreviousState
PTOKEN_PRIVILEGES, # PreviousState (out, optional) PTOKEN_PRIVILEGES, # PreviousState (out, optional)
ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength
] ]

View file

@ -5,5 +5,7 @@ TOKEN_ALL_ACCESS = 0xf01ff
GetCurrentProcess = ctypes.windll.kernel32.GetCurrentProcess GetCurrentProcess = ctypes.windll.kernel32.GetCurrentProcess
GetCurrentProcess.restype = ctypes.wintypes.HANDLE GetCurrentProcess.restype = ctypes.wintypes.HANDLE
OpenProcessToken = ctypes.windll.advapi32.OpenProcessToken OpenProcessToken = ctypes.windll.advapi32.OpenProcessToken
OpenProcessToken.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.POINTER(ctypes.wintypes.HANDLE)) OpenProcessToken.argtypes = (
ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD,
ctypes.POINTER(ctypes.wintypes.HANDLE))
OpenProcessToken.restype = ctypes.wintypes.BOOL OpenProcessToken.restype = ctypes.wintypes.BOOL

View file

@ -60,12 +60,15 @@ POLICY_EXECUTE = (
POLICY_VIEW_LOCAL_INFORMATION | POLICY_VIEW_LOCAL_INFORMATION |
POLICY_LOOKUP_NAMES) POLICY_LOOKUP_NAMES)
class TokenAccess: class TokenAccess:
TOKEN_QUERY = 0x8 TOKEN_QUERY = 0x8
class TokenInformationClass: class TokenInformationClass:
TokenUser = 1 TokenUser = 1
class TOKEN_USER(ctypes.Structure): class TOKEN_USER(ctypes.Structure):
num = 1 num = 1
_fields_ = [ _fields_ = [
@ -100,6 +103,7 @@ class SECURITY_DESCRIPTOR(ctypes.Structure):
('Dacl', ctypes.c_void_p), ('Dacl', ctypes.c_void_p),
] ]
class SECURITY_ATTRIBUTES(ctypes.Structure): class SECURITY_ATTRIBUTES(ctypes.Structure):
""" """
typedef struct _SECURITY_ATTRIBUTES { typedef struct _SECURITY_ATTRIBUTES {
@ -126,3 +130,10 @@ class SECURITY_ATTRIBUTES(ctypes.Structure):
def descriptor(self, value): def descriptor(self, value):
self._descriptor = value self._descriptor = value
self.lpSecurityDescriptor = ctypes.addressof(value) self.lpSecurityDescriptor = ctypes.addressof(value)
ctypes.windll.advapi32.SetSecurityDescriptorOwner.argtypes = (
ctypes.POINTER(SECURITY_DESCRIPTOR),
ctypes.c_void_p,
ctypes.wintypes.BOOL,
)

View file

@ -1,6 +1,7 @@
import ctypes.wintypes import ctypes.wintypes
BOOL = ctypes.wintypes.BOOL BOOL = ctypes.wintypes.BOOL
class SHELLSTATE(ctypes.Structure): class SHELLSTATE(ctypes.Structure):
_fields_ = [ _fields_ = [
('show_all_objects', BOOL, 1), ('show_all_objects', BOOL, 1),
@ -34,6 +35,7 @@ class SHELLSTATE(ctypes.Structure):
('spare_flags', ctypes.wintypes.UINT, 13), ('spare_flags', ctypes.wintypes.UINT, 13),
] ]
SSF_SHOWALLOBJECTS = 0x00000001 SSF_SHOWALLOBJECTS = 0x00000001
"The fShowAllObjects member is being requested." "The fShowAllObjects member is being requested."
@ -62,7 +64,13 @@ SSF_SHOWATTRIBCOL = 0x00000100
"The fShowAttribCol member is being requested. (Windows Vista: Not used.)" "The fShowAttribCol member is being requested. (Windows Vista: Not used.)"
SSF_DESKTOPHTML = 0x00000200 SSF_DESKTOPHTML = 0x00000200
"The fDesktopHTML member is being requested. Set is not available. Instead, for versions of Microsoft Windows prior to Windows XP, enable Desktop HTML by IActiveDesktop. The use of IActiveDesktop for this purpose, however, is not recommended for Windows XP and later versions of Windows, and is deprecated in Windows Vista." """
The fDesktopHTML member is being requested. Set is not available.
Instead, for versions of Microsoft Windows prior to Windows XP,
enable Desktop HTML by IActiveDesktop. The use of IActiveDesktop
for this purpose, however, is not recommended for Windows XP and
later versions of Windows, and is deprecated in Windows Vista.
"""
SSF_WIN95CLASSIC = 0x00000400 SSF_WIN95CLASSIC = 0x00000400
"The fWin95Classic member is being requested." "The fWin95Classic member is being requested."
@ -118,5 +126,5 @@ SHGetSetSettings.argtypes = [
ctypes.POINTER(SHELLSTATE), ctypes.POINTER(SHELLSTATE),
ctypes.wintypes.DWORD, ctypes.wintypes.DWORD,
ctypes.wintypes.BOOL, # get or set (True: set) ctypes.wintypes.BOOL, # get or set (True: set)
] ]
SHGetSetSettings.restype = None SHGetSetSettings.restype = None

View file

@ -6,7 +6,7 @@ SystemParametersInfo.argtypes = (
ctypes.wintypes.UINT, ctypes.wintypes.UINT,
ctypes.c_void_p, ctypes.c_void_p,
ctypes.wintypes.UINT, ctypes.wintypes.UINT,
) )
SPI_GETACTIVEWINDOWTRACKING = 0x1000 SPI_GETACTIVEWINDOWTRACKING = 0x1000
SPI_SETACTIVEWINDOWTRACKING = 0x1001 SPI_SETACTIVEWINDOWTRACKING = 0x1001

View file

@ -5,20 +5,22 @@ import re
import itertools import itertools
from contextlib import contextmanager from contextlib import contextmanager
import io import io
import six
import ctypes import ctypes
from ctypes import windll from ctypes import windll
import six
from six.moves import map
from jaraco.windows.api import clipboard, memory from jaraco.windows.api import clipboard, memory
from jaraco.windows.error import handle_nonzero_success, WindowsError from jaraco.windows.error import handle_nonzero_success, WindowsError
from jaraco.windows.memory import LockedMemory from jaraco.windows.memory import LockedMemory
__all__ = ( __all__ = (
'CF_TEXT', 'GetClipboardData', 'CloseClipboard', 'GetClipboardData', 'CloseClipboard',
'SetClipboardData', 'OpenClipboard', 'SetClipboardData', 'OpenClipboard',
) )
def OpenClipboard(owner=None): def OpenClipboard(owner=None):
""" """
Open the clipboard. Open the clipboard.
@ -30,9 +32,14 @@ def OpenClipboard(owner=None):
""" """
handle_nonzero_success(windll.user32.OpenClipboard(owner)) handle_nonzero_success(windll.user32.OpenClipboard(owner))
CloseClipboard = lambda: handle_nonzero_success(windll.user32.CloseClipboard())
def CloseClipboard():
handle_nonzero_success(windll.user32.CloseClipboard())
data_handlers = dict() data_handlers = dict()
def handles(*formats): def handles(*formats):
def register(func): def register(func):
for format in formats: for format in formats:
@ -40,36 +47,43 @@ def handles(*formats):
return func return func
return register return register
def nts(s):
def nts(buffer):
""" """
Null Terminated String Null Terminated String
Get the portion of s up to a null character. Get the portion of bytestring buffer up to a null character.
""" """
result, null, rest = s.partition('\x00') result, null, rest = buffer.partition('\x00')
return result return result
@handles(clipboard.CF_DIBV5, clipboard.CF_DIB) @handles(clipboard.CF_DIBV5, clipboard.CF_DIB)
def raw_data(handle): def raw_data(handle):
return LockedMemory(handle).data return LockedMemory(handle).data
@handles(clipboard.CF_TEXT) @handles(clipboard.CF_TEXT)
def text_string(handle): def text_string(handle):
return nts(raw_data(handle)) return nts(raw_data(handle))
@handles(clipboard.CF_UNICODETEXT) @handles(clipboard.CF_UNICODETEXT)
def unicode_string(handle): def unicode_string(handle):
return nts(raw_data(handle).decode('utf-16')) return nts(raw_data(handle).decode('utf-16'))
@handles(clipboard.CF_BITMAP) @handles(clipboard.CF_BITMAP)
def as_bitmap(handle): def as_bitmap(handle):
# handle is HBITMAP # handle is HBITMAP
raise NotImplementedError("Can't convert to DIB") raise NotImplementedError("Can't convert to DIB")
# todo: use GetDIBits http://msdn.microsoft.com/en-us/library/dd144879%28v=VS.85%29.aspx # todo: use GetDIBits http://msdn.microsoft.com
# /en-us/library/dd144879%28v=VS.85%29.aspx
@handles(clipboard.CF_HTML) @handles(clipboard.CF_HTML)
class HTMLSnippet(object): class HTMLSnippet(object):
def __init__(self, handle): def __init__(self, handle):
self.data = text_string(handle) self.data = nts(raw_data(handle).decode('utf-8'))
self.headers = self.parse_headers(self.data) self.headers = self.parse_headers(self.data)
@property @property
@ -79,11 +93,13 @@ class HTMLSnippet(object):
@staticmethod @staticmethod
def parse_headers(data): def parse_headers(data):
d = io.StringIO(data) d = io.StringIO(data)
def header_line(line): def header_line(line):
return re.match('(\w+):(.*)', line) return re.match('(\w+):(.*)', line)
headers = itertools.imap(header_line, d) headers = map(header_line, d)
# grab headers until they no longer match # grab headers until they no longer match
headers = itertools.takewhile(bool, headers) headers = itertools.takewhile(bool, headers)
def best_type(value): def best_type(value):
try: try:
return int(value) return int(value)
@ -101,26 +117,34 @@ class HTMLSnippet(object):
) )
return dict(pairs) return dict(pairs)
def GetClipboardData(type=clipboard.CF_UNICODETEXT): def GetClipboardData(type=clipboard.CF_UNICODETEXT):
if not type in data_handlers: if type not in data_handlers:
raise NotImplementedError("No support for data of type %d" % type) raise NotImplementedError("No support for data of type %d" % type)
handle = clipboard.GetClipboardData(type) handle = clipboard.GetClipboardData(type)
if handle is None: if handle is None:
raise TypeError("No clipboard data of type %d" % type) raise TypeError("No clipboard data of type %d" % type)
return data_handlers[type](handle) return data_handlers[type](handle)
EmptyClipboard = lambda: handle_nonzero_success(windll.user32.EmptyClipboard())
def EmptyClipboard():
handle_nonzero_success(windll.user32.EmptyClipboard())
def SetClipboardData(type, content): def SetClipboardData(type, content):
""" """
Modeled after http://msdn.microsoft.com/en-us/library/ms649016%28VS.85%29.aspx#_win32_Copying_Information_to_the_Clipboard Modeled after http://msdn.microsoft.com
/en-us/library/ms649016%28VS.85%29.aspx
#_win32_Copying_Information_to_the_Clipboard
""" """
allocators = { allocators = {
clipboard.CF_TEXT: ctypes.create_string_buffer, clipboard.CF_TEXT: ctypes.create_string_buffer,
clipboard.CF_UNICODETEXT: ctypes.create_unicode_buffer, clipboard.CF_UNICODETEXT: ctypes.create_unicode_buffer,
clipboard.CF_HTML: ctypes.create_string_buffer,
} }
if not type in allocators: if type not in allocators:
raise NotImplementedError("Only text types are supported at this time") raise NotImplementedError(
"Only text and HTML types are supported at this time")
# allocate the memory for the data # allocate the memory for the data
content = allocators[type](content) content = allocators[type](content)
flags = memory.GMEM_MOVEABLE flags = memory.GMEM_MOVEABLE
@ -132,47 +156,57 @@ def SetClipboardData(type, content):
if result is None: if result is None:
raise WindowsError() raise WindowsError()
def set_text(source): def set_text(source):
with context(): with context():
EmptyClipboard() EmptyClipboard()
SetClipboardData(clipboard.CF_TEXT, source) SetClipboardData(clipboard.CF_TEXT, source)
def get_text(): def get_text():
with context(): with context():
result = GetClipboardData(clipboard.CF_TEXT) result = GetClipboardData(clipboard.CF_TEXT)
return result return result
def set_unicode_text(source): def set_unicode_text(source):
with context(): with context():
EmptyClipboard() EmptyClipboard()
SetClipboardData(clipboard.CF_UNICODETEXT, source) SetClipboardData(clipboard.CF_UNICODETEXT, source)
def get_unicode_text(): def get_unicode_text():
with context(): with context():
return GetClipboardData() return GetClipboardData()
def get_html(): def get_html():
with context(): with context():
result = GetClipboardData(clipboard.CF_HTML) result = GetClipboardData(clipboard.CF_HTML)
return result return result
def set_html(source): def set_html(source):
with context(): with context():
EmptyClipboard() EmptyClipboard()
SetClipboardData(clipboard.CF_UNICODETEXT, source) SetClipboardData(clipboard.CF_UNICODETEXT, source)
def get_image(): def get_image():
with context(): with context():
return GetClipboardData(clipboard.CF_DIB) return GetClipboardData(clipboard.CF_DIB)
def paste_stdout(): def paste_stdout():
getter = get_unicode_text if six.PY3 else get_text getter = get_unicode_text if six.PY3 else get_text
sys.stdout.write(getter()) sys.stdout.write(getter())
def stdin_copy(): def stdin_copy():
setter = set_unicode_text if six.PY3 else set_text setter = set_unicode_text if six.PY3 else set_text
setter(sys.stdin.read()) setter(sys.stdin.read())
@contextmanager @contextmanager
def context(): def context():
OpenClipboard() OpenClipboard()
@ -181,10 +215,12 @@ def context():
finally: finally:
CloseClipboard() CloseClipboard()
def get_formats(): def get_formats():
with context(): with context():
format_index = 0 format_index = 0
while True: while True:
format_index = clipboard.EnumClipboardFormats(format_index) format_index = clipboard.EnumClipboardFormats(format_index)
if format_index == 0: break if format_index == 0:
break
yield format_index yield format_index

View file

@ -3,17 +3,20 @@ import ctypes
import jaraco.windows.api.credential as api import jaraco.windows.api.credential as api
from . import error from . import error
CRED_TYPE_GENERIC=1 CRED_TYPE_GENERIC = 1
def CredDelete(TargetName, Type, Flags=0): def CredDelete(TargetName, Type, Flags=0):
error.handle_nonzero_success(api.CredDelete(TargetName, Type, Flags)) error.handle_nonzero_success(api.CredDelete(TargetName, Type, Flags))
def CredRead(TargetName, Type, Flags=0): def CredRead(TargetName, Type, Flags=0):
cred_pointer = api.PCREDENTIAL() cred_pointer = api.PCREDENTIAL()
res = api.CredRead(TargetName, Type, Flags, ctypes.byref(cred_pointer)) res = api.CredRead(TargetName, Type, Flags, ctypes.byref(cred_pointer))
error.handle_nonzero_success(res) error.handle_nonzero_success(res)
return cred_pointer.contents return cred_pointer.contents
def CredWrite(Credential, Flags=0): def CredWrite(Credential, Flags=0):
res = api.CredWrite(Credential, Flags) res = api.CredWrite(Credential, Flags)
error.handle_nonzero_success(res) error.handle_nonzero_success(res)

View file

@ -14,6 +14,11 @@ import ctypes
from ctypes import wintypes from ctypes import wintypes
from jaraco.windows.error import handle_nonzero_success from jaraco.windows.error import handle_nonzero_success
# for type declarations
__import__('jaraco.windows.api.memory')
class DATA_BLOB(ctypes.Structure): class DATA_BLOB(ctypes.Structure):
r""" r"""
A data blob structure for use with MS DPAPI functions. A data blob structure for use with MS DPAPI functions.
@ -48,7 +53,7 @@ class DATA_BLOB(ctypes.Structure):
def get_data(self): def get_data(self):
"Get the data for this blob" "Get the data for this blob"
array = ctypes.POINTER(ctypes.c_char*len(self)) array = ctypes.POINTER(ctypes.c_char * len(self))
return ctypes.cast(self.data, array).contents.raw return ctypes.cast(self.data, array).contents.raw
def __len__(self): def __len__(self):
@ -65,6 +70,7 @@ class DATA_BLOB(ctypes.Structure):
""" """
ctypes.windll.kernel32.LocalFree(self.data) ctypes.windll.kernel32.LocalFree(self.data)
p_DATA_BLOB = ctypes.POINTER(DATA_BLOB) p_DATA_BLOB = ctypes.POINTER(DATA_BLOB)
_CryptProtectData = ctypes.windll.crypt32.CryptProtectData _CryptProtectData = ctypes.windll.crypt32.CryptProtectData
@ -76,7 +82,7 @@ _CryptProtectData.argtypes = [
ctypes.c_void_p, # POINTER(CRYPTPROTECT_PROMPTSTRUCT), # prompt struct ctypes.c_void_p, # POINTER(CRYPTPROTECT_PROMPTSTRUCT), # prompt struct
wintypes.DWORD, # flags wintypes.DWORD, # flags
p_DATA_BLOB, # data out p_DATA_BLOB, # data out
] ]
_CryptProtectData.restype = wintypes.BOOL _CryptProtectData.restype = wintypes.BOOL
_CryptUnprotectData = ctypes.windll.crypt32.CryptUnprotectData _CryptUnprotectData = ctypes.windll.crypt32.CryptUnprotectData
@ -88,15 +94,16 @@ _CryptUnprotectData.argtypes = [
ctypes.c_void_p, # POINTER(CRYPTPROTECT_PROMPTSTRUCT), # prompt struct ctypes.c_void_p, # POINTER(CRYPTPROTECT_PROMPTSTRUCT), # prompt struct
wintypes.DWORD, # flags wintypes.DWORD, # flags
p_DATA_BLOB, # data out p_DATA_BLOB, # data out
] ]
_CryptUnprotectData.restype = wintypes.BOOL _CryptUnprotectData.restype = wintypes.BOOL
CRYPTPROTECT_UI_FORBIDDEN = 0x01 CRYPTPROTECT_UI_FORBIDDEN = 0x01
def CryptProtectData( def CryptProtectData(
data, description=None, optional_entropy=None, data, description=None, optional_entropy=None,
prompt_struct=None, flags=0, prompt_struct=None, flags=0,
): ):
""" """
Encrypt data Encrypt data
""" """
@ -118,7 +125,9 @@ def CryptProtectData(
data_out.free() data_out.free()
return res return res
def CryptUnprotectData(data, optional_entropy=None, prompt_struct=None, flags=0):
def CryptUnprotectData(
data, optional_entropy=None, prompt_struct=None, flags=0):
""" """
Returns a tuple of (description, data) where description is the Returns a tuple of (description, data) where description is the
the description that was passed to the CryptProtectData call and the description that was passed to the CryptProtectData call and

View file

@ -7,7 +7,7 @@ import ctypes
import ctypes.wintypes import ctypes.wintypes
import six import six
winreg = six.moves.winreg from six.moves import winreg
from jaraco.ui.editor import EditableFile from jaraco.ui.editor import EditableFile
@ -19,17 +19,21 @@ from .registry import key_values as registry_key_values
def SetEnvironmentVariable(name, value): def SetEnvironmentVariable(name, value):
error.handle_nonzero_success(environ.SetEnvironmentVariable(name, value)) error.handle_nonzero_success(environ.SetEnvironmentVariable(name, value))
def ClearEnvironmentVariable(name): def ClearEnvironmentVariable(name):
error.handle_nonzero_success(environ.SetEnvironmentVariable(name, None)) error.handle_nonzero_success(environ.SetEnvironmentVariable(name, None))
def GetEnvironmentVariable(name): def GetEnvironmentVariable(name):
max_size = 2**15-1 max_size = 2**15 - 1
buffer = ctypes.create_unicode_buffer(max_size) buffer = ctypes.create_unicode_buffer(max_size)
error.handle_nonzero_success(environ.GetEnvironmentVariable(name, buffer, max_size)) error.handle_nonzero_success(
environ.GetEnvironmentVariable(name, buffer, max_size))
return buffer.value return buffer.value
### ###
class RegisteredEnvironment(object): class RegisteredEnvironment(object):
""" """
Manages the environment variables as set in the Windows Registry. Manages the environment variables as set in the Windows Registry.
@ -86,8 +90,8 @@ class RegisteredEnvironment(object):
if value in values: if value in values:
return return
new_value = sep.join(values + [value]) new_value = sep.join(values + [value])
winreg.SetValueEx(class_.key, name, 0, winreg.REG_EXPAND_SZ, winreg.SetValueEx(
new_value) class_.key, name, 0, winreg.REG_EXPAND_SZ, new_value)
class_.notify() class_.notify()
@classmethod @classmethod
@ -141,24 +145,30 @@ class RegisteredEnvironment(object):
) )
error.handle_nonzero_success(res) error.handle_nonzero_success(res)
class MachineRegisteredEnvironment(RegisteredEnvironment): class MachineRegisteredEnvironment(RegisteredEnvironment):
path = r'SYSTEM\CurrentControlSet\Control\Session Manager\Environment' path = r'SYSTEM\CurrentControlSet\Control\Session Manager\Environment'
hklm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) hklm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
try: try:
key = winreg.OpenKey(hklm, path, 0, key = winreg.OpenKey(
hklm, path, 0,
winreg.KEY_READ | winreg.KEY_WRITE) winreg.KEY_READ | winreg.KEY_WRITE)
except WindowsError: except WindowsError:
key = winreg.OpenKey(hklm, path, 0, winreg.KEY_READ) key = winreg.OpenKey(hklm, path, 0, winreg.KEY_READ)
class UserRegisteredEnvironment(RegisteredEnvironment): class UserRegisteredEnvironment(RegisteredEnvironment):
hkcu = winreg.ConnectRegistry(None, winreg.HKEY_CURRENT_USER) hkcu = winreg.ConnectRegistry(None, winreg.HKEY_CURRENT_USER)
key = winreg.OpenKey(hkcu, 'Environment', 0, key = winreg.OpenKey(
hkcu, 'Environment', 0,
winreg.KEY_READ | winreg.KEY_WRITE) winreg.KEY_READ | winreg.KEY_WRITE)
def trim(s): def trim(s):
from textwrap import dedent from textwrap import dedent
return dedent(s).strip() return dedent(s).strip()
def enver(*args): def enver(*args):
""" """
%prog [<name>=[value]] %prog [<name>=[value]]
@ -201,17 +211,21 @@ def enver(*args):
dest='class_', dest='class_',
help="Use the current user's environment", help="Use the current user's environment",
) )
parser.add_option('-a', '--append', parser.add_option(
'-a', '--append',
action='store_true', default=False, action='store_true', default=False,
help="Append the value to any existing value (default for PATH and PATHEXT)",) help="Append the value to any existing value (default for PATH and PATHEXT)",
)
parser.add_option( parser.add_option(
'-r', '--replace', '-r', '--replace',
action='store_true', default=False, action='store_true', default=False,
help="Replace any existing value (used to override default append for PATH and PATHEXT)", help="Replace any existing value (used to override default append "
"for PATH and PATHEXT)",
) )
parser.add_option( parser.add_option(
'--remove-value', action='store_true', default=False, '--remove-value', action='store_true', default=False,
help="Remove any matching values from a semicolon-separated multi-value variable", help="Remove any matching values from a semicolon-separated "
"multi-value variable",
) )
parser.add_option( parser.add_option(
'-e', '--edit', action='store_true', default=False, '-e', '--edit', action='store_true', default=False,
@ -224,7 +238,7 @@ def enver(*args):
if args: if args:
parser.error("Too many parameters specified") parser.error("Too many parameters specified")
raise SystemExit(1) raise SystemExit(1)
if not '=' in param and not options.edit: if '=' not in param and not options.edit:
parser.error("Expected <name>= or <name>=<value>") parser.error("Expected <name>= or <name>=<value>")
raise SystemExit(2) raise SystemExit(2)
name, sep, value = param.partition('=') name, sep, value = param.partition('=')
@ -238,5 +252,6 @@ def enver(*args):
except IndexError: except IndexError:
options.class_.show() options.class_.show()
if __name__ == '__main__': if __name__ == '__main__':
enver() enver()

View file

@ -6,9 +6,12 @@ import ctypes
import ctypes.wintypes import ctypes.wintypes
import six import six
builtins = six.moves.builtins builtins = six.moves.builtins
__import__('jaraco.windows.api.memory')
def format_system_message(errno): def format_system_message(errno):
""" """
Call FormatMessage with a system error number to retrieve Call FormatMessage with a system error number to retrieve
@ -46,13 +49,16 @@ def format_system_message(errno):
class WindowsError(builtins.WindowsError): class WindowsError(builtins.WindowsError):
"more info about errors at http://msdn.microsoft.com/en-us/library/ms681381(VS.85).aspx" """
More info about errors at
http://msdn.microsoft.com/en-us/library/ms681381(VS.85).aspx
"""
def __init__(self, value=None): def __init__(self, value=None):
if value is None: if value is None:
value = ctypes.windll.kernel32.GetLastError() value = ctypes.windll.kernel32.GetLastError()
strerror = format_system_message(value) strerror = format_system_message(value)
if sys.version_info > (3,3): if sys.version_info > (3, 3):
args = 0, strerror, None, value args = 0, strerror, None, value
else: else:
args = value, strerror args = value, strerror
@ -72,6 +78,7 @@ class WindowsError(builtins.WindowsError):
def __repr__(self): def __repr__(self):
return '{self.__class__.__name__}({self.winerror})'.format(**vars()) return '{self.__class__.__name__}({self.winerror})'.format(**vars())
def handle_nonzero_success(result): def handle_nonzero_success(result):
if result == 0: if result == 0:
raise WindowsError() raise WindowsError()

View file

@ -8,6 +8,7 @@ import win32evtlogutil
error = win32api.error # The error the evtlog module raises. error = win32api.error # The error the evtlog module raises.
class EventLog(object): class EventLog(object):
def __init__(self, name="Application", machine_name=None): def __init__(self, name="Application", machine_name=None):
self.machine_name = machine_name self.machine_name = machine_name
@ -29,6 +30,7 @@ class EventLog(object):
win32evtlog.EVENTLOG_BACKWARDS_READ win32evtlog.EVENTLOG_BACKWARDS_READ
| win32evtlog.EVENTLOG_SEQUENTIAL_READ | win32evtlog.EVENTLOG_SEQUENTIAL_READ
) )
def get_records(self, flags=_default_flags): def get_records(self, flags=_default_flags):
with self: with self:
while True: while True:

View file

@ -7,19 +7,24 @@ import sys
import operator import operator
import collections import collections
import functools import functools
from ctypes import (POINTER, byref, cast, create_unicode_buffer, import stat
from ctypes import (
POINTER, byref, cast, create_unicode_buffer,
create_string_buffer, windll) create_string_buffer, windll)
from ctypes.wintypes import LPWSTR
import nt
import posixpath
import six import six
from six.moves import builtins, filter, map from six.moves import builtins, filter, map
from jaraco.structures import binary from jaraco.structures import binary
from jaraco.text import local_format as lf
from jaraco.windows.error import WindowsError, handle_nonzero_success from jaraco.windows.error import WindowsError, handle_nonzero_success
import jaraco.windows.api.filesystem as api import jaraco.windows.api.filesystem as api
from jaraco.windows import reparse from jaraco.windows import reparse
def mklink(): def mklink():
""" """
Like cmd.exe's mklink except it will infer directory status of the Like cmd.exe's mklink except it will infer directory status of the
@ -27,7 +32,8 @@ def mklink():
""" """
from optparse import OptionParser from optparse import OptionParser
parser = OptionParser(usage="usage: %prog [options] link target") parser = OptionParser(usage="usage: %prog [options] link target")
parser.add_option('-d', '--directory', parser.add_option(
'-d', '--directory',
help="Target is a directory (only necessary if not present)", help="Target is a directory (only necessary if not present)",
action="store_true") action="store_true")
options, args = parser.parse_args() options, args = parser.parse_args()
@ -38,6 +44,7 @@ def mklink():
symlink(target, link, options.directory) symlink(target, link, options.directory)
sys.stdout.write("Symbolic link created: %(link)s --> %(target)s\n" % vars()) sys.stdout.write("Symbolic link created: %(link)s --> %(target)s\n" % vars())
def _is_target_a_directory(link, rel_target): def _is_target_a_directory(link, rel_target):
""" """
If creating a symlink from link to a target, determine if target If creating a symlink from link to a target, determine if target
@ -46,15 +53,20 @@ def _is_target_a_directory(link, rel_target):
target = os.path.join(os.path.dirname(link), rel_target) target = os.path.join(os.path.dirname(link), rel_target)
return os.path.isdir(target) return os.path.isdir(target)
def symlink(target, link, target_is_directory = False):
def symlink(target, link, target_is_directory=False):
""" """
An implementation of os.symlink for Windows (Vista and greater) An implementation of os.symlink for Windows (Vista and greater)
""" """
target_is_directory = (target_is_directory or target_is_directory = (
_is_target_a_directory(link, target)) target_is_directory or
_is_target_a_directory(link, target)
)
# normalize the target (MS symlinks don't respect forward slashes) # normalize the target (MS symlinks don't respect forward slashes)
target = os.path.normpath(target) target = os.path.normpath(target)
handle_nonzero_success(api.CreateSymbolicLink(link, target, target_is_directory)) handle_nonzero_success(
api.CreateSymbolicLink(link, target, target_is_directory))
def link(target, link): def link(target, link):
""" """
@ -62,6 +74,7 @@ def link(target, link):
""" """
handle_nonzero_success(api.CreateHardLink(link, target, None)) handle_nonzero_success(api.CreateHardLink(link, target, None))
def is_reparse_point(path): def is_reparse_point(path):
""" """
Determine if the given path is a reparse point. Determine if the given path is a reparse point.
@ -74,10 +87,12 @@ def is_reparse_point(path):
and bool(res & api.FILE_ATTRIBUTE_REPARSE_POINT) and bool(res & api.FILE_ATTRIBUTE_REPARSE_POINT)
) )
def islink(path): def islink(path):
"Determine if the given path is a symlink" "Determine if the given path is a symlink"
return is_reparse_point(path) and is_symlink(path) return is_reparse_point(path) and is_symlink(path)
def _patch_path(path): def _patch_path(path):
""" """
Paths have a max length of api.MAX_PATH characters (260). If a target path Paths have a max length of api.MAX_PATH characters (260). If a target path
@ -86,13 +101,15 @@ def _patch_path(path):
See http://msdn.microsoft.com/en-us/library/aa365247%28v=vs.85%29.aspx for See http://msdn.microsoft.com/en-us/library/aa365247%28v=vs.85%29.aspx for
details. details.
""" """
if path.startswith('\\\\?\\'): return path if path.startswith('\\\\?\\'):
return path
abs_path = os.path.abspath(path) abs_path = os.path.abspath(path)
if not abs_path[1] == ':': if not abs_path[1] == ':':
# python doesn't include the drive letter, but \\?\ requires it # python doesn't include the drive letter, but \\?\ requires it
abs_path = os.getcwd()[:2] + abs_path abs_path = os.getcwd()[:2] + abs_path
return '\\\\?\\' + abs_path return '\\\\?\\' + abs_path
def is_symlink(path): def is_symlink(path):
""" """
Assuming path is a reparse point, determine if it's a symlink. Assuming path is a reparse point, determine if it's a symlink.
@ -102,11 +119,13 @@ def is_symlink(path):
return _is_symlink(next(find_files(path))) return _is_symlink(next(find_files(path)))
except WindowsError as orig_error: except WindowsError as orig_error:
tmpl = "Error accessing {path}: {orig_error.message}" tmpl = "Error accessing {path}: {orig_error.message}"
raise builtins.WindowsError(lf(tmpl)) raise builtins.WindowsError(tmpl.format(**locals()))
def _is_symlink(find_data): def _is_symlink(find_data):
return find_data.reserved[0] == api.IO_REPARSE_TAG_SYMLINK return find_data.reserved[0] == api.IO_REPARSE_TAG_SYMLINK
def find_files(spec): def find_files(spec):
""" """
A pythonic wrapper around the FindFirstFile/FindNextFile win32 api. A pythonic wrapper around the FindFirstFile/FindNextFile win32 api.
@ -133,11 +152,13 @@ def find_files(spec):
error = WindowsError() error = WindowsError()
if error.code == api.ERROR_NO_MORE_FILES: if error.code == api.ERROR_NO_MORE_FILES:
break break
else: raise error else:
raise error
# todo: how to close handle when generator is destroyed? # todo: how to close handle when generator is destroyed?
# hint: catch GeneratorExit # hint: catch GeneratorExit
windll.kernel32.FindClose(handle) windll.kernel32.FindClose(handle)
def get_final_path(path): def get_final_path(path):
""" """
For a given path, determine the ultimate location of that path. For a given path, determine the ultimate location of that path.
@ -150,7 +171,9 @@ def get_final_path(path):
trace_symlink_target instead. trace_symlink_target instead.
""" """
desired_access = api.NULL desired_access = api.NULL
share_mode = api.FILE_SHARE_READ | api.FILE_SHARE_WRITE | api.FILE_SHARE_DELETE share_mode = (
api.FILE_SHARE_READ | api.FILE_SHARE_WRITE | api.FILE_SHARE_DELETE
)
security_attributes = api.LPSECURITY_ATTRIBUTES() # NULL pointer security_attributes = api.LPSECURITY_ATTRIBUTES() # NULL pointer
hFile = api.CreateFile( hFile = api.CreateFile(
path, path,
@ -165,10 +188,12 @@ def get_final_path(path):
if hFile == api.INVALID_HANDLE_VALUE: if hFile == api.INVALID_HANDLE_VALUE:
raise WindowsError() raise WindowsError()
buf_size = api.GetFinalPathNameByHandle(hFile, api.LPWSTR(), 0, api.VOLUME_NAME_DOS) buf_size = api.GetFinalPathNameByHandle(
hFile, LPWSTR(), 0, api.VOLUME_NAME_DOS)
handle_nonzero_success(buf_size) handle_nonzero_success(buf_size)
buf = create_unicode_buffer(buf_size) buf = create_unicode_buffer(buf_size)
result_length = api.GetFinalPathNameByHandle(hFile, buf, len(buf), api.VOLUME_NAME_DOS) result_length = api.GetFinalPathNameByHandle(
hFile, buf, len(buf), api.VOLUME_NAME_DOS)
assert result_length < len(buf) assert result_length < len(buf)
handle_nonzero_success(result_length) handle_nonzero_success(result_length)
@ -176,22 +201,83 @@ def get_final_path(path):
return buf[:result_length] return buf[:result_length]
def compat_stat(path):
"""
Generate stat as found on Python 3.2 and later.
"""
stat = os.stat(path)
info = get_file_info(path)
# rewrite st_ino, st_dev, and st_nlink based on file info
return nt.stat_result(
(stat.st_mode,) +
(info.file_index, info.volume_serial_number, info.number_of_links) +
stat[4:]
)
def samefile(f1, f2):
"""
Backport of samefile from Python 3.2 with support for Windows.
"""
return posixpath.samestat(compat_stat(f1), compat_stat(f2))
def get_file_info(path):
# open the file the same way CPython does in posixmodule.c
desired_access = api.FILE_READ_ATTRIBUTES
share_mode = 0
security_attributes = None
creation_disposition = api.OPEN_EXISTING
flags_and_attributes = (
api.FILE_ATTRIBUTE_NORMAL |
api.FILE_FLAG_BACKUP_SEMANTICS |
api.FILE_FLAG_OPEN_REPARSE_POINT
)
template_file = None
handle = api.CreateFile(
path,
desired_access,
share_mode,
security_attributes,
creation_disposition,
flags_and_attributes,
template_file,
)
if handle == api.INVALID_HANDLE_VALUE:
raise WindowsError()
info = api.BY_HANDLE_FILE_INFORMATION()
res = api.GetFileInformationByHandle(handle, info)
handle_nonzero_success(res)
handle_nonzero_success(api.CloseHandle(handle))
return info
def GetBinaryType(filepath): def GetBinaryType(filepath):
res = api.DWORD() res = api.DWORD()
handle_nonzero_success(api._GetBinaryType(filepath, res)) handle_nonzero_success(api._GetBinaryType(filepath, res))
return res return res
def _make_null_terminated_list(obs): def _make_null_terminated_list(obs):
obs = _makelist(obs) obs = _makelist(obs)
if obs is None: return if obs is None:
return
return u'\x00'.join(obs) + u'\x00\x00' return u'\x00'.join(obs) + u'\x00\x00'
def _makelist(ob): def _makelist(ob):
if ob is None: return if ob is None:
return
if not isinstance(ob, (list, tuple, set)): if not isinstance(ob, (list, tuple, set)):
return [ob] return [ob]
return ob return ob
def SHFileOperation(operation, from_, to=None, flags=[]): def SHFileOperation(operation, from_, to=None, flags=[]):
flags = functools.reduce(operator.or_, flags, 0) flags = functools.reduce(operator.or_, flags, 0)
from_ = _make_null_terminated_list(from_) from_ = _make_null_terminated_list(from_)
@ -201,6 +287,7 @@ def SHFileOperation(operation, from_, to=None, flags=[]):
if res != 0: if res != 0:
raise RuntimeError("SHFileOperation returned %d" % res) raise RuntimeError("SHFileOperation returned %d" % res)
def join(*paths): def join(*paths):
r""" r"""
Wrapper around os.path.join that works with Windows drive letters. Wrapper around os.path.join that works with Windows drive letters.
@ -214,17 +301,17 @@ def join(*paths):
drive = next(filter(None, reversed(drives)), '') drive = next(filter(None, reversed(drives)), '')
return os.path.join(drive, os.path.join(*paths)) return os.path.join(drive, os.path.join(*paths))
def resolve_path(target, start=os.path.curdir): def resolve_path(target, start=os.path.curdir):
r""" r"""
Find a path from start to target where target is relative to start. Find a path from start to target where target is relative to start.
>>> orig_wd = os.getcwd() >>> tmp = str(getfixture('tmpdir_as_cwd'))
>>> os.chdir('c:\\windows') # so we know what the working directory is
>>> findpath('d:\\') >>> findpath('d:\\')
'd:\\' 'd:\\'
>>> findpath('d:\\', 'c:\\windows') >>> findpath('d:\\', tmp)
'd:\\' 'd:\\'
>>> findpath('\\bar', 'd:\\') >>> findpath('\\bar', 'd:\\')
@ -239,11 +326,11 @@ def resolve_path(target, start=os.path.curdir):
>>> findpath('\\baz', 'd:\\foo\\bar') # fails with '\\baz' >>> findpath('\\baz', 'd:\\foo\\bar') # fails with '\\baz'
'd:\\baz' 'd:\\baz'
>>> os.path.abspath(findpath('\\bar')) >>> os.path.abspath(findpath('\\bar')).lower()
'c:\\bar' 'c:\\bar'
>>> os.path.abspath(findpath('bar')) >>> os.path.abspath(findpath('bar'))
'c:\\windows\\bar' '...\\bar'
>>> findpath('..', 'd:\\foo\\bar') >>> findpath('..', 'd:\\foo\\bar')
'd:\\foo' 'd:\\foo'
@ -254,8 +341,10 @@ def resolve_path(target, start=os.path.curdir):
""" """
return os.path.normpath(join(start, target)) return os.path.normpath(join(start, target))
findpath = resolve_path findpath = resolve_path
def trace_symlink_target(link): def trace_symlink_target(link):
""" """
Given a file that is known to be a symlink, trace it to its ultimate Given a file that is known to be a symlink, trace it to its ultimate
@ -273,6 +362,7 @@ def trace_symlink_target(link):
link = resolve_path(link, orig) link = resolve_path(link, orig)
return link return link
def readlink(link): def readlink(link):
""" """
readlink(link) -> target readlink(link) -> target
@ -291,7 +381,8 @@ def readlink(link):
if handle == api.INVALID_HANDLE_VALUE: if handle == api.INVALID_HANDLE_VALUE:
raise WindowsError() raise WindowsError()
res = reparse.DeviceIoControl(handle, api.FSCTL_GET_REPARSE_POINT, None, 10240) res = reparse.DeviceIoControl(
handle, api.FSCTL_GET_REPARSE_POINT, None, 10240)
bytes = create_string_buffer(res) bytes = create_string_buffer(res)
p_rdb = cast(bytes, POINTER(api.REPARSE_DATA_BUFFER)) p_rdb = cast(bytes, POINTER(api.REPARSE_DATA_BUFFER))
@ -302,6 +393,7 @@ def readlink(link):
handle_nonzero_success(api.CloseHandle(handle)) handle_nonzero_success(api.CloseHandle(handle))
return rdb.get_substitute_name() return rdb.get_substitute_name()
def patch_os_module(): def patch_os_module():
""" """
jaraco.windows provides the os.symlink and os.readlink functions. jaraco.windows provides the os.symlink and os.readlink functions.
@ -313,6 +405,7 @@ def patch_os_module():
if not hasattr(os, 'readlink'): if not hasattr(os, 'readlink'):
os.readlink = readlink os.readlink = readlink
def find_symlinks(root): def find_symlinks(root):
for dirpath, dirnames, filenames in os.walk(root): for dirpath, dirnames, filenames in os.walk(root):
for name in dirnames + filenames: for name in dirnames + filenames:
@ -323,6 +416,7 @@ def find_symlinks(root):
if name in dirnames: if name in dirnames:
dirnames.remove(name) dirnames.remove(name)
def find_symlinks_cmd(): def find_symlinks_cmd():
""" """
%prog [start-path] %prog [start-path]
@ -333,7 +427,8 @@ def find_symlinks_cmd():
from textwrap import dedent from textwrap import dedent
parser = OptionParser(usage=dedent(find_symlinks_cmd.__doc__).strip()) parser = OptionParser(usage=dedent(find_symlinks_cmd.__doc__).strip())
options, args = parser.parse_args() options, args = parser.parse_args()
if not args: args = ['.'] if not args:
args = ['.']
root = args.pop() root = args.pop()
if args: if args:
parser.error("unexpected argument(s)") parser.error("unexpected argument(s)")
@ -346,8 +441,19 @@ def find_symlinks_cmd():
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
@six.add_metaclass(binary.BitMask) @six.add_metaclass(binary.BitMask)
class FileAttributes(int): class FileAttributes(int):
# extract the values from the stat module on Python 3.5
# and later.
locals().update(
(name.split('FILE_ATTRIBUTES_')[1].lower(), value)
for name, value in vars(stat).items()
if name.startswith('FILE_ATTRIBUTES_')
)
# For Python 3.4 and earlier, define the constants here
archive = 0x20 archive = 0x20
compressed = 0x800 compressed = 0x800
hidden = 0x2 hidden = 0x2
@ -364,11 +470,16 @@ class FileAttributes(int):
temporary = 0x100 temporary = 0x100
virtual = 0x10000 virtual = 0x10000
def GetFileAttributes(filepath): @classmethod
def get(cls, filepath):
attrs = api.GetFileAttributes(filepath) attrs = api.GetFileAttributes(filepath)
if attrs == api.INVALID_FILE_ATTRIBUTES: if attrs == api.INVALID_FILE_ATTRIBUTES:
raise WindowsError() raise WindowsError()
return FileAttributes(attrs) return cls(attrs)
GetFileAttributes = FileAttributes.get
def SetFileAttributes(filepath, *attrs): def SetFileAttributes(filepath, *attrs):
""" """
@ -382,8 +493,8 @@ def SetFileAttributes(filepath, *attrs):
""" """
nice_names = collections.defaultdict( nice_names = collections.defaultdict(
lambda key: key, lambda key: key,
hidden = 'FILE_ATTRIBUTE_HIDDEN', hidden='FILE_ATTRIBUTE_HIDDEN',
read_only = 'FILE_ATTRIBUTE_READONLY', read_only='FILE_ATTRIBUTE_READONLY',
) )
flags = (getattr(api, nice_names[attr], attr) for attr in attrs) flags = (getattr(api, nice_names[attr], attr) for attr in attrs)
flags = functools.reduce(operator.or_, flags) flags = functools.reduce(operator.or_, flags)

View file

@ -0,0 +1,109 @@
from __future__ import unicode_literals
import os.path
# realpath taken from https://bugs.python.org/file38057/issue9949-v4.patch
def realpath(path):
if isinstance(path, str):
prefix = '\\\\?\\'
unc_prefix = prefix + 'UNC'
new_unc_prefix = '\\'
cwd = os.getcwd()
else:
prefix = b'\\\\?\\'
unc_prefix = prefix + b'UNC'
new_unc_prefix = b'\\'
cwd = os.getcwdb()
had_prefix = path.startswith(prefix)
path, ok = _resolve_path(cwd, path, {})
# The path returned by _getfinalpathname will always start with \\?\ -
# strip off that prefix unless it was already provided on the original
# path.
if not had_prefix:
# For UNC paths, the prefix will actually be \\?\UNC - handle that
# case as well.
if path.startswith(unc_prefix):
path = new_unc_prefix + path[len(unc_prefix):]
elif path.startswith(prefix):
path = path[len(prefix):]
return path
def _resolve_path(path, rest, seen):
# Windows normalizes the path before resolving symlinks; be sure to
# follow the same behavior.
rest = os.path.normpath(rest)
if isinstance(rest, str):
sep = '\\'
else:
sep = b'\\'
if os.path.isabs(rest):
drive, rest = os.path.splitdrive(rest)
path = drive + sep
rest = rest[1:]
while rest:
name, _, rest = rest.partition(sep)
new_path = os.path.join(path, name) if path else name
if os.path.exists(new_path):
if not rest:
# The whole path exists. Resolve it using the OS.
path = os.path._getfinalpathname(new_path)
else:
# The OS can resolve `new_path`; keep traversing the path.
path = new_path
elif not os.path.lexists(new_path):
# `new_path` does not exist on the filesystem at all. Use the
# OS to resolve `path`, if it exists, and then append the
# remainder.
if os.path.exists(path):
path = os.path._getfinalpathname(path)
rest = os.path.join(name, rest) if rest else name
return os.path.join(path, rest), True
else:
# We have a symbolic link that the OS cannot resolve. Try to
# resolve it ourselves.
# On Windows, symbolic link resolution can be partially or
# fully disabled [1]. The end result of a disabled symlink
# appears the same as a broken symlink (lexists() returns True
# but exists() returns False). And in both cases, the link can
# still be read using readlink(). Call stat() and check the
# resulting error code to ensure we don't circumvent the
# Windows symbolic link restrictions.
# [1] https://technet.microsoft.com/en-us/library/cc754077.aspx
try:
os.stat(new_path)
except OSError as e:
# WinError 1463: The symbolic link cannot be followed
# because its type is disabled.
if e.winerror == 1463:
raise
key = os.path.normcase(new_path)
if key in seen:
# This link has already been seen; try to use the
# previously resolved value.
path = seen[key]
if path is None:
# It has not yet been resolved, which means we must
# have a symbolic link loop. Return what we have
# resolved so far plus the remainder of the path (who
# cares about the Zen of Python?).
path = os.path.join(new_path, rest) if rest else new_path
return path, False
else:
# Mark this link as in the process of being resolved.
seen[key] = None
# Try to resolve it.
path, ok = _resolve_path(path, os.readlink(new_path), seen)
if ok:
# Resolution succeded; store the resolved value.
seen[key] = path
else:
# Resolution failed; punt.
return (os.path.join(path, rest) if rest else path), False
return path, True

View file

@ -17,6 +17,8 @@ from threading import Thread
import itertools import itertools
import logging import logging
import six
from more_itertools.recipes import consume from more_itertools.recipes import consume
import jaraco.text import jaraco.text
@ -25,9 +27,11 @@ from jaraco.windows.api import event
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class NotifierException(Exception): class NotifierException(Exception):
pass pass
class FileFilter(object): class FileFilter(object):
def set_root(self, root): def set_root(self, root):
self.root = root self.root = root
@ -35,9 +39,11 @@ class FileFilter(object):
def _get_file_path(self, filename): def _get_file_path(self, filename):
try: try:
filename = os.path.join(self.root, filename) filename = os.path.join(self.root, filename)
except AttributeError: pass except AttributeError:
pass
return filename return filename
class ModifiedTimeFilter(FileFilter): class ModifiedTimeFilter(FileFilter):
""" """
Returns true for each call where the modified time of the file is after Returns true for each call where the modified time of the file is after
@ -53,6 +59,7 @@ class ModifiedTimeFilter(FileFilter):
log.debug('{filepath} last modified at {last_mod}.'.format(**vars())) log.debug('{filepath} last modified at {last_mod}.'.format(**vars()))
return last_mod > self.cutoff return last_mod > self.cutoff
class PatternFilter(FileFilter): class PatternFilter(FileFilter):
""" """
Filter that returns True for files that match pattern (a regular Filter that returns True for files that match pattern (a regular
@ -60,13 +67,14 @@ class PatternFilter(FileFilter):
""" """
def __init__(self, pattern): def __init__(self, pattern):
self.pattern = ( self.pattern = (
re.compile(pattern) if isinstance(pattern, basestring) re.compile(pattern) if isinstance(pattern, six.string_types)
else pattern else pattern
) )
def __call__(self, file): def __call__(self, file):
return bool(self.pattern.match(file, re.I)) return bool(self.pattern.match(file, re.I))
class GlobFilter(PatternFilter): class GlobFilter(PatternFilter):
""" """
Filter that returns True for files that match the pattern (a glob Filter that returns True for files that match the pattern (a glob
@ -102,6 +110,7 @@ class AggregateFilter(FileFilter):
def __call__(self, file): def __call__(self, file):
return all(fil(file) for fil in self.filters) return all(fil(file) for fil in self.filters)
class OncePerModFilter(FileFilter): class OncePerModFilter(FileFilter):
def __init__(self): def __init__(self):
self.history = list() self.history = list()
@ -115,15 +124,18 @@ class OncePerModFilter(FileFilter):
del self.history[-50:] del self.history[-50:]
return result return result
def files_with_path(files, path): def files_with_path(files, path):
return (os.path.join(path, file) for file in files) return (os.path.join(path, file) for file in files)
def get_file_paths(walk_result): def get_file_paths(walk_result):
root, dirs, files = walk_result root, dirs, files = walk_result
return files_with_path(files, root) return files_with_path(files, root)
class Notifier(object): class Notifier(object):
def __init__(self, root = '.', filters = []): def __init__(self, root='.', filters=[]):
# assign the root, verify it exists # assign the root, verify it exists
self.root = root self.root = root
if not os.path.isdir(self.root): if not os.path.isdir(self.root):
@ -138,7 +150,8 @@ class Notifier(object):
def __del__(self): def __del__(self):
try: try:
fs.FindCloseChangeNotification(self.hChange) fs.FindCloseChangeNotification(self.hChange)
except: pass except Exception:
pass
def _get_change_handle(self): def _get_change_handle(self):
# set up to monitor the directory tree specified # set up to monitor the directory tree specified
@ -151,8 +164,8 @@ class Notifier(object):
# make sure it worked; if not, bail # make sure it worked; if not, bail
INVALID_HANDLE_VALUE = fs.INVALID_HANDLE_VALUE INVALID_HANDLE_VALUE = fs.INVALID_HANDLE_VALUE
if self.hChange == INVALID_HANDLE_VALUE: if self.hChange == INVALID_HANDLE_VALUE:
raise NotifierException('Could not set up directory change ' raise NotifierException(
'notification') 'Could not set up directory change notification')
@staticmethod @staticmethod
def _filtered_walk(path, file_filter): def _filtered_walk(path, file_filter):
@ -171,6 +184,7 @@ class Notifier(object):
def quit(self): def quit(self):
event.SetEvent(self.quit_event) event.SetEvent(self.quit_event)
class BlockingNotifier(Notifier): class BlockingNotifier(Notifier):
@staticmethod @staticmethod
@ -215,17 +229,18 @@ class BlockingNotifier(Notifier):
result = next(results) result = next(results)
return result return result
class ThreadedNotifier(BlockingNotifier, Thread): class ThreadedNotifier(BlockingNotifier, Thread):
r""" r"""
ThreadedNotifier provides a simple interface that calls the handler ThreadedNotifier provides a simple interface that calls the handler
for each file rooted in root that passes the filters. It runs as its own for each file rooted in root that passes the filters. It runs as its own
thread, so must be started as such. thread, so must be started as such::
>>> notifier = ThreadedNotifier('c:\\', handler = StreamHandler()) # doctest: +SKIP notifier = ThreadedNotifier('c:\\', handler = StreamHandler())
>>> notifier.start() # doctest: +SKIP notifier.start()
C:\Autoexec.bat changed. C:\Autoexec.bat changed.
""" """
def __init__(self, root = '.', filters = [], handler = lambda file: None): def __init__(self, root='.', filters=[], handler=lambda file: None):
# init notifier stuff # init notifier stuff
BlockingNotifier.__init__(self, root, filters) BlockingNotifier.__init__(self, root, filters)
# init thread stuff # init thread stuff
@ -242,13 +257,14 @@ class ThreadedNotifier(BlockingNotifier, Thread):
for file in self.get_changed_files(): for file in self.get_changed_files():
self.handle(file) self.handle(file)
class StreamHandler(object): class StreamHandler(object):
""" """
StreamHandler: a sample handler object for use with the threaded StreamHandler: a sample handler object for use with the threaded
notifier that will announce by writing to the supplied stream notifier that will announce by writing to the supplied stream
(stdout by default) the name of the file. (stdout by default) the name of the file.
""" """
def __init__(self, output = sys.stdout): def __init__(self, output=sys.stdout):
self.output = output self.output = output
def __call__(self, filename): def __call__(self, filename):

View file

@ -14,20 +14,21 @@ from jaraco.windows.api import errors, inet
def GetAdaptersAddresses(): def GetAdaptersAddresses():
size = ctypes.c_ulong() size = ctypes.c_ulong()
res = inet.GetAdaptersAddresses(0,0,None, None,size) res = inet.GetAdaptersAddresses(0, 0, None, None, size)
if res != errors.ERROR_BUFFER_OVERFLOW: if res != errors.ERROR_BUFFER_OVERFLOW:
raise RuntimeError("Error getting structure length (%d)" % res) raise RuntimeError("Error getting structure length (%d)" % res)
print(size.value) print(size.value)
pointer_type = ctypes.POINTER(inet.IP_ADAPTER_ADDRESSES) pointer_type = ctypes.POINTER(inet.IP_ADAPTER_ADDRESSES)
buffer = ctypes.create_string_buffer(size.value) buffer = ctypes.create_string_buffer(size.value)
struct_p = ctypes.cast(buffer, pointer_type) struct_p = ctypes.cast(buffer, pointer_type)
res = inet.GetAdaptersAddresses(0,0,None, struct_p, size) res = inet.GetAdaptersAddresses(0, 0, None, struct_p, size)
if res != errors.NO_ERROR: if res != errors.NO_ERROR:
raise RuntimeError("Error retrieving table (%d)" % res) raise RuntimeError("Error retrieving table (%d)" % res)
while struct_p: while struct_p:
yield struct_p.contents yield struct_p.contents
struct_p = struct_p.contents.next struct_p = struct_p.contents.next
class AllocatedTable(object): class AllocatedTable(object):
""" """
Both the interface table and the ip address table use the same Both the interface table and the ip address table use the same
@ -79,20 +80,23 @@ class AllocatedTable(object):
on the table size. on the table size.
""" """
table = self.get_table() table = self.get_table()
entries_array = self.row_structure*table.num_entries entries_array = self.row_structure * table.num_entries
pointer_type = ctypes.POINTER(entries_array) pointer_type = ctypes.POINTER(entries_array)
return ctypes.cast(table.entries, pointer_type).contents return ctypes.cast(table.entries, pointer_type).contents
class InterfaceTable(AllocatedTable): class InterfaceTable(AllocatedTable):
method = inet.GetIfTable method = inet.GetIfTable
structure = inet.MIB_IFTABLE structure = inet.MIB_IFTABLE
row_structure = inet.MIB_IFROW row_structure = inet.MIB_IFROW
class AddressTable(AllocatedTable): class AddressTable(AllocatedTable):
method = inet.GetIpAddrTable method = inet.GetIpAddrTable
structure = inet.MIB_IPADDRTABLE structure = inet.MIB_IPADDRTABLE
row_structure = inet.MIB_IPADDRROW row_structure = inet.MIB_IPADDRROW
class AddressManager(object): class AddressManager(object):
@staticmethod @staticmethod
def hardware_address_to_string(addr): def hardware_address_to_string(addr):
@ -100,7 +104,8 @@ class AddressManager(object):
return ':'.join(hex_bytes) return ':'.join(hex_bytes)
def get_host_mac_address_strings(self): def get_host_mac_address_strings(self):
return (self.hardware_address_to_string(addr) return (
self.hardware_address_to_string(addr)
for addr in self.get_host_mac_addresses()) for addr in self.get_host_mac_addresses())
def get_host_ip_address_strings(self): def get_host_ip_address_strings(self):

View file

@ -2,6 +2,7 @@ import ctypes
from .api import library from .api import library
def find_lib(lib): def find_lib(lib):
r""" r"""
Find the DLL for a given library. Find the DLL for a given library.

View file

@ -3,6 +3,7 @@ from ctypes import WinError
from .api import memory from .api import memory
class LockedMemory(object): class LockedMemory(object):
def __init__(self, handle): def __init__(self, handle):
self.handle = handle self.handle = handle

View file

@ -5,6 +5,7 @@ import six
from .error import handle_nonzero_success from .error import handle_nonzero_success
from .api import memory from .api import memory
class MemoryMap(object): class MemoryMap(object):
""" """
A memory map object which can have security attributes overridden. A memory map object which can have security attributes overridden.

View file

@ -10,13 +10,15 @@ import itertools
import six import six
class CookieMonster(object): class CookieMonster(object):
"Read cookies out of a user's IE cookies file" "Read cookies out of a user's IE cookies file"
@property @property
def cookie_dir(self): def cookie_dir(self):
import _winreg as winreg import _winreg as winreg
key = winreg.OpenKeyEx(winreg.HKEY_CURRENT_USER, 'Software' key = winreg.OpenKeyEx(
winreg.HKEY_CURRENT_USER, 'Software'
'\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders') '\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders')
cookie_dir, type = winreg.QueryValueEx(key, 'Cookies') cookie_dir, type = winreg.QueryValueEx(key, 'Cookies')
return cookie_dir return cookie_dir
@ -24,10 +26,12 @@ class CookieMonster(object):
def entries(self, filename): def entries(self, filename):
with open(os.path.join(self.cookie_dir, filename)) as cookie_file: with open(os.path.join(self.cookie_dir, filename)) as cookie_file:
while True: while True:
entry = itertools.takewhile(self.is_not_cookie_delimiter, entry = itertools.takewhile(
self.is_not_cookie_delimiter,
cookie_file) cookie_file)
entry = list(map(six.text_type.rstrip, entry)) entry = list(map(six.text_type.rstrip, entry))
if not entry: break if not entry:
break
cookie = self.make_cookie(*entry) cookie = self.make_cookie(*entry)
yield cookie yield cookie
@ -36,7 +40,8 @@ class CookieMonster(object):
return s != '*\n' return s != '*\n'
@staticmethod @staticmethod
def make_cookie(key, value, domain, flags, ExpireLow, ExpireHigh, def make_cookie(
key, value, domain, flags, ExpireLow, ExpireHigh,
CreateLow, CreateHigh): CreateLow, CreateHigh):
expires = (int(ExpireHigh) << 32) | int(ExpireLow) expires = (int(ExpireHigh) << 32) | int(ExpireLow)
created = (int(CreateHigh) << 32) | int(CreateLow) created = (int(CreateHigh) << 32) | int(CreateLow)

View file

@ -7,7 +7,9 @@ __all__ = ('AddConnection')
from jaraco.windows.error import WindowsError from jaraco.windows.error import WindowsError
from .api import net from .api import net
def AddConnection(remote_name, type=net.RESOURCETYPE_ANY, local_name=None,
def AddConnection(
remote_name, type=net.RESOURCETYPE_ANY, local_name=None,
provider_name=None, user=None, password=None, flags=0): provider_name=None, user=None, password=None, flags=0):
resource = net.NETRESOURCE( resource = net.NETRESOURCE(
type=type, type=type,

View file

@ -1,10 +1,9 @@
#-*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function from __future__ import print_function
import itertools import itertools
import contextlib import contextlib
import ctypes
from more_itertools.recipes import consume, unique_justseen from more_itertools.recipes import consume, unique_justseen
try: try:
@ -15,11 +14,13 @@ except ImportError:
from jaraco.windows.error import handle_nonzero_success from jaraco.windows.error import handle_nonzero_success
from .api import power from .api import power
def GetSystemPowerStatus(): def GetSystemPowerStatus():
stat = power.SYSTEM_POWER_STATUS() stat = power.SYSTEM_POWER_STATUS()
handle_nonzero_success(GetSystemPowerStatus(stat)) handle_nonzero_success(GetSystemPowerStatus(stat))
return stat return stat
def _init_power_watcher(): def _init_power_watcher():
global power_watcher global power_watcher
if 'power_watcher' not in globals(): if 'power_watcher' not in globals():
@ -27,18 +28,22 @@ def _init_power_watcher():
query = 'SELECT * from Win32_PowerManagementEvent' query = 'SELECT * from Win32_PowerManagementEvent'
power_watcher = wmi.ExecNotificationQuery(query) power_watcher = wmi.ExecNotificationQuery(query)
def get_power_management_events(): def get_power_management_events():
_init_power_watcher() _init_power_watcher()
while True: while True:
yield power_watcher.NextEvent() yield power_watcher.NextEvent()
def wait_for_power_status_change(): def wait_for_power_status_change():
EVT_POWER_STATUS_CHANGE = 10 EVT_POWER_STATUS_CHANGE = 10
def not_power_status_change(evt): def not_power_status_change(evt):
return evt.EventType != EVT_POWER_STATUS_CHANGE return evt.EventType != EVT_POWER_STATUS_CHANGE
events = get_power_management_events() events = get_power_management_events()
consume(itertools.takewhile(not_power_status_change, events)) consume(itertools.takewhile(not_power_status_change, events))
def get_unique_power_states(): def get_unique_power_states():
""" """
Just like get_power_states, but ensures values are returned only Just like get_power_states, but ensures values are returned only
@ -46,6 +51,7 @@ def get_unique_power_states():
""" """
return unique_justseen(get_power_states()) return unique_justseen(get_power_states())
def get_power_states(): def get_power_states():
""" """
Continuously return the power state of the system when it changes. Continuously return the power state of the system when it changes.
@ -57,6 +63,7 @@ def get_power_states():
yield state.ac_line_status_string yield state.ac_line_status_string
wait_for_power_status_change() wait_for_power_status_change()
@contextlib.contextmanager @contextlib.contextmanager
def no_sleep(): def no_sleep():
""" """

View file

@ -7,24 +7,31 @@ from .api import security
from .api import privilege from .api import privilege
from .api import process from .api import process
def get_process_token(): def get_process_token():
""" """
Get the current process token Get the current process token
""" """
token = wintypes.HANDLE() token = wintypes.HANDLE()
res = process.OpenProcessToken(process.GetCurrentProcess(), process.TOKEN_ALL_ACCESS, token) res = process.OpenProcessToken(
if not res > 0: raise RuntimeError("Couldn't get process token") process.GetCurrentProcess(), process.TOKEN_ALL_ACCESS, token)
if not res > 0:
raise RuntimeError("Couldn't get process token")
return token return token
def get_symlink_luid(): def get_symlink_luid():
""" """
Get the LUID for the SeCreateSymbolicLinkPrivilege Get the LUID for the SeCreateSymbolicLinkPrivilege
""" """
symlink_luid = privilege.LUID() symlink_luid = privilege.LUID()
res = privilege.LookupPrivilegeValue(None, "SeCreateSymbolicLinkPrivilege", symlink_luid) res = privilege.LookupPrivilegeValue(
if not res > 0: raise RuntimeError("Couldn't lookup privilege value") None, "SeCreateSymbolicLinkPrivilege", symlink_luid)
if not res > 0:
raise RuntimeError("Couldn't lookup privilege value")
return symlink_luid return symlink_luid
def get_privilege_information(): def get_privilege_information():
""" """
Get all privileges associated with the current process. Get all privileges associated with the current process.
@ -51,9 +58,11 @@ def get_privilege_information():
res = privilege.GetTokenInformation(*params) res = privilege.GetTokenInformation(*params)
assert res > 0, "Error in second GetTokenInformation (%d)" % res assert res > 0, "Error in second GetTokenInformation (%d)" % res
privileges = ctypes.cast(buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES)).contents privileges = ctypes.cast(
buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES)).contents
return privileges return privileges
def report_privilege_information(): def report_privilege_information():
""" """
Report all privilege information assigned to the current process. Report all privilege information assigned to the current process.
@ -62,6 +71,7 @@ def report_privilege_information():
print("found {0} privileges".format(privileges.count)) print("found {0} privileges".format(privileges.count))
tuple(map(print, privileges)) tuple(map(print, privileges))
def enable_symlink_privilege(): def enable_symlink_privilege():
""" """
Try to assign the symlink privilege to the current process token. Try to assign the symlink privilege to the current process token.
@ -84,9 +94,11 @@ def enable_symlink_privilege():
ERROR_NOT_ALL_ASSIGNED = 1300 ERROR_NOT_ALL_ASSIGNED = 1300
return ctypes.windll.kernel32.GetLastError() != ERROR_NOT_ALL_ASSIGNED return ctypes.windll.kernel32.GetLastError() != ERROR_NOT_ALL_ASSIGNED
class PolicyHandle(wintypes.HANDLE): class PolicyHandle(wintypes.HANDLE):
pass pass
class LSA_UNICODE_STRING(ctypes.Structure): class LSA_UNICODE_STRING(ctypes.Structure):
_fields_ = [ _fields_ = [
('length', ctypes.c_ushort), ('length', ctypes.c_ushort),
@ -94,15 +106,20 @@ class LSA_UNICODE_STRING(ctypes.Structure):
('buffer', ctypes.wintypes.LPWSTR), ('buffer', ctypes.wintypes.LPWSTR),
] ]
def OpenPolicy(system_name, object_attributes, access_mask): def OpenPolicy(system_name, object_attributes, access_mask):
policy = PolicyHandle() policy = PolicyHandle()
raise NotImplementedError("Need to construct structures for parameters " raise NotImplementedError(
"(see http://msdn.microsoft.com/en-us/library/windows/desktop/aa378299%28v=vs.85%29.aspx)") "Need to construct structures for parameters "
res = ctypes.windll.advapi32.LsaOpenPolicy(system_name, object_attributes, "(see http://msdn.microsoft.com/en-us/library/windows"
"/desktop/aa378299%28v=vs.85%29.aspx)")
res = ctypes.windll.advapi32.LsaOpenPolicy(
system_name, object_attributes,
access_mask, ctypes.byref(policy)) access_mask, ctypes.byref(policy))
assert res == 0, "Error status {res}".format(**vars()) assert res == 0, "Error status {res}".format(**vars())
return policy return policy
def grant_symlink_privilege(who, machine=''): def grant_symlink_privilege(who, machine=''):
""" """
Grant the 'create symlink' privilege to who. Grant the 'create symlink' privilege to who.
@ -113,10 +130,13 @@ def grant_symlink_privilege(who, machine=''):
policy = OpenPolicy(machine, flags) policy = OpenPolicy(machine, flags)
return policy return policy
def main(): def main():
assigned = enable_symlink_privilege() assigned = enable_symlink_privilege()
msg = ['failure', 'success'][assigned] msg = ['failure', 'success'][assigned]
print("Symlink privilege assignment completed with {0}".format(msg)) print("Symlink privilege assignment completed with {0}".format(msg))
if __name__ == '__main__': main()
if __name__ == '__main__':
main()

View file

@ -3,6 +3,7 @@ from itertools import count
import six import six
winreg = six.moves.winreg winreg = six.moves.winreg
def key_values(key): def key_values(key):
for index in count(): for index in count():
try: try:
@ -10,6 +11,7 @@ def key_values(key):
except WindowsError: except WindowsError:
break break
def key_subkeys(key): def key_subkeys(key):
for index in count(): for index in count():
try: try:

View file

@ -5,7 +5,9 @@ import ctypes.wintypes
from .error import handle_nonzero_success from .error import handle_nonzero_success
from .api import filesystem from .api import filesystem
def DeviceIoControl(device, io_control_code, in_buffer, out_buffer, overlapped=None):
def DeviceIoControl(
device, io_control_code, in_buffer, out_buffer, overlapped=None):
if overlapped is not None: if overlapped is not None:
raise NotImplementedError("overlapped handles not yet supported") raise NotImplementedError("overlapped handles not yet supported")

View file

@ -3,20 +3,24 @@ import ctypes.wintypes
from jaraco.windows.error import handle_nonzero_success from jaraco.windows.error import handle_nonzero_success
from .api import security from .api import security
def GetTokenInformation(token, information_class): def GetTokenInformation(token, information_class):
""" """
Given a token, get the token information for it. Given a token, get the token information for it.
""" """
data_size = ctypes.wintypes.DWORD() data_size = ctypes.wintypes.DWORD()
ctypes.windll.advapi32.GetTokenInformation(token, information_class.num, ctypes.windll.advapi32.GetTokenInformation(
token, information_class.num,
0, 0, ctypes.byref(data_size)) 0, 0, ctypes.byref(data_size))
data = ctypes.create_string_buffer(data_size.value) data = ctypes.create_string_buffer(data_size.value)
handle_nonzero_success(ctypes.windll.advapi32.GetTokenInformation(token, handle_nonzero_success(ctypes.windll.advapi32.GetTokenInformation(
token,
information_class.num, information_class.num,
ctypes.byref(data), ctypes.sizeof(data), ctypes.byref(data), ctypes.sizeof(data),
ctypes.byref(data_size))) ctypes.byref(data_size)))
return ctypes.cast(data, ctypes.POINTER(security.TOKEN_USER)).contents return ctypes.cast(data, ctypes.POINTER(security.TOKEN_USER)).contents
def OpenProcessToken(proc_handle, access): def OpenProcessToken(proc_handle, access):
result = ctypes.wintypes.HANDLE() result = ctypes.wintypes.HANDLE()
proc_handle = ctypes.wintypes.HANDLE(proc_handle) proc_handle = ctypes.wintypes.HANDLE(proc_handle)
@ -24,6 +28,7 @@ def OpenProcessToken(proc_handle, access):
proc_handle, access, ctypes.byref(result))) proc_handle, access, ctypes.byref(result)))
return result return result
def get_current_user(): def get_current_user():
""" """
Return a TOKEN_USER for the owner of this process. Return a TOKEN_USER for the owner of this process.
@ -34,6 +39,7 @@ def get_current_user():
) )
return GetTokenInformation(process, security.TOKEN_USER) return GetTokenInformation(process, security.TOKEN_USER)
def get_security_attributes_for_user(user=None): def get_security_attributes_for_user(user=None):
""" """
Return a SECURITY_ATTRIBUTES structure with the SID set to the Return a SECURITY_ATTRIBUTES structure with the SID set to the
@ -42,7 +48,8 @@ def get_security_attributes_for_user(user=None):
if user is None: if user is None:
user = get_current_user() user = get_current_user()
assert isinstance(user, security.TOKEN_USER), "user must be TOKEN_USER instance" assert isinstance(user, security.TOKEN_USER), (
"user must be TOKEN_USER instance")
SD = security.SECURITY_DESCRIPTOR() SD = security.SECURITY_DESCRIPTOR()
SA = security.SECURITY_ATTRIBUTES() SA = security.SECURITY_ATTRIBUTES()
@ -51,8 +58,10 @@ def get_security_attributes_for_user(user=None):
SA.descriptor = SD SA.descriptor = SD
SA.bInheritHandle = 1 SA.bInheritHandle = 1
ctypes.windll.advapi32.InitializeSecurityDescriptor(ctypes.byref(SD), ctypes.windll.advapi32.InitializeSecurityDescriptor(
ctypes.byref(SD),
security.SECURITY_DESCRIPTOR.REVISION) security.SECURITY_DESCRIPTOR.REVISION)
ctypes.windll.advapi32.SetSecurityDescriptorOwner(ctypes.byref(SD), ctypes.windll.advapi32.SetSecurityDescriptorOwner(
ctypes.byref(SD),
user.SID, 0) user.SID, 0)
return SA return SA

View file

@ -1,7 +1,8 @@
""" """
Windows Services support for controlling Windows Services. Windows Services support for controlling Windows Services.
Based on http://code.activestate.com/recipes/115875-controlling-windows-services/ Based on http://code.activestate.com
/recipes/115875-controlling-windows-services/
""" """
from __future__ import print_function from __future__ import print_function
@ -13,6 +14,7 @@ import win32api
import win32con import win32con
import win32service import win32service
class Service(object): class Service(object):
""" """
The Service Class is used for controlling Windows The Service Class is used for controlling Windows
@ -47,7 +49,8 @@ class Service(object):
pause: Pauses service (Only if service supports feature). pause: Pauses service (Only if service supports feature).
resume: Resumes service that has been paused. resume: Resumes service that has been paused.
status: Queries current status of service. status: Queries current status of service.
fetchstatus: Continually queries service until requested status(STARTING, RUNNING, fetchstatus: Continually queries service until requested
status(STARTING, RUNNING,
STOPPING & STOPPED) is met or timeout value(in seconds) reached. STOPPING & STOPPED) is met or timeout value(in seconds) reached.
Default timeout value is infinite. Default timeout value is infinite.
infotype: Queries service for process type. (Single, shared and/or infotype: Queries service for process type. (Single, shared and/or
@ -64,18 +67,21 @@ class Service(object):
def __init__(self, service, machinename=None, dbname=None): def __init__(self, service, machinename=None, dbname=None):
self.userv = service self.userv = service
self.scmhandle = win32service.OpenSCManager(machinename, dbname, win32service.SC_MANAGER_ALL_ACCESS) self.scmhandle = win32service.OpenSCManager(
machinename, dbname, win32service.SC_MANAGER_ALL_ACCESS)
self.sserv, self.lserv = self.getname() self.sserv, self.lserv = self.getname()
if (self.sserv or self.lserv) is None: if (self.sserv or self.lserv) is None:
sys.exit() sys.exit()
self.handle = win32service.OpenService(self.scmhandle, self.sserv, win32service.SERVICE_ALL_ACCESS) self.handle = win32service.OpenService(
self.scmhandle, self.sserv, win32service.SERVICE_ALL_ACCESS)
self.sccss = "SYSTEM\\CurrentControlSet\\Services\\" self.sccss = "SYSTEM\\CurrentControlSet\\Services\\"
def start(self): def start(self):
win32service.StartService(self.handle, None) win32service.StartService(self.handle, None)
def stop(self): def stop(self):
self.stat = win32service.ControlService(self.handle, win32service.SERVICE_CONTROL_STOP) self.stat = win32service.ControlService(
self.handle, win32service.SERVICE_CONTROL_STOP)
def restart(self): def restart(self):
self.stop() self.stop()
@ -83,29 +89,31 @@ class Service(object):
self.start() self.start()
def pause(self): def pause(self):
self.stat = win32service.ControlService(self.handle, win32service.SERVICE_CONTROL_PAUSE) self.stat = win32service.ControlService(
self.handle, win32service.SERVICE_CONTROL_PAUSE)
def resume(self): def resume(self):
self.stat = win32service.ControlService(self.handle, win32service.SERVICE_CONTROL_CONTINUE) self.stat = win32service.ControlService(
self.handle, win32service.SERVICE_CONTROL_CONTINUE)
def status(self, prn = 0): def status(self, prn=0):
self.stat = win32service.QueryServiceStatus(self.handle) self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[1]==win32service.SERVICE_STOPPED: if self.stat[1] == win32service.SERVICE_STOPPED:
if prn == 1: if prn == 1:
print("The %s service is stopped." % self.lserv) print("The %s service is stopped." % self.lserv)
else: else:
return "STOPPED" return "STOPPED"
elif self.stat[1]==win32service.SERVICE_START_PENDING: elif self.stat[1] == win32service.SERVICE_START_PENDING:
if prn == 1: if prn == 1:
print("The %s service is starting." % self.lserv) print("The %s service is starting." % self.lserv)
else: else:
return "STARTING" return "STARTING"
elif self.stat[1]==win32service.SERVICE_STOP_PENDING: elif self.stat[1] == win32service.SERVICE_STOP_PENDING:
if prn == 1: if prn == 1:
print("The %s service is stopping." % self.lserv) print("The %s service is stopping." % self.lserv)
else: else:
return "STOPPING" return "STOPPING"
elif self.stat[1]==win32service.SERVICE_RUNNING: elif self.stat[1] == win32service.SERVICE_RUNNING:
if prn == 1: if prn == 1:
print("The %s service is running." % self.lserv) print("The %s service is running." % self.lserv)
else: else:
@ -116,6 +124,7 @@ class Service(object):
if timeout is not None: if timeout is not None:
timeout = int(timeout) timeout = int(timeout)
timeout *= 2 timeout *= 2
def to(timeout): def to(timeout):
time.sleep(.5) time.sleep(.5)
if timeout is not None: if timeout is not None:
@ -127,11 +136,11 @@ class Service(object):
if self.fstatus == "STOPPED": if self.fstatus == "STOPPED":
while 1: while 1:
self.stat = win32service.QueryServiceStatus(self.handle) self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[1]==win32service.SERVICE_STOPPED: if self.stat[1] == win32service.SERVICE_STOPPED:
self.fstate = "STOPPED" self.fstate = "STOPPED"
break break
else: else:
timeout=to(timeout) timeout = to(timeout)
if timeout == "TO": if timeout == "TO":
return "TIMEDOUT" return "TIMEDOUT"
break break

View file

@ -1,10 +1,12 @@
from .api import shell from .api import shell
def get_recycle_bin_confirm(): def get_recycle_bin_confirm():
settings = shell.SHELLSTATE() settings = shell.SHELLSTATE()
shell.SHGetSetSettings(settings, shell.SSF_NOCONFIRMRECYCLE, False) shell.SHGetSetSettings(settings, shell.SSF_NOCONFIRMRECYCLE, False)
return not settings.no_confirm_recycle return not settings.no_confirm_recycle
def set_recycle_bin_confirm(confirm=False): def set_recycle_bin_confirm(confirm=False):
settings = shell.SHELLSTATE() settings = shell.SHELLSTATE()
settings.no_confirm_recycle = not confirm settings.no_confirm_recycle = not confirm

View file

@ -14,6 +14,7 @@ from jaraco.windows.api import event as win32event
__author__ = 'Jason R. Coombs <jaraco@jaraco.com>' __author__ = 'Jason R. Coombs <jaraco@jaraco.com>'
class WaitableTimer: class WaitableTimer:
""" """
t = WaitableTimer() t = WaitableTimer()
@ -32,12 +33,12 @@ class WaitableTimer:
def stop(self): def stop(self):
win32event.SetEvent(self.stop_event) win32event.SetEvent(self.stop_event)
def wait_for_signal(self, timeout = None): def wait_for_signal(self, timeout=None):
""" """
wait for the signal; return after the signal has occurred or the wait for the signal; return after the signal has occurred or the
timeout in seconds elapses. timeout in seconds elapses.
""" """
timeout_ms = int(timeout*1000) if timeout else win32event.INFINITE timeout_ms = int(timeout * 1000) if timeout else win32event.INFINITE
win32event.WaitForSingleObject(self.signal_event, timeout_ms) win32event.WaitForSingleObject(self.signal_event, timeout_ms)
def _signal_loop(self, due_time, period): def _signal_loop(self, due_time, period):
@ -54,14 +55,14 @@ class WaitableTimer:
except Exception: except Exception:
pass pass
#we're done here, just quit
def _wait(self, seconds): def _wait(self, seconds):
milliseconds = int(seconds*1000) milliseconds = int(seconds * 1000)
if milliseconds > 0: if milliseconds > 0:
res = win32event.WaitForSingleObject(self.stop_event, milliseconds) res = win32event.WaitForSingleObject(self.stop_event, milliseconds)
if res == win32event.WAIT_OBJECT_0: raise Exception if res == win32event.WAIT_OBJECT_0:
if res == win32event.WAIT_TIMEOUT: pass raise Exception
if res == win32event.WAIT_TIMEOUT:
pass
win32event.SetEvent(self.signal_event) win32event.SetEvent(self.signal_event)
@staticmethod @staticmethod

View file

@ -8,6 +8,7 @@ from ctypes.wintypes import WORD, WCHAR, BOOL, LONG
from jaraco.windows.util import Extended from jaraco.windows.util import Extended
from jaraco.collections import RangeMap from jaraco.collections import RangeMap
class AnyDict(object): class AnyDict(object):
"A dictionary that returns the same value regardless of key" "A dictionary that returns the same value regardless of key"
@ -17,6 +18,7 @@ class AnyDict(object):
def __getitem__(self, key): def __getitem__(self, key):
return self.value return self.value
class SYSTEMTIME(Extended, ctypes.Structure): class SYSTEMTIME(Extended, ctypes.Structure):
_fields_ = [ _fields_ = [
('year', WORD), ('year', WORD),
@ -29,6 +31,7 @@ class SYSTEMTIME(Extended, ctypes.Structure):
('millisecond', WORD), ('millisecond', WORD),
] ]
class REG_TZI_FORMAT(Extended, ctypes.Structure): class REG_TZI_FORMAT(Extended, ctypes.Structure):
_fields_ = [ _fields_ = [
('bias', LONG), ('bias', LONG),
@ -38,17 +41,19 @@ class REG_TZI_FORMAT(Extended, ctypes.Structure):
('daylight_start', SYSTEMTIME), ('daylight_start', SYSTEMTIME),
] ]
class TIME_ZONE_INFORMATION(Extended, ctypes.Structure): class TIME_ZONE_INFORMATION(Extended, ctypes.Structure):
_fields_ = [ _fields_ = [
('bias', LONG), ('bias', LONG),
('standard_name', WCHAR*32), ('standard_name', WCHAR * 32),
('standard_start', SYSTEMTIME), ('standard_start', SYSTEMTIME),
('standard_bias', LONG), ('standard_bias', LONG),
('daylight_name', WCHAR*32), ('daylight_name', WCHAR * 32),
('daylight_start', SYSTEMTIME), ('daylight_start', SYSTEMTIME),
('daylight_bias', LONG), ('daylight_bias', LONG),
] ]
class DYNAMIC_TIME_ZONE_INFORMATION(TIME_ZONE_INFORMATION): class DYNAMIC_TIME_ZONE_INFORMATION(TIME_ZONE_INFORMATION):
""" """
Because the structure of the DYNAMIC_TIME_ZONE_INFORMATION extends Because the structure of the DYNAMIC_TIME_ZONE_INFORMATION extends
@ -70,7 +75,7 @@ class DYNAMIC_TIME_ZONE_INFORMATION(TIME_ZONE_INFORMATION):
""" """
_fields_ = [ _fields_ = [
# ctypes automatically includes the fields from the parent # ctypes automatically includes the fields from the parent
('key_name', WCHAR*128), ('key_name', WCHAR * 128),
('dynamic_daylight_time_disabled', BOOL), ('dynamic_daylight_time_disabled', BOOL),
] ]
@ -89,6 +94,7 @@ class DYNAMIC_TIME_ZONE_INFORMATION(TIME_ZONE_INFORMATION):
kwargs[field_name] = arg kwargs[field_name] = arg
super(DYNAMIC_TIME_ZONE_INFORMATION, self).__init__(*self_args, **kwargs) super(DYNAMIC_TIME_ZONE_INFORMATION, self).__init__(*self_args, **kwargs)
class Info(DYNAMIC_TIME_ZONE_INFORMATION): class Info(DYNAMIC_TIME_ZONE_INFORMATION):
""" """
A time zone definition class based on the win32 A time zone definition class based on the win32
@ -126,7 +132,7 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION):
def __init_from_bytes(self, bytes, **kwargs): def __init_from_bytes(self, bytes, **kwargs):
reg_tzi = REG_TZI_FORMAT() reg_tzi = REG_TZI_FORMAT()
# todo: use buffer API in Python 3 # todo: use buffer API in Python 3
buffer = buffer(bytes) buffer = memoryview(bytes)
ctypes.memmove(ctypes.addressof(reg_tzi), buffer, len(buffer)) ctypes.memmove(ctypes.addressof(reg_tzi), buffer, len(buffer))
self.__init_from_reg_tzi(self, reg_tzi, **kwargs) self.__init_from_reg_tzi(self, reg_tzi, **kwargs)
@ -146,12 +152,14 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION):
value = super(Info, other).__getattribute__(other, name) value = super(Info, other).__getattribute__(other, name)
setattr(self, name, value) setattr(self, name, value)
# consider instead of the loop above just copying the memory directly # consider instead of the loop above just copying the memory directly
#size = max(ctypes.sizeof(DYNAMIC_TIME_ZONE_INFO), ctypes.sizeof(other)) # size = max(ctypes.sizeof(DYNAMIC_TIME_ZONE_INFO), ctypes.sizeof(other))
#ctypes.memmove(ctypes.addressof(self), other, size) # ctypes.memmove(ctypes.addressof(self), other, size)
def __getattribute__(self, attr): def __getattribute__(self, attr):
value = super(Info, self).__getattribute__(attr) value = super(Info, self).__getattribute__(attr)
make_minute_timedelta = lambda m: datetime.timedelta(minutes = m)
def make_minute_timedelta(m):
datetime.timedelta(minutes=m)
if 'bias' in attr: if 'bias' in attr:
value = make_minute_timedelta(value) value = make_minute_timedelta(value)
return value return value
@ -205,10 +213,12 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION):
def _locate_day(year, cutoff): def _locate_day(year, cutoff):
""" """
Takes a SYSTEMTIME object, such as retrieved from a TIME_ZONE_INFORMATION Takes a SYSTEMTIME object, such as retrieved from a TIME_ZONE_INFORMATION
structure or call to GetTimeZoneInformation and interprets it based on the given structure or call to GetTimeZoneInformation and interprets
it based on the given
year to identify the actual day. year to identify the actual day.
This method is necessary because the SYSTEMTIME structure refers to a day by its This method is necessary because the SYSTEMTIME structure
refers to a day by its
day of the week and week of the month (e.g. 4th saturday in March). day of the week and week of the month (e.g. 4th saturday in March).
>>> SATURDAY = 6 >>> SATURDAY = 6
@ -227,9 +237,11 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION):
week_of_month = cutoff.day week_of_month = cutoff.day
# so the following is the first day of that week # so the following is the first day of that week
day = (week_of_month - 1) * 7 + 1 day = (week_of_month - 1) * 7 + 1
result = datetime.datetime(year, cutoff.month, day, result = datetime.datetime(
year, cutoff.month, day,
cutoff.hour, cutoff.minute, cutoff.second, cutoff.millisecond) cutoff.hour, cutoff.minute, cutoff.second, cutoff.millisecond)
# now the result is the correct week, but not necessarily the correct day of the week # now the result is the correct week, but not necessarily
# the correct day of the week
days_to_go = (target_weekday - result.weekday()) % 7 days_to_go = (target_weekday - result.weekday()) % 7
result += datetime.timedelta(days_to_go) result += datetime.timedelta(days_to_go)
# if we selected a day in the month following the target month, # if we selected a day in the month following the target month,
@ -238,5 +250,5 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION):
# to be the last week in a month and adding the time delta might have # to be the last week in a month and adding the time delta might have
# pushed the result into the next month. # pushed the result into the next month.
while result.month == cutoff.month + 1: while result.month == cutoff.month + 1:
result -= datetime.timedelta(weeks = 1) result -= datetime.timedelta(weeks=1)
return result return result

View file

@ -3,6 +3,7 @@
import ctypes import ctypes
from jaraco.windows.util import ensure_unicode from jaraco.windows.util import ensure_unicode
def MessageBox(text, caption=None, handle=None, type=None): def MessageBox(text, caption=None, handle=None, type=None):
text, caption = map(ensure_unicode, (text, caption)) text, caption = map(ensure_unicode, (text, caption))
ctypes.windll.user32.MessageBoxW(handle, text, caption, type) ctypes.windll.user32.MessageBoxW(handle, text, caption, type)

View file

@ -3,6 +3,7 @@ from .api import errors
from .api.user import GetUserName from .api.user import GetUserName
from .error import WindowsError, handle_nonzero_success from .error import WindowsError, handle_nonzero_success
def get_user_name(): def get_user_name():
size = ctypes.wintypes.DWORD() size = ctypes.wintypes.DWORD()
try: try:

View file

@ -2,6 +2,7 @@
import ctypes import ctypes
def ensure_unicode(param): def ensure_unicode(param):
try: try:
param = ctypes.create_unicode_buffer(param) param = ctypes.create_unicode_buffer(param)
@ -9,10 +10,11 @@ def ensure_unicode(param):
pass # just return the param as is pass # just return the param as is
return param return param
class Extended(object): class Extended(object):
"Used to add extended capability to structures" "Used to add extended capability to structures"
def __eq__(self, other): def __eq__(self, other):
return buffer(self) == buffer(other) return memoryview(self) == memoryview(other)
def __ne__(self, other): def __ne__(self, other):
return buffer(self) != buffer(other) return memoryview(self) != memoryview(other)

View file

@ -1,11 +1,14 @@
import os import os
from path import path from path import Path
def install_pptp(name, param_lines): def install_pptp(name, param_lines):
""" """
""" """
# or consider using the API: http://msdn.microsoft.com/en-us/library/aa446739%28v=VS.85%29.aspx # or consider using the API:
pbk_path = (path(os.environ['PROGRAMDATA']) # http://msdn.microsoft.com/en-us/library/aa446739%28v=VS.85%29.aspx
pbk_path = (
Path(os.environ['PROGRAMDATA'])
/ 'Microsoft' / 'Network' / 'Connections' / 'pbk' / 'rasphone.pbk') / 'Microsoft' / 'Network' / 'Connections' / 'pbk' / 'rasphone.pbk')
pbk_path.dirname().makedirs_p() pbk_path.dirname().makedirs_p()
with open(pbk_path, 'a') as pbk: with open(pbk_path, 'a') as pbk:

View file

@ -17,6 +17,7 @@ def set(value):
) )
handle_nonzero_success(result) handle_nonzero_success(result)
def get(): def get():
value = ctypes.wintypes.BOOL() value = ctypes.wintypes.BOOL()
result = system.SystemParametersInfo( result = system.SystemParametersInfo(
@ -28,6 +29,7 @@ def get():
handle_nonzero_success(result) handle_nonzero_success(result)
return bool(value) return bool(value)
def set_delay(milliseconds): def set_delay(milliseconds):
result = system.SystemParametersInfo( result = system.SystemParametersInfo(
system.SPI_SETACTIVEWNDTRKTIMEOUT, system.SPI_SETACTIVEWNDTRKTIMEOUT,
@ -37,6 +39,7 @@ def set_delay(milliseconds):
) )
handle_nonzero_success(result) handle_nonzero_success(result)
def get_delay(): def get_delay():
value = ctypes.wintypes.DWORD() value = ctypes.wintypes.DWORD()
result = system.SystemParametersInfo( result = system.SystemParametersInfo(

View file

@ -1,2 +1,2 @@
from more_itertools.more import * from more_itertools.more import * # noqa
from more_itertools.recipes import * from more_itertools.recipes import * # noqa

File diff suppressed because it is too large Load diff

View file

@ -8,21 +8,79 @@ Some backward-compatible usability improvements have been made.
""" """
from collections import deque from collections import deque
from itertools import chain, combinations, count, cycle, groupby, ifilterfalse, imap, islice, izip, izip_longest, repeat, starmap, tee # Wrapping breaks 2to3. from itertools import (
chain, combinations, count, cycle, groupby, islice, repeat, starmap, tee
)
import operator import operator
from random import randrange, sample, choice from random import randrange, sample, choice
from six import PY2
from six.moves import filter, filterfalse, map, range, zip, zip_longest
__all__ = ['take', 'tabulate', 'consume', 'nth', 'quantify', 'padnone', __all__ = [
'ncycles', 'dotproduct', 'flatten', 'repeatfunc', 'pairwise', 'accumulate',
'grouper', 'roundrobin', 'powerset', 'unique_everseen', 'all_equal',
'unique_justseen', 'iter_except', 'random_product', 'consume',
'random_permutation', 'random_combination', 'dotproduct',
'random_combination_with_replacement'] 'first_true',
'flatten',
'grouper',
'iter_except',
'ncycles',
'nth',
'nth_combination',
'padnone',
'pairwise',
'partition',
'powerset',
'prepend',
'quantify',
'random_combination_with_replacement',
'random_combination',
'random_permutation',
'random_product',
'repeatfunc',
'roundrobin',
'tabulate',
'tail',
'take',
'unique_everseen',
'unique_justseen',
]
def accumulate(iterable, func=operator.add):
"""
Return an iterator whose items are the accumulated results of a function
(specified by the optional *func* argument) that takes two arguments.
By default, returns accumulated sums with :func:`operator.add`.
>>> list(accumulate([1, 2, 3, 4, 5])) # Running sum
[1, 3, 6, 10, 15]
>>> list(accumulate([1, 2, 3], func=operator.mul)) # Running product
[1, 2, 6]
>>> list(accumulate([0, 1, -1, 2, 3, 2], func=max)) # Running maximum
[0, 1, 1, 2, 3, 3]
This function is available in the ``itertools`` module for Python 3.2 and
greater.
"""
it = iter(iterable)
try:
total = next(it)
except StopIteration:
return
else:
yield total
for element in it:
total = func(total, element)
yield total
def take(n, iterable): def take(n, iterable):
"""Return first n items of the iterable as a list """Return first *n* items of the iterable as a list.
>>> take(3, range(10)) >>> take(3, range(10))
[0, 1, 2] [0, 1, 2]
@ -37,21 +95,37 @@ def take(n, iterable):
def tabulate(function, start=0): def tabulate(function, start=0):
"""Return an iterator mapping the function over linear input. """Return an iterator over the results of ``func(start)``,
``func(start + 1)``, ``func(start + 2)``...
The start argument will be increased by 1 each time the iterator is called *func* should be a function that accepts one integer argument.
and fed into the function.
>>> t = tabulate(lambda x: x**2, -3) If *start* is not specified it defaults to 0. It will be incremented each
>>> take(3, t) time the iterator is advanced.
[9, 4, 1]
>>> square = lambda x: x ** 2
>>> iterator = tabulate(square, -3)
>>> take(4, iterator)
[9, 4, 1, 0]
""" """
return imap(function, count(start)) return map(function, count(start))
def tail(n, iterable):
"""Return an iterator over the last *n* items of *iterable*.
>>> t = tail(3, 'ABCDEFG')
>>> list(t)
['E', 'F', 'G']
"""
return iter(deque(iterable, maxlen=n))
def consume(iterator, n=None): def consume(iterator, n=None):
"""Advance the iterator n-steps ahead. If n is none, consume entirely. """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
entirely.
Efficiently exhausts an iterator without returning values. Defaults to Efficiently exhausts an iterator without returning values. Defaults to
consuming the whole iterator, but an optional second argument may be consuming the whole iterator, but an optional second argument may be
@ -90,7 +164,7 @@ def consume(iterator, n=None):
def nth(iterable, n, default=None): def nth(iterable, n, default=None):
"""Returns the nth item or a default value """Returns the nth item or a default value.
>>> l = range(10) >>> l = range(10)
>>> nth(l, 3) >>> nth(l, 3)
@ -102,30 +176,46 @@ def nth(iterable, n, default=None):
return next(islice(iterable, n, None), default) return next(islice(iterable, n, None), default)
def all_equal(iterable):
"""
Returns ``True`` if all the elements are equal to each other.
>>> all_equal('aaaa')
True
>>> all_equal('aaab')
False
"""
g = groupby(iterable)
return next(g, True) and not next(g, False)
def quantify(iterable, pred=bool): def quantify(iterable, pred=bool):
"""Return the how many times the predicate is true """Return the how many times the predicate is true.
>>> quantify([True, False, True]) >>> quantify([True, False, True])
2 2
""" """
return sum(imap(pred, iterable)) return sum(map(pred, iterable))
def padnone(iterable): def padnone(iterable):
"""Returns the sequence of elements and then returns None indefinitely. """Returns the sequence of elements and then returns ``None`` indefinitely.
>>> take(5, padnone(range(3))) >>> take(5, padnone(range(3)))
[0, 1, 2, None, None] [0, 1, 2, None, None]
Useful for emulating the behavior of the built-in map() function. Useful for emulating the behavior of the built-in :func:`map` function.
See also :func:`padded`.
""" """
return chain(iterable, repeat(None)) return chain(iterable, repeat(None))
def ncycles(iterable, n): def ncycles(iterable, n):
"""Returns the sequence elements n times """Returns the sequence elements *n* times
>>> list(ncycles(["a", "b"], 3)) >>> list(ncycles(["a", "b"], 3))
['a', 'b', 'a', 'b', 'a', 'b'] ['a', 'b', 'a', 'b', 'a', 'b']
@ -135,32 +225,47 @@ def ncycles(iterable, n):
def dotproduct(vec1, vec2): def dotproduct(vec1, vec2):
"""Returns the dot product of the two iterables """Returns the dot product of the two iterables.
>>> dotproduct([10, 10], [20, 20]) >>> dotproduct([10, 10], [20, 20])
400 400
""" """
return sum(imap(operator.mul, vec1, vec2)) return sum(map(operator.mul, vec1, vec2))
def flatten(listOfLists): def flatten(listOfLists):
"""Return an iterator flattening one level of nesting in a list of lists """Return an iterator flattening one level of nesting in a list of lists.
>>> list(flatten([[0, 1], [2, 3]])) >>> list(flatten([[0, 1], [2, 3]]))
[0, 1, 2, 3] [0, 1, 2, 3]
See also :func:`collapse`, which can flatten multiple levels of nesting.
""" """
return chain.from_iterable(listOfLists) return chain.from_iterable(listOfLists)
def repeatfunc(func, times=None, *args): def repeatfunc(func, times=None, *args):
"""Repeat calls to func with specified arguments. """Call *func* with *args* repeatedly, returning an iterable over the
results.
>>> list(repeatfunc(lambda: 5, 3)) If *times* is specified, the iterable will terminate after that many
[5, 5, 5] repetitions:
>>> list(repeatfunc(lambda x: x ** 2, 3, 3))
[9, 9, 9] >>> from operator import add
>>> times = 4
>>> args = 3, 5
>>> list(repeatfunc(add, times, *args))
[8, 8, 8, 8]
If *times* is ``None`` the iterable will not terminate:
>>> from random import randrange
>>> times = None
>>> args = 1, 11
>>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
[2, 4, 8, 1, 8, 4]
""" """
if times is None: if times is None:
@ -177,30 +282,37 @@ def pairwise(iterable):
""" """
a, b = tee(iterable) a, b = tee(iterable)
next(b, None) next(b, None)
return izip(a, b) return zip(a, b)
def grouper(n, iterable, fillvalue=None): def grouper(n, iterable, fillvalue=None):
"""Collect data into fixed-length chunks or blocks """Collect data into fixed-length chunks or blocks.
>>> list(grouper(3, 'ABCDEFG', 'x')) >>> list(grouper(3, 'ABCDEFG', 'x'))
[('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
""" """
args = [iter(iterable)] * n args = [iter(iterable)] * n
return izip_longest(fillvalue=fillvalue, *args) return zip_longest(fillvalue=fillvalue, *args)
def roundrobin(*iterables): def roundrobin(*iterables):
"""Yields an item from each iterable, alternating between them """Yields an item from each iterable, alternating between them.
>>> list(roundrobin('ABC', 'D', 'EF')) >>> list(roundrobin('ABC', 'D', 'EF'))
['A', 'D', 'E', 'B', 'F', 'C'] ['A', 'D', 'E', 'B', 'F', 'C']
This function produces the same output as :func:`interleave_longest`, but
may perform better for some inputs (in particular when the number of
iterables is small).
""" """
# Recipe credited to George Sakkis # Recipe credited to George Sakkis
pending = len(iterables) pending = len(iterables)
if PY2:
nexts = cycle(iter(it).next for it in iterables) nexts = cycle(iter(it).next for it in iterables)
else:
nexts = cycle(iter(it).__next__ for it in iterables)
while pending: while pending:
try: try:
for next in nexts: for next in nexts:
@ -210,37 +322,72 @@ def roundrobin(*iterables):
nexts = cycle(islice(nexts, pending)) nexts = cycle(islice(nexts, pending))
def partition(pred, iterable):
"""
Returns a 2-tuple of iterables derived from the input iterable.
The first yields the items that have ``pred(item) == False``.
The second yields the items that have ``pred(item) == True``.
>>> is_odd = lambda x: x % 2 != 0
>>> iterable = range(10)
>>> even_items, odd_items = partition(is_odd, iterable)
>>> list(even_items), list(odd_items)
([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
"""
# partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9
t1, t2 = tee(iterable)
return filterfalse(pred, t1), filter(pred, t2)
def powerset(iterable): def powerset(iterable):
"""Yields all possible subsets of the iterable """Yields all possible subsets of the iterable.
>>> list(powerset([1,2,3])) >>> list(powerset([1,2,3]))
[(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
""" """
s = list(iterable) s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1)) return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
def unique_everseen(iterable, key=None): def unique_everseen(iterable, key=None):
"""Yield unique elements, preserving order. """
Yield unique elements, preserving order.
>>> list(unique_everseen('AAAABBBCCDAABBB')) >>> list(unique_everseen('AAAABBBCCDAABBB'))
['A', 'B', 'C', 'D'] ['A', 'B', 'C', 'D']
>>> list(unique_everseen('ABBCcAD', str.lower)) >>> list(unique_everseen('ABBCcAD', str.lower))
['A', 'B', 'C', 'D'] ['A', 'B', 'C', 'D']
Sequences with a mix of hashable and unhashable items can be used.
The function will be slower (i.e., `O(n^2)`) for unhashable items.
""" """
seen = set() seenset = set()
seen_add = seen.add seenset_add = seenset.add
seenlist = []
seenlist_add = seenlist.append
if key is None: if key is None:
for element in ifilterfalse(seen.__contains__, iterable): for element in iterable:
seen_add(element) try:
if element not in seenset:
seenset_add(element)
yield element
except TypeError:
if element not in seenlist:
seenlist_add(element)
yield element yield element
else: else:
for element in iterable: for element in iterable:
k = key(element) k = key(element)
if k not in seen: try:
seen_add(k) if k not in seenset:
seenset_add(k)
yield element
except TypeError:
if k not in seenlist:
seenlist_add(k)
yield element yield element
@ -253,17 +400,17 @@ def unique_justseen(iterable, key=None):
['A', 'B', 'C', 'A', 'D'] ['A', 'B', 'C', 'A', 'D']
""" """
return imap(next, imap(operator.itemgetter(1), groupby(iterable, key))) return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
def iter_except(func, exception, first=None): def iter_except(func, exception, first=None):
"""Yields results from a function repeatedly until an exception is raised. """Yields results from a function repeatedly until an exception is raised.
Converts a call-until-exception interface to an iterator interface. Converts a call-until-exception interface to an iterator interface.
Like __builtin__.iter(func, sentinel) but uses an exception instead Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
of a sentinel to end the loop. to end the loop.
>>> l = range(3) >>> l = [0, 1, 2]
>>> list(iter_except(l.pop, IndexError)) >>> list(iter_except(l.pop, IndexError))
[2, 1, 0] [2, 1, 0]
@ -277,28 +424,58 @@ def iter_except(func, exception, first=None):
pass pass
def random_product(*args, **kwds): def first_true(iterable, default=False, pred=None):
"""Returns a random pairing of items from each iterable argument """
Returns the first true value in the iterable.
If `repeat` is provided as a kwarg, it's value will be used to indicate If no true value is found, returns *default*
how many pairings should be chosen.
>>> random_product(['a', 'b', 'c'], [1, 2], repeat=2) # doctest:+SKIP If *pred* is not None, returns the first item for which
('b', '2', 'c', '2') ``pred(item) == True`` .
>>> first_true(range(10))
1
>>> first_true(range(10), pred=lambda x: x > 5)
6
>>> first_true(range(10), default='missing', pred=lambda x: x > 9)
'missing'
""" """
pools = map(tuple, args) * kwds.get('repeat', 1) return next(filter(pred, iterable), default)
def random_product(*args, **kwds):
"""Draw an item at random from each of the input iterables.
>>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
('c', 3, 'Z')
If *repeat* is provided as a keyword argument, that many items will be
drawn from each iterable.
>>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
('a', 2, 'd', 3)
This equivalent to taking a random selection from
``itertools.product(*args, **kwarg)``.
"""
pools = [tuple(pool) for pool in args] * kwds.get('repeat', 1)
return tuple(choice(pool) for pool in pools) return tuple(choice(pool) for pool in pools)
def random_permutation(iterable, r=None): def random_permutation(iterable, r=None):
"""Returns a random permutation. """Return a random *r* length permutation of the elements in *iterable*.
If r is provided, the permutation is truncated to length r. If *r* is not specified or is ``None``, then *r* defaults to the length of
*iterable*.
>>> random_permutation(range(5)) # doctest:+SKIP >>> random_permutation(range(5)) # doctest:+SKIP
(3, 4, 0, 1, 2) (3, 4, 0, 1, 2)
This equivalent to taking a random selection from
``itertools.permutations(iterable, r)``.
""" """
pool = tuple(iterable) pool = tuple(iterable)
r = len(pool) if r is None else r r = len(pool) if r is None else r
@ -306,26 +483,83 @@ def random_permutation(iterable, r=None):
def random_combination(iterable, r): def random_combination(iterable, r):
"""Returns a random combination of length r, chosen without replacement. """Return a random *r* length subsequence of the elements in *iterable*.
>>> random_combination(range(5), 3) # doctest:+SKIP >>> random_combination(range(5), 3) # doctest:+SKIP
(2, 3, 4) (2, 3, 4)
This equivalent to taking a random selection from
``itertools.combinations(iterable, r)``.
""" """
pool = tuple(iterable) pool = tuple(iterable)
n = len(pool) n = len(pool)
indices = sorted(sample(xrange(n), r)) indices = sorted(sample(range(n), r))
return tuple(pool[i] for i in indices) return tuple(pool[i] for i in indices)
def random_combination_with_replacement(iterable, r): def random_combination_with_replacement(iterable, r):
"""Returns a random combination of length r, chosen with replacement. """Return a random *r* length subsequence of elements in *iterable*,
allowing individual elements to be repeated.
>>> random_combination_with_replacement(range(3), 5) # # doctest:+SKIP >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
(0, 0, 1, 2, 2) (0, 0, 1, 2, 2)
This equivalent to taking a random selection from
``itertools.combinations_with_replacement(iterable, r)``.
""" """
pool = tuple(iterable) pool = tuple(iterable)
n = len(pool) n = len(pool)
indices = sorted(randrange(n) for i in xrange(r)) indices = sorted(randrange(n) for i in range(r))
return tuple(pool[i] for i in indices) return tuple(pool[i] for i in indices)
def nth_combination(iterable, r, index):
"""Equivalent to ``list(combinations(iterable, r))[index]``.
The subsequences of *iterable* that are of length *r* can be ordered
lexicographically. :func:`nth_combination` computes the subsequence at
sort position *index* directly, without computing the previous
subsequences.
"""
pool = tuple(iterable)
n = len(pool)
if (r < 0) or (r > n):
raise ValueError
c = 1
k = min(r, n - r)
for i in range(1, k + 1):
c = c * (n - k + i) // i
if index < 0:
index += c
if (index < 0) or (index >= c):
raise IndexError
result = []
while r:
c, n, r = c * r // n, n - 1, r - 1
while index >= c:
index -= c
c, n = c * (n - r) // n, n - 1
result.append(pool[-1 - n])
return tuple(result)
def prepend(value, iterator):
"""Yield *value*, followed by the elements in *iterator*.
>>> value = '0'
>>> iterator = ['1', '2', '3']
>>> list(prepend(value, iterator))
['0', '1', '2', '3']
To prepend multiple values, see :func:`itertools.chain`.
"""
return chain([value], iterator)

File diff suppressed because it is too large Load diff

View file

@ -1,13 +1,39 @@
from random import seed from doctest import DocTestSuite
from unittest import TestCase from unittest import TestCase
from nose.tools import eq_, assert_raises, ok_ from itertools import combinations
from six.moves import range
from more_itertools import * import more_itertools as mi
def setup_module(): def load_tests(loader, tests, ignore):
seed(1337) # Add the doctests
tests.addTests(DocTestSuite('more_itertools.recipes'))
return tests
class AccumulateTests(TestCase):
"""Tests for ``accumulate()``"""
def test_empty(self):
"""Test that an empty input returns an empty output"""
self.assertEqual(list(mi.accumulate([])), [])
def test_default(self):
"""Test accumulate with the default function (addition)"""
self.assertEqual(list(mi.accumulate([1, 2, 3])), [1, 3, 6])
def test_bogus_function(self):
"""Test accumulate with an invalid function"""
with self.assertRaises(TypeError):
list(mi.accumulate([1, 2, 3], func=lambda x: x))
def test_custom_function(self):
"""Test accumulate with a custom function"""
self.assertEqual(
list(mi.accumulate((1, 2, 3, 2, 1), func=max)), [1, 2, 3, 3, 3]
)
class TakeTests(TestCase): class TakeTests(TestCase):
@ -15,25 +41,25 @@ class TakeTests(TestCase):
def test_simple_take(self): def test_simple_take(self):
"""Test basic usage""" """Test basic usage"""
t = take(5, xrange(10)) t = mi.take(5, range(10))
eq_(t, [0, 1, 2, 3, 4]) self.assertEqual(t, [0, 1, 2, 3, 4])
def test_null_take(self): def test_null_take(self):
"""Check the null case""" """Check the null case"""
t = take(0, xrange(10)) t = mi.take(0, range(10))
eq_(t, []) self.assertEqual(t, [])
def test_negative_take(self): def test_negative_take(self):
"""Make sure taking negative items results in a ValueError""" """Make sure taking negative items results in a ValueError"""
assert_raises(ValueError, take, -3, xrange(10)) self.assertRaises(ValueError, lambda: mi.take(-3, range(10)))
def test_take_too_much(self): def test_take_too_much(self):
"""Taking more than an iterator has remaining should return what the """Taking more than an iterator has remaining should return what the
iterator has remaining. iterator has remaining.
""" """
t = take(10, xrange(5)) t = mi.take(10, range(5))
eq_(t, [0, 1, 2, 3, 4]) self.assertEqual(t, [0, 1, 2, 3, 4])
class TabulateTests(TestCase): class TabulateTests(TestCase):
@ -41,15 +67,35 @@ class TabulateTests(TestCase):
def test_simple_tabulate(self): def test_simple_tabulate(self):
"""Test the happy path""" """Test the happy path"""
t = tabulate(lambda x: x) t = mi.tabulate(lambda x: x)
f = tuple([next(t) for _ in range(3)]) f = tuple([next(t) for _ in range(3)])
eq_(f, (0, 1, 2)) self.assertEqual(f, (0, 1, 2))
def test_count(self): def test_count(self):
"""Ensure tabulate accepts specific count""" """Ensure tabulate accepts specific count"""
t = tabulate(lambda x: 2 * x, -1) t = mi.tabulate(lambda x: 2 * x, -1)
f = (next(t), next(t), next(t)) f = (next(t), next(t), next(t))
eq_(f, (-2, 0, 2)) self.assertEqual(f, (-2, 0, 2))
class TailTests(TestCase):
"""Tests for ``tail()``"""
def test_greater(self):
"""Length of iterable is greather than requested tail"""
self.assertEqual(list(mi.tail(3, 'ABCDEFG')), ['E', 'F', 'G'])
def test_equal(self):
"""Length of iterable is equal to the requested tail"""
self.assertEqual(
list(mi.tail(7, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G']
)
def test_less(self):
"""Length of iterable is less than requested tail"""
self.assertEqual(
list(mi.tail(8, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G']
)
class ConsumeTests(TestCase): class ConsumeTests(TestCase):
@ -58,25 +104,25 @@ class ConsumeTests(TestCase):
def test_sanity(self): def test_sanity(self):
"""Test basic functionality""" """Test basic functionality"""
r = (x for x in range(10)) r = (x for x in range(10))
consume(r, 3) mi.consume(r, 3)
eq_(3, next(r)) self.assertEqual(3, next(r))
def test_null_consume(self): def test_null_consume(self):
"""Check the null case""" """Check the null case"""
r = (x for x in range(10)) r = (x for x in range(10))
consume(r, 0) mi.consume(r, 0)
eq_(0, next(r)) self.assertEqual(0, next(r))
def test_negative_consume(self): def test_negative_consume(self):
"""Check that negative consumsion throws an error""" """Check that negative consumsion throws an error"""
r = (x for x in range(10)) r = (x for x in range(10))
assert_raises(ValueError, consume, r, -1) self.assertRaises(ValueError, lambda: mi.consume(r, -1))
def test_total_consume(self): def test_total_consume(self):
"""Check that iterator is totally consumed by default""" """Check that iterator is totally consumed by default"""
r = (x for x in range(10)) r = (x for x in range(10))
consume(r) mi.consume(r)
assert_raises(StopIteration, next, r) self.assertRaises(StopIteration, lambda: next(r))
class NthTests(TestCase): class NthTests(TestCase):
@ -86,16 +132,45 @@ class NthTests(TestCase):
"""Make sure the nth item is returned""" """Make sure the nth item is returned"""
l = range(10) l = range(10)
for i, v in enumerate(l): for i, v in enumerate(l):
eq_(nth(l, i), v) self.assertEqual(mi.nth(l, i), v)
def test_default(self): def test_default(self):
"""Ensure a default value is returned when nth item not found""" """Ensure a default value is returned when nth item not found"""
l = range(3) l = range(3)
eq_(nth(l, 100, "zebra"), "zebra") self.assertEqual(mi.nth(l, 100, "zebra"), "zebra")
def test_negative_item_raises(self): def test_negative_item_raises(self):
"""Ensure asking for a negative item raises an exception""" """Ensure asking for a negative item raises an exception"""
assert_raises(ValueError, nth, range(10), -3) self.assertRaises(ValueError, lambda: mi.nth(range(10), -3))
class AllEqualTests(TestCase):
"""Tests for ``all_equal()``"""
def test_true(self):
"""Everything is equal"""
self.assertTrue(mi.all_equal('aaaaaa'))
self.assertTrue(mi.all_equal([0, 0, 0, 0]))
def test_false(self):
"""Not everything is equal"""
self.assertFalse(mi.all_equal('aaaaab'))
self.assertFalse(mi.all_equal([0, 0, 0, 1]))
def test_tricky(self):
"""Not everything is identical, but everything is equal"""
items = [1, complex(1, 0), 1.0]
self.assertTrue(mi.all_equal(items))
def test_empty(self):
"""Return True if the iterable is empty"""
self.assertTrue(mi.all_equal(''))
self.assertTrue(mi.all_equal([]))
def test_one(self):
"""Return True if the iterable is singular"""
self.assertTrue(mi.all_equal('0'))
self.assertTrue(mi.all_equal([0]))
class QuantifyTests(TestCase): class QuantifyTests(TestCase):
@ -104,12 +179,12 @@ class QuantifyTests(TestCase):
def test_happy_path(self): def test_happy_path(self):
"""Make sure True count is returned""" """Make sure True count is returned"""
q = [True, False, True] q = [True, False, True]
eq_(quantify(q), 2) self.assertEqual(mi.quantify(q), 2)
def test_custom_predicate(self): def test_custom_predicate(self):
"""Ensure non-default predicates return as expected""" """Ensure non-default predicates return as expected"""
q = range(10) q = range(10)
eq_(quantify(q, lambda x: x % 2 == 0), 5) self.assertEqual(mi.quantify(q, lambda x: x % 2 == 0), 5)
class PadnoneTests(TestCase): class PadnoneTests(TestCase):
@ -118,8 +193,8 @@ class PadnoneTests(TestCase):
def test_happy_path(self): def test_happy_path(self):
"""wrapper iterator should return None indefinitely""" """wrapper iterator should return None indefinitely"""
r = range(2) r = range(2)
p = padnone(r) p = mi.padnone(r)
eq_([0, 1, None, None], [next(p) for _ in range(4)]) self.assertEqual([0, 1, None, None], [next(p) for _ in range(4)])
class NcyclesTests(TestCase): class NcyclesTests(TestCase):
@ -128,19 +203,21 @@ class NcyclesTests(TestCase):
def test_happy_path(self): def test_happy_path(self):
"""cycle a sequence three times""" """cycle a sequence three times"""
r = ["a", "b", "c"] r = ["a", "b", "c"]
n = ncycles(r, 3) n = mi.ncycles(r, 3)
eq_(["a", "b", "c", "a", "b", "c", "a", "b", "c"], self.assertEqual(
list(n)) ["a", "b", "c", "a", "b", "c", "a", "b", "c"],
list(n)
)
def test_null_case(self): def test_null_case(self):
"""asking for 0 cycles should return an empty iterator""" """asking for 0 cycles should return an empty iterator"""
n = ncycles(range(100), 0) n = mi.ncycles(range(100), 0)
assert_raises(StopIteration, next, n) self.assertRaises(StopIteration, lambda: next(n))
def test_pathalogical_case(self): def test_pathalogical_case(self):
"""asking for negative cycles should return an empty iterator""" """asking for negative cycles should return an empty iterator"""
n = ncycles(range(100), -10) n = mi.ncycles(range(100), -10)
assert_raises(StopIteration, next, n) self.assertRaises(StopIteration, lambda: next(n))
class DotproductTests(TestCase): class DotproductTests(TestCase):
@ -148,7 +225,7 @@ class DotproductTests(TestCase):
def test_happy_path(self): def test_happy_path(self):
"""simple dotproduct example""" """simple dotproduct example"""
eq_(400, dotproduct([10, 10], [20, 20])) self.assertEqual(400, mi.dotproduct([10, 10], [20, 20]))
class FlattenTests(TestCase): class FlattenTests(TestCase):
@ -157,12 +234,12 @@ class FlattenTests(TestCase):
def test_basic_usage(self): def test_basic_usage(self):
"""ensure list of lists is flattened one level""" """ensure list of lists is flattened one level"""
f = [[0, 1, 2], [3, 4, 5]] f = [[0, 1, 2], [3, 4, 5]]
eq_(range(6), list(flatten(f))) self.assertEqual(list(range(6)), list(mi.flatten(f)))
def test_single_level(self): def test_single_level(self):
"""ensure list of lists is flattened only one level""" """ensure list of lists is flattened only one level"""
f = [[0, [1, 2]], [[3, 4], 5]] f = [[0, [1, 2]], [[3, 4], 5]]
eq_([0, [1, 2], [3, 4], 5], list(flatten(f))) self.assertEqual([0, [1, 2], [3, 4], 5], list(mi.flatten(f)))
class RepeatfuncTests(TestCase): class RepeatfuncTests(TestCase):
@ -170,23 +247,23 @@ class RepeatfuncTests(TestCase):
def test_simple_repeat(self): def test_simple_repeat(self):
"""test simple repeated functions""" """test simple repeated functions"""
r = repeatfunc(lambda: 5) r = mi.repeatfunc(lambda: 5)
eq_([5, 5, 5, 5, 5], [next(r) for _ in range(5)]) self.assertEqual([5, 5, 5, 5, 5], [next(r) for _ in range(5)])
def test_finite_repeat(self): def test_finite_repeat(self):
"""ensure limited repeat when times is provided""" """ensure limited repeat when times is provided"""
r = repeatfunc(lambda: 5, times=5) r = mi.repeatfunc(lambda: 5, times=5)
eq_([5, 5, 5, 5, 5], list(r)) self.assertEqual([5, 5, 5, 5, 5], list(r))
def test_added_arguments(self): def test_added_arguments(self):
"""ensure arguments are applied to the function""" """ensure arguments are applied to the function"""
r = repeatfunc(lambda x: x, 2, 3) r = mi.repeatfunc(lambda x: x, 2, 3)
eq_([3, 3], list(r)) self.assertEqual([3, 3], list(r))
def test_null_times(self): def test_null_times(self):
"""repeat 0 should return an empty iterator""" """repeat 0 should return an empty iterator"""
r = repeatfunc(range, 0, 3) r = mi.repeatfunc(range, 0, 3)
assert_raises(StopIteration, next, r) self.assertRaises(StopIteration, lambda: next(r))
class PairwiseTests(TestCase): class PairwiseTests(TestCase):
@ -194,13 +271,13 @@ class PairwiseTests(TestCase):
def test_base_case(self): def test_base_case(self):
"""ensure an iterable will return pairwise""" """ensure an iterable will return pairwise"""
p = pairwise([1, 2, 3]) p = mi.pairwise([1, 2, 3])
eq_([(1, 2), (2, 3)], list(p)) self.assertEqual([(1, 2), (2, 3)], list(p))
def test_short_case(self): def test_short_case(self):
"""ensure an empty iterator if there's not enough values to pair""" """ensure an empty iterator if there's not enough values to pair"""
p = pairwise("a") p = mi.pairwise("a")
assert_raises(StopIteration, next, p) self.assertRaises(StopIteration, lambda: next(p))
class GrouperTests(TestCase): class GrouperTests(TestCase):
@ -211,18 +288,25 @@ class GrouperTests(TestCase):
the iterable. the iterable.
""" """
eq_(list(grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')]) self.assertEqual(
list(mi.grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')]
)
def test_odd(self): def test_odd(self):
"""Test when group size does not divide evenly into the length of the """Test when group size does not divide evenly into the length of the
iterable. iterable.
""" """
eq_(list(grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)]) self.assertEqual(
list(mi.grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)]
)
def test_fill_value(self): def test_fill_value(self):
"""Test that the fill value is used to pad the final group""" """Test that the fill value is used to pad the final group"""
eq_(list(grouper(3, 'ABCDE', 'x')), [('A', 'B', 'C'), ('D', 'E', 'x')]) self.assertEqual(
list(mi.grouper(3, 'ABCDE', 'x')),
[('A', 'B', 'C'), ('D', 'E', 'x')]
)
class RoundrobinTests(TestCase): class RoundrobinTests(TestCase):
@ -230,13 +314,33 @@ class RoundrobinTests(TestCase):
def test_even_groups(self): def test_even_groups(self):
"""Ensure ordered output from evenly populated iterables""" """Ensure ordered output from evenly populated iterables"""
eq_(list(roundrobin('ABC', [1, 2, 3], range(3))), self.assertEqual(
['A', 1, 0, 'B', 2, 1, 'C', 3, 2]) list(mi.roundrobin('ABC', [1, 2, 3], range(3))),
['A', 1, 0, 'B', 2, 1, 'C', 3, 2]
)
def test_uneven_groups(self): def test_uneven_groups(self):
"""Ensure ordered output from unevenly populated iterables""" """Ensure ordered output from unevenly populated iterables"""
eq_(list(roundrobin('ABCD', [1, 2], range(0))), self.assertEqual(
['A', 1, 'B', 2, 'C', 'D']) list(mi.roundrobin('ABCD', [1, 2], range(0))),
['A', 1, 'B', 2, 'C', 'D']
)
class PartitionTests(TestCase):
"""Tests for ``partition()``"""
def test_bool(self):
"""Test when pred() returns a boolean"""
lesser, greater = mi.partition(lambda x: x > 5, range(10))
self.assertEqual(list(lesser), [0, 1, 2, 3, 4, 5])
self.assertEqual(list(greater), [6, 7, 8, 9])
def test_arbitrary(self):
"""Test when pred() returns an integer"""
divisibles, remainders = mi.partition(lambda x: x % 3, range(10))
self.assertEqual(list(divisibles), [0, 3, 6, 9])
self.assertEqual(list(remainders), [1, 2, 4, 5, 7, 8])
class PowersetTests(TestCase): class PowersetTests(TestCase):
@ -244,9 +348,11 @@ class PowersetTests(TestCase):
def test_combinatorics(self): def test_combinatorics(self):
"""Ensure a proper enumeration""" """Ensure a proper enumeration"""
p = powerset([1, 2, 3]) p = mi.powerset([1, 2, 3])
eq_(list(p), self.assertEqual(
[(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]) list(p),
[(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
)
class UniqueEverseenTests(TestCase): class UniqueEverseenTests(TestCase):
@ -254,14 +360,28 @@ class UniqueEverseenTests(TestCase):
def test_everseen(self): def test_everseen(self):
"""ensure duplicate elements are ignored""" """ensure duplicate elements are ignored"""
u = unique_everseen('AAAABBBBCCDAABBB') u = mi.unique_everseen('AAAABBBBCCDAABBB')
eq_(['A', 'B', 'C', 'D'], self.assertEqual(
list(u)) ['A', 'B', 'C', 'D'],
list(u)
)
def test_custom_key(self): def test_custom_key(self):
"""ensure the custom key comparison works""" """ensure the custom key comparison works"""
u = unique_everseen('aAbACCc', key=str.lower) u = mi.unique_everseen('aAbACCc', key=str.lower)
eq_(list('abC'), list(u)) self.assertEqual(list('abC'), list(u))
def test_unhashable(self):
"""ensure things work for unhashable items"""
iterable = ['a', [1, 2, 3], [1, 2, 3], 'a']
u = mi.unique_everseen(iterable)
self.assertEqual(list(u), ['a', [1, 2, 3]])
def test_unhashable_key(self):
"""ensure things work for unhashable items with a custom key"""
iterable = ['a', [1, 2, 3], [1, 2, 3], 'a']
u = mi.unique_everseen(iterable, key=lambda x: x)
self.assertEqual(list(u), ['a', [1, 2, 3]])
class UniqueJustseenTests(TestCase): class UniqueJustseenTests(TestCase):
@ -269,13 +389,13 @@ class UniqueJustseenTests(TestCase):
def test_justseen(self): def test_justseen(self):
"""ensure only last item is remembered""" """ensure only last item is remembered"""
u = unique_justseen('AAAABBBCCDABB') u = mi.unique_justseen('AAAABBBCCDABB')
eq_(list('ABCDAB'), list(u)) self.assertEqual(list('ABCDAB'), list(u))
def test_custom_key(self): def test_custom_key(self):
"""ensure the custom key comparison works""" """ensure the custom key comparison works"""
u = unique_justseen('AABCcAD', str.lower) u = mi.unique_justseen('AABCcAD', str.lower)
eq_(list('ABCAD'), list(u)) self.assertEqual(list('ABCAD'), list(u))
class IterExceptTests(TestCase): class IterExceptTests(TestCase):
@ -284,27 +404,49 @@ class IterExceptTests(TestCase):
def test_exact_exception(self): def test_exact_exception(self):
"""ensure the exact specified exception is caught""" """ensure the exact specified exception is caught"""
l = [1, 2, 3] l = [1, 2, 3]
i = iter_except(l.pop, IndexError) i = mi.iter_except(l.pop, IndexError)
eq_(list(i), [3, 2, 1]) self.assertEqual(list(i), [3, 2, 1])
def test_generic_exception(self): def test_generic_exception(self):
"""ensure the generic exception can be caught""" """ensure the generic exception can be caught"""
l = [1, 2] l = [1, 2]
i = iter_except(l.pop, Exception) i = mi.iter_except(l.pop, Exception)
eq_(list(i), [2, 1]) self.assertEqual(list(i), [2, 1])
def test_uncaught_exception_is_raised(self): def test_uncaught_exception_is_raised(self):
"""ensure a non-specified exception is raised""" """ensure a non-specified exception is raised"""
l = [1, 2, 3] l = [1, 2, 3]
i = iter_except(l.pop, KeyError) i = mi.iter_except(l.pop, KeyError)
assert_raises(IndexError, list, i) self.assertRaises(IndexError, lambda: list(i))
def test_first(self): def test_first(self):
"""ensure first is run before the function""" """ensure first is run before the function"""
l = [1, 2, 3] l = [1, 2, 3]
f = lambda: 25 f = lambda: 25
i = iter_except(l.pop, IndexError, f) i = mi.iter_except(l.pop, IndexError, f)
eq_(list(i), [25, 3, 2, 1]) self.assertEqual(list(i), [25, 3, 2, 1])
class FirstTrueTests(TestCase):
"""Tests for ``first_true()``"""
def test_something_true(self):
"""Test with no keywords"""
self.assertEqual(mi.first_true(range(10)), 1)
def test_nothing_true(self):
"""Test default return value."""
self.assertEqual(mi.first_true([0, 0, 0]), False)
def test_default(self):
"""Test with a default keyword"""
self.assertEqual(mi.first_true([0, 0, 0], default='!'), '!')
def test_pred(self):
"""Test with a custom predicate"""
self.assertEqual(
mi.first_true([2, 4, 6], pred=lambda x: x % 3 == 0), 6
)
class RandomProductTests(TestCase): class RandomProductTests(TestCase):
@ -327,12 +469,12 @@ class RandomProductTests(TestCase):
""" """
nums = [1, 2, 3] nums = [1, 2, 3]
lets = ['a', 'b', 'c'] lets = ['a', 'b', 'c']
n, m = zip(*[random_product(nums, lets) for _ in range(100)]) n, m = zip(*[mi.random_product(nums, lets) for _ in range(100)])
n, m = set(n), set(m) n, m = set(n), set(m)
eq_(n, set(nums)) self.assertEqual(n, set(nums))
eq_(m, set(lets)) self.assertEqual(m, set(lets))
eq_(len(n), len(nums)) self.assertEqual(len(n), len(nums))
eq_(len(m), len(lets)) self.assertEqual(len(m), len(lets))
def test_list_with_repeat(self): def test_list_with_repeat(self):
"""ensure multiple items are chosen, and that they appear to be chosen """ensure multiple items are chosen, and that they appear to be chosen
@ -341,13 +483,13 @@ class RandomProductTests(TestCase):
""" """
nums = [1, 2, 3] nums = [1, 2, 3]
lets = ['a', 'b', 'c'] lets = ['a', 'b', 'c']
r = list(random_product(nums, lets, repeat=100)) r = list(mi.random_product(nums, lets, repeat=100))
eq_(2 * 100, len(r)) self.assertEqual(2 * 100, len(r))
n, m = set(r[::2]), set(r[1::2]) n, m = set(r[::2]), set(r[1::2])
eq_(n, set(nums)) self.assertEqual(n, set(nums))
eq_(m, set(lets)) self.assertEqual(m, set(lets))
eq_(len(n), len(nums)) self.assertEqual(len(n), len(nums))
eq_(len(m), len(lets)) self.assertEqual(len(m), len(lets))
class RandomPermutationTests(TestCase): class RandomPermutationTests(TestCase):
@ -361,8 +503,8 @@ class RandomPermutationTests(TestCase):
""" """
i = range(15) i = range(15)
r = random_permutation(i) r = mi.random_permutation(i)
eq_(set(i), set(r)) self.assertEqual(set(i), set(r))
if i == r: if i == r:
raise AssertionError("Values were not permuted") raise AssertionError("Values were not permuted")
@ -380,13 +522,13 @@ class RandomPermutationTests(TestCase):
items = range(15) items = range(15)
item_set = set(items) item_set = set(items)
all_items = set() all_items = set()
for _ in xrange(100): for _ in range(100):
permutation = random_permutation(items, 5) permutation = mi.random_permutation(items, 5)
eq_(len(permutation), 5) self.assertEqual(len(permutation), 5)
permutation_set = set(permutation) permutation_set = set(permutation)
ok_(permutation_set <= item_set) self.assertLessEqual(permutation_set, item_set)
all_items |= permutation_set all_items |= permutation_set
eq_(all_items, item_set) self.assertEqual(all_items, item_set)
class RandomCombinationTests(TestCase): class RandomCombinationTests(TestCase):
@ -397,18 +539,20 @@ class RandomCombinationTests(TestCase):
samplings of random combinations""" samplings of random combinations"""
items = range(15) items = range(15)
all_items = set() all_items = set()
for _ in xrange(50): for _ in range(50):
combination = random_combination(items, 5) combination = mi.random_combination(items, 5)
all_items |= set(combination) all_items |= set(combination)
eq_(all_items, set(items)) self.assertEqual(all_items, set(items))
def test_no_replacement(self): def test_no_replacement(self):
"""ensure that elements are sampled without replacement""" """ensure that elements are sampled without replacement"""
items = range(15) items = range(15)
for _ in xrange(50): for _ in range(50):
combination = random_combination(items, len(items)) combination = mi.random_combination(items, len(items))
eq_(len(combination), len(set(combination))) self.assertEqual(len(combination), len(set(combination)))
assert_raises(ValueError, random_combination, items, len(items) + 1) self.assertRaises(
ValueError, lambda: mi.random_combination(items, len(items) + 1)
)
class RandomCombinationWithReplacementTests(TestCase): class RandomCombinationWithReplacementTests(TestCase):
@ -417,17 +561,56 @@ class RandomCombinationWithReplacementTests(TestCase):
def test_replacement(self): def test_replacement(self):
"""ensure that elements are sampled with replacement""" """ensure that elements are sampled with replacement"""
items = range(5) items = range(5)
combo = random_combination_with_replacement(items, len(items) * 2) combo = mi.random_combination_with_replacement(items, len(items) * 2)
eq_(2 * len(items), len(combo)) self.assertEqual(2 * len(items), len(combo))
if len(set(combo)) == len(combo): if len(set(combo)) == len(combo):
raise AssertionError("Combination contained no duplicates") raise AssertionError("Combination contained no duplicates")
def test_psuedorandomness(self): def test_pseudorandomness(self):
"""ensure different subsets of the iterable get returned over many """ensure different subsets of the iterable get returned over many
samplings of random combinations""" samplings of random combinations"""
items = range(15) items = range(15)
all_items = set() all_items = set()
for _ in xrange(50): for _ in range(50):
combination = random_combination_with_replacement(items, 5) combination = mi.random_combination_with_replacement(items, 5)
all_items |= set(combination) all_items |= set(combination)
eq_(all_items, set(items)) self.assertEqual(all_items, set(items))
class NthCombinationTests(TestCase):
def test_basic(self):
iterable = 'abcdefg'
r = 4
for index, expected in enumerate(combinations(iterable, r)):
actual = mi.nth_combination(iterable, r, index)
self.assertEqual(actual, expected)
def test_long(self):
actual = mi.nth_combination(range(180), 4, 2000000)
expected = (2, 12, 35, 126)
self.assertEqual(actual, expected)
def test_invalid_r(self):
for r in (-1, 3):
with self.assertRaises(ValueError):
mi.nth_combination([], r, 0)
def test_invalid_index(self):
with self.assertRaises(IndexError):
mi.nth_combination('abcdefg', 3, -36)
class PrependTests(TestCase):
def test_basic(self):
value = 'a'
iterator = iter('bcdefg')
actual = list(mi.prepend(value, iterator))
expected = list('abcdefg')
self.assertEqual(actual, expected)
def test_multiple(self):
value = 'ab'
iterator = iter('cdefg')
actual = tuple(mi.prepend(value, iterator))
expected = ('ab',) + tuple('cdefg')
self.assertEqual(actual, expected)

View file

@ -1,25 +1,3 @@
#
# Copyright (c) 2010 Mikhail Gusarov
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
""" """
path.py - An object representing a path to a file or directory. path.py - An object representing a path to a file or directory.
@ -29,8 +7,18 @@ Example::
from path import Path from path import Path
d = Path('/home/guido/bin') d = Path('/home/guido/bin')
# Globbing
for f in d.files('*.py'): for f in d.files('*.py'):
f.chmod(0o755) f.chmod(0o755)
# Changing the working directory:
with Path("somewhere"):
# cwd in now `somewhere`
...
# Concatenate paths with /
foo_txt = Path("bar") / "foo.txt"
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
@ -41,7 +29,6 @@ import os
import fnmatch import fnmatch
import glob import glob
import shutil import shutil
import codecs
import hashlib import hashlib
import errno import errno
import tempfile import tempfile
@ -50,8 +37,10 @@ import operator
import re import re
import contextlib import contextlib
import io import io
from distutils import dir_util
import importlib import importlib
import itertools
import platform
import ntpath
try: try:
import win32security import win32security
@ -77,22 +66,17 @@ string_types = str,
text_type = str text_type = str
getcwdu = os.getcwd getcwdu = os.getcwd
def surrogate_escape(error):
"""
Simulate the Python 3 ``surrogateescape`` handler, but for Python 2 only.
"""
chars = error.object[error.start:error.end]
assert len(chars) == 1
val = ord(chars)
val += 0xdc00
return __builtin__.unichr(val), error.end
if PY2: if PY2:
import __builtin__ import __builtin__
string_types = __builtin__.basestring, string_types = __builtin__.basestring,
text_type = __builtin__.unicode text_type = __builtin__.unicode
getcwdu = os.getcwdu getcwdu = os.getcwdu
codecs.register_error('surrogateescape', surrogate_escape) map = itertools.imap
filter = itertools.ifilter
FileNotFoundError = OSError
itertools.filterfalse = itertools.ifilterfalse
@contextlib.contextmanager @contextlib.contextmanager
def io_error_compat(): def io_error_compat():
@ -107,7 +91,8 @@ def io_error_compat():
############################################################################## ##############################################################################
__all__ = ['Path', 'CaseInsensitivePattern']
__all__ = ['Path', 'TempDir', 'CaseInsensitivePattern']
LINESEPS = ['\r\n', '\r', '\n'] LINESEPS = ['\r\n', '\r', '\n']
@ -119,8 +104,8 @@ U_NL_END = re.compile(r'(?:{0})$'.format(U_NEWLINE.pattern))
try: try:
import pkg_resources import importlib_metadata
__version__ = pkg_resources.require('path.py')[0].version __version__ = importlib_metadata.version('path.py')
except Exception: except Exception:
__version__ = 'unknown' __version__ = 'unknown'
@ -131,7 +116,7 @@ class TreeWalkWarning(Warning):
# from jaraco.functools # from jaraco.functools
def compose(*funcs): def compose(*funcs):
compose_two = lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)) compose_two = lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)) # noqa
return functools.reduce(compose_two, funcs) return functools.reduce(compose_two, funcs)
@ -170,6 +155,60 @@ class multimethod(object):
) )
class matchers(object):
# TODO: make this class a module
@staticmethod
def load(param):
"""
If the supplied parameter is a string, assum it's a simple
pattern.
"""
return (
matchers.Pattern(param) if isinstance(param, string_types)
else param if param is not None
else matchers.Null()
)
class Base(object):
pass
class Null(Base):
def __call__(self, path):
return True
class Pattern(Base):
def __init__(self, pattern):
self.pattern = pattern
def get_pattern(self, normcase):
try:
return self._pattern
except AttributeError:
pass
self._pattern = normcase(self.pattern)
return self._pattern
def __call__(self, path):
normcase = getattr(self, 'normcase', path.module.normcase)
pattern = self.get_pattern(normcase)
return fnmatch.fnmatchcase(normcase(path.name), pattern)
class CaseInsensitive(Pattern):
"""
A Pattern with a ``'normcase'`` property, suitable for passing to
:meth:`listdir`, :meth:`dirs`, :meth:`files`, :meth:`walk`,
:meth:`walkdirs`, or :meth:`walkfiles` to match case-insensitive.
For example, to get all files ending in .py, .Py, .pY, or .PY in the
current directory::
from path import Path, matchers
Path('.').files(matchers.CaseInsensitive('*.py'))
"""
normcase = staticmethod(ntpath.normcase)
class Path(text_type): class Path(text_type):
""" """
Represents a filesystem path. Represents a filesystem path.
@ -214,16 +253,6 @@ class Path(text_type):
""" """
return cls return cls
@classmethod
def _always_unicode(cls, path):
"""
Ensure the path as retrieved from a Python API, such as :func:`os.listdir`,
is a proper Unicode string.
"""
if PY3 or isinstance(path, text_type):
return path
return path.decode(sys.getfilesystemencoding(), 'surrogateescape')
# --- Special Python methods. # --- Special Python methods.
def __repr__(self): def __repr__(self):
@ -277,6 +306,9 @@ class Path(text_type):
def __exit__(self, *_): def __exit__(self, *_):
os.chdir(self._old_dir) os.chdir(self._old_dir)
def __fspath__(self):
return self
@classmethod @classmethod
def getcwd(cls): def getcwd(cls):
""" Return the current working directory as a path object. """ Return the current working directory as a path object.
@ -330,23 +362,45 @@ class Path(text_type):
return self.expandvars().expanduser().normpath() return self.expandvars().expanduser().normpath()
@property @property
def namebase(self): def stem(self):
""" The same as :meth:`name`, but with one file extension stripped off. """ The same as :meth:`name`, but with one file extension stripped off.
For example, >>> Path('/home/guido/python.tar.gz').stem
``Path('/home/guido/python.tar.gz').name == 'python.tar.gz'``, 'python.tar'
but
``Path('/home/guido/python.tar.gz').namebase == 'python.tar'``.
""" """
base, ext = self.module.splitext(self.name) base, ext = self.module.splitext(self.name)
return base return base
@property
def namebase(self):
warnings.warn("Use .stem instead of .namebase", DeprecationWarning)
return self.stem
@property @property
def ext(self): def ext(self):
""" The file extension, for example ``'.py'``. """ """ The file extension, for example ``'.py'``. """
f, ext = self.module.splitext(self) f, ext = self.module.splitext(self)
return ext return ext
def with_suffix(self, suffix):
""" Return a new path with the file suffix changed (or added, if none)
>>> Path('/home/guido/python.tar.gz').with_suffix(".foo")
Path('/home/guido/python.tar.foo')
>>> Path('python').with_suffix('.zip')
Path('python.zip')
>>> Path('filename.ext').with_suffix('zip')
Traceback (most recent call last):
...
ValueError: Invalid suffix 'zip'
"""
if not suffix.startswith('.'):
raise ValueError("Invalid suffix {suffix!r}".format(**locals()))
return self.stripext() + suffix
@property @property
def drive(self): def drive(self):
""" The drive specifier, for example ``'C:'``. """ The drive specifier, for example ``'C:'``.
@ -437,8 +491,9 @@ class Path(text_type):
@multimethod @multimethod
def joinpath(cls, first, *others): def joinpath(cls, first, *others):
""" """
Join first to zero or more :class:`Path` components, adding a separator Join first to zero or more :class:`Path` components,
character (:samp:`{first}.module.sep`) if needed. Returns a new instance of adding a separator character (:samp:`{first}.module.sep`)
if needed. Returns a new instance of
:samp:`{first}._next_class`. :samp:`{first}._next_class`.
.. seealso:: :func:`os.path.join` .. seealso:: :func:`os.path.join`
@ -516,7 +571,7 @@ class Path(text_type):
# --- Listing, searching, walking, and matching # --- Listing, searching, walking, and matching
def listdir(self, pattern=None): def listdir(self, match=None):
""" D.listdir() -> List of items in this directory. """ D.listdir() -> List of items in this directory.
Use :meth:`files` or :meth:`dirs` instead if you want a listing Use :meth:`files` or :meth:`dirs` instead if you want a listing
@ -524,46 +579,39 @@ class Path(text_type):
The elements of the list are Path objects. The elements of the list are Path objects.
With the optional `pattern` argument, this only lists With the optional `match` argument, a callable,
items whose names match the given pattern. only return items whose names match the given pattern.
.. seealso:: :meth:`files`, :meth:`dirs` .. seealso:: :meth:`files`, :meth:`dirs`
""" """
if pattern is None: match = matchers.load(match)
pattern = '*' return list(filter(match, (
return [ self / child for child in os.listdir(self)
self / child )))
for child in map(self._always_unicode, os.listdir(self))
if self._next_class(child).fnmatch(pattern)
]
def dirs(self, pattern=None): def dirs(self, *args, **kwargs):
""" D.dirs() -> List of this directory's subdirectories. """ D.dirs() -> List of this directory's subdirectories.
The elements of the list are Path objects. The elements of the list are Path objects.
This does not walk recursively into subdirectories This does not walk recursively into subdirectories
(but see :meth:`walkdirs`). (but see :meth:`walkdirs`).
With the optional `pattern` argument, this only lists Accepts parameters to :meth:`listdir`.
directories whose names match the given pattern. For
example, ``d.dirs('build-*')``.
""" """
return [p for p in self.listdir(pattern) if p.isdir()] return [p for p in self.listdir(*args, **kwargs) if p.isdir()]
def files(self, pattern=None): def files(self, *args, **kwargs):
""" D.files() -> List of the files in this directory. """ D.files() -> List of the files in this directory.
The elements of the list are Path objects. The elements of the list are Path objects.
This does not walk into subdirectories (see :meth:`walkfiles`). This does not walk into subdirectories (see :meth:`walkfiles`).
With the optional `pattern` argument, this only lists files Accepts parameters to :meth:`listdir`.
whose names match the given pattern. For example,
``d.files('*.pyc')``.
""" """
return [p for p in self.listdir(pattern) if p.isfile()] return [p for p in self.listdir(*args, **kwargs) if p.isfile()]
def walk(self, pattern=None, errors='strict'): def walk(self, match=None, errors='strict'):
""" D.walk() -> iterator over files and subdirs, recursively. """ D.walk() -> iterator over files and subdirs, recursively.
The iterator yields Path objects naming each child item of The iterator yields Path objects naming each child item of
@ -593,6 +641,8 @@ class Path(text_type):
raise ValueError("invalid errors parameter") raise ValueError("invalid errors parameter")
errors = vars(Handlers).get(errors, errors) errors = vars(Handlers).get(errors, errors)
match = matchers.load(match)
try: try:
childList = self.listdir() childList = self.listdir()
except Exception: except Exception:
@ -603,7 +653,7 @@ class Path(text_type):
return return
for child in childList: for child in childList:
if pattern is None or child.fnmatch(pattern): if match(child):
yield child yield child
try: try:
isdir = child.isdir() isdir = child.isdir()
@ -615,92 +665,26 @@ class Path(text_type):
isdir = False isdir = False
if isdir: if isdir:
for item in child.walk(pattern, errors): for item in child.walk(errors=errors, match=match):
yield item yield item
def walkdirs(self, pattern=None, errors='strict'): def walkdirs(self, *args, **kwargs):
""" D.walkdirs() -> iterator over subdirs, recursively. """ D.walkdirs() -> iterator over subdirs, recursively.
With the optional `pattern` argument, this yields only
directories whose names match the given pattern. For
example, ``mydir.walkdirs('*test')`` yields only directories
with names ending in ``'test'``.
The `errors=` keyword argument controls behavior when an
error occurs. The default is ``'strict'``, which causes an
exception. The other allowed values are ``'warn'`` (which
reports the error via :func:`warnings.warn()`), and ``'ignore'``.
""" """
if errors not in ('strict', 'warn', 'ignore'): return (
raise ValueError("invalid errors parameter") item
for item in self.walk(*args, **kwargs)
if item.isdir()
)
try: def walkfiles(self, *args, **kwargs):
dirs = self.dirs()
except Exception:
if errors == 'ignore':
return
elif errors == 'warn':
warnings.warn(
"Unable to list directory '%s': %s"
% (self, sys.exc_info()[1]),
TreeWalkWarning)
return
else:
raise
for child in dirs:
if pattern is None or child.fnmatch(pattern):
yield child
for subsubdir in child.walkdirs(pattern, errors):
yield subsubdir
def walkfiles(self, pattern=None, errors='strict'):
""" D.walkfiles() -> iterator over files in D, recursively. """ D.walkfiles() -> iterator over files in D, recursively.
The optional argument `pattern` limits the results to files
with names that match the pattern. For example,
``mydir.walkfiles('*.tmp')`` yields only files with the ``.tmp``
extension.
""" """
if errors not in ('strict', 'warn', 'ignore'): return (
raise ValueError("invalid errors parameter") item
for item in self.walk(*args, **kwargs)
try: if item.isfile()
childList = self.listdir() )
except Exception:
if errors == 'ignore':
return
elif errors == 'warn':
warnings.warn(
"Unable to list directory '%s': %s"
% (self, sys.exc_info()[1]),
TreeWalkWarning)
return
else:
raise
for child in childList:
try:
isfile = child.isfile()
isdir = not isfile and child.isdir()
except:
if errors == 'ignore':
continue
elif errors == 'warn':
warnings.warn(
"Unable to access '%s': %s"
% (self, sys.exc_info()[1]),
TreeWalkWarning)
continue
else:
raise
if isfile:
if pattern is None or child.fnmatch(pattern):
yield child
elif isdir:
for f in child.walkfiles(pattern, errors):
yield f
def fnmatch(self, pattern, normcase=None): def fnmatch(self, pattern, normcase=None):
""" Return ``True`` if `self.name` matches the given `pattern`. """ Return ``True`` if `self.name` matches the given `pattern`.
@ -710,8 +694,8 @@ class Path(text_type):
attribute, it is applied to the name and path prior to comparison. attribute, it is applied to the name and path prior to comparison.
`normcase` - (optional) A function used to normalize the pattern and `normcase` - (optional) A function used to normalize the pattern and
filename before matching. Defaults to :meth:`self.module`, which defaults filename before matching. Defaults to :meth:`self.module`, which
to :meth:`os.path.normcase`. defaults to :meth:`os.path.normcase`.
.. seealso:: :func:`fnmatch.fnmatch` .. seealso:: :func:`fnmatch.fnmatch`
""" """
@ -730,10 +714,32 @@ class Path(text_type):
of all the files users have in their :file:`bin` directories. of all the files users have in their :file:`bin` directories.
.. seealso:: :func:`glob.glob` .. seealso:: :func:`glob.glob`
.. note:: Glob is **not** recursive, even when using ``**``.
To do recursive globbing see :func:`walk`,
:func:`walkdirs` or :func:`walkfiles`.
""" """
cls = self._next_class cls = self._next_class
return [cls(s) for s in glob.glob(self / pattern)] return [cls(s) for s in glob.glob(self / pattern)]
def iglob(self, pattern):
""" Return an iterator of Path objects that match the pattern.
`pattern` - a path relative to this directory, with wildcards.
For example, ``Path('/users').iglob('*/bin/*')`` returns an
iterator of all the files users have in their :file:`bin`
directories.
.. seealso:: :func:`glob.iglob`
.. note:: Glob is **not** recursive, even when using ``**``.
To do recursive globbing see :func:`walk`,
:func:`walkdirs` or :func:`walkfiles`.
"""
cls = self._next_class
return (cls(s) for s in glob.iglob(self / pattern))
# #
# --- Reading or writing an entire file at once. # --- Reading or writing an entire file at once.
@ -882,14 +888,8 @@ class Path(text_type):
translated to ``'\n'``. If ``False``, newline characters are translated to ``'\n'``. If ``False``, newline characters are
stripped off. Default is ``True``. stripped off. Default is ``True``.
This uses ``'U'`` mode.
.. seealso:: :meth:`text` .. seealso:: :meth:`text`
""" """
if encoding is None and retain:
with self.open('U') as f:
return f.readlines()
else:
return self.text(encoding, errors).splitlines(retain) return self.text(encoding, errors).splitlines(retain)
def write_lines(self, lines, encoding=None, errors='strict', def write_lines(self, lines, encoding=None, errors='strict',
@ -931,14 +931,15 @@ class Path(text_type):
to read the file later. to read the file later.
""" """
with self.open('ab' if append else 'wb') as f: with self.open('ab' if append else 'wb') as f:
for l in lines: for line in lines:
isUnicode = isinstance(l, text_type) isUnicode = isinstance(line, text_type)
if linesep is not None: if linesep is not None:
pattern = U_NL_END if isUnicode else NL_END pattern = U_NL_END if isUnicode else NL_END
l = pattern.sub('', l) + linesep line = pattern.sub('', line) + linesep
if isUnicode: if isUnicode:
l = l.encode(encoding or sys.getdefaultencoding(), errors) line = line.encode(
f.write(l) encoding or sys.getdefaultencoding(), errors)
f.write(line)
def read_md5(self): def read_md5(self):
""" Calculate the md5 hash for this file. """ Calculate the md5 hash for this file.
@ -952,8 +953,8 @@ class Path(text_type):
def _hash(self, hash_name): def _hash(self, hash_name):
""" Returns a hash object for the file at the current path. """ Returns a hash object for the file at the current path.
`hash_name` should be a hash algo name (such as ``'md5'`` or ``'sha1'``) `hash_name` should be a hash algo name (such as ``'md5'``
that's available in the :mod:`hashlib` module. or ``'sha1'``) that's available in the :mod:`hashlib` module.
""" """
m = hashlib.new(hash_name) m = hashlib.new(hash_name)
for chunk in self.chunks(8192, mode="rb"): for chunk in self.chunks(8192, mode="rb"):
@ -1176,7 +1177,8 @@ class Path(text_type):
gid = grp.getgrnam(gid).gr_gid gid = grp.getgrnam(gid).gr_gid
os.chown(self, uid, gid) os.chown(self, uid, gid)
else: else:
raise NotImplementedError("Ownership not available on this platform.") msg = "Ownership not available on this platform."
raise NotImplementedError(msg)
return self return self
def rename(self, new): def rename(self, new):
@ -1236,7 +1238,8 @@ class Path(text_type):
self.rmdir() self.rmdir()
except OSError: except OSError:
_, e, _ = sys.exc_info() _, e, _ = sys.exc_info()
if e.errno != errno.ENOTEMPTY and e.errno != errno.EEXIST: bypass_codes = errno.ENOTEMPTY, errno.EEXIST, errno.ENOENT
if e.errno not in bypass_codes:
raise raise
return self return self
@ -1277,9 +1280,8 @@ class Path(text_type):
file does not exist. """ file does not exist. """
try: try:
self.unlink() self.unlink()
except OSError: except FileNotFoundError as exc:
_, e, _ = sys.exc_info() if PY2 and exc.errno != errno.ENOENT:
if e.errno != errno.ENOENT:
raise raise
return self return self
@ -1306,11 +1308,16 @@ class Path(text_type):
return self._next_class(newpath) return self._next_class(newpath)
if hasattr(os, 'symlink'): if hasattr(os, 'symlink'):
def symlink(self, newlink): def symlink(self, newlink=None):
""" Create a symbolic link at `newlink`, pointing here. """ Create a symbolic link at `newlink`, pointing here.
If newlink is not supplied, the symbolic link will assume
the name self.basename(), creating the link in the cwd.
.. seealso:: :func:`os.symlink` .. seealso:: :func:`os.symlink`
""" """
if newlink is None:
newlink = self.basename()
os.symlink(self, newlink) os.symlink(self, newlink)
return self._next_class(newlink) return self._next_class(newlink)
@ -1368,30 +1375,60 @@ class Path(text_type):
cd = chdir cd = chdir
def merge_tree(self, dst, symlinks=False, *args, **kwargs): def merge_tree(
self, dst, symlinks=False,
# *
update=False,
copy_function=shutil.copy2,
ignore=lambda dir, contents: []):
""" """
Copy entire contents of self to dst, overwriting existing Copy entire contents of self to dst, overwriting existing
contents in dst with those in self. contents in dst with those in self.
If the additional keyword `update` is True, each Pass ``symlinks=True`` to copy symbolic links as links.
`src` will only be copied if `dst` does not exist,
or `src` is newer than `dst`.
Note that the technique employed stages the files in a temporary Accepts a ``copy_function``, similar to copytree.
directory first, so this function is not suitable for merging
trees with large files, especially if the temporary directory To avoid overwriting newer files, supply a copy function
is not capable of storing a copy of the entire source tree. wrapped in ``only_newer``. For example::
src.merge_tree(dst, copy_function=only_newer(shutil.copy2))
""" """
update = kwargs.pop('update', False) dst = self._next_class(dst)
with tempdir() as _temp_dir: dst.makedirs_p()
# first copy the tree to a stage directory to support
# the parameters and behavior of copytree. if update:
stage = _temp_dir / str(hash(self)) warnings.warn(
self.copytree(stage, symlinks, *args, **kwargs) "Update is deprecated; "
# now copy everything from the stage directory using "use copy_function=only_newer(shutil.copy2)",
# the semantics of dir_util.copy_tree DeprecationWarning,
dir_util.copy_tree(stage, dst, preserve_symlinks=symlinks, stacklevel=2,
update=update) )
copy_function = only_newer(copy_function)
sources = self.listdir()
_ignored = ignore(self, [item.name for item in sources])
def ignored(item):
return item.name in _ignored
for source in itertools.filterfalse(ignored, sources):
dest = dst / source.name
if symlinks and source.islink():
target = source.readlink()
target.symlink(dest)
elif source.isdir():
source.merge_tree(
dest,
symlinks=symlinks,
update=update,
copy_function=copy_function,
ignore=ignore,
)
else:
copy_function(source, dest)
self.copystat(dst)
# #
# --- Special stuff from os # --- Special stuff from os
@ -1410,19 +1447,23 @@ class Path(text_type):
# in-place re-writing, courtesy of Martijn Pieters # in-place re-writing, courtesy of Martijn Pieters
# http://www.zopatista.com/python/2013/11/26/inplace-file-rewriting/ # http://www.zopatista.com/python/2013/11/26/inplace-file-rewriting/
@contextlib.contextmanager @contextlib.contextmanager
def in_place(self, mode='r', buffering=-1, encoding=None, errors=None, def in_place(
newline=None, backup_extension=None): self, mode='r', buffering=-1, encoding=None, errors=None,
newline=None, backup_extension=None,
):
""" """
A context in which a file may be re-written in-place with new content. A context in which a file may be re-written in-place with
new content.
Yields a tuple of :samp:`({readable}, {writable})` file objects, where `writable` Yields a tuple of :samp:`({readable}, {writable})` file
replaces `readable`. objects, where `writable` replaces `readable`.
If an exception occurs, the old file is restored, removing the If an exception occurs, the old file is restored, removing the
written data. written data.
Mode *must not* use ``'w'``, ``'a'``, or ``'+'``; only read-only-modes are Mode *must not* use ``'w'``, ``'a'``, or ``'+'``; only
allowed. A :exc:`ValueError` is raised on invalid modes. read-only-modes are allowed. A :exc:`ValueError` is raised
on invalid modes.
For example, to add line numbers to a file:: For example, to add line numbers to a file::
@ -1448,22 +1489,28 @@ class Path(text_type):
except os.error: except os.error:
pass pass
os.rename(self, backup_fn) os.rename(self, backup_fn)
readable = io.open(backup_fn, mode, buffering=buffering, readable = io.open(
encoding=encoding, errors=errors, newline=newline) backup_fn, mode, buffering=buffering,
encoding=encoding, errors=errors, newline=newline,
)
try: try:
perm = os.fstat(readable.fileno()).st_mode perm = os.fstat(readable.fileno()).st_mode
except OSError: except OSError:
writable = open(self, 'w' + mode.replace('r', ''), writable = open(
self, 'w' + mode.replace('r', ''),
buffering=buffering, encoding=encoding, errors=errors, buffering=buffering, encoding=encoding, errors=errors,
newline=newline) newline=newline,
)
else: else:
os_mode = os.O_CREAT | os.O_WRONLY | os.O_TRUNC os_mode = os.O_CREAT | os.O_WRONLY | os.O_TRUNC
if hasattr(os, 'O_BINARY'): if hasattr(os, 'O_BINARY'):
os_mode |= os.O_BINARY os_mode |= os.O_BINARY
fd = os.open(self, os_mode, perm) fd = os.open(self, os_mode, perm)
writable = io.open(fd, "w" + mode.replace('r', ''), writable = io.open(
fd, "w" + mode.replace('r', ''),
buffering=buffering, encoding=encoding, errors=errors, buffering=buffering, encoding=encoding, errors=errors,
newline=newline) newline=newline,
)
try: try:
if hasattr(os, 'chmod'): if hasattr(os, 'chmod'):
os.chmod(self, perm) os.chmod(self, perm)
@ -1516,6 +1563,23 @@ class Path(text_type):
return functools.partial(SpecialResolver, cls) return functools.partial(SpecialResolver, cls)
def only_newer(copy_func):
"""
Wrap a copy function (like shutil.copy2) to return
the dst if it's newer than the source.
"""
@functools.wraps(copy_func)
def wrapper(src, dst, *args, **kwargs):
is_newer_dst = (
dst.exists()
and dst.getmtime() >= src.getmtime()
)
if is_newer_dst:
return dst
return copy_func(src, dst, *args, **kwargs)
return wrapper
class SpecialResolver(object): class SpecialResolver(object):
class ResolverScope: class ResolverScope:
def __init__(self, paths, scope): def __init__(self, paths, scope):
@ -1584,14 +1648,15 @@ class Multi:
) )
class tempdir(Path): class TempDir(Path):
""" """
A temporary directory via :func:`tempfile.mkdtemp`, and constructed with the A temporary directory via :func:`tempfile.mkdtemp`, and
same parameters that you can use as a context manager. constructed with the same parameters that you can use
as a context manager.
Example: Example::
with tempdir() as d: with TempDir() as d:
# do stuff with the Path object "d" # do stuff with the Path object "d"
# here the directory is deleted automatically # here the directory is deleted automatically
@ -1606,19 +1671,27 @@ class tempdir(Path):
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
dirname = tempfile.mkdtemp(*args, **kwargs) dirname = tempfile.mkdtemp(*args, **kwargs)
return super(tempdir, cls).__new__(cls, dirname) return super(TempDir, cls).__new__(cls, dirname)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
pass pass
def __enter__(self): def __enter__(self):
return self # TempDir should return a Path version of itself and not itself
# so that a second context manager does not create a second
# temporary directory, but rather changes CWD to the location
# of the temporary directory.
return self._next_class(self)
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
if not exc_value: if not exc_value:
self.rmtree() self.rmtree()
# For backwards compatibility.
tempdir = TempDir
def _multi_permission_mask(mode): def _multi_permission_mask(mode):
""" """
Support multiple, comma-separated Unix chmod symbolic modes. Support multiple, comma-separated Unix chmod symbolic modes.
@ -1626,7 +1699,8 @@ def _multi_permission_mask(mode):
>>> _multi_permission_mask('a=r,u+w')(0) == 0o644 >>> _multi_permission_mask('a=r,u+w')(0) == 0o644
True True
""" """
compose = lambda f, g: lambda *args, **kwargs: g(f(*args, **kwargs)) def compose(f, g):
return lambda *args, **kwargs: g(f(*args, **kwargs))
return functools.reduce(compose, map(_permission_mask, mode.split(','))) return functools.reduce(compose, map(_permission_mask, mode.split(',')))
@ -1692,31 +1766,56 @@ def _permission_mask(mode):
return functools.partial(op_map[op], mask) return functools.partial(op_map[op], mask)
class CaseInsensitivePattern(text_type): class CaseInsensitivePattern(matchers.CaseInsensitive):
def __init__(self, value):
warnings.warn(
"Use matchers.CaseInsensitive instead",
DeprecationWarning,
stacklevel=2,
)
super(CaseInsensitivePattern, self).__init__(value)
class FastPath(Path):
def __init__(self, *args, **kwargs):
warnings.warn(
"Use Path, as FastPath no longer holds any advantage",
DeprecationWarning,
stacklevel=2,
)
super(FastPath, self).__init__(*args, **kwargs)
def patch_for_linux_python2():
""" """
A string with a ``'normcase'`` property, suitable for passing to As reported in #130, when Linux users create filenames
:meth:`listdir`, :meth:`dirs`, :meth:`files`, :meth:`walk`, not in the file system encoding, it creates problems on
:meth:`walkdirs`, or :meth:`walkfiles` to match case-insensitive. Python 2. This function attempts to patch the os module
to make it behave more like that on Python 3.
For example, to get all files ending in .py, .Py, .pY, or .PY in the
current directory::
from path import Path, CaseInsensitivePattern as ci
Path('.').files(ci('*.py'))
""" """
if not PY2 or platform.system() != 'Linux':
return
@property try:
def normcase(self): import backports.os
return __import__('ntpath').normcase except ImportError:
return
######################## class OS:
# Backward-compatibility """
class path(Path): The proxy to the os module
def __new__(cls, *args, **kwargs): """
msg = "path is deprecated. Use Path instead." def __init__(self, wrapped):
warnings.warn(msg, DeprecationWarning) self._orig = wrapped
return Path.__new__(cls, *args, **kwargs)
def __getattr__(self, name):
return getattr(self._orig, name)
def listdir(self, *args, **kwargs):
items = self._orig.listdir(*args, **kwargs)
return list(map(backports.os.fsdecode, items))
globals().update(os=OS(os))
__all__ += ['path'] patch_for_linux_python2()
########################

View file

@ -22,25 +22,54 @@ import os
import sys import sys
import shutil import shutil
import time import time
import types
import ntpath import ntpath
import posixpath import posixpath
import textwrap import textwrap
import platform import platform
import importlib import importlib
import operator
import datetime
import subprocess
import re
import pytest import pytest
import packaging.version
from path import Path, tempdir import path
from path import CaseInsensitivePattern as ci from path import TempDir
from path import matchers
from path import SpecialResolver from path import SpecialResolver
from path import Multi from path import Multi
Path = None
def p(**choices): def p(**choices):
""" Choose a value from several possible values, based on os.name """ """ Choose a value from several possible values, based on os.name """
return choices[os.name] return choices[os.name]
@pytest.fixture(autouse=True, params=[path.Path])
def path_class(request, monkeypatch):
"""
Invoke tests on any number of Path classes.
"""
monkeypatch.setitem(globals(), 'Path', request.param)
def mac_version(target, comparator=operator.ge):
"""
Return True if on a Mac whose version passes the comparator.
"""
current_ver = packaging.version.parse(platform.mac_ver()[0])
target_ver = packaging.version.parse(target)
return (
platform.system() == 'Darwin'
and comparator(current_ver, target_ver)
)
class TestBasics: class TestBasics:
def test_relpath(self): def test_relpath(self):
root = Path(p(nt='C:\\', posix='/')) root = Path(p(nt='C:\\', posix='/'))
@ -51,14 +80,14 @@ class TestBasics:
up = Path(os.pardir) up = Path(os.pardir)
# basics # basics
assert root.relpathto(boz) == Path('foo')/'bar'/'Baz'/'Boz' assert root.relpathto(boz) == Path('foo') / 'bar' / 'Baz' / 'Boz'
assert bar.relpathto(boz) == Path('Baz')/'Boz' assert bar.relpathto(boz) == Path('Baz') / 'Boz'
assert quux.relpathto(boz) == up/'bar'/'Baz'/'Boz' assert quux.relpathto(boz) == up / 'bar' / 'Baz' / 'Boz'
assert boz.relpathto(quux) == up/up/up/'quux' assert boz.relpathto(quux) == up / up / up / 'quux'
assert boz.relpathto(bar) == up/up assert boz.relpathto(bar) == up / up
# Path is not the first element in concatenation # Path is not the first element in concatenation
assert root.relpathto(boz) == 'foo'/Path('bar')/'Baz'/'Boz' assert root.relpathto(boz) == 'foo' / Path('bar') / 'Baz' / 'Boz'
# x.relpathto(x) == curdir # x.relpathto(x) == curdir
assert root.relpathto(root) == os.curdir assert root.relpathto(root) == os.curdir
@ -112,7 +141,7 @@ class TestBasics:
# Test p1/p1. # Test p1/p1.
p1 = Path("foo") p1 = Path("foo")
p2 = Path("bar") p2 = Path("bar")
assert p1/p2 == p(nt='foo\\bar', posix='foo/bar') assert p1 / p2 == p(nt='foo\\bar', posix='foo/bar')
def test_properties(self): def test_properties(self):
# Create sample path object. # Create sample path object.
@ -207,6 +236,30 @@ class TestBasics:
assert res2 == 'foo/bar' assert res2 == 'foo/bar'
class TestPerformance:
@pytest.mark.skipif(
path.PY2,
reason="Tests fail frequently on Python 2; see #153")
def test_import_time(self, monkeypatch):
"""
Import of path.py should take less than 100ms.
Run tests in a subprocess to isolate from test suite overhead.
"""
cmd = [
sys.executable,
'-m', 'timeit',
'-n', '1',
'-r', '1',
'import path',
]
res = subprocess.check_output(cmd, universal_newlines=True)
dur = re.search(r'(\d+) msec per loop', res).group(1)
limit = datetime.timedelta(milliseconds=100)
duration = datetime.timedelta(milliseconds=int(dur))
assert duration < limit
class TestSelfReturn: class TestSelfReturn:
""" """
Some methods don't necessarily return any value (e.g. makedirs, Some methods don't necessarily return any value (e.g. makedirs,
@ -246,7 +299,7 @@ class TestSelfReturn:
class TestScratchDir: class TestScratchDir:
""" """
Tests that run in a temporary directory (does not test tempdir class) Tests that run in a temporary directory (does not test TempDir class)
""" """
def test_context_manager(self, tmpdir): def test_context_manager(self, tmpdir):
"""Can be used as context manager for chdir.""" """Can be used as context manager for chdir."""
@ -282,12 +335,12 @@ class TestScratchDir:
ct = f.ctime ct = f.ctime
assert t0 <= ct <= t1 assert t0 <= ct <= t1
time.sleep(threshold*2) time.sleep(threshold * 2)
fobj = open(f, 'ab') fobj = open(f, 'ab')
fobj.write('some bytes'.encode('utf-8')) fobj.write('some bytes'.encode('utf-8'))
fobj.close() fobj.close()
time.sleep(threshold*2) time.sleep(threshold * 2)
t2 = time.time() - threshold t2 = time.time() - threshold
f.touch() f.touch()
t3 = time.time() + threshold t3 = time.time() + threshold
@ -305,9 +358,12 @@ class TestScratchDir:
assert ct == ct2 assert ct == ct2
assert ct2 < t2 assert ct2 < t2
else: else:
# On other systems, it might be the CHANGE time assert (
# (especially on Unix, time of inode changes) # ctime is unchanged
assert ct == ct2 or ct2 == f.mtime ct == ct2 or
# ctime is approximately the mtime
ct2 == pytest.approx(f.mtime, 0.001)
)
def test_listing(self, tmpdir): def test_listing(self, tmpdir):
d = Path(tmpdir) d = Path(tmpdir)
@ -330,6 +386,11 @@ class TestScratchDir:
assert d.glob('*') == [af] assert d.glob('*') == [af]
assert d.glob('*.html') == [] assert d.glob('*.html') == []
assert d.glob('testfile') == [] assert d.glob('testfile') == []
# .iglob matches .glob but as an iterator.
assert list(d.iglob('*')) == d.glob('*')
assert isinstance(d.iglob('*'), types.GeneratorType)
finally: finally:
af.remove() af.remove()
@ -348,9 +409,17 @@ class TestScratchDir:
for f in files: for f in files:
try: try:
f.remove() f.remove()
except: except Exception:
pass pass
@pytest.mark.xfail(
mac_version('10.13'),
reason="macOS disallows invalid encodings",
)
@pytest.mark.xfail(
platform.system() == 'Windows' and path.PY3,
reason="Can't write latin characters. See #133",
)
def test_listdir_other_encoding(self, tmpdir): def test_listdir_other_encoding(self, tmpdir):
""" """
Some filesystems allow non-character sequences in path names. Some filesystems allow non-character sequences in path names.
@ -498,28 +567,28 @@ class TestScratchDir:
def test_patterns(self, tmpdir): def test_patterns(self, tmpdir):
d = Path(tmpdir) d = Path(tmpdir)
names = ['x.tmp', 'x.xtmp', 'x2g', 'x22', 'x.txt'] names = ['x.tmp', 'x.xtmp', 'x2g', 'x22', 'x.txt']
dirs = [d, d/'xdir', d/'xdir.tmp', d/'xdir.tmp'/'xsubdir'] dirs = [d, d / 'xdir', d / 'xdir.tmp', d / 'xdir.tmp' / 'xsubdir']
for e in dirs: for e in dirs:
if not e.isdir(): if not e.isdir():
e.makedirs() e.makedirs()
for name in names: for name in names:
(e/name).touch() (e / name).touch()
self.assertList(d.listdir('*.tmp'), [d/'x.tmp', d/'xdir.tmp']) self.assertList(d.listdir('*.tmp'), [d / 'x.tmp', d / 'xdir.tmp'])
self.assertList(d.files('*.tmp'), [d/'x.tmp']) self.assertList(d.files('*.tmp'), [d / 'x.tmp'])
self.assertList(d.dirs('*.tmp'), [d/'xdir.tmp']) self.assertList(d.dirs('*.tmp'), [d / 'xdir.tmp'])
self.assertList(d.walk(), [e for e in dirs self.assertList(d.walk(), [e for e in dirs
if e != d] + [e/n for e in dirs if e != d] + [e / n for e in dirs
for n in names]) for n in names])
self.assertList(d.walk('*.tmp'), self.assertList(d.walk('*.tmp'),
[e/'x.tmp' for e in dirs] + [d/'xdir.tmp']) [e / 'x.tmp' for e in dirs] + [d / 'xdir.tmp'])
self.assertList(d.walkfiles('*.tmp'), [e/'x.tmp' for e in dirs]) self.assertList(d.walkfiles('*.tmp'), [e / 'x.tmp' for e in dirs])
self.assertList(d.walkdirs('*.tmp'), [d/'xdir.tmp']) self.assertList(d.walkdirs('*.tmp'), [d / 'xdir.tmp'])
def test_unicode(self, tmpdir): def test_unicode(self, tmpdir):
d = Path(tmpdir) d = Path(tmpdir)
p = d/'unicode.txt' p = d / 'unicode.txt'
def test(enc): def test(enc):
""" Test that path works with the specified encoding, """ Test that path works with the specified encoding,
@ -527,18 +596,22 @@ class TestScratchDir:
Unicode codepoints. Unicode codepoints.
""" """
given = ('Hello world\n' given = (
'Hello world\n'
'\u0d0a\u0a0d\u0d15\u0a15\r\n' '\u0d0a\u0a0d\u0d15\u0a15\r\n'
'\u0d0a\u0a0d\u0d15\u0a15\x85' '\u0d0a\u0a0d\u0d15\u0a15\x85'
'\u0d0a\u0a0d\u0d15\u0a15\u2028' '\u0d0a\u0a0d\u0d15\u0a15\u2028'
'\r' '\r'
'hanging') 'hanging'
clean = ('Hello world\n' )
clean = (
'Hello world\n'
'\u0d0a\u0a0d\u0d15\u0a15\n' '\u0d0a\u0a0d\u0d15\u0a15\n'
'\u0d0a\u0a0d\u0d15\u0a15\n' '\u0d0a\u0a0d\u0d15\u0a15\n'
'\u0d0a\u0a0d\u0d15\u0a15\n' '\u0d0a\u0a0d\u0d15\u0a15\n'
'\n' '\n'
'hanging') 'hanging'
)
givenLines = [ givenLines = [
('Hello world\n'), ('Hello world\n'),
('\u0d0a\u0a0d\u0d15\u0a15\r\n'), ('\u0d0a\u0a0d\u0d15\u0a15\r\n'),
@ -581,8 +654,9 @@ class TestScratchDir:
return return
# Write Unicode to file using path.write_text(). # Write Unicode to file using path.write_text().
cleanNoHanging = clean + '\n' # This test doesn't work with a # This test doesn't work with a hanging line.
# hanging line. cleanNoHanging = clean + '\n'
p.write_text(cleanNoHanging, enc) p.write_text(cleanNoHanging, enc)
p.write_text(cleanNoHanging, enc, append=True) p.write_text(cleanNoHanging, enc, append=True)
# Check the result. # Check the result.
@ -641,7 +715,7 @@ class TestScratchDir:
test('UTF-16') test('UTF-16')
def test_chunks(self, tmpdir): def test_chunks(self, tmpdir):
p = (tempdir() / 'test.txt').touch() p = (TempDir() / 'test.txt').touch()
txt = "0123456789" txt = "0123456789"
size = 5 size = 5
p.write_text(txt) p.write_text(txt)
@ -650,16 +724,18 @@ class TestScratchDir:
assert i == len(txt) / size - 1 assert i == len(txt) / size - 1
@pytest.mark.skipif(not hasattr(os.path, 'samefile'), @pytest.mark.skipif(
reason="samefile not present") not hasattr(os.path, 'samefile'),
reason="samefile not present",
)
def test_samefile(self, tmpdir): def test_samefile(self, tmpdir):
f1 = (tempdir() / '1.txt').touch() f1 = (TempDir() / '1.txt').touch()
f1.write_text('foo') f1.write_text('foo')
f2 = (tempdir() / '2.txt').touch() f2 = (TempDir() / '2.txt').touch()
f1.write_text('foo') f1.write_text('foo')
f3 = (tempdir() / '3.txt').touch() f3 = (TempDir() / '3.txt').touch()
f1.write_text('bar') f1.write_text('bar')
f4 = (tempdir() / '4.txt') f4 = (TempDir() / '4.txt')
f1.copyfile(f4) f1.copyfile(f4)
assert os.path.samefile(f1, f2) == f1.samefile(f2) assert os.path.samefile(f1, f2) == f1.samefile(f2)
@ -680,6 +756,26 @@ class TestScratchDir:
self.fail("Calling `rmtree_p` on non-existent directory " self.fail("Calling `rmtree_p` on non-existent directory "
"should not raise an exception.") "should not raise an exception.")
def test_rmdir_p_exists(self, tmpdir):
"""
Invocation of rmdir_p on an existant directory should
remove the directory.
"""
d = Path(tmpdir)
sub = d / 'subfolder'
sub.mkdir()
sub.rmdir_p()
assert not sub.exists()
def test_rmdir_p_nonexistent(self, tmpdir):
"""
A non-existent file should not raise an exception.
"""
d = Path(tmpdir)
sub = d / 'subfolder'
assert not sub.exists()
sub.rmdir_p()
class TestMergeTree: class TestMergeTree:
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@ -701,6 +797,11 @@ class TestMergeTree:
else: else:
self.test_file.copy(self.test_link) self.test_file.copy(self.test_link)
def check_link(self):
target = Path(self.subdir_b / self.test_link.name)
check = target.islink if hasattr(os, 'symlink') else target.isfile
assert check()
def test_with_nonexisting_dst_kwargs(self): def test_with_nonexisting_dst_kwargs(self):
self.subdir_a.merge_tree(self.subdir_b, symlinks=True) self.subdir_a.merge_tree(self.subdir_b, symlinks=True)
assert self.subdir_b.isdir() assert self.subdir_b.isdir()
@ -709,7 +810,7 @@ class TestMergeTree:
self.subdir_b / self.test_link.name, self.subdir_b / self.test_link.name,
)) ))
assert set(self.subdir_b.listdir()) == expected assert set(self.subdir_b.listdir()) == expected
assert Path(self.subdir_b / self.test_link.name).islink() self.check_link()
def test_with_nonexisting_dst_args(self): def test_with_nonexisting_dst_args(self):
self.subdir_a.merge_tree(self.subdir_b, True) self.subdir_a.merge_tree(self.subdir_b, True)
@ -719,7 +820,7 @@ class TestMergeTree:
self.subdir_b / self.test_link.name, self.subdir_b / self.test_link.name,
)) ))
assert set(self.subdir_b.listdir()) == expected assert set(self.subdir_b.listdir()) == expected
assert Path(self.subdir_b / self.test_link.name).islink() self.check_link()
def test_with_existing_dst(self): def test_with_existing_dst(self):
self.subdir_b.rmtree() self.subdir_b.rmtree()
@ -740,7 +841,7 @@ class TestMergeTree:
self.subdir_b / test_new.name, self.subdir_b / test_new.name,
)) ))
assert set(self.subdir_b.listdir()) == expected assert set(self.subdir_b.listdir()) == expected
assert Path(self.subdir_b / self.test_link.name).islink() self.check_link()
assert len(Path(self.subdir_b / self.test_file.name).bytes()) == 5000 assert len(Path(self.subdir_b / self.test_file.name).bytes()) == 5000
def test_copytree_parameters(self): def test_copytree_parameters(self):
@ -753,6 +854,20 @@ class TestMergeTree:
assert self.subdir_b.isdir() assert self.subdir_b.isdir()
assert self.subdir_b.listdir() == [self.subdir_b / self.test_file.name] assert self.subdir_b.listdir() == [self.subdir_b / self.test_file.name]
def test_only_newer(self):
"""
merge_tree should accept a copy_function in which only
newer files are copied and older files do not overwrite
newer copies in the dest.
"""
target = self.subdir_b / 'testfile.txt'
target.write_text('this is newer')
self.subdir_a.merge_tree(
self.subdir_b,
copy_function=path.only_newer(shutil.copy2),
)
assert target.text() == 'this is newer'
class TestChdir: class TestChdir:
def test_chdir_or_cd(self, tmpdir): def test_chdir_or_cd(self, tmpdir):
@ -781,17 +896,17 @@ class TestChdir:
class TestSubclass: class TestSubclass:
class PathSubclass(Path):
pass
def test_subclass_produces_same_class(self): def test_subclass_produces_same_class(self):
""" """
When operations are invoked on a subclass, they should produce another When operations are invoked on a subclass, they should produce another
instance of that subclass. instance of that subclass.
""" """
p = self.PathSubclass('/foo') class PathSubclass(Path):
pass
p = PathSubclass('/foo')
subdir = p / 'bar' subdir = p / 'bar'
assert isinstance(subdir, self.PathSubclass) assert isinstance(subdir, PathSubclass)
class TestTempDir: class TestTempDir:
@ -800,8 +915,8 @@ class TestTempDir:
""" """
One should be able to readily construct a temporary directory One should be able to readily construct a temporary directory
""" """
d = tempdir() d = TempDir()
assert isinstance(d, Path) assert isinstance(d, path.Path)
assert d.exists() assert d.exists()
assert d.isdir() assert d.isdir()
d.rmdir() d.rmdir()
@ -809,24 +924,24 @@ class TestTempDir:
def test_next_class(self): def test_next_class(self):
""" """
It should be possible to invoke operations on a tempdir and get It should be possible to invoke operations on a TempDir and get
Path classes. Path classes.
""" """
d = tempdir() d = TempDir()
sub = d / 'subdir' sub = d / 'subdir'
assert isinstance(sub, Path) assert isinstance(sub, path.Path)
d.rmdir() d.rmdir()
def test_context_manager(self): def test_context_manager(self):
""" """
One should be able to use a tempdir object as a context, which will One should be able to use a TempDir object as a context, which will
clean up the contents after. clean up the contents after.
""" """
d = tempdir() d = TempDir()
res = d.__enter__() res = d.__enter__()
assert res is d assert res == path.Path(d)
(d / 'somefile.txt').touch() (d / 'somefile.txt').touch()
assert not isinstance(d / 'somefile.txt', tempdir) assert not isinstance(d / 'somefile.txt', TempDir)
d.__exit__(None, None, None) d.__exit__(None, None, None)
assert not d.exists() assert not d.exists()
@ -834,10 +949,10 @@ class TestTempDir:
""" """
The context manager will not clean up if an exception occurs. The context manager will not clean up if an exception occurs.
""" """
d = tempdir() d = TempDir()
d.__enter__() d.__enter__()
(d / 'somefile.txt').touch() (d / 'somefile.txt').touch()
assert not isinstance(d / 'somefile.txt', tempdir) assert not isinstance(d / 'somefile.txt', TempDir)
d.__exit__(TypeError, TypeError('foo'), None) d.__exit__(TypeError, TypeError('foo'), None)
assert d.exists() assert d.exists()
@ -847,7 +962,7 @@ class TestTempDir:
provide a temporry directory that will be deleted after that. provide a temporry directory that will be deleted after that.
""" """
with tempdir() as d: with TempDir() as d:
assert d.isdir() assert d.isdir()
assert not d.isdir() assert not d.isdir()
@ -876,7 +991,8 @@ class TestPatternMatching:
assert p.fnmatch('FOO[ABC]AR') assert p.fnmatch('FOO[ABC]AR')
def test_fnmatch_custom_normcase(self): def test_fnmatch_custom_normcase(self):
normcase = lambda path: path.upper() def normcase(path):
return path.upper()
p = Path('FooBar') p = Path('FooBar')
assert p.fnmatch('foobar', normcase=normcase) assert p.fnmatch('foobar', normcase=normcase)
assert p.fnmatch('FOO[ABC]AR', normcase=normcase) assert p.fnmatch('FOO[ABC]AR', normcase=normcase)
@ -891,8 +1007,8 @@ class TestPatternMatching:
def test_listdir_patterns(self, tmpdir): def test_listdir_patterns(self, tmpdir):
p = Path(tmpdir) p = Path(tmpdir)
(p/'sub').mkdir() (p / 'sub').mkdir()
(p/'File').touch() (p / 'File').touch()
assert p.listdir('s*') == [p / 'sub'] assert p.listdir('s*') == [p / 'sub']
assert len(p.listdir('*')) == 2 assert len(p.listdir('*')) == 2
@ -903,14 +1019,14 @@ class TestPatternMatching:
""" """
always_unix = Path.using_module(posixpath) always_unix = Path.using_module(posixpath)
p = always_unix(tmpdir) p = always_unix(tmpdir)
(p/'sub').mkdir() (p / 'sub').mkdir()
(p/'File').touch() (p / 'File').touch()
assert p.listdir('S*') == [] assert p.listdir('S*') == []
always_win = Path.using_module(ntpath) always_win = Path.using_module(ntpath)
p = always_win(tmpdir) p = always_win(tmpdir)
assert p.listdir('S*') == [p/'sub'] assert p.listdir('S*') == [p / 'sub']
assert p.listdir('f*') == [p/'File'] assert p.listdir('f*') == [p / 'File']
def test_listdir_case_insensitive(self, tmpdir): def test_listdir_case_insensitive(self, tmpdir):
""" """
@ -918,27 +1034,30 @@ class TestPatternMatching:
used by that Path class. used by that Path class.
""" """
p = Path(tmpdir) p = Path(tmpdir)
(p/'sub').mkdir() (p / 'sub').mkdir()
(p/'File').touch() (p / 'File').touch()
assert p.listdir(ci('S*')) == [p/'sub'] assert p.listdir(matchers.CaseInsensitive('S*')) == [p / 'sub']
assert p.listdir(ci('f*')) == [p/'File'] assert p.listdir(matchers.CaseInsensitive('f*')) == [p / 'File']
assert p.files(ci('S*')) == [] assert p.files(matchers.CaseInsensitive('S*')) == []
assert p.dirs(ci('f*')) == [] assert p.dirs(matchers.CaseInsensitive('f*')) == []
def test_walk_case_insensitive(self, tmpdir): def test_walk_case_insensitive(self, tmpdir):
p = Path(tmpdir) p = Path(tmpdir)
(p/'sub1'/'foo').makedirs_p() (p / 'sub1' / 'foo').makedirs_p()
(p/'sub2'/'foo').makedirs_p() (p / 'sub2' / 'foo').makedirs_p()
(p/'sub1'/'foo'/'bar.Txt').touch() (p / 'sub1' / 'foo' / 'bar.Txt').touch()
(p/'sub2'/'foo'/'bar.TXT').touch() (p / 'sub2' / 'foo' / 'bar.TXT').touch()
(p/'sub2'/'foo'/'bar.txt.bz2').touch() (p / 'sub2' / 'foo' / 'bar.txt.bz2').touch()
files = list(p.walkfiles(ci('*.txt'))) files = list(p.walkfiles(matchers.CaseInsensitive('*.txt')))
assert len(files) == 2 assert len(files) == 2
assert p/'sub2'/'foo'/'bar.TXT' in files assert p / 'sub2' / 'foo' / 'bar.TXT' in files
assert p/'sub1'/'foo'/'bar.Txt' in files assert p / 'sub1' / 'foo' / 'bar.Txt' in files
@pytest.mark.skipif(sys.version_info < (2, 6),
reason="in_place requires io module in Python 2.6") @pytest.mark.skipif(
sys.version_info < (2, 6),
reason="in_place requires io module in Python 2.6",
)
class TestInPlace: class TestInPlace:
reference_content = textwrap.dedent(""" reference_content = textwrap.dedent("""
The quick brown fox jumped over the lazy dog. The quick brown fox jumped over the lazy dog.
@ -959,7 +1078,7 @@ class TestInPlace:
@classmethod @classmethod
def create_reference(cls, tmpdir): def create_reference(cls, tmpdir):
p = Path(tmpdir)/'document' p = Path(tmpdir) / 'document'
with p.open('w') as stream: with p.open('w') as stream:
stream.write(cls.reference_content) stream.write(cls.reference_content)
return p return p
@ -984,7 +1103,7 @@ class TestInPlace:
assert "some error" in str(exc) assert "some error" in str(exc)
with doc.open() as stream: with doc.open() as stream:
data = stream.read() data = stream.read()
assert not 'Lorem' in data assert 'Lorem' not in data
assert 'lazy dog' in data assert 'lazy dog' in data
@ -1023,8 +1142,9 @@ class TestSpecialPaths:
def test_unix_paths_fallback(self, tmpdir, monkeypatch, feign_linux): def test_unix_paths_fallback(self, tmpdir, monkeypatch, feign_linux):
"Without XDG_CONFIG_HOME set, ~/.config should be used." "Without XDG_CONFIG_HOME set, ~/.config should be used."
fake_home = tmpdir / '_home' fake_home = tmpdir / '_home'
monkeypatch.delitem(os.environ, 'XDG_CONFIG_HOME', raising=False)
monkeypatch.setitem(os.environ, 'HOME', str(fake_home)) monkeypatch.setitem(os.environ, 'HOME', str(fake_home))
expected = str(tmpdir / '_home' / '.config') expected = Path('~/.config').expanduser()
assert SpecialResolver(Path).user.config == expected assert SpecialResolver(Path).user.config == expected
def test_property(self): def test_property(self):
@ -1075,7 +1195,8 @@ class TestMultiPath:
cls = Multi.for_class(Path) cls = Multi.for_class(Path)
assert issubclass(cls, Path) assert issubclass(cls, Path)
assert issubclass(cls, Multi) assert issubclass(cls, Multi)
assert cls.__name__ == 'MultiPath' expected_name = 'Multi' + Path.__name__
assert cls.__name__ == expected_name
def test_detect_no_pathsep(self): def test_detect_no_pathsep(self):
""" """
@ -1115,5 +1236,23 @@ class TestMultiPath:
assert path == input assert path == input
if __name__ == '__main__': @pytest.mark.xfail('path.PY2', reason="Python 2 has no __future__")
pytest.main() def test_no_dependencies():
"""
Path.py guarantees that the path module can be
transplanted into an environment without any dependencies.
"""
cmd = [
sys.executable,
'-S',
'-c', 'import path',
]
subprocess.check_call(cmd)
def test_version():
"""
Under normal circumstances, path should present a
__version__.
"""
assert re.match(r'\d+\.\d+.*', path.__version__)