Merge branch 'nightly' into fix/unvendor

This commit is contained in:
Labrys of Knossos 2018-12-15 15:21:14 -05:00
commit 05c3de0f36
218 changed files with 40366 additions and 2132 deletions

2
.gitignore vendored
View file

@ -9,3 +9,5 @@
/userscripts/ /userscripts/
/logs/ /logs/
/.idea/ /.idea/
/libs*/*.dist-info
/libs*/*.egg-info

View file

@ -3,12 +3,12 @@
import datetime import datetime
import os import os
import sys import sys
import core
from libs.six import text_type import core
from core import logger, nzbToMediaDB from core import logger, nzbToMediaDB
from core.nzbToMediaUtil import convert_to_ascii, CharReplace, plex_update, replace_links
from core.nzbToMediaUserScript import external_script from core.nzbToMediaUserScript import external_script
from core.nzbToMediaUtil import CharReplace, convert_to_ascii, plex_update, replace_links
from libs.six import text_type
def processTorrent(inputDirectory, inputName, inputCategory, inputHash, inputID, clientAgent): def processTorrent(inputDirectory, inputName, inputCategory, inputHash, inputID, clientAgent):

View file

@ -5,10 +5,10 @@ from __future__ import print_function
import itertools import itertools
import locale import locale
import os import os
import platform
import re import re
import subprocess import subprocess
import sys import sys
import platform
import time import time
@ -31,23 +31,25 @@ CONFIG_TV_FILE = os.path.join(PROGRAM_DIR, 'autoProcessTv.cfg')
TEST_FILE = os.path.join(os.path.join(PROGRAM_DIR, 'tests'), 'test.mp4') TEST_FILE = os.path.join(os.path.join(PROGRAM_DIR, 'tests'), 'test.mp4')
MYAPP = None MYAPP = None
import six
from six.moves import reload_module from six.moves import reload_module
from core import logger, nzbToMediaDB, versionCheck
from core.autoProcess.autoProcessComics import autoProcessComics from core.autoProcess.autoProcessComics import autoProcessComics
from core.autoProcess.autoProcessGames import autoProcessGames from core.autoProcess.autoProcessGames import autoProcessGames
from core.autoProcess.autoProcessMovie import autoProcessMovie from core.autoProcess.autoProcessMovie import autoProcessMovie
from core.autoProcess.autoProcessMusic import autoProcessMusic from core.autoProcess.autoProcessMusic import autoProcessMusic
from core.autoProcess.autoProcessTV import autoProcessTV from core.autoProcess.autoProcessTV import autoProcessTV
from core import logger, versionCheck, nzbToMediaDB from core.databases import mainDB
from core.nzbToMediaConfig import config from core.nzbToMediaConfig import config
from core.nzbToMediaUtil import ( from core.nzbToMediaUtil import (
category_search, sanitizeName, copy_link, parse_args, flatten, getDirs, RunningProcess, WakeUp, category_search, cleanDir, cleanDir, copy_link,
rmReadOnly, rmDir, pause_torrent, resume_torrent, remove_torrent, listMediaFiles, create_torrent_class, extractFiles, flatten, getDirs, get_downloadInfo,
extractFiles, cleanDir, update_downloadInfoStatus, get_downloadInfo, WakeUp, makeDir, cleanDir, listMediaFiles, makeDir, parse_args, pause_torrent, remove_torrent,
create_torrent_class, listMediaFiles, RunningProcess, resume_torrent, rmDir, rmReadOnly, sanitizeName, update_downloadInfoStatus,
) )
from core.transcoder import transcoder from core.transcoder import transcoder
from core.databases import mainDB
# Client Agents # Client Agents
NZB_CLIENTS = ['sabnzbd', 'nzbget', 'manual'] NZB_CLIENTS = ['sabnzbd', 'nzbget', 'manual']
@ -269,6 +271,7 @@ def initialize(section=None):
if not SYS_ENCODING or SYS_ENCODING in ('ANSI_X3.4-1968', 'US-ASCII', 'ASCII'): if not SYS_ENCODING or SYS_ENCODING in ('ANSI_X3.4-1968', 'US-ASCII', 'ASCII'):
SYS_ENCODING = 'UTF-8' SYS_ENCODING = 'UTF-8'
if six.PY2:
if not hasattr(sys, "setdefaultencoding"): if not hasattr(sys, "setdefaultencoding"):
reload_module(sys) reload_module(sys)

View file

@ -1,11 +1,12 @@
# coding=utf-8 # coding=utf-8
import os import os
import core
import requests import requests
from core.nzbToMediaUtil import convert_to_ascii, remoteDir, server_responding import core
from core import logger from core import logger
from core.nzbToMediaUtil import convert_to_ascii, remoteDir, server_responding
requests.packages.urllib3.disable_warnings() requests.packages.urllib3.disable_warnings()

View file

@ -1,12 +1,13 @@
# coding=utf-8 # coding=utf-8
import os import os
import core
import requests
import shutil import shutil
from core.nzbToMediaUtil import convert_to_ascii, server_responding import requests
import core
from core import logger from core import logger
from core.nzbToMediaUtil import convert_to_ascii, server_responding
requests.packages.urllib3.disable_warnings() requests.packages.urllib3.disable_warnings()

View file

@ -1,14 +1,15 @@
# coding=utf-8 # coding=utf-8
import json
import os import os
import time import time
import requests
import json
import core
from core.nzbToMediaSceneExceptions import process_all_exceptions import requests
from core.nzbToMediaUtil import convert_to_ascii, rmDir, find_imdbid, find_download, listMediaFiles, remoteDir, import_subs, server_responding, reportNzb
import core
from core import logger from core import logger
from core.nzbToMediaSceneExceptions import process_all_exceptions
from core.nzbToMediaUtil import convert_to_ascii, find_download, find_imdbid, import_subs, listMediaFiles, remoteDir, reportNzb, rmDir, server_responding
from core.transcoder import transcoder from core.transcoder import transcoder
requests.packages.urllib3.disable_warnings() requests.packages.urllib3.disable_warnings()

View file

@ -1,14 +1,15 @@
# coding=utf-8 # coding=utf-8
import json
import os import os
import time import time
import requests
import core
import json
from core.nzbToMediaUtil import convert_to_ascii, rmDir, remoteDir, listMediaFiles, server_responding import requests
from core.nzbToMediaSceneExceptions import process_all_exceptions
import core
from core import logger from core import logger
from core.nzbToMediaSceneExceptions import process_all_exceptions
from core.nzbToMediaUtil import convert_to_ascii, listMediaFiles, remoteDir, rmDir, server_responding
requests.packages.urllib3.disable_warnings() requests.packages.urllib3.disable_warnings()

View file

@ -1,17 +1,18 @@
# coding=utf-8 # coding=utf-8
import copy import copy
import errno
import json
import os import os
import time import time
import errno
import requests
import json
import core
import requests
import core
from core import logger
from core.nzbToMediaAutoFork import autoFork from core.nzbToMediaAutoFork import autoFork
from core.nzbToMediaSceneExceptions import process_all_exceptions from core.nzbToMediaSceneExceptions import process_all_exceptions
from core.nzbToMediaUtil import convert_to_ascii, flatten, rmDir, listMediaFiles, remoteDir, import_subs, server_responding, reportNzb from core.nzbToMediaUtil import convert_to_ascii, flatten, import_subs, listMediaFiles, remoteDir, reportNzb, rmDir, server_responding
from core import logger
from core.transcoder import transcoder from core.transcoder import transcoder
requests.packages.urllib3.disable_warnings() requests.packages.urllib3.disable_warnings()

View file

@ -1,2 +1 @@
# coding=utf-8 # coding=utf-8
__all__ = ["mainDB"]

View file

@ -4,10 +4,11 @@ import os
import platform import platform
import shutil import shutil
import stat import stat
from time import sleep
import core
from subprocess import call, Popen
import subprocess import subprocess
from subprocess import Popen, call
from time import sleep
import core
def extract(filePath, outputDestination): def extract(filePath, outputDestination):

View file

@ -1,10 +1,10 @@
# coding=utf-8 # coding=utf-8
from __future__ import with_statement
import logging
import os import os
import sys import sys
import threading import threading
import logging
import core import core
# number of log files to keep # number of log files to keep

View file

@ -1,7 +1,6 @@
# coding=utf-8 # coding=utf-8
import requests import requests
from six import iteritems from six import iteritems
import core import core

View file

@ -1,15 +1,16 @@
# coding=utf-8 # coding=utf-8
from six import iteritems import copy
import os import os
import shutil import shutil
import copy
import core
from configobj import *
from core import logger
from itertools import chain from itertools import chain
import configobj
from six import iteritems
import core
from core import logger
class Section(configobj.Section, object): class Section(configobj.Section, object):
def isenabled(section): def isenabled(section):

View file

@ -1,6 +1,6 @@
# coding=utf-8 # coding=utf-8
from __future__ import print_function, with_statement from __future__ import print_function
import re import re
import sqlite3 import sqlite3

View file

@ -1,10 +1,12 @@
# coding=utf-8 # coding=utf-8
import os import os
import re
import core
import shlex
import platform import platform
import re
import shlex
import subprocess import subprocess
import core
from core import logger from core import logger
from core.nzbToMediaUtil import listMediaFiles from core.nzbToMediaUtil import listMediaFiles

View file

@ -1,10 +1,12 @@
# coding=utf-8 # coding=utf-8
import os import os
import core
from subprocess import Popen from subprocess import Popen
from core.transcoder import transcoder
from core.nzbToMediaUtil import import_subs, listMediaFiles, rmDir import core
from core import logger from core import logger
from core.nzbToMediaUtil import import_subs, listMediaFiles, rmDir
from core.transcoder import transcoder
def external_script(outputDestination, torrentName, torrentLabel, settings): def external_script(outputDestination, torrentName, torrentLabel, settings):

View file

@ -1314,11 +1314,18 @@ class RunningProcess(object):
class WindowsProcess(object): class WindowsProcess(object):
def __init__(self): def __init__(self):
self.mutex = None
self.mutexname = "nzbtomedia_{pid}".format(pid=core.PID_FILE.replace('\\', '/')) # {D0E858DF-985E-4907-B7FB-8D732C3FC3B9}" self.mutexname = "nzbtomedia_{pid}".format(pid=core.PID_FILE.replace('\\', '/')) # {D0E858DF-985E-4907-B7FB-8D732C3FC3B9}"
if platform.system() == 'Windows': if platform.system() == 'Windows':
try:
from win32.win32event import CreateMutex
from win32.win32api import CloseHandle, GetLastError
from win32.lib.winerror import ERROR_ALREADY_EXISTS
except ImportError:
from win32event import CreateMutex from win32event import CreateMutex
from win32api import CloseHandle, GetLastError from win32api import CloseHandle, GetLastError
from winerror import ERROR_ALREADY_EXISTS from winerror import ERROR_ALREADY_EXISTS
self.CreateMutex = CreateMutex self.CreateMutex = CreateMutex
self.CloseHandle = CloseHandle self.CloseHandle = CloseHandle
self.GetLastError = GetLastError self.GetLastError = GetLastError

View file

@ -1,17 +1,19 @@
# coding=utf-8 # coding=utf-8
from six import iteritems
import errno import errno
import json
import os import os
import platform import platform
import subprocess
import core
import json
import shutil
import re import re
import shutil
import subprocess
from babelfish import Language
from six import iteritems
import core
from core import logger from core import logger
from core.nzbToMediaUtil import makeDir from core.nzbToMediaUtil import makeDir
from babelfish import Language
def isVideoGood(videofile, status): def isVideoGood(videofile, status):

View file

@ -4,17 +4,16 @@
import os import os
import platform import platform
import shutil
import subprocess
import re import re
import urllib import shutil
import tarfile
import stat import stat
import subprocess
import tarfile
import traceback import traceback
import gh_api as github import urllib
import core import core
from core import logger from core import gh_api as github, logger
class CheckVersion(object): class CheckVersion(object):
@ -182,6 +181,7 @@ class GitUpdateManager(UpdateManager):
if output: if output:
output = output.strip() output = output.strip()
output = output.decode('utf-8')
if core.LOG_GIT: if core.LOG_GIT:
logger.log(u"git output: {output}".format(output=output), logger.DEBUG) logger.log(u"git output: {output}".format(output=output), logger.DEBUG)

Binary file not shown.

BIN
libs/bin/easy_install.exe Normal file

Binary file not shown.

BIN
libs/bin/enver.exe Normal file

Binary file not shown.

BIN
libs/bin/find-symlinks.exe Normal file

Binary file not shown.

BIN
libs/bin/gclip.exe Normal file

Binary file not shown.

BIN
libs/bin/mklink.exe Normal file

Binary file not shown.

BIN
libs/bin/pclip.exe Normal file

Binary file not shown.

BIN
libs/bin/xmouse.exe Normal file

Binary file not shown.

5
libs/easy_install.py Normal file
View file

@ -0,0 +1,5 @@
"""Run the EasyInstall command"""
if __name__ == '__main__':
from setuptools.command.easy_install import main
main()

View file

@ -0,0 +1,17 @@
from .api import distribution, Distribution, PackageNotFoundError # noqa: F401
from .api import metadata, entry_points, resolve, version, read_text
# Import for installation side-effects.
from . import _hooks # noqa: F401
__all__ = [
'metadata',
'entry_points',
'resolve',
'version',
'read_text',
]
__version__ = version(__name__)

View file

@ -0,0 +1,148 @@
from __future__ import unicode_literals, absolute_import
import re
import sys
import itertools
from .api import Distribution
from zipfile import ZipFile
if sys.version_info >= (3,): # pragma: nocover
from contextlib import suppress
from pathlib import Path
else: # pragma: nocover
from contextlib2 import suppress # noqa
from itertools import imap as map # type: ignore
from pathlib2 import Path
FileNotFoundError = IOError, OSError
__metaclass__ = type
def install(cls):
"""Class decorator for installation on sys.meta_path."""
sys.meta_path.append(cls)
return cls
class NullFinder:
@staticmethod
def find_spec(*args, **kwargs):
return None
# In Python 2, the import system requires finders
# to have a find_module() method, but this usage
# is deprecated in Python 3 in favor of find_spec().
# For the purposes of this finder (i.e. being present
# on sys.meta_path but having no other import
# system functionality), the two methods are identical.
find_module = find_spec
@install
class MetadataPathFinder(NullFinder):
"""A degenerate finder for distribution packages on the file system.
This finder supplies only a find_distribution() method for versions
of Python that do not have a PathFinder find_distribution().
"""
search_template = r'{name}(-.*)?\.(dist|egg)-info'
@classmethod
def find_distribution(cls, name):
paths = cls._search_paths(name)
dists = map(PathDistribution, paths)
return next(dists, None)
@classmethod
def _search_paths(cls, name):
"""
Find metadata directories in sys.path heuristically.
"""
return itertools.chain.from_iterable(
cls._search_path(path, name)
for path in map(Path, sys.path)
)
@classmethod
def _search_path(cls, root, name):
if not root.is_dir():
return ()
normalized = name.replace('-', '_')
return (
item
for item in root.iterdir()
if item.is_dir()
and re.match(
cls.search_template.format(name=normalized),
str(item.name),
flags=re.IGNORECASE,
)
)
class PathDistribution(Distribution):
def __init__(self, path):
"""Construct a distribution from a path to the metadata directory."""
self._path = path
def read_text(self, filename):
with suppress(FileNotFoundError):
with self._path.joinpath(filename).open(encoding='utf-8') as fp:
return fp.read()
return None
read_text.__doc__ = Distribution.read_text.__doc__
@install
class WheelMetadataFinder(NullFinder):
"""A degenerate finder for distribution packages in wheels.
This finder supplies only a find_distribution() method for versions
of Python that do not have a PathFinder find_distribution().
"""
search_template = r'{name}(-.*)?\.whl'
@classmethod
def find_distribution(cls, name):
paths = cls._search_paths(name)
dists = map(WheelDistribution, paths)
return next(dists, None)
@classmethod
def _search_paths(cls, name):
return (
item
for item in map(Path, sys.path)
if re.match(
cls.search_template.format(name=name),
str(item.name),
flags=re.IGNORECASE,
)
)
class WheelDistribution(Distribution):
def __init__(self, archive):
self._archive = archive
name, version = archive.name.split('-')[0:2]
self._dist_info = '{}-{}.dist-info'.format(name, version)
def read_text(self, filename):
with ZipFile(_path_to_filename(self._archive)) as zf:
with suppress(KeyError):
as_bytes = zf.read('{}/{}'.format(self._dist_info, filename))
return as_bytes.decode('utf-8')
return None
read_text.__doc__ = Distribution.read_text.__doc__
def _path_to_filename(path): # pragma: nocover
"""
On non-compliant systems, ensure a path-like object is
a string.
"""
try:
return path.__fspath__()
except AttributeError:
return str(path)

View file

@ -0,0 +1,146 @@
import io
import abc
import sys
import email
from importlib import import_module
if sys.version_info > (3,): # pragma: nocover
from configparser import ConfigParser
else: # pragma: nocover
from ConfigParser import SafeConfigParser as ConfigParser
try:
BaseClass = ModuleNotFoundError
except NameError: # pragma: nocover
BaseClass = ImportError # type: ignore
__metaclass__ = type
class PackageNotFoundError(BaseClass):
"""The package was not found."""
class Distribution:
"""A Python distribution package."""
@abc.abstractmethod
def read_text(self, filename):
"""Attempt to load metadata file given by the name.
:param filename: The name of the file in the distribution info.
:return: The text if found, otherwise None.
"""
@classmethod
def from_name(cls, name):
"""Return the Distribution for the given package name.
:param name: The name of the distribution package to search for.
:return: The Distribution instance (or subclass thereof) for the named
package, if found.
:raises PackageNotFoundError: When the named package's distribution
metadata cannot be found.
"""
for resolver in cls._discover_resolvers():
resolved = resolver(name)
if resolved is not None:
return resolved
else:
raise PackageNotFoundError(name)
@staticmethod
def _discover_resolvers():
"""Search the meta_path for resolvers."""
declared = (
getattr(finder, 'find_distribution', None)
for finder in sys.meta_path
)
return filter(None, declared)
@property
def metadata(self):
"""Return the parsed metadata for this Distribution.
The returned object will have keys that name the various bits of
metadata. See PEP 566 for details.
"""
return email.message_from_string(
self.read_text('METADATA') or self.read_text('PKG-INFO')
)
@property
def version(self):
"""Return the 'Version' metadata for the distribution package."""
return self.metadata['Version']
def distribution(package):
"""Get the ``Distribution`` instance for the given package.
:param package: The name of the package as a string.
:return: A ``Distribution`` instance (or subclass thereof).
"""
return Distribution.from_name(package)
def metadata(package):
"""Get the metadata for the package.
:param package: The name of the distribution package to query.
:return: An email.Message containing the parsed metadata.
"""
return Distribution.from_name(package).metadata
def version(package):
"""Get the version string for the named package.
:param package: The name of the distribution package to query.
:return: The version string for the package as defined in the package's
"Version" metadata key.
"""
return distribution(package).version
def entry_points(name):
"""Return the entry points for the named distribution package.
:param name: The name of the distribution package to query.
:return: A ConfigParser instance where the sections and keys are taken
from the entry_points.txt ini-style contents.
"""
as_string = read_text(name, 'entry_points.txt')
# 2018-09-10(barry): Should we provide any options here, or let the caller
# send options to the underlying ConfigParser? For now, YAGNI.
config = ConfigParser()
try:
config.read_string(as_string)
except AttributeError: # pragma: nocover
# Python 2 has no read_string
config.readfp(io.StringIO(as_string))
return config
def resolve(entry_point):
"""Resolve an entry point string into the named callable.
:param entry_point: An entry point string of the form
`path.to.module:callable`.
:return: The actual callable object `path.to.module.callable`
:raises ValueError: When `entry_point` doesn't have the proper format.
"""
path, colon, name = entry_point.rpartition(':')
if colon != ':':
raise ValueError('Not an entry point: {}'.format(entry_point))
module = import_module(path)
return getattr(module, name)
def read_text(package, filename):
"""
Read the text of the file in the distribution info directory.
"""
return distribution(package).read_text(filename)

View file

View file

@ -0,0 +1,57 @@
=========================
importlib_metadata NEWS
=========================
0.7 (2018-11-27)
================
* Fixed issue where packages with dashes in their names would
not be discovered. Closes #21.
* Distribution lookup is now case-insensitive. Closes #20.
* Wheel distributions can no longer be discovered by their module
name. Like Path distributions, they must be indicated by their
distribution package name.
0.6 (2018-10-07)
================
* Removed ``importlib_metadata.distribution`` function. Now
the public interface is primarily the utility functions exposed
in ``importlib_metadata.__all__``. Closes #14.
* Added two new utility functions ``read_text`` and
``metadata``.
0.5 (2018-09-18)
================
* Updated README and removed details about Distribution
class, now considered private. Closes #15.
* Added test suite support for Python 3.4+.
* Fixed SyntaxErrors on Python 3.4 and 3.5. !12
* Fixed errors on Windows joining Path elements. !15
0.4 (2018-09-14)
================
* Housekeeping.
0.3 (2018-09-14)
================
* Added usage documentation. Closes #8
* Add support for getting metadata from wheels on ``sys.path``. Closes #9
0.2 (2018-09-11)
================
* Added ``importlib_metadata.entry_points()``. Closes #1
* Added ``importlib_metadata.resolve()``. Closes #12
* Add support for Python 2.7. Closes #4
0.1 (2018-09-10)
================
* Initial release.
..
Local Variables:
mode: change-log-mode
indent-tabs-mode: nil
sentence-end-double-space: t
fill-column: 78
coding: utf-8
End:

View file

