Update vendored windows libs

This commit is contained in:
Labrys of Knossos 2022-11-28 05:59:32 -05:00
commit b1cefa94e5
226 changed files with 33472 additions and 11882 deletions

View file

@ -3,73 +3,66 @@ Routines for obtaining the class names
of an object and its parent classes.
"""
from __future__ import unicode_literals
from more_itertools import unique_everseen
def all_bases(c):
"""
return a tuple of all base classes the class c has as a parent.
>>> object in all_bases(list)
True
"""
return c.mro()[1:]
"""
return a tuple of all base classes the class c has as a parent.
>>> object in all_bases(list)
True
"""
return c.mro()[1:]
def all_classes(c):
"""
return a tuple of all classes to which c belongs
>>> list in all_classes(list)
True
"""
return c.mro()
"""
return a tuple of all classes to which c belongs
>>> list in all_classes(list)
True
"""
return c.mro()
# borrowed from
# http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/
def iter_subclasses(cls, _seen=None):
"""
Generator over all subclasses of a given class, in depth-first order.
def iter_subclasses(cls):
"""
Generator over all subclasses of a given class, in depth-first order.
>>> bool in list(iter_subclasses(int))
True
>>> class A(object): pass
>>> class B(A): pass
>>> class C(A): pass
>>> class D(B,C): pass
>>> class E(D): pass
>>>
>>> for cls in iter_subclasses(A):
... print(cls.__name__)
B
D
E
C
>>> # get ALL (new-style) classes currently defined
>>> res = [cls.__name__ for cls in iter_subclasses(object)]
>>> 'type' in res
True
>>> 'tuple' in res
True
>>> len(res) > 100
True
"""
>>> bool in list(iter_subclasses(int))
True
>>> class A(object): pass
>>> class B(A): pass
>>> class C(A): pass
>>> class D(B,C): pass
>>> class E(D): pass
>>>
>>> for cls in iter_subclasses(A):
... print(cls.__name__)
B
D
E
C
>>> # get ALL classes currently defined
>>> res = [cls.__name__ for cls in iter_subclasses(object)]
>>> 'type' in res
True
>>> 'tuple' in res
True
>>> len(res) > 100
True
"""
return unique_everseen(_iter_all_subclasses(cls))
if not isinstance(cls, type):
raise TypeError(
'iter_subclasses must be called with '
'new-style classes, not %.100r' % cls
)
if _seen is None:
_seen = set()
try:
subs = cls.__subclasses__()
except TypeError: # fails only when cls is type
subs = cls.__subclasses__(cls)
for sub in subs:
if sub in _seen:
continue
_seen.add(sub)
yield sub
for sub in iter_subclasses(sub, _seen):
yield sub
def _iter_all_subclasses(cls):
try:
subs = cls.__subclasses__()
except TypeError: # fails only when cls is type
subs = cls.__subclasses__(cls)
for sub in subs:
yield sub
yield from iter_subclasses(sub)

View file

@ -4,38 +4,63 @@ meta.py
Some useful metaclasses.
"""
from __future__ import unicode_literals
class LeafClassesMeta(type):
"""
A metaclass for classes that keeps track of all of them that
aren't base classes.
"""
"""
A metaclass for classes that keeps track of all of them that
aren't base classes.
_leaf_classes = set()
>>> Parent = LeafClassesMeta('MyParentClass', (), {})
>>> Parent in Parent._leaf_classes
True
>>> Child = LeafClassesMeta('MyChildClass', (Parent,), {})
>>> Child in Parent._leaf_classes
True
>>> Parent in Parent._leaf_classes
False
def __init__(cls, name, bases, attrs):
if not hasattr(cls, '_leaf_classes'):
cls._leaf_classes = set()
leaf_classes = getattr(cls, '_leaf_classes')
leaf_classes.add(cls)
# remove any base classes
leaf_classes -= set(bases)
>>> Other = LeafClassesMeta('OtherClass', (), {})
>>> Parent in Other._leaf_classes
False
>>> len(Other._leaf_classes)
1
"""
def __init__(cls, name, bases, attrs):
if not hasattr(cls, '_leaf_classes'):
cls._leaf_classes = set()
leaf_classes = getattr(cls, '_leaf_classes')
leaf_classes.add(cls)
# remove any base classes
leaf_classes -= set(bases)
class TagRegistered(type):
"""
As classes of this metaclass are created, they keep a registry in the
base class of all classes by a class attribute, indicated by attr_name.
"""
attr_name = 'tag'
"""
As classes of this metaclass are created, they keep a registry in the
base class of all classes by a class attribute, indicated by attr_name.
def __init__(cls, name, bases, namespace):
super(TagRegistered, cls).__init__(name, bases, namespace)
if not hasattr(cls, '_registry'):
cls._registry = {}
meta = cls.__class__
attr = getattr(cls, meta.attr_name, None)
if attr:
cls._registry[attr] = cls
>>> FooObject = TagRegistered('FooObject', (), dict(tag='foo'))
>>> FooObject._registry['foo'] is FooObject
True
>>> BarObject = TagRegistered('Barobject', (FooObject,), dict(tag='bar'))
>>> FooObject._registry is BarObject._registry
True
>>> len(FooObject._registry)
2
'...' below should be 'jaraco.classes' but for pytest-dev/pytest#3396
>>> FooObject._registry['bar']
<class '....meta.Barobject'>
"""
attr_name = 'tag'
def __init__(cls, name, bases, namespace):
super(TagRegistered, cls).__init__(name, bases, namespace)
if not hasattr(cls, '_registry'):
cls._registry = {}
meta = cls.__class__
attr = getattr(cls, meta.attr_name, None)
if attr:
cls._registry[attr] = cls

View file

@ -1,67 +1,170 @@
from __future__ import unicode_literals
import six
__metaclass__ = type
class NonDataProperty:
"""Much like the property builtin, but only implements __get__,
making it a non-data property, and can be subsequently reset.
"""Much like the property builtin, but only implements __get__,
making it a non-data property, and can be subsequently reset.
See http://users.rcn.com/python/download/Descriptor.htm for more
information.
See http://users.rcn.com/python/download/Descriptor.htm for more
information.
>>> class X(object):
... @NonDataProperty
... def foo(self):
... return 3
>>> x = X()
>>> x.foo
3
>>> x.foo = 4
>>> x.foo
4
"""
>>> class X(object):
... @NonDataProperty
... def foo(self):
... return 3
>>> x = X()
>>> x.foo
3
>>> x.foo = 4
>>> x.foo
4
def __init__(self, fget):
assert fget is not None, "fget cannot be none"
assert six.callable(fget), "fget must be callable"
self.fget = fget
'...' below should be 'jaraco.classes' but for pytest-dev/pytest#3396
>>> X.foo
<....properties.NonDataProperty object at ...>
"""
def __get__(self, obj, objtype=None):
if obj is None:
return self
return self.fget(obj)
def __init__(self, fget):
assert fget is not None, "fget cannot be none"
assert callable(fget), "fget must be callable"
self.fget = fget
def __get__(self, obj, objtype=None):
if obj is None:
return self
return self.fget(obj)
# from http://stackoverflow.com/a/5191224
class ClassPropertyDescriptor:
def __init__(self, fget, fset=None):
self.fget = fget
self.fset = fset
def __get__(self, obj, klass=None):
if klass is None:
klass = type(obj)
return self.fget.__get__(obj, klass)()
def __set__(self, obj, value):
if not self.fset:
raise AttributeError("can't set attribute")
type_ = type(obj)
return self.fset.__get__(obj, type_)(value)
def setter(self, func):
if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func)
self.fset = func
return self
class classproperty:
"""
Like @property but applies at the class level.
def classproperty(func):
if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func)
>>> class X(metaclass=classproperty.Meta):
... val = None
... @classproperty
... def foo(cls):
... return cls.val
... @foo.setter
... def foo(cls, val):
... cls.val = val
>>> X.foo
>>> X.foo = 3
>>> X.foo
3
>>> x = X()
>>> x.foo
3
>>> X.foo = 4
>>> x.foo
4
return ClassPropertyDescriptor(func)
Setting the property on an instance affects the class.
>>> x.foo = 5
>>> x.foo
5
>>> X.foo
5
>>> vars(x)
{}
>>> X().foo
5
Attempting to set an attribute where no setter was defined
results in an AttributeError:
>>> class GetOnly(metaclass=classproperty.Meta):
... @classproperty
... def foo(cls):
... return 'bar'
>>> GetOnly.foo = 3
Traceback (most recent call last):
...
AttributeError: can't set attribute
It is also possible to wrap a classmethod or staticmethod in
a classproperty.
>>> class Static(metaclass=classproperty.Meta):
... @classproperty
... @classmethod
... def foo(cls):
... return 'foo'
... @classproperty
... @staticmethod
... def bar():
... return 'bar'
>>> Static.foo
'foo'
>>> Static.bar
'bar'
*Legacy*
For compatibility, if the metaclass isn't specified, the
legacy behavior will be invoked.
>>> class X:
... val = None
... @classproperty
... def foo(cls):
... return cls.val
... @foo.setter
... def foo(cls, val):
... cls.val = val
>>> X.foo
>>> X.foo = 3
>>> X.foo
3
>>> x = X()
>>> x.foo
3
>>> X.foo = 4
>>> x.foo
4
Note, because the metaclass was not specified, setting
a value on an instance does not have the intended effect.
>>> x.foo = 5
>>> x.foo
5
>>> X.foo # should be 5
4
>>> vars(x) # should be empty
{'foo': 5}
>>> X().foo # should be 5
4
"""
class Meta(type):
def __setattr__(self, key, value):
obj = self.__dict__.get(key, None)
if type(obj) is classproperty:
return obj.__set__(self, value)
return super().__setattr__(key, value)
def __init__(self, fget, fset=None):
self.fget = self._ensure_method(fget)
self.fset = fset
fset and self.setter(fset)
def __get__(self, instance, owner=None):
return self.fget.__get__(None, owner)()
def __set__(self, owner, value):
if not self.fset:
raise AttributeError("can't set attribute")
if type(owner) is not classproperty.Meta:
owner = type(owner)
return self.fset.__get__(None, owner)(value)
def setter(self, fset):
self.fset = self._ensure_method(fset)
return self
@classmethod
def _ensure_method(cls, fn):
"""
Ensure fn is a classmethod or staticmethod.
"""
needs_method = not isinstance(fn, (classmethod, staticmethod))
return classmethod(fn) if needs_method else fn

File diff suppressed because it is too large Load diff

253
libs/win/jaraco/context.py Normal file
View file