@ -0,0 +1,180 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# flake8: noqa
#
# importlib_metadata documentation build configuration file, created by
# sphinx-quickstart on Thu Nov 30 10:21:00 2017.
#
# This file is execfile()d with the current directory set to its
# containing dir.
#
# Note that not all possible configuration values are present in this
# autogenerated file.
#
# All configuration values have a default; values that are commented out
# serve to show the default.
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
# import os
# import sys
# sys.path.insert(0, os.path.abspath('.'))
# -- General configuration ------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['sphinx.ext.autodoc',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.coverage',
'sphinx.ext.viewcode']
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
# The master toctree document.
master_doc = 'index'
# General information about the project.
project = 'importlib_metadata'
copyright = '2017-2018, Jason Coombs, Barry Warsaw'
author = 'Jason Coombs, Barry Warsaw'
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
version = '0.1'
# The full version, including alpha/beta/rc tags.
release = '0.1'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This patterns also effect to html_static_path and html_extra_path
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = False
# -- Options for HTML output ----------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'default'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# This is required for the alabaster theme
# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
html_sidebars = {
'**': [
'relations.html', # needs 'show_related': True theme option to display
'searchbox.html',
]
}
# -- Options for HTMLHelp output ------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'importlib_metadatadoc'
# -- Options for LaTeX output ---------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'importlib_metadata.tex', 'importlib\\_metadata Documentation',
'Brett Cannon, Barry Warsaw', 'manual'),
]
# -- Options for manual page output ---------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'importlib_metadata', 'importlib_metadata Documentation',
[author], 1)
]
# -- Options for Texinfo output -------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'importlib_metadata', 'importlib_metadata Documentation',
author, 'importlib_metadata', 'One line description of project.',
'Miscellaneous'),
]
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {
'python': ('https://docs.python.org/3', None),
}

View file

@ -0,0 +1,53 @@
===============================
Welcome to importlib_metadata
===============================
``importlib_metadata`` is a library which provides an API for accessing an
installed package's `metadata`_, such as its entry points or its top-level
name. This functionality intends to replace most uses of ``pkg_resources``
`entry point API`_ and `metadata API`_. Along with ``importlib.resources`` in
`Python 3.7 and newer`_ (backported as `importlib_resources`_ for older
versions of Python), this can eliminate the need to use the older and less
efficient ``pkg_resources`` package.
``importlib_metadata`` is a backport of Python 3.8's standard library
`importlib.metadata`_ module for Python 2.7, and 3.4 through 3.7. Users of
Python 3.8 and beyond are encouraged to use the standard library module, and
in fact for these versions, ``importlib_metadata`` just shadows that module.
Developers looking for detailed API descriptions should refer to the Python
3.8 standard library documentation.
The documentation here includes a general :ref:`usage <using>` guide.
.. toctree::
:maxdepth: 2
:caption: Contents:
using.rst
changelog.rst
Project details
===============
* Project home: https://gitlab.com/python-devs/importlib_metadata
* Report bugs at: https://gitlab.com/python-devs/importlib_metadata/issues
* Code hosting: https://gitlab.com/python-devs/importlib_metadata.git
* Documentation: http://importlib_metadata.readthedocs.io/
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
.. _`metadata`: https://www.python.org/dev/peps/pep-0566/
.. _`entry point API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#entry-points
.. _`metadata API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#metadata-api
.. _`Python 3.7 and newer`: https://docs.python.org/3/library/importlib.html#module-importlib.resources
.. _`importlib_resources`: https://importlib-resources.readthedocs.io/en/latest/index.html
.. _`importlib.metadata`: TBD

View file

@ -0,0 +1,133 @@
.. _using:
==========================
Using importlib_metadata
==========================
``importlib_metadata`` is a library that provides for access to installed
package metadata. Built in part on Python's import system, this library
intends to replace similar functionality in ``pkg_resources`` `entry point
API`_ and `metadata API`_. Along with ``importlib.resources`` in `Python 3.7
and newer`_ (backported as `importlib_resources`_ for older versions of
Python), this can eliminate the need to use the older and less efficient
``pkg_resources`` package.
By "installed package" we generally mean a third party package installed into
Python's ``site-packages`` directory via tools such as ``pip``. Specifically,
it means a package with either a discoverable ``dist-info`` or ``egg-info``
directory, and metadata defined by `PEP 566`_ or its older specifications.
By default, package metadata can live on the file system or in wheels on
``sys.path``. Through an extension mechanism, the metadata can live almost
anywhere.
Overview
========
Let's say you wanted to get the version string for a package you've installed
using ``pip``. We start by creating a virtual environment and installing
something into it::
$ python3 -m venv example
$ source example/bin/activate
(example) $ pip install importlib_metadata
(example) $ pip install wheel
You can get the version string for ``wheel`` by running the following::
(example) $ python
>>> from importlib_metadata import version
>>> version('wheel')
'0.31.1'
You can also get the set of entry points for the ``wheel`` package. Since the
``entry_points.txt`` file is an ``.ini``-style, the ``entry_points()``
function returns a `ConfigParser instance`_. To get the list of command line
entry points, extract the ``console_scripts`` section::
>>> cp = entry_points('wheel')
>>> cp.options('console_scripts')
['wheel']
You can also get the callable that the entry point is mapped to::
>>> cp.get('console_scripts', 'wheel')
'wheel.tool:main'
Even more conveniently, you can resolve this entry point to the actual
callable::
>>> from importlib_metadata import resolve
>>> ep = cp.get('console_scripts', 'wheel')
>>> resolve(ep)
<function main at 0x111b91bf8>
Distributions
=============
While the above API is the most common and convenient usage, you can get all
of that information from the ``Distribution`` class. A ``Distribution`` is an
abstract object that represents the metadata for a Python package. You can
get the ``Distribution`` instance::
>>> from importlib_metadata import distribution
>>> dist = distribution('wheel')
Thus, an alternative way to get the version number is through the
``Distribution`` instance::
>>> dist.version
'0.31.1'
There are all kinds of additional metadata available on the ``Distribution``
instance::
>>> d.metadata['Requires-Python']
'>=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*'
>>> d.metadata['License']
'MIT'
The full set of available metadata is not described here. See PEP 566 for
additional details.
Extending the search algorithm
==============================
Because package metadata is not available through ``sys.path`` searches, or
package loaders directly, the metadata for a package is found through import
system `finders`_. To find a distribution package's metadata,
``importlib_metadata`` queries the list of `meta path finders`_ on
`sys.meta_path`_.
By default ``importlib_metadata`` installs a finder for packages found on the
file system. This finder doesn't actually find any *packages*, but it cany
find the package's metadata.
The abstract class :py:class:`importlib.abc.MetaPathFinder` defines the
interface expected of finders by Python's import system.
``importlib_metadata`` extends this protocol by looking for an optional
``find_distribution()`` ``@classmethod`` on the finders from
``sys.meta_path``. If the finder has this method, it takes a single argument
which is the name of the distribution package to find. The method returns
``None`` if it cannot find the distribution package, otherwise it returns an
instance of the ``Distribution`` abstract class.
What this means in practice is that to support finding distribution package
metadata in locations other than the file system, you should derive from
``Distribution`` and implement the ``load_metadata()`` method. This takes a
single argument which is the name of the package whose metadata is being
found. This instance of the ``Distribution`` base abstract class is what your
finder's ``find_distribution()`` method should return.
.. _`entry point API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#entry-points
.. _`metadata API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#metadata-api
.. _`Python 3.7 and newer`: https://docs.python.org/3/library/importlib.html#module-importlib.resources
.. _`importlib_resources`: https://importlib-resources.readthedocs.io/en/latest/index.html
.. _`PEP 566`: https://www.python.org/dev/peps/pep-0566/
.. _`ConfigParser instance`: https://docs.python.org/3/library/configparser.html#configparser.ConfigParser
.. _`finders`: https://docs.python.org/3/reference/import.html#finders-and-loaders
.. _`meta path finders`: https://docs.python.org/3/glossary.html#term-meta-path-finder
.. _`sys.meta_path`: https://docs.python.org/3/library/sys.html#sys.meta_path

View file

@ -0,0 +1,44 @@
import re
import unittest
import importlib_metadata
class APITests(unittest.TestCase):
version_pattern = r'\d+\.\d+(\.\d)?'
def test_retrieves_version_of_self(self):
version = importlib_metadata.version('importlib_metadata')
assert isinstance(version, str)
assert re.match(self.version_pattern, version)
def test_retrieves_version_of_pip(self):
# Assume pip is installed and retrieve the version of pip.
version = importlib_metadata.version('pip')
assert isinstance(version, str)
assert re.match(self.version_pattern, version)
def test_for_name_does_not_exist(self):
with self.assertRaises(importlib_metadata.PackageNotFoundError):
importlib_metadata.distribution('does-not-exist')
def test_for_top_level(self):
distribution = importlib_metadata.distribution('importlib_metadata')
self.assertEqual(
distribution.read_text('top_level.txt').strip(),
'importlib_metadata')
def test_entry_points(self):
parser = importlib_metadata.entry_points('pip')
# We should probably not be dependent on a third party package's
# internal API staying stable.
entry_point = parser.get('console_scripts', 'pip')
self.assertEqual(entry_point, 'pip._internal:main')
def test_metadata_for_this_package(self):
md = importlib_metadata.metadata('importlib_metadata')
assert md['author'] == 'Barry Warsaw'
assert md['LICENSE'] == 'Apache Software License'
assert md['Name'] == 'importlib-metadata'
classifiers = md.get_all('Classifier')
assert 'Topic :: Software Development :: Libraries' in classifiers

View file

@ -0,0 +1,121 @@
from __future__ import unicode_literals
import re
import sys
import shutil
import tempfile
import unittest
import importlib
import contextlib
import importlib_metadata
try:
from contextlib import ExitStack
except ImportError:
from contextlib2 import ExitStack
try:
import pathlib
except ImportError:
import pathlib2 as pathlib
from importlib_metadata import _hooks
class BasicTests(unittest.TestCase):
version_pattern = r'\d+\.\d+(\.\d)?'
def test_retrieves_version_of_pip(self):
# Assume pip is installed and retrieve the version of pip.
dist = importlib_metadata.Distribution.from_name('pip')
assert isinstance(dist.version, str)
assert re.match(self.version_pattern, dist.version)
def test_for_name_does_not_exist(self):
with self.assertRaises(importlib_metadata.PackageNotFoundError):
importlib_metadata.Distribution.from_name('does-not-exist')
def test_new_style_classes(self):
self.assertIsInstance(importlib_metadata.Distribution, type)
self.assertIsInstance(_hooks.MetadataPathFinder, type)
self.assertIsInstance(_hooks.WheelMetadataFinder, type)
self.assertIsInstance(_hooks.WheelDistribution, type)
class ImportTests(unittest.TestCase):
def test_import_nonexistent_module(self):
# Ensure that the MetadataPathFinder does not crash an import of a
# non-existant module.
with self.assertRaises(ImportError):
importlib.import_module('does_not_exist')
def test_resolve(self):
entry_points = importlib_metadata.entry_points('pip')
main = importlib_metadata.resolve(
entry_points.get('console_scripts', 'pip'))
import pip._internal
self.assertEqual(main, pip._internal.main)
def test_resolve_invalid(self):
self.assertRaises(ValueError, importlib_metadata.resolve, 'bogus.ep')
class NameNormalizationTests(unittest.TestCase):
@staticmethod
def pkg_with_dashes(site_dir):
"""
Create minimal metadata for a package with dashes
in the name (and thus underscores in the filename).
"""
metadata_dir = site_dir / 'my_pkg.dist-info'
metadata_dir.mkdir()
metadata = metadata_dir / 'METADATA'
with metadata.open('w') as strm:
strm.write('Version: 1.0\n')
return 'my-pkg'
@staticmethod
@contextlib.contextmanager
def site_dir():
tmpdir = tempfile.mkdtemp()
sys.path[:0] = [tmpdir]
try:
yield pathlib.Path(tmpdir)
finally:
sys.path.remove(tmpdir)
shutil.rmtree(tmpdir)
def setUp(self):
self.fixtures = ExitStack()
self.addCleanup(self.fixtures.close)
self.site_dir = self.fixtures.enter_context(self.site_dir())
def test_dashes_in_dist_name_found_as_underscores(self):
"""
For a package with a dash in the name, the dist-info metadata
uses underscores in the name. Ensure the metadata loads.
"""
pkg_name = self.pkg_with_dashes(self.site_dir)
assert importlib_metadata.version(pkg_name) == '1.0'
@staticmethod
def pkg_with_mixed_case(site_dir):
"""
Create minimal metadata for a package with mixed case
in the name.
"""
metadata_dir = site_dir / 'CherryPy.dist-info'
metadata_dir.mkdir()
metadata = metadata_dir / 'METADATA'
with metadata.open('w') as strm:
strm.write('Version: 1.0\n')
return 'CherryPy'
def test_dist_name_found_as_any_case(self):
"""
Ensure the metadata loads when queried with any case.
"""
pkg_name = self.pkg_with_mixed_case(self.site_dir)
assert importlib_metadata.version(pkg_name) == '1.0'
assert importlib_metadata.version(pkg_name.lower()) == '1.0'
assert importlib_metadata.version(pkg_name.upper()) == '1.0'

View file

@ -0,0 +1,42 @@
import sys
import unittest
import importlib_metadata
try:
from contextlib import ExitStack
except ImportError:
from contextlib2 import ExitStack
from importlib_resources import path
class BespokeLoader:
archive = 'bespoke'
class TestZip(unittest.TestCase):
def setUp(self):
# Find the path to the example.*.whl so we can add it to the front of
# sys.path, where we'll then try to find the metadata thereof.
self.resources = ExitStack()
self.addCleanup(self.resources.close)
wheel = self.resources.enter_context(
path('importlib_metadata.tests.data',
'example-21.12-py3-none-any.whl'))
sys.path.insert(0, str(wheel))
self.resources.callback(sys.path.pop, 0)
def test_zip_version(self):
self.assertEqual(importlib_metadata.version('example'), '21.12')
def test_zip_entry_points(self):
parser = importlib_metadata.entry_points('example')
entry_point = parser.get('console_scripts', 'example')
self.assertEqual(entry_point, 'example:main')
def test_missing_metadata(self):
distribution = importlib_metadata.distribution('example')
self.assertIsNone(distribution.read_text('does not exist'))
def test_case_insensitive(self):
self.assertEqual(importlib_metadata.version('Example'), '21.12')

View file

@ -0,0 +1 @@
0.7

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -1 +0,0 @@
import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -1 +0,0 @@
import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -1 +0,0 @@
import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -1 +0,0 @@
import sys, types, os;p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));ie = os.path.exists(os.path.join(p,'__init__.py'));m = not ie and sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -0,0 +1 @@
import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p)

View file

@ -1 +0,0 @@
__import__("pkg_resources").declare_namespace(__name__)

View file

@ -5,6 +5,7 @@ of an object and its parent classes.
from __future__ import unicode_literals from __future__ import unicode_literals
def all_bases(c): def all_bases(c):
""" """
return a tuple of all base classes the class c has as a parent. return a tuple of all base classes the class c has as a parent.
@ -13,6 +14,7 @@ def all_bases(c):
""" """
return c.mro()[1:] return c.mro()[1:]
def all_classes(c): def all_classes(c):
""" """
return a tuple of all classes to which c belongs return a tuple of all classes to which c belongs
@ -21,7 +23,10 @@ def all_classes(c):
""" """
return c.mro() return c.mro()
# borrowed from http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/ # borrowed from
# http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/
def iter_subclasses(cls, _seen=None): def iter_subclasses(cls, _seen=None):
""" """
Generator over all subclasses of a given class, in depth-first order. Generator over all subclasses of a given class, in depth-first order.
@ -51,9 +56,12 @@ def iter_subclasses(cls, _seen=None):
""" """
if not isinstance(cls, type): if not isinstance(cls, type):
raise TypeError('iter_subclasses must be called with ' raise TypeError(
'new-style classes, not %.100r' % cls) 'iter_subclasses must be called with '
if _seen is None: _seen = set() 'new-style classes, not %.100r' % cls
)
if _seen is None:
_seen = set()
try: try:
subs = cls.__subclasses__() subs = cls.__subclasses__()
except TypeError: # fails only when cls is type except TypeError: # fails only when cls is type

View file

@ -6,6 +6,7 @@ Some useful metaclasses.
from __future__ import unicode_literals from __future__ import unicode_literals
class LeafClassesMeta(type): class LeafClassesMeta(type):
""" """
A metaclass for classes that keeps track of all of them that A metaclass for classes that keeps track of all of them that

View file

@ -2,8 +2,10 @@ from __future__ import unicode_literals
import six import six
__metaclass__ = type
class NonDataProperty(object):
class NonDataProperty:
"""Much like the property builtin, but only implements __get__, """Much like the property builtin, but only implements __get__,
making it a non-data property, and can be subsequently reset. making it a non-data property, and can be subsequently reset.
@ -34,7 +36,7 @@ class NonDataProperty(object):
# from http://stackoverflow.com/a/5191224 # from http://stackoverflow.com/a/5191224
class ClassPropertyDescriptor(object): class ClassPropertyDescriptor:
def __init__(self, fget, fset=None): def __init__(self, fget, fset=None):
self.fget = fget self.fget = fget

View file

@ -7,12 +7,63 @@ import operator
import collections import collections
import itertools import itertools
import copy import copy
import functools
try:
import collections.abc
except ImportError:
# Python 2.7
collections.abc = collections
import six import six
from jaraco.classes.properties import NonDataProperty from jaraco.classes.properties import NonDataProperty
import jaraco.text import jaraco.text
class Projection(collections.abc.Mapping):
"""
Project a set of keys over a mapping
>>> sample = {'a': 1, 'b': 2, 'c': 3}
>>> prj = Projection(['a', 'c', 'd'], sample)
>>> prj == {'a': 1, 'c': 3}
True
Keys should only appear if they were specified and exist in the space.
>>> sorted(list(prj.keys()))
['a', 'c']
Use the projection to update another dict.
>>> target = {'a': 2, 'b': 2}
>>> target.update(prj)
>>> target == {'a': 1, 'b': 2, 'c': 3}
True
Also note that Projection keeps a reference to the original dict, so
if you modify the original dict, that could modify the Projection.
>>> del sample['a']
>>> dict(prj)
{'c': 3}
"""
def __init__(self, keys, space):
self._keys = tuple(keys)
self._space = space
def __getitem__(self, key):
if key not in self._keys:
raise KeyError(key)
return self._space[key]
def __iter__(self):
return iter(set(self._keys).intersection(self._space))
def __len__(self):
return len(tuple(iter(self)))
class DictFilter(object): class DictFilter(object):
""" """
Takes a dict, and simulates a sub-dict based on the keys. Takes a dict, and simulates a sub-dict based on the keys.
@ -52,7 +103,6 @@ class DictFilter(object):
self.pattern_keys = set() self.pattern_keys = set()
def get_pattern_keys(self): def get_pattern_keys(self):
#key_matches = lambda k, v: self.include_pattern.match(k)
keys = filter(self.include_pattern.match, self.dict.keys()) keys = filter(self.include_pattern.match, self.dict.keys())
return set(keys) return set(keys)
pattern_keys = NonDataProperty(get_pattern_keys) pattern_keys = NonDataProperty(get_pattern_keys)
@ -70,7 +120,7 @@ class DictFilter(object):
return values return values
def __getitem__(self, i): def __getitem__(self, i):
if not i in self.include_keys: if i not in self.include_keys:
return KeyError, i return KeyError, i
return self.dict[i] return self.dict[i]
@ -190,7 +240,7 @@ class RangeMap(dict):
return default return default
def _find_first_match_(self, keys, item): def _find_first_match_(self, keys, item):
is_match = lambda k: self.match(item, k) is_match = functools.partial(self.match, item)
matches = list(filter(is_match, keys)) matches = list(filter(is_match, keys))
if matches: if matches:
return matches[0] return matches[0]
@ -205,12 +255,15 @@ class RangeMap(dict):
# some special values for the RangeMap # some special values for the RangeMap
undefined_value = type(str('RangeValueUndefined'), (object,), {})() undefined_value = type(str('RangeValueUndefined'), (object,), {})()
class Item(int): pass
class Item(int):
"RangeMap Item"
first_item = Item(0) first_item = Item(0)
last_item = Item(-1) last_item = Item(-1)
__identity = lambda x: x def __identity(x):
return x
def sorted_items(d, key=__identity, reverse=False): def sorted_items(d, key=__identity, reverse=False):
@ -229,7 +282,8 @@ def sorted_items(d, key=__identity, reverse=False):
(('foo', 20), ('baz', 10), ('bar', 42)) (('foo', 20), ('baz', 10), ('bar', 42))
""" """
# wrap the key func so it operates on the first element of each item # wrap the key func so it operates on the first element of each item
pairkey_key = lambda item: key(item[0]) def pairkey_key(item):
return key(item[0])
return sorted(d.items(), key=pairkey_key, reverse=reverse) return sorted(d.items(), key=pairkey_key, reverse=reverse)
@ -414,7 +468,11 @@ class ItemsAsAttributes(object):
It also works on dicts that customize __getitem__ It also works on dicts that customize __getitem__
>>> missing_func = lambda self, key: 'missing item' >>> missing_func = lambda self, key: 'missing item'
>>> C = type(str('C'), (dict, ItemsAsAttributes), dict(__missing__ = missing_func)) >>> C = type(
... str('C'),
... (dict, ItemsAsAttributes),
... dict(__missing__ = missing_func),
... )
>>> i = C() >>> i = C()
>>> i.missing >>> i.missing
'missing item' 'missing item'
@ -428,6 +486,7 @@ class ItemsAsAttributes(object):
# attempt to get the value from the mapping (return self[key]) # attempt to get the value from the mapping (return self[key])
# but be careful not to lose the original exception context. # but be careful not to lose the original exception context.
noval = object() noval = object()
def _safe_getitem(cont, key, missing_result): def _safe_getitem(cont, key, missing_result):
try: try:
return cont[key] return cont[key]
@ -483,7 +542,7 @@ class IdentityOverrideMap(dict):
return key return key
class DictStack(list, collections.Mapping): class DictStack(list, collections.abc.Mapping):
""" """
A stack of dictionaries that behaves as a view on those dictionaries, A stack of dictionaries that behaves as a view on those dictionaries,
giving preference to the last. giving preference to the last.
@ -506,6 +565,7 @@ class DictStack(list, collections.Mapping):
>>> d = stack.pop() >>> d = stack.pop()
>>> stack['a'] >>> stack['a']
1 1
>>> stack.get('b', None)
""" """
def keys(self): def keys(self):
@ -513,7 +573,8 @@ class DictStack(list, collections.Mapping):
def __getitem__(self, key): def __getitem__(self, key):
for scope in reversed(self): for scope in reversed(self):
if key in scope: return scope[key] if key in scope:
return scope[key]
raise KeyError(key) raise KeyError(key)
push = list.append push = list.append
@ -553,6 +614,10 @@ class BijectiveMap(dict):
Traceback (most recent call last): Traceback (most recent call last):
ValueError: Key/Value pairs may not overlap ValueError: Key/Value pairs may not overlap
>>> m['e'] = 'd'
Traceback (most recent call last):
ValueError: Key/Value pairs may not overlap
>>> print(m.pop('d')) >>> print(m.pop('d'))
c c
@ -583,7 +648,12 @@ class BijectiveMap(dict):
def __setitem__(self, item, value): def __setitem__(self, item, value):
if item == value: if item == value:
raise ValueError("Key cannot map to itself") raise ValueError("Key cannot map to itself")
if (value in self or item in self) and self[item] != value: overlap = (
item in self and self[item] != value
or
value in self and self[value] != item
)
if overlap:
raise ValueError("Key/Value pairs may not overlap") raise ValueError("Key/Value pairs may not overlap")
super(BijectiveMap, self).__setitem__(item, value) super(BijectiveMap, self).__setitem__(item, value)
super(BijectiveMap, self).__setitem__(value, item) super(BijectiveMap, self).__setitem__(value, item)
@ -607,7 +677,7 @@ class BijectiveMap(dict):
self.__setitem__(*item) self.__setitem__(*item)
class FrozenDict(collections.Mapping, collections.Hashable): class FrozenDict(collections.abc.Mapping, collections.abc.Hashable):
""" """
An immutable mapping. An immutable mapping.
@ -641,8 +711,8 @@ class FrozenDict(collections.Mapping, collections.Hashable):
>>> isinstance(copy.copy(a), FrozenDict) >>> isinstance(copy.copy(a), FrozenDict)
True True
FrozenDict supplies .copy(), even though collections.Mapping doesn't FrozenDict supplies .copy(), even though
demand it. collections.abc.Mapping doesn't demand it.
>>> a.copy() == a >>> a.copy() == a
True True
@ -747,6 +817,9 @@ class Everything(object):
>>> import random >>> import random
>>> random.randint(1, 999) in Everything() >>> random.randint(1, 999) in Everything()
True True
>>> random.choice([None, 'foo', 42, ('a', 'b', 'c')]) in Everything()
True
""" """
def __contains__(self, other): def __contains__(self, other):
return True return True
@ -771,3 +844,63 @@ class InstrumentedDict(six.moves.UserDict):
def __init__(self, data): def __init__(self, data):
six.moves.UserDict.__init__(self) six.moves.UserDict.__init__(self)
self.data = data self.data = data
class Least(object):
"""
A value that is always lesser than any other
>>> least = Least()
>>> 3 < least
False
>>> 3 > least
True
>>> least < 3
True
>>> least <= 3
True
>>> least > 3
False
>>> 'x' > least
True
>>> None > least
True
"""
def __le__(self, other):
return True
__lt__ = __le__
def __ge__(self, other):
return False
__gt__ = __ge__
class Greatest(object):
"""
A value that is always greater than any other
>>> greatest = Greatest()
>>> 3 < greatest
True
>>> 3 > greatest
False
>>> greatest < 3
False
>>> greatest > 3
True
>>> greatest >= 3
True
>>> 'x' > greatest
False
>>> None > greatest
False
"""
def __ge__(self, other):
return True
__gt__ = __ge__
def __le__(self, other):
return False
__lt__ = __le__

View file

@ -1,8 +1,16 @@
from __future__ import absolute_import, unicode_literals, print_function, division from __future__ import (
absolute_import, unicode_literals, print_function, division,
)
import functools import functools
import time import time
import warnings import warnings
import inspect
import collections
from itertools import count
__metaclass__ = type
try: try:
from functools import lru_cache from functools import lru_cache
@ -16,13 +24,17 @@ except ImportError:
warnings.warn("No lru_cache available") warnings.warn("No lru_cache available")
import more_itertools.recipes
def compose(*funcs): def compose(*funcs):
""" """
Compose any number of unary functions into a single unary function. Compose any number of unary functions into a single unary function.
>>> import textwrap >>> import textwrap
>>> from six import text_type >>> from six import text_type
>>> text_type.strip(textwrap.dedent(compose.__doc__)) == compose(text_type.strip, textwrap.dedent)(compose.__doc__) >>> stripped = text_type.strip(textwrap.dedent(compose.__doc__))
>>> compose(text_type.strip, textwrap.dedent)(compose.__doc__) == stripped
True True
Compose also allows the innermost function to take arbitrary arguments. Compose also allows the innermost function to take arbitrary arguments.
@ -33,7 +45,8 @@ def compose(*funcs):
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7] [1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
""" """
compose_two = lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)) def compose_two(f1, f2):
return lambda *args, **kwargs: f1(f2(*args, **kwargs))
return functools.reduce(compose_two, funcs) return functools.reduce(compose_two, funcs)
@ -60,19 +73,36 @@ def once(func):
This decorator can ensure that an expensive or non-idempotent function This decorator can ensure that an expensive or non-idempotent function
will not be expensive on subsequent calls and is idempotent. will not be expensive on subsequent calls and is idempotent.
>>> func = once(lambda a: a+3) >>> add_three = once(lambda a: a+3)
>>> func(3) >>> add_three(3)
6 6
>>> func(9) >>> add_three(9)
6 6
>>> func('12') >>> add_three('12')
6 6
To reset the stored value, simply clear the property ``saved_result``.
>>> del add_three.saved_result
>>> add_three(9)
12
>>> add_three(8)
12
Or invoke 'reset()' on it.
>>> add_three.reset()
>>> add_three(-3)
0
>>> add_three(0)
0
""" """
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if not hasattr(func, 'always_returns'): if not hasattr(wrapper, 'saved_result'):
func.always_returns = func(*args, **kwargs) wrapper.saved_result = func(*args, **kwargs)
return func.always_returns return wrapper.saved_result
wrapper.reset = lambda: vars(wrapper).__delitem__('saved_result')
return wrapper return wrapper
@ -131,17 +161,22 @@ def method_cache(method, cache_wrapper=None):
>>> a.method2() >>> a.method2()
3 3
Caution - do not subsequently wrap the method with another decorator, such
as ``@property``, which changes the semantics of the function.
See also See also
http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/ http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/
for another implementation and additional justification. for another implementation and additional justification.
""" """
cache_wrapper = cache_wrapper or lru_cache() cache_wrapper = cache_wrapper or lru_cache()
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
# it's the first call, replace the method with a cached, bound method # it's the first call, replace the method with a cached, bound method
bound_method = functools.partial(method, self) bound_method = functools.partial(method, self)
cached_method = cache_wrapper(bound_method) cached_method = cache_wrapper(bound_method)
setattr(self, method.__name__, cached_method) setattr(self, method.__name__, cached_method)
return cached_method(*args, **kwargs) return cached_method(*args, **kwargs)
return _special_method_cache(method, cache_wrapper) or wrapper return _special_method_cache(method, cache_wrapper) or wrapper
@ -191,6 +226,29 @@ def apply(transform):
return wrap return wrap
def result_invoke(action):
r"""
Decorate a function with an action function that is
invoked on the results returned from the decorated
function (for its side-effect), then return the original
result.
>>> @result_invoke(print)
... def add_two(a, b):
... return a + b
>>> x = add_two(2, 3)
5
"""
def wrap(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
action(result)
return result
return wrapper
return wrap
def call_aside(f, *args, **kwargs): def call_aside(f, *args, **kwargs):
""" """
Call a function for its side effect after initialization. Call a function for its side effect after initialization.
@ -211,7 +269,7 @@ def call_aside(f, *args, **kwargs):
return f return f
class Throttler(object): class Throttler:
""" """
Rate-limit a function (or other callable) Rate-limit a function (or other callable)
""" """
@ -259,10 +317,143 @@ def retry_call(func, cleanup=lambda: None, retries=0, trap=()):
exception. On the final attempt, allow any exceptions exception. On the final attempt, allow any exceptions
to propagate. to propagate.
""" """
for attempt in range(retries): attempts = count() if retries == float('inf') else range(retries)
for attempt in attempts:
try: try:
return func() return func()
except trap: except trap:
cleanup() cleanup()
return func() return func()
def retry(*r_args, **r_kwargs):
"""
Decorator wrapper for retry_call. Accepts arguments to retry_call
except func and then returns a decorator for the decorated function.
Ex:
>>> @retry(retries=3)
... def my_func(a, b):
... "this is my funk"
... print(a, b)
>>> my_func.__doc__
'this is my funk'
"""
def decorate(func):
@functools.wraps(func)
def wrapper(*f_args, **f_kwargs):
bound = functools.partial(func, *f_args, **f_kwargs)
return retry_call(bound, *r_args, **r_kwargs)
return wrapper
return decorate
def print_yielded(func):
"""
Convert a generator into a function that prints all yielded elements
>>> @print_yielded
... def x():
... yield 3; yield None
>>> x()
3
None
"""
print_all = functools.partial(map, print)
print_results = compose(more_itertools.recipes.consume, print_all, func)
return functools.wraps(func)(print_results)
def pass_none(func):
"""
Wrap func so it's not called if its first param is None
>>> print_text = pass_none(print)
>>> print_text('text')
text
>>> print_text(None)
"""
@functools.wraps(func)
def wrapper(param, *args, **kwargs):
if param is not None:
return func(param, *args, **kwargs)
return wrapper
def assign_params(func, namespace):
"""
Assign parameters from namespace where func solicits.
>>> def func(x, y=3):
... print(x, y)
>>> assigned = assign_params(func, dict(x=2, z=4))
>>> assigned()
2 3
The usual errors are raised if a function doesn't receive
its required parameters:
>>> assigned = assign_params(func, dict(y=3, z=4))
>>> assigned()
Traceback (most recent call last):
TypeError: func() ...argument...
"""
try:
sig = inspect.signature(func)
params = sig.parameters.keys()
except AttributeError:
spec = inspect.getargspec(func)
params = spec.args
call_ns = {
k: namespace[k]
for k in params
if k in namespace
}
return functools.partial(func, **call_ns)
def save_method_args(method):
"""
Wrap a method such that when it is called, the args and kwargs are
saved on the method.
>>> class MyClass:
... @save_method_args
... def method(self, a, b):
... print(a, b)
>>> my_ob = MyClass()
>>> my_ob.method(1, 2)
1 2
>>> my_ob._saved_method.args
(1, 2)
>>> my_ob._saved_method.kwargs
{}
>>> my_ob.method(a=3, b='foo')
3 foo
>>> my_ob._saved_method.args
()
>>> my_ob._saved_method.kwargs == dict(a=3, b='foo')
True
The arguments are stored on the instance, allowing for
different instance to save different args.
>>> your_ob = MyClass()
>>> your_ob.method({str('x'): 3}, b=[4])
{'x': 3} [4]
>>> your_ob._saved_method.args
({'x': 3},)
>>> my_ob._saved_method.args
()
"""
args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs')
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
attr_name = '_saved_' + method.__name__
attr = args_and_kwargs(args, kwargs)
setattr(self, attr_name, attr)
return method(self, *args, **kwargs)
return wrapper

View file

@ -1,5 +1,6 @@
from __future__ import absolute_import, unicode_literals from __future__ import absolute_import, unicode_literals
import numbers
from functools import reduce from functools import reduce
@ -25,6 +26,7 @@ def get_bit_values(number, size=32):
number += 2**size number += 2**size
return list(map(int, bin(number)[-size:])) return list(map(int, bin(number)[-size:]))
def gen_bit_values(number): def gen_bit_values(number):
""" """
Return a zero or one for each bit of a numeric value up to the most Return a zero or one for each bit of a numeric value up to the most
@ -36,6 +38,7 @@ def gen_bit_values(number):
digits = bin(number)[2:] digits = bin(number)[2:]
return map(int, reversed(digits)) return map(int, reversed(digits))
def coalesce(bits): def coalesce(bits):
""" """
Take a sequence of bits, most significant first, and Take a sequence of bits, most significant first, and
@ -47,6 +50,7 @@ def coalesce(bits):
operation = lambda a, b: (a << 1 | b) operation = lambda a, b: (a << 1 | b)
return reduce(operation, bits) return reduce(operation, bits)
class Flags(object): class Flags(object):
""" """
Subclasses should define _names, a list of flag names beginning Subclasses should define _names, a list of flag names beginning
@ -96,6 +100,7 @@ class Flags(object):
index = self._names.index(key) index = self._names.index(key)
return self._values[index] return self._values[index]
class BitMask(type): class BitMask(type):
""" """
A metaclass to create a bitmask with attributes. Subclass an int and A metaclass to create a bitmask with attributes. Subclass an int and
@ -119,12 +124,28 @@ class BitMask(type):
>>> b2 = MyBits(8) >>> b2 = MyBits(8)
>>> any([b2.a, b2.b, b2.c]) >>> any([b2.a, b2.b, b2.c])
False False
If the instance defines methods, they won't be wrapped in
properties.
>>> ns['get_value'] = classmethod(lambda cls: 'some value')
>>> ns['prop'] = property(lambda self: 'a property')
>>> MyBits = BitMask(str('MyBits'), (int,), ns)
>>> MyBits(3).get_value()
'some value'
>>> MyBits(3).prop
'a property'
""" """
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
def make_property(name, value):
if name.startswith('_') or not isinstance(value, numbers.Number):
return value
return property(lambda self, value=value: bool(self & value))
newattrs = dict( newattrs = dict(
(attr, property(lambda self, value=value: bool(self & value))) (name, make_property(name, value))
for attr, value in attrs.items() for name, value in attrs.items()
if not attr.startswith('_')
) )
return type.__new__(cls, name, bases, newattrs) return type.__new__(cls, name, bases, newattrs)

View file

@ -39,6 +39,7 @@ class FoldedCase(six.text_type):
""" """
A case insensitive string class; behaves just like str A case insensitive string class; behaves just like str
except compares equal when the only variation is case. except compares equal when the only variation is case.
>>> s = FoldedCase('hello world') >>> s = FoldedCase('hello world')
>>> s == 'Hello World' >>> s == 'Hello World'
@ -47,6 +48,9 @@ class FoldedCase(six.text_type):
>>> 'Hello World' == s >>> 'Hello World' == s
True True
>>> s != 'Hello World'
False
>>> s.index('O') >>> s.index('O')
4 4
@ -55,6 +59,38 @@ class FoldedCase(six.text_type):
>>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta'])) >>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta']))
['alpha', 'Beta', 'GAMMA'] ['alpha', 'Beta', 'GAMMA']
Sequence membership is straightforward.
>>> "Hello World" in [s]
True
>>> s in ["Hello World"]
True
You may test for set inclusion, but candidate and elements
must both be folded.
>>> FoldedCase("Hello World") in {s}
True
>>> s in {FoldedCase("Hello World")}
True
String inclusion works as long as the FoldedCase object
is on the right.
>>> "hello" in FoldedCase("Hello World")
True
But not if the FoldedCase object is on the left:
>>> FoldedCase('hello') in 'Hello World'
False
In that case, use in_:
>>> FoldedCase('hello').in_('Hello World')
True
""" """
def __lt__(self, other): def __lt__(self, other):
return self.lower() < other.lower() return self.lower() < other.lower()
@ -65,14 +101,23 @@ class FoldedCase(six.text_type):
def __eq__(self, other): def __eq__(self, other):
return self.lower() == other.lower() return self.lower() == other.lower()
def __ne__(self, other):
return self.lower() != other.lower()
def __hash__(self): def __hash__(self):
return hash(self.lower()) return hash(self.lower())
def __contains__(self, other):
return super(FoldedCase, self).lower().__contains__(other.lower())
def in_(self, other):
"Does self appear in other?"
return self in FoldedCase(other)
# cache lower since it's likely to be called frequently. # cache lower since it's likely to be called frequently.
@jaraco.functools.method_cache
def lower(self): def lower(self):
self._lower = super(FoldedCase, self).lower() return super(FoldedCase, self).lower()
self.lower = lambda: self._lower
return self._lower
def index(self, sub): def index(self, sub):
return self.lower().index(sub.lower()) return self.lower().index(sub.lower())
@ -147,6 +192,7 @@ def is_decodable(value):
return False return False
return True return True
def is_binary(value): def is_binary(value):
""" """
Return True if the value appears to be binary (that is, it's a byte Return True if the value appears to be binary (that is, it's a byte
@ -154,6 +200,7 @@ def is_binary(value):
""" """
return isinstance(value, bytes) and not is_decodable(value) return isinstance(value, bytes) and not is_decodable(value)
def trim(s): def trim(s):
r""" r"""
Trim something like a docstring to remove the whitespace that Trim something like a docstring to remove the whitespace that
@ -164,8 +211,10 @@ def trim(s):
""" """
return textwrap.dedent(s).strip() return textwrap.dedent(s).strip()
class Splitter(object): class Splitter(object):
"""object that will split a string with the given arguments for each call """object that will split a string with the given arguments for each call
>>> s = Splitter(',') >>> s = Splitter(',')
>>> s('hello, world, this is your, master calling') >>> s('hello, world, this is your, master calling')
['hello', ' world', ' this is your', ' master calling'] ['hello', ' world', ' this is your', ' master calling']
@ -176,9 +225,11 @@ class Splitter(object):
def __call__(self, s): def __call__(self, s):
return s.split(*self.args) return s.split(*self.args)
def indent(string, prefix=' ' * 4): def indent(string, prefix=' ' * 4):
return prefix + string return prefix + string
class WordSet(tuple): class WordSet(tuple):
""" """
Given a Python identifier, return the words that identifier represents, Given a Python identifier, return the words that identifier represents,
@ -269,6 +320,7 @@ class WordSet(tuple):
def from_class_name(cls, subject): def from_class_name(cls, subject):
return cls.parse(subject.__class__.__name__) return cls.parse(subject.__class__.__name__)
# for backward compatibility # for backward compatibility
words = WordSet.parse words = WordSet.parse
@ -318,6 +370,7 @@ class SeparatedValues(six.text_type):
parts = self.split(self.separator) parts = self.split(self.separator)
return six.moves.filter(None, (part.strip() for part in parts)) return six.moves.filter(None, (part.strip() for part in parts))
class Stripper: class Stripper:
r""" r"""
Given a series of lines, find the common prefix and strip it from them. Given a series of lines, find the common prefix and strip it from them.
@ -369,3 +422,31 @@ class Stripper:
while s1[:index] != s2[:index]: while s1[:index] != s2[:index]:
index -= 1 index -= 1
return s1[:index] return s1[:index]
def remove_prefix(text, prefix):
"""
Remove the prefix from the text if it exists.
>>> remove_prefix('underwhelming performance', 'underwhelming ')
'performance'
>>> remove_prefix('something special', 'sample')
'something special'
"""
null, prefix, rest = text.rpartition(prefix)
return rest
def remove_suffix(text, suffix):
"""
Remove the suffix from the text if it exists.
>>> remove_suffix('name.git', '.git')
'name'
>>> remove_suffix('something special', 'sample')
'something special'
"""
rest, suffix, null = text.partition(suffix)
return rest

View file

@ -60,3 +60,18 @@ class Command(object):
cls.add_subparsers(parser) cls.add_subparsers(parser)
args = parser.parse_args() args = parser.parse_args()
args.action.run(args) args.action.run(args)
class Extend(argparse.Action):
"""
Argparse action to take an nargs=* argument
and add any values to the existing value.
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', nargs='*', default=[], action=Extend)
>>> args = parser.parse_args(['--foo', 'a=1', '--foo', 'b=2', 'c=3'])
>>> args.foo
['a=1', 'b=2', 'c=3']
"""
def __call__(self, parser, namespace, values, option_string=None):
getattr(namespace, self.dest).extend(values)

View file

@ -1,3 +1,5 @@
# deprecated -- use TQDM
from __future__ import (print_function, absolute_import, unicode_literals, from __future__ import (print_function, absolute_import, unicode_literals,
division) division)

View file

@ -45,3 +45,9 @@ GetClipboardData.restype = ctypes.wintypes.HANDLE
SetClipboardData = ctypes.windll.user32.SetClipboardData SetClipboardData = ctypes.windll.user32.SetClipboardData
SetClipboardData.argtypes = ctypes.wintypes.UINT, ctypes.wintypes.HANDLE SetClipboardData.argtypes = ctypes.wintypes.UINT, ctypes.wintypes.HANDLE
SetClipboardData.restype = ctypes.wintypes.HANDLE SetClipboardData.restype = ctypes.wintypes.HANDLE
OpenClipboard = ctypes.windll.user32.OpenClipboard
OpenClipboard.argtypes = ctypes.wintypes.HANDLE,
OpenClipboard.restype = ctypes.wintypes.BOOL
ctypes.windll.user32.CloseClipboard.restype = ctypes.wintypes.BOOL

View file

@ -10,9 +10,11 @@ try:
except ImportError: except ImportError:
LPBYTE = ctypes.POINTER(ctypes.wintypes.BYTE) LPBYTE = ctypes.POINTER(ctypes.wintypes.BYTE)
class CredentialAttribute(ctypes.Structure): class CredentialAttribute(ctypes.Structure):
_fields_ = [] _fields_ = []
class Credential(ctypes.Structure): class Credential(ctypes.Structure):
_fields_ = [ _fields_ = [
('flags', DWORD), ('flags', DWORD),
@ -32,6 +34,7 @@ class Credential(ctypes.Structure):
def __del__(self): def __del__(self):
ctypes.windll.advapi32.CredFree(ctypes.byref(self)) ctypes.windll.advapi32.CredFree(ctypes.byref(self))
PCREDENTIAL = ctypes.POINTER(Credential) PCREDENTIAL = ctypes.POINTER(Credential)
CredRead = ctypes.windll.advapi32.CredReadW CredRead = ctypes.windll.advapi32.CredReadW

View file

@ -1,11 +1,6 @@
from ctypes import ( from ctypes import windll, POINTER
Structure, windll, POINTER, byref, cast, create_unicode_buffer,
c_size_t, c_int, create_string_buffer, c_uint64, c_ushort, c_short,
c_uint,
)
from ctypes.wintypes import ( from ctypes.wintypes import (
BOOLEAN, LPWSTR, DWORD, LPVOID, HANDLE, FILETIME, LPWSTR, DWORD, LPVOID, HANDLE, BOOL,
WCHAR, BOOL, HWND, WORD, UINT,
) )
CreateEvent = windll.kernel32.CreateEventW CreateEvent = windll.kernel32.CreateEventW
@ -29,6 +24,7 @@ _WaitForMultipleObjects = windll.kernel32.WaitForMultipleObjects
_WaitForMultipleObjects.argtypes = DWORD, POINTER(HANDLE), BOOL, DWORD _WaitForMultipleObjects.argtypes = DWORD, POINTER(HANDLE), BOOL, DWORD
_WaitForMultipleObjects.restype = DWORD _WaitForMultipleObjects.restype = DWORD
def WaitForMultipleObjects(handles, wait_all=False, timeout=0): def WaitForMultipleObjects(handles, wait_all=False, timeout=0):
n_handles = len(handles) n_handles = len(handles)
handle_array = (HANDLE * n_handles)() handle_array = (HANDLE * n_handles)()
@ -36,6 +32,7 @@ def WaitForMultipleObjects(handles, wait_all=False, timeout=0):
handle_array[index] = handle handle_array[index] = handle
return _WaitForMultipleObjects(n_handles, handle_array, wait_all, timeout) return _WaitForMultipleObjects(n_handles, handle_array, wait_all, timeout)
WAIT_OBJECT_0 = 0 WAIT_OBJECT_0 = 0
INFINITE = -1 INFINITE = -1
WAIT_TIMEOUT = 0x102 WAIT_TIMEOUT = 0x102

View file

@ -28,16 +28,20 @@ MAX_PATH = 260
GetFinalPathNameByHandle = ctypes.windll.kernel32.GetFinalPathNameByHandleW GetFinalPathNameByHandle = ctypes.windll.kernel32.GetFinalPathNameByHandleW
GetFinalPathNameByHandle.argtypes = ( GetFinalPathNameByHandle.argtypes = (
ctypes.wintypes.HANDLE, ctypes.wintypes.LPWSTR, ctypes.wintypes.DWORD, ctypes.wintypes.DWORD, ctypes.wintypes.HANDLE, ctypes.wintypes.LPWSTR, ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
) )
GetFinalPathNameByHandle.restype = ctypes.wintypes.DWORD GetFinalPathNameByHandle.restype = ctypes.wintypes.DWORD
class SECURITY_ATTRIBUTES(ctypes.Structure): class SECURITY_ATTRIBUTES(ctypes.Structure):
_fields_ = ( _fields_ = (
('length', ctypes.wintypes.DWORD), ('length', ctypes.wintypes.DWORD),
('p_security_descriptor', ctypes.wintypes.LPVOID), ('p_security_descriptor', ctypes.wintypes.LPVOID),
('inherit_handle', ctypes.wintypes.BOOLEAN), ('inherit_handle', ctypes.wintypes.BOOLEAN),
) )
LPSECURITY_ATTRIBUTES = ctypes.POINTER(SECURITY_ATTRIBUTES) LPSECURITY_ATTRIBUTES = ctypes.POINTER(SECURITY_ATTRIBUTES)
CreateFile = ctypes.windll.kernel32.CreateFileW CreateFile = ctypes.windll.kernel32.CreateFileW
@ -77,23 +81,61 @@ CloseHandle = ctypes.windll.kernel32.CloseHandle
CloseHandle.argtypes = (ctypes.wintypes.HANDLE,) CloseHandle.argtypes = (ctypes.wintypes.HANDLE,)
CloseHandle.restype = ctypes.wintypes.BOOLEAN CloseHandle.restype = ctypes.wintypes.BOOLEAN
class WIN32_FIND_DATA(ctypes.Structure):
class WIN32_FIND_DATA(ctypes.wintypes.WIN32_FIND_DATAW):
"""
_fields_ = [ _fields_ = [
('file_attributes', ctypes.wintypes.DWORD), ("dwFileAttributes", DWORD),
('creation_time', ctypes.wintypes.FILETIME), ("ftCreationTime", FILETIME),
('last_access_time', ctypes.wintypes.FILETIME), ("ftLastAccessTime", FILETIME),
('last_write_time', ctypes.wintypes.FILETIME), ("ftLastWriteTime", FILETIME),
('file_size_words', ctypes.wintypes.DWORD*2), ("nFileSizeHigh", DWORD),
('reserved', ctypes.wintypes.DWORD*2), ("nFileSizeLow", DWORD),
('filename', ctypes.wintypes.WCHAR*MAX_PATH), ("dwReserved0", DWORD),
('alternate_filename', ctypes.wintypes.WCHAR*14), ("dwReserved1", DWORD),
("cFileName", WCHAR * MAX_PATH),
("cAlternateFileName", WCHAR * 14)]
] ]
"""
@property
def file_attributes(self):
return self.dwFileAttributes
@property
def creation_time(self):
return self.ftCreationTime
@property
def last_access_time(self):
return self.ftLastAccessTime
@property
def last_write_time(self):
return self.ftLastWriteTime
@property
def file_size_words(self):
return [self.nFileSizeHigh, self.nFileSizeLow]
@property
def reserved(self):
return [self.dwReserved0, self.dwReserved1]
@property
def filename(self):
return self.cFileName
@property
def alternate_filename(self):
return self.cAlternateFileName
@property @property
def file_size(self): def file_size(self):
return ctypes.cast(self.file_size_words, ctypes.POINTER(ctypes.c_uint64)).contents return self.nFileSizeHigh << 32 + self.nFileSizeLow
LPWIN32_FIND_DATA = ctypes.POINTER(WIN32_FIND_DATA)
LPWIN32_FIND_DATA = ctypes.POINTER(ctypes.wintypes.WIN32_FIND_DATAW)
FindFirstFile = ctypes.windll.kernel32.FindFirstFileW FindFirstFile = ctypes.windll.kernel32.FindFirstFileW
FindFirstFile.argtypes = (ctypes.wintypes.LPWSTR, LPWIN32_FIND_DATA) FindFirstFile.argtypes = (ctypes.wintypes.LPWSTR, LPWIN32_FIND_DATA)
@ -102,6 +144,8 @@ FindNextFile = ctypes.windll.kernel32.FindNextFileW
FindNextFile.argtypes = (ctypes.wintypes.HANDLE, LPWIN32_FIND_DATA) FindNextFile.argtypes = (ctypes.wintypes.HANDLE, LPWIN32_FIND_DATA)
FindNextFile.restype = ctypes.wintypes.BOOLEAN FindNextFile.restype = ctypes.wintypes.BOOLEAN
ctypes.windll.kernel32.FindClose.argtypes = ctypes.wintypes.HANDLE,
SCS_32BIT_BINARY = 0 # A 32-bit Windows-based application SCS_32BIT_BINARY = 0 # A 32-bit Windows-based application
SCS_64BIT_BINARY = 6 # A 64-bit Windows-based application SCS_64BIT_BINARY = 6 # A 64-bit Windows-based application
SCS_DOS_BINARY = 1 # An MS-DOS-based application SCS_DOS_BINARY = 1 # An MS-DOS-based application
@ -111,10 +155,45 @@ SCS_POSIX_BINARY = 4 # A POSIX-based application
SCS_WOW_BINARY = 2 # A 16-bit Windows-based application SCS_WOW_BINARY = 2 # A 16-bit Windows-based application
_GetBinaryType = ctypes.windll.kernel32.GetBinaryTypeW _GetBinaryType = ctypes.windll.kernel32.GetBinaryTypeW
_GetBinaryType.argtypes = (ctypes.wintypes.LPWSTR, ctypes.POINTER(ctypes.wintypes.DWORD)) _GetBinaryType.argtypes = (
ctypes.wintypes.LPWSTR, ctypes.POINTER(ctypes.wintypes.DWORD),
)
_GetBinaryType.restype = ctypes.wintypes.BOOL _GetBinaryType.restype = ctypes.wintypes.BOOL
FILEOP_FLAGS = ctypes.wintypes.WORD FILEOP_FLAGS = ctypes.wintypes.WORD
class BY_HANDLE_FILE_INFORMATION(ctypes.Structure):
_fields_ = [
('file_attributes', ctypes.wintypes.DWORD),
('creation_time', ctypes.wintypes.FILETIME),
('last_access_time', ctypes.wintypes.FILETIME),
('last_write_time', ctypes.wintypes.FILETIME),
('volume_serial_number', ctypes.wintypes.DWORD),
('file_size_high', ctypes.wintypes.DWORD),
('file_size_low', ctypes.wintypes.DWORD),
('number_of_links', ctypes.wintypes.DWORD),
('file_index_high', ctypes.wintypes.DWORD),
('file_index_low', ctypes.wintypes.DWORD),
]
@property
def file_size(self):
return (self.file_size_high << 32) + self.file_size_low
@property
def file_index(self):
return (self.file_index_high << 32) + self.file_index_low
GetFileInformationByHandle = ctypes.windll.kernel32.GetFileInformationByHandle
GetFileInformationByHandle.restype = ctypes.wintypes.BOOL
GetFileInformationByHandle.argtypes = (
ctypes.wintypes.HANDLE,
ctypes.POINTER(BY_HANDLE_FILE_INFORMATION),
)
class SHFILEOPSTRUCT(ctypes.Structure): class SHFILEOPSTRUCT(ctypes.Structure):
_fields_ = [ _fields_ = [
('status_dialog', ctypes.wintypes.HWND), ('status_dialog', ctypes.wintypes.HWND),
@ -126,6 +205,8 @@ class SHFILEOPSTRUCT(ctypes.Structure):
('name_mapping_handles', ctypes.wintypes.LPVOID), ('name_mapping_handles', ctypes.wintypes.LPVOID),
('progress_title', ctypes.wintypes.LPWSTR), ('progress_title', ctypes.wintypes.LPWSTR),
] ]
_SHFileOperation = ctypes.windll.shell32.SHFileOperationW _SHFileOperation = ctypes.windll.shell32.SHFileOperationW
_SHFileOperation.argtypes = [ctypes.POINTER(SHFILEOPSTRUCT)] _SHFileOperation.argtypes = [ctypes.POINTER(SHFILEOPSTRUCT)]
_SHFileOperation.restype = ctypes.c_int _SHFileOperation.restype = ctypes.c_int
@ -149,6 +230,7 @@ REPLACEFILE_WRITE_THROUGH = 0x1
REPLACEFILE_IGNORE_MERGE_ERRORS = 0x2 REPLACEFILE_IGNORE_MERGE_ERRORS = 0x2
REPLACEFILE_IGNORE_ACL_ERRORS = 0x4 REPLACEFILE_IGNORE_ACL_ERRORS = 0x4
class STAT_STRUCT(ctypes.Structure): class STAT_STRUCT(ctypes.Structure):
_fields_ = [ _fields_ = [
('dev', ctypes.c_uint), ('dev', ctypes.c_uint),
@ -165,17 +247,22 @@ class STAT_STRUCT(ctypes.Structure):
('ctime', ctypes.c_uint), ('ctime', ctypes.c_uint),
] ]
_wstat = ctypes.windll.msvcrt._wstat _wstat = ctypes.windll.msvcrt._wstat
_wstat.argtypes = [ctypes.wintypes.LPWSTR, ctypes.POINTER(STAT_STRUCT)] _wstat.argtypes = [ctypes.wintypes.LPWSTR, ctypes.POINTER(STAT_STRUCT)]
_wstat.restype = ctypes.c_int _wstat.restype = ctypes.c_int
FILE_NOTIFY_CHANGE_LAST_WRITE = 0x10 FILE_NOTIFY_CHANGE_LAST_WRITE = 0x10
FindFirstChangeNotification = ctypes.windll.kernel32.FindFirstChangeNotificationW FindFirstChangeNotification = (
FindFirstChangeNotification.argtypes = ctypes.wintypes.LPWSTR, ctypes.wintypes.BOOL, ctypes.wintypes.DWORD ctypes.windll.kernel32.FindFirstChangeNotificationW)
FindFirstChangeNotification.argtypes = (
ctypes.wintypes.LPWSTR, ctypes.wintypes.BOOL, ctypes.wintypes.DWORD,
)
FindFirstChangeNotification.restype = ctypes.wintypes.HANDLE FindFirstChangeNotification.restype = ctypes.wintypes.HANDLE
FindCloseChangeNotification = ctypes.windll.kernel32.FindCloseChangeNotification FindCloseChangeNotification = (
ctypes.windll.kernel32.FindCloseChangeNotification)
FindCloseChangeNotification.argtypes = ctypes.wintypes.HANDLE, FindCloseChangeNotification.argtypes = ctypes.wintypes.HANDLE,
FindCloseChangeNotification.restype = ctypes.wintypes.BOOL FindCloseChangeNotification.restype = ctypes.wintypes.BOOL
@ -203,6 +290,7 @@ DeviceIoControl.argtypes = [
] ]
DeviceIoControl.restype = ctypes.wintypes.BOOL DeviceIoControl.restype = ctypes.wintypes.BOOL
class REPARSE_DATA_BUFFER(ctypes.Structure): class REPARSE_DATA_BUFFER(ctypes.Structure):
_fields_ = [ _fields_ = [
('tag', ctypes.c_ulong), ('tag', ctypes.c_ulong),
@ -215,6 +303,7 @@ class REPARSE_DATA_BUFFER(ctypes.Structure):
('flags', ctypes.c_ulong), ('flags', ctypes.c_ulong),
('path_buffer', ctypes.c_byte * 1), ('path_buffer', ctypes.c_byte * 1),
] ]
def get_print_name(self): def get_print_name(self):
wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR) wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR)
arr_typ = ctypes.wintypes.WCHAR * (self.print_name_length // wchar_size) arr_typ = ctypes.wintypes.WCHAR * (self.print_name_length // wchar_size)

View file

@ -2,6 +2,7 @@ import struct
import ctypes.wintypes import ctypes.wintypes
from ctypes.wintypes import DWORD, WCHAR, BYTE, BOOL from ctypes.wintypes import DWORD, WCHAR, BYTE, BOOL
# from mprapi.h # from mprapi.h
MAX_INTERFACE_NAME_LEN = 2**8 MAX_INTERFACE_NAME_LEN = 2**8
@ -13,6 +14,7 @@ MAXLEN_IFDESCR = 2**8
MAX_ADAPTER_ADDRESS_LENGTH = 8 MAX_ADAPTER_ADDRESS_LENGTH = 8
MAX_DHCPV6_DUID_LENGTH = 130 MAX_DHCPV6_DUID_LENGTH = 130
class MIB_IFROW(ctypes.Structure): class MIB_IFROW(ctypes.Structure):
_fields_ = ( _fields_ = (
('name', WCHAR * MAX_INTERFACE_NAME_LEN), ('name', WCHAR * MAX_INTERFACE_NAME_LEN),
@ -46,7 +48,7 @@ class MIB_IFROW(ctypes.Structure):
val = getattr(self, val_prop) val = getattr(self, val_prop)
len_prop = '{0}_length'.format(name) len_prop = '{0}_length'.format(name)
length = getattr(self, len_prop) length = getattr(self, len_prop)
return str(buffer(val))[:length] return str(memoryview(val))[:length]
@property @property
def physical_address(self): def physical_address(self):
@ -56,12 +58,14 @@ class MIB_IFROW(ctypes.Structure):
def description(self): def description(self):
return self._get_binary_property('description') return self._get_binary_property('description')
class MIB_IFTABLE(ctypes.Structure): class MIB_IFTABLE(ctypes.Structure):
_fields_ = ( _fields_ = (
('num_entries', DWORD), # dwNumEntries ('num_entries', DWORD), # dwNumEntries
('entries', MIB_IFROW * 0), # table ('entries', MIB_IFROW * 0), # table
) )
class MIB_IPADDRROW(ctypes.Structure): class MIB_IPADDRROW(ctypes.Structure):
_fields_ = ( _fields_ = (
('address_num', DWORD), ('address_num', DWORD),
@ -79,40 +83,49 @@ class MIB_IPADDRROW(ctypes.Structure):
_ = struct.pack('L', self.address_num) _ = struct.pack('L', self.address_num)
return struct.unpack('!L', _)[0] return struct.unpack('!L', _)[0]
class MIB_IPADDRTABLE(ctypes.Structure): class MIB_IPADDRTABLE(ctypes.Structure):
_fields_ = ( _fields_ = (
('num_entries', DWORD), ('num_entries', DWORD),
('entries', MIB_IPADDRROW * 0), ('entries', MIB_IPADDRROW * 0),
) )
class SOCKADDR(ctypes.Structure): class SOCKADDR(ctypes.Structure):
_fields_ = ( _fields_ = (
('family', ctypes.c_ushort), ('family', ctypes.c_ushort),
('data', ctypes.c_byte * 14), ('data', ctypes.c_byte * 14),
) )
LPSOCKADDR = ctypes.POINTER(SOCKADDR) LPSOCKADDR = ctypes.POINTER(SOCKADDR)
class SOCKET_ADDRESS(ctypes.Structure): class SOCKET_ADDRESS(ctypes.Structure):
_fields_ = [ _fields_ = [
('address', LPSOCKADDR), ('address', LPSOCKADDR),
('length', ctypes.c_int), ('length', ctypes.c_int),
] ]
class _IP_ADAPTER_ADDRESSES_METRIC(ctypes.Structure): class _IP_ADAPTER_ADDRESSES_METRIC(ctypes.Structure):
_fields_ = [ _fields_ = [
('length', ctypes.c_ulong), ('length', ctypes.c_ulong),
('interface_index', DWORD), ('interface_index', DWORD),
] ]
class _IP_ADAPTER_ADDRESSES_U1(ctypes.Union): class _IP_ADAPTER_ADDRESSES_U1(ctypes.Union):
_fields_ = [ _fields_ = [
('alignment', ctypes.c_ulonglong), ('alignment', ctypes.c_ulonglong),
('metric', _IP_ADAPTER_ADDRESSES_METRIC), ('metric', _IP_ADAPTER_ADDRESSES_METRIC),
] ]
class IP_ADAPTER_ADDRESSES(ctypes.Structure): class IP_ADAPTER_ADDRESSES(ctypes.Structure):
pass pass
LP_IP_ADAPTER_ADDRESSES = ctypes.POINTER(IP_ADAPTER_ADDRESSES) LP_IP_ADAPTER_ADDRESSES = ctypes.POINTER(IP_ADAPTER_ADDRESSES)
# for now, just use void * for pointers to unused structures # for now, just use void * for pointers to unused structures
@ -125,7 +138,8 @@ PIP_ADAPTER_WINS_SERVER_ADDRESS_LH = ctypes.c_void_p
PIP_ADAPTER_GATEWAY_ADDRESS_LH = ctypes.c_void_p PIP_ADAPTER_GATEWAY_ADDRESS_LH = ctypes.c_void_p
PIP_ADAPTER_DNS_SUFFIX = ctypes.c_void_p PIP_ADAPTER_DNS_SUFFIX = ctypes.c_void_p
IF_OPER_STATUS = ctypes.c_uint # this is an enum, consider http://code.activestate.com/recipes/576415/ IF_OPER_STATUS = ctypes.c_uint # this is an enum, consider
# http://code.activestate.com/recipes/576415/
IF_LUID = ctypes.c_uint64 IF_LUID = ctypes.c_uint64
NET_IF_COMPARTMENT_ID = ctypes.c_uint32 NET_IF_COMPARTMENT_ID = ctypes.c_uint32

View file

@ -3,7 +3,7 @@ import ctypes.wintypes
GMEM_MOVEABLE = 0x2 GMEM_MOVEABLE = 0x2
GlobalAlloc = ctypes.windll.kernel32.GlobalAlloc GlobalAlloc = ctypes.windll.kernel32.GlobalAlloc
GlobalAlloc.argtypes = ctypes.wintypes.UINT, ctypes.c_ssize_t GlobalAlloc.argtypes = ctypes.wintypes.UINT, ctypes.c_size_t
GlobalAlloc.restype = ctypes.wintypes.HANDLE GlobalAlloc.restype = ctypes.wintypes.HANDLE
GlobalLock = ctypes.windll.kernel32.GlobalLock GlobalLock = ctypes.windll.kernel32.GlobalLock
@ -31,3 +31,15 @@ CreateFileMapping.restype = ctypes.wintypes.HANDLE
MapViewOfFile = ctypes.windll.kernel32.MapViewOfFile MapViewOfFile = ctypes.windll.kernel32.MapViewOfFile
MapViewOfFile.restype = ctypes.wintypes.HANDLE MapViewOfFile.restype = ctypes.wintypes.HANDLE
UnmapViewOfFile = ctypes.windll.kernel32.UnmapViewOfFile
UnmapViewOfFile.argtypes = ctypes.wintypes.HANDLE,
RtlMoveMemory = ctypes.windll.kernel32.RtlMoveMemory
RtlMoveMemory.argtypes = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
)
ctypes.windll.kernel32.LocalFree.argtypes = ctypes.wintypes.HLOCAL,

View file

@ -13,6 +13,7 @@ import six
LRESULT = LPARAM LRESULT = LPARAM
class LPARAM_wstr(LPARAM): class LPARAM_wstr(LPARAM):
""" """
A special instance of LPARAM that can be constructed from a string A special instance of LPARAM that can be constructed from a string
@ -25,6 +26,7 @@ class LPARAM_wstr(LPARAM):
return LPVOID.from_param(six.text_type(param)) return LPVOID.from_param(six.text_type(param))
return LPARAM.from_param(param) return LPARAM.from_param(param)
SendMessage = ctypes.windll.user32.SendMessageW SendMessage = ctypes.windll.user32.SendMessageW
SendMessage.argtypes = (HWND, UINT, WPARAM, LPARAM_wstr) SendMessage.argtypes = (HWND, UINT, WPARAM, LPARAM_wstr)
SendMessage.restype = LRESULT SendMessage.restype = LRESULT
@ -32,7 +34,8 @@ SendMessage.restype = LRESULT
HWND_BROADCAST = 0xFFFF HWND_BROADCAST = 0xFFFF
WM_SETTINGCHANGE = 0x1A WM_SETTINGCHANGE = 0x1A
# constants from http://msdn.microsoft.com/en-us/library/ms644952%28v=vs.85%29.aspx # constants from http://msdn.microsoft.com
# /en-us/library/ms644952%28v=vs.85%29.aspx
SMTO_ABORTIFHUNG = 0x02 SMTO_ABORTIFHUNG = 0x02
SMTO_BLOCK = 0x01 SMTO_BLOCK = 0x01
SMTO_NORMAL = 0x00 SMTO_NORMAL = 0x00
@ -45,6 +48,7 @@ SendMessageTimeout.argtypes = SendMessage.argtypes + (
) )
SendMessageTimeout.restype = LRESULT SendMessageTimeout.restype = LRESULT
def unicode_as_lparam(source): def unicode_as_lparam(source):
pointer = ctypes.cast(ctypes.c_wchar_p(source), ctypes.c_void_p) pointer = ctypes.cast(ctypes.c_wchar_p(source), ctypes.c_void_p)
return LPARAM(pointer.value) return LPARAM(pointer.value)

View file

@ -5,6 +5,7 @@ mpr = ctypes.windll.mpr
RESOURCETYPE_ANY = 0 RESOURCETYPE_ANY = 0
class NETRESOURCE(ctypes.Structure): class NETRESOURCE(ctypes.Structure):
_fields_ = [ _fields_ = [
('scope', ctypes.wintypes.DWORD), ('scope', ctypes.wintypes.DWORD),
@ -16,6 +17,8 @@ class NETRESOURCE(ctypes.Structure):
('comment', ctypes.wintypes.LPWSTR), ('comment', ctypes.wintypes.LPWSTR),
('provider', ctypes.wintypes.LPWSTR), ('provider', ctypes.wintypes.LPWSTR),
] ]
LPNETRESOURCE = ctypes.POINTER(NETRESOURCE) LPNETRESOURCE = ctypes.POINTER(NETRESOURCE)
WNetAddConnection2 = mpr.WNetAddConnection2W WNetAddConnection2 = mpr.WNetAddConnection2W

View file

@ -1,5 +1,6 @@
import ctypes.wintypes import ctypes.wintypes
class SYSTEM_POWER_STATUS(ctypes.Structure): class SYSTEM_POWER_STATUS(ctypes.Structure):
_fields_ = ( _fields_ = (
('ac_line_status', ctypes.wintypes.BYTE), ('ac_line_status', ctypes.wintypes.BYTE),
@ -12,7 +13,9 @@ class SYSTEM_POWER_STATUS(ctypes.Structure):
@property @property
def ac_line_status_string(self): def ac_line_status_string(self):
return {0:'offline', 1: 'online', 255: 'unknown'}[self.ac_line_status] return {
0: 'offline', 1: 'online', 255: 'unknown'}[self.ac_line_status]
LPSYSTEM_POWER_STATUS = ctypes.POINTER(SYSTEM_POWER_STATUS) LPSYSTEM_POWER_STATUS = ctypes.POINTER(SYSTEM_POWER_STATUS)
GetSystemPowerStatus = ctypes.windll.kernel32.GetSystemPowerStatus GetSystemPowerStatus = ctypes.windll.kernel32.GetSystemPowerStatus
@ -23,6 +26,7 @@ SetThreadExecutionState = ctypes.windll.kernel32.SetThreadExecutionState
SetThreadExecutionState.argtypes = [ctypes.c_uint] SetThreadExecutionState.argtypes = [ctypes.c_uint]
SetThreadExecutionState.restype = ctypes.c_uint SetThreadExecutionState.restype = ctypes.c_uint
class ES: class ES:
""" """
Execution state constants Execution state constants

View file

@ -1,5 +1,6 @@
import ctypes.wintypes import ctypes.wintypes
class LUID(ctypes.Structure): class LUID(ctypes.Structure):
_fields_ = [ _fields_ = [
('low_part', ctypes.wintypes.DWORD), ('low_part', ctypes.wintypes.DWORD),
@ -15,6 +16,7 @@ class LUID(ctypes.Structure):
def __ne__(self, other): def __ne__(self, other):
return not (self == other) return not (self == other)
LookupPrivilegeValue = ctypes.windll.advapi32.LookupPrivilegeValueW LookupPrivilegeValue = ctypes.windll.advapi32.LookupPrivilegeValueW
LookupPrivilegeValue.argtypes = ( LookupPrivilegeValue.argtypes = (
ctypes.wintypes.LPWSTR, # system name ctypes.wintypes.LPWSTR, # system name
@ -23,17 +25,20 @@ LookupPrivilegeValue.argtypes = (
) )
LookupPrivilegeValue.restype = ctypes.wintypes.BOOL LookupPrivilegeValue.restype = ctypes.wintypes.BOOL
class TOKEN_INFORMATION_CLASS: class TOKEN_INFORMATION_CLASS:
TokenUser = 1 TokenUser = 1
TokenGroups = 2 TokenGroups = 2
TokenPrivileges = 3 TokenPrivileges = 3
# ... see http://msdn.microsoft.com/en-us/library/aa379626%28VS.85%29.aspx # ... see http://msdn.microsoft.com/en-us/library/aa379626%28VS.85%29.aspx
SE_PRIVILEGE_ENABLED_BY_DEFAULT = 0x00000001 SE_PRIVILEGE_ENABLED_BY_DEFAULT = 0x00000001
SE_PRIVILEGE_ENABLED = 0x00000002 SE_PRIVILEGE_ENABLED = 0x00000002
SE_PRIVILEGE_REMOVED = 0x00000004 SE_PRIVILEGE_REMOVED = 0x00000004
SE_PRIVILEGE_USED_FOR_ACCESS = 0x80000000 SE_PRIVILEGE_USED_FOR_ACCESS = 0x80000000
class LUID_AND_ATTRIBUTES(ctypes.Structure): class LUID_AND_ATTRIBUTES(ctypes.Structure):
_fields_ = [ _fields_ = [
('LUID', LUID), ('LUID', LUID),
@ -50,14 +55,17 @@ class LUID_AND_ATTRIBUTES(ctypes.Structure):
size = ctypes.wintypes.DWORD(10240) size = ctypes.wintypes.DWORD(10240)
buf = ctypes.create_unicode_buffer(size.value) buf = ctypes.create_unicode_buffer(size.value)
res = LookupPrivilegeName(None, self.LUID, buf, size) res = LookupPrivilegeName(None, self.LUID, buf, size)
if res == 0: raise RuntimeError if res == 0:
raise RuntimeError
return buf[:size.value] return buf[:size.value]
def __str__(self): def __str__(self):
res = self.get_name() res = self.get_name()
if self.is_enabled(): res += ' (enabled)' if self.is_enabled():
res += ' (enabled)'
return res return res
LookupPrivilegeName = ctypes.windll.advapi32.LookupPrivilegeNameW LookupPrivilegeName = ctypes.windll.advapi32.LookupPrivilegeNameW
LookupPrivilegeName.argtypes = ( LookupPrivilegeName.argtypes = (
ctypes.wintypes.LPWSTR, # lpSystemName ctypes.wintypes.LPWSTR, # lpSystemName
@ -67,6 +75,7 @@ LookupPrivilegeName.argtypes = (
) )
LookupPrivilegeName.restype = ctypes.wintypes.BOOL LookupPrivilegeName.restype = ctypes.wintypes.BOOL
class TOKEN_PRIVILEGES(ctypes.Structure): class TOKEN_PRIVILEGES(ctypes.Structure):
_fields_ = [ _fields_ = [
('count', ctypes.wintypes.DWORD), ('count', ctypes.wintypes.DWORD),
@ -75,12 +84,14 @@ class TOKEN_PRIVILEGES(ctypes.Structure):
def get_array(self): def get_array(self):
array_type = LUID_AND_ATTRIBUTES * self.count array_type = LUID_AND_ATTRIBUTES * self.count
privileges = ctypes.cast(self.privileges, ctypes.POINTER(array_type)).contents privileges = ctypes.cast(
self.privileges, ctypes.POINTER(array_type)).contents
return privileges return privileges
def __iter__(self): def __iter__(self):
return iter(self.get_array()) return iter(self.get_array())
PTOKEN_PRIVILEGES = ctypes.POINTER(TOKEN_PRIVILEGES) PTOKEN_PRIVILEGES = ctypes.POINTER(TOKEN_PRIVILEGES)
GetTokenInformation = ctypes.windll.advapi32.GetTokenInformation GetTokenInformation = ctypes.windll.advapi32.GetTokenInformation

View file

@ -5,5 +5,7 @@ TOKEN_ALL_ACCESS = 0xf01ff
GetCurrentProcess = ctypes.windll.kernel32.GetCurrentProcess GetCurrentProcess = ctypes.windll.kernel32.GetCurrentProcess
GetCurrentProcess.restype = ctypes.wintypes.HANDLE GetCurrentProcess.restype = ctypes.wintypes.HANDLE
OpenProcessToken = ctypes.windll.advapi32.OpenProcessToken OpenProcessToken = ctypes.windll.advapi32.OpenProcessToken
OpenProcessToken.argtypes = (ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, ctypes.POINTER(ctypes.wintypes.HANDLE)) OpenProcessToken.argtypes = (
ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD,
ctypes.POINTER(ctypes.wintypes.HANDLE))
OpenProcessToken.restype = ctypes.wintypes.BOOL OpenProcessToken.restype = ctypes.wintypes.BOOL

View file

@ -60,12 +60,15 @@ POLICY_EXECUTE = (
POLICY_VIEW_LOCAL_INFORMATION | POLICY_VIEW_LOCAL_INFORMATION |
POLICY_LOOKUP_NAMES) POLICY_LOOKUP_NAMES)
class TokenAccess: class TokenAccess:
TOKEN_QUERY = 0x8 TOKEN_QUERY = 0x8
class TokenInformationClass: class TokenInformationClass:
TokenUser = 1 TokenUser = 1
class TOKEN_USER(ctypes.Structure): class TOKEN_USER(ctypes.Structure):
num = 1 num = 1
_fields_ = [ _fields_ = [
@ -100,6 +103,7 @@ class SECURITY_DESCRIPTOR(ctypes.Structure):
('Dacl', ctypes.c_void_p), ('Dacl', ctypes.c_void_p),
] ]
class SECURITY_ATTRIBUTES(ctypes.Structure): class SECURITY_ATTRIBUTES(ctypes.Structure):
""" """
typedef struct _SECURITY_ATTRIBUTES { typedef struct _SECURITY_ATTRIBUTES {
@ -126,3 +130,10 @@ class SECURITY_ATTRIBUTES(ctypes.Structure):
def descriptor(self, value): def descriptor(self, value):
self._descriptor = value self._descriptor = value
self.lpSecurityDescriptor = ctypes.addressof(value) self.lpSecurityDescriptor = ctypes.addressof(value)
ctypes.windll.advapi32.SetSecurityDescriptorOwner.argtypes = (
ctypes.POINTER(SECURITY_DESCRIPTOR),
ctypes.c_void_p,
ctypes.wintypes.BOOL,
)

View file

@ -1,6 +1,7 @@
import ctypes.wintypes import ctypes.wintypes
BOOL = ctypes.wintypes.BOOL BOOL = ctypes.wintypes.BOOL
class SHELLSTATE(ctypes.Structure): class SHELLSTATE(ctypes.Structure):
_fields_ = [ _fields_ = [
('show_all_objects', BOOL, 1), ('show_all_objects', BOOL, 1),
@ -34,6 +35,7 @@ class SHELLSTATE(ctypes.Structure):
('spare_flags', ctypes.wintypes.UINT, 13), ('spare_flags', ctypes.wintypes.UINT, 13),
] ]
SSF_SHOWALLOBJECTS = 0x00000001 SSF_SHOWALLOBJECTS = 0x00000001
"The fShowAllObjects member is being requested." "The fShowAllObjects member is being requested."
@ -62,7 +64,13 @@ SSF_SHOWATTRIBCOL = 0x00000100
"The fShowAttribCol member is being requested. (Windows Vista: Not used.)" "The fShowAttribCol member is being requested. (Windows Vista: Not used.)"
SSF_DESKTOPHTML = 0x00000200 SSF_DESKTOPHTML = 0x00000200
"The fDesktopHTML member is being requested. Set is not available. Instead, for versions of Microsoft Windows prior to Windows XP, enable Desktop HTML by IActiveDesktop. The use of IActiveDesktop for this purpose, however, is not recommended for Windows XP and later versions of Windows, and is deprecated in Windows Vista." """
The fDesktopHTML member is being requested. Set is not available.
Instead, for versions of Microsoft Windows prior to Windows XP,
enable Desktop HTML by IActiveDesktop. The use of IActiveDesktop
for this purpose, however, is not recommended for Windows XP and
later versions of Windows, and is deprecated in Windows Vista.
"""
SSF_WIN95CLASSIC = 0x00000400 SSF_WIN95CLASSIC = 0x00000400
"The fWin95Classic member is being requested." "The fWin95Classic member is being requested."

View file

@ -5,20 +5,22 @@ import re
import itertools import itertools
from contextlib import contextmanager from contextlib import contextmanager
import io import io
import six
import ctypes import ctypes
from ctypes import windll from ctypes import windll
import six
from six.moves import map
from jaraco.windows.api import clipboard, memory from jaraco.windows.api import clipboard, memory
from jaraco.windows.error import handle_nonzero_success, WindowsError from jaraco.windows.error import handle_nonzero_success, WindowsError
from jaraco.windows.memory import LockedMemory from jaraco.windows.memory import LockedMemory
__all__ = ( __all__ = (
'CF_TEXT', 'GetClipboardData', 'CloseClipboard', 'GetClipboardData', 'CloseClipboard',
'SetClipboardData', 'OpenClipboard', 'SetClipboardData', 'OpenClipboard',
) )
def OpenClipboard(owner=None): def OpenClipboard(owner=None):
""" """
Open the clipboard. Open the clipboard.
@ -30,9 +32,14 @@ def OpenClipboard(owner=None):
""" """
handle_nonzero_success(windll.user32.OpenClipboard(owner)) handle_nonzero_success(windll.user32.OpenClipboard(owner))
CloseClipboard = lambda: handle_nonzero_success(windll.user32.CloseClipboard())
def CloseClipboard():
handle_nonzero_success(windll.user32.CloseClipboard())
data_handlers = dict() data_handlers = dict()
def handles(*formats): def handles(*formats):
def register(func): def register(func):
for format in formats: for format in formats:
@ -40,36 +47,43 @@ def handles(*formats):
return func return func
return register return register
def nts(s):
def nts(buffer):
""" """
Null Terminated String Null Terminated String
Get the portion of s up to a null character. Get the portion of bytestring buffer up to a null character.
""" """
result, null, rest = s.partition('\x00') result, null, rest = buffer.partition('\x00')
return result return result
@handles(clipboard.CF_DIBV5, clipboard.CF_DIB) @handles(clipboard.CF_DIBV5, clipboard.CF_DIB)
def raw_data(handle): def raw_data(handle):
return LockedMemory(handle).data return LockedMemory(handle).data
@handles(clipboard.CF_TEXT) @handles(clipboard.CF_TEXT)
def text_string(handle): def text_string(handle):
return nts(raw_data(handle)) return nts(raw_data(handle))
@handles(clipboard.CF_UNICODETEXT) @handles(clipboard.CF_UNICODETEXT)
def unicode_string(handle): def unicode_string(handle):
return nts(raw_data(handle).decode('utf-16')) return nts(raw_data(handle).decode('utf-16'))
@handles(clipboard.CF_BITMAP) @handles(clipboard.CF_BITMAP)
def as_bitmap(handle): def as_bitmap(handle):
# handle is HBITMAP # handle is HBITMAP
raise NotImplementedError("Can't convert to DIB") raise NotImplementedError("Can't convert to DIB")
# todo: use GetDIBits http://msdn.microsoft.com/en-us/library/dd144879%28v=VS.85%29.aspx # todo: use GetDIBits http://msdn.microsoft.com
# /en-us/library/dd144879%28v=VS.85%29.aspx
@handles(clipboard.CF_HTML) @handles(clipboard.CF_HTML)
class HTMLSnippet(object): class HTMLSnippet(object):
def __init__(self, handle): def __init__(self, handle):
self.data = text_string(handle) self.data = nts(raw_data(handle).decode('utf-8'))
self.headers = self.parse_headers(self.data) self.headers = self.parse_headers(self.data)
@property @property
@ -79,11 +93,13 @@ class HTMLSnippet(object):
@staticmethod @staticmethod
def parse_headers(data): def parse_headers(data):
d = io.StringIO(data) d = io.StringIO(data)
def header_line(line): def header_line(line):
return re.match('(\w+):(.*)', line) return re.match('(\w+):(.*)', line)
headers = itertools.imap(header_line, d) headers = map(header_line, d)
# grab headers until they no longer match # grab headers until they no longer match
headers = itertools.takewhile(bool, headers) headers = itertools.takewhile(bool, headers)
def best_type(value): def best_type(value):
try: try:
return int(value) return int(value)
@ -101,26 +117,34 @@ class HTMLSnippet(object):
) )
return dict(pairs) return dict(pairs)
def GetClipboardData(type=clipboard.CF_UNICODETEXT): def GetClipboardData(type=clipboard.CF_UNICODETEXT):
if not type in data_handlers: if type not in data_handlers:
raise NotImplementedError("No support for data of type %d" % type) raise NotImplementedError("No support for data of type %d" % type)
handle = clipboard.GetClipboardData(type) handle = clipboard.GetClipboardData(type)
if handle is None: if handle is None:
raise TypeError("No clipboard data of type %d" % type) raise TypeError("No clipboard data of type %d" % type)
return data_handlers[type](handle) return data_handlers[type](handle)
EmptyClipboard = lambda: handle_nonzero_success(windll.user32.EmptyClipboard())
def EmptyClipboard():
handle_nonzero_success(windll.user32.EmptyClipboard())
def SetClipboardData(type, content): def SetClipboardData(type, content):
""" """
Modeled after http://msdn.microsoft.com/en-us/library/ms649016%28VS.85%29.aspx#_win32_Copying_Information_to_the_Clipboard Modeled after http://msdn.microsoft.com
/en-us/library/ms649016%28VS.85%29.aspx
#_win32_Copying_Information_to_the_Clipboard
""" """
allocators = { allocators = {
clipboard.CF_TEXT: ctypes.create_string_buffer, clipboard.CF_TEXT: ctypes.create_string_buffer,
clipboard.CF_UNICODETEXT: ctypes.create_unicode_buffer, clipboard.CF_UNICODETEXT: ctypes.create_unicode_buffer,
clipboard.CF_HTML: ctypes.create_string_buffer,
} }
if not type in allocators: if type not in allocators:
raise NotImplementedError("Only text types are supported at this time") raise NotImplementedError(
"Only text and HTML types are supported at this time")
# allocate the memory for the data # allocate the memory for the data
content = allocators[type](content) content = allocators[type](content)
flags = memory.GMEM_MOVEABLE flags = memory.GMEM_MOVEABLE
@ -132,47 +156,57 @@ def SetClipboardData(type, content):
if result is None: if result is None:
raise WindowsError() raise WindowsError()
def set_text(source): def set_text(source):
with context(): with context():
EmptyClipboard() EmptyClipboard()
SetClipboardData(clipboard.CF_TEXT, source) SetClipboardData(clipboard.CF_TEXT, source)
def get_text(): def get_text():
with context(): with context():
result = GetClipboardData(clipboard.CF_TEXT) result = GetClipboardData(clipboard.CF_TEXT)
return result return result
def set_unicode_text(source): def set_unicode_text(source):
with context(): with context():
EmptyClipboard() EmptyClipboard()
SetClipboardData(clipboard.CF_UNICODETEXT, source) SetClipboardData(clipboard.CF_UNICODETEXT, source)
def get_unicode_text(): def get_unicode_text():
with context(): with context():
return GetClipboardData() return GetClipboardData()
def get_html(): def get_html():
with context(): with context():
result = GetClipboardData(clipboard.CF_HTML) result = GetClipboardData(clipboard.CF_HTML)
return result return result
def set_html(source): def set_html(source):
with context(): with context():
EmptyClipboard() EmptyClipboard()
SetClipboardData(clipboard.CF_UNICODETEXT, source) SetClipboardData(clipboard.CF_UNICODETEXT, source)
def get_image(): def get_image():
with context(): with context():
return GetClipboardData(clipboard.CF_DIB) return GetClipboardData(clipboard.CF_DIB)
def paste_stdout(): def paste_stdout():
getter = get_unicode_text if six.PY3 else get_text getter = get_unicode_text if six.PY3 else get_text
sys.stdout.write(getter()) sys.stdout.write(getter())
def stdin_copy(): def stdin_copy():
setter = set_unicode_text if six.PY3 else set_text setter = set_unicode_text if six.PY3 else set_text
setter(sys.stdin.read()) setter(sys.stdin.read())
@contextmanager @contextmanager
def context(): def context():
OpenClipboard() OpenClipboard()
@ -181,10 +215,12 @@ def context():
finally: finally:
CloseClipboard() CloseClipboard()
def get_formats(): def get_formats():
with context(): with context():
format_index = 0 format_index = 0
while True: while True:
format_index = clipboard.EnumClipboardFormats(format_index) format_index = clipboard.EnumClipboardFormats(format_index)
if format_index == 0: break if format_index == 0:
break
yield format_index yield format_index

View file

@ -5,15 +5,18 @@ from . import error
CRED_TYPE_GENERIC = 1 CRED_TYPE_GENERIC = 1
def CredDelete(TargetName, Type, Flags=0): def CredDelete(TargetName, Type, Flags=0):
error.handle_nonzero_success(api.CredDelete(TargetName, Type, Flags)) error.handle_nonzero_success(api.CredDelete(TargetName, Type, Flags))
def CredRead(TargetName, Type, Flags=0): def CredRead(TargetName, Type, Flags=0):
cred_pointer = api.PCREDENTIAL() cred_pointer = api.PCREDENTIAL()
res = api.CredRead(TargetName, Type, Flags, ctypes.byref(cred_pointer)) res = api.CredRead(TargetName, Type, Flags, ctypes.byref(cred_pointer))
error.handle_nonzero_success(res) error.handle_nonzero_success(res)
return cred_pointer.contents return cred_pointer.contents
def CredWrite(Credential, Flags=0): def CredWrite(Credential, Flags=0):
res = api.CredWrite(Credential, Flags) res = api.CredWrite(Credential, Flags)
error.handle_nonzero_success(res) error.handle_nonzero_success(res)

View file

@ -14,6 +14,11 @@ import ctypes
from ctypes import wintypes from ctypes import wintypes
from jaraco.windows.error import handle_nonzero_success from jaraco.windows.error import handle_nonzero_success
# for type declarations
__import__('jaraco.windows.api.memory')
class DATA_BLOB(ctypes.Structure): class DATA_BLOB(ctypes.Structure):
r""" r"""
A data blob structure for use with MS DPAPI functions. A data blob structure for use with MS DPAPI functions.
@ -65,6 +70,7 @@ class DATA_BLOB(ctypes.Structure):
""" """
ctypes.windll.kernel32.LocalFree(self.data) ctypes.windll.kernel32.LocalFree(self.data)
p_DATA_BLOB = ctypes.POINTER(DATA_BLOB) p_DATA_BLOB = ctypes.POINTER(DATA_BLOB)
_CryptProtectData = ctypes.windll.crypt32.CryptProtectData _CryptProtectData = ctypes.windll.crypt32.CryptProtectData
@ -93,6 +99,7 @@ _CryptUnprotectData.restype = wintypes.BOOL
CRYPTPROTECT_UI_FORBIDDEN = 0x01 CRYPTPROTECT_UI_FORBIDDEN = 0x01
def CryptProtectData( def CryptProtectData(
data, description=None, optional_entropy=None, data, description=None, optional_entropy=None,
prompt_struct=None, flags=0, prompt_struct=None, flags=0,
@ -118,7 +125,9 @@ def CryptProtectData(
data_out.free() data_out.free()
return res return res
def CryptUnprotectData(data, optional_entropy=None, prompt_struct=None, flags=0):
def CryptUnprotectData(
data, optional_entropy=None, prompt_struct=None, flags=0):
""" """
Returns a tuple of (description, data) where description is the Returns a tuple of (description, data) where description is the
the description that was passed to the CryptProtectData call and the description that was passed to the CryptProtectData call and

View file

@ -7,7 +7,7 @@ import ctypes
import ctypes.wintypes import ctypes.wintypes
import six import six
winreg = six.moves.winreg from six.moves import winreg
from jaraco.ui.editor import EditableFile from jaraco.ui.editor import EditableFile
@ -19,17 +19,21 @@ from .registry import key_values as registry_key_values
def SetEnvironmentVariable(name, value): def SetEnvironmentVariable(name, value):
error.handle_nonzero_success(environ.SetEnvironmentVariable(name, value)) error.handle_nonzero_success(environ.SetEnvironmentVariable(name, value))
def ClearEnvironmentVariable(name): def ClearEnvironmentVariable(name):
error.handle_nonzero_success(environ.SetEnvironmentVariable(name, None)) error.handle_nonzero_success(environ.SetEnvironmentVariable(name, None))
def GetEnvironmentVariable(name): def GetEnvironmentVariable(name):
max_size = 2**15 - 1 max_size = 2**15 - 1
buffer = ctypes.create_unicode_buffer(max_size) buffer = ctypes.create_unicode_buffer(max_size)
error.handle_nonzero_success(environ.GetEnvironmentVariable(name, buffer, max_size)) error.handle_nonzero_success(
environ.GetEnvironmentVariable(name, buffer, max_size))
return buffer.value return buffer.value
### ###
class RegisteredEnvironment(object): class RegisteredEnvironment(object):
""" """
Manages the environment variables as set in the Windows Registry. Manages the environment variables as set in the Windows Registry.
@ -86,8 +90,8 @@ class RegisteredEnvironment(object):
if value in values: if value in values:
return return
new_value = sep.join(values + [value]) new_value = sep.join(values + [value])
winreg.SetValueEx(class_.key, name, 0, winreg.REG_EXPAND_SZ, winreg.SetValueEx(
new_value) class_.key, name, 0, winreg.REG_EXPAND_SZ, new_value)
class_.notify() class_.notify()
@classmethod @classmethod
@ -141,24 +145,30 @@ class RegisteredEnvironment(object):
) )
error.handle_nonzero_success(res) error.handle_nonzero_success(res)
class MachineRegisteredEnvironment(RegisteredEnvironment): class MachineRegisteredEnvironment(RegisteredEnvironment):
path = r'SYSTEM\CurrentControlSet\Control\Session Manager\Environment' path = r'SYSTEM\CurrentControlSet\Control\Session Manager\Environment'
hklm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) hklm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
try: try:
key = winreg.OpenKey(hklm, path, 0, key = winreg.OpenKey(
hklm, path, 0,
winreg.KEY_READ | winreg.KEY_WRITE) winreg.KEY_READ | winreg.KEY_WRITE)
except WindowsError: except WindowsError:
key = winreg.OpenKey(hklm, path, 0, winreg.KEY_READ) key = winreg.OpenKey(hklm, path, 0, winreg.KEY_READ)
class UserRegisteredEnvironment(RegisteredEnvironment): class UserRegisteredEnvironment(RegisteredEnvironment):
hkcu = winreg.ConnectRegistry(None, winreg.HKEY_CURRENT_USER) hkcu = winreg.ConnectRegistry(None, winreg.HKEY_CURRENT_USER)
key = winreg.OpenKey(hkcu, 'Environment', 0, key = winreg.OpenKey(
hkcu, 'Environment', 0,
winreg.KEY_READ | winreg.KEY_WRITE) winreg.KEY_READ | winreg.KEY_WRITE)
def trim(s): def trim(s):
from textwrap import dedent from textwrap import dedent
return dedent(s).strip() return dedent(s).strip()
def enver(*args): def enver(*args):
""" """
%prog [<name>=[value]] %prog [<name>=[value]]
@ -201,17 +211,21 @@ def enver(*args):
dest='class_', dest='class_',
help="Use the current user's environment", help="Use the current user's environment",
) )
parser.add_option('-a', '--append', parser.add_option(
'-a', '--append',
action='store_true', default=False, action='store_true', default=False,
help="Append the value to any existing value (default for PATH and PATHEXT)",) help="Append the value to any existing value (default for PATH and PATHEXT)",
)
parser.add_option( parser.add_option(
'-r', '--replace', '-r', '--replace',
action='store_true', default=False, action='store_true', default=False,
help="Replace any existing value (used to override default append for PATH and PATHEXT)", help="Replace any existing value (used to override default append "
"for PATH and PATHEXT)",
) )
parser.add_option( parser.add_option(
'--remove-value', action='store_true', default=False, '--remove-value', action='store_true', default=False,
help="Remove any matching values from a semicolon-separated multi-value variable", help="Remove any matching values from a semicolon-separated "
"multi-value variable",
) )
parser.add_option( parser.add_option(
'-e', '--edit', action='store_true', default=False, '-e', '--edit', action='store_true', default=False,
@ -224,7 +238,7 @@ def enver(*args):
if args: if args:
parser.error("Too many parameters specified") parser.error("Too many parameters specified")
raise SystemExit(1) raise SystemExit(1)
if not '=' in param and not options.edit: if '=' not in param and not options.edit:
parser.error("Expected <name>= or <name>=<value>") parser.error("Expected <name>= or <name>=<value>")
raise SystemExit(2) raise SystemExit(2)
name, sep, value = param.partition('=') name, sep, value = param.partition('=')
@ -238,5 +252,6 @@ def enver(*args):
except IndexError: except IndexError:
options.class_.show() options.class_.show()
if __name__ == '__main__': if __name__ == '__main__':
enver() enver()

View file

@ -6,9 +6,12 @@ import ctypes
import ctypes.wintypes import ctypes.wintypes
import six import six
builtins = six.moves.builtins builtins = six.moves.builtins
__import__('jaraco.windows.api.memory')
def format_system_message(errno): def format_system_message(errno):
""" """
Call FormatMessage with a system error number to retrieve Call FormatMessage with a system error number to retrieve
@ -46,7 +49,10 @@ def format_system_message(errno):
class WindowsError(builtins.WindowsError): class WindowsError(builtins.WindowsError):
"more info about errors at http://msdn.microsoft.com/en-us/library/ms681381(VS.85).aspx" """
More info about errors at
http://msdn.microsoft.com/en-us/library/ms681381(VS.85).aspx
"""
def __init__(self, value=None): def __init__(self, value=None):
if value is None: if value is None:
@ -72,6 +78,7 @@ class WindowsError(builtins.WindowsError):
def __repr__(self): def __repr__(self):
return '{self.__class__.__name__}({self.winerror})'.format(**vars()) return '{self.__class__.__name__}({self.winerror})'.format(**vars())
def handle_nonzero_success(result): def handle_nonzero_success(result):
if result == 0: if result == 0:
raise WindowsError() raise WindowsError()

View file

@ -8,6 +8,7 @@ import win32evtlogutil
error = win32api.error # The error the evtlog module raises. error = win32api.error # The error the evtlog module raises.
class EventLog(object): class EventLog(object):
def __init__(self, name="Application", machine_name=None): def __init__(self, name="Application", machine_name=None):
self.machine_name = machine_name self.machine_name = machine_name
@ -29,6 +30,7 @@ class EventLog(object):
win32evtlog.EVENTLOG_BACKWARDS_READ win32evtlog.EVENTLOG_BACKWARDS_READ
| win32evtlog.EVENTLOG_SEQUENTIAL_READ | win32evtlog.EVENTLOG_SEQUENTIAL_READ
) )
def get_records(self, flags=_default_flags): def get_records(self, flags=_default_flags):
with self: with self:
while True: while True:

View file

@ -7,19 +7,24 @@ import sys
import operator import operator
import collections import collections
import functools import functools
from ctypes import (POINTER, byref, cast, create_unicode_buffer, import stat
from ctypes import (
POINTER, byref, cast, create_unicode_buffer,
create_string_buffer, windll) create_string_buffer, windll)
from ctypes.wintypes import LPWSTR
import nt
import posixpath
import six import six
from six.moves import builtins, filter, map from six.moves import builtins, filter, map
from jaraco.structures import binary from jaraco.structures import binary
from jaraco.text import local_format as lf
from jaraco.windows.error import WindowsError, handle_nonzero_success from jaraco.windows.error import WindowsError, handle_nonzero_success
import jaraco.windows.api.filesystem as api import jaraco.windows.api.filesystem as api
from jaraco.windows import reparse from jaraco.windows import reparse
def mklink(): def mklink():
""" """
Like cmd.exe's mklink except it will infer directory status of the Like cmd.exe's mklink except it will infer directory status of the
@ -27,7 +32,8 @@ def mklink():
""" """
from optparse import OptionParser from optparse import OptionParser
parser = OptionParser(usage="usage: %prog [options] link target") parser = OptionParser(usage="usage: %prog [options] link target")
parser.add_option('-d', '--directory', parser.add_option(
'-d', '--directory',
help="Target is a directory (only necessary if not present)", help="Target is a directory (only necessary if not present)",
action="store_true") action="store_true")
options, args = parser.parse_args() options, args = parser.parse_args()
@ -38,6 +44,7 @@ def mklink():
symlink(target, link, options.directory) symlink(target, link, options.directory)
sys.stdout.write("Symbolic link created: %(link)s --> %(target)s\n" % vars()) sys.stdout.write("Symbolic link created: %(link)s --> %(target)s\n" % vars())
def _is_target_a_directory(link, rel_target): def _is_target_a_directory(link, rel_target):
""" """
If creating a symlink from link to a target, determine if target If creating a symlink from link to a target, determine if target
@ -46,15 +53,20 @@ def _is_target_a_directory(link, rel_target):
target = os.path.join(os.path.dirname(link), rel_target) target = os.path.join(os.path.dirname(link), rel_target)
return os.path.isdir(target) return os.path.isdir(target)
def symlink(target, link, target_is_directory=False): def symlink(target, link, target_is_directory=False):
""" """
An implementation of os.symlink for Windows (Vista and greater) An implementation of os.symlink for Windows (Vista and greater)
""" """
target_is_directory = (target_is_directory or target_is_directory = (
_is_target_a_directory(link, target)) target_is_directory or
_is_target_a_directory(link, target)
)
# normalize the target (MS symlinks don't respect forward slashes) # normalize the target (MS symlinks don't respect forward slashes)
target = os.path.normpath(target) target = os.path.normpath(target)
handle_nonzero_success(api.CreateSymbolicLink(link, target, target_is_directory)) handle_nonzero_success(
api.CreateSymbolicLink(link, target, target_is_directory))
def link(target, link): def link(target, link):
""" """
@ -62,6 +74,7 @@ def link(target, link):
""" """
handle_nonzero_success(api.CreateHardLink(link, target, None)) handle_nonzero_success(api.CreateHardLink(link, target, None))
def is_reparse_point(path): def is_reparse_point(path):
""" """
Determine if the given path is a reparse point. Determine if the given path is a reparse point.
@ -74,10 +87,12 @@ def is_reparse_point(path):
and bool(res & api.FILE_ATTRIBUTE_REPARSE_POINT) and bool(res & api.FILE_ATTRIBUTE_REPARSE_POINT)
) )
def islink(path): def islink(path):
"Determine if the given path is a symlink" "Determine if the given path is a symlink"
return is_reparse_point(path) and is_symlink(path) return is_reparse_point(path) and is_symlink(path)
def _patch_path(path): def _patch_path(path):
""" """
Paths have a max length of api.MAX_PATH characters (260). If a target path Paths have a max length of api.MAX_PATH characters (260). If a target path
@ -86,13 +101,15 @@ def _patch_path(path):
See http://msdn.microsoft.com/en-us/library/aa365247%28v=vs.85%29.aspx for See http://msdn.microsoft.com/en-us/library/aa365247%28v=vs.85%29.aspx for
details. details.
""" """
if path.startswith('\\\\?\\'): return path if path.startswith('\\\\?\\'):
return path
abs_path = os.path.abspath(path) abs_path = os.path.abspath(path)
if not abs_path[1] == ':': if not abs_path[1] == ':':
# python doesn't include the drive letter, but \\?\ requires it # python doesn't include the drive letter, but \\?\ requires it
abs_path = os.getcwd()[:2] + abs_path abs_path = os.getcwd()[:2] + abs_path
return '\\\\?\\' + abs_path return '\\\\?\\' + abs_path
def is_symlink(path): def is_symlink(path):
""" """
Assuming path is a reparse point, determine if it's a symlink. Assuming path is a reparse point, determine if it's a symlink.
@ -102,11 +119,13 @@ def is_symlink(path):
return _is_symlink(next(find_files(path))) return _is_symlink(next(find_files(path)))
except WindowsError as orig_error: except WindowsError as orig_error:
tmpl = "Error accessing {path}: {orig_error.message}" tmpl = "Error accessing {path}: {orig_error.message}"
raise builtins.WindowsError(lf(tmpl)) raise builtins.WindowsError(tmpl.format(**locals()))
def _is_symlink(find_data): def _is_symlink(find_data):
return find_data.reserved[0] == api.IO_REPARSE_TAG_SYMLINK return find_data.reserved[0] == api.IO_REPARSE_TAG_SYMLINK
def find_files(spec): def find_files(spec):
""" """
A pythonic wrapper around the FindFirstFile/FindNextFile win32 api. A pythonic wrapper around the FindFirstFile/FindNextFile win32 api.
@ -133,11 +152,13 @@ def find_files(spec):
error = WindowsError() error = WindowsError()
if error.code == api.ERROR_NO_MORE_FILES: if error.code == api.ERROR_NO_MORE_FILES:
break break
else: raise error else:
raise error
# todo: how to close handle when generator is destroyed? # todo: how to close handle when generator is destroyed?
# hint: catch GeneratorExit # hint: catch GeneratorExit
windll.kernel32.FindClose(handle) windll.kernel32.FindClose(handle)
def get_final_path(path): def get_final_path(path):
""" """
For a given path, determine the ultimate location of that path. For a given path, determine the ultimate location of that path.
@ -150,7 +171,9 @@ def get_final_path(path):
trace_symlink_target instead. trace_symlink_target instead.
""" """
desired_access = api.NULL desired_access = api.NULL
share_mode = api.FILE_SHARE_READ | api.FILE_SHARE_WRITE | api.FILE_SHARE_DELETE share_mode = (
api.FILE_SHARE_READ | api.FILE_SHARE_WRITE | api.FILE_SHARE_DELETE
)
security_attributes = api.LPSECURITY_ATTRIBUTES() # NULL pointer security_attributes = api.LPSECURITY_ATTRIBUTES() # NULL pointer
hFile = api.CreateFile( hFile = api.CreateFile(
path, path,
@ -165,10 +188,12 @@ def get_final_path(path):
if hFile == api.INVALID_HANDLE_VALUE: if hFile == api.INVALID_HANDLE_VALUE:
raise WindowsError() raise WindowsError()
buf_size = api.GetFinalPathNameByHandle(hFile, api.LPWSTR(), 0, api.VOLUME_NAME_DOS) buf_size = api.GetFinalPathNameByHandle(
hFile, LPWSTR(), 0, api.VOLUME_NAME_DOS)
handle_nonzero_success(buf_size) handle_nonzero_success(buf_size)
buf = create_unicode_buffer(buf_size) buf = create_unicode_buffer(buf_size)
result_length = api.GetFinalPathNameByHandle(hFile, buf, len(buf), api.VOLUME_NAME_DOS) result_length = api.GetFinalPathNameByHandle(
hFile, buf, len(buf), api.VOLUME_NAME_DOS)
assert result_length < len(buf) assert result_length < len(buf)
handle_nonzero_success(result_length) handle_nonzero_success(result_length)
@ -176,22 +201,83 @@ def get_final_path(path):
return buf[:result_length] return buf[:result_length]
def compat_stat(path):
"""
Generate stat as found on Python 3.2 and later.
"""
stat = os.stat(path)
info = get_file_info(path)
# rewrite st_ino, st_dev, and st_nlink based on file info
return nt.stat_result(
(stat.st_mode,) +
(info.file_index, info.volume_serial_number, info.number_of_links) +
stat[4:]
)
def samefile(f1, f2):
"""
Backport of samefile from Python 3.2 with support for Windows.
"""
return posixpath.samestat(compat_stat(f1), compat_stat(f2))
def get_file_info(path):
# open the file the same way CPython does in posixmodule.c
desired_access = api.FILE_READ_ATTRIBUTES
share_mode = 0
security_attributes = None
creation_disposition = api.OPEN_EXISTING
flags_and_attributes = (
api.FILE_ATTRIBUTE_NORMAL |
api.FILE_FLAG_BACKUP_SEMANTICS |
api.FILE_FLAG_OPEN_REPARSE_POINT
)
template_file = None
handle = api.CreateFile(
path,
desired_access,
share_mode,
security_attributes,
creation_disposition,
flags_and_attributes,
template_file,
)
if handle == api.INVALID_HANDLE_VALUE:
raise WindowsError()
info = api.BY_HANDLE_FILE_INFORMATION()
res = api.GetFileInformationByHandle(handle, info)
handle_nonzero_success(res)
handle_nonzero_success(api.CloseHandle(handle))
return info
def GetBinaryType(filepath): def GetBinaryType(filepath):
res = api.DWORD() res = api.DWORD()
handle_nonzero_success(api._GetBinaryType(filepath, res)) handle_nonzero_success(api._GetBinaryType(filepath, res))
return res return res
def _make_null_terminated_list(obs): def _make_null_terminated_list(obs):
obs = _makelist(obs) obs = _makelist(obs)
if obs is None: return if obs is None:
return
return u'\x00'.join(obs) + u'\x00\x00' return u'\x00'.join(obs) + u'\x00\x00'
def _makelist(ob): def _makelist(ob):
if ob is None: return if ob is None:
return
if not isinstance(ob, (list, tuple, set)): if not isinstance(ob, (list, tuple, set)):
return [ob] return [ob]
return ob return ob
def SHFileOperation(operation, from_, to=None, flags=[]): def SHFileOperation(operation, from_, to=None, flags=[]):
flags = functools.reduce(operator.or_, flags, 0) flags = functools.reduce(operator.or_, flags, 0)
from_ = _make_null_terminated_list(from_) from_ = _make_null_terminated_list(from_)
@ -201,6 +287,7 @@ def SHFileOperation(operation, from_, to=None, flags=[]):
if res != 0: if res != 0:
raise RuntimeError("SHFileOperation returned %d" % res) raise RuntimeError("SHFileOperation returned %d" % res)
def join(*paths): def join(*paths):
r""" r"""
Wrapper around os.path.join that works with Windows drive letters. Wrapper around os.path.join that works with Windows drive letters.
@ -214,17 +301,17 @@ def join(*paths):
drive = next(filter(None, reversed(drives)), '') drive = next(filter(None, reversed(drives)), '')
return os.path.join(drive, os.path.join(*paths)) return os.path.join(drive, os.path.join(*paths))
def resolve_path(target, start=os.path.curdir): def resolve_path(target, start=os.path.curdir):
r""" r"""
Find a path from start to target where target is relative to start. Find a path from start to target where target is relative to start.
>>> orig_wd = os.getcwd() >>> tmp = str(getfixture('tmpdir_as_cwd'))
>>> os.chdir('c:\\windows') # so we know what the working directory is
>>> findpath('d:\\') >>> findpath('d:\\')
'd:\\' 'd:\\'
>>> findpath('d:\\', 'c:\\windows') >>> findpath('d:\\', tmp)
'd:\\' 'd:\\'
>>> findpath('\\bar', 'd:\\') >>> findpath('\\bar', 'd:\\')
@ -239,11 +326,11 @@ def resolve_path(target, start=os.path.curdir):
>>> findpath('\\baz', 'd:\\foo\\bar') # fails with '\\baz' >>> findpath('\\baz', 'd:\\foo\\bar') # fails with '\\baz'
'd:\\baz' 'd:\\baz'
>>> os.path.abspath(findpath('\\bar')) >>> os.path.abspath(findpath('\\bar')).lower()
'c:\\bar' 'c:\\bar'
>>> os.path.abspath(findpath('bar')) >>> os.path.abspath(findpath('bar'))
'c:\\windows\\bar' '...\\bar'
>>> findpath('..', 'd:\\foo\\bar') >>> findpath('..', 'd:\\foo\\bar')
'd:\\foo' 'd:\\foo'
@ -254,8 +341,10 @@ def resolve_path(target, start=os.path.curdir):
""" """
return os.path.normpath(join(start, target)) return os.path.normpath(join(start, target))
findpath = resolve_path findpath = resolve_path
def trace_symlink_target(link): def trace_symlink_target(link):
""" """
Given a file that is known to be a symlink, trace it to its ultimate Given a file that is known to be a symlink, trace it to its ultimate
@ -273,6 +362,7 @@ def trace_symlink_target(link):
link = resolve_path(link, orig) link = resolve_path(link, orig)
return link return link
def readlink(link): def readlink(link):
""" """
readlink(link) -> target readlink(link) -> target
@ -291,7 +381,8 @@ def readlink(link):
if handle == api.INVALID_HANDLE_VALUE: if handle == api.INVALID_HANDLE_VALUE:
raise WindowsError() raise WindowsError()
res = reparse.DeviceIoControl(handle, api.FSCTL_GET_REPARSE_POINT, None, 10240) res = reparse.DeviceIoControl(
handle, api.FSCTL_GET_REPARSE_POINT, None, 10240)
bytes = create_string_buffer(res) bytes = create_string_buffer(res)
p_rdb = cast(bytes, POINTER(api.REPARSE_DATA_BUFFER)) p_rdb = cast(bytes, POINTER(api.REPARSE_DATA_BUFFER))
@ -302,6 +393,7 @@ def readlink(link):
handle_nonzero_success(api.CloseHandle(handle)) handle_nonzero_success(api.CloseHandle(handle))
return rdb.get_substitute_name() return rdb.get_substitute_name()
def patch_os_module(): def patch_os_module():
""" """
jaraco.windows provides the os.symlink and os.readlink functions. jaraco.windows provides the os.symlink and os.readlink functions.
@ -313,6 +405,7 @@ def patch_os_module():
if not hasattr(os, 'readlink'): if not hasattr(os, 'readlink'):
os.readlink = readlink os.readlink = readlink
def find_symlinks(root): def find_symlinks(root):
for dirpath, dirnames, filenames in os.walk(root): for dirpath, dirnames, filenames in os.walk(root):
for name in dirnames + filenames: for name in dirnames + filenames:
@ -323,6 +416,7 @@ def find_symlinks(root):
if name in dirnames: if name in dirnames:
dirnames.remove(name) dirnames.remove(name)
def find_symlinks_cmd(): def find_symlinks_cmd():
""" """
%prog [start-path] %prog [start-path]
@ -333,7 +427,8 @@ def find_symlinks_cmd():
from textwrap import dedent from textwrap import dedent
parser = OptionParser(usage=dedent(find_symlinks_cmd.__doc__).strip()) parser = OptionParser(usage=dedent(find_symlinks_cmd.__doc__).strip())
options, args = parser.parse_args() options, args = parser.parse_args()
if not args: args = ['.'] if not args:
args = ['.']
root = args.pop() root = args.pop()
if args: if args:
parser.error("unexpected argument(s)") parser.error("unexpected argument(s)")
@ -346,8 +441,19 @@ def find_symlinks_cmd():
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
@six.add_metaclass(binary.BitMask) @six.add_metaclass(binary.BitMask)
class FileAttributes(int): class FileAttributes(int):
# extract the values from the stat module on Python 3.5
# and later.
locals().update(
(name.split('FILE_ATTRIBUTES_')[1].lower(), value)
for name, value in vars(stat).items()
if name.startswith('FILE_ATTRIBUTES_')
)
# For Python 3.4 and earlier, define the constants here
archive = 0x20 archive = 0x20
compressed = 0x800 compressed = 0x800
hidden = 0x2 hidden = 0x2
@ -364,11 +470,16 @@ class FileAttributes(int):
temporary = 0x100 temporary = 0x100
virtual = 0x10000 virtual = 0x10000
def GetFileAttributes(filepath): @classmethod
def get(cls, filepath):
attrs = api.GetFileAttributes(filepath) attrs = api.GetFileAttributes(filepath)
if attrs == api.INVALID_FILE_ATTRIBUTES: if attrs == api.INVALID_FILE_ATTRIBUTES:
raise WindowsError() raise WindowsError()
return FileAttributes(attrs) return cls(attrs)
GetFileAttributes = FileAttributes.get
def SetFileAttributes(filepath, *attrs): def SetFileAttributes(filepath, *attrs):
""" """

View file

@ -0,0 +1,109 @@
from __future__ import unicode_literals
import os.path
# realpath taken from https://bugs.python.org/file38057/issue9949-v4.patch
def realpath(path):
if isinstance(path, str):
prefix = '\\\\?\\'
unc_prefix = prefix + 'UNC'
new_unc_prefix = '\\'
cwd = os.getcwd()
else:
prefix = b'\\\\?\\'
unc_prefix = prefix + b'UNC'
new_unc_prefix = b'\\'
cwd = os.getcwdb()
had_prefix = path.startswith(prefix)
path, ok = _resolve_path(cwd, path, {})
# The path returned by _getfinalpathname will always start with \\?\ -
# strip off that prefix unless it was already provided on the original
# path.
if not had_prefix:
# For UNC paths, the prefix will actually be \\?\UNC - handle that
# case as well.
if path.startswith(unc_prefix):
path = new_unc_prefix + path[len(unc_prefix):]
elif path.startswith(prefix):
path = path[len(prefix):]
return path
def _resolve_path(path, rest, seen):
# Windows normalizes the path before resolving symlinks; be sure to
# follow the same behavior.
rest = os.path.normpath(rest)
if isinstance(rest, str):
sep = '\\'
else:
sep = b'\\'
if os.path.isabs(rest):
drive, rest = os.path.splitdrive(rest)
path = drive + sep
rest = rest[1:]
while rest:
name, _, rest = rest.partition(sep)
new_path = os.path.join(path, name) if path else name
if os.path.exists(new_path):
if not rest:
# The whole path exists. Resolve it using the OS.
path = os.path._getfinalpathname(new_path)
else:
# The OS can resolve `new_path`; keep traversing the path.
path = new_path
elif not os.path.lexists(new_path):
# `new_path` does not exist on the filesystem at all. Use the
# OS to resolve `path`, if it exists, and then append the
# remainder.
if os.path.exists(path):
path = os.path._getfinalpathname(path)
rest = os.path.join(name, rest) if rest else name
return os.path.join(path, rest), True
else:
# We have a symbolic link that the OS cannot resolve. Try to
# resolve it ourselves.
# On Windows, symbolic link resolution can be partially or
# fully disabled [1]. The end result of a disabled symlink
# appears the same as a broken symlink (lexists() returns True
# but exists() returns False). And in both cases, the link can
# still be read using readlink(). Call stat() and check the
# resulting error code to ensure we don't circumvent the
# Windows symbolic link restrictions.
# [1] https://technet.microsoft.com/en-us/library/cc754077.aspx
try:
os.stat(new_path)
except OSError as e:
# WinError 1463: The symbolic link cannot be followed
# because its type is disabled.
if e.winerror == 1463:
raise
key = os.path.normcase(new_path)
if key in seen:
# This link has already been seen; try to use the
# previously resolved value.
path = seen[key]
if path is None:
# It has not yet been resolved, which means we must
# have a symbolic link loop. Return what we have
# resolved so far plus the remainder of the path (who
# cares about the Zen of Python?).
path = os.path.join(new_path, rest) if rest else new_path
return path, False
else:
# Mark this link as in the process of being resolved.
seen[key] = None
# Try to resolve it.
path, ok = _resolve_path(path, os.readlink(new_path), seen)
if ok:
# Resolution succeded; store the resolved value.
seen[key] = path
else:
# Resolution failed; punt.
return (os.path.join(path, rest) if rest else path), False
return path, True

View file

@ -17,6 +17,8 @@ from threading import Thread
import itertools import itertools
import logging import logging
import six
from more_itertools.recipes import consume from more_itertools.recipes import consume
import jaraco.text import jaraco.text
@ -25,9 +27,11 @@ from jaraco.windows.api import event
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class NotifierException(Exception): class NotifierException(Exception):
pass pass
class FileFilter(object): class FileFilter(object):
def set_root(self, root): def set_root(self, root):
self.root = root self.root = root
@ -35,9 +39,11 @@ class FileFilter(object):
def _get_file_path(self, filename): def _get_file_path(self, filename):
try: try:
filename = os.path.join(self.root, filename) filename = os.path.join(self.root, filename)
except AttributeError: pass except AttributeError:
pass
return filename return filename
class ModifiedTimeFilter(FileFilter): class ModifiedTimeFilter(FileFilter):
""" """
Returns true for each call where the modified time of the file is after Returns true for each call where the modified time of the file is after
@ -53,6 +59,7 @@ class ModifiedTimeFilter(FileFilter):
log.debug('{filepath} last modified at {last_mod}.'.format(**vars())) log.debug('{filepath} last modified at {last_mod}.'.format(**vars()))
return last_mod > self.cutoff return last_mod > self.cutoff
class PatternFilter(FileFilter): class PatternFilter(FileFilter):
""" """
Filter that returns True for files that match pattern (a regular Filter that returns True for files that match pattern (a regular
@ -60,13 +67,14 @@ class PatternFilter(FileFilter):
""" """
def __init__(self, pattern): def __init__(self, pattern):
self.pattern = ( self.pattern = (
re.compile(pattern) if isinstance(pattern, basestring) re.compile(pattern) if isinstance(pattern, six.string_types)
else pattern else pattern
) )
def __call__(self, file): def __call__(self, file):
return bool(self.pattern.match(file, re.I)) return bool(self.pattern.match(file, re.I))
class GlobFilter(PatternFilter): class GlobFilter(PatternFilter):
""" """
Filter that returns True for files that match the pattern (a glob Filter that returns True for files that match the pattern (a glob
@ -102,6 +110,7 @@ class AggregateFilter(FileFilter):
def __call__(self, file): def __call__(self, file):
return all(fil(file) for fil in self.filters) return all(fil(file) for fil in self.filters)
class OncePerModFilter(FileFilter): class OncePerModFilter(FileFilter):
def __init__(self): def __init__(self):
self.history = list() self.history = list()
@ -115,13 +124,16 @@ class OncePerModFilter(FileFilter):
del self.history[-50:] del self.history[-50:]
return result return result
def files_with_path(files, path): def files_with_path(files, path):
return (os.path.join(path, file) for file in files) return (os.path.join(path, file) for file in files)
def get_file_paths(walk_result): def get_file_paths(walk_result):
root, dirs, files = walk_result root, dirs, files = walk_result
return files_with_path(files, root) return files_with_path(files, root)
class Notifier(object): class Notifier(object):
def __init__(self, root='.', filters=[]): def __init__(self, root='.', filters=[]):
# assign the root, verify it exists # assign the root, verify it exists
@ -138,7 +150,8 @@ class Notifier(object):
def __del__(self): def __del__(self):
try: try:
fs.FindCloseChangeNotification(self.hChange) fs.FindCloseChangeNotification(self.hChange)
except: pass except Exception:
pass
def _get_change_handle(self): def _get_change_handle(self):
# set up to monitor the directory tree specified # set up to monitor the directory tree specified
@ -151,8 +164,8 @@ class Notifier(object):
# make sure it worked; if not, bail # make sure it worked; if not, bail
INVALID_HANDLE_VALUE = fs.INVALID_HANDLE_VALUE INVALID_HANDLE_VALUE = fs.INVALID_HANDLE_VALUE
if self.hChange == INVALID_HANDLE_VALUE: if self.hChange == INVALID_HANDLE_VALUE:
raise NotifierException('Could not set up directory change ' raise NotifierException(
'notification') 'Could not set up directory change notification')
@staticmethod @staticmethod
def _filtered_walk(path, file_filter): def _filtered_walk(path, file_filter):
@ -171,6 +184,7 @@ class Notifier(object):
def quit(self): def quit(self):
event.SetEvent(self.quit_event) event.SetEvent(self.quit_event)
class BlockingNotifier(Notifier): class BlockingNotifier(Notifier):
@staticmethod @staticmethod
@ -215,14 +229,15 @@ class BlockingNotifier(Notifier):
result = next(results) result = next(results)
return result return result
class ThreadedNotifier(BlockingNotifier, Thread): class ThreadedNotifier(BlockingNotifier, Thread):
r""" r"""
ThreadedNotifier provides a simple interface that calls the handler ThreadedNotifier provides a simple interface that calls the handler
for each file rooted in root that passes the filters. It runs as its own for each file rooted in root that passes the filters. It runs as its own
thread, so must be started as such. thread, so must be started as such::
>>> notifier = ThreadedNotifier('c:\\', handler = StreamHandler()) # doctest: +SKIP notifier = ThreadedNotifier('c:\\', handler = StreamHandler())
>>> notifier.start() # doctest: +SKIP notifier.start()
C:\Autoexec.bat changed. C:\Autoexec.bat changed.
""" """
def __init__(self, root='.', filters=[], handler=lambda file: None): def __init__(self, root='.', filters=[], handler=lambda file: None):
@ -242,6 +257,7 @@ class ThreadedNotifier(BlockingNotifier, Thread):
for file in self.get_changed_files(): for file in self.get_changed_files():
self.handle(file) self.handle(file)
class StreamHandler(object): class StreamHandler(object):
""" """
StreamHandler: a sample handler object for use with the threaded StreamHandler: a sample handler object for use with the threaded

View file

@ -28,6 +28,7 @@ def GetAdaptersAddresses():
yield struct_p.contents yield struct_p.contents
struct_p = struct_p.contents.next struct_p = struct_p.contents.next
class AllocatedTable(object): class AllocatedTable(object):
""" """
Both the interface table and the ip address table use the same Both the interface table and the ip address table use the same
@ -83,16 +84,19 @@ class AllocatedTable(object):
pointer_type = ctypes.POINTER(entries_array) pointer_type = ctypes.POINTER(entries_array)
return ctypes.cast(table.entries, pointer_type).contents return ctypes.cast(table.entries, pointer_type).contents
class InterfaceTable(AllocatedTable): class InterfaceTable(AllocatedTable):
method = inet.GetIfTable method = inet.GetIfTable
structure = inet.MIB_IFTABLE structure = inet.MIB_IFTABLE
row_structure = inet.MIB_IFROW row_structure = inet.MIB_IFROW
class AddressTable(AllocatedTable): class AddressTable(AllocatedTable):
method = inet.GetIpAddrTable method = inet.GetIpAddrTable
structure = inet.MIB_IPADDRTABLE structure = inet.MIB_IPADDRTABLE
row_structure = inet.MIB_IPADDRROW row_structure = inet.MIB_IPADDRROW
class AddressManager(object): class AddressManager(object):
@staticmethod @staticmethod
def hardware_address_to_string(addr): def hardware_address_to_string(addr):
@ -100,7 +104,8 @@ class AddressManager(object):
return ':'.join(hex_bytes) return ':'.join(hex_bytes)
def get_host_mac_address_strings(self): def get_host_mac_address_strings(self):
return (self.hardware_address_to_string(addr) return (
self.hardware_address_to_string(addr)
for addr in self.get_host_mac_addresses()) for addr in self.get_host_mac_addresses())
def get_host_ip_address_strings(self): def get_host_ip_address_strings(self):

View file

@ -2,6 +2,7 @@ import ctypes
from .api import library from .api import library
def find_lib(lib): def find_lib(lib):
r""" r"""
Find the DLL for a given library. Find the DLL for a given library.

View file

@ -3,6 +3,7 @@ from ctypes import WinError
from .api import memory from .api import memory
class LockedMemory(object): class LockedMemory(object):
def __init__(self, handle): def __init__(self, handle):
self.handle = handle self.handle = handle

View file

@ -5,6 +5,7 @@ import six
from .error import handle_nonzero_success from .error import handle_nonzero_success
from .api import memory from .api import memory
class MemoryMap(object): class MemoryMap(object):
""" """
A memory map object which can have security attributes overridden. A memory map object which can have security attributes overridden.

View file

@ -10,13 +10,15 @@ import itertools
import six import six
class CookieMonster(object): class CookieMonster(object):
"Read cookies out of a user's IE cookies file" "Read cookies out of a user's IE cookies file"
@property @property
def cookie_dir(self): def cookie_dir(self):
import _winreg as winreg import _winreg as winreg
key = winreg.OpenKeyEx(winreg.HKEY_CURRENT_USER, 'Software' key = winreg.OpenKeyEx(
winreg.HKEY_CURRENT_USER, 'Software'
'\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders') '\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders')
cookie_dir, type = winreg.QueryValueEx(key, 'Cookies') cookie_dir, type = winreg.QueryValueEx(key, 'Cookies')
return cookie_dir return cookie_dir
@ -24,10 +26,12 @@ class CookieMonster(object):
def entries(self, filename): def entries(self, filename):
with open(os.path.join(self.cookie_dir, filename)) as cookie_file: with open(os.path.join(self.cookie_dir, filename)) as cookie_file:
while True: while True:
entry = itertools.takewhile(self.is_not_cookie_delimiter, entry = itertools.takewhile(
self.is_not_cookie_delimiter,
cookie_file) cookie_file)
entry = list(map(six.text_type.rstrip, entry)) entry = list(map(six.text_type.rstrip, entry))
if not entry: break if not entry:
break
cookie = self.make_cookie(*entry) cookie = self.make_cookie(*entry)
yield cookie yield cookie
@ -36,7 +40,8 @@ class CookieMonster(object):
return s != '*\n' return s != '*\n'
@staticmethod @staticmethod
def make_cookie(key, value, domain, flags, ExpireLow, ExpireHigh, def make_cookie(
key, value, domain, flags, ExpireLow, ExpireHigh,
CreateLow, CreateHigh): CreateLow, CreateHigh):
expires = (int(ExpireHigh) << 32) | int(ExpireLow) expires = (int(ExpireHigh) << 32) | int(ExpireLow)
created = (int(CreateHigh) << 32) | int(CreateLow) created = (int(CreateHigh) << 32) | int(CreateLow)

View file

@ -7,7 +7,9 @@ __all__ = ('AddConnection')
from jaraco.windows.error import WindowsError from jaraco.windows.error import WindowsError
from .api import net from .api import net
def AddConnection(remote_name, type=net.RESOURCETYPE_ANY, local_name=None,
def AddConnection(
remote_name, type=net.RESOURCETYPE_ANY, local_name=None,
provider_name=None, user=None, password=None, flags=0): provider_name=None, user=None, password=None, flags=0):
resource = net.NETRESOURCE( resource = net.NETRESOURCE(
type=type, type=type,

View file

@ -4,7 +4,6 @@ from __future__ import print_function
import itertools import itertools
import contextlib import contextlib
import ctypes
from more_itertools.recipes import consume, unique_justseen from more_itertools.recipes import consume, unique_justseen
try: try:
@ -15,11 +14,13 @@ except ImportError:
from jaraco.windows.error import handle_nonzero_success from jaraco.windows.error import handle_nonzero_success
from .api import power from .api import power
def GetSystemPowerStatus(): def GetSystemPowerStatus():
stat = power.SYSTEM_POWER_STATUS() stat = power.SYSTEM_POWER_STATUS()
handle_nonzero_success(GetSystemPowerStatus(stat)) handle_nonzero_success(GetSystemPowerStatus(stat))
return stat return stat
def _init_power_watcher(): def _init_power_watcher():
global power_watcher global power_watcher
if 'power_watcher' not in globals(): if 'power_watcher' not in globals():
@ -27,18 +28,22 @@ def _init_power_watcher():
query = 'SELECT * from Win32_PowerManagementEvent' query = 'SELECT * from Win32_PowerManagementEvent'
power_watcher = wmi.ExecNotificationQuery(query) power_watcher = wmi.ExecNotificationQuery(query)
def get_power_management_events(): def get_power_management_events():
_init_power_watcher() _init_power_watcher()
while True: while True:
yield power_watcher.NextEvent() yield power_watcher.NextEvent()
def wait_for_power_status_change(): def wait_for_power_status_change():
EVT_POWER_STATUS_CHANGE = 10 EVT_POWER_STATUS_CHANGE = 10
def not_power_status_change(evt): def not_power_status_change(evt):
return evt.EventType != EVT_POWER_STATUS_CHANGE return evt.EventType != EVT_POWER_STATUS_CHANGE
events = get_power_management_events() events = get_power_management_events()
consume(itertools.takewhile(not_power_status_change, events)) consume(itertools.takewhile(not_power_status_change, events))
def get_unique_power_states(): def get_unique_power_states():
""" """
Just like get_power_states, but ensures values are returned only Just like get_power_states, but ensures values are returned only
@ -46,6 +51,7 @@ def get_unique_power_states():
""" """
return unique_justseen(get_power_states()) return unique_justseen(get_power_states())
def get_power_states(): def get_power_states():
""" """
Continuously return the power state of the system when it changes. Continuously return the power state of the system when it changes.
@ -57,6 +63,7 @@ def get_power_states():
yield state.ac_line_status_string yield state.ac_line_status_string
wait_for_power_status_change() wait_for_power_status_change()
@contextlib.contextmanager @contextlib.contextmanager
def no_sleep(): def no_sleep():
""" """

View file

@ -7,24 +7,31 @@ from .api import security
from .api import privilege from .api import privilege
from .api import process from .api import process
def get_process_token(): def get_process_token():
""" """
Get the current process token Get the current process token
""" """
token = wintypes.HANDLE() token = wintypes.HANDLE()
res = process.OpenProcessToken(process.GetCurrentProcess(), process.TOKEN_ALL_ACCESS, token) res = process.OpenProcessToken(
if not res > 0: raise RuntimeError("Couldn't get process token") process.GetCurrentProcess(), process.TOKEN_ALL_ACCESS, token)
if not res > 0:
raise RuntimeError("Couldn't get process token")
return token return token
def get_symlink_luid(): def get_symlink_luid():
""" """
Get the LUID for the SeCreateSymbolicLinkPrivilege Get the LUID for the SeCreateSymbolicLinkPrivilege
""" """
symlink_luid = privilege.LUID() symlink_luid = privilege.LUID()
res = privilege.LookupPrivilegeValue(None, "SeCreateSymbolicLinkPrivilege", symlink_luid) res = privilege.LookupPrivilegeValue(
if not res > 0: raise RuntimeError("Couldn't lookup privilege value") None, "SeCreateSymbolicLinkPrivilege", symlink_luid)
if not res > 0:
raise RuntimeError("Couldn't lookup privilege value")
return symlink_luid return symlink_luid
def get_privilege_information(): def get_privilege_information():
""" """
Get all privileges associated with the current process. Get all privileges associated with the current process.
@ -51,9 +58,11 @@ def get_privilege_information():
res = privilege.GetTokenInformation(*params) res = privilege.GetTokenInformation(*params)
assert res > 0, "Error in second GetTokenInformation (%d)" % res assert res > 0, "Error in second GetTokenInformation (%d)" % res
privileges = ctypes.cast(buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES)).contents privileges = ctypes.cast(
buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES)).contents
return privileges return privileges
def report_privilege_information(): def report_privilege_information():
""" """
Report all privilege information assigned to the current process. Report all privilege information assigned to the current process.
@ -62,6 +71,7 @@ def report_privilege_information():
print("found {0} privileges".format(privileges.count)) print("found {0} privileges".format(privileges.count))
tuple(map(print, privileges)) tuple(map(print, privileges))
def enable_symlink_privilege(): def enable_symlink_privilege():
""" """
Try to assign the symlink privilege to the current process token. Try to assign the symlink privilege to the current process token.
@ -84,9 +94,11 @@ def enable_symlink_privilege():
ERROR_NOT_ALL_ASSIGNED = 1300 ERROR_NOT_ALL_ASSIGNED = 1300
return ctypes.windll.kernel32.GetLastError() != ERROR_NOT_ALL_ASSIGNED return ctypes.windll.kernel32.GetLastError() != ERROR_NOT_ALL_ASSIGNED
class PolicyHandle(wintypes.HANDLE): class PolicyHandle(wintypes.HANDLE):
pass pass
class LSA_UNICODE_STRING(ctypes.Structure): class LSA_UNICODE_STRING(ctypes.Structure):
_fields_ = [ _fields_ = [
('length', ctypes.c_ushort), ('length', ctypes.c_ushort),
@ -94,15 +106,20 @@ class LSA_UNICODE_STRING(ctypes.Structure):
('buffer', ctypes.wintypes.LPWSTR), ('buffer', ctypes.wintypes.LPWSTR),
] ]
def OpenPolicy(system_name, object_attributes, access_mask): def OpenPolicy(system_name, object_attributes, access_mask):
policy = PolicyHandle() policy = PolicyHandle()
raise NotImplementedError("Need to construct structures for parameters " raise NotImplementedError(
"(see http://msdn.microsoft.com/en-us/library/windows/desktop/aa378299%28v=vs.85%29.aspx)") "Need to construct structures for parameters "
res = ctypes.windll.advapi32.LsaOpenPolicy(system_name, object_attributes, "(see http://msdn.microsoft.com/en-us/library/windows"
"/desktop/aa378299%28v=vs.85%29.aspx)")
res = ctypes.windll.advapi32.LsaOpenPolicy(
system_name, object_attributes,
access_mask, ctypes.byref(policy)) access_mask, ctypes.byref(policy))
assert res == 0, "Error status {res}".format(**vars()) assert res == 0, "Error status {res}".format(**vars())
return policy return policy
def grant_symlink_privilege(who, machine=''): def grant_symlink_privilege(who, machine=''):
""" """
Grant the 'create symlink' privilege to who. Grant the 'create symlink' privilege to who.
@ -113,10 +130,13 @@ def grant_symlink_privilege(who, machine=''):
policy = OpenPolicy(machine, flags) policy = OpenPolicy(machine, flags)
return policy return policy
def main(): def main():
assigned = enable_symlink_privilege() assigned = enable_symlink_privilege()
msg = ['failure', 'success'][assigned] msg = ['failure', 'success'][assigned]
print("Symlink privilege assignment completed with {0}".format(msg)) print("Symlink privilege assignment completed with {0}".format(msg))
if __name__ == '__main__': main()
if __name__ == '__main__':
main()

View file

@ -3,6 +3,7 @@ from itertools import count
import six import six
winreg = six.moves.winreg winreg = six.moves.winreg
def key_values(key): def key_values(key):
for index in count(): for index in count():
try: try:
@ -10,6 +11,7 @@ def key_values(key):
except WindowsError: except WindowsError:
break break
def key_subkeys(key): def key_subkeys(key):
for index in count(): for index in count():
try: try:

View file

@ -5,7 +5,9 @@ import ctypes.wintypes
from .error import handle_nonzero_success from .error import handle_nonzero_success
from .api import filesystem from .api import filesystem
def DeviceIoControl(device, io_control_code, in_buffer, out_buffer, overlapped=None):
def DeviceIoControl(
device, io_control_code, in_buffer, out_buffer, overlapped=None):
if overlapped is not None: if overlapped is not None:
raise NotImplementedError("overlapped handles not yet supported") raise NotImplementedError("overlapped handles not yet supported")

View file

@ -3,20 +3,24 @@ import ctypes.wintypes
from jaraco.windows.error import handle_nonzero_success from jaraco.windows.error import handle_nonzero_success
from .api import security from .api import security
def GetTokenInformation(token, information_class): def GetTokenInformation(token, information_class):
""" """
Given a token, get the token information for it. Given a token, get the token information for it.
""" """
data_size = ctypes.wintypes.DWORD() data_size = ctypes.wintypes.DWORD()
ctypes.windll.advapi32.GetTokenInformation(token, information_class.num, ctypes.windll.advapi32.GetTokenInformation(
token, information_class.num,
0, 0, ctypes.byref(data_size)) 0, 0, ctypes.byref(data_size))
data = ctypes.create_string_buffer(data_size.value) data = ctypes.create_string_buffer(data_size.value)
handle_nonzero_success(ctypes.windll.advapi32.GetTokenInformation(token, handle_nonzero_success(ctypes.windll.advapi32.GetTokenInformation(
token,
information_class.num, information_class.num,
ctypes.byref(data), ctypes.sizeof(data), ctypes.byref(data), ctypes.sizeof(data),
ctypes.byref(data_size))) ctypes.byref(data_size)))
return ctypes.cast(data, ctypes.POINTER(security.TOKEN_USER)).contents return ctypes.cast(data, ctypes.POINTER(security.TOKEN_USER)).contents
def OpenProcessToken(proc_handle, access): def OpenProcessToken(proc_handle, access):
result = ctypes.wintypes.HANDLE() result = ctypes.wintypes.HANDLE()
proc_handle = ctypes.wintypes.HANDLE(proc_handle) proc_handle = ctypes.wintypes.HANDLE(proc_handle)
@ -24,6 +28,7 @@ def OpenProcessToken(proc_handle, access):
proc_handle, access, ctypes.byref(result))) proc_handle, access, ctypes.byref(result)))
return result return result
def get_current_user(): def get_current_user():
""" """
Return a TOKEN_USER for the owner of this process. Return a TOKEN_USER for the owner of this process.
@ -34,6 +39,7 @@ def get_current_user():
) )
return GetTokenInformation(process, security.TOKEN_USER) return GetTokenInformation(process, security.TOKEN_USER)
def get_security_attributes_for_user(user=None): def get_security_attributes_for_user(user=None):
""" """
Return a SECURITY_ATTRIBUTES structure with the SID set to the Return a SECURITY_ATTRIBUTES structure with the SID set to the
@ -42,7 +48,8 @@ def get_security_attributes_for_user(user=None):
if user is None: if user is None:
user = get_current_user() user = get_current_user()
assert isinstance(user, security.TOKEN_USER), "user must be TOKEN_USER instance" assert isinstance(user, security.TOKEN_USER), (
"user must be TOKEN_USER instance")
SD = security.SECURITY_DESCRIPTOR() SD = security.SECURITY_DESCRIPTOR()
SA = security.SECURITY_ATTRIBUTES() SA = security.SECURITY_ATTRIBUTES()
@ -51,8 +58,10 @@ def get_security_attributes_for_user(user=None):
SA.descriptor = SD SA.descriptor = SD
SA.bInheritHandle = 1 SA.bInheritHandle = 1
ctypes.windll.advapi32.InitializeSecurityDescriptor(ctypes.byref(SD), ctypes.windll.advapi32.InitializeSecurityDescriptor(
ctypes.byref(SD),
security.SECURITY_DESCRIPTOR.REVISION) security.SECURITY_DESCRIPTOR.REVISION)
ctypes.windll.advapi32.SetSecurityDescriptorOwner(ctypes.byref(SD), ctypes.windll.advapi32.SetSecurityDescriptorOwner(
ctypes.byref(SD),
user.SID, 0) user.SID, 0)
return SA return SA

View file

@ -1,7 +1,8 @@
""" """
Windows Services support for controlling Windows Services. Windows Services support for controlling Windows Services.
Based on http://code.activestate.com/recipes/115875-controlling-windows-services/ Based on http://code.activestate.com
/recipes/115875-controlling-windows-services/
""" """
from __future__ import print_function from __future__ import print_function
@ -13,6 +14,7 @@ import win32api
import win32con import win32con
import win32service import win32service
class Service(object): class Service(object):
""" """
The Service Class is used for controlling Windows The Service Class is used for controlling Windows
@ -47,7 +49,8 @@ class Service(object):
pause: Pauses service (Only if service supports feature). pause: Pauses service (Only if service supports feature).
resume: Resumes service that has been paused. resume: Resumes service that has been paused.
status: Queries current status of service. status: Queries current status of service.
fetchstatus: Continually queries service until requested status(STARTING, RUNNING, fetchstatus: Continually queries service until requested
status(STARTING, RUNNING,
STOPPING & STOPPED) is met or timeout value(in seconds) reached. STOPPING & STOPPED) is met or timeout value(in seconds) reached.
Default timeout value is infinite. Default timeout value is infinite.
infotype: Queries service for process type. (Single, shared and/or infotype: Queries service for process type. (Single, shared and/or
@ -64,18 +67,21 @@ class Service(object):
def __init__(self, service, machinename=None, dbname=None): def __init__(self, service, machinename=None, dbname=None):
self.userv = service self.userv = service
self.scmhandle = win32service.OpenSCManager(machinename, dbname, win32service.SC_MANAGER_ALL_ACCESS) self.scmhandle = win32service.OpenSCManager(
machinename, dbname, win32service.SC_MANAGER_ALL_ACCESS)
self.sserv, self.lserv = self.getname() self.sserv, self.lserv = self.getname()
if (self.sserv or self.lserv) is None: if (self.sserv or self.lserv) is None:
sys.exit() sys.exit()
self.handle = win32service.OpenService(self.scmhandle, self.sserv, win32service.SERVICE_ALL_ACCESS) self.handle = win32service.OpenService(
self.scmhandle, self.sserv, win32service.SERVICE_ALL_ACCESS)
self.sccss = "SYSTEM\\CurrentControlSet\\Services\\" self.sccss = "SYSTEM\\CurrentControlSet\\Services\\"
def start(self): def start(self):
win32service.StartService(self.handle, None) win32service.StartService(self.handle, None)
def stop(self): def stop(self):
self.stat = win32service.ControlService(self.handle, win32service.SERVICE_CONTROL_STOP) self.stat = win32service.ControlService(
self.handle, win32service.SERVICE_CONTROL_STOP)
def restart(self): def restart(self):
self.stop() self.stop()
@ -83,10 +89,12 @@ class Service(object):
self.start() self.start()
def pause(self): def pause(self):
self.stat = win32service.ControlService(self.handle, win32service.SERVICE_CONTROL_PAUSE) self.stat = win32service.ControlService(
self.handle, win32service.SERVICE_CONTROL_PAUSE)
def resume(self): def resume(self):
self.stat = win32service.ControlService(self.handle, win32service.SERVICE_CONTROL_CONTINUE) self.stat = win32service.ControlService(
self.handle, win32service.SERVICE_CONTROL_CONTINUE)
def status(self, prn=0): def status(self, prn=0):
self.stat = win32service.QueryServiceStatus(self.handle) self.stat = win32service.QueryServiceStatus(self.handle)
@ -116,6 +124,7 @@ class Service(object):
if timeout is not None: if timeout is not None:
timeout = int(timeout) timeout = int(timeout)
timeout *= 2 timeout *= 2
def to(timeout): def to(timeout):
time.sleep(.5) time.sleep(.5)
if timeout is not None: if timeout is not None:

View file

@ -1,10 +1,12 @@
from .api import shell from .api import shell
def get_recycle_bin_confirm(): def get_recycle_bin_confirm():
settings = shell.SHELLSTATE() settings = shell.SHELLSTATE()
shell.SHGetSetSettings(settings, shell.SSF_NOCONFIRMRECYCLE, False) shell.SHGetSetSettings(settings, shell.SSF_NOCONFIRMRECYCLE, False)
return not settings.no_confirm_recycle return not settings.no_confirm_recycle
def set_recycle_bin_confirm(confirm=False): def set_recycle_bin_confirm(confirm=False):
settings = shell.SHELLSTATE() settings = shell.SHELLSTATE()
settings.no_confirm_recycle = not confirm settings.no_confirm_recycle = not confirm

View file

@ -14,6 +14,7 @@ from jaraco.windows.api import event as win32event
__author__ = 'Jason R. Coombs <jaraco@jaraco.com>' __author__ = 'Jason R. Coombs <jaraco@jaraco.com>'
class WaitableTimer: class WaitableTimer:
""" """
t = WaitableTimer() t = WaitableTimer()
@ -54,14 +55,14 @@ class WaitableTimer:
except Exception: except Exception:
pass pass
#we're done here, just quit
def _wait(self, seconds): def _wait(self, seconds):
milliseconds = int(seconds * 1000) milliseconds = int(seconds * 1000)
if milliseconds > 0: if milliseconds > 0:
res = win32event.WaitForSingleObject(self.stop_event, milliseconds) res = win32event.WaitForSingleObject(self.stop_event, milliseconds)
if res == win32event.WAIT_OBJECT_0: raise Exception if res == win32event.WAIT_OBJECT_0:
if res == win32event.WAIT_TIMEOUT: pass raise Exception
if res == win32event.WAIT_TIMEOUT:
pass
win32event.SetEvent(self.signal_event) win32event.SetEvent(self.signal_event)
@staticmethod @staticmethod

View file

@ -8,6 +8,7 @@ from ctypes.wintypes import WORD, WCHAR, BOOL, LONG
from jaraco.windows.util import Extended from jaraco.windows.util import Extended
from jaraco.collections import RangeMap from jaraco.collections import RangeMap
class AnyDict(object): class AnyDict(object):
"A dictionary that returns the same value regardless of key" "A dictionary that returns the same value regardless of key"
@ -17,6 +18,7 @@ class AnyDict(object):
def __getitem__(self, key): def __getitem__(self, key):
return self.value return self.value
class SYSTEMTIME(Extended, ctypes.Structure): class SYSTEMTIME(Extended, ctypes.Structure):
_fields_ = [ _fields_ = [
('year', WORD), ('year', WORD),
@ -29,6 +31,7 @@ class SYSTEMTIME(Extended, ctypes.Structure):
('millisecond', WORD), ('millisecond', WORD),
] ]
class REG_TZI_FORMAT(Extended, ctypes.Structure): class REG_TZI_FORMAT(Extended, ctypes.Structure):
_fields_ = [ _fields_ = [
('bias', LONG), ('bias', LONG),
@ -38,6 +41,7 @@ class REG_TZI_FORMAT(Extended, ctypes.Structure):
('daylight_start', SYSTEMTIME), ('daylight_start', SYSTEMTIME),
] ]
class TIME_ZONE_INFORMATION(Extended, ctypes.Structure): class TIME_ZONE_INFORMATION(Extended, ctypes.Structure):
_fields_ = [ _fields_ = [
('bias', LONG), ('bias', LONG),
@ -49,6 +53,7 @@ class TIME_ZONE_INFORMATION(Extended, ctypes.Structure):
('daylight_bias', LONG), ('daylight_bias', LONG),
] ]
class DYNAMIC_TIME_ZONE_INFORMATION(TIME_ZONE_INFORMATION): class DYNAMIC_TIME_ZONE_INFORMATION(TIME_ZONE_INFORMATION):
""" """
Because the structure of the DYNAMIC_TIME_ZONE_INFORMATION extends Because the structure of the DYNAMIC_TIME_ZONE_INFORMATION extends
@ -89,6 +94,7 @@ class DYNAMIC_TIME_ZONE_INFORMATION(TIME_ZONE_INFORMATION):
kwargs[field_name] = arg kwargs[field_name] = arg
super(DYNAMIC_TIME_ZONE_INFORMATION, self).__init__(*self_args, **kwargs) super(DYNAMIC_TIME_ZONE_INFORMATION, self).__init__(*self_args, **kwargs)
class Info(DYNAMIC_TIME_ZONE_INFORMATION): class Info(DYNAMIC_TIME_ZONE_INFORMATION):
""" """
A time zone definition class based on the win32 A time zone definition class based on the win32
@ -126,7 +132,7 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION):
def __init_from_bytes(self, bytes, **kwargs): def __init_from_bytes(self, bytes, **kwargs):
reg_tzi = REG_TZI_FORMAT() reg_tzi = REG_TZI_FORMAT()
# todo: use buffer API in Python 3 # todo: use buffer API in Python 3
buffer = buffer(bytes) buffer = memoryview(bytes)
ctypes.memmove(ctypes.addressof(reg_tzi), buffer, len(buffer)) ctypes.memmove(ctypes.addressof(reg_tzi), buffer, len(buffer))
self.__init_from_reg_tzi(self, reg_tzi, **kwargs) self.__init_from_reg_tzi(self, reg_tzi, **kwargs)
@ -151,7 +157,9 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION):
def __getattribute__(self, attr): def __getattribute__(self, attr):
value = super(Info, self).__getattribute__(attr) value = super(Info, self).__getattribute__(attr)
make_minute_timedelta = lambda m: datetime.timedelta(minutes = m)
def make_minute_timedelta(m):
datetime.timedelta(minutes=m)
if 'bias' in attr: if 'bias' in attr:
value = make_minute_timedelta(value) value = make_minute_timedelta(value)
return value return value
@ -205,10 +213,12 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION):
def _locate_day(year, cutoff): def _locate_day(year, cutoff):
""" """
Takes a SYSTEMTIME object, such as retrieved from a TIME_ZONE_INFORMATION Takes a SYSTEMTIME object, such as retrieved from a TIME_ZONE_INFORMATION
structure or call to GetTimeZoneInformation and interprets it based on the given structure or call to GetTimeZoneInformation and interprets
it based on the given
year to identify the actual day. year to identify the actual day.
This method is necessary because the SYSTEMTIME structure refers to a day by its This method is necessary because the SYSTEMTIME structure
refers to a day by its
day of the week and week of the month (e.g. 4th saturday in March). day of the week and week of the month (e.g. 4th saturday in March).
>>> SATURDAY = 6 >>> SATURDAY = 6
@ -227,9 +237,11 @@ class Info(DYNAMIC_TIME_ZONE_INFORMATION):
week_of_month = cutoff.day week_of_month = cutoff.day
# so the following is the first day of that week # so the following is the first day of that week
day = (week_of_month - 1) * 7 + 1 day = (week_of_month - 1) * 7 + 1
result = datetime.datetime(year, cutoff.month, day, result = datetime.datetime(
year, cutoff.month, day,
cutoff.hour, cutoff.minute, cutoff.second, cutoff.millisecond) cutoff.hour, cutoff.minute, cutoff.second, cutoff.millisecond)
# now the result is the correct week, but not necessarily the correct day of the week # now the result is the correct week, but not necessarily
# the correct day of the week
days_to_go = (target_weekday - result.weekday()) % 7 days_to_go = (target_weekday - result.weekday()) % 7
result += datetime.timedelta(days_to_go) result += datetime.timedelta(days_to_go)
# if we selected a day in the month following the target month, # if we selected a day in the month following the target month,

Some files were not shown because too many files have changed in this diff Show more