@ -0,0 +1,253 @@
import os
import subprocess
import contextlib
import functools
import tempfile
import shutil
import operator
@contextlib.contextmanager
def pushd(dir):
orig = os.getcwd()
os.chdir(dir)
try:
yield dir
finally:
os.chdir(orig)
@contextlib.contextmanager
def tarball_context(url, target_dir=None, runner=None, pushd=pushd):
"""
Get a tarball, extract it, change to that directory, yield, then
clean up.
`runner` is the function to invoke commands.
`pushd` is a context manager for changing the directory.
"""
if target_dir is None:
target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '')
if runner is None:
runner = functools.partial(subprocess.check_call, shell=True)
# In the tar command, use --strip-components=1 to strip the first path and
# then
# use -C to cause the files to be extracted to {target_dir}. This ensures
# that we always know where the files were extracted.
runner('mkdir {target_dir}'.format(**vars()))
try:
getter = 'wget {url} -O -'
extract = 'tar x{compression} --strip-components=1 -C {target_dir}'
cmd = ' | '.join((getter, extract))
runner(cmd.format(compression=infer_compression(url), **vars()))
with pushd(target_dir):
yield target_dir
finally:
runner('rm -Rf {target_dir}'.format(**vars()))
def infer_compression(url):
"""
Given a URL or filename, infer the compression code for tar.
"""
# cheat and just assume it's the last two characters
compression_indicator = url[-2:]
mapping = dict(gz='z', bz='j', xz='J')
# Assume 'z' (gzip) if no match
return mapping.get(compression_indicator, 'z')
@contextlib.contextmanager
def temp_dir(remover=shutil.rmtree):
"""
Create a temporary directory context. Pass a custom remover
to override the removal behavior.
"""
temp_dir = tempfile.mkdtemp()
try:
yield temp_dir
finally:
remover(temp_dir)
@contextlib.contextmanager
def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir):
"""
Check out the repo indicated by url.
If dest_ctx is supplied, it should be a context manager
to yield the target directory for the check out.
"""
exe = 'git' if 'git' in url else 'hg'
with dest_ctx() as repo_dir:
cmd = [exe, 'clone', url, repo_dir]
if branch:
cmd.extend(['--branch', branch])
devnull = open(os.path.devnull, 'w')
stdout = devnull if quiet else None
subprocess.check_call(cmd, stdout=stdout)
yield repo_dir
@contextlib.contextmanager
def null():
yield
class ExceptionTrap:
"""
A context manager that will catch certain exceptions and provide an
indication they occurred.
>>> with ExceptionTrap() as trap:
... raise Exception()
>>> bool(trap)
True
>>> with ExceptionTrap() as trap:
... pass
>>> bool(trap)
False
>>> with ExceptionTrap(ValueError) as trap:
... raise ValueError("1 + 1 is not 3")
>>> bool(trap)
True
>>> with ExceptionTrap(ValueError) as trap:
... raise Exception()
Traceback (most recent call last):
...
Exception
>>> bool(trap)
False
"""
exc_info = None, None, None
def __init__(self, exceptions=(Exception,)):
self.exceptions = exceptions
def __enter__(self):
return self
@property
def type(self):
return self.exc_info[0]
@property
def value(self):
return self.exc_info[1]
@property
def tb(self):
return self.exc_info[2]
def __exit__(self, *exc_info):
type = exc_info[0]
matches = type and issubclass(type, self.exceptions)
if matches:
self.exc_info = exc_info
return matches
def __bool__(self):
return bool(self.type)
def raises(self, func, *, _test=bool):
"""
Wrap func and replace the result with the truth
value of the trap (True if an exception occurred).
First, give the decorator an alias to support Python 3.8
Syntax.
>>> raises = ExceptionTrap(ValueError).raises
Now decorate a function that always fails.
>>> @raises
... def fail():
... raise ValueError('failed')
>>> fail()
True
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
with ExceptionTrap(self.exceptions) as trap:
func(*args, **kwargs)
return _test(trap)
return wrapper
def passes(self, func):
"""
Wrap func and replace the result with the truth
value of the trap (True if no exception).
First, give the decorator an alias to support Python 3.8
Syntax.
>>> passes = ExceptionTrap(ValueError).passes
Now decorate a function that always fails.
>>> @passes
... def fail():
... raise ValueError('failed')
>>> fail()
False
"""
return self.raises(func, _test=operator.not_)
class suppress(contextlib.suppress, contextlib.ContextDecorator):
"""
A version of contextlib.suppress with decorator support.
>>> @suppress(KeyError)
... def key_error():
... {}['']
>>> key_error()
"""
class on_interrupt(contextlib.ContextDecorator):
"""
Replace a KeyboardInterrupt with SystemExit(1)
>>> def do_interrupt():
... raise KeyboardInterrupt()
>>> on_interrupt('error')(do_interrupt)()
Traceback (most recent call last):
...
SystemExit: 1
>>> on_interrupt('error', code=255)(do_interrupt)()
Traceback (most recent call last):
...
SystemExit: 255
>>> on_interrupt('suppress')(do_interrupt)()
>>> with __import__('pytest').raises(KeyboardInterrupt):
... on_interrupt('ignore')(do_interrupt)()
"""
def __init__(
self,
action='error',
# py3.7 compat
# /,
code=1,
):
self.action = action
self.code = code
def __enter__(self):
return self
def __exit__(self, exctype, excinst, exctb):
if exctype is not KeyboardInterrupt or self.action == 'ignore':
return
elif self.action == 'error':
raise SystemExit(self.code) from excinst
return self.action == 'suppress'

View file

@ -1,459 +1,525 @@
from __future__ import (
absolute_import, unicode_literals, print_function, division,
)
import functools
import time
import warnings
import inspect
import collections
from itertools import count
import types
import itertools
__metaclass__ = type
import more_itertools
from typing import Callable, TypeVar
try:
from functools import lru_cache
except ImportError:
try:
from backports.functools_lru_cache import lru_cache
except ImportError:
try:
from functools32 import lru_cache
except ImportError:
warnings.warn("No lru_cache available")
import more_itertools.recipes
CallableT = TypeVar("CallableT", bound=Callable[..., object])
def compose(*funcs):
"""
Compose any number of unary functions into a single unary function.
"""
Compose any number of unary functions into a single unary function.
>>> import textwrap
>>> from six import text_type
>>> stripped = text_type.strip(textwrap.dedent(compose.__doc__))
>>> compose(text_type.strip, textwrap.dedent)(compose.__doc__) == stripped
True
>>> import textwrap
>>> expected = str.strip(textwrap.dedent(compose.__doc__))
>>> strip_and_dedent = compose(str.strip, textwrap.dedent)
>>> strip_and_dedent(compose.__doc__) == expected
True
Compose also allows the innermost function to take arbitrary arguments.
Compose also allows the innermost function to take arbitrary arguments.
>>> round_three = lambda x: round(x, ndigits=3)
>>> f = compose(round_three, int.__truediv__)
>>> [f(3*x, x+1) for x in range(1,10)]
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
"""
>>> round_three = lambda x: round(x, ndigits=3)
>>> f = compose(round_three, int.__truediv__)
>>> [f(3*x, x+1) for x in range(1,10)]
[1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7]
"""
def compose_two(f1, f2):
return lambda *args, **kwargs: f1(f2(*args, **kwargs))
return functools.reduce(compose_two, funcs)
def compose_two(f1, f2):
return lambda *args, **kwargs: f1(f2(*args, **kwargs))
return functools.reduce(compose_two, funcs)
def method_caller(method_name, *args, **kwargs):
"""
Return a function that will call a named method on the
target object with optional positional and keyword
arguments.
"""
Return a function that will call a named method on the
target object with optional positional and keyword
arguments.
>>> lower = method_caller('lower')
>>> lower('MyString')
'mystring'
"""
def call_method(target):
func = getattr(target, method_name)
return func(*args, **kwargs)
return call_method
>>> lower = method_caller('lower')
>>> lower('MyString')
'mystring'
"""
def call_method(target):
func = getattr(target, method_name)
return func(*args, **kwargs)
return call_method
def once(func):
"""
Decorate func so it's only ever called the first time.
"""
Decorate func so it's only ever called the first time.
This decorator can ensure that an expensive or non-idempotent function
will not be expensive on subsequent calls and is idempotent.
This decorator can ensure that an expensive or non-idempotent function
will not be expensive on subsequent calls and is idempotent.
>>> add_three = once(lambda a: a+3)
>>> add_three(3)
6
>>> add_three(9)
6
>>> add_three('12')
6
>>> add_three = once(lambda a: a+3)
>>> add_three(3)
6
>>> add_three(9)
6
>>> add_three('12')
6
To reset the stored value, simply clear the property ``saved_result``.
To reset the stored value, simply clear the property ``saved_result``.
>>> del add_three.saved_result
>>> add_three(9)
12
>>> add_three(8)
12
>>> del add_three.saved_result
>>> add_three(9)
12
>>> add_three(8)
12
Or invoke 'reset()' on it.
Or invoke 'reset()' on it.
>>> add_three.reset()
>>> add_three(-3)
0
>>> add_three(0)
0
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not hasattr(wrapper, 'saved_result'):
wrapper.saved_result = func(*args, **kwargs)
return wrapper.saved_result
wrapper.reset = lambda: vars(wrapper).__delitem__('saved_result')
return wrapper
>>> add_three.reset()
>>> add_three(-3)
0
>>> add_three(0)
0
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not hasattr(wrapper, 'saved_result'):
wrapper.saved_result = func(*args, **kwargs)
return wrapper.saved_result
wrapper.reset = lambda: vars(wrapper).__delitem__('saved_result')
return wrapper
def method_cache(method, cache_wrapper=None):
"""
Wrap lru_cache to support storing the cache data in the object instances.
def method_cache(
method: CallableT,
cache_wrapper: Callable[
[CallableT], CallableT
] = functools.lru_cache(), # type: ignore[assignment]
) -> CallableT:
"""
Wrap lru_cache to support storing the cache data in the object instances.
Abstracts the common paradigm where the method explicitly saves an
underscore-prefixed protected property on first call and returns that
subsequently.
Abstracts the common paradigm where the method explicitly saves an
underscore-prefixed protected property on first call and returns that
subsequently.
>>> class MyClass:
... calls = 0
...
... @method_cache
... def method(self, value):
... self.calls += 1
... return value
>>> class MyClass:
... calls = 0
...
... @method_cache
... def method(self, value):
... self.calls += 1
... return value
>>> a = MyClass()
>>> a.method(3)
3
>>> for x in range(75):
... res = a.method(x)
>>> a.calls
75
>>> a = MyClass()
>>> a.method(3)
3
>>> for x in range(75):
... res = a.method(x)
>>> a.calls
75
Note that the apparent behavior will be exactly like that of lru_cache
except that the cache is stored on each instance, so values in one
instance will not flush values from another, and when an instance is
deleted, so are the cached values for that instance.
Note that the apparent behavior will be exactly like that of lru_cache
except that the cache is stored on each instance, so values in one
instance will not flush values from another, and when an instance is
deleted, so are the cached values for that instance.
>>> b = MyClass()
>>> for x in range(35):
... res = b.method(x)
>>> b.calls
35
>>> a.method(0)
0
>>> a.calls
75
>>> b = MyClass()
>>> for x in range(35):
... res = b.method(x)
>>> b.calls
35
>>> a.method(0)
0
>>> a.calls
75
Note that if method had been decorated with ``functools.lru_cache()``,
a.calls would have been 76 (due to the cached value of 0 having been
flushed by the 'b' instance).
Note that if method had been decorated with ``functools.lru_cache()``,
a.calls would have been 76 (due to the cached value of 0 having been
flushed by the 'b' instance).
Clear the cache with ``.cache_clear()``
Clear the cache with ``.cache_clear()``
>>> a.method.cache_clear()
>>> a.method.cache_clear()
Another cache wrapper may be supplied:
Same for a method that hasn't yet been called.
>>> cache = lru_cache(maxsize=2)
>>> MyClass.method2 = method_cache(lambda self: 3, cache_wrapper=cache)
>>> a = MyClass()
>>> a.method2()
3
>>> c = MyClass()
>>> c.method.cache_clear()
Caution - do not subsequently wrap the method with another decorator, such
as ``@property``, which changes the semantics of the function.
Another cache wrapper may be supplied:
See also
http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/
for another implementation and additional justification.
"""
cache_wrapper = cache_wrapper or lru_cache()
>>> cache = functools.lru_cache(maxsize=2)
>>> MyClass.method2 = method_cache(lambda self: 3, cache_wrapper=cache)
>>> a = MyClass()
>>> a.method2()
3
def wrapper(self, *args, **kwargs):
# it's the first call, replace the method with a cached, bound method
bound_method = functools.partial(method, self)
cached_method = cache_wrapper(bound_method)
setattr(self, method.__name__, cached_method)
return cached_method(*args, **kwargs)
Caution - do not subsequently wrap the method with another decorator, such
as ``@property``, which changes the semantics of the function.
return _special_method_cache(method, cache_wrapper) or wrapper
See also
http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/
for another implementation and additional justification.
"""
def wrapper(self: object, *args: object, **kwargs: object) -> object:
# it's the first call, replace the method with a cached, bound method
bound_method: CallableT = types.MethodType( # type: ignore[assignment]
method, self
)
cached_method = cache_wrapper(bound_method)
setattr(self, method.__name__, cached_method)
return cached_method(*args, **kwargs)
# Support cache clear even before cache has been created.
wrapper.cache_clear = lambda: None # type: ignore[attr-defined]
return ( # type: ignore[return-value]
_special_method_cache(method, cache_wrapper) or wrapper
)
def _special_method_cache(method, cache_wrapper):
"""
Because Python treats special methods differently, it's not
possible to use instance attributes to implement the cached
methods.
"""
Because Python treats special methods differently, it's not
possible to use instance attributes to implement the cached
methods.
Instead, install the wrapper method under a different name
and return a simple proxy to that wrapper.
Instead, install the wrapper method under a different name
and return a simple proxy to that wrapper.
https://github.com/jaraco/jaraco.functools/issues/5
"""
name = method.__name__
special_names = '__getattr__', '__getitem__'
if name not in special_names:
return
https://github.com/jaraco/jaraco.functools/issues/5
"""
name = method.__name__
special_names = '__getattr__', '__getitem__'
if name not in special_names:
return
wrapper_name = '__cached' + name
wrapper_name = '__cached' + name
def proxy(self, *args, **kwargs):
if wrapper_name not in vars(self):
bound = functools.partial(method, self)
cache = cache_wrapper(bound)
setattr(self, wrapper_name, cache)
else:
cache = getattr(self, wrapper_name)
return cache(*args, **kwargs)
def proxy(self, *args, **kwargs):
if wrapper_name not in vars(self):
bound = types.MethodType(method, self)
cache = cache_wrapper(bound)
setattr(self, wrapper_name, cache)
else:
cache = getattr(self, wrapper_name)
return cache(*args, **kwargs)
return proxy
return proxy
def apply(transform):
"""
Decorate a function with a transform function that is
invoked on results returned from the decorated function.
"""
Decorate a function with a transform function that is
invoked on results returned from the decorated function.
>>> @apply(reversed)
... def get_numbers(start):
... return range(start, start+3)
>>> list(get_numbers(4))
[6, 5, 4]
"""
def wrap(func):
return compose(transform, func)
return wrap
>>> @apply(reversed)
... def get_numbers(start):
... "doc for get_numbers"
... return range(start, start+3)
>>> list(get_numbers(4))
[6, 5, 4]
>>> get_numbers.__doc__
'doc for get_numbers'
"""
def wrap(func):
return functools.wraps(func)(compose(transform, func))
return wrap
def result_invoke(action):
r"""
Decorate a function with an action function that is
invoked on the results returned from the decorated
function (for its side-effect), then return the original
result.
r"""
Decorate a function with an action function that is
invoked on the results returned from the decorated
function (for its side-effect), then return the original
result.
>>> @result_invoke(print)
... def add_two(a, b):
... return a + b
>>> x = add_two(2, 3)
5
"""
def wrap(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
action(result)
return result
return wrapper
return wrap
>>> @result_invoke(print)
... def add_two(a, b):
... return a + b
>>> x = add_two(2, 3)
5
>>> x
5
"""
def wrap(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
action(result)
return result
return wrapper
return wrap
def call_aside(f, *args, **kwargs):
"""
Call a function for its side effect after initialization.
"""
Call a function for its side effect after initialization.
>>> @call_aside
... def func(): print("called")
called
>>> func()
called
>>> @call_aside
... def func(): print("called")
called
>>> func()
called
Use functools.partial to pass parameters to the initial call
Use functools.partial to pass parameters to the initial call
>>> @functools.partial(call_aside, name='bingo')
... def func(name): print("called with", name)
called with bingo
"""
f(*args, **kwargs)
return f
>>> @functools.partial(call_aside, name='bingo')
... def func(name): print("called with", name)
called with bingo
"""
f(*args, **kwargs)
return f
class Throttler:
"""
Rate-limit a function (or other callable)
"""
def __init__(self, func, max_rate=float('Inf')):
if isinstance(func, Throttler):
func = func.func
self.func = func
self.max_rate = max_rate
self.reset()
"""
Rate-limit a function (or other callable)
"""
def reset(self):
self.last_called = 0
def __init__(self, func, max_rate=float('Inf')):
if isinstance(func, Throttler):
func = func.func
self.func = func
self.max_rate = max_rate
self.reset()
def __call__(self, *args, **kwargs):
self._wait()
return self.func(*args, **kwargs)
def reset(self):
self.last_called = 0
def _wait(self):
"ensure at least 1/max_rate seconds from last call"
elapsed = time.time() - self.last_called
must_wait = 1 / self.max_rate - elapsed
time.sleep(max(0, must_wait))
self.last_called = time.time()
def __call__(self, *args, **kwargs):
self._wait()
return self.func(*args, **kwargs)
def __get__(self, obj, type=None):
return first_invoke(self._wait, functools.partial(self.func, obj))
def _wait(self):
"ensure at least 1/max_rate seconds from last call"
elapsed = time.time() - self.last_called
must_wait = 1 / self.max_rate - elapsed
time.sleep(max(0, must_wait))
self.last_called = time.time()
def __get__(self, obj, type=None):
return first_invoke(self._wait, functools.partial(self.func, obj))
def first_invoke(func1, func2):
"""
Return a function that when invoked will invoke func1 without
any parameters (for its side-effect) and then invoke func2
with whatever parameters were passed, returning its result.
"""
def wrapper(*args, **kwargs):
func1()
return func2(*args, **kwargs)
return wrapper
"""
Return a function that when invoked will invoke func1 without
any parameters (for its side-effect) and then invoke func2
with whatever parameters were passed, returning its result.
"""
def wrapper(*args, **kwargs):
func1()
return func2(*args, **kwargs)
return wrapper
def retry_call(func, cleanup=lambda: None, retries=0, trap=()):
"""
Given a callable func, trap the indicated exceptions
for up to 'retries' times, invoking cleanup on the
exception. On the final attempt, allow any exceptions
to propagate.
"""
attempts = count() if retries == float('inf') else range(retries)
for attempt in attempts:
try:
return func()
except trap:
cleanup()
"""
Given a callable func, trap the indicated exceptions
for up to 'retries' times, invoking cleanup on the
exception. On the final attempt, allow any exceptions
to propagate.
"""
attempts = itertools.count() if retries == float('inf') else range(retries)
for attempt in attempts:
try:
return func()
except trap:
cleanup()
return func()
return func()
def retry(*r_args, **r_kwargs):
"""
Decorator wrapper for retry_call. Accepts arguments to retry_call
except func and then returns a decorator for the decorated function.
"""
Decorator wrapper for retry_call. Accepts arguments to retry_call
except func and then returns a decorator for the decorated function.
Ex:
Ex:
>>> @retry(retries=3)
... def my_func(a, b):
... "this is my funk"
... print(a, b)
>>> my_func.__doc__
'this is my funk'
"""
def decorate(func):
@functools.wraps(func)
def wrapper(*f_args, **f_kwargs):
bound = functools.partial(func, *f_args, **f_kwargs)
return retry_call(bound, *r_args, **r_kwargs)
return wrapper
return decorate
>>> @retry(retries=3)
... def my_func(a, b):
... "this is my funk"
... print(a, b)
>>> my_func.__doc__
'this is my funk'
"""
def decorate(func):
@functools.wraps(func)
def wrapper(*f_args, **f_kwargs):
bound = functools.partial(func, *f_args, **f_kwargs)
return retry_call(bound, *r_args, **r_kwargs)
return wrapper
return decorate
def print_yielded(func):
"""
Convert a generator into a function that prints all yielded elements
"""
Convert a generator into a function that prints all yielded elements
>>> @print_yielded
... def x():
... yield 3; yield None
>>> x()
3
None
"""
print_all = functools.partial(map, print)
print_results = compose(more_itertools.recipes.consume, print_all, func)
return functools.wraps(func)(print_results)
>>> @print_yielded
... def x():
... yield 3; yield None
>>> x()
3
None
"""
print_all = functools.partial(map, print)
print_results = compose(more_itertools.consume, print_all, func)
return functools.wraps(func)(print_results)
def pass_none(func):
"""
Wrap func so it's not called if its first param is None
"""
Wrap func so it's not called if its first param is None
>>> print_text = pass_none(print)
>>> print_text('text')
text
>>> print_text(None)
"""
@functools.wraps(func)
def wrapper(param, *args, **kwargs):
if param is not None:
return func(param, *args, **kwargs)
return wrapper
>>> print_text = pass_none(print)
>>> print_text('text')
text
>>> print_text(None)
"""
@functools.wraps(func)
def wrapper(param, *args, **kwargs):
if param is not None:
return func(param, *args, **kwargs)
return wrapper
def assign_params(func, namespace):
"""
Assign parameters from namespace where func solicits.
"""
Assign parameters from namespace where func solicits.
>>> def func(x, y=3):
... print(x, y)
>>> assigned = assign_params(func, dict(x=2, z=4))
>>> assigned()
2 3
>>> def func(x, y=3):
... print(x, y)
>>> assigned = assign_params(func, dict(x=2, z=4))
>>> assigned()
2 3
The usual errors are raised if a function doesn't receive
its required parameters:
The usual errors are raised if a function doesn't receive
its required parameters:
>>> assigned = assign_params(func, dict(y=3, z=4))
>>> assigned()
Traceback (most recent call last):
TypeError: func() ...argument...
"""
try:
sig = inspect.signature(func)
params = sig.parameters.keys()
except AttributeError:
spec = inspect.getargspec(func)
params = spec.args
call_ns = {
k: namespace[k]
for k in params
if k in namespace
}
return functools.partial(func, **call_ns)
>>> assigned = assign_params(func, dict(y=3, z=4))
>>> assigned()
Traceback (most recent call last):
TypeError: func() ...argument...
It even works on methods:
>>> class Handler:
... def meth(self, arg):
... print(arg)
>>> assign_params(Handler().meth, dict(arg='crystal', foo='clear'))()
crystal
"""
sig = inspect.signature(func)
params = sig.parameters.keys()
call_ns = {k: namespace[k] for k in params if k in namespace}
return functools.partial(func, **call_ns)
def save_method_args(method):
"""
Wrap a method such that when it is called, the args and kwargs are
saved on the method.
"""
Wrap a method such that when it is called, the args and kwargs are
saved on the method.
>>> class MyClass:
... @save_method_args
... def method(self, a, b):
... print(a, b)
>>> my_ob = MyClass()
>>> my_ob.method(1, 2)
1 2
>>> my_ob._saved_method.args
(1, 2)
>>> my_ob._saved_method.kwargs
{}
>>> my_ob.method(a=3, b='foo')
3 foo
>>> my_ob._saved_method.args
()
>>> my_ob._saved_method.kwargs == dict(a=3, b='foo')
True
>>> class MyClass:
... @save_method_args
... def method(self, a, b):
... print(a, b)
>>> my_ob = MyClass()
>>> my_ob.method(1, 2)
1 2
>>> my_ob._saved_method.args
(1, 2)
>>> my_ob._saved_method.kwargs
{}
>>> my_ob.method(a=3, b='foo')
3 foo
>>> my_ob._saved_method.args
()
>>> my_ob._saved_method.kwargs == dict(a=3, b='foo')
True
The arguments are stored on the instance, allowing for
different instance to save different args.
The arguments are stored on the instance, allowing for
different instance to save different args.
>>> your_ob = MyClass()
>>> your_ob.method({str('x'): 3}, b=[4])
{'x': 3} [4]
>>> your_ob._saved_method.args
({'x': 3},)
>>> my_ob._saved_method.args
()
"""
args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs')
>>> your_ob = MyClass()
>>> your_ob.method({str('x'): 3}, b=[4])
{'x': 3} [4]
>>> your_ob._saved_method.args
({'x': 3},)
>>> my_ob._saved_method.args
()
"""
args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs')
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
attr_name = '_saved_' + method.__name__
attr = args_and_kwargs(args, kwargs)
setattr(self, attr_name, attr)
return method(self, *args, **kwargs)
return wrapper
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
attr_name = '_saved_' + method.__name__
attr = args_and_kwargs(args, kwargs)
setattr(self, attr_name, attr)
return method(self, *args, **kwargs)
return wrapper
def except_(*exceptions, replace=None, use=None):
"""
Replace the indicated exceptions, if raised, with the indicated
literal replacement or evaluated expression (if present).
>>> safe_int = except_(ValueError)(int)
>>> safe_int('five')
>>> safe_int('5')
5
Specify a literal replacement with ``replace``.
>>> safe_int_r = except_(ValueError, replace=0)(int)
>>> safe_int_r('five')
0
Provide an expression to ``use`` to pass through particular parameters.
>>> safe_int_pt = except_(ValueError, use='args[0]')(int)
>>> safe_int_pt('five')
'five'
"""
def decorate(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except exceptions:
try:
return eval(use)
except TypeError:
return replace
return wrapper
return decorate

View file

@ -1,151 +1,156 @@
from __future__ import absolute_import, unicode_literals
import numbers
from functools import reduce
def get_bit_values(number, size=32):
"""
Get bit values as a list for a given number
"""
Get bit values as a list for a given number
>>> get_bit_values(1) == [0]*31 + [1]
True
>>> get_bit_values(1) == [0]*31 + [1]
True
>>> get_bit_values(0xDEADBEEF)
[1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1]
>>> get_bit_values(0xDEADBEEF)
[1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, \
1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1]
You may override the default word size of 32-bits to match your actual
application.
You may override the default word size of 32-bits to match your actual
application.
>>> get_bit_values(0x3, 2)
[1, 1]
>>> get_bit_values(0x3, 2)
[1, 1]
>>> get_bit_values(0x3, 4)
[0, 0, 1, 1]
"""
number += 2**size
return list(map(int, bin(number)[-size:]))
>>> get_bit_values(0x3, 4)
[0, 0, 1, 1]
"""
number += 2 ** size
return list(map(int, bin(number)[-size:]))
def gen_bit_values(number):
"""
Return a zero or one for each bit of a numeric value up to the most
significant 1 bit, beginning with the least significant bit.
"""
Return a zero or one for each bit of a numeric value up to the most
significant 1 bit, beginning with the least significant bit.
>>> list(gen_bit_values(16))
[0, 0, 0, 0, 1]
"""
digits = bin(number)[2:]
return map(int, reversed(digits))
>>> list(gen_bit_values(16))
[0, 0, 0, 0, 1]
"""
digits = bin(number)[2:]
return map(int, reversed(digits))
def coalesce(bits):
"""
Take a sequence of bits, most significant first, and
coalesce them into a number.
"""
Take a sequence of bits, most significant first, and
coalesce them into a number.
>>> coalesce([1,0,1])
5
"""
operation = lambda a, b: (a << 1 | b)
return reduce(operation, bits)
>>> coalesce([1,0,1])
5
"""
def operation(a, b):
return a << 1 | b
return reduce(operation, bits)
class Flags(object):
"""
Subclasses should define _names, a list of flag names beginning
with the least-significant bit.
class Flags:
"""
Subclasses should define _names, a list of flag names beginning
with the least-significant bit.
>>> class MyFlags(Flags):
... _names = 'a', 'b', 'c'
>>> mf = MyFlags.from_number(5)
>>> mf['a']
1
>>> mf['b']
0
>>> mf['c'] == mf[2]
True
>>> mf['b'] = 1
>>> mf['a'] = 0
>>> mf.number
6
"""
def __init__(self, values):
self._values = list(values)
if hasattr(self, '_names'):
n_missing_bits = len(self._names) - len(self._values)
self._values.extend([0] * n_missing_bits)
>>> class MyFlags(Flags):
... _names = 'a', 'b', 'c'
>>> mf = MyFlags.from_number(5)
>>> mf['a']
1
>>> mf['b']
0
>>> mf['c'] == mf[2]
True
>>> mf['b'] = 1
>>> mf['a'] = 0
>>> mf.number
6
"""
@classmethod
def from_number(cls, number):
return cls(gen_bit_values(number))
def __init__(self, values):
self._values = list(values)
if hasattr(self, '_names'):
n_missing_bits = len(self._names) - len(self._values)
self._values.extend([0] * n_missing_bits)
@property
def number(self):
return coalesce(reversed(self._values))
@classmethod
def from_number(cls, number):
return cls(gen_bit_values(number))
def __setitem__(self, key, value):
# first try by index, then by name
try:
self._values[key] = value
except TypeError:
index = self._names.index(key)
self._values[index] = value
@property
def number(self):
return coalesce(reversed(self._values))
def __getitem__(self, key):
# first try by index, then by name
try:
return self._values[key]
except TypeError:
index = self._names.index(key)
return self._values[index]
def __setitem__(self, key, value):
# first try by index, then by name
try:
self._values[key] = value
except TypeError:
index = self._names.index(key)
self._values[index] = value
def __getitem__(self, key):
# first try by index, then by name
try:
return self._values[key]
except TypeError:
index = self._names.index(key)
return self._values[index]
class BitMask(type):
"""
A metaclass to create a bitmask with attributes. Subclass an int and
set this as the metaclass to use.
"""
A metaclass to create a bitmask with attributes. Subclass an int and
set this as the metaclass to use.
Here's how to create such a class on Python 3:
Construct such a class:
class MyBits(int, metaclass=BitMask):
a = 0x1
b = 0x4
c = 0x3
>>> class MyBits(int, metaclass=BitMask):
... a = 0x1
... b = 0x4
... c = 0x3
For testing purposes, construct explicitly to support Python 2
>>> b1 = MyBits(3)
>>> b1.a, b1.b, b1.c
(True, False, True)
>>> b2 = MyBits(8)
>>> any([b2.a, b2.b, b2.c])
False
>>> ns = dict(a=0x1, b=0x4, c=0x3)
>>> MyBits = BitMask(str('MyBits'), (int,), ns)
If the instance defines methods, they won't be wrapped in
properties.
>>> b1 = MyBits(3)
>>> b1.a, b1.b, b1.c
(True, False, True)
>>> b2 = MyBits(8)
>>> any([b2.a, b2.b, b2.c])
False
>>> class MyBits(int, metaclass=BitMask):
... a = 0x1
... b = 0x4
... c = 0x3
...
... @classmethod
... def get_value(cls):
... return 'some value'
...
... @property
... def prop(cls):
... return 'a property'
>>> MyBits(3).get_value()
'some value'
>>> MyBits(3).prop
'a property'
"""
If the instance defines methods, they won't be wrapped in
properties.
def __new__(cls, name, bases, attrs):
def make_property(name, value):
if name.startswith('_') or not isinstance(value, numbers.Number):
return value
return property(lambda self, value=value: bool(self & value))
>>> ns['get_value'] = classmethod(lambda cls: 'some value')
>>> ns['prop'] = property(lambda self: 'a property')
>>> MyBits = BitMask(str('MyBits'), (int,), ns)
>>> MyBits(3).get_value()
'some value'
>>> MyBits(3).prop
'a property'
"""
def __new__(cls, name, bases, attrs):
def make_property(name, value):
if name.startswith('_') or not isinstance(value, numbers.Number):
return value
return property(lambda self, value=value: bool(self & value))
newattrs = dict(
(name, make_property(name, value))
for name, value in attrs.items()
)
return type.__new__(cls, name, bases, newattrs)
newattrs = dict(
(name, make_property(name, value)) for name, value in attrs.items()
)
return type.__new__(cls, name, bases, newattrs)

View file

@ -1,452 +0,0 @@
from __future__ import absolute_import, unicode_literals, print_function
import sys
import re
import inspect
import itertools
import textwrap
import functools
import six
import jaraco.collections
from jaraco.functools import compose
def substitution(old, new):
"""
Return a function that will perform a substitution on a string
"""
return lambda s: s.replace(old, new)
def multi_substitution(*substitutions):
"""
Take a sequence of pairs specifying substitutions, and create
a function that performs those substitutions.
>>> multi_substitution(('foo', 'bar'), ('bar', 'baz'))('foo')
'baz'
"""
substitutions = itertools.starmap(substitution, substitutions)
# compose function applies last function first, so reverse the
# substitutions to get the expected order.
substitutions = reversed(tuple(substitutions))
return compose(*substitutions)
class FoldedCase(six.text_type):
"""
A case insensitive string class; behaves just like str
except compares equal when the only variation is case.
>>> s = FoldedCase('hello world')
>>> s == 'Hello World'
True
>>> 'Hello World' == s
True
>>> s != 'Hello World'
False
>>> s.index('O')
4
>>> s.split('O')
['hell', ' w', 'rld']
>>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta']))
['alpha', 'Beta', 'GAMMA']
Sequence membership is straightforward.
>>> "Hello World" in [s]
True
>>> s in ["Hello World"]
True
You may test for set inclusion, but candidate and elements
must both be folded.
>>> FoldedCase("Hello World") in {s}
True
>>> s in {FoldedCase("Hello World")}
True
String inclusion works as long as the FoldedCase object
is on the right.
>>> "hello" in FoldedCase("Hello World")
True
But not if the FoldedCase object is on the left:
>>> FoldedCase('hello') in 'Hello World'
False
In that case, use in_:
>>> FoldedCase('hello').in_('Hello World')
True
"""
def __lt__(self, other):
return self.lower() < other.lower()
def __gt__(self, other):
return self.lower() > other.lower()
def __eq__(self, other):
return self.lower() == other.lower()
def __ne__(self, other):
return self.lower() != other.lower()
def __hash__(self):
return hash(self.lower())
def __contains__(self, other):
return super(FoldedCase, self).lower().__contains__(other.lower())
def in_(self, other):
"Does self appear in other?"
return self in FoldedCase(other)
# cache lower since it's likely to be called frequently.
@jaraco.functools.method_cache
def lower(self):
return super(FoldedCase, self).lower()
def index(self, sub):
return self.lower().index(sub.lower())
def split(self, splitter=' ', maxsplit=0):
pattern = re.compile(re.escape(splitter), re.I)
return pattern.split(self, maxsplit)
def local_format(string):
"""
format the string using variables in the caller's local namespace.
>>> a = 3
>>> local_format("{a:5}")
' 3'
"""
context = inspect.currentframe().f_back.f_locals
if sys.version_info < (3, 2):
return string.format(**context)
return string.format_map(context)
def global_format(string):
"""
format the string using variables in the caller's global namespace.
>>> a = 3
>>> fmt = "The func name: {global_format.__name__}"
>>> global_format(fmt)
'The func name: global_format'
"""
context = inspect.currentframe().f_back.f_globals
if sys.version_info < (3, 2):
return string.format(**context)
return string.format_map(context)
def namespace_format(string):
"""
Format the string using variable in the caller's scope (locals + globals).
>>> a = 3
>>> fmt = "A is {a} and this func is {namespace_format.__name__}"
>>> namespace_format(fmt)
'A is 3 and this func is namespace_format'
"""
context = jaraco.collections.DictStack()
context.push(inspect.currentframe().f_back.f_globals)
context.push(inspect.currentframe().f_back.f_locals)
if sys.version_info < (3, 2):
return string.format(**context)
return string.format_map(context)
def is_decodable(value):
r"""
Return True if the supplied value is decodable (using the default
encoding).
>>> is_decodable(b'\xff')
False
>>> is_decodable(b'\x32')
True
"""
# TODO: This code could be expressed more consisely and directly
# with a jaraco.context.ExceptionTrap, but that adds an unfortunate
# long dependency tree, so for now, use boolean literals.
try:
value.decode()
except UnicodeDecodeError:
return False
return True
def is_binary(value):
"""
Return True if the value appears to be binary (that is, it's a byte
string and isn't decodable).
"""
return isinstance(value, bytes) and not is_decodable(value)
def trim(s):
r"""
Trim something like a docstring to remove the whitespace that
is common due to indentation and formatting.
>>> trim("\n\tfoo = bar\n\t\tbar = baz\n")
'foo = bar\n\tbar = baz'
"""
return textwrap.dedent(s).strip()
class Splitter(object):
"""object that will split a string with the given arguments for each call
>>> s = Splitter(',')
>>> s('hello, world, this is your, master calling')
['hello', ' world', ' this is your', ' master calling']
"""
def __init__(self, *args):
self.args = args
def __call__(self, s):
return s.split(*self.args)
def indent(string, prefix=' ' * 4):
return prefix + string
class WordSet(tuple):
"""
Given a Python identifier, return the words that identifier represents,
whether in camel case, underscore-separated, etc.
>>> WordSet.parse("camelCase")
('camel', 'Case')
>>> WordSet.parse("under_sep")
('under', 'sep')
Acronyms should be retained
>>> WordSet.parse("firstSNL")
('first', 'SNL')
>>> WordSet.parse("you_and_I")
('you', 'and', 'I')
>>> WordSet.parse("A simple test")
('A', 'simple', 'test')
Multiple caps should not interfere with the first cap of another word.
>>> WordSet.parse("myABCClass")
('my', 'ABC', 'Class')
The result is a WordSet, so you can get the form you need.
>>> WordSet.parse("myABCClass").underscore_separated()
'my_ABC_Class'
>>> WordSet.parse('a-command').camel_case()
'ACommand'
>>> WordSet.parse('someIdentifier').lowered().space_separated()
'some identifier'
Slices of the result should return another WordSet.
>>> WordSet.parse('taken-out-of-context')[1:].underscore_separated()
'out_of_context'
>>> WordSet.from_class_name(WordSet()).lowered().space_separated()
'word set'
"""
_pattern = re.compile('([A-Z]?[a-z]+)|([A-Z]+(?![a-z]))')
def capitalized(self):
return WordSet(word.capitalize() for word in self)
def lowered(self):
return WordSet(word.lower() for word in self)
def camel_case(self):
return ''.join(self.capitalized())
def headless_camel_case(self):
words = iter(self)
first = next(words).lower()
return itertools.chain((first,), WordSet(words).camel_case())
def underscore_separated(self):
return '_'.join(self)
def dash_separated(self):
return '-'.join(self)
def space_separated(self):
return ' '.join(self)
def __getitem__(self, item):
result = super(WordSet, self).__getitem__(item)
if isinstance(item, slice):
result = WordSet(result)
return result
# for compatibility with Python 2
def __getslice__(self, i, j):
return self.__getitem__(slice(i, j))
@classmethod
def parse(cls, identifier):
matches = cls._pattern.finditer(identifier)
return WordSet(match.group(0) for match in matches)
@classmethod
def from_class_name(cls, subject):
return cls.parse(subject.__class__.__name__)
# for backward compatibility
words = WordSet.parse
def simple_html_strip(s):
r"""
Remove HTML from the string `s`.
>>> str(simple_html_strip(''))
''
>>> print(simple_html_strip('A <bold>stormy</bold> day in paradise'))
A stormy day in paradise
>>> print(simple_html_strip('Somebody <!-- do not --> tell the truth.'))
Somebody tell the truth.
>>> print(simple_html_strip('What about<br/>\nmultiple lines?'))
What about
multiple lines?
"""
html_stripper = re.compile('(<!--.*?-->)|(<[^>]*>)|([^<]+)', re.DOTALL)
texts = (
match.group(3) or ''
for match
in html_stripper.finditer(s)
)
return ''.join(texts)
class SeparatedValues(six.text_type):
"""
A string separated by a separator. Overrides __iter__ for getting
the values.
>>> list(SeparatedValues('a,b,c'))
['a', 'b', 'c']
Whitespace is stripped and empty values are discarded.
>>> list(SeparatedValues(' a, b , c, '))
['a', 'b', 'c']
"""
separator = ','
def __iter__(self):
parts = self.split(self.separator)
return six.moves.filter(None, (part.strip() for part in parts))
class Stripper:
r"""
Given a series of lines, find the common prefix and strip it from them.
>>> lines = [
... 'abcdefg\n',
... 'abc\n',
... 'abcde\n',
... ]
>>> res = Stripper.strip_prefix(lines)
>>> res.prefix
'abc'
>>> list(res.lines)
['defg\n', '\n', 'de\n']
If no prefix is common, nothing should be stripped.
>>> lines = [
... 'abcd\n',
... '1234\n',
... ]
>>> res = Stripper.strip_prefix(lines)
>>> res.prefix = ''
>>> list(res.lines)
['abcd\n', '1234\n']
"""
def __init__(self, prefix, lines):
self.prefix = prefix
self.lines = map(self, lines)
@classmethod
def strip_prefix(cls, lines):
prefix_lines, lines = itertools.tee(lines)
prefix = functools.reduce(cls.common_prefix, prefix_lines)
return cls(prefix, lines)
def __call__(self, line):
if not self.prefix:
return line
null, prefix, rest = line.partition(self.prefix)
return rest
@staticmethod
def common_prefix(s1, s2):
"""
Return the common prefix of two lines.
"""
index = min(len(s1), len(s2))
while s1[:index] != s2[:index]:
index -= 1
return s1[:index]
def remove_prefix(text, prefix):
"""
Remove the prefix from the text if it exists.
>>> remove_prefix('underwhelming performance', 'underwhelming ')
'performance'
>>> remove_prefix('something special', 'sample')
'something special'
"""
null, prefix, rest = text.rpartition(prefix)
return rest
def remove_suffix(text, suffix):
"""
Remove the suffix from the text if it exists.
>>> remove_suffix('name.git', '.git')
'name'
>>> remove_suffix('something special', 'sample')
'something special'
"""
rest, suffix, null = text.partition(suffix)
return rest

View file

@ -0,0 +1,2 @@
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
Curabitur pretium tincidunt lacus. Nulla gravida orci a odio. Nullam varius, turpis et commodo pharetra, est eros bibendum elit, nec luctus magna felis sollicitudin mauris. Integer in mauris eu nibh euismod gravida. Duis ac tellus et risus vulputate vehicula. Donec lobortis risus a elit. Etiam tempor. Ut ullamcorper, ligula eu tempor congue, eros est euismod turpis, id tincidunt sapien risus a quam. Maecenas fermentum consequat mi. Donec fermentum. Pellentesque malesuada nulla a mi. Duis sapien sem, aliquet nec, commodo eget, consequat quis, neque. Aliquam faucibus, elit ut dictum aliquet, felis nisl adipiscing sapien, sed malesuada diam lacus eget erat. Cras mollis scelerisque nunc. Nullam arcu. Aliquam consequat. Curabitur augue lorem, dapibus quis, laoreet et, pretium ac, nisi. Aenean magna nisl, mollis quis, molestie eu, feugiat in, orci. In hac habitasse platea dictumst.

View file

@ -0,0 +1,622 @@
import re
import itertools
import textwrap
import functools
try:
from importlib.resources import files # type: ignore
except ImportError: # pragma: nocover
from importlib_resources import files # type: ignore
from jaraco.functools import compose, method_cache
from jaraco.context import ExceptionTrap
def substitution(old, new):
"""
Return a function that will perform a substitution on a string
"""
return lambda s: s.replace(old, new)
def multi_substitution(*substitutions):
"""
Take a sequence of pairs specifying substitutions, and create
a function that performs those substitutions.
>>> multi_substitution(('foo', 'bar'), ('bar', 'baz'))('foo')
'baz'
"""
substitutions = itertools.starmap(substitution, substitutions)
# compose function applies last function first, so reverse the
# substitutions to get the expected order.
substitutions = reversed(tuple(substitutions))
return compose(*substitutions)
class FoldedCase(str):
"""
A case insensitive string class; behaves just like str
except compares equal when the only variation is case.
>>> s = FoldedCase('hello world')
>>> s == 'Hello World'
True
>>> 'Hello World' == s
True
>>> s != 'Hello World'
False
>>> s.index('O')
4
>>> s.split('O')
['hell', ' w', 'rld']
>>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta']))
['alpha', 'Beta', 'GAMMA']
Sequence membership is straightforward.
>>> "Hello World" in [s]
True
>>> s in ["Hello World"]
True
Allows testing for set inclusion, but candidate and elements
must both be folded.
>>> FoldedCase("Hello World") in {s}
True
>>> s in {FoldedCase("Hello World")}
True
String inclusion works as long as the FoldedCase object
is on the right.
>>> "hello" in FoldedCase("Hello World")
True
But not if the FoldedCase object is on the left:
>>> FoldedCase('hello') in 'Hello World'
False
In that case, use ``in_``:
>>> FoldedCase('hello').in_('Hello World')
True
>>> FoldedCase('hello') > FoldedCase('Hello')
False
>>> FoldedCase('ß') == FoldedCase('ss')
True
"""
def __lt__(self, other):
return self.casefold() < other.casefold()
def __gt__(self, other):
return self.casefold() > other.casefold()
def __eq__(self, other):
return self.casefold() == other.casefold()
def __ne__(self, other):
return self.casefold() != other.casefold()
def __hash__(self):
return hash(self.casefold())
def __contains__(self, other):
return super().casefold().__contains__(other.casefold())
def in_(self, other):
"Does self appear in other?"
return self in FoldedCase(other)
# cache casefold since it's likely to be called frequently.
@method_cache
def casefold(self):
return super().casefold()
def index(self, sub):
return self.casefold().index(sub.casefold())
def split(self, splitter=' ', maxsplit=0):
pattern = re.compile(re.escape(splitter), re.I)
return pattern.split(self, maxsplit)
# Python 3.8 compatibility
_unicode_trap = ExceptionTrap(UnicodeDecodeError)
@_unicode_trap.passes
def is_decodable(value):
r"""
Return True if the supplied value is decodable (using the default
encoding).
>>> is_decodable(b'\xff')
False
>>> is_decodable(b'\x32')
True
"""
value.decode()
def is_binary(value):
r"""
Return True if the value appears to be binary (that is, it's a byte
string and isn't decodable).
>>> is_binary(b'\xff')
True
>>> is_binary('\xff')
False
"""
return isinstance(value, bytes) and not is_decodable(value)
def trim(s):
r"""
Trim something like a docstring to remove the whitespace that
is common due to indentation and formatting.
>>> trim("\n\tfoo = bar\n\t\tbar = baz\n")
'foo = bar\n\tbar = baz'
"""
return textwrap.dedent(s).strip()
def wrap(s):
"""
Wrap lines of text, retaining existing newlines as
paragraph markers.
>>> print(wrap(lorem_ipsum))
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do
eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad
minim veniam, quis nostrud exercitation ullamco laboris nisi ut
aliquip ex ea commodo consequat. Duis aute irure dolor in
reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla
pariatur. Excepteur sint occaecat cupidatat non proident, sunt in
culpa qui officia deserunt mollit anim id est laborum.
<BLANKLINE>
Curabitur pretium tincidunt lacus. Nulla gravida orci a odio. Nullam
varius, turpis et commodo pharetra, est eros bibendum elit, nec luctus
magna felis sollicitudin mauris. Integer in mauris eu nibh euismod
gravida. Duis ac tellus et risus vulputate vehicula. Donec lobortis
risus a elit. Etiam tempor. Ut ullamcorper, ligula eu tempor congue,
eros est euismod turpis, id tincidunt sapien risus a quam. Maecenas
fermentum consequat mi. Donec fermentum. Pellentesque malesuada nulla
a mi. Duis sapien sem, aliquet nec, commodo eget, consequat quis,
neque. Aliquam faucibus, elit ut dictum aliquet, felis nisl adipiscing
sapien, sed malesuada diam lacus eget erat. Cras mollis scelerisque
nunc. Nullam arcu. Aliquam consequat. Curabitur augue lorem, dapibus
quis, laoreet et, pretium ac, nisi. Aenean magna nisl, mollis quis,
molestie eu, feugiat in, orci. In hac habitasse platea dictumst.
"""
paragraphs = s.splitlines()
wrapped = ('\n'.join(textwrap.wrap(para)) for para in paragraphs)
return '\n\n'.join(wrapped)
def unwrap(s):
r"""
Given a multi-line string, return an unwrapped version.
>>> wrapped = wrap(lorem_ipsum)
>>> wrapped.count('\n')
20
>>> unwrapped = unwrap(wrapped)
>>> unwrapped.count('\n')
1
>>> print(unwrapped)
Lorem ipsum dolor sit amet, consectetur adipiscing ...
Curabitur pretium tincidunt lacus. Nulla gravida orci ...
"""
paragraphs = re.split(r'\n\n+', s)
cleaned = (para.replace('\n', ' ') for para in paragraphs)
return '\n'.join(cleaned)
lorem_ipsum: str = files(__name__).joinpath('Lorem ipsum.txt').read_text()
class Splitter(object):
"""object that will split a string with the given arguments for each call
>>> s = Splitter(',')
>>> s('hello, world, this is your, master calling')
['hello', ' world', ' this is your', ' master calling']
"""
def __init__(self, *args):
self.args = args
def __call__(self, s):
return s.split(*self.args)
def indent(string, prefix=' ' * 4):
"""
>>> indent('foo')
' foo'
"""
return prefix + string
class WordSet(tuple):
"""
Given an identifier, return the words that identifier represents,
whether in camel case, underscore-separated, etc.
>>> WordSet.parse("camelCase")
('camel', 'Case')
>>> WordSet.parse("under_sep")
('under', 'sep')
Acronyms should be retained
>>> WordSet.parse("firstSNL")
('first', 'SNL')
>>> WordSet.parse("you_and_I")
('you', 'and', 'I')
>>> WordSet.parse("A simple test")
('A', 'simple', 'test')
Multiple caps should not interfere with the first cap of another word.
>>> WordSet.parse("myABCClass")
('my', 'ABC', 'Class')
The result is a WordSet, providing access to other forms.
>>> WordSet.parse("myABCClass").underscore_separated()
'my_ABC_Class'
>>> WordSet.parse('a-command').camel_case()
'ACommand'
>>> WordSet.parse('someIdentifier').lowered().space_separated()
'some identifier'
Slices of the result should return another WordSet.
>>> WordSet.parse('taken-out-of-context')[1:].underscore_separated()
'out_of_context'
>>> WordSet.from_class_name(WordSet()).lowered().space_separated()
'word set'
>>> example = WordSet.parse('figured it out')
>>> example.headless_camel_case()
'figuredItOut'
>>> example.dash_separated()
'figured-it-out'
"""
_pattern = re.compile('([A-Z]?[a-z]+)|([A-Z]+(?![a-z]))')
def capitalized(self):
return WordSet(word.capitalize() for word in self)
def lowered(self):
return WordSet(word.lower() for word in self)
def camel_case(self):
return ''.join(self.capitalized())
def headless_camel_case(self):
words = iter(self)
first = next(words).lower()
new_words = itertools.chain((first,), WordSet(words).camel_case())
return ''.join(new_words)
def underscore_separated(self):
return '_'.join(self)
def dash_separated(self):
return '-'.join(self)
def space_separated(self):
return ' '.join(self)
def trim_right(self, item):
"""
Remove the item from the end of the set.
>>> WordSet.parse('foo bar').trim_right('foo')
('foo', 'bar')
>>> WordSet.parse('foo bar').trim_right('bar')
('foo',)
>>> WordSet.parse('').trim_right('bar')
()
"""
return self[:-1] if self and self[-1] == item else self
def trim_left(self, item):
"""
Remove the item from the beginning of the set.
>>> WordSet.parse('foo bar').trim_left('foo')
('bar',)
>>> WordSet.parse('foo bar').trim_left('bar')
('foo', 'bar')
>>> WordSet.parse('').trim_left('bar')
()
"""
return self[1:] if self and self[0] == item else self
def trim(self, item):
"""
>>> WordSet.parse('foo bar').trim('foo')
('bar',)
"""
return self.trim_left(item).trim_right(item)
def __getitem__(self, item):
result = super(WordSet, self).__getitem__(item)
if isinstance(item, slice):
result = WordSet(result)
return result
@classmethod
def parse(cls, identifier):
matches = cls._pattern.finditer(identifier)
return WordSet(match.group(0) for match in matches)
@classmethod
def from_class_name(cls, subject):
return cls.parse(subject.__class__.__name__)
# for backward compatibility
words = WordSet.parse
def simple_html_strip(s):
r"""
Remove HTML from the string `s`.
>>> str(simple_html_strip(''))
''
>>> print(simple_html_strip('A <bold>stormy</bold> day in paradise'))
A stormy day in paradise
>>> print(simple_html_strip('Somebody <!-- do not --> tell the truth.'))
Somebody tell the truth.
>>> print(simple_html_strip('What about<br/>\nmultiple lines?'))
What about
multiple lines?
"""
html_stripper = re.compile('(<!--.*?-->)|(<[^>]*>)|([^<]+)', re.DOTALL)
texts = (match.group(3) or '' for match in html_stripper.finditer(s))
return ''.join(texts)
class SeparatedValues(str):
"""
A string separated by a separator. Overrides __iter__ for getting
the values.
>>> list(SeparatedValues('a,b,c'))
['a', 'b', 'c']
Whitespace is stripped and empty values are discarded.
>>> list(SeparatedValues(' a, b , c, '))
['a', 'b', 'c']
"""
separator = ','
def __iter__(self):
parts = self.split(self.separator)
return filter(None, (part.strip() for part in parts))
class Stripper:
r"""
Given a series of lines, find the common prefix and strip it from them.
>>> lines = [
... 'abcdefg\n',
... 'abc\n',
... 'abcde\n',
... ]
>>> res = Stripper.strip_prefix(lines)
>>> res.prefix
'abc'
>>> list(res.lines)
['defg\n', '\n', 'de\n']
If no prefix is common, nothing should be stripped.
>>> lines = [
... 'abcd\n',
... '1234\n',
... ]
>>> res = Stripper.strip_prefix(lines)
>>> res.prefix = ''
>>> list(res.lines)
['abcd\n', '1234\n']
"""
def __init__(self, prefix, lines):
self.prefix = prefix
self.lines = map(self, lines)
@classmethod
def strip_prefix(cls, lines):
prefix_lines, lines = itertools.tee(lines)
prefix = functools.reduce(cls.common_prefix, prefix_lines)
return cls(prefix, lines)
def __call__(self, line):
if not self.prefix:
return line
null, prefix, rest = line.partition(self.prefix)
return rest
@staticmethod
def common_prefix(s1, s2):
"""
Return the common prefix of two lines.
"""
index = min(len(s1), len(s2))
while s1[:index] != s2[:index]:
index -= 1
return s1[:index]
def remove_prefix(text, prefix):
"""
Remove the prefix from the text if it exists.
>>> remove_prefix('underwhelming performance', 'underwhelming ')
'performance'
>>> remove_prefix('something special', 'sample')
'something special'
"""
null, prefix, rest = text.rpartition(prefix)
return rest
def remove_suffix(text, suffix):
"""
Remove the suffix from the text if it exists.
>>> remove_suffix('name.git', '.git')
'name'
>>> remove_suffix('something special', 'sample')
'something special'
"""
rest, suffix, null = text.partition(suffix)
return rest
def normalize_newlines(text):
r"""
Replace alternate newlines with the canonical newline.
>>> normalize_newlines('Lorem Ipsum\u2029')
'Lorem Ipsum\n'
>>> normalize_newlines('Lorem Ipsum\r\n')
'Lorem Ipsum\n'
>>> normalize_newlines('Lorem Ipsum\x85')
'Lorem Ipsum\n'
"""
newlines = ['\r\n', '\r', '\n', '\u0085', '\u2028', '\u2029']
pattern = '|'.join(newlines)
return re.sub(pattern, '\n', text)
def _nonblank(str):
return str and not str.startswith('#')
@functools.singledispatch
def yield_lines(iterable):
r"""
Yield valid lines of a string or iterable.
>>> list(yield_lines(''))
[]
>>> list(yield_lines(['foo', 'bar']))
['foo', 'bar']
>>> list(yield_lines('foo\nbar'))
['foo', 'bar']
>>> list(yield_lines('\nfoo\n#bar\nbaz #comment'))
['foo', 'baz #comment']
>>> list(yield_lines(['foo\nbar', 'baz', 'bing\n\n\n']))
['foo', 'bar', 'baz', 'bing']
"""
return itertools.chain.from_iterable(map(yield_lines, iterable))
@yield_lines.register(str)
def _(text):
return filter(_nonblank, map(str.strip, text.splitlines()))
def drop_comment(line):
"""
Drop comments.
>>> drop_comment('foo # bar')
'foo'
A hash without a space may be in a URL.
>>> drop_comment('http://example.com/foo#bar')
'http://example.com/foo#bar'
"""
return line.partition(' #')[0]
def join_continuation(lines):
r"""
Join lines continued by a trailing backslash.
>>> list(join_continuation(['foo \\', 'bar', 'baz']))
['foobar', 'baz']
>>> list(join_continuation(['foo \\', 'bar', 'baz']))
['foobar', 'baz']
>>> list(join_continuation(['foo \\', 'bar \\', 'baz']))
['foobarbaz']
Not sure why, but...
The character preceeding the backslash is also elided.
>>> list(join_continuation(['goo\\', 'dly']))
['godly']
A terrible idea, but...
If no line is available to continue, suppress the lines.
>>> list(join_continuation(['foo', 'bar\\', 'baz\\']))
['foo']
"""
lines = iter(lines)
for item in lines:
while item.endswith('\\'):
try:
item = item[:-2].strip() + next(lines)
except StopIteration:
return
yield item
def read_newlines(filename, limit=1024):
r"""
>>> tmp_path = getfixture('tmp_path')
>>> filename = tmp_path / 'out.txt'
>>> _ = filename.write_text('foo\n', newline='')
>>> read_newlines(filename)
'\n'
>>> _ = filename.write_text('foo\r\n', newline='')
>>> read_newlines(filename)
'\r\n'
>>> _ = filename.write_text('foo\r\nbar\nbing\r', newline='')
>>> read_newlines(filename)
('\r', '\n', '\r\n')
"""
with open(filename) as fp:
fp.read(limit)
return fp.newlines

View file

@ -0,0 +1,25 @@
qwerty = "-=qwertyuiop[]asdfghjkl;'zxcvbnm,./_+QWERTYUIOP{}ASDFGHJKL:\"ZXCVBNM<>?"
dvorak = "[]',.pyfgcrl/=aoeuidhtns-;qjkxbmwvz{}\"<>PYFGCRL?+AOEUIDHTNS_:QJKXBMWVZ"
to_dvorak = str.maketrans(qwerty, dvorak)
to_qwerty = str.maketrans(dvorak, qwerty)
def translate(input, translation):
"""
>>> translate('dvorak', to_dvorak)
'ekrpat'
>>> translate('qwerty', to_qwerty)
'x,dokt'
"""
return input.translate(translation)
def _translate_stream(stream, translation):
"""
>>> import io
>>> _translate_stream(io.StringIO('foo'), to_dvorak)
urr
"""
print(translate(stream.read(), translation))

View file

@ -0,0 +1,33 @@
import autocommand
import inflect
from more_itertools import always_iterable
import jaraco.text
def report_newlines(filename):
r"""
Report the newlines in the indicated file.
>>> tmp_path = getfixture('tmp_path')
>>> filename = tmp_path / 'out.txt'
>>> _ = filename.write_text('foo\nbar\n', newline='')
>>> report_newlines(filename)
newline is '\n'
>>> filename = tmp_path / 'out.txt'
>>> _ = filename.write_text('foo\nbar\r\n', newline='')
>>> report_newlines(filename)
newlines are ('\n', '\r\n')
"""
newlines = jaraco.text.read_newlines(filename)
count = len(tuple(always_iterable(newlines)))
engine = inflect.engine()
print(
engine.plural_noun("newline", count),
engine.plural_verb("is", count),
repr(newlines),
)
autocommand.autocommand(__name__)(report_newlines)

View file

@ -0,0 +1,21 @@
import sys
import autocommand
from jaraco.text import Stripper
def strip_prefix():
r"""
Strip any common prefix from stdin.
>>> import io, pytest
>>> getfixture('monkeypatch').setattr('sys.stdin', io.StringIO('abcdef\nabc123'))
>>> strip_prefix()
def
123
"""
sys.stdout.writelines(Stripper.strip_prefix(sys.stdin).lines)
autocommand.autocommand(__name__)(strip_prefix)

View file

@ -0,0 +1,6 @@
import sys
from . import layouts
__name__ == '__main__' and layouts._translate_stream(sys.stdin, layouts.to_dvorak)

View file

@ -0,0 +1,6 @@
import sys
from . import layouts
__name__ == '__main__' and layouts._translate_stream(sys.stdin, layouts.to_qwerty)

View file

@ -1,77 +1,76 @@
import argparse
import six
from jaraco.classes import meta
from jaraco import text
from jaraco import text # type: ignore
@six.add_metaclass(meta.LeafClassesMeta)
class Command(object):
"""
A general-purpose base class for creating commands for a command-line
program using argparse. Each subclass of Command represents a separate
sub-command of a program.
class Command(metaclass=meta.LeafClassesMeta):
"""
A general-purpose base class for creating commands for a command-line
program using argparse. Each subclass of Command represents a separate
sub-command of a program.
For example, one might use Command subclasses to implement the Mercurial
command set::
For example, one might use Command subclasses to implement the Mercurial
command set::
class Commit(Command):
@staticmethod
def add_arguments(cls, parser):
parser.add_argument('-m', '--message')
class Commit(Command):
@staticmethod
def add_arguments(cls, parser):
parser.add_argument('-m', '--message')
@classmethod
def run(cls, args):
"Run the 'commit' command with args (parsed)"
@classmethod
def run(cls, args):
"Run the 'commit' command with args (parsed)"
class Merge(Command): pass
class Pull(Command): pass
...
class Merge(Command): pass
class Pull(Command): pass
...
Then one could create an entry point for Mercurial like so::
Then one could create an entry point for Mercurial like so::
def hg_command():
Command.invoke()
"""
def hg_command():
Command.invoke()
"""
@classmethod
def add_subparsers(cls, parser):
subparsers = parser.add_subparsers()
[cmd_class.add_parser(subparsers) for cmd_class in cls._leaf_classes]
@classmethod
def add_subparsers(cls, parser):
subparsers = parser.add_subparsers()
[cmd_class.add_parser(subparsers) for cmd_class in cls._leaf_classes]
@classmethod
def add_parser(cls, subparsers):
cmd_string = text.words(cls.__name__).lowered().dash_separated()
parser = subparsers.add_parser(cmd_string)
parser.set_defaults(action=cls)
cls.add_arguments(parser)
return parser
@classmethod
def add_parser(cls, subparsers):
cmd_string = text.words(cls.__name__).lowered().dash_separated()
parser = subparsers.add_parser(cmd_string)
parser.set_defaults(action=cls)
cls.add_arguments(parser)
return parser
@classmethod
def add_arguments(cls, parser):
pass
@classmethod
def add_arguments(cls, parser):
pass
@classmethod
def invoke(cls):
"""
Invoke the command using ArgumentParser
"""
parser = argparse.ArgumentParser()
cls.add_subparsers(parser)
args = parser.parse_args()
args.action.run(args)
@classmethod
def invoke(cls):
"""
Invoke the command using ArgumentParser
"""
parser = argparse.ArgumentParser()
cls.add_subparsers(parser)
args = parser.parse_args()
args.action.run(args)
class Extend(argparse.Action):
"""
Argparse action to take an nargs=* argument
and add any values to the existing value.
"""
Argparse action to take an nargs=* argument
and add any values to the existing value.
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', nargs='*', default=[], action=Extend)
>>> args = parser.parse_args(['--foo', 'a=1', '--foo', 'b=2', 'c=3'])
>>> args.foo
['a=1', 'b=2', 'c=3']
"""
def __call__(self, parser, namespace, values, option_string=None):
getattr(namespace, self.dest).extend(values)
>>> parser = argparse.ArgumentParser()
>>> _ = parser.add_argument('--foo', nargs='*', default=[], action=Extend)
>>> args = parser.parse_args(['--foo', 'a=1', '--foo', 'b=2', 'c=3'])
>>> args.foo
['a=1', 'b=2', 'c=3']
"""
def __call__(self, parser, namespace, values, option_string=None):
getattr(namespace, self.dest).extend(values)

View file

@ -1,5 +1,3 @@
from __future__ import unicode_literals, absolute_import
import tempfile
import os
import sys
@ -9,100 +7,105 @@ import collections
import io
import difflib
import six
from typing import Mapping
class EditProcessException(RuntimeError):
pass
class EditProcessException(RuntimeError): pass
class EditableFile(object):
"""
EditableFile saves some data to a temporary file, launches a
platform editor for interactive editing, and then reloads the data,
setting .changed to True if the data was edited.
"""
EditableFile saves some data to a temporary file, launches a
platform editor for interactive editing, and then reloads the data,
setting .changed to True if the data was edited.
e.g.::
e.g.::
x = EditableFile('foo')
x.edit()
x = EditableFile('foo')
x.edit()
if x.changed:
print(x.data)
if x.changed:
print(x.data)
The EDITOR environment variable can define which executable to use
(also XML_EDITOR if the content-type to edit includes 'xml'). If no
EDITOR is defined, defaults to 'notepad' on Windows and 'edit' on
other platforms.
"""
platform_default_editors = collections.defaultdict(
lambda: 'edit',
win32 = 'notepad',
linux2 = 'vi',
)
encoding = 'utf-8'
The EDITOR environment variable can define which executable to use
(also XML_EDITOR if the content-type to edit includes 'xml'). If no
EDITOR is defined, defaults to 'notepad' on Windows and 'edit' on
other platforms.
"""
def __init__(self, data='', content_type='text/plain'):
self.data = six.text_type(data)
self.content_type = content_type
platform_default_editors: Mapping[str, str] = collections.defaultdict(
lambda: 'edit',
win32='notepad',
linux2='vi',
)
encoding = 'utf-8'
def __enter__(self):
extension = mimetypes.guess_extension(self.content_type) or ''
fobj, self.name = tempfile.mkstemp(extension)
os.write(fobj, self.data.encode(self.encoding))
os.close(fobj)
return self
def __init__(self, data='', content_type='text/plain'):
self.data = str(data)
self.content_type = content_type
def read(self):
with open(self.name, 'rb') as f:
return f.read().decode(self.encoding)
def __enter__(self):
extension = mimetypes.guess_extension(self.content_type) or ''
fobj, self.name = tempfile.mkstemp(extension)
os.write(fobj, self.data.encode(self.encoding))
os.close(fobj)
return self
def __exit__(self, *tb_info):
os.remove(self.name)
def read(self):
with open(self.name, 'rb') as f:
return f.read().decode(self.encoding)
def edit(self):
"""
Edit the file
"""
self.changed = False
with self:
editor = self.get_editor()
cmd = [editor, self.name]
try:
res = subprocess.call(cmd)
except Exception as e:
print("Error launching editor %(editor)s" % locals())
print(e)
return
if res != 0:
msg = '%(editor)s returned error status %(res)d' % locals()
raise EditProcessException(msg)
new_data = self.read()
if new_data != self.data:
self.changed = self._save_diff(self.data, new_data)
self.data = new_data
def __exit__(self, *tb_info):
os.remove(self.name)
@staticmethod
def _search_env(keys):
"""
Search the environment for the supplied keys, returning the first
one found or None if none was found.
"""
matches = (os.environ[key] for key in keys if key in os.environ)
return next(matches, None)
def edit(self):
"""
Edit the file
"""
self.changed = False
with self:
editor = self.get_editor()
cmd = [editor, self.name]
try:
res = subprocess.call(cmd)
except Exception as e:
print("Error launching editor %(editor)s" % locals())
print(e)
return
if res != 0:
msg = '%(editor)s returned error status %(res)d' % locals()
raise EditProcessException(msg)
new_data = self.read()
if new_data != self.data:
self.changed = self._save_diff(self.data, new_data)
self.data = new_data
def get_editor(self):
"""
Give preference to an XML_EDITOR or EDITOR defined in the
environment. Otherwise use a default editor based on platform.
"""
env_search = ['EDITOR']
if 'xml' in self.content_type:
env_search.insert(0, 'XML_EDITOR')
default_editor = self.platform_default_editors[sys.platform]
return self._search_env(env_search) or default_editor
@staticmethod
def _search_env(keys):
"""
Search the environment for the supplied keys, returning the first
one found or None if none was found.
"""
matches = (os.environ[key] for key in keys if key in os.environ)
return next(matches, None)
@staticmethod
def _save_diff(*versions):
def get_lines(content):
return list(io.StringIO(content))
lines = map(get_lines, versions)
diff = difflib.context_diff(*lines)
return tuple(diff)
def get_editor(self):
"""
Give preference to an XML_EDITOR or EDITOR defined in the
environment. Otherwise use a default editor based on platform.
"""
env_search = ['EDITOR']
if 'xml' in self.content_type:
env_search.insert(0, 'XML_EDITOR')
default_editor = self.platform_default_editors[sys.platform]
return self._search_env(env_search) or default_editor
@staticmethod
def _save_diff(*versions):
def get_lines(content):
return list(io.StringIO(content))
lines = map(get_lines, versions)
diff = difflib.context_diff(*lines)
return tuple(diff)

View file

@ -3,24 +3,28 @@ This module currently provides a cross-platform getch function
"""
try:
# Windows
from msvcrt import getch
# Windows
from msvcrt import getch # type: ignore
getch # workaround for https://github.com/kevinw/pyflakes/issues/13
except ImportError:
pass
pass
try:
# Unix
import sys
import tty
import termios
# Unix
import sys
import tty
import termios
def getch(): # type: ignore
fd = sys.stdin.fileno()
old = termios.tcgetattr(fd)
try:
tty.setraw(fd)
return sys.stdin.read(1)
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old)
def getch():
fd = sys.stdin.fileno()
old = termios.tcgetattr(fd)
try:
tty.setraw(fd)
return sys.stdin.read(1)
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old)
except ImportError:
pass
pass

View file

@ -1,34 +1,32 @@
from __future__ import print_function, absolute_import, unicode_literals
import itertools
import six
class Menu(object):
"""
A simple command-line based menu
"""
def __init__(self, choices=None, formatter=str):
self.choices = choices or list()
self.formatter = formatter
"""
A simple command-line based menu
"""
def get_choice(self, prompt="> "):
n = len(self.choices)
number_width = len(str(n)) + 1
menu_fmt = '{number:{number_width}}) {choice}'
formatted_choices = map(self.formatter, self.choices)
for number, choice in zip(itertools.count(1), formatted_choices):
print(menu_fmt.format(**locals()))
print()
try:
answer = int(six.moves.input(prompt))
result = self.choices[answer - 1]
except ValueError:
print('invalid selection')
result = None
except IndexError:
print('invalid selection')
result = None
except KeyboardInterrupt:
result = None
return result
def __init__(self, choices=None, formatter=str):
self.choices = choices or list()
self.formatter = formatter
def get_choice(self, prompt="> "):
n = len(self.choices)
number_width = len(str(n)) + 1
menu_fmt = '{number:{number_width}}) {choice}'
formatted_choices = map(self.formatter, self.choices)
for number, choice in zip(itertools.count(1), formatted_choices):
print(menu_fmt.format(**locals()))
print()
try:
answer = int(input(prompt))
result = self.choices[answer - 1]
except ValueError:
print('invalid selection')
result = None
except IndexError:
print('invalid selection')
result = None
except KeyboardInterrupt:
result = None
return result

View file

@ -1,152 +1,141 @@
# deprecated -- use TQDM
from __future__ import (print_function, absolute_import, unicode_literals,
division)
import time
import sys
import itertools
import abc
import datetime
import six
class AbstractProgressBar(metaclass=abc.ABCMeta):
def __init__(self, unit='', size=70):
"""
Size is the nominal size in characters
"""
self.unit = unit
self.size = size
@six.add_metaclass(abc.ABCMeta)
class AbstractProgressBar(object):
def __init__(self, unit='', size=70):
"""
Size is the nominal size in characters
"""
self.unit = unit
self.size = size
def report(self, amt):
sys.stdout.write('\r%s' % self.get_bar(amt))
sys.stdout.flush()
def report(self, amt):
sys.stdout.write('\r%s' % self.get_bar(amt))
sys.stdout.flush()
@abc.abstractmethod
def get_bar(self, amt):
"Return the string to be printed. Should be size >= self.size"
@abc.abstractmethod
def get_bar(self, amt):
"Return the string to be printed. Should be size >= self.size"
def summary(self, str):
return ' (' + self.unit_str(str) + ')'
def summary(self, str):
return ' (' + self.unit_str(str) + ')'
def unit_str(self, str):
if self.unit:
str += ' ' + self.unit
return str
def unit_str(self, str):
if self.unit:
str += ' ' + self.unit
return str
def finish(self):
print()
def finish(self):
print()
def __enter__(self):
self.report(0)
return self
def __enter__(self):
self.report(0)
return self
def __exit__(self, exc, exc_val, tb):
if exc is None:
self.finish()
else:
print()
def __exit__(self, exc, exc_val, tb):
if exc is None:
self.finish()
else:
print()
def iterate(self, iterable):
"""
Report the status as the iterable is consumed.
"""
with self:
for n, item in enumerate(iterable, 1):
self.report(n)
yield item
def iterate(self, iterable):
"""
Report the status as the iterable is consumed.
"""
with self:
for n, item in enumerate(iterable, 1):
self.report(n)
yield item
class SimpleProgressBar(AbstractProgressBar):
_PROG_DISPGLYPH = itertools.cycle(['|', '/', '-', '\\'])
_PROG_DISPGLYPH = itertools.cycle(['|', '/', '-', '\\'])
def get_bar(self, amt):
bar = next(self._PROG_DISPGLYPH)
template = ' [{bar:^{bar_len}}]'
summary = self.summary('{amt}')
template += summary
empty = template.format(
bar='',
bar_len=0,
amt=amt,
)
bar_len = self.size - len(empty)
return template.format(**locals())
def get_bar(self, amt):
bar = next(self._PROG_DISPGLYPH)
template = ' [{bar:^{bar_len}}]'
summary = self.summary('{amt}')
template += summary
empty = template.format(
bar='',
bar_len=0,
amt=amt,
)
bar_len = self.size - len(empty)
return template.format(**locals())
@classmethod
def demo(cls):
bar3 = cls(unit='cubes', size=30)
with bar3:
for x in six.moves.range(1, 759):
bar3.report(x)
time.sleep(0.01)
@classmethod
def demo(cls):
bar3 = cls(unit='cubes', size=30)
with bar3:
for x in range(1, 759):
bar3.report(x)
time.sleep(0.01)
class TargetProgressBar(AbstractProgressBar):
def __init__(self, total=None, unit='', size=70):
"""
Size is the nominal size in characters
"""
self.total = total
super(TargetProgressBar, self).__init__(unit, size)
def __init__(self, total=None, unit='', size=70):
"""
Size is the nominal size in characters
"""
self.total = total
super(TargetProgressBar, self).__init__(unit, size)
def get_bar(self, amt):
template = ' [{bar:<{bar_len}}]'
completed = amt / self.total
percent = int(completed * 100)
percent_str = ' {percent:3}%'
template += percent_str
summary = self.summary('{amt}/{total}')
template += summary
empty = template.format(
total=self.total,
bar='',
bar_len=0,
**locals()
)
bar_len = self.size - len(empty)
bar = '=' * int(completed * bar_len)
return template.format(total=self.total, **locals())
def get_bar(self, amt):
template = ' [{bar:<{bar_len}}]'
completed = amt / self.total
percent = int(completed * 100)
percent_str = ' {percent:3}%'
template += percent_str
summary = self.summary('{amt}/{total}')
template += summary
empty = template.format(total=self.total, bar='', bar_len=0, **locals())
bar_len = self.size - len(empty)
bar = '=' * int(completed * bar_len)
return template.format(total=self.total, **locals())
@classmethod
def demo(cls):
bar1 = cls(100, 'blocks')
with bar1:
for x in six.moves.range(1, 101):
bar1.report(x)
time.sleep(0.05)
@classmethod
def demo(cls):
bar1 = cls(100, 'blocks')
with bar1:
for x in range(1, 101):
bar1.report(x)
time.sleep(0.05)
bar2 = cls(758, size=50)
with bar2:
for x in six.moves.range(1, 759):
bar2.report(x)
time.sleep(0.01)
bar2 = cls(758, size=50)
with bar2:
for x in range(1, 759):
bar2.report(x)
time.sleep(0.01)
def finish(self):
self.report(self.total)
super(TargetProgressBar, self).finish()
def finish(self):
self.report(self.total)
super(TargetProgressBar, self).finish()
def countdown(template, duration=datetime.timedelta(seconds=5)):
"""
Do a countdown for duration, printing the template (which may accept one
positional argument). Template should be something like
``countdown complete in {} seconds.``
"""
now = datetime.datetime.now()
deadline = now + duration
remaining = deadline - datetime.datetime.now()
while remaining:
remaining = deadline - datetime.datetime.now()
remaining = max(datetime.timedelta(), remaining)
msg = template.format(remaining.total_seconds())
print(msg, end=' '*10)
sys.stdout.flush()
time.sleep(.1)
print('\b'*80, end='')
sys.stdout.flush()
print()
"""
Do a countdown for duration, printing the template (which may accept one
positional argument). Template should be something like
``countdown complete in {} seconds.``
"""
now = datetime.datetime.now()
deadline = now + duration
remaining = deadline - datetime.datetime.now()
while remaining:
remaining = deadline - datetime.datetime.now()
remaining = max(datetime.timedelta(), remaining)
msg = template.format(remaining.total_seconds())
print(msg, end=' ' * 10)
sys.stdout.flush()
time.sleep(0.1)
print('\b' * 80, end='')
sys.stdout.flush()
print()

View file

@ -30,16 +30,16 @@ CF_GDIOBJFIRST = 0x0300
CF_GDIOBJLAST = 0x03FF
RegisterClipboardFormat = ctypes.windll.user32.RegisterClipboardFormatW
RegisterClipboardFormat.argtypes = ctypes.wintypes.LPWSTR,
RegisterClipboardFormat.argtypes = (ctypes.wintypes.LPWSTR,)
RegisterClipboardFormat.restype = ctypes.wintypes.UINT
CF_HTML = RegisterClipboardFormat('HTML Format')
EnumClipboardFormats = ctypes.windll.user32.EnumClipboardFormats
EnumClipboardFormats.argtypes = ctypes.wintypes.UINT,
EnumClipboardFormats.argtypes = (ctypes.wintypes.UINT,)
EnumClipboardFormats.restype = ctypes.wintypes.UINT
GetClipboardData = ctypes.windll.user32.GetClipboardData
GetClipboardData.argtypes = ctypes.wintypes.UINT,
GetClipboardData.argtypes = (ctypes.wintypes.UINT,)
GetClipboardData.restype = ctypes.wintypes.HANDLE
SetClipboardData = ctypes.windll.user32.SetClipboardData
@ -47,7 +47,7 @@ SetClipboardData.argtypes = ctypes.wintypes.UINT, ctypes.wintypes.HANDLE
SetClipboardData.restype = ctypes.wintypes.HANDLE
OpenClipboard = ctypes.windll.user32.OpenClipboard
OpenClipboard.argtypes = ctypes.wintypes.HANDLE,
OpenClipboard.argtypes = (ctypes.wintypes.HANDLE,)
OpenClipboard.restype = ctypes.wintypes.BOOL
ctypes.windll.user32.CloseClipboard.restype = ctypes.wintypes.BOOL

View file

@ -5,58 +5,52 @@ Support for Credential Vault
import ctypes
from ctypes.wintypes import DWORD, LPCWSTR, BOOL, LPWSTR, FILETIME
try:
from ctypes.wintypes import LPBYTE
from ctypes.wintypes import LPBYTE
except ImportError:
LPBYTE = ctypes.POINTER(ctypes.wintypes.BYTE)
LPBYTE = ctypes.POINTER(ctypes.wintypes.BYTE) # type: ignore
class CredentialAttribute(ctypes.Structure):
_fields_ = []
_fields_ = [] # type: ignore
class Credential(ctypes.Structure):
_fields_ = [
('flags', DWORD),
('type', DWORD),
('target_name', LPWSTR),
('comment', LPWSTR),
('last_written', FILETIME),
('credential_blob_size', DWORD),
('credential_blob', LPBYTE),
('persist', DWORD),
('attribute_count', DWORD),
('attributes', ctypes.POINTER(CredentialAttribute)),
('target_alias', LPWSTR),
('user_name', LPWSTR),
]
_fields_ = [
('flags', DWORD),
('type', DWORD),
('target_name', LPWSTR),
('comment', LPWSTR),
('last_written', FILETIME),
('credential_blob_size', DWORD),
('credential_blob', LPBYTE),
('persist', DWORD),
('attribute_count', DWORD),
('attributes', ctypes.POINTER(CredentialAttribute)),
('target_alias', LPWSTR),
('user_name', LPWSTR),
]
def __del__(self):
ctypes.windll.advapi32.CredFree(ctypes.byref(self))
def __del__(self):
ctypes.windll.advapi32.CredFree(ctypes.byref(self))
PCREDENTIAL = ctypes.POINTER(Credential)
CredRead = ctypes.windll.advapi32.CredReadW
CredRead.argtypes = (
LPCWSTR, # TargetName
DWORD, # Type
DWORD, # Flags
ctypes.POINTER(PCREDENTIAL), # Credential
LPCWSTR, # TargetName
DWORD, # Type
DWORD, # Flags
ctypes.POINTER(PCREDENTIAL), # Credential
)
CredRead.restype = BOOL
CredWrite = ctypes.windll.advapi32.CredWriteW
CredWrite.argtypes = (
PCREDENTIAL, # Credential
DWORD, # Flags
)
CredWrite.argtypes = (PCREDENTIAL, DWORD) # Credential # Flags
CredWrite.restype = BOOL
CredDelete = ctypes.windll.advapi32.CredDeleteW
CredDelete.argtypes = (
LPCWSTR, # TargetName
DWORD, # Type
DWORD, # Flags
)
CredDelete.argtypes = (LPCWSTR, DWORD, DWORD) # TargetName # Type # Flags
CredDelete.restype = BOOL

View file

@ -7,7 +7,7 @@ SetEnvironmentVariable.argtypes = [ctypes.wintypes.LPCWSTR] * 2
GetEnvironmentVariable = ctypes.windll.kernel32.GetEnvironmentVariableW
GetEnvironmentVariable.restype = ctypes.wintypes.BOOL
GetEnvironmentVariable.argtypes = [
ctypes.wintypes.LPCWSTR,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.DWORD,
ctypes.wintypes.LPCWSTR,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.DWORD,
]

View file

@ -1,19 +1,12 @@
from ctypes import windll, POINTER
from ctypes.wintypes import (
LPWSTR, DWORD, LPVOID, HANDLE, BOOL,
)
from ctypes.wintypes import LPWSTR, DWORD, LPVOID, HANDLE, BOOL
CreateEvent = windll.kernel32.CreateEventW
CreateEvent.argtypes = (
LPVOID, # LPSECURITY_ATTRIBUTES
BOOL,
BOOL,
LPWSTR,
)
CreateEvent.argtypes = (LPVOID, BOOL, BOOL, LPWSTR) # LPSECURITY_ATTRIBUTES
CreateEvent.restype = HANDLE
SetEvent = windll.kernel32.SetEvent
SetEvent.argtypes = HANDLE,
SetEvent.argtypes = (HANDLE,)
SetEvent.restype = BOOL
WaitForSingleObject = windll.kernel32.WaitForSingleObject
@ -26,11 +19,11 @@ _WaitForMultipleObjects.restype = DWORD
def WaitForMultipleObjects(handles, wait_all=False, timeout=0):
n_handles = len(handles)
handle_array = (HANDLE * n_handles)()
for index, handle in enumerate(handles):
handle_array[index] = handle
return _WaitForMultipleObjects(n_handles, handle_array, wait_all, timeout)
n_handles = len(handles)
handle_array = (HANDLE * n_handles)()
for index, handle in enumerate(handles):
handle_array[index] = handle
return _WaitForMultipleObjects(n_handles, handle_array, wait_all, timeout)
WAIT_OBJECT_0 = 0

View file

@ -2,22 +2,24 @@ import ctypes.wintypes
CreateSymbolicLink = ctypes.windll.kernel32.CreateSymbolicLinkW
CreateSymbolicLink.argtypes = (
ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.DWORD,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.DWORD,
)
CreateSymbolicLink.restype = ctypes.wintypes.BOOLEAN
SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE = 0x2
CreateHardLink = ctypes.windll.kernel32.CreateHardLinkW
CreateHardLink.argtypes = (
ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPVOID, # reserved for LPSECURITY_ATTRIBUTES
ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPVOID, # reserved for LPSECURITY_ATTRIBUTES
)
CreateHardLink.restype = ctypes.wintypes.BOOLEAN
GetFileAttributes = ctypes.windll.kernel32.GetFileAttributesW
GetFileAttributes.argtypes = ctypes.wintypes.LPWSTR,
GetFileAttributes.argtypes = (ctypes.wintypes.LPWSTR,)
GetFileAttributes.restype = ctypes.wintypes.DWORD
SetFileAttributes = ctypes.windll.kernel32.SetFileAttributesW
@ -28,31 +30,33 @@ MAX_PATH = 260
GetFinalPathNameByHandle = ctypes.windll.kernel32.GetFinalPathNameByHandleW
GetFinalPathNameByHandle.argtypes = (
ctypes.wintypes.HANDLE, ctypes.wintypes.LPWSTR, ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
ctypes.wintypes.HANDLE,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
)
GetFinalPathNameByHandle.restype = ctypes.wintypes.DWORD
class SECURITY_ATTRIBUTES(ctypes.Structure):
_fields_ = (
('length', ctypes.wintypes.DWORD),
('p_security_descriptor', ctypes.wintypes.LPVOID),
('inherit_handle', ctypes.wintypes.BOOLEAN),
)
_fields_ = (
('length', ctypes.wintypes.DWORD),
('p_security_descriptor', ctypes.wintypes.LPVOID),
('inherit_handle', ctypes.wintypes.BOOLEAN),
)
LPSECURITY_ATTRIBUTES = ctypes.POINTER(SECURITY_ATTRIBUTES)
CreateFile = ctypes.windll.kernel32.CreateFileW
CreateFile.argtypes = (
ctypes.wintypes.LPWSTR,
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
LPSECURITY_ATTRIBUTES,
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
ctypes.wintypes.HANDLE,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
LPSECURITY_ATTRIBUTES,
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
ctypes.wintypes.HANDLE,
)
CreateFile.restype = ctypes.wintypes.HANDLE
FILE_SHARE_READ = 1
@ -83,56 +87,56 @@ CloseHandle.restype = ctypes.wintypes.BOOLEAN
class WIN32_FIND_DATA(ctypes.wintypes.WIN32_FIND_DATAW):
"""
_fields_ = [
("dwFileAttributes", DWORD),
("ftCreationTime", FILETIME),
("ftLastAccessTime", FILETIME),
("ftLastWriteTime", FILETIME),
("nFileSizeHigh", DWORD),
("nFileSizeLow", DWORD),
("dwReserved0", DWORD),
("dwReserved1", DWORD),
("cFileName", WCHAR * MAX_PATH),
("cAlternateFileName", WCHAR * 14)]
]
"""
"""
_fields_ = [
("dwFileAttributes", DWORD),
("ftCreationTime", FILETIME),
("ftLastAccessTime", FILETIME),
("ftLastWriteTime", FILETIME),
("nFileSizeHigh", DWORD),
("nFileSizeLow", DWORD),
("dwReserved0", DWORD),
("dwReserved1", DWORD),
("cFileName", WCHAR * MAX_PATH),
("cAlternateFileName", WCHAR * 14)]
]
"""
@property
def file_attributes(self):
return self.dwFileAttributes
@property
def file_attributes(self):
return self.dwFileAttributes
@property
def creation_time(self):
return self.ftCreationTime
@property
def creation_time(self):
return self.ftCreationTime
@property
def last_access_time(self):
return self.ftLastAccessTime
@property
def last_access_time(self):
return self.ftLastAccessTime
@property
def last_write_time(self):
return self.ftLastWriteTime
@property
def last_write_time(self):
return self.ftLastWriteTime
@property
def file_size_words(self):
return [self.nFileSizeHigh, self.nFileSizeLow]
@property
def file_size_words(self):
return [self.nFileSizeHigh, self.nFileSizeLow]
@property
def reserved(self):
return [self.dwReserved0, self.dwReserved1]
@property
def reserved(self):
return [self.dwReserved0, self.dwReserved1]
@property
def filename(self):
return self.cFileName
@property
def filename(self):
return self.cFileName
@property
def alternate_filename(self):
return self.cAlternateFileName
@property
def alternate_filename(self):
return self.cAlternateFileName
@property
def file_size(self):
return self.nFileSizeHigh << 32 + self.nFileSizeLow
@property
def file_size(self):
return self.nFileSizeHigh << 32 + self.nFileSizeLow
LPWIN32_FIND_DATA = ctypes.POINTER(ctypes.wintypes.WIN32_FIND_DATAW)
@ -144,7 +148,7 @@ FindNextFile = ctypes.windll.kernel32.FindNextFileW
FindNextFile.argtypes = (ctypes.wintypes.HANDLE, LPWIN32_FIND_DATA)
FindNextFile.restype = ctypes.wintypes.BOOLEAN
ctypes.windll.kernel32.FindClose.argtypes = ctypes.wintypes.HANDLE,
ctypes.windll.kernel32.FindClose.argtypes = (ctypes.wintypes.HANDLE,)
SCS_32BIT_BINARY = 0 # A 32-bit Windows-based application
SCS_64BIT_BINARY = 6 # A 64-bit Windows-based application
@ -156,7 +160,8 @@ SCS_WOW_BINARY = 2 # A 16-bit Windows-based application
_GetBinaryType = ctypes.windll.kernel32.GetBinaryTypeW
_GetBinaryType.argtypes = (
ctypes.wintypes.LPWSTR, ctypes.POINTER(ctypes.wintypes.DWORD),
ctypes.wintypes.LPWSTR,
ctypes.POINTER(ctypes.wintypes.DWORD),
)
_GetBinaryType.restype = ctypes.wintypes.BOOL
@ -164,47 +169,47 @@ FILEOP_FLAGS = ctypes.wintypes.WORD
class BY_HANDLE_FILE_INFORMATION(ctypes.Structure):
_fields_ = [
('file_attributes', ctypes.wintypes.DWORD),
('creation_time', ctypes.wintypes.FILETIME),
('last_access_time', ctypes.wintypes.FILETIME),
('last_write_time', ctypes.wintypes.FILETIME),
('volume_serial_number', ctypes.wintypes.DWORD),
('file_size_high', ctypes.wintypes.DWORD),
('file_size_low', ctypes.wintypes.DWORD),
('number_of_links', ctypes.wintypes.DWORD),
('file_index_high', ctypes.wintypes.DWORD),
('file_index_low', ctypes.wintypes.DWORD),
]
_fields_ = [
('file_attributes', ctypes.wintypes.DWORD),
('creation_time', ctypes.wintypes.FILETIME),
('last_access_time', ctypes.wintypes.FILETIME),
('last_write_time', ctypes.wintypes.FILETIME),
('volume_serial_number', ctypes.wintypes.DWORD),
('file_size_high', ctypes.wintypes.DWORD),
('file_size_low', ctypes.wintypes.DWORD),
('number_of_links', ctypes.wintypes.DWORD),
('file_index_high', ctypes.wintypes.DWORD),
('file_index_low', ctypes.wintypes.DWORD),
]
@property
def file_size(self):
return (self.file_size_high << 32) + self.file_size_low
@property
def file_size(self):
return (self.file_size_high << 32) + self.file_size_low
@property
def file_index(self):
return (self.file_index_high << 32) + self.file_index_low
@property
def file_index(self):
return (self.file_index_high << 32) + self.file_index_low
GetFileInformationByHandle = ctypes.windll.kernel32.GetFileInformationByHandle
GetFileInformationByHandle.restype = ctypes.wintypes.BOOL
GetFileInformationByHandle.argtypes = (
ctypes.wintypes.HANDLE,
ctypes.POINTER(BY_HANDLE_FILE_INFORMATION),
ctypes.wintypes.HANDLE,
ctypes.POINTER(BY_HANDLE_FILE_INFORMATION),
)
class SHFILEOPSTRUCT(ctypes.Structure):
_fields_ = [
('status_dialog', ctypes.wintypes.HWND),
('operation', ctypes.wintypes.UINT),
('from_', ctypes.wintypes.LPWSTR),
('to', ctypes.wintypes.LPWSTR),
('flags', FILEOP_FLAGS),
('operations_aborted', ctypes.wintypes.BOOL),
('name_mapping_handles', ctypes.wintypes.LPVOID),
('progress_title', ctypes.wintypes.LPWSTR),
]
_fields_ = [
('status_dialog', ctypes.wintypes.HWND),
('operation', ctypes.wintypes.UINT),
('from_', ctypes.wintypes.LPWSTR),
('to', ctypes.wintypes.LPWSTR),
('flags', FILEOP_FLAGS),
('operations_aborted', ctypes.wintypes.BOOL),
('name_mapping_handles', ctypes.wintypes.LPVOID),
('progress_title', ctypes.wintypes.LPWSTR),
]
_SHFileOperation = ctypes.windll.shell32.SHFileOperationW
@ -218,12 +223,12 @@ FO_DELETE = 3
ReplaceFile = ctypes.windll.kernel32.ReplaceFileW
ReplaceFile.restype = ctypes.wintypes.BOOL
ReplaceFile.argtypes = [
ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.DWORD,
ctypes.wintypes.LPVOID,
ctypes.wintypes.LPVOID,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.DWORD,
ctypes.wintypes.LPVOID,
ctypes.wintypes.LPVOID,
]
REPLACEFILE_WRITE_THROUGH = 0x1
@ -232,20 +237,20 @@ REPLACEFILE_IGNORE_ACL_ERRORS = 0x4
class STAT_STRUCT(ctypes.Structure):
_fields_ = [
('dev', ctypes.c_uint),
('ino', ctypes.c_ushort),
('mode', ctypes.c_ushort),
('nlink', ctypes.c_short),
('uid', ctypes.c_short),
('gid', ctypes.c_short),
('rdev', ctypes.c_uint),
# the following 4 fields are ctypes.c_uint64 for _stat64
('size', ctypes.c_uint),
('atime', ctypes.c_uint),
('mtime', ctypes.c_uint),
('ctime', ctypes.c_uint),
]
_fields_ = [
('dev', ctypes.c_uint),
('ino', ctypes.c_ushort),
('mode', ctypes.c_ushort),
('nlink', ctypes.c_short),
('uid', ctypes.c_short),
('gid', ctypes.c_short),
('rdev', ctypes.c_uint),
# the following 4 fields are ctypes.c_uint64 for _stat64
('size', ctypes.c_uint),
('atime', ctypes.c_uint),
('mtime', ctypes.c_uint),
('ctime', ctypes.c_uint),
]
_wstat = ctypes.windll.msvcrt._wstat
@ -254,64 +259,64 @@ _wstat.restype = ctypes.c_int
FILE_NOTIFY_CHANGE_LAST_WRITE = 0x10
FindFirstChangeNotification = (
ctypes.windll.kernel32.FindFirstChangeNotificationW)
FindFirstChangeNotification = ctypes.windll.kernel32.FindFirstChangeNotificationW
FindFirstChangeNotification.argtypes = (
ctypes.wintypes.LPWSTR, ctypes.wintypes.BOOL, ctypes.wintypes.DWORD,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.BOOL,
ctypes.wintypes.DWORD,
)
FindFirstChangeNotification.restype = ctypes.wintypes.HANDLE
FindCloseChangeNotification = (
ctypes.windll.kernel32.FindCloseChangeNotification)
FindCloseChangeNotification.argtypes = ctypes.wintypes.HANDLE,
FindCloseChangeNotification = ctypes.windll.kernel32.FindCloseChangeNotification
FindCloseChangeNotification.argtypes = (ctypes.wintypes.HANDLE,)
FindCloseChangeNotification.restype = ctypes.wintypes.BOOL
FindNextChangeNotification = ctypes.windll.kernel32.FindNextChangeNotification
FindNextChangeNotification.argtypes = ctypes.wintypes.HANDLE,
FindNextChangeNotification.argtypes = (ctypes.wintypes.HANDLE,)
FindNextChangeNotification.restype = ctypes.wintypes.BOOL
FILE_FLAG_OPEN_REPARSE_POINT = 0x00200000
IO_REPARSE_TAG_SYMLINK = 0xA000000C
FSCTL_GET_REPARSE_POINT = 0x900a8
FSCTL_GET_REPARSE_POINT = 0x900A8
LPDWORD = ctypes.POINTER(ctypes.wintypes.DWORD)
LPOVERLAPPED = ctypes.wintypes.LPVOID
DeviceIoControl = ctypes.windll.kernel32.DeviceIoControl
DeviceIoControl.argtypes = [
ctypes.wintypes.HANDLE,
ctypes.wintypes.DWORD,
ctypes.wintypes.LPVOID,
ctypes.wintypes.DWORD,
ctypes.wintypes.LPVOID,
ctypes.wintypes.DWORD,
LPDWORD,
LPOVERLAPPED,
ctypes.wintypes.HANDLE,
ctypes.wintypes.DWORD,
ctypes.wintypes.LPVOID,
ctypes.wintypes.DWORD,
ctypes.wintypes.LPVOID,
ctypes.wintypes.DWORD,
LPDWORD,
LPOVERLAPPED,
]
DeviceIoControl.restype = ctypes.wintypes.BOOL
class REPARSE_DATA_BUFFER(ctypes.Structure):
_fields_ = [
('tag', ctypes.c_ulong),
('data_length', ctypes.c_ushort),
('reserved', ctypes.c_ushort),
('substitute_name_offset', ctypes.c_ushort),
('substitute_name_length', ctypes.c_ushort),
('print_name_offset', ctypes.c_ushort),
('print_name_length', ctypes.c_ushort),
('flags', ctypes.c_ulong),
('path_buffer', ctypes.c_byte * 1),
]
_fields_ = [
('tag', ctypes.c_ulong),
('data_length', ctypes.c_ushort),
('reserved', ctypes.c_ushort),
('substitute_name_offset', ctypes.c_ushort),
('substitute_name_length', ctypes.c_ushort),
('print_name_offset', ctypes.c_ushort),
('print_name_length', ctypes.c_ushort),
('flags', ctypes.c_ulong),
('path_buffer', ctypes.c_byte * 1),
]
def get_print_name(self):
wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR)
arr_typ = ctypes.wintypes.WCHAR * (self.print_name_length // wchar_size)
data = ctypes.byref(self.path_buffer, self.print_name_offset)
return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value
def get_print_name(self):
wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR)
arr_typ = ctypes.wintypes.WCHAR * (self.print_name_length // wchar_size)
data = ctypes.byref(self.path_buffer, self.print_name_offset)
return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value
def get_substitute_name(self):
wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR)
arr_typ = ctypes.wintypes.WCHAR * (self.substitute_name_length // wchar_size)
data = ctypes.byref(self.path_buffer, self.substitute_name_offset)
return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value
def get_substitute_name(self):
wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR)
arr_typ = ctypes.wintypes.WCHAR * (self.substitute_name_length // wchar_size)
data = ctypes.byref(self.path_buffer, self.substitute_name_offset)
return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value

View file

@ -4,11 +4,11 @@ from ctypes.wintypes import DWORD, WCHAR, BYTE, BOOL
# from mprapi.h
MAX_INTERFACE_NAME_LEN = 2**8
MAX_INTERFACE_NAME_LEN = 2 ** 8
# from iprtrmib.h
MAXLEN_PHYSADDR = 2**3
MAXLEN_IFDESCR = 2**8
MAXLEN_PHYSADDR = 2 ** 3
MAXLEN_IFDESCR = 2 ** 8
# from iptypes.h
MAX_ADAPTER_ADDRESS_LENGTH = 8
@ -16,114 +16,102 @@ MAX_DHCPV6_DUID_LENGTH = 130
class MIB_IFROW(ctypes.Structure):
_fields_ = (
('name', WCHAR * MAX_INTERFACE_NAME_LEN),
('index', DWORD),
('type', DWORD),
('MTU', DWORD),
('speed', DWORD),
('physical_address_length', DWORD),
('physical_address_raw', BYTE * MAXLEN_PHYSADDR),
('admin_status', DWORD),
('operational_status', DWORD),
('last_change', DWORD),
('octets_received', DWORD),
('unicast_packets_received', DWORD),
('non_unicast_packets_received', DWORD),
('incoming_discards', DWORD),
('incoming_errors', DWORD),
('incoming_unknown_protocols', DWORD),
('octets_sent', DWORD),
('unicast_packets_sent', DWORD),
('non_unicast_packets_sent', DWORD),
('outgoing_discards', DWORD),
('outgoing_errors', DWORD),
('outgoing_queue_length', DWORD),
('description_length', DWORD),
('description_raw', ctypes.c_char * MAXLEN_IFDESCR),
)
_fields_ = (
('name', WCHAR * MAX_INTERFACE_NAME_LEN),
('index', DWORD),
('type', DWORD),
('MTU', DWORD),
('speed', DWORD),
('physical_address_length', DWORD),
('physical_address_raw', BYTE * MAXLEN_PHYSADDR),
('admin_status', DWORD),
('operational_status', DWORD),
('last_change', DWORD),
('octets_received', DWORD),
('unicast_packets_received', DWORD),
('non_unicast_packets_received', DWORD),
('incoming_discards', DWORD),
('incoming_errors', DWORD),
('incoming_unknown_protocols', DWORD),
('octets_sent', DWORD),
('unicast_packets_sent', DWORD),
('non_unicast_packets_sent', DWORD),
('outgoing_discards', DWORD),
('outgoing_errors', DWORD),
('outgoing_queue_length', DWORD),
('description_length', DWORD),
('description_raw', ctypes.c_char * MAXLEN_IFDESCR),
)
def _get_binary_property(self, name):
val_prop = '{0}_raw'.format(name)
val = getattr(self, val_prop)
len_prop = '{0}_length'.format(name)
length = getattr(self, len_prop)
return str(memoryview(val))[:length]
def _get_binary_property(self, name):
val_prop = '{0}_raw'.format(name)
val = getattr(self, val_prop)
len_prop = '{0}_length'.format(name)
length = getattr(self, len_prop)
return str(memoryview(val))[:length]
@property
def physical_address(self):
return self._get_binary_property('physical_address')
@property
def physical_address(self):
return self._get_binary_property('physical_address')
@property
def description(self):
return self._get_binary_property('description')
@property
def description(self):
return self._get_binary_property('description')
class MIB_IFTABLE(ctypes.Structure):
_fields_ = (
('num_entries', DWORD), # dwNumEntries
('entries', MIB_IFROW * 0), # table
)
_fields_ = (
('num_entries', DWORD), # dwNumEntries
('entries', MIB_IFROW * 0), # table
)
class MIB_IPADDRROW(ctypes.Structure):
_fields_ = (
('address_num', DWORD),
('index', DWORD),
('mask', DWORD),
('broadcast_address', DWORD),
('reassembly_size', DWORD),
('unused', ctypes.c_ushort),
('type', ctypes.c_ushort),
)
_fields_ = (
('address_num', DWORD),
('index', DWORD),
('mask', DWORD),
('broadcast_address', DWORD),
('reassembly_size', DWORD),
('unused', ctypes.c_ushort),
('type', ctypes.c_ushort),
)
@property
def address(self):
"The address in big-endian"
_ = struct.pack('L', self.address_num)
return struct.unpack('!L', _)[0]
@property
def address(self):
"The address in big-endian"
_ = struct.pack('L', self.address_num)
return struct.unpack('!L', _)[0]
class MIB_IPADDRTABLE(ctypes.Structure):
_fields_ = (
('num_entries', DWORD),
('entries', MIB_IPADDRROW * 0),
)
_fields_ = (('num_entries', DWORD), ('entries', MIB_IPADDRROW * 0))
class SOCKADDR(ctypes.Structure):
_fields_ = (
('family', ctypes.c_ushort),
('data', ctypes.c_byte * 14),
)
_fields_ = (('family', ctypes.c_ushort), ('data', ctypes.c_byte * 14))
LPSOCKADDR = ctypes.POINTER(SOCKADDR)
class SOCKET_ADDRESS(ctypes.Structure):
_fields_ = [
('address', LPSOCKADDR),
('length', ctypes.c_int),
]
_fields_ = [('address', LPSOCKADDR), ('length', ctypes.c_int)]
class _IP_ADAPTER_ADDRESSES_METRIC(ctypes.Structure):
_fields_ = [
('length', ctypes.c_ulong),
('interface_index', DWORD),
]
_fields_ = [('length', ctypes.c_ulong), ('interface_index', DWORD)]
class _IP_ADAPTER_ADDRESSES_U1(ctypes.Union):
_fields_ = [
('alignment', ctypes.c_ulonglong),
('metric', _IP_ADAPTER_ADDRESSES_METRIC),
]
_fields_ = [
('alignment', ctypes.c_ulonglong),
('metric', _IP_ADAPTER_ADDRESSES_METRIC),
]
class IP_ADAPTER_ADDRESSES(ctypes.Structure):
pass
pass
LP_IP_ADAPTER_ADDRESSES = ctypes.POINTER(IP_ADAPTER_ADDRESSES)
@ -149,69 +137,69 @@ NET_IF_CONNECTION_TYPE = ctypes.c_uint # enum
TUNNEL_TYPE = ctypes.c_uint # enum
IP_ADAPTER_ADDRESSES._fields_ = [
# ('u', _IP_ADAPTER_ADDRESSES_U1),
('length', ctypes.c_ulong),
('interface_index', DWORD),
('next', LP_IP_ADAPTER_ADDRESSES),
('adapter_name', ctypes.c_char_p),
('first_unicast_address', PIP_ADAPTER_UNICAST_ADDRESS),
('first_anycast_address', PIP_ADAPTER_ANYCAST_ADDRESS),
('first_multicast_address', PIP_ADAPTER_MULTICAST_ADDRESS),
('first_dns_server_address', PIP_ADAPTER_DNS_SERVER_ADDRESS),
('dns_suffix', ctypes.c_wchar_p),
('description', ctypes.c_wchar_p),
('friendly_name', ctypes.c_wchar_p),
('byte', BYTE * MAX_ADAPTER_ADDRESS_LENGTH),
('physical_address_length', DWORD),
('flags', DWORD),
('mtu', DWORD),
('interface_type', DWORD),
('oper_status', IF_OPER_STATUS),
('ipv6_interface_index', DWORD),
('zone_indices', DWORD),
('first_prefix', PIP_ADAPTER_PREFIX),
('transmit_link_speed', ctypes.c_uint64),
('receive_link_speed', ctypes.c_uint64),
('first_wins_server_address', PIP_ADAPTER_WINS_SERVER_ADDRESS_LH),
('first_gateway_address', PIP_ADAPTER_GATEWAY_ADDRESS_LH),
('ipv4_metric', ctypes.c_ulong),
('ipv6_metric', ctypes.c_ulong),
('luid', IF_LUID),
('dhcpv4_server', SOCKET_ADDRESS),
('compartment_id', NET_IF_COMPARTMENT_ID),
('network_guid', NET_IF_NETWORK_GUID),
('connection_type', NET_IF_CONNECTION_TYPE),
('tunnel_type', TUNNEL_TYPE),
('dhcpv6_server', SOCKET_ADDRESS),
('dhcpv6_client_duid', ctypes.c_byte * MAX_DHCPV6_DUID_LENGTH),
('dhcpv6_client_duid_length', ctypes.c_ulong),
('dhcpv6_iaid', ctypes.c_ulong),
('first_dns_suffix', PIP_ADAPTER_DNS_SUFFIX),
# ('u', _IP_ADAPTER_ADDRESSES_U1),
('length', ctypes.c_ulong),
('interface_index', DWORD),
('next', LP_IP_ADAPTER_ADDRESSES),
('adapter_name', ctypes.c_char_p),
('first_unicast_address', PIP_ADAPTER_UNICAST_ADDRESS),
('first_anycast_address', PIP_ADAPTER_ANYCAST_ADDRESS),
('first_multicast_address', PIP_ADAPTER_MULTICAST_ADDRESS),
('first_dns_server_address', PIP_ADAPTER_DNS_SERVER_ADDRESS),
('dns_suffix', ctypes.c_wchar_p),
('description', ctypes.c_wchar_p),
('friendly_name', ctypes.c_wchar_p),
('byte', BYTE * MAX_ADAPTER_ADDRESS_LENGTH),
('physical_address_length', DWORD),
('flags', DWORD),
('mtu', DWORD),
('interface_type', DWORD),
('oper_status', IF_OPER_STATUS),
('ipv6_interface_index', DWORD),
('zone_indices', DWORD),
('first_prefix', PIP_ADAPTER_PREFIX),
('transmit_link_speed', ctypes.c_uint64),
('receive_link_speed', ctypes.c_uint64),
('first_wins_server_address', PIP_ADAPTER_WINS_SERVER_ADDRESS_LH),
('first_gateway_address', PIP_ADAPTER_GATEWAY_ADDRESS_LH),
('ipv4_metric', ctypes.c_ulong),
('ipv6_metric', ctypes.c_ulong),
('luid', IF_LUID),
('dhcpv4_server', SOCKET_ADDRESS),
('compartment_id', NET_IF_COMPARTMENT_ID),
('network_guid', NET_IF_NETWORK_GUID),
('connection_type', NET_IF_CONNECTION_TYPE),
('tunnel_type', TUNNEL_TYPE),
('dhcpv6_server', SOCKET_ADDRESS),
('dhcpv6_client_duid', ctypes.c_byte * MAX_DHCPV6_DUID_LENGTH),
('dhcpv6_client_duid_length', ctypes.c_ulong),
('dhcpv6_iaid', ctypes.c_ulong),
('first_dns_suffix', PIP_ADAPTER_DNS_SUFFIX),
]
# define some parameters to the API Functions
GetIfTable = ctypes.windll.iphlpapi.GetIfTable
GetIfTable.argtypes = [
ctypes.POINTER(MIB_IFTABLE),
ctypes.POINTER(ctypes.c_ulong),
BOOL,
ctypes.POINTER(MIB_IFTABLE),
ctypes.POINTER(ctypes.c_ulong),
BOOL,
]
GetIfTable.restype = DWORD
GetIpAddrTable = ctypes.windll.iphlpapi.GetIpAddrTable
GetIpAddrTable.argtypes = [
ctypes.POINTER(MIB_IPADDRTABLE),
ctypes.POINTER(ctypes.c_ulong),
BOOL,
ctypes.POINTER(MIB_IPADDRTABLE),
ctypes.POINTER(ctypes.c_ulong),
BOOL,
]
GetIpAddrTable.restype = DWORD
GetAdaptersAddresses = ctypes.windll.iphlpapi.GetAdaptersAddresses
GetAdaptersAddresses.argtypes = [
ctypes.c_ulong,
ctypes.c_ulong,
ctypes.c_void_p,
ctypes.POINTER(IP_ADAPTER_ADDRESSES),
ctypes.POINTER(ctypes.c_ulong),
ctypes.c_ulong,
ctypes.c_ulong,
ctypes.c_void_p,
ctypes.POINTER(IP_ADAPTER_ADDRESSES),
ctypes.POINTER(ctypes.c_ulong),
]
GetAdaptersAddresses.restype = ctypes.c_ulong

View file

@ -2,8 +2,8 @@ import ctypes.wintypes
GetModuleFileName = ctypes.windll.kernel32.GetModuleFileNameW
GetModuleFileName.argtypes = (
ctypes.wintypes.HANDLE,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.DWORD,
ctypes.wintypes.HANDLE,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.DWORD,
)
GetModuleFileName.restype = ctypes.wintypes.DWORD

View file

@ -7,25 +7,25 @@ GlobalAlloc.argtypes = ctypes.wintypes.UINT, ctypes.c_size_t
GlobalAlloc.restype = ctypes.wintypes.HANDLE
GlobalLock = ctypes.windll.kernel32.GlobalLock
GlobalLock.argtypes = ctypes.wintypes.HGLOBAL,
GlobalLock.argtypes = (ctypes.wintypes.HGLOBAL,)
GlobalLock.restype = ctypes.wintypes.LPVOID
GlobalUnlock = ctypes.windll.kernel32.GlobalUnlock
GlobalUnlock.argtypes = ctypes.wintypes.HGLOBAL,
GlobalUnlock.argtypes = (ctypes.wintypes.HGLOBAL,)
GlobalUnlock.restype = ctypes.wintypes.BOOL
GlobalSize = ctypes.windll.kernel32.GlobalSize
GlobalSize.argtypes = ctypes.wintypes.HGLOBAL,
GlobalSize.argtypes = (ctypes.wintypes.HGLOBAL,)
GlobalSize.restype = ctypes.c_size_t
CreateFileMapping = ctypes.windll.kernel32.CreateFileMappingW
CreateFileMapping.argtypes = [
ctypes.wintypes.HANDLE,
ctypes.c_void_p,
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
ctypes.wintypes.LPWSTR,
ctypes.wintypes.HANDLE,
ctypes.c_void_p,
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
ctypes.wintypes.DWORD,
ctypes.wintypes.LPWSTR,
]
CreateFileMapping.restype = ctypes.wintypes.HANDLE
@ -33,13 +33,9 @@ MapViewOfFile = ctypes.windll.kernel32.MapViewOfFile
MapViewOfFile.restype = ctypes.wintypes.HANDLE
UnmapViewOfFile = ctypes.windll.kernel32.UnmapViewOfFile
UnmapViewOfFile.argtypes = ctypes.wintypes.HANDLE,
UnmapViewOfFile.argtypes = (ctypes.wintypes.HANDLE,)
RtlMoveMemory = ctypes.windll.kernel32.RtlMoveMemory
RtlMoveMemory.argtypes = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
)
RtlMoveMemory.argtypes = (ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t)
ctypes.windll.kernel32.LocalFree.argtypes = ctypes.wintypes.HLOCAL,
ctypes.windll.kernel32.LocalFree.argtypes = (ctypes.wintypes.HLOCAL,)

View file

@ -1,5 +1,3 @@
#!/usr/bin/env python
"""
jaraco.windows.message
@ -9,22 +7,21 @@ Windows Messaging support
import ctypes
from ctypes.wintypes import HWND, UINT, WPARAM, LPARAM, DWORD, LPVOID
import six
LRESULT = LPARAM
class LPARAM_wstr(LPARAM):
"""
A special instance of LPARAM that can be constructed from a string
instance (for functions such as SendMessage, whose LPARAM may point to
a unicode string).
"""
@classmethod
def from_param(cls, param):
if isinstance(param, six.string_types):
return LPVOID.from_param(six.text_type(param))
return LPARAM.from_param(param)
"""
A special instance of LPARAM that can be constructed from a string
instance (for functions such as SendMessage, whose LPARAM may point to
a unicode string).
"""
@classmethod
def from_param(cls, param):
if isinstance(param, str):
return LPVOID.from_param(str(param))
return LPARAM.from_param(param)
SendMessage = ctypes.windll.user32.SendMessageW
@ -43,12 +40,10 @@ SMTO_NOTIMEOUTIFNOTHUNG = 0x08
SMTO_ERRORONEXIT = 0x20
SendMessageTimeout = ctypes.windll.user32.SendMessageTimeoutW
SendMessageTimeout.argtypes = SendMessage.argtypes + (
UINT, UINT, ctypes.POINTER(DWORD)
)
SendMessageTimeout.argtypes = SendMessage.argtypes + (UINT, UINT, ctypes.POINTER(DWORD))
SendMessageTimeout.restype = LRESULT
def unicode_as_lparam(source):
pointer = ctypes.cast(ctypes.c_wchar_p(source), ctypes.c_void_p)
return LPARAM(pointer.value)
pointer = ctypes.cast(ctypes.c_wchar_p(source), ctypes.c_void_p)
return LPARAM(pointer.value)

View file

@ -7,24 +7,24 @@ RESOURCETYPE_ANY = 0
class NETRESOURCE(ctypes.Structure):
_fields_ = [
('scope', ctypes.wintypes.DWORD),
('type', ctypes.wintypes.DWORD),
('display_type', ctypes.wintypes.DWORD),
('usage', ctypes.wintypes.DWORD),
('local_name', ctypes.wintypes.LPWSTR),
('remote_name', ctypes.wintypes.LPWSTR),
('comment', ctypes.wintypes.LPWSTR),
('provider', ctypes.wintypes.LPWSTR),
]
_fields_ = [
('scope', ctypes.wintypes.DWORD),
('type', ctypes.wintypes.DWORD),
('display_type', ctypes.wintypes.DWORD),
('usage', ctypes.wintypes.DWORD),
('local_name', ctypes.wintypes.LPWSTR),
('remote_name', ctypes.wintypes.LPWSTR),
('comment', ctypes.wintypes.LPWSTR),
('provider', ctypes.wintypes.LPWSTR),
]
LPNETRESOURCE = ctypes.POINTER(NETRESOURCE)
WNetAddConnection2 = mpr.WNetAddConnection2W
WNetAddConnection2.argtypes = (
LPNETRESOURCE,
ctypes.wintypes.LPCWSTR,
ctypes.wintypes.LPCWSTR,
ctypes.wintypes.DWORD,
LPNETRESOURCE,
ctypes.wintypes.LPCWSTR,
ctypes.wintypes.LPCWSTR,
ctypes.wintypes.DWORD,
)

View file

@ -2,24 +2,23 @@ import ctypes.wintypes
class SYSTEM_POWER_STATUS(ctypes.Structure):
_fields_ = (
('ac_line_status', ctypes.wintypes.BYTE),
('battery_flag', ctypes.wintypes.BYTE),
('battery_life_percent', ctypes.wintypes.BYTE),
('reserved', ctypes.wintypes.BYTE),
('battery_life_time', ctypes.wintypes.DWORD),
('battery_full_life_time', ctypes.wintypes.DWORD),
)
_fields_ = (
('ac_line_status', ctypes.wintypes.BYTE),
('battery_flag', ctypes.wintypes.BYTE),
('battery_life_percent', ctypes.wintypes.BYTE),
('reserved', ctypes.wintypes.BYTE),
('battery_life_time', ctypes.wintypes.DWORD),
('battery_full_life_time', ctypes.wintypes.DWORD),
)
@property
def ac_line_status_string(self):
return {
0: 'offline', 1: 'online', 255: 'unknown'}[self.ac_line_status]
@property
def ac_line_status_string(self):
return {0: 'offline', 1: 'online', 255: 'unknown'}[self.ac_line_status]
LPSYSTEM_POWER_STATUS = ctypes.POINTER(SYSTEM_POWER_STATUS)
GetSystemPowerStatus = ctypes.windll.kernel32.GetSystemPowerStatus
GetSystemPowerStatus.argtypes = LPSYSTEM_POWER_STATUS,
GetSystemPowerStatus.argtypes = (LPSYSTEM_POWER_STATUS,)
GetSystemPowerStatus.restype = ctypes.wintypes.BOOL
SetThreadExecutionState = ctypes.windll.kernel32.SetThreadExecutionState
@ -28,10 +27,11 @@ SetThreadExecutionState.restype = ctypes.c_uint
class ES:
"""
Execution state constants
"""
continuous = 0x80000000
system_required = 1
display_required = 2
awaymode_required = 0x40
"""
Execution state constants
"""
continuous = 0x80000000
system_required = 1
display_required = 2
awaymode_required = 0x40

View file

@ -2,35 +2,32 @@ import ctypes.wintypes
class LUID(ctypes.Structure):
_fields_ = [
('low_part', ctypes.wintypes.DWORD),
('high_part', ctypes.wintypes.LONG),
]
_fields_ = [
('low_part', ctypes.wintypes.DWORD),
('high_part', ctypes.wintypes.LONG),
]
def __eq__(self, other):
return (
self.high_part == other.high_part and
self.low_part == other.low_part
)
def __eq__(self, other):
return self.high_part == other.high_part and self.low_part == other.low_part
def __ne__(self, other):
return not (self == other)
def __ne__(self, other):
return not (self == other)
LookupPrivilegeValue = ctypes.windll.advapi32.LookupPrivilegeValueW
LookupPrivilegeValue.argtypes = (
ctypes.wintypes.LPWSTR, # system name
ctypes.wintypes.LPWSTR, # name
ctypes.POINTER(LUID),
ctypes.wintypes.LPWSTR, # system name
ctypes.wintypes.LPWSTR, # name
ctypes.POINTER(LUID),
)
LookupPrivilegeValue.restype = ctypes.wintypes.BOOL
class TOKEN_INFORMATION_CLASS:
TokenUser = 1
TokenGroups = 2
TokenPrivileges = 3
# ... see http://msdn.microsoft.com/en-us/library/aa379626%28VS.85%29.aspx
TokenUser = 1
TokenGroups = 2
TokenPrivileges = 3
# ... see http://msdn.microsoft.com/en-us/library/aa379626%28VS.85%29.aspx
SE_PRIVILEGE_ENABLED_BY_DEFAULT = 0x00000001
@ -40,67 +37,63 @@ SE_PRIVILEGE_USED_FOR_ACCESS = 0x80000000
class LUID_AND_ATTRIBUTES(ctypes.Structure):
_fields_ = [
('LUID', LUID),
('attributes', ctypes.wintypes.DWORD),
]
_fields_ = [('LUID', LUID), ('attributes', ctypes.wintypes.DWORD)]
def is_enabled(self):
return bool(self.attributes & SE_PRIVILEGE_ENABLED)
def is_enabled(self):
return bool(self.attributes & SE_PRIVILEGE_ENABLED)
def enable(self):
self.attributes |= SE_PRIVILEGE_ENABLED
def enable(self):
self.attributes |= SE_PRIVILEGE_ENABLED
def get_name(self):
size = ctypes.wintypes.DWORD(10240)
buf = ctypes.create_unicode_buffer(size.value)
res = LookupPrivilegeName(None, self.LUID, buf, size)
if res == 0:
raise RuntimeError
return buf[:size.value]
def get_name(self):
size = ctypes.wintypes.DWORD(10240)
buf = ctypes.create_unicode_buffer(size.value)
res = LookupPrivilegeName(None, self.LUID, buf, size)
if res == 0:
raise RuntimeError
return buf[: size.value]
def __str__(self):
res = self.get_name()
if self.is_enabled():
res += ' (enabled)'
return res
def __str__(self):
res = self.get_name()
if self.is_enabled():
res += ' (enabled)'
return res
LookupPrivilegeName = ctypes.windll.advapi32.LookupPrivilegeNameW
LookupPrivilegeName.argtypes = (
ctypes.wintypes.LPWSTR, # lpSystemName
ctypes.POINTER(LUID), # lpLuid
ctypes.wintypes.LPWSTR, # lpName
ctypes.POINTER(ctypes.wintypes.DWORD), # cchName
ctypes.wintypes.LPWSTR, # lpSystemName
ctypes.POINTER(LUID), # lpLuid
ctypes.wintypes.LPWSTR, # lpName
ctypes.POINTER(ctypes.wintypes.DWORD), # cchName
)
LookupPrivilegeName.restype = ctypes.wintypes.BOOL
class TOKEN_PRIVILEGES(ctypes.Structure):
_fields_ = [
('count', ctypes.wintypes.DWORD),
('privileges', LUID_AND_ATTRIBUTES * 0),
]
_fields_ = [
('count', ctypes.wintypes.DWORD),
('privileges', LUID_AND_ATTRIBUTES * 0),
]
def get_array(self):
array_type = LUID_AND_ATTRIBUTES * self.count
privileges = ctypes.cast(
self.privileges, ctypes.POINTER(array_type)).contents
return privileges
def get_array(self):
array_type = LUID_AND_ATTRIBUTES * self.count
privileges = ctypes.cast(self.privileges, ctypes.POINTER(array_type)).contents
return privileges
def __iter__(self):
return iter(self.get_array())
def __iter__(self):
return iter(self.get_array())
PTOKEN_PRIVILEGES = ctypes.POINTER(TOKEN_PRIVILEGES)
GetTokenInformation = ctypes.windll.advapi32.GetTokenInformation
GetTokenInformation.argtypes = [
ctypes.wintypes.HANDLE, # TokenHandle
ctypes.c_uint, # TOKEN_INFORMATION_CLASS value
ctypes.c_void_p, # TokenInformation
ctypes.wintypes.DWORD, # TokenInformationLength
ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength
ctypes.wintypes.HANDLE, # TokenHandle
ctypes.c_uint, # TOKEN_INFORMATION_CLASS value
ctypes.c_void_p, # TokenInformation
ctypes.wintypes.DWORD, # TokenInformationLength
ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength
]
GetTokenInformation.restype = ctypes.wintypes.BOOL
@ -108,10 +101,10 @@ GetTokenInformation.restype = ctypes.wintypes.BOOL
AdjustTokenPrivileges = ctypes.windll.advapi32.AdjustTokenPrivileges
AdjustTokenPrivileges.restype = ctypes.wintypes.BOOL
AdjustTokenPrivileges.argtypes = [
ctypes.wintypes.HANDLE, # TokenHandle
ctypes.wintypes.BOOL, # DisableAllPrivileges
PTOKEN_PRIVILEGES, # NewState (optional)
ctypes.wintypes.DWORD, # BufferLength of PreviousState
PTOKEN_PRIVILEGES, # PreviousState (out, optional)
ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength
ctypes.wintypes.HANDLE, # TokenHandle
ctypes.wintypes.BOOL, # DisableAllPrivileges
PTOKEN_PRIVILEGES, # NewState (optional)
ctypes.wintypes.DWORD, # BufferLength of PreviousState
PTOKEN_PRIVILEGES, # PreviousState (out, optional)
ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength
]

View file

@ -1,11 +1,13 @@
import ctypes.wintypes
TOKEN_ALL_ACCESS = 0xf01ff
TOKEN_ALL_ACCESS = 0xF01FF
GetCurrentProcess = ctypes.windll.kernel32.GetCurrentProcess
GetCurrentProcess.restype = ctypes.wintypes.HANDLE
OpenProcessToken = ctypes.windll.advapi32.OpenProcessToken
OpenProcessToken.argtypes = (
ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD,
ctypes.POINTER(ctypes.wintypes.HANDLE))
ctypes.wintypes.HANDLE,
ctypes.wintypes.DWORD,
ctypes.POINTER(ctypes.wintypes.HANDLE),
)
OpenProcessToken.restype = ctypes.wintypes.BOOL

View file

@ -24,116 +24,117 @@ POLICY_LOOKUP_NAMES = 0x00000800
POLICY_NOTIFICATION = 0x00001000
POLICY_ALL_ACCESS = (
STANDARD_RIGHTS_REQUIRED |
POLICY_VIEW_LOCAL_INFORMATION |
POLICY_VIEW_AUDIT_INFORMATION |
POLICY_GET_PRIVATE_INFORMATION |
POLICY_TRUST_ADMIN |
POLICY_CREATE_ACCOUNT |
POLICY_CREATE_SECRET |
POLICY_CREATE_PRIVILEGE |
POLICY_SET_DEFAULT_QUOTA_LIMITS |
POLICY_SET_AUDIT_REQUIREMENTS |
POLICY_AUDIT_LOG_ADMIN |
POLICY_SERVER_ADMIN |
POLICY_LOOKUP_NAMES)
STANDARD_RIGHTS_REQUIRED
| POLICY_VIEW_LOCAL_INFORMATION
| POLICY_VIEW_AUDIT_INFORMATION
| POLICY_GET_PRIVATE_INFORMATION
| POLICY_TRUST_ADMIN
| POLICY_CREATE_ACCOUNT
| POLICY_CREATE_SECRET
| POLICY_CREATE_PRIVILEGE
| POLICY_SET_DEFAULT_QUOTA_LIMITS
| POLICY_SET_AUDIT_REQUIREMENTS
| POLICY_AUDIT_LOG_ADMIN
| POLICY_SERVER_ADMIN
| POLICY_LOOKUP_NAMES
)
POLICY_READ = (
STANDARD_RIGHTS_READ |
POLICY_VIEW_AUDIT_INFORMATION |
POLICY_GET_PRIVATE_INFORMATION)
STANDARD_RIGHTS_READ
| POLICY_VIEW_AUDIT_INFORMATION
| POLICY_GET_PRIVATE_INFORMATION
)
POLICY_WRITE = (
STANDARD_RIGHTS_WRITE |
POLICY_TRUST_ADMIN |
POLICY_CREATE_ACCOUNT |
POLICY_CREATE_SECRET |
POLICY_CREATE_PRIVILEGE |
POLICY_SET_DEFAULT_QUOTA_LIMITS |
POLICY_SET_AUDIT_REQUIREMENTS |
POLICY_AUDIT_LOG_ADMIN |
POLICY_SERVER_ADMIN)
STANDARD_RIGHTS_WRITE
| POLICY_TRUST_ADMIN
| POLICY_CREATE_ACCOUNT
| POLICY_CREATE_SECRET
| POLICY_CREATE_PRIVILEGE
| POLICY_SET_DEFAULT_QUOTA_LIMITS
| POLICY_SET_AUDIT_REQUIREMENTS
| POLICY_AUDIT_LOG_ADMIN
| POLICY_SERVER_ADMIN
)
POLICY_EXECUTE = (
STANDARD_RIGHTS_EXECUTE |
POLICY_VIEW_LOCAL_INFORMATION |
POLICY_LOOKUP_NAMES)
STANDARD_RIGHTS_EXECUTE | POLICY_VIEW_LOCAL_INFORMATION | POLICY_LOOKUP_NAMES
)
class TokenAccess:
TOKEN_QUERY = 0x8
TOKEN_QUERY = 0x8
class TokenInformationClass:
TokenUser = 1
TokenUser = 1
class TOKEN_USER(ctypes.Structure):
num = 1
_fields_ = [
('SID', ctypes.c_void_p),
('ATTRIBUTES', ctypes.wintypes.DWORD),
]
num = 1
_fields_ = [('SID', ctypes.c_void_p), ('ATTRIBUTES', ctypes.wintypes.DWORD)]
class SECURITY_DESCRIPTOR(ctypes.Structure):
"""
typedef struct _SECURITY_DESCRIPTOR
{
UCHAR Revision;
UCHAR Sbz1;
SECURITY_DESCRIPTOR_CONTROL Control;
PSID Owner;
PSID Group;
PACL Sacl;
PACL Dacl;
} SECURITY_DESCRIPTOR;
"""
SECURITY_DESCRIPTOR_CONTROL = ctypes.wintypes.USHORT
REVISION = 1
"""
typedef struct _SECURITY_DESCRIPTOR
{
UCHAR Revision;
UCHAR Sbz1;
SECURITY_DESCRIPTOR_CONTROL Control;
PSID Owner;
PSID Group;
PACL Sacl;
PACL Dacl;
} SECURITY_DESCRIPTOR;
"""
_fields_ = [
('Revision', ctypes.c_ubyte),
('Sbz1', ctypes.c_ubyte),
('Control', SECURITY_DESCRIPTOR_CONTROL),
('Owner', ctypes.c_void_p),
('Group', ctypes.c_void_p),
('Sacl', ctypes.c_void_p),
('Dacl', ctypes.c_void_p),
]
SECURITY_DESCRIPTOR_CONTROL = ctypes.wintypes.USHORT
REVISION = 1
_fields_ = [
('Revision', ctypes.c_ubyte),
('Sbz1', ctypes.c_ubyte),
('Control', SECURITY_DESCRIPTOR_CONTROL),
('Owner', ctypes.c_void_p),
('Group', ctypes.c_void_p),
('Sacl', ctypes.c_void_p),
('Dacl', ctypes.c_void_p),
]
class SECURITY_ATTRIBUTES(ctypes.Structure):
"""
typedef struct _SECURITY_ATTRIBUTES {
DWORD nLength;
LPVOID lpSecurityDescriptor;
BOOL bInheritHandle;
} SECURITY_ATTRIBUTES;
"""
_fields_ = [
('nLength', ctypes.wintypes.DWORD),
('lpSecurityDescriptor', ctypes.c_void_p),
('bInheritHandle', ctypes.wintypes.BOOL),
]
"""
typedef struct _SECURITY_ATTRIBUTES {
DWORD nLength;
LPVOID lpSecurityDescriptor;
BOOL bInheritHandle;
} SECURITY_ATTRIBUTES;
"""
def __init__(self, *args, **kwargs):
super(SECURITY_ATTRIBUTES, self).__init__(*args, **kwargs)
self.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES)
_fields_ = [
('nLength', ctypes.wintypes.DWORD),
('lpSecurityDescriptor', ctypes.c_void_p),
('bInheritHandle', ctypes.wintypes.BOOL),
]
@property
def descriptor(self):
return self._descriptor
def __init__(self, *args, **kwargs):
super(SECURITY_ATTRIBUTES, self).__init__(*args, **kwargs)
self.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES)
@descriptor.setter
def descriptor(self, value):
self._descriptor = value
self.lpSecurityDescriptor = ctypes.addressof(value)
@property
def descriptor(self):
return self._descriptor
@descriptor.setter
def descriptor(self, value):
self._descriptor = value
self.lpSecurityDescriptor = ctypes.addressof(value)
ctypes.windll.advapi32.SetSecurityDescriptorOwner.argtypes = (
ctypes.POINTER(SECURITY_DESCRIPTOR),
ctypes.c_void_p,
ctypes.wintypes.BOOL,
ctypes.POINTER(SECURITY_DESCRIPTOR),
ctypes.c_void_p,
ctypes.wintypes.BOOL,
)

View file

@ -1,39 +1,40 @@
import ctypes.wintypes
BOOL = ctypes.wintypes.BOOL
class SHELLSTATE(ctypes.Structure):
_fields_ = [
('show_all_objects', BOOL, 1),
('show_extensions', BOOL, 1),
('no_confirm_recycle', BOOL, 1),
('show_sys_files', BOOL, 1),
('show_comp_color', BOOL, 1),
('double_click_in_web_view', BOOL, 1),
('desktop_HTML', BOOL, 1),
('win95_classic', BOOL, 1),
('dont_pretty_path', BOOL, 1),
('show_attrib_col', BOOL, 1),
('map_network_drive_button', BOOL, 1),
('show_info_tip', BOOL, 1),
('hide_icons', BOOL, 1),
('web_view', BOOL, 1),
('filter', BOOL, 1),
('show_super_hidden', BOOL, 1),
('no_net_crawling', BOOL, 1),
('win95_unused', ctypes.wintypes.DWORD),
('param_sort', ctypes.wintypes.LONG),
('sort_direction', ctypes.c_int),
('version', ctypes.wintypes.UINT),
('not_used', ctypes.wintypes.UINT),
('sep_process', BOOL, 1),
('start_panel_on', BOOL, 1),
('show_start_page', BOOL, 1),
('auto_check_select', BOOL, 1),
('icons_only', BOOL, 1),
('show_type_overlay', BOOL, 1),
('spare_flags', ctypes.wintypes.UINT, 13),
]
_fields_ = [
('show_all_objects', BOOL, 1),
('show_extensions', BOOL, 1),
('no_confirm_recycle', BOOL, 1),
('show_sys_files', BOOL, 1),
('show_comp_color', BOOL, 1),
('double_click_in_web_view', BOOL, 1),
('desktop_HTML', BOOL, 1),
('win95_classic', BOOL, 1),
('dont_pretty_path', BOOL, 1),
('show_attrib_col', BOOL, 1),
('map_network_drive_button', BOOL, 1),
('show_info_tip', BOOL, 1),
('hide_icons', BOOL, 1),
('web_view', BOOL, 1),
('filter', BOOL, 1),
('show_super_hidden', BOOL, 1),
('no_net_crawling', BOOL, 1),
('win95_unused', ctypes.wintypes.DWORD),
('param_sort', ctypes.wintypes.LONG),
('sort_direction', ctypes.c_int),
('version', ctypes.wintypes.UINT),
('not_used', ctypes.wintypes.UINT),
('sep_process', BOOL, 1),
('start_panel_on', BOOL, 1),
('show_start_page', BOOL, 1),
('auto_check_select', BOOL, 1),
('icons_only', BOOL, 1),
('show_type_overlay', BOOL, 1),
('spare_flags', ctypes.wintypes.UINT, 13),
]
SSF_SHOWALLOBJECTS = 0x00000001
@ -123,8 +124,8 @@ SSF_SHOWTYPEOVERLAY = 0x02000000
SHGetSetSettings = ctypes.windll.shell32.SHGetSetSettings
SHGetSetSettings.argtypes = [
ctypes.POINTER(SHELLSTATE),
ctypes.wintypes.DWORD,
ctypes.wintypes.BOOL, # get or set (True: set)
ctypes.POINTER(SHELLSTATE),
ctypes.wintypes.DWORD,
ctypes.wintypes.BOOL, # get or set (True: set)
]
SHGetSetSettings.restype = None

View file

@ -2,10 +2,10 @@ import ctypes.wintypes
SystemParametersInfo = ctypes.windll.user32.SystemParametersInfoW
SystemParametersInfo.argtypes = (
ctypes.wintypes.UINT,
ctypes.wintypes.UINT,
ctypes.c_void_p,
ctypes.wintypes.UINT,
ctypes.wintypes.UINT,
ctypes.wintypes.UINT,
ctypes.c_void_p,
ctypes.wintypes.UINT,
)
SPI_GETACTIVEWINDOWTRACKING = 0x1000

View file

@ -1,9 +1,9 @@
import ctypes.wintypes
try:
from ctypes.wintypes import LPDWORD
from ctypes.wintypes import LPDWORD
except ImportError:
LPDWORD = ctypes.POINTER(ctypes.wintypes.DWORD)
LPDWORD = ctypes.POINTER(ctypes.wintypes.DWORD) # type: ignore
GetUserName = ctypes.windll.advapi32.GetUserNameW
GetUserName.argtypes = ctypes.wintypes.LPWSTR, LPDWORD

View file

@ -0,0 +1,39 @@
import subprocess
import itertools
from more_itertools import consume, always_iterable
def extract_environment(env_cmd, initial=None):
"""
Take a command (either a single command or list of arguments)
and return the environment created after running that command.
Note that if the command must be a batch file or .cmd file, or the
changes to the environment will not be captured.
If initial is supplied, it is used as the initial environment passed
to the child process.
"""
# construct the command that will alter the environment
env_cmd = subprocess.list2cmdline(always_iterable(env_cmd))
# create a tag so we can tell in the output when the proc is done
tag = 'Done running command'
# construct a cmd.exe command to do accomplish this
cmd = 'cmd.exe /s /c "{env_cmd} && echo "{tag}" && set"'.format(**vars())
# launch the process
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=initial)
# parse the output sent to stdout
lines = proc.stdout
# make sure the lines are strings
def make_str(s):
return s.decode()
lines = map(make_str, lines)
# consume whatever output occurs until the tag is reached
consume(itertools.takewhile(lambda l: tag not in l, lines))
# construct a dictionary of the pairs
result = dict(line.rstrip().split('=', 1) for line in lines)
# let the process finish
proc.communicate()
return result

View file

@ -1,5 +1,3 @@
from __future__ import with_statement, print_function
import sys
import re
import itertools
@ -7,220 +5,265 @@ from contextlib import contextmanager
import io
import ctypes
from ctypes import windll
import six
from six.moves import map
import textwrap
import collections
from jaraco.windows.api import clipboard, memory
from jaraco.windows.error import handle_nonzero_success, WindowsError
from jaraco.windows.memory import LockedMemory
__all__ = (
'GetClipboardData', 'CloseClipboard',
'SetClipboardData', 'OpenClipboard',
)
__all__ = ('GetClipboardData', 'CloseClipboard', 'SetClipboardData', 'OpenClipboard')
def OpenClipboard(owner=None):
"""
Open the clipboard.
"""
Open the clipboard.
owner
[in] Handle to the window to be associated with the open clipboard.
If this parameter is None, the open clipboard is associated with the
current task.
"""
handle_nonzero_success(windll.user32.OpenClipboard(owner))
owner
[in] Handle to the window to be associated with the open clipboard.
If this parameter is None, the open clipboard is associated with the
current task.
"""
handle_nonzero_success(windll.user32.OpenClipboard(owner))
def CloseClipboard():
handle_nonzero_success(windll.user32.CloseClipboard())
handle_nonzero_success(windll.user32.CloseClipboard())
data_handlers = dict()
def handles(*formats):
def register(func):
for format in formats:
data_handlers[format] = func
return func
return register
def register(func):
for format in formats:
data_handlers[format] = func
return func
return register
def nts(buffer):
"""
Null Terminated String
Get the portion of bytestring buffer up to a null character.
"""
result, null, rest = buffer.partition('\x00')
return result
"""
Null Terminated String
Get the portion of bytestring buffer up to a null character.
"""
result, null, rest = buffer.partition('\x00')
return result
@handles(clipboard.CF_DIBV5, clipboard.CF_DIB)
def raw_data(handle):
return LockedMemory(handle).data
return LockedMemory(handle).data
@handles(clipboard.CF_TEXT)
def text_string(handle):
return nts(raw_data(handle))
return nts(raw_data(handle))
@handles(clipboard.CF_UNICODETEXT)
def unicode_string(handle):
return nts(raw_data(handle).decode('utf-16'))
return nts(raw_data(handle).decode('utf-16'))
@handles(clipboard.CF_BITMAP)
def as_bitmap(handle):
# handle is HBITMAP
raise NotImplementedError("Can't convert to DIB")
# todo: use GetDIBits http://msdn.microsoft.com
# /en-us/library/dd144879%28v=VS.85%29.aspx
# handle is HBITMAP
raise NotImplementedError("Can't convert to DIB")
# todo: use GetDIBits http://msdn.microsoft.com
# /en-us/library/dd144879%28v=VS.85%29.aspx
@handles(clipboard.CF_HTML)
class HTMLSnippet(object):
def __init__(self, handle):
self.data = nts(raw_data(handle).decode('utf-8'))
self.headers = self.parse_headers(self.data)
"""
HTML Snippet representing the Microsoft `HTML snippet format
<https://docs.microsoft.com/en-us/windows/win32/dataxchg/html-clipboard-format>`_.
"""
@property
def html(self):
return self.data[self.headers['StartHTML']:]
def __init__(self, handle):
self.data = nts(raw_data(handle).decode('utf-8'))
self.headers = self.parse_headers(self.data)
@staticmethod
def parse_headers(data):
d = io.StringIO(data)
@property
def html(self):
return self.data[self.headers['StartHTML'] :]
def header_line(line):
return re.match('(\w+):(.*)', line)
headers = map(header_line, d)
# grab headers until they no longer match
headers = itertools.takewhile(bool, headers)
@property
def fragment(self):
return self.data[self.headers['StartFragment'] : self.headers['EndFragment']]
def best_type(value):
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
pairs = (
(header.group(1), best_type(header.group(2)))
for header
in headers
)
return dict(pairs)
@staticmethod
def parse_headers(data):
d = io.StringIO(data)
def header_line(line):
return re.match(r'(\w+):(.*)', line)
headers = map(header_line, d)
# grab headers until they no longer match
headers = itertools.takewhile(bool, headers)
def best_type(value):
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
pairs = ((header.group(1), best_type(header.group(2))) for header in headers)
return dict(pairs)
@classmethod
def from_string(cls, source):
"""
Construct an HTMLSnippet with all the headers, modeled after
https://docs.microsoft.com/en-us/troubleshoot/cpp/add-html-code-clipboard
"""
tmpl = textwrap.dedent(
"""
Version:0.9
StartHTML:{start_html:08d}
EndHTML:{end_html:08d}
StartFragment:{start_fragment:08d}
EndFragment:{end_fragment:08d}
<html><body>
<!--StartFragment -->
{source}
<!--EndFragment -->
</body></html>
"""
).strip()
zeros = collections.defaultdict(lambda: 0, locals())
pre_value = tmpl.format_map(zeros)
start_html = pre_value.find('<html>')
end_html = len(tmpl)
assert end_html < 100000000
start_fragment = pre_value.find(source)
end_fragment = pre_value.rfind('\n<!--EndFragment')
tmpl_length = len(tmpl) - len('{source}')
snippet = cls.__new__(cls)
snippet.data = tmpl.format_map(locals())
snippet.headers = cls.parse_headers(snippet.data)
return snippet
def GetClipboardData(type=clipboard.CF_UNICODETEXT):
if type not in data_handlers:
raise NotImplementedError("No support for data of type %d" % type)
handle = clipboard.GetClipboardData(type)
if handle is None:
raise TypeError("No clipboard data of type %d" % type)
return data_handlers[type](handle)
if type not in data_handlers:
raise NotImplementedError("No support for data of type %d" % type)
handle = clipboard.GetClipboardData(type)
if handle is None:
raise TypeError("No clipboard data of type %d" % type)
return data_handlers[type](handle)
def EmptyClipboard():
handle_nonzero_success(windll.user32.EmptyClipboard())
handle_nonzero_success(windll.user32.EmptyClipboard())
def SetClipboardData(type, content):
"""
Modeled after http://msdn.microsoft.com
/en-us/library/ms649016%28VS.85%29.aspx
#_win32_Copying_Information_to_the_Clipboard
"""
allocators = {
clipboard.CF_TEXT: ctypes.create_string_buffer,
clipboard.CF_UNICODETEXT: ctypes.create_unicode_buffer,
clipboard.CF_HTML: ctypes.create_string_buffer,
}
if type not in allocators:
raise NotImplementedError(
"Only text and HTML types are supported at this time")
# allocate the memory for the data
content = allocators[type](content)
flags = memory.GMEM_MOVEABLE
size = ctypes.sizeof(content)
handle_to_copy = windll.kernel32.GlobalAlloc(flags, size)
with LockedMemory(handle_to_copy) as lm:
ctypes.memmove(lm.data_ptr, content, size)
result = clipboard.SetClipboardData(type, handle_to_copy)
if result is None:
raise WindowsError()
"""
Modeled after http://msdn.microsoft.com
/en-us/library/ms649016%28VS.85%29.aspx
#_win32_Copying_Information_to_the_Clipboard
"""
allocators = {
clipboard.CF_TEXT: ctypes.create_string_buffer,
clipboard.CF_UNICODETEXT: ctypes.create_unicode_buffer,
clipboard.CF_HTML: ctypes.create_string_buffer,
}
if type not in allocators:
raise NotImplementedError("Only text and HTML types are supported at this time")
# allocate the memory for the data
content = allocators[type](content)
flags = memory.GMEM_MOVEABLE
size = ctypes.sizeof(content)
handle_to_copy = windll.kernel32.GlobalAlloc(flags, size)
with LockedMemory(handle_to_copy) as lm:
ctypes.memmove(lm.data_ptr, content, size)
result = clipboard.SetClipboardData(type, handle_to_copy)
if result is None:
raise WindowsError()
def set_text(source):
with context():
EmptyClipboard()
SetClipboardData(clipboard.CF_TEXT, source)
with context():
EmptyClipboard()
SetClipboardData(clipboard.CF_TEXT, source)
def get_text():
with context():
result = GetClipboardData(clipboard.CF_TEXT)
return result
with context():
result = GetClipboardData(clipboard.CF_TEXT)
return result
def set_unicode_text(source):
with context():
EmptyClipboard()
SetClipboardData(clipboard.CF_UNICODETEXT, source)
with context():
EmptyClipboard()
SetClipboardData(clipboard.CF_UNICODETEXT, source)
def get_unicode_text():
with context():
return GetClipboardData()
with context():
return GetClipboardData()
def get_html():
with context():
result = GetClipboardData(clipboard.CF_HTML)
return result
"""
>>> set_html('<b>foo</b>')
>>> get_html().html
'<html><body>...<b>foo</b>...</body></html>'
>>> get_html().fragment
'<b>foo</b>'
"""
with context():
result = GetClipboardData(clipboard.CF_HTML)
return result
def set_html(source):
with context():
EmptyClipboard()
SetClipboardData(clipboard.CF_UNICODETEXT, source)
"""
>>> set_html('<b>foo</b>')
"""
snippet = HTMLSnippet.from_string(source)
with context():
EmptyClipboard()
SetClipboardData(clipboard.CF_HTML, snippet.data.encode('utf-8'))
def get_image():
with context():
return GetClipboardData(clipboard.CF_DIB)
with context():
return GetClipboardData(clipboard.CF_DIB)
def paste_stdout():
getter = get_unicode_text if six.PY3 else get_text
sys.stdout.write(getter())
sys.stdout.write(get_unicode_text())
def stdin_copy():
setter = set_unicode_text if six.PY3 else set_text
setter(sys.stdin.read())
set_unicode_text(sys.stdin.read())
@contextmanager
def context():
OpenClipboard()
try:
yield
finally:
CloseClipboard()
OpenClipboard()
try:
yield
finally:
CloseClipboard()
def get_formats():
with context():
format_index = 0
while True:
format_index = clipboard.EnumClipboardFormats(format_index)
if format_index == 0:
break
yield format_index
with context():
format_index = 0
while True:
format_index = clipboard.EnumClipboardFormats(format_index)
if format_index == 0:
break
yield format_index

View file

@ -7,16 +7,16 @@ CRED_TYPE_GENERIC = 1
def CredDelete(TargetName, Type, Flags=0):
error.handle_nonzero_success(api.CredDelete(TargetName, Type, Flags))
error.handle_nonzero_success(api.CredDelete(TargetName, Type, Flags))
def CredRead(TargetName, Type, Flags=0):
cred_pointer = api.PCREDENTIAL()
res = api.CredRead(TargetName, Type, Flags, ctypes.byref(cred_pointer))
error.handle_nonzero_success(res)
return cred_pointer.contents
cred_pointer = api.PCREDENTIAL()
res = api.CredRead(TargetName, Type, Flags, ctypes.byref(cred_pointer))
error.handle_nonzero_success(res)
return cred_pointer.contents
def CredWrite(Credential, Flags=0):
res = api.CredWrite(Credential, Flags)
error.handle_nonzero_success(res)
res = api.CredWrite(Credential, Flags)
error.handle_nonzero_success(res)

View file

@ -1,4 +1,3 @@
"""
Python routines to interface with the Microsoft
Data Protection API (DPAPI).
@ -20,80 +19,77 @@ __import__('jaraco.windows.api.memory')
class DATA_BLOB(ctypes.Structure):
r"""
A data blob structure for use with MS DPAPI functions.
r"""
A data blob structure for use with MS DPAPI functions.
Initialize with string of characters
>>> input = b'abc123\x00456'
>>> blob = DATA_BLOB(input)
>>> len(blob)
10
>>> blob.get_data() == input
True
"""
_fields_ = [
('data_size', wintypes.DWORD),
('data', ctypes.c_void_p),
]
Initialize with string of characters
>>> input = b'abc123\x00456'
>>> blob = DATA_BLOB(input)
>>> len(blob)
10
>>> blob.get_data() == input
True
"""
_fields_ = [('data_size', wintypes.DWORD), ('data', ctypes.c_void_p)]
def __init__(self, data=None):
super(DATA_BLOB, self).__init__()
self.set_data(data)
def __init__(self, data=None):
super(DATA_BLOB, self).__init__()
self.set_data(data)
def set_data(self, data):
"Use this method to set the data for this blob"
if data is None:
self.data_size = 0
self.data = None
return
self.data_size = len(data)
# create a string buffer so that null bytes aren't interpreted
# as the end of the string
self.data = ctypes.cast(ctypes.create_string_buffer(data), ctypes.c_void_p)
def set_data(self, data):
"Use this method to set the data for this blob"
if data is None:
self.data_size = 0
self.data = None
return
self.data_size = len(data)
# create a string buffer so that null bytes aren't interpreted
# as the end of the string
self.data = ctypes.cast(ctypes.create_string_buffer(data), ctypes.c_void_p)
def get_data(self):
"Get the data for this blob"
array = ctypes.POINTER(ctypes.c_char * len(self))
return ctypes.cast(self.data, array).contents.raw
def get_data(self):
"Get the data for this blob"
array = ctypes.POINTER(ctypes.c_char * len(self))
return ctypes.cast(self.data, array).contents.raw
def __len__(self):
return self.data_size
def __len__(self):
return self.data_size
def __str__(self):
return self.get_data()
def __str__(self):
return self.get_data()
def free(self):
"""
"data out" blobs have locally-allocated memory.
Call this method to free the memory allocated by CryptProtectData
and CryptUnprotectData.
"""
ctypes.windll.kernel32.LocalFree(self.data)
def free(self):
"""
"data out" blobs have locally-allocated memory.
Call this method to free the memory allocated by CryptProtectData
and CryptUnprotectData.
"""
ctypes.windll.kernel32.LocalFree(self.data)
p_DATA_BLOB = ctypes.POINTER(DATA_BLOB)
_CryptProtectData = ctypes.windll.crypt32.CryptProtectData
_CryptProtectData.argtypes = [
p_DATA_BLOB, # data in
wintypes.LPCWSTR, # data description
p_DATA_BLOB, # optional entropy
ctypes.c_void_p, # reserved
ctypes.c_void_p, # POINTER(CRYPTPROTECT_PROMPTSTRUCT), # prompt struct
wintypes.DWORD, # flags
p_DATA_BLOB, # data out
p_DATA_BLOB, # data in
wintypes.LPCWSTR, # data description
p_DATA_BLOB, # optional entropy
ctypes.c_void_p, # reserved
ctypes.c_void_p, # POINTER(CRYPTPROTECT_PROMPTSTRUCT), # prompt struct
wintypes.DWORD, # flags
p_DATA_BLOB, # data out
]
_CryptProtectData.restype = wintypes.BOOL
_CryptUnprotectData = ctypes.windll.crypt32.CryptUnprotectData
_CryptUnprotectData.argtypes = [
p_DATA_BLOB, # data in
ctypes.POINTER(wintypes.LPWSTR), # data description
p_DATA_BLOB, # optional entropy
ctypes.c_void_p, # reserved
ctypes.c_void_p, # POINTER(CRYPTPROTECT_PROMPTSTRUCT), # prompt struct
wintypes.DWORD, # flags
p_DATA_BLOB, # data out
p_DATA_BLOB, # data in
ctypes.POINTER(wintypes.LPWSTR), # data description
p_DATA_BLOB, # optional entropy
ctypes.c_void_p, # reserved
ctypes.c_void_p, # POINTER(CRYPTPROTECT_PROMPTSTRUCT), # prompt struct
wintypes.DWORD, # flags
p_DATA_BLOB, # data out
]
_CryptUnprotectData.restype = wintypes.BOOL
@ -101,55 +97,47 @@ CRYPTPROTECT_UI_FORBIDDEN = 0x01
def CryptProtectData(
data, description=None, optional_entropy=None,
prompt_struct=None, flags=0,
data, description=None, optional_entropy=None, prompt_struct=None, flags=0
):
"""
Encrypt data
"""
data_in = DATA_BLOB(data)
entropy = DATA_BLOB(optional_entropy) if optional_entropy else None
data_out = DATA_BLOB()
"""
Encrypt data
"""
data_in = DATA_BLOB(data)
entropy = DATA_BLOB(optional_entropy) if optional_entropy else None
data_out = DATA_BLOB()
res = _CryptProtectData(
data_in,
description,
entropy,
None, # reserved
prompt_struct,
flags,
data_out,
)
handle_nonzero_success(res)
res = data_out.get_data()
data_out.free()
return res
res = _CryptProtectData(
data_in, description, entropy, None, prompt_struct, flags, data_out # reserved
)
handle_nonzero_success(res)
res = data_out.get_data()
data_out.free()
return res
def CryptUnprotectData(
data, optional_entropy=None, prompt_struct=None, flags=0):
"""
Returns a tuple of (description, data) where description is the
the description that was passed to the CryptProtectData call and
data is the decrypted result.
"""
data_in = DATA_BLOB(data)
entropy = DATA_BLOB(optional_entropy) if optional_entropy else None
data_out = DATA_BLOB()
ptr_description = wintypes.LPWSTR()
res = _CryptUnprotectData(
data_in,
ctypes.byref(ptr_description),
entropy,
None, # reserved
prompt_struct,
flags | CRYPTPROTECT_UI_FORBIDDEN,
data_out,
)
handle_nonzero_success(res)
description = ptr_description.value
if ptr_description.value is not None:
ctypes.windll.kernel32.LocalFree(ptr_description)
res = data_out.get_data()
data_out.free()
return description, res
def CryptUnprotectData(data, optional_entropy=None, prompt_struct=None, flags=0):
"""
Returns a tuple of (description, data) where description is the
the description that was passed to the CryptProtectData call and
data is the decrypted result.
"""
data_in = DATA_BLOB(data)
entropy = DATA_BLOB(optional_entropy) if optional_entropy else None
data_out = DATA_BLOB()
ptr_description = wintypes.LPWSTR()
res = _CryptUnprotectData(
data_in,
ctypes.byref(ptr_description),
entropy,
None, # reserved
prompt_struct,
flags | CRYPTPROTECT_UI_FORBIDDEN,
data_out,
)
handle_nonzero_success(res)
description = ptr_description.value
if ptr_description.value is not None:
ctypes.windll.kernel32.LocalFree(ptr_description)
res = data_out.get_data()
data_out.free()
return description, res

View file

@ -1,14 +1,8 @@
#!/usr/bin/env python
from __future__ import absolute_import
import sys
import winreg
import ctypes
import ctypes.wintypes
import six
from six.moves import winreg
from jaraco.ui.editor import EditableFile
from jaraco.windows import error
@ -17,241 +11,247 @@ from .registry import key_values as registry_key_values
def SetEnvironmentVariable(name, value):
error.handle_nonzero_success(environ.SetEnvironmentVariable(name, value))
error.handle_nonzero_success(environ.SetEnvironmentVariable(name, value))
def ClearEnvironmentVariable(name):
error.handle_nonzero_success(environ.SetEnvironmentVariable(name, None))
error.handle_nonzero_success(environ.SetEnvironmentVariable(name, None))
def GetEnvironmentVariable(name):
max_size = 2**15 - 1
buffer = ctypes.create_unicode_buffer(max_size)
error.handle_nonzero_success(
environ.GetEnvironmentVariable(name, buffer, max_size))
return buffer.value
max_size = 2 ** 15 - 1
buffer = ctypes.create_unicode_buffer(max_size)
error.handle_nonzero_success(environ.GetEnvironmentVariable(name, buffer, max_size))
return buffer.value
###
class RegisteredEnvironment(object):
"""
Manages the environment variables as set in the Windows Registry.
"""
"""
Manages the environment variables as set in the Windows Registry.
"""
@classmethod
def show(class_):
for name, value, type in registry_key_values(class_.key):
sys.stdout.write('='.join((name, value)) + '\n')
@classmethod
def show(class_):
for name, value, type in registry_key_values(class_.key):
sys.stdout.write('='.join((name, value)) + '\n')
NoDefault = type('NoDefault', (object,), dict())
NoDefault = type('NoDefault', (object,), dict())
@classmethod
def get(class_, name, default=NoDefault):
try:
value, type = winreg.QueryValueEx(class_.key, name)
return value
except WindowsError:
if default is not class_.NoDefault:
return default
raise ValueError("No such key", name)
@classmethod
def get(class_, name, default=NoDefault):
try:
value, type = winreg.QueryValueEx(class_.key, name)
return value
except WindowsError:
if default is not class_.NoDefault:
return default
raise ValueError("No such key", name)
@classmethod
def get_values_list(class_, name, sep):
res = class_.get(name.upper(), [])
if isinstance(res, six.string_types):
res = res.split(sep)
return res
@classmethod
def get_values_list(class_, name, sep):
res = class_.get(name.upper(), [])
if isinstance(res, str):
res = res.split(sep)
return res
@classmethod
def set(class_, name, value, options):
# consider opening the key read-only except for here
# key = winreg.OpenKey(class_.key, None, 0, winreg.KEY_WRITE)
# and follow up by closing it.
if not value:
return class_.delete(name)
do_append = options.append or (
name.upper() in ('PATH', 'PATHEXT') and not options.replace
)
if do_append:
sep = ';'
values = class_.get_values_list(name, sep) + [value]
value = sep.join(values)
winreg.SetValueEx(class_.key, name, 0, winreg.REG_EXPAND_SZ, value)
class_.notify()
@classmethod
def set(class_, name, value, options):
# consider opening the key read-only except for here
# key = winreg.OpenKey(class_.key, None, 0, winreg.KEY_WRITE)
# and follow up by closing it.
if not value:
return class_.delete(name)
do_append = options.append or (
name.upper() in ('PATH', 'PATHEXT') and not options.replace
)
if do_append:
sep = ';'
values = class_.get_values_list(name, sep) + [value]
value = sep.join(values)
winreg.SetValueEx(class_.key, name, 0, winreg.REG_EXPAND_SZ, value)
class_.notify()
@classmethod
def add(class_, name, value, sep=';'):
"""
Add a value to a delimited variable, but only when the value isn't
already present.
"""
values = class_.get_values_list(name, sep)
if value in values:
return
new_value = sep.join(values + [value])
winreg.SetValueEx(
class_.key, name, 0, winreg.REG_EXPAND_SZ, new_value)
class_.notify()
@classmethod
def add(class_, name, value, sep=';'):
"""
Add a value to a delimited variable, but only when the value isn't
already present.
"""
values = class_.get_values_list(name, sep)
if value in values:
return
new_value = sep.join(values + [value])
winreg.SetValueEx(class_.key, name, 0, winreg.REG_EXPAND_SZ, new_value)
class_.notify()
@classmethod
def remove_values(class_, name, value_substring, options):
sep = ';'
values = class_.get_values_list(name, sep)
new_values = [
value
for value in values
if value_substring.lower() not in value.lower()
]
values = sep.join(new_values)
winreg.SetValueEx(class_.key, name, 0, winreg.REG_EXPAND_SZ, values)
class_.notify()
@classmethod
def remove_values(class_, name, value_substring, options):
sep = ';'
values = class_.get_values_list(name, sep)
new_values = [
value for value in values if value_substring.lower() not in value.lower()
]
values = sep.join(new_values)
winreg.SetValueEx(class_.key, name, 0, winreg.REG_EXPAND_SZ, values)
class_.notify()
@classmethod
def edit(class_, name, value='', options=None):
# value, options ignored
sep = ';'
values = class_.get_values_list(name, sep)
e = EditableFile('\n'.join(values))
e.edit()
if e.changed:
values = sep.join(e.data.strip().split('\n'))
winreg.SetValueEx(class_.key, name, 0, winreg.REG_EXPAND_SZ, values)
class_.notify()
@classmethod
def edit(class_, name, value='', options=None):
# value, options ignored
sep = ';'
values = class_.get_values_list(name, sep)
e = EditableFile('\n'.join(values))
e.edit()
if e.changed:
values = sep.join(e.data.strip().split('\n'))
winreg.SetValueEx(class_.key, name, 0, winreg.REG_EXPAND_SZ, values)
class_.notify()
@classmethod
def delete(class_, name):
winreg.DeleteValue(class_.key, name)
class_.notify()
@classmethod
def delete(class_, name):
winreg.DeleteValue(class_.key, name)
class_.notify()
@classmethod
def notify(class_):
"""
Notify other windows that the environment has changed (following
http://support.microsoft.com/kb/104011).
"""
# TODO: Implement Microsoft UIPI (User Interface Privilege Isolation) to
# elevate privilege to system level so the system gets this notification
# for now, this must be run as admin to work as expected
return_val = ctypes.wintypes.DWORD()
res = message.SendMessageTimeout(
message.HWND_BROADCAST,
message.WM_SETTINGCHANGE,
0, # wparam must be null
'Environment',
message.SMTO_ABORTIFHUNG,
5000, # timeout in ms
return_val,
)
error.handle_nonzero_success(res)
@classmethod
def notify(class_):
"""
Notify other windows that the environment has changed (following
http://support.microsoft.com/kb/104011).
"""
# TODO: Implement Microsoft UIPI (User Interface Privilege Isolation) to
# elevate privilege to system level so the system gets this notification
# for now, this must be run as admin to work as expected
return_val = ctypes.wintypes.DWORD()
res = message.SendMessageTimeout(
message.HWND_BROADCAST,
message.WM_SETTINGCHANGE,
0, # wparam must be null
'Environment',
message.SMTO_ABORTIFHUNG,
5000, # timeout in ms
return_val,
)
error.handle_nonzero_success(res)
class MachineRegisteredEnvironment(RegisteredEnvironment):
path = r'SYSTEM\CurrentControlSet\Control\Session Manager\Environment'
hklm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
try:
key = winreg.OpenKey(
hklm, path, 0,
winreg.KEY_READ | winreg.KEY_WRITE)
except WindowsError:
key = winreg.OpenKey(hklm, path, 0, winreg.KEY_READ)
path = r'SYSTEM\CurrentControlSet\Control\Session Manager\Environment'
hklm = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
try:
key = winreg.OpenKey(hklm, path, 0, winreg.KEY_READ | winreg.KEY_WRITE)
except WindowsError:
key = winreg.OpenKey(hklm, path, 0, winreg.KEY_READ)
class UserRegisteredEnvironment(RegisteredEnvironment):
hkcu = winreg.ConnectRegistry(None, winreg.HKEY_CURRENT_USER)
key = winreg.OpenKey(
hkcu, 'Environment', 0,
winreg.KEY_READ | winreg.KEY_WRITE)
hkcu = winreg.ConnectRegistry(None, winreg.HKEY_CURRENT_USER)
key = winreg.OpenKey(hkcu, 'Environment', 0, winreg.KEY_READ | winreg.KEY_WRITE)
def trim(s):
from textwrap import dedent
return dedent(s).strip()
from textwrap import dedent
return dedent(s).strip()
def enver(*args):
"""
%prog [<name>=[value]]
"""
%prog [<name>=[value]]
To show all environment variables, call with no parameters:
%prog
To Add/Modify/Delete environment variable:
%prog <name>=[value]
To show all environment variables, call with no parameters:
%prog
To Add/Modify/Delete environment variable:
%prog <name>=[value]
If <name> is PATH or PATHEXT, %prog will by default append the value using
a semicolon as a separator. Use -r to disable this behavior or -a to force
it for variables other than PATH and PATHEXT.
If <name> is PATH or PATHEXT, %prog will by default append the value using
a semicolon as a separator. Use -r to disable this behavior or -a to force
it for variables other than PATH and PATHEXT.
If append is prescribed, but the value doesn't exist, the value will be
created.
If append is prescribed, but the value doesn't exist, the value will be
created.
If there is no value, %prog will delete the <name> environment variable.
i.e. "PATH="
If there is no value, %prog will delete the <name> environment variable.
i.e. "PATH="
To remove a specific value or values from a semicolon-separated
multi-value variable (such as PATH), use --remove-value.
To remove a specific value or values from a semicolon-separated
multi-value variable (such as PATH), use --remove-value.
e.g. enver --remove-value PATH=C:\\Unwanted\\Dir\\In\\Path
e.g. enver --remove-value PATH=C:\\Unwanted\\Dir\\In\\Path
Remove-value matches case-insensitive and also matches any substring
so the following would also be sufficient to remove the aforementioned
undesirable dir.
Remove-value matches case-insensitive and also matches any substring
so the following would also be sufficient to remove the aforementioned
undesirable dir.
enver --remove-value PATH=UNWANTED
enver --remove-value PATH=UNWANTED
Note that %prog does not affect the current running environment, and can
only affect subsequently spawned applications.
"""
from optparse import OptionParser
parser = OptionParser(usage=trim(enver.__doc__))
parser.add_option(
'-U', '--user-environment',
action='store_const', const=UserRegisteredEnvironment,
default=MachineRegisteredEnvironment,
dest='class_',
help="Use the current user's environment",
)
parser.add_option(
'-a', '--append',
action='store_true', default=False,
help="Append the value to any existing value (default for PATH and PATHEXT)",
)
parser.add_option(
'-r', '--replace',
action='store_true', default=False,
help="Replace any existing value (used to override default append "
"for PATH and PATHEXT)",
)
parser.add_option(
'--remove-value', action='store_true', default=False,
help="Remove any matching values from a semicolon-separated "
"multi-value variable",
)
parser.add_option(
'-e', '--edit', action='store_true', default=False,
help="Edit the value in a local editor",
)
options, args = parser.parse_args(*args)
Note that %prog does not affect the current running environment, and can
only affect subsequently spawned applications.
"""
from optparse import OptionParser
try:
param = args.pop()
if args:
parser.error("Too many parameters specified")
raise SystemExit(1)
if '=' not in param and not options.edit:
parser.error("Expected <name>= or <name>=<value>")
raise SystemExit(2)
name, sep, value = param.partition('=')
method_name = 'set'
if options.remove_value:
method_name = 'remove_values'
if options.edit:
method_name = 'edit'
method = getattr(options.class_, method_name)
method(name, value, options)
except IndexError:
options.class_.show()
parser = OptionParser(usage=trim(enver.__doc__))
parser.add_option(
'-U',
'--user-environment',
action='store_const',
const=UserRegisteredEnvironment,
default=MachineRegisteredEnvironment,
dest='class_',
help="Use the current user's environment",
)
parser.add_option(
'-a',
'--append',
action='store_true',
default=False,
help="Append the value to any existing value (default for PATH and PATHEXT)",
)
parser.add_option(
'-r',
'--replace',
action='store_true',
default=False,
help="Replace any existing value (used to override default append "
"for PATH and PATHEXT)",
)
parser.add_option(
'--remove-value',
action='store_true',
default=False,
help="Remove any matching values from a semicolon-separated "
"multi-value variable",
)
parser.add_option(
'-e',
'--edit',
action='store_true',
default=False,
help="Edit the value in a local editor",
)
options, args = parser.parse_args(*args)
try:
param = args.pop()
if args:
parser.error("Too many parameters specified")
raise SystemExit(1)
if '=' not in param and not options.edit:
parser.error("Expected <name>= or <name>=<value>")
raise SystemExit(2)
name, sep, value = param.partition('=')
method_name = 'set'
if options.remove_value:
method_name = 'remove_values'
if options.edit:
method_name = 'edit'
method = getattr(options.class_, method_name)
method(name, value, options)
except IndexError:
options.class_.show()
if __name__ == '__main__':
enver()
enver()

View file

@ -1,84 +1,79 @@
#!/usr/bin/env python
import sys
import builtins
import ctypes
import ctypes.wintypes
import six
builtins = six.moves.builtins
__import__('jaraco.windows.api.memory')
def format_system_message(errno):
"""
Call FormatMessage with a system error number to retrieve
the descriptive error message.
"""
# first some flags used by FormatMessageW
ALLOCATE_BUFFER = 0x100
FROM_SYSTEM = 0x1000
"""
Call FormatMessage with a system error number to retrieve
the descriptive error message.
"""
# first some flags used by FormatMessageW
ALLOCATE_BUFFER = 0x100
FROM_SYSTEM = 0x1000
# Let FormatMessageW allocate the buffer (we'll free it below)
# Also, let it know we want a system error message.
flags = ALLOCATE_BUFFER | FROM_SYSTEM
source = None
message_id = errno
language_id = 0
result_buffer = ctypes.wintypes.LPWSTR()
buffer_size = 0
arguments = None
bytes = ctypes.windll.kernel32.FormatMessageW(
flags,
source,
message_id,
language_id,
ctypes.byref(result_buffer),
buffer_size,
arguments,
)
# note the following will cause an infinite loop if GetLastError
# repeatedly returns an error that cannot be formatted, although
# this should not happen.
handle_nonzero_success(bytes)
message = result_buffer.value
ctypes.windll.kernel32.LocalFree(result_buffer)
return message
# Let FormatMessageW allocate the buffer (we'll free it below)
# Also, let it know we want a system error message.
flags = ALLOCATE_BUFFER | FROM_SYSTEM
source = None
message_id = errno
language_id = 0
result_buffer = ctypes.wintypes.LPWSTR()
buffer_size = 0
arguments = None
bytes = ctypes.windll.kernel32.FormatMessageW(
flags,
source,
message_id,
language_id,
ctypes.byref(result_buffer),
buffer_size,
arguments,
)
# note the following will cause an infinite loop if GetLastError
# repeatedly returns an error that cannot be formatted, although
# this should not happen.
handle_nonzero_success(bytes)
message = result_buffer.value
ctypes.windll.kernel32.LocalFree(result_buffer)
return message
class WindowsError(builtins.WindowsError):
"""
More info about errors at
http://msdn.microsoft.com/en-us/library/ms681381(VS.85).aspx
"""
"""
More info about errors at
http://msdn.microsoft.com/en-us/library/ms681381(VS.85).aspx
"""
def __init__(self, value=None):
if value is None:
value = ctypes.windll.kernel32.GetLastError()
strerror = format_system_message(value)
if sys.version_info > (3, 3):
args = 0, strerror, None, value
else:
args = value, strerror
super(WindowsError, self).__init__(*args)
def __init__(self, value=None):
if value is None:
value = ctypes.windll.kernel32.GetLastError()
strerror = format_system_message(value)
if sys.version_info > (3, 3):
args = 0, strerror, None, value
else:
args = value, strerror
super(WindowsError, self).__init__(*args)
@property
def message(self):
return self.strerror
@property
def message(self):
return self.strerror
@property
def code(self):
return self.winerror
@property
def code(self):
return self.winerror
def __str__(self):
return self.message
def __str__(self):
return self.message
def __repr__(self):
return '{self.__class__.__name__}({self.winerror})'.format(**vars())
def __repr__(self):
return '{self.__class__.__name__}({self.winerror})'.format(**vars())
def handle_nonzero_success(result):
if result == 0:
raise WindowsError()
if result == 0:
raise WindowsError()

View file

@ -1,7 +1,5 @@
import functools
from six.moves import map
import win32api
import win32evtlog
import win32evtlogutil
@ -10,43 +8,43 @@ error = win32api.error # The error the evtlog module raises.
class EventLog(object):
def __init__(self, name="Application", machine_name=None):
self.machine_name = machine_name
self.name = name
self.formatter = functools.partial(
win32evtlogutil.FormatMessage, logType=self.name)
def __init__(self, name="Application", machine_name=None):
self.machine_name = machine_name
self.name = name
self.formatter = functools.partial(
win32evtlogutil.FormatMessage, logType=self.name
)
def __enter__(self):
if hasattr(self, 'handle'):
raise ValueError("Overlapping attempts to use this log context")
self.handle = win32evtlog.OpenEventLog(self.machine_name, self.name)
return self
def __enter__(self):
if hasattr(self, 'handle'):
raise ValueError("Overlapping attempts to use this log context")
self.handle = win32evtlog.OpenEventLog(self.machine_name, self.name)
return self
def __exit__(self, *args):
win32evtlog.CloseEventLog(self.handle)
del self.handle
def __exit__(self, *args):
win32evtlog.CloseEventLog(self.handle)
del self.handle
_default_flags = (
win32evtlog.EVENTLOG_BACKWARDS_READ
| win32evtlog.EVENTLOG_SEQUENTIAL_READ
)
_default_flags = (
win32evtlog.EVENTLOG_BACKWARDS_READ | win32evtlog.EVENTLOG_SEQUENTIAL_READ
)
def get_records(self, flags=_default_flags):
with self:
while True:
objects = win32evtlog.ReadEventLog(self.handle, flags, 0)
if not objects:
break
for item in objects:
yield item
def get_records(self, flags=_default_flags):
with self:
while True:
objects = win32evtlog.ReadEventLog(self.handle, flags, 0)
if not objects:
break
for item in objects:
yield item
def __iter__(self):
return self.get_records()
def __iter__(self):
return self.get_records()
def format_record(self, record):
return self.formatter(record)
def format_record(self, record):
return self.formatter(record)
def format_records(self, records=None):
if records is None:
records = self.get_records()
return map(self.format_record, records)
def format_records(self, records=None):
if records is None:
records = self.get_records()
return map(self.format_record, records)

View file

@ -1,7 +1,3 @@
#!/usr/bin/env python
from __future__ import print_function
import os
import sys
import operator
@ -9,14 +5,17 @@ import collections
import functools
import stat
from ctypes import (
POINTER, byref, cast, create_unicode_buffer,
create_string_buffer, windll)
POINTER,
byref,
cast,
create_unicode_buffer,
create_string_buffer,
windll,
)
from ctypes.wintypes import LPWSTR
import nt
import posixpath
import six
from six.moves import builtins, filter, map
import builtins
from jaraco.structures import binary
@ -26,476 +25,473 @@ from jaraco.windows import reparse
def mklink():
"""
Like cmd.exe's mklink except it will infer directory status of the
target.
"""
from optparse import OptionParser
parser = OptionParser(usage="usage: %prog [options] link target")
parser.add_option(
'-d', '--directory',
help="Target is a directory (only necessary if not present)",
action="store_true")
options, args = parser.parse_args()
try:
link, target = args
except ValueError:
parser.error("incorrect number of arguments")
symlink(target, link, options.directory)
sys.stdout.write("Symbolic link created: %(link)s --> %(target)s\n" % vars())
"""
Like cmd.exe's mklink except it will infer directory status of the
target.
"""
from optparse import OptionParser
parser = OptionParser(usage="usage: %prog [options] link target")
parser.add_option(
'-d',
'--directory',
help="Target is a directory (only necessary if not present)",
action="store_true",
)
options, args = parser.parse_args()
try:
link, target = args
except ValueError:
parser.error("incorrect number of arguments")
symlink(target, link, options.directory)
sys.stdout.write("Symbolic link created: %(link)s --> %(target)s\n" % vars())
def _is_target_a_directory(link, rel_target):
"""
If creating a symlink from link to a target, determine if target
is a directory (relative to dirname(link)).
"""
target = os.path.join(os.path.dirname(link), rel_target)
return os.path.isdir(target)
"""
If creating a symlink from link to a target, determine if target
is a directory (relative to dirname(link)).
"""
target = os.path.join(os.path.dirname(link), rel_target)
return os.path.isdir(target)
def symlink(target, link, target_is_directory=False):
"""
An implementation of os.symlink for Windows (Vista and greater)
"""
target_is_directory = (
target_is_directory or
_is_target_a_directory(link, target)
)
# normalize the target (MS symlinks don't respect forward slashes)
target = os.path.normpath(target)
handle_nonzero_success(
api.CreateSymbolicLink(link, target, target_is_directory))
"""
An implementation of os.symlink for Windows (Vista and greater)
"""
target_is_directory = target_is_directory or _is_target_a_directory(link, target)
# normalize the target (MS symlinks don't respect forward slashes)
target = os.path.normpath(target)
flags = target_is_directory | api.SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE
handle_nonzero_success(api.CreateSymbolicLink(link, target, flags))
def link(target, link):
"""
Establishes a hard link between an existing file and a new file.
"""
handle_nonzero_success(api.CreateHardLink(link, target, None))
"""
Establishes a hard link between an existing file and a new file.
"""
handle_nonzero_success(api.CreateHardLink(link, target, None))
def is_reparse_point(path):
"""
Determine if the given path is a reparse point.
Return False if the file does not exist or the file attributes cannot
be determined.
"""
res = api.GetFileAttributes(path)
return (
res != api.INVALID_FILE_ATTRIBUTES
and bool(res & api.FILE_ATTRIBUTE_REPARSE_POINT)
)
"""
Determine if the given path is a reparse point.
Return False if the file does not exist or the file attributes cannot
be determined.
"""
res = api.GetFileAttributes(path)
return res != api.INVALID_FILE_ATTRIBUTES and bool(
res & api.FILE_ATTRIBUTE_REPARSE_POINT
)
def islink(path):
"Determine if the given path is a symlink"
return is_reparse_point(path) and is_symlink(path)
"Determine if the given path is a symlink"
return is_reparse_point(path) and is_symlink(path)
def _patch_path(path):
"""
Paths have a max length of api.MAX_PATH characters (260). If a target path
is longer than that, it needs to be made absolute and prepended with
\\?\ in order to work with API calls.
See http://msdn.microsoft.com/en-us/library/aa365247%28v=vs.85%29.aspx for
details.
"""
if path.startswith('\\\\?\\'):
return path
abs_path = os.path.abspath(path)
if not abs_path[1] == ':':
# python doesn't include the drive letter, but \\?\ requires it
abs_path = os.getcwd()[:2] + abs_path
return '\\\\?\\' + abs_path
r"""
Paths have a max length of api.MAX_PATH characters (260). If a target path
is longer than that, it needs to be made absolute and prepended with
\\?\ in order to work with API calls.
See http://msdn.microsoft.com/en-us/library/aa365247%28v=vs.85%29.aspx for
details.
"""
if path.startswith('\\\\?\\'):
return path
abs_path = os.path.abspath(path)
if not abs_path[1] == ':':
# python doesn't include the drive letter, but \\?\ requires it
abs_path = os.getcwd()[:2] + abs_path
return '\\\\?\\' + abs_path
def is_symlink(path):
"""
Assuming path is a reparse point, determine if it's a symlink.
"""
path = _patch_path(path)
try:
return _is_symlink(next(find_files(path)))
except WindowsError as orig_error:
tmpl = "Error accessing {path}: {orig_error.message}"
raise builtins.WindowsError(tmpl.format(**locals()))
"""
Assuming path is a reparse point, determine if it's a symlink.
"""
path = _patch_path(path)
try:
return _is_symlink(next(find_files(path)))
# comment below workaround for PyCQA/pyflakes#376
except WindowsError as orig_error: # noqa: F841
tmpl = "Error accessing {path}: {orig_error.message}"
raise builtins.WindowsError(tmpl.format(**locals()))
def _is_symlink(find_data):
return find_data.reserved[0] == api.IO_REPARSE_TAG_SYMLINK
return find_data.reserved[0] == api.IO_REPARSE_TAG_SYMLINK
def find_files(spec):
"""
A pythonic wrapper around the FindFirstFile/FindNextFile win32 api.
r"""
A pythonic wrapper around the FindFirstFile/FindNextFile win32 api.
>>> root_files = tuple(find_files(r'c:\*'))
>>> len(root_files) > 1
True
>>> root_files[0].filename == root_files[1].filename
False
>>> root_files = tuple(find_files(r'c:\*'))
>>> len(root_files) > 1
True
>>> root_files[0].filename == root_files[1].filename
False
This test might fail on a non-standard installation
>>> 'Windows' in (fd.filename for fd in root_files)
True
"""
fd = api.WIN32_FIND_DATA()
handle = api.FindFirstFile(spec, byref(fd))
while True:
if handle == api.INVALID_HANDLE_VALUE:
raise WindowsError()
yield fd
fd = api.WIN32_FIND_DATA()
res = api.FindNextFile(handle, byref(fd))
if res == 0: # error
error = WindowsError()
if error.code == api.ERROR_NO_MORE_FILES:
break
else:
raise error
# todo: how to close handle when generator is destroyed?
# hint: catch GeneratorExit
windll.kernel32.FindClose(handle)
This test might fail on a non-standard installation
>>> 'Windows' in (fd.filename for fd in root_files)
True
"""
fd = api.WIN32_FIND_DATA()
handle = api.FindFirstFile(spec, byref(fd))
while True:
if handle == api.INVALID_HANDLE_VALUE:
raise WindowsError()
yield fd
fd = api.WIN32_FIND_DATA()
res = api.FindNextFile(handle, byref(fd))
if res == 0: # error
error = WindowsError()
if error.code == api.ERROR_NO_MORE_FILES:
break
else:
raise error
# todo: how to close handle when generator is destroyed?
# hint: catch GeneratorExit
windll.kernel32.FindClose(handle)
def get_final_path(path):
"""
For a given path, determine the ultimate location of that path.
Useful for resolving symlink targets.
This functions wraps the GetFinalPathNameByHandle from the Windows
SDK.
r"""
For a given path, determine the ultimate location of that path.
Useful for resolving symlink targets.
This functions wraps the GetFinalPathNameByHandle from the Windows
SDK.
Note, this function fails if a handle cannot be obtained (such as
for C:\Pagefile.sys on a stock windows system). Consider using
trace_symlink_target instead.
"""
desired_access = api.NULL
share_mode = (
api.FILE_SHARE_READ | api.FILE_SHARE_WRITE | api.FILE_SHARE_DELETE
)
security_attributes = api.LPSECURITY_ATTRIBUTES() # NULL pointer
hFile = api.CreateFile(
path,
desired_access,
share_mode,
security_attributes,
api.OPEN_EXISTING,
api.FILE_FLAG_BACKUP_SEMANTICS,
api.NULL,
)
Note, this function fails if a handle cannot be obtained (such as
for C:\Pagefile.sys on a stock windows system). Consider using
trace_symlink_target instead.
"""
desired_access = api.NULL
share_mode = api.FILE_SHARE_READ | api.FILE_SHARE_WRITE | api.FILE_SHARE_DELETE
security_attributes = api.LPSECURITY_ATTRIBUTES() # NULL pointer
hFile = api.CreateFile(
path,
desired_access,
share_mode,
security_attributes,
api.OPEN_EXISTING,
api.FILE_FLAG_BACKUP_SEMANTICS,
api.NULL,
)
if hFile == api.INVALID_HANDLE_VALUE:
raise WindowsError()
if hFile == api.INVALID_HANDLE_VALUE:
raise WindowsError()
buf_size = api.GetFinalPathNameByHandle(
hFile, LPWSTR(), 0, api.VOLUME_NAME_DOS)
handle_nonzero_success(buf_size)
buf = create_unicode_buffer(buf_size)
result_length = api.GetFinalPathNameByHandle(
hFile, buf, len(buf), api.VOLUME_NAME_DOS)
buf_size = api.GetFinalPathNameByHandle(hFile, LPWSTR(), 0, api.VOLUME_NAME_DOS)
handle_nonzero_success(buf_size)
buf = create_unicode_buffer(buf_size)
result_length = api.GetFinalPathNameByHandle(
hFile, buf, len(buf), api.VOLUME_NAME_DOS
)
assert result_length < len(buf)
handle_nonzero_success(result_length)
handle_nonzero_success(api.CloseHandle(hFile))
assert result_length < len(buf)
handle_nonzero_success(result_length)
handle_nonzero_success(api.CloseHandle(hFile))
return buf[:result_length]
return buf[:result_length]
def compat_stat(path):
"""
Generate stat as found on Python 3.2 and later.
"""
stat = os.stat(path)
info = get_file_info(path)
# rewrite st_ino, st_dev, and st_nlink based on file info
return nt.stat_result(
(stat.st_mode,) +
(info.file_index, info.volume_serial_number, info.number_of_links) +
stat[4:]
)
"""
Generate stat as found on Python 3.2 and later.
"""
stat = os.stat(path)
info = get_file_info(path)
# rewrite st_ino, st_dev, and st_nlink based on file info
return nt.stat_result(
(stat.st_mode,)
+ (info.file_index, info.volume_serial_number, info.number_of_links)
+ stat[4:]
)
def samefile(f1, f2):
"""
Backport of samefile from Python 3.2 with support for Windows.
"""
return posixpath.samestat(compat_stat(f1), compat_stat(f2))
"""
Backport of samefile from Python 3.2 with support for Windows.
"""
return posixpath.samestat(compat_stat(f1), compat_stat(f2))
def get_file_info(path):
# open the file the same way CPython does in posixmodule.c
desired_access = api.FILE_READ_ATTRIBUTES
share_mode = 0
security_attributes = None
creation_disposition = api.OPEN_EXISTING
flags_and_attributes = (
api.FILE_ATTRIBUTE_NORMAL |
api.FILE_FLAG_BACKUP_SEMANTICS |
api.FILE_FLAG_OPEN_REPARSE_POINT
)
template_file = None
# open the file the same way CPython does in posixmodule.c
desired_access = api.FILE_READ_ATTRIBUTES
share_mode = 0
security_attributes = None
creation_disposition = api.OPEN_EXISTING
flags_and_attributes = (
api.FILE_ATTRIBUTE_NORMAL
| api.FILE_FLAG_BACKUP_SEMANTICS
| api.FILE_FLAG_OPEN_REPARSE_POINT
)
template_file = None
handle = api.CreateFile(
path,
desired_access,
share_mode,
security_attributes,
creation_disposition,
flags_and_attributes,
template_file,
)
handle = api.CreateFile(
path,
desired_access,
share_mode,
security_attributes,
creation_disposition,
flags_and_attributes,
template_file,
)
if handle == api.INVALID_HANDLE_VALUE:
raise WindowsError()
if handle == api.INVALID_HANDLE_VALUE:
raise WindowsError()
info = api.BY_HANDLE_FILE_INFORMATION()
res = api.GetFileInformationByHandle(handle, info)
handle_nonzero_success(res)
handle_nonzero_success(api.CloseHandle(handle))
info = api.BY_HANDLE_FILE_INFORMATION()
res = api.GetFileInformationByHandle(handle, info)
handle_nonzero_success(res)
handle_nonzero_success(api.CloseHandle(handle))
return info
return info
def GetBinaryType(filepath):
res = api.DWORD()
handle_nonzero_success(api._GetBinaryType(filepath, res))
return res
res = api.DWORD()
handle_nonzero_success(api._GetBinaryType(filepath, res))
return res
def _make_null_terminated_list(obs):
obs = _makelist(obs)
if obs is None:
return
return u'\x00'.join(obs) + u'\x00\x00'
obs = _makelist(obs)
if obs is None:
return
return u'\x00'.join(obs) + u'\x00\x00'
def _makelist(ob):
if ob is None:
return
if not isinstance(ob, (list, tuple, set)):
return [ob]
return ob
if ob is None:
return
if not isinstance(ob, (list, tuple, set)):
return [ob]
return ob
def SHFileOperation(operation, from_, to=None, flags=[]):
flags = functools.reduce(operator.or_, flags, 0)
from_ = _make_null_terminated_list(from_)
to = _make_null_terminated_list(to)
params = api.SHFILEOPSTRUCT(0, operation, from_, to, flags)
res = api._SHFileOperation(params)
if res != 0:
raise RuntimeError("SHFileOperation returned %d" % res)
flags = functools.reduce(operator.or_, flags, 0)
from_ = _make_null_terminated_list(from_)
to = _make_null_terminated_list(to)
params = api.SHFILEOPSTRUCT(0, operation, from_, to, flags)
res = api._SHFileOperation(params)
if res != 0:
raise RuntimeError("SHFileOperation returned %d" % res)
def join(*paths):
r"""
Wrapper around os.path.join that works with Windows drive letters.
r"""
Wrapper around os.path.join that works with Windows drive letters.
>>> join('d:\\foo', '\\bar')
'd:\\bar'
"""
paths_with_drives = map(os.path.splitdrive, paths)
drives, paths = zip(*paths_with_drives)
# the drive we care about is the last one in the list
drive = next(filter(None, reversed(drives)), '')
return os.path.join(drive, os.path.join(*paths))
>>> join('d:\\foo', '\\bar')
'd:\\bar'
"""
paths_with_drives = map(os.path.splitdrive, paths)
drives, paths = zip(*paths_with_drives)
# the drive we care about is the last one in the list
drive = next(filter(None, reversed(drives)), '')
return os.path.join(drive, os.path.join(*paths))
def resolve_path(target, start=os.path.curdir):
r"""
Find a path from start to target where target is relative to start.
r"""
Find a path from start to target where target is relative to start.
>>> tmp = str(getfixture('tmpdir_as_cwd'))
>>> tmp = str(getfixture('tmpdir_as_cwd'))
>>> findpath('d:\\')
'd:\\'
>>> findpath('d:\\')
'd:\\'
>>> findpath('d:\\', tmp)
'd:\\'
>>> findpath('d:\\', tmp)
'd:\\'
>>> findpath('\\bar', 'd:\\')
'd:\\bar'
>>> findpath('\\bar', 'd:\\')
'd:\\bar'
>>> findpath('\\bar', 'd:\\foo') # fails with '\\bar'
'd:\\bar'
>>> findpath('\\bar', 'd:\\foo') # fails with '\\bar'
'd:\\bar'
>>> findpath('bar', 'd:\\foo')
'd:\\foo\\bar'
>>> findpath('bar', 'd:\\foo')
'd:\\foo\\bar'
>>> findpath('\\baz', 'd:\\foo\\bar') # fails with '\\baz'
'd:\\baz'
>>> findpath('\\baz', 'd:\\foo\\bar') # fails with '\\baz'
'd:\\baz'
>>> os.path.abspath(findpath('\\bar')).lower()
'c:\\bar'
>>> os.path.abspath(findpath('\\bar')).lower()
'c:\\bar'
>>> os.path.abspath(findpath('bar'))
'...\\bar'
>>> os.path.abspath(findpath('bar'))
'...\\bar'
>>> findpath('..', 'd:\\foo\\bar')
'd:\\foo'
>>> findpath('..', 'd:\\foo\\bar')
'd:\\foo'
The parent of the root directory is the root directory.
>>> findpath('..', 'd:\\')
'd:\\'
"""
return os.path.normpath(join(start, target))
The parent of the root directory is the root directory.
>>> findpath('..', 'd:\\')
'd:\\'
"""
return os.path.normpath(join(start, target))
findpath = resolve_path
def trace_symlink_target(link):
"""
Given a file that is known to be a symlink, trace it to its ultimate
target.
"""
Given a file that is known to be a symlink, trace it to its ultimate
target.
Raises TargetNotPresent when the target cannot be determined.
Raises ValueError when the specified link is not a symlink.
"""
Raises TargetNotPresent when the target cannot be determined.
Raises ValueError when the specified link is not a symlink.
"""
if not is_symlink(link):
raise ValueError("link must point to a symlink on the system")
while is_symlink(link):
orig = os.path.dirname(link)
link = readlink(link)
link = resolve_path(link, orig)
return link
if not is_symlink(link):
raise ValueError("link must point to a symlink on the system")
while is_symlink(link):
orig = os.path.dirname(link)
link = readlink(link)
link = resolve_path(link, orig)
return link
def readlink(link):
"""
readlink(link) -> target
Return a string representing the path to which the symbolic link points.
"""
handle = api.CreateFile(
link,
0,
0,
None,
api.OPEN_EXISTING,
api.FILE_FLAG_OPEN_REPARSE_POINT | api.FILE_FLAG_BACKUP_SEMANTICS,
None,
)
"""
readlink(link) -> target
Return a string representing the path to which the symbolic link points.
"""
handle = api.CreateFile(
link,
0,
0,
None,
api.OPEN_EXISTING,
api.FILE_FLAG_OPEN_REPARSE_POINT | api.FILE_FLAG_BACKUP_SEMANTICS,
None,
)
if handle == api.INVALID_HANDLE_VALUE:
raise WindowsError()
if handle == api.INVALID_HANDLE_VALUE:
raise WindowsError()
res = reparse.DeviceIoControl(
handle, api.FSCTL_GET_REPARSE_POINT, None, 10240)
res = reparse.DeviceIoControl(handle, api.FSCTL_GET_REPARSE_POINT, None, 10240)
bytes = create_string_buffer(res)
p_rdb = cast(bytes, POINTER(api.REPARSE_DATA_BUFFER))
rdb = p_rdb.contents
if not rdb.tag == api.IO_REPARSE_TAG_SYMLINK:
raise RuntimeError("Expected IO_REPARSE_TAG_SYMLINK, but got %d" % rdb.tag)
bytes = create_string_buffer(res)
p_rdb = cast(bytes, POINTER(api.REPARSE_DATA_BUFFER))
rdb = p_rdb.contents
if not rdb.tag == api.IO_REPARSE_TAG_SYMLINK:
raise RuntimeError("Expected IO_REPARSE_TAG_SYMLINK, but got %d" % rdb.tag)
handle_nonzero_success(api.CloseHandle(handle))
return rdb.get_substitute_name()
handle_nonzero_success(api.CloseHandle(handle))
return rdb.get_substitute_name()
def patch_os_module():
"""
jaraco.windows provides the os.symlink and os.readlink functions.
Monkey-patch the os module to include them if not present.
"""
if not hasattr(os, 'symlink'):
os.symlink = symlink
os.path.islink = islink
if not hasattr(os, 'readlink'):
os.readlink = readlink
"""
jaraco.windows provides the os.symlink and os.readlink functions.
Monkey-patch the os module to include them if not present.
"""
if not hasattr(os, 'symlink'):
os.symlink = symlink
os.path.islink = islink
if not hasattr(os, 'readlink'):
os.readlink = readlink
def find_symlinks(root):
for dirpath, dirnames, filenames in os.walk(root):
for name in dirnames + filenames:
pathname = os.path.join(dirpath, name)
if is_symlink(pathname):
yield pathname
# don't traverse symlinks
if name in dirnames:
dirnames.remove(name)
for dirpath, dirnames, filenames in os.walk(root):
for name in dirnames + filenames:
pathname = os.path.join(dirpath, name)
if is_symlink(pathname):
yield pathname
# don't traverse symlinks
if name in dirnames:
dirnames.remove(name)
def find_symlinks_cmd():
"""
%prog [start-path]
Search the specified path (defaults to the current directory) for symlinks,
printing the source and target on each line.
"""
from optparse import OptionParser
from textwrap import dedent
parser = OptionParser(usage=dedent(find_symlinks_cmd.__doc__).strip())
options, args = parser.parse_args()
if not args:
args = ['.']
root = args.pop()
if args:
parser.error("unexpected argument(s)")
try:
for symlink in find_symlinks(root):
target = readlink(symlink)
dir = ['', 'D'][os.path.isdir(symlink)]
msg = '{dir:2}{symlink} --> {target}'.format(**locals())
print(msg)
except KeyboardInterrupt:
pass
"""
%prog [start-path]
Search the specified path (defaults to the current directory) for symlinks,
printing the source and target on each line.
"""
from optparse import OptionParser
from textwrap import dedent
parser = OptionParser(usage=dedent(find_symlinks_cmd.__doc__).strip())
options, args = parser.parse_args()
if not args:
args = ['.']
root = args.pop()
if args:
parser.error("unexpected argument(s)")
try:
for symlink in find_symlinks(root):
target = readlink(symlink)
dir = ['', 'D'][os.path.isdir(symlink)]
msg = '{dir:2}{symlink} --> {target}'.format(**locals())
print(msg)
except KeyboardInterrupt:
pass
@six.add_metaclass(binary.BitMask)
class FileAttributes(int):
class FileAttributes(int, metaclass=binary.BitMask):
# extract the values from the stat module on Python 3.5
# and later.
locals().update(
(name.split('FILE_ATTRIBUTES_')[1].lower(), value)
for name, value in vars(stat).items()
if name.startswith('FILE_ATTRIBUTES_')
)
# extract the values from the stat module on Python 3.5
# and later.
locals().update(
(name.split('FILE_ATTRIBUTES_')[1].lower(), value)
for name, value in vars(stat).items()
if name.startswith('FILE_ATTRIBUTES_')
)
# For Python 3.4 and earlier, define the constants here
archive = 0x20
compressed = 0x800
hidden = 0x2
device = 0x40
directory = 0x10
encrypted = 0x4000
normal = 0x80
not_content_indexed = 0x2000
offline = 0x1000
read_only = 0x1
reparse_point = 0x400
sparse_file = 0x200
system = 0x4
temporary = 0x100
virtual = 0x10000
# For Python 3.4 and earlier, define the constants here
archive = 0x20
compressed = 0x800
hidden = 0x2
device = 0x40
directory = 0x10
encrypted = 0x4000
normal = 0x80
not_content_indexed = 0x2000
offline = 0x1000
read_only = 0x1
reparse_point = 0x400
sparse_file = 0x200
system = 0x4
temporary = 0x100
virtual = 0x10000
@classmethod
def get(cls, filepath):
attrs = api.GetFileAttributes(filepath)
if attrs == api.INVALID_FILE_ATTRIBUTES:
raise WindowsError()
return cls(attrs)
@classmethod
def get(cls, filepath):
attrs = api.GetFileAttributes(filepath)
if attrs == api.INVALID_FILE_ATTRIBUTES:
raise WindowsError()
return cls(attrs)
GetFileAttributes = FileAttributes.get
def SetFileAttributes(filepath, *attrs):
"""
Set file attributes. e.g.:
"""
Set file attributes. e.g.:
SetFileAttributes('C:\\foo', 'hidden')
SetFileAttributes('C:\\foo', 'hidden')
Each attr must be either a numeric value, a constant defined in
jaraco.windows.filesystem.api, or one of the nice names
defined in this function.
"""
nice_names = collections.defaultdict(
lambda key: key,
hidden='FILE_ATTRIBUTE_HIDDEN',
read_only='FILE_ATTRIBUTE_READONLY',
)
flags = (getattr(api, nice_names[attr], attr) for attr in attrs)
flags = functools.reduce(operator.or_, flags)
handle_nonzero_success(api.SetFileAttributes(filepath, flags))
Each attr must be either a numeric value, a constant defined in
jaraco.windows.filesystem.api, or one of the nice names
defined in this function.
"""
nice_names = collections.defaultdict(
lambda key: key,
hidden='FILE_ATTRIBUTE_HIDDEN',
read_only='FILE_ATTRIBUTE_READONLY',
)
flags = (getattr(api, nice_names[attr], attr) for attr in attrs)
flags = functools.reduce(operator.or_, flags)
handle_nonzero_success(api.SetFileAttributes(filepath, flags))

View file

@ -1,109 +1,107 @@
from __future__ import unicode_literals
import os.path
# realpath taken from https://bugs.python.org/file38057/issue9949-v4.patch
def realpath(path):
if isinstance(path, str):
prefix = '\\\\?\\'
unc_prefix = prefix + 'UNC'
new_unc_prefix = '\\'
cwd = os.getcwd()
else:
prefix = b'\\\\?\\'
unc_prefix = prefix + b'UNC'
new_unc_prefix = b'\\'
cwd = os.getcwdb()
had_prefix = path.startswith(prefix)
path, ok = _resolve_path(cwd, path, {})
# The path returned by _getfinalpathname will always start with \\?\ -
# strip off that prefix unless it was already provided on the original
# path.
if not had_prefix:
# For UNC paths, the prefix will actually be \\?\UNC - handle that
# case as well.
if path.startswith(unc_prefix):
path = new_unc_prefix + path[len(unc_prefix):]
elif path.startswith(prefix):
path = path[len(prefix):]
return path
if isinstance(path, str):
prefix = '\\\\?\\'
unc_prefix = prefix + 'UNC'
new_unc_prefix = '\\'
cwd = os.getcwd()
else:
prefix = b'\\\\?\\'
unc_prefix = prefix + b'UNC'
new_unc_prefix = b'\\'
cwd = os.getcwdb()
had_prefix = path.startswith(prefix)
path, ok = _resolve_path(cwd, path, {})
# The path returned by _getfinalpathname will always start with \\?\ -
# strip off that prefix unless it was already provided on the original
# path.
if not had_prefix:
# For UNC paths, the prefix will actually be \\?\UNC - handle that
# case as well.
if path.startswith(unc_prefix):
path = new_unc_prefix + path[len(unc_prefix) :]
elif path.startswith(prefix):
path = path[len(prefix) :]
return path
def _resolve_path(path, rest, seen):
# Windows normalizes the path before resolving symlinks; be sure to
# follow the same behavior.
rest = os.path.normpath(rest)
def _resolve_path(path, rest, seen): # noqa: C901
# Windows normalizes the path before resolving symlinks; be sure to
# follow the same behavior.
rest = os.path.normpath(rest)
if isinstance(rest, str):
sep = '\\'
else:
sep = b'\\'
if isinstance(rest, str):
sep = '\\'
else:
sep = b'\\'
if os.path.isabs(rest):
drive, rest = os.path.splitdrive(rest)
path = drive + sep
rest = rest[1:]
if os.path.isabs(rest):
drive, rest = os.path.splitdrive(rest)
path = drive + sep
rest = rest[1:]
while rest:
name, _, rest = rest.partition(sep)
new_path = os.path.join(path, name) if path else name
if os.path.exists(new_path):
if not rest:
# The whole path exists. Resolve it using the OS.
path = os.path._getfinalpathname(new_path)
else:
# The OS can resolve `new_path`; keep traversing the path.
path = new_path
elif not os.path.lexists(new_path):
# `new_path` does not exist on the filesystem at all. Use the
# OS to resolve `path`, if it exists, and then append the
# remainder.
if os.path.exists(path):
path = os.path._getfinalpathname(path)
rest = os.path.join(name, rest) if rest else name
return os.path.join(path, rest), True
else:
# We have a symbolic link that the OS cannot resolve. Try to
# resolve it ourselves.
while rest:
name, _, rest = rest.partition(sep)
new_path = os.path.join(path, name) if path else name
if os.path.exists(new_path):
if not rest:
# The whole path exists. Resolve it using the OS.
path = os.path._getfinalpathname(new_path)
else:
# The OS can resolve `new_path`; keep traversing the path.
path = new_path
elif not os.path.lexists(new_path):
# `new_path` does not exist on the filesystem at all. Use the
# OS to resolve `path`, if it exists, and then append the
# remainder.
if os.path.exists(path):
path = os.path._getfinalpathname(path)
rest = os.path.join(name, rest) if rest else name
return os.path.join(path, rest), True
else:
# We have a symbolic link that the OS cannot resolve. Try to
# resolve it ourselves.
# On Windows, symbolic link resolution can be partially or
# fully disabled [1]. The end result of a disabled symlink
# appears the same as a broken symlink (lexists() returns True
# but exists() returns False). And in both cases, the link can
# still be read using readlink(). Call stat() and check the
# resulting error code to ensure we don't circumvent the
# Windows symbolic link restrictions.
# [1] https://technet.microsoft.com/en-us/library/cc754077.aspx
try:
os.stat(new_path)
except OSError as e:
# WinError 1463: The symbolic link cannot be followed
# because its type is disabled.
if e.winerror == 1463:
raise
# On Windows, symbolic link resolution can be partially or
# fully disabled [1]. The end result of a disabled symlink
# appears the same as a broken symlink (lexists() returns True
# but exists() returns False). And in both cases, the link can
# still be read using readlink(). Call stat() and check the
# resulting error code to ensure we don't circumvent the
# Windows symbolic link restrictions.
# [1] https://technet.microsoft.com/en-us/library/cc754077.aspx
try:
os.stat(new_path)
except OSError as e:
# WinError 1463: The symbolic link cannot be followed
# because its type is disabled.
if e.winerror == 1463:
raise
key = os.path.normcase(new_path)
if key in seen:
# This link has already been seen; try to use the
# previously resolved value.
path = seen[key]
if path is None:
# It has not yet been resolved, which means we must
# have a symbolic link loop. Return what we have
# resolved so far plus the remainder of the path (who
# cares about the Zen of Python?).
path = os.path.join(new_path, rest) if rest else new_path
return path, False
else:
# Mark this link as in the process of being resolved.
seen[key] = None
# Try to resolve it.
path, ok = _resolve_path(path, os.readlink(new_path), seen)
if ok:
# Resolution succeded; store the resolved value.
seen[key] = path
else:
# Resolution failed; punt.
return (os.path.join(path, rest) if rest else path), False
return path, True
key = os.path.normcase(new_path)
if key in seen:
# This link has already been seen; try to use the
# previously resolved value.
path = seen[key]
if path is None:
# It has not yet been resolved, which means we must
# have a symbolic link loop. Return what we have
# resolved so far plus the remainder of the path (who
# cares about the Zen of Python?).
path = os.path.join(new_path, rest) if rest else new_path
return path, False
else:
# Mark this link as in the process of being resolved.
seen[key] = None
# Try to resolve it.
path, ok = _resolve_path(path, os.readlink(new_path), seen)
if ok:
# Resolution succeded; store the resolved value.
seen[key] = path
else:
# Resolution failed; punt.
return (os.path.join(path, rest) if rest else path), False
return path, True

View file

@ -1,14 +1,10 @@
# -*- coding: UTF-8 -*-
"""
FileChange
Classes and routines for monitoring the file system for changes.
Classes and routines for monitoring the file system for changes.
Copyright © 2004, 2011, 2013 Jason R. Coombs
"""
from __future__ import print_function
import os
import sys
import datetime
@ -17,8 +13,6 @@ from threading import Thread
import itertools
import logging
import six
from more_itertools.recipes import consume
import jaraco.text
@ -29,243 +23,237 @@ log = logging.getLogger(__name__)
class NotifierException(Exception):
pass
pass
class FileFilter(object):
def set_root(self, root):
self.root = root
def set_root(self, root):
self.root = root
def _get_file_path(self, filename):
try:
filename = os.path.join(self.root, filename)
except AttributeError:
pass
return filename
def _get_file_path(self, filename):
try:
filename = os.path.join(self.root, filename)
except AttributeError:
pass
return filename
class ModifiedTimeFilter(FileFilter):
"""
Returns true for each call where the modified time of the file is after
the cutoff time.
"""
def __init__(self, cutoff):
self.cutoff = cutoff
"""
Returns true for each call where the modified time of the file is after
the cutoff time.
"""
def __call__(self, file):
filepath = self._get_file_path(file)
last_mod = datetime.datetime.utcfromtimestamp(
os.stat(filepath).st_mtime)
log.debug('{filepath} last modified at {last_mod}.'.format(**vars()))
return last_mod > self.cutoff
def __init__(self, cutoff):
self.cutoff = cutoff
def __call__(self, file):
filepath = self._get_file_path(file)
last_mod = datetime.datetime.utcfromtimestamp(os.stat(filepath).st_mtime)
log.debug('{filepath} last modified at {last_mod}.'.format(**vars()))
return last_mod > self.cutoff
class PatternFilter(FileFilter):
"""
Filter that returns True for files that match pattern (a regular
expression).
"""
def __init__(self, pattern):
self.pattern = (
re.compile(pattern) if isinstance(pattern, six.string_types)
else pattern
)
"""
Filter that returns True for files that match pattern (a regular
expression).
"""
def __call__(self, file):
return bool(self.pattern.match(file, re.I))
def __init__(self, pattern):
self.pattern = re.compile(pattern) if isinstance(pattern, str) else pattern
def __call__(self, file):
return bool(self.pattern.match(file, re.I))
class GlobFilter(PatternFilter):
"""
Filter that returns True for files that match the pattern (a glob
expression.
"""
def __init__(self, expression):
super(GlobFilter, self).__init__(
self.convert_file_pattern(expression))
"""
Filter that returns True for files that match the pattern (a glob
expression.
"""
@staticmethod
def convert_file_pattern(p):
r"""
converts a filename specification (such as c:\*.*) to an equivelent
regular expression
>>> GlobFilter.convert_file_pattern('/*')
'/.*'
"""
subs = (('\\', '\\\\'), ('.', '\\.'), ('*', '.*'), ('?', '.'))
return jaraco.text.multi_substitution(*subs)(p)
def __init__(self, expression):
super(GlobFilter, self).__init__(self.convert_file_pattern(expression))
@staticmethod
def convert_file_pattern(p):
r"""
converts a filename specification (such as c:\*.*) to an equivelent
regular expression
>>> GlobFilter.convert_file_pattern('/*')
'/.*'
"""
subs = (('\\', '\\\\'), ('.', '\\.'), ('*', '.*'), ('?', '.'))
return jaraco.text.multi_substitution(*subs)(p)
class AggregateFilter(FileFilter):
"""
This file filter will aggregate the filters passed to it, and when called,
will return the results of each filter ANDed together.
"""
def __init__(self, *filters):
self.filters = filters
"""
This file filter will aggregate the filters passed to it, and when called,
will return the results of each filter ANDed together.
"""
def set_root(self, root):
consume(f.set_root(root) for f in self.filters)
def __init__(self, *filters):
self.filters = filters
def __call__(self, file):
return all(fil(file) for fil in self.filters)
def set_root(self, root):
consume(f.set_root(root) for f in self.filters)
def __call__(self, file):
return all(fil(file) for fil in self.filters)
class OncePerModFilter(FileFilter):
def __init__(self):
self.history = list()
def __init__(self):
self.history = list()
def __call__(self, file):
file = os.path.join(self.root, file)
key = file, os.stat(file).st_mtime
result = key not in self.history
self.history.append(key)
if len(self.history) > 100:
del self.history[-50:]
return result
def __call__(self, file):
file = os.path.join(self.root, file)
key = file, os.stat(file).st_mtime
result = key not in self.history
self.history.append(key)
if len(self.history) > 100:
del self.history[-50:]
return result
def files_with_path(files, path):
return (os.path.join(path, file) for file in files)
return (os.path.join(path, file) for file in files)
def get_file_paths(walk_result):
root, dirs, files = walk_result
return files_with_path(files, root)
root, dirs, files = walk_result
return files_with_path(files, root)
class Notifier(object):
def __init__(self, root='.', filters=[]):
# assign the root, verify it exists
self.root = root
if not os.path.isdir(self.root):
raise NotifierException(
'Root directory "%s" does not exist' % self.root)
self.filters = filters
def __init__(self, root='.', filters=[]):
# assign the root, verify it exists
self.root = root
if not os.path.isdir(self.root):
raise NotifierException('Root directory "%s" does not exist' % self.root)
self.filters = filters
self.watch_subtree = False
self.quit_event = event.CreateEvent(None, 0, 0, None)
self.opm_filter = OncePerModFilter()
self.watch_subtree = False
self.quit_event = event.CreateEvent(None, 0, 0, None)
self.opm_filter = OncePerModFilter()
def __del__(self):
try:
fs.FindCloseChangeNotification(self.hChange)
except Exception:
pass
def __del__(self):
try:
fs.FindCloseChangeNotification(self.hChange)
except Exception:
pass
def _get_change_handle(self):
# set up to monitor the directory tree specified
self.hChange = fs.FindFirstChangeNotification(
self.root,
self.watch_subtree,
fs.FILE_NOTIFY_CHANGE_LAST_WRITE,
)
def _get_change_handle(self):
# set up to monitor the directory tree specified
self.hChange = fs.FindFirstChangeNotification(
self.root, self.watch_subtree, fs.FILE_NOTIFY_CHANGE_LAST_WRITE
)
# make sure it worked; if not, bail
INVALID_HANDLE_VALUE = fs.INVALID_HANDLE_VALUE
if self.hChange == INVALID_HANDLE_VALUE:
raise NotifierException(
'Could not set up directory change notification')
# make sure it worked; if not, bail
INVALID_HANDLE_VALUE = fs.INVALID_HANDLE_VALUE
if self.hChange == INVALID_HANDLE_VALUE:
raise NotifierException('Could not set up directory change notification')
@staticmethod
def _filtered_walk(path, file_filter):
"""
static method that calls os.walk, but filters out
anything that doesn't match the filter
"""
for root, dirs, files in os.walk(path):
log.debug('looking in %s', root)
log.debug('files is %s', files)
file_filter.set_root(root)
files = filter(file_filter, files)
log.debug('filtered files is %s', files)
yield (root, dirs, files)
@staticmethod
def _filtered_walk(path, file_filter):
"""
static method that calls os.walk, but filters out
anything that doesn't match the filter
"""
for root, dirs, files in os.walk(path):
log.debug('looking in %s', root)
log.debug('files is %s', files)
file_filter.set_root(root)
files = filter(file_filter, files)
log.debug('filtered files is %s', files)
yield (root, dirs, files)
def quit(self):
event.SetEvent(self.quit_event)
def quit(self):
event.SetEvent(self.quit_event)
class BlockingNotifier(Notifier):
@staticmethod
def wait_results(*args):
"""calls WaitForMultipleObjects repeatedly with args"""
return itertools.starmap(event.WaitForMultipleObjects, itertools.repeat(args))
@staticmethod
def wait_results(*args):
""" calls WaitForMultipleObjects repeatedly with args """
return itertools.starmap(
event.WaitForMultipleObjects,
itertools.repeat(args))
def get_changed_files(self):
self._get_change_handle()
check_time = datetime.datetime.utcnow()
# block (sleep) until something changes in the
# target directory or a quit is requested.
# timeout so we can catch keyboard interrupts or other exceptions
events = (self.hChange, self.quit_event)
for result in self.wait_results(events, False, 1000):
if result == event.WAIT_TIMEOUT:
continue
index = result - event.WAIT_OBJECT_0
if events[index] is self.quit_event:
# quit was received; stop yielding results
return
def get_changed_files(self):
self._get_change_handle()
check_time = datetime.datetime.utcnow()
# block (sleep) until something changes in the
# target directory or a quit is requested.
# timeout so we can catch keyboard interrupts or other exceptions
events = (self.hChange, self.quit_event)
for result in self.wait_results(events, False, 1000):
if result == event.WAIT_TIMEOUT:
continue
index = result - event.WAIT_OBJECT_0
if events[index] is self.quit_event:
# quit was received; stop yielding results
return
# something has changed.
log.debug('Change notification received')
fs.FindNextChangeNotification(self.hChange)
next_check_time = datetime.datetime.utcnow()
log.debug('Looking for all files changed after %s', check_time)
for file in self.find_files_after(check_time):
yield file
check_time = next_check_time
# something has changed.
log.debug('Change notification received')
fs.FindNextChangeNotification(self.hChange)
next_check_time = datetime.datetime.utcnow()
log.debug('Looking for all files changed after %s', check_time)
for file in self.find_files_after(check_time):
yield file
check_time = next_check_time
def find_files_after(self, cutoff):
mtf = ModifiedTimeFilter(cutoff)
af = AggregateFilter(mtf, self.opm_filter, *self.filters)
results = Notifier._filtered_walk(self.root, af)
results = itertools.imap(get_file_paths, results)
if self.watch_subtree:
result = itertools.chain(*results)
else:
result = next(results)
return result
def find_files_after(self, cutoff):
mtf = ModifiedTimeFilter(cutoff)
af = AggregateFilter(mtf, self.opm_filter, *self.filters)
results = Notifier._filtered_walk(self.root, af)
results = itertools.imap(get_file_paths, results)
if self.watch_subtree:
result = itertools.chain(*results)
else:
result = next(results)
return result
class ThreadedNotifier(BlockingNotifier, Thread):
r"""
ThreadedNotifier provides a simple interface that calls the handler
for each file rooted in root that passes the filters. It runs as its own
thread, so must be started as such::
r"""
ThreadedNotifier provides a simple interface that calls the handler
for each file rooted in root that passes the filters. It runs as its own
thread, so must be started as such::
notifier = ThreadedNotifier('c:\\', handler = StreamHandler())
notifier.start()
C:\Autoexec.bat changed.
"""
def __init__(self, root='.', filters=[], handler=lambda file: None):
# init notifier stuff
BlockingNotifier.__init__(self, root, filters)
# init thread stuff
Thread.__init__(self)
# set it as a daemon thread so that it doesn't block waiting to close.
# I tried setting __del__(self) to .quit(), but unfortunately, there
# are references to this object in the win32api stuff, so __del__
# never gets called.
self.setDaemon(True)
notifier = ThreadedNotifier('c:\\', handler = StreamHandler())
notifier.start()
C:\Autoexec.bat changed.
"""
self.handle = handler
def __init__(self, root='.', filters=[], handler=lambda file: None):
# init notifier stuff
BlockingNotifier.__init__(self, root, filters)
# init thread stuff
Thread.__init__(self)
# set it as a daemon thread so that it doesn't block waiting to close.
# I tried setting __del__(self) to .quit(), but unfortunately, there
# are references to this object in the win32api stuff, so __del__
# never gets called.
self.setDaemon(True)
def run(self):
for file in self.get_changed_files():
self.handle(file)
self.handle = handler
def run(self):
for file in self.get_changed_files():
self.handle(file)
class StreamHandler(object):
"""
StreamHandler: a sample handler object for use with the threaded
notifier that will announce by writing to the supplied stream
(stdout by default) the name of the file.
"""
def __init__(self, output=sys.stdout):
self.output = output
"""
StreamHandler: a sample handler object for use with the threaded
notifier that will announce by writing to the supplied stream
(stdout by default) the name of the file.
"""
def __call__(self, filename):
self.output.write('%s changed.\n' % filename)
def __init__(self, output=sys.stdout):
self.output = output
def __call__(self, filename):
self.output.write('%s changed.\n' % filename)

View file

@ -3,8 +3,6 @@ Some routines for retrieving the addresses from the local
network config.
"""
from __future__ import print_function
import itertools
import ctypes
@ -13,112 +11,108 @@ from jaraco.windows.api import errors, inet
def GetAdaptersAddresses():
size = ctypes.c_ulong()
res = inet.GetAdaptersAddresses(0, 0, None, None, size)
if res != errors.ERROR_BUFFER_OVERFLOW:
raise RuntimeError("Error getting structure length (%d)" % res)
print(size.value)
pointer_type = ctypes.POINTER(inet.IP_ADAPTER_ADDRESSES)
buffer = ctypes.create_string_buffer(size.value)
struct_p = ctypes.cast(buffer, pointer_type)
res = inet.GetAdaptersAddresses(0, 0, None, struct_p, size)
if res != errors.NO_ERROR:
raise RuntimeError("Error retrieving table (%d)" % res)
while struct_p:
yield struct_p.contents
struct_p = struct_p.contents.next
size = ctypes.c_ulong()
res = inet.GetAdaptersAddresses(0, 0, None, None, size)
if res != errors.ERROR_BUFFER_OVERFLOW:
raise RuntimeError("Error getting structure length (%d)" % res)
print(size.value)
pointer_type = ctypes.POINTER(inet.IP_ADAPTER_ADDRESSES)
buffer = ctypes.create_string_buffer(size.value)
struct_p = ctypes.cast(buffer, pointer_type)
res = inet.GetAdaptersAddresses(0, 0, None, struct_p, size)
if res != errors.NO_ERROR:
raise RuntimeError("Error retrieving table (%d)" % res)
while struct_p:
yield struct_p.contents
struct_p = struct_p.contents.next
class AllocatedTable(object):
"""
Both the interface table and the ip address table use the same
technique to store arrays of structures of variable length. This
base class captures the functionality to retrieve and access those
table entries.
"""
Both the interface table and the ip address table use the same
technique to store arrays of structures of variable length. This
base class captures the functionality to retrieve and access those
table entries.
The subclass needs to define three class attributes:
method: a callable that takes three arguments - a pointer to
the structure, the length of the data contained by the
structure, and a boolean of whether the result should
be sorted.
structure: a C structure defininition that describes the table
format.
row_structure: a C structure definition that describes the row
format.
"""
def __get_table_size(self):
"""
Retrieve the size of the buffer needed by calling the method
with a null pointer and length of zero. This should trigger an
insufficient buffer error and return the size needed for the
buffer.
"""
length = ctypes.wintypes.DWORD()
res = self.method(None, length, False)
if res != errors.ERROR_INSUFFICIENT_BUFFER:
raise RuntimeError("Error getting table length (%d)" % res)
return length.value
The subclass needs to define three class attributes:
method: a callable that takes three arguments - a pointer to
the structure, the length of the data contained by the
structure, and a boolean of whether the result should
be sorted.
structure: a C structure defininition that describes the table
format.
row_structure: a C structure definition that describes the row
format.
"""
def get_table(self):
"""
Get the table
"""
buffer_length = self.__get_table_size()
returned_buffer_length = ctypes.wintypes.DWORD(buffer_length)
buffer = ctypes.create_string_buffer(buffer_length)
pointer_type = ctypes.POINTER(self.structure)
table_p = ctypes.cast(buffer, pointer_type)
res = self.method(table_p, returned_buffer_length, False)
if res != errors.NO_ERROR:
raise RuntimeError("Error retrieving table (%d)" % res)
return table_p.contents
def __get_table_size(self):
"""
Retrieve the size of the buffer needed by calling the method
with a null pointer and length of zero. This should trigger an
insufficient buffer error and return the size needed for the
buffer.
"""
length = ctypes.wintypes.DWORD()
res = self.method(None, length, False)
if res != errors.ERROR_INSUFFICIENT_BUFFER:
raise RuntimeError("Error getting table length (%d)" % res)
return length.value
@property
def entries(self):
"""
Using the table structure, return the array of entries based
on the table size.
"""
table = self.get_table()
entries_array = self.row_structure * table.num_entries
pointer_type = ctypes.POINTER(entries_array)
return ctypes.cast(table.entries, pointer_type).contents
def get_table(self):
"""
Get the table
"""
buffer_length = self.__get_table_size()
returned_buffer_length = ctypes.wintypes.DWORD(buffer_length)
buffer = ctypes.create_string_buffer(buffer_length)
pointer_type = ctypes.POINTER(self.structure)
table_p = ctypes.cast(buffer, pointer_type)
res = self.method(table_p, returned_buffer_length, False)
if res != errors.NO_ERROR:
raise RuntimeError("Error retrieving table (%d)" % res)
return table_p.contents
@property
def entries(self):
"""
Using the table structure, return the array of entries based
on the table size.
"""
table = self.get_table()
entries_array = self.row_structure * table.num_entries
pointer_type = ctypes.POINTER(entries_array)
return ctypes.cast(table.entries, pointer_type).contents
class InterfaceTable(AllocatedTable):
method = inet.GetIfTable
structure = inet.MIB_IFTABLE
row_structure = inet.MIB_IFROW
method = inet.GetIfTable
structure = inet.MIB_IFTABLE
row_structure = inet.MIB_IFROW
class AddressTable(AllocatedTable):
method = inet.GetIpAddrTable
structure = inet.MIB_IPADDRTABLE
row_structure = inet.MIB_IPADDRROW
method = inet.GetIpAddrTable
structure = inet.MIB_IPADDRTABLE
row_structure = inet.MIB_IPADDRROW
class AddressManager(object):
@staticmethod
def hardware_address_to_string(addr):
hex_bytes = (byte.encode('hex') for byte in addr)
return ':'.join(hex_bytes)
@staticmethod
def hardware_address_to_string(addr):
hex_bytes = (byte.encode('hex') for byte in addr)
return ':'.join(hex_bytes)
def get_host_mac_address_strings(self):
return (
self.hardware_address_to_string(addr)
for addr in self.get_host_mac_addresses())
def get_host_mac_address_strings(self):
return (
self.hardware_address_to_string(addr)
for addr in self.get_host_mac_addresses()
)
def get_host_ip_address_strings(self):
return itertools.imap(str, self.get_host_ip_addresses())
def get_host_ip_address_strings(self):
return itertools.imap(str, self.get_host_ip_addresses())
def get_host_mac_addresses(self):
return (
entry.physical_address
for entry in InterfaceTable().entries
)
def get_host_mac_addresses(self):
return (entry.physical_address for entry in InterfaceTable().entries)
def get_host_ip_addresses(self):
return (
entry.address
for entry in AddressTable().entries
)
def get_host_ip_addresses(self):
return (entry.address for entry in AddressTable().entries)

View file

@ -4,18 +4,18 @@ from .api import library
def find_lib(lib):
r"""
Find the DLL for a given library.
r"""
Find the DLL for a given library.
Accepts a string or loaded module
Accepts a string or loaded module
>>> print(find_lib('kernel32').lower())
c:\windows\system32\kernel32.dll
"""
if isinstance(lib, str):
lib = getattr(ctypes.windll, lib)
>>> print(find_lib('kernel32').lower())
c:\windows\system32\kernel32.dll
"""
if isinstance(lib, str):
lib = getattr(ctypes.windll, lib)
size = 1024
result = ctypes.create_unicode_buffer(size)
library.GetModuleFileName(lib._handle, result, size)
return result.value
size = 1024
result = ctypes.create_unicode_buffer(size)
library.GetModuleFileName(lib._handle, result, size)
return result.value

View file

@ -5,25 +5,25 @@ from .api import memory
class LockedMemory(object):
def __init__(self, handle):
self.handle = handle
def __init__(self, handle):
self.handle = handle
def __enter__(self):
self.data_ptr = memory.GlobalLock(self.handle)
if not self.data_ptr:
del self.data_ptr
raise WinError()
return self
def __enter__(self):
self.data_ptr = memory.GlobalLock(self.handle)
if not self.data_ptr:
del self.data_ptr
raise WinError()
return self
def __exit__(self, *args):
memory.GlobalUnlock(self.handle)
del self.data_ptr
def __exit__(self, *args):
memory.GlobalUnlock(self.handle)
del self.data_ptr
@property
def data(self):
with self:
return ctypes.string_at(self.data_ptr, self.size)
@property
def data(self):
with self:
return ctypes.string_at(self.data_ptr, self.size)
@property
def size(self):
return memory.GlobalSize(self.data_ptr)
@property
def size(self):
return memory.GlobalSize(self.data_ptr)

View file

@ -1,63 +1,66 @@
import ctypes.wintypes
import six
from .error import handle_nonzero_success
from .api import memory
class MemoryMap(object):
"""
A memory map object which can have security attributes overridden.
"""
def __init__(self, name, length, security_attributes=None):
self.name = name
self.length = length
self.security_attributes = security_attributes
self.pos = 0
"""
A memory map object which can have security attributes overridden.
"""
def __enter__(self):
p_SA = (
ctypes.byref(self.security_attributes)
if self.security_attributes else None
)
INVALID_HANDLE_VALUE = -1
PAGE_READWRITE = 0x4
FILE_MAP_WRITE = 0x2
filemap = ctypes.windll.kernel32.CreateFileMappingW(
INVALID_HANDLE_VALUE, p_SA, PAGE_READWRITE, 0, self.length,
six.text_type(self.name))
handle_nonzero_success(filemap)
if filemap == INVALID_HANDLE_VALUE:
raise Exception("Failed to create file mapping")
self.filemap = filemap
self.view = memory.MapViewOfFile(filemap, FILE_MAP_WRITE, 0, 0, 0)
return self
def __init__(self, name, length, security_attributes=None):
self.name = name
self.length = length
self.security_attributes = security_attributes
self.pos = 0
def seek(self, pos):
self.pos = pos
def __enter__(self):
p_SA = (
ctypes.byref(self.security_attributes) if self.security_attributes else None
)
INVALID_HANDLE_VALUE = -1
PAGE_READWRITE = 0x4
FILE_MAP_WRITE = 0x2
filemap = ctypes.windll.kernel32.CreateFileMappingW(
INVALID_HANDLE_VALUE,
p_SA,
PAGE_READWRITE,
0,
self.length,
str(self.name),
)
handle_nonzero_success(filemap)
if filemap == INVALID_HANDLE_VALUE:
raise Exception("Failed to create file mapping")
self.filemap = filemap
self.view = memory.MapViewOfFile(filemap, FILE_MAP_WRITE, 0, 0, 0)
return self
def write(self, msg):
assert isinstance(msg, bytes)
n = len(msg)
if self.pos + n >= self.length: # A little safety.
raise ValueError("Refusing to write %d bytes" % n)
dest = self.view + self.pos
length = ctypes.c_size_t(n)
ctypes.windll.kernel32.RtlMoveMemory(dest, msg, length)
self.pos += n
def seek(self, pos):
self.pos = pos
def read(self, n):
"""
Read n bytes from mapped view.
"""
out = ctypes.create_string_buffer(n)
source = self.view + self.pos
length = ctypes.c_size_t(n)
ctypes.windll.kernel32.RtlMoveMemory(out, source, length)
self.pos += n
return out.raw
def write(self, msg):
assert isinstance(msg, bytes)
n = len(msg)
if self.pos + n >= self.length: # A little safety.
raise ValueError("Refusing to write %d bytes" % n)
dest = self.view + self.pos
length = ctypes.c_size_t(n)
ctypes.windll.kernel32.RtlMoveMemory(dest, msg, length)
self.pos += n
def __exit__(self, exc_type, exc_val, tb):
ctypes.windll.kernel32.UnmapViewOfFile(self.view)
ctypes.windll.kernel32.CloseHandle(self.filemap)
def read(self, n):
"""
Read n bytes from mapped view.
"""
out = ctypes.create_string_buffer(n)
source = self.view + self.pos
length = ctypes.c_size_t(n)
ctypes.windll.kernel32.RtlMoveMemory(out, source, length)
self.pos += n
return out.raw
def __exit__(self, exc_type, exc_val, tb):
ctypes.windll.kernel32.UnmapViewOfFile(self.view)
ctypes.windll.kernel32.CloseHandle(self.filemap)

View file

@ -1,5 +1,3 @@
# -*- coding: UTF-8 -*-
"""cookies.py
Cookie support utilities
@ -8,52 +6,50 @@ Cookie support utilities
import os
import itertools
import six
class CookieMonster(object):
"Read cookies out of a user's IE cookies file"
"Read cookies out of a user's IE cookies file"
@property
def cookie_dir(self):
import _winreg as winreg
key = winreg.OpenKeyEx(
winreg.HKEY_CURRENT_USER, 'Software'
'\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders')
cookie_dir, type = winreg.QueryValueEx(key, 'Cookies')
return cookie_dir
@property
def cookie_dir(self):
import _winreg as winreg
def entries(self, filename):
with open(os.path.join(self.cookie_dir, filename)) as cookie_file:
while True:
entry = itertools.takewhile(
self.is_not_cookie_delimiter,
cookie_file)
entry = list(map(six.text_type.rstrip, entry))
if not entry:
break
cookie = self.make_cookie(*entry)
yield cookie
key = winreg.OpenKeyEx(
winreg.HKEY_CURRENT_USER,
'Software' r'\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders',
)
cookie_dir, type = winreg.QueryValueEx(key, 'Cookies')
return cookie_dir
@staticmethod
def is_not_cookie_delimiter(s):
return s != '*\n'
def entries(self, filename):
with open(os.path.join(self.cookie_dir, filename)) as cookie_file:
while True:
entry = itertools.takewhile(self.is_not_cookie_delimiter, cookie_file)
entry = [item.rstrip() for item in entry]
if not entry:
break
cookie = self.make_cookie(*entry)
yield cookie
@staticmethod
def make_cookie(
key, value, domain, flags, ExpireLow, ExpireHigh,
CreateLow, CreateHigh):
expires = (int(ExpireHigh) << 32) | int(ExpireLow)
created = (int(CreateHigh) << 32) | int(CreateLow)
flags = int(flags)
domain, sep, path = domain.partition('/')
path = '/' + path
return dict(
key=key,
value=value,
domain=domain,
flags=flags,
expires=expires,
created=created,
path=path,
)
@staticmethod
def is_not_cookie_delimiter(s):
return s != '*\n'
@staticmethod
def make_cookie(
key, value, domain, flags, ExpireLow, ExpireHigh, CreateLow, CreateHigh
):
expires = (int(ExpireHigh) << 32) | int(ExpireLow)
created = (int(CreateHigh) << 32) | int(CreateLow)
flags = int(flags)
domain, sep, path = domain.partition('/')
path = '/' + path
return dict(
key=key,
value=value,
domain=domain,
flags=flags,
expires=expires,
created=created,
path=path,
)

View file

@ -0,0 +1,37 @@
import subprocess
default_components = [
'Microsoft.VisualStudio.Component.CoreEditor',
'Microsoft.VisualStudio.Workload.CoreEditor',
'Microsoft.VisualStudio.Component.Roslyn.Compiler',
'Microsoft.Component.MSBuild',
'Microsoft.VisualStudio.Component.TextTemplating',
'Microsoft.VisualStudio.Component.VC.CoreIde',
'Microsoft.VisualStudio.Component.VC.Tools.x86.x64',
'Microsoft.VisualStudio.Component.VC.Tools.ARM64',
'Microsoft.VisualStudio.Component.Windows10SDK.19041',
'Microsoft.VisualStudio.Component.VC.Redist.14.Latest',
'Microsoft.VisualStudio.ComponentGroup.NativeDesktop.Core',
'Microsoft.VisualStudio.Workload.NativeDesktop',
]
def install(components=default_components):
cmd = [
'vs_buildtools',
'--quiet',
'--wait',
'--norestart',
'--nocache',
'--installPath',
'C:\\BuildTools',
]
for component in components:
cmd += ['--add', component]
res = subprocess.Popen(cmd).wait()
if res != 3010:
raise SystemExit(res)
__name__ == '__main__' and install()

View file

@ -2,29 +2,30 @@
API hooks for network stuff.
"""
__all__ = ('AddConnection')
__all__ = 'AddConnection'
from jaraco.windows.error import WindowsError
from .api import net
def AddConnection(
remote_name, type=net.RESOURCETYPE_ANY, local_name=None,
provider_name=None, user=None, password=None, flags=0):
resource = net.NETRESOURCE(
type=type,
remote_name=remote_name,
local_name=local_name,
provider_name=provider_name,
# WNetAddConnection2 ignores the other members of NETRESOURCE
)
remote_name,
type=net.RESOURCETYPE_ANY,
local_name=None,
provider_name=None,
user=None,
password=None,
flags=0,
):
resource = net.NETRESOURCE(
type=type,
remote_name=remote_name,
local_name=local_name,
provider_name=provider_name,
# WNetAddConnection2 ignores the other members of NETRESOURCE
)
result = net.WNetAddConnection2(
resource,
password,
user,
flags,
)
result = net.WNetAddConnection2(resource, password, user, flags)
if result != 0:
raise WindowsError(result)
if result != 0:
raise WindowsError(result)

View file

@ -1,77 +1,75 @@
# -*- coding: utf-8 -*-
from __future__ import print_function
import itertools
import contextlib
from more_itertools.recipes import consume, unique_justseen
try:
import wmi as wmilib
import wmi as wmilib
except ImportError:
pass
pass
from jaraco.windows.error import handle_nonzero_success
from .api import power
def GetSystemPowerStatus():
stat = power.SYSTEM_POWER_STATUS()
handle_nonzero_success(GetSystemPowerStatus(stat))
return stat
stat = power.SYSTEM_POWER_STATUS()
handle_nonzero_success(GetSystemPowerStatus(stat))
return stat
def _init_power_watcher():
global power_watcher
if 'power_watcher' not in globals():
wmi = wmilib.WMI()
query = 'SELECT * from Win32_PowerManagementEvent'
power_watcher = wmi.ExecNotificationQuery(query)
global power_watcher
if 'power_watcher' not in globals():
wmi = wmilib.WMI()
query = 'SELECT * from Win32_PowerManagementEvent'
power_watcher = wmi.ExecNotificationQuery(query)
def get_power_management_events():
_init_power_watcher()
while True:
yield power_watcher.NextEvent()
_init_power_watcher()
while True:
yield power_watcher.NextEvent()
def wait_for_power_status_change():
EVT_POWER_STATUS_CHANGE = 10
EVT_POWER_STATUS_CHANGE = 10
def not_power_status_change(evt):
return evt.EventType != EVT_POWER_STATUS_CHANGE
events = get_power_management_events()
consume(itertools.takewhile(not_power_status_change, events))
def not_power_status_change(evt):
return evt.EventType != EVT_POWER_STATUS_CHANGE
events = get_power_management_events()
consume(itertools.takewhile(not_power_status_change, events))
def get_unique_power_states():
"""
Just like get_power_states, but ensures values are returned only
when the state changes.
"""
return unique_justseen(get_power_states())
"""
Just like get_power_states, but ensures values are returned only
when the state changes.
"""
return unique_justseen(get_power_states())
def get_power_states():
"""
Continuously return the power state of the system when it changes.
This function will block indefinitely if the power state never
changes.
"""
while True:
state = GetSystemPowerStatus()
yield state.ac_line_status_string
wait_for_power_status_change()
"""
Continuously return the power state of the system when it changes.
This function will block indefinitely if the power state never
changes.
"""
while True:
state = GetSystemPowerStatus()
yield state.ac_line_status_string
wait_for_power_status_change()
@contextlib.contextmanager
def no_sleep():
"""
Context that prevents the computer from going to sleep.
"""
mode = power.ES.continuous | power.ES.system_required
handle_nonzero_success(power.SetThreadExecutionState(mode))
try:
yield
finally:
handle_nonzero_success(power.SetThreadExecutionState(power.ES.continuous))
"""
Context that prevents the computer from going to sleep.
"""
mode = power.ES.continuous | power.ES.system_required
handle_nonzero_success(power.SetThreadExecutionState(mode))
try:
yield
finally:
handle_nonzero_success(power.SetThreadExecutionState(power.ES.continuous))

View file

@ -1,5 +1,3 @@
from __future__ import print_function
import ctypes
from ctypes import wintypes
@ -9,134 +7,138 @@ from .api import process
def get_process_token():
"""
Get the current process token
"""
token = wintypes.HANDLE()
res = process.OpenProcessToken(
process.GetCurrentProcess(), process.TOKEN_ALL_ACCESS, token)
if not res > 0:
raise RuntimeError("Couldn't get process token")
return token
"""
Get the current process token
"""
token = wintypes.HANDLE()
res = process.OpenProcessToken(
process.GetCurrentProcess(), process.TOKEN_ALL_ACCESS, token
)
if not res > 0:
raise RuntimeError("Couldn't get process token")
return token
def get_symlink_luid():
"""
Get the LUID for the SeCreateSymbolicLinkPrivilege
"""
symlink_luid = privilege.LUID()
res = privilege.LookupPrivilegeValue(
None, "SeCreateSymbolicLinkPrivilege", symlink_luid)
if not res > 0:
raise RuntimeError("Couldn't lookup privilege value")
return symlink_luid
"""
Get the LUID for the SeCreateSymbolicLinkPrivilege
"""
symlink_luid = privilege.LUID()
res = privilege.LookupPrivilegeValue(
None, "SeCreateSymbolicLinkPrivilege", symlink_luid
)
if not res > 0:
raise RuntimeError("Couldn't lookup privilege value")
return symlink_luid
def get_privilege_information():
"""
Get all privileges associated with the current process.
"""
# first call with zero length to determine what size buffer we need
"""
Get all privileges associated with the current process.
"""
# first call with zero length to determine what size buffer we need
return_length = wintypes.DWORD()
params = [
get_process_token(),
privilege.TOKEN_INFORMATION_CLASS.TokenPrivileges,
None,
0,
return_length,
]
return_length = wintypes.DWORD()
params = [
get_process_token(),
privilege.TOKEN_INFORMATION_CLASS.TokenPrivileges,
None,
0,
return_length,
]
res = privilege.GetTokenInformation(*params)
res = privilege.GetTokenInformation(*params)
# assume we now have the necessary length in return_length
# assume we now have the necessary length in return_length
buffer = ctypes.create_string_buffer(return_length.value)
params[2] = buffer
params[3] = return_length.value
buffer = ctypes.create_string_buffer(return_length.value)
params[2] = buffer
params[3] = return_length.value
res = privilege.GetTokenInformation(*params)
assert res > 0, "Error in second GetTokenInformation (%d)" % res
res = privilege.GetTokenInformation(*params)
assert res > 0, "Error in second GetTokenInformation (%d)" % res
privileges = ctypes.cast(
buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES)).contents
return privileges
privileges = ctypes.cast(
buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES)
).contents
return privileges
def report_privilege_information():
"""
Report all privilege information assigned to the current process.
"""
privileges = get_privilege_information()
print("found {0} privileges".format(privileges.count))
tuple(map(print, privileges))
"""
Report all privilege information assigned to the current process.
"""
privileges = get_privilege_information()
print("found {0} privileges".format(privileges.count))
tuple(map(print, privileges))
def enable_symlink_privilege():
"""
Try to assign the symlink privilege to the current process token.
Return True if the assignment is successful.
"""
# create a space in memory for a TOKEN_PRIVILEGES structure
# with one element
size = ctypes.sizeof(privilege.TOKEN_PRIVILEGES)
size += ctypes.sizeof(privilege.LUID_AND_ATTRIBUTES)
buffer = ctypes.create_string_buffer(size)
tp = ctypes.cast(buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES)).contents
tp.count = 1
tp.get_array()[0].enable()
tp.get_array()[0].LUID = get_symlink_luid()
token = get_process_token()
res = privilege.AdjustTokenPrivileges(token, False, tp, 0, None, None)
if res == 0:
raise RuntimeError("Error in AdjustTokenPrivileges")
"""
Try to assign the symlink privilege to the current process token.
Return True if the assignment is successful.
"""
# create a space in memory for a TOKEN_PRIVILEGES structure
# with one element
size = ctypes.sizeof(privilege.TOKEN_PRIVILEGES)
size += ctypes.sizeof(privilege.LUID_AND_ATTRIBUTES)
buffer = ctypes.create_string_buffer(size)
tp = ctypes.cast(buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES)).contents
tp.count = 1
tp.get_array()[0].enable()
tp.get_array()[0].LUID = get_symlink_luid()
token = get_process_token()
res = privilege.AdjustTokenPrivileges(token, False, tp, 0, None, None)
if res == 0:
raise RuntimeError("Error in AdjustTokenPrivileges")
ERROR_NOT_ALL_ASSIGNED = 1300
return ctypes.windll.kernel32.GetLastError() != ERROR_NOT_ALL_ASSIGNED
ERROR_NOT_ALL_ASSIGNED = 1300
return ctypes.windll.kernel32.GetLastError() != ERROR_NOT_ALL_ASSIGNED
class PolicyHandle(wintypes.HANDLE):
pass
pass
class LSA_UNICODE_STRING(ctypes.Structure):
_fields_ = [
('length', ctypes.c_ushort),
('max_length', ctypes.c_ushort),
('buffer', ctypes.wintypes.LPWSTR),
]
_fields_ = [
('length', ctypes.c_ushort),
('max_length', ctypes.c_ushort),
('buffer', ctypes.wintypes.LPWSTR),
]
def OpenPolicy(system_name, object_attributes, access_mask):
policy = PolicyHandle()
raise NotImplementedError(
"Need to construct structures for parameters "
"(see http://msdn.microsoft.com/en-us/library/windows"
"/desktop/aa378299%28v=vs.85%29.aspx)")
res = ctypes.windll.advapi32.LsaOpenPolicy(
system_name, object_attributes,
access_mask, ctypes.byref(policy))
assert res == 0, "Error status {res}".format(**vars())
return policy
policy = PolicyHandle()
raise NotImplementedError(
"Need to construct structures for parameters "
"(see http://msdn.microsoft.com/en-us/library/windows"
"/desktop/aa378299%28v=vs.85%29.aspx)"
)
res = ctypes.windll.advapi32.LsaOpenPolicy(
system_name, object_attributes, access_mask, ctypes.byref(policy)
)
assert res == 0, "Error status {res}".format(**vars())
return policy
def grant_symlink_privilege(who, machine=''):
"""
Grant the 'create symlink' privilege to who.
"""
Grant the 'create symlink' privilege to who.
Based on http://support.microsoft.com/kb/132958
"""
flags = security.POLICY_CREATE_ACCOUNT | security.POLICY_LOOKUP_NAMES
policy = OpenPolicy(machine, flags)
return policy
Based on http://support.microsoft.com/kb/132958
"""
flags = security.POLICY_CREATE_ACCOUNT | security.POLICY_LOOKUP_NAMES
policy = OpenPolicy(machine, flags)
return policy
def main():
assigned = enable_symlink_privilege()
msg = ['failure', 'success'][assigned]
assigned = enable_symlink_privilege()
msg = ['failure', 'success'][assigned]
print("Symlink privilege assignment completed with {0}".format(msg))
print("Symlink privilege assignment completed with {0}".format(msg))
if __name__ == '__main__':
main()
main()

View file

@ -1,20 +1,18 @@
import winreg
from itertools import count
import six
winreg = six.moves.winreg
def key_values(key):
for index in count():
try:
yield winreg.EnumValue(key, index)
except WindowsError:
break
for index in count():
try:
yield winreg.EnumValue(key, index)
except WindowsError:
break
def key_subkeys(key):
for index in count():
try:
yield winreg.EnumKey(key, index)
except WindowsError:
break
for index in count():
try:
yield winreg.EnumKey(key, index)
except WindowsError:
break

View file

@ -1,35 +1,34 @@
from __future__ import division
import ctypes.wintypes
from .error import handle_nonzero_success
from .api import filesystem
def DeviceIoControl(
device, io_control_code, in_buffer, out_buffer, overlapped=None):
if overlapped is not None:
raise NotImplementedError("overlapped handles not yet supported")
def DeviceIoControl(device, io_control_code, in_buffer, out_buffer, overlapped=None):
if overlapped is not None:
raise NotImplementedError("overlapped handles not yet supported")
if isinstance(out_buffer, int):
out_buffer = ctypes.create_string_buffer(out_buffer)
if isinstance(out_buffer, int):
out_buffer = ctypes.create_string_buffer(out_buffer)
in_buffer_size = len(in_buffer) if in_buffer is not None else 0
out_buffer_size = len(out_buffer)
assert isinstance(out_buffer, ctypes.Array)
in_buffer_size = len(in_buffer) if in_buffer is not None else 0
out_buffer_size = len(out_buffer)
assert isinstance(out_buffer, ctypes.Array)
returned_bytes = ctypes.wintypes.DWORD()
returned_bytes = ctypes.wintypes.DWORD()
res = filesystem.DeviceIoControl(
device,
io_control_code,
in_buffer, in_buffer_size,
out_buffer, out_buffer_size,
returned_bytes,
overlapped,
)
res = filesystem.DeviceIoControl(
device,
io_control_code,
in_buffer,
in_buffer_size,
out_buffer,
out_buffer_size,
returned_bytes,
overlapped,
)
handle_nonzero_success(res)
handle_nonzero_success(returned_bytes)
handle_nonzero_success(res)
handle_nonzero_success(returned_bytes)
return out_buffer[:returned_bytes.value]
return out_buffer[: returned_bytes.value]

View file

@ -5,63 +5,66 @@ from .api import security
def GetTokenInformation(token, information_class):
"""
Given a token, get the token information for it.
"""
data_size = ctypes.wintypes.DWORD()
ctypes.windll.advapi32.GetTokenInformation(
token, information_class.num,
0, 0, ctypes.byref(data_size))
data = ctypes.create_string_buffer(data_size.value)
handle_nonzero_success(ctypes.windll.advapi32.GetTokenInformation(
token,
information_class.num,
ctypes.byref(data), ctypes.sizeof(data),
ctypes.byref(data_size)))
return ctypes.cast(data, ctypes.POINTER(security.TOKEN_USER)).contents
"""
Given a token, get the token information for it.
"""
data_size = ctypes.wintypes.DWORD()
ctypes.windll.advapi32.GetTokenInformation(
token, information_class.num, 0, 0, ctypes.byref(data_size)
)
data = ctypes.create_string_buffer(data_size.value)
handle_nonzero_success(
ctypes.windll.advapi32.GetTokenInformation(
token,
information_class.num,
ctypes.byref(data),
ctypes.sizeof(data),
ctypes.byref(data_size),
)
)
return ctypes.cast(data, ctypes.POINTER(security.TOKEN_USER)).contents
def OpenProcessToken(proc_handle, access):
result = ctypes.wintypes.HANDLE()
proc_handle = ctypes.wintypes.HANDLE(proc_handle)
handle_nonzero_success(ctypes.windll.advapi32.OpenProcessToken(
proc_handle, access, ctypes.byref(result)))
return result
result = ctypes.wintypes.HANDLE()
proc_handle = ctypes.wintypes.HANDLE(proc_handle)
handle_nonzero_success(
ctypes.windll.advapi32.OpenProcessToken(
proc_handle, access, ctypes.byref(result)
)
)
return result
def get_current_user():
"""
Return a TOKEN_USER for the owner of this process.
"""
process = OpenProcessToken(
ctypes.windll.kernel32.GetCurrentProcess(),
security.TokenAccess.TOKEN_QUERY,
)
return GetTokenInformation(process, security.TOKEN_USER)
"""
Return a TOKEN_USER for the owner of this process.
"""
process = OpenProcessToken(
ctypes.windll.kernel32.GetCurrentProcess(), security.TokenAccess.TOKEN_QUERY
)
return GetTokenInformation(process, security.TOKEN_USER)
def get_security_attributes_for_user(user=None):
"""
Return a SECURITY_ATTRIBUTES structure with the SID set to the
specified user (uses current user if none is specified).
"""
if user is None:
user = get_current_user()
"""
Return a SECURITY_ATTRIBUTES structure with the SID set to the
specified user (uses current user if none is specified).
"""
if user is None:
user = get_current_user()
assert isinstance(user, security.TOKEN_USER), (
"user must be TOKEN_USER instance")
assert isinstance(user, security.TOKEN_USER), "user must be TOKEN_USER instance"
SD = security.SECURITY_DESCRIPTOR()
SA = security.SECURITY_ATTRIBUTES()
# by attaching the actual security descriptor, it will be garbage-
# collected with the security attributes
SA.descriptor = SD
SA.bInheritHandle = 1
SD = security.SECURITY_DESCRIPTOR()
SA = security.SECURITY_ATTRIBUTES()
# by attaching the actual security descriptor, it will be garbage-
# collected with the security attributes
SA.descriptor = SD
SA.bInheritHandle = 1
ctypes.windll.advapi32.InitializeSecurityDescriptor(
ctypes.byref(SD),
security.SECURITY_DESCRIPTOR.REVISION)
ctypes.windll.advapi32.SetSecurityDescriptorOwner(
ctypes.byref(SD),
user.SID, 0)
return SA
ctypes.windll.advapi32.InitializeSecurityDescriptor(
ctypes.byref(SD), security.SECURITY_DESCRIPTOR.REVISION
)
ctypes.windll.advapi32.SetSecurityDescriptorOwner(ctypes.byref(SD), user.SID, 0)
return SA

View file

@ -5,8 +5,6 @@ Based on http://code.activestate.com
/recipes/115875-controlling-windows-services/
"""
from __future__ import print_function
import sys
import time
@ -16,221 +14,240 @@ import win32service
class Service(object):
"""
The Service Class is used for controlling Windows
services. Just pass the name of the service you wish to control to the
class instance and go from there. For example, if you want to control
the Workstation service try this:
"""
The Service Class is used for controlling Windows
services. Just pass the name of the service you wish to control to the
class instance and go from there. For example, if you want to control
the Workstation service try this:
from jaraco.windows import services
workstation = services.Service("Workstation")
workstation.start()
workstation.fetchstatus("running", 10)
workstation.stop()
workstation.fetchstatus("stopped")
from jaraco.windows import services
workstation = services.Service("Workstation")
workstation.start()
workstation.fetchstatus("running", 10)
workstation.stop()
workstation.fetchstatus("stopped")
Creating an instance of the Service class is done by passing the name of
the service as it appears in the Management Console or the short name as
it appears in the registry. Mixed case is ok.
cvs = services.Service("CVS NT Service 1.11.1.2 (Build 41)")
or
cvs = services.Service("cvs")
Creating an instance of the Service class is done by passing the name of
the service as it appears in the Management Console or the short name as
it appears in the registry. Mixed case is ok.
cvs = services.Service("CVS NT Service 1.11.1.2 (Build 41)")
or
cvs = services.Service("cvs")
If needing remote service control try this:
cvs = services.Service("cvs", r"\\CVS_SERVER")
or
cvs = services.Service("cvs", "\\\\CVS_SERVER")
If needing remote service control try this:
cvs = services.Service("cvs", r"\\CVS_SERVER")
or
cvs = services.Service("cvs", "\\\\CVS_SERVER")
The Service Class supports these methods:
The Service Class supports these methods:
start: Starts service.
stop: Stops service.
restart: Stops and restarts service.
pause: Pauses service (Only if service supports feature).
resume: Resumes service that has been paused.
status: Queries current status of service.
fetchstatus: Continually queries service until requested
status(STARTING, RUNNING,
STOPPING & STOPPED) is met or timeout value(in seconds) reached.
Default timeout value is infinite.
infotype: Queries service for process type. (Single, shared and/or
interactive process)
infoctrl: Queries control information about a running service.
i.e. Can it be paused, stopped, etc?
infostartup: Queries service Startup type. (Boot, System,
Automatic, Manual, Disabled)
setstartup Changes/sets Startup type. (Boot, System,
Automatic, Manual, Disabled)
getname: Gets the long and short service names used by Windowin32service.
(Generally used for internal purposes)
"""
start: Starts service.
stop: Stops service.
restart: Stops and restarts service.
pause: Pauses service (Only if service supports feature).
resume: Resumes service that has been paused.
status: Queries current status of service.
fetchstatus: Continually queries service until requested
status(STARTING, RUNNING,
STOPPING & STOPPED) is met or timeout value(in seconds) reached.
Default timeout value is infinite.
infotype: Queries service for process type. (Single, shared and/or
interactive process)
infoctrl: Queries control information about a running service.
i.e. Can it be paused, stopped, etc?
infostartup: Queries service Startup type. (Boot, System,
Automatic, Manual, Disabled)
setstartup: Changes/sets Startup type. (Boot, System,
Automatic, Manual, Disabled)
getname: Gets the long and short service names used by Windowin32service.
(Generally used for internal purposes)
"""
def __init__(self, service, machinename=None, dbname=None):
self.userv = service
self.scmhandle = win32service.OpenSCManager(
machinename, dbname, win32service.SC_MANAGER_ALL_ACCESS)
self.sserv, self.lserv = self.getname()
if (self.sserv or self.lserv) is None:
sys.exit()
self.handle = win32service.OpenService(
self.scmhandle, self.sserv, win32service.SERVICE_ALL_ACCESS)
self.sccss = "SYSTEM\\CurrentControlSet\\Services\\"
def __init__(self, service, machinename=None, dbname=None):
self.userv = service
self.scmhandle = win32service.OpenSCManager(
machinename, dbname, win32service.SC_MANAGER_ALL_ACCESS
)
self.sserv, self.lserv = self.getname()
if (self.sserv or self.lserv) is None:
sys.exit()
self.handle = win32service.OpenService(
self.scmhandle, self.sserv, win32service.SERVICE_ALL_ACCESS
)
self.sccss = "SYSTEM\\CurrentControlSet\\Services\\"
def start(self):
win32service.StartService(self.handle, None)
def start(self):
win32service.StartService(self.handle, None)
def stop(self):
self.stat = win32service.ControlService(
self.handle, win32service.SERVICE_CONTROL_STOP)
def stop(self):
self.stat = win32service.ControlService(
self.handle, win32service.SERVICE_CONTROL_STOP
)
def restart(self):
self.stop()
self.fetchstatus("STOPPED")
self.start()
def restart(self):
self.stop()
self.fetchstatus("STOPPED")
self.start()
def pause(self):
self.stat = win32service.ControlService(
self.handle, win32service.SERVICE_CONTROL_PAUSE)
def pause(self):
self.stat = win32service.ControlService(
self.handle, win32service.SERVICE_CONTROL_PAUSE
)
def resume(self):
self.stat = win32service.ControlService(
self.handle, win32service.SERVICE_CONTROL_CONTINUE)
def resume(self):
self.stat = win32service.ControlService(
self.handle, win32service.SERVICE_CONTROL_CONTINUE
)
def status(self, prn=0):
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[1] == win32service.SERVICE_STOPPED:
if prn == 1:
print("The %s service is stopped." % self.lserv)
else:
return "STOPPED"
elif self.stat[1] == win32service.SERVICE_START_PENDING:
if prn == 1:
print("The %s service is starting." % self.lserv)
else:
return "STARTING"
elif self.stat[1] == win32service.SERVICE_STOP_PENDING:
if prn == 1:
print("The %s service is stopping." % self.lserv)
else:
return "STOPPING"
elif self.stat[1] == win32service.SERVICE_RUNNING:
if prn == 1:
print("The %s service is running." % self.lserv)
else:
return "RUNNING"
def status(self, prn=0):
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[1] == win32service.SERVICE_STOPPED:
if prn == 1:
print("The %s service is stopped." % self.lserv)
else:
return "STOPPED"
elif self.stat[1] == win32service.SERVICE_START_PENDING:
if prn == 1:
print("The %s service is starting." % self.lserv)
else:
return "STARTING"
elif self.stat[1] == win32service.SERVICE_STOP_PENDING:
if prn == 1:
print("The %s service is stopping." % self.lserv)
else:
return "STOPPING"
elif self.stat[1] == win32service.SERVICE_RUNNING:
if prn == 1:
print("The %s service is running." % self.lserv)
else:
return "RUNNING"
def fetchstatus(self, fstatus, timeout=None):
self.fstatus = fstatus.upper()
if timeout is not None:
timeout = int(timeout)
timeout *= 2
def fetchstatus(self, fstatus, timeout=None):
self.fstatus = fstatus.upper()
if timeout is not None:
timeout = int(timeout)
timeout *= 2
def to(timeout):
time.sleep(.5)
if timeout is not None:
if timeout > 1:
timeout -= 1
return timeout
else:
return "TO"
if self.fstatus == "STOPPED":
while 1:
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[1] == win32service.SERVICE_STOPPED:
self.fstate = "STOPPED"
break
else:
timeout = to(timeout)
if timeout == "TO":
return "TIMEDOUT"
break
elif self.fstatus == "STOPPING":
while 1:
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[1]==win32service.SERVICE_STOP_PENDING:
self.fstate = "STOPPING"
break
else:
timeout=to(timeout)
if timeout == "TO":
return "TIMEDOUT"
break
elif self.fstatus == "RUNNING":
while 1:
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[1]==win32service.SERVICE_RUNNING:
self.fstate = "RUNNING"
break
else:
timeout=to(timeout)
if timeout == "TO":
return "TIMEDOUT"
break
elif self.fstatus == "STARTING":
while 1:
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[1]==win32service.SERVICE_START_PENDING:
self.fstate = "STARTING"
break
else:
timeout=to(timeout)
if timeout == "TO":
return "TIMEDOUT"
break
def to(timeout):
time.sleep(0.5)
if timeout is not None:
if timeout > 1:
timeout -= 1
return timeout
else:
return "TO"
def infotype(self):
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[0] and win32service.SERVICE_WIN32_OWN_PROCESS:
print("The %s service runs in its own process." % self.lserv)
if self.stat[0] and win32service.SERVICE_WIN32_SHARE_PROCESS:
print("The %s service shares a process with other services." % self.lserv)
if self.stat[0] and win32service.SERVICE_INTERACTIVE_PROCESS:
print("The %s service can interact with the desktop." % self.lserv)
if self.fstatus == "STOPPED":
while 1:
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[1] == win32service.SERVICE_STOPPED:
self.fstate = "STOPPED"
break
else:
timeout = to(timeout)
if timeout == "TO":
return "TIMEDOUT"
break
elif self.fstatus == "STOPPING":
while 1:
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[1] == win32service.SERVICE_STOP_PENDING:
self.fstate = "STOPPING"
break
else:
timeout = to(timeout)
if timeout == "TO":
return "TIMEDOUT"
break
elif self.fstatus == "RUNNING":
while 1:
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[1] == win32service.SERVICE_RUNNING:
self.fstate = "RUNNING"
break
else:
timeout = to(timeout)
if timeout == "TO":
return "TIMEDOUT"
break
elif self.fstatus == "STARTING":
while 1:
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[1] == win32service.SERVICE_START_PENDING:
self.fstate = "STARTING"
break
else:
timeout = to(timeout)
if timeout == "TO":
return "TIMEDOUT"
break
def infoctrl(self):
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[2] and win32service.SERVICE_ACCEPT_PAUSE_CONTINUE:
print("The %s service can be paused." % self.lserv)
if self.stat[2] and win32service.SERVICE_ACCEPT_STOP:
print("The %s service can be stopped." % self.lserv)
if self.stat[2] and win32service.SERVICE_ACCEPT_SHUTDOWN:
print("The %s service can be shutdown." % self.lserv)
def infotype(self):
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[0] and win32service.SERVICE_WIN32_OWN_PROCESS:
print("The %s service runs in its own process." % self.lserv)
if self.stat[0] and win32service.SERVICE_WIN32_SHARE_PROCESS:
print("The %s service shares a process with other services." % self.lserv)
if self.stat[0] and win32service.SERVICE_INTERACTIVE_PROCESS:
print("The %s service can interact with the desktop." % self.lserv)
def infostartup(self):
self.isuphandle = win32api.RegOpenKeyEx(win32con.HKEY_LOCAL_MACHINE, self.sccss + self.sserv, 0, win32con.KEY_READ)
self.isuptype = win32api.RegQueryValueEx(self.isuphandle, "Start")[0]
win32api.RegCloseKey(self.isuphandle)
if self.isuptype == 0:
return "boot"
elif self.isuptype == 1:
return "system"
elif self.isuptype == 2:
return "automatic"
elif self.isuptype == 3:
return "manual"
elif self.isuptype == 4:
return "disabled"
def infoctrl(self):
self.stat = win32service.QueryServiceStatus(self.handle)
if self.stat[2] and win32service.SERVICE_ACCEPT_PAUSE_CONTINUE:
print("The %s service can be paused." % self.lserv)
if self.stat[2] and win32service.SERVICE_ACCEPT_STOP:
print("The %s service can be stopped." % self.lserv)
if self.stat[2] and win32service.SERVICE_ACCEPT_SHUTDOWN:
print("The %s service can be shutdown." % self.lserv)
@property
def suptype(self):
types = 'boot', 'system', 'automatic', 'manual', 'disabled'
lookup = dict((name, number) for number, name in enumerate(types))
return lookup[self.startuptype]
def infostartup(self):
self.isuphandle = win32api.RegOpenKeyEx(
win32con.HKEY_LOCAL_MACHINE, self.sccss + self.sserv, 0, win32con.KEY_READ
)
self.isuptype = win32api.RegQueryValueEx(self.isuphandle, "Start")[0]
win32api.RegCloseKey(self.isuphandle)
if self.isuptype == 0:
return "boot"
elif self.isuptype == 1:
return "system"
elif self.isuptype == 2:
return "automatic"
elif self.isuptype == 3:
return "manual"
elif self.isuptype == 4:
return "disabled"
def setstartup(self, startuptype):
self.startuptype = startuptype.lower()
self.snc = win32service.SERVICE_NO_CHANGE
win32service.ChangeServiceConfig(self.handle, self.snc, self.suptype,
self.snc, None, None, 0, None, None, None, self.lserv)
@property
def suptype(self):
types = 'boot', 'system', 'automatic', 'manual', 'disabled'
lookup = dict((name, number) for number, name in enumerate(types))
return lookup[self.startuptype]
def getname(self):
self.snames=win32service.EnumServicesStatus(self.scmhandle)
for i in self.snames:
if i[0].lower() == self.userv.lower():
return i[0], i[1]
break
if i[1].lower() == self.userv.lower():
return i[0], i[1]
break
print("Error: The %s service doesn't seem to exist." % self.userv)
return None, None
def setstartup(self, startuptype):
self.startuptype = startuptype.lower()
self.snc = win32service.SERVICE_NO_CHANGE
win32service.ChangeServiceConfig(
self.handle,
self.snc,
self.suptype,
self.snc,
None,
None,
0,
None,
None,
None,
self.lserv,
)
def getname(self):
self.snames = win32service.EnumServicesStatus(self.scmhandle)
for i in self.snames:
if i[0].lower() == self.userv.lower():
return i[0], i[1]
break
if i[1].lower() == self.userv.lower():
return i[0], i[1]
break
print("Error: The %s service doesn't seem to exist." % self.userv)
return None, None

View file

@ -2,13 +2,13 @@ from .api import shell
def get_recycle_bin_confirm():
settings = shell.SHELLSTATE()
shell.SHGetSetSettings(settings, shell.SSF_NOCONFIRMRECYCLE, False)
return not settings.no_confirm_recycle
settings = shell.SHELLSTATE()
shell.SHGetSetSettings(settings, shell.SSF_NOCONFIRMRECYCLE, False)
return not settings.no_confirm_recycle
def set_recycle_bin_confirm(confirm=False):
settings = shell.SHELLSTATE()
settings.no_confirm_recycle = not confirm
shell.SHGetSetSettings(settings, shell.SSF_NOCONFIRMRECYCLE, True)
# cross fingers and hope it worked
settings = shell.SHELLSTATE()
settings.no_confirm_recycle = not confirm
shell.SHGetSetSettings(settings, shell.SSF_NOCONFIRMRECYCLE, True)
# cross fingers and hope it worked

View file

@ -1,71 +1,66 @@
# -*- coding: UTF-8 -*-
"""
timers
In particular, contains a waitable timer.
In particular, contains a waitable timer.
"""
from __future__ import absolute_import
import time
from six.moves import _thread
import _thread
from jaraco.windows.api import event as win32event
__author__ = 'Jason R. Coombs <jaraco@jaraco.com>'
class WaitableTimer:
"""
t = WaitableTimer()
t.set(None, 10) # every 10 seconds
t.wait_for_signal() # 10 seconds elapses
t.stop()
t.wait_for_signal(20) # 20 seconds elapses (timeout occurred)
"""
def __init__(self):
self.signal_event = win32event.CreateEvent(None, 0, 0, None)
self.stop_event = win32event.CreateEvent(None, 0, 0, None)
"""
t = WaitableTimer()
t.set(None, 10) # every 10 seconds
t.wait_for_signal() # 10 seconds elapses
t.stop()
t.wait_for_signal(20) # 20 seconds elapses (timeout occurred)
"""
def set(self, due_time, period):
_thread.start_new_thread(self._signal_loop, (due_time, period))
def __init__(self):
self.signal_event = win32event.CreateEvent(None, 0, 0, None)
self.stop_event = win32event.CreateEvent(None, 0, 0, None)
def stop(self):
win32event.SetEvent(self.stop_event)
def set(self, due_time, period):
_thread.start_new_thread(self._signal_loop, (due_time, period))
def wait_for_signal(self, timeout=None):
"""
wait for the signal; return after the signal has occurred or the
timeout in seconds elapses.
"""
timeout_ms = int(timeout * 1000) if timeout else win32event.INFINITE
win32event.WaitForSingleObject(self.signal_event, timeout_ms)
def stop(self):
win32event.SetEvent(self.stop_event)
def _signal_loop(self, due_time, period):
if not due_time and not period:
raise ValueError("due_time or period must be non-zero")
try:
if not due_time:
due_time = time.time() + period
if due_time:
self._wait(due_time - time.time())
while period:
due_time += period
self._wait(due_time - time.time())
except Exception:
pass
def wait_for_signal(self, timeout=None):
"""
wait for the signal; return after the signal has occurred or the
timeout in seconds elapses.
"""
timeout_ms = int(timeout * 1000) if timeout else win32event.INFINITE
win32event.WaitForSingleObject(self.signal_event, timeout_ms)
def _wait(self, seconds):
milliseconds = int(seconds * 1000)
if milliseconds > 0:
res = win32event.WaitForSingleObject(self.stop_event, milliseconds)
if res == win32event.WAIT_OBJECT_0:
raise Exception
if res == win32event.WAIT_TIMEOUT:
pass
win32event.SetEvent(self.signal_event)
def _signal_loop(self, due_time, period):
if not due_time and not period:
raise ValueError("due_time or period must be non-zero")
try:
if not due_time:
due_time = time.time() + period
if due_time:
self._wait(due_time - time.time())
while period:
due_time += period
self._wait(due_time - time.time())
except Exception:
pass
@staticmethod
def get_even_due_time(period):
now = time.time()
return now - (now % period)
def _wait(self, seconds):
milliseconds = int(seconds * 1000)
if milliseconds > 0:
res = win32event.WaitForSingleObject(self.stop_event, milliseconds)
if res == win32event.WAIT_OBJECT_0:
raise Exception
if res == win32event.WAIT_TIMEOUT:
pass
win32event.SetEvent(self.signal_event)
@staticmethod
def get_even_due_time(period):
now = time.time()
return now - (now % period)

View file

@ -10,245 +10,253 @@ from jaraco.collections import RangeMap
class AnyDict(object):
"A dictionary that returns the same value regardless of key"
"A dictionary that returns the same value regardless of key"
def __init__(self, value):
self.value = value
def __init__(self, value):
self.value = value
def __getitem__(self, key):
return self.value
def __getitem__(self, key):
return self.value
class SYSTEMTIME(Extended, ctypes.Structure):
_fields_ = [
('year', WORD),
('month', WORD),
('day_of_week', WORD),
('day', WORD),
('hour', WORD),
('minute', WORD),
('second', WORD),
('millisecond', WORD),
]
_fields_ = [
('year', WORD),
('month', WORD),
('day_of_week', WORD),
('day', WORD),
('hour', WORD),
('minute', WORD),
('second', WORD),
('millisecond', WORD),
]
class REG_TZI_FORMAT(Extended, ctypes.Structure):
_fields_ = [
('bias', LONG),
('standard_bias', LONG),
('daylight_bias', LONG),
('standard_start', SYSTEMTIME),
('daylight_start', SYSTEMTIME),
]
_fields_ = [
('bias', LONG),
('standard_bias', LONG),
('daylight_bias', LONG),
('standard_start', SYSTEMTIME),
('daylight_start', SYSTEMTIME),
]
class TIME_ZONE_INFORMATION(Extended, ctypes.Structure):
_fields_ = [
('bias', LONG),
('standard_name', WCHAR * 32),
('standard_start', SYSTEMTIME),
('standard_bias', LONG),
('daylight_name', WCHAR * 32),
('daylight_start', SYSTEMTIME),
('daylight_bias', LONG),
]
_fields_ = [
('bias', LONG),
('standard_name', WCHAR * 32),
('standard_start', SYSTEMTIME),
('standard_bias', LONG),
('daylight_name', WCHAR * 32),
('daylight_start', SYSTEMTIME),
('daylight_bias', LONG),
]
class DYNAMIC_TIME_ZONE_INFORMATION(TIME_ZONE_INFORMATION):
"""
Because the structure of the DYNAMIC_TIME_ZONE_INFORMATION extends
the structure of the TIME_ZONE_INFORMATION, this structure
can be used as a drop-in replacement for calls where the
structure is passed by reference.
"""
Because the structure of the DYNAMIC_TIME_ZONE_INFORMATION extends
the structure of the TIME_ZONE_INFORMATION, this structure
can be used as a drop-in replacement for calls where the
structure is passed by reference.
For example,
dynamic_tzi = DYNAMIC_TIME_ZONE_INFORMATION()
ctypes.windll.kernel32.GetTimeZoneInformation(ctypes.byref(dynamic_tzi))
For example,
dynamic_tzi = DYNAMIC_TIME_ZONE_INFORMATION()
ctypes.windll.kernel32.GetTimeZoneInformation(ctypes.byref(dynamic_tzi))
(although the key_name and dynamic_daylight_time_disabled flags will be
set to the default (null)).
(although the key_name and dynamic_daylight_time_disabled flags will be
set to the default (null)).
>>> isinstance(DYNAMIC_TIME_ZONE_INFORMATION(), TIME_ZONE_INFORMATION)
True
>>> isinstance(DYNAMIC_TIME_ZONE_INFORMATION(), TIME_ZONE_INFORMATION)
True
"""
_fields_ = [
# ctypes automatically includes the fields from the parent
('key_name', WCHAR * 128),
('dynamic_daylight_time_disabled', BOOL),
]
"""
def __init__(self, *args, **kwargs):
"""Allow initialization from args from both this class and
its superclass. Default ctypes implementation seems to
assume that this class is only initialized with its own
_fields_ (for non-keyword-args)."""
super_self = super(DYNAMIC_TIME_ZONE_INFORMATION, self)
super_fields = super_self._fields_
super_args = args[:len(super_fields)]
self_args = args[len(super_fields):]
# convert the super args to keyword args so they're also handled
for field, arg in zip(super_fields, super_args):
field_name, spec = field
kwargs[field_name] = arg
super(DYNAMIC_TIME_ZONE_INFORMATION, self).__init__(*self_args, **kwargs)
_fields_ = [
# ctypes automatically includes the fields from the parent
('key_name', WCHAR * 128),
('dynamic_daylight_time_disabled', BOOL),
]
def __init__(self, *args, **kwargs):
"""Allow initialization from args from both this class and
its superclass. Default ctypes implementation seems to
assume that this class is only initialized with its own
_fields_ (for non-keyword-args)."""
super_self = super(DYNAMIC_TIME_ZONE_INFORMATION, self)
super_fields = super_self._fields_
super_args = args[: len(super_fields)]
self_args = args[len(super_fields) :]
# convert the super args to keyword args so they're also handled
for field, arg in zip(super_fields, super_args):
field_name, spec = field
kwargs[field_name] = arg
super(DYNAMIC_TIME_ZONE_INFORMATION, self).__init__(*self_args, **kwargs)
class Info(DYNAMIC_TIME_ZONE_INFORMATION):
"""
A time zone definition class based on the win32
DYNAMIC_TIME_ZONE_INFORMATION structure.
"""
A time zone definition class based on the win32
DYNAMIC_TIME_ZONE_INFORMATION structure.
Describes a bias against UTC (bias), and two dates at which a separate
additional bias applies (standard_bias and daylight_bias).
"""
Describes a bias against UTC (bias), and two dates at which a separate
additional bias applies (standard_bias and daylight_bias).
"""
def field_names(self):
return map(operator.itemgetter(0), self._fields_)
def field_names(self):
return map(operator.itemgetter(0), self._fields_)
def __init__(self, *args, **kwargs):
"""
Try to construct a timezone.Info from
a) [DYNAMIC_]TIME_ZONE_INFORMATION args
b) another Info
c) a REG_TZI_FORMAT
d) a byte structure
"""
funcs = (
super(Info, self).__init__,
self.__init_from_other,
self.__init_from_reg_tzi,
self.__init_from_bytes,
)
for func in funcs:
try:
func(*args, **kwargs)
return
except TypeError:
pass
raise TypeError("Invalid arguments for %s" % self.__class__)
def __init__(self, *args, **kwargs):
"""
Try to construct a timezone.Info from
a) [DYNAMIC_]TIME_ZONE_INFORMATION args
b) another Info
c) a REG_TZI_FORMAT
d) a byte structure
"""
funcs = (
super(Info, self).__init__,
self.__init_from_other,
self.__init_from_reg_tzi,
self.__init_from_bytes,
)
for func in funcs:
try:
func(*args, **kwargs)
return
except TypeError:
pass
raise TypeError("Invalid arguments for %s" % self.__class__)
def __init_from_bytes(self, bytes, **kwargs):
reg_tzi = REG_TZI_FORMAT()
# todo: use buffer API in Python 3
buffer = memoryview(bytes)
ctypes.memmove(ctypes.addressof(reg_tzi), buffer, len(buffer))
self.__init_from_reg_tzi(self, reg_tzi, **kwargs)
def __init_from_bytes(self, bytes, **kwargs):
reg_tzi = REG_TZI_FORMAT()
# todo: use buffer API in Python 3
buffer = memoryview(bytes)
ctypes.memmove(ctypes.addressof(reg_tzi), buffer, len(buffer))
self.__init_from_reg_tzi(self, reg_tzi, **kwargs)
def __init_from_reg_tzi(self, reg_tzi, **kwargs):
if not isinstance(reg_tzi, REG_TZI_FORMAT):
raise TypeError("Not a REG_TZI_FORMAT")
for field_name, type in reg_tzi._fields_:
setattr(self, field_name, getattr(reg_tzi, field_name))
for name, value in kwargs.items():
setattr(self, name, value)
def __init_from_reg_tzi(self, reg_tzi, **kwargs):
if not isinstance(reg_tzi, REG_TZI_FORMAT):
raise TypeError("Not a REG_TZI_FORMAT")
for field_name, type in reg_tzi._fields_:
setattr(self, field_name, getattr(reg_tzi, field_name))
for name, value in kwargs.items():
setattr(self, name, value)
def __init_from_other(self, other):
if not isinstance(other, TIME_ZONE_INFORMATION):
raise TypeError("Not a TIME_ZONE_INFORMATION")
for name in other.field_names():
# explicitly get the value from the underlying structure
value = super(Info, other).__getattribute__(other, name)
setattr(self, name, value)
# consider instead of the loop above just copying the memory directly
# size = max(ctypes.sizeof(DYNAMIC_TIME_ZONE_INFO), ctypes.sizeof(other))
# ctypes.memmove(ctypes.addressof(self), other, size)
def __init_from_other(self, other):
if not isinstance(other, TIME_ZONE_INFORMATION):
raise TypeError("Not a TIME_ZONE_INFORMATION")
for name in other.field_names():
# explicitly get the value from the underlying structure
value = super(Info, other).__getattribute__(other, name)
setattr(self, name, value)
# consider instead of the loop above just copying the memory directly
# size = max(ctypes.sizeof(DYNAMIC_TIME_ZONE_INFO), ctypes.sizeof(other))
# ctypes.memmove(ctypes.addressof(self), other, size)
def __getattribute__(self, attr):
value = super(Info, self).__getattribute__(attr)
def __getattribute__(self, attr):
value = super(Info, self).__getattribute__(attr)
def make_minute_timedelta(m):
datetime.timedelta(minutes=m)
if 'bias' in attr:
value = make_minute_timedelta(value)
return value
def make_minute_timedelta(m):
datetime.timedelta(minutes=m)
@classmethod
def current(class_):
"Windows Platform SDK GetTimeZoneInformation"
tzi = class_()
kernel32 = ctypes.windll.kernel32
getter = kernel32.GetTimeZoneInformation
getter = getattr(kernel32, 'GetDynamicTimeZoneInformation', getter)
code = getter(ctypes.byref(tzi))
return code, tzi
if 'bias' in attr:
value = make_minute_timedelta(value)
return value
def set(self):
kernel32 = ctypes.windll.kernel32
setter = kernel32.SetTimeZoneInformation
setter = getattr(kernel32, 'SetDynamicTimeZoneInformation', setter)
return setter(ctypes.byref(self))
@classmethod
def current(class_):
"Windows Platform SDK GetTimeZoneInformation"
tzi = class_()
kernel32 = ctypes.windll.kernel32
getter = kernel32.GetTimeZoneInformation
getter = getattr(kernel32, 'GetDynamicTimeZoneInformation', getter)
code = getter(ctypes.byref(tzi))
return code, tzi
def copy(self):
return self.__class__(self)
def set(self):
kernel32 = ctypes.windll.kernel32
setter = kernel32.SetTimeZoneInformation
setter = getattr(kernel32, 'SetDynamicTimeZoneInformation', setter)
return setter(ctypes.byref(self))
def locate_daylight_start(self, year):
info = self.get_info_for_year(year)
return self._locate_day(year, info.daylight_start)
def copy(self):
return self.__class__(self)
def locate_standard_start(self, year):
info = self.get_info_for_year(year)
return self._locate_day(year, info.standard_start)
def locate_daylight_start(self, year):
info = self.get_info_for_year(year)
return self._locate_day(year, info.daylight_start)
def get_info_for_year(self, year):
return self.dynamic_info[year]
def locate_standard_start(self, year):
info = self.get_info_for_year(year)
return self._locate_day(year, info.standard_start)
@property
def dynamic_info(self):
"Return a map that for a given year will return the correct Info"
if self.key_name:
dyn_key = self.get_key().subkey('Dynamic DST')
del dyn_key['FirstEntry']
del dyn_key['LastEntry']
years = map(int, dyn_key.keys())
values = map(Info, dyn_key.values())
# create a range mapping that searches by descending year and matches
# if the target year is greater or equal.
return RangeMap(zip(years, values), RangeMap.descending, operator.ge)
else:
return AnyDict(self)
def get_info_for_year(self, year):
return self.dynamic_info[year]
@staticmethod
def _locate_day(year, cutoff):
"""
Takes a SYSTEMTIME object, such as retrieved from a TIME_ZONE_INFORMATION
structure or call to GetTimeZoneInformation and interprets
it based on the given
year to identify the actual day.
@property
def dynamic_info(self):
"Return a map that for a given year will return the correct Info"
if self.key_name:
dyn_key = self.get_key().subkey('Dynamic DST')
del dyn_key['FirstEntry']
del dyn_key['LastEntry']
years = map(int, dyn_key.keys())
values = map(Info, dyn_key.values())
# create a range mapping that searches by descending year and matches
# if the target year is greater or equal.
return RangeMap(zip(years, values), RangeMap.descending, operator.ge)
else:
return AnyDict(self)
This method is necessary because the SYSTEMTIME structure
refers to a day by its
day of the week and week of the month (e.g. 4th saturday in March).
@staticmethod
def _locate_day(year, cutoff):
"""
Takes a SYSTEMTIME object, such as retrieved from a TIME_ZONE_INFORMATION
structure or call to GetTimeZoneInformation and interprets
it based on the given
year to identify the actual day.
>>> SATURDAY = 6
>>> MARCH = 3
>>> st = SYSTEMTIME(2000, MARCH, SATURDAY, 4, 0, 0, 0, 0)
This method is necessary because the SYSTEMTIME structure
refers to a day by its
day of the week and week of the month (e.g. 4th saturday in March).
# according to my calendar, the 4th Saturday in March in 2009 was the 28th
>>> expected_date = datetime.datetime(2009, 3, 28)
>>> Info._locate_day(2009, st) == expected_date
True
"""
# MS stores Sunday as 0, Python datetime stores Monday as zero
target_weekday = (cutoff.day_of_week + 6) % 7
# For SYSTEMTIMEs relating to time zone inforamtion, cutoff.day
# is the week of the month
week_of_month = cutoff.day
# so the following is the first day of that week
day = (week_of_month - 1) * 7 + 1
result = datetime.datetime(
year, cutoff.month, day,
cutoff.hour, cutoff.minute, cutoff.second, cutoff.millisecond)
# now the result is the correct week, but not necessarily
# the correct day of the week
days_to_go = (target_weekday - result.weekday()) % 7
result += datetime.timedelta(days_to_go)
# if we selected a day in the month following the target month,
# move back a week or two.
# This is necessary because Microsoft defines the fifth week in a month
# to be the last week in a month and adding the time delta might have
# pushed the result into the next month.
while result.month == cutoff.month + 1:
result -= datetime.timedelta(weeks=1)
return result
>>> SATURDAY = 6
>>> MARCH = 3
>>> st = SYSTEMTIME(2000, MARCH, SATURDAY, 4, 0, 0, 0, 0)
# according to my calendar, the 4th Saturday in March in 2009 was the 28th
>>> expected_date = datetime.datetime(2009, 3, 28)
>>> Info._locate_day(2009, st) == expected_date
True
"""
# MS stores Sunday as 0, Python datetime stores Monday as zero
target_weekday = (cutoff.day_of_week + 6) % 7
# For SYSTEMTIMEs relating to time zone inforamtion, cutoff.day
# is the week of the month
week_of_month = cutoff.day
# so the following is the first day of that week
day = (week_of_month - 1) * 7 + 1
result = datetime.datetime(
year,
cutoff.month,
day,
cutoff.hour,
cutoff.minute,
cutoff.second,
cutoff.millisecond,
)
# now the result is the correct week, but not necessarily
# the correct day of the week
days_to_go = (target_weekday - result.weekday()) % 7
result += datetime.timedelta(days_to_go)
# if we selected a day in the month following the target month,
# move back a week or two.
# This is necessary because Microsoft defines the fifth week in a month
# to be the last week in a month and adding the time delta might have
# pushed the result into the next month.
while result.month == cutoff.month + 1:
result -= datetime.timedelta(weeks=1)
return result

View file

@ -5,5 +5,5 @@ from jaraco.windows.util import ensure_unicode
def MessageBox(text, caption=None, handle=None, type=None):
text, caption = map(ensure_unicode, (text, caption))
ctypes.windll.user32.MessageBoxW(handle, text, caption, type)
text, caption = map(ensure_unicode, (text, caption))
ctypes.windll.user32.MessageBoxW(handle, text, caption, type)

View file

@ -5,12 +5,12 @@ from .error import WindowsError, handle_nonzero_success
def get_user_name():
size = ctypes.wintypes.DWORD()
try:
handle_nonzero_success(GetUserName(None, size))
except WindowsError as e:
if e.code != errors.ERROR_INSUFFICIENT_BUFFER:
raise
buffer = ctypes.create_unicode_buffer(size.value)
handle_nonzero_success(GetUserName(buffer, size))
return buffer.value
size = ctypes.wintypes.DWORD()
try:
handle_nonzero_success(GetUserName(None, size))
except WindowsError as e:
if e.code != errors.ERROR_INSUFFICIENT_BUFFER:
raise
buffer = ctypes.create_unicode_buffer(size.value)
handle_nonzero_success(GetUserName(buffer, size))
return buffer.value

View file

@ -4,17 +4,18 @@ import ctypes
def ensure_unicode(param):
try:
param = ctypes.create_unicode_buffer(param)
except TypeError:
pass # just return the param as is
return param
try:
param = ctypes.create_unicode_buffer(param)
except TypeError:
pass # just return the param as is
return param
class Extended(object):
"Used to add extended capability to structures"
def __eq__(self, other):
return memoryview(self) == memoryview(other)
"Used to add extended capability to structures"
def __ne__(self, other):
return memoryview(self) != memoryview(other)
def __eq__(self, other):
return memoryview(self) == memoryview(other)
def __ne__(self, other):
return memoryview(self) != memoryview(other)

View file

@ -3,15 +3,19 @@ from path import Path
def install_pptp(name, param_lines):
"""
"""
# or consider using the API:
# http://msdn.microsoft.com/en-us/library/aa446739%28v=VS.85%29.aspx
pbk_path = (
Path(os.environ['PROGRAMDATA'])
/ 'Microsoft' / 'Network' / 'Connections' / 'pbk' / 'rasphone.pbk')
pbk_path.dirname().makedirs_p()
with open(pbk_path, 'a') as pbk:
pbk.write('[{name}]\n'.format(name=name))
pbk.writelines(param_lines)
pbk.write('\n')
""" """
# or consider using the API:
# http://msdn.microsoft.com/en-us/library/aa446739%28v=VS.85%29.aspx
pbk_path = (
Path(os.environ['PROGRAMDATA'])
/ 'Microsoft'
/ 'Network'
/ 'Connections'
/ 'pbk'
/ 'rasphone.pbk'
)
pbk_path.dirname().makedirs_p()
with open(pbk_path, 'a') as pbk:
pbk.write('[{name}]\n'.format(name=name))
pbk.writelines(param_lines)
pbk.write('\n')

View file

@ -1,7 +1,3 @@
#!python
from __future__ import print_function
import ctypes
from jaraco.windows.error import handle_nonzero_success
from jaraco.windows.api import system
@ -9,92 +5,84 @@ from jaraco.ui.cmdline import Command
def set(value):
result = system.SystemParametersInfo(
system.SPI_SETACTIVEWINDOWTRACKING,
0,
ctypes.cast(value, ctypes.c_void_p),
0,
)
handle_nonzero_success(result)
result = system.SystemParametersInfo(
system.SPI_SETACTIVEWINDOWTRACKING, 0, ctypes.cast(value, ctypes.c_void_p), 0
)
handle_nonzero_success(result)
def get():
value = ctypes.wintypes.BOOL()
result = system.SystemParametersInfo(
system.SPI_GETACTIVEWINDOWTRACKING,
0,
ctypes.byref(value),
0,
)
handle_nonzero_success(result)
return bool(value)
value = ctypes.wintypes.BOOL()
result = system.SystemParametersInfo(
system.SPI_GETACTIVEWINDOWTRACKING, 0, ctypes.byref(value), 0
)
handle_nonzero_success(result)
return bool(value)
def set_delay(milliseconds):
result = system.SystemParametersInfo(
system.SPI_SETACTIVEWNDTRKTIMEOUT,
0,
ctypes.cast(milliseconds, ctypes.c_void_p),
0,
)
handle_nonzero_success(result)
result = system.SystemParametersInfo(
system.SPI_SETACTIVEWNDTRKTIMEOUT,
0,
ctypes.cast(milliseconds, ctypes.c_void_p),
0,
)
handle_nonzero_success(result)
def get_delay():
value = ctypes.wintypes.DWORD()
result = system.SystemParametersInfo(
system.SPI_GETACTIVEWNDTRKTIMEOUT,
0,
ctypes.byref(value),
0,
)
handle_nonzero_success(result)
return int(value.value)
value = ctypes.wintypes.DWORD()
result = system.SystemParametersInfo(
system.SPI_GETACTIVEWNDTRKTIMEOUT, 0, ctypes.byref(value), 0
)
handle_nonzero_success(result)
return int(value.value)
class DelayParam(Command):
@staticmethod
def add_arguments(parser):
parser.add_argument(
'-d', '--delay', type=int,
help="Delay in milliseconds for active window tracking"
)
@staticmethod
def add_arguments(parser):
parser.add_argument(
'-d',
'--delay',
type=int,
help="Delay in milliseconds for active window tracking",
)
class Show(Command):
@classmethod
def run(cls, args):
msg = "xmouse: {enabled} (delay {delay}ms)".format(
enabled=get(),
delay=get_delay(),
)
print(msg)
@classmethod
def run(cls, args):
msg = "xmouse: {enabled} (delay {delay}ms)".format(
enabled=get(), delay=get_delay()
)
print(msg)
class Enable(DelayParam):
@classmethod
def run(cls, args):
print("enabling xmouse")
set(True)
args.delay and set_delay(args.delay)
@classmethod
def run(cls, args):
print("enabling xmouse")
set(True)
args.delay and set_delay(args.delay)
class Disable(DelayParam):
@classmethod
def run(cls, args):
print("disabling xmouse")
set(False)
args.delay and set_delay(args.delay)
@classmethod
def run(cls, args):
print("disabling xmouse")
set(False)
args.delay and set_delay(args.delay)
class Toggle(DelayParam):
@classmethod
def run(cls, args):
value = get()
print("xmouse: %s -> %s" % (value, not value))
set(not value)
args.delay and set_delay(args.delay)
@classmethod
def run(cls, args):
value = get()
print("xmouse: %s -> %s" % (value, not value))
set(not value)
args.delay and set_delay(args.delay)
if __name__ == '__main__':
Command.invoke()
Command.invoke()