diff --git a/libs/win/autocommand/__init__.py b/libs/win/autocommand/__init__.py new file mode 100644 index 00000000..73fbfca6 --- /dev/null +++ b/libs/win/autocommand/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2014-2016 Nathan West +# +# This file is part of autocommand. +# +# autocommand is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# autocommand is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with autocommand. If not, see . + +# flake8 flags all these imports as unused, hence the NOQAs everywhere. + +from .automain import automain # NOQA +from .autoparse import autoparse, smart_open # NOQA +from .autocommand import autocommand # NOQA + +try: + from .autoasync import autoasync # NOQA +except ImportError: # pragma: no cover + pass diff --git a/libs/win/autocommand/autoasync.py b/libs/win/autocommand/autoasync.py new file mode 100644 index 00000000..688f7e05 --- /dev/null +++ b/libs/win/autocommand/autoasync.py @@ -0,0 +1,142 @@ +# Copyright 2014-2015 Nathan West +# +# This file is part of autocommand. +# +# autocommand is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# autocommand is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with autocommand. If not, see . + +from asyncio import get_event_loop, iscoroutine +from functools import wraps +from inspect import signature + + +async def _run_forever_coro(coro, args, kwargs, loop): + ''' + This helper function launches an async main function that was tagged with + forever=True. There are two possibilities: + + - The function is a normal function, which handles initializing the event + loop, which is then run forever + - The function is a coroutine, which needs to be scheduled in the event + loop, which is then run forever + - There is also the possibility that the function is a normal function + wrapping a coroutine function + + The function is therefore called unconditionally and scheduled in the event + loop if the return value is a coroutine object. + + The reason this is a separate function is to make absolutely sure that all + the objects created are garbage collected after all is said and done; we + do this to ensure that any exceptions raised in the tasks are collected + ASAP. + ''' + + # Personal note: I consider this an antipattern, as it relies on the use of + # unowned resources. The setup function dumps some stuff into the event + # loop where it just whirls in the ether without a well defined owner or + # lifetime. For this reason, there's a good chance I'll remove the + # forever=True feature from autoasync at some point in the future. + thing = coro(*args, **kwargs) + if iscoroutine(thing): + await thing + + +def autoasync(coro=None, *, loop=None, forever=False, pass_loop=False): + ''' + Convert an asyncio coroutine into a function which, when called, is + evaluted in an event loop, and the return value returned. This is intented + to make it easy to write entry points into asyncio coroutines, which + otherwise need to be explictly evaluted with an event loop's + run_until_complete. + + If `loop` is given, it is used as the event loop to run the coro in. If it + is None (the default), the loop is retreived using asyncio.get_event_loop. + This call is defered until the decorated function is called, so that + callers can install custom event loops or event loop policies after + @autoasync is applied. + + If `forever` is True, the loop is run forever after the decorated coroutine + is finished. Use this for servers created with asyncio.start_server and the + like. + + If `pass_loop` is True, the event loop object is passed into the coroutine + as the `loop` kwarg when the wrapper function is called. In this case, the + wrapper function's __signature__ is updated to remove this parameter, so + that autoparse can still be used on it without generating a parameter for + `loop`. + + This coroutine can be called with ( @autoasync(...) ) or without + ( @autoasync ) arguments. + + Examples: + + @autoasync + def get_file(host, port): + reader, writer = yield from asyncio.open_connection(host, port) + data = reader.read() + sys.stdout.write(data.decode()) + + get_file(host, port) + + @autoasync(forever=True, pass_loop=True) + def server(host, port, loop): + yield_from loop.create_server(Proto, host, port) + + server('localhost', 8899) + + ''' + if coro is None: + return lambda c: autoasync( + c, loop=loop, + forever=forever, + pass_loop=pass_loop) + + # The old and new signatures are required to correctly bind the loop + # parameter in 100% of cases, even if it's a positional parameter. + # NOTE: A future release will probably require the loop parameter to be + # a kwonly parameter. + if pass_loop: + old_sig = signature(coro) + new_sig = old_sig.replace(parameters=( + param for name, param in old_sig.parameters.items() + if name != "loop")) + + @wraps(coro) + def autoasync_wrapper(*args, **kwargs): + # Defer the call to get_event_loop so that, if a custom policy is + # installed after the autoasync decorator, it is respected at call time + local_loop = get_event_loop() if loop is None else loop + + # Inject the 'loop' argument. We have to use this signature binding to + # ensure it's injected in the correct place (positional, keyword, etc) + if pass_loop: + bound_args = old_sig.bind_partial() + bound_args.arguments.update( + loop=local_loop, + **new_sig.bind(*args, **kwargs).arguments) + args, kwargs = bound_args.args, bound_args.kwargs + + if forever: + local_loop.create_task(_run_forever_coro( + coro, args, kwargs, local_loop + )) + local_loop.run_forever() + else: + return local_loop.run_until_complete(coro(*args, **kwargs)) + + # Attach the updated signature. This allows 'pass_loop' to be used with + # autoparse + if pass_loop: + autoasync_wrapper.__signature__ = new_sig + + return autoasync_wrapper diff --git a/libs/win/autocommand/autocommand.py b/libs/win/autocommand/autocommand.py new file mode 100644 index 00000000..097e86de --- /dev/null +++ b/libs/win/autocommand/autocommand.py @@ -0,0 +1,70 @@ +# Copyright 2014-2015 Nathan West +# +# This file is part of autocommand. +# +# autocommand is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# autocommand is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with autocommand. If not, see . + +from .autoparse import autoparse +from .automain import automain +try: + from .autoasync import autoasync +except ImportError: # pragma: no cover + pass + + +def autocommand( + module, *, + description=None, + epilog=None, + add_nos=False, + parser=None, + loop=None, + forever=False, + pass_loop=False): + + if callable(module): + raise TypeError('autocommand requires a module name argument') + + def autocommand_decorator(func): + # Step 1: if requested, run it all in an asyncio event loop. autoasync + # patches the __signature__ of the decorated function, so that in the + # event that pass_loop is True, the `loop` parameter of the original + # function will *not* be interpreted as a command-line argument by + # autoparse + if loop is not None or forever or pass_loop: + func = autoasync( + func, + loop=None if loop is True else loop, + pass_loop=pass_loop, + forever=forever) + + # Step 2: create parser. We do this second so that the arguments are + # parsed and passed *before* entering the asyncio event loop, if it + # exists. This simplifies the stack trace and ensures errors are + # reported earlier. It also ensures that errors raised during parsing & + # passing are still raised if `forever` is True. + func = autoparse( + func, + description=description, + epilog=epilog, + add_nos=add_nos, + parser=parser) + + # Step 3: call the function automatically if __name__ == '__main__' (or + # if True was provided) + func = automain(module)(func) + + return func + + return autocommand_decorator diff --git a/libs/win/autocommand/automain.py b/libs/win/autocommand/automain.py new file mode 100644 index 00000000..6cc45db6 --- /dev/null +++ b/libs/win/autocommand/automain.py @@ -0,0 +1,59 @@ +# Copyright 2014-2015 Nathan West +# +# This file is part of autocommand. +# +# autocommand is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# autocommand is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with autocommand. If not, see . + +import sys +from .errors import AutocommandError + + +class AutomainRequiresModuleError(AutocommandError, TypeError): + pass + + +def automain(module, *, args=(), kwargs=None): + ''' + This decorator automatically invokes a function if the module is being run + as the "__main__" module. Optionally, provide args or kwargs with which to + call the function. If `module` is "__main__", the function is called, and + the program is `sys.exit`ed with the return value. You can also pass `True` + to cause the function to be called unconditionally. If the function is not + called, it is returned unchanged by the decorator. + + Usage: + + @automain(__name__) # Pass __name__ to check __name__=="__main__" + def main(): + ... + + If __name__ is "__main__" here, the main function is called, and then + sys.exit called with the return value. + ''' + + # Check that @automain(...) was called, rather than @automain + if callable(module): + raise AutomainRequiresModuleError(module) + + if module == '__main__' or module is True: + if kwargs is None: + kwargs = {} + + # Use a function definition instead of a lambda for a neater traceback + def automain_decorator(main): + sys.exit(main(*args, **kwargs)) + + return automain_decorator + else: + return lambda main: main diff --git a/libs/win/autocommand/autoparse.py b/libs/win/autocommand/autoparse.py new file mode 100644 index 00000000..0276a3fa --- /dev/null +++ b/libs/win/autocommand/autoparse.py @@ -0,0 +1,333 @@ +# Copyright 2014-2015 Nathan West +# +# This file is part of autocommand. +# +# autocommand is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# autocommand is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with autocommand. If not, see . + +import sys +from re import compile as compile_regex +from inspect import signature, getdoc, Parameter +from argparse import ArgumentParser +from contextlib import contextmanager +from functools import wraps +from io import IOBase +from autocommand.errors import AutocommandError + + +_empty = Parameter.empty + + +class AnnotationError(AutocommandError): + '''Annotation error: annotation must be a string, type, or tuple of both''' + + +class PositionalArgError(AutocommandError): + ''' + Postional Arg Error: autocommand can't handle postional-only parameters + ''' + + +class KWArgError(AutocommandError): + '''kwarg Error: autocommand can't handle a **kwargs parameter''' + + +class DocstringError(AutocommandError): + '''Docstring error''' + + +class TooManySplitsError(DocstringError): + ''' + The docstring had too many ---- section splits. Currently we only support + using up to a single split, to split the docstring into description and + epilog parts. + ''' + + +def _get_type_description(annotation): + ''' + Given an annotation, return the (type, description) for the parameter. + If you provide an annotation that is somehow both a string and a callable, + the behavior is undefined. + ''' + if annotation is _empty: + return None, None + elif callable(annotation): + return annotation, None + elif isinstance(annotation, str): + return None, annotation + elif isinstance(annotation, tuple): + try: + arg1, arg2 = annotation + except ValueError as e: + raise AnnotationError(annotation) from e + else: + if callable(arg1) and isinstance(arg2, str): + return arg1, arg2 + elif isinstance(arg1, str) and callable(arg2): + return arg2, arg1 + + raise AnnotationError(annotation) + + +def _add_arguments(param, parser, used_char_args, add_nos): + ''' + Add the argument(s) to an ArgumentParser (using add_argument) for a given + parameter. used_char_args is the set of -short options currently already in + use, and is updated (if necessary) by this function. If add_nos is True, + this will also add an inverse switch for all boolean options. For + instance, for the boolean parameter "verbose", this will create --verbose + and --no-verbose. + ''' + + # Impl note: This function is kept separate from make_parser because it's + # already very long and I wanted to separate out as much as possible into + # its own call scope, to prevent even the possibility of suble mutation + # bugs. + if param.kind is param.POSITIONAL_ONLY: + raise PositionalArgError(param) + elif param.kind is param.VAR_KEYWORD: + raise KWArgError(param) + + # These are the kwargs for the add_argument function. + arg_spec = {} + is_option = False + + # Get the type and default from the annotation. + arg_type, description = _get_type_description(param.annotation) + + # Get the default value + default = param.default + + # If there is no explicit type, and the default is present and not None, + # infer the type from the default. + if arg_type is None and default not in {_empty, None}: + arg_type = type(default) + + # Add default. The presence of a default means this is an option, not an + # argument. + if default is not _empty: + arg_spec['default'] = default + is_option = True + + # Add the type + if arg_type is not None: + # Special case for bool: make it just a --switch + if arg_type is bool: + if not default or default is _empty: + arg_spec['action'] = 'store_true' + else: + arg_spec['action'] = 'store_false' + + # Switches are always options + is_option = True + + # Special case for file types: make it a string type, for filename + elif isinstance(default, IOBase): + arg_spec['type'] = str + + # TODO: special case for list type. + # - How to specificy type of list members? + # - param: [int] + # - param: int =[] + # - action='append' vs nargs='*' + + else: + arg_spec['type'] = arg_type + + # nargs: if the signature includes *args, collect them as trailing CLI + # arguments in a list. *args can't have a default value, so it can never be + # an option. + if param.kind is param.VAR_POSITIONAL: + # TODO: consider depluralizing metavar/name here. + arg_spec['nargs'] = '*' + + # Add description. + if description is not None: + arg_spec['help'] = description + + # Get the --flags + flags = [] + name = param.name + + if is_option: + # Add the first letter as a -short option. + for letter in name[0], name[0].swapcase(): + if letter not in used_char_args: + used_char_args.add(letter) + flags.append('-{}'.format(letter)) + break + + # If the parameter is a --long option, or is a -short option that + # somehow failed to get a flag, add it. + if len(name) > 1 or not flags: + flags.append('--{}'.format(name)) + + arg_spec['dest'] = name + else: + flags.append(name) + + parser.add_argument(*flags, **arg_spec) + + # Create the --no- version for boolean switches + if add_nos and arg_type is bool: + parser.add_argument( + '--no-{}'.format(name), + action='store_const', + dest=name, + const=default if default is not _empty else False) + + +def make_parser(func_sig, description, epilog, add_nos): + ''' + Given the signature of a function, create an ArgumentParser + ''' + parser = ArgumentParser(description=description, epilog=epilog) + + used_char_args = {'h'} + + # Arange the params so that single-character arguments are first. This + # esnures they don't have to get --long versions. sorted is stable, so the + # parameters will otherwise still be in relative order. + params = sorted( + func_sig.parameters.values(), + key=lambda param: len(param.name) > 1) + + for param in params: + _add_arguments(param, parser, used_char_args, add_nos) + + return parser + + +_DOCSTRING_SPLIT = compile_regex(r'\n\s*-{4,}\s*\n') + + +def parse_docstring(docstring): + ''' + Given a docstring, parse it into a description and epilog part + ''' + if docstring is None: + return '', '' + + parts = _DOCSTRING_SPLIT.split(docstring) + + if len(parts) == 1: + return docstring, '' + elif len(parts) == 2: + return parts[0], parts[1] + else: + raise TooManySplitsError() + + +def autoparse( + func=None, *, + description=None, + epilog=None, + add_nos=False, + parser=None): + ''' + This decorator converts a function that takes normal arguments into a + function which takes a single optional argument, argv, parses it using an + argparse.ArgumentParser, and calls the underlying function with the parsed + arguments. If it is not given, sys.argv[1:] is used. This is so that the + function can be used as a setuptools entry point, as well as a normal main + function. sys.argv[1:] is not evaluated until the function is called, to + allow injecting different arguments for testing. + + It uses the argument signature of the function to create an + ArgumentParser. Parameters without defaults become positional parameters, + while parameters *with* defaults become --options. Use annotations to set + the type of the parameter. + + The `desctiption` and `epilog` parameters corrospond to the same respective + argparse parameters. If no description is given, it defaults to the + decorated functions's docstring, if present. + + If add_nos is True, every boolean option (that is, every parameter with a + default of True/False or a type of bool) will have a --no- version created + as well, which inverts the option. For instance, the --verbose option will + have a --no-verbose counterpart. These are not mutually exclusive- + whichever one appears last in the argument list will have precedence. + + If a parser is given, it is used instead of one generated from the function + signature. In this case, no parser is created; instead, the given parser is + used to parse the argv argument. The parser's results' argument names must + match up with the parameter names of the decorated function. + + The decorated function is attached to the result as the `func` attribute, + and the parser is attached as the `parser` attribute. + ''' + + # If @autoparse(...) is used instead of @autoparse + if func is None: + return lambda f: autoparse( + f, description=description, + epilog=epilog, + add_nos=add_nos, + parser=parser) + + func_sig = signature(func) + + docstr_description, docstr_epilog = parse_docstring(getdoc(func)) + + if parser is None: + parser = make_parser( + func_sig, + description or docstr_description, + epilog or docstr_epilog, + add_nos) + + @wraps(func) + def autoparse_wrapper(argv=None): + if argv is None: + argv = sys.argv[1:] + + # Get empty argument binding, to fill with parsed arguments. This + # object does all the heavy lifting of turning named arguments into + # into correctly bound *args and **kwargs. + parsed_args = func_sig.bind_partial() + parsed_args.arguments.update(vars(parser.parse_args(argv))) + + return func(*parsed_args.args, **parsed_args.kwargs) + + # TODO: attach an updated __signature__ to autoparse_wrapper, just in case. + + # Attach the wrapped function and parser, and return the wrapper. + autoparse_wrapper.func = func + autoparse_wrapper.parser = parser + return autoparse_wrapper + + +@contextmanager +def smart_open(filename_or_file, *args, **kwargs): + ''' + This context manager allows you to open a filename, if you want to default + some already-existing file object, like sys.stdout, which shouldn't be + closed at the end of the context. If the filename argument is a str, bytes, + or int, the file object is created via a call to open with the given *args + and **kwargs, sent to the context, and closed at the end of the context, + just like "with open(filename) as f:". If it isn't one of the openable + types, the object simply sent to the context unchanged, and left unclosed + at the end of the context. Example: + + def work_with_file(name=sys.stdout): + with smart_open(name) as f: + # Works correctly if name is a str filename or sys.stdout + print("Some stuff", file=f) + # If it was a filename, f is closed at the end here. + ''' + if isinstance(filename_or_file, (str, bytes, int)): + with open(filename_or_file, *args, **kwargs) as file: + yield file + else: + yield filename_or_file diff --git a/libs/win/autocommand/errors.py b/libs/win/autocommand/errors.py new file mode 100644 index 00000000..25706073 --- /dev/null +++ b/libs/win/autocommand/errors.py @@ -0,0 +1,23 @@ +# Copyright 2014-2016 Nathan West +# +# This file is part of autocommand. +# +# autocommand is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# autocommand is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with autocommand. If not, see . + + +class AutocommandError(Exception): + '''Base class for autocommand exceptions''' + pass + +# Individual modules will define errors specific to that module. diff --git a/libs/win/bin/enver.exe b/libs/win/bin/enver.exe index 3999307b..238225bc 100644 Binary files a/libs/win/bin/enver.exe and b/libs/win/bin/enver.exe differ diff --git a/libs/win/bin/find-symlinks.exe b/libs/win/bin/find-symlinks.exe index 4c98a64f..a048fb35 100644 Binary files a/libs/win/bin/find-symlinks.exe and b/libs/win/bin/find-symlinks.exe differ diff --git a/libs/win/bin/gclip.exe b/libs/win/bin/gclip.exe index 0c4805e1..b9b6c928 100644 Binary files a/libs/win/bin/gclip.exe and b/libs/win/bin/gclip.exe differ diff --git a/libs/win/bin/mklink.exe b/libs/win/bin/mklink.exe index 5d477820..eca54539 100644 Binary files a/libs/win/bin/mklink.exe and b/libs/win/bin/mklink.exe differ diff --git a/libs/win/bin/pclip.exe b/libs/win/bin/pclip.exe index 3c3e8b48..9e8bb32c 100644 Binary files a/libs/win/bin/pclip.exe and b/libs/win/bin/pclip.exe differ diff --git a/libs/win/bin/xmouse.exe b/libs/win/bin/xmouse.exe index 4b7c018e..bd19cbb2 100644 Binary files a/libs/win/bin/xmouse.exe and b/libs/win/bin/xmouse.exe differ diff --git a/libs/win/bugs/environ-api-wierdness.py b/libs/win/bugs/environ-api-wierdness.py new file mode 100644 index 00000000..a7c8b089 --- /dev/null +++ b/libs/win/bugs/environ-api-wierdness.py @@ -0,0 +1,39 @@ +import ctypes +from jaraco.windows import environ +import os + +getenv = ctypes.cdll.msvcrt.getenv +getenv.restype = ctypes.c_char_p +putenv = ctypes.cdll.msvcrt._putenv + + +def do_putenv(*pair): + return putenv("=".join(pair)) + + +def print_environment_variable(key): + for method in (os.environ.get, os.getenv, environ.GetEnvironmentVariable, getenv): + try: + print(repr(method(key))) + except Exception as e: + print(e, end=' ') + print + + +def do_test(): + key = 'TEST_PYTHON_ENVIRONMENT' + print_environment_variable(key) + methods = ( + os.environ.__setitem__, + os.putenv, + environ.SetEnvironmentVariable, + do_putenv, + ) + for i, method in enumerate(methods): + print('round', i) + method(key, 'value when using method %d' % i) + print_environment_variable(key) + + +if __name__ == '__main__': + do_test() diff --git a/libs/win/bugs/find_target_path.py b/libs/win/bugs/find_target_path.py new file mode 100644 index 00000000..ac75417e --- /dev/null +++ b/libs/win/bugs/find_target_path.py @@ -0,0 +1,69 @@ +import os + + +def findpath(target, start=os.path.curdir): + r""" + Find a path from start to target where target is relative to start. + + >>> orig_wd = os.getcwd() + >>> os.chdir('c:\\windows') # so we know what the working directory is + + >>> findpath('d:\\') + 'd:\\' + + >>> findpath('d:\\', 'c:\\windows') + 'd:\\' + + >>> findpath('\\bar', 'd:\\') + 'd:\\bar' + + >>> findpath('\\bar', 'd:\\foo') # fails with '\\bar' + 'd:\\bar' + + >>> findpath('bar', 'd:\\foo') + 'd:\\foo\\bar' + + >>> findpath('bar\\baz', 'd:\\foo') + 'd:\\foo\\bar\\baz' + + >>> findpath('\\baz', 'd:\\foo\\bar') # fails with '\\baz' + 'd:\\baz' + + Since we're on the C drive, findpath may be allowed to return + relative paths for targets on the same drive. I use abspath to + confirm that the ultimate target is what we expect. + >>> os.path.abspath(findpath('\\bar')) + 'c:\\bar' + + >>> os.path.abspath(findpath('bar')) + 'c:\\windows\\bar' + + >>> findpath('..', 'd:\\foo\\bar') + 'd:\\foo' + + >>> findpath('..\\bar', 'd:\\foo') + 'd:\\bar' + + The parent of the root directory is the root directory. + >>> findpath('..', 'd:\\') + 'd:\\' + + restore the original working directory + >>> os.chdir(orig_wd) + """ + return os.path.normpath(os.path.join(start, target)) + + +def main(): + import sys + + if sys.argv[1:]: + print(findpath(*sys.argv[1:])) + else: + import doctest + + doctest.testmod() + + +if __name__ == '__main__': + main() diff --git a/libs/win/bugs/multi_os_libc.py b/libs/win/bugs/multi_os_libc.py new file mode 100644 index 00000000..b1443122 --- /dev/null +++ b/libs/win/bugs/multi_os_libc.py @@ -0,0 +1,21 @@ +from ctypes import CDLL, c_char_p + + +def get_libc(): + libnames = ('msvcrt', 'libc.so.6') + for libname in libnames: + try: + return CDLL(libname) + except WindowsError: + pass + except OSError: + pass + raise RuntimeError("Unable to find a suitable libc (tried %s)" % libnames) + + +getenv = get_libc().getenv +getenv.restype = c_char_p + +# call into your linked module here + +print('new value is', getenv('FOO')) diff --git a/libs/win/bugs/vista-symlink-islink-bug.py b/libs/win/bugs/vista-symlink-islink-bug.py new file mode 100644 index 00000000..a8e8f010 --- /dev/null +++ b/libs/win/bugs/vista-symlink-islink-bug.py @@ -0,0 +1,29 @@ +import os +import sys + +try: + from jaraco.windows.filesystem import symlink +except ImportError: + # a dirty reimplementation of symlink from jaraco.windows + from ctypes import windll + from ctypes.wintypes import LPWSTR, DWORD, BOOLEAN + + CreateSymbolicLink = windll.kernel32.CreateSymbolicLinkW + CreateSymbolicLink.argtypes = (LPWSTR, LPWSTR, DWORD) + CreateSymbolicLink.restype = BOOLEAN + + def symlink(link, target, target_is_directory=False): + """ + An implementation of os.symlink for Windows (Vista and greater) + """ + target_is_directory = target_is_directory or os.path.isdir(target) + CreateSymbolicLink(link, target, target_is_directory) + + +assert sys.platform in ('win32',) +os.makedirs(r'.\foo') +assert os.path.isdir(r'.\foo') + +symlink(r'.\foo_sym', r'.\foo') +assert os.path.isdir(r'.\foo_sym') +assert os.path.islink(r'.\foo_sym') # fails diff --git a/libs/win/bugs/wnetaddconnection2-error-on-64-bit.py b/libs/win/bugs/wnetaddconnection2-error-on-64-bit.py new file mode 100644 index 00000000..e7c7a756 --- /dev/null +++ b/libs/win/bugs/wnetaddconnection2-error-on-64-bit.py @@ -0,0 +1,20 @@ +# reported at http://social.msdn.microsoft.com/Forums/en-US/wsk/thread/f43c2faf-3df3-4f11-9f5e-1a9101753f93 +from win32wnet import WNetAddConnection2, NETRESOURCE + +resource = NETRESOURCE() +resource.lpRemoteName = r'\\aoshi\users' +username = 'jaraco' +res = WNetAddConnection2(resource, UserName=username) +print('first result is', res) +res = WNetAddConnection2(resource, UserName=username) +print('second result is', res) + +""" +Output is: + +first result is None +Traceback (most recent call last): + File ".\wnetaddconnection2-error-on-64-bit.py", line 7, in + res = WNetAddConnection2(resource, UserName=username) +pywintypes.error: (1219, 'WNetAddConnection2', 'Multiple connections to a server or shared resource by the same user, using more than one user name, are not allowed. Disconnect all previous connections to the server or shared resource and try again.') +""" diff --git a/libs/win/importlib_metadata/__init__.py b/libs/win/importlib_metadata/__init__.py deleted file mode 100644 index f594c6f7..00000000 --- a/libs/win/importlib_metadata/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from .api import distribution, Distribution, PackageNotFoundError # noqa: F401 -from .api import metadata, entry_points, resolve, version, read_text - -# Import for installation side-effects. -from . import _hooks # noqa: F401 - - -__all__ = [ - 'metadata', - 'entry_points', - 'resolve', - 'version', - 'read_text', - ] - - -__version__ = version(__name__) diff --git a/libs/win/importlib_metadata/_hooks.py b/libs/win/importlib_metadata/_hooks.py deleted file mode 100644 index 1fd62698..00000000 --- a/libs/win/importlib_metadata/_hooks.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import unicode_literals, absolute_import - -import re -import sys -import itertools - -from .api import Distribution -from zipfile import ZipFile - -if sys.version_info >= (3,): # pragma: nocover - from contextlib import suppress - from pathlib import Path -else: # pragma: nocover - from contextlib2 import suppress # noqa - from itertools import imap as map # type: ignore - from pathlib2 import Path - - FileNotFoundError = IOError, OSError - __metaclass__ = type - - -def install(cls): - """Class decorator for installation on sys.meta_path.""" - sys.meta_path.append(cls) - return cls - - -class NullFinder: - @staticmethod - def find_spec(*args, **kwargs): - return None - - # In Python 2, the import system requires finders - # to have a find_module() method, but this usage - # is deprecated in Python 3 in favor of find_spec(). - # For the purposes of this finder (i.e. being present - # on sys.meta_path but having no other import - # system functionality), the two methods are identical. - find_module = find_spec - - -@install -class MetadataPathFinder(NullFinder): - """A degenerate finder for distribution packages on the file system. - - This finder supplies only a find_distribution() method for versions - of Python that do not have a PathFinder find_distribution(). - """ - search_template = r'{name}(-.*)?\.(dist|egg)-info' - - @classmethod - def find_distribution(cls, name): - paths = cls._search_paths(name) - dists = map(PathDistribution, paths) - return next(dists, None) - - @classmethod - def _search_paths(cls, name): - """ - Find metadata directories in sys.path heuristically. - """ - return itertools.chain.from_iterable( - cls._search_path(path, name) - for path in map(Path, sys.path) - ) - - @classmethod - def _search_path(cls, root, name): - if not root.is_dir(): - return () - normalized = name.replace('-', '_') - return ( - item - for item in root.iterdir() - if item.is_dir() - and re.match( - cls.search_template.format(name=normalized), - str(item.name), - flags=re.IGNORECASE, - ) - ) - - -class PathDistribution(Distribution): - def __init__(self, path): - """Construct a distribution from a path to the metadata directory.""" - self._path = path - - def read_text(self, filename): - with suppress(FileNotFoundError): - with self._path.joinpath(filename).open(encoding='utf-8') as fp: - return fp.read() - return None - read_text.__doc__ = Distribution.read_text.__doc__ - - -@install -class WheelMetadataFinder(NullFinder): - """A degenerate finder for distribution packages in wheels. - - This finder supplies only a find_distribution() method for versions - of Python that do not have a PathFinder find_distribution(). - """ - search_template = r'{name}(-.*)?\.whl' - - @classmethod - def find_distribution(cls, name): - paths = cls._search_paths(name) - dists = map(WheelDistribution, paths) - return next(dists, None) - - @classmethod - def _search_paths(cls, name): - return ( - item - for item in map(Path, sys.path) - if re.match( - cls.search_template.format(name=name), - str(item.name), - flags=re.IGNORECASE, - ) - ) - - -class WheelDistribution(Distribution): - def __init__(self, archive): - self._archive = archive - name, version = archive.name.split('-')[0:2] - self._dist_info = '{}-{}.dist-info'.format(name, version) - - def read_text(self, filename): - with ZipFile(_path_to_filename(self._archive)) as zf: - with suppress(KeyError): - as_bytes = zf.read('{}/{}'.format(self._dist_info, filename)) - return as_bytes.decode('utf-8') - return None - read_text.__doc__ = Distribution.read_text.__doc__ - - -def _path_to_filename(path): # pragma: nocover - """ - On non-compliant systems, ensure a path-like object is - a string. - """ - try: - return path.__fspath__() - except AttributeError: - return str(path) diff --git a/libs/win/importlib_metadata/api.py b/libs/win/importlib_metadata/api.py deleted file mode 100644 index 41942a39..00000000 --- a/libs/win/importlib_metadata/api.py +++ /dev/null @@ -1,146 +0,0 @@ -import io -import abc -import sys -import email - -from importlib import import_module - -if sys.version_info > (3,): # pragma: nocover - from configparser import ConfigParser -else: # pragma: nocover - from ConfigParser import SafeConfigParser as ConfigParser - -try: - BaseClass = ModuleNotFoundError -except NameError: # pragma: nocover - BaseClass = ImportError # type: ignore - - -__metaclass__ = type - - -class PackageNotFoundError(BaseClass): - """The package was not found.""" - - -class Distribution: - """A Python distribution package.""" - - @abc.abstractmethod - def read_text(self, filename): - """Attempt to load metadata file given by the name. - - :param filename: The name of the file in the distribution info. - :return: The text if found, otherwise None. - """ - - @classmethod - def from_name(cls, name): - """Return the Distribution for the given package name. - - :param name: The name of the distribution package to search for. - :return: The Distribution instance (or subclass thereof) for the named - package, if found. - :raises PackageNotFoundError: When the named package's distribution - metadata cannot be found. - """ - for resolver in cls._discover_resolvers(): - resolved = resolver(name) - if resolved is not None: - return resolved - else: - raise PackageNotFoundError(name) - - @staticmethod - def _discover_resolvers(): - """Search the meta_path for resolvers.""" - declared = ( - getattr(finder, 'find_distribution', None) - for finder in sys.meta_path - ) - return filter(None, declared) - - @property - def metadata(self): - """Return the parsed metadata for this Distribution. - - The returned object will have keys that name the various bits of - metadata. See PEP 566 for details. - """ - return email.message_from_string( - self.read_text('METADATA') or self.read_text('PKG-INFO') - ) - - @property - def version(self): - """Return the 'Version' metadata for the distribution package.""" - return self.metadata['Version'] - - -def distribution(package): - """Get the ``Distribution`` instance for the given package. - - :param package: The name of the package as a string. - :return: A ``Distribution`` instance (or subclass thereof). - """ - return Distribution.from_name(package) - - -def metadata(package): - """Get the metadata for the package. - - :param package: The name of the distribution package to query. - :return: An email.Message containing the parsed metadata. - """ - return Distribution.from_name(package).metadata - - -def version(package): - """Get the version string for the named package. - - :param package: The name of the distribution package to query. - :return: The version string for the package as defined in the package's - "Version" metadata key. - """ - return distribution(package).version - - -def entry_points(name): - """Return the entry points for the named distribution package. - - :param name: The name of the distribution package to query. - :return: A ConfigParser instance where the sections and keys are taken - from the entry_points.txt ini-style contents. - """ - as_string = read_text(name, 'entry_points.txt') - # 2018-09-10(barry): Should we provide any options here, or let the caller - # send options to the underlying ConfigParser? For now, YAGNI. - config = ConfigParser() - try: - config.read_string(as_string) - except AttributeError: # pragma: nocover - # Python 2 has no read_string - config.readfp(io.StringIO(as_string)) - return config - - -def resolve(entry_point): - """Resolve an entry point string into the named callable. - - :param entry_point: An entry point string of the form - `path.to.module:callable`. - :return: The actual callable object `path.to.module.callable` - :raises ValueError: When `entry_point` doesn't have the proper format. - """ - path, colon, name = entry_point.rpartition(':') - if colon != ':': - raise ValueError('Not an entry point: {}'.format(entry_point)) - module = import_module(path) - return getattr(module, name) - - -def read_text(package, filename): - """ - Read the text of the file in the distribution info directory. - """ - return distribution(package).read_text(filename) diff --git a/libs/win/importlib_metadata/docs/changelog.rst b/libs/win/importlib_metadata/docs/changelog.rst deleted file mode 100644 index f8f1fedc..00000000 --- a/libs/win/importlib_metadata/docs/changelog.rst +++ /dev/null @@ -1,57 +0,0 @@ -========================= - importlib_metadata NEWS -========================= - -0.7 (2018-11-27) -================ -* Fixed issue where packages with dashes in their names would - not be discovered. Closes #21. -* Distribution lookup is now case-insensitive. Closes #20. -* Wheel distributions can no longer be discovered by their module - name. Like Path distributions, they must be indicated by their - distribution package name. - -0.6 (2018-10-07) -================ -* Removed ``importlib_metadata.distribution`` function. Now - the public interface is primarily the utility functions exposed - in ``importlib_metadata.__all__``. Closes #14. -* Added two new utility functions ``read_text`` and - ``metadata``. - -0.5 (2018-09-18) -================ -* Updated README and removed details about Distribution - class, now considered private. Closes #15. -* Added test suite support for Python 3.4+. -* Fixed SyntaxErrors on Python 3.4 and 3.5. !12 -* Fixed errors on Windows joining Path elements. !15 - -0.4 (2018-09-14) -================ -* Housekeeping. - -0.3 (2018-09-14) -================ -* Added usage documentation. Closes #8 -* Add support for getting metadata from wheels on ``sys.path``. Closes #9 - -0.2 (2018-09-11) -================ -* Added ``importlib_metadata.entry_points()``. Closes #1 -* Added ``importlib_metadata.resolve()``. Closes #12 -* Add support for Python 2.7. Closes #4 - -0.1 (2018-09-10) -================ -* Initial release. - - -.. - Local Variables: - mode: change-log-mode - indent-tabs-mode: nil - sentence-end-double-space: t - fill-column: 78 - coding: utf-8 - End: diff --git a/libs/win/importlib_metadata/docs/conf.py b/libs/win/importlib_metadata/docs/conf.py deleted file mode 100644 index c87fc4f2..00000000 --- a/libs/win/importlib_metadata/docs/conf.py +++ /dev/null @@ -1,180 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# -# flake8: noqa -# -# importlib_metadata documentation build configuration file, created by -# sphinx-quickstart on Thu Nov 30 10:21:00 2017. -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -# import os -# import sys -# sys.path.insert(0, os.path.abspath('.')) - - -# -- General configuration ------------------------------------------------ - -# If your documentation needs a minimal Sphinx version, state it here. -# -# needs_sphinx = '1.0' - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = ['sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.coverage', - 'sphinx.ext.viewcode'] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# -# source_suffix = ['.rst', '.md'] -source_suffix = '.rst' - -# The master toctree document. -master_doc = 'index' - -# General information about the project. -project = 'importlib_metadata' -copyright = '2017-2018, Jason Coombs, Barry Warsaw' -author = 'Jason Coombs, Barry Warsaw' - -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. -version = '0.1' -# The full version, including alpha/beta/rc tags. -release = '0.1' - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language = None - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = False - - -# -- Options for HTML output ---------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'default' - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -# -# html_theme_options = {} - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - -# Custom sidebar templates, must be a dictionary that maps document names -# to template names. -# -# This is required for the alabaster theme -# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars -html_sidebars = { - '**': [ - 'relations.html', # needs 'show_related': True theme option to display - 'searchbox.html', - ] -} - - -# -- Options for HTMLHelp output ------------------------------------------ - -# Output file base name for HTML help builder. -htmlhelp_basename = 'importlib_metadatadoc' - - -# -- Options for LaTeX output --------------------------------------------- - -latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # - # 'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). - # - # 'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. - # - # 'preamble': '', - - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [ - (master_doc, 'importlib_metadata.tex', 'importlib\\_metadata Documentation', - 'Brett Cannon, Barry Warsaw', 'manual'), -] - - -# -- Options for manual page output --------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'importlib_metadata', 'importlib_metadata Documentation', - [author], 1) -] - - -# -- Options for Texinfo output ------------------------------------------- - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - (master_doc, 'importlib_metadata', 'importlib_metadata Documentation', - author, 'importlib_metadata', 'One line description of project.', - 'Miscellaneous'), -] - - - - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - } diff --git a/libs/win/importlib_metadata/docs/index.rst b/libs/win/importlib_metadata/docs/index.rst deleted file mode 100644 index 21da1ed6..00000000 --- a/libs/win/importlib_metadata/docs/index.rst +++ /dev/null @@ -1,53 +0,0 @@ -=============================== - Welcome to importlib_metadata -=============================== - -``importlib_metadata`` is a library which provides an API for accessing an -installed package's `metadata`_, such as its entry points or its top-level -name. This functionality intends to replace most uses of ``pkg_resources`` -`entry point API`_ and `metadata API`_. Along with ``importlib.resources`` in -`Python 3.7 and newer`_ (backported as `importlib_resources`_ for older -versions of Python), this can eliminate the need to use the older and less -efficient ``pkg_resources`` package. - -``importlib_metadata`` is a backport of Python 3.8's standard library -`importlib.metadata`_ module for Python 2.7, and 3.4 through 3.7. Users of -Python 3.8 and beyond are encouraged to use the standard library module, and -in fact for these versions, ``importlib_metadata`` just shadows that module. -Developers looking for detailed API descriptions should refer to the Python -3.8 standard library documentation. - -The documentation here includes a general :ref:`usage ` guide. - - -.. toctree:: - :maxdepth: 2 - :caption: Contents: - - using.rst - changelog.rst - - -Project details -=============== - - * Project home: https://gitlab.com/python-devs/importlib_metadata - * Report bugs at: https://gitlab.com/python-devs/importlib_metadata/issues - * Code hosting: https://gitlab.com/python-devs/importlib_metadata.git - * Documentation: http://importlib_metadata.readthedocs.io/ - - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` - - -.. _`metadata`: https://www.python.org/dev/peps/pep-0566/ -.. _`entry point API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#entry-points -.. _`metadata API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#metadata-api -.. _`Python 3.7 and newer`: https://docs.python.org/3/library/importlib.html#module-importlib.resources -.. _`importlib_resources`: https://importlib-resources.readthedocs.io/en/latest/index.html -.. _`importlib.metadata`: TBD diff --git a/libs/win/importlib_metadata/docs/using.rst b/libs/win/importlib_metadata/docs/using.rst deleted file mode 100644 index 2af6c822..00000000 --- a/libs/win/importlib_metadata/docs/using.rst +++ /dev/null @@ -1,133 +0,0 @@ -.. _using: - -========================== - Using importlib_metadata -========================== - -``importlib_metadata`` is a library that provides for access to installed -package metadata. Built in part on Python's import system, this library -intends to replace similar functionality in ``pkg_resources`` `entry point -API`_ and `metadata API`_. Along with ``importlib.resources`` in `Python 3.7 -and newer`_ (backported as `importlib_resources`_ for older versions of -Python), this can eliminate the need to use the older and less efficient -``pkg_resources`` package. - -By "installed package" we generally mean a third party package installed into -Python's ``site-packages`` directory via tools such as ``pip``. Specifically, -it means a package with either a discoverable ``dist-info`` or ``egg-info`` -directory, and metadata defined by `PEP 566`_ or its older specifications. -By default, package metadata can live on the file system or in wheels on -``sys.path``. Through an extension mechanism, the metadata can live almost -anywhere. - - -Overview -======== - -Let's say you wanted to get the version string for a package you've installed -using ``pip``. We start by creating a virtual environment and installing -something into it:: - - $ python3 -m venv example - $ source example/bin/activate - (example) $ pip install importlib_metadata - (example) $ pip install wheel - -You can get the version string for ``wheel`` by running the following:: - - (example) $ python - >>> from importlib_metadata import version - >>> version('wheel') - '0.31.1' - -You can also get the set of entry points for the ``wheel`` package. Since the -``entry_points.txt`` file is an ``.ini``-style, the ``entry_points()`` -function returns a `ConfigParser instance`_. To get the list of command line -entry points, extract the ``console_scripts`` section:: - - >>> cp = entry_points('wheel') - >>> cp.options('console_scripts') - ['wheel'] - -You can also get the callable that the entry point is mapped to:: - - >>> cp.get('console_scripts', 'wheel') - 'wheel.tool:main' - -Even more conveniently, you can resolve this entry point to the actual -callable:: - - >>> from importlib_metadata import resolve - >>> ep = cp.get('console_scripts', 'wheel') - >>> resolve(ep) - - - -Distributions -============= - -While the above API is the most common and convenient usage, you can get all -of that information from the ``Distribution`` class. A ``Distribution`` is an -abstract object that represents the metadata for a Python package. You can -get the ``Distribution`` instance:: - - >>> from importlib_metadata import distribution - >>> dist = distribution('wheel') - -Thus, an alternative way to get the version number is through the -``Distribution`` instance:: - - >>> dist.version - '0.31.1' - -There are all kinds of additional metadata available on the ``Distribution`` -instance:: - - >>> d.metadata['Requires-Python'] - '>=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*' - >>> d.metadata['License'] - 'MIT' - -The full set of available metadata is not described here. See PEP 566 for -additional details. - - -Extending the search algorithm -============================== - -Because package metadata is not available through ``sys.path`` searches, or -package loaders directly, the metadata for a package is found through import -system `finders`_. To find a distribution package's metadata, -``importlib_metadata`` queries the list of `meta path finders`_ on -`sys.meta_path`_. - -By default ``importlib_metadata`` installs a finder for packages found on the -file system. This finder doesn't actually find any *packages*, but it cany -find the package's metadata. - -The abstract class :py:class:`importlib.abc.MetaPathFinder` defines the -interface expected of finders by Python's import system. -``importlib_metadata`` extends this protocol by looking for an optional -``find_distribution()`` ``@classmethod`` on the finders from -``sys.meta_path``. If the finder has this method, it takes a single argument -which is the name of the distribution package to find. The method returns -``None`` if it cannot find the distribution package, otherwise it returns an -instance of the ``Distribution`` abstract class. - -What this means in practice is that to support finding distribution package -metadata in locations other than the file system, you should derive from -``Distribution`` and implement the ``load_metadata()`` method. This takes a -single argument which is the name of the package whose metadata is being -found. This instance of the ``Distribution`` base abstract class is what your -finder's ``find_distribution()`` method should return. - - -.. _`entry point API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#entry-points -.. _`metadata API`: https://setuptools.readthedocs.io/en/latest/pkg_resources.html#metadata-api -.. _`Python 3.7 and newer`: https://docs.python.org/3/library/importlib.html#module-importlib.resources -.. _`importlib_resources`: https://importlib-resources.readthedocs.io/en/latest/index.html -.. _`PEP 566`: https://www.python.org/dev/peps/pep-0566/ -.. _`ConfigParser instance`: https://docs.python.org/3/library/configparser.html#configparser.ConfigParser -.. _`finders`: https://docs.python.org/3/reference/import.html#finders-and-loaders -.. _`meta path finders`: https://docs.python.org/3/glossary.html#term-meta-path-finder -.. _`sys.meta_path`: https://docs.python.org/3/library/sys.html#sys.meta_path diff --git a/libs/win/importlib_metadata/tests/test_api.py b/libs/win/importlib_metadata/tests/test_api.py deleted file mode 100644 index 82c61f51..00000000 --- a/libs/win/importlib_metadata/tests/test_api.py +++ /dev/null @@ -1,44 +0,0 @@ -import re -import unittest - -import importlib_metadata - - -class APITests(unittest.TestCase): - version_pattern = r'\d+\.\d+(\.\d)?' - - def test_retrieves_version_of_self(self): - version = importlib_metadata.version('importlib_metadata') - assert isinstance(version, str) - assert re.match(self.version_pattern, version) - - def test_retrieves_version_of_pip(self): - # Assume pip is installed and retrieve the version of pip. - version = importlib_metadata.version('pip') - assert isinstance(version, str) - assert re.match(self.version_pattern, version) - - def test_for_name_does_not_exist(self): - with self.assertRaises(importlib_metadata.PackageNotFoundError): - importlib_metadata.distribution('does-not-exist') - - def test_for_top_level(self): - distribution = importlib_metadata.distribution('importlib_metadata') - self.assertEqual( - distribution.read_text('top_level.txt').strip(), - 'importlib_metadata') - - def test_entry_points(self): - parser = importlib_metadata.entry_points('pip') - # We should probably not be dependent on a third party package's - # internal API staying stable. - entry_point = parser.get('console_scripts', 'pip') - self.assertEqual(entry_point, 'pip._internal:main') - - def test_metadata_for_this_package(self): - md = importlib_metadata.metadata('importlib_metadata') - assert md['author'] == 'Barry Warsaw' - assert md['LICENSE'] == 'Apache Software License' - assert md['Name'] == 'importlib-metadata' - classifiers = md.get_all('Classifier') - assert 'Topic :: Software Development :: Libraries' in classifiers diff --git a/libs/win/importlib_metadata/tests/test_main.py b/libs/win/importlib_metadata/tests/test_main.py deleted file mode 100644 index 381e4dae..00000000 --- a/libs/win/importlib_metadata/tests/test_main.py +++ /dev/null @@ -1,121 +0,0 @@ -from __future__ import unicode_literals - -import re -import sys -import shutil -import tempfile -import unittest -import importlib -import contextlib -import importlib_metadata - -try: - from contextlib import ExitStack -except ImportError: - from contextlib2 import ExitStack - -try: - import pathlib -except ImportError: - import pathlib2 as pathlib - -from importlib_metadata import _hooks - - -class BasicTests(unittest.TestCase): - version_pattern = r'\d+\.\d+(\.\d)?' - - def test_retrieves_version_of_pip(self): - # Assume pip is installed and retrieve the version of pip. - dist = importlib_metadata.Distribution.from_name('pip') - assert isinstance(dist.version, str) - assert re.match(self.version_pattern, dist.version) - - def test_for_name_does_not_exist(self): - with self.assertRaises(importlib_metadata.PackageNotFoundError): - importlib_metadata.Distribution.from_name('does-not-exist') - - def test_new_style_classes(self): - self.assertIsInstance(importlib_metadata.Distribution, type) - self.assertIsInstance(_hooks.MetadataPathFinder, type) - self.assertIsInstance(_hooks.WheelMetadataFinder, type) - self.assertIsInstance(_hooks.WheelDistribution, type) - - -class ImportTests(unittest.TestCase): - def test_import_nonexistent_module(self): - # Ensure that the MetadataPathFinder does not crash an import of a - # non-existant module. - with self.assertRaises(ImportError): - importlib.import_module('does_not_exist') - - def test_resolve(self): - entry_points = importlib_metadata.entry_points('pip') - main = importlib_metadata.resolve( - entry_points.get('console_scripts', 'pip')) - import pip._internal - self.assertEqual(main, pip._internal.main) - - def test_resolve_invalid(self): - self.assertRaises(ValueError, importlib_metadata.resolve, 'bogus.ep') - - -class NameNormalizationTests(unittest.TestCase): - @staticmethod - def pkg_with_dashes(site_dir): - """ - Create minimal metadata for a package with dashes - in the name (and thus underscores in the filename). - """ - metadata_dir = site_dir / 'my_pkg.dist-info' - metadata_dir.mkdir() - metadata = metadata_dir / 'METADATA' - with metadata.open('w') as strm: - strm.write('Version: 1.0\n') - return 'my-pkg' - - @staticmethod - @contextlib.contextmanager - def site_dir(): - tmpdir = tempfile.mkdtemp() - sys.path[:0] = [tmpdir] - try: - yield pathlib.Path(tmpdir) - finally: - sys.path.remove(tmpdir) - shutil.rmtree(tmpdir) - - def setUp(self): - self.fixtures = ExitStack() - self.addCleanup(self.fixtures.close) - self.site_dir = self.fixtures.enter_context(self.site_dir()) - - def test_dashes_in_dist_name_found_as_underscores(self): - """ - For a package with a dash in the name, the dist-info metadata - uses underscores in the name. Ensure the metadata loads. - """ - pkg_name = self.pkg_with_dashes(self.site_dir) - assert importlib_metadata.version(pkg_name) == '1.0' - - @staticmethod - def pkg_with_mixed_case(site_dir): - """ - Create minimal metadata for a package with mixed case - in the name. - """ - metadata_dir = site_dir / 'CherryPy.dist-info' - metadata_dir.mkdir() - metadata = metadata_dir / 'METADATA' - with metadata.open('w') as strm: - strm.write('Version: 1.0\n') - return 'CherryPy' - - def test_dist_name_found_as_any_case(self): - """ - Ensure the metadata loads when queried with any case. - """ - pkg_name = self.pkg_with_mixed_case(self.site_dir) - assert importlib_metadata.version(pkg_name) == '1.0' - assert importlib_metadata.version(pkg_name.lower()) == '1.0' - assert importlib_metadata.version(pkg_name.upper()) == '1.0' diff --git a/libs/win/importlib_metadata/tests/test_zip.py b/libs/win/importlib_metadata/tests/test_zip.py deleted file mode 100644 index 7bdf55a9..00000000 --- a/libs/win/importlib_metadata/tests/test_zip.py +++ /dev/null @@ -1,42 +0,0 @@ -import sys -import unittest -import importlib_metadata - -try: - from contextlib import ExitStack -except ImportError: - from contextlib2 import ExitStack - -from importlib_resources import path - - -class BespokeLoader: - archive = 'bespoke' - - -class TestZip(unittest.TestCase): - def setUp(self): - # Find the path to the example.*.whl so we can add it to the front of - # sys.path, where we'll then try to find the metadata thereof. - self.resources = ExitStack() - self.addCleanup(self.resources.close) - wheel = self.resources.enter_context( - path('importlib_metadata.tests.data', - 'example-21.12-py3-none-any.whl')) - sys.path.insert(0, str(wheel)) - self.resources.callback(sys.path.pop, 0) - - def test_zip_version(self): - self.assertEqual(importlib_metadata.version('example'), '21.12') - - def test_zip_entry_points(self): - parser = importlib_metadata.entry_points('example') - entry_point = parser.get('console_scripts', 'example') - self.assertEqual(entry_point, 'example:main') - - def test_missing_metadata(self): - distribution = importlib_metadata.distribution('example') - self.assertIsNone(distribution.read_text('does not exist')) - - def test_case_insensitive(self): - self.assertEqual(importlib_metadata.version('Example'), '21.12') diff --git a/libs/win/importlib_metadata/version.txt b/libs/win/importlib_metadata/version.txt deleted file mode 100644 index eb49d7c7..00000000 --- a/libs/win/importlib_metadata/version.txt +++ /dev/null @@ -1 +0,0 @@ -0.7 diff --git a/libs/win/importlib_resources/__init__.py b/libs/win/importlib_resources/__init__.py new file mode 100644 index 00000000..34e3a995 --- /dev/null +++ b/libs/win/importlib_resources/__init__.py @@ -0,0 +1,36 @@ +"""Read resources contained within a package.""" + +from ._common import ( + as_file, + files, + Package, +) + +from ._legacy import ( + contents, + open_binary, + read_binary, + open_text, + read_text, + is_resource, + path, + Resource, +) + +from .abc import ResourceReader + + +__all__ = [ + 'Package', + 'Resource', + 'ResourceReader', + 'as_file', + 'contents', + 'files', + 'is_resource', + 'open_binary', + 'open_text', + 'path', + 'read_binary', + 'read_text', +] diff --git a/libs/win/importlib_resources/_adapters.py b/libs/win/importlib_resources/_adapters.py new file mode 100644 index 00000000..ea363d86 --- /dev/null +++ b/libs/win/importlib_resources/_adapters.py @@ -0,0 +1,170 @@ +from contextlib import suppress +from io import TextIOWrapper + +from . import abc + + +class SpecLoaderAdapter: + """ + Adapt a package spec to adapt the underlying loader. + """ + + def __init__(self, spec, adapter=lambda spec: spec.loader): + self.spec = spec + self.loader = adapter(spec) + + def __getattr__(self, name): + return getattr(self.spec, name) + + +class TraversableResourcesLoader: + """ + Adapt a loader to provide TraversableResources. + """ + + def __init__(self, spec): + self.spec = spec + + def get_resource_reader(self, name): + return CompatibilityFiles(self.spec)._native() + + +def _io_wrapper(file, mode='r', *args, **kwargs): + if mode == 'r': + return TextIOWrapper(file, *args, **kwargs) + elif mode == 'rb': + return file + raise ValueError( + "Invalid mode value '{}', only 'r' and 'rb' are supported".format(mode) + ) + + +class CompatibilityFiles: + """ + Adapter for an existing or non-existent resource reader + to provide a compatibility .files(). + """ + + class SpecPath(abc.Traversable): + """ + Path tied to a module spec. + Can be read and exposes the resource reader children. + """ + + def __init__(self, spec, reader): + self._spec = spec + self._reader = reader + + def iterdir(self): + if not self._reader: + return iter(()) + return iter( + CompatibilityFiles.ChildPath(self._reader, path) + for path in self._reader.contents() + ) + + def is_file(self): + return False + + is_dir = is_file + + def joinpath(self, other): + if not self._reader: + return CompatibilityFiles.OrphanPath(other) + return CompatibilityFiles.ChildPath(self._reader, other) + + @property + def name(self): + return self._spec.name + + def open(self, mode='r', *args, **kwargs): + return _io_wrapper(self._reader.open_resource(None), mode, *args, **kwargs) + + class ChildPath(abc.Traversable): + """ + Path tied to a resource reader child. + Can be read but doesn't expose any meaningful children. + """ + + def __init__(self, reader, name): + self._reader = reader + self._name = name + + def iterdir(self): + return iter(()) + + def is_file(self): + return self._reader.is_resource(self.name) + + def is_dir(self): + return not self.is_file() + + def joinpath(self, other): + return CompatibilityFiles.OrphanPath(self.name, other) + + @property + def name(self): + return self._name + + def open(self, mode='r', *args, **kwargs): + return _io_wrapper( + self._reader.open_resource(self.name), mode, *args, **kwargs + ) + + class OrphanPath(abc.Traversable): + """ + Orphan path, not tied to a module spec or resource reader. + Can't be read and doesn't expose any meaningful children. + """ + + def __init__(self, *path_parts): + if len(path_parts) < 1: + raise ValueError('Need at least one path part to construct a path') + self._path = path_parts + + def iterdir(self): + return iter(()) + + def is_file(self): + return False + + is_dir = is_file + + def joinpath(self, other): + return CompatibilityFiles.OrphanPath(*self._path, other) + + @property + def name(self): + return self._path[-1] + + def open(self, mode='r', *args, **kwargs): + raise FileNotFoundError("Can't open orphan path") + + def __init__(self, spec): + self.spec = spec + + @property + def _reader(self): + with suppress(AttributeError): + return self.spec.loader.get_resource_reader(self.spec.name) + + def _native(self): + """ + Return the native reader if it supports files(). + """ + reader = self._reader + return reader if hasattr(reader, 'files') else self + + def __getattr__(self, attr): + return getattr(self._reader, attr) + + def files(self): + return CompatibilityFiles.SpecPath(self.spec, self._reader) + + +def wrap_spec(package): + """ + Construct a package spec with traversable compatibility + on the spec/loader/reader. + """ + return SpecLoaderAdapter(package.__spec__, TraversableResourcesLoader) diff --git a/libs/win/importlib_resources/_common.py b/libs/win/importlib_resources/_common.py new file mode 100644 index 00000000..9f19784d --- /dev/null +++ b/libs/win/importlib_resources/_common.py @@ -0,0 +1,207 @@ +import os +import pathlib +import tempfile +import functools +import contextlib +import types +import importlib +import inspect +import warnings +import itertools + +from typing import Union, Optional, cast +from .abc import ResourceReader, Traversable + +from ._compat import wrap_spec + +Package = Union[types.ModuleType, str] +Anchor = Package + + +def package_to_anchor(func): + """ + Replace 'package' parameter as 'anchor' and warn about the change. + + Other errors should fall through. + + >>> files('a', 'b') + Traceback (most recent call last): + TypeError: files() takes from 0 to 1 positional arguments but 2 were given + """ + undefined = object() + + @functools.wraps(func) + def wrapper(anchor=undefined, package=undefined): + if package is not undefined: + if anchor is not undefined: + return func(anchor, package) + warnings.warn( + "First parameter to files is renamed to 'anchor'", + DeprecationWarning, + stacklevel=2, + ) + return func(package) + elif anchor is undefined: + return func() + return func(anchor) + + return wrapper + + +@package_to_anchor +def files(anchor: Optional[Anchor] = None) -> Traversable: + """ + Get a Traversable resource for an anchor. + """ + return from_package(resolve(anchor)) + + +def get_resource_reader(package: types.ModuleType) -> Optional[ResourceReader]: + """ + Return the package's loader if it's a ResourceReader. + """ + # We can't use + # a issubclass() check here because apparently abc.'s __subclasscheck__() + # hook wants to create a weak reference to the object, but + # zipimport.zipimporter does not support weak references, resulting in a + # TypeError. That seems terrible. + spec = package.__spec__ + reader = getattr(spec.loader, 'get_resource_reader', None) # type: ignore + if reader is None: + return None + return reader(spec.name) # type: ignore + + +@functools.singledispatch +def resolve(cand: Optional[Anchor]) -> types.ModuleType: + return cast(types.ModuleType, cand) + + +@resolve.register +def _(cand: str) -> types.ModuleType: + return importlib.import_module(cand) + + +@resolve.register +def _(cand: None) -> types.ModuleType: + return resolve(_infer_caller().f_globals['__name__']) + + +def _infer_caller(): + """ + Walk the stack and find the frame of the first caller not in this module. + """ + + def is_this_file(frame_info): + return frame_info.filename == __file__ + + def is_wrapper(frame_info): + return frame_info.function == 'wrapper' + + not_this_file = itertools.filterfalse(is_this_file, inspect.stack()) + # also exclude 'wrapper' due to singledispatch in the call stack + callers = itertools.filterfalse(is_wrapper, not_this_file) + return next(callers).frame + + +def from_package(package: types.ModuleType): + """ + Return a Traversable object for the given package. + + """ + spec = wrap_spec(package) + reader = spec.loader.get_resource_reader(spec.name) + return reader.files() + + +@contextlib.contextmanager +def _tempfile( + reader, + suffix='', + # gh-93353: Keep a reference to call os.remove() in late Python + # finalization. + *, + _os_remove=os.remove, +): + # Not using tempfile.NamedTemporaryFile as it leads to deeper 'try' + # blocks due to the need to close the temporary file to work on Windows + # properly. + fd, raw_path = tempfile.mkstemp(suffix=suffix) + try: + try: + os.write(fd, reader()) + finally: + os.close(fd) + del reader + yield pathlib.Path(raw_path) + finally: + try: + _os_remove(raw_path) + except FileNotFoundError: + pass + + +def _temp_file(path): + return _tempfile(path.read_bytes, suffix=path.name) + + +def _is_present_dir(path: Traversable) -> bool: + """ + Some Traversables implement ``is_dir()`` to raise an + exception (i.e. ``FileNotFoundError``) when the + directory doesn't exist. This function wraps that call + to always return a boolean and only return True + if there's a dir and it exists. + """ + with contextlib.suppress(FileNotFoundError): + return path.is_dir() + return False + + +@functools.singledispatch +def as_file(path): + """ + Given a Traversable object, return that object as a + path on the local file system in a context manager. + """ + return _temp_dir(path) if _is_present_dir(path) else _temp_file(path) + + +@as_file.register(pathlib.Path) +@contextlib.contextmanager +def _(path): + """ + Degenerate behavior for pathlib.Path objects. + """ + yield path + + +@contextlib.contextmanager +def _temp_path(dir: tempfile.TemporaryDirectory): + """ + Wrap tempfile.TemporyDirectory to return a pathlib object. + """ + with dir as result: + yield pathlib.Path(result) + + +@contextlib.contextmanager +def _temp_dir(path): + """ + Given a traversable dir, recursively replicate the whole tree + to the file system in a context manager. + """ + assert path.is_dir() + with _temp_path(tempfile.TemporaryDirectory()) as temp_dir: + yield _write_contents(temp_dir, path) + + +def _write_contents(target, source): + child = target.joinpath(source.name) + if source.is_dir(): + child.mkdir() + for item in source.iterdir(): + _write_contents(child, item) + else: + child.open('wb').write(source.read_bytes()) + return child diff --git a/libs/win/importlib_resources/_compat.py b/libs/win/importlib_resources/_compat.py new file mode 100644 index 00000000..8d7ade08 --- /dev/null +++ b/libs/win/importlib_resources/_compat.py @@ -0,0 +1,108 @@ +# flake8: noqa + +import abc +import os +import sys +import pathlib +from contextlib import suppress +from typing import Union + + +if sys.version_info >= (3, 10): + from zipfile import Path as ZipPath # type: ignore +else: + from zipp import Path as ZipPath # type: ignore + + +try: + from typing import runtime_checkable # type: ignore +except ImportError: + + def runtime_checkable(cls): # type: ignore + return cls + + +try: + from typing import Protocol # type: ignore +except ImportError: + Protocol = abc.ABC # type: ignore + + +class TraversableResourcesLoader: + """ + Adapt loaders to provide TraversableResources and other + compatibility. + + Used primarily for Python 3.9 and earlier where the native + loaders do not yet implement TraversableResources. + """ + + def __init__(self, spec): + self.spec = spec + + @property + def path(self): + return self.spec.origin + + def get_resource_reader(self, name): + from . import readers, _adapters + + def _zip_reader(spec): + with suppress(AttributeError): + return readers.ZipReader(spec.loader, spec.name) + + def _namespace_reader(spec): + with suppress(AttributeError, ValueError): + return readers.NamespaceReader(spec.submodule_search_locations) + + def _available_reader(spec): + with suppress(AttributeError): + return spec.loader.get_resource_reader(spec.name) + + def _native_reader(spec): + reader = _available_reader(spec) + return reader if hasattr(reader, 'files') else None + + def _file_reader(spec): + try: + path = pathlib.Path(self.path) + except TypeError: + return None + if path.exists(): + return readers.FileReader(self) + + return ( + # native reader if it supplies 'files' + _native_reader(self.spec) + or + # local ZipReader if a zip module + _zip_reader(self.spec) + or + # local NamespaceReader if a namespace module + _namespace_reader(self.spec) + or + # local FileReader + _file_reader(self.spec) + # fallback - adapt the spec ResourceReader to TraversableReader + or _adapters.CompatibilityFiles(self.spec) + ) + + +def wrap_spec(package): + """ + Construct a package spec with traversable compatibility + on the spec/loader/reader. + + Supersedes _adapters.wrap_spec to use TraversableResourcesLoader + from above for older Python compatibility (<3.10). + """ + from . import _adapters + + return _adapters.SpecLoaderAdapter(package.__spec__, TraversableResourcesLoader) + + +if sys.version_info >= (3, 9): + StrPath = Union[str, os.PathLike[str]] +else: + # PathLike is only subscriptable at runtime in 3.9+ + StrPath = Union[str, "os.PathLike[str]"] diff --git a/libs/win/importlib_resources/_itertools.py b/libs/win/importlib_resources/_itertools.py new file mode 100644 index 00000000..cce05582 --- /dev/null +++ b/libs/win/importlib_resources/_itertools.py @@ -0,0 +1,35 @@ +from itertools import filterfalse + +from typing import ( + Callable, + Iterable, + Iterator, + Optional, + Set, + TypeVar, + Union, +) + +# Type and type variable definitions +_T = TypeVar('_T') +_U = TypeVar('_U') + + +def unique_everseen( + iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = None +) -> Iterator[_T]: + "List unique elements, preserving order. Remember all elements ever seen." + # unique_everseen('AAAABBBCCDAABBB') --> A B C D + # unique_everseen('ABBCcAD', str.lower) --> A B C D + seen: Set[Union[_T, _U]] = set() + seen_add = seen.add + if key is None: + for element in filterfalse(seen.__contains__, iterable): + seen_add(element) + yield element + else: + for element in iterable: + k = key(element) + if k not in seen: + seen_add(k) + yield element diff --git a/libs/win/importlib_resources/_legacy.py b/libs/win/importlib_resources/_legacy.py new file mode 100644 index 00000000..b1ea8105 --- /dev/null +++ b/libs/win/importlib_resources/_legacy.py @@ -0,0 +1,120 @@ +import functools +import os +import pathlib +import types +import warnings + +from typing import Union, Iterable, ContextManager, BinaryIO, TextIO, Any + +from . import _common + +Package = Union[types.ModuleType, str] +Resource = str + + +def deprecated(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + warnings.warn( + f"{func.__name__} is deprecated. Use files() instead. " + "Refer to https://importlib-resources.readthedocs.io" + "/en/latest/using.html#migrating-from-legacy for migration advice.", + DeprecationWarning, + stacklevel=2, + ) + return func(*args, **kwargs) + + return wrapper + + +def normalize_path(path: Any) -> str: + """Normalize a path by ensuring it is a string. + + If the resulting string contains path separators, an exception is raised. + """ + str_path = str(path) + parent, file_name = os.path.split(str_path) + if parent: + raise ValueError(f'{path!r} must be only a file name') + return file_name + + +@deprecated +def open_binary(package: Package, resource: Resource) -> BinaryIO: + """Return a file-like object opened for binary reading of the resource.""" + return (_common.files(package) / normalize_path(resource)).open('rb') + + +@deprecated +def read_binary(package: Package, resource: Resource) -> bytes: + """Return the binary contents of the resource.""" + return (_common.files(package) / normalize_path(resource)).read_bytes() + + +@deprecated +def open_text( + package: Package, + resource: Resource, + encoding: str = 'utf-8', + errors: str = 'strict', +) -> TextIO: + """Return a file-like object opened for text reading of the resource.""" + return (_common.files(package) / normalize_path(resource)).open( + 'r', encoding=encoding, errors=errors + ) + + +@deprecated +def read_text( + package: Package, + resource: Resource, + encoding: str = 'utf-8', + errors: str = 'strict', +) -> str: + """Return the decoded string of the resource. + + The decoding-related arguments have the same semantics as those of + bytes.decode(). + """ + with open_text(package, resource, encoding, errors) as fp: + return fp.read() + + +@deprecated +def contents(package: Package) -> Iterable[str]: + """Return an iterable of entries in `package`. + + Note that not all entries are resources. Specifically, directories are + not considered resources. Use `is_resource()` on each entry returned here + to check if it is a resource or not. + """ + return [path.name for path in _common.files(package).iterdir()] + + +@deprecated +def is_resource(package: Package, name: str) -> bool: + """True if `name` is a resource inside `package`. + + Directories are *not* resources. + """ + resource = normalize_path(name) + return any( + traversable.name == resource and traversable.is_file() + for traversable in _common.files(package).iterdir() + ) + + +@deprecated +def path( + package: Package, + resource: Resource, +) -> ContextManager[pathlib.Path]: + """A context manager providing a file path object to the resource. + + If the resource does not already exist on its own on the file system, + a temporary file will be created. If the file was created, the file + will be deleted upon exiting the context manager (no exception is + raised if the file was deleted prior to the context manager + exiting). + """ + return _common.as_file(_common.files(package) / normalize_path(resource)) diff --git a/libs/win/importlib_resources/abc.py b/libs/win/importlib_resources/abc.py new file mode 100644 index 00000000..23b6aeaf --- /dev/null +++ b/libs/win/importlib_resources/abc.py @@ -0,0 +1,170 @@ +import abc +import io +import itertools +import pathlib +from typing import Any, BinaryIO, Iterable, Iterator, NoReturn, Text, Optional + +from ._compat import runtime_checkable, Protocol, StrPath + + +__all__ = ["ResourceReader", "Traversable", "TraversableResources"] + + +class ResourceReader(metaclass=abc.ABCMeta): + """Abstract base class for loaders to provide resource reading support.""" + + @abc.abstractmethod + def open_resource(self, resource: Text) -> BinaryIO: + """Return an opened, file-like object for binary reading. + + The 'resource' argument is expected to represent only a file name. + If the resource cannot be found, FileNotFoundError is raised. + """ + # This deliberately raises FileNotFoundError instead of + # NotImplementedError so that if this method is accidentally called, + # it'll still do the right thing. + raise FileNotFoundError + + @abc.abstractmethod + def resource_path(self, resource: Text) -> Text: + """Return the file system path to the specified resource. + + The 'resource' argument is expected to represent only a file name. + If the resource does not exist on the file system, raise + FileNotFoundError. + """ + # This deliberately raises FileNotFoundError instead of + # NotImplementedError so that if this method is accidentally called, + # it'll still do the right thing. + raise FileNotFoundError + + @abc.abstractmethod + def is_resource(self, path: Text) -> bool: + """Return True if the named 'path' is a resource. + + Files are resources, directories are not. + """ + raise FileNotFoundError + + @abc.abstractmethod + def contents(self) -> Iterable[str]: + """Return an iterable of entries in `package`.""" + raise FileNotFoundError + + +class TraversalError(Exception): + pass + + +@runtime_checkable +class Traversable(Protocol): + """ + An object with a subset of pathlib.Path methods suitable for + traversing directories and opening files. + + Any exceptions that occur when accessing the backing resource + may propagate unaltered. + """ + + @abc.abstractmethod + def iterdir(self) -> Iterator["Traversable"]: + """ + Yield Traversable objects in self + """ + + def read_bytes(self) -> bytes: + """ + Read contents of self as bytes + """ + with self.open('rb') as strm: + return strm.read() + + def read_text(self, encoding: Optional[str] = None) -> str: + """ + Read contents of self as text + """ + with self.open(encoding=encoding) as strm: + return strm.read() + + @abc.abstractmethod + def is_dir(self) -> bool: + """ + Return True if self is a directory + """ + + @abc.abstractmethod + def is_file(self) -> bool: + """ + Return True if self is a file + """ + + def joinpath(self, *descendants: StrPath) -> "Traversable": + """ + Return Traversable resolved with any descendants applied. + + Each descendant should be a path segment relative to self + and each may contain multiple levels separated by + ``posixpath.sep`` (``/``). + """ + if not descendants: + return self + names = itertools.chain.from_iterable( + path.parts for path in map(pathlib.PurePosixPath, descendants) + ) + target = next(names) + matches = ( + traversable for traversable in self.iterdir() if traversable.name == target + ) + try: + match = next(matches) + except StopIteration: + raise TraversalError( + "Target not found during traversal.", target, list(names) + ) + return match.joinpath(*names) + + def __truediv__(self, child: StrPath) -> "Traversable": + """ + Return Traversable child in self + """ + return self.joinpath(child) + + @abc.abstractmethod + def open(self, mode='r', *args, **kwargs): + """ + mode may be 'r' or 'rb' to open as text or binary. Return a handle + suitable for reading (same as pathlib.Path.open). + + When opening as text, accepts encoding parameters such as those + accepted by io.TextIOWrapper. + """ + + @property + @abc.abstractmethod + def name(self) -> str: + """ + The base name of this object without any parent references. + """ + + +class TraversableResources(ResourceReader): + """ + The required interface for providing traversable + resources. + """ + + @abc.abstractmethod + def files(self) -> "Traversable": + """Return a Traversable object for the loaded package.""" + + def open_resource(self, resource: StrPath) -> io.BufferedReader: + return self.files().joinpath(resource).open('rb') + + def resource_path(self, resource: Any) -> NoReturn: + raise FileNotFoundError(resource) + + def is_resource(self, path: StrPath) -> bool: + return self.files().joinpath(path).is_file() + + def contents(self) -> Iterator[str]: + return (item.name for item in self.files().iterdir()) diff --git a/libs/win/importlib_metadata/docs/__init__.py b/libs/win/importlib_resources/py.typed similarity index 100% rename from libs/win/importlib_metadata/docs/__init__.py rename to libs/win/importlib_resources/py.typed diff --git a/libs/win/importlib_resources/readers.py b/libs/win/importlib_resources/readers.py new file mode 100644 index 00000000..ab34db74 --- /dev/null +++ b/libs/win/importlib_resources/readers.py @@ -0,0 +1,120 @@ +import collections +import pathlib +import operator + +from . import abc + +from ._itertools import unique_everseen +from ._compat import ZipPath + + +def remove_duplicates(items): + return iter(collections.OrderedDict.fromkeys(items)) + + +class FileReader(abc.TraversableResources): + def __init__(self, loader): + self.path = pathlib.Path(loader.path).parent + + def resource_path(self, resource): + """ + Return the file system path to prevent + `resources.path()` from creating a temporary + copy. + """ + return str(self.path.joinpath(resource)) + + def files(self): + return self.path + + +class ZipReader(abc.TraversableResources): + def __init__(self, loader, module): + _, _, name = module.rpartition('.') + self.prefix = loader.prefix.replace('\\', '/') + name + '/' + self.archive = loader.archive + + def open_resource(self, resource): + try: + return super().open_resource(resource) + except KeyError as exc: + raise FileNotFoundError(exc.args[0]) + + def is_resource(self, path): + # workaround for `zipfile.Path.is_file` returning true + # for non-existent paths. + target = self.files().joinpath(path) + return target.is_file() and target.exists() + + def files(self): + return ZipPath(self.archive, self.prefix) + + +class MultiplexedPath(abc.Traversable): + """ + Given a series of Traversable objects, implement a merged + version of the interface across all objects. Useful for + namespace packages which may be multihomed at a single + name. + """ + + def __init__(self, *paths): + self._paths = list(map(pathlib.Path, remove_duplicates(paths))) + if not self._paths: + message = 'MultiplexedPath must contain at least one path' + raise FileNotFoundError(message) + if not all(path.is_dir() for path in self._paths): + raise NotADirectoryError('MultiplexedPath only supports directories') + + def iterdir(self): + files = (file for path in self._paths for file in path.iterdir()) + return unique_everseen(files, key=operator.attrgetter('name')) + + def read_bytes(self): + raise FileNotFoundError(f'{self} is not a file') + + def read_text(self, *args, **kwargs): + raise FileNotFoundError(f'{self} is not a file') + + def is_dir(self): + return True + + def is_file(self): + return False + + def joinpath(self, *descendants): + try: + return super().joinpath(*descendants) + except abc.TraversalError: + # One of the paths did not resolve (a directory does not exist). + # Just return something that will not exist. + return self._paths[0].joinpath(*descendants) + + def open(self, *args, **kwargs): + raise FileNotFoundError(f'{self} is not a file') + + @property + def name(self): + return self._paths[0].name + + def __repr__(self): + paths = ', '.join(f"'{path}'" for path in self._paths) + return f'MultiplexedPath({paths})' + + +class NamespaceReader(abc.TraversableResources): + def __init__(self, namespace_path): + if 'NamespacePath' not in str(namespace_path): + raise ValueError('Invalid path') + self.path = MultiplexedPath(*list(namespace_path)) + + def resource_path(self, resource): + """ + Return the file system path to prevent + `resources.path()` from creating a temporary + copy. + """ + return str(self.path.joinpath(resource)) + + def files(self): + return self.path diff --git a/libs/win/importlib_resources/simple.py b/libs/win/importlib_resources/simple.py new file mode 100644 index 00000000..7770c922 --- /dev/null +++ b/libs/win/importlib_resources/simple.py @@ -0,0 +1,106 @@ +""" +Interface adapters for low-level readers. +""" + +import abc +import io +import itertools +from typing import BinaryIO, List + +from .abc import Traversable, TraversableResources + + +class SimpleReader(abc.ABC): + """ + The minimum, low-level interface required from a resource + provider. + """ + + @property + @abc.abstractmethod + def package(self) -> str: + """ + The name of the package for which this reader loads resources. + """ + + @abc.abstractmethod + def children(self) -> List['SimpleReader']: + """ + Obtain an iterable of SimpleReader for available + child containers (e.g. directories). + """ + + @abc.abstractmethod + def resources(self) -> List[str]: + """ + Obtain available named resources for this virtual package. + """ + + @abc.abstractmethod + def open_binary(self, resource: str) -> BinaryIO: + """ + Obtain a File-like for a named resource. + """ + + @property + def name(self): + return self.package.split('.')[-1] + + +class ResourceContainer(Traversable): + """ + Traversable container for a package's resources via its reader. + """ + + def __init__(self, reader: SimpleReader): + self.reader = reader + + def is_dir(self): + return True + + def is_file(self): + return False + + def iterdir(self): + files = (ResourceHandle(self, name) for name in self.reader.resources) + dirs = map(ResourceContainer, self.reader.children()) + return itertools.chain(files, dirs) + + def open(self, *args, **kwargs): + raise IsADirectoryError() + + +class ResourceHandle(Traversable): + """ + Handle to a named resource in a ResourceReader. + """ + + def __init__(self, parent: ResourceContainer, name: str): + self.parent = parent + self.name = name # type: ignore + + def is_file(self): + return True + + def is_dir(self): + return False + + def open(self, mode='r', *args, **kwargs): + stream = self.parent.reader.open_binary(self.name) + if 'b' not in mode: + stream = io.TextIOWrapper(*args, **kwargs) + return stream + + def joinpath(self, name): + raise RuntimeError("Cannot traverse into a resource") + + +class TraversableReader(TraversableResources, SimpleReader): + """ + A TraversableResources based on SimpleReader. Resource providers + may derive from this class to provide the TraversableResources + interface by supplying the SimpleReader interface. + """ + + def files(self): + return ResourceContainer(self) diff --git a/libs/win/importlib_metadata/tests/__init__.py b/libs/win/importlib_resources/tests/__init__.py similarity index 100% rename from libs/win/importlib_metadata/tests/__init__.py rename to libs/win/importlib_resources/tests/__init__.py diff --git a/libs/win/importlib_resources/tests/_compat.py b/libs/win/importlib_resources/tests/_compat.py new file mode 100644 index 00000000..e7bf06dd --- /dev/null +++ b/libs/win/importlib_resources/tests/_compat.py @@ -0,0 +1,32 @@ +import os + + +try: + from test.support import import_helper # type: ignore +except ImportError: + # Python 3.9 and earlier + class import_helper: # type: ignore + from test.support import ( + modules_setup, + modules_cleanup, + DirsOnSysPath, + CleanImport, + ) + + +try: + from test.support import os_helper # type: ignore +except ImportError: + # Python 3.9 compat + class os_helper: # type:ignore + from test.support import temp_dir + + +try: + # Python 3.10 + from test.support.os_helper import unlink +except ImportError: + from test.support import unlink as _unlink + + def unlink(target): + return _unlink(os.fspath(target)) diff --git a/libs/win/importlib_resources/tests/_path.py b/libs/win/importlib_resources/tests/_path.py new file mode 100644 index 00000000..c630e4d3 --- /dev/null +++ b/libs/win/importlib_resources/tests/_path.py @@ -0,0 +1,50 @@ +import pathlib +import functools + + +#### +# from jaraco.path 3.4 + + +def build(spec, prefix=pathlib.Path()): + """ + Build a set of files/directories, as described by the spec. + + Each key represents a pathname, and the value represents + the content. Content may be a nested directory. + + >>> spec = { + ... 'README.txt': "A README file", + ... "foo": { + ... "__init__.py": "", + ... "bar": { + ... "__init__.py": "", + ... }, + ... "baz.py": "# Some code", + ... } + ... } + >>> tmpdir = getfixture('tmpdir') + >>> build(spec, tmpdir) + """ + for name, contents in spec.items(): + create(contents, pathlib.Path(prefix) / name) + + +@functools.singledispatch +def create(content, path): + path.mkdir(exist_ok=True) + build(content, prefix=path) # type: ignore + + +@create.register +def _(content: bytes, path): + path.write_bytes(content) + + +@create.register +def _(content: str, path): + path.write_text(content) + + +# end from jaraco.path +#### diff --git a/libs/win/importlib_metadata/tests/data/__init__.py b/libs/win/importlib_resources/tests/data01/__init__.py similarity index 100% rename from libs/win/importlib_metadata/tests/data/__init__.py rename to libs/win/importlib_resources/tests/data01/__init__.py diff --git a/libs/win/importlib_resources/tests/data01/binary.file b/libs/win/importlib_resources/tests/data01/binary.file new file mode 100644 index 00000000..eaf36c1d Binary files /dev/null and b/libs/win/importlib_resources/tests/data01/binary.file differ diff --git a/libs/win/more_itertools/tests/__init__.py b/libs/win/importlib_resources/tests/data01/subdirectory/__init__.py similarity index 100% rename from libs/win/more_itertools/tests/__init__.py rename to libs/win/importlib_resources/tests/data01/subdirectory/__init__.py diff --git a/libs/win/importlib_resources/tests/data01/subdirectory/binary.file b/libs/win/importlib_resources/tests/data01/subdirectory/binary.file new file mode 100644 index 00000000..eaf36c1d Binary files /dev/null and b/libs/win/importlib_resources/tests/data01/subdirectory/binary.file differ diff --git a/libs/win/importlib_resources/tests/data01/utf-16.file b/libs/win/importlib_resources/tests/data01/utf-16.file new file mode 100644 index 00000000..2cb77229 Binary files /dev/null and b/libs/win/importlib_resources/tests/data01/utf-16.file differ diff --git a/libs/win/importlib_resources/tests/data01/utf-8.file b/libs/win/importlib_resources/tests/data01/utf-8.file new file mode 100644 index 00000000..1c0132ad --- /dev/null +++ b/libs/win/importlib_resources/tests/data01/utf-8.file @@ -0,0 +1 @@ +Hello, UTF-8 world! diff --git a/libs/win/importlib_resources/tests/data02/__init__.py b/libs/win/importlib_resources/tests/data02/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/win/importlib_resources/tests/data02/one/__init__.py b/libs/win/importlib_resources/tests/data02/one/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/win/importlib_resources/tests/data02/one/resource1.txt b/libs/win/importlib_resources/tests/data02/one/resource1.txt new file mode 100644 index 00000000..61a813e4 --- /dev/null +++ b/libs/win/importlib_resources/tests/data02/one/resource1.txt @@ -0,0 +1 @@ +one resource diff --git a/libs/win/importlib_resources/tests/data02/two/__init__.py b/libs/win/importlib_resources/tests/data02/two/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/win/importlib_resources/tests/data02/two/resource2.txt b/libs/win/importlib_resources/tests/data02/two/resource2.txt new file mode 100644 index 00000000..a80ce46e --- /dev/null +++ b/libs/win/importlib_resources/tests/data02/two/resource2.txt @@ -0,0 +1 @@ +two resource diff --git a/libs/win/importlib_resources/tests/namespacedata01/binary.file b/libs/win/importlib_resources/tests/namespacedata01/binary.file new file mode 100644 index 00000000..eaf36c1d Binary files /dev/null and b/libs/win/importlib_resources/tests/namespacedata01/binary.file differ diff --git a/libs/win/importlib_resources/tests/namespacedata01/utf-16.file b/libs/win/importlib_resources/tests/namespacedata01/utf-16.file new file mode 100644 index 00000000..2cb77229 Binary files /dev/null and b/libs/win/importlib_resources/tests/namespacedata01/utf-16.file differ diff --git a/libs/win/importlib_resources/tests/namespacedata01/utf-8.file b/libs/win/importlib_resources/tests/namespacedata01/utf-8.file new file mode 100644 index 00000000..1c0132ad --- /dev/null +++ b/libs/win/importlib_resources/tests/namespacedata01/utf-8.file @@ -0,0 +1 @@ +Hello, UTF-8 world! diff --git a/libs/win/importlib_resources/tests/test_compatibilty_files.py b/libs/win/importlib_resources/tests/test_compatibilty_files.py new file mode 100644 index 00000000..d92c7c56 --- /dev/null +++ b/libs/win/importlib_resources/tests/test_compatibilty_files.py @@ -0,0 +1,102 @@ +import io +import unittest + +import importlib_resources as resources + +from importlib_resources._adapters import ( + CompatibilityFiles, + wrap_spec, +) + +from . import util + + +class CompatibilityFilesTests(unittest.TestCase): + @property + def package(self): + bytes_data = io.BytesIO(b'Hello, world!') + return util.create_package( + file=bytes_data, + path='some_path', + contents=('a', 'b', 'c'), + ) + + @property + def files(self): + return resources.files(self.package) + + def test_spec_path_iter(self): + self.assertEqual( + sorted(path.name for path in self.files.iterdir()), + ['a', 'b', 'c'], + ) + + def test_child_path_iter(self): + self.assertEqual(list((self.files / 'a').iterdir()), []) + + def test_orphan_path_iter(self): + self.assertEqual(list((self.files / 'a' / 'a').iterdir()), []) + self.assertEqual(list((self.files / 'a' / 'a' / 'a').iterdir()), []) + + def test_spec_path_is(self): + self.assertFalse(self.files.is_file()) + self.assertFalse(self.files.is_dir()) + + def test_child_path_is(self): + self.assertTrue((self.files / 'a').is_file()) + self.assertFalse((self.files / 'a').is_dir()) + + def test_orphan_path_is(self): + self.assertFalse((self.files / 'a' / 'a').is_file()) + self.assertFalse((self.files / 'a' / 'a').is_dir()) + self.assertFalse((self.files / 'a' / 'a' / 'a').is_file()) + self.assertFalse((self.files / 'a' / 'a' / 'a').is_dir()) + + def test_spec_path_name(self): + self.assertEqual(self.files.name, 'testingpackage') + + def test_child_path_name(self): + self.assertEqual((self.files / 'a').name, 'a') + + def test_orphan_path_name(self): + self.assertEqual((self.files / 'a' / 'b').name, 'b') + self.assertEqual((self.files / 'a' / 'b' / 'c').name, 'c') + + def test_spec_path_open(self): + self.assertEqual(self.files.read_bytes(), b'Hello, world!') + self.assertEqual(self.files.read_text(), 'Hello, world!') + + def test_child_path_open(self): + self.assertEqual((self.files / 'a').read_bytes(), b'Hello, world!') + self.assertEqual((self.files / 'a').read_text(), 'Hello, world!') + + def test_orphan_path_open(self): + with self.assertRaises(FileNotFoundError): + (self.files / 'a' / 'b').read_bytes() + with self.assertRaises(FileNotFoundError): + (self.files / 'a' / 'b' / 'c').read_bytes() + + def test_open_invalid_mode(self): + with self.assertRaises(ValueError): + self.files.open('0') + + def test_orphan_path_invalid(self): + with self.assertRaises(ValueError): + CompatibilityFiles.OrphanPath() + + def test_wrap_spec(self): + spec = wrap_spec(self.package) + self.assertIsInstance(spec.loader.get_resource_reader(None), CompatibilityFiles) + + +class CompatibilityFilesNoReaderTests(unittest.TestCase): + @property + def package(self): + return util.create_package_from_loader(None) + + @property + def files(self): + return resources.files(self.package) + + def test_spec_path_joinpath(self): + self.assertIsInstance(self.files / 'a', CompatibilityFiles.OrphanPath) diff --git a/libs/win/importlib_resources/tests/test_contents.py b/libs/win/importlib_resources/tests/test_contents.py new file mode 100644 index 00000000..525568e8 --- /dev/null +++ b/libs/win/importlib_resources/tests/test_contents.py @@ -0,0 +1,43 @@ +import unittest +import importlib_resources as resources + +from . import data01 +from . import util + + +class ContentsTests: + expected = { + '__init__.py', + 'binary.file', + 'subdirectory', + 'utf-16.file', + 'utf-8.file', + } + + def test_contents(self): + contents = {path.name for path in resources.files(self.data).iterdir()} + assert self.expected <= contents + + +class ContentsDiskTests(ContentsTests, unittest.TestCase): + def setUp(self): + self.data = data01 + + +class ContentsZipTests(ContentsTests, util.ZipSetup, unittest.TestCase): + pass + + +class ContentsNamespaceTests(ContentsTests, unittest.TestCase): + expected = { + # no __init__ because of namespace design + # no subdirectory as incidental difference in fixture + 'binary.file', + 'utf-16.file', + 'utf-8.file', + } + + def setUp(self): + from . import namespacedata01 + + self.data = namespacedata01 diff --git a/libs/win/importlib_resources/tests/test_files.py b/libs/win/importlib_resources/tests/test_files.py new file mode 100644 index 00000000..d258fb5f --- /dev/null +++ b/libs/win/importlib_resources/tests/test_files.py @@ -0,0 +1,112 @@ +import typing +import textwrap +import unittest +import warnings +import importlib +import contextlib + +import importlib_resources as resources +from ..abc import Traversable +from . import data01 +from . import util +from . import _path +from ._compat import os_helper, import_helper + + +@contextlib.contextmanager +def suppress_known_deprecation(): + with warnings.catch_warnings(record=True) as ctx: + warnings.simplefilter('default', category=DeprecationWarning) + yield ctx + + +class FilesTests: + def test_read_bytes(self): + files = resources.files(self.data) + actual = files.joinpath('utf-8.file').read_bytes() + assert actual == b'Hello, UTF-8 world!\n' + + def test_read_text(self): + files = resources.files(self.data) + actual = files.joinpath('utf-8.file').read_text(encoding='utf-8') + assert actual == 'Hello, UTF-8 world!\n' + + @unittest.skipUnless( + hasattr(typing, 'runtime_checkable'), + "Only suitable when typing supports runtime_checkable", + ) + def test_traversable(self): + assert isinstance(resources.files(self.data), Traversable) + + def test_old_parameter(self): + """ + Files used to take a 'package' parameter. Make sure anyone + passing by name is still supported. + """ + with suppress_known_deprecation(): + resources.files(package=self.data) + + +class OpenDiskTests(FilesTests, unittest.TestCase): + def setUp(self): + self.data = data01 + + +class OpenZipTests(FilesTests, util.ZipSetup, unittest.TestCase): + pass + + +class OpenNamespaceTests(FilesTests, unittest.TestCase): + def setUp(self): + from . import namespacedata01 + + self.data = namespacedata01 + + +class SiteDir: + def setUp(self): + self.fixtures = contextlib.ExitStack() + self.addCleanup(self.fixtures.close) + self.site_dir = self.fixtures.enter_context(os_helper.temp_dir()) + self.fixtures.enter_context(import_helper.DirsOnSysPath(self.site_dir)) + self.fixtures.enter_context(import_helper.CleanImport()) + + +class ModulesFilesTests(SiteDir, unittest.TestCase): + def test_module_resources(self): + """ + A module can have resources found adjacent to the module. + """ + spec = { + 'mod.py': '', + 'res.txt': 'resources are the best', + } + _path.build(spec, self.site_dir) + import mod + + actual = resources.files(mod).joinpath('res.txt').read_text() + assert actual == spec['res.txt'] + + +class ImplicitContextFilesTests(SiteDir, unittest.TestCase): + def test_implicit_files(self): + """ + Without any parameter, files() will infer the location as the caller. + """ + spec = { + 'somepkg': { + '__init__.py': textwrap.dedent( + """ + import importlib_resources as res + val = res.files().joinpath('res.txt').read_text() + """ + ), + 'res.txt': 'resources are the best', + }, + } + _path.build(spec, self.site_dir) + assert importlib.import_module('somepkg').val == 'resources are the best' + + +if __name__ == '__main__': + unittest.main() diff --git a/libs/win/importlib_resources/tests/test_open.py b/libs/win/importlib_resources/tests/test_open.py new file mode 100644 index 00000000..87b42c3d --- /dev/null +++ b/libs/win/importlib_resources/tests/test_open.py @@ -0,0 +1,81 @@ +import unittest + +import importlib_resources as resources +from . import data01 +from . import util + + +class CommonBinaryTests(util.CommonTests, unittest.TestCase): + def execute(self, package, path): + target = resources.files(package).joinpath(path) + with target.open('rb'): + pass + + +class CommonTextTests(util.CommonTests, unittest.TestCase): + def execute(self, package, path): + target = resources.files(package).joinpath(path) + with target.open(): + pass + + +class OpenTests: + def test_open_binary(self): + target = resources.files(self.data) / 'binary.file' + with target.open('rb') as fp: + result = fp.read() + self.assertEqual(result, b'\x00\x01\x02\x03') + + def test_open_text_default_encoding(self): + target = resources.files(self.data) / 'utf-8.file' + with target.open() as fp: + result = fp.read() + self.assertEqual(result, 'Hello, UTF-8 world!\n') + + def test_open_text_given_encoding(self): + target = resources.files(self.data) / 'utf-16.file' + with target.open(encoding='utf-16', errors='strict') as fp: + result = fp.read() + self.assertEqual(result, 'Hello, UTF-16 world!\n') + + def test_open_text_with_errors(self): + # Raises UnicodeError without the 'errors' argument. + target = resources.files(self.data) / 'utf-16.file' + with target.open(encoding='utf-8', errors='strict') as fp: + self.assertRaises(UnicodeError, fp.read) + with target.open(encoding='utf-8', errors='ignore') as fp: + result = fp.read() + self.assertEqual( + result, + 'H\x00e\x00l\x00l\x00o\x00,\x00 ' + '\x00U\x00T\x00F\x00-\x001\x006\x00 ' + '\x00w\x00o\x00r\x00l\x00d\x00!\x00\n\x00', + ) + + def test_open_binary_FileNotFoundError(self): + target = resources.files(self.data) / 'does-not-exist' + self.assertRaises(FileNotFoundError, target.open, 'rb') + + def test_open_text_FileNotFoundError(self): + target = resources.files(self.data) / 'does-not-exist' + self.assertRaises(FileNotFoundError, target.open) + + +class OpenDiskTests(OpenTests, unittest.TestCase): + def setUp(self): + self.data = data01 + + +class OpenDiskNamespaceTests(OpenTests, unittest.TestCase): + def setUp(self): + from . import namespacedata01 + + self.data = namespacedata01 + + +class OpenZipTests(OpenTests, util.ZipSetup, unittest.TestCase): + pass + + +if __name__ == '__main__': + unittest.main() diff --git a/libs/win/importlib_resources/tests/test_path.py b/libs/win/importlib_resources/tests/test_path.py new file mode 100644 index 00000000..4f4d3943 --- /dev/null +++ b/libs/win/importlib_resources/tests/test_path.py @@ -0,0 +1,64 @@ +import io +import unittest + +import importlib_resources as resources +from . import data01 +from . import util + + +class CommonTests(util.CommonTests, unittest.TestCase): + def execute(self, package, path): + with resources.as_file(resources.files(package).joinpath(path)): + pass + + +class PathTests: + def test_reading(self): + # Path should be readable. + # Test also implicitly verifies the returned object is a pathlib.Path + # instance. + target = resources.files(self.data) / 'utf-8.file' + with resources.as_file(target) as path: + self.assertTrue(path.name.endswith("utf-8.file"), repr(path)) + # pathlib.Path.read_text() was introduced in Python 3.5. + with path.open('r', encoding='utf-8') as file: + text = file.read() + self.assertEqual('Hello, UTF-8 world!\n', text) + + +class PathDiskTests(PathTests, unittest.TestCase): + data = data01 + + def test_natural_path(self): + """ + Guarantee the internal implementation detail that + file-system-backed resources do not get the tempdir + treatment. + """ + target = resources.files(self.data) / 'utf-8.file' + with resources.as_file(target) as path: + assert 'data' in str(path) + + +class PathMemoryTests(PathTests, unittest.TestCase): + def setUp(self): + file = io.BytesIO(b'Hello, UTF-8 world!\n') + self.addCleanup(file.close) + self.data = util.create_package( + file=file, path=FileNotFoundError("package exists only in memory") + ) + self.data.__spec__.origin = None + self.data.__spec__.has_location = False + + +class PathZipTests(PathTests, util.ZipSetup, unittest.TestCase): + def test_remove_in_context_manager(self): + # It is not an error if the file that was temporarily stashed on the + # file system is removed inside the `with` stanza. + target = resources.files(self.data) / 'utf-8.file' + with resources.as_file(target) as path: + path.unlink() + + +if __name__ == '__main__': + unittest.main() diff --git a/libs/win/importlib_resources/tests/test_read.py b/libs/win/importlib_resources/tests/test_read.py new file mode 100644 index 00000000..41dd6db5 --- /dev/null +++ b/libs/win/importlib_resources/tests/test_read.py @@ -0,0 +1,76 @@ +import unittest +import importlib_resources as resources + +from . import data01 +from . import util +from importlib import import_module + + +class CommonBinaryTests(util.CommonTests, unittest.TestCase): + def execute(self, package, path): + resources.files(package).joinpath(path).read_bytes() + + +class CommonTextTests(util.CommonTests, unittest.TestCase): + def execute(self, package, path): + resources.files(package).joinpath(path).read_text() + + +class ReadTests: + def test_read_bytes(self): + result = resources.files(self.data).joinpath('binary.file').read_bytes() + self.assertEqual(result, b'\0\1\2\3') + + def test_read_text_default_encoding(self): + result = resources.files(self.data).joinpath('utf-8.file').read_text() + self.assertEqual(result, 'Hello, UTF-8 world!\n') + + def test_read_text_given_encoding(self): + result = ( + resources.files(self.data) + .joinpath('utf-16.file') + .read_text(encoding='utf-16') + ) + self.assertEqual(result, 'Hello, UTF-16 world!\n') + + def test_read_text_with_errors(self): + # Raises UnicodeError without the 'errors' argument. + target = resources.files(self.data) / 'utf-16.file' + self.assertRaises(UnicodeError, target.read_text, encoding='utf-8') + result = target.read_text(encoding='utf-8', errors='ignore') + self.assertEqual( + result, + 'H\x00e\x00l\x00l\x00o\x00,\x00 ' + '\x00U\x00T\x00F\x00-\x001\x006\x00 ' + '\x00w\x00o\x00r\x00l\x00d\x00!\x00\n\x00', + ) + + +class ReadDiskTests(ReadTests, unittest.TestCase): + data = data01 + + +class ReadZipTests(ReadTests, util.ZipSetup, unittest.TestCase): + def test_read_submodule_resource(self): + submodule = import_module('ziptestdata.subdirectory') + result = resources.files(submodule).joinpath('binary.file').read_bytes() + self.assertEqual(result, b'\0\1\2\3') + + def test_read_submodule_resource_by_name(self): + result = ( + resources.files('ziptestdata.subdirectory') + .joinpath('binary.file') + .read_bytes() + ) + self.assertEqual(result, b'\0\1\2\3') + + +class ReadNamespaceTests(ReadTests, unittest.TestCase): + def setUp(self): + from . import namespacedata01 + + self.data = namespacedata01 + + +if __name__ == '__main__': + unittest.main() diff --git a/libs/win/importlib_resources/tests/test_reader.py b/libs/win/importlib_resources/tests/test_reader.py new file mode 100644 index 00000000..1c8ebeeb --- /dev/null +++ b/libs/win/importlib_resources/tests/test_reader.py @@ -0,0 +1,133 @@ +import os.path +import sys +import pathlib +import unittest + +from importlib import import_module +from importlib_resources.readers import MultiplexedPath, NamespaceReader + + +class MultiplexedPathTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + path = pathlib.Path(__file__).parent / 'namespacedata01' + cls.folder = str(path) + + def test_init_no_paths(self): + with self.assertRaises(FileNotFoundError): + MultiplexedPath() + + def test_init_file(self): + with self.assertRaises(NotADirectoryError): + MultiplexedPath(os.path.join(self.folder, 'binary.file')) + + def test_iterdir(self): + contents = {path.name for path in MultiplexedPath(self.folder).iterdir()} + try: + contents.remove('__pycache__') + except (KeyError, ValueError): + pass + self.assertEqual(contents, {'binary.file', 'utf-16.file', 'utf-8.file'}) + + def test_iterdir_duplicate(self): + data01 = os.path.abspath(os.path.join(__file__, '..', 'data01')) + contents = { + path.name for path in MultiplexedPath(self.folder, data01).iterdir() + } + for remove in ('__pycache__', '__init__.pyc'): + try: + contents.remove(remove) + except (KeyError, ValueError): + pass + self.assertEqual( + contents, + {'__init__.py', 'binary.file', 'subdirectory', 'utf-16.file', 'utf-8.file'}, + ) + + def test_is_dir(self): + self.assertEqual(MultiplexedPath(self.folder).is_dir(), True) + + def test_is_file(self): + self.assertEqual(MultiplexedPath(self.folder).is_file(), False) + + def test_open_file(self): + path = MultiplexedPath(self.folder) + with self.assertRaises(FileNotFoundError): + path.read_bytes() + with self.assertRaises(FileNotFoundError): + path.read_text() + with self.assertRaises(FileNotFoundError): + path.open() + + def test_join_path(self): + prefix = os.path.abspath(os.path.join(__file__, '..')) + data01 = os.path.join(prefix, 'data01') + path = MultiplexedPath(self.folder, data01) + self.assertEqual( + str(path.joinpath('binary.file'))[len(prefix) + 1 :], + os.path.join('namespacedata01', 'binary.file'), + ) + self.assertEqual( + str(path.joinpath('subdirectory'))[len(prefix) + 1 :], + os.path.join('data01', 'subdirectory'), + ) + self.assertEqual( + str(path.joinpath('imaginary'))[len(prefix) + 1 :], + os.path.join('namespacedata01', 'imaginary'), + ) + self.assertEqual(path.joinpath(), path) + + def test_join_path_compound(self): + path = MultiplexedPath(self.folder) + assert not path.joinpath('imaginary/foo.py').exists() + + def test_repr(self): + self.assertEqual( + repr(MultiplexedPath(self.folder)), + f"MultiplexedPath('{self.folder}')", + ) + + def test_name(self): + self.assertEqual( + MultiplexedPath(self.folder).name, + os.path.basename(self.folder), + ) + + +class NamespaceReaderTest(unittest.TestCase): + site_dir = str(pathlib.Path(__file__).parent) + + @classmethod + def setUpClass(cls): + sys.path.append(cls.site_dir) + + @classmethod + def tearDownClass(cls): + sys.path.remove(cls.site_dir) + + def test_init_error(self): + with self.assertRaises(ValueError): + NamespaceReader(['path1', 'path2']) + + def test_resource_path(self): + namespacedata01 = import_module('namespacedata01') + reader = NamespaceReader(namespacedata01.__spec__.submodule_search_locations) + + root = os.path.abspath(os.path.join(__file__, '..', 'namespacedata01')) + self.assertEqual( + reader.resource_path('binary.file'), os.path.join(root, 'binary.file') + ) + self.assertEqual( + reader.resource_path('imaginary'), os.path.join(root, 'imaginary') + ) + + def test_files(self): + namespacedata01 = import_module('namespacedata01') + reader = NamespaceReader(namespacedata01.__spec__.submodule_search_locations) + root = os.path.abspath(os.path.join(__file__, '..', 'namespacedata01')) + self.assertIsInstance(reader.files(), MultiplexedPath) + self.assertEqual(repr(reader.files()), f"MultiplexedPath('{root}')") + + +if __name__ == '__main__': + unittest.main() diff --git a/libs/win/importlib_resources/tests/test_resource.py b/libs/win/importlib_resources/tests/test_resource.py new file mode 100644 index 00000000..82390271 --- /dev/null +++ b/libs/win/importlib_resources/tests/test_resource.py @@ -0,0 +1,260 @@ +import sys +import unittest +import importlib_resources as resources +import uuid +import pathlib + +from . import data01 +from . import zipdata01, zipdata02 +from . import util +from importlib import import_module +from ._compat import import_helper, unlink + + +class ResourceTests: + # Subclasses are expected to set the `data` attribute. + + def test_is_file_exists(self): + target = resources.files(self.data) / 'binary.file' + self.assertTrue(target.is_file()) + + def test_is_file_missing(self): + target = resources.files(self.data) / 'not-a-file' + self.assertFalse(target.is_file()) + + def test_is_dir(self): + target = resources.files(self.data) / 'subdirectory' + self.assertFalse(target.is_file()) + self.assertTrue(target.is_dir()) + + +class ResourceDiskTests(ResourceTests, unittest.TestCase): + def setUp(self): + self.data = data01 + + +class ResourceZipTests(ResourceTests, util.ZipSetup, unittest.TestCase): + pass + + +def names(traversable): + return {item.name for item in traversable.iterdir()} + + +class ResourceLoaderTests(unittest.TestCase): + def test_resource_contents(self): + package = util.create_package( + file=data01, path=data01.__file__, contents=['A', 'B', 'C'] + ) + self.assertEqual(names(resources.files(package)), {'A', 'B', 'C'}) + + def test_is_file(self): + package = util.create_package( + file=data01, path=data01.__file__, contents=['A', 'B', 'C', 'D/E', 'D/F'] + ) + self.assertTrue(resources.files(package).joinpath('B').is_file()) + + def test_is_dir(self): + package = util.create_package( + file=data01, path=data01.__file__, contents=['A', 'B', 'C', 'D/E', 'D/F'] + ) + self.assertTrue(resources.files(package).joinpath('D').is_dir()) + + def test_resource_missing(self): + package = util.create_package( + file=data01, path=data01.__file__, contents=['A', 'B', 'C', 'D/E', 'D/F'] + ) + self.assertFalse(resources.files(package).joinpath('Z').is_file()) + + +class ResourceCornerCaseTests(unittest.TestCase): + def test_package_has_no_reader_fallback(self): + # Test odd ball packages which: + # 1. Do not have a ResourceReader as a loader + # 2. Are not on the file system + # 3. Are not in a zip file + module = util.create_package( + file=data01, path=data01.__file__, contents=['A', 'B', 'C'] + ) + # Give the module a dummy loader. + module.__loader__ = object() + # Give the module a dummy origin. + module.__file__ = '/path/which/shall/not/be/named' + module.__spec__.loader = module.__loader__ + module.__spec__.origin = module.__file__ + self.assertFalse(resources.files(module).joinpath('A').is_file()) + + +class ResourceFromZipsTest01(util.ZipSetupBase, unittest.TestCase): + ZIP_MODULE = zipdata01 # type: ignore + + def test_is_submodule_resource(self): + submodule = import_module('ziptestdata.subdirectory') + self.assertTrue(resources.files(submodule).joinpath('binary.file').is_file()) + + def test_read_submodule_resource_by_name(self): + self.assertTrue( + resources.files('ziptestdata.subdirectory') + .joinpath('binary.file') + .is_file() + ) + + def test_submodule_contents(self): + submodule = import_module('ziptestdata.subdirectory') + self.assertEqual( + names(resources.files(submodule)), {'__init__.py', 'binary.file'} + ) + + def test_submodule_contents_by_name(self): + self.assertEqual( + names(resources.files('ziptestdata.subdirectory')), + {'__init__.py', 'binary.file'}, + ) + + def test_as_file_directory(self): + with resources.as_file(resources.files('ziptestdata')) as data: + assert data.name == 'ziptestdata' + assert data.is_dir() + assert data.joinpath('subdirectory').is_dir() + assert len(list(data.iterdir())) + assert not data.parent.exists() + + +class ResourceFromZipsTest02(util.ZipSetupBase, unittest.TestCase): + ZIP_MODULE = zipdata02 # type: ignore + + def test_unrelated_contents(self): + """ + Test thata zip with two unrelated subpackages return + distinct resources. Ref python/importlib_resources#44. + """ + self.assertEqual( + names(resources.files('ziptestdata.one')), + {'__init__.py', 'resource1.txt'}, + ) + self.assertEqual( + names(resources.files('ziptestdata.two')), + {'__init__.py', 'resource2.txt'}, + ) + + +class DeletingZipsTest(unittest.TestCase): + """Having accessed resources in a zip file should not keep an open + reference to the zip. + """ + + ZIP_MODULE = zipdata01 + + def setUp(self): + modules = import_helper.modules_setup() + self.addCleanup(import_helper.modules_cleanup, *modules) + + data_path = pathlib.Path(self.ZIP_MODULE.__file__) + data_dir = data_path.parent + self.source_zip_path = data_dir / 'ziptestdata.zip' + self.zip_path = pathlib.Path(f'{uuid.uuid4()}.zip').absolute() + self.zip_path.write_bytes(self.source_zip_path.read_bytes()) + sys.path.append(str(self.zip_path)) + self.data = import_module('ziptestdata') + + def tearDown(self): + try: + sys.path.remove(str(self.zip_path)) + except ValueError: + pass + + try: + del sys.path_importer_cache[str(self.zip_path)] + del sys.modules[self.data.__name__] + except KeyError: + pass + + try: + unlink(self.zip_path) + except OSError: + # If the test fails, this will probably fail too + pass + + def test_iterdir_does_not_keep_open(self): + c = [item.name for item in resources.files('ziptestdata').iterdir()] + self.zip_path.unlink() + del c + + def test_is_file_does_not_keep_open(self): + c = resources.files('ziptestdata').joinpath('binary.file').is_file() + self.zip_path.unlink() + del c + + def test_is_file_failure_does_not_keep_open(self): + c = resources.files('ziptestdata').joinpath('not-present').is_file() + self.zip_path.unlink() + del c + + @unittest.skip("Desired but not supported.") + def test_as_file_does_not_keep_open(self): # pragma: no cover + c = resources.as_file(resources.files('ziptestdata') / 'binary.file') + self.zip_path.unlink() + del c + + def test_entered_path_does_not_keep_open(self): + # This is what certifi does on import to make its bundle + # available for the process duration. + c = resources.as_file( + resources.files('ziptestdata') / 'binary.file' + ).__enter__() + self.zip_path.unlink() + del c + + def test_read_binary_does_not_keep_open(self): + c = resources.files('ziptestdata').joinpath('binary.file').read_bytes() + self.zip_path.unlink() + del c + + def test_read_text_does_not_keep_open(self): + c = resources.files('ziptestdata').joinpath('utf-8.file').read_text() + self.zip_path.unlink() + del c + + +class ResourceFromNamespaceTest01(unittest.TestCase): + site_dir = str(pathlib.Path(__file__).parent) + + @classmethod + def setUpClass(cls): + sys.path.append(cls.site_dir) + + @classmethod + def tearDownClass(cls): + sys.path.remove(cls.site_dir) + + def test_is_submodule_resource(self): + self.assertTrue( + resources.files(import_module('namespacedata01')) + .joinpath('binary.file') + .is_file() + ) + + def test_read_submodule_resource_by_name(self): + self.assertTrue( + resources.files('namespacedata01').joinpath('binary.file').is_file() + ) + + def test_submodule_contents(self): + contents = names(resources.files(import_module('namespacedata01'))) + try: + contents.remove('__pycache__') + except KeyError: + pass + self.assertEqual(contents, {'binary.file', 'utf-8.file', 'utf-16.file'}) + + def test_submodule_contents_by_name(self): + contents = names(resources.files('namespacedata01')) + try: + contents.remove('__pycache__') + except KeyError: + pass + self.assertEqual(contents, {'binary.file', 'utf-8.file', 'utf-16.file'}) + + +if __name__ == '__main__': + unittest.main() diff --git a/libs/win/importlib_resources/tests/update-zips.py b/libs/win/importlib_resources/tests/update-zips.py new file mode 100644 index 00000000..231334aa --- /dev/null +++ b/libs/win/importlib_resources/tests/update-zips.py @@ -0,0 +1,53 @@ +""" +Generate the zip test data files. + +Run to build the tests/zipdataNN/ziptestdata.zip files from +files in tests/dataNN. + +Replaces the file with the working copy, but does commit anything +to the source repo. +""" + +import contextlib +import os +import pathlib +import zipfile + + +def main(): + """ + >>> from unittest import mock + >>> monkeypatch = getfixture('monkeypatch') + >>> monkeypatch.setattr(zipfile, 'ZipFile', mock.MagicMock()) + >>> print(); main() # print workaround for bpo-32509 + + ...data01... -> ziptestdata/... + ... + ...data02... -> ziptestdata/... + ... + """ + suffixes = '01', '02' + tuple(map(generate, suffixes)) + + +def generate(suffix): + root = pathlib.Path(__file__).parent.relative_to(os.getcwd()) + zfpath = root / f'zipdata{suffix}/ziptestdata.zip' + with zipfile.ZipFile(zfpath, 'w') as zf: + for src, rel in walk(root / f'data{suffix}'): + dst = 'ziptestdata' / pathlib.PurePosixPath(rel.as_posix()) + print(src, '->', dst) + zf.write(src, dst) + + +def walk(datapath): + for dirpath, dirnames, filenames in os.walk(datapath): + with contextlib.suppress(ValueError): + dirnames.remove('__pycache__') + for filename in filenames: + res = pathlib.Path(dirpath) / filename + rel = res.relative_to(datapath) + yield res, rel + + +__name__ == '__main__' and main() diff --git a/libs/win/importlib_resources/tests/util.py b/libs/win/importlib_resources/tests/util.py new file mode 100644 index 00000000..b596c0ce --- /dev/null +++ b/libs/win/importlib_resources/tests/util.py @@ -0,0 +1,167 @@ +import abc +import importlib +import io +import sys +import types +import pathlib + +from . import data01 +from . import zipdata01 +from ..abc import ResourceReader +from ._compat import import_helper + + +from importlib.machinery import ModuleSpec + + +class Reader(ResourceReader): + def __init__(self, **kwargs): + vars(self).update(kwargs) + + def get_resource_reader(self, package): + return self + + def open_resource(self, path): + self._path = path + if isinstance(self.file, Exception): + raise self.file + return self.file + + def resource_path(self, path_): + self._path = path_ + if isinstance(self.path, Exception): + raise self.path + return self.path + + def is_resource(self, path_): + self._path = path_ + if isinstance(self.path, Exception): + raise self.path + + def part(entry): + return entry.split('/') + + return any( + len(parts) == 1 and parts[0] == path_ for parts in map(part, self._contents) + ) + + def contents(self): + if isinstance(self.path, Exception): + raise self.path + yield from self._contents + + +def create_package_from_loader(loader, is_package=True): + name = 'testingpackage' + module = types.ModuleType(name) + spec = ModuleSpec(name, loader, origin='does-not-exist', is_package=is_package) + module.__spec__ = spec + module.__loader__ = loader + return module + + +def create_package(file=None, path=None, is_package=True, contents=()): + return create_package_from_loader( + Reader(file=file, path=path, _contents=contents), + is_package, + ) + + +class CommonTests(metaclass=abc.ABCMeta): + """ + Tests shared by test_open, test_path, and test_read. + """ + + @abc.abstractmethod + def execute(self, package, path): + """ + Call the pertinent legacy API function (e.g. open_text, path) + on package and path. + """ + + def test_package_name(self): + # Passing in the package name should succeed. + self.execute(data01.__name__, 'utf-8.file') + + def test_package_object(self): + # Passing in the package itself should succeed. + self.execute(data01, 'utf-8.file') + + def test_string_path(self): + # Passing in a string for the path should succeed. + path = 'utf-8.file' + self.execute(data01, path) + + def test_pathlib_path(self): + # Passing in a pathlib.PurePath object for the path should succeed. + path = pathlib.PurePath('utf-8.file') + self.execute(data01, path) + + def test_importing_module_as_side_effect(self): + # The anchor package can already be imported. + del sys.modules[data01.__name__] + self.execute(data01.__name__, 'utf-8.file') + + def test_missing_path(self): + # Attempting to open or read or request the path for a + # non-existent path should succeed if open_resource + # can return a viable data stream. + bytes_data = io.BytesIO(b'Hello, world!') + package = create_package(file=bytes_data, path=FileNotFoundError()) + self.execute(package, 'utf-8.file') + self.assertEqual(package.__loader__._path, 'utf-8.file') + + def test_extant_path(self): + # Attempting to open or read or request the path when the + # path does exist should still succeed. Does not assert + # anything about the result. + bytes_data = io.BytesIO(b'Hello, world!') + # any path that exists + path = __file__ + package = create_package(file=bytes_data, path=path) + self.execute(package, 'utf-8.file') + self.assertEqual(package.__loader__._path, 'utf-8.file') + + def test_useless_loader(self): + package = create_package(file=FileNotFoundError(), path=FileNotFoundError()) + with self.assertRaises(FileNotFoundError): + self.execute(package, 'utf-8.file') + + +class ZipSetupBase: + ZIP_MODULE = None + + @classmethod + def setUpClass(cls): + data_path = pathlib.Path(cls.ZIP_MODULE.__file__) + data_dir = data_path.parent + cls._zip_path = str(data_dir / 'ziptestdata.zip') + sys.path.append(cls._zip_path) + cls.data = importlib.import_module('ziptestdata') + + @classmethod + def tearDownClass(cls): + try: + sys.path.remove(cls._zip_path) + except ValueError: + pass + + try: + del sys.path_importer_cache[cls._zip_path] + del sys.modules[cls.data.__name__] + except KeyError: + pass + + try: + del cls.data + del cls._zip_path + except AttributeError: + pass + + def setUp(self): + modules = import_helper.modules_setup() + self.addCleanup(import_helper.modules_cleanup, *modules) + + +class ZipSetup(ZipSetupBase): + ZIP_MODULE = zipdata01 # type: ignore diff --git a/libs/win/importlib_resources/tests/zipdata01/__init__.py b/libs/win/importlib_resources/tests/zipdata01/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/win/importlib_resources/tests/zipdata01/ziptestdata.zip b/libs/win/importlib_resources/tests/zipdata01/ziptestdata.zip new file mode 100644 index 00000000..9a3bb073 Binary files /dev/null and b/libs/win/importlib_resources/tests/zipdata01/ziptestdata.zip differ diff --git a/libs/win/importlib_resources/tests/zipdata02/__init__.py b/libs/win/importlib_resources/tests/zipdata02/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/libs/win/importlib_resources/tests/zipdata02/ziptestdata.zip b/libs/win/importlib_resources/tests/zipdata02/ziptestdata.zip new file mode 100644 index 00000000..d63ff512 Binary files /dev/null and b/libs/win/importlib_resources/tests/zipdata02/ziptestdata.zip differ diff --git a/libs/win/incubator/replace-file.py b/libs/win/incubator/replace-file.py new file mode 100644 index 00000000..a4d24bde --- /dev/null +++ b/libs/win/incubator/replace-file.py @@ -0,0 +1,10 @@ +from jaraco.windows.api.filesystem import ReplaceFile + +open('orig-file', 'w').write('some content') +open('replacing-file', 'w').write('new content') +ReplaceFile('orig-file', 'replacing-file', 'orig-backup', 0, 0, 0) +assert open('orig-file').read() == 'new content' +assert open('orig-backup').read() == 'some content' +import os + +assert not os.path.exists('replacing-file') diff --git a/libs/win/incubator/trace-symlink.py b/libs/win/incubator/trace-symlink.py new file mode 100644 index 00000000..1e7716db --- /dev/null +++ b/libs/win/incubator/trace-symlink.py @@ -0,0 +1,22 @@ +from jaraco.windows.filesystem import trace_symlink_target + +from optparse import OptionParser + + +def get_args(): + parser = OptionParser() + options, args = parser.parse_args() + try: + options.filename = args.pop(0) + except IndexError: + parser.error('filename required') + return options + + +def main(): + options = get_args() + print(trace_symlink_target(options.filename)) + + +if __name__ == '__main__': + main() diff --git a/libs/win/inflect/__init__.py b/libs/win/inflect/__init__.py new file mode 100644 index 00000000..78d2e33c --- /dev/null +++ b/libs/win/inflect/__init__.py @@ -0,0 +1,3991 @@ +""" +inflect: english language inflection + - correctly generate plurals, ordinals, indefinite articles + - convert numbers to words + +Copyright (C) 2010 Paul Dyson + +Based upon the Perl module +`Lingua::EN::Inflect `_. + +methods: + classical inflect + plural plural_noun plural_verb plural_adj singular_noun no num a an + compare compare_nouns compare_verbs compare_adjs + present_participle + ordinal + number_to_words + join + defnoun defverb defadj defa defan + +INFLECTIONS: + classical inflect + plural plural_noun plural_verb plural_adj singular_noun compare + no num a an present_participle + +PLURALS: + classical inflect + plural plural_noun plural_verb plural_adj singular_noun no num + compare compare_nouns compare_verbs compare_adjs + +COMPARISONS: + classical + compare compare_nouns compare_verbs compare_adjs + +ARTICLES: + classical inflect num a an + +NUMERICAL: + ordinal number_to_words + +USER_DEFINED: + defnoun defverb defadj defa defan + +Exceptions: + UnknownClassicalModeError + BadNumValueError + BadChunkingOptionError + NumOutOfRangeError + BadUserDefinedPatternError + BadRcFileError + BadGenderError + +""" + +import ast +import re +import functools +import collections +import contextlib +from typing import ( + Dict, + Union, + Optional, + Iterable, + List, + Match, + Tuple, + Callable, + Sequence, + cast, + Any, +) +from numbers import Number + + +from pydantic import Field, validate_arguments +from pydantic.typing import Annotated + + +class UnknownClassicalModeError(Exception): + pass + + +class BadNumValueError(Exception): + pass + + +class BadChunkingOptionError(Exception): + pass + + +class NumOutOfRangeError(Exception): + pass + + +class BadUserDefinedPatternError(Exception): + pass + + +class BadRcFileError(Exception): + pass + + +class BadGenderError(Exception): + pass + + +STDOUT_ON = False + + +def print3(txt: str) -> None: + if STDOUT_ON: + print(txt) + + +def enclose(s: str) -> str: + return f"(?:{s})" + + +def joinstem(cutpoint: Optional[int] = 0, words: Optional[Iterable[str]] = None) -> str: + """ + Join stem of each word in words into a string for regex. + + Each word is truncated at cutpoint. + + Cutpoint is usually negative indicating the number of letters to remove + from the end of each word. + + >>> joinstem(-2, ["ephemeris", "iris", ".*itis"]) + '(?:ephemer|ir|.*it)' + + >>> joinstem(None, ["ephemeris"]) + '(?:ephemeris)' + + >>> joinstem(5, None) + '(?:)' + """ + return enclose("|".join(w[:cutpoint] for w in words or [])) + + +def bysize(words: Iterable[str]) -> Dict[int, set]: + """ + From a list of words, return a dict of sets sorted by word length. + + >>> words = ['ant', 'cat', 'dog', 'pig', 'frog', 'goat', 'horse', 'elephant'] + >>> ret = bysize(words) + >>> sorted(ret[3]) + ['ant', 'cat', 'dog', 'pig'] + >>> ret[5] + {'horse'} + """ + res: Dict[int, set] = collections.defaultdict(set) + for w in words: + res[len(w)].add(w) + return res + + +def make_pl_si_lists( + lst: Iterable[str], + plending: str, + siendingsize: Optional[int], + dojoinstem: bool = True, +): + """ + given a list of singular words: lst + + an ending to append to make the plural: plending + + the number of characters to remove from the singular + before appending plending: siendingsize + + a flag whether to create a joinstem: dojoinstem + + return: + a list of pluralised words: si_list (called si because this is what you need to + look for to make the singular) + + the pluralised words as a dict of sets sorted by word length: si_bysize + the singular words as a dict of sets sorted by word length: pl_bysize + if dojoinstem is True: a regular expression that matches any of the stems: stem + """ + if siendingsize is not None: + siendingsize = -siendingsize + si_list = [w[:siendingsize] + plending for w in lst] + pl_bysize = bysize(lst) + si_bysize = bysize(si_list) + if dojoinstem: + stem = joinstem(siendingsize, lst) + return si_list, si_bysize, pl_bysize, stem + else: + return si_list, si_bysize, pl_bysize + + +# 1. PLURALS + +pl_sb_irregular_s = { + "corpus": "corpuses|corpora", + "opus": "opuses|opera", + "genus": "genera", + "mythos": "mythoi", + "penis": "penises|penes", + "testis": "testes", + "atlas": "atlases|atlantes", + "yes": "yeses", +} + +pl_sb_irregular = { + "child": "children", + "chili": "chilis|chilies", + "brother": "brothers|brethren", + "infinity": "infinities|infinity", + "loaf": "loaves", + "lore": "lores|lore", + "hoof": "hoofs|hooves", + "beef": "beefs|beeves", + "thief": "thiefs|thieves", + "money": "monies", + "mongoose": "mongooses", + "ox": "oxen", + "cow": "cows|kine", + "graffito": "graffiti", + "octopus": "octopuses|octopodes", + "genie": "genies|genii", + "ganglion": "ganglions|ganglia", + "trilby": "trilbys", + "turf": "turfs|turves", + "numen": "numina", + "atman": "atmas", + "occiput": "occiputs|occipita", + "sabretooth": "sabretooths", + "sabertooth": "sabertooths", + "lowlife": "lowlifes", + "flatfoot": "flatfoots", + "tenderfoot": "tenderfoots", + "romany": "romanies", + "jerry": "jerries", + "mary": "maries", + "talouse": "talouses", + "rom": "roma", + "carmen": "carmina", +} + +pl_sb_irregular.update(pl_sb_irregular_s) +# pl_sb_irregular_keys = enclose('|'.join(pl_sb_irregular.keys())) + +pl_sb_irregular_caps = { + "Romany": "Romanies", + "Jerry": "Jerrys", + "Mary": "Marys", + "Rom": "Roma", +} + +pl_sb_irregular_compound = {"prima donna": "prima donnas|prime donne"} + +si_sb_irregular = {v: k for (k, v) in pl_sb_irregular.items()} +for k in list(si_sb_irregular): + if "|" in k: + k1, k2 = k.split("|") + si_sb_irregular[k1] = si_sb_irregular[k2] = si_sb_irregular[k] + del si_sb_irregular[k] +si_sb_irregular_caps = {v: k for (k, v) in pl_sb_irregular_caps.items()} +si_sb_irregular_compound = {v: k for (k, v) in pl_sb_irregular_compound.items()} +for k in list(si_sb_irregular_compound): + if "|" in k: + k1, k2 = k.split("|") + si_sb_irregular_compound[k1] = si_sb_irregular_compound[ + k2 + ] = si_sb_irregular_compound[k] + del si_sb_irregular_compound[k] + +# si_sb_irregular_keys = enclose('|'.join(si_sb_irregular.keys())) + +# Z's that don't double + +pl_sb_z_zes_list = ("quartz", "topaz") +pl_sb_z_zes_bysize = bysize(pl_sb_z_zes_list) + +pl_sb_ze_zes_list = ("snooze",) +pl_sb_ze_zes_bysize = bysize(pl_sb_ze_zes_list) + + +# CLASSICAL "..is" -> "..ides" + +pl_sb_C_is_ides_complete = [ + # GENERAL WORDS... + "ephemeris", + "iris", + "clitoris", + "chrysalis", + "epididymis", +] + +pl_sb_C_is_ides_endings = [ + # INFLAMATIONS... + "itis" +] + +pl_sb_C_is_ides = joinstem( + -2, pl_sb_C_is_ides_complete + [f".*{w}" for w in pl_sb_C_is_ides_endings] +) + +pl_sb_C_is_ides_list = pl_sb_C_is_ides_complete + pl_sb_C_is_ides_endings + +( + si_sb_C_is_ides_list, + si_sb_C_is_ides_bysize, + pl_sb_C_is_ides_bysize, +) = make_pl_si_lists(pl_sb_C_is_ides_list, "ides", 2, dojoinstem=False) + + +# CLASSICAL "..a" -> "..ata" + +pl_sb_C_a_ata_list = ( + "anathema", + "bema", + "carcinoma", + "charisma", + "diploma", + "dogma", + "drama", + "edema", + "enema", + "enigma", + "lemma", + "lymphoma", + "magma", + "melisma", + "miasma", + "oedema", + "sarcoma", + "schema", + "soma", + "stigma", + "stoma", + "trauma", + "gumma", + "pragma", +) + +( + si_sb_C_a_ata_list, + si_sb_C_a_ata_bysize, + pl_sb_C_a_ata_bysize, + pl_sb_C_a_ata, +) = make_pl_si_lists(pl_sb_C_a_ata_list, "ata", 1) + +# UNCONDITIONAL "..a" -> "..ae" + +pl_sb_U_a_ae_list = ( + "alumna", + "alga", + "vertebra", + "persona", + "vita", +) +( + si_sb_U_a_ae_list, + si_sb_U_a_ae_bysize, + pl_sb_U_a_ae_bysize, + pl_sb_U_a_ae, +) = make_pl_si_lists(pl_sb_U_a_ae_list, "e", None) + +# CLASSICAL "..a" -> "..ae" + +pl_sb_C_a_ae_list = ( + "amoeba", + "antenna", + "formula", + "hyperbola", + "medusa", + "nebula", + "parabola", + "abscissa", + "hydra", + "nova", + "lacuna", + "aurora", + "umbra", + "flora", + "fauna", +) +( + si_sb_C_a_ae_list, + si_sb_C_a_ae_bysize, + pl_sb_C_a_ae_bysize, + pl_sb_C_a_ae, +) = make_pl_si_lists(pl_sb_C_a_ae_list, "e", None) + + +# CLASSICAL "..en" -> "..ina" + +pl_sb_C_en_ina_list = ("stamen", "foramen", "lumen") + +( + si_sb_C_en_ina_list, + si_sb_C_en_ina_bysize, + pl_sb_C_en_ina_bysize, + pl_sb_C_en_ina, +) = make_pl_si_lists(pl_sb_C_en_ina_list, "ina", 2) + + +# UNCONDITIONAL "..um" -> "..a" + +pl_sb_U_um_a_list = ( + "bacterium", + "agendum", + "desideratum", + "erratum", + "stratum", + "datum", + "ovum", + "extremum", + "candelabrum", +) +( + si_sb_U_um_a_list, + si_sb_U_um_a_bysize, + pl_sb_U_um_a_bysize, + pl_sb_U_um_a, +) = make_pl_si_lists(pl_sb_U_um_a_list, "a", 2) + +# CLASSICAL "..um" -> "..a" + +pl_sb_C_um_a_list = ( + "maximum", + "minimum", + "momentum", + "optimum", + "quantum", + "cranium", + "curriculum", + "dictum", + "phylum", + "aquarium", + "compendium", + "emporium", + "encomium", + "gymnasium", + "honorarium", + "interregnum", + "lustrum", + "memorandum", + "millennium", + "rostrum", + "spectrum", + "speculum", + "stadium", + "trapezium", + "ultimatum", + "medium", + "vacuum", + "velum", + "consortium", + "arboretum", +) + +( + si_sb_C_um_a_list, + si_sb_C_um_a_bysize, + pl_sb_C_um_a_bysize, + pl_sb_C_um_a, +) = make_pl_si_lists(pl_sb_C_um_a_list, "a", 2) + + +# UNCONDITIONAL "..us" -> "i" + +pl_sb_U_us_i_list = ( + "alumnus", + "alveolus", + "bacillus", + "bronchus", + "locus", + "nucleus", + "stimulus", + "meniscus", + "sarcophagus", +) +( + si_sb_U_us_i_list, + si_sb_U_us_i_bysize, + pl_sb_U_us_i_bysize, + pl_sb_U_us_i, +) = make_pl_si_lists(pl_sb_U_us_i_list, "i", 2) + +# CLASSICAL "..us" -> "..i" + +pl_sb_C_us_i_list = ( + "focus", + "radius", + "genius", + "incubus", + "succubus", + "nimbus", + "fungus", + "nucleolus", + "stylus", + "torus", + "umbilicus", + "uterus", + "hippopotamus", + "cactus", +) + +( + si_sb_C_us_i_list, + si_sb_C_us_i_bysize, + pl_sb_C_us_i_bysize, + pl_sb_C_us_i, +) = make_pl_si_lists(pl_sb_C_us_i_list, "i", 2) + + +# CLASSICAL "..us" -> "..us" (ASSIMILATED 4TH DECLENSION LATIN NOUNS) + +pl_sb_C_us_us = ( + "status", + "apparatus", + "prospectus", + "sinus", + "hiatus", + "impetus", + "plexus", +) +pl_sb_C_us_us_bysize = bysize(pl_sb_C_us_us) + +# UNCONDITIONAL "..on" -> "a" + +pl_sb_U_on_a_list = ( + "criterion", + "perihelion", + "aphelion", + "phenomenon", + "prolegomenon", + "noumenon", + "organon", + "asyndeton", + "hyperbaton", +) +( + si_sb_U_on_a_list, + si_sb_U_on_a_bysize, + pl_sb_U_on_a_bysize, + pl_sb_U_on_a, +) = make_pl_si_lists(pl_sb_U_on_a_list, "a", 2) + +# CLASSICAL "..on" -> "..a" + +pl_sb_C_on_a_list = ("oxymoron",) + +( + si_sb_C_on_a_list, + si_sb_C_on_a_bysize, + pl_sb_C_on_a_bysize, + pl_sb_C_on_a, +) = make_pl_si_lists(pl_sb_C_on_a_list, "a", 2) + + +# CLASSICAL "..o" -> "..i" (BUT NORMALLY -> "..os") + +pl_sb_C_o_i = [ + "solo", + "soprano", + "basso", + "alto", + "contralto", + "tempo", + "piano", + "virtuoso", +] # list not tuple so can concat for pl_sb_U_o_os + +pl_sb_C_o_i_bysize = bysize(pl_sb_C_o_i) +si_sb_C_o_i_bysize = bysize([f"{w[:-1]}i" for w in pl_sb_C_o_i]) + +pl_sb_C_o_i_stems = joinstem(-1, pl_sb_C_o_i) + +# ALWAYS "..o" -> "..os" + +pl_sb_U_o_os_complete = {"ado", "ISO", "NATO", "NCO", "NGO", "oto"} +si_sb_U_o_os_complete = {f"{w}s" for w in pl_sb_U_o_os_complete} + + +pl_sb_U_o_os_endings = [ + "aficionado", + "aggro", + "albino", + "allegro", + "ammo", + "Antananarivo", + "archipelago", + "armadillo", + "auto", + "avocado", + "Bamako", + "Barquisimeto", + "bimbo", + "bingo", + "Biro", + "bolero", + "Bolzano", + "bongo", + "Boto", + "burro", + "Cairo", + "canto", + "cappuccino", + "casino", + "cello", + "Chicago", + "Chimango", + "cilantro", + "cochito", + "coco", + "Colombo", + "Colorado", + "commando", + "concertino", + "contango", + "credo", + "crescendo", + "cyano", + "demo", + "ditto", + "Draco", + "dynamo", + "embryo", + "Esperanto", + "espresso", + "euro", + "falsetto", + "Faro", + "fiasco", + "Filipino", + "flamenco", + "furioso", + "generalissimo", + "Gestapo", + "ghetto", + "gigolo", + "gizmo", + "Greensboro", + "gringo", + "Guaiabero", + "guano", + "gumbo", + "gyro", + "hairdo", + "hippo", + "Idaho", + "impetigo", + "inferno", + "info", + "intermezzo", + "intertrigo", + "Iquico", + "jumbo", + "junto", + "Kakapo", + "kilo", + "Kinkimavo", + "Kokako", + "Kosovo", + "Lesotho", + "libero", + "libido", + "libretto", + "lido", + "Lilo", + "limbo", + "limo", + "lineno", + "lingo", + "lino", + "livedo", + "loco", + "logo", + "lumbago", + "macho", + "macro", + "mafioso", + "magneto", + "magnifico", + "Majuro", + "Malabo", + "manifesto", + "Maputo", + "Maracaibo", + "medico", + "memo", + "metro", + "Mexico", + "micro", + "Milano", + "Monaco", + "mono", + "Montenegro", + "Morocco", + "Muqdisho", + "myo", + "neutrino", + "Ningbo", + "octavo", + "oregano", + "Orinoco", + "Orlando", + "Oslo", + "panto", + "Paramaribo", + "Pardusco", + "pedalo", + "photo", + "pimento", + "pinto", + "pleco", + "Pluto", + "pogo", + "polo", + "poncho", + "Porto-Novo", + "Porto", + "pro", + "psycho", + "pueblo", + "quarto", + "Quito", + "repo", + "rhino", + "risotto", + "rococo", + "rondo", + "Sacramento", + "saddo", + "sago", + "salvo", + "Santiago", + "Sapporo", + "Sarajevo", + "scherzando", + "scherzo", + "silo", + "sirocco", + "sombrero", + "staccato", + "sterno", + "stucco", + "stylo", + "sumo", + "Taiko", + "techno", + "terrazzo", + "testudo", + "timpano", + "tiro", + "tobacco", + "Togo", + "Tokyo", + "torero", + "Torino", + "Toronto", + "torso", + "tremolo", + "typo", + "tyro", + "ufo", + "UNESCO", + "vaquero", + "vermicello", + "verso", + "vibrato", + "violoncello", + "Virgo", + "weirdo", + "WHO", + "WTO", + "Yamoussoukro", + "yo-yo", + "zero", + "Zibo", +] + pl_sb_C_o_i + +pl_sb_U_o_os_bysize = bysize(pl_sb_U_o_os_endings) +si_sb_U_o_os_bysize = bysize([f"{w}s" for w in pl_sb_U_o_os_endings]) + + +# UNCONDITIONAL "..ch" -> "..chs" + +pl_sb_U_ch_chs_list = ("czech", "eunuch", "stomach") + +( + si_sb_U_ch_chs_list, + si_sb_U_ch_chs_bysize, + pl_sb_U_ch_chs_bysize, + pl_sb_U_ch_chs, +) = make_pl_si_lists(pl_sb_U_ch_chs_list, "s", None) + + +# UNCONDITIONAL "..[ei]x" -> "..ices" + +pl_sb_U_ex_ices_list = ("codex", "murex", "silex") +( + si_sb_U_ex_ices_list, + si_sb_U_ex_ices_bysize, + pl_sb_U_ex_ices_bysize, + pl_sb_U_ex_ices, +) = make_pl_si_lists(pl_sb_U_ex_ices_list, "ices", 2) + +pl_sb_U_ix_ices_list = ("radix", "helix") +( + si_sb_U_ix_ices_list, + si_sb_U_ix_ices_bysize, + pl_sb_U_ix_ices_bysize, + pl_sb_U_ix_ices, +) = make_pl_si_lists(pl_sb_U_ix_ices_list, "ices", 2) + +# CLASSICAL "..[ei]x" -> "..ices" + +pl_sb_C_ex_ices_list = ( + "vortex", + "vertex", + "cortex", + "latex", + "pontifex", + "apex", + "index", + "simplex", +) + +( + si_sb_C_ex_ices_list, + si_sb_C_ex_ices_bysize, + pl_sb_C_ex_ices_bysize, + pl_sb_C_ex_ices, +) = make_pl_si_lists(pl_sb_C_ex_ices_list, "ices", 2) + + +pl_sb_C_ix_ices_list = ("appendix",) + +( + si_sb_C_ix_ices_list, + si_sb_C_ix_ices_bysize, + pl_sb_C_ix_ices_bysize, + pl_sb_C_ix_ices, +) = make_pl_si_lists(pl_sb_C_ix_ices_list, "ices", 2) + + +# ARABIC: ".." -> "..i" + +pl_sb_C_i_list = ("afrit", "afreet", "efreet") + +(si_sb_C_i_list, si_sb_C_i_bysize, pl_sb_C_i_bysize, pl_sb_C_i) = make_pl_si_lists( + pl_sb_C_i_list, "i", None +) + + +# HEBREW: ".." -> "..im" + +pl_sb_C_im_list = ("goy", "seraph", "cherub") + +(si_sb_C_im_list, si_sb_C_im_bysize, pl_sb_C_im_bysize, pl_sb_C_im) = make_pl_si_lists( + pl_sb_C_im_list, "im", None +) + + +# UNCONDITIONAL "..man" -> "..mans" + +pl_sb_U_man_mans_list = """ + ataman caiman cayman ceriman + desman dolman farman harman hetman + human leman ottoman shaman talisman +""".split() +pl_sb_U_man_mans_caps_list = """ + Alabaman Bahaman Burman German + Hiroshiman Liman Nakayaman Norman Oklahoman + Panaman Roman Selman Sonaman Tacoman Yakiman + Yokohaman Yuman +""".split() + +( + si_sb_U_man_mans_list, + si_sb_U_man_mans_bysize, + pl_sb_U_man_mans_bysize, +) = make_pl_si_lists(pl_sb_U_man_mans_list, "s", None, dojoinstem=False) +( + si_sb_U_man_mans_caps_list, + si_sb_U_man_mans_caps_bysize, + pl_sb_U_man_mans_caps_bysize, +) = make_pl_si_lists(pl_sb_U_man_mans_caps_list, "s", None, dojoinstem=False) + +# UNCONDITIONAL "..louse" -> "..lice" +pl_sb_U_louse_lice_list = ("booklouse", "grapelouse", "louse", "woodlouse") + +( + si_sb_U_louse_lice_list, + si_sb_U_louse_lice_bysize, + pl_sb_U_louse_lice_bysize, +) = make_pl_si_lists(pl_sb_U_louse_lice_list, "lice", 5, dojoinstem=False) + +pl_sb_uninflected_s_complete = [ + # PAIRS OR GROUPS SUBSUMED TO A SINGULAR... + "breeches", + "britches", + "pajamas", + "pyjamas", + "clippers", + "gallows", + "hijinks", + "headquarters", + "pliers", + "scissors", + "testes", + "herpes", + "pincers", + "shears", + "proceedings", + "trousers", + # UNASSIMILATED LATIN 4th DECLENSION + "cantus", + "coitus", + "nexus", + # RECENT IMPORTS... + "contretemps", + "corps", + "debris", + "siemens", + # DISEASES + "mumps", + # MISCELLANEOUS OTHERS... + "diabetes", + "jackanapes", + "series", + "species", + "subspecies", + "rabies", + "chassis", + "innings", + "news", + "mews", + "haggis", +] + +pl_sb_uninflected_s_endings = [ + # RECENT IMPORTS... + "ois", + # DISEASES + "measles", +] + +pl_sb_uninflected_s = pl_sb_uninflected_s_complete + [ + f".*{w}" for w in pl_sb_uninflected_s_endings +] + +pl_sb_uninflected_herd = ( + # DON'T INFLECT IN CLASSICAL MODE, OTHERWISE NORMAL INFLECTION + "wildebeest", + "swine", + "eland", + "bison", + "buffalo", + "cattle", + "elk", + "rhinoceros", + "zucchini", + "caribou", + "dace", + "grouse", + "guinea fowl", + "guinea-fowl", + "haddock", + "hake", + "halibut", + "herring", + "mackerel", + "pickerel", + "pike", + "roe", + "seed", + "shad", + "snipe", + "teal", + "turbot", + "water fowl", + "water-fowl", +) + +pl_sb_uninflected_complete = [ + # SOME FISH AND HERD ANIMALS + "tuna", + "salmon", + "mackerel", + "trout", + "bream", + "sea-bass", + "sea bass", + "carp", + "cod", + "flounder", + "whiting", + "moose", + # OTHER ODDITIES + "graffiti", + "djinn", + "samuri", + "offspring", + "pence", + "quid", + "hertz", +] + pl_sb_uninflected_s_complete +# SOME WORDS ENDING IN ...s (OFTEN PAIRS TAKEN AS A WHOLE) + +pl_sb_uninflected_caps = [ + # ALL NATIONALS ENDING IN -ese + "Portuguese", + "Amoyese", + "Borghese", + "Congoese", + "Faroese", + "Foochowese", + "Genevese", + "Genoese", + "Gilbertese", + "Hottentotese", + "Kiplingese", + "Kongoese", + "Lucchese", + "Maltese", + "Nankingese", + "Niasese", + "Pekingese", + "Piedmontese", + "Pistoiese", + "Sarawakese", + "Shavese", + "Vermontese", + "Wenchowese", + "Yengeese", +] + + +pl_sb_uninflected_endings = [ + # UNCOUNTABLE NOUNS + "butter", + "cash", + "furniture", + "information", + # SOME FISH AND HERD ANIMALS + "fish", + "deer", + "sheep", + # ALL NATIONALS ENDING IN -ese + "nese", + "rese", + "lese", + "mese", + # DISEASES + "pox", + # OTHER ODDITIES + "craft", +] + pl_sb_uninflected_s_endings +# SOME WORDS ENDING IN ...s (OFTEN PAIRS TAKEN AS A WHOLE) + + +pl_sb_uninflected_bysize = bysize(pl_sb_uninflected_endings) + + +# SINGULAR WORDS ENDING IN ...s (ALL INFLECT WITH ...es) + +pl_sb_singular_s_complete = [ + "acropolis", + "aegis", + "alias", + "asbestos", + "bathos", + "bias", + "bronchitis", + "bursitis", + "caddis", + "cannabis", + "canvas", + "chaos", + "cosmos", + "dais", + "digitalis", + "epidermis", + "ethos", + "eyas", + "gas", + "glottis", + "hubris", + "ibis", + "lens", + "mantis", + "marquis", + "metropolis", + "pathos", + "pelvis", + "polis", + "rhinoceros", + "sassafras", + "trellis", +] + pl_sb_C_is_ides_complete + + +pl_sb_singular_s_endings = ["ss", "us"] + pl_sb_C_is_ides_endings + +pl_sb_singular_s_bysize = bysize(pl_sb_singular_s_endings) + +si_sb_singular_s_complete = [f"{w}es" for w in pl_sb_singular_s_complete] +si_sb_singular_s_endings = [f"{w}es" for w in pl_sb_singular_s_endings] +si_sb_singular_s_bysize = bysize(si_sb_singular_s_endings) + +pl_sb_singular_s_es = ["[A-Z].*es"] + +pl_sb_singular_s = enclose( + "|".join( + pl_sb_singular_s_complete + + [f".*{w}" for w in pl_sb_singular_s_endings] + + pl_sb_singular_s_es + ) +) + + +# PLURALS ENDING IN uses -> use + + +si_sb_ois_oi_case = ("Bolshois", "Hanois") + +si_sb_uses_use_case = ("Betelgeuses", "Duses", "Meuses", "Syracuses", "Toulouses") + +si_sb_uses_use = ( + "abuses", + "applauses", + "blouses", + "carouses", + "causes", + "chartreuses", + "clauses", + "contuses", + "douses", + "excuses", + "fuses", + "grouses", + "hypotenuses", + "masseuses", + "menopauses", + "misuses", + "muses", + "overuses", + "pauses", + "peruses", + "profuses", + "recluses", + "reuses", + "ruses", + "souses", + "spouses", + "suffuses", + "transfuses", + "uses", +) + +si_sb_ies_ie_case = ( + "Addies", + "Aggies", + "Allies", + "Amies", + "Angies", + "Annies", + "Annmaries", + "Archies", + "Arties", + "Aussies", + "Barbies", + "Barries", + "Basies", + "Bennies", + "Bernies", + "Berties", + "Bessies", + "Betties", + "Billies", + "Blondies", + "Bobbies", + "Bonnies", + "Bowies", + "Brandies", + "Bries", + "Brownies", + "Callies", + "Carnegies", + "Carries", + "Cassies", + "Charlies", + "Cheries", + "Christies", + "Connies", + "Curies", + "Dannies", + "Debbies", + "Dixies", + "Dollies", + "Donnies", + "Drambuies", + "Eddies", + "Effies", + "Ellies", + "Elsies", + "Eries", + "Ernies", + "Essies", + "Eugenies", + "Fannies", + "Flossies", + "Frankies", + "Freddies", + "Gillespies", + "Goldies", + "Gracies", + "Guthries", + "Hallies", + "Hatties", + "Hetties", + "Hollies", + "Jackies", + "Jamies", + "Janies", + "Jannies", + "Jeanies", + "Jeannies", + "Jennies", + "Jessies", + "Jimmies", + "Jodies", + "Johnies", + "Johnnies", + "Josies", + "Julies", + "Kalgoorlies", + "Kathies", + "Katies", + "Kellies", + "Kewpies", + "Kristies", + "Laramies", + "Lassies", + "Lauries", + "Leslies", + "Lessies", + "Lillies", + "Lizzies", + "Lonnies", + "Lories", + "Lorries", + "Lotties", + "Louies", + "Mackenzies", + "Maggies", + "Maisies", + "Mamies", + "Marcies", + "Margies", + "Maries", + "Marjories", + "Matties", + "McKenzies", + "Melanies", + "Mickies", + "Millies", + "Minnies", + "Mollies", + "Mounties", + "Nannies", + "Natalies", + "Nellies", + "Netties", + "Ollies", + "Ozzies", + "Pearlies", + "Pottawatomies", + "Reggies", + "Richies", + "Rickies", + "Robbies", + "Ronnies", + "Rosalies", + "Rosemaries", + "Rosies", + "Roxies", + "Rushdies", + "Ruthies", + "Sadies", + "Sallies", + "Sammies", + "Scotties", + "Selassies", + "Sherries", + "Sophies", + "Stacies", + "Stefanies", + "Stephanies", + "Stevies", + "Susies", + "Sylvies", + "Tammies", + "Terries", + "Tessies", + "Tommies", + "Tracies", + "Trekkies", + "Valaries", + "Valeries", + "Valkyries", + "Vickies", + "Virgies", + "Willies", + "Winnies", + "Wylies", + "Yorkies", +) + +si_sb_ies_ie = ( + "aeries", + "baggies", + "belies", + "biggies", + "birdies", + "bogies", + "bonnies", + "boogies", + "bookies", + "bourgeoisies", + "brownies", + "budgies", + "caddies", + "calories", + "camaraderies", + "cockamamies", + "collies", + "cookies", + "coolies", + "cooties", + "coteries", + "crappies", + "curies", + "cutesies", + "dogies", + "eyries", + "floozies", + "footsies", + "freebies", + "genies", + "goalies", + "groupies", + "hies", + "jalousies", + "junkies", + "kiddies", + "laddies", + "lassies", + "lies", + "lingeries", + "magpies", + "menageries", + "mommies", + "movies", + "neckties", + "newbies", + "nighties", + "oldies", + "organdies", + "overlies", + "pies", + "pinkies", + "pixies", + "potpies", + "prairies", + "quickies", + "reveries", + "rookies", + "rotisseries", + "softies", + "sorties", + "species", + "stymies", + "sweeties", + "ties", + "underlies", + "unties", + "veggies", + "vies", + "yuppies", + "zombies", +) + + +si_sb_oes_oe_case = ( + "Chloes", + "Crusoes", + "Defoes", + "Faeroes", + "Ivanhoes", + "Joes", + "McEnroes", + "Moes", + "Monroes", + "Noes", + "Poes", + "Roscoes", + "Tahoes", + "Tippecanoes", + "Zoes", +) + +si_sb_oes_oe = ( + "aloes", + "backhoes", + "canoes", + "does", + "floes", + "foes", + "hoes", + "mistletoes", + "oboes", + "pekoes", + "roes", + "sloes", + "throes", + "tiptoes", + "toes", + "woes", +) + +si_sb_z_zes = ("quartzes", "topazes") + +si_sb_zzes_zz = ("buzzes", "fizzes", "frizzes", "razzes") + +si_sb_ches_che_case = ( + "Andromaches", + "Apaches", + "Blanches", + "Comanches", + "Nietzsches", + "Porsches", + "Roches", +) + +si_sb_ches_che = ( + "aches", + "avalanches", + "backaches", + "bellyaches", + "caches", + "cloches", + "creches", + "douches", + "earaches", + "fiches", + "headaches", + "heartaches", + "microfiches", + "niches", + "pastiches", + "psyches", + "quiches", + "stomachaches", + "toothaches", + "tranches", +) + +si_sb_xes_xe = ("annexes", "axes", "deluxes", "pickaxes") + +si_sb_sses_sse_case = ("Hesses", "Jesses", "Larousses", "Matisses") +si_sb_sses_sse = ( + "bouillabaisses", + "crevasses", + "demitasses", + "impasses", + "mousses", + "posses", +) + +si_sb_ves_ve_case = ( + # *[nwl]ives -> [nwl]live + "Clives", + "Palmolives", +) +si_sb_ves_ve = ( + # *[^d]eaves -> eave + "interweaves", + "weaves", + # *[nwl]ives -> [nwl]live + "olives", + # *[eoa]lves -> [eoa]lve + "bivalves", + "dissolves", + "resolves", + "salves", + "twelves", + "valves", +) + + +plverb_special_s = enclose( + "|".join( + [pl_sb_singular_s] + + pl_sb_uninflected_s + + list(pl_sb_irregular_s) + + ["(.*[csx])is", "(.*)ceps", "[A-Z].*s"] + ) +) + +_pl_sb_postfix_adj_defn = ( + ("general", enclose(r"(?!major|lieutenant|brigadier|adjutant|.*star)\S+")), + ("martial", enclose("court")), + ("force", enclose("pound")), +) + +pl_sb_postfix_adj: Iterable[str] = ( + enclose(val + f"(?=(?:-|\\s+){key})") for key, val in _pl_sb_postfix_adj_defn +) + +pl_sb_postfix_adj_stems = f"({'|'.join(pl_sb_postfix_adj)})(.*)" + + +# PLURAL WORDS ENDING IS es GO TO SINGULAR is + +si_sb_es_is = ( + "amanuenses", + "amniocenteses", + "analyses", + "antitheses", + "apotheoses", + "arterioscleroses", + "atheroscleroses", + "axes", + # 'bases', # bases -> basis + "catalyses", + "catharses", + "chasses", + "cirrhoses", + "cocces", + "crises", + "diagnoses", + "dialyses", + "diereses", + "electrolyses", + "emphases", + "exegeses", + "geneses", + "halitoses", + "hydrolyses", + "hypnoses", + "hypotheses", + "hystereses", + "metamorphoses", + "metastases", + "misdiagnoses", + "mitoses", + "mononucleoses", + "narcoses", + "necroses", + "nemeses", + "neuroses", + "oases", + "osmoses", + "osteoporoses", + "paralyses", + "parentheses", + "parthenogeneses", + "periphrases", + "photosyntheses", + "probosces", + "prognoses", + "prophylaxes", + "prostheses", + "preces", + "psoriases", + "psychoanalyses", + "psychokineses", + "psychoses", + "scleroses", + "scolioses", + "sepses", + "silicoses", + "symbioses", + "synopses", + "syntheses", + "taxes", + "telekineses", + "theses", + "thromboses", + "tuberculoses", + "urinalyses", +) + +pl_prep_list = """ + about above across after among around at athwart before behind + below beneath beside besides between betwixt beyond but by + during except for from in into near of off on onto out over + since till to under until unto upon with""".split() + +pl_prep_list_da = pl_prep_list + ["de", "du", "da"] + +pl_prep_bysize = bysize(pl_prep_list_da) + +pl_prep = enclose("|".join(pl_prep_list_da)) + +pl_sb_prep_dual_compound = fr"(.*?)((?:-|\s+)(?:{pl_prep})(?:-|\s+))a(?:-|\s+)(.*)" + + +singular_pronoun_genders = { + "neuter", + "feminine", + "masculine", + "gender-neutral", + "feminine or masculine", + "masculine or feminine", +} + +pl_pron_nom = { + # NOMINATIVE REFLEXIVE + "i": "we", + "myself": "ourselves", + "you": "you", + "yourself": "yourselves", + "she": "they", + "herself": "themselves", + "he": "they", + "himself": "themselves", + "it": "they", + "itself": "themselves", + "they": "they", + "themself": "themselves", + # POSSESSIVE + "mine": "ours", + "yours": "yours", + "hers": "theirs", + "his": "theirs", + "its": "theirs", + "theirs": "theirs", +} + +si_pron: Dict[str, Dict[str, Union[str, Dict[str, str]]]] = { + "nom": {v: k for (k, v) in pl_pron_nom.items()} +} +si_pron["nom"]["we"] = "I" + + +pl_pron_acc = { + # ACCUSATIVE REFLEXIVE + "me": "us", + "myself": "ourselves", + "you": "you", + "yourself": "yourselves", + "her": "them", + "herself": "themselves", + "him": "them", + "himself": "themselves", + "it": "them", + "itself": "themselves", + "them": "them", + "themself": "themselves", +} + +pl_pron_acc_keys = enclose("|".join(pl_pron_acc)) +pl_pron_acc_keys_bysize = bysize(pl_pron_acc) + +si_pron["acc"] = {v: k for (k, v) in pl_pron_acc.items()} + +for _thecase, _plur, _gend, _sing in ( + ("nom", "they", "neuter", "it"), + ("nom", "they", "feminine", "she"), + ("nom", "they", "masculine", "he"), + ("nom", "they", "gender-neutral", "they"), + ("nom", "they", "feminine or masculine", "she or he"), + ("nom", "they", "masculine or feminine", "he or she"), + ("nom", "themselves", "neuter", "itself"), + ("nom", "themselves", "feminine", "herself"), + ("nom", "themselves", "masculine", "himself"), + ("nom", "themselves", "gender-neutral", "themself"), + ("nom", "themselves", "feminine or masculine", "herself or himself"), + ("nom", "themselves", "masculine or feminine", "himself or herself"), + ("nom", "theirs", "neuter", "its"), + ("nom", "theirs", "feminine", "hers"), + ("nom", "theirs", "masculine", "his"), + ("nom", "theirs", "gender-neutral", "theirs"), + ("nom", "theirs", "feminine or masculine", "hers or his"), + ("nom", "theirs", "masculine or feminine", "his or hers"), + ("acc", "them", "neuter", "it"), + ("acc", "them", "feminine", "her"), + ("acc", "them", "masculine", "him"), + ("acc", "them", "gender-neutral", "them"), + ("acc", "them", "feminine or masculine", "her or him"), + ("acc", "them", "masculine or feminine", "him or her"), + ("acc", "themselves", "neuter", "itself"), + ("acc", "themselves", "feminine", "herself"), + ("acc", "themselves", "masculine", "himself"), + ("acc", "themselves", "gender-neutral", "themself"), + ("acc", "themselves", "feminine or masculine", "herself or himself"), + ("acc", "themselves", "masculine or feminine", "himself or herself"), +): + try: + si_pron[_thecase][_plur][_gend] = _sing # type: ignore + except TypeError: + si_pron[_thecase][_plur] = {} + si_pron[_thecase][_plur][_gend] = _sing # type: ignore + + +si_pron_acc_keys = enclose("|".join(si_pron["acc"])) +si_pron_acc_keys_bysize = bysize(si_pron["acc"]) + + +def get_si_pron(thecase, word, gender) -> str: + try: + sing = si_pron[thecase][word] + except KeyError: + raise # not a pronoun + try: + return sing[gender] # has several types due to gender + except TypeError: + return cast(str, sing) # answer independent of gender + + +# These dictionaries group verbs by first, second and third person +# conjugations. + +plverb_irregular_pres = { + "am": "are", + "are": "are", + "is": "are", + "was": "were", + "were": "were", + "was": "were", + "have": "have", + "have": "have", + "has": "have", + "do": "do", + "do": "do", + "does": "do", +} + +plverb_ambiguous_pres = { + "act": "act", + "act": "act", + "acts": "act", + "blame": "blame", + "blame": "blame", + "blames": "blame", + "can": "can", + "can": "can", + "can": "can", + "must": "must", + "must": "must", + "must": "must", + "fly": "fly", + "fly": "fly", + "flies": "fly", + "copy": "copy", + "copy": "copy", + "copies": "copy", + "drink": "drink", + "drink": "drink", + "drinks": "drink", + "fight": "fight", + "fight": "fight", + "fights": "fight", + "fire": "fire", + "fire": "fire", + "fires": "fire", + "like": "like", + "like": "like", + "likes": "like", + "look": "look", + "look": "look", + "looks": "look", + "make": "make", + "make": "make", + "makes": "make", + "reach": "reach", + "reach": "reach", + "reaches": "reach", + "run": "run", + "run": "run", + "runs": "run", + "sink": "sink", + "sink": "sink", + "sinks": "sink", + "sleep": "sleep", + "sleep": "sleep", + "sleeps": "sleep", + "view": "view", + "view": "view", + "views": "view", +} + +plverb_ambiguous_pres_keys = re.compile( + fr"^({enclose('|'.join(plverb_ambiguous_pres))})((\s.*)?)$", re.IGNORECASE +) + + +plverb_irregular_non_pres = ( + "did", + "had", + "ate", + "made", + "put", + "spent", + "fought", + "sank", + "gave", + "sought", + "shall", + "could", + "ought", + "should", +) + +plverb_ambiguous_non_pres = re.compile( + r"^((?:thought|saw|bent|will|might|cut))((\s.*)?)$", re.IGNORECASE +) + +# "..oes" -> "..oe" (the rest are "..oes" -> "o") + +pl_v_oes_oe = ("canoes", "floes", "oboes", "roes", "throes", "woes") +pl_v_oes_oe_endings_size4 = ("hoes", "toes") +pl_v_oes_oe_endings_size5 = ("shoes",) + + +pl_count_zero = ("0", "no", "zero", "nil") + + +pl_count_one = ("1", "a", "an", "one", "each", "every", "this", "that") + +pl_adj_special = {"a": "some", "an": "some", "this": "these", "that": "those"} + +pl_adj_special_keys = re.compile( + fr"^({enclose('|'.join(pl_adj_special))})$", re.IGNORECASE +) + +pl_adj_poss = { + "my": "our", + "your": "your", + "its": "their", + "her": "their", + "his": "their", + "their": "their", +} + +pl_adj_poss_keys = re.compile(fr"^({enclose('|'.join(pl_adj_poss))})$", re.IGNORECASE) + + +# 2. INDEFINITE ARTICLES + +# THIS PATTERN MATCHES STRINGS OF CAPITALS STARTING WITH A "VOWEL-SOUND" +# CONSONANT FOLLOWED BY ANOTHER CONSONANT, AND WHICH ARE NOT LIKELY +# TO BE REAL WORDS (OH, ALL RIGHT THEN, IT'S JUST MAGIC!) + +A_abbrev = re.compile( + r""" +(?! FJO | [HLMNS]Y. | RY[EO] | SQU + | ( F[LR]? | [HL] | MN? | N | RH? | S[CHKLMNPTVW]? | X(YL)?) [AEIOU]) +[FHLMNRSX][A-Z] +""", + re.VERBOSE, +) + +# THIS PATTERN CODES THE BEGINNINGS OF ALL ENGLISH WORDS BEGINING WITH A +# 'y' FOLLOWED BY A CONSONANT. ANY OTHER Y-CONSONANT PREFIX THEREFORE +# IMPLIES AN ABBREVIATION. + +A_y_cons = re.compile(r"^(y(b[lor]|cl[ea]|fere|gg|p[ios]|rou|tt))", re.IGNORECASE) + +# EXCEPTIONS TO EXCEPTIONS + +A_explicit_a = re.compile(r"^((?:unabomber|unanimous|US))", re.IGNORECASE) + +A_explicit_an = re.compile( + r"^((?:euler|hour(?!i)|heir|honest|hono[ur]|mpeg))", re.IGNORECASE +) + +A_ordinal_an = re.compile(r"^([aefhilmnorsx]-?th)", re.IGNORECASE) + +A_ordinal_a = re.compile(r"^([bcdgjkpqtuvwyz]-?th)", re.IGNORECASE) + + +# NUMERICAL INFLECTIONS + +nth = { + 0: "th", + 1: "st", + 2: "nd", + 3: "rd", + 4: "th", + 5: "th", + 6: "th", + 7: "th", + 8: "th", + 9: "th", + 11: "th", + 12: "th", + 13: "th", +} +nth_suff = set(nth.values()) + +ordinal = dict( + ty="tieth", + one="first", + two="second", + three="third", + five="fifth", + eight="eighth", + nine="ninth", + twelve="twelfth", +) + +ordinal_suff = re.compile(fr"({'|'.join(ordinal)})\Z") + + +# NUMBERS + +unit = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"] +teen = [ + "ten", + "eleven", + "twelve", + "thirteen", + "fourteen", + "fifteen", + "sixteen", + "seventeen", + "eighteen", + "nineteen", +] +ten = [ + "", + "", + "twenty", + "thirty", + "forty", + "fifty", + "sixty", + "seventy", + "eighty", + "ninety", +] +mill = [ + " ", + " thousand", + " million", + " billion", + " trillion", + " quadrillion", + " quintillion", + " sextillion", + " septillion", + " octillion", + " nonillion", + " decillion", +] + + +# SUPPORT CLASSICAL PLURALIZATIONS + +def_classical = dict( + all=False, zero=False, herd=False, names=True, persons=False, ancient=False +) + +all_classical = {k: True for k in def_classical} +no_classical = {k: False for k in def_classical} + + +# Maps strings to built-in constant types +string_to_constant = {"True": True, "False": False, "None": None} + + +# Pre-compiled regular expression objects +DOLLAR_DIGITS = re.compile(r"\$(\d+)") +FUNCTION_CALL = re.compile(r"((\w+)\([^)]*\)*)", re.IGNORECASE) +PARTITION_WORD = re.compile(r"\A(\s*)(.+?)(\s*)\Z") +PL_SB_POSTFIX_ADJ_STEMS_RE = re.compile( + fr"^(?:{pl_sb_postfix_adj_stems})$", re.IGNORECASE +) +PL_SB_PREP_DUAL_COMPOUND_RE = re.compile( + fr"^(?:{pl_sb_prep_dual_compound})$", re.IGNORECASE +) +DENOMINATOR = re.compile(r"(?P.+)( (per|a) .+)") +PLVERB_SPECIAL_S_RE = re.compile(fr"^({plverb_special_s})$") +WHITESPACE = re.compile(r"\s") +ENDS_WITH_S = re.compile(r"^(.*[^s])s$", re.IGNORECASE) +ENDS_WITH_APOSTROPHE_S = re.compile(r"^(.*)'s?$") +INDEFINITE_ARTICLE_TEST = re.compile(r"\A(\s*)(?:an?\s+)?(.+?)(\s*)\Z", re.IGNORECASE) +SPECIAL_AN = re.compile(r"^[aefhilmnorsx]$", re.IGNORECASE) +SPECIAL_A = re.compile(r"^[bcdgjkpqtuvwyz]$", re.IGNORECASE) +SPECIAL_ABBREV_AN = re.compile(r"^[aefhilmnorsx][.-]", re.IGNORECASE) +SPECIAL_ABBREV_A = re.compile(r"^[a-z][.-]", re.IGNORECASE) +CONSONANTS = re.compile(r"^[^aeiouy]", re.IGNORECASE) +ARTICLE_SPECIAL_EU = re.compile(r"^e[uw]", re.IGNORECASE) +ARTICLE_SPECIAL_ONCE = re.compile(r"^onc?e\b", re.IGNORECASE) +ARTICLE_SPECIAL_ONETIME = re.compile(r"^onetime\b", re.IGNORECASE) +ARTICLE_SPECIAL_UNIT = re.compile(r"^uni([^nmd]|mo)", re.IGNORECASE) +ARTICLE_SPECIAL_UBA = re.compile(r"^u[bcfghjkqrst][aeiou]", re.IGNORECASE) +ARTICLE_SPECIAL_UKR = re.compile(r"^ukr", re.IGNORECASE) +SPECIAL_CAPITALS = re.compile(r"^U[NK][AIEO]?") +VOWELS = re.compile(r"^[aeiou]", re.IGNORECASE) + +DIGIT_GROUP = re.compile(r"(\d)") +TWO_DIGITS = re.compile(r"(\d)(\d)") +THREE_DIGITS = re.compile(r"(\d)(\d)(\d)") +THREE_DIGITS_WORD = re.compile(r"(\d)(\d)(\d)(?=\D*\Z)") +TWO_DIGITS_WORD = re.compile(r"(\d)(\d)(?=\D*\Z)") +ONE_DIGIT_WORD = re.compile(r"(\d)(?=\D*\Z)") + +FOUR_DIGIT_COMMA = re.compile(r"(\d)(\d{3}(?:,|\Z))") +NON_DIGIT = re.compile(r"\D") +WHITESPACES_COMMA = re.compile(r"\s+,") +COMMA_WORD = re.compile(r", (\S+)\s+\Z") +WHITESPACES = re.compile(r"\s+") + + +PRESENT_PARTICIPLE_REPLACEMENTS = ( + (re.compile(r"ie$"), r"y"), + ( + re.compile(r"ue$"), + r"u", + ), # TODO: isn't ue$ -> u encompassed in the following rule? + (re.compile(r"([auy])e$"), r"\g<1>"), + (re.compile(r"ski$"), r"ski"), + (re.compile(r"[^b]i$"), r""), + (re.compile(r"^(are|were)$"), r"be"), + (re.compile(r"^(had)$"), r"hav"), + (re.compile(r"^(hoe)$"), r"\g<1>"), + (re.compile(r"([^e])e$"), r"\g<1>"), + (re.compile(r"er$"), r"er"), + (re.compile(r"([^aeiou][aeiouy]([bdgmnprst]))$"), r"\g<1>\g<2>"), +) + +DIGIT = re.compile(r"\d") + + +class Words(str): + lowered: str + split_: List[str] + first: str + last: str + + def __init__(self, orig) -> None: + self.lowered = self.lower() + self.split_ = self.split() + self.first = self.split_[0] + self.last = self.split_[-1] + + +Word = Annotated[str, Field(min_length=1)] +Falsish = Any # ideally, falsish would only validate on bool(value) is False + + +class engine: + def __init__(self) -> None: + + self.classical_dict = def_classical.copy() + self.persistent_count: Optional[int] = None + self.mill_count = 0 + self.pl_sb_user_defined: List[str] = [] + self.pl_v_user_defined: List[str] = [] + self.pl_adj_user_defined: List[str] = [] + self.si_sb_user_defined: List[str] = [] + self.A_a_user_defined: List[str] = [] + self.thegender = "neuter" + self.__number_args: Optional[Dict[str, str]] = None + + @property + def _number_args(self): + return cast(Dict[str, str], self.__number_args) + + @_number_args.setter + def _number_args(self, val): + self.__number_args = val + + deprecated_methods = dict( + pl="plural", + plnoun="plural_noun", + plverb="plural_verb", + pladj="plural_adj", + sinoun="single_noun", + prespart="present_participle", + numwords="number_to_words", + plequal="compare", + plnounequal="compare_nouns", + plverbequal="compare_verbs", + pladjequal="compare_adjs", + wordlist="join", + ) + + def __getattr__(self, meth): + if meth in self.deprecated_methods: + print3(f"{meth}() deprecated, use {self.deprecated_methods[meth]}()") + raise DeprecationWarning + raise AttributeError + + def defnoun(self, singular: str, plural: str) -> int: + """ + Set the noun plural of singular to plural. + + """ + self.checkpat(singular) + self.checkpatplural(plural) + self.pl_sb_user_defined.extend((singular, plural)) + self.si_sb_user_defined.extend((plural, singular)) + return 1 + + def defverb(self, s1: str, p1: str, s2: str, p2: str, s3: str, p3: str) -> int: + """ + Set the verb plurals for s1, s2 and s3 to p1, p2 and p3 respectively. + + Where 1, 2 and 3 represent the 1st, 2nd and 3rd person forms of the verb. + + """ + self.checkpat(s1) + self.checkpat(s2) + self.checkpat(s3) + self.checkpatplural(p1) + self.checkpatplural(p2) + self.checkpatplural(p3) + self.pl_v_user_defined.extend((s1, p1, s2, p2, s3, p3)) + return 1 + + def defadj(self, singular: str, plural: str) -> int: + """ + Set the adjective plural of singular to plural. + + """ + self.checkpat(singular) + self.checkpatplural(plural) + self.pl_adj_user_defined.extend((singular, plural)) + return 1 + + def defa(self, pattern: str) -> int: + """ + Define the indefinite article as 'a' for words matching pattern. + + """ + self.checkpat(pattern) + self.A_a_user_defined.extend((pattern, "a")) + return 1 + + def defan(self, pattern: str) -> int: + """ + Define the indefinite article as 'an' for words matching pattern. + + """ + self.checkpat(pattern) + self.A_a_user_defined.extend((pattern, "an")) + return 1 + + def checkpat(self, pattern: Optional[str]) -> None: + """ + check for errors in a regex pattern + """ + if pattern is None: + return + try: + re.match(pattern, "") + except re.error: + print3(f"\nBad user-defined singular pattern:\n\t{pattern}\n") + raise BadUserDefinedPatternError + + def checkpatplural(self, pattern: str) -> None: + """ + check for errors in a regex replace pattern + """ + return + + @validate_arguments + def ud_match(self, word: Word, wordlist: Sequence[Optional[Word]]) -> Optional[str]: + for i in range(len(wordlist) - 2, -2, -2): # backwards through even elements + mo = re.search(fr"^{wordlist[i]}$", word, re.IGNORECASE) + if mo: + if wordlist[i + 1] is None: + return None + pl = DOLLAR_DIGITS.sub( + r"\\1", cast(Word, wordlist[i + 1]) + ) # change $n to \n for expand + return mo.expand(pl) + return None + + def classical(self, **kwargs) -> None: + """ + turn classical mode on and off for various categories + + turn on all classical modes: + classical() + classical(all=True) + + turn on or off specific claassical modes: + e.g. + classical(herd=True) + classical(names=False) + + By default all classical modes are off except names. + + unknown value in args or key in kwargs raises + exception: UnknownClasicalModeError + + """ + if not kwargs: + self.classical_dict = all_classical.copy() + return + if "all" in kwargs: + if kwargs["all"]: + self.classical_dict = all_classical.copy() + else: + self.classical_dict = no_classical.copy() + + for k, v in kwargs.items(): + if k in def_classical: + self.classical_dict[k] = v + else: + raise UnknownClassicalModeError + + def num( + self, count: Optional[int] = None, show: Optional[int] = None + ) -> str: # (;$count,$show) + """ + Set the number to be used in other method calls. + + Returns count. + + Set show to False to return '' instead. + + """ + if count is not None: + try: + self.persistent_count = int(count) + except ValueError: + raise BadNumValueError + if (show is None) or show: + return str(count) + else: + self.persistent_count = None + return "" + + def gender(self, gender: str) -> None: + """ + set the gender for the singular of plural pronouns + + can be one of: + 'neuter' ('they' -> 'it') + 'feminine' ('they' -> 'she') + 'masculine' ('they' -> 'he') + 'gender-neutral' ('they' -> 'they') + 'feminine or masculine' ('they' -> 'she or he') + 'masculine or feminine' ('they' -> 'he or she') + """ + if gender in singular_pronoun_genders: + self.thegender = gender + else: + raise BadGenderError + + def _get_value_from_ast(self, obj): + """ + Return the value of the ast object. + """ + if isinstance(obj, ast.Num): + return obj.n + elif isinstance(obj, ast.Str): + return obj.s + elif isinstance(obj, ast.List): + return [self._get_value_from_ast(e) for e in obj.elts] + elif isinstance(obj, ast.Tuple): + return tuple([self._get_value_from_ast(e) for e in obj.elts]) + + # None, True and False are NameConstants in Py3.4 and above. + elif isinstance(obj, ast.NameConstant): + return obj.value + + # Probably passed a variable name. + # Or passed a single word without wrapping it in quotes as an argument + # ex: p.inflect("I plural(see)") instead of p.inflect("I plural('see')") + raise NameError(f"name '{obj.id}' is not defined") + + def _string_to_substitute( + self, mo: Match, methods_dict: Dict[str, Callable] + ) -> str: + """ + Return the string to be substituted for the match. + """ + matched_text, f_name = mo.groups() + # matched_text is the complete match string. e.g. plural_noun(cat) + # f_name is the function name. e.g. plural_noun + + # Return matched_text if function name is not in methods_dict + if f_name not in methods_dict: + return matched_text + + # Parse the matched text + a_tree = ast.parse(matched_text) + + # get the args and kwargs from ast objects + args_list = [ + self._get_value_from_ast(a) + for a in a_tree.body[0].value.args # type: ignore[attr-defined] + ] + kwargs_list = { + kw.arg: self._get_value_from_ast(kw.value) + for kw in a_tree.body[0].value.keywords # type: ignore[attr-defined] + } + + # Call the corresponding function + return methods_dict[f_name](*args_list, **kwargs_list) + + # 0. PERFORM GENERAL INFLECTIONS IN A STRING + + @validate_arguments + def inflect(self, text: Word) -> str: + """ + Perform inflections in a string. + + e.g. inflect('The plural of cat is plural(cat)') returns + 'The plural of cat is cats' + + can use plural, plural_noun, plural_verb, plural_adj, + singular_noun, a, an, no, ordinal, number_to_words, + and prespart + + """ + save_persistent_count = self.persistent_count + + # Dictionary of allowed methods + methods_dict: Dict[str, Callable] = { + "plural": self.plural, + "plural_adj": self.plural_adj, + "plural_noun": self.plural_noun, + "plural_verb": self.plural_verb, + "singular_noun": self.singular_noun, + "a": self.a, + "an": self.a, + "no": self.no, + "ordinal": self.ordinal, + "number_to_words": self.number_to_words, + "present_participle": self.present_participle, + "num": self.num, + } + + # Regular expression to find Python's function call syntax + output = FUNCTION_CALL.sub( + lambda mo: self._string_to_substitute(mo, methods_dict), text + ) + self.persistent_count = save_persistent_count + return output + + # ## PLURAL SUBROUTINES + + def postprocess(self, orig: str, inflected) -> str: + inflected = str(inflected) + if "|" in inflected: + word_options = inflected.split("|") + # When two parts of a noun need to be pluralized + if len(word_options[0].split(" ")) == len(word_options[1].split(" ")): + result = inflected.split("|")[self.classical_dict["all"]].split(" ") + # When only the last part of the noun needs to be pluralized + else: + result = inflected.split(" ") + for index, word in enumerate(result): + if "|" in word: + result[index] = word.split("|")[self.classical_dict["all"]] + else: + result = inflected.split(" ") + + # Try to fix word wise capitalization + for index, word in enumerate(orig.split(" ")): + if word == "I": + # Is this the only word for exceptions like this + # Where the original is fully capitalized + # without 'meaning' capitalization? + # Also this fails to handle a capitalizaion in context + continue + if word.capitalize() == word: + result[index] = result[index].capitalize() + if word == word.upper(): + result[index] = result[index].upper() + return " ".join(result) + + def partition_word(self, text: str) -> Tuple[str, str, str]: + mo = PARTITION_WORD.search(text) + if mo: + return mo.group(1), mo.group(2), mo.group(3) + else: + return "", "", "" + + @validate_arguments + def plural(self, text: Word, count: Optional[Union[str, int, Any]] = None) -> str: + """ + Return the plural of text. + + If count supplied, then return text if count is one of: + 1, a, an, one, each, every, this, that + + otherwise return the plural. + + Whitespace at the start and end is preserved. + + """ + pre, word, post = self.partition_word(text) + if not word: + return text + plural = self.postprocess( + word, + self._pl_special_adjective(word, count) + or self._pl_special_verb(word, count) + or self._plnoun(word, count), + ) + return f"{pre}{plural}{post}" + + @validate_arguments + def plural_noun( + self, text: Word, count: Optional[Union[str, int, Any]] = None + ) -> str: + """ + Return the plural of text, where text is a noun. + + If count supplied, then return text if count is one of: + 1, a, an, one, each, every, this, that + + otherwise return the plural. + + Whitespace at the start and end is preserved. + + """ + pre, word, post = self.partition_word(text) + if not word: + return text + plural = self.postprocess(word, self._plnoun(word, count)) + return f"{pre}{plural}{post}" + + @validate_arguments + def plural_verb( + self, text: Word, count: Optional[Union[str, int, Any]] = None + ) -> str: + """ + Return the plural of text, where text is a verb. + + If count supplied, then return text if count is one of: + 1, a, an, one, each, every, this, that + + otherwise return the plural. + + Whitespace at the start and end is preserved. + + """ + pre, word, post = self.partition_word(text) + if not word: + return text + plural = self.postprocess( + word, + self._pl_special_verb(word, count) or self._pl_general_verb(word, count), + ) + return f"{pre}{plural}{post}" + + @validate_arguments + def plural_adj( + self, text: Word, count: Optional[Union[str, int, Any]] = None + ) -> str: + """ + Return the plural of text, where text is an adjective. + + If count supplied, then return text if count is one of: + 1, a, an, one, each, every, this, that + + otherwise return the plural. + + Whitespace at the start and end is preserved. + + """ + pre, word, post = self.partition_word(text) + if not word: + return text + plural = self.postprocess(word, self._pl_special_adjective(word, count) or word) + return f"{pre}{plural}{post}" + + @validate_arguments + def compare(self, word1: Word, word2: Word) -> Union[str, bool]: + """ + compare word1 and word2 for equality regardless of plurality + + return values: + eq - the strings are equal + p:s - word1 is the plural of word2 + s:p - word2 is the plural of word1 + p:p - word1 and word2 are two different plural forms of the one word + False - otherwise + + >>> compare = engine().compare + >>> compare("egg", "eggs") + 's:p' + >>> compare('egg', 'egg') + 'eq' + + Words should not be empty. + + >>> compare('egg', '') + Traceback (most recent call last): + ... + pydantic.error_wrappers.ValidationError: 1 validation error for Compare + word2 + ensure this value has at least 1 characters... + """ + norms = self.plural_noun, self.plural_verb, self.plural_adj + results = (self._plequal(word1, word2, norm) for norm in norms) + return next(filter(None, results), False) + + @validate_arguments + def compare_nouns(self, word1: Word, word2: Word) -> Union[str, bool]: + """ + compare word1 and word2 for equality regardless of plurality + word1 and word2 are to be treated as nouns + + return values: + eq - the strings are equal + p:s - word1 is the plural of word2 + s:p - word2 is the plural of word1 + p:p - word1 and word2 are two different plural forms of the one word + False - otherwise + + """ + return self._plequal(word1, word2, self.plural_noun) + + @validate_arguments + def compare_verbs(self, word1: Word, word2: Word) -> Union[str, bool]: + """ + compare word1 and word2 for equality regardless of plurality + word1 and word2 are to be treated as verbs + + return values: + eq - the strings are equal + p:s - word1 is the plural of word2 + s:p - word2 is the plural of word1 + p:p - word1 and word2 are two different plural forms of the one word + False - otherwise + + """ + return self._plequal(word1, word2, self.plural_verb) + + @validate_arguments + def compare_adjs(self, word1: Word, word2: Word) -> Union[str, bool]: + """ + compare word1 and word2 for equality regardless of plurality + word1 and word2 are to be treated as adjectives + + return values: + eq - the strings are equal + p:s - word1 is the plural of word2 + s:p - word2 is the plural of word1 + p:p - word1 and word2 are two different plural forms of the one word + False - otherwise + + """ + return self._plequal(word1, word2, self.plural_adj) + + @validate_arguments + def singular_noun( + self, + text: Word, + count: Optional[Union[int, str, Any]] = None, + gender: Optional[str] = None, + ) -> Union[str, bool]: + """ + Return the singular of text, where text is a plural noun. + + If count supplied, then return the singular if count is one of: + 1, a, an, one, each, every, this, that or if count is None + + otherwise return text unchanged. + + Whitespace at the start and end is preserved. + + >>> p = engine() + >>> p.singular_noun('horses') + 'horse' + >>> p.singular_noun('knights') + 'knight' + + Returns False when a singular noun is passed. + + >>> p.singular_noun('horse') + False + >>> p.singular_noun('knight') + False + >>> p.singular_noun('soldier') + False + + """ + pre, word, post = self.partition_word(text) + if not word: + return text + sing = self._sinoun(word, count=count, gender=gender) + if sing is not False: + plural = self.postprocess(word, sing) + return f"{pre}{plural}{post}" + return False + + def _plequal(self, word1: str, word2: str, pl) -> Union[str, bool]: # noqa: C901 + classval = self.classical_dict.copy() + self.classical_dict = all_classical.copy() + if word1 == word2: + return "eq" + if word1 == pl(word2): + return "p:s" + if pl(word1) == word2: + return "s:p" + self.classical_dict = no_classical.copy() + if word1 == pl(word2): + return "p:s" + if pl(word1) == word2: + return "s:p" + self.classical_dict = classval.copy() + + if pl == self.plural or pl == self.plural_noun: + if self._pl_check_plurals_N(word1, word2): + return "p:p" + if self._pl_check_plurals_N(word2, word1): + return "p:p" + if pl == self.plural or pl == self.plural_adj: + if self._pl_check_plurals_adj(word1, word2): + return "p:p" + return False + + def _pl_reg_plurals(self, pair: str, stems: str, end1: str, end2: str) -> bool: + pattern = fr"({stems})({end1}\|\1{end2}|{end2}\|\1{end1})" + return bool(re.search(pattern, pair)) + + def _pl_check_plurals_N(self, word1: str, word2: str) -> bool: + stem_endings = ( + (pl_sb_C_a_ata, "as", "ata"), + (pl_sb_C_is_ides, "is", "ides"), + (pl_sb_C_a_ae, "s", "e"), + (pl_sb_C_en_ina, "ens", "ina"), + (pl_sb_C_um_a, "ums", "a"), + (pl_sb_C_us_i, "uses", "i"), + (pl_sb_C_on_a, "ons", "a"), + (pl_sb_C_o_i_stems, "os", "i"), + (pl_sb_C_ex_ices, "exes", "ices"), + (pl_sb_C_ix_ices, "ixes", "ices"), + (pl_sb_C_i, "s", "i"), + (pl_sb_C_im, "s", "im"), + (".*eau", "s", "x"), + (".*ieu", "s", "x"), + (".*tri", "xes", "ces"), + (".{2,}[yia]n", "xes", "ges"), + ) + + words = map(Words, (word1, word2)) + pair = "|".join(word.last for word in words) + + return ( + pair in pl_sb_irregular_s.values() + or pair in pl_sb_irregular.values() + or pair in pl_sb_irregular_caps.values() + or any( + self._pl_reg_plurals(pair, stems, end1, end2) + for stems, end1, end2 in stem_endings + ) + ) + + def _pl_check_plurals_adj(self, word1: str, word2: str) -> bool: + word1a = word1[: word1.rfind("'")] if word1.endswith(("'s", "'")) else "" + word2a = word2[: word2.rfind("'")] if word2.endswith(("'s", "'")) else "" + + return ( + bool(word1a) + and bool(word2a) + and ( + self._pl_check_plurals_N(word1a, word2a) + or self._pl_check_plurals_N(word2a, word1a) + ) + ) + + def get_count(self, count: Optional[Union[str, int]] = None) -> Union[str, int]: + if count is None and self.persistent_count is not None: + count = self.persistent_count + + if count is not None: + count = ( + 1 + if ( + (str(count) in pl_count_one) + or ( + self.classical_dict["zero"] + and str(count).lower() in pl_count_zero + ) + ) + else 2 + ) + else: + count = "" + return count + + # @profile + def _plnoun( # noqa: C901 + self, word: str, count: Optional[Union[str, int]] = None + ) -> str: + count = self.get_count(count) + + # DEFAULT TO PLURAL + + if count == 1: + return word + + # HANDLE USER-DEFINED NOUNS + + value = self.ud_match(word, self.pl_sb_user_defined) + if value is not None: + return value + + # HANDLE EMPTY WORD, SINGULAR COUNT AND UNINFLECTED PLURALS + + if word == "": + return word + + word = Words(word) + + if word.last.lower() in pl_sb_uninflected_complete: + return word + + if word in pl_sb_uninflected_caps: + return word + + for k, v in pl_sb_uninflected_bysize.items(): + if word.lowered[-k:] in v: + return word + + if self.classical_dict["herd"] and word.last.lower() in pl_sb_uninflected_herd: + return word + + # HANDLE COMPOUNDS ("Governor General", "mother-in-law", "aide-de-camp", ETC.) + + mo = PL_SB_POSTFIX_ADJ_STEMS_RE.search(word) + if mo and mo.group(2) != "": + return f"{self._plnoun(mo.group(1), 2)}{mo.group(2)}" + + if " a " in word.lowered or "-a-" in word.lowered: + mo = PL_SB_PREP_DUAL_COMPOUND_RE.search(word) + if mo and mo.group(2) != "" and mo.group(3) != "": + return ( + f"{self._plnoun(mo.group(1), 2)}" + f"{mo.group(2)}" + f"{self._plnoun(mo.group(3))}" + ) + + if len(word.split_) >= 3: + for numword in range(1, len(word.split_) - 1): + if word.split_[numword] in pl_prep_list_da: + return " ".join( + word.split_[: numword - 1] + + [self._plnoun(word.split_[numword - 1], 2)] + + word.split_[numword:] + ) + + # only pluralize denominators in units + mo = DENOMINATOR.search(word.lowered) + if mo: + index = len(mo.group("denominator")) + return f"{self._plnoun(word[:index])}{word[index:]}" + + # handle units given in degrees (only accept if + # there is no more than one word following) + # degree Celsius => degrees Celsius but degree + # fahrenheit hour => degree fahrenheit hours + if len(word.split_) >= 2 and word.split_[-2] == "degree": + return " ".join([self._plnoun(word.first)] + word.split_[1:]) + + with contextlib.suppress(ValueError): + return self._handle_prepositional_phrase( + word.lowered, + functools.partial(self._plnoun, count=2), + '-', + ) + + # HANDLE PRONOUNS + + for k, v in pl_pron_acc_keys_bysize.items(): + if word.lowered[-k:] in v: # ends with accusative pronoun + for pk, pv in pl_prep_bysize.items(): + if word.lowered[:pk] in pv: # starts with a prep + if word.lowered.split() == [ + word.lowered[:pk], + word.lowered[-k:], + ]: + # only whitespace in between + return word.lowered[:-k] + pl_pron_acc[word.lowered[-k:]] + + try: + return pl_pron_nom[word.lowered] + except KeyError: + pass + + try: + return pl_pron_acc[word.lowered] + except KeyError: + pass + + # HANDLE ISOLATED IRREGULAR PLURALS + + if word.last in pl_sb_irregular_caps: + llen = len(word.last) + return f"{word[:-llen]}{pl_sb_irregular_caps[word.last]}" + + lowered_last = word.last.lower() + if lowered_last in pl_sb_irregular: + llen = len(lowered_last) + return f"{word[:-llen]}{pl_sb_irregular[lowered_last]}" + + dash_split = word.lowered.split('-') + if (" ".join(dash_split[-2:])).lower() in pl_sb_irregular_compound: + llen = len( + " ".join(dash_split[-2:]) + ) # TODO: what if 2 spaces between these words? + return ( + f"{word[:-llen]}" + f"{pl_sb_irregular_compound[(' '.join(dash_split[-2:])).lower()]}" + ) + + if word.lowered[-3:] == "quy": + return f"{word[:-1]}ies" + + if word.lowered[-6:] == "person": + if self.classical_dict["persons"]: + return f"{word}s" + else: + return f"{word[:-4]}ople" + + # HANDLE FAMILIES OF IRREGULAR PLURALS + + if word.lowered[-3:] == "man": + for k, v in pl_sb_U_man_mans_bysize.items(): + if word.lowered[-k:] in v: + return f"{word}s" + for k, v in pl_sb_U_man_mans_caps_bysize.items(): + if word[-k:] in v: + return f"{word}s" + return f"{word[:-3]}men" + if word.lowered[-5:] == "mouse": + return f"{word[:-5]}mice" + if word.lowered[-5:] == "louse": + v = pl_sb_U_louse_lice_bysize.get(len(word)) + if v and word.lowered in v: + return f"{word[:-5]}lice" + return f"{word}s" + if word.lowered[-5:] == "goose": + return f"{word[:-5]}geese" + if word.lowered[-5:] == "tooth": + return f"{word[:-5]}teeth" + if word.lowered[-4:] == "foot": + return f"{word[:-4]}feet" + if word.lowered[-4:] == "taco": + return f"{word[:-5]}tacos" + + if word.lowered == "die": + return "dice" + + # HANDLE UNASSIMILATED IMPORTS + + if word.lowered[-4:] == "ceps": + return word + if word.lowered[-4:] == "zoon": + return f"{word[:-2]}a" + if word.lowered[-3:] in ("cis", "sis", "xis"): + return f"{word[:-2]}es" + + for lastlet, d, numend, post in ( + ("h", pl_sb_U_ch_chs_bysize, None, "s"), + ("x", pl_sb_U_ex_ices_bysize, -2, "ices"), + ("x", pl_sb_U_ix_ices_bysize, -2, "ices"), + ("m", pl_sb_U_um_a_bysize, -2, "a"), + ("s", pl_sb_U_us_i_bysize, -2, "i"), + ("n", pl_sb_U_on_a_bysize, -2, "a"), + ("a", pl_sb_U_a_ae_bysize, None, "e"), + ): + if word.lowered[-1] == lastlet: # this test to add speed + for k, v in d.items(): + if word.lowered[-k:] in v: + return word[:numend] + post + + # HANDLE INCOMPLETELY ASSIMILATED IMPORTS + + if self.classical_dict["ancient"]: + if word.lowered[-4:] == "trix": + return f"{word[:-1]}ces" + if word.lowered[-3:] in ("eau", "ieu"): + return f"{word}x" + if word.lowered[-3:] in ("ynx", "inx", "anx") and len(word) > 4: + return f"{word[:-1]}ges" + + for lastlet, d, numend, post in ( + ("n", pl_sb_C_en_ina_bysize, -2, "ina"), + ("x", pl_sb_C_ex_ices_bysize, -2, "ices"), + ("x", pl_sb_C_ix_ices_bysize, -2, "ices"), + ("m", pl_sb_C_um_a_bysize, -2, "a"), + ("s", pl_sb_C_us_i_bysize, -2, "i"), + ("s", pl_sb_C_us_us_bysize, None, ""), + ("a", pl_sb_C_a_ae_bysize, None, "e"), + ("a", pl_sb_C_a_ata_bysize, None, "ta"), + ("s", pl_sb_C_is_ides_bysize, -1, "des"), + ("o", pl_sb_C_o_i_bysize, -1, "i"), + ("n", pl_sb_C_on_a_bysize, -2, "a"), + ): + if word.lowered[-1] == lastlet: # this test to add speed + for k, v in d.items(): + if word.lowered[-k:] in v: + return word[:numend] + post + + for d, numend, post in ( + (pl_sb_C_i_bysize, None, "i"), + (pl_sb_C_im_bysize, None, "im"), + ): + for k, v in d.items(): + if word.lowered[-k:] in v: + return word[:numend] + post + + # HANDLE SINGULAR NOUNS ENDING IN ...s OR OTHER SILIBANTS + + if lowered_last in pl_sb_singular_s_complete: + return f"{word}es" + + for k, v in pl_sb_singular_s_bysize.items(): + if word.lowered[-k:] in v: + return f"{word}es" + + if word.lowered[-2:] == "es" and word[0] == word[0].upper(): + return f"{word}es" + + if word.lowered[-1] == "z": + for k, v in pl_sb_z_zes_bysize.items(): + if word.lowered[-k:] in v: + return f"{word}es" + + if word.lowered[-2:-1] != "z": + return f"{word}zes" + + if word.lowered[-2:] == "ze": + for k, v in pl_sb_ze_zes_bysize.items(): + if word.lowered[-k:] in v: + return f"{word}s" + + if word.lowered[-2:] in ("ch", "sh", "zz", "ss") or word.lowered[-1] == "x": + return f"{word}es" + + # HANDLE ...f -> ...ves + + if word.lowered[-3:] in ("elf", "alf", "olf"): + return f"{word[:-1]}ves" + if word.lowered[-3:] == "eaf" and word.lowered[-4:-3] != "d": + return f"{word[:-1]}ves" + if word.lowered[-4:] in ("nife", "life", "wife"): + return f"{word[:-2]}ves" + if word.lowered[-3:] == "arf": + return f"{word[:-1]}ves" + + # HANDLE ...y + + if word.lowered[-1] == "y": + if word.lowered[-2:-1] in "aeiou" or len(word) == 1: + return f"{word}s" + + if self.classical_dict["names"]: + if word.lowered[-1] == "y" and word[0] == word[0].upper(): + return f"{word}s" + + return f"{word[:-1]}ies" + + # HANDLE ...o + + if lowered_last in pl_sb_U_o_os_complete: + return f"{word}s" + + for k, v in pl_sb_U_o_os_bysize.items(): + if word.lowered[-k:] in v: + return f"{word}s" + + if word.lowered[-2:] in ("ao", "eo", "io", "oo", "uo"): + return f"{word}s" + + if word.lowered[-1] == "o": + return f"{word}es" + + # OTHERWISE JUST ADD ...s + + return f"{word}s" + + @classmethod + def _handle_prepositional_phrase(cls, phrase, transform, sep): + """ + Given a word or phrase possibly separated by sep, parse out + the prepositional phrase and apply the transform to the word + preceding the prepositional phrase. + + Raise ValueError if the pivot is not found or if at least two + separators are not found. + + >>> engine._handle_prepositional_phrase("man-of-war", str.upper, '-') + 'MAN-of-war' + >>> engine._handle_prepositional_phrase("man of war", str.upper, ' ') + 'MAN of war' + """ + parts = phrase.split(sep) + if len(parts) < 3: + raise ValueError("Cannot handle words with fewer than two separators") + + pivot = cls._find_pivot(parts, pl_prep_list_da) + + transformed = transform(parts[pivot - 1]) or parts[pivot - 1] + return " ".join( + parts[: pivot - 1] + [sep.join([transformed, parts[pivot], ''])] + ) + " ".join(parts[(pivot + 1) :]) + + @staticmethod + def _find_pivot(words, candidates): + pivots = ( + index for index in range(1, len(words) - 1) if words[index] in candidates + ) + try: + return next(pivots) + except StopIteration: + raise ValueError("No pivot found") + + def _pl_special_verb( # noqa: C901 + self, word: str, count: Optional[Union[str, int]] = None + ) -> Union[str, bool]: + if self.classical_dict["zero"] and str(count).lower() in pl_count_zero: + return False + count = self.get_count(count) + + if count == 1: + return word + + # HANDLE USER-DEFINED VERBS + + value = self.ud_match(word, self.pl_v_user_defined) + if value is not None: + return value + + # HANDLE IRREGULAR PRESENT TENSE (SIMPLE AND COMPOUND) + + try: + words = Words(word) + except IndexError: + return False # word is '' + + if words.first in plverb_irregular_pres: + return f"{plverb_irregular_pres[words.first]}{words[len(words.first) :]}" + + # HANDLE IRREGULAR FUTURE, PRETERITE AND PERFECT TENSES + + if words.first in plverb_irregular_non_pres: + return word + + # HANDLE PRESENT NEGATIONS (SIMPLE AND COMPOUND) + + if words.first.endswith("n't") and words.first[:-3] in plverb_irregular_pres: + return ( + f"{plverb_irregular_pres[words.first[:-3]]}n't" + f"{words[len(words.first) :]}" + ) + + if words.first.endswith("n't"): + return word + + # HANDLE SPECIAL CASES + + mo = PLVERB_SPECIAL_S_RE.search(word) + if mo: + return False + if WHITESPACE.search(word): + return False + + if words.lowered == "quizzes": + return "quiz" + + # HANDLE STANDARD 3RD PERSON (CHOP THE ...(e)s OFF SINGLE WORDS) + + if ( + words.lowered[-4:] in ("ches", "shes", "zzes", "sses") + or words.lowered[-3:] == "xes" + ): + return words[:-2] + + if words.lowered[-3:] == "ies" and len(words) > 3: + return words.lowered[:-3] + "y" + + if ( + words.last.lower() in pl_v_oes_oe + or words.lowered[-4:] in pl_v_oes_oe_endings_size4 + or words.lowered[-5:] in pl_v_oes_oe_endings_size5 + ): + return words[:-1] + + if words.lowered.endswith("oes") and len(words) > 3: + return words.lowered[:-2] + + mo = ENDS_WITH_S.search(words) + if mo: + return mo.group(1) + + # OTHERWISE, A REGULAR VERB (HANDLE ELSEWHERE) + + return False + + def _pl_general_verb( + self, word: str, count: Optional[Union[str, int]] = None + ) -> str: + count = self.get_count(count) + + if count == 1: + return word + + # HANDLE AMBIGUOUS PRESENT TENSES (SIMPLE AND COMPOUND) + + mo = plverb_ambiguous_pres_keys.search(word) + if mo: + return f"{plverb_ambiguous_pres[mo.group(1).lower()]}{mo.group(2)}" + + # HANDLE AMBIGUOUS PRETERITE AND PERFECT TENSES + + mo = plverb_ambiguous_non_pres.search(word) + if mo: + return word + + # OTHERWISE, 1st OR 2ND PERSON IS UNINFLECTED + + return word + + def _pl_special_adjective( + self, word: str, count: Optional[Union[str, int]] = None + ) -> Union[str, bool]: + count = self.get_count(count) + + if count == 1: + return word + + # HANDLE USER-DEFINED ADJECTIVES + + value = self.ud_match(word, self.pl_adj_user_defined) + if value is not None: + return value + + # HANDLE KNOWN CASES + + mo = pl_adj_special_keys.search(word) + if mo: + return pl_adj_special[mo.group(1).lower()] + + # HANDLE POSSESSIVES + + mo = pl_adj_poss_keys.search(word) + if mo: + return pl_adj_poss[mo.group(1).lower()] + + mo = ENDS_WITH_APOSTROPHE_S.search(word) + if mo: + pl = self.plural_noun(mo.group(1)) + trailing_s = "" if pl[-1] == "s" else "s" + return f"{pl}'{trailing_s}" + + # OTHERWISE, NO IDEA + + return False + + # @profile + def _sinoun( # noqa: C901 + self, + word: str, + count: Optional[Union[str, int]] = None, + gender: Optional[str] = None, + ) -> Union[str, bool]: + count = self.get_count(count) + + # DEFAULT TO PLURAL + + if count == 2: + return word + + # SET THE GENDER + + try: + if gender is None: + gender = self.thegender + elif gender not in singular_pronoun_genders: + raise BadGenderError + except (TypeError, IndexError): + raise BadGenderError + + # HANDLE USER-DEFINED NOUNS + + value = self.ud_match(word, self.si_sb_user_defined) + if value is not None: + return value + + # HANDLE EMPTY WORD, SINGULAR COUNT AND UNINFLECTED PLURALS + + if word == "": + return word + + if word in si_sb_ois_oi_case: + return word[:-1] + + words = Words(word) + + if words.last.lower() in pl_sb_uninflected_complete: + return word + + if word in pl_sb_uninflected_caps: + return word + + for k, v in pl_sb_uninflected_bysize.items(): + if words.lowered[-k:] in v: + return word + + if self.classical_dict["herd"] and words.last.lower() in pl_sb_uninflected_herd: + return word + + if words.last.lower() in pl_sb_C_us_us: + return word if self.classical_dict["ancient"] else False + + # HANDLE COMPOUNDS ("Governor General", "mother-in-law", "aide-de-camp", ETC.) + + mo = PL_SB_POSTFIX_ADJ_STEMS_RE.search(word) + if mo and mo.group(2) != "": + return f"{self._sinoun(mo.group(1), 1, gender=gender)}{mo.group(2)}" + + with contextlib.suppress(ValueError): + return self._handle_prepositional_phrase( + words.lowered, + functools.partial(self._sinoun, count=1, gender=gender), + ' ', + ) + + with contextlib.suppress(ValueError): + return self._handle_prepositional_phrase( + words.lowered, + functools.partial(self._sinoun, count=1, gender=gender), + '-', + ) + + # HANDLE PRONOUNS + + for k, v in si_pron_acc_keys_bysize.items(): + if words.lowered[-k:] in v: # ends with accusative pronoun + for pk, pv in pl_prep_bysize.items(): + if words.lowered[:pk] in pv: # starts with a prep + if words.lowered.split() == [ + words.lowered[:pk], + words.lowered[-k:], + ]: + # only whitespace in between + return words.lowered[:-k] + get_si_pron( + "acc", words.lowered[-k:], gender + ) + + try: + return get_si_pron("nom", words.lowered, gender) + except KeyError: + pass + + try: + return get_si_pron("acc", words.lowered, gender) + except KeyError: + pass + + # HANDLE ISOLATED IRREGULAR PLURALS + + if words.last in si_sb_irregular_caps: + llen = len(words.last) + return "{}{}".format(word[:-llen], si_sb_irregular_caps[words.last]) + + if words.last.lower() in si_sb_irregular: + llen = len(words.last.lower()) + return "{}{}".format(word[:-llen], si_sb_irregular[words.last.lower()]) + + dash_split = words.lowered.split("-") + if (" ".join(dash_split[-2:])).lower() in si_sb_irregular_compound: + llen = len( + " ".join(dash_split[-2:]) + ) # TODO: what if 2 spaces between these words? + return "{}{}".format( + word[:-llen], + si_sb_irregular_compound[(" ".join(dash_split[-2:])).lower()], + ) + + if words.lowered[-5:] == "quies": + return word[:-3] + "y" + + if words.lowered[-7:] == "persons": + return word[:-1] + if words.lowered[-6:] == "people": + return word[:-4] + "rson" + + # HANDLE FAMILIES OF IRREGULAR PLURALS + + if words.lowered[-4:] == "mans": + for k, v in si_sb_U_man_mans_bysize.items(): + if words.lowered[-k:] in v: + return word[:-1] + for k, v in si_sb_U_man_mans_caps_bysize.items(): + if word[-k:] in v: + return word[:-1] + if words.lowered[-3:] == "men": + return word[:-3] + "man" + if words.lowered[-4:] == "mice": + return word[:-4] + "mouse" + if words.lowered[-4:] == "lice": + v = si_sb_U_louse_lice_bysize.get(len(word)) + if v and words.lowered in v: + return word[:-4] + "louse" + if words.lowered[-5:] == "geese": + return word[:-5] + "goose" + if words.lowered[-5:] == "teeth": + return word[:-5] + "tooth" + if words.lowered[-4:] == "feet": + return word[:-4] + "foot" + + if words.lowered == "dice": + return "die" + + # HANDLE UNASSIMILATED IMPORTS + + if words.lowered[-4:] == "ceps": + return word + if words.lowered[-3:] == "zoa": + return word[:-1] + "on" + + for lastlet, d, unass_numend, post in ( + ("s", si_sb_U_ch_chs_bysize, -1, ""), + ("s", si_sb_U_ex_ices_bysize, -4, "ex"), + ("s", si_sb_U_ix_ices_bysize, -4, "ix"), + ("a", si_sb_U_um_a_bysize, -1, "um"), + ("i", si_sb_U_us_i_bysize, -1, "us"), + ("a", si_sb_U_on_a_bysize, -1, "on"), + ("e", si_sb_U_a_ae_bysize, -1, ""), + ): + if words.lowered[-1] == lastlet: # this test to add speed + for k, v in d.items(): + if words.lowered[-k:] in v: + return word[:unass_numend] + post + + # HANDLE INCOMPLETELY ASSIMILATED IMPORTS + + if self.classical_dict["ancient"]: + + if words.lowered[-6:] == "trices": + return word[:-3] + "x" + if words.lowered[-4:] in ("eaux", "ieux"): + return word[:-1] + if words.lowered[-5:] in ("ynges", "inges", "anges") and len(word) > 6: + return word[:-3] + "x" + + for lastlet, d, class_numend, post in ( + ("a", si_sb_C_en_ina_bysize, -3, "en"), + ("s", si_sb_C_ex_ices_bysize, -4, "ex"), + ("s", si_sb_C_ix_ices_bysize, -4, "ix"), + ("a", si_sb_C_um_a_bysize, -1, "um"), + ("i", si_sb_C_us_i_bysize, -1, "us"), + ("s", pl_sb_C_us_us_bysize, None, ""), + ("e", si_sb_C_a_ae_bysize, -1, ""), + ("a", si_sb_C_a_ata_bysize, -2, ""), + ("s", si_sb_C_is_ides_bysize, -3, "s"), + ("i", si_sb_C_o_i_bysize, -1, "o"), + ("a", si_sb_C_on_a_bysize, -1, "on"), + ("m", si_sb_C_im_bysize, -2, ""), + ("i", si_sb_C_i_bysize, -1, ""), + ): + if words.lowered[-1] == lastlet: # this test to add speed + for k, v in d.items(): + if words.lowered[-k:] in v: + return word[:class_numend] + post + + # HANDLE PLURLS ENDING IN uses -> use + + if ( + words.lowered[-6:] == "houses" + or word in si_sb_uses_use_case + or words.last.lower() in si_sb_uses_use + ): + return word[:-1] + + # HANDLE PLURLS ENDING IN ies -> ie + + if word in si_sb_ies_ie_case or words.last.lower() in si_sb_ies_ie: + return word[:-1] + + # HANDLE PLURLS ENDING IN oes -> oe + + if ( + words.lowered[-5:] == "shoes" + or word in si_sb_oes_oe_case + or words.last.lower() in si_sb_oes_oe + ): + return word[:-1] + + # HANDLE SINGULAR NOUNS ENDING IN ...s OR OTHER SILIBANTS + + if word in si_sb_sses_sse_case or words.last.lower() in si_sb_sses_sse: + return word[:-1] + + if words.last.lower() in si_sb_singular_s_complete: + return word[:-2] + + for k, v in si_sb_singular_s_bysize.items(): + if words.lowered[-k:] in v: + return word[:-2] + + if words.lowered[-4:] == "eses" and word[0] == word[0].upper(): + return word[:-2] + + if words.last.lower() in si_sb_z_zes: + return word[:-2] + + if words.last.lower() in si_sb_zzes_zz: + return word[:-2] + + if words.lowered[-4:] == "zzes": + return word[:-3] + + if word in si_sb_ches_che_case or words.last.lower() in si_sb_ches_che: + return word[:-1] + + if words.lowered[-4:] in ("ches", "shes"): + return word[:-2] + + if words.last.lower() in si_sb_xes_xe: + return word[:-1] + + if words.lowered[-3:] == "xes": + return word[:-2] + + # HANDLE ...f -> ...ves + + if word in si_sb_ves_ve_case or words.last.lower() in si_sb_ves_ve: + return word[:-1] + + if words.lowered[-3:] == "ves": + if words.lowered[-5:-3] in ("el", "al", "ol"): + return word[:-3] + "f" + if words.lowered[-5:-3] == "ea" and word[-6:-5] != "d": + return word[:-3] + "f" + if words.lowered[-5:-3] in ("ni", "li", "wi"): + return word[:-3] + "fe" + if words.lowered[-5:-3] == "ar": + return word[:-3] + "f" + + # HANDLE ...y + + if words.lowered[-2:] == "ys": + if len(words.lowered) > 2 and words.lowered[-3] in "aeiou": + return word[:-1] + + if self.classical_dict["names"]: + if words.lowered[-2:] == "ys" and word[0] == word[0].upper(): + return word[:-1] + + if words.lowered[-3:] == "ies": + return word[:-3] + "y" + + # HANDLE ...o + + if words.lowered[-2:] == "os": + + if words.last.lower() in si_sb_U_o_os_complete: + return word[:-1] + + for k, v in si_sb_U_o_os_bysize.items(): + if words.lowered[-k:] in v: + return word[:-1] + + if words.lowered[-3:] in ("aos", "eos", "ios", "oos", "uos"): + return word[:-1] + + if words.lowered[-3:] == "oes": + return word[:-2] + + # UNASSIMILATED IMPORTS FINAL RULE + + if word in si_sb_es_is: + return word[:-2] + "is" + + # OTHERWISE JUST REMOVE ...s + + if words.lowered[-1] == "s": + return word[:-1] + + # COULD NOT FIND SINGULAR + + return False + + # ADJECTIVES + + @validate_arguments + def a(self, text: Word, count: Optional[Union[int, str, Any]] = 1) -> str: + """ + Return the appropriate indefinite article followed by text. + + The indefinite article is either 'a' or 'an'. + + If count is not one, then return count followed by text + instead of 'a' or 'an'. + + Whitespace at the start and end is preserved. + + """ + mo = INDEFINITE_ARTICLE_TEST.search(text) + if mo: + word = mo.group(2) + if not word: + return text + pre = mo.group(1) + post = mo.group(3) + result = self._indef_article(word, count) + return f"{pre}{result}{post}" + return "" + + an = a + + _indef_article_cases = ( + # HANDLE ORDINAL FORMS + (A_ordinal_a, "a"), + (A_ordinal_an, "an"), + # HANDLE SPECIAL CASES + (A_explicit_an, "an"), + (SPECIAL_AN, "an"), + (SPECIAL_A, "a"), + # HANDLE ABBREVIATIONS + (A_abbrev, "an"), + (SPECIAL_ABBREV_AN, "an"), + (SPECIAL_ABBREV_A, "a"), + # HANDLE CONSONANTS + (CONSONANTS, "a"), + # HANDLE SPECIAL VOWEL-FORMS + (ARTICLE_SPECIAL_EU, "a"), + (ARTICLE_SPECIAL_ONCE, "a"), + (ARTICLE_SPECIAL_ONETIME, "a"), + (ARTICLE_SPECIAL_UNIT, "a"), + (ARTICLE_SPECIAL_UBA, "a"), + (ARTICLE_SPECIAL_UKR, "a"), + (A_explicit_a, "a"), + # HANDLE SPECIAL CAPITALS + (SPECIAL_CAPITALS, "a"), + # HANDLE VOWELS + (VOWELS, "an"), + # HANDLE y... + # (BEFORE CERTAIN CONSONANTS IMPLIES (UNNATURALIZED) "i.." SOUND) + (A_y_cons, "an"), + ) + + def _indef_article(self, word: str, count: Union[int, str, Any]) -> str: + mycount = self.get_count(count) + + if mycount != 1: + return f"{count} {word}" + + # HANDLE USER-DEFINED VARIANTS + + value = self.ud_match(word, self.A_a_user_defined) + if value is not None: + return f"{value} {word}" + + matches = ( + f'{article} {word}' + for regexen, article in self._indef_article_cases + if regexen.search(word) + ) + + # OTHERWISE, GUESS "a" + fallback = f'a {word}' + return next(matches, fallback) + + # 2. TRANSLATE ZERO-QUANTIFIED $word TO "no plural($word)" + + @validate_arguments + def no(self, text: Word, count: Optional[Union[int, str]] = None) -> str: + """ + If count is 0, no, zero or nil, return 'no' followed by the plural + of text. + + If count is one of: + 1, a, an, one, each, every, this, that + return count followed by text. + + Otherwise return count follow by the plural of text. + + In the return value count is always followed by a space. + + Whitespace at the start and end is preserved. + + """ + if count is None and self.persistent_count is not None: + count = self.persistent_count + + if count is None: + count = 0 + mo = PARTITION_WORD.search(text) + if mo: + pre = mo.group(1) + word = mo.group(2) + post = mo.group(3) + else: + pre = "" + word = "" + post = "" + + if str(count).lower() in pl_count_zero: + count = 'no' + return f"{pre}{count} {self.plural(word, count)}{post}" + + # PARTICIPLES + + @validate_arguments + def present_participle(self, word: Word) -> str: + """ + Return the present participle for word. + + word is the 3rd person singular verb. + + """ + plv = self.plural_verb(word, 2) + ans = plv + + for regexen, repl in PRESENT_PARTICIPLE_REPLACEMENTS: + ans, num = regexen.subn(repl, plv) + if num: + return f"{ans}ing" + return f"{ans}ing" + + # NUMERICAL INFLECTIONS + + @validate_arguments + def ordinal(self, num: Union[int, Word]) -> str: # noqa: C901 + """ + Return the ordinal of num. + + num can be an integer or text + + e.g. ordinal(1) returns '1st' + ordinal('one') returns 'first' + + """ + if DIGIT.match(str(num)): + if isinstance(num, (int, float)): + n = int(num) + else: + if "." in str(num): + try: + # numbers after decimal, + # so only need last one for ordinal + n = int(num[-1]) + + except ValueError: # ends with '.', so need to use whole string + n = int(num[:-1]) + else: + n = int(num) + try: + post = nth[n % 100] + except KeyError: + post = nth[n % 10] + return f"{num}{post}" + else: + # Mad props to Damian Conway (?) whose ordinal() + # algorithm is type-bendy enough to foil MyPy + str_num: str = num # type: ignore[assignment] + mo = ordinal_suff.search(str_num) + if mo: + post = ordinal[mo.group(1)] + rval = ordinal_suff.sub(post, str_num) + else: + rval = f"{str_num}th" + return rval + + def millfn(self, ind: int = 0) -> str: + if ind > len(mill) - 1: + print3("number out of range") + raise NumOutOfRangeError + return mill[ind] + + def unitfn(self, units: int, mindex: int = 0) -> str: + return f"{unit[units]}{self.millfn(mindex)}" + + def tenfn(self, tens, units, mindex=0) -> str: + if tens != 1: + tens_part = ten[tens] + if tens and units: + hyphen = "-" + else: + hyphen = "" + unit_part = unit[units] + mill_part = self.millfn(mindex) + return f"{tens_part}{hyphen}{unit_part}{mill_part}" + return f"{teen[units]}{mill[mindex]}" + + def hundfn(self, hundreds: int, tens: int, units: int, mindex: int) -> str: + if hundreds: + andword = f" {self._number_args['andword']} " if tens or units else "" + # use unit not unitfn as simpler + return ( + f"{unit[hundreds]} hundred{andword}" + f"{self.tenfn(tens, units)}{self.millfn(mindex)}, " + ) + if tens or units: + return f"{self.tenfn(tens, units)}{self.millfn(mindex)}, " + return "" + + def group1sub(self, mo: Match) -> str: + units = int(mo.group(1)) + if units == 1: + return f" {self._number_args['one']}, " + elif units: + return f"{unit[units]}, " + else: + return f" {self._number_args['zero']}, " + + def group1bsub(self, mo: Match) -> str: + units = int(mo.group(1)) + if units: + return f"{unit[units]}, " + else: + return f" {self._number_args['zero']}, " + + def group2sub(self, mo: Match) -> str: + tens = int(mo.group(1)) + units = int(mo.group(2)) + if tens: + return f"{self.tenfn(tens, units)}, " + if units: + return f" {self._number_args['zero']} {unit[units]}, " + return f" {self._number_args['zero']} {self._number_args['zero']}, " + + def group3sub(self, mo: Match) -> str: + hundreds = int(mo.group(1)) + tens = int(mo.group(2)) + units = int(mo.group(3)) + if hundreds == 1: + hunword = f" {self._number_args['one']}" + elif hundreds: + hunword = str(unit[hundreds]) + else: + hunword = f" {self._number_args['zero']}" + if tens: + tenword = self.tenfn(tens, units) + elif units: + tenword = f" {self._number_args['zero']} {unit[units]}" + else: + tenword = f" {self._number_args['zero']} {self._number_args['zero']}" + return f"{hunword} {tenword}, " + + def hundsub(self, mo: Match) -> str: + ret = self.hundfn( + int(mo.group(1)), int(mo.group(2)), int(mo.group(3)), self.mill_count + ) + self.mill_count += 1 + return ret + + def tensub(self, mo: Match) -> str: + return f"{self.tenfn(int(mo.group(1)), int(mo.group(2)), self.mill_count)}, " + + def unitsub(self, mo: Match) -> str: + return f"{self.unitfn(int(mo.group(1)), self.mill_count)}, " + + def enword(self, num: str, group: int) -> str: + # import pdb + # pdb.set_trace() + + if group == 1: + num = DIGIT_GROUP.sub(self.group1sub, num) + elif group == 2: + num = TWO_DIGITS.sub(self.group2sub, num) + num = DIGIT_GROUP.sub(self.group1bsub, num, 1) + elif group == 3: + num = THREE_DIGITS.sub(self.group3sub, num) + num = TWO_DIGITS.sub(self.group2sub, num, 1) + num = DIGIT_GROUP.sub(self.group1sub, num, 1) + elif int(num) == 0: + num = self._number_args["zero"] + elif int(num) == 1: + num = self._number_args["one"] + else: + num = num.lstrip().lstrip("0") + self.mill_count = 0 + # surely there's a better way to do the next bit + mo = THREE_DIGITS_WORD.search(num) + while mo: + num = THREE_DIGITS_WORD.sub(self.hundsub, num, 1) + mo = THREE_DIGITS_WORD.search(num) + num = TWO_DIGITS_WORD.sub(self.tensub, num, 1) + num = ONE_DIGIT_WORD.sub(self.unitsub, num, 1) + return num + + @validate_arguments(config=dict(arbitrary_types_allowed=True)) # noqa: C901 + def number_to_words( # noqa: C901 + self, + num: Union[Number, Word], + wantlist: bool = False, + group: int = 0, + comma: Union[Falsish, str] = ",", + andword: str = "and", + zero: str = "zero", + one: str = "one", + decimal: Union[Falsish, str] = "point", + threshold: Optional[int] = None, + ) -> Union[str, List[str]]: + """ + Return a number in words. + + group = 1, 2 or 3 to group numbers before turning into words + comma: define comma + + andword: + word for 'and'. Can be set to ''. + e.g. "one hundred and one" vs "one hundred one" + + zero: word for '0' + one: word for '1' + decimal: word for decimal point + threshold: numbers above threshold not turned into words + + parameters not remembered from last call. Departure from Perl version. + """ + self._number_args = {"andword": andword, "zero": zero, "one": one} + num = str(num) + + # Handle "stylistic" conversions (up to a given threshold)... + if threshold is not None and float(num) > threshold: + spnum = num.split(".", 1) + while comma: + (spnum[0], n) = FOUR_DIGIT_COMMA.subn(r"\1,\2", spnum[0]) + if n == 0: + break + try: + return f"{spnum[0]}.{spnum[1]}" + except IndexError: + return str(spnum[0]) + + if group < 0 or group > 3: + raise BadChunkingOptionError + nowhite = num.lstrip() + if nowhite[0] == "+": + sign = "plus" + elif nowhite[0] == "-": + sign = "minus" + else: + sign = "" + + if num in nth_suff: + num = zero + + myord = num[-2:] in nth_suff + if myord: + num = num[:-2] + finalpoint = False + if decimal: + if group != 0: + chunks = num.split(".") + else: + chunks = num.split(".", 1) + if chunks[-1] == "": # remove blank string if nothing after decimal + chunks = chunks[:-1] + finalpoint = True # add 'point' to end of output + else: + chunks = [num] + + first: Union[int, str, bool] = 1 + loopstart = 0 + + if chunks[0] == "": + first = 0 + if len(chunks) > 1: + loopstart = 1 + + for i in range(loopstart, len(chunks)): + chunk = chunks[i] + # remove all non numeric \D + chunk = NON_DIGIT.sub("", chunk) + if chunk == "": + chunk = "0" + + if group == 0 and (first == 0 or first == ""): + chunk = self.enword(chunk, 1) + else: + chunk = self.enword(chunk, group) + + if chunk[-2:] == ", ": + chunk = chunk[:-2] + chunk = WHITESPACES_COMMA.sub(",", chunk) + + if group == 0 and first: + chunk = COMMA_WORD.sub(f" {andword} \\1", chunk) + chunk = WHITESPACES.sub(" ", chunk) + # chunk = re.sub(r"(\A\s|\s\Z)", self.blankfn, chunk) + chunk = chunk.strip() + if first: + first = "" + chunks[i] = chunk + + numchunks = [] + if first != 0: + numchunks = chunks[0].split(f"{comma} ") + + if myord and numchunks: + # TODO: can this be just one re as it is in perl? + mo = ordinal_suff.search(numchunks[-1]) + if mo: + numchunks[-1] = ordinal_suff.sub(ordinal[mo.group(1)], numchunks[-1]) + else: + numchunks[-1] += "th" + + for chunk in chunks[1:]: + numchunks.append(decimal) + numchunks.extend(chunk.split(f"{comma} ")) + + if finalpoint: + numchunks.append(decimal) + + # wantlist: Perl list context. can explicitly specify in Python + if wantlist: + if sign: + numchunks = [sign] + numchunks + return numchunks + elif group: + signout = f"{sign} " if sign else "" + return f"{signout}{', '.join(numchunks)}" + else: + signout = f"{sign} " if sign else "" + num = f"{signout}{numchunks.pop(0)}" + if decimal is None: + first = True + else: + first = not num.endswith(decimal) + for nc in numchunks: + if nc == decimal: + num += f" {nc}" + first = 0 + elif first: + num += f"{comma} {nc}" + else: + num += f" {nc}" + return num + + # Join words with commas and a trailing 'and' (when appropriate)... + + @validate_arguments + def join( + self, + words: Optional[Sequence[Word]], + sep: Optional[str] = None, + sep_spaced: bool = True, + final_sep: Optional[str] = None, + conj: str = "and", + conj_spaced: bool = True, + ) -> str: + """ + Join words into a list. + + e.g. join(['ant', 'bee', 'fly']) returns 'ant, bee, and fly' + + options: + conj: replacement for 'and' + sep: separator. default ',', unless ',' is in the list then ';' + final_sep: final separator. default ',', unless ',' is in the list then ';' + conj_spaced: boolean. Should conj have spaces around it + + """ + if not words: + return "" + if len(words) == 1: + return words[0] + + if conj_spaced: + if conj == "": + conj = " " + else: + conj = f" {conj} " + + if len(words) == 2: + return f"{words[0]}{conj}{words[1]}" + + if sep is None: + if "," in "".join(words): + sep = ";" + else: + sep = "," + if final_sep is None: + final_sep = sep + + final_sep = f"{final_sep}{conj}" + + if sep_spaced: + sep += " " + + return f"{sep.join(words[0:-1])}{final_sep}{words[-1]}" diff --git a/libs/win/inflect/py.typed b/libs/win/inflect/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/libs/win/jaraco.classes-1.5-py3.6-nspkg.pth b/libs/win/jaraco.classes-1.5-py3.6-nspkg.pth deleted file mode 100644 index 61cb14f9..00000000 --- a/libs/win/jaraco.classes-1.5-py3.6-nspkg.pth +++ /dev/null @@ -1 +0,0 @@ -import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/win/jaraco.collections-1.6.0-py3.7-nspkg.pth b/libs/win/jaraco.collections-1.6.0-py3.7-nspkg.pth deleted file mode 100644 index 61cb14f9..00000000 --- a/libs/win/jaraco.collections-1.6.0-py3.7-nspkg.pth +++ /dev/null @@ -1 +0,0 @@ -import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/win/jaraco.functools-1.20-py3.6-nspkg.pth b/libs/win/jaraco.functools-1.20-py3.6-nspkg.pth deleted file mode 100644 index 61cb14f9..00000000 --- a/libs/win/jaraco.functools-1.20-py3.6-nspkg.pth +++ /dev/null @@ -1 +0,0 @@ -import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/win/jaraco.structures-1.1.2-py3.6-nspkg.pth b/libs/win/jaraco.structures-1.1.2-py3.6-nspkg.pth deleted file mode 100644 index 61cb14f9..00000000 --- a/libs/win/jaraco.structures-1.1.2-py3.6-nspkg.pth +++ /dev/null @@ -1 +0,0 @@ -import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/win/jaraco.text-1.10.1-py3.6-nspkg.pth b/libs/win/jaraco.text-1.10.1-py3.6-nspkg.pth deleted file mode 100644 index 61cb14f9..00000000 --- a/libs/win/jaraco.text-1.10.1-py3.6-nspkg.pth +++ /dev/null @@ -1 +0,0 @@ -import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/win/jaraco.ui-1.6-py3.6-nspkg.pth b/libs/win/jaraco.ui-1.6-py3.6-nspkg.pth deleted file mode 100644 index 61cb14f9..00000000 --- a/libs/win/jaraco.ui-1.6-py3.6-nspkg.pth +++ /dev/null @@ -1 +0,0 @@ -import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/win/jaraco.windows-3.9.2-py3.7-nspkg.pth b/libs/win/jaraco.windows-3.9.2-py3.7-nspkg.pth deleted file mode 100644 index 61cb14f9..00000000 --- a/libs/win/jaraco.windows-3.9.2-py3.7-nspkg.pth +++ /dev/null @@ -1 +0,0 @@ -import sys, types, os;has_mfs = sys.version_info > (3, 5);p = os.path.join(sys._getframe(1).f_locals['sitedir'], *('jaraco',));importlib = has_mfs and __import__('importlib.util');has_mfs and __import__('importlib.machinery');m = has_mfs and sys.modules.setdefault('jaraco', importlib.util.module_from_spec(importlib.machinery.PathFinder.find_spec('jaraco', [os.path.dirname(p)])));m = m or sys.modules.setdefault('jaraco', types.ModuleType('jaraco'));mp = (m or []) and m.__dict__.setdefault('__path__',[]);(p not in mp) and mp.append(p) diff --git a/libs/win/jaraco/classes/ancestry.py b/libs/win/jaraco/classes/ancestry.py index 040ce612..dd9b2e92 100644 --- a/libs/win/jaraco/classes/ancestry.py +++ b/libs/win/jaraco/classes/ancestry.py @@ -3,73 +3,66 @@ Routines for obtaining the class names of an object and its parent classes. """ -from __future__ import unicode_literals +from more_itertools import unique_everseen def all_bases(c): - """ - return a tuple of all base classes the class c has as a parent. - >>> object in all_bases(list) - True - """ - return c.mro()[1:] + """ + return a tuple of all base classes the class c has as a parent. + >>> object in all_bases(list) + True + """ + return c.mro()[1:] def all_classes(c): - """ - return a tuple of all classes to which c belongs - >>> list in all_classes(list) - True - """ - return c.mro() + """ + return a tuple of all classes to which c belongs + >>> list in all_classes(list) + True + """ + return c.mro() + # borrowed from # http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/ -def iter_subclasses(cls, _seen=None): - """ - Generator over all subclasses of a given class, in depth-first order. +def iter_subclasses(cls): + """ + Generator over all subclasses of a given class, in depth-first order. - >>> bool in list(iter_subclasses(int)) - True - >>> class A(object): pass - >>> class B(A): pass - >>> class C(A): pass - >>> class D(B,C): pass - >>> class E(D): pass - >>> - >>> for cls in iter_subclasses(A): - ... print(cls.__name__) - B - D - E - C - >>> # get ALL (new-style) classes currently defined - >>> res = [cls.__name__ for cls in iter_subclasses(object)] - >>> 'type' in res - True - >>> 'tuple' in res - True - >>> len(res) > 100 - True - """ + >>> bool in list(iter_subclasses(int)) + True + >>> class A(object): pass + >>> class B(A): pass + >>> class C(A): pass + >>> class D(B,C): pass + >>> class E(D): pass + >>> + >>> for cls in iter_subclasses(A): + ... print(cls.__name__) + B + D + E + C + >>> # get ALL classes currently defined + >>> res = [cls.__name__ for cls in iter_subclasses(object)] + >>> 'type' in res + True + >>> 'tuple' in res + True + >>> len(res) > 100 + True + """ + return unique_everseen(_iter_all_subclasses(cls)) - if not isinstance(cls, type): - raise TypeError( - 'iter_subclasses must be called with ' - 'new-style classes, not %.100r' % cls - ) - if _seen is None: - _seen = set() - try: - subs = cls.__subclasses__() - except TypeError: # fails only when cls is type - subs = cls.__subclasses__(cls) - for sub in subs: - if sub in _seen: - continue - _seen.add(sub) - yield sub - for sub in iter_subclasses(sub, _seen): - yield sub + +def _iter_all_subclasses(cls): + try: + subs = cls.__subclasses__() + except TypeError: # fails only when cls is type + subs = cls.__subclasses__(cls) + for sub in subs: + yield sub + yield from iter_subclasses(sub) diff --git a/libs/win/jaraco/classes/meta.py b/libs/win/jaraco/classes/meta.py index c26f7dc2..bd41a1d9 100644 --- a/libs/win/jaraco/classes/meta.py +++ b/libs/win/jaraco/classes/meta.py @@ -4,38 +4,63 @@ meta.py Some useful metaclasses. """ -from __future__ import unicode_literals - class LeafClassesMeta(type): - """ - A metaclass for classes that keeps track of all of them that - aren't base classes. - """ + """ + A metaclass for classes that keeps track of all of them that + aren't base classes. - _leaf_classes = set() + >>> Parent = LeafClassesMeta('MyParentClass', (), {}) + >>> Parent in Parent._leaf_classes + True + >>> Child = LeafClassesMeta('MyChildClass', (Parent,), {}) + >>> Child in Parent._leaf_classes + True + >>> Parent in Parent._leaf_classes + False - def __init__(cls, name, bases, attrs): - if not hasattr(cls, '_leaf_classes'): - cls._leaf_classes = set() - leaf_classes = getattr(cls, '_leaf_classes') - leaf_classes.add(cls) - # remove any base classes - leaf_classes -= set(bases) + >>> Other = LeafClassesMeta('OtherClass', (), {}) + >>> Parent in Other._leaf_classes + False + >>> len(Other._leaf_classes) + 1 + """ + + def __init__(cls, name, bases, attrs): + if not hasattr(cls, '_leaf_classes'): + cls._leaf_classes = set() + leaf_classes = getattr(cls, '_leaf_classes') + leaf_classes.add(cls) + # remove any base classes + leaf_classes -= set(bases) class TagRegistered(type): - """ - As classes of this metaclass are created, they keep a registry in the - base class of all classes by a class attribute, indicated by attr_name. - """ - attr_name = 'tag' + """ + As classes of this metaclass are created, they keep a registry in the + base class of all classes by a class attribute, indicated by attr_name. - def __init__(cls, name, bases, namespace): - super(TagRegistered, cls).__init__(name, bases, namespace) - if not hasattr(cls, '_registry'): - cls._registry = {} - meta = cls.__class__ - attr = getattr(cls, meta.attr_name, None) - if attr: - cls._registry[attr] = cls + >>> FooObject = TagRegistered('FooObject', (), dict(tag='foo')) + >>> FooObject._registry['foo'] is FooObject + True + >>> BarObject = TagRegistered('Barobject', (FooObject,), dict(tag='bar')) + >>> FooObject._registry is BarObject._registry + True + >>> len(FooObject._registry) + 2 + + '...' below should be 'jaraco.classes' but for pytest-dev/pytest#3396 + >>> FooObject._registry['bar'] + + """ + + attr_name = 'tag' + + def __init__(cls, name, bases, namespace): + super(TagRegistered, cls).__init__(name, bases, namespace) + if not hasattr(cls, '_registry'): + cls._registry = {} + meta = cls.__class__ + attr = getattr(cls, meta.attr_name, None) + if attr: + cls._registry[attr] = cls diff --git a/libs/win/jaraco/classes/properties.py b/libs/win/jaraco/classes/properties.py index 57f9054f..62f9e200 100644 --- a/libs/win/jaraco/classes/properties.py +++ b/libs/win/jaraco/classes/properties.py @@ -1,67 +1,170 @@ -from __future__ import unicode_literals - -import six - -__metaclass__ = type - - class NonDataProperty: - """Much like the property builtin, but only implements __get__, - making it a non-data property, and can be subsequently reset. + """Much like the property builtin, but only implements __get__, + making it a non-data property, and can be subsequently reset. - See http://users.rcn.com/python/download/Descriptor.htm for more - information. + See http://users.rcn.com/python/download/Descriptor.htm for more + information. - >>> class X(object): - ... @NonDataProperty - ... def foo(self): - ... return 3 - >>> x = X() - >>> x.foo - 3 - >>> x.foo = 4 - >>> x.foo - 4 - """ + >>> class X(object): + ... @NonDataProperty + ... def foo(self): + ... return 3 + >>> x = X() + >>> x.foo + 3 + >>> x.foo = 4 + >>> x.foo + 4 - def __init__(self, fget): - assert fget is not None, "fget cannot be none" - assert six.callable(fget), "fget must be callable" - self.fget = fget + '...' below should be 'jaraco.classes' but for pytest-dev/pytest#3396 + >>> X.foo + <....properties.NonDataProperty object at ...> + """ - def __get__(self, obj, objtype=None): - if obj is None: - return self - return self.fget(obj) + def __init__(self, fget): + assert fget is not None, "fget cannot be none" + assert callable(fget), "fget must be callable" + self.fget = fget + + def __get__(self, obj, objtype=None): + if obj is None: + return self + return self.fget(obj) -# from http://stackoverflow.com/a/5191224 -class ClassPropertyDescriptor: - - def __init__(self, fget, fset=None): - self.fget = fget - self.fset = fset - - def __get__(self, obj, klass=None): - if klass is None: - klass = type(obj) - return self.fget.__get__(obj, klass)() - - def __set__(self, obj, value): - if not self.fset: - raise AttributeError("can't set attribute") - type_ = type(obj) - return self.fset.__get__(obj, type_)(value) - - def setter(self, func): - if not isinstance(func, (classmethod, staticmethod)): - func = classmethod(func) - self.fset = func - return self +class classproperty: + """ + Like @property but applies at the class level. -def classproperty(func): - if not isinstance(func, (classmethod, staticmethod)): - func = classmethod(func) + >>> class X(metaclass=classproperty.Meta): + ... val = None + ... @classproperty + ... def foo(cls): + ... return cls.val + ... @foo.setter + ... def foo(cls, val): + ... cls.val = val + >>> X.foo + >>> X.foo = 3 + >>> X.foo + 3 + >>> x = X() + >>> x.foo + 3 + >>> X.foo = 4 + >>> x.foo + 4 - return ClassPropertyDescriptor(func) + Setting the property on an instance affects the class. + + >>> x.foo = 5 + >>> x.foo + 5 + >>> X.foo + 5 + >>> vars(x) + {} + >>> X().foo + 5 + + Attempting to set an attribute where no setter was defined + results in an AttributeError: + + >>> class GetOnly(metaclass=classproperty.Meta): + ... @classproperty + ... def foo(cls): + ... return 'bar' + >>> GetOnly.foo = 3 + Traceback (most recent call last): + ... + AttributeError: can't set attribute + + It is also possible to wrap a classmethod or staticmethod in + a classproperty. + + >>> class Static(metaclass=classproperty.Meta): + ... @classproperty + ... @classmethod + ... def foo(cls): + ... return 'foo' + ... @classproperty + ... @staticmethod + ... def bar(): + ... return 'bar' + >>> Static.foo + 'foo' + >>> Static.bar + 'bar' + + *Legacy* + + For compatibility, if the metaclass isn't specified, the + legacy behavior will be invoked. + + >>> class X: + ... val = None + ... @classproperty + ... def foo(cls): + ... return cls.val + ... @foo.setter + ... def foo(cls, val): + ... cls.val = val + >>> X.foo + >>> X.foo = 3 + >>> X.foo + 3 + >>> x = X() + >>> x.foo + 3 + >>> X.foo = 4 + >>> x.foo + 4 + + Note, because the metaclass was not specified, setting + a value on an instance does not have the intended effect. + + >>> x.foo = 5 + >>> x.foo + 5 + >>> X.foo # should be 5 + 4 + >>> vars(x) # should be empty + {'foo': 5} + >>> X().foo # should be 5 + 4 + """ + + class Meta(type): + def __setattr__(self, key, value): + obj = self.__dict__.get(key, None) + if type(obj) is classproperty: + return obj.__set__(self, value) + return super().__setattr__(key, value) + + def __init__(self, fget, fset=None): + self.fget = self._ensure_method(fget) + self.fset = fset + fset and self.setter(fset) + + def __get__(self, instance, owner=None): + return self.fget.__get__(None, owner)() + + def __set__(self, owner, value): + if not self.fset: + raise AttributeError("can't set attribute") + if type(owner) is not classproperty.Meta: + owner = type(owner) + return self.fset.__get__(None, owner)(value) + + def setter(self, fset): + self.fset = self._ensure_method(fset) + return self + + @classmethod + def _ensure_method(cls, fn): + """ + Ensure fn is a classmethod or staticmethod. + """ + needs_method = not isinstance(fn, (classmethod, staticmethod)) + return classmethod(fn) if needs_method else fn diff --git a/libs/win/jaraco/collections.py b/libs/win/jaraco/collections.py index bb463deb..db89b122 100644 --- a/libs/win/jaraco/collections.py +++ b/libs/win/jaraco/collections.py @@ -1,906 +1,1090 @@ -# -*- coding: utf-8 -*- - -from __future__ import absolute_import, unicode_literals, division - import re import operator -import collections +import collections.abc import itertools import copy import functools +import random -try: - import collections.abc -except ImportError: - # Python 2.7 - collections.abc = collections - -import six from jaraco.classes.properties import NonDataProperty import jaraco.text class Projection(collections.abc.Mapping): - """ - Project a set of keys over a mapping + """ + Project a set of keys over a mapping - >>> sample = {'a': 1, 'b': 2, 'c': 3} - >>> prj = Projection(['a', 'c', 'd'], sample) - >>> prj == {'a': 1, 'c': 3} - True + >>> sample = {'a': 1, 'b': 2, 'c': 3} + >>> prj = Projection(['a', 'c', 'd'], sample) + >>> prj == {'a': 1, 'c': 3} + True - Keys should only appear if they were specified and exist in the space. + Keys should only appear if they were specified and exist in the space. - >>> sorted(list(prj.keys())) - ['a', 'c'] + >>> sorted(list(prj.keys())) + ['a', 'c'] - Use the projection to update another dict. + Attempting to access a key not in the projection + results in a KeyError. - >>> target = {'a': 2, 'b': 2} - >>> target.update(prj) - >>> target == {'a': 1, 'b': 2, 'c': 3} - True + >>> prj['b'] + Traceback (most recent call last): + ... + KeyError: 'b' - Also note that Projection keeps a reference to the original dict, so - if you modify the original dict, that could modify the Projection. + Use the projection to update another dict. - >>> del sample['a'] - >>> dict(prj) - {'c': 3} - """ - def __init__(self, keys, space): - self._keys = tuple(keys) - self._space = space + >>> target = {'a': 2, 'b': 2} + >>> target.update(prj) + >>> target == {'a': 1, 'b': 2, 'c': 3} + True - def __getitem__(self, key): - if key not in self._keys: - raise KeyError(key) - return self._space[key] + Also note that Projection keeps a reference to the original dict, so + if you modify the original dict, that could modify the Projection. - def __iter__(self): - return iter(set(self._keys).intersection(self._space)) + >>> del sample['a'] + >>> dict(prj) + {'c': 3} + """ - def __len__(self): - return len(tuple(iter(self))) + def __init__(self, keys, space): + self._keys = tuple(keys) + self._space = space + + def __getitem__(self, key): + if key not in self._keys: + raise KeyError(key) + return self._space[key] + + def __iter__(self): + return iter(set(self._keys).intersection(self._space)) + + def __len__(self): + return len(tuple(iter(self))) -class DictFilter(object): - """ - Takes a dict, and simulates a sub-dict based on the keys. +class DictFilter(collections.abc.Mapping): + """ + Takes a dict, and simulates a sub-dict based on the keys. - >>> sample = {'a': 1, 'b': 2, 'c': 3} - >>> filtered = DictFilter(sample, ['a', 'c']) - >>> filtered == {'a': 1, 'c': 3} - True + >>> sample = {'a': 1, 'b': 2, 'c': 3} + >>> filtered = DictFilter(sample, ['a', 'c']) + >>> filtered == {'a': 1, 'c': 3} + True + >>> set(filtered.values()) == {1, 3} + True + >>> set(filtered.items()) == {('a', 1), ('c', 3)} + True - One can also filter by a regular expression pattern + One can also filter by a regular expression pattern - >>> sample['d'] = 4 - >>> sample['ef'] = 5 + >>> sample['d'] = 4 + >>> sample['ef'] = 5 - Here we filter for only single-character keys + Here we filter for only single-character keys - >>> filtered = DictFilter(sample, include_pattern='.$') - >>> filtered == {'a': 1, 'b': 2, 'c': 3, 'd': 4} - True + >>> filtered = DictFilter(sample, include_pattern='.$') + >>> filtered == {'a': 1, 'b': 2, 'c': 3, 'd': 4} + True - Also note that DictFilter keeps a reference to the original dict, so - if you modify the original dict, that could modify the filtered dict. + >>> filtered['e'] + Traceback (most recent call last): + ... + KeyError: 'e' - >>> del sample['d'] - >>> del sample['a'] - >>> filtered == {'b': 2, 'c': 3} - True + >>> 'e' in filtered + False - """ - def __init__(self, dict, include_keys=[], include_pattern=None): - self.dict = dict - self.specified_keys = set(include_keys) - if include_pattern is not None: - self.include_pattern = re.compile(include_pattern) - else: - # for performance, replace the pattern_keys property - self.pattern_keys = set() + Pattern is useful for excluding keys with a prefix. - def get_pattern_keys(self): - keys = filter(self.include_pattern.match, self.dict.keys()) - return set(keys) - pattern_keys = NonDataProperty(get_pattern_keys) + >>> filtered = DictFilter(sample, include_pattern=r'(?![ace])') + >>> dict(filtered) + {'b': 2, 'd': 4} - @property - def include_keys(self): - return self.specified_keys.union(self.pattern_keys) + Also note that DictFilter keeps a reference to the original dict, so + if you modify the original dict, that could modify the filtered dict. - def keys(self): - return self.include_keys.intersection(self.dict.keys()) + >>> del sample['d'] + >>> dict(filtered) + {'b': 2} + """ - def values(self): - keys = self.keys() - values = map(self.dict.get, keys) - return values + def __init__(self, dict, include_keys=[], include_pattern=None): + self.dict = dict + self.specified_keys = set(include_keys) + if include_pattern is not None: + self.include_pattern = re.compile(include_pattern) + else: + # for performance, replace the pattern_keys property + self.pattern_keys = set() - def __getitem__(self, i): - if i not in self.include_keys: - return KeyError, i - return self.dict[i] + def get_pattern_keys(self): + keys = filter(self.include_pattern.match, self.dict.keys()) + return set(keys) - def items(self): - keys = self.keys() - values = map(self.dict.get, keys) - return zip(keys, values) + pattern_keys = NonDataProperty(get_pattern_keys) - def __eq__(self, other): - return dict(self) == other + @property + def include_keys(self): + return self.specified_keys | self.pattern_keys - def __ne__(self, other): - return dict(self) != other + def __getitem__(self, i): + if i not in self.include_keys: + raise KeyError(i) + return self.dict[i] + + def __iter__(self): + return filter(self.include_keys.__contains__, self.dict.keys()) + + def __len__(self): + return len(list(self)) def dict_map(function, dictionary): - """ - dict_map is much like the built-in function map. It takes a dictionary - and applys a function to the values of that dictionary, returning a - new dictionary with the mapped values in the original keys. + """ + dict_map is much like the built-in function map. It takes a dictionary + and applys a function to the values of that dictionary, returning a + new dictionary with the mapped values in the original keys. - >>> d = dict_map(lambda x:x+1, dict(a=1, b=2)) - >>> d == dict(a=2,b=3) - True - """ - return dict((key, function(value)) for key, value in dictionary.items()) + >>> d = dict_map(lambda x:x+1, dict(a=1, b=2)) + >>> d == dict(a=2,b=3) + True + """ + return dict((key, function(value)) for key, value in dictionary.items()) class RangeMap(dict): - """ - A dictionary-like object that uses the keys as bounds for a range. - Inclusion of the value for that range is determined by the - key_match_comparator, which defaults to less-than-or-equal. - A value is returned for a key if it is the first key that matches in - the sorted list of keys. + """ + A dictionary-like object that uses the keys as bounds for a range. + Inclusion of the value for that range is determined by the + key_match_comparator, which defaults to less-than-or-equal. + A value is returned for a key if it is the first key that matches in + the sorted list of keys. - One may supply keyword parameters to be passed to the sort function used - to sort keys (i.e. cmp [python 2 only], keys, reverse) as sort_params. + One may supply keyword parameters to be passed to the sort function used + to sort keys (i.e. key, reverse) as sort_params. - Let's create a map that maps 1-3 -> 'a', 4-6 -> 'b' + Let's create a map that maps 1-3 -> 'a', 4-6 -> 'b' - >>> r = RangeMap({3: 'a', 6: 'b'}) # boy, that was easy - >>> r[1], r[2], r[3], r[4], r[5], r[6] - ('a', 'a', 'a', 'b', 'b', 'b') + >>> r = RangeMap({3: 'a', 6: 'b'}) # boy, that was easy + >>> r[1], r[2], r[3], r[4], r[5], r[6] + ('a', 'a', 'a', 'b', 'b', 'b') - Even float values should work so long as the comparison operator - supports it. + Even float values should work so long as the comparison operator + supports it. - >>> r[4.5] - 'b' + >>> r[4.5] + 'b' - But you'll notice that the way rangemap is defined, it must be open-ended - on one side. + But you'll notice that the way rangemap is defined, it must be open-ended + on one side. - >>> r[0] - 'a' - >>> r[-1] - 'a' + >>> r[0] + 'a' + >>> r[-1] + 'a' - One can close the open-end of the RangeMap by using undefined_value + One can close the open-end of the RangeMap by using undefined_value - >>> r = RangeMap({0: RangeMap.undefined_value, 3: 'a', 6: 'b'}) - >>> r[0] - Traceback (most recent call last): - ... - KeyError: 0 + >>> r = RangeMap({0: RangeMap.undefined_value, 3: 'a', 6: 'b'}) + >>> r[0] + Traceback (most recent call last): + ... + KeyError: 0 - One can get the first or last elements in the range by using RangeMap.Item + One can get the first or last elements in the range by using RangeMap.Item - >>> last_item = RangeMap.Item(-1) - >>> r[last_item] - 'b' + >>> last_item = RangeMap.Item(-1) + >>> r[last_item] + 'b' - .last_item is a shortcut for Item(-1) + .last_item is a shortcut for Item(-1) - >>> r[RangeMap.last_item] - 'b' + >>> r[RangeMap.last_item] + 'b' - Sometimes it's useful to find the bounds for a RangeMap + Sometimes it's useful to find the bounds for a RangeMap - >>> r.bounds() - (0, 6) + >>> r.bounds() + (0, 6) - RangeMap supports .get(key, default) + RangeMap supports .get(key, default) - >>> r.get(0, 'not found') - 'not found' + >>> r.get(0, 'not found') + 'not found' - >>> r.get(7, 'not found') - 'not found' - """ - def __init__(self, source, sort_params={}, key_match_comparator=operator.le): - dict.__init__(self, source) - self.sort_params = sort_params - self.match = key_match_comparator + >>> r.get(7, 'not found') + 'not found' - def __getitem__(self, item): - sorted_keys = sorted(self.keys(), **self.sort_params) - if isinstance(item, RangeMap.Item): - result = self.__getitem__(sorted_keys[item]) - else: - key = self._find_first_match_(sorted_keys, item) - result = dict.__getitem__(self, key) - if result is RangeMap.undefined_value: - raise KeyError(key) - return result + One often wishes to define the ranges by their left-most values, + which requires use of sort params and a key_match_comparator. - def get(self, key, default=None): - """ - Return the value for key if key is in the dictionary, else default. - If default is not given, it defaults to None, so that this method - never raises a KeyError. - """ - try: - return self[key] - except KeyError: - return default + >>> r = RangeMap({1: 'a', 4: 'b'}, + ... sort_params=dict(reverse=True), + ... key_match_comparator=operator.ge) + >>> r[1], r[2], r[3], r[4], r[5], r[6] + ('a', 'a', 'a', 'b', 'b', 'b') - def _find_first_match_(self, keys, item): - is_match = functools.partial(self.match, item) - matches = list(filter(is_match, keys)) - if matches: - return matches[0] - raise KeyError(item) + That wasn't nearly as easy as before, so an alternate constructor + is provided: - def bounds(self): - sorted_keys = sorted(self.keys(), **self.sort_params) - return ( - sorted_keys[RangeMap.first_item], - sorted_keys[RangeMap.last_item], - ) + >>> r = RangeMap.left({1: 'a', 4: 'b', 7: RangeMap.undefined_value}) + >>> r[1], r[2], r[3], r[4], r[5], r[6] + ('a', 'a', 'a', 'b', 'b', 'b') - # some special values for the RangeMap - undefined_value = type(str('RangeValueUndefined'), (object,), {})() + """ - class Item(int): - "RangeMap Item" - first_item = Item(0) - last_item = Item(-1) + def __init__(self, source, sort_params={}, key_match_comparator=operator.le): + dict.__init__(self, source) + self.sort_params = sort_params + self.match = key_match_comparator + + @classmethod + def left(cls, source): + return cls( + source, sort_params=dict(reverse=True), key_match_comparator=operator.ge + ) + + def __getitem__(self, item): + sorted_keys = sorted(self.keys(), **self.sort_params) + if isinstance(item, RangeMap.Item): + result = self.__getitem__(sorted_keys[item]) + else: + key = self._find_first_match_(sorted_keys, item) + result = dict.__getitem__(self, key) + if result is RangeMap.undefined_value: + raise KeyError(key) + return result + + def get(self, key, default=None): + """ + Return the value for key if key is in the dictionary, else default. + If default is not given, it defaults to None, so that this method + never raises a KeyError. + """ + try: + return self[key] + except KeyError: + return default + + def _find_first_match_(self, keys, item): + is_match = functools.partial(self.match, item) + matches = list(filter(is_match, keys)) + if matches: + return matches[0] + raise KeyError(item) + + def bounds(self): + sorted_keys = sorted(self.keys(), **self.sort_params) + return (sorted_keys[RangeMap.first_item], sorted_keys[RangeMap.last_item]) + + # some special values for the RangeMap + undefined_value = type(str('RangeValueUndefined'), (), {})() + + class Item(int): + "RangeMap Item" + + first_item = Item(0) + last_item = Item(-1) def __identity(x): - return x + return x def sorted_items(d, key=__identity, reverse=False): - """ - Return the items of the dictionary sorted by the keys + """ + Return the items of the dictionary sorted by the keys - >>> sample = dict(foo=20, bar=42, baz=10) - >>> tuple(sorted_items(sample)) - (('bar', 42), ('baz', 10), ('foo', 20)) + >>> sample = dict(foo=20, bar=42, baz=10) + >>> tuple(sorted_items(sample)) + (('bar', 42), ('baz', 10), ('foo', 20)) - >>> reverse_string = lambda s: ''.join(reversed(s)) - >>> tuple(sorted_items(sample, key=reverse_string)) - (('foo', 20), ('bar', 42), ('baz', 10)) + >>> reverse_string = lambda s: ''.join(reversed(s)) + >>> tuple(sorted_items(sample, key=reverse_string)) + (('foo', 20), ('bar', 42), ('baz', 10)) - >>> tuple(sorted_items(sample, reverse=True)) - (('foo', 20), ('baz', 10), ('bar', 42)) - """ - # wrap the key func so it operates on the first element of each item - def pairkey_key(item): - return key(item[0]) - return sorted(d.items(), key=pairkey_key, reverse=reverse) + >>> tuple(sorted_items(sample, reverse=True)) + (('foo', 20), ('baz', 10), ('bar', 42)) + """ + # wrap the key func so it operates on the first element of each item + def pairkey_key(item): + return key(item[0]) + + return sorted(d.items(), key=pairkey_key, reverse=reverse) class KeyTransformingDict(dict): - """ - A dict subclass that transforms the keys before they're used. - Subclasses may override the default transform_key to customize behavior. - """ - @staticmethod - def transform_key(key): - return key + """ + A dict subclass that transforms the keys before they're used. + Subclasses may override the default transform_key to customize behavior. + """ - def __init__(self, *args, **kargs): - super(KeyTransformingDict, self).__init__() - # build a dictionary using the default constructs - d = dict(*args, **kargs) - # build this dictionary using transformed keys. - for item in d.items(): - self.__setitem__(*item) + @staticmethod + def transform_key(key): # pragma: nocover + return key - def __setitem__(self, key, val): - key = self.transform_key(key) - super(KeyTransformingDict, self).__setitem__(key, val) + def __init__(self, *args, **kargs): + super(KeyTransformingDict, self).__init__() + # build a dictionary using the default constructs + d = dict(*args, **kargs) + # build this dictionary using transformed keys. + for item in d.items(): + self.__setitem__(*item) - def __getitem__(self, key): - key = self.transform_key(key) - return super(KeyTransformingDict, self).__getitem__(key) + def __setitem__(self, key, val): + key = self.transform_key(key) + super(KeyTransformingDict, self).__setitem__(key, val) - def __contains__(self, key): - key = self.transform_key(key) - return super(KeyTransformingDict, self).__contains__(key) + def __getitem__(self, key): + key = self.transform_key(key) + return super(KeyTransformingDict, self).__getitem__(key) - def __delitem__(self, key): - key = self.transform_key(key) - return super(KeyTransformingDict, self).__delitem__(key) + def __contains__(self, key): + key = self.transform_key(key) + return super(KeyTransformingDict, self).__contains__(key) - def get(self, key, *args, **kwargs): - key = self.transform_key(key) - return super(KeyTransformingDict, self).get(key, *args, **kwargs) + def __delitem__(self, key): + key = self.transform_key(key) + return super(KeyTransformingDict, self).__delitem__(key) - def setdefault(self, key, *args, **kwargs): - key = self.transform_key(key) - return super(KeyTransformingDict, self).setdefault(key, *args, **kwargs) + def get(self, key, *args, **kwargs): + key = self.transform_key(key) + return super(KeyTransformingDict, self).get(key, *args, **kwargs) - def pop(self, key, *args, **kwargs): - key = self.transform_key(key) - return super(KeyTransformingDict, self).pop(key, *args, **kwargs) + def setdefault(self, key, *args, **kwargs): + key = self.transform_key(key) + return super(KeyTransformingDict, self).setdefault(key, *args, **kwargs) - def matching_key_for(self, key): - """ - Given a key, return the actual key stored in self that matches. - Raise KeyError if the key isn't found. - """ - try: - return next(e_key for e_key in self.keys() if e_key == key) - except StopIteration: - raise KeyError(key) + def pop(self, key, *args, **kwargs): + key = self.transform_key(key) + return super(KeyTransformingDict, self).pop(key, *args, **kwargs) + + def matching_key_for(self, key): + """ + Given a key, return the actual key stored in self that matches. + Raise KeyError if the key isn't found. + """ + try: + return next(e_key for e_key in self.keys() if e_key == key) + except StopIteration: + raise KeyError(key) class FoldedCaseKeyedDict(KeyTransformingDict): - """ - A case-insensitive dictionary (keys are compared as insensitive - if they are strings). + """ + A case-insensitive dictionary (keys are compared as insensitive + if they are strings). - >>> d = FoldedCaseKeyedDict() - >>> d['heLlo'] = 'world' - >>> list(d.keys()) == ['heLlo'] - True - >>> list(d.values()) == ['world'] - True - >>> d['hello'] == 'world' - True - >>> 'hello' in d - True - >>> 'HELLO' in d - True - >>> print(repr(FoldedCaseKeyedDict({'heLlo': 'world'})).replace("u'", "'")) - {'heLlo': 'world'} - >>> d = FoldedCaseKeyedDict({'heLlo': 'world'}) - >>> print(d['hello']) - world - >>> print(d['Hello']) - world - >>> list(d.keys()) - ['heLlo'] - >>> d = FoldedCaseKeyedDict({'heLlo': 'world', 'Hello': 'world'}) - >>> list(d.values()) - ['world'] - >>> key, = d.keys() - >>> key in ['heLlo', 'Hello'] - True - >>> del d['HELLO'] - >>> d - {} + >>> d = FoldedCaseKeyedDict() + >>> d['heLlo'] = 'world' + >>> list(d.keys()) == ['heLlo'] + True + >>> list(d.values()) == ['world'] + True + >>> d['hello'] == 'world' + True + >>> 'hello' in d + True + >>> 'HELLO' in d + True + >>> print(repr(FoldedCaseKeyedDict({'heLlo': 'world'}))) + {'heLlo': 'world'} + >>> d = FoldedCaseKeyedDict({'heLlo': 'world'}) + >>> print(d['hello']) + world + >>> print(d['Hello']) + world + >>> list(d.keys()) + ['heLlo'] + >>> d = FoldedCaseKeyedDict({'heLlo': 'world', 'Hello': 'world'}) + >>> list(d.values()) + ['world'] + >>> key, = d.keys() + >>> key in ['heLlo', 'Hello'] + True + >>> del d['HELLO'] + >>> d + {} - get should work + get should work - >>> d['Sumthin'] = 'else' - >>> d.get('SUMTHIN') - 'else' - >>> d.get('OTHER', 'thing') - 'thing' - >>> del d['sumthin'] + >>> d['Sumthin'] = 'else' + >>> d.get('SUMTHIN') + 'else' + >>> d.get('OTHER', 'thing') + 'thing' + >>> del d['sumthin'] - setdefault should also work + setdefault should also work - >>> d['This'] = 'that' - >>> print(d.setdefault('this', 'other')) - that - >>> len(d) - 1 - >>> print(d['this']) - that - >>> print(d.setdefault('That', 'other')) - other - >>> print(d['THAT']) - other + >>> d['This'] = 'that' + >>> print(d.setdefault('this', 'other')) + that + >>> len(d) + 1 + >>> print(d['this']) + that + >>> print(d.setdefault('That', 'other')) + other + >>> print(d['THAT']) + other - Make it pop! + Make it pop! - >>> print(d.pop('THAT')) - other + >>> print(d.pop('THAT')) + other - To retrieve the key in its originally-supplied form, use matching_key_for + To retrieve the key in its originally-supplied form, use matching_key_for - >>> print(d.matching_key_for('this')) - This - """ - @staticmethod - def transform_key(key): - return jaraco.text.FoldedCase(key) + >>> print(d.matching_key_for('this')) + This + + >>> d.matching_key_for('missing') + Traceback (most recent call last): + ... + KeyError: 'missing' + """ + + @staticmethod + def transform_key(key): + return jaraco.text.FoldedCase(key) -class DictAdapter(object): - """ - Provide a getitem interface for attributes of an object. +class DictAdapter: + """ + Provide a getitem interface for attributes of an object. - Let's say you want to get at the string.lowercase property in a formatted - string. It's easy with DictAdapter. + Let's say you want to get at the string.lowercase property in a formatted + string. It's easy with DictAdapter. - >>> import string - >>> print("lowercase is %(ascii_lowercase)s" % DictAdapter(string)) - lowercase is abcdefghijklmnopqrstuvwxyz - """ - def __init__(self, wrapped_ob): - self.object = wrapped_ob + >>> import string + >>> print("lowercase is %(ascii_lowercase)s" % DictAdapter(string)) + lowercase is abcdefghijklmnopqrstuvwxyz + """ - def __getitem__(self, name): - return getattr(self.object, name) + def __init__(self, wrapped_ob): + self.object = wrapped_ob + + def __getitem__(self, name): + return getattr(self.object, name) -class ItemsAsAttributes(object): - """ - Mix-in class to enable a mapping object to provide items as - attributes. +class ItemsAsAttributes: + """ + Mix-in class to enable a mapping object to provide items as + attributes. - >>> C = type(str('C'), (dict, ItemsAsAttributes), dict()) - >>> i = C() - >>> i['foo'] = 'bar' - >>> i.foo - 'bar' + >>> C = type(str('C'), (dict, ItemsAsAttributes), dict()) + >>> i = C() + >>> i['foo'] = 'bar' + >>> i.foo + 'bar' - Natural attribute access takes precedence + Natural attribute access takes precedence - >>> i.foo = 'henry' - >>> i.foo - 'henry' + >>> i.foo = 'henry' + >>> i.foo + 'henry' - But as you might expect, the mapping functionality is preserved. + But as you might expect, the mapping functionality is preserved. - >>> i['foo'] - 'bar' + >>> i['foo'] + 'bar' - A normal attribute error should be raised if an attribute is - requested that doesn't exist. + A normal attribute error should be raised if an attribute is + requested that doesn't exist. - >>> i.missing - Traceback (most recent call last): - ... - AttributeError: 'C' object has no attribute 'missing' + >>> i.missing + Traceback (most recent call last): + ... + AttributeError: 'C' object has no attribute 'missing' - It also works on dicts that customize __getitem__ + It also works on dicts that customize __getitem__ - >>> missing_func = lambda self, key: 'missing item' - >>> C = type( - ... str('C'), - ... (dict, ItemsAsAttributes), - ... dict(__missing__ = missing_func), - ... ) - >>> i = C() - >>> i.missing - 'missing item' - >>> i.foo - 'missing item' - """ - def __getattr__(self, key): - try: - return getattr(super(ItemsAsAttributes, self), key) - except AttributeError as e: - # attempt to get the value from the mapping (return self[key]) - # but be careful not to lose the original exception context. - noval = object() + >>> missing_func = lambda self, key: 'missing item' + >>> C = type( + ... str('C'), + ... (dict, ItemsAsAttributes), + ... dict(__missing__ = missing_func), + ... ) + >>> i = C() + >>> i.missing + 'missing item' + >>> i.foo + 'missing item' + """ - def _safe_getitem(cont, key, missing_result): - try: - return cont[key] - except KeyError: - return missing_result - result = _safe_getitem(self, key, noval) - if result is not noval: - return result - # raise the original exception, but use the original class - # name, not 'super'. - message, = e.args - message = message.replace('super', self.__class__.__name__, 1) - e.args = message, - raise + def __getattr__(self, key): + try: + return getattr(super(ItemsAsAttributes, self), key) + except AttributeError as e: + # attempt to get the value from the mapping (return self[key]) + # but be careful not to lose the original exception context. + noval = object() + + def _safe_getitem(cont, key, missing_result): + try: + return cont[key] + except KeyError: + return missing_result + + result = _safe_getitem(self, key, noval) + if result is not noval: + return result + # raise the original exception, but use the original class + # name, not 'super'. + (message,) = e.args + message = message.replace('super', self.__class__.__name__, 1) + e.args = (message,) + raise def invert_map(map): - """ - Given a dictionary, return another dictionary with keys and values - switched. If any of the values resolve to the same key, raises - a ValueError. + """ + Given a dictionary, return another dictionary with keys and values + switched. If any of the values resolve to the same key, raises + a ValueError. - >>> numbers = dict(a=1, b=2, c=3) - >>> letters = invert_map(numbers) - >>> letters[1] - 'a' - >>> numbers['d'] = 3 - >>> invert_map(numbers) - Traceback (most recent call last): - ... - ValueError: Key conflict in inverted mapping - """ - res = dict((v, k) for k, v in map.items()) - if not len(res) == len(map): - raise ValueError('Key conflict in inverted mapping') - return res + >>> numbers = dict(a=1, b=2, c=3) + >>> letters = invert_map(numbers) + >>> letters[1] + 'a' + >>> numbers['d'] = 3 + >>> invert_map(numbers) + Traceback (most recent call last): + ... + ValueError: Key conflict in inverted mapping + """ + res = dict((v, k) for k, v in map.items()) + if not len(res) == len(map): + raise ValueError('Key conflict in inverted mapping') + return res class IdentityOverrideMap(dict): - """ - A dictionary that by default maps each key to itself, but otherwise - acts like a normal dictionary. + """ + A dictionary that by default maps each key to itself, but otherwise + acts like a normal dictionary. - >>> d = IdentityOverrideMap() - >>> d[42] - 42 - >>> d['speed'] = 'speedo' - >>> print(d['speed']) - speedo - """ + >>> d = IdentityOverrideMap() + >>> d[42] + 42 + >>> d['speed'] = 'speedo' + >>> print(d['speed']) + speedo + """ - def __missing__(self, key): - return key + def __missing__(self, key): + return key -class DictStack(list, collections.abc.Mapping): - """ - A stack of dictionaries that behaves as a view on those dictionaries, - giving preference to the last. +class DictStack(list, collections.abc.MutableMapping): + """ + A stack of dictionaries that behaves as a view on those dictionaries, + giving preference to the last. - >>> stack = DictStack([dict(a=1, c=2), dict(b=2, a=2)]) - >>> stack['a'] - 2 - >>> stack['b'] - 2 - >>> stack['c'] - 2 - >>> stack.push(dict(a=3)) - >>> stack['a'] - 3 - >>> set(stack.keys()) == set(['a', 'b', 'c']) - True - >>> d = stack.pop() - >>> stack['a'] - 2 - >>> d = stack.pop() - >>> stack['a'] - 1 - >>> stack.get('b', None) - """ + >>> stack = DictStack([dict(a=1, c=2), dict(b=2, a=2)]) + >>> stack['a'] + 2 + >>> stack['b'] + 2 + >>> stack['c'] + 2 + >>> len(stack) + 3 + >>> stack.push(dict(a=3)) + >>> stack['a'] + 3 + >>> stack['a'] = 4 + >>> set(stack.keys()) == set(['a', 'b', 'c']) + True + >>> set(stack.items()) == set([('a', 4), ('b', 2), ('c', 2)]) + True + >>> dict(**stack) == dict(stack) == dict(a=4, c=2, b=2) + True + >>> d = stack.pop() + >>> stack['a'] + 2 + >>> d = stack.pop() + >>> stack['a'] + 1 + >>> stack.get('b', None) + >>> 'c' in stack + True + >>> del stack['c'] + >>> dict(stack) + {'a': 1} + """ - def keys(self): - return list(set(itertools.chain.from_iterable(c.keys() for c in self))) + def __iter__(self): + dicts = list.__iter__(self) + return iter(set(itertools.chain.from_iterable(c.keys() for c in dicts))) - def __getitem__(self, key): - for scope in reversed(self): - if key in scope: - return scope[key] - raise KeyError(key) + def __getitem__(self, key): + for scope in reversed(tuple(list.__iter__(self))): + if key in scope: + return scope[key] + raise KeyError(key) - push = list.append + push = list.append + + def __contains__(self, other): + return collections.abc.Mapping.__contains__(self, other) + + def __len__(self): + return len(list(iter(self))) + + def __setitem__(self, key, item): + last = list.__getitem__(self, -1) + return last.__setitem__(key, item) + + def __delitem__(self, key): + last = list.__getitem__(self, -1) + return last.__delitem__(key) + + # workaround for mypy confusion + def pop(self, *args, **kwargs): + return list.pop(self, *args, **kwargs) class BijectiveMap(dict): - """ - A Bijective Map (two-way mapping). + """ + A Bijective Map (two-way mapping). - Implemented as a simple dictionary of 2x the size, mapping values back - to keys. + Implemented as a simple dictionary of 2x the size, mapping values back + to keys. - Note, this implementation may be incomplete. If there's not a test for - your use case below, it's likely to fail, so please test and send pull - requests or patches for additional functionality needed. + Note, this implementation may be incomplete. If there's not a test for + your use case below, it's likely to fail, so please test and send pull + requests or patches for additional functionality needed. - >>> m = BijectiveMap() - >>> m['a'] = 'b' - >>> m == {'a': 'b', 'b': 'a'} - True - >>> print(m['b']) - a + >>> m = BijectiveMap() + >>> m['a'] = 'b' + >>> m == {'a': 'b', 'b': 'a'} + True + >>> print(m['b']) + a - >>> m['c'] = 'd' - >>> len(m) - 2 + >>> m['c'] = 'd' + >>> len(m) + 2 - Some weird things happen if you map an item to itself or overwrite a - single key of a pair, so it's disallowed. + Some weird things happen if you map an item to itself or overwrite a + single key of a pair, so it's disallowed. - >>> m['e'] = 'e' - Traceback (most recent call last): - ValueError: Key cannot map to itself + >>> m['e'] = 'e' + Traceback (most recent call last): + ValueError: Key cannot map to itself - >>> m['d'] = 'e' - Traceback (most recent call last): - ValueError: Key/Value pairs may not overlap + >>> m['d'] = 'e' + Traceback (most recent call last): + ValueError: Key/Value pairs may not overlap - >>> m['e'] = 'd' - Traceback (most recent call last): - ValueError: Key/Value pairs may not overlap + >>> m['e'] = 'd' + Traceback (most recent call last): + ValueError: Key/Value pairs may not overlap - >>> print(m.pop('d')) - c + >>> print(m.pop('d')) + c - >>> 'c' in m - False + >>> 'c' in m + False - >>> m = BijectiveMap(dict(a='b')) - >>> len(m) - 1 - >>> print(m['b']) - a + >>> m = BijectiveMap(dict(a='b')) + >>> len(m) + 1 + >>> print(m['b']) + a - >>> m = BijectiveMap() - >>> m.update(a='b') - >>> m['b'] - 'a' + >>> m = BijectiveMap() + >>> m.update(a='b') + >>> m['b'] + 'a' - >>> del m['b'] - >>> len(m) - 0 - >>> 'a' in m - False - """ - def __init__(self, *args, **kwargs): - super(BijectiveMap, self).__init__() - self.update(*args, **kwargs) + >>> del m['b'] + >>> len(m) + 0 + >>> 'a' in m + False + """ - def __setitem__(self, item, value): - if item == value: - raise ValueError("Key cannot map to itself") - overlap = ( - item in self and self[item] != value - or - value in self and self[value] != item - ) - if overlap: - raise ValueError("Key/Value pairs may not overlap") - super(BijectiveMap, self).__setitem__(item, value) - super(BijectiveMap, self).__setitem__(value, item) + def __init__(self, *args, **kwargs): + super(BijectiveMap, self).__init__() + self.update(*args, **kwargs) - def __delitem__(self, item): - self.pop(item) + def __setitem__(self, item, value): + if item == value: + raise ValueError("Key cannot map to itself") + overlap = ( + item in self + and self[item] != value + or value in self + and self[value] != item + ) + if overlap: + raise ValueError("Key/Value pairs may not overlap") + super(BijectiveMap, self).__setitem__(item, value) + super(BijectiveMap, self).__setitem__(value, item) - def __len__(self): - return super(BijectiveMap, self).__len__() // 2 + def __delitem__(self, item): + self.pop(item) - def pop(self, key, *args, **kwargs): - mirror = self[key] - super(BijectiveMap, self).__delitem__(mirror) - return super(BijectiveMap, self).pop(key, *args, **kwargs) + def __len__(self): + return super(BijectiveMap, self).__len__() // 2 - def update(self, *args, **kwargs): - # build a dictionary using the default constructs - d = dict(*args, **kwargs) - # build this dictionary using transformed keys. - for item in d.items(): - self.__setitem__(*item) + def pop(self, key, *args, **kwargs): + mirror = self[key] + super(BijectiveMap, self).__delitem__(mirror) + return super(BijectiveMap, self).pop(key, *args, **kwargs) + + def update(self, *args, **kwargs): + # build a dictionary using the default constructs + d = dict(*args, **kwargs) + # build this dictionary using transformed keys. + for item in d.items(): + self.__setitem__(*item) class FrozenDict(collections.abc.Mapping, collections.abc.Hashable): - """ - An immutable mapping. + """ + An immutable mapping. - >>> a = FrozenDict(a=1, b=2) - >>> b = FrozenDict(a=1, b=2) - >>> a == b - True + >>> a = FrozenDict(a=1, b=2) + >>> b = FrozenDict(a=1, b=2) + >>> a == b + True - >>> a == dict(a=1, b=2) - True - >>> dict(a=1, b=2) == a - True + >>> a == dict(a=1, b=2) + True + >>> dict(a=1, b=2) == a + True + >>> 'a' in a + True + >>> type(hash(a)) is type(0) + True + >>> set(iter(a)) == {'a', 'b'} + True + >>> len(a) + 2 + >>> a['a'] == a.get('a') == 1 + True - >>> a['c'] = 3 - Traceback (most recent call last): - ... - TypeError: 'FrozenDict' object does not support item assignment + >>> a['c'] = 3 + Traceback (most recent call last): + ... + TypeError: 'FrozenDict' object does not support item assignment - >>> a.update(y=3) - Traceback (most recent call last): - ... - AttributeError: 'FrozenDict' object has no attribute 'update' + >>> a.update(y=3) + Traceback (most recent call last): + ... + AttributeError: 'FrozenDict' object has no attribute 'update' - Copies should compare equal + Copies should compare equal - >>> copy.copy(a) == a - True + >>> copy.copy(a) == a + True - Copies should be the same type + Copies should be the same type - >>> isinstance(copy.copy(a), FrozenDict) - True + >>> isinstance(copy.copy(a), FrozenDict) + True - FrozenDict supplies .copy(), even though - collections.abc.Mapping doesn't demand it. + FrozenDict supplies .copy(), even though + collections.abc.Mapping doesn't demand it. - >>> a.copy() == a - True - >>> a.copy() is not a - True - """ - __slots__ = ['__data'] + >>> a.copy() == a + True + >>> a.copy() is not a + True + """ - def __new__(cls, *args, **kwargs): - self = super(FrozenDict, cls).__new__(cls) - self.__data = dict(*args, **kwargs) - return self + __slots__ = ['__data'] - # Container - def __contains__(self, key): - return key in self.__data + def __new__(cls, *args, **kwargs): + self = super(FrozenDict, cls).__new__(cls) + self.__data = dict(*args, **kwargs) + return self - # Hashable - def __hash__(self): - return hash(tuple(sorted(self.__data.iteritems()))) + # Container + def __contains__(self, key): + return key in self.__data - # Mapping - def __iter__(self): - return iter(self.__data) + # Hashable + def __hash__(self): + return hash(tuple(sorted(self.__data.items()))) - def __len__(self): - return len(self.__data) + # Mapping + def __iter__(self): + return iter(self.__data) - def __getitem__(self, key): - return self.__data[key] + def __len__(self): + return len(self.__data) - # override get for efficiency provided by dict - def get(self, *args, **kwargs): - return self.__data.get(*args, **kwargs) + def __getitem__(self, key): + return self.__data[key] - # override eq to recognize underlying implementation - def __eq__(self, other): - if isinstance(other, FrozenDict): - other = other.__data - return self.__data.__eq__(other) + # override get for efficiency provided by dict + def get(self, *args, **kwargs): + return self.__data.get(*args, **kwargs) - def copy(self): - "Return a shallow copy of self" - return copy.copy(self) + # override eq to recognize underlying implementation + def __eq__(self, other): + if isinstance(other, FrozenDict): + other = other.__data + return self.__data.__eq__(other) + + def copy(self): + "Return a shallow copy of self" + return copy.copy(self) class Enumeration(ItemsAsAttributes, BijectiveMap): - """ - A convenient way to provide enumerated values + """ + A convenient way to provide enumerated values - >>> e = Enumeration('a b c') - >>> e['a'] - 0 + >>> e = Enumeration('a b c') + >>> e['a'] + 0 - >>> e.a - 0 + >>> e.a + 0 - >>> e[1] - 'b' + >>> e[1] + 'b' - >>> set(e.names) == set('abc') - True + >>> set(e.names) == set('abc') + True - >>> set(e.codes) == set(range(3)) - True + >>> set(e.codes) == set(range(3)) + True - >>> e.get('d') is None - True + >>> e.get('d') is None + True - Codes need not start with 0 + Codes need not start with 0 - >>> e = Enumeration('a b c', range(1, 4)) - >>> e['a'] - 1 + >>> e = Enumeration('a b c', range(1, 4)) + >>> e['a'] + 1 - >>> e[3] - 'c' - """ - def __init__(self, names, codes=None): - if isinstance(names, six.string_types): - names = names.split() - if codes is None: - codes = itertools.count() - super(Enumeration, self).__init__(zip(names, codes)) + >>> e[3] + 'c' + """ - @property - def names(self): - return (key for key in self if isinstance(key, six.string_types)) + def __init__(self, names, codes=None): + if isinstance(names, str): + names = names.split() + if codes is None: + codes = itertools.count() + super(Enumeration, self).__init__(zip(names, codes)) - @property - def codes(self): - return (self[name] for name in self.names) + @property + def names(self): + return (key for key in self if isinstance(key, str)) + + @property + def codes(self): + return (self[name] for name in self.names) -class Everything(object): - """ - A collection "containing" every possible thing. +class Everything: + """ + A collection "containing" every possible thing. - >>> 'foo' in Everything() - True + >>> 'foo' in Everything() + True - >>> import random - >>> random.randint(1, 999) in Everything() - True + >>> import random + >>> random.randint(1, 999) in Everything() + True - >>> random.choice([None, 'foo', 42, ('a', 'b', 'c')]) in Everything() - True - """ - def __contains__(self, other): - return True + >>> random.choice([None, 'foo', 42, ('a', 'b', 'c')]) in Everything() + True + """ + + def __contains__(self, other): + return True -class InstrumentedDict(six.moves.UserDict): - """ - Instrument an existing dictionary with additional - functionality, but always reference and mutate - the original dictionary. +class InstrumentedDict(collections.UserDict): # type: ignore # buggy mypy + """ + Instrument an existing dictionary with additional + functionality, but always reference and mutate + the original dictionary. - >>> orig = {'a': 1, 'b': 2} - >>> inst = InstrumentedDict(orig) - >>> inst['a'] - 1 - >>> inst['c'] = 3 - >>> orig['c'] - 3 - >>> inst.keys() == orig.keys() - True - """ - def __init__(self, data): - six.moves.UserDict.__init__(self) - self.data = data + >>> orig = {'a': 1, 'b': 2} + >>> inst = InstrumentedDict(orig) + >>> inst['a'] + 1 + >>> inst['c'] = 3 + >>> orig['c'] + 3 + >>> inst.keys() == orig.keys() + True + """ + + def __init__(self, data): + super().__init__() + self.data = data -class Least(object): - """ - A value that is always lesser than any other +class Least: + """ + A value that is always lesser than any other - >>> least = Least() - >>> 3 < least - False - >>> 3 > least - True - >>> least < 3 - True - >>> least <= 3 - True - >>> least > 3 - False - >>> 'x' > least - True - >>> None > least - True - """ + >>> least = Least() + >>> 3 < least + False + >>> 3 > least + True + >>> least < 3 + True + >>> least <= 3 + True + >>> least > 3 + False + >>> 'x' > least + True + >>> None > least + True + """ - def __le__(self, other): - return True - __lt__ = __le__ + def __le__(self, other): + return True - def __ge__(self, other): - return False - __gt__ = __ge__ + __lt__ = __le__ + + def __ge__(self, other): + return False + + __gt__ = __ge__ -class Greatest(object): - """ - A value that is always greater than any other +class Greatest: + """ + A value that is always greater than any other - >>> greatest = Greatest() - >>> 3 < greatest - True - >>> 3 > greatest - False - >>> greatest < 3 - False - >>> greatest > 3 - True - >>> greatest >= 3 - True - >>> 'x' > greatest - False - >>> None > greatest - False - """ + >>> greatest = Greatest() + >>> 3 < greatest + True + >>> 3 > greatest + False + >>> greatest < 3 + False + >>> greatest > 3 + True + >>> greatest >= 3 + True + >>> 'x' > greatest + False + >>> None > greatest + False + """ - def __ge__(self, other): - return True - __gt__ = __ge__ + def __ge__(self, other): + return True - def __le__(self, other): - return False - __lt__ = __le__ + __gt__ = __ge__ + + def __le__(self, other): + return False + + __lt__ = __le__ + + +def pop_all(items): + """ + Clear items in place and return a copy of items. + + >>> items = [1, 2, 3] + >>> popped = pop_all(items) + >>> popped is items + False + >>> popped + [1, 2, 3] + >>> items + [] + """ + result, items[:] = items[:], [] + return result + + +# mypy disabled for pytest-dev/pytest#8332 +class FreezableDefaultDict(collections.defaultdict): # type: ignore + """ + Often it is desirable to prevent the mutation of + a default dict after its initial construction, such + as to prevent mutation during iteration. + + >>> dd = FreezableDefaultDict(list) + >>> dd[0].append('1') + >>> dd.freeze() + >>> dd[1] + [] + >>> len(dd) + 1 + """ + + def __missing__(self, key): + return getattr(self, '_frozen', super().__missing__)(key) + + def freeze(self): + self._frozen = lambda key: self.default_factory() + + +class Accumulator: + def __init__(self, initial=0): + self.val = initial + + def __call__(self, val): + self.val += val + return self.val + + +class WeightedLookup(RangeMap): + """ + Given parameters suitable for a dict representing keys + and a weighted proportion, return a RangeMap representing + spans of values proportial to the weights: + + >>> even = WeightedLookup(a=1, b=1) + + [0, 1) -> a + [1, 2) -> b + + >>> lk = WeightedLookup(a=1, b=2) + + [0, 1) -> a + [1, 3) -> b + + >>> lk[.5] + 'a' + >>> lk[1.5] + 'b' + + Adds ``.random()`` to select a random weighted value: + + >>> lk.random() in ['a', 'b'] + True + + >>> choices = [lk.random() for x in range(1000)] + + Statistically speaking, choices should be .5 a:b + >>> ratio = choices.count('a') / choices.count('b') + >>> .4 < ratio < .6 + True + """ + + def __init__(self, *args, **kwargs): + raw = dict(*args, **kwargs) + + # allocate keys by weight + indexes = map(Accumulator(), raw.values()) + super().__init__(zip(indexes, raw.keys()), key_match_comparator=operator.lt) + + def random(self): + lower, upper = self.bounds() + selector = random.random() * upper + return self[selector] diff --git a/libs/win/jaraco/context.py b/libs/win/jaraco/context.py new file mode 100644 index 00000000..818f16f3 --- /dev/null +++ b/libs/win/jaraco/context.py @@ -0,0 +1,253 @@ +import os +import subprocess +import contextlib +import functools +import tempfile +import shutil +import operator + + +@contextlib.contextmanager +def pushd(dir): + orig = os.getcwd() + os.chdir(dir) + try: + yield dir + finally: + os.chdir(orig) + + +@contextlib.contextmanager +def tarball_context(url, target_dir=None, runner=None, pushd=pushd): + """ + Get a tarball, extract it, change to that directory, yield, then + clean up. + `runner` is the function to invoke commands. + `pushd` is a context manager for changing the directory. + """ + if target_dir is None: + target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '') + if runner is None: + runner = functools.partial(subprocess.check_call, shell=True) + # In the tar command, use --strip-components=1 to strip the first path and + # then + # use -C to cause the files to be extracted to {target_dir}. This ensures + # that we always know where the files were extracted. + runner('mkdir {target_dir}'.format(**vars())) + try: + getter = 'wget {url} -O -' + extract = 'tar x{compression} --strip-components=1 -C {target_dir}' + cmd = ' | '.join((getter, extract)) + runner(cmd.format(compression=infer_compression(url), **vars())) + with pushd(target_dir): + yield target_dir + finally: + runner('rm -Rf {target_dir}'.format(**vars())) + + +def infer_compression(url): + """ + Given a URL or filename, infer the compression code for tar. + """ + # cheat and just assume it's the last two characters + compression_indicator = url[-2:] + mapping = dict(gz='z', bz='j', xz='J') + # Assume 'z' (gzip) if no match + return mapping.get(compression_indicator, 'z') + + +@contextlib.contextmanager +def temp_dir(remover=shutil.rmtree): + """ + Create a temporary directory context. Pass a custom remover + to override the removal behavior. + """ + temp_dir = tempfile.mkdtemp() + try: + yield temp_dir + finally: + remover(temp_dir) + + +@contextlib.contextmanager +def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir): + """ + Check out the repo indicated by url. + + If dest_ctx is supplied, it should be a context manager + to yield the target directory for the check out. + """ + exe = 'git' if 'git' in url else 'hg' + with dest_ctx() as repo_dir: + cmd = [exe, 'clone', url, repo_dir] + if branch: + cmd.extend(['--branch', branch]) + devnull = open(os.path.devnull, 'w') + stdout = devnull if quiet else None + subprocess.check_call(cmd, stdout=stdout) + yield repo_dir + + +@contextlib.contextmanager +def null(): + yield + + +class ExceptionTrap: + """ + A context manager that will catch certain exceptions and provide an + indication they occurred. + + >>> with ExceptionTrap() as trap: + ... raise Exception() + >>> bool(trap) + True + + >>> with ExceptionTrap() as trap: + ... pass + >>> bool(trap) + False + + >>> with ExceptionTrap(ValueError) as trap: + ... raise ValueError("1 + 1 is not 3") + >>> bool(trap) + True + + >>> with ExceptionTrap(ValueError) as trap: + ... raise Exception() + Traceback (most recent call last): + ... + Exception + + >>> bool(trap) + False + """ + + exc_info = None, None, None + + def __init__(self, exceptions=(Exception,)): + self.exceptions = exceptions + + def __enter__(self): + return self + + @property + def type(self): + return self.exc_info[0] + + @property + def value(self): + return self.exc_info[1] + + @property + def tb(self): + return self.exc_info[2] + + def __exit__(self, *exc_info): + type = exc_info[0] + matches = type and issubclass(type, self.exceptions) + if matches: + self.exc_info = exc_info + return matches + + def __bool__(self): + return bool(self.type) + + def raises(self, func, *, _test=bool): + """ + Wrap func and replace the result with the truth + value of the trap (True if an exception occurred). + + First, give the decorator an alias to support Python 3.8 + Syntax. + + >>> raises = ExceptionTrap(ValueError).raises + + Now decorate a function that always fails. + + >>> @raises + ... def fail(): + ... raise ValueError('failed') + >>> fail() + True + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + with ExceptionTrap(self.exceptions) as trap: + func(*args, **kwargs) + return _test(trap) + + return wrapper + + def passes(self, func): + """ + Wrap func and replace the result with the truth + value of the trap (True if no exception). + + First, give the decorator an alias to support Python 3.8 + Syntax. + + >>> passes = ExceptionTrap(ValueError).passes + + Now decorate a function that always fails. + + >>> @passes + ... def fail(): + ... raise ValueError('failed') + + >>> fail() + False + """ + return self.raises(func, _test=operator.not_) + + +class suppress(contextlib.suppress, contextlib.ContextDecorator): + """ + A version of contextlib.suppress with decorator support. + + >>> @suppress(KeyError) + ... def key_error(): + ... {}[''] + >>> key_error() + """ + + +class on_interrupt(contextlib.ContextDecorator): + """ + Replace a KeyboardInterrupt with SystemExit(1) + + >>> def do_interrupt(): + ... raise KeyboardInterrupt() + >>> on_interrupt('error')(do_interrupt)() + Traceback (most recent call last): + ... + SystemExit: 1 + >>> on_interrupt('error', code=255)(do_interrupt)() + Traceback (most recent call last): + ... + SystemExit: 255 + >>> on_interrupt('suppress')(do_interrupt)() + >>> with __import__('pytest').raises(KeyboardInterrupt): + ... on_interrupt('ignore')(do_interrupt)() + """ + + def __init__( + self, + action='error', + # py3.7 compat + # /, + code=1, + ): + self.action = action + self.code = code + + def __enter__(self): + return self + + def __exit__(self, exctype, excinst, exctb): + if exctype is not KeyboardInterrupt or self.action == 'ignore': + return + elif self.action == 'error': + raise SystemExit(self.code) from excinst + return self.action == 'suppress' diff --git a/libs/win/jaraco/functools.py b/libs/win/jaraco/functools.py index 134102a7..fcdbb4f9 100644 --- a/libs/win/jaraco/functools.py +++ b/libs/win/jaraco/functools.py @@ -1,459 +1,525 @@ -from __future__ import ( - absolute_import, unicode_literals, print_function, division, -) - import functools import time -import warnings import inspect import collections -from itertools import count +import types +import itertools -__metaclass__ = type +import more_itertools + +from typing import Callable, TypeVar -try: - from functools import lru_cache -except ImportError: - try: - from backports.functools_lru_cache import lru_cache - except ImportError: - try: - from functools32 import lru_cache - except ImportError: - warnings.warn("No lru_cache available") - - -import more_itertools.recipes +CallableT = TypeVar("CallableT", bound=Callable[..., object]) def compose(*funcs): - """ - Compose any number of unary functions into a single unary function. + """ + Compose any number of unary functions into a single unary function. - >>> import textwrap - >>> from six import text_type - >>> stripped = text_type.strip(textwrap.dedent(compose.__doc__)) - >>> compose(text_type.strip, textwrap.dedent)(compose.__doc__) == stripped - True + >>> import textwrap + >>> expected = str.strip(textwrap.dedent(compose.__doc__)) + >>> strip_and_dedent = compose(str.strip, textwrap.dedent) + >>> strip_and_dedent(compose.__doc__) == expected + True - Compose also allows the innermost function to take arbitrary arguments. + Compose also allows the innermost function to take arbitrary arguments. - >>> round_three = lambda x: round(x, ndigits=3) - >>> f = compose(round_three, int.__truediv__) - >>> [f(3*x, x+1) for x in range(1,10)] - [1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7] - """ + >>> round_three = lambda x: round(x, ndigits=3) + >>> f = compose(round_three, int.__truediv__) + >>> [f(3*x, x+1) for x in range(1,10)] + [1.5, 2.0, 2.25, 2.4, 2.5, 2.571, 2.625, 2.667, 2.7] + """ - def compose_two(f1, f2): - return lambda *args, **kwargs: f1(f2(*args, **kwargs)) - return functools.reduce(compose_two, funcs) + def compose_two(f1, f2): + return lambda *args, **kwargs: f1(f2(*args, **kwargs)) + + return functools.reduce(compose_two, funcs) def method_caller(method_name, *args, **kwargs): - """ - Return a function that will call a named method on the - target object with optional positional and keyword - arguments. + """ + Return a function that will call a named method on the + target object with optional positional and keyword + arguments. - >>> lower = method_caller('lower') - >>> lower('MyString') - 'mystring' - """ - def call_method(target): - func = getattr(target, method_name) - return func(*args, **kwargs) - return call_method + >>> lower = method_caller('lower') + >>> lower('MyString') + 'mystring' + """ + + def call_method(target): + func = getattr(target, method_name) + return func(*args, **kwargs) + + return call_method def once(func): - """ - Decorate func so it's only ever called the first time. + """ + Decorate func so it's only ever called the first time. - This decorator can ensure that an expensive or non-idempotent function - will not be expensive on subsequent calls and is idempotent. + This decorator can ensure that an expensive or non-idempotent function + will not be expensive on subsequent calls and is idempotent. - >>> add_three = once(lambda a: a+3) - >>> add_three(3) - 6 - >>> add_three(9) - 6 - >>> add_three('12') - 6 + >>> add_three = once(lambda a: a+3) + >>> add_three(3) + 6 + >>> add_three(9) + 6 + >>> add_three('12') + 6 - To reset the stored value, simply clear the property ``saved_result``. + To reset the stored value, simply clear the property ``saved_result``. - >>> del add_three.saved_result - >>> add_three(9) - 12 - >>> add_three(8) - 12 + >>> del add_three.saved_result + >>> add_three(9) + 12 + >>> add_three(8) + 12 - Or invoke 'reset()' on it. + Or invoke 'reset()' on it. - >>> add_three.reset() - >>> add_three(-3) - 0 - >>> add_three(0) - 0 - """ - @functools.wraps(func) - def wrapper(*args, **kwargs): - if not hasattr(wrapper, 'saved_result'): - wrapper.saved_result = func(*args, **kwargs) - return wrapper.saved_result - wrapper.reset = lambda: vars(wrapper).__delitem__('saved_result') - return wrapper + >>> add_three.reset() + >>> add_three(-3) + 0 + >>> add_three(0) + 0 + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not hasattr(wrapper, 'saved_result'): + wrapper.saved_result = func(*args, **kwargs) + return wrapper.saved_result + + wrapper.reset = lambda: vars(wrapper).__delitem__('saved_result') + return wrapper -def method_cache(method, cache_wrapper=None): - """ - Wrap lru_cache to support storing the cache data in the object instances. +def method_cache( + method: CallableT, + cache_wrapper: Callable[ + [CallableT], CallableT + ] = functools.lru_cache(), # type: ignore[assignment] +) -> CallableT: + """ + Wrap lru_cache to support storing the cache data in the object instances. - Abstracts the common paradigm where the method explicitly saves an - underscore-prefixed protected property on first call and returns that - subsequently. + Abstracts the common paradigm where the method explicitly saves an + underscore-prefixed protected property on first call and returns that + subsequently. - >>> class MyClass: - ... calls = 0 - ... - ... @method_cache - ... def method(self, value): - ... self.calls += 1 - ... return value + >>> class MyClass: + ... calls = 0 + ... + ... @method_cache + ... def method(self, value): + ... self.calls += 1 + ... return value - >>> a = MyClass() - >>> a.method(3) - 3 - >>> for x in range(75): - ... res = a.method(x) - >>> a.calls - 75 + >>> a = MyClass() + >>> a.method(3) + 3 + >>> for x in range(75): + ... res = a.method(x) + >>> a.calls + 75 - Note that the apparent behavior will be exactly like that of lru_cache - except that the cache is stored on each instance, so values in one - instance will not flush values from another, and when an instance is - deleted, so are the cached values for that instance. + Note that the apparent behavior will be exactly like that of lru_cache + except that the cache is stored on each instance, so values in one + instance will not flush values from another, and when an instance is + deleted, so are the cached values for that instance. - >>> b = MyClass() - >>> for x in range(35): - ... res = b.method(x) - >>> b.calls - 35 - >>> a.method(0) - 0 - >>> a.calls - 75 + >>> b = MyClass() + >>> for x in range(35): + ... res = b.method(x) + >>> b.calls + 35 + >>> a.method(0) + 0 + >>> a.calls + 75 - Note that if method had been decorated with ``functools.lru_cache()``, - a.calls would have been 76 (due to the cached value of 0 having been - flushed by the 'b' instance). + Note that if method had been decorated with ``functools.lru_cache()``, + a.calls would have been 76 (due to the cached value of 0 having been + flushed by the 'b' instance). - Clear the cache with ``.cache_clear()`` + Clear the cache with ``.cache_clear()`` - >>> a.method.cache_clear() + >>> a.method.cache_clear() - Another cache wrapper may be supplied: + Same for a method that hasn't yet been called. - >>> cache = lru_cache(maxsize=2) - >>> MyClass.method2 = method_cache(lambda self: 3, cache_wrapper=cache) - >>> a = MyClass() - >>> a.method2() - 3 + >>> c = MyClass() + >>> c.method.cache_clear() - Caution - do not subsequently wrap the method with another decorator, such - as ``@property``, which changes the semantics of the function. + Another cache wrapper may be supplied: - See also - http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/ - for another implementation and additional justification. - """ - cache_wrapper = cache_wrapper or lru_cache() + >>> cache = functools.lru_cache(maxsize=2) + >>> MyClass.method2 = method_cache(lambda self: 3, cache_wrapper=cache) + >>> a = MyClass() + >>> a.method2() + 3 - def wrapper(self, *args, **kwargs): - # it's the first call, replace the method with a cached, bound method - bound_method = functools.partial(method, self) - cached_method = cache_wrapper(bound_method) - setattr(self, method.__name__, cached_method) - return cached_method(*args, **kwargs) + Caution - do not subsequently wrap the method with another decorator, such + as ``@property``, which changes the semantics of the function. - return _special_method_cache(method, cache_wrapper) or wrapper + See also + http://code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/ + for another implementation and additional justification. + """ + + def wrapper(self: object, *args: object, **kwargs: object) -> object: + # it's the first call, replace the method with a cached, bound method + bound_method: CallableT = types.MethodType( # type: ignore[assignment] + method, self + ) + cached_method = cache_wrapper(bound_method) + setattr(self, method.__name__, cached_method) + return cached_method(*args, **kwargs) + + # Support cache clear even before cache has been created. + wrapper.cache_clear = lambda: None # type: ignore[attr-defined] + + return ( # type: ignore[return-value] + _special_method_cache(method, cache_wrapper) or wrapper + ) def _special_method_cache(method, cache_wrapper): - """ - Because Python treats special methods differently, it's not - possible to use instance attributes to implement the cached - methods. + """ + Because Python treats special methods differently, it's not + possible to use instance attributes to implement the cached + methods. - Instead, install the wrapper method under a different name - and return a simple proxy to that wrapper. + Instead, install the wrapper method under a different name + and return a simple proxy to that wrapper. - https://github.com/jaraco/jaraco.functools/issues/5 - """ - name = method.__name__ - special_names = '__getattr__', '__getitem__' - if name not in special_names: - return + https://github.com/jaraco/jaraco.functools/issues/5 + """ + name = method.__name__ + special_names = '__getattr__', '__getitem__' + if name not in special_names: + return - wrapper_name = '__cached' + name + wrapper_name = '__cached' + name - def proxy(self, *args, **kwargs): - if wrapper_name not in vars(self): - bound = functools.partial(method, self) - cache = cache_wrapper(bound) - setattr(self, wrapper_name, cache) - else: - cache = getattr(self, wrapper_name) - return cache(*args, **kwargs) + def proxy(self, *args, **kwargs): + if wrapper_name not in vars(self): + bound = types.MethodType(method, self) + cache = cache_wrapper(bound) + setattr(self, wrapper_name, cache) + else: + cache = getattr(self, wrapper_name) + return cache(*args, **kwargs) - return proxy + return proxy def apply(transform): - """ - Decorate a function with a transform function that is - invoked on results returned from the decorated function. + """ + Decorate a function with a transform function that is + invoked on results returned from the decorated function. - >>> @apply(reversed) - ... def get_numbers(start): - ... return range(start, start+3) - >>> list(get_numbers(4)) - [6, 5, 4] - """ - def wrap(func): - return compose(transform, func) - return wrap + >>> @apply(reversed) + ... def get_numbers(start): + ... "doc for get_numbers" + ... return range(start, start+3) + >>> list(get_numbers(4)) + [6, 5, 4] + >>> get_numbers.__doc__ + 'doc for get_numbers' + """ + + def wrap(func): + return functools.wraps(func)(compose(transform, func)) + + return wrap def result_invoke(action): - r""" - Decorate a function with an action function that is - invoked on the results returned from the decorated - function (for its side-effect), then return the original - result. + r""" + Decorate a function with an action function that is + invoked on the results returned from the decorated + function (for its side-effect), then return the original + result. - >>> @result_invoke(print) - ... def add_two(a, b): - ... return a + b - >>> x = add_two(2, 3) - 5 - """ - def wrap(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - result = func(*args, **kwargs) - action(result) - return result - return wrapper - return wrap + >>> @result_invoke(print) + ... def add_two(a, b): + ... return a + b + >>> x = add_two(2, 3) + 5 + >>> x + 5 + """ + + def wrap(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + result = func(*args, **kwargs) + action(result) + return result + + return wrapper + + return wrap def call_aside(f, *args, **kwargs): - """ - Call a function for its side effect after initialization. + """ + Call a function for its side effect after initialization. - >>> @call_aside - ... def func(): print("called") - called - >>> func() - called + >>> @call_aside + ... def func(): print("called") + called + >>> func() + called - Use functools.partial to pass parameters to the initial call + Use functools.partial to pass parameters to the initial call - >>> @functools.partial(call_aside, name='bingo') - ... def func(name): print("called with", name) - called with bingo - """ - f(*args, **kwargs) - return f + >>> @functools.partial(call_aside, name='bingo') + ... def func(name): print("called with", name) + called with bingo + """ + f(*args, **kwargs) + return f class Throttler: - """ - Rate-limit a function (or other callable) - """ - def __init__(self, func, max_rate=float('Inf')): - if isinstance(func, Throttler): - func = func.func - self.func = func - self.max_rate = max_rate - self.reset() + """ + Rate-limit a function (or other callable) + """ - def reset(self): - self.last_called = 0 + def __init__(self, func, max_rate=float('Inf')): + if isinstance(func, Throttler): + func = func.func + self.func = func + self.max_rate = max_rate + self.reset() - def __call__(self, *args, **kwargs): - self._wait() - return self.func(*args, **kwargs) + def reset(self): + self.last_called = 0 - def _wait(self): - "ensure at least 1/max_rate seconds from last call" - elapsed = time.time() - self.last_called - must_wait = 1 / self.max_rate - elapsed - time.sleep(max(0, must_wait)) - self.last_called = time.time() + def __call__(self, *args, **kwargs): + self._wait() + return self.func(*args, **kwargs) - def __get__(self, obj, type=None): - return first_invoke(self._wait, functools.partial(self.func, obj)) + def _wait(self): + "ensure at least 1/max_rate seconds from last call" + elapsed = time.time() - self.last_called + must_wait = 1 / self.max_rate - elapsed + time.sleep(max(0, must_wait)) + self.last_called = time.time() + + def __get__(self, obj, type=None): + return first_invoke(self._wait, functools.partial(self.func, obj)) def first_invoke(func1, func2): - """ - Return a function that when invoked will invoke func1 without - any parameters (for its side-effect) and then invoke func2 - with whatever parameters were passed, returning its result. - """ - def wrapper(*args, **kwargs): - func1() - return func2(*args, **kwargs) - return wrapper + """ + Return a function that when invoked will invoke func1 without + any parameters (for its side-effect) and then invoke func2 + with whatever parameters were passed, returning its result. + """ + + def wrapper(*args, **kwargs): + func1() + return func2(*args, **kwargs) + + return wrapper def retry_call(func, cleanup=lambda: None, retries=0, trap=()): - """ - Given a callable func, trap the indicated exceptions - for up to 'retries' times, invoking cleanup on the - exception. On the final attempt, allow any exceptions - to propagate. - """ - attempts = count() if retries == float('inf') else range(retries) - for attempt in attempts: - try: - return func() - except trap: - cleanup() + """ + Given a callable func, trap the indicated exceptions + for up to 'retries' times, invoking cleanup on the + exception. On the final attempt, allow any exceptions + to propagate. + """ + attempts = itertools.count() if retries == float('inf') else range(retries) + for attempt in attempts: + try: + return func() + except trap: + cleanup() - return func() + return func() def retry(*r_args, **r_kwargs): - """ - Decorator wrapper for retry_call. Accepts arguments to retry_call - except func and then returns a decorator for the decorated function. + """ + Decorator wrapper for retry_call. Accepts arguments to retry_call + except func and then returns a decorator for the decorated function. - Ex: + Ex: - >>> @retry(retries=3) - ... def my_func(a, b): - ... "this is my funk" - ... print(a, b) - >>> my_func.__doc__ - 'this is my funk' - """ - def decorate(func): - @functools.wraps(func) - def wrapper(*f_args, **f_kwargs): - bound = functools.partial(func, *f_args, **f_kwargs) - return retry_call(bound, *r_args, **r_kwargs) - return wrapper - return decorate + >>> @retry(retries=3) + ... def my_func(a, b): + ... "this is my funk" + ... print(a, b) + >>> my_func.__doc__ + 'this is my funk' + """ + + def decorate(func): + @functools.wraps(func) + def wrapper(*f_args, **f_kwargs): + bound = functools.partial(func, *f_args, **f_kwargs) + return retry_call(bound, *r_args, **r_kwargs) + + return wrapper + + return decorate def print_yielded(func): - """ - Convert a generator into a function that prints all yielded elements + """ + Convert a generator into a function that prints all yielded elements - >>> @print_yielded - ... def x(): - ... yield 3; yield None - >>> x() - 3 - None - """ - print_all = functools.partial(map, print) - print_results = compose(more_itertools.recipes.consume, print_all, func) - return functools.wraps(func)(print_results) + >>> @print_yielded + ... def x(): + ... yield 3; yield None + >>> x() + 3 + None + """ + print_all = functools.partial(map, print) + print_results = compose(more_itertools.consume, print_all, func) + return functools.wraps(func)(print_results) def pass_none(func): - """ - Wrap func so it's not called if its first param is None + """ + Wrap func so it's not called if its first param is None - >>> print_text = pass_none(print) - >>> print_text('text') - text - >>> print_text(None) - """ - @functools.wraps(func) - def wrapper(param, *args, **kwargs): - if param is not None: - return func(param, *args, **kwargs) - return wrapper + >>> print_text = pass_none(print) + >>> print_text('text') + text + >>> print_text(None) + """ + + @functools.wraps(func) + def wrapper(param, *args, **kwargs): + if param is not None: + return func(param, *args, **kwargs) + + return wrapper def assign_params(func, namespace): - """ - Assign parameters from namespace where func solicits. + """ + Assign parameters from namespace where func solicits. - >>> def func(x, y=3): - ... print(x, y) - >>> assigned = assign_params(func, dict(x=2, z=4)) - >>> assigned() - 2 3 + >>> def func(x, y=3): + ... print(x, y) + >>> assigned = assign_params(func, dict(x=2, z=4)) + >>> assigned() + 2 3 - The usual errors are raised if a function doesn't receive - its required parameters: + The usual errors are raised if a function doesn't receive + its required parameters: - >>> assigned = assign_params(func, dict(y=3, z=4)) - >>> assigned() - Traceback (most recent call last): - TypeError: func() ...argument... - """ - try: - sig = inspect.signature(func) - params = sig.parameters.keys() - except AttributeError: - spec = inspect.getargspec(func) - params = spec.args - call_ns = { - k: namespace[k] - for k in params - if k in namespace - } - return functools.partial(func, **call_ns) + >>> assigned = assign_params(func, dict(y=3, z=4)) + >>> assigned() + Traceback (most recent call last): + TypeError: func() ...argument... + + It even works on methods: + + >>> class Handler: + ... def meth(self, arg): + ... print(arg) + >>> assign_params(Handler().meth, dict(arg='crystal', foo='clear'))() + crystal + """ + sig = inspect.signature(func) + params = sig.parameters.keys() + call_ns = {k: namespace[k] for k in params if k in namespace} + return functools.partial(func, **call_ns) def save_method_args(method): - """ - Wrap a method such that when it is called, the args and kwargs are - saved on the method. + """ + Wrap a method such that when it is called, the args and kwargs are + saved on the method. - >>> class MyClass: - ... @save_method_args - ... def method(self, a, b): - ... print(a, b) - >>> my_ob = MyClass() - >>> my_ob.method(1, 2) - 1 2 - >>> my_ob._saved_method.args - (1, 2) - >>> my_ob._saved_method.kwargs - {} - >>> my_ob.method(a=3, b='foo') - 3 foo - >>> my_ob._saved_method.args - () - >>> my_ob._saved_method.kwargs == dict(a=3, b='foo') - True + >>> class MyClass: + ... @save_method_args + ... def method(self, a, b): + ... print(a, b) + >>> my_ob = MyClass() + >>> my_ob.method(1, 2) + 1 2 + >>> my_ob._saved_method.args + (1, 2) + >>> my_ob._saved_method.kwargs + {} + >>> my_ob.method(a=3, b='foo') + 3 foo + >>> my_ob._saved_method.args + () + >>> my_ob._saved_method.kwargs == dict(a=3, b='foo') + True - The arguments are stored on the instance, allowing for - different instance to save different args. + The arguments are stored on the instance, allowing for + different instance to save different args. - >>> your_ob = MyClass() - >>> your_ob.method({str('x'): 3}, b=[4]) - {'x': 3} [4] - >>> your_ob._saved_method.args - ({'x': 3},) - >>> my_ob._saved_method.args - () - """ - args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs') + >>> your_ob = MyClass() + >>> your_ob.method({str('x'): 3}, b=[4]) + {'x': 3} [4] + >>> your_ob._saved_method.args + ({'x': 3},) + >>> my_ob._saved_method.args + () + """ + args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs') - @functools.wraps(method) - def wrapper(self, *args, **kwargs): - attr_name = '_saved_' + method.__name__ - attr = args_and_kwargs(args, kwargs) - setattr(self, attr_name, attr) - return method(self, *args, **kwargs) - return wrapper + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + attr_name = '_saved_' + method.__name__ + attr = args_and_kwargs(args, kwargs) + setattr(self, attr_name, attr) + return method(self, *args, **kwargs) + + return wrapper + + +def except_(*exceptions, replace=None, use=None): + """ + Replace the indicated exceptions, if raised, with the indicated + literal replacement or evaluated expression (if present). + + >>> safe_int = except_(ValueError)(int) + >>> safe_int('five') + >>> safe_int('5') + 5 + + Specify a literal replacement with ``replace``. + + >>> safe_int_r = except_(ValueError, replace=0)(int) + >>> safe_int_r('five') + 0 + + Provide an expression to ``use`` to pass through particular parameters. + + >>> safe_int_pt = except_(ValueError, use='args[0]')(int) + >>> safe_int_pt('five') + 'five' + + """ + + def decorate(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except exceptions: + try: + return eval(use) + except TypeError: + return replace + + return wrapper + + return decorate diff --git a/libs/win/jaraco/structures/binary.py b/libs/win/jaraco/structures/binary.py index be57cc76..c4cbbeda 100644 --- a/libs/win/jaraco/structures/binary.py +++ b/libs/win/jaraco/structures/binary.py @@ -1,151 +1,156 @@ -from __future__ import absolute_import, unicode_literals - import numbers from functools import reduce def get_bit_values(number, size=32): - """ - Get bit values as a list for a given number + """ + Get bit values as a list for a given number - >>> get_bit_values(1) == [0]*31 + [1] - True + >>> get_bit_values(1) == [0]*31 + [1] + True - >>> get_bit_values(0xDEADBEEF) - [1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1] + >>> get_bit_values(0xDEADBEEF) + [1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, \ +1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1] - You may override the default word size of 32-bits to match your actual - application. + You may override the default word size of 32-bits to match your actual + application. - >>> get_bit_values(0x3, 2) - [1, 1] + >>> get_bit_values(0x3, 2) + [1, 1] - >>> get_bit_values(0x3, 4) - [0, 0, 1, 1] - """ - number += 2**size - return list(map(int, bin(number)[-size:])) + >>> get_bit_values(0x3, 4) + [0, 0, 1, 1] + """ + number += 2 ** size + return list(map(int, bin(number)[-size:])) def gen_bit_values(number): - """ - Return a zero or one for each bit of a numeric value up to the most - significant 1 bit, beginning with the least significant bit. + """ + Return a zero or one for each bit of a numeric value up to the most + significant 1 bit, beginning with the least significant bit. - >>> list(gen_bit_values(16)) - [0, 0, 0, 0, 1] - """ - digits = bin(number)[2:] - return map(int, reversed(digits)) + >>> list(gen_bit_values(16)) + [0, 0, 0, 0, 1] + """ + digits = bin(number)[2:] + return map(int, reversed(digits)) def coalesce(bits): - """ - Take a sequence of bits, most significant first, and - coalesce them into a number. + """ + Take a sequence of bits, most significant first, and + coalesce them into a number. - >>> coalesce([1,0,1]) - 5 - """ - operation = lambda a, b: (a << 1 | b) - return reduce(operation, bits) + >>> coalesce([1,0,1]) + 5 + """ + + def operation(a, b): + return a << 1 | b + + return reduce(operation, bits) -class Flags(object): - """ - Subclasses should define _names, a list of flag names beginning - with the least-significant bit. +class Flags: + """ + Subclasses should define _names, a list of flag names beginning + with the least-significant bit. - >>> class MyFlags(Flags): - ... _names = 'a', 'b', 'c' - >>> mf = MyFlags.from_number(5) - >>> mf['a'] - 1 - >>> mf['b'] - 0 - >>> mf['c'] == mf[2] - True - >>> mf['b'] = 1 - >>> mf['a'] = 0 - >>> mf.number - 6 - """ - def __init__(self, values): - self._values = list(values) - if hasattr(self, '_names'): - n_missing_bits = len(self._names) - len(self._values) - self._values.extend([0] * n_missing_bits) + >>> class MyFlags(Flags): + ... _names = 'a', 'b', 'c' + >>> mf = MyFlags.from_number(5) + >>> mf['a'] + 1 + >>> mf['b'] + 0 + >>> mf['c'] == mf[2] + True + >>> mf['b'] = 1 + >>> mf['a'] = 0 + >>> mf.number + 6 + """ - @classmethod - def from_number(cls, number): - return cls(gen_bit_values(number)) + def __init__(self, values): + self._values = list(values) + if hasattr(self, '_names'): + n_missing_bits = len(self._names) - len(self._values) + self._values.extend([0] * n_missing_bits) - @property - def number(self): - return coalesce(reversed(self._values)) + @classmethod + def from_number(cls, number): + return cls(gen_bit_values(number)) - def __setitem__(self, key, value): - # first try by index, then by name - try: - self._values[key] = value - except TypeError: - index = self._names.index(key) - self._values[index] = value + @property + def number(self): + return coalesce(reversed(self._values)) - def __getitem__(self, key): - # first try by index, then by name - try: - return self._values[key] - except TypeError: - index = self._names.index(key) - return self._values[index] + def __setitem__(self, key, value): + # first try by index, then by name + try: + self._values[key] = value + except TypeError: + index = self._names.index(key) + self._values[index] = value + + def __getitem__(self, key): + # first try by index, then by name + try: + return self._values[key] + except TypeError: + index = self._names.index(key) + return self._values[index] class BitMask(type): - """ - A metaclass to create a bitmask with attributes. Subclass an int and - set this as the metaclass to use. + """ + A metaclass to create a bitmask with attributes. Subclass an int and + set this as the metaclass to use. - Here's how to create such a class on Python 3: + Construct such a class: - class MyBits(int, metaclass=BitMask): - a = 0x1 - b = 0x4 - c = 0x3 + >>> class MyBits(int, metaclass=BitMask): + ... a = 0x1 + ... b = 0x4 + ... c = 0x3 - For testing purposes, construct explicitly to support Python 2 + >>> b1 = MyBits(3) + >>> b1.a, b1.b, b1.c + (True, False, True) + >>> b2 = MyBits(8) + >>> any([b2.a, b2.b, b2.c]) + False - >>> ns = dict(a=0x1, b=0x4, c=0x3) - >>> MyBits = BitMask(str('MyBits'), (int,), ns) + If the instance defines methods, they won't be wrapped in + properties. - >>> b1 = MyBits(3) - >>> b1.a, b1.b, b1.c - (True, False, True) - >>> b2 = MyBits(8) - >>> any([b2.a, b2.b, b2.c]) - False + >>> class MyBits(int, metaclass=BitMask): + ... a = 0x1 + ... b = 0x4 + ... c = 0x3 + ... + ... @classmethod + ... def get_value(cls): + ... return 'some value' + ... + ... @property + ... def prop(cls): + ... return 'a property' + >>> MyBits(3).get_value() + 'some value' + >>> MyBits(3).prop + 'a property' + """ - If the instance defines methods, they won't be wrapped in - properties. + def __new__(cls, name, bases, attrs): + def make_property(name, value): + if name.startswith('_') or not isinstance(value, numbers.Number): + return value + return property(lambda self, value=value: bool(self & value)) - >>> ns['get_value'] = classmethod(lambda cls: 'some value') - >>> ns['prop'] = property(lambda self: 'a property') - >>> MyBits = BitMask(str('MyBits'), (int,), ns) - - >>> MyBits(3).get_value() - 'some value' - >>> MyBits(3).prop - 'a property' - """ - - def __new__(cls, name, bases, attrs): - def make_property(name, value): - if name.startswith('_') or not isinstance(value, numbers.Number): - return value - return property(lambda self, value=value: bool(self & value)) - - newattrs = dict( - (name, make_property(name, value)) - for name, value in attrs.items() - ) - return type.__new__(cls, name, bases, newattrs) + newattrs = dict( + (name, make_property(name, value)) for name, value in attrs.items() + ) + return type.__new__(cls, name, bases, newattrs) diff --git a/libs/win/jaraco/text.py b/libs/win/jaraco/text.py deleted file mode 100644 index 71b4b0bc..00000000 --- a/libs/win/jaraco/text.py +++ /dev/null @@ -1,452 +0,0 @@ -from __future__ import absolute_import, unicode_literals, print_function - -import sys -import re -import inspect -import itertools -import textwrap -import functools - -import six - -import jaraco.collections -from jaraco.functools import compose - - -def substitution(old, new): - """ - Return a function that will perform a substitution on a string - """ - return lambda s: s.replace(old, new) - - -def multi_substitution(*substitutions): - """ - Take a sequence of pairs specifying substitutions, and create - a function that performs those substitutions. - - >>> multi_substitution(('foo', 'bar'), ('bar', 'baz'))('foo') - 'baz' - """ - substitutions = itertools.starmap(substitution, substitutions) - # compose function applies last function first, so reverse the - # substitutions to get the expected order. - substitutions = reversed(tuple(substitutions)) - return compose(*substitutions) - - -class FoldedCase(six.text_type): - """ - A case insensitive string class; behaves just like str - except compares equal when the only variation is case. - - >>> s = FoldedCase('hello world') - - >>> s == 'Hello World' - True - - >>> 'Hello World' == s - True - - >>> s != 'Hello World' - False - - >>> s.index('O') - 4 - - >>> s.split('O') - ['hell', ' w', 'rld'] - - >>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta'])) - ['alpha', 'Beta', 'GAMMA'] - - Sequence membership is straightforward. - - >>> "Hello World" in [s] - True - >>> s in ["Hello World"] - True - - You may test for set inclusion, but candidate and elements - must both be folded. - - >>> FoldedCase("Hello World") in {s} - True - >>> s in {FoldedCase("Hello World")} - True - - String inclusion works as long as the FoldedCase object - is on the right. - - >>> "hello" in FoldedCase("Hello World") - True - - But not if the FoldedCase object is on the left: - - >>> FoldedCase('hello') in 'Hello World' - False - - In that case, use in_: - - >>> FoldedCase('hello').in_('Hello World') - True - - """ - def __lt__(self, other): - return self.lower() < other.lower() - - def __gt__(self, other): - return self.lower() > other.lower() - - def __eq__(self, other): - return self.lower() == other.lower() - - def __ne__(self, other): - return self.lower() != other.lower() - - def __hash__(self): - return hash(self.lower()) - - def __contains__(self, other): - return super(FoldedCase, self).lower().__contains__(other.lower()) - - def in_(self, other): - "Does self appear in other?" - return self in FoldedCase(other) - - # cache lower since it's likely to be called frequently. - @jaraco.functools.method_cache - def lower(self): - return super(FoldedCase, self).lower() - - def index(self, sub): - return self.lower().index(sub.lower()) - - def split(self, splitter=' ', maxsplit=0): - pattern = re.compile(re.escape(splitter), re.I) - return pattern.split(self, maxsplit) - - -def local_format(string): - """ - format the string using variables in the caller's local namespace. - - >>> a = 3 - >>> local_format("{a:5}") - ' 3' - """ - context = inspect.currentframe().f_back.f_locals - if sys.version_info < (3, 2): - return string.format(**context) - return string.format_map(context) - - -def global_format(string): - """ - format the string using variables in the caller's global namespace. - - >>> a = 3 - >>> fmt = "The func name: {global_format.__name__}" - >>> global_format(fmt) - 'The func name: global_format' - """ - context = inspect.currentframe().f_back.f_globals - if sys.version_info < (3, 2): - return string.format(**context) - return string.format_map(context) - - -def namespace_format(string): - """ - Format the string using variable in the caller's scope (locals + globals). - - >>> a = 3 - >>> fmt = "A is {a} and this func is {namespace_format.__name__}" - >>> namespace_format(fmt) - 'A is 3 and this func is namespace_format' - """ - context = jaraco.collections.DictStack() - context.push(inspect.currentframe().f_back.f_globals) - context.push(inspect.currentframe().f_back.f_locals) - if sys.version_info < (3, 2): - return string.format(**context) - return string.format_map(context) - - -def is_decodable(value): - r""" - Return True if the supplied value is decodable (using the default - encoding). - - >>> is_decodable(b'\xff') - False - >>> is_decodable(b'\x32') - True - """ - # TODO: This code could be expressed more consisely and directly - # with a jaraco.context.ExceptionTrap, but that adds an unfortunate - # long dependency tree, so for now, use boolean literals. - try: - value.decode() - except UnicodeDecodeError: - return False - return True - - -def is_binary(value): - """ - Return True if the value appears to be binary (that is, it's a byte - string and isn't decodable). - """ - return isinstance(value, bytes) and not is_decodable(value) - - -def trim(s): - r""" - Trim something like a docstring to remove the whitespace that - is common due to indentation and formatting. - - >>> trim("\n\tfoo = bar\n\t\tbar = baz\n") - 'foo = bar\n\tbar = baz' - """ - return textwrap.dedent(s).strip() - - -class Splitter(object): - """object that will split a string with the given arguments for each call - - >>> s = Splitter(',') - >>> s('hello, world, this is your, master calling') - ['hello', ' world', ' this is your', ' master calling'] - """ - def __init__(self, *args): - self.args = args - - def __call__(self, s): - return s.split(*self.args) - - -def indent(string, prefix=' ' * 4): - return prefix + string - - -class WordSet(tuple): - """ - Given a Python identifier, return the words that identifier represents, - whether in camel case, underscore-separated, etc. - - >>> WordSet.parse("camelCase") - ('camel', 'Case') - - >>> WordSet.parse("under_sep") - ('under', 'sep') - - Acronyms should be retained - - >>> WordSet.parse("firstSNL") - ('first', 'SNL') - - >>> WordSet.parse("you_and_I") - ('you', 'and', 'I') - - >>> WordSet.parse("A simple test") - ('A', 'simple', 'test') - - Multiple caps should not interfere with the first cap of another word. - - >>> WordSet.parse("myABCClass") - ('my', 'ABC', 'Class') - - The result is a WordSet, so you can get the form you need. - - >>> WordSet.parse("myABCClass").underscore_separated() - 'my_ABC_Class' - - >>> WordSet.parse('a-command').camel_case() - 'ACommand' - - >>> WordSet.parse('someIdentifier').lowered().space_separated() - 'some identifier' - - Slices of the result should return another WordSet. - - >>> WordSet.parse('taken-out-of-context')[1:].underscore_separated() - 'out_of_context' - - >>> WordSet.from_class_name(WordSet()).lowered().space_separated() - 'word set' - """ - _pattern = re.compile('([A-Z]?[a-z]+)|([A-Z]+(?![a-z]))') - - def capitalized(self): - return WordSet(word.capitalize() for word in self) - - def lowered(self): - return WordSet(word.lower() for word in self) - - def camel_case(self): - return ''.join(self.capitalized()) - - def headless_camel_case(self): - words = iter(self) - first = next(words).lower() - return itertools.chain((first,), WordSet(words).camel_case()) - - def underscore_separated(self): - return '_'.join(self) - - def dash_separated(self): - return '-'.join(self) - - def space_separated(self): - return ' '.join(self) - - def __getitem__(self, item): - result = super(WordSet, self).__getitem__(item) - if isinstance(item, slice): - result = WordSet(result) - return result - - # for compatibility with Python 2 - def __getslice__(self, i, j): - return self.__getitem__(slice(i, j)) - - @classmethod - def parse(cls, identifier): - matches = cls._pattern.finditer(identifier) - return WordSet(match.group(0) for match in matches) - - @classmethod - def from_class_name(cls, subject): - return cls.parse(subject.__class__.__name__) - - -# for backward compatibility -words = WordSet.parse - - -def simple_html_strip(s): - r""" - Remove HTML from the string `s`. - - >>> str(simple_html_strip('')) - '' - - >>> print(simple_html_strip('A stormy day in paradise')) - A stormy day in paradise - - >>> print(simple_html_strip('Somebody tell the truth.')) - Somebody tell the truth. - - >>> print(simple_html_strip('What about
\nmultiple lines?')) - What about - multiple lines? - """ - html_stripper = re.compile('()|(<[^>]*>)|([^<]+)', re.DOTALL) - texts = ( - match.group(3) or '' - for match - in html_stripper.finditer(s) - ) - return ''.join(texts) - - -class SeparatedValues(six.text_type): - """ - A string separated by a separator. Overrides __iter__ for getting - the values. - - >>> list(SeparatedValues('a,b,c')) - ['a', 'b', 'c'] - - Whitespace is stripped and empty values are discarded. - - >>> list(SeparatedValues(' a, b , c, ')) - ['a', 'b', 'c'] - """ - separator = ',' - - def __iter__(self): - parts = self.split(self.separator) - return six.moves.filter(None, (part.strip() for part in parts)) - - -class Stripper: - r""" - Given a series of lines, find the common prefix and strip it from them. - - >>> lines = [ - ... 'abcdefg\n', - ... 'abc\n', - ... 'abcde\n', - ... ] - >>> res = Stripper.strip_prefix(lines) - >>> res.prefix - 'abc' - >>> list(res.lines) - ['defg\n', '\n', 'de\n'] - - If no prefix is common, nothing should be stripped. - - >>> lines = [ - ... 'abcd\n', - ... '1234\n', - ... ] - >>> res = Stripper.strip_prefix(lines) - >>> res.prefix = '' - >>> list(res.lines) - ['abcd\n', '1234\n'] - """ - def __init__(self, prefix, lines): - self.prefix = prefix - self.lines = map(self, lines) - - @classmethod - def strip_prefix(cls, lines): - prefix_lines, lines = itertools.tee(lines) - prefix = functools.reduce(cls.common_prefix, prefix_lines) - return cls(prefix, lines) - - def __call__(self, line): - if not self.prefix: - return line - null, prefix, rest = line.partition(self.prefix) - return rest - - @staticmethod - def common_prefix(s1, s2): - """ - Return the common prefix of two lines. - """ - index = min(len(s1), len(s2)) - while s1[:index] != s2[:index]: - index -= 1 - return s1[:index] - - -def remove_prefix(text, prefix): - """ - Remove the prefix from the text if it exists. - - >>> remove_prefix('underwhelming performance', 'underwhelming ') - 'performance' - - >>> remove_prefix('something special', 'sample') - 'something special' - """ - null, prefix, rest = text.rpartition(prefix) - return rest - - -def remove_suffix(text, suffix): - """ - Remove the suffix from the text if it exists. - - >>> remove_suffix('name.git', '.git') - 'name' - - >>> remove_suffix('something special', 'sample') - 'something special' - """ - rest, suffix, null = text.partition(suffix) - return rest diff --git a/libs/win/jaraco/text/Lorem ipsum.txt b/libs/win/jaraco/text/Lorem ipsum.txt new file mode 100644 index 00000000..986f944b --- /dev/null +++ b/libs/win/jaraco/text/Lorem ipsum.txt @@ -0,0 +1,2 @@ +Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. +Curabitur pretium tincidunt lacus. Nulla gravida orci a odio. Nullam varius, turpis et commodo pharetra, est eros bibendum elit, nec luctus magna felis sollicitudin mauris. Integer in mauris eu nibh euismod gravida. Duis ac tellus et risus vulputate vehicula. Donec lobortis risus a elit. Etiam tempor. Ut ullamcorper, ligula eu tempor congue, eros est euismod turpis, id tincidunt sapien risus a quam. Maecenas fermentum consequat mi. Donec fermentum. Pellentesque malesuada nulla a mi. Duis sapien sem, aliquet nec, commodo eget, consequat quis, neque. Aliquam faucibus, elit ut dictum aliquet, felis nisl adipiscing sapien, sed malesuada diam lacus eget erat. Cras mollis scelerisque nunc. Nullam arcu. Aliquam consequat. Curabitur augue lorem, dapibus quis, laoreet et, pretium ac, nisi. Aenean magna nisl, mollis quis, molestie eu, feugiat in, orci. In hac habitasse platea dictumst. diff --git a/libs/win/jaraco/text/__init__.py b/libs/win/jaraco/text/__init__.py new file mode 100644 index 00000000..e51101c2 --- /dev/null +++ b/libs/win/jaraco/text/__init__.py @@ -0,0 +1,622 @@ +import re +import itertools +import textwrap +import functools + +try: + from importlib.resources import files # type: ignore +except ImportError: # pragma: nocover + from importlib_resources import files # type: ignore + +from jaraco.functools import compose, method_cache +from jaraco.context import ExceptionTrap + + +def substitution(old, new): + """ + Return a function that will perform a substitution on a string + """ + return lambda s: s.replace(old, new) + + +def multi_substitution(*substitutions): + """ + Take a sequence of pairs specifying substitutions, and create + a function that performs those substitutions. + + >>> multi_substitution(('foo', 'bar'), ('bar', 'baz'))('foo') + 'baz' + """ + substitutions = itertools.starmap(substitution, substitutions) + # compose function applies last function first, so reverse the + # substitutions to get the expected order. + substitutions = reversed(tuple(substitutions)) + return compose(*substitutions) + + +class FoldedCase(str): + """ + A case insensitive string class; behaves just like str + except compares equal when the only variation is case. + + >>> s = FoldedCase('hello world') + + >>> s == 'Hello World' + True + + >>> 'Hello World' == s + True + + >>> s != 'Hello World' + False + + >>> s.index('O') + 4 + + >>> s.split('O') + ['hell', ' w', 'rld'] + + >>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta'])) + ['alpha', 'Beta', 'GAMMA'] + + Sequence membership is straightforward. + + >>> "Hello World" in [s] + True + >>> s in ["Hello World"] + True + + Allows testing for set inclusion, but candidate and elements + must both be folded. + + >>> FoldedCase("Hello World") in {s} + True + >>> s in {FoldedCase("Hello World")} + True + + String inclusion works as long as the FoldedCase object + is on the right. + + >>> "hello" in FoldedCase("Hello World") + True + + But not if the FoldedCase object is on the left: + + >>> FoldedCase('hello') in 'Hello World' + False + + In that case, use ``in_``: + + >>> FoldedCase('hello').in_('Hello World') + True + + >>> FoldedCase('hello') > FoldedCase('Hello') + False + + >>> FoldedCase('ß') == FoldedCase('ss') + True + """ + + def __lt__(self, other): + return self.casefold() < other.casefold() + + def __gt__(self, other): + return self.casefold() > other.casefold() + + def __eq__(self, other): + return self.casefold() == other.casefold() + + def __ne__(self, other): + return self.casefold() != other.casefold() + + def __hash__(self): + return hash(self.casefold()) + + def __contains__(self, other): + return super().casefold().__contains__(other.casefold()) + + def in_(self, other): + "Does self appear in other?" + return self in FoldedCase(other) + + # cache casefold since it's likely to be called frequently. + @method_cache + def casefold(self): + return super().casefold() + + def index(self, sub): + return self.casefold().index(sub.casefold()) + + def split(self, splitter=' ', maxsplit=0): + pattern = re.compile(re.escape(splitter), re.I) + return pattern.split(self, maxsplit) + + +# Python 3.8 compatibility +_unicode_trap = ExceptionTrap(UnicodeDecodeError) + + +@_unicode_trap.passes +def is_decodable(value): + r""" + Return True if the supplied value is decodable (using the default + encoding). + + >>> is_decodable(b'\xff') + False + >>> is_decodable(b'\x32') + True + """ + value.decode() + + +def is_binary(value): + r""" + Return True if the value appears to be binary (that is, it's a byte + string and isn't decodable). + + >>> is_binary(b'\xff') + True + >>> is_binary('\xff') + False + """ + return isinstance(value, bytes) and not is_decodable(value) + + +def trim(s): + r""" + Trim something like a docstring to remove the whitespace that + is common due to indentation and formatting. + + >>> trim("\n\tfoo = bar\n\t\tbar = baz\n") + 'foo = bar\n\tbar = baz' + """ + return textwrap.dedent(s).strip() + + +def wrap(s): + """ + Wrap lines of text, retaining existing newlines as + paragraph markers. + + >>> print(wrap(lorem_ipsum)) + Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do + eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad + minim veniam, quis nostrud exercitation ullamco laboris nisi ut + aliquip ex ea commodo consequat. Duis aute irure dolor in + reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla + pariatur. Excepteur sint occaecat cupidatat non proident, sunt in + culpa qui officia deserunt mollit anim id est laborum. + + Curabitur pretium tincidunt lacus. Nulla gravida orci a odio. Nullam + varius, turpis et commodo pharetra, est eros bibendum elit, nec luctus + magna felis sollicitudin mauris. Integer in mauris eu nibh euismod + gravida. Duis ac tellus et risus vulputate vehicula. Donec lobortis + risus a elit. Etiam tempor. Ut ullamcorper, ligula eu tempor congue, + eros est euismod turpis, id tincidunt sapien risus a quam. Maecenas + fermentum consequat mi. Donec fermentum. Pellentesque malesuada nulla + a mi. Duis sapien sem, aliquet nec, commodo eget, consequat quis, + neque. Aliquam faucibus, elit ut dictum aliquet, felis nisl adipiscing + sapien, sed malesuada diam lacus eget erat. Cras mollis scelerisque + nunc. Nullam arcu. Aliquam consequat. Curabitur augue lorem, dapibus + quis, laoreet et, pretium ac, nisi. Aenean magna nisl, mollis quis, + molestie eu, feugiat in, orci. In hac habitasse platea dictumst. + """ + paragraphs = s.splitlines() + wrapped = ('\n'.join(textwrap.wrap(para)) for para in paragraphs) + return '\n\n'.join(wrapped) + + +def unwrap(s): + r""" + Given a multi-line string, return an unwrapped version. + + >>> wrapped = wrap(lorem_ipsum) + >>> wrapped.count('\n') + 20 + >>> unwrapped = unwrap(wrapped) + >>> unwrapped.count('\n') + 1 + >>> print(unwrapped) + Lorem ipsum dolor sit amet, consectetur adipiscing ... + Curabitur pretium tincidunt lacus. Nulla gravida orci ... + + """ + paragraphs = re.split(r'\n\n+', s) + cleaned = (para.replace('\n', ' ') for para in paragraphs) + return '\n'.join(cleaned) + + +lorem_ipsum: str = files(__name__).joinpath('Lorem ipsum.txt').read_text() + + +class Splitter(object): + """object that will split a string with the given arguments for each call + + >>> s = Splitter(',') + >>> s('hello, world, this is your, master calling') + ['hello', ' world', ' this is your', ' master calling'] + """ + + def __init__(self, *args): + self.args = args + + def __call__(self, s): + return s.split(*self.args) + + +def indent(string, prefix=' ' * 4): + """ + >>> indent('foo') + ' foo' + """ + return prefix + string + + +class WordSet(tuple): + """ + Given an identifier, return the words that identifier represents, + whether in camel case, underscore-separated, etc. + + >>> WordSet.parse("camelCase") + ('camel', 'Case') + + >>> WordSet.parse("under_sep") + ('under', 'sep') + + Acronyms should be retained + + >>> WordSet.parse("firstSNL") + ('first', 'SNL') + + >>> WordSet.parse("you_and_I") + ('you', 'and', 'I') + + >>> WordSet.parse("A simple test") + ('A', 'simple', 'test') + + Multiple caps should not interfere with the first cap of another word. + + >>> WordSet.parse("myABCClass") + ('my', 'ABC', 'Class') + + The result is a WordSet, providing access to other forms. + + >>> WordSet.parse("myABCClass").underscore_separated() + 'my_ABC_Class' + + >>> WordSet.parse('a-command').camel_case() + 'ACommand' + + >>> WordSet.parse('someIdentifier').lowered().space_separated() + 'some identifier' + + Slices of the result should return another WordSet. + + >>> WordSet.parse('taken-out-of-context')[1:].underscore_separated() + 'out_of_context' + + >>> WordSet.from_class_name(WordSet()).lowered().space_separated() + 'word set' + + >>> example = WordSet.parse('figured it out') + >>> example.headless_camel_case() + 'figuredItOut' + >>> example.dash_separated() + 'figured-it-out' + + """ + + _pattern = re.compile('([A-Z]?[a-z]+)|([A-Z]+(?![a-z]))') + + def capitalized(self): + return WordSet(word.capitalize() for word in self) + + def lowered(self): + return WordSet(word.lower() for word in self) + + def camel_case(self): + return ''.join(self.capitalized()) + + def headless_camel_case(self): + words = iter(self) + first = next(words).lower() + new_words = itertools.chain((first,), WordSet(words).camel_case()) + return ''.join(new_words) + + def underscore_separated(self): + return '_'.join(self) + + def dash_separated(self): + return '-'.join(self) + + def space_separated(self): + return ' '.join(self) + + def trim_right(self, item): + """ + Remove the item from the end of the set. + + >>> WordSet.parse('foo bar').trim_right('foo') + ('foo', 'bar') + >>> WordSet.parse('foo bar').trim_right('bar') + ('foo',) + >>> WordSet.parse('').trim_right('bar') + () + """ + return self[:-1] if self and self[-1] == item else self + + def trim_left(self, item): + """ + Remove the item from the beginning of the set. + + >>> WordSet.parse('foo bar').trim_left('foo') + ('bar',) + >>> WordSet.parse('foo bar').trim_left('bar') + ('foo', 'bar') + >>> WordSet.parse('').trim_left('bar') + () + """ + return self[1:] if self and self[0] == item else self + + def trim(self, item): + """ + >>> WordSet.parse('foo bar').trim('foo') + ('bar',) + """ + return self.trim_left(item).trim_right(item) + + def __getitem__(self, item): + result = super(WordSet, self).__getitem__(item) + if isinstance(item, slice): + result = WordSet(result) + return result + + @classmethod + def parse(cls, identifier): + matches = cls._pattern.finditer(identifier) + return WordSet(match.group(0) for match in matches) + + @classmethod + def from_class_name(cls, subject): + return cls.parse(subject.__class__.__name__) + + +# for backward compatibility +words = WordSet.parse + + +def simple_html_strip(s): + r""" + Remove HTML from the string `s`. + + >>> str(simple_html_strip('')) + '' + + >>> print(simple_html_strip('A stormy day in paradise')) + A stormy day in paradise + + >>> print(simple_html_strip('Somebody tell the truth.')) + Somebody tell the truth. + + >>> print(simple_html_strip('What about
\nmultiple lines?')) + What about + multiple lines? + """ + html_stripper = re.compile('()|(<[^>]*>)|([^<]+)', re.DOTALL) + texts = (match.group(3) or '' for match in html_stripper.finditer(s)) + return ''.join(texts) + + +class SeparatedValues(str): + """ + A string separated by a separator. Overrides __iter__ for getting + the values. + + >>> list(SeparatedValues('a,b,c')) + ['a', 'b', 'c'] + + Whitespace is stripped and empty values are discarded. + + >>> list(SeparatedValues(' a, b , c, ')) + ['a', 'b', 'c'] + """ + + separator = ',' + + def __iter__(self): + parts = self.split(self.separator) + return filter(None, (part.strip() for part in parts)) + + +class Stripper: + r""" + Given a series of lines, find the common prefix and strip it from them. + + >>> lines = [ + ... 'abcdefg\n', + ... 'abc\n', + ... 'abcde\n', + ... ] + >>> res = Stripper.strip_prefix(lines) + >>> res.prefix + 'abc' + >>> list(res.lines) + ['defg\n', '\n', 'de\n'] + + If no prefix is common, nothing should be stripped. + + >>> lines = [ + ... 'abcd\n', + ... '1234\n', + ... ] + >>> res = Stripper.strip_prefix(lines) + >>> res.prefix = '' + >>> list(res.lines) + ['abcd\n', '1234\n'] + """ + + def __init__(self, prefix, lines): + self.prefix = prefix + self.lines = map(self, lines) + + @classmethod + def strip_prefix(cls, lines): + prefix_lines, lines = itertools.tee(lines) + prefix = functools.reduce(cls.common_prefix, prefix_lines) + return cls(prefix, lines) + + def __call__(self, line): + if not self.prefix: + return line + null, prefix, rest = line.partition(self.prefix) + return rest + + @staticmethod + def common_prefix(s1, s2): + """ + Return the common prefix of two lines. + """ + index = min(len(s1), len(s2)) + while s1[:index] != s2[:index]: + index -= 1 + return s1[:index] + + +def remove_prefix(text, prefix): + """ + Remove the prefix from the text if it exists. + + >>> remove_prefix('underwhelming performance', 'underwhelming ') + 'performance' + + >>> remove_prefix('something special', 'sample') + 'something special' + """ + null, prefix, rest = text.rpartition(prefix) + return rest + + +def remove_suffix(text, suffix): + """ + Remove the suffix from the text if it exists. + + >>> remove_suffix('name.git', '.git') + 'name' + + >>> remove_suffix('something special', 'sample') + 'something special' + """ + rest, suffix, null = text.partition(suffix) + return rest + + +def normalize_newlines(text): + r""" + Replace alternate newlines with the canonical newline. + + >>> normalize_newlines('Lorem Ipsum\u2029') + 'Lorem Ipsum\n' + >>> normalize_newlines('Lorem Ipsum\r\n') + 'Lorem Ipsum\n' + >>> normalize_newlines('Lorem Ipsum\x85') + 'Lorem Ipsum\n' + """ + newlines = ['\r\n', '\r', '\n', '\u0085', '\u2028', '\u2029'] + pattern = '|'.join(newlines) + return re.sub(pattern, '\n', text) + + +def _nonblank(str): + return str and not str.startswith('#') + + +@functools.singledispatch +def yield_lines(iterable): + r""" + Yield valid lines of a string or iterable. + + >>> list(yield_lines('')) + [] + >>> list(yield_lines(['foo', 'bar'])) + ['foo', 'bar'] + >>> list(yield_lines('foo\nbar')) + ['foo', 'bar'] + >>> list(yield_lines('\nfoo\n#bar\nbaz #comment')) + ['foo', 'baz #comment'] + >>> list(yield_lines(['foo\nbar', 'baz', 'bing\n\n\n'])) + ['foo', 'bar', 'baz', 'bing'] + """ + return itertools.chain.from_iterable(map(yield_lines, iterable)) + + +@yield_lines.register(str) +def _(text): + return filter(_nonblank, map(str.strip, text.splitlines())) + + +def drop_comment(line): + """ + Drop comments. + + >>> drop_comment('foo # bar') + 'foo' + + A hash without a space may be in a URL. + + >>> drop_comment('http://example.com/foo#bar') + 'http://example.com/foo#bar' + """ + return line.partition(' #')[0] + + +def join_continuation(lines): + r""" + Join lines continued by a trailing backslash. + + >>> list(join_continuation(['foo \\', 'bar', 'baz'])) + ['foobar', 'baz'] + >>> list(join_continuation(['foo \\', 'bar', 'baz'])) + ['foobar', 'baz'] + >>> list(join_continuation(['foo \\', 'bar \\', 'baz'])) + ['foobarbaz'] + + Not sure why, but... + The character preceeding the backslash is also elided. + + >>> list(join_continuation(['goo\\', 'dly'])) + ['godly'] + + A terrible idea, but... + If no line is available to continue, suppress the lines. + + >>> list(join_continuation(['foo', 'bar\\', 'baz\\'])) + ['foo'] + """ + lines = iter(lines) + for item in lines: + while item.endswith('\\'): + try: + item = item[:-2].strip() + next(lines) + except StopIteration: + return + yield item + + +def read_newlines(filename, limit=1024): + r""" + >>> tmp_path = getfixture('tmp_path') + >>> filename = tmp_path / 'out.txt' + >>> _ = filename.write_text('foo\n', newline='') + >>> read_newlines(filename) + '\n' + >>> _ = filename.write_text('foo\r\n', newline='') + >>> read_newlines(filename) + '\r\n' + >>> _ = filename.write_text('foo\r\nbar\nbing\r', newline='') + >>> read_newlines(filename) + ('\r', '\n', '\r\n') + """ + with open(filename) as fp: + fp.read(limit) + return fp.newlines diff --git a/libs/win/jaraco/text/layouts.py b/libs/win/jaraco/text/layouts.py new file mode 100644 index 00000000..9636f0f7 --- /dev/null +++ b/libs/win/jaraco/text/layouts.py @@ -0,0 +1,25 @@ +qwerty = "-=qwertyuiop[]asdfghjkl;'zxcvbnm,./_+QWERTYUIOP{}ASDFGHJKL:\"ZXCVBNM<>?" +dvorak = "[]',.pyfgcrl/=aoeuidhtns-;qjkxbmwvz{}\"<>PYFGCRL?+AOEUIDHTNS_:QJKXBMWVZ" + + +to_dvorak = str.maketrans(qwerty, dvorak) +to_qwerty = str.maketrans(dvorak, qwerty) + + +def translate(input, translation): + """ + >>> translate('dvorak', to_dvorak) + 'ekrpat' + >>> translate('qwerty', to_qwerty) + 'x,dokt' + """ + return input.translate(translation) + + +def _translate_stream(stream, translation): + """ + >>> import io + >>> _translate_stream(io.StringIO('foo'), to_dvorak) + urr + """ + print(translate(stream.read(), translation)) diff --git a/libs/win/jaraco/text/show-newlines.py b/libs/win/jaraco/text/show-newlines.py new file mode 100644 index 00000000..2ba32062 --- /dev/null +++ b/libs/win/jaraco/text/show-newlines.py @@ -0,0 +1,33 @@ +import autocommand +import inflect + +from more_itertools import always_iterable + +import jaraco.text + + +def report_newlines(filename): + r""" + Report the newlines in the indicated file. + + >>> tmp_path = getfixture('tmp_path') + >>> filename = tmp_path / 'out.txt' + >>> _ = filename.write_text('foo\nbar\n', newline='') + >>> report_newlines(filename) + newline is '\n' + >>> filename = tmp_path / 'out.txt' + >>> _ = filename.write_text('foo\nbar\r\n', newline='') + >>> report_newlines(filename) + newlines are ('\n', '\r\n') + """ + newlines = jaraco.text.read_newlines(filename) + count = len(tuple(always_iterable(newlines))) + engine = inflect.engine() + print( + engine.plural_noun("newline", count), + engine.plural_verb("is", count), + repr(newlines), + ) + + +autocommand.autocommand(__name__)(report_newlines) diff --git a/libs/win/jaraco/text/strip-prefix.py b/libs/win/jaraco/text/strip-prefix.py new file mode 100644 index 00000000..761717a9 --- /dev/null +++ b/libs/win/jaraco/text/strip-prefix.py @@ -0,0 +1,21 @@ +import sys + +import autocommand + +from jaraco.text import Stripper + + +def strip_prefix(): + r""" + Strip any common prefix from stdin. + + >>> import io, pytest + >>> getfixture('monkeypatch').setattr('sys.stdin', io.StringIO('abcdef\nabc123')) + >>> strip_prefix() + def + 123 + """ + sys.stdout.writelines(Stripper.strip_prefix(sys.stdin).lines) + + +autocommand.autocommand(__name__)(strip_prefix) diff --git a/libs/win/jaraco/text/to-dvorak.py b/libs/win/jaraco/text/to-dvorak.py new file mode 100644 index 00000000..a6d5da80 --- /dev/null +++ b/libs/win/jaraco/text/to-dvorak.py @@ -0,0 +1,6 @@ +import sys + +from . import layouts + + +__name__ == '__main__' and layouts._translate_stream(sys.stdin, layouts.to_dvorak) diff --git a/libs/win/jaraco/text/to-qwerty.py b/libs/win/jaraco/text/to-qwerty.py new file mode 100644 index 00000000..abe27286 --- /dev/null +++ b/libs/win/jaraco/text/to-qwerty.py @@ -0,0 +1,6 @@ +import sys + +from . import layouts + + +__name__ == '__main__' and layouts._translate_stream(sys.stdin, layouts.to_qwerty) diff --git a/libs/win/jaraco/ui/cmdline.py b/libs/win/jaraco/ui/cmdline.py index a7982ddb..8517af7a 100644 --- a/libs/win/jaraco/ui/cmdline.py +++ b/libs/win/jaraco/ui/cmdline.py @@ -1,77 +1,76 @@ import argparse -import six from jaraco.classes import meta -from jaraco import text +from jaraco import text # type: ignore -@six.add_metaclass(meta.LeafClassesMeta) -class Command(object): - """ - A general-purpose base class for creating commands for a command-line - program using argparse. Each subclass of Command represents a separate - sub-command of a program. +class Command(metaclass=meta.LeafClassesMeta): + """ + A general-purpose base class for creating commands for a command-line + program using argparse. Each subclass of Command represents a separate + sub-command of a program. - For example, one might use Command subclasses to implement the Mercurial - command set:: + For example, one might use Command subclasses to implement the Mercurial + command set:: - class Commit(Command): - @staticmethod - def add_arguments(cls, parser): - parser.add_argument('-m', '--message') + class Commit(Command): + @staticmethod + def add_arguments(cls, parser): + parser.add_argument('-m', '--message') - @classmethod - def run(cls, args): - "Run the 'commit' command with args (parsed)" + @classmethod + def run(cls, args): + "Run the 'commit' command with args (parsed)" - class Merge(Command): pass - class Pull(Command): pass - ... + class Merge(Command): pass + class Pull(Command): pass + ... - Then one could create an entry point for Mercurial like so:: + Then one could create an entry point for Mercurial like so:: - def hg_command(): - Command.invoke() - """ + def hg_command(): + Command.invoke() + """ - @classmethod - def add_subparsers(cls, parser): - subparsers = parser.add_subparsers() - [cmd_class.add_parser(subparsers) for cmd_class in cls._leaf_classes] + @classmethod + def add_subparsers(cls, parser): + subparsers = parser.add_subparsers() + [cmd_class.add_parser(subparsers) for cmd_class in cls._leaf_classes] - @classmethod - def add_parser(cls, subparsers): - cmd_string = text.words(cls.__name__).lowered().dash_separated() - parser = subparsers.add_parser(cmd_string) - parser.set_defaults(action=cls) - cls.add_arguments(parser) - return parser + @classmethod + def add_parser(cls, subparsers): + cmd_string = text.words(cls.__name__).lowered().dash_separated() + parser = subparsers.add_parser(cmd_string) + parser.set_defaults(action=cls) + cls.add_arguments(parser) + return parser - @classmethod - def add_arguments(cls, parser): - pass + @classmethod + def add_arguments(cls, parser): + pass - @classmethod - def invoke(cls): - """ - Invoke the command using ArgumentParser - """ - parser = argparse.ArgumentParser() - cls.add_subparsers(parser) - args = parser.parse_args() - args.action.run(args) + @classmethod + def invoke(cls): + """ + Invoke the command using ArgumentParser + """ + parser = argparse.ArgumentParser() + cls.add_subparsers(parser) + args = parser.parse_args() + args.action.run(args) class Extend(argparse.Action): - """ - Argparse action to take an nargs=* argument - and add any values to the existing value. + """ + Argparse action to take an nargs=* argument + and add any values to the existing value. - >>> parser = argparse.ArgumentParser() - >>> _ = parser.add_argument('--foo', nargs='*', default=[], action=Extend) - >>> args = parser.parse_args(['--foo', 'a=1', '--foo', 'b=2', 'c=3']) - >>> args.foo - ['a=1', 'b=2', 'c=3'] - """ - def __call__(self, parser, namespace, values, option_string=None): - getattr(namespace, self.dest).extend(values) + >>> parser = argparse.ArgumentParser() + >>> _ = parser.add_argument('--foo', nargs='*', default=[], action=Extend) + >>> args = parser.parse_args(['--foo', 'a=1', '--foo', 'b=2', 'c=3']) + >>> args.foo + ['a=1', 'b=2', 'c=3'] + """ + + def __call__(self, parser, namespace, values, option_string=None): + getattr(namespace, self.dest).extend(values) diff --git a/libs/win/jaraco/ui/editor.py b/libs/win/jaraco/ui/editor.py index b37c759d..c3102b64 100644 --- a/libs/win/jaraco/ui/editor.py +++ b/libs/win/jaraco/ui/editor.py @@ -1,5 +1,3 @@ -from __future__ import unicode_literals, absolute_import - import tempfile import os import sys @@ -9,100 +7,105 @@ import collections import io import difflib -import six +from typing import Mapping + + +class EditProcessException(RuntimeError): + pass -class EditProcessException(RuntimeError): pass class EditableFile(object): - """ - EditableFile saves some data to a temporary file, launches a - platform editor for interactive editing, and then reloads the data, - setting .changed to True if the data was edited. + """ + EditableFile saves some data to a temporary file, launches a + platform editor for interactive editing, and then reloads the data, + setting .changed to True if the data was edited. - e.g.:: + e.g.:: - x = EditableFile('foo') - x.edit() + x = EditableFile('foo') + x.edit() - if x.changed: - print(x.data) + if x.changed: + print(x.data) - The EDITOR environment variable can define which executable to use - (also XML_EDITOR if the content-type to edit includes 'xml'). If no - EDITOR is defined, defaults to 'notepad' on Windows and 'edit' on - other platforms. - """ - platform_default_editors = collections.defaultdict( - lambda: 'edit', - win32 = 'notepad', - linux2 = 'vi', - ) - encoding = 'utf-8' + The EDITOR environment variable can define which executable to use + (also XML_EDITOR if the content-type to edit includes 'xml'). If no + EDITOR is defined, defaults to 'notepad' on Windows and 'edit' on + other platforms. + """ - def __init__(self, data='', content_type='text/plain'): - self.data = six.text_type(data) - self.content_type = content_type + platform_default_editors: Mapping[str, str] = collections.defaultdict( + lambda: 'edit', + win32='notepad', + linux2='vi', + ) + encoding = 'utf-8' - def __enter__(self): - extension = mimetypes.guess_extension(self.content_type) or '' - fobj, self.name = tempfile.mkstemp(extension) - os.write(fobj, self.data.encode(self.encoding)) - os.close(fobj) - return self + def __init__(self, data='', content_type='text/plain'): + self.data = str(data) + self.content_type = content_type - def read(self): - with open(self.name, 'rb') as f: - return f.read().decode(self.encoding) + def __enter__(self): + extension = mimetypes.guess_extension(self.content_type) or '' + fobj, self.name = tempfile.mkstemp(extension) + os.write(fobj, self.data.encode(self.encoding)) + os.close(fobj) + return self - def __exit__(self, *tb_info): - os.remove(self.name) + def read(self): + with open(self.name, 'rb') as f: + return f.read().decode(self.encoding) - def edit(self): - """ - Edit the file - """ - self.changed = False - with self: - editor = self.get_editor() - cmd = [editor, self.name] - try: - res = subprocess.call(cmd) - except Exception as e: - print("Error launching editor %(editor)s" % locals()) - print(e) - return - if res != 0: - msg = '%(editor)s returned error status %(res)d' % locals() - raise EditProcessException(msg) - new_data = self.read() - if new_data != self.data: - self.changed = self._save_diff(self.data, new_data) - self.data = new_data + def __exit__(self, *tb_info): + os.remove(self.name) - @staticmethod - def _search_env(keys): - """ - Search the environment for the supplied keys, returning the first - one found or None if none was found. - """ - matches = (os.environ[key] for key in keys if key in os.environ) - return next(matches, None) + def edit(self): + """ + Edit the file + """ + self.changed = False + with self: + editor = self.get_editor() + cmd = [editor, self.name] + try: + res = subprocess.call(cmd) + except Exception as e: + print("Error launching editor %(editor)s" % locals()) + print(e) + return + if res != 0: + msg = '%(editor)s returned error status %(res)d' % locals() + raise EditProcessException(msg) + new_data = self.read() + if new_data != self.data: + self.changed = self._save_diff(self.data, new_data) + self.data = new_data - def get_editor(self): - """ - Give preference to an XML_EDITOR or EDITOR defined in the - environment. Otherwise use a default editor based on platform. - """ - env_search = ['EDITOR'] - if 'xml' in self.content_type: - env_search.insert(0, 'XML_EDITOR') - default_editor = self.platform_default_editors[sys.platform] - return self._search_env(env_search) or default_editor + @staticmethod + def _search_env(keys): + """ + Search the environment for the supplied keys, returning the first + one found or None if none was found. + """ + matches = (os.environ[key] for key in keys if key in os.environ) + return next(matches, None) - @staticmethod - def _save_diff(*versions): - def get_lines(content): - return list(io.StringIO(content)) - lines = map(get_lines, versions) - diff = difflib.context_diff(*lines) - return tuple(diff) + def get_editor(self): + """ + Give preference to an XML_EDITOR or EDITOR defined in the + environment. Otherwise use a default editor based on platform. + """ + env_search = ['EDITOR'] + if 'xml' in self.content_type: + env_search.insert(0, 'XML_EDITOR') + default_editor = self.platform_default_editors[sys.platform] + return self._search_env(env_search) or default_editor + + @staticmethod + def _save_diff(*versions): + def get_lines(content): + return list(io.StringIO(content)) + + lines = map(get_lines, versions) + diff = difflib.context_diff(*lines) + return tuple(diff) diff --git a/libs/win/jaraco/ui/input.py b/libs/win/jaraco/ui/input.py index 3d108fc0..c7b2c86a 100644 --- a/libs/win/jaraco/ui/input.py +++ b/libs/win/jaraco/ui/input.py @@ -3,24 +3,28 @@ This module currently provides a cross-platform getch function """ try: - # Windows - from msvcrt import getch + # Windows + from msvcrt import getch # type: ignore + + getch # workaround for https://github.com/kevinw/pyflakes/issues/13 except ImportError: - pass + pass try: - # Unix - import sys - import tty - import termios + # Unix + import sys + import tty + import termios + + def getch(): # type: ignore + fd = sys.stdin.fileno() + old = termios.tcgetattr(fd) + try: + tty.setraw(fd) + return sys.stdin.read(1) + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old) + - def getch(): - fd = sys.stdin.fileno() - old = termios.tcgetattr(fd) - try: - tty.setraw(fd) - return sys.stdin.read(1) - finally: - termios.tcsetattr(fd, termios.TCSADRAIN, old) except ImportError: - pass + pass diff --git a/libs/win/jaraco/ui/menu.py b/libs/win/jaraco/ui/menu.py index aede93b3..78687b3f 100644 --- a/libs/win/jaraco/ui/menu.py +++ b/libs/win/jaraco/ui/menu.py @@ -1,34 +1,32 @@ -from __future__ import print_function, absolute_import, unicode_literals - import itertools -import six class Menu(object): - """ - A simple command-line based menu - """ - def __init__(self, choices=None, formatter=str): - self.choices = choices or list() - self.formatter = formatter + """ + A simple command-line based menu + """ - def get_choice(self, prompt="> "): - n = len(self.choices) - number_width = len(str(n)) + 1 - menu_fmt = '{number:{number_width}}) {choice}' - formatted_choices = map(self.formatter, self.choices) - for number, choice in zip(itertools.count(1), formatted_choices): - print(menu_fmt.format(**locals())) - print() - try: - answer = int(six.moves.input(prompt)) - result = self.choices[answer - 1] - except ValueError: - print('invalid selection') - result = None - except IndexError: - print('invalid selection') - result = None - except KeyboardInterrupt: - result = None - return result + def __init__(self, choices=None, formatter=str): + self.choices = choices or list() + self.formatter = formatter + + def get_choice(self, prompt="> "): + n = len(self.choices) + number_width = len(str(n)) + 1 + menu_fmt = '{number:{number_width}}) {choice}' + formatted_choices = map(self.formatter, self.choices) + for number, choice in zip(itertools.count(1), formatted_choices): + print(menu_fmt.format(**locals())) + print() + try: + answer = int(input(prompt)) + result = self.choices[answer - 1] + except ValueError: + print('invalid selection') + result = None + except IndexError: + print('invalid selection') + result = None + except KeyboardInterrupt: + result = None + return result diff --git a/libs/win/jaraco/ui/progress.py b/libs/win/jaraco/ui/progress.py index d083310b..b57b1c85 100644 --- a/libs/win/jaraco/ui/progress.py +++ b/libs/win/jaraco/ui/progress.py @@ -1,152 +1,141 @@ # deprecated -- use TQDM -from __future__ import (print_function, absolute_import, unicode_literals, - division) - import time import sys import itertools import abc import datetime -import six +class AbstractProgressBar(metaclass=abc.ABCMeta): + def __init__(self, unit='', size=70): + """ + Size is the nominal size in characters + """ + self.unit = unit + self.size = size -@six.add_metaclass(abc.ABCMeta) -class AbstractProgressBar(object): - def __init__(self, unit='', size=70): - """ - Size is the nominal size in characters - """ - self.unit = unit - self.size = size + def report(self, amt): + sys.stdout.write('\r%s' % self.get_bar(amt)) + sys.stdout.flush() - def report(self, amt): - sys.stdout.write('\r%s' % self.get_bar(amt)) - sys.stdout.flush() + @abc.abstractmethod + def get_bar(self, amt): + "Return the string to be printed. Should be size >= self.size" - @abc.abstractmethod - def get_bar(self, amt): - "Return the string to be printed. Should be size >= self.size" + def summary(self, str): + return ' (' + self.unit_str(str) + ')' - def summary(self, str): - return ' (' + self.unit_str(str) + ')' + def unit_str(self, str): + if self.unit: + str += ' ' + self.unit + return str - def unit_str(self, str): - if self.unit: - str += ' ' + self.unit - return str + def finish(self): + print() - def finish(self): - print() + def __enter__(self): + self.report(0) + return self - def __enter__(self): - self.report(0) - return self + def __exit__(self, exc, exc_val, tb): + if exc is None: + self.finish() + else: + print() - def __exit__(self, exc, exc_val, tb): - if exc is None: - self.finish() - else: - print() - - def iterate(self, iterable): - """ - Report the status as the iterable is consumed. - """ - with self: - for n, item in enumerate(iterable, 1): - self.report(n) - yield item + def iterate(self, iterable): + """ + Report the status as the iterable is consumed. + """ + with self: + for n, item in enumerate(iterable, 1): + self.report(n) + yield item class SimpleProgressBar(AbstractProgressBar): - _PROG_DISPGLYPH = itertools.cycle(['|', '/', '-', '\\']) + _PROG_DISPGLYPH = itertools.cycle(['|', '/', '-', '\\']) - def get_bar(self, amt): - bar = next(self._PROG_DISPGLYPH) - template = ' [{bar:^{bar_len}}]' - summary = self.summary('{amt}') - template += summary - empty = template.format( - bar='', - bar_len=0, - amt=amt, - ) - bar_len = self.size - len(empty) - return template.format(**locals()) + def get_bar(self, amt): + bar = next(self._PROG_DISPGLYPH) + template = ' [{bar:^{bar_len}}]' + summary = self.summary('{amt}') + template += summary + empty = template.format( + bar='', + bar_len=0, + amt=amt, + ) + bar_len = self.size - len(empty) + return template.format(**locals()) - @classmethod - def demo(cls): - bar3 = cls(unit='cubes', size=30) - with bar3: - for x in six.moves.range(1, 759): - bar3.report(x) - time.sleep(0.01) + @classmethod + def demo(cls): + bar3 = cls(unit='cubes', size=30) + with bar3: + for x in range(1, 759): + bar3.report(x) + time.sleep(0.01) class TargetProgressBar(AbstractProgressBar): - def __init__(self, total=None, unit='', size=70): - """ - Size is the nominal size in characters - """ - self.total = total - super(TargetProgressBar, self).__init__(unit, size) + def __init__(self, total=None, unit='', size=70): + """ + Size is the nominal size in characters + """ + self.total = total + super(TargetProgressBar, self).__init__(unit, size) - def get_bar(self, amt): - template = ' [{bar:<{bar_len}}]' - completed = amt / self.total - percent = int(completed * 100) - percent_str = ' {percent:3}%' - template += percent_str - summary = self.summary('{amt}/{total}') - template += summary - empty = template.format( - total=self.total, - bar='', - bar_len=0, - **locals() - ) - bar_len = self.size - len(empty) - bar = '=' * int(completed * bar_len) - return template.format(total=self.total, **locals()) + def get_bar(self, amt): + template = ' [{bar:<{bar_len}}]' + completed = amt / self.total + percent = int(completed * 100) + percent_str = ' {percent:3}%' + template += percent_str + summary = self.summary('{amt}/{total}') + template += summary + empty = template.format(total=self.total, bar='', bar_len=0, **locals()) + bar_len = self.size - len(empty) + bar = '=' * int(completed * bar_len) + return template.format(total=self.total, **locals()) - @classmethod - def demo(cls): - bar1 = cls(100, 'blocks') - with bar1: - for x in six.moves.range(1, 101): - bar1.report(x) - time.sleep(0.05) + @classmethod + def demo(cls): + bar1 = cls(100, 'blocks') + with bar1: + for x in range(1, 101): + bar1.report(x) + time.sleep(0.05) - bar2 = cls(758, size=50) - with bar2: - for x in six.moves.range(1, 759): - bar2.report(x) - time.sleep(0.01) + bar2 = cls(758, size=50) + with bar2: + for x in range(1, 759): + bar2.report(x) + time.sleep(0.01) - def finish(self): - self.report(self.total) - super(TargetProgressBar, self).finish() + def finish(self): + self.report(self.total) + super(TargetProgressBar, self).finish() def countdown(template, duration=datetime.timedelta(seconds=5)): - """ - Do a countdown for duration, printing the template (which may accept one - positional argument). Template should be something like - ``countdown complete in {} seconds.`` - """ - now = datetime.datetime.now() - deadline = now + duration - remaining = deadline - datetime.datetime.now() - while remaining: - remaining = deadline - datetime.datetime.now() - remaining = max(datetime.timedelta(), remaining) - msg = template.format(remaining.total_seconds()) - print(msg, end=' '*10) - sys.stdout.flush() - time.sleep(.1) - print('\b'*80, end='') - sys.stdout.flush() - print() + """ + Do a countdown for duration, printing the template (which may accept one + positional argument). Template should be something like + ``countdown complete in {} seconds.`` + """ + now = datetime.datetime.now() + deadline = now + duration + remaining = deadline - datetime.datetime.now() + while remaining: + remaining = deadline - datetime.datetime.now() + remaining = max(datetime.timedelta(), remaining) + msg = template.format(remaining.total_seconds()) + print(msg, end=' ' * 10) + sys.stdout.flush() + time.sleep(0.1) + print('\b' * 80, end='') + sys.stdout.flush() + print() diff --git a/libs/win/jaraco/windows/api/clipboard.py b/libs/win/jaraco/windows/api/clipboard.py index d871aaa9..80a751f8 100644 --- a/libs/win/jaraco/windows/api/clipboard.py +++ b/libs/win/jaraco/windows/api/clipboard.py @@ -30,16 +30,16 @@ CF_GDIOBJFIRST = 0x0300 CF_GDIOBJLAST = 0x03FF RegisterClipboardFormat = ctypes.windll.user32.RegisterClipboardFormatW -RegisterClipboardFormat.argtypes = ctypes.wintypes.LPWSTR, +RegisterClipboardFormat.argtypes = (ctypes.wintypes.LPWSTR,) RegisterClipboardFormat.restype = ctypes.wintypes.UINT CF_HTML = RegisterClipboardFormat('HTML Format') EnumClipboardFormats = ctypes.windll.user32.EnumClipboardFormats -EnumClipboardFormats.argtypes = ctypes.wintypes.UINT, +EnumClipboardFormats.argtypes = (ctypes.wintypes.UINT,) EnumClipboardFormats.restype = ctypes.wintypes.UINT GetClipboardData = ctypes.windll.user32.GetClipboardData -GetClipboardData.argtypes = ctypes.wintypes.UINT, +GetClipboardData.argtypes = (ctypes.wintypes.UINT,) GetClipboardData.restype = ctypes.wintypes.HANDLE SetClipboardData = ctypes.windll.user32.SetClipboardData @@ -47,7 +47,7 @@ SetClipboardData.argtypes = ctypes.wintypes.UINT, ctypes.wintypes.HANDLE SetClipboardData.restype = ctypes.wintypes.HANDLE OpenClipboard = ctypes.windll.user32.OpenClipboard -OpenClipboard.argtypes = ctypes.wintypes.HANDLE, +OpenClipboard.argtypes = (ctypes.wintypes.HANDLE,) OpenClipboard.restype = ctypes.wintypes.BOOL ctypes.windll.user32.CloseClipboard.restype = ctypes.wintypes.BOOL diff --git a/libs/win/jaraco/windows/api/credential.py b/libs/win/jaraco/windows/api/credential.py index 003c3cb3..db1deb9e 100644 --- a/libs/win/jaraco/windows/api/credential.py +++ b/libs/win/jaraco/windows/api/credential.py @@ -5,58 +5,52 @@ Support for Credential Vault import ctypes from ctypes.wintypes import DWORD, LPCWSTR, BOOL, LPWSTR, FILETIME + try: - from ctypes.wintypes import LPBYTE + from ctypes.wintypes import LPBYTE except ImportError: - LPBYTE = ctypes.POINTER(ctypes.wintypes.BYTE) + LPBYTE = ctypes.POINTER(ctypes.wintypes.BYTE) # type: ignore class CredentialAttribute(ctypes.Structure): - _fields_ = [] + _fields_ = [] # type: ignore class Credential(ctypes.Structure): - _fields_ = [ - ('flags', DWORD), - ('type', DWORD), - ('target_name', LPWSTR), - ('comment', LPWSTR), - ('last_written', FILETIME), - ('credential_blob_size', DWORD), - ('credential_blob', LPBYTE), - ('persist', DWORD), - ('attribute_count', DWORD), - ('attributes', ctypes.POINTER(CredentialAttribute)), - ('target_alias', LPWSTR), - ('user_name', LPWSTR), - ] + _fields_ = [ + ('flags', DWORD), + ('type', DWORD), + ('target_name', LPWSTR), + ('comment', LPWSTR), + ('last_written', FILETIME), + ('credential_blob_size', DWORD), + ('credential_blob', LPBYTE), + ('persist', DWORD), + ('attribute_count', DWORD), + ('attributes', ctypes.POINTER(CredentialAttribute)), + ('target_alias', LPWSTR), + ('user_name', LPWSTR), + ] - def __del__(self): - ctypes.windll.advapi32.CredFree(ctypes.byref(self)) + def __del__(self): + ctypes.windll.advapi32.CredFree(ctypes.byref(self)) PCREDENTIAL = ctypes.POINTER(Credential) CredRead = ctypes.windll.advapi32.CredReadW CredRead.argtypes = ( - LPCWSTR, # TargetName - DWORD, # Type - DWORD, # Flags - ctypes.POINTER(PCREDENTIAL), # Credential + LPCWSTR, # TargetName + DWORD, # Type + DWORD, # Flags + ctypes.POINTER(PCREDENTIAL), # Credential ) CredRead.restype = BOOL CredWrite = ctypes.windll.advapi32.CredWriteW -CredWrite.argtypes = ( - PCREDENTIAL, # Credential - DWORD, # Flags -) +CredWrite.argtypes = (PCREDENTIAL, DWORD) # Credential # Flags CredWrite.restype = BOOL CredDelete = ctypes.windll.advapi32.CredDeleteW -CredDelete.argtypes = ( - LPCWSTR, # TargetName - DWORD, # Type - DWORD, # Flags -) +CredDelete.argtypes = (LPCWSTR, DWORD, DWORD) # TargetName # Type # Flags CredDelete.restype = BOOL diff --git a/libs/win/jaraco/windows/api/environ.py b/libs/win/jaraco/windows/api/environ.py index f394da02..a7882364 100644 --- a/libs/win/jaraco/windows/api/environ.py +++ b/libs/win/jaraco/windows/api/environ.py @@ -7,7 +7,7 @@ SetEnvironmentVariable.argtypes = [ctypes.wintypes.LPCWSTR] * 2 GetEnvironmentVariable = ctypes.windll.kernel32.GetEnvironmentVariableW GetEnvironmentVariable.restype = ctypes.wintypes.BOOL GetEnvironmentVariable.argtypes = [ - ctypes.wintypes.LPCWSTR, - ctypes.wintypes.LPWSTR, - ctypes.wintypes.DWORD, + ctypes.wintypes.LPCWSTR, + ctypes.wintypes.LPWSTR, + ctypes.wintypes.DWORD, ] diff --git a/libs/win/jaraco/windows/api/event.py b/libs/win/jaraco/windows/api/event.py index 5d2818c6..5b5d9f9d 100644 --- a/libs/win/jaraco/windows/api/event.py +++ b/libs/win/jaraco/windows/api/event.py @@ -1,19 +1,12 @@ from ctypes import windll, POINTER -from ctypes.wintypes import ( - LPWSTR, DWORD, LPVOID, HANDLE, BOOL, -) +from ctypes.wintypes import LPWSTR, DWORD, LPVOID, HANDLE, BOOL CreateEvent = windll.kernel32.CreateEventW -CreateEvent.argtypes = ( - LPVOID, # LPSECURITY_ATTRIBUTES - BOOL, - BOOL, - LPWSTR, -) +CreateEvent.argtypes = (LPVOID, BOOL, BOOL, LPWSTR) # LPSECURITY_ATTRIBUTES CreateEvent.restype = HANDLE SetEvent = windll.kernel32.SetEvent -SetEvent.argtypes = HANDLE, +SetEvent.argtypes = (HANDLE,) SetEvent.restype = BOOL WaitForSingleObject = windll.kernel32.WaitForSingleObject @@ -26,11 +19,11 @@ _WaitForMultipleObjects.restype = DWORD def WaitForMultipleObjects(handles, wait_all=False, timeout=0): - n_handles = len(handles) - handle_array = (HANDLE * n_handles)() - for index, handle in enumerate(handles): - handle_array[index] = handle - return _WaitForMultipleObjects(n_handles, handle_array, wait_all, timeout) + n_handles = len(handles) + handle_array = (HANDLE * n_handles)() + for index, handle in enumerate(handles): + handle_array[index] = handle + return _WaitForMultipleObjects(n_handles, handle_array, wait_all, timeout) WAIT_OBJECT_0 = 0 diff --git a/libs/win/jaraco/windows/api/filesystem.py b/libs/win/jaraco/windows/api/filesystem.py index fbd999de..b06dc6d2 100644 --- a/libs/win/jaraco/windows/api/filesystem.py +++ b/libs/win/jaraco/windows/api/filesystem.py @@ -2,22 +2,24 @@ import ctypes.wintypes CreateSymbolicLink = ctypes.windll.kernel32.CreateSymbolicLinkW CreateSymbolicLink.argtypes = ( - ctypes.wintypes.LPWSTR, - ctypes.wintypes.LPWSTR, - ctypes.wintypes.DWORD, + ctypes.wintypes.LPWSTR, + ctypes.wintypes.LPWSTR, + ctypes.wintypes.DWORD, ) CreateSymbolicLink.restype = ctypes.wintypes.BOOLEAN +SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE = 0x2 + CreateHardLink = ctypes.windll.kernel32.CreateHardLinkW CreateHardLink.argtypes = ( - ctypes.wintypes.LPWSTR, - ctypes.wintypes.LPWSTR, - ctypes.wintypes.LPVOID, # reserved for LPSECURITY_ATTRIBUTES + ctypes.wintypes.LPWSTR, + ctypes.wintypes.LPWSTR, + ctypes.wintypes.LPVOID, # reserved for LPSECURITY_ATTRIBUTES ) CreateHardLink.restype = ctypes.wintypes.BOOLEAN GetFileAttributes = ctypes.windll.kernel32.GetFileAttributesW -GetFileAttributes.argtypes = ctypes.wintypes.LPWSTR, +GetFileAttributes.argtypes = (ctypes.wintypes.LPWSTR,) GetFileAttributes.restype = ctypes.wintypes.DWORD SetFileAttributes = ctypes.windll.kernel32.SetFileAttributesW @@ -28,31 +30,33 @@ MAX_PATH = 260 GetFinalPathNameByHandle = ctypes.windll.kernel32.GetFinalPathNameByHandleW GetFinalPathNameByHandle.argtypes = ( - ctypes.wintypes.HANDLE, ctypes.wintypes.LPWSTR, ctypes.wintypes.DWORD, - ctypes.wintypes.DWORD, + ctypes.wintypes.HANDLE, + ctypes.wintypes.LPWSTR, + ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, ) GetFinalPathNameByHandle.restype = ctypes.wintypes.DWORD class SECURITY_ATTRIBUTES(ctypes.Structure): - _fields_ = ( - ('length', ctypes.wintypes.DWORD), - ('p_security_descriptor', ctypes.wintypes.LPVOID), - ('inherit_handle', ctypes.wintypes.BOOLEAN), - ) + _fields_ = ( + ('length', ctypes.wintypes.DWORD), + ('p_security_descriptor', ctypes.wintypes.LPVOID), + ('inherit_handle', ctypes.wintypes.BOOLEAN), + ) LPSECURITY_ATTRIBUTES = ctypes.POINTER(SECURITY_ATTRIBUTES) CreateFile = ctypes.windll.kernel32.CreateFileW CreateFile.argtypes = ( - ctypes.wintypes.LPWSTR, - ctypes.wintypes.DWORD, - ctypes.wintypes.DWORD, - LPSECURITY_ATTRIBUTES, - ctypes.wintypes.DWORD, - ctypes.wintypes.DWORD, - ctypes.wintypes.HANDLE, + ctypes.wintypes.LPWSTR, + ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, + LPSECURITY_ATTRIBUTES, + ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, + ctypes.wintypes.HANDLE, ) CreateFile.restype = ctypes.wintypes.HANDLE FILE_SHARE_READ = 1 @@ -83,56 +87,56 @@ CloseHandle.restype = ctypes.wintypes.BOOLEAN class WIN32_FIND_DATA(ctypes.wintypes.WIN32_FIND_DATAW): - """ - _fields_ = [ - ("dwFileAttributes", DWORD), - ("ftCreationTime", FILETIME), - ("ftLastAccessTime", FILETIME), - ("ftLastWriteTime", FILETIME), - ("nFileSizeHigh", DWORD), - ("nFileSizeLow", DWORD), - ("dwReserved0", DWORD), - ("dwReserved1", DWORD), - ("cFileName", WCHAR * MAX_PATH), - ("cAlternateFileName", WCHAR * 14)] - ] - """ + """ + _fields_ = [ + ("dwFileAttributes", DWORD), + ("ftCreationTime", FILETIME), + ("ftLastAccessTime", FILETIME), + ("ftLastWriteTime", FILETIME), + ("nFileSizeHigh", DWORD), + ("nFileSizeLow", DWORD), + ("dwReserved0", DWORD), + ("dwReserved1", DWORD), + ("cFileName", WCHAR * MAX_PATH), + ("cAlternateFileName", WCHAR * 14)] + ] + """ - @property - def file_attributes(self): - return self.dwFileAttributes + @property + def file_attributes(self): + return self.dwFileAttributes - @property - def creation_time(self): - return self.ftCreationTime + @property + def creation_time(self): + return self.ftCreationTime - @property - def last_access_time(self): - return self.ftLastAccessTime + @property + def last_access_time(self): + return self.ftLastAccessTime - @property - def last_write_time(self): - return self.ftLastWriteTime + @property + def last_write_time(self): + return self.ftLastWriteTime - @property - def file_size_words(self): - return [self.nFileSizeHigh, self.nFileSizeLow] + @property + def file_size_words(self): + return [self.nFileSizeHigh, self.nFileSizeLow] - @property - def reserved(self): - return [self.dwReserved0, self.dwReserved1] + @property + def reserved(self): + return [self.dwReserved0, self.dwReserved1] - @property - def filename(self): - return self.cFileName + @property + def filename(self): + return self.cFileName - @property - def alternate_filename(self): - return self.cAlternateFileName + @property + def alternate_filename(self): + return self.cAlternateFileName - @property - def file_size(self): - return self.nFileSizeHigh << 32 + self.nFileSizeLow + @property + def file_size(self): + return self.nFileSizeHigh << 32 + self.nFileSizeLow LPWIN32_FIND_DATA = ctypes.POINTER(ctypes.wintypes.WIN32_FIND_DATAW) @@ -144,7 +148,7 @@ FindNextFile = ctypes.windll.kernel32.FindNextFileW FindNextFile.argtypes = (ctypes.wintypes.HANDLE, LPWIN32_FIND_DATA) FindNextFile.restype = ctypes.wintypes.BOOLEAN -ctypes.windll.kernel32.FindClose.argtypes = ctypes.wintypes.HANDLE, +ctypes.windll.kernel32.FindClose.argtypes = (ctypes.wintypes.HANDLE,) SCS_32BIT_BINARY = 0 # A 32-bit Windows-based application SCS_64BIT_BINARY = 6 # A 64-bit Windows-based application @@ -156,7 +160,8 @@ SCS_WOW_BINARY = 2 # A 16-bit Windows-based application _GetBinaryType = ctypes.windll.kernel32.GetBinaryTypeW _GetBinaryType.argtypes = ( - ctypes.wintypes.LPWSTR, ctypes.POINTER(ctypes.wintypes.DWORD), + ctypes.wintypes.LPWSTR, + ctypes.POINTER(ctypes.wintypes.DWORD), ) _GetBinaryType.restype = ctypes.wintypes.BOOL @@ -164,47 +169,47 @@ FILEOP_FLAGS = ctypes.wintypes.WORD class BY_HANDLE_FILE_INFORMATION(ctypes.Structure): - _fields_ = [ - ('file_attributes', ctypes.wintypes.DWORD), - ('creation_time', ctypes.wintypes.FILETIME), - ('last_access_time', ctypes.wintypes.FILETIME), - ('last_write_time', ctypes.wintypes.FILETIME), - ('volume_serial_number', ctypes.wintypes.DWORD), - ('file_size_high', ctypes.wintypes.DWORD), - ('file_size_low', ctypes.wintypes.DWORD), - ('number_of_links', ctypes.wintypes.DWORD), - ('file_index_high', ctypes.wintypes.DWORD), - ('file_index_low', ctypes.wintypes.DWORD), - ] + _fields_ = [ + ('file_attributes', ctypes.wintypes.DWORD), + ('creation_time', ctypes.wintypes.FILETIME), + ('last_access_time', ctypes.wintypes.FILETIME), + ('last_write_time', ctypes.wintypes.FILETIME), + ('volume_serial_number', ctypes.wintypes.DWORD), + ('file_size_high', ctypes.wintypes.DWORD), + ('file_size_low', ctypes.wintypes.DWORD), + ('number_of_links', ctypes.wintypes.DWORD), + ('file_index_high', ctypes.wintypes.DWORD), + ('file_index_low', ctypes.wintypes.DWORD), + ] - @property - def file_size(self): - return (self.file_size_high << 32) + self.file_size_low + @property + def file_size(self): + return (self.file_size_high << 32) + self.file_size_low - @property - def file_index(self): - return (self.file_index_high << 32) + self.file_index_low + @property + def file_index(self): + return (self.file_index_high << 32) + self.file_index_low GetFileInformationByHandle = ctypes.windll.kernel32.GetFileInformationByHandle GetFileInformationByHandle.restype = ctypes.wintypes.BOOL GetFileInformationByHandle.argtypes = ( - ctypes.wintypes.HANDLE, - ctypes.POINTER(BY_HANDLE_FILE_INFORMATION), + ctypes.wintypes.HANDLE, + ctypes.POINTER(BY_HANDLE_FILE_INFORMATION), ) class SHFILEOPSTRUCT(ctypes.Structure): - _fields_ = [ - ('status_dialog', ctypes.wintypes.HWND), - ('operation', ctypes.wintypes.UINT), - ('from_', ctypes.wintypes.LPWSTR), - ('to', ctypes.wintypes.LPWSTR), - ('flags', FILEOP_FLAGS), - ('operations_aborted', ctypes.wintypes.BOOL), - ('name_mapping_handles', ctypes.wintypes.LPVOID), - ('progress_title', ctypes.wintypes.LPWSTR), - ] + _fields_ = [ + ('status_dialog', ctypes.wintypes.HWND), + ('operation', ctypes.wintypes.UINT), + ('from_', ctypes.wintypes.LPWSTR), + ('to', ctypes.wintypes.LPWSTR), + ('flags', FILEOP_FLAGS), + ('operations_aborted', ctypes.wintypes.BOOL), + ('name_mapping_handles', ctypes.wintypes.LPVOID), + ('progress_title', ctypes.wintypes.LPWSTR), + ] _SHFileOperation = ctypes.windll.shell32.SHFileOperationW @@ -218,12 +223,12 @@ FO_DELETE = 3 ReplaceFile = ctypes.windll.kernel32.ReplaceFileW ReplaceFile.restype = ctypes.wintypes.BOOL ReplaceFile.argtypes = [ - ctypes.wintypes.LPWSTR, - ctypes.wintypes.LPWSTR, - ctypes.wintypes.LPWSTR, - ctypes.wintypes.DWORD, - ctypes.wintypes.LPVOID, - ctypes.wintypes.LPVOID, + ctypes.wintypes.LPWSTR, + ctypes.wintypes.LPWSTR, + ctypes.wintypes.LPWSTR, + ctypes.wintypes.DWORD, + ctypes.wintypes.LPVOID, + ctypes.wintypes.LPVOID, ] REPLACEFILE_WRITE_THROUGH = 0x1 @@ -232,20 +237,20 @@ REPLACEFILE_IGNORE_ACL_ERRORS = 0x4 class STAT_STRUCT(ctypes.Structure): - _fields_ = [ - ('dev', ctypes.c_uint), - ('ino', ctypes.c_ushort), - ('mode', ctypes.c_ushort), - ('nlink', ctypes.c_short), - ('uid', ctypes.c_short), - ('gid', ctypes.c_short), - ('rdev', ctypes.c_uint), - # the following 4 fields are ctypes.c_uint64 for _stat64 - ('size', ctypes.c_uint), - ('atime', ctypes.c_uint), - ('mtime', ctypes.c_uint), - ('ctime', ctypes.c_uint), - ] + _fields_ = [ + ('dev', ctypes.c_uint), + ('ino', ctypes.c_ushort), + ('mode', ctypes.c_ushort), + ('nlink', ctypes.c_short), + ('uid', ctypes.c_short), + ('gid', ctypes.c_short), + ('rdev', ctypes.c_uint), + # the following 4 fields are ctypes.c_uint64 for _stat64 + ('size', ctypes.c_uint), + ('atime', ctypes.c_uint), + ('mtime', ctypes.c_uint), + ('ctime', ctypes.c_uint), + ] _wstat = ctypes.windll.msvcrt._wstat @@ -254,64 +259,64 @@ _wstat.restype = ctypes.c_int FILE_NOTIFY_CHANGE_LAST_WRITE = 0x10 -FindFirstChangeNotification = ( - ctypes.windll.kernel32.FindFirstChangeNotificationW) +FindFirstChangeNotification = ctypes.windll.kernel32.FindFirstChangeNotificationW FindFirstChangeNotification.argtypes = ( - ctypes.wintypes.LPWSTR, ctypes.wintypes.BOOL, ctypes.wintypes.DWORD, + ctypes.wintypes.LPWSTR, + ctypes.wintypes.BOOL, + ctypes.wintypes.DWORD, ) FindFirstChangeNotification.restype = ctypes.wintypes.HANDLE -FindCloseChangeNotification = ( - ctypes.windll.kernel32.FindCloseChangeNotification) -FindCloseChangeNotification.argtypes = ctypes.wintypes.HANDLE, +FindCloseChangeNotification = ctypes.windll.kernel32.FindCloseChangeNotification +FindCloseChangeNotification.argtypes = (ctypes.wintypes.HANDLE,) FindCloseChangeNotification.restype = ctypes.wintypes.BOOL FindNextChangeNotification = ctypes.windll.kernel32.FindNextChangeNotification -FindNextChangeNotification.argtypes = ctypes.wintypes.HANDLE, +FindNextChangeNotification.argtypes = (ctypes.wintypes.HANDLE,) FindNextChangeNotification.restype = ctypes.wintypes.BOOL FILE_FLAG_OPEN_REPARSE_POINT = 0x00200000 IO_REPARSE_TAG_SYMLINK = 0xA000000C -FSCTL_GET_REPARSE_POINT = 0x900a8 +FSCTL_GET_REPARSE_POINT = 0x900A8 LPDWORD = ctypes.POINTER(ctypes.wintypes.DWORD) LPOVERLAPPED = ctypes.wintypes.LPVOID DeviceIoControl = ctypes.windll.kernel32.DeviceIoControl DeviceIoControl.argtypes = [ - ctypes.wintypes.HANDLE, - ctypes.wintypes.DWORD, - ctypes.wintypes.LPVOID, - ctypes.wintypes.DWORD, - ctypes.wintypes.LPVOID, - ctypes.wintypes.DWORD, - LPDWORD, - LPOVERLAPPED, + ctypes.wintypes.HANDLE, + ctypes.wintypes.DWORD, + ctypes.wintypes.LPVOID, + ctypes.wintypes.DWORD, + ctypes.wintypes.LPVOID, + ctypes.wintypes.DWORD, + LPDWORD, + LPOVERLAPPED, ] DeviceIoControl.restype = ctypes.wintypes.BOOL class REPARSE_DATA_BUFFER(ctypes.Structure): - _fields_ = [ - ('tag', ctypes.c_ulong), - ('data_length', ctypes.c_ushort), - ('reserved', ctypes.c_ushort), - ('substitute_name_offset', ctypes.c_ushort), - ('substitute_name_length', ctypes.c_ushort), - ('print_name_offset', ctypes.c_ushort), - ('print_name_length', ctypes.c_ushort), - ('flags', ctypes.c_ulong), - ('path_buffer', ctypes.c_byte * 1), - ] + _fields_ = [ + ('tag', ctypes.c_ulong), + ('data_length', ctypes.c_ushort), + ('reserved', ctypes.c_ushort), + ('substitute_name_offset', ctypes.c_ushort), + ('substitute_name_length', ctypes.c_ushort), + ('print_name_offset', ctypes.c_ushort), + ('print_name_length', ctypes.c_ushort), + ('flags', ctypes.c_ulong), + ('path_buffer', ctypes.c_byte * 1), + ] - def get_print_name(self): - wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR) - arr_typ = ctypes.wintypes.WCHAR * (self.print_name_length // wchar_size) - data = ctypes.byref(self.path_buffer, self.print_name_offset) - return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value + def get_print_name(self): + wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR) + arr_typ = ctypes.wintypes.WCHAR * (self.print_name_length // wchar_size) + data = ctypes.byref(self.path_buffer, self.print_name_offset) + return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value - def get_substitute_name(self): - wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR) - arr_typ = ctypes.wintypes.WCHAR * (self.substitute_name_length // wchar_size) - data = ctypes.byref(self.path_buffer, self.substitute_name_offset) - return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value + def get_substitute_name(self): + wchar_size = ctypes.sizeof(ctypes.wintypes.WCHAR) + arr_typ = ctypes.wintypes.WCHAR * (self.substitute_name_length // wchar_size) + data = ctypes.byref(self.path_buffer, self.substitute_name_offset) + return ctypes.cast(data, ctypes.POINTER(arr_typ)).contents.value diff --git a/libs/win/jaraco/windows/api/inet.py b/libs/win/jaraco/windows/api/inet.py index 36c8e37c..f144b0f3 100644 --- a/libs/win/jaraco/windows/api/inet.py +++ b/libs/win/jaraco/windows/api/inet.py @@ -4,11 +4,11 @@ from ctypes.wintypes import DWORD, WCHAR, BYTE, BOOL # from mprapi.h -MAX_INTERFACE_NAME_LEN = 2**8 +MAX_INTERFACE_NAME_LEN = 2 ** 8 # from iprtrmib.h -MAXLEN_PHYSADDR = 2**3 -MAXLEN_IFDESCR = 2**8 +MAXLEN_PHYSADDR = 2 ** 3 +MAXLEN_IFDESCR = 2 ** 8 # from iptypes.h MAX_ADAPTER_ADDRESS_LENGTH = 8 @@ -16,114 +16,102 @@ MAX_DHCPV6_DUID_LENGTH = 130 class MIB_IFROW(ctypes.Structure): - _fields_ = ( - ('name', WCHAR * MAX_INTERFACE_NAME_LEN), - ('index', DWORD), - ('type', DWORD), - ('MTU', DWORD), - ('speed', DWORD), - ('physical_address_length', DWORD), - ('physical_address_raw', BYTE * MAXLEN_PHYSADDR), - ('admin_status', DWORD), - ('operational_status', DWORD), - ('last_change', DWORD), - ('octets_received', DWORD), - ('unicast_packets_received', DWORD), - ('non_unicast_packets_received', DWORD), - ('incoming_discards', DWORD), - ('incoming_errors', DWORD), - ('incoming_unknown_protocols', DWORD), - ('octets_sent', DWORD), - ('unicast_packets_sent', DWORD), - ('non_unicast_packets_sent', DWORD), - ('outgoing_discards', DWORD), - ('outgoing_errors', DWORD), - ('outgoing_queue_length', DWORD), - ('description_length', DWORD), - ('description_raw', ctypes.c_char * MAXLEN_IFDESCR), - ) + _fields_ = ( + ('name', WCHAR * MAX_INTERFACE_NAME_LEN), + ('index', DWORD), + ('type', DWORD), + ('MTU', DWORD), + ('speed', DWORD), + ('physical_address_length', DWORD), + ('physical_address_raw', BYTE * MAXLEN_PHYSADDR), + ('admin_status', DWORD), + ('operational_status', DWORD), + ('last_change', DWORD), + ('octets_received', DWORD), + ('unicast_packets_received', DWORD), + ('non_unicast_packets_received', DWORD), + ('incoming_discards', DWORD), + ('incoming_errors', DWORD), + ('incoming_unknown_protocols', DWORD), + ('octets_sent', DWORD), + ('unicast_packets_sent', DWORD), + ('non_unicast_packets_sent', DWORD), + ('outgoing_discards', DWORD), + ('outgoing_errors', DWORD), + ('outgoing_queue_length', DWORD), + ('description_length', DWORD), + ('description_raw', ctypes.c_char * MAXLEN_IFDESCR), + ) - def _get_binary_property(self, name): - val_prop = '{0}_raw'.format(name) - val = getattr(self, val_prop) - len_prop = '{0}_length'.format(name) - length = getattr(self, len_prop) - return str(memoryview(val))[:length] + def _get_binary_property(self, name): + val_prop = '{0}_raw'.format(name) + val = getattr(self, val_prop) + len_prop = '{0}_length'.format(name) + length = getattr(self, len_prop) + return str(memoryview(val))[:length] - @property - def physical_address(self): - return self._get_binary_property('physical_address') + @property + def physical_address(self): + return self._get_binary_property('physical_address') - @property - def description(self): - return self._get_binary_property('description') + @property + def description(self): + return self._get_binary_property('description') class MIB_IFTABLE(ctypes.Structure): - _fields_ = ( - ('num_entries', DWORD), # dwNumEntries - ('entries', MIB_IFROW * 0), # table - ) + _fields_ = ( + ('num_entries', DWORD), # dwNumEntries + ('entries', MIB_IFROW * 0), # table + ) class MIB_IPADDRROW(ctypes.Structure): - _fields_ = ( - ('address_num', DWORD), - ('index', DWORD), - ('mask', DWORD), - ('broadcast_address', DWORD), - ('reassembly_size', DWORD), - ('unused', ctypes.c_ushort), - ('type', ctypes.c_ushort), - ) + _fields_ = ( + ('address_num', DWORD), + ('index', DWORD), + ('mask', DWORD), + ('broadcast_address', DWORD), + ('reassembly_size', DWORD), + ('unused', ctypes.c_ushort), + ('type', ctypes.c_ushort), + ) - @property - def address(self): - "The address in big-endian" - _ = struct.pack('L', self.address_num) - return struct.unpack('!L', _)[0] + @property + def address(self): + "The address in big-endian" + _ = struct.pack('L', self.address_num) + return struct.unpack('!L', _)[0] class MIB_IPADDRTABLE(ctypes.Structure): - _fields_ = ( - ('num_entries', DWORD), - ('entries', MIB_IPADDRROW * 0), - ) + _fields_ = (('num_entries', DWORD), ('entries', MIB_IPADDRROW * 0)) class SOCKADDR(ctypes.Structure): - _fields_ = ( - ('family', ctypes.c_ushort), - ('data', ctypes.c_byte * 14), - ) + _fields_ = (('family', ctypes.c_ushort), ('data', ctypes.c_byte * 14)) LPSOCKADDR = ctypes.POINTER(SOCKADDR) class SOCKET_ADDRESS(ctypes.Structure): - _fields_ = [ - ('address', LPSOCKADDR), - ('length', ctypes.c_int), - ] + _fields_ = [('address', LPSOCKADDR), ('length', ctypes.c_int)] class _IP_ADAPTER_ADDRESSES_METRIC(ctypes.Structure): - _fields_ = [ - ('length', ctypes.c_ulong), - ('interface_index', DWORD), - ] + _fields_ = [('length', ctypes.c_ulong), ('interface_index', DWORD)] class _IP_ADAPTER_ADDRESSES_U1(ctypes.Union): - _fields_ = [ - ('alignment', ctypes.c_ulonglong), - ('metric', _IP_ADAPTER_ADDRESSES_METRIC), - ] + _fields_ = [ + ('alignment', ctypes.c_ulonglong), + ('metric', _IP_ADAPTER_ADDRESSES_METRIC), + ] class IP_ADAPTER_ADDRESSES(ctypes.Structure): - pass + pass LP_IP_ADAPTER_ADDRESSES = ctypes.POINTER(IP_ADAPTER_ADDRESSES) @@ -149,69 +137,69 @@ NET_IF_CONNECTION_TYPE = ctypes.c_uint # enum TUNNEL_TYPE = ctypes.c_uint # enum IP_ADAPTER_ADDRESSES._fields_ = [ - # ('u', _IP_ADAPTER_ADDRESSES_U1), - ('length', ctypes.c_ulong), - ('interface_index', DWORD), - ('next', LP_IP_ADAPTER_ADDRESSES), - ('adapter_name', ctypes.c_char_p), - ('first_unicast_address', PIP_ADAPTER_UNICAST_ADDRESS), - ('first_anycast_address', PIP_ADAPTER_ANYCAST_ADDRESS), - ('first_multicast_address', PIP_ADAPTER_MULTICAST_ADDRESS), - ('first_dns_server_address', PIP_ADAPTER_DNS_SERVER_ADDRESS), - ('dns_suffix', ctypes.c_wchar_p), - ('description', ctypes.c_wchar_p), - ('friendly_name', ctypes.c_wchar_p), - ('byte', BYTE * MAX_ADAPTER_ADDRESS_LENGTH), - ('physical_address_length', DWORD), - ('flags', DWORD), - ('mtu', DWORD), - ('interface_type', DWORD), - ('oper_status', IF_OPER_STATUS), - ('ipv6_interface_index', DWORD), - ('zone_indices', DWORD), - ('first_prefix', PIP_ADAPTER_PREFIX), - ('transmit_link_speed', ctypes.c_uint64), - ('receive_link_speed', ctypes.c_uint64), - ('first_wins_server_address', PIP_ADAPTER_WINS_SERVER_ADDRESS_LH), - ('first_gateway_address', PIP_ADAPTER_GATEWAY_ADDRESS_LH), - ('ipv4_metric', ctypes.c_ulong), - ('ipv6_metric', ctypes.c_ulong), - ('luid', IF_LUID), - ('dhcpv4_server', SOCKET_ADDRESS), - ('compartment_id', NET_IF_COMPARTMENT_ID), - ('network_guid', NET_IF_NETWORK_GUID), - ('connection_type', NET_IF_CONNECTION_TYPE), - ('tunnel_type', TUNNEL_TYPE), - ('dhcpv6_server', SOCKET_ADDRESS), - ('dhcpv6_client_duid', ctypes.c_byte * MAX_DHCPV6_DUID_LENGTH), - ('dhcpv6_client_duid_length', ctypes.c_ulong), - ('dhcpv6_iaid', ctypes.c_ulong), - ('first_dns_suffix', PIP_ADAPTER_DNS_SUFFIX), + # ('u', _IP_ADAPTER_ADDRESSES_U1), + ('length', ctypes.c_ulong), + ('interface_index', DWORD), + ('next', LP_IP_ADAPTER_ADDRESSES), + ('adapter_name', ctypes.c_char_p), + ('first_unicast_address', PIP_ADAPTER_UNICAST_ADDRESS), + ('first_anycast_address', PIP_ADAPTER_ANYCAST_ADDRESS), + ('first_multicast_address', PIP_ADAPTER_MULTICAST_ADDRESS), + ('first_dns_server_address', PIP_ADAPTER_DNS_SERVER_ADDRESS), + ('dns_suffix', ctypes.c_wchar_p), + ('description', ctypes.c_wchar_p), + ('friendly_name', ctypes.c_wchar_p), + ('byte', BYTE * MAX_ADAPTER_ADDRESS_LENGTH), + ('physical_address_length', DWORD), + ('flags', DWORD), + ('mtu', DWORD), + ('interface_type', DWORD), + ('oper_status', IF_OPER_STATUS), + ('ipv6_interface_index', DWORD), + ('zone_indices', DWORD), + ('first_prefix', PIP_ADAPTER_PREFIX), + ('transmit_link_speed', ctypes.c_uint64), + ('receive_link_speed', ctypes.c_uint64), + ('first_wins_server_address', PIP_ADAPTER_WINS_SERVER_ADDRESS_LH), + ('first_gateway_address', PIP_ADAPTER_GATEWAY_ADDRESS_LH), + ('ipv4_metric', ctypes.c_ulong), + ('ipv6_metric', ctypes.c_ulong), + ('luid', IF_LUID), + ('dhcpv4_server', SOCKET_ADDRESS), + ('compartment_id', NET_IF_COMPARTMENT_ID), + ('network_guid', NET_IF_NETWORK_GUID), + ('connection_type', NET_IF_CONNECTION_TYPE), + ('tunnel_type', TUNNEL_TYPE), + ('dhcpv6_server', SOCKET_ADDRESS), + ('dhcpv6_client_duid', ctypes.c_byte * MAX_DHCPV6_DUID_LENGTH), + ('dhcpv6_client_duid_length', ctypes.c_ulong), + ('dhcpv6_iaid', ctypes.c_ulong), + ('first_dns_suffix', PIP_ADAPTER_DNS_SUFFIX), ] # define some parameters to the API Functions GetIfTable = ctypes.windll.iphlpapi.GetIfTable GetIfTable.argtypes = [ - ctypes.POINTER(MIB_IFTABLE), - ctypes.POINTER(ctypes.c_ulong), - BOOL, + ctypes.POINTER(MIB_IFTABLE), + ctypes.POINTER(ctypes.c_ulong), + BOOL, ] GetIfTable.restype = DWORD GetIpAddrTable = ctypes.windll.iphlpapi.GetIpAddrTable GetIpAddrTable.argtypes = [ - ctypes.POINTER(MIB_IPADDRTABLE), - ctypes.POINTER(ctypes.c_ulong), - BOOL, + ctypes.POINTER(MIB_IPADDRTABLE), + ctypes.POINTER(ctypes.c_ulong), + BOOL, ] GetIpAddrTable.restype = DWORD GetAdaptersAddresses = ctypes.windll.iphlpapi.GetAdaptersAddresses GetAdaptersAddresses.argtypes = [ - ctypes.c_ulong, - ctypes.c_ulong, - ctypes.c_void_p, - ctypes.POINTER(IP_ADAPTER_ADDRESSES), - ctypes.POINTER(ctypes.c_ulong), + ctypes.c_ulong, + ctypes.c_ulong, + ctypes.c_void_p, + ctypes.POINTER(IP_ADAPTER_ADDRESSES), + ctypes.POINTER(ctypes.c_ulong), ] GetAdaptersAddresses.restype = ctypes.c_ulong diff --git a/libs/win/jaraco/windows/api/library.py b/libs/win/jaraco/windows/api/library.py index 7f14a58e..736de57c 100644 --- a/libs/win/jaraco/windows/api/library.py +++ b/libs/win/jaraco/windows/api/library.py @@ -2,8 +2,8 @@ import ctypes.wintypes GetModuleFileName = ctypes.windll.kernel32.GetModuleFileNameW GetModuleFileName.argtypes = ( - ctypes.wintypes.HANDLE, - ctypes.wintypes.LPWSTR, - ctypes.wintypes.DWORD, + ctypes.wintypes.HANDLE, + ctypes.wintypes.LPWSTR, + ctypes.wintypes.DWORD, ) GetModuleFileName.restype = ctypes.wintypes.DWORD diff --git a/libs/win/jaraco/windows/api/memory.py b/libs/win/jaraco/windows/api/memory.py index 0371099c..c8b60472 100644 --- a/libs/win/jaraco/windows/api/memory.py +++ b/libs/win/jaraco/windows/api/memory.py @@ -7,25 +7,25 @@ GlobalAlloc.argtypes = ctypes.wintypes.UINT, ctypes.c_size_t GlobalAlloc.restype = ctypes.wintypes.HANDLE GlobalLock = ctypes.windll.kernel32.GlobalLock -GlobalLock.argtypes = ctypes.wintypes.HGLOBAL, +GlobalLock.argtypes = (ctypes.wintypes.HGLOBAL,) GlobalLock.restype = ctypes.wintypes.LPVOID GlobalUnlock = ctypes.windll.kernel32.GlobalUnlock -GlobalUnlock.argtypes = ctypes.wintypes.HGLOBAL, +GlobalUnlock.argtypes = (ctypes.wintypes.HGLOBAL,) GlobalUnlock.restype = ctypes.wintypes.BOOL GlobalSize = ctypes.windll.kernel32.GlobalSize -GlobalSize.argtypes = ctypes.wintypes.HGLOBAL, +GlobalSize.argtypes = (ctypes.wintypes.HGLOBAL,) GlobalSize.restype = ctypes.c_size_t CreateFileMapping = ctypes.windll.kernel32.CreateFileMappingW CreateFileMapping.argtypes = [ - ctypes.wintypes.HANDLE, - ctypes.c_void_p, - ctypes.wintypes.DWORD, - ctypes.wintypes.DWORD, - ctypes.wintypes.DWORD, - ctypes.wintypes.LPWSTR, + ctypes.wintypes.HANDLE, + ctypes.c_void_p, + ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, + ctypes.wintypes.DWORD, + ctypes.wintypes.LPWSTR, ] CreateFileMapping.restype = ctypes.wintypes.HANDLE @@ -33,13 +33,9 @@ MapViewOfFile = ctypes.windll.kernel32.MapViewOfFile MapViewOfFile.restype = ctypes.wintypes.HANDLE UnmapViewOfFile = ctypes.windll.kernel32.UnmapViewOfFile -UnmapViewOfFile.argtypes = ctypes.wintypes.HANDLE, +UnmapViewOfFile.argtypes = (ctypes.wintypes.HANDLE,) RtlMoveMemory = ctypes.windll.kernel32.RtlMoveMemory -RtlMoveMemory.argtypes = ( - ctypes.c_void_p, - ctypes.c_void_p, - ctypes.c_size_t, -) +RtlMoveMemory.argtypes = (ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t) -ctypes.windll.kernel32.LocalFree.argtypes = ctypes.wintypes.HLOCAL, +ctypes.windll.kernel32.LocalFree.argtypes = (ctypes.wintypes.HLOCAL,) diff --git a/libs/win/jaraco/windows/api/message.py b/libs/win/jaraco/windows/api/message.py index 5ce2d808..c4c20bad 100644 --- a/libs/win/jaraco/windows/api/message.py +++ b/libs/win/jaraco/windows/api/message.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - """ jaraco.windows.message @@ -9,22 +7,21 @@ Windows Messaging support import ctypes from ctypes.wintypes import HWND, UINT, WPARAM, LPARAM, DWORD, LPVOID -import six - LRESULT = LPARAM class LPARAM_wstr(LPARAM): - """ - A special instance of LPARAM that can be constructed from a string - instance (for functions such as SendMessage, whose LPARAM may point to - a unicode string). - """ - @classmethod - def from_param(cls, param): - if isinstance(param, six.string_types): - return LPVOID.from_param(six.text_type(param)) - return LPARAM.from_param(param) + """ + A special instance of LPARAM that can be constructed from a string + instance (for functions such as SendMessage, whose LPARAM may point to + a unicode string). + """ + + @classmethod + def from_param(cls, param): + if isinstance(param, str): + return LPVOID.from_param(str(param)) + return LPARAM.from_param(param) SendMessage = ctypes.windll.user32.SendMessageW @@ -43,12 +40,10 @@ SMTO_NOTIMEOUTIFNOTHUNG = 0x08 SMTO_ERRORONEXIT = 0x20 SendMessageTimeout = ctypes.windll.user32.SendMessageTimeoutW -SendMessageTimeout.argtypes = SendMessage.argtypes + ( - UINT, UINT, ctypes.POINTER(DWORD) -) +SendMessageTimeout.argtypes = SendMessage.argtypes + (UINT, UINT, ctypes.POINTER(DWORD)) SendMessageTimeout.restype = LRESULT def unicode_as_lparam(source): - pointer = ctypes.cast(ctypes.c_wchar_p(source), ctypes.c_void_p) - return LPARAM(pointer.value) + pointer = ctypes.cast(ctypes.c_wchar_p(source), ctypes.c_void_p) + return LPARAM(pointer.value) diff --git a/libs/win/jaraco/windows/api/net.py b/libs/win/jaraco/windows/api/net.py index ce693319..14defac2 100644 --- a/libs/win/jaraco/windows/api/net.py +++ b/libs/win/jaraco/windows/api/net.py @@ -7,24 +7,24 @@ RESOURCETYPE_ANY = 0 class NETRESOURCE(ctypes.Structure): - _fields_ = [ - ('scope', ctypes.wintypes.DWORD), - ('type', ctypes.wintypes.DWORD), - ('display_type', ctypes.wintypes.DWORD), - ('usage', ctypes.wintypes.DWORD), - ('local_name', ctypes.wintypes.LPWSTR), - ('remote_name', ctypes.wintypes.LPWSTR), - ('comment', ctypes.wintypes.LPWSTR), - ('provider', ctypes.wintypes.LPWSTR), - ] + _fields_ = [ + ('scope', ctypes.wintypes.DWORD), + ('type', ctypes.wintypes.DWORD), + ('display_type', ctypes.wintypes.DWORD), + ('usage', ctypes.wintypes.DWORD), + ('local_name', ctypes.wintypes.LPWSTR), + ('remote_name', ctypes.wintypes.LPWSTR), + ('comment', ctypes.wintypes.LPWSTR), + ('provider', ctypes.wintypes.LPWSTR), + ] LPNETRESOURCE = ctypes.POINTER(NETRESOURCE) WNetAddConnection2 = mpr.WNetAddConnection2W WNetAddConnection2.argtypes = ( - LPNETRESOURCE, - ctypes.wintypes.LPCWSTR, - ctypes.wintypes.LPCWSTR, - ctypes.wintypes.DWORD, + LPNETRESOURCE, + ctypes.wintypes.LPCWSTR, + ctypes.wintypes.LPCWSTR, + ctypes.wintypes.DWORD, ) diff --git a/libs/win/jaraco/windows/api/power.py b/libs/win/jaraco/windows/api/power.py index 77253a8a..b43de876 100644 --- a/libs/win/jaraco/windows/api/power.py +++ b/libs/win/jaraco/windows/api/power.py @@ -2,24 +2,23 @@ import ctypes.wintypes class SYSTEM_POWER_STATUS(ctypes.Structure): - _fields_ = ( - ('ac_line_status', ctypes.wintypes.BYTE), - ('battery_flag', ctypes.wintypes.BYTE), - ('battery_life_percent', ctypes.wintypes.BYTE), - ('reserved', ctypes.wintypes.BYTE), - ('battery_life_time', ctypes.wintypes.DWORD), - ('battery_full_life_time', ctypes.wintypes.DWORD), - ) + _fields_ = ( + ('ac_line_status', ctypes.wintypes.BYTE), + ('battery_flag', ctypes.wintypes.BYTE), + ('battery_life_percent', ctypes.wintypes.BYTE), + ('reserved', ctypes.wintypes.BYTE), + ('battery_life_time', ctypes.wintypes.DWORD), + ('battery_full_life_time', ctypes.wintypes.DWORD), + ) - @property - def ac_line_status_string(self): - return { - 0: 'offline', 1: 'online', 255: 'unknown'}[self.ac_line_status] + @property + def ac_line_status_string(self): + return {0: 'offline', 1: 'online', 255: 'unknown'}[self.ac_line_status] LPSYSTEM_POWER_STATUS = ctypes.POINTER(SYSTEM_POWER_STATUS) GetSystemPowerStatus = ctypes.windll.kernel32.GetSystemPowerStatus -GetSystemPowerStatus.argtypes = LPSYSTEM_POWER_STATUS, +GetSystemPowerStatus.argtypes = (LPSYSTEM_POWER_STATUS,) GetSystemPowerStatus.restype = ctypes.wintypes.BOOL SetThreadExecutionState = ctypes.windll.kernel32.SetThreadExecutionState @@ -28,10 +27,11 @@ SetThreadExecutionState.restype = ctypes.c_uint class ES: - """ - Execution state constants - """ - continuous = 0x80000000 - system_required = 1 - display_required = 2 - awaymode_required = 0x40 + """ + Execution state constants + """ + + continuous = 0x80000000 + system_required = 1 + display_required = 2 + awaymode_required = 0x40 diff --git a/libs/win/jaraco/windows/api/privilege.py b/libs/win/jaraco/windows/api/privilege.py index b841311e..7fb6a497 100644 --- a/libs/win/jaraco/windows/api/privilege.py +++ b/libs/win/jaraco/windows/api/privilege.py @@ -2,35 +2,32 @@ import ctypes.wintypes class LUID(ctypes.Structure): - _fields_ = [ - ('low_part', ctypes.wintypes.DWORD), - ('high_part', ctypes.wintypes.LONG), - ] + _fields_ = [ + ('low_part', ctypes.wintypes.DWORD), + ('high_part', ctypes.wintypes.LONG), + ] - def __eq__(self, other): - return ( - self.high_part == other.high_part and - self.low_part == other.low_part - ) + def __eq__(self, other): + return self.high_part == other.high_part and self.low_part == other.low_part - def __ne__(self, other): - return not (self == other) + def __ne__(self, other): + return not (self == other) LookupPrivilegeValue = ctypes.windll.advapi32.LookupPrivilegeValueW LookupPrivilegeValue.argtypes = ( - ctypes.wintypes.LPWSTR, # system name - ctypes.wintypes.LPWSTR, # name - ctypes.POINTER(LUID), + ctypes.wintypes.LPWSTR, # system name + ctypes.wintypes.LPWSTR, # name + ctypes.POINTER(LUID), ) LookupPrivilegeValue.restype = ctypes.wintypes.BOOL class TOKEN_INFORMATION_CLASS: - TokenUser = 1 - TokenGroups = 2 - TokenPrivileges = 3 - # ... see http://msdn.microsoft.com/en-us/library/aa379626%28VS.85%29.aspx + TokenUser = 1 + TokenGroups = 2 + TokenPrivileges = 3 + # ... see http://msdn.microsoft.com/en-us/library/aa379626%28VS.85%29.aspx SE_PRIVILEGE_ENABLED_BY_DEFAULT = 0x00000001 @@ -40,67 +37,63 @@ SE_PRIVILEGE_USED_FOR_ACCESS = 0x80000000 class LUID_AND_ATTRIBUTES(ctypes.Structure): - _fields_ = [ - ('LUID', LUID), - ('attributes', ctypes.wintypes.DWORD), - ] + _fields_ = [('LUID', LUID), ('attributes', ctypes.wintypes.DWORD)] - def is_enabled(self): - return bool(self.attributes & SE_PRIVILEGE_ENABLED) + def is_enabled(self): + return bool(self.attributes & SE_PRIVILEGE_ENABLED) - def enable(self): - self.attributes |= SE_PRIVILEGE_ENABLED + def enable(self): + self.attributes |= SE_PRIVILEGE_ENABLED - def get_name(self): - size = ctypes.wintypes.DWORD(10240) - buf = ctypes.create_unicode_buffer(size.value) - res = LookupPrivilegeName(None, self.LUID, buf, size) - if res == 0: - raise RuntimeError - return buf[:size.value] + def get_name(self): + size = ctypes.wintypes.DWORD(10240) + buf = ctypes.create_unicode_buffer(size.value) + res = LookupPrivilegeName(None, self.LUID, buf, size) + if res == 0: + raise RuntimeError + return buf[: size.value] - def __str__(self): - res = self.get_name() - if self.is_enabled(): - res += ' (enabled)' - return res + def __str__(self): + res = self.get_name() + if self.is_enabled(): + res += ' (enabled)' + return res LookupPrivilegeName = ctypes.windll.advapi32.LookupPrivilegeNameW LookupPrivilegeName.argtypes = ( - ctypes.wintypes.LPWSTR, # lpSystemName - ctypes.POINTER(LUID), # lpLuid - ctypes.wintypes.LPWSTR, # lpName - ctypes.POINTER(ctypes.wintypes.DWORD), # cchName + ctypes.wintypes.LPWSTR, # lpSystemName + ctypes.POINTER(LUID), # lpLuid + ctypes.wintypes.LPWSTR, # lpName + ctypes.POINTER(ctypes.wintypes.DWORD), # cchName ) LookupPrivilegeName.restype = ctypes.wintypes.BOOL class TOKEN_PRIVILEGES(ctypes.Structure): - _fields_ = [ - ('count', ctypes.wintypes.DWORD), - ('privileges', LUID_AND_ATTRIBUTES * 0), - ] + _fields_ = [ + ('count', ctypes.wintypes.DWORD), + ('privileges', LUID_AND_ATTRIBUTES * 0), + ] - def get_array(self): - array_type = LUID_AND_ATTRIBUTES * self.count - privileges = ctypes.cast( - self.privileges, ctypes.POINTER(array_type)).contents - return privileges + def get_array(self): + array_type = LUID_AND_ATTRIBUTES * self.count + privileges = ctypes.cast(self.privileges, ctypes.POINTER(array_type)).contents + return privileges - def __iter__(self): - return iter(self.get_array()) + def __iter__(self): + return iter(self.get_array()) PTOKEN_PRIVILEGES = ctypes.POINTER(TOKEN_PRIVILEGES) GetTokenInformation = ctypes.windll.advapi32.GetTokenInformation GetTokenInformation.argtypes = [ - ctypes.wintypes.HANDLE, # TokenHandle - ctypes.c_uint, # TOKEN_INFORMATION_CLASS value - ctypes.c_void_p, # TokenInformation - ctypes.wintypes.DWORD, # TokenInformationLength - ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength + ctypes.wintypes.HANDLE, # TokenHandle + ctypes.c_uint, # TOKEN_INFORMATION_CLASS value + ctypes.c_void_p, # TokenInformation + ctypes.wintypes.DWORD, # TokenInformationLength + ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength ] GetTokenInformation.restype = ctypes.wintypes.BOOL @@ -108,10 +101,10 @@ GetTokenInformation.restype = ctypes.wintypes.BOOL AdjustTokenPrivileges = ctypes.windll.advapi32.AdjustTokenPrivileges AdjustTokenPrivileges.restype = ctypes.wintypes.BOOL AdjustTokenPrivileges.argtypes = [ - ctypes.wintypes.HANDLE, # TokenHandle - ctypes.wintypes.BOOL, # DisableAllPrivileges - PTOKEN_PRIVILEGES, # NewState (optional) - ctypes.wintypes.DWORD, # BufferLength of PreviousState - PTOKEN_PRIVILEGES, # PreviousState (out, optional) - ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength + ctypes.wintypes.HANDLE, # TokenHandle + ctypes.wintypes.BOOL, # DisableAllPrivileges + PTOKEN_PRIVILEGES, # NewState (optional) + ctypes.wintypes.DWORD, # BufferLength of PreviousState + PTOKEN_PRIVILEGES, # PreviousState (out, optional) + ctypes.POINTER(ctypes.wintypes.DWORD), # ReturnLength ] diff --git a/libs/win/jaraco/windows/api/process.py b/libs/win/jaraco/windows/api/process.py index 56ce7852..3337cba5 100644 --- a/libs/win/jaraco/windows/api/process.py +++ b/libs/win/jaraco/windows/api/process.py @@ -1,11 +1,13 @@ import ctypes.wintypes -TOKEN_ALL_ACCESS = 0xf01ff +TOKEN_ALL_ACCESS = 0xF01FF GetCurrentProcess = ctypes.windll.kernel32.GetCurrentProcess GetCurrentProcess.restype = ctypes.wintypes.HANDLE OpenProcessToken = ctypes.windll.advapi32.OpenProcessToken OpenProcessToken.argtypes = ( - ctypes.wintypes.HANDLE, ctypes.wintypes.DWORD, - ctypes.POINTER(ctypes.wintypes.HANDLE)) + ctypes.wintypes.HANDLE, + ctypes.wintypes.DWORD, + ctypes.POINTER(ctypes.wintypes.HANDLE), +) OpenProcessToken.restype = ctypes.wintypes.BOOL diff --git a/libs/win/jaraco/windows/api/security.py b/libs/win/jaraco/windows/api/security.py index c9e7b58e..db5d220c 100644 --- a/libs/win/jaraco/windows/api/security.py +++ b/libs/win/jaraco/windows/api/security.py @@ -24,116 +24,117 @@ POLICY_LOOKUP_NAMES = 0x00000800 POLICY_NOTIFICATION = 0x00001000 POLICY_ALL_ACCESS = ( - STANDARD_RIGHTS_REQUIRED | - POLICY_VIEW_LOCAL_INFORMATION | - POLICY_VIEW_AUDIT_INFORMATION | - POLICY_GET_PRIVATE_INFORMATION | - POLICY_TRUST_ADMIN | - POLICY_CREATE_ACCOUNT | - POLICY_CREATE_SECRET | - POLICY_CREATE_PRIVILEGE | - POLICY_SET_DEFAULT_QUOTA_LIMITS | - POLICY_SET_AUDIT_REQUIREMENTS | - POLICY_AUDIT_LOG_ADMIN | - POLICY_SERVER_ADMIN | - POLICY_LOOKUP_NAMES) + STANDARD_RIGHTS_REQUIRED + | POLICY_VIEW_LOCAL_INFORMATION + | POLICY_VIEW_AUDIT_INFORMATION + | POLICY_GET_PRIVATE_INFORMATION + | POLICY_TRUST_ADMIN + | POLICY_CREATE_ACCOUNT + | POLICY_CREATE_SECRET + | POLICY_CREATE_PRIVILEGE + | POLICY_SET_DEFAULT_QUOTA_LIMITS + | POLICY_SET_AUDIT_REQUIREMENTS + | POLICY_AUDIT_LOG_ADMIN + | POLICY_SERVER_ADMIN + | POLICY_LOOKUP_NAMES +) POLICY_READ = ( - STANDARD_RIGHTS_READ | - POLICY_VIEW_AUDIT_INFORMATION | - POLICY_GET_PRIVATE_INFORMATION) + STANDARD_RIGHTS_READ + | POLICY_VIEW_AUDIT_INFORMATION + | POLICY_GET_PRIVATE_INFORMATION +) POLICY_WRITE = ( - STANDARD_RIGHTS_WRITE | - POLICY_TRUST_ADMIN | - POLICY_CREATE_ACCOUNT | - POLICY_CREATE_SECRET | - POLICY_CREATE_PRIVILEGE | - POLICY_SET_DEFAULT_QUOTA_LIMITS | - POLICY_SET_AUDIT_REQUIREMENTS | - POLICY_AUDIT_LOG_ADMIN | - POLICY_SERVER_ADMIN) + STANDARD_RIGHTS_WRITE + | POLICY_TRUST_ADMIN + | POLICY_CREATE_ACCOUNT + | POLICY_CREATE_SECRET + | POLICY_CREATE_PRIVILEGE + | POLICY_SET_DEFAULT_QUOTA_LIMITS + | POLICY_SET_AUDIT_REQUIREMENTS + | POLICY_AUDIT_LOG_ADMIN + | POLICY_SERVER_ADMIN +) POLICY_EXECUTE = ( - STANDARD_RIGHTS_EXECUTE | - POLICY_VIEW_LOCAL_INFORMATION | - POLICY_LOOKUP_NAMES) + STANDARD_RIGHTS_EXECUTE | POLICY_VIEW_LOCAL_INFORMATION | POLICY_LOOKUP_NAMES +) class TokenAccess: - TOKEN_QUERY = 0x8 + TOKEN_QUERY = 0x8 class TokenInformationClass: - TokenUser = 1 + TokenUser = 1 class TOKEN_USER(ctypes.Structure): - num = 1 - _fields_ = [ - ('SID', ctypes.c_void_p), - ('ATTRIBUTES', ctypes.wintypes.DWORD), - ] + num = 1 + _fields_ = [('SID', ctypes.c_void_p), ('ATTRIBUTES', ctypes.wintypes.DWORD)] class SECURITY_DESCRIPTOR(ctypes.Structure): - """ - typedef struct _SECURITY_DESCRIPTOR - { - UCHAR Revision; - UCHAR Sbz1; - SECURITY_DESCRIPTOR_CONTROL Control; - PSID Owner; - PSID Group; - PACL Sacl; - PACL Dacl; - } SECURITY_DESCRIPTOR; - """ - SECURITY_DESCRIPTOR_CONTROL = ctypes.wintypes.USHORT - REVISION = 1 + """ + typedef struct _SECURITY_DESCRIPTOR + { + UCHAR Revision; + UCHAR Sbz1; + SECURITY_DESCRIPTOR_CONTROL Control; + PSID Owner; + PSID Group; + PACL Sacl; + PACL Dacl; + } SECURITY_DESCRIPTOR; + """ - _fields_ = [ - ('Revision', ctypes.c_ubyte), - ('Sbz1', ctypes.c_ubyte), - ('Control', SECURITY_DESCRIPTOR_CONTROL), - ('Owner', ctypes.c_void_p), - ('Group', ctypes.c_void_p), - ('Sacl', ctypes.c_void_p), - ('Dacl', ctypes.c_void_p), - ] + SECURITY_DESCRIPTOR_CONTROL = ctypes.wintypes.USHORT + REVISION = 1 + + _fields_ = [ + ('Revision', ctypes.c_ubyte), + ('Sbz1', ctypes.c_ubyte), + ('Control', SECURITY_DESCRIPTOR_CONTROL), + ('Owner', ctypes.c_void_p), + ('Group', ctypes.c_void_p), + ('Sacl', ctypes.c_void_p), + ('Dacl', ctypes.c_void_p), + ] class SECURITY_ATTRIBUTES(ctypes.Structure): - """ - typedef struct _SECURITY_ATTRIBUTES { - DWORD nLength; - LPVOID lpSecurityDescriptor; - BOOL bInheritHandle; - } SECURITY_ATTRIBUTES; - """ - _fields_ = [ - ('nLength', ctypes.wintypes.DWORD), - ('lpSecurityDescriptor', ctypes.c_void_p), - ('bInheritHandle', ctypes.wintypes.BOOL), - ] + """ + typedef struct _SECURITY_ATTRIBUTES { + DWORD nLength; + LPVOID lpSecurityDescriptor; + BOOL bInheritHandle; + } SECURITY_ATTRIBUTES; + """ - def __init__(self, *args, **kwargs): - super(SECURITY_ATTRIBUTES, self).__init__(*args, **kwargs) - self.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) + _fields_ = [ + ('nLength', ctypes.wintypes.DWORD), + ('lpSecurityDescriptor', ctypes.c_void_p), + ('bInheritHandle', ctypes.wintypes.BOOL), + ] - @property - def descriptor(self): - return self._descriptor + def __init__(self, *args, **kwargs): + super(SECURITY_ATTRIBUTES, self).__init__(*args, **kwargs) + self.nLength = ctypes.sizeof(SECURITY_ATTRIBUTES) - @descriptor.setter - def descriptor(self, value): - self._descriptor = value - self.lpSecurityDescriptor = ctypes.addressof(value) + @property + def descriptor(self): + return self._descriptor + + @descriptor.setter + def descriptor(self, value): + self._descriptor = value + self.lpSecurityDescriptor = ctypes.addressof(value) ctypes.windll.advapi32.SetSecurityDescriptorOwner.argtypes = ( - ctypes.POINTER(SECURITY_DESCRIPTOR), - ctypes.c_void_p, - ctypes.wintypes.BOOL, + ctypes.POINTER(SECURITY_DESCRIPTOR), + ctypes.c_void_p, + ctypes.wintypes.BOOL, ) diff --git a/libs/win/jaraco/windows/api/shell.py b/libs/win/jaraco/windows/api/shell.py index 1d428c87..af7174dc 100644 --- a/libs/win/jaraco/windows/api/shell.py +++ b/libs/win/jaraco/windows/api/shell.py @@ -1,39 +1,40 @@ import ctypes.wintypes + BOOL = ctypes.wintypes.BOOL class SHELLSTATE(ctypes.Structure): - _fields_ = [ - ('show_all_objects', BOOL, 1), - ('show_extensions', BOOL, 1), - ('no_confirm_recycle', BOOL, 1), - ('show_sys_files', BOOL, 1), - ('show_comp_color', BOOL, 1), - ('double_click_in_web_view', BOOL, 1), - ('desktop_HTML', BOOL, 1), - ('win95_classic', BOOL, 1), - ('dont_pretty_path', BOOL, 1), - ('show_attrib_col', BOOL, 1), - ('map_network_drive_button', BOOL, 1), - ('show_info_tip', BOOL, 1), - ('hide_icons', BOOL, 1), - ('web_view', BOOL, 1), - ('filter', BOOL, 1), - ('show_super_hidden', BOOL, 1), - ('no_net_crawling', BOOL, 1), - ('win95_unused', ctypes.wintypes.DWORD), - ('param_sort', ctypes.wintypes.LONG), - ('sort_direction', ctypes.c_int), - ('version', ctypes.wintypes.UINT), - ('not_used', ctypes.wintypes.UINT), - ('sep_process', BOOL, 1), - ('start_panel_on', BOOL, 1), - ('show_start_page', BOOL, 1), - ('auto_check_select', BOOL, 1), - ('icons_only', BOOL, 1), - ('show_type_overlay', BOOL, 1), - ('spare_flags', ctypes.wintypes.UINT, 13), - ] + _fields_ = [ + ('show_all_objects', BOOL, 1), + ('show_extensions', BOOL, 1), + ('no_confirm_recycle', BOOL, 1), + ('show_sys_files', BOOL, 1), + ('show_comp_color', BOOL, 1), + ('double_click_in_web_view', BOOL, 1), + ('desktop_HTML', BOOL, 1), + ('win95_classic', BOOL, 1), + ('dont_pretty_path', BOOL, 1), + ('show_attrib_col', BOOL, 1), + ('map_network_drive_button', BOOL, 1), + ('show_info_tip', BOOL, 1), + ('hide_icons', BOOL, 1), + ('web_view', BOOL, 1), + ('filter', BOOL, 1), + ('show_super_hidden', BOOL, 1), + ('no_net_crawling', BOOL, 1), + ('win95_unused', ctypes.wintypes.DWORD), + ('param_sort', ctypes.wintypes.LONG), + ('sort_direction', ctypes.c_int), + ('version', ctypes.wintypes.UINT), + ('not_used', ctypes.wintypes.UINT), + ('sep_process', BOOL, 1), + ('start_panel_on', BOOL, 1), + ('show_start_page', BOOL, 1), + ('auto_check_select', BOOL, 1), + ('icons_only', BOOL, 1), + ('show_type_overlay', BOOL, 1), + ('spare_flags', ctypes.wintypes.UINT, 13), + ] SSF_SHOWALLOBJECTS = 0x00000001 @@ -123,8 +124,8 @@ SSF_SHOWTYPEOVERLAY = 0x02000000 SHGetSetSettings = ctypes.windll.shell32.SHGetSetSettings SHGetSetSettings.argtypes = [ - ctypes.POINTER(SHELLSTATE), - ctypes.wintypes.DWORD, - ctypes.wintypes.BOOL, # get or set (True: set) + ctypes.POINTER(SHELLSTATE), + ctypes.wintypes.DWORD, + ctypes.wintypes.BOOL, # get or set (True: set) ] SHGetSetSettings.restype = None diff --git a/libs/win/jaraco/windows/api/system.py b/libs/win/jaraco/windows/api/system.py index 6a09f5ad..ba69affb 100644 --- a/libs/win/jaraco/windows/api/system.py +++ b/libs/win/jaraco/windows/api/system.py @@ -2,10 +2,10 @@ import ctypes.wintypes SystemParametersInfo = ctypes.windll.user32.SystemParametersInfoW SystemParametersInfo.argtypes = ( - ctypes.wintypes.UINT, - ctypes.wintypes.UINT, - ctypes.c_void_p, - ctypes.wintypes.UINT, + ctypes.wintypes.UINT, + ctypes.wintypes.UINT, + ctypes.c_void_p, + ctypes.wintypes.UINT, ) SPI_GETACTIVEWINDOWTRACKING = 0x1000 diff --git a/libs/win/jaraco/windows/api/user.py b/libs/win/jaraco/windows/api/user.py index 9d0b3f8d..936cdd23 100644 --- a/libs/win/jaraco/windows/api/user.py +++ b/libs/win/jaraco/windows/api/user.py @@ -1,9 +1,9 @@ import ctypes.wintypes try: - from ctypes.wintypes import LPDWORD + from ctypes.wintypes import LPDWORD except ImportError: - LPDWORD = ctypes.POINTER(ctypes.wintypes.DWORD) + LPDWORD = ctypes.POINTER(ctypes.wintypes.DWORD) # type: ignore GetUserName = ctypes.windll.advapi32.GetUserNameW GetUserName.argtypes = ctypes.wintypes.LPWSTR, LPDWORD diff --git a/libs/win/jaraco/windows/batch.py b/libs/win/jaraco/windows/batch.py new file mode 100644 index 00000000..b0ac6f0c --- /dev/null +++ b/libs/win/jaraco/windows/batch.py @@ -0,0 +1,39 @@ +import subprocess +import itertools + +from more_itertools import consume, always_iterable + + +def extract_environment(env_cmd, initial=None): + """ + Take a command (either a single command or list of arguments) + and return the environment created after running that command. + Note that if the command must be a batch file or .cmd file, or the + changes to the environment will not be captured. + + If initial is supplied, it is used as the initial environment passed + to the child process. + """ + # construct the command that will alter the environment + env_cmd = subprocess.list2cmdline(always_iterable(env_cmd)) + # create a tag so we can tell in the output when the proc is done + tag = 'Done running command' + # construct a cmd.exe command to do accomplish this + cmd = 'cmd.exe /s /c "{env_cmd} && echo "{tag}" && set"'.format(**vars()) + # launch the process + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=initial) + # parse the output sent to stdout + lines = proc.stdout + # make sure the lines are strings + + def make_str(s): + return s.decode() + + lines = map(make_str, lines) + # consume whatever output occurs until the tag is reached + consume(itertools.takewhile(lambda l: tag not in l, lines)) + # construct a dictionary of the pairs + result = dict(line.rstrip().split('=', 1) for line in lines) + # let the process finish + proc.communicate() + return result diff --git a/libs/win/jaraco/windows/clipboard.py b/libs/win/jaraco/windows/clipboard.py index 2f4bbc3a..cf7dc6bc 100644 --- a/libs/win/jaraco/windows/clipboard.py +++ b/libs/win/jaraco/windows/clipboard.py @@ -1,5 +1,3 @@ -from __future__ import with_statement, print_function - import sys import re import itertools @@ -7,220 +5,265 @@ from contextlib import contextmanager import io import ctypes from ctypes import windll - -import six -from six.moves import map +import textwrap +import collections from jaraco.windows.api import clipboard, memory from jaraco.windows.error import handle_nonzero_success, WindowsError from jaraco.windows.memory import LockedMemory -__all__ = ( - 'GetClipboardData', 'CloseClipboard', - 'SetClipboardData', 'OpenClipboard', -) +__all__ = ('GetClipboardData', 'CloseClipboard', 'SetClipboardData', 'OpenClipboard') def OpenClipboard(owner=None): - """ - Open the clipboard. + """ + Open the clipboard. - owner - [in] Handle to the window to be associated with the open clipboard. - If this parameter is None, the open clipboard is associated with the - current task. - """ - handle_nonzero_success(windll.user32.OpenClipboard(owner)) + owner + [in] Handle to the window to be associated with the open clipboard. + If this parameter is None, the open clipboard is associated with the + current task. + """ + handle_nonzero_success(windll.user32.OpenClipboard(owner)) def CloseClipboard(): - handle_nonzero_success(windll.user32.CloseClipboard()) + handle_nonzero_success(windll.user32.CloseClipboard()) data_handlers = dict() def handles(*formats): - def register(func): - for format in formats: - data_handlers[format] = func - return func - return register + def register(func): + for format in formats: + data_handlers[format] = func + return func + + return register def nts(buffer): - """ - Null Terminated String - Get the portion of bytestring buffer up to a null character. - """ - result, null, rest = buffer.partition('\x00') - return result + """ + Null Terminated String + Get the portion of bytestring buffer up to a null character. + """ + result, null, rest = buffer.partition('\x00') + return result @handles(clipboard.CF_DIBV5, clipboard.CF_DIB) def raw_data(handle): - return LockedMemory(handle).data + return LockedMemory(handle).data @handles(clipboard.CF_TEXT) def text_string(handle): - return nts(raw_data(handle)) + return nts(raw_data(handle)) @handles(clipboard.CF_UNICODETEXT) def unicode_string(handle): - return nts(raw_data(handle).decode('utf-16')) + return nts(raw_data(handle).decode('utf-16')) @handles(clipboard.CF_BITMAP) def as_bitmap(handle): - # handle is HBITMAP - raise NotImplementedError("Can't convert to DIB") - # todo: use GetDIBits http://msdn.microsoft.com - # /en-us/library/dd144879%28v=VS.85%29.aspx + # handle is HBITMAP + raise NotImplementedError("Can't convert to DIB") + # todo: use GetDIBits http://msdn.microsoft.com + # /en-us/library/dd144879%28v=VS.85%29.aspx @handles(clipboard.CF_HTML) class HTMLSnippet(object): - def __init__(self, handle): - self.data = nts(raw_data(handle).decode('utf-8')) - self.headers = self.parse_headers(self.data) + """ + HTML Snippet representing the Microsoft `HTML snippet format + `_. + """ - @property - def html(self): - return self.data[self.headers['StartHTML']:] + def __init__(self, handle): + self.data = nts(raw_data(handle).decode('utf-8')) + self.headers = self.parse_headers(self.data) - @staticmethod - def parse_headers(data): - d = io.StringIO(data) + @property + def html(self): + return self.data[self.headers['StartHTML'] :] - def header_line(line): - return re.match('(\w+):(.*)', line) - headers = map(header_line, d) - # grab headers until they no longer match - headers = itertools.takewhile(bool, headers) + @property + def fragment(self): + return self.data[self.headers['StartFragment'] : self.headers['EndFragment']] - def best_type(value): - try: - return int(value) - except ValueError: - pass - try: - return float(value) - except ValueError: - pass - return value - pairs = ( - (header.group(1), best_type(header.group(2))) - for header - in headers - ) - return dict(pairs) + @staticmethod + def parse_headers(data): + d = io.StringIO(data) + + def header_line(line): + return re.match(r'(\w+):(.*)', line) + + headers = map(header_line, d) + # grab headers until they no longer match + headers = itertools.takewhile(bool, headers) + + def best_type(value): + try: + return int(value) + except ValueError: + pass + try: + return float(value) + except ValueError: + pass + return value + + pairs = ((header.group(1), best_type(header.group(2))) for header in headers) + return dict(pairs) + + @classmethod + def from_string(cls, source): + """ + Construct an HTMLSnippet with all the headers, modeled after + https://docs.microsoft.com/en-us/troubleshoot/cpp/add-html-code-clipboard + """ + tmpl = textwrap.dedent( + """ + Version:0.9 + StartHTML:{start_html:08d} + EndHTML:{end_html:08d} + StartFragment:{start_fragment:08d} + EndFragment:{end_fragment:08d} + + + {source} + + + """ + ).strip() + zeros = collections.defaultdict(lambda: 0, locals()) + pre_value = tmpl.format_map(zeros) + start_html = pre_value.find('') + end_html = len(tmpl) + assert end_html < 100000000 + start_fragment = pre_value.find(source) + end_fragment = pre_value.rfind('\n %(target)s\n" % vars()) + """ + Like cmd.exe's mklink except it will infer directory status of the + target. + """ + from optparse import OptionParser + + parser = OptionParser(usage="usage: %prog [options] link target") + parser.add_option( + '-d', + '--directory', + help="Target is a directory (only necessary if not present)", + action="store_true", + ) + options, args = parser.parse_args() + try: + link, target = args + except ValueError: + parser.error("incorrect number of arguments") + symlink(target, link, options.directory) + sys.stdout.write("Symbolic link created: %(link)s --> %(target)s\n" % vars()) def _is_target_a_directory(link, rel_target): - """ - If creating a symlink from link to a target, determine if target - is a directory (relative to dirname(link)). - """ - target = os.path.join(os.path.dirname(link), rel_target) - return os.path.isdir(target) + """ + If creating a symlink from link to a target, determine if target + is a directory (relative to dirname(link)). + """ + target = os.path.join(os.path.dirname(link), rel_target) + return os.path.isdir(target) def symlink(target, link, target_is_directory=False): - """ - An implementation of os.symlink for Windows (Vista and greater) - """ - target_is_directory = ( - target_is_directory or - _is_target_a_directory(link, target) - ) - # normalize the target (MS symlinks don't respect forward slashes) - target = os.path.normpath(target) - handle_nonzero_success( - api.CreateSymbolicLink(link, target, target_is_directory)) + """ + An implementation of os.symlink for Windows (Vista and greater) + """ + target_is_directory = target_is_directory or _is_target_a_directory(link, target) + # normalize the target (MS symlinks don't respect forward slashes) + target = os.path.normpath(target) + flags = target_is_directory | api.SYMBOLIC_LINK_FLAG_ALLOW_UNPRIVILEGED_CREATE + handle_nonzero_success(api.CreateSymbolicLink(link, target, flags)) def link(target, link): - """ - Establishes a hard link between an existing file and a new file. - """ - handle_nonzero_success(api.CreateHardLink(link, target, None)) + """ + Establishes a hard link between an existing file and a new file. + """ + handle_nonzero_success(api.CreateHardLink(link, target, None)) def is_reparse_point(path): - """ - Determine if the given path is a reparse point. - Return False if the file does not exist or the file attributes cannot - be determined. - """ - res = api.GetFileAttributes(path) - return ( - res != api.INVALID_FILE_ATTRIBUTES - and bool(res & api.FILE_ATTRIBUTE_REPARSE_POINT) - ) + """ + Determine if the given path is a reparse point. + Return False if the file does not exist or the file attributes cannot + be determined. + """ + res = api.GetFileAttributes(path) + return res != api.INVALID_FILE_ATTRIBUTES and bool( + res & api.FILE_ATTRIBUTE_REPARSE_POINT + ) def islink(path): - "Determine if the given path is a symlink" - return is_reparse_point(path) and is_symlink(path) + "Determine if the given path is a symlink" + return is_reparse_point(path) and is_symlink(path) def _patch_path(path): - """ - Paths have a max length of api.MAX_PATH characters (260). If a target path - is longer than that, it needs to be made absolute and prepended with - \\?\ in order to work with API calls. - See http://msdn.microsoft.com/en-us/library/aa365247%28v=vs.85%29.aspx for - details. - """ - if path.startswith('\\\\?\\'): - return path - abs_path = os.path.abspath(path) - if not abs_path[1] == ':': - # python doesn't include the drive letter, but \\?\ requires it - abs_path = os.getcwd()[:2] + abs_path - return '\\\\?\\' + abs_path + r""" + Paths have a max length of api.MAX_PATH characters (260). If a target path + is longer than that, it needs to be made absolute and prepended with + \\?\ in order to work with API calls. + See http://msdn.microsoft.com/en-us/library/aa365247%28v=vs.85%29.aspx for + details. + """ + if path.startswith('\\\\?\\'): + return path + abs_path = os.path.abspath(path) + if not abs_path[1] == ':': + # python doesn't include the drive letter, but \\?\ requires it + abs_path = os.getcwd()[:2] + abs_path + return '\\\\?\\' + abs_path def is_symlink(path): - """ - Assuming path is a reparse point, determine if it's a symlink. - """ - path = _patch_path(path) - try: - return _is_symlink(next(find_files(path))) - except WindowsError as orig_error: - tmpl = "Error accessing {path}: {orig_error.message}" - raise builtins.WindowsError(tmpl.format(**locals())) + """ + Assuming path is a reparse point, determine if it's a symlink. + """ + path = _patch_path(path) + try: + return _is_symlink(next(find_files(path))) + # comment below workaround for PyCQA/pyflakes#376 + except WindowsError as orig_error: # noqa: F841 + tmpl = "Error accessing {path}: {orig_error.message}" + raise builtins.WindowsError(tmpl.format(**locals())) def _is_symlink(find_data): - return find_data.reserved[0] == api.IO_REPARSE_TAG_SYMLINK + return find_data.reserved[0] == api.IO_REPARSE_TAG_SYMLINK def find_files(spec): - """ - A pythonic wrapper around the FindFirstFile/FindNextFile win32 api. + r""" + A pythonic wrapper around the FindFirstFile/FindNextFile win32 api. - >>> root_files = tuple(find_files(r'c:\*')) - >>> len(root_files) > 1 - True - >>> root_files[0].filename == root_files[1].filename - False + >>> root_files = tuple(find_files(r'c:\*')) + >>> len(root_files) > 1 + True + >>> root_files[0].filename == root_files[1].filename + False - This test might fail on a non-standard installation - >>> 'Windows' in (fd.filename for fd in root_files) - True - """ - fd = api.WIN32_FIND_DATA() - handle = api.FindFirstFile(spec, byref(fd)) - while True: - if handle == api.INVALID_HANDLE_VALUE: - raise WindowsError() - yield fd - fd = api.WIN32_FIND_DATA() - res = api.FindNextFile(handle, byref(fd)) - if res == 0: # error - error = WindowsError() - if error.code == api.ERROR_NO_MORE_FILES: - break - else: - raise error - # todo: how to close handle when generator is destroyed? - # hint: catch GeneratorExit - windll.kernel32.FindClose(handle) + This test might fail on a non-standard installation + >>> 'Windows' in (fd.filename for fd in root_files) + True + """ + fd = api.WIN32_FIND_DATA() + handle = api.FindFirstFile(spec, byref(fd)) + while True: + if handle == api.INVALID_HANDLE_VALUE: + raise WindowsError() + yield fd + fd = api.WIN32_FIND_DATA() + res = api.FindNextFile(handle, byref(fd)) + if res == 0: # error + error = WindowsError() + if error.code == api.ERROR_NO_MORE_FILES: + break + else: + raise error + # todo: how to close handle when generator is destroyed? + # hint: catch GeneratorExit + windll.kernel32.FindClose(handle) def get_final_path(path): - """ - For a given path, determine the ultimate location of that path. - Useful for resolving symlink targets. - This functions wraps the GetFinalPathNameByHandle from the Windows - SDK. + r""" + For a given path, determine the ultimate location of that path. + Useful for resolving symlink targets. + This functions wraps the GetFinalPathNameByHandle from the Windows + SDK. - Note, this function fails if a handle cannot be obtained (such as - for C:\Pagefile.sys on a stock windows system). Consider using - trace_symlink_target instead. - """ - desired_access = api.NULL - share_mode = ( - api.FILE_SHARE_READ | api.FILE_SHARE_WRITE | api.FILE_SHARE_DELETE - ) - security_attributes = api.LPSECURITY_ATTRIBUTES() # NULL pointer - hFile = api.CreateFile( - path, - desired_access, - share_mode, - security_attributes, - api.OPEN_EXISTING, - api.FILE_FLAG_BACKUP_SEMANTICS, - api.NULL, - ) + Note, this function fails if a handle cannot be obtained (such as + for C:\Pagefile.sys on a stock windows system). Consider using + trace_symlink_target instead. + """ + desired_access = api.NULL + share_mode = api.FILE_SHARE_READ | api.FILE_SHARE_WRITE | api.FILE_SHARE_DELETE + security_attributes = api.LPSECURITY_ATTRIBUTES() # NULL pointer + hFile = api.CreateFile( + path, + desired_access, + share_mode, + security_attributes, + api.OPEN_EXISTING, + api.FILE_FLAG_BACKUP_SEMANTICS, + api.NULL, + ) - if hFile == api.INVALID_HANDLE_VALUE: - raise WindowsError() + if hFile == api.INVALID_HANDLE_VALUE: + raise WindowsError() - buf_size = api.GetFinalPathNameByHandle( - hFile, LPWSTR(), 0, api.VOLUME_NAME_DOS) - handle_nonzero_success(buf_size) - buf = create_unicode_buffer(buf_size) - result_length = api.GetFinalPathNameByHandle( - hFile, buf, len(buf), api.VOLUME_NAME_DOS) + buf_size = api.GetFinalPathNameByHandle(hFile, LPWSTR(), 0, api.VOLUME_NAME_DOS) + handle_nonzero_success(buf_size) + buf = create_unicode_buffer(buf_size) + result_length = api.GetFinalPathNameByHandle( + hFile, buf, len(buf), api.VOLUME_NAME_DOS + ) - assert result_length < len(buf) - handle_nonzero_success(result_length) - handle_nonzero_success(api.CloseHandle(hFile)) + assert result_length < len(buf) + handle_nonzero_success(result_length) + handle_nonzero_success(api.CloseHandle(hFile)) - return buf[:result_length] + return buf[:result_length] def compat_stat(path): - """ - Generate stat as found on Python 3.2 and later. - """ - stat = os.stat(path) - info = get_file_info(path) - # rewrite st_ino, st_dev, and st_nlink based on file info - return nt.stat_result( - (stat.st_mode,) + - (info.file_index, info.volume_serial_number, info.number_of_links) + - stat[4:] - ) + """ + Generate stat as found on Python 3.2 and later. + """ + stat = os.stat(path) + info = get_file_info(path) + # rewrite st_ino, st_dev, and st_nlink based on file info + return nt.stat_result( + (stat.st_mode,) + + (info.file_index, info.volume_serial_number, info.number_of_links) + + stat[4:] + ) def samefile(f1, f2): - """ - Backport of samefile from Python 3.2 with support for Windows. - """ - return posixpath.samestat(compat_stat(f1), compat_stat(f2)) + """ + Backport of samefile from Python 3.2 with support for Windows. + """ + return posixpath.samestat(compat_stat(f1), compat_stat(f2)) def get_file_info(path): - # open the file the same way CPython does in posixmodule.c - desired_access = api.FILE_READ_ATTRIBUTES - share_mode = 0 - security_attributes = None - creation_disposition = api.OPEN_EXISTING - flags_and_attributes = ( - api.FILE_ATTRIBUTE_NORMAL | - api.FILE_FLAG_BACKUP_SEMANTICS | - api.FILE_FLAG_OPEN_REPARSE_POINT - ) - template_file = None + # open the file the same way CPython does in posixmodule.c + desired_access = api.FILE_READ_ATTRIBUTES + share_mode = 0 + security_attributes = None + creation_disposition = api.OPEN_EXISTING + flags_and_attributes = ( + api.FILE_ATTRIBUTE_NORMAL + | api.FILE_FLAG_BACKUP_SEMANTICS + | api.FILE_FLAG_OPEN_REPARSE_POINT + ) + template_file = None - handle = api.CreateFile( - path, - desired_access, - share_mode, - security_attributes, - creation_disposition, - flags_and_attributes, - template_file, - ) + handle = api.CreateFile( + path, + desired_access, + share_mode, + security_attributes, + creation_disposition, + flags_and_attributes, + template_file, + ) - if handle == api.INVALID_HANDLE_VALUE: - raise WindowsError() + if handle == api.INVALID_HANDLE_VALUE: + raise WindowsError() - info = api.BY_HANDLE_FILE_INFORMATION() - res = api.GetFileInformationByHandle(handle, info) - handle_nonzero_success(res) - handle_nonzero_success(api.CloseHandle(handle)) + info = api.BY_HANDLE_FILE_INFORMATION() + res = api.GetFileInformationByHandle(handle, info) + handle_nonzero_success(res) + handle_nonzero_success(api.CloseHandle(handle)) - return info + return info def GetBinaryType(filepath): - res = api.DWORD() - handle_nonzero_success(api._GetBinaryType(filepath, res)) - return res + res = api.DWORD() + handle_nonzero_success(api._GetBinaryType(filepath, res)) + return res def _make_null_terminated_list(obs): - obs = _makelist(obs) - if obs is None: - return - return u'\x00'.join(obs) + u'\x00\x00' + obs = _makelist(obs) + if obs is None: + return + return u'\x00'.join(obs) + u'\x00\x00' def _makelist(ob): - if ob is None: - return - if not isinstance(ob, (list, tuple, set)): - return [ob] - return ob + if ob is None: + return + if not isinstance(ob, (list, tuple, set)): + return [ob] + return ob def SHFileOperation(operation, from_, to=None, flags=[]): - flags = functools.reduce(operator.or_, flags, 0) - from_ = _make_null_terminated_list(from_) - to = _make_null_terminated_list(to) - params = api.SHFILEOPSTRUCT(0, operation, from_, to, flags) - res = api._SHFileOperation(params) - if res != 0: - raise RuntimeError("SHFileOperation returned %d" % res) + flags = functools.reduce(operator.or_, flags, 0) + from_ = _make_null_terminated_list(from_) + to = _make_null_terminated_list(to) + params = api.SHFILEOPSTRUCT(0, operation, from_, to, flags) + res = api._SHFileOperation(params) + if res != 0: + raise RuntimeError("SHFileOperation returned %d" % res) def join(*paths): - r""" - Wrapper around os.path.join that works with Windows drive letters. + r""" + Wrapper around os.path.join that works with Windows drive letters. - >>> join('d:\\foo', '\\bar') - 'd:\\bar' - """ - paths_with_drives = map(os.path.splitdrive, paths) - drives, paths = zip(*paths_with_drives) - # the drive we care about is the last one in the list - drive = next(filter(None, reversed(drives)), '') - return os.path.join(drive, os.path.join(*paths)) + >>> join('d:\\foo', '\\bar') + 'd:\\bar' + """ + paths_with_drives = map(os.path.splitdrive, paths) + drives, paths = zip(*paths_with_drives) + # the drive we care about is the last one in the list + drive = next(filter(None, reversed(drives)), '') + return os.path.join(drive, os.path.join(*paths)) def resolve_path(target, start=os.path.curdir): - r""" - Find a path from start to target where target is relative to start. + r""" + Find a path from start to target where target is relative to start. - >>> tmp = str(getfixture('tmpdir_as_cwd')) + >>> tmp = str(getfixture('tmpdir_as_cwd')) - >>> findpath('d:\\') - 'd:\\' + >>> findpath('d:\\') + 'd:\\' - >>> findpath('d:\\', tmp) - 'd:\\' + >>> findpath('d:\\', tmp) + 'd:\\' - >>> findpath('\\bar', 'd:\\') - 'd:\\bar' + >>> findpath('\\bar', 'd:\\') + 'd:\\bar' - >>> findpath('\\bar', 'd:\\foo') # fails with '\\bar' - 'd:\\bar' + >>> findpath('\\bar', 'd:\\foo') # fails with '\\bar' + 'd:\\bar' - >>> findpath('bar', 'd:\\foo') - 'd:\\foo\\bar' + >>> findpath('bar', 'd:\\foo') + 'd:\\foo\\bar' - >>> findpath('\\baz', 'd:\\foo\\bar') # fails with '\\baz' - 'd:\\baz' + >>> findpath('\\baz', 'd:\\foo\\bar') # fails with '\\baz' + 'd:\\baz' - >>> os.path.abspath(findpath('\\bar')).lower() - 'c:\\bar' + >>> os.path.abspath(findpath('\\bar')).lower() + 'c:\\bar' - >>> os.path.abspath(findpath('bar')) - '...\\bar' + >>> os.path.abspath(findpath('bar')) + '...\\bar' - >>> findpath('..', 'd:\\foo\\bar') - 'd:\\foo' + >>> findpath('..', 'd:\\foo\\bar') + 'd:\\foo' - The parent of the root directory is the root directory. - >>> findpath('..', 'd:\\') - 'd:\\' - """ - return os.path.normpath(join(start, target)) + The parent of the root directory is the root directory. + >>> findpath('..', 'd:\\') + 'd:\\' + """ + return os.path.normpath(join(start, target)) findpath = resolve_path def trace_symlink_target(link): - """ - Given a file that is known to be a symlink, trace it to its ultimate - target. + """ + Given a file that is known to be a symlink, trace it to its ultimate + target. - Raises TargetNotPresent when the target cannot be determined. - Raises ValueError when the specified link is not a symlink. - """ + Raises TargetNotPresent when the target cannot be determined. + Raises ValueError when the specified link is not a symlink. + """ - if not is_symlink(link): - raise ValueError("link must point to a symlink on the system") - while is_symlink(link): - orig = os.path.dirname(link) - link = readlink(link) - link = resolve_path(link, orig) - return link + if not is_symlink(link): + raise ValueError("link must point to a symlink on the system") + while is_symlink(link): + orig = os.path.dirname(link) + link = readlink(link) + link = resolve_path(link, orig) + return link def readlink(link): - """ - readlink(link) -> target - Return a string representing the path to which the symbolic link points. - """ - handle = api.CreateFile( - link, - 0, - 0, - None, - api.OPEN_EXISTING, - api.FILE_FLAG_OPEN_REPARSE_POINT | api.FILE_FLAG_BACKUP_SEMANTICS, - None, - ) + """ + readlink(link) -> target + Return a string representing the path to which the symbolic link points. + """ + handle = api.CreateFile( + link, + 0, + 0, + None, + api.OPEN_EXISTING, + api.FILE_FLAG_OPEN_REPARSE_POINT | api.FILE_FLAG_BACKUP_SEMANTICS, + None, + ) - if handle == api.INVALID_HANDLE_VALUE: - raise WindowsError() + if handle == api.INVALID_HANDLE_VALUE: + raise WindowsError() - res = reparse.DeviceIoControl( - handle, api.FSCTL_GET_REPARSE_POINT, None, 10240) + res = reparse.DeviceIoControl(handle, api.FSCTL_GET_REPARSE_POINT, None, 10240) - bytes = create_string_buffer(res) - p_rdb = cast(bytes, POINTER(api.REPARSE_DATA_BUFFER)) - rdb = p_rdb.contents - if not rdb.tag == api.IO_REPARSE_TAG_SYMLINK: - raise RuntimeError("Expected IO_REPARSE_TAG_SYMLINK, but got %d" % rdb.tag) + bytes = create_string_buffer(res) + p_rdb = cast(bytes, POINTER(api.REPARSE_DATA_BUFFER)) + rdb = p_rdb.contents + if not rdb.tag == api.IO_REPARSE_TAG_SYMLINK: + raise RuntimeError("Expected IO_REPARSE_TAG_SYMLINK, but got %d" % rdb.tag) - handle_nonzero_success(api.CloseHandle(handle)) - return rdb.get_substitute_name() + handle_nonzero_success(api.CloseHandle(handle)) + return rdb.get_substitute_name() def patch_os_module(): - """ - jaraco.windows provides the os.symlink and os.readlink functions. - Monkey-patch the os module to include them if not present. - """ - if not hasattr(os, 'symlink'): - os.symlink = symlink - os.path.islink = islink - if not hasattr(os, 'readlink'): - os.readlink = readlink + """ + jaraco.windows provides the os.symlink and os.readlink functions. + Monkey-patch the os module to include them if not present. + """ + if not hasattr(os, 'symlink'): + os.symlink = symlink + os.path.islink = islink + if not hasattr(os, 'readlink'): + os.readlink = readlink def find_symlinks(root): - for dirpath, dirnames, filenames in os.walk(root): - for name in dirnames + filenames: - pathname = os.path.join(dirpath, name) - if is_symlink(pathname): - yield pathname - # don't traverse symlinks - if name in dirnames: - dirnames.remove(name) + for dirpath, dirnames, filenames in os.walk(root): + for name in dirnames + filenames: + pathname = os.path.join(dirpath, name) + if is_symlink(pathname): + yield pathname + # don't traverse symlinks + if name in dirnames: + dirnames.remove(name) def find_symlinks_cmd(): - """ - %prog [start-path] - Search the specified path (defaults to the current directory) for symlinks, - printing the source and target on each line. - """ - from optparse import OptionParser - from textwrap import dedent - parser = OptionParser(usage=dedent(find_symlinks_cmd.__doc__).strip()) - options, args = parser.parse_args() - if not args: - args = ['.'] - root = args.pop() - if args: - parser.error("unexpected argument(s)") - try: - for symlink in find_symlinks(root): - target = readlink(symlink) - dir = ['', 'D'][os.path.isdir(symlink)] - msg = '{dir:2}{symlink} --> {target}'.format(**locals()) - print(msg) - except KeyboardInterrupt: - pass + """ + %prog [start-path] + Search the specified path (defaults to the current directory) for symlinks, + printing the source and target on each line. + """ + from optparse import OptionParser + from textwrap import dedent + + parser = OptionParser(usage=dedent(find_symlinks_cmd.__doc__).strip()) + options, args = parser.parse_args() + if not args: + args = ['.'] + root = args.pop() + if args: + parser.error("unexpected argument(s)") + try: + for symlink in find_symlinks(root): + target = readlink(symlink) + dir = ['', 'D'][os.path.isdir(symlink)] + msg = '{dir:2}{symlink} --> {target}'.format(**locals()) + print(msg) + except KeyboardInterrupt: + pass -@six.add_metaclass(binary.BitMask) -class FileAttributes(int): +class FileAttributes(int, metaclass=binary.BitMask): - # extract the values from the stat module on Python 3.5 - # and later. - locals().update( - (name.split('FILE_ATTRIBUTES_')[1].lower(), value) - for name, value in vars(stat).items() - if name.startswith('FILE_ATTRIBUTES_') - ) + # extract the values from the stat module on Python 3.5 + # and later. + locals().update( + (name.split('FILE_ATTRIBUTES_')[1].lower(), value) + for name, value in vars(stat).items() + if name.startswith('FILE_ATTRIBUTES_') + ) - # For Python 3.4 and earlier, define the constants here - archive = 0x20 - compressed = 0x800 - hidden = 0x2 - device = 0x40 - directory = 0x10 - encrypted = 0x4000 - normal = 0x80 - not_content_indexed = 0x2000 - offline = 0x1000 - read_only = 0x1 - reparse_point = 0x400 - sparse_file = 0x200 - system = 0x4 - temporary = 0x100 - virtual = 0x10000 + # For Python 3.4 and earlier, define the constants here + archive = 0x20 + compressed = 0x800 + hidden = 0x2 + device = 0x40 + directory = 0x10 + encrypted = 0x4000 + normal = 0x80 + not_content_indexed = 0x2000 + offline = 0x1000 + read_only = 0x1 + reparse_point = 0x400 + sparse_file = 0x200 + system = 0x4 + temporary = 0x100 + virtual = 0x10000 - @classmethod - def get(cls, filepath): - attrs = api.GetFileAttributes(filepath) - if attrs == api.INVALID_FILE_ATTRIBUTES: - raise WindowsError() - return cls(attrs) + @classmethod + def get(cls, filepath): + attrs = api.GetFileAttributes(filepath) + if attrs == api.INVALID_FILE_ATTRIBUTES: + raise WindowsError() + return cls(attrs) GetFileAttributes = FileAttributes.get def SetFileAttributes(filepath, *attrs): - """ - Set file attributes. e.g.: + """ + Set file attributes. e.g.: - SetFileAttributes('C:\\foo', 'hidden') + SetFileAttributes('C:\\foo', 'hidden') - Each attr must be either a numeric value, a constant defined in - jaraco.windows.filesystem.api, or one of the nice names - defined in this function. - """ - nice_names = collections.defaultdict( - lambda key: key, - hidden='FILE_ATTRIBUTE_HIDDEN', - read_only='FILE_ATTRIBUTE_READONLY', - ) - flags = (getattr(api, nice_names[attr], attr) for attr in attrs) - flags = functools.reduce(operator.or_, flags) - handle_nonzero_success(api.SetFileAttributes(filepath, flags)) + Each attr must be either a numeric value, a constant defined in + jaraco.windows.filesystem.api, or one of the nice names + defined in this function. + """ + nice_names = collections.defaultdict( + lambda key: key, + hidden='FILE_ATTRIBUTE_HIDDEN', + read_only='FILE_ATTRIBUTE_READONLY', + ) + flags = (getattr(api, nice_names[attr], attr) for attr in attrs) + flags = functools.reduce(operator.or_, flags) + handle_nonzero_success(api.SetFileAttributes(filepath, flags)) diff --git a/libs/win/jaraco/windows/filesystem/backports.py b/libs/win/jaraco/windows/filesystem/backports.py index abb45d07..d26c92b8 100644 --- a/libs/win/jaraco/windows/filesystem/backports.py +++ b/libs/win/jaraco/windows/filesystem/backports.py @@ -1,109 +1,107 @@ -from __future__ import unicode_literals - import os.path # realpath taken from https://bugs.python.org/file38057/issue9949-v4.patch def realpath(path): - if isinstance(path, str): - prefix = '\\\\?\\' - unc_prefix = prefix + 'UNC' - new_unc_prefix = '\\' - cwd = os.getcwd() - else: - prefix = b'\\\\?\\' - unc_prefix = prefix + b'UNC' - new_unc_prefix = b'\\' - cwd = os.getcwdb() - had_prefix = path.startswith(prefix) - path, ok = _resolve_path(cwd, path, {}) - # The path returned by _getfinalpathname will always start with \\?\ - - # strip off that prefix unless it was already provided on the original - # path. - if not had_prefix: - # For UNC paths, the prefix will actually be \\?\UNC - handle that - # case as well. - if path.startswith(unc_prefix): - path = new_unc_prefix + path[len(unc_prefix):] - elif path.startswith(prefix): - path = path[len(prefix):] - return path + if isinstance(path, str): + prefix = '\\\\?\\' + unc_prefix = prefix + 'UNC' + new_unc_prefix = '\\' + cwd = os.getcwd() + else: + prefix = b'\\\\?\\' + unc_prefix = prefix + b'UNC' + new_unc_prefix = b'\\' + cwd = os.getcwdb() + had_prefix = path.startswith(prefix) + path, ok = _resolve_path(cwd, path, {}) + # The path returned by _getfinalpathname will always start with \\?\ - + # strip off that prefix unless it was already provided on the original + # path. + if not had_prefix: + # For UNC paths, the prefix will actually be \\?\UNC - handle that + # case as well. + if path.startswith(unc_prefix): + path = new_unc_prefix + path[len(unc_prefix) :] + elif path.startswith(prefix): + path = path[len(prefix) :] + return path -def _resolve_path(path, rest, seen): - # Windows normalizes the path before resolving symlinks; be sure to - # follow the same behavior. - rest = os.path.normpath(rest) +def _resolve_path(path, rest, seen): # noqa: C901 + # Windows normalizes the path before resolving symlinks; be sure to + # follow the same behavior. + rest = os.path.normpath(rest) - if isinstance(rest, str): - sep = '\\' - else: - sep = b'\\' + if isinstance(rest, str): + sep = '\\' + else: + sep = b'\\' - if os.path.isabs(rest): - drive, rest = os.path.splitdrive(rest) - path = drive + sep - rest = rest[1:] + if os.path.isabs(rest): + drive, rest = os.path.splitdrive(rest) + path = drive + sep + rest = rest[1:] - while rest: - name, _, rest = rest.partition(sep) - new_path = os.path.join(path, name) if path else name - if os.path.exists(new_path): - if not rest: - # The whole path exists. Resolve it using the OS. - path = os.path._getfinalpathname(new_path) - else: - # The OS can resolve `new_path`; keep traversing the path. - path = new_path - elif not os.path.lexists(new_path): - # `new_path` does not exist on the filesystem at all. Use the - # OS to resolve `path`, if it exists, and then append the - # remainder. - if os.path.exists(path): - path = os.path._getfinalpathname(path) - rest = os.path.join(name, rest) if rest else name - return os.path.join(path, rest), True - else: - # We have a symbolic link that the OS cannot resolve. Try to - # resolve it ourselves. + while rest: + name, _, rest = rest.partition(sep) + new_path = os.path.join(path, name) if path else name + if os.path.exists(new_path): + if not rest: + # The whole path exists. Resolve it using the OS. + path = os.path._getfinalpathname(new_path) + else: + # The OS can resolve `new_path`; keep traversing the path. + path = new_path + elif not os.path.lexists(new_path): + # `new_path` does not exist on the filesystem at all. Use the + # OS to resolve `path`, if it exists, and then append the + # remainder. + if os.path.exists(path): + path = os.path._getfinalpathname(path) + rest = os.path.join(name, rest) if rest else name + return os.path.join(path, rest), True + else: + # We have a symbolic link that the OS cannot resolve. Try to + # resolve it ourselves. - # On Windows, symbolic link resolution can be partially or - # fully disabled [1]. The end result of a disabled symlink - # appears the same as a broken symlink (lexists() returns True - # but exists() returns False). And in both cases, the link can - # still be read using readlink(). Call stat() and check the - # resulting error code to ensure we don't circumvent the - # Windows symbolic link restrictions. - # [1] https://technet.microsoft.com/en-us/library/cc754077.aspx - try: - os.stat(new_path) - except OSError as e: - # WinError 1463: The symbolic link cannot be followed - # because its type is disabled. - if e.winerror == 1463: - raise + # On Windows, symbolic link resolution can be partially or + # fully disabled [1]. The end result of a disabled symlink + # appears the same as a broken symlink (lexists() returns True + # but exists() returns False). And in both cases, the link can + # still be read using readlink(). Call stat() and check the + # resulting error code to ensure we don't circumvent the + # Windows symbolic link restrictions. + # [1] https://technet.microsoft.com/en-us/library/cc754077.aspx + try: + os.stat(new_path) + except OSError as e: + # WinError 1463: The symbolic link cannot be followed + # because its type is disabled. + if e.winerror == 1463: + raise - key = os.path.normcase(new_path) - if key in seen: - # This link has already been seen; try to use the - # previously resolved value. - path = seen[key] - if path is None: - # It has not yet been resolved, which means we must - # have a symbolic link loop. Return what we have - # resolved so far plus the remainder of the path (who - # cares about the Zen of Python?). - path = os.path.join(new_path, rest) if rest else new_path - return path, False - else: - # Mark this link as in the process of being resolved. - seen[key] = None - # Try to resolve it. - path, ok = _resolve_path(path, os.readlink(new_path), seen) - if ok: - # Resolution succeded; store the resolved value. - seen[key] = path - else: - # Resolution failed; punt. - return (os.path.join(path, rest) if rest else path), False - return path, True + key = os.path.normcase(new_path) + if key in seen: + # This link has already been seen; try to use the + # previously resolved value. + path = seen[key] + if path is None: + # It has not yet been resolved, which means we must + # have a symbolic link loop. Return what we have + # resolved so far plus the remainder of the path (who + # cares about the Zen of Python?). + path = os.path.join(new_path, rest) if rest else new_path + return path, False + else: + # Mark this link as in the process of being resolved. + seen[key] = None + # Try to resolve it. + path, ok = _resolve_path(path, os.readlink(new_path), seen) + if ok: + # Resolution succeded; store the resolved value. + seen[key] = path + else: + # Resolution failed; punt. + return (os.path.join(path, rest) if rest else path), False + return path, True diff --git a/libs/win/jaraco/windows/filesystem/change.py b/libs/win/jaraco/windows/filesystem/change.py index 620d9272..7a3c508f 100644 --- a/libs/win/jaraco/windows/filesystem/change.py +++ b/libs/win/jaraco/windows/filesystem/change.py @@ -1,14 +1,10 @@ -# -*- coding: UTF-8 -*- - """ FileChange - Classes and routines for monitoring the file system for changes. + Classes and routines for monitoring the file system for changes. Copyright © 2004, 2011, 2013 Jason R. Coombs """ -from __future__ import print_function - import os import sys import datetime @@ -17,8 +13,6 @@ from threading import Thread import itertools import logging -import six - from more_itertools.recipes import consume import jaraco.text @@ -29,243 +23,237 @@ log = logging.getLogger(__name__) class NotifierException(Exception): - pass + pass class FileFilter(object): - def set_root(self, root): - self.root = root + def set_root(self, root): + self.root = root - def _get_file_path(self, filename): - try: - filename = os.path.join(self.root, filename) - except AttributeError: - pass - return filename + def _get_file_path(self, filename): + try: + filename = os.path.join(self.root, filename) + except AttributeError: + pass + return filename class ModifiedTimeFilter(FileFilter): - """ - Returns true for each call where the modified time of the file is after - the cutoff time. - """ - def __init__(self, cutoff): - self.cutoff = cutoff + """ + Returns true for each call where the modified time of the file is after + the cutoff time. + """ - def __call__(self, file): - filepath = self._get_file_path(file) - last_mod = datetime.datetime.utcfromtimestamp( - os.stat(filepath).st_mtime) - log.debug('{filepath} last modified at {last_mod}.'.format(**vars())) - return last_mod > self.cutoff + def __init__(self, cutoff): + self.cutoff = cutoff + + def __call__(self, file): + filepath = self._get_file_path(file) + last_mod = datetime.datetime.utcfromtimestamp(os.stat(filepath).st_mtime) + log.debug('{filepath} last modified at {last_mod}.'.format(**vars())) + return last_mod > self.cutoff class PatternFilter(FileFilter): - """ - Filter that returns True for files that match pattern (a regular - expression). - """ - def __init__(self, pattern): - self.pattern = ( - re.compile(pattern) if isinstance(pattern, six.string_types) - else pattern - ) + """ + Filter that returns True for files that match pattern (a regular + expression). + """ - def __call__(self, file): - return bool(self.pattern.match(file, re.I)) + def __init__(self, pattern): + self.pattern = re.compile(pattern) if isinstance(pattern, str) else pattern + + def __call__(self, file): + return bool(self.pattern.match(file, re.I)) class GlobFilter(PatternFilter): - """ - Filter that returns True for files that match the pattern (a glob - expression. - """ - def __init__(self, expression): - super(GlobFilter, self).__init__( - self.convert_file_pattern(expression)) + """ + Filter that returns True for files that match the pattern (a glob + expression. + """ - @staticmethod - def convert_file_pattern(p): - r""" - converts a filename specification (such as c:\*.*) to an equivelent - regular expression - >>> GlobFilter.convert_file_pattern('/*') - '/.*' - """ - subs = (('\\', '\\\\'), ('.', '\\.'), ('*', '.*'), ('?', '.')) - return jaraco.text.multi_substitution(*subs)(p) + def __init__(self, expression): + super(GlobFilter, self).__init__(self.convert_file_pattern(expression)) + + @staticmethod + def convert_file_pattern(p): + r""" + converts a filename specification (such as c:\*.*) to an equivelent + regular expression + >>> GlobFilter.convert_file_pattern('/*') + '/.*' + """ + subs = (('\\', '\\\\'), ('.', '\\.'), ('*', '.*'), ('?', '.')) + return jaraco.text.multi_substitution(*subs)(p) class AggregateFilter(FileFilter): - """ - This file filter will aggregate the filters passed to it, and when called, - will return the results of each filter ANDed together. - """ - def __init__(self, *filters): - self.filters = filters + """ + This file filter will aggregate the filters passed to it, and when called, + will return the results of each filter ANDed together. + """ - def set_root(self, root): - consume(f.set_root(root) for f in self.filters) + def __init__(self, *filters): + self.filters = filters - def __call__(self, file): - return all(fil(file) for fil in self.filters) + def set_root(self, root): + consume(f.set_root(root) for f in self.filters) + + def __call__(self, file): + return all(fil(file) for fil in self.filters) class OncePerModFilter(FileFilter): - def __init__(self): - self.history = list() + def __init__(self): + self.history = list() - def __call__(self, file): - file = os.path.join(self.root, file) - key = file, os.stat(file).st_mtime - result = key not in self.history - self.history.append(key) - if len(self.history) > 100: - del self.history[-50:] - return result + def __call__(self, file): + file = os.path.join(self.root, file) + key = file, os.stat(file).st_mtime + result = key not in self.history + self.history.append(key) + if len(self.history) > 100: + del self.history[-50:] + return result def files_with_path(files, path): - return (os.path.join(path, file) for file in files) + return (os.path.join(path, file) for file in files) def get_file_paths(walk_result): - root, dirs, files = walk_result - return files_with_path(files, root) + root, dirs, files = walk_result + return files_with_path(files, root) class Notifier(object): - def __init__(self, root='.', filters=[]): - # assign the root, verify it exists - self.root = root - if not os.path.isdir(self.root): - raise NotifierException( - 'Root directory "%s" does not exist' % self.root) - self.filters = filters + def __init__(self, root='.', filters=[]): + # assign the root, verify it exists + self.root = root + if not os.path.isdir(self.root): + raise NotifierException('Root directory "%s" does not exist' % self.root) + self.filters = filters - self.watch_subtree = False - self.quit_event = event.CreateEvent(None, 0, 0, None) - self.opm_filter = OncePerModFilter() + self.watch_subtree = False + self.quit_event = event.CreateEvent(None, 0, 0, None) + self.opm_filter = OncePerModFilter() - def __del__(self): - try: - fs.FindCloseChangeNotification(self.hChange) - except Exception: - pass + def __del__(self): + try: + fs.FindCloseChangeNotification(self.hChange) + except Exception: + pass - def _get_change_handle(self): - # set up to monitor the directory tree specified - self.hChange = fs.FindFirstChangeNotification( - self.root, - self.watch_subtree, - fs.FILE_NOTIFY_CHANGE_LAST_WRITE, - ) + def _get_change_handle(self): + # set up to monitor the directory tree specified + self.hChange = fs.FindFirstChangeNotification( + self.root, self.watch_subtree, fs.FILE_NOTIFY_CHANGE_LAST_WRITE + ) - # make sure it worked; if not, bail - INVALID_HANDLE_VALUE = fs.INVALID_HANDLE_VALUE - if self.hChange == INVALID_HANDLE_VALUE: - raise NotifierException( - 'Could not set up directory change notification') + # make sure it worked; if not, bail + INVALID_HANDLE_VALUE = fs.INVALID_HANDLE_VALUE + if self.hChange == INVALID_HANDLE_VALUE: + raise NotifierException('Could not set up directory change notification') - @staticmethod - def _filtered_walk(path, file_filter): - """ - static method that calls os.walk, but filters out - anything that doesn't match the filter - """ - for root, dirs, files in os.walk(path): - log.debug('looking in %s', root) - log.debug('files is %s', files) - file_filter.set_root(root) - files = filter(file_filter, files) - log.debug('filtered files is %s', files) - yield (root, dirs, files) + @staticmethod + def _filtered_walk(path, file_filter): + """ + static method that calls os.walk, but filters out + anything that doesn't match the filter + """ + for root, dirs, files in os.walk(path): + log.debug('looking in %s', root) + log.debug('files is %s', files) + file_filter.set_root(root) + files = filter(file_filter, files) + log.debug('filtered files is %s', files) + yield (root, dirs, files) - def quit(self): - event.SetEvent(self.quit_event) + def quit(self): + event.SetEvent(self.quit_event) class BlockingNotifier(Notifier): + @staticmethod + def wait_results(*args): + """calls WaitForMultipleObjects repeatedly with args""" + return itertools.starmap(event.WaitForMultipleObjects, itertools.repeat(args)) - @staticmethod - def wait_results(*args): - """ calls WaitForMultipleObjects repeatedly with args """ - return itertools.starmap( - event.WaitForMultipleObjects, - itertools.repeat(args)) + def get_changed_files(self): + self._get_change_handle() + check_time = datetime.datetime.utcnow() + # block (sleep) until something changes in the + # target directory or a quit is requested. + # timeout so we can catch keyboard interrupts or other exceptions + events = (self.hChange, self.quit_event) + for result in self.wait_results(events, False, 1000): + if result == event.WAIT_TIMEOUT: + continue + index = result - event.WAIT_OBJECT_0 + if events[index] is self.quit_event: + # quit was received; stop yielding results + return - def get_changed_files(self): - self._get_change_handle() - check_time = datetime.datetime.utcnow() - # block (sleep) until something changes in the - # target directory or a quit is requested. - # timeout so we can catch keyboard interrupts or other exceptions - events = (self.hChange, self.quit_event) - for result in self.wait_results(events, False, 1000): - if result == event.WAIT_TIMEOUT: - continue - index = result - event.WAIT_OBJECT_0 - if events[index] is self.quit_event: - # quit was received; stop yielding results - return + # something has changed. + log.debug('Change notification received') + fs.FindNextChangeNotification(self.hChange) + next_check_time = datetime.datetime.utcnow() + log.debug('Looking for all files changed after %s', check_time) + for file in self.find_files_after(check_time): + yield file + check_time = next_check_time - # something has changed. - log.debug('Change notification received') - fs.FindNextChangeNotification(self.hChange) - next_check_time = datetime.datetime.utcnow() - log.debug('Looking for all files changed after %s', check_time) - for file in self.find_files_after(check_time): - yield file - check_time = next_check_time - - def find_files_after(self, cutoff): - mtf = ModifiedTimeFilter(cutoff) - af = AggregateFilter(mtf, self.opm_filter, *self.filters) - results = Notifier._filtered_walk(self.root, af) - results = itertools.imap(get_file_paths, results) - if self.watch_subtree: - result = itertools.chain(*results) - else: - result = next(results) - return result + def find_files_after(self, cutoff): + mtf = ModifiedTimeFilter(cutoff) + af = AggregateFilter(mtf, self.opm_filter, *self.filters) + results = Notifier._filtered_walk(self.root, af) + results = itertools.imap(get_file_paths, results) + if self.watch_subtree: + result = itertools.chain(*results) + else: + result = next(results) + return result class ThreadedNotifier(BlockingNotifier, Thread): - r""" - ThreadedNotifier provides a simple interface that calls the handler - for each file rooted in root that passes the filters. It runs as its own - thread, so must be started as such:: + r""" + ThreadedNotifier provides a simple interface that calls the handler + for each file rooted in root that passes the filters. It runs as its own + thread, so must be started as such:: - notifier = ThreadedNotifier('c:\\', handler = StreamHandler()) - notifier.start() - C:\Autoexec.bat changed. - """ - def __init__(self, root='.', filters=[], handler=lambda file: None): - # init notifier stuff - BlockingNotifier.__init__(self, root, filters) - # init thread stuff - Thread.__init__(self) - # set it as a daemon thread so that it doesn't block waiting to close. - # I tried setting __del__(self) to .quit(), but unfortunately, there - # are references to this object in the win32api stuff, so __del__ - # never gets called. - self.setDaemon(True) + notifier = ThreadedNotifier('c:\\', handler = StreamHandler()) + notifier.start() + C:\Autoexec.bat changed. + """ - self.handle = handler + def __init__(self, root='.', filters=[], handler=lambda file: None): + # init notifier stuff + BlockingNotifier.__init__(self, root, filters) + # init thread stuff + Thread.__init__(self) + # set it as a daemon thread so that it doesn't block waiting to close. + # I tried setting __del__(self) to .quit(), but unfortunately, there + # are references to this object in the win32api stuff, so __del__ + # never gets called. + self.setDaemon(True) - def run(self): - for file in self.get_changed_files(): - self.handle(file) + self.handle = handler + + def run(self): + for file in self.get_changed_files(): + self.handle(file) class StreamHandler(object): - """ - StreamHandler: a sample handler object for use with the threaded - notifier that will announce by writing to the supplied stream - (stdout by default) the name of the file. - """ - def __init__(self, output=sys.stdout): - self.output = output + """ + StreamHandler: a sample handler object for use with the threaded + notifier that will announce by writing to the supplied stream + (stdout by default) the name of the file. + """ - def __call__(self, filename): - self.output.write('%s changed.\n' % filename) + def __init__(self, output=sys.stdout): + self.output = output + + def __call__(self, filename): + self.output.write('%s changed.\n' % filename) diff --git a/libs/win/jaraco/windows/inet.py b/libs/win/jaraco/windows/inet.py index 37c40cda..4c98dc54 100644 --- a/libs/win/jaraco/windows/inet.py +++ b/libs/win/jaraco/windows/inet.py @@ -3,8 +3,6 @@ Some routines for retrieving the addresses from the local network config. """ -from __future__ import print_function - import itertools import ctypes @@ -13,112 +11,108 @@ from jaraco.windows.api import errors, inet def GetAdaptersAddresses(): - size = ctypes.c_ulong() - res = inet.GetAdaptersAddresses(0, 0, None, None, size) - if res != errors.ERROR_BUFFER_OVERFLOW: - raise RuntimeError("Error getting structure length (%d)" % res) - print(size.value) - pointer_type = ctypes.POINTER(inet.IP_ADAPTER_ADDRESSES) - buffer = ctypes.create_string_buffer(size.value) - struct_p = ctypes.cast(buffer, pointer_type) - res = inet.GetAdaptersAddresses(0, 0, None, struct_p, size) - if res != errors.NO_ERROR: - raise RuntimeError("Error retrieving table (%d)" % res) - while struct_p: - yield struct_p.contents - struct_p = struct_p.contents.next + size = ctypes.c_ulong() + res = inet.GetAdaptersAddresses(0, 0, None, None, size) + if res != errors.ERROR_BUFFER_OVERFLOW: + raise RuntimeError("Error getting structure length (%d)" % res) + print(size.value) + pointer_type = ctypes.POINTER(inet.IP_ADAPTER_ADDRESSES) + buffer = ctypes.create_string_buffer(size.value) + struct_p = ctypes.cast(buffer, pointer_type) + res = inet.GetAdaptersAddresses(0, 0, None, struct_p, size) + if res != errors.NO_ERROR: + raise RuntimeError("Error retrieving table (%d)" % res) + while struct_p: + yield struct_p.contents + struct_p = struct_p.contents.next class AllocatedTable(object): - """ - Both the interface table and the ip address table use the same - technique to store arrays of structures of variable length. This - base class captures the functionality to retrieve and access those - table entries. + """ + Both the interface table and the ip address table use the same + technique to store arrays of structures of variable length. This + base class captures the functionality to retrieve and access those + table entries. - The subclass needs to define three class attributes: - method: a callable that takes three arguments - a pointer to - the structure, the length of the data contained by the - structure, and a boolean of whether the result should - be sorted. - structure: a C structure defininition that describes the table - format. - row_structure: a C structure definition that describes the row - format. - """ - def __get_table_size(self): - """ - Retrieve the size of the buffer needed by calling the method - with a null pointer and length of zero. This should trigger an - insufficient buffer error and return the size needed for the - buffer. - """ - length = ctypes.wintypes.DWORD() - res = self.method(None, length, False) - if res != errors.ERROR_INSUFFICIENT_BUFFER: - raise RuntimeError("Error getting table length (%d)" % res) - return length.value + The subclass needs to define three class attributes: + method: a callable that takes three arguments - a pointer to + the structure, the length of the data contained by the + structure, and a boolean of whether the result should + be sorted. + structure: a C structure defininition that describes the table + format. + row_structure: a C structure definition that describes the row + format. + """ - def get_table(self): - """ - Get the table - """ - buffer_length = self.__get_table_size() - returned_buffer_length = ctypes.wintypes.DWORD(buffer_length) - buffer = ctypes.create_string_buffer(buffer_length) - pointer_type = ctypes.POINTER(self.structure) - table_p = ctypes.cast(buffer, pointer_type) - res = self.method(table_p, returned_buffer_length, False) - if res != errors.NO_ERROR: - raise RuntimeError("Error retrieving table (%d)" % res) - return table_p.contents + def __get_table_size(self): + """ + Retrieve the size of the buffer needed by calling the method + with a null pointer and length of zero. This should trigger an + insufficient buffer error and return the size needed for the + buffer. + """ + length = ctypes.wintypes.DWORD() + res = self.method(None, length, False) + if res != errors.ERROR_INSUFFICIENT_BUFFER: + raise RuntimeError("Error getting table length (%d)" % res) + return length.value - @property - def entries(self): - """ - Using the table structure, return the array of entries based - on the table size. - """ - table = self.get_table() - entries_array = self.row_structure * table.num_entries - pointer_type = ctypes.POINTER(entries_array) - return ctypes.cast(table.entries, pointer_type).contents + def get_table(self): + """ + Get the table + """ + buffer_length = self.__get_table_size() + returned_buffer_length = ctypes.wintypes.DWORD(buffer_length) + buffer = ctypes.create_string_buffer(buffer_length) + pointer_type = ctypes.POINTER(self.structure) + table_p = ctypes.cast(buffer, pointer_type) + res = self.method(table_p, returned_buffer_length, False) + if res != errors.NO_ERROR: + raise RuntimeError("Error retrieving table (%d)" % res) + return table_p.contents + + @property + def entries(self): + """ + Using the table structure, return the array of entries based + on the table size. + """ + table = self.get_table() + entries_array = self.row_structure * table.num_entries + pointer_type = ctypes.POINTER(entries_array) + return ctypes.cast(table.entries, pointer_type).contents class InterfaceTable(AllocatedTable): - method = inet.GetIfTable - structure = inet.MIB_IFTABLE - row_structure = inet.MIB_IFROW + method = inet.GetIfTable + structure = inet.MIB_IFTABLE + row_structure = inet.MIB_IFROW class AddressTable(AllocatedTable): - method = inet.GetIpAddrTable - structure = inet.MIB_IPADDRTABLE - row_structure = inet.MIB_IPADDRROW + method = inet.GetIpAddrTable + structure = inet.MIB_IPADDRTABLE + row_structure = inet.MIB_IPADDRROW class AddressManager(object): - @staticmethod - def hardware_address_to_string(addr): - hex_bytes = (byte.encode('hex') for byte in addr) - return ':'.join(hex_bytes) + @staticmethod + def hardware_address_to_string(addr): + hex_bytes = (byte.encode('hex') for byte in addr) + return ':'.join(hex_bytes) - def get_host_mac_address_strings(self): - return ( - self.hardware_address_to_string(addr) - for addr in self.get_host_mac_addresses()) + def get_host_mac_address_strings(self): + return ( + self.hardware_address_to_string(addr) + for addr in self.get_host_mac_addresses() + ) - def get_host_ip_address_strings(self): - return itertools.imap(str, self.get_host_ip_addresses()) + def get_host_ip_address_strings(self): + return itertools.imap(str, self.get_host_ip_addresses()) - def get_host_mac_addresses(self): - return ( - entry.physical_address - for entry in InterfaceTable().entries - ) + def get_host_mac_addresses(self): + return (entry.physical_address for entry in InterfaceTable().entries) - def get_host_ip_addresses(self): - return ( - entry.address - for entry in AddressTable().entries - ) + def get_host_ip_addresses(self): + return (entry.address for entry in AddressTable().entries) diff --git a/libs/win/jaraco/windows/lib.py b/libs/win/jaraco/windows/lib.py index 0602c8e0..64ebfffb 100644 --- a/libs/win/jaraco/windows/lib.py +++ b/libs/win/jaraco/windows/lib.py @@ -4,18 +4,18 @@ from .api import library def find_lib(lib): - r""" - Find the DLL for a given library. + r""" + Find the DLL for a given library. - Accepts a string or loaded module + Accepts a string or loaded module - >>> print(find_lib('kernel32').lower()) - c:\windows\system32\kernel32.dll - """ - if isinstance(lib, str): - lib = getattr(ctypes.windll, lib) + >>> print(find_lib('kernel32').lower()) + c:\windows\system32\kernel32.dll + """ + if isinstance(lib, str): + lib = getattr(ctypes.windll, lib) - size = 1024 - result = ctypes.create_unicode_buffer(size) - library.GetModuleFileName(lib._handle, result, size) - return result.value + size = 1024 + result = ctypes.create_unicode_buffer(size) + library.GetModuleFileName(lib._handle, result, size) + return result.value diff --git a/libs/win/jaraco/windows/memory.py b/libs/win/jaraco/windows/memory.py index d4bcb83c..1e989376 100644 --- a/libs/win/jaraco/windows/memory.py +++ b/libs/win/jaraco/windows/memory.py @@ -5,25 +5,25 @@ from .api import memory class LockedMemory(object): - def __init__(self, handle): - self.handle = handle + def __init__(self, handle): + self.handle = handle - def __enter__(self): - self.data_ptr = memory.GlobalLock(self.handle) - if not self.data_ptr: - del self.data_ptr - raise WinError() - return self + def __enter__(self): + self.data_ptr = memory.GlobalLock(self.handle) + if not self.data_ptr: + del self.data_ptr + raise WinError() + return self - def __exit__(self, *args): - memory.GlobalUnlock(self.handle) - del self.data_ptr + def __exit__(self, *args): + memory.GlobalUnlock(self.handle) + del self.data_ptr - @property - def data(self): - with self: - return ctypes.string_at(self.data_ptr, self.size) + @property + def data(self): + with self: + return ctypes.string_at(self.data_ptr, self.size) - @property - def size(self): - return memory.GlobalSize(self.data_ptr) + @property + def size(self): + return memory.GlobalSize(self.data_ptr) diff --git a/libs/win/jaraco/windows/mmap.py b/libs/win/jaraco/windows/mmap.py index 11460894..c64c2548 100644 --- a/libs/win/jaraco/windows/mmap.py +++ b/libs/win/jaraco/windows/mmap.py @@ -1,63 +1,66 @@ import ctypes.wintypes -import six - from .error import handle_nonzero_success from .api import memory class MemoryMap(object): - """ - A memory map object which can have security attributes overridden. - """ - def __init__(self, name, length, security_attributes=None): - self.name = name - self.length = length - self.security_attributes = security_attributes - self.pos = 0 + """ + A memory map object which can have security attributes overridden. + """ - def __enter__(self): - p_SA = ( - ctypes.byref(self.security_attributes) - if self.security_attributes else None - ) - INVALID_HANDLE_VALUE = -1 - PAGE_READWRITE = 0x4 - FILE_MAP_WRITE = 0x2 - filemap = ctypes.windll.kernel32.CreateFileMappingW( - INVALID_HANDLE_VALUE, p_SA, PAGE_READWRITE, 0, self.length, - six.text_type(self.name)) - handle_nonzero_success(filemap) - if filemap == INVALID_HANDLE_VALUE: - raise Exception("Failed to create file mapping") - self.filemap = filemap - self.view = memory.MapViewOfFile(filemap, FILE_MAP_WRITE, 0, 0, 0) - return self + def __init__(self, name, length, security_attributes=None): + self.name = name + self.length = length + self.security_attributes = security_attributes + self.pos = 0 - def seek(self, pos): - self.pos = pos + def __enter__(self): + p_SA = ( + ctypes.byref(self.security_attributes) if self.security_attributes else None + ) + INVALID_HANDLE_VALUE = -1 + PAGE_READWRITE = 0x4 + FILE_MAP_WRITE = 0x2 + filemap = ctypes.windll.kernel32.CreateFileMappingW( + INVALID_HANDLE_VALUE, + p_SA, + PAGE_READWRITE, + 0, + self.length, + str(self.name), + ) + handle_nonzero_success(filemap) + if filemap == INVALID_HANDLE_VALUE: + raise Exception("Failed to create file mapping") + self.filemap = filemap + self.view = memory.MapViewOfFile(filemap, FILE_MAP_WRITE, 0, 0, 0) + return self - def write(self, msg): - assert isinstance(msg, bytes) - n = len(msg) - if self.pos + n >= self.length: # A little safety. - raise ValueError("Refusing to write %d bytes" % n) - dest = self.view + self.pos - length = ctypes.c_size_t(n) - ctypes.windll.kernel32.RtlMoveMemory(dest, msg, length) - self.pos += n + def seek(self, pos): + self.pos = pos - def read(self, n): - """ - Read n bytes from mapped view. - """ - out = ctypes.create_string_buffer(n) - source = self.view + self.pos - length = ctypes.c_size_t(n) - ctypes.windll.kernel32.RtlMoveMemory(out, source, length) - self.pos += n - return out.raw + def write(self, msg): + assert isinstance(msg, bytes) + n = len(msg) + if self.pos + n >= self.length: # A little safety. + raise ValueError("Refusing to write %d bytes" % n) + dest = self.view + self.pos + length = ctypes.c_size_t(n) + ctypes.windll.kernel32.RtlMoveMemory(dest, msg, length) + self.pos += n - def __exit__(self, exc_type, exc_val, tb): - ctypes.windll.kernel32.UnmapViewOfFile(self.view) - ctypes.windll.kernel32.CloseHandle(self.filemap) + def read(self, n): + """ + Read n bytes from mapped view. + """ + out = ctypes.create_string_buffer(n) + source = self.view + self.pos + length = ctypes.c_size_t(n) + ctypes.windll.kernel32.RtlMoveMemory(out, source, length) + self.pos += n + return out.raw + + def __exit__(self, exc_type, exc_val, tb): + ctypes.windll.kernel32.UnmapViewOfFile(self.view) + ctypes.windll.kernel32.CloseHandle(self.filemap) diff --git a/libs/win/jaraco/windows/msie.py b/libs/win/jaraco/windows/msie.py index c4b5793c..d4136182 100644 --- a/libs/win/jaraco/windows/msie.py +++ b/libs/win/jaraco/windows/msie.py @@ -1,5 +1,3 @@ -# -*- coding: UTF-8 -*- - """cookies.py Cookie support utilities @@ -8,52 +6,50 @@ Cookie support utilities import os import itertools -import six - class CookieMonster(object): - "Read cookies out of a user's IE cookies file" + "Read cookies out of a user's IE cookies file" - @property - def cookie_dir(self): - import _winreg as winreg - key = winreg.OpenKeyEx( - winreg.HKEY_CURRENT_USER, 'Software' - '\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders') - cookie_dir, type = winreg.QueryValueEx(key, 'Cookies') - return cookie_dir + @property + def cookie_dir(self): + import _winreg as winreg - def entries(self, filename): - with open(os.path.join(self.cookie_dir, filename)) as cookie_file: - while True: - entry = itertools.takewhile( - self.is_not_cookie_delimiter, - cookie_file) - entry = list(map(six.text_type.rstrip, entry)) - if not entry: - break - cookie = self.make_cookie(*entry) - yield cookie + key = winreg.OpenKeyEx( + winreg.HKEY_CURRENT_USER, + 'Software' r'\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders', + ) + cookie_dir, type = winreg.QueryValueEx(key, 'Cookies') + return cookie_dir - @staticmethod - def is_not_cookie_delimiter(s): - return s != '*\n' + def entries(self, filename): + with open(os.path.join(self.cookie_dir, filename)) as cookie_file: + while True: + entry = itertools.takewhile(self.is_not_cookie_delimiter, cookie_file) + entry = [item.rstrip() for item in entry] + if not entry: + break + cookie = self.make_cookie(*entry) + yield cookie - @staticmethod - def make_cookie( - key, value, domain, flags, ExpireLow, ExpireHigh, - CreateLow, CreateHigh): - expires = (int(ExpireHigh) << 32) | int(ExpireLow) - created = (int(CreateHigh) << 32) | int(CreateLow) - flags = int(flags) - domain, sep, path = domain.partition('/') - path = '/' + path - return dict( - key=key, - value=value, - domain=domain, - flags=flags, - expires=expires, - created=created, - path=path, - ) + @staticmethod + def is_not_cookie_delimiter(s): + return s != '*\n' + + @staticmethod + def make_cookie( + key, value, domain, flags, ExpireLow, ExpireHigh, CreateLow, CreateHigh + ): + expires = (int(ExpireHigh) << 32) | int(ExpireLow) + created = (int(CreateHigh) << 32) | int(CreateLow) + flags = int(flags) + domain, sep, path = domain.partition('/') + path = '/' + path + return dict( + key=key, + value=value, + domain=domain, + flags=flags, + expires=expires, + created=created, + path=path, + ) diff --git a/libs/win/jaraco/windows/msvc.py b/libs/win/jaraco/windows/msvc.py new file mode 100644 index 00000000..b060c7b0 --- /dev/null +++ b/libs/win/jaraco/windows/msvc.py @@ -0,0 +1,37 @@ +import subprocess + + +default_components = [ + 'Microsoft.VisualStudio.Component.CoreEditor', + 'Microsoft.VisualStudio.Workload.CoreEditor', + 'Microsoft.VisualStudio.Component.Roslyn.Compiler', + 'Microsoft.Component.MSBuild', + 'Microsoft.VisualStudio.Component.TextTemplating', + 'Microsoft.VisualStudio.Component.VC.CoreIde', + 'Microsoft.VisualStudio.Component.VC.Tools.x86.x64', + 'Microsoft.VisualStudio.Component.VC.Tools.ARM64', + 'Microsoft.VisualStudio.Component.Windows10SDK.19041', + 'Microsoft.VisualStudio.Component.VC.Redist.14.Latest', + 'Microsoft.VisualStudio.ComponentGroup.NativeDesktop.Core', + 'Microsoft.VisualStudio.Workload.NativeDesktop', +] + + +def install(components=default_components): + cmd = [ + 'vs_buildtools', + '--quiet', + '--wait', + '--norestart', + '--nocache', + '--installPath', + 'C:\\BuildTools', + ] + for component in components: + cmd += ['--add', component] + res = subprocess.Popen(cmd).wait() + if res != 3010: + raise SystemExit(res) + + +__name__ == '__main__' and install() diff --git a/libs/win/jaraco/windows/net.py b/libs/win/jaraco/windows/net.py index 709f0dbf..4057a5f3 100644 --- a/libs/win/jaraco/windows/net.py +++ b/libs/win/jaraco/windows/net.py @@ -2,29 +2,30 @@ API hooks for network stuff. """ -__all__ = ('AddConnection') +__all__ = 'AddConnection' from jaraco.windows.error import WindowsError from .api import net def AddConnection( - remote_name, type=net.RESOURCETYPE_ANY, local_name=None, - provider_name=None, user=None, password=None, flags=0): - resource = net.NETRESOURCE( - type=type, - remote_name=remote_name, - local_name=local_name, - provider_name=provider_name, - # WNetAddConnection2 ignores the other members of NETRESOURCE - ) + remote_name, + type=net.RESOURCETYPE_ANY, + local_name=None, + provider_name=None, + user=None, + password=None, + flags=0, +): + resource = net.NETRESOURCE( + type=type, + remote_name=remote_name, + local_name=local_name, + provider_name=provider_name, + # WNetAddConnection2 ignores the other members of NETRESOURCE + ) - result = net.WNetAddConnection2( - resource, - password, - user, - flags, - ) + result = net.WNetAddConnection2(resource, password, user, flags) - if result != 0: - raise WindowsError(result) + if result != 0: + raise WindowsError(result) diff --git a/libs/win/jaraco/windows/power.py b/libs/win/jaraco/windows/power.py index 8d8276fa..f88f1308 100644 --- a/libs/win/jaraco/windows/power.py +++ b/libs/win/jaraco/windows/power.py @@ -1,77 +1,75 @@ -# -*- coding: utf-8 -*- - -from __future__ import print_function - import itertools import contextlib from more_itertools.recipes import consume, unique_justseen + try: - import wmi as wmilib + import wmi as wmilib except ImportError: - pass + pass from jaraco.windows.error import handle_nonzero_success from .api import power def GetSystemPowerStatus(): - stat = power.SYSTEM_POWER_STATUS() - handle_nonzero_success(GetSystemPowerStatus(stat)) - return stat + stat = power.SYSTEM_POWER_STATUS() + handle_nonzero_success(GetSystemPowerStatus(stat)) + return stat def _init_power_watcher(): - global power_watcher - if 'power_watcher' not in globals(): - wmi = wmilib.WMI() - query = 'SELECT * from Win32_PowerManagementEvent' - power_watcher = wmi.ExecNotificationQuery(query) + global power_watcher + if 'power_watcher' not in globals(): + wmi = wmilib.WMI() + query = 'SELECT * from Win32_PowerManagementEvent' + power_watcher = wmi.ExecNotificationQuery(query) def get_power_management_events(): - _init_power_watcher() - while True: - yield power_watcher.NextEvent() + _init_power_watcher() + while True: + yield power_watcher.NextEvent() def wait_for_power_status_change(): - EVT_POWER_STATUS_CHANGE = 10 + EVT_POWER_STATUS_CHANGE = 10 - def not_power_status_change(evt): - return evt.EventType != EVT_POWER_STATUS_CHANGE - events = get_power_management_events() - consume(itertools.takewhile(not_power_status_change, events)) + def not_power_status_change(evt): + return evt.EventType != EVT_POWER_STATUS_CHANGE + + events = get_power_management_events() + consume(itertools.takewhile(not_power_status_change, events)) def get_unique_power_states(): - """ - Just like get_power_states, but ensures values are returned only - when the state changes. - """ - return unique_justseen(get_power_states()) + """ + Just like get_power_states, but ensures values are returned only + when the state changes. + """ + return unique_justseen(get_power_states()) def get_power_states(): - """ - Continuously return the power state of the system when it changes. - This function will block indefinitely if the power state never - changes. - """ - while True: - state = GetSystemPowerStatus() - yield state.ac_line_status_string - wait_for_power_status_change() + """ + Continuously return the power state of the system when it changes. + This function will block indefinitely if the power state never + changes. + """ + while True: + state = GetSystemPowerStatus() + yield state.ac_line_status_string + wait_for_power_status_change() @contextlib.contextmanager def no_sleep(): - """ - Context that prevents the computer from going to sleep. - """ - mode = power.ES.continuous | power.ES.system_required - handle_nonzero_success(power.SetThreadExecutionState(mode)) - try: - yield - finally: - handle_nonzero_success(power.SetThreadExecutionState(power.ES.continuous)) + """ + Context that prevents the computer from going to sleep. + """ + mode = power.ES.continuous | power.ES.system_required + handle_nonzero_success(power.SetThreadExecutionState(mode)) + try: + yield + finally: + handle_nonzero_success(power.SetThreadExecutionState(power.ES.continuous)) diff --git a/libs/win/jaraco/windows/privilege.py b/libs/win/jaraco/windows/privilege.py index 848a526d..7a75bcfa 100644 --- a/libs/win/jaraco/windows/privilege.py +++ b/libs/win/jaraco/windows/privilege.py @@ -1,5 +1,3 @@ -from __future__ import print_function - import ctypes from ctypes import wintypes @@ -9,134 +7,138 @@ from .api import process def get_process_token(): - """ - Get the current process token - """ - token = wintypes.HANDLE() - res = process.OpenProcessToken( - process.GetCurrentProcess(), process.TOKEN_ALL_ACCESS, token) - if not res > 0: - raise RuntimeError("Couldn't get process token") - return token + """ + Get the current process token + """ + token = wintypes.HANDLE() + res = process.OpenProcessToken( + process.GetCurrentProcess(), process.TOKEN_ALL_ACCESS, token + ) + if not res > 0: + raise RuntimeError("Couldn't get process token") + return token def get_symlink_luid(): - """ - Get the LUID for the SeCreateSymbolicLinkPrivilege - """ - symlink_luid = privilege.LUID() - res = privilege.LookupPrivilegeValue( - None, "SeCreateSymbolicLinkPrivilege", symlink_luid) - if not res > 0: - raise RuntimeError("Couldn't lookup privilege value") - return symlink_luid + """ + Get the LUID for the SeCreateSymbolicLinkPrivilege + """ + symlink_luid = privilege.LUID() + res = privilege.LookupPrivilegeValue( + None, "SeCreateSymbolicLinkPrivilege", symlink_luid + ) + if not res > 0: + raise RuntimeError("Couldn't lookup privilege value") + return symlink_luid def get_privilege_information(): - """ - Get all privileges associated with the current process. - """ - # first call with zero length to determine what size buffer we need + """ + Get all privileges associated with the current process. + """ + # first call with zero length to determine what size buffer we need - return_length = wintypes.DWORD() - params = [ - get_process_token(), - privilege.TOKEN_INFORMATION_CLASS.TokenPrivileges, - None, - 0, - return_length, - ] + return_length = wintypes.DWORD() + params = [ + get_process_token(), + privilege.TOKEN_INFORMATION_CLASS.TokenPrivileges, + None, + 0, + return_length, + ] - res = privilege.GetTokenInformation(*params) + res = privilege.GetTokenInformation(*params) - # assume we now have the necessary length in return_length + # assume we now have the necessary length in return_length - buffer = ctypes.create_string_buffer(return_length.value) - params[2] = buffer - params[3] = return_length.value + buffer = ctypes.create_string_buffer(return_length.value) + params[2] = buffer + params[3] = return_length.value - res = privilege.GetTokenInformation(*params) - assert res > 0, "Error in second GetTokenInformation (%d)" % res + res = privilege.GetTokenInformation(*params) + assert res > 0, "Error in second GetTokenInformation (%d)" % res - privileges = ctypes.cast( - buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES)).contents - return privileges + privileges = ctypes.cast( + buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES) + ).contents + return privileges def report_privilege_information(): - """ - Report all privilege information assigned to the current process. - """ - privileges = get_privilege_information() - print("found {0} privileges".format(privileges.count)) - tuple(map(print, privileges)) + """ + Report all privilege information assigned to the current process. + """ + privileges = get_privilege_information() + print("found {0} privileges".format(privileges.count)) + tuple(map(print, privileges)) def enable_symlink_privilege(): - """ - Try to assign the symlink privilege to the current process token. - Return True if the assignment is successful. - """ - # create a space in memory for a TOKEN_PRIVILEGES structure - # with one element - size = ctypes.sizeof(privilege.TOKEN_PRIVILEGES) - size += ctypes.sizeof(privilege.LUID_AND_ATTRIBUTES) - buffer = ctypes.create_string_buffer(size) - tp = ctypes.cast(buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES)).contents - tp.count = 1 - tp.get_array()[0].enable() - tp.get_array()[0].LUID = get_symlink_luid() - token = get_process_token() - res = privilege.AdjustTokenPrivileges(token, False, tp, 0, None, None) - if res == 0: - raise RuntimeError("Error in AdjustTokenPrivileges") + """ + Try to assign the symlink privilege to the current process token. + Return True if the assignment is successful. + """ + # create a space in memory for a TOKEN_PRIVILEGES structure + # with one element + size = ctypes.sizeof(privilege.TOKEN_PRIVILEGES) + size += ctypes.sizeof(privilege.LUID_AND_ATTRIBUTES) + buffer = ctypes.create_string_buffer(size) + tp = ctypes.cast(buffer, ctypes.POINTER(privilege.TOKEN_PRIVILEGES)).contents + tp.count = 1 + tp.get_array()[0].enable() + tp.get_array()[0].LUID = get_symlink_luid() + token = get_process_token() + res = privilege.AdjustTokenPrivileges(token, False, tp, 0, None, None) + if res == 0: + raise RuntimeError("Error in AdjustTokenPrivileges") - ERROR_NOT_ALL_ASSIGNED = 1300 - return ctypes.windll.kernel32.GetLastError() != ERROR_NOT_ALL_ASSIGNED + ERROR_NOT_ALL_ASSIGNED = 1300 + return ctypes.windll.kernel32.GetLastError() != ERROR_NOT_ALL_ASSIGNED class PolicyHandle(wintypes.HANDLE): - pass + pass class LSA_UNICODE_STRING(ctypes.Structure): - _fields_ = [ - ('length', ctypes.c_ushort), - ('max_length', ctypes.c_ushort), - ('buffer', ctypes.wintypes.LPWSTR), - ] + _fields_ = [ + ('length', ctypes.c_ushort), + ('max_length', ctypes.c_ushort), + ('buffer', ctypes.wintypes.LPWSTR), + ] def OpenPolicy(system_name, object_attributes, access_mask): - policy = PolicyHandle() - raise NotImplementedError( - "Need to construct structures for parameters " - "(see http://msdn.microsoft.com/en-us/library/windows" - "/desktop/aa378299%28v=vs.85%29.aspx)") - res = ctypes.windll.advapi32.LsaOpenPolicy( - system_name, object_attributes, - access_mask, ctypes.byref(policy)) - assert res == 0, "Error status {res}".format(**vars()) - return policy + policy = PolicyHandle() + raise NotImplementedError( + "Need to construct structures for parameters " + "(see http://msdn.microsoft.com/en-us/library/windows" + "/desktop/aa378299%28v=vs.85%29.aspx)" + ) + res = ctypes.windll.advapi32.LsaOpenPolicy( + system_name, object_attributes, access_mask, ctypes.byref(policy) + ) + assert res == 0, "Error status {res}".format(**vars()) + return policy def grant_symlink_privilege(who, machine=''): - """ - Grant the 'create symlink' privilege to who. + """ + Grant the 'create symlink' privilege to who. - Based on http://support.microsoft.com/kb/132958 - """ - flags = security.POLICY_CREATE_ACCOUNT | security.POLICY_LOOKUP_NAMES - policy = OpenPolicy(machine, flags) - return policy + Based on http://support.microsoft.com/kb/132958 + """ + flags = security.POLICY_CREATE_ACCOUNT | security.POLICY_LOOKUP_NAMES + policy = OpenPolicy(machine, flags) + return policy def main(): - assigned = enable_symlink_privilege() - msg = ['failure', 'success'][assigned] + assigned = enable_symlink_privilege() + msg = ['failure', 'success'][assigned] - print("Symlink privilege assignment completed with {0}".format(msg)) + print("Symlink privilege assignment completed with {0}".format(msg)) if __name__ == '__main__': - main() + main() diff --git a/libs/win/jaraco/windows/registry.py b/libs/win/jaraco/windows/registry.py index b6f3b239..dd4b6848 100644 --- a/libs/win/jaraco/windows/registry.py +++ b/libs/win/jaraco/windows/registry.py @@ -1,20 +1,18 @@ +import winreg from itertools import count -import six -winreg = six.moves.winreg - def key_values(key): - for index in count(): - try: - yield winreg.EnumValue(key, index) - except WindowsError: - break + for index in count(): + try: + yield winreg.EnumValue(key, index) + except WindowsError: + break def key_subkeys(key): - for index in count(): - try: - yield winreg.EnumKey(key, index) - except WindowsError: - break + for index in count(): + try: + yield winreg.EnumKey(key, index) + except WindowsError: + break diff --git a/libs/win/jaraco/windows/reparse.py b/libs/win/jaraco/windows/reparse.py index 2751e967..f9159381 100644 --- a/libs/win/jaraco/windows/reparse.py +++ b/libs/win/jaraco/windows/reparse.py @@ -1,35 +1,34 @@ -from __future__ import division - import ctypes.wintypes from .error import handle_nonzero_success from .api import filesystem -def DeviceIoControl( - device, io_control_code, in_buffer, out_buffer, overlapped=None): - if overlapped is not None: - raise NotImplementedError("overlapped handles not yet supported") +def DeviceIoControl(device, io_control_code, in_buffer, out_buffer, overlapped=None): + if overlapped is not None: + raise NotImplementedError("overlapped handles not yet supported") - if isinstance(out_buffer, int): - out_buffer = ctypes.create_string_buffer(out_buffer) + if isinstance(out_buffer, int): + out_buffer = ctypes.create_string_buffer(out_buffer) - in_buffer_size = len(in_buffer) if in_buffer is not None else 0 - out_buffer_size = len(out_buffer) - assert isinstance(out_buffer, ctypes.Array) + in_buffer_size = len(in_buffer) if in_buffer is not None else 0 + out_buffer_size = len(out_buffer) + assert isinstance(out_buffer, ctypes.Array) - returned_bytes = ctypes.wintypes.DWORD() + returned_bytes = ctypes.wintypes.DWORD() - res = filesystem.DeviceIoControl( - device, - io_control_code, - in_buffer, in_buffer_size, - out_buffer, out_buffer_size, - returned_bytes, - overlapped, - ) + res = filesystem.DeviceIoControl( + device, + io_control_code, + in_buffer, + in_buffer_size, + out_buffer, + out_buffer_size, + returned_bytes, + overlapped, + ) - handle_nonzero_success(res) - handle_nonzero_success(returned_bytes) + handle_nonzero_success(res) + handle_nonzero_success(returned_bytes) - return out_buffer[:returned_bytes.value] + return out_buffer[: returned_bytes.value] diff --git a/libs/win/jaraco/windows/security.py b/libs/win/jaraco/windows/security.py index 7c481ed6..43582e04 100644 --- a/libs/win/jaraco/windows/security.py +++ b/libs/win/jaraco/windows/security.py @@ -5,63 +5,66 @@ from .api import security def GetTokenInformation(token, information_class): - """ - Given a token, get the token information for it. - """ - data_size = ctypes.wintypes.DWORD() - ctypes.windll.advapi32.GetTokenInformation( - token, information_class.num, - 0, 0, ctypes.byref(data_size)) - data = ctypes.create_string_buffer(data_size.value) - handle_nonzero_success(ctypes.windll.advapi32.GetTokenInformation( - token, - information_class.num, - ctypes.byref(data), ctypes.sizeof(data), - ctypes.byref(data_size))) - return ctypes.cast(data, ctypes.POINTER(security.TOKEN_USER)).contents + """ + Given a token, get the token information for it. + """ + data_size = ctypes.wintypes.DWORD() + ctypes.windll.advapi32.GetTokenInformation( + token, information_class.num, 0, 0, ctypes.byref(data_size) + ) + data = ctypes.create_string_buffer(data_size.value) + handle_nonzero_success( + ctypes.windll.advapi32.GetTokenInformation( + token, + information_class.num, + ctypes.byref(data), + ctypes.sizeof(data), + ctypes.byref(data_size), + ) + ) + return ctypes.cast(data, ctypes.POINTER(security.TOKEN_USER)).contents def OpenProcessToken(proc_handle, access): - result = ctypes.wintypes.HANDLE() - proc_handle = ctypes.wintypes.HANDLE(proc_handle) - handle_nonzero_success(ctypes.windll.advapi32.OpenProcessToken( - proc_handle, access, ctypes.byref(result))) - return result + result = ctypes.wintypes.HANDLE() + proc_handle = ctypes.wintypes.HANDLE(proc_handle) + handle_nonzero_success( + ctypes.windll.advapi32.OpenProcessToken( + proc_handle, access, ctypes.byref(result) + ) + ) + return result def get_current_user(): - """ - Return a TOKEN_USER for the owner of this process. - """ - process = OpenProcessToken( - ctypes.windll.kernel32.GetCurrentProcess(), - security.TokenAccess.TOKEN_QUERY, - ) - return GetTokenInformation(process, security.TOKEN_USER) + """ + Return a TOKEN_USER for the owner of this process. + """ + process = OpenProcessToken( + ctypes.windll.kernel32.GetCurrentProcess(), security.TokenAccess.TOKEN_QUERY + ) + return GetTokenInformation(process, security.TOKEN_USER) def get_security_attributes_for_user(user=None): - """ - Return a SECURITY_ATTRIBUTES structure with the SID set to the - specified user (uses current user if none is specified). - """ - if user is None: - user = get_current_user() + """ + Return a SECURITY_ATTRIBUTES structure with the SID set to the + specified user (uses current user if none is specified). + """ + if user is None: + user = get_current_user() - assert isinstance(user, security.TOKEN_USER), ( - "user must be TOKEN_USER instance") + assert isinstance(user, security.TOKEN_USER), "user must be TOKEN_USER instance" - SD = security.SECURITY_DESCRIPTOR() - SA = security.SECURITY_ATTRIBUTES() - # by attaching the actual security descriptor, it will be garbage- - # collected with the security attributes - SA.descriptor = SD - SA.bInheritHandle = 1 + SD = security.SECURITY_DESCRIPTOR() + SA = security.SECURITY_ATTRIBUTES() + # by attaching the actual security descriptor, it will be garbage- + # collected with the security attributes + SA.descriptor = SD + SA.bInheritHandle = 1 - ctypes.windll.advapi32.InitializeSecurityDescriptor( - ctypes.byref(SD), - security.SECURITY_DESCRIPTOR.REVISION) - ctypes.windll.advapi32.SetSecurityDescriptorOwner( - ctypes.byref(SD), - user.SID, 0) - return SA + ctypes.windll.advapi32.InitializeSecurityDescriptor( + ctypes.byref(SD), security.SECURITY_DESCRIPTOR.REVISION + ) + ctypes.windll.advapi32.SetSecurityDescriptorOwner(ctypes.byref(SD), user.SID, 0) + return SA diff --git a/libs/win/jaraco/windows/services.py b/libs/win/jaraco/windows/services.py index 97cea7ab..b2064dcd 100644 --- a/libs/win/jaraco/windows/services.py +++ b/libs/win/jaraco/windows/services.py @@ -5,8 +5,6 @@ Based on http://code.activestate.com /recipes/115875-controlling-windows-services/ """ -from __future__ import print_function - import sys import time @@ -16,221 +14,240 @@ import win32service class Service(object): - """ - The Service Class is used for controlling Windows - services. Just pass the name of the service you wish to control to the - class instance and go from there. For example, if you want to control - the Workstation service try this: + """ + The Service Class is used for controlling Windows + services. Just pass the name of the service you wish to control to the + class instance and go from there. For example, if you want to control + the Workstation service try this: - from jaraco.windows import services - workstation = services.Service("Workstation") - workstation.start() - workstation.fetchstatus("running", 10) - workstation.stop() - workstation.fetchstatus("stopped") + from jaraco.windows import services + workstation = services.Service("Workstation") + workstation.start() + workstation.fetchstatus("running", 10) + workstation.stop() + workstation.fetchstatus("stopped") - Creating an instance of the Service class is done by passing the name of - the service as it appears in the Management Console or the short name as - it appears in the registry. Mixed case is ok. - cvs = services.Service("CVS NT Service 1.11.1.2 (Build 41)") - or - cvs = services.Service("cvs") + Creating an instance of the Service class is done by passing the name of + the service as it appears in the Management Console or the short name as + it appears in the registry. Mixed case is ok. + cvs = services.Service("CVS NT Service 1.11.1.2 (Build 41)") + or + cvs = services.Service("cvs") - If needing remote service control try this: - cvs = services.Service("cvs", r"\\CVS_SERVER") - or - cvs = services.Service("cvs", "\\\\CVS_SERVER") + If needing remote service control try this: + cvs = services.Service("cvs", r"\\CVS_SERVER") + or + cvs = services.Service("cvs", "\\\\CVS_SERVER") - The Service Class supports these methods: + The Service Class supports these methods: - start: Starts service. - stop: Stops service. - restart: Stops and restarts service. - pause: Pauses service (Only if service supports feature). - resume: Resumes service that has been paused. - status: Queries current status of service. - fetchstatus: Continually queries service until requested - status(STARTING, RUNNING, - STOPPING & STOPPED) is met or timeout value(in seconds) reached. - Default timeout value is infinite. - infotype: Queries service for process type. (Single, shared and/or - interactive process) - infoctrl: Queries control information about a running service. - i.e. Can it be paused, stopped, etc? - infostartup: Queries service Startup type. (Boot, System, - Automatic, Manual, Disabled) - setstartup Changes/sets Startup type. (Boot, System, - Automatic, Manual, Disabled) - getname: Gets the long and short service names used by Windowin32service. - (Generally used for internal purposes) - """ + start: Starts service. + stop: Stops service. + restart: Stops and restarts service. + pause: Pauses service (Only if service supports feature). + resume: Resumes service that has been paused. + status: Queries current status of service. + fetchstatus: Continually queries service until requested + status(STARTING, RUNNING, + STOPPING & STOPPED) is met or timeout value(in seconds) reached. + Default timeout value is infinite. + infotype: Queries service for process type. (Single, shared and/or + interactive process) + infoctrl: Queries control information about a running service. + i.e. Can it be paused, stopped, etc? + infostartup: Queries service Startup type. (Boot, System, + Automatic, Manual, Disabled) + setstartup: Changes/sets Startup type. (Boot, System, + Automatic, Manual, Disabled) + getname: Gets the long and short service names used by Windowin32service. + (Generally used for internal purposes) + """ - def __init__(self, service, machinename=None, dbname=None): - self.userv = service - self.scmhandle = win32service.OpenSCManager( - machinename, dbname, win32service.SC_MANAGER_ALL_ACCESS) - self.sserv, self.lserv = self.getname() - if (self.sserv or self.lserv) is None: - sys.exit() - self.handle = win32service.OpenService( - self.scmhandle, self.sserv, win32service.SERVICE_ALL_ACCESS) - self.sccss = "SYSTEM\\CurrentControlSet\\Services\\" + def __init__(self, service, machinename=None, dbname=None): + self.userv = service + self.scmhandle = win32service.OpenSCManager( + machinename, dbname, win32service.SC_MANAGER_ALL_ACCESS + ) + self.sserv, self.lserv = self.getname() + if (self.sserv or self.lserv) is None: + sys.exit() + self.handle = win32service.OpenService( + self.scmhandle, self.sserv, win32service.SERVICE_ALL_ACCESS + ) + self.sccss = "SYSTEM\\CurrentControlSet\\Services\\" - def start(self): - win32service.StartService(self.handle, None) + def start(self): + win32service.StartService(self.handle, None) - def stop(self): - self.stat = win32service.ControlService( - self.handle, win32service.SERVICE_CONTROL_STOP) + def stop(self): + self.stat = win32service.ControlService( + self.handle, win32service.SERVICE_CONTROL_STOP + ) - def restart(self): - self.stop() - self.fetchstatus("STOPPED") - self.start() + def restart(self): + self.stop() + self.fetchstatus("STOPPED") + self.start() - def pause(self): - self.stat = win32service.ControlService( - self.handle, win32service.SERVICE_CONTROL_PAUSE) + def pause(self): + self.stat = win32service.ControlService( + self.handle, win32service.SERVICE_CONTROL_PAUSE + ) - def resume(self): - self.stat = win32service.ControlService( - self.handle, win32service.SERVICE_CONTROL_CONTINUE) + def resume(self): + self.stat = win32service.ControlService( + self.handle, win32service.SERVICE_CONTROL_CONTINUE + ) - def status(self, prn=0): - self.stat = win32service.QueryServiceStatus(self.handle) - if self.stat[1] == win32service.SERVICE_STOPPED: - if prn == 1: - print("The %s service is stopped." % self.lserv) - else: - return "STOPPED" - elif self.stat[1] == win32service.SERVICE_START_PENDING: - if prn == 1: - print("The %s service is starting." % self.lserv) - else: - return "STARTING" - elif self.stat[1] == win32service.SERVICE_STOP_PENDING: - if prn == 1: - print("The %s service is stopping." % self.lserv) - else: - return "STOPPING" - elif self.stat[1] == win32service.SERVICE_RUNNING: - if prn == 1: - print("The %s service is running." % self.lserv) - else: - return "RUNNING" + def status(self, prn=0): + self.stat = win32service.QueryServiceStatus(self.handle) + if self.stat[1] == win32service.SERVICE_STOPPED: + if prn == 1: + print("The %s service is stopped." % self.lserv) + else: + return "STOPPED" + elif self.stat[1] == win32service.SERVICE_START_PENDING: + if prn == 1: + print("The %s service is starting." % self.lserv) + else: + return "STARTING" + elif self.stat[1] == win32service.SERVICE_STOP_PENDING: + if prn == 1: + print("The %s service is stopping." % self.lserv) + else: + return "STOPPING" + elif self.stat[1] == win32service.SERVICE_RUNNING: + if prn == 1: + print("The %s service is running." % self.lserv) + else: + return "RUNNING" - def fetchstatus(self, fstatus, timeout=None): - self.fstatus = fstatus.upper() - if timeout is not None: - timeout = int(timeout) - timeout *= 2 + def fetchstatus(self, fstatus, timeout=None): + self.fstatus = fstatus.upper() + if timeout is not None: + timeout = int(timeout) + timeout *= 2 - def to(timeout): - time.sleep(.5) - if timeout is not None: - if timeout > 1: - timeout -= 1 - return timeout - else: - return "TO" - if self.fstatus == "STOPPED": - while 1: - self.stat = win32service.QueryServiceStatus(self.handle) - if self.stat[1] == win32service.SERVICE_STOPPED: - self.fstate = "STOPPED" - break - else: - timeout = to(timeout) - if timeout == "TO": - return "TIMEDOUT" - break - elif self.fstatus == "STOPPING": - while 1: - self.stat = win32service.QueryServiceStatus(self.handle) - if self.stat[1]==win32service.SERVICE_STOP_PENDING: - self.fstate = "STOPPING" - break - else: - timeout=to(timeout) - if timeout == "TO": - return "TIMEDOUT" - break - elif self.fstatus == "RUNNING": - while 1: - self.stat = win32service.QueryServiceStatus(self.handle) - if self.stat[1]==win32service.SERVICE_RUNNING: - self.fstate = "RUNNING" - break - else: - timeout=to(timeout) - if timeout == "TO": - return "TIMEDOUT" - break - elif self.fstatus == "STARTING": - while 1: - self.stat = win32service.QueryServiceStatus(self.handle) - if self.stat[1]==win32service.SERVICE_START_PENDING: - self.fstate = "STARTING" - break - else: - timeout=to(timeout) - if timeout == "TO": - return "TIMEDOUT" - break + def to(timeout): + time.sleep(0.5) + if timeout is not None: + if timeout > 1: + timeout -= 1 + return timeout + else: + return "TO" - def infotype(self): - self.stat = win32service.QueryServiceStatus(self.handle) - if self.stat[0] and win32service.SERVICE_WIN32_OWN_PROCESS: - print("The %s service runs in its own process." % self.lserv) - if self.stat[0] and win32service.SERVICE_WIN32_SHARE_PROCESS: - print("The %s service shares a process with other services." % self.lserv) - if self.stat[0] and win32service.SERVICE_INTERACTIVE_PROCESS: - print("The %s service can interact with the desktop." % self.lserv) + if self.fstatus == "STOPPED": + while 1: + self.stat = win32service.QueryServiceStatus(self.handle) + if self.stat[1] == win32service.SERVICE_STOPPED: + self.fstate = "STOPPED" + break + else: + timeout = to(timeout) + if timeout == "TO": + return "TIMEDOUT" + break + elif self.fstatus == "STOPPING": + while 1: + self.stat = win32service.QueryServiceStatus(self.handle) + if self.stat[1] == win32service.SERVICE_STOP_PENDING: + self.fstate = "STOPPING" + break + else: + timeout = to(timeout) + if timeout == "TO": + return "TIMEDOUT" + break + elif self.fstatus == "RUNNING": + while 1: + self.stat = win32service.QueryServiceStatus(self.handle) + if self.stat[1] == win32service.SERVICE_RUNNING: + self.fstate = "RUNNING" + break + else: + timeout = to(timeout) + if timeout == "TO": + return "TIMEDOUT" + break + elif self.fstatus == "STARTING": + while 1: + self.stat = win32service.QueryServiceStatus(self.handle) + if self.stat[1] == win32service.SERVICE_START_PENDING: + self.fstate = "STARTING" + break + else: + timeout = to(timeout) + if timeout == "TO": + return "TIMEDOUT" + break - def infoctrl(self): - self.stat = win32service.QueryServiceStatus(self.handle) - if self.stat[2] and win32service.SERVICE_ACCEPT_PAUSE_CONTINUE: - print("The %s service can be paused." % self.lserv) - if self.stat[2] and win32service.SERVICE_ACCEPT_STOP: - print("The %s service can be stopped." % self.lserv) - if self.stat[2] and win32service.SERVICE_ACCEPT_SHUTDOWN: - print("The %s service can be shutdown." % self.lserv) + def infotype(self): + self.stat = win32service.QueryServiceStatus(self.handle) + if self.stat[0] and win32service.SERVICE_WIN32_OWN_PROCESS: + print("The %s service runs in its own process." % self.lserv) + if self.stat[0] and win32service.SERVICE_WIN32_SHARE_PROCESS: + print("The %s service shares a process with other services." % self.lserv) + if self.stat[0] and win32service.SERVICE_INTERACTIVE_PROCESS: + print("The %s service can interact with the desktop." % self.lserv) - def infostartup(self): - self.isuphandle = win32api.RegOpenKeyEx(win32con.HKEY_LOCAL_MACHINE, self.sccss + self.sserv, 0, win32con.KEY_READ) - self.isuptype = win32api.RegQueryValueEx(self.isuphandle, "Start")[0] - win32api.RegCloseKey(self.isuphandle) - if self.isuptype == 0: - return "boot" - elif self.isuptype == 1: - return "system" - elif self.isuptype == 2: - return "automatic" - elif self.isuptype == 3: - return "manual" - elif self.isuptype == 4: - return "disabled" + def infoctrl(self): + self.stat = win32service.QueryServiceStatus(self.handle) + if self.stat[2] and win32service.SERVICE_ACCEPT_PAUSE_CONTINUE: + print("The %s service can be paused." % self.lserv) + if self.stat[2] and win32service.SERVICE_ACCEPT_STOP: + print("The %s service can be stopped." % self.lserv) + if self.stat[2] and win32service.SERVICE_ACCEPT_SHUTDOWN: + print("The %s service can be shutdown." % self.lserv) - @property - def suptype(self): - types = 'boot', 'system', 'automatic', 'manual', 'disabled' - lookup = dict((name, number) for number, name in enumerate(types)) - return lookup[self.startuptype] + def infostartup(self): + self.isuphandle = win32api.RegOpenKeyEx( + win32con.HKEY_LOCAL_MACHINE, self.sccss + self.sserv, 0, win32con.KEY_READ + ) + self.isuptype = win32api.RegQueryValueEx(self.isuphandle, "Start")[0] + win32api.RegCloseKey(self.isuphandle) + if self.isuptype == 0: + return "boot" + elif self.isuptype == 1: + return "system" + elif self.isuptype == 2: + return "automatic" + elif self.isuptype == 3: + return "manual" + elif self.isuptype == 4: + return "disabled" - def setstartup(self, startuptype): - self.startuptype = startuptype.lower() - self.snc = win32service.SERVICE_NO_CHANGE - win32service.ChangeServiceConfig(self.handle, self.snc, self.suptype, - self.snc, None, None, 0, None, None, None, self.lserv) + @property + def suptype(self): + types = 'boot', 'system', 'automatic', 'manual', 'disabled' + lookup = dict((name, number) for number, name in enumerate(types)) + return lookup[self.startuptype] - def getname(self): - self.snames=win32service.EnumServicesStatus(self.scmhandle) - for i in self.snames: - if i[0].lower() == self.userv.lower(): - return i[0], i[1] - break - if i[1].lower() == self.userv.lower(): - return i[0], i[1] - break - print("Error: The %s service doesn't seem to exist." % self.userv) - return None, None + def setstartup(self, startuptype): + self.startuptype = startuptype.lower() + self.snc = win32service.SERVICE_NO_CHANGE + win32service.ChangeServiceConfig( + self.handle, + self.snc, + self.suptype, + self.snc, + None, + None, + 0, + None, + None, + None, + self.lserv, + ) + + def getname(self): + self.snames = win32service.EnumServicesStatus(self.scmhandle) + for i in self.snames: + if i[0].lower() == self.userv.lower(): + return i[0], i[1] + break + if i[1].lower() == self.userv.lower(): + return i[0], i[1] + break + print("Error: The %s service doesn't seem to exist." % self.userv) + return None, None diff --git a/libs/win/jaraco/windows/shell.py b/libs/win/jaraco/windows/shell.py index 58333359..2805cbe4 100644 --- a/libs/win/jaraco/windows/shell.py +++ b/libs/win/jaraco/windows/shell.py @@ -2,13 +2,13 @@ from .api import shell def get_recycle_bin_confirm(): - settings = shell.SHELLSTATE() - shell.SHGetSetSettings(settings, shell.SSF_NOCONFIRMRECYCLE, False) - return not settings.no_confirm_recycle + settings = shell.SHELLSTATE() + shell.SHGetSetSettings(settings, shell.SSF_NOCONFIRMRECYCLE, False) + return not settings.no_confirm_recycle def set_recycle_bin_confirm(confirm=False): - settings = shell.SHELLSTATE() - settings.no_confirm_recycle = not confirm - shell.SHGetSetSettings(settings, shell.SSF_NOCONFIRMRECYCLE, True) - # cross fingers and hope it worked + settings = shell.SHELLSTATE() + settings.no_confirm_recycle = not confirm + shell.SHGetSetSettings(settings, shell.SSF_NOCONFIRMRECYCLE, True) + # cross fingers and hope it worked diff --git a/libs/win/jaraco/windows/timers.py b/libs/win/jaraco/windows/timers.py index 626118a9..caf4a58e 100644 --- a/libs/win/jaraco/windows/timers.py +++ b/libs/win/jaraco/windows/timers.py @@ -1,71 +1,66 @@ -# -*- coding: UTF-8 -*- - """ timers - In particular, contains a waitable timer. + In particular, contains a waitable timer. """ -from __future__ import absolute_import - import time -from six.moves import _thread +import _thread from jaraco.windows.api import event as win32event -__author__ = 'Jason R. Coombs ' - class WaitableTimer: - """ - t = WaitableTimer() - t.set(None, 10) # every 10 seconds - t.wait_for_signal() # 10 seconds elapses - t.stop() - t.wait_for_signal(20) # 20 seconds elapses (timeout occurred) - """ - def __init__(self): - self.signal_event = win32event.CreateEvent(None, 0, 0, None) - self.stop_event = win32event.CreateEvent(None, 0, 0, None) + """ + t = WaitableTimer() + t.set(None, 10) # every 10 seconds + t.wait_for_signal() # 10 seconds elapses + t.stop() + t.wait_for_signal(20) # 20 seconds elapses (timeout occurred) + """ - def set(self, due_time, period): - _thread.start_new_thread(self._signal_loop, (due_time, period)) + def __init__(self): + self.signal_event = win32event.CreateEvent(None, 0, 0, None) + self.stop_event = win32event.CreateEvent(None, 0, 0, None) - def stop(self): - win32event.SetEvent(self.stop_event) + def set(self, due_time, period): + _thread.start_new_thread(self._signal_loop, (due_time, period)) - def wait_for_signal(self, timeout=None): - """ - wait for the signal; return after the signal has occurred or the - timeout in seconds elapses. - """ - timeout_ms = int(timeout * 1000) if timeout else win32event.INFINITE - win32event.WaitForSingleObject(self.signal_event, timeout_ms) + def stop(self): + win32event.SetEvent(self.stop_event) - def _signal_loop(self, due_time, period): - if not due_time and not period: - raise ValueError("due_time or period must be non-zero") - try: - if not due_time: - due_time = time.time() + period - if due_time: - self._wait(due_time - time.time()) - while period: - due_time += period - self._wait(due_time - time.time()) - except Exception: - pass + def wait_for_signal(self, timeout=None): + """ + wait for the signal; return after the signal has occurred or the + timeout in seconds elapses. + """ + timeout_ms = int(timeout * 1000) if timeout else win32event.INFINITE + win32event.WaitForSingleObject(self.signal_event, timeout_ms) - def _wait(self, seconds): - milliseconds = int(seconds * 1000) - if milliseconds > 0: - res = win32event.WaitForSingleObject(self.stop_event, milliseconds) - if res == win32event.WAIT_OBJECT_0: - raise Exception - if res == win32event.WAIT_TIMEOUT: - pass - win32event.SetEvent(self.signal_event) + def _signal_loop(self, due_time, period): + if not due_time and not period: + raise ValueError("due_time or period must be non-zero") + try: + if not due_time: + due_time = time.time() + period + if due_time: + self._wait(due_time - time.time()) + while period: + due_time += period + self._wait(due_time - time.time()) + except Exception: + pass - @staticmethod - def get_even_due_time(period): - now = time.time() - return now - (now % period) + def _wait(self, seconds): + milliseconds = int(seconds * 1000) + if milliseconds > 0: + res = win32event.WaitForSingleObject(self.stop_event, milliseconds) + if res == win32event.WAIT_OBJECT_0: + raise Exception + if res == win32event.WAIT_TIMEOUT: + pass + win32event.SetEvent(self.signal_event) + + @staticmethod + def get_even_due_time(period): + now = time.time() + return now - (now % period) diff --git a/libs/win/jaraco/windows/timezone.py b/libs/win/jaraco/windows/timezone.py index 7eedcf0b..fdefc931 100644 --- a/libs/win/jaraco/windows/timezone.py +++ b/libs/win/jaraco/windows/timezone.py @@ -10,245 +10,253 @@ from jaraco.collections import RangeMap class AnyDict(object): - "A dictionary that returns the same value regardless of key" + "A dictionary that returns the same value regardless of key" - def __init__(self, value): - self.value = value + def __init__(self, value): + self.value = value - def __getitem__(self, key): - return self.value + def __getitem__(self, key): + return self.value class SYSTEMTIME(Extended, ctypes.Structure): - _fields_ = [ - ('year', WORD), - ('month', WORD), - ('day_of_week', WORD), - ('day', WORD), - ('hour', WORD), - ('minute', WORD), - ('second', WORD), - ('millisecond', WORD), - ] + _fields_ = [ + ('year', WORD), + ('month', WORD), + ('day_of_week', WORD), + ('day', WORD), + ('hour', WORD), + ('minute', WORD), + ('second', WORD), + ('millisecond', WORD), + ] class REG_TZI_FORMAT(Extended, ctypes.Structure): - _fields_ = [ - ('bias', LONG), - ('standard_bias', LONG), - ('daylight_bias', LONG), - ('standard_start', SYSTEMTIME), - ('daylight_start', SYSTEMTIME), - ] + _fields_ = [ + ('bias', LONG), + ('standard_bias', LONG), + ('daylight_bias', LONG), + ('standard_start', SYSTEMTIME), + ('daylight_start', SYSTEMTIME), + ] class TIME_ZONE_INFORMATION(Extended, ctypes.Structure): - _fields_ = [ - ('bias', LONG), - ('standard_name', WCHAR * 32), - ('standard_start', SYSTEMTIME), - ('standard_bias', LONG), - ('daylight_name', WCHAR * 32), - ('daylight_start', SYSTEMTIME), - ('daylight_bias', LONG), - ] + _fields_ = [ + ('bias', LONG), + ('standard_name', WCHAR * 32), + ('standard_start', SYSTEMTIME), + ('standard_bias', LONG), + ('daylight_name', WCHAR * 32), + ('daylight_start', SYSTEMTIME), + ('daylight_bias', LONG), + ] class DYNAMIC_TIME_ZONE_INFORMATION(TIME_ZONE_INFORMATION): - """ - Because the structure of the DYNAMIC_TIME_ZONE_INFORMATION extends - the structure of the TIME_ZONE_INFORMATION, this structure - can be used as a drop-in replacement for calls where the - structure is passed by reference. + """ + Because the structure of the DYNAMIC_TIME_ZONE_INFORMATION extends + the structure of the TIME_ZONE_INFORMATION, this structure + can be used as a drop-in replacement for calls where the + structure is passed by reference. - For example, - dynamic_tzi = DYNAMIC_TIME_ZONE_INFORMATION() - ctypes.windll.kernel32.GetTimeZoneInformation(ctypes.byref(dynamic_tzi)) + For example, + dynamic_tzi = DYNAMIC_TIME_ZONE_INFORMATION() + ctypes.windll.kernel32.GetTimeZoneInformation(ctypes.byref(dynamic_tzi)) - (although the key_name and dynamic_daylight_time_disabled flags will be - set to the default (null)). + (although the key_name and dynamic_daylight_time_disabled flags will be + set to the default (null)). - >>> isinstance(DYNAMIC_TIME_ZONE_INFORMATION(), TIME_ZONE_INFORMATION) - True + >>> isinstance(DYNAMIC_TIME_ZONE_INFORMATION(), TIME_ZONE_INFORMATION) + True - """ - _fields_ = [ - # ctypes automatically includes the fields from the parent - ('key_name', WCHAR * 128), - ('dynamic_daylight_time_disabled', BOOL), - ] + """ - def __init__(self, *args, **kwargs): - """Allow initialization from args from both this class and - its superclass. Default ctypes implementation seems to - assume that this class is only initialized with its own - _fields_ (for non-keyword-args).""" - super_self = super(DYNAMIC_TIME_ZONE_INFORMATION, self) - super_fields = super_self._fields_ - super_args = args[:len(super_fields)] - self_args = args[len(super_fields):] - # convert the super args to keyword args so they're also handled - for field, arg in zip(super_fields, super_args): - field_name, spec = field - kwargs[field_name] = arg - super(DYNAMIC_TIME_ZONE_INFORMATION, self).__init__(*self_args, **kwargs) + _fields_ = [ + # ctypes automatically includes the fields from the parent + ('key_name', WCHAR * 128), + ('dynamic_daylight_time_disabled', BOOL), + ] + + def __init__(self, *args, **kwargs): + """Allow initialization from args from both this class and + its superclass. Default ctypes implementation seems to + assume that this class is only initialized with its own + _fields_ (for non-keyword-args).""" + super_self = super(DYNAMIC_TIME_ZONE_INFORMATION, self) + super_fields = super_self._fields_ + super_args = args[: len(super_fields)] + self_args = args[len(super_fields) :] + # convert the super args to keyword args so they're also handled + for field, arg in zip(super_fields, super_args): + field_name, spec = field + kwargs[field_name] = arg + super(DYNAMIC_TIME_ZONE_INFORMATION, self).__init__(*self_args, **kwargs) class Info(DYNAMIC_TIME_ZONE_INFORMATION): - """ - A time zone definition class based on the win32 - DYNAMIC_TIME_ZONE_INFORMATION structure. + """ + A time zone definition class based on the win32 + DYNAMIC_TIME_ZONE_INFORMATION structure. - Describes a bias against UTC (bias), and two dates at which a separate - additional bias applies (standard_bias and daylight_bias). - """ + Describes a bias against UTC (bias), and two dates at which a separate + additional bias applies (standard_bias and daylight_bias). + """ - def field_names(self): - return map(operator.itemgetter(0), self._fields_) + def field_names(self): + return map(operator.itemgetter(0), self._fields_) - def __init__(self, *args, **kwargs): - """ - Try to construct a timezone.Info from - a) [DYNAMIC_]TIME_ZONE_INFORMATION args - b) another Info - c) a REG_TZI_FORMAT - d) a byte structure - """ - funcs = ( - super(Info, self).__init__, - self.__init_from_other, - self.__init_from_reg_tzi, - self.__init_from_bytes, - ) - for func in funcs: - try: - func(*args, **kwargs) - return - except TypeError: - pass - raise TypeError("Invalid arguments for %s" % self.__class__) + def __init__(self, *args, **kwargs): + """ + Try to construct a timezone.Info from + a) [DYNAMIC_]TIME_ZONE_INFORMATION args + b) another Info + c) a REG_TZI_FORMAT + d) a byte structure + """ + funcs = ( + super(Info, self).__init__, + self.__init_from_other, + self.__init_from_reg_tzi, + self.__init_from_bytes, + ) + for func in funcs: + try: + func(*args, **kwargs) + return + except TypeError: + pass + raise TypeError("Invalid arguments for %s" % self.__class__) - def __init_from_bytes(self, bytes, **kwargs): - reg_tzi = REG_TZI_FORMAT() - # todo: use buffer API in Python 3 - buffer = memoryview(bytes) - ctypes.memmove(ctypes.addressof(reg_tzi), buffer, len(buffer)) - self.__init_from_reg_tzi(self, reg_tzi, **kwargs) + def __init_from_bytes(self, bytes, **kwargs): + reg_tzi = REG_TZI_FORMAT() + # todo: use buffer API in Python 3 + buffer = memoryview(bytes) + ctypes.memmove(ctypes.addressof(reg_tzi), buffer, len(buffer)) + self.__init_from_reg_tzi(self, reg_tzi, **kwargs) - def __init_from_reg_tzi(self, reg_tzi, **kwargs): - if not isinstance(reg_tzi, REG_TZI_FORMAT): - raise TypeError("Not a REG_TZI_FORMAT") - for field_name, type in reg_tzi._fields_: - setattr(self, field_name, getattr(reg_tzi, field_name)) - for name, value in kwargs.items(): - setattr(self, name, value) + def __init_from_reg_tzi(self, reg_tzi, **kwargs): + if not isinstance(reg_tzi, REG_TZI_FORMAT): + raise TypeError("Not a REG_TZI_FORMAT") + for field_name, type in reg_tzi._fields_: + setattr(self, field_name, getattr(reg_tzi, field_name)) + for name, value in kwargs.items(): + setattr(self, name, value) - def __init_from_other(self, other): - if not isinstance(other, TIME_ZONE_INFORMATION): - raise TypeError("Not a TIME_ZONE_INFORMATION") - for name in other.field_names(): - # explicitly get the value from the underlying structure - value = super(Info, other).__getattribute__(other, name) - setattr(self, name, value) - # consider instead of the loop above just copying the memory directly - # size = max(ctypes.sizeof(DYNAMIC_TIME_ZONE_INFO), ctypes.sizeof(other)) - # ctypes.memmove(ctypes.addressof(self), other, size) + def __init_from_other(self, other): + if not isinstance(other, TIME_ZONE_INFORMATION): + raise TypeError("Not a TIME_ZONE_INFORMATION") + for name in other.field_names(): + # explicitly get the value from the underlying structure + value = super(Info, other).__getattribute__(other, name) + setattr(self, name, value) + # consider instead of the loop above just copying the memory directly + # size = max(ctypes.sizeof(DYNAMIC_TIME_ZONE_INFO), ctypes.sizeof(other)) + # ctypes.memmove(ctypes.addressof(self), other, size) - def __getattribute__(self, attr): - value = super(Info, self).__getattribute__(attr) + def __getattribute__(self, attr): + value = super(Info, self).__getattribute__(attr) - def make_minute_timedelta(m): - datetime.timedelta(minutes=m) - if 'bias' in attr: - value = make_minute_timedelta(value) - return value + def make_minute_timedelta(m): + datetime.timedelta(minutes=m) - @classmethod - def current(class_): - "Windows Platform SDK GetTimeZoneInformation" - tzi = class_() - kernel32 = ctypes.windll.kernel32 - getter = kernel32.GetTimeZoneInformation - getter = getattr(kernel32, 'GetDynamicTimeZoneInformation', getter) - code = getter(ctypes.byref(tzi)) - return code, tzi + if 'bias' in attr: + value = make_minute_timedelta(value) + return value - def set(self): - kernel32 = ctypes.windll.kernel32 - setter = kernel32.SetTimeZoneInformation - setter = getattr(kernel32, 'SetDynamicTimeZoneInformation', setter) - return setter(ctypes.byref(self)) + @classmethod + def current(class_): + "Windows Platform SDK GetTimeZoneInformation" + tzi = class_() + kernel32 = ctypes.windll.kernel32 + getter = kernel32.GetTimeZoneInformation + getter = getattr(kernel32, 'GetDynamicTimeZoneInformation', getter) + code = getter(ctypes.byref(tzi)) + return code, tzi - def copy(self): - return self.__class__(self) + def set(self): + kernel32 = ctypes.windll.kernel32 + setter = kernel32.SetTimeZoneInformation + setter = getattr(kernel32, 'SetDynamicTimeZoneInformation', setter) + return setter(ctypes.byref(self)) - def locate_daylight_start(self, year): - info = self.get_info_for_year(year) - return self._locate_day(year, info.daylight_start) + def copy(self): + return self.__class__(self) - def locate_standard_start(self, year): - info = self.get_info_for_year(year) - return self._locate_day(year, info.standard_start) + def locate_daylight_start(self, year): + info = self.get_info_for_year(year) + return self._locate_day(year, info.daylight_start) - def get_info_for_year(self, year): - return self.dynamic_info[year] + def locate_standard_start(self, year): + info = self.get_info_for_year(year) + return self._locate_day(year, info.standard_start) - @property - def dynamic_info(self): - "Return a map that for a given year will return the correct Info" - if self.key_name: - dyn_key = self.get_key().subkey('Dynamic DST') - del dyn_key['FirstEntry'] - del dyn_key['LastEntry'] - years = map(int, dyn_key.keys()) - values = map(Info, dyn_key.values()) - # create a range mapping that searches by descending year and matches - # if the target year is greater or equal. - return RangeMap(zip(years, values), RangeMap.descending, operator.ge) - else: - return AnyDict(self) + def get_info_for_year(self, year): + return self.dynamic_info[year] - @staticmethod - def _locate_day(year, cutoff): - """ - Takes a SYSTEMTIME object, such as retrieved from a TIME_ZONE_INFORMATION - structure or call to GetTimeZoneInformation and interprets - it based on the given - year to identify the actual day. + @property + def dynamic_info(self): + "Return a map that for a given year will return the correct Info" + if self.key_name: + dyn_key = self.get_key().subkey('Dynamic DST') + del dyn_key['FirstEntry'] + del dyn_key['LastEntry'] + years = map(int, dyn_key.keys()) + values = map(Info, dyn_key.values()) + # create a range mapping that searches by descending year and matches + # if the target year is greater or equal. + return RangeMap(zip(years, values), RangeMap.descending, operator.ge) + else: + return AnyDict(self) - This method is necessary because the SYSTEMTIME structure - refers to a day by its - day of the week and week of the month (e.g. 4th saturday in March). + @staticmethod + def _locate_day(year, cutoff): + """ + Takes a SYSTEMTIME object, such as retrieved from a TIME_ZONE_INFORMATION + structure or call to GetTimeZoneInformation and interprets + it based on the given + year to identify the actual day. - >>> SATURDAY = 6 - >>> MARCH = 3 - >>> st = SYSTEMTIME(2000, MARCH, SATURDAY, 4, 0, 0, 0, 0) + This method is necessary because the SYSTEMTIME structure + refers to a day by its + day of the week and week of the month (e.g. 4th saturday in March). - # according to my calendar, the 4th Saturday in March in 2009 was the 28th - >>> expected_date = datetime.datetime(2009, 3, 28) - >>> Info._locate_day(2009, st) == expected_date - True - """ - # MS stores Sunday as 0, Python datetime stores Monday as zero - target_weekday = (cutoff.day_of_week + 6) % 7 - # For SYSTEMTIMEs relating to time zone inforamtion, cutoff.day - # is the week of the month - week_of_month = cutoff.day - # so the following is the first day of that week - day = (week_of_month - 1) * 7 + 1 - result = datetime.datetime( - year, cutoff.month, day, - cutoff.hour, cutoff.minute, cutoff.second, cutoff.millisecond) - # now the result is the correct week, but not necessarily - # the correct day of the week - days_to_go = (target_weekday - result.weekday()) % 7 - result += datetime.timedelta(days_to_go) - # if we selected a day in the month following the target month, - # move back a week or two. - # This is necessary because Microsoft defines the fifth week in a month - # to be the last week in a month and adding the time delta might have - # pushed the result into the next month. - while result.month == cutoff.month + 1: - result -= datetime.timedelta(weeks=1) - return result + >>> SATURDAY = 6 + >>> MARCH = 3 + >>> st = SYSTEMTIME(2000, MARCH, SATURDAY, 4, 0, 0, 0, 0) + + # according to my calendar, the 4th Saturday in March in 2009 was the 28th + >>> expected_date = datetime.datetime(2009, 3, 28) + >>> Info._locate_day(2009, st) == expected_date + True + """ + # MS stores Sunday as 0, Python datetime stores Monday as zero + target_weekday = (cutoff.day_of_week + 6) % 7 + # For SYSTEMTIMEs relating to time zone inforamtion, cutoff.day + # is the week of the month + week_of_month = cutoff.day + # so the following is the first day of that week + day = (week_of_month - 1) * 7 + 1 + result = datetime.datetime( + year, + cutoff.month, + day, + cutoff.hour, + cutoff.minute, + cutoff.second, + cutoff.millisecond, + ) + # now the result is the correct week, but not necessarily + # the correct day of the week + days_to_go = (target_weekday - result.weekday()) % 7 + result += datetime.timedelta(days_to_go) + # if we selected a day in the month following the target month, + # move back a week or two. + # This is necessary because Microsoft defines the fifth week in a month + # to be the last week in a month and adding the time delta might have + # pushed the result into the next month. + while result.month == cutoff.month + 1: + result -= datetime.timedelta(weeks=1) + return result diff --git a/libs/win/jaraco/windows/ui.py b/libs/win/jaraco/windows/ui.py index 20f948f3..7249331e 100644 --- a/libs/win/jaraco/windows/ui.py +++ b/libs/win/jaraco/windows/ui.py @@ -5,5 +5,5 @@ from jaraco.windows.util import ensure_unicode def MessageBox(text, caption=None, handle=None, type=None): - text, caption = map(ensure_unicode, (text, caption)) - ctypes.windll.user32.MessageBoxW(handle, text, caption, type) + text, caption = map(ensure_unicode, (text, caption)) + ctypes.windll.user32.MessageBoxW(handle, text, caption, type) diff --git a/libs/win/jaraco/windows/user.py b/libs/win/jaraco/windows/user.py index 9b574777..503233b7 100644 --- a/libs/win/jaraco/windows/user.py +++ b/libs/win/jaraco/windows/user.py @@ -5,12 +5,12 @@ from .error import WindowsError, handle_nonzero_success def get_user_name(): - size = ctypes.wintypes.DWORD() - try: - handle_nonzero_success(GetUserName(None, size)) - except WindowsError as e: - if e.code != errors.ERROR_INSUFFICIENT_BUFFER: - raise - buffer = ctypes.create_unicode_buffer(size.value) - handle_nonzero_success(GetUserName(buffer, size)) - return buffer.value + size = ctypes.wintypes.DWORD() + try: + handle_nonzero_success(GetUserName(None, size)) + except WindowsError as e: + if e.code != errors.ERROR_INSUFFICIENT_BUFFER: + raise + buffer = ctypes.create_unicode_buffer(size.value) + handle_nonzero_success(GetUserName(buffer, size)) + return buffer.value diff --git a/libs/win/jaraco/windows/util.py b/libs/win/jaraco/windows/util.py index 5524df85..c51ff997 100644 --- a/libs/win/jaraco/windows/util.py +++ b/libs/win/jaraco/windows/util.py @@ -4,17 +4,18 @@ import ctypes def ensure_unicode(param): - try: - param = ctypes.create_unicode_buffer(param) - except TypeError: - pass # just return the param as is - return param + try: + param = ctypes.create_unicode_buffer(param) + except TypeError: + pass # just return the param as is + return param class Extended(object): - "Used to add extended capability to structures" - def __eq__(self, other): - return memoryview(self) == memoryview(other) + "Used to add extended capability to structures" - def __ne__(self, other): - return memoryview(self) != memoryview(other) + def __eq__(self, other): + return memoryview(self) == memoryview(other) + + def __ne__(self, other): + return memoryview(self) != memoryview(other) diff --git a/libs/win/jaraco/windows/vpn.py b/libs/win/jaraco/windows/vpn.py index 9cf31dc1..df9f1503 100644 --- a/libs/win/jaraco/windows/vpn.py +++ b/libs/win/jaraco/windows/vpn.py @@ -3,15 +3,19 @@ from path import Path def install_pptp(name, param_lines): - """ - """ - # or consider using the API: - # http://msdn.microsoft.com/en-us/library/aa446739%28v=VS.85%29.aspx - pbk_path = ( - Path(os.environ['PROGRAMDATA']) - / 'Microsoft' / 'Network' / 'Connections' / 'pbk' / 'rasphone.pbk') - pbk_path.dirname().makedirs_p() - with open(pbk_path, 'a') as pbk: - pbk.write('[{name}]\n'.format(name=name)) - pbk.writelines(param_lines) - pbk.write('\n') + """ """ + # or consider using the API: + # http://msdn.microsoft.com/en-us/library/aa446739%28v=VS.85%29.aspx + pbk_path = ( + Path(os.environ['PROGRAMDATA']) + / 'Microsoft' + / 'Network' + / 'Connections' + / 'pbk' + / 'rasphone.pbk' + ) + pbk_path.dirname().makedirs_p() + with open(pbk_path, 'a') as pbk: + pbk.write('[{name}]\n'.format(name=name)) + pbk.writelines(param_lines) + pbk.write('\n') diff --git a/libs/win/jaraco/windows/xmouse.py b/libs/win/jaraco/windows/xmouse.py index 20b19435..8f3485f3 100644 --- a/libs/win/jaraco/windows/xmouse.py +++ b/libs/win/jaraco/windows/xmouse.py @@ -1,7 +1,3 @@ -#!python - -from __future__ import print_function - import ctypes from jaraco.windows.error import handle_nonzero_success from jaraco.windows.api import system @@ -9,92 +5,84 @@ from jaraco.ui.cmdline import Command def set(value): - result = system.SystemParametersInfo( - system.SPI_SETACTIVEWINDOWTRACKING, - 0, - ctypes.cast(value, ctypes.c_void_p), - 0, - ) - handle_nonzero_success(result) + result = system.SystemParametersInfo( + system.SPI_SETACTIVEWINDOWTRACKING, 0, ctypes.cast(value, ctypes.c_void_p), 0 + ) + handle_nonzero_success(result) def get(): - value = ctypes.wintypes.BOOL() - result = system.SystemParametersInfo( - system.SPI_GETACTIVEWINDOWTRACKING, - 0, - ctypes.byref(value), - 0, - ) - handle_nonzero_success(result) - return bool(value) + value = ctypes.wintypes.BOOL() + result = system.SystemParametersInfo( + system.SPI_GETACTIVEWINDOWTRACKING, 0, ctypes.byref(value), 0 + ) + handle_nonzero_success(result) + return bool(value) def set_delay(milliseconds): - result = system.SystemParametersInfo( - system.SPI_SETACTIVEWNDTRKTIMEOUT, - 0, - ctypes.cast(milliseconds, ctypes.c_void_p), - 0, - ) - handle_nonzero_success(result) + result = system.SystemParametersInfo( + system.SPI_SETACTIVEWNDTRKTIMEOUT, + 0, + ctypes.cast(milliseconds, ctypes.c_void_p), + 0, + ) + handle_nonzero_success(result) def get_delay(): - value = ctypes.wintypes.DWORD() - result = system.SystemParametersInfo( - system.SPI_GETACTIVEWNDTRKTIMEOUT, - 0, - ctypes.byref(value), - 0, - ) - handle_nonzero_success(result) - return int(value.value) + value = ctypes.wintypes.DWORD() + result = system.SystemParametersInfo( + system.SPI_GETACTIVEWNDTRKTIMEOUT, 0, ctypes.byref(value), 0 + ) + handle_nonzero_success(result) + return int(value.value) class DelayParam(Command): - @staticmethod - def add_arguments(parser): - parser.add_argument( - '-d', '--delay', type=int, - help="Delay in milliseconds for active window tracking" - ) + @staticmethod + def add_arguments(parser): + parser.add_argument( + '-d', + '--delay', + type=int, + help="Delay in milliseconds for active window tracking", + ) class Show(Command): - @classmethod - def run(cls, args): - msg = "xmouse: {enabled} (delay {delay}ms)".format( - enabled=get(), - delay=get_delay(), - ) - print(msg) + @classmethod + def run(cls, args): + msg = "xmouse: {enabled} (delay {delay}ms)".format( + enabled=get(), delay=get_delay() + ) + print(msg) class Enable(DelayParam): - @classmethod - def run(cls, args): - print("enabling xmouse") - set(True) - args.delay and set_delay(args.delay) + @classmethod + def run(cls, args): + print("enabling xmouse") + set(True) + args.delay and set_delay(args.delay) class Disable(DelayParam): - @classmethod - def run(cls, args): - print("disabling xmouse") - set(False) - args.delay and set_delay(args.delay) + @classmethod + def run(cls, args): + print("disabling xmouse") + set(False) + args.delay and set_delay(args.delay) class Toggle(DelayParam): - @classmethod - def run(cls, args): - value = get() - print("xmouse: %s -> %s" % (value, not value)) - set(not value) - args.delay and set_delay(args.delay) + @classmethod + def run(cls, args): + value = get() + print("xmouse: %s -> %s" % (value, not value)) + set(not value) + args.delay and set_delay(args.delay) if __name__ == '__main__': - Command.invoke() + Command.invoke() diff --git a/libs/win/more_itertools/__init__.py b/libs/win/more_itertools/__init__.py index bba462c3..557bfc20 100644 --- a/libs/win/more_itertools/__init__.py +++ b/libs/win/more_itertools/__init__.py @@ -1,2 +1,6 @@ -from more_itertools.more import * # noqa -from more_itertools.recipes import * # noqa +"""More routines for operating on iterables, beyond itertools""" + +from .more import * # noqa +from .recipes import * # noqa + +__version__ = '9.0.0' diff --git a/libs/win/more_itertools/__init__.pyi b/libs/win/more_itertools/__init__.pyi new file mode 100644 index 00000000..96f6e36c --- /dev/null +++ b/libs/win/more_itertools/__init__.pyi @@ -0,0 +1,2 @@ +from .more import * +from .recipes import * diff --git a/libs/win/more_itertools/more.py b/libs/win/more_itertools/more.py index 05e851ee..7f73d6ed 100644 --- a/libs/win/more_itertools/more.py +++ b/libs/win/more_itertools/more.py @@ -1,8 +1,9 @@ -from __future__ import print_function +import warnings -from collections import Counter, defaultdict, deque -from functools import partial, wraps -from heapq import merge +from collections import Counter, defaultdict, deque, abc +from collections.abc import Sequence +from functools import partial, reduce, wraps +from heapq import heapify, heapreplace, heappop from itertools import ( chain, compress, @@ -14,100 +15,162 @@ from itertools import ( repeat, starmap, takewhile, - tee + tee, + zip_longest, ) -from operator import itemgetter, lt, gt, sub -from sys import maxsize, version_info -try: - from collections.abc import Sequence -except ImportError: - from collections import Sequence +from math import exp, factorial, floor, log +from queue import Empty, Queue +from random import random, randrange, uniform +from operator import itemgetter, mul, sub, gt, lt, ge, le +from sys import hexversion, maxsize +from time import monotonic -from six import binary_type, string_types, text_type -from six.moves import filter, map, range, zip, zip_longest - -from .recipes import consume, flatten, take +from .recipes import ( + _marker, + _zip_equal, + UnequalIterablesError, + consume, + flatten, + pairwise, + powerset, + take, + unique_everseen, + all_equal, +) __all__ = [ + 'AbortThread', + 'SequenceView', + 'UnequalIterablesError', 'adjacent', + 'all_unique', 'always_iterable', 'always_reversible', 'bucket', + 'callback_iter', 'chunked', + 'chunked_even', 'circular_shifts', 'collapse', - 'collate', + 'combination_index', 'consecutive_groups', + 'constrained_batches', 'consumer', 'count_cycle', + 'countable', 'difference', + 'distinct_combinations', 'distinct_permutations', 'distribute', 'divide', + 'duplicates_everseen', + 'duplicates_justseen', 'exactly_n', + 'filter_except', 'first', 'groupby_transform', + 'ichunked', + 'iequals', 'ilen', - 'interleave_longest', 'interleave', + 'interleave_evenly', + 'interleave_longest', 'intersperse', + 'is_sorted', 'islice_extended', 'iterate', 'last', 'locate', + 'longest_common_prefix', 'lstrip', 'make_decorator', + 'map_except', + 'map_if', 'map_reduce', + 'mark_ends', + 'minmax', + 'nth_or_last', + 'nth_permutation', + 'nth_product', 'numeric_range', 'one', + 'only', 'padded', + 'partitions', 'peekable', + 'permutation_index', + 'product_index', + 'raise_', + 'repeat_each', + 'repeat_last', 'replace', 'rlocate', 'rstrip', 'run_length', + 'sample', 'seekable', - 'SequenceView', + 'set_partitions', 'side_effect', 'sliced', 'sort_together', - 'split_at', 'split_after', + 'split_at', 'split_before', + 'split_into', + 'split_when', 'spy', 'stagger', 'strip', + 'strictly_n', + 'substrings', + 'substrings_indexes', + 'time_limited', + 'unique_in_window', 'unique_to_each', + 'unzip', + 'value_chain', 'windowed', + 'windowed_complete', 'with_iter', + 'zip_broadcast', + 'zip_equal', 'zip_offset', ] -_marker = object() - -def chunked(iterable, n): +def chunked(iterable, n, strict=False): """Break *iterable* into lists of length *n*: >>> list(chunked([1, 2, 3, 4, 5, 6], 3)) [[1, 2, 3], [4, 5, 6]] - If the length of *iterable* is not evenly divisible by *n*, the last - returned list will be shorter: + By the default, the last yielded list will have fewer than *n* elements + if the length of *iterable* is not divisible by *n*: >>> list(chunked([1, 2, 3, 4, 5, 6, 7, 8], 3)) [[1, 2, 3], [4, 5, 6], [7, 8]] To use a fill-in value instead, see the :func:`grouper` recipe. - :func:`chunked` is useful for splitting up a computation on a large number - of keys into batches, to be pickled and sent off to worker processes. One - example is operations on rows in MySQL, which does not implement - server-side cursors properly and would otherwise load the entire dataset - into RAM on the client. + If the length of *iterable* is not divisible by *n* and *strict* is + ``True``, then ``ValueError`` will be raised before the last + list is yielded. """ - return iter(partial(take, n, iter(iterable)), []) + iterator = iter(partial(take, n, iter(iterable)), []) + if strict: + if n is None: + raise ValueError('n must not be None when using strict mode.') + + def ret(): + for chunk in iterator: + if len(chunk) != n: + raise ValueError('iterable is not divisible by n.') + yield chunk + + return iter(ret()) + else: + return iterator def first(iterable, default=_marker): @@ -129,14 +192,12 @@ def first(iterable, default=_marker): """ try: return next(iter(iterable)) - except StopIteration: - # I'm on the edge about raising ValueError instead of StopIteration. At - # the moment, ValueError wins, because the caller could conceivably - # want to do something different with flow control when I raise the - # exception, and it's weird to explicitly catch StopIteration. + except StopIteration as e: if default is _marker: - raise ValueError('first() was called on an empty iterable, and no ' - 'default value was provided.') + raise ValueError( + 'first() was called on an empty iterable, and no ' + 'default value was provided.' + ) from e return default @@ -153,20 +214,40 @@ def last(iterable, default=_marker): raise ``ValueError``. """ try: - try: - # Try to access the last item directly + if isinstance(iterable, Sequence): return iterable[-1] - except (TypeError, AttributeError, KeyError): - # If not slice-able, iterate entirely using length-1 deque - return deque(iterable, maxlen=1)[0] - except IndexError: # If the iterable was empty + # Work around https://bugs.python.org/issue38525 + elif hasattr(iterable, '__reversed__') and (hexversion != 0x030800F0): + return next(reversed(iterable)) + else: + return deque(iterable, maxlen=1)[-1] + except (IndexError, TypeError, StopIteration): if default is _marker: - raise ValueError('last() was called on an empty iterable, and no ' - 'default value was provided.') + raise ValueError( + 'last() was called on an empty iterable, and no default was ' + 'provided.' + ) return default -class peekable(object): +def nth_or_last(iterable, n, default=_marker): + """Return the nth or the last item of *iterable*, + or *default* if *iterable* is empty. + + >>> nth_or_last([0, 1, 2, 3], 2) + 2 + >>> nth_or_last([0, 1], 2) + 1 + >>> nth_or_last([], 0, 'some default') + 'some default' + + If *default* is not provided and there are no items in the iterable, + raise ``ValueError``. + """ + return last(islice(iterable, n + 1), default=default) + + +class peekable: """Wrap an iterator to allow lookahead and prepending elements. Call :meth:`peek` on the result to get the value that will be returned @@ -219,11 +300,12 @@ class peekable(object): >>> if p: # peekable has items ... list(p) ['a', 'b'] - >>> if not p: # peekable is exhaused + >>> if not p: # peekable is exhausted ... list(p) [] """ + def __init__(self, iterable): self._it = iter(iterable) self._cache = deque() @@ -238,10 +320,6 @@ class peekable(object): return False return True - def __nonzero__(self): - # For Python 2 compatibility - return self.__bool__() - def peek(self, default=_marker): """Return the item that will be next returned from ``next()``. @@ -295,8 +373,6 @@ class peekable(object): return next(self._it) - next = __next__ # For Python 2 compatibility - def _get_slice(self, index): # Normalize the slice's arguments step = 1 if (index.step is None) else index.step @@ -336,70 +412,6 @@ class peekable(object): return self._cache[index] -def _collate(*iterables, **kwargs): - """Helper for ``collate()``, called when the user is using the ``reverse`` - or ``key`` keyword arguments on Python versions below 3.5. - - """ - key = kwargs.pop('key', lambda a: a) - reverse = kwargs.pop('reverse', False) - - min_or_max = partial(max if reverse else min, key=itemgetter(0)) - peekables = [peekable(it) for it in iterables] - peekables = [p for p in peekables if p] # Kill empties. - while peekables: - _, p = min_or_max((key(p.peek()), p) for p in peekables) - yield next(p) - peekables = [x for x in peekables if x] - - -def collate(*iterables, **kwargs): - """Return a sorted merge of the items from each of several already-sorted - *iterables*. - - >>> list(collate('ACDZ', 'AZ', 'JKL')) - ['A', 'A', 'C', 'D', 'J', 'K', 'L', 'Z', 'Z'] - - Works lazily, keeping only the next value from each iterable in memory. Use - :func:`collate` to, for example, perform a n-way mergesort of items that - don't fit in memory. - - If a *key* function is specified, the iterables will be sorted according - to its result: - - >>> key = lambda s: int(s) # Sort by numeric value, not by string - >>> list(collate(['1', '10'], ['2', '11'], key=key)) - ['1', '2', '10', '11'] - - - If the *iterables* are sorted in descending order, set *reverse* to - ``True``: - - >>> list(collate([5, 3, 1], [4, 2, 0], reverse=True)) - [5, 4, 3, 2, 1, 0] - - If the elements of the passed-in iterables are out of order, you might get - unexpected results. - - On Python 2.7, this function delegates to :func:`heapq.merge` if neither - of the keyword arguments are specified. On Python 3.5+, this function - is an alias for :func:`heapq.merge`. - - """ - if not kwargs: - return merge(*iterables) - - return _collate(*iterables, **kwargs) - - -# If using Python version 3.5 or greater, heapq.merge() will be faster than -# collate - use that instead. -if version_info >= (3, 5, 0): - _collate_docstring = collate.__doc__ - collate = partial(merge) - collate.__doc__ = _collate_docstring - - def consumer(func): """Decorator that automatically advances a PEP-342-style "reverse iterator" to its first yield point so you don't have to call ``next()`` on it @@ -422,11 +434,13 @@ def consumer(func): ``t.send()`` could be used. """ + @wraps(func) def wrapper(*args, **kwargs): gen = func(*args, **kwargs) next(gen) return gen + return wrapper @@ -439,20 +453,20 @@ def ilen(iterable): This consumes the iterable, so handle with care. """ - # maxlen=1 only stores the last item in the deque - d = deque(enumerate(iterable, 1), maxlen=1) - # since we started enumerate at 1, - # the first item of the last pair will be the length of the iterable - # (assuming there were items) - return d[0][0] if d else 0 + # This approach was selected because benchmarks showed it's likely the + # fastest of the known implementations at the time of writing. + # See GitHub tracker: #236, #230. + counter = count() + deque(zip(iterable, counter), maxlen=0) + return next(counter) def iterate(func, start): """Return ``start``, ``func(start)``, ``func(func(start))``, ... - >>> from itertools import islice - >>> list(islice(iterate(lambda x: 2*x, 1), 10)) - [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + >>> from itertools import islice + >>> list(islice(iterate(lambda x: 2*x, 1), 10)) + [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] """ while True: @@ -472,8 +486,7 @@ def with_iter(context_manager): """ with context_manager as iterable: - for item in iterable: - yield item + yield from iterable def one(iterable, too_short=None, too_long=None): @@ -507,7 +520,8 @@ def one(iterable, too_short=None, too_long=None): >>> one(it) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - ValueError: too many items in iterable (expected 1)' + ValueError: Expected exactly one item in iterable, but got 'too', + 'many', and perhaps more. >>> too_long = RuntimeError >>> one(it, too_long=too_long) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): @@ -515,29 +529,112 @@ def one(iterable, too_short=None, too_long=None): RuntimeError Note that :func:`one` attempts to advance *iterable* twice to ensure there - is only one item. If there is more than one, both items will be discarded. - See :func:`spy` or :func:`peekable` to check iterable contents less - destructively. + is only one item. See :func:`spy` or :func:`peekable` to check iterable + contents less destructively. """ it = iter(iterable) try: - value = next(it) + first_value = next(it) + except StopIteration as e: + raise ( + too_short or ValueError('too few items in iterable (expected 1)') + ) from e + + try: + second_value = next(it) except StopIteration: - raise too_short or ValueError('too few items in iterable (expected 1)') + pass + else: + msg = ( + 'Expected exactly one item in iterable, but got {!r}, {!r}, ' + 'and perhaps more.'.format(first_value, second_value) + ) + raise too_long or ValueError(msg) + + return first_value + + +def raise_(exception, *args): + raise exception(*args) + + +def strictly_n(iterable, n, too_short=None, too_long=None): + """Validate that *iterable* has exactly *n* items and return them if + it does. If it has fewer than *n* items, call function *too_short* + with those items. If it has more than *n* items, call function + *too_long* with the first ``n + 1`` items. + + >>> iterable = ['a', 'b', 'c', 'd'] + >>> n = 4 + >>> list(strictly_n(iterable, n)) + ['a', 'b', 'c', 'd'] + + By default, *too_short* and *too_long* are functions that raise + ``ValueError``. + + >>> list(strictly_n('ab', 3)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too few items in iterable (got 2) + + >>> list(strictly_n('abc', 2)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too many items in iterable (got at least 3) + + You can instead supply functions that do something else. + *too_short* will be called with the number of items in *iterable*. + *too_long* will be called with `n + 1`. + + >>> def too_short(item_count): + ... raise RuntimeError + >>> it = strictly_n('abcd', 6, too_short=too_short) + >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + RuntimeError + + >>> def too_long(item_count): + ... print('The boss is going to hear about this') + >>> it = strictly_n('abcdef', 4, too_long=too_long) + >>> list(it) + The boss is going to hear about this + ['a', 'b', 'c', 'd'] + + """ + if too_short is None: + too_short = lambda item_count: raise_( + ValueError, + 'Too few items in iterable (got {})'.format(item_count), + ) + + if too_long is None: + too_long = lambda item_count: raise_( + ValueError, + 'Too many items in iterable (got at least {})'.format(item_count), + ) + + it = iter(iterable) + for i in range(n): + try: + item = next(it) + except StopIteration: + too_short(i) + return + else: + yield item try: next(it) except StopIteration: pass else: - raise too_long or ValueError('too many items in iterable (expected 1)') - - return value + too_long(n + 1) -def distinct_permutations(iterable): +def distinct_permutations(iterable, r=None): """Yield successive distinct permutations of the elements in *iterable*. >>> sorted(distinct_permutations([1, 0, 1])) @@ -553,34 +650,88 @@ def distinct_permutations(iterable): items input, and each `x_i` is the count of a distinct item in the input sequence. + If *r* is given, only the *r*-length permutations are yielded. + + >>> sorted(distinct_permutations([1, 0, 1], r=2)) + [(0, 1), (1, 0), (1, 1)] + >>> sorted(distinct_permutations(range(3), r=2)) + [(0, 1), (0, 2), (1, 0), (1, 2), (2, 0), (2, 1)] + """ - def perm_unique_helper(item_counts, perm, i): - """Internal helper function + # Algorithm: https://w.wiki/Qai + def _full(A): + while True: + # Yield the permutation we have + yield tuple(A) - :arg item_counts: Stores the unique items in ``iterable`` and how many - times they are repeated - :arg perm: The permutation that is being built for output - :arg i: The index of the permutation being modified + # Find the largest index i such that A[i] < A[i + 1] + for i in range(size - 2, -1, -1): + if A[i] < A[i + 1]: + break + # If no such index exists, this permutation is the last one + else: + return - The output permutations are built up recursively; the distinct items - are placed until their repetitions are exhausted. - """ - if i < 0: - yield tuple(perm) - else: - for item in item_counts: - if item_counts[item] <= 0: - continue - perm[i] = item - item_counts[item] -= 1 - for x in perm_unique_helper(item_counts, perm, i - 1): - yield x - item_counts[item] += 1 + # Find the largest index j greater than j such that A[i] < A[j] + for j in range(size - 1, i, -1): + if A[i] < A[j]: + break - item_counts = Counter(iterable) - length = sum(item_counts.values()) + # Swap the value of A[i] with that of A[j], then reverse the + # sequence from A[i + 1] to form the new permutation + A[i], A[j] = A[j], A[i] + A[i + 1 :] = A[: i - size : -1] # A[i + 1:][::-1] - return perm_unique_helper(item_counts, [None] * length, length - 1) + # Algorithm: modified from the above + def _partial(A, r): + # Split A into the first r items and the last r items + head, tail = A[:r], A[r:] + right_head_indexes = range(r - 1, -1, -1) + left_tail_indexes = range(len(tail)) + + while True: + # Yield the permutation we have + yield tuple(head) + + # Starting from the right, find the first index of the head with + # value smaller than the maximum value of the tail - call it i. + pivot = tail[-1] + for i in right_head_indexes: + if head[i] < pivot: + break + pivot = head[i] + else: + return + + # Starting from the left, find the first value of the tail + # with a value greater than head[i] and swap. + for j in left_tail_indexes: + if tail[j] > head[i]: + head[i], tail[j] = tail[j], head[i] + break + # If we didn't find one, start from the right and find the first + # index of the head with a value greater than head[i] and swap. + else: + for j in right_head_indexes: + if head[j] > head[i]: + head[i], head[j] = head[j], head[i] + break + + # Reverse head[i + 1:] and swap it with tail[:r - (i + 1)] + tail += head[: i - r : -1] # head[i + 1:][::-1] + i += 1 + head[i:], tail[:] = tail[: r - i], tail[r - i :] + + items = sorted(iterable) + + size = len(items) + if r is None: + r = size + + if 0 < r <= size: + return _full(items) if (r == size) else _partial(items, r) + + return iter(() if r else ((),)) def intersperse(e, iterable, n=1): @@ -597,8 +748,8 @@ def intersperse(e, iterable, n=1): if n == 0: raise ValueError('n must be > 0') elif n == 1: - # interleave(repeat(e), iterable) -> e, x_0, e, e, x_1, e, x_2... - # islice(..., 1, None) -> x_0, e, e, x_1, e, x_2... + # interleave(repeat(e), iterable) -> e, x_0, e, x_1, e, x_2... + # islice(..., 1, None) -> x_0, e, x_1, e, x_2... return islice(interleave(repeat(e), iterable), 1, None) else: # interleave(filler, chunks) -> [e], [x_0, x_1], [e], [x_2, x_3]... @@ -650,7 +801,7 @@ def windowed(seq, n, fillvalue=None, step=1): [(1, 2, 3), (2, 3, 4), (3, 4, 5)] When the window is larger than the iterable, *fillvalue* is used in place - of missing values:: + of missing values: >>> list(windowed([1, 2, 3], 4)) [(1, 2, 3, None)] @@ -660,6 +811,14 @@ def windowed(seq, n, fillvalue=None, step=1): >>> list(windowed([1, 2, 3, 4, 5, 6], 3, fillvalue='!', step=2)) [(1, 2, 3), (3, 4, 5), (5, 6, '!')] + To slide into the iterable's items, use :func:`chain` to add filler items + to the left: + + >>> iterable = [1, 2, 3, 4] + >>> n = 3 + >>> padding = [None] * (n - 1) + >>> list(windowed(chain(padding, iterable), 3)) + [(None, None, 1), (None, 1, 2), (1, 2, 3), (2, 3, 4)] """ if n < 0: raise ValueError('n must be >= 0') @@ -669,37 +828,92 @@ def windowed(seq, n, fillvalue=None, step=1): if step < 1: raise ValueError('step must be >= 1') - it = iter(seq) - window = deque([], n) - append = window.append - - # Initial deque fill - for _ in range(n): - append(next(it, fillvalue)) - yield tuple(window) - - # Appending new items to the right causes old items to fall off the left - i = 0 - for item in it: - append(item) - i = (i + 1) % step - if i % step == 0: + window = deque(maxlen=n) + i = n + for _ in map(window.append, seq): + i -= 1 + if not i: + i = step yield tuple(window) - # If there are items from the iterable in the window, pad with the given - # value and emit them. - if (i % step) and (step - i < n): - for _ in range(step - i): - append(fillvalue) + size = len(window) + if size == 0: + return + elif size < n: + yield tuple(chain(window, repeat(fillvalue, n - size))) + elif 0 < i < min(step, n): + window += (fillvalue,) * i yield tuple(window) -class bucket(object): +def substrings(iterable): + """Yield all of the substrings of *iterable*. + + >>> [''.join(s) for s in substrings('more')] + ['m', 'o', 'r', 'e', 'mo', 'or', 're', 'mor', 'ore', 'more'] + + Note that non-string iterables can also be subdivided. + + >>> list(substrings([0, 1, 2])) + [(0,), (1,), (2,), (0, 1), (1, 2), (0, 1, 2)] + + """ + # The length-1 substrings + seq = [] + for item in iter(iterable): + seq.append(item) + yield (item,) + seq = tuple(seq) + item_count = len(seq) + + # And the rest + for n in range(2, item_count + 1): + for i in range(item_count - n + 1): + yield seq[i : i + n] + + +def substrings_indexes(seq, reverse=False): + """Yield all substrings and their positions in *seq* + + The items yielded will be a tuple of the form ``(substr, i, j)``, where + ``substr == seq[i:j]``. + + This function only works for iterables that support slicing, such as + ``str`` objects. + + >>> for item in substrings_indexes('more'): + ... print(item) + ('m', 0, 1) + ('o', 1, 2) + ('r', 2, 3) + ('e', 3, 4) + ('mo', 0, 2) + ('or', 1, 3) + ('re', 2, 4) + ('mor', 0, 3) + ('ore', 1, 4) + ('more', 0, 4) + + Set *reverse* to ``True`` to yield the same items in the opposite order. + + + """ + r = range(1, len(seq) + 1) + if reverse: + r = reversed(r) + return ( + (seq[i : i + L], i, i + L) for L in r for i in range(len(seq) - L + 1) + ) + + +class bucket: """Wrap *iterable* and return an object that buckets it iterable into child iterables based on a *key* function. >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] - >>> s = bucket(iterable, key=lambda x: x[0]) + >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character + >>> sorted(list(s)) # Get the keys + ['a', 'b', 'c'] >>> a_iterable = s['a'] >>> next(a_iterable) 'a1' @@ -727,6 +941,7 @@ class bucket(object): [] """ + def __init__(self, iterable, key, validator=None): self._it = iter(iterable) self._key = key @@ -772,6 +987,14 @@ class bucket(object): elif self._validator(item_value): self._cache[item_value].append(item) + def __iter__(self): + for item in self._it: + item_value = self._key(item) + if self._validator(item_value): + self._cache[item_value].append(item) + + yield from self._cache.keys() + def __getitem__(self, value): if not self._validator(value): return iter(()) @@ -819,7 +1042,7 @@ def spy(iterable, n=1): it = iter(iterable) head = take(n, it) - return head, chain(head, it) + return head.copy(), chain(head, it) def interleave(*iterables): @@ -852,6 +1075,72 @@ def interleave_longest(*iterables): return (x for x in i if x is not _marker) +def interleave_evenly(iterables, lengths=None): + """ + Interleave multiple iterables so that their elements are evenly distributed + throughout the output sequence. + + >>> iterables = [1, 2, 3, 4, 5], ['a', 'b'] + >>> list(interleave_evenly(iterables)) + [1, 2, 'a', 3, 4, 'b', 5] + + >>> iterables = [[1, 2, 3], [4, 5], [6, 7, 8]] + >>> list(interleave_evenly(iterables)) + [1, 6, 4, 2, 7, 3, 8, 5] + + This function requires iterables of known length. Iterables without + ``__len__()`` can be used by manually specifying lengths with *lengths*: + + >>> from itertools import combinations, repeat + >>> iterables = [combinations(range(4), 2), ['a', 'b', 'c']] + >>> lengths = [4 * (4 - 1) // 2, 3] + >>> list(interleave_evenly(iterables, lengths=lengths)) + [(0, 1), (0, 2), 'a', (0, 3), (1, 2), 'b', (1, 3), (2, 3), 'c'] + + Based on Bresenham's algorithm. + """ + if lengths is None: + try: + lengths = [len(it) for it in iterables] + except TypeError: + raise ValueError( + 'Iterable lengths could not be determined automatically. ' + 'Specify them with the lengths keyword.' + ) + elif len(iterables) != len(lengths): + raise ValueError('Mismatching number of iterables and lengths.') + + dims = len(lengths) + + # sort iterables by length, descending + lengths_permute = sorted( + range(dims), key=lambda i: lengths[i], reverse=True + ) + lengths_desc = [lengths[i] for i in lengths_permute] + iters_desc = [iter(iterables[i]) for i in lengths_permute] + + # the longest iterable is the primary one (Bresenham: the longest + # distance along an axis) + delta_primary, deltas_secondary = lengths_desc[0], lengths_desc[1:] + iter_primary, iters_secondary = iters_desc[0], iters_desc[1:] + errors = [delta_primary // dims] * len(deltas_secondary) + + to_yield = sum(lengths) + while to_yield: + yield next(iter_primary) + to_yield -= 1 + # update errors for each secondary iterable + errors = [e - delta for e, delta in zip(errors, deltas_secondary)] + + # those iterables for which the error is negative are yielded + # ("diagonal step" in Bresenham) + for i, e in enumerate(errors): + if e < 0: + yield next(iters_secondary[i]) + to_yield -= 1 + errors[i] += delta_primary + + def collapse(iterable, base_type=None, levels=None): """Flatten an iterable with multiple levels of nesting (e.g., a list of lists of tuples) into non-iterable types. @@ -860,7 +1149,9 @@ def collapse(iterable, base_type=None, levels=None): >>> list(collapse(iterable)) [1, 2, 3, 4, 5, 6] - String types are not considered iterable and will not be collapsed. + Binary and text strings are not considered iterable and + will not be collapsed. + To avoid collapsing other types, specify *base_type*: >>> iterable = ['ab', ('cd', 'ef'), ['gh', 'ij']] @@ -876,11 +1167,12 @@ def collapse(iterable, base_type=None, levels=None): ['a', ['b'], 'c', ['d']] """ + def walk(node, level): if ( - ((levels is not None) and (level > levels)) or - isinstance(node, string_types) or - ((base_type is not None) and isinstance(node, base_type)) + ((levels is not None) and (level > levels)) + or isinstance(node, (str, bytes)) + or ((base_type is not None) and isinstance(node, base_type)) ): yield node return @@ -892,11 +1184,9 @@ def collapse(iterable, base_type=None, levels=None): return else: for child in tree: - for x in walk(child, level + 1): - yield x + yield from walk(child, level + 1) - for x in walk(iterable, 0): - yield x + yield from walk(iterable, 0) def side_effect(func, iterable, chunk_size=None, before=None, after=None): @@ -954,56 +1244,93 @@ def side_effect(func, iterable, chunk_size=None, before=None, after=None): else: for chunk in chunked(iterable, chunk_size): func(chunk) - for item in chunk: - yield item + yield from chunk finally: if after is not None: after() -def sliced(seq, n): +def sliced(seq, n, strict=False): """Yield slices of length *n* from the sequence *seq*. - >>> list(sliced((1, 2, 3, 4, 5, 6), 3)) - [(1, 2, 3), (4, 5, 6)] + >>> list(sliced((1, 2, 3, 4, 5, 6), 3)) + [(1, 2, 3), (4, 5, 6)] - If the length of the sequence is not divisible by the requested slice - length, the last slice will be shorter. + By the default, the last yielded slice will have fewer than *n* elements + if the length of *seq* is not divisible by *n*: - >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3)) - [(1, 2, 3), (4, 5, 6), (7, 8)] + >>> list(sliced((1, 2, 3, 4, 5, 6, 7, 8), 3)) + [(1, 2, 3), (4, 5, 6), (7, 8)] + + If the length of *seq* is not divisible by *n* and *strict* is + ``True``, then ``ValueError`` will be raised before the last + slice is yielded. This function will only work for iterables that support slicing. For non-sliceable iterables, see :func:`chunked`. """ - return takewhile(bool, (seq[i: i + n] for i in count(0, n))) + iterator = takewhile(len, (seq[i : i + n] for i in count(0, n))) + if strict: + + def ret(): + for _slice in iterator: + if len(_slice) != n: + raise ValueError("seq is not divisible by n.") + yield _slice + + return iter(ret()) + else: + return iterator -def split_at(iterable, pred): +def split_at(iterable, pred, maxsplit=-1, keep_separator=False): """Yield lists of items from *iterable*, where each list is delimited by - an item where callable *pred* returns ``True``. The lists do not include - the delimiting items. + an item where callable *pred* returns ``True``. >>> list(split_at('abcdcba', lambda x: x == 'b')) [['a'], ['c', 'd', 'c'], ['a']] >>> list(split_at(range(10), lambda n: n % 2 == 1)) [[0], [2], [4], [6], [8], []] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_at(range(10), lambda n: n % 2 == 1, maxsplit=2)) + [[0], [2], [4, 5, 6, 7, 8, 9]] + + By default, the delimiting items are not included in the output. + The include them, set *keep_separator* to ``True``. + + >>> list(split_at('abcdcba', lambda x: x == 'b', keep_separator=True)) + [['a'], ['b'], ['c', 'd', 'c'], ['b'], ['a']] + """ + if maxsplit == 0: + yield list(iterable) + return + buf = [] - for item in iterable: + it = iter(iterable) + for item in it: if pred(item): yield buf + if keep_separator: + yield [item] + if maxsplit == 1: + yield list(it) + return buf = [] + maxsplit -= 1 else: buf.append(item) yield buf -def split_before(iterable, pred): - """Yield lists of items from *iterable*, where each list starts with an - item where callable *pred* returns ``True``: +def split_before(iterable, pred, maxsplit=-1): + """Yield lists of items from *iterable*, where each list ends just before + an item for which callable *pred* returns ``True``: >>> list(split_before('OneTwo', lambda s: s.isupper())) [['O', 'n', 'e'], ['T', 'w', 'o']] @@ -1011,17 +1338,32 @@ def split_before(iterable, pred): >>> list(split_before(range(10), lambda n: n % 3 == 0)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_before(range(10), lambda n: n % 3 == 0, maxsplit=2)) + [[0, 1, 2], [3, 4, 5], [6, 7, 8, 9]] """ + if maxsplit == 0: + yield list(iterable) + return + buf = [] - for item in iterable: + it = iter(iterable) + for item in it: if pred(item) and buf: yield buf + if maxsplit == 1: + yield [item] + list(it) + return buf = [] + maxsplit -= 1 buf.append(item) - yield buf + if buf: + yield buf -def split_after(iterable, pred): +def split_after(iterable, pred, maxsplit=-1): """Yield lists of items from *iterable*, where each list ends with an item where callable *pred* returns ``True``: @@ -1031,17 +1373,122 @@ def split_after(iterable, pred): >>> list(split_after(range(10), lambda n: n % 3 == 0)) [[0], [1, 2, 3], [4, 5, 6], [7, 8, 9]] + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_after(range(10), lambda n: n % 3 == 0, maxsplit=2)) + [[0], [1, 2, 3], [4, 5, 6, 7, 8, 9]] + """ + if maxsplit == 0: + yield list(iterable) + return + buf = [] - for item in iterable: + it = iter(iterable) + for item in it: buf.append(item) if pred(item) and buf: yield buf + if maxsplit == 1: + yield list(it) + return buf = [] + maxsplit -= 1 if buf: yield buf +def split_when(iterable, pred, maxsplit=-1): + """Split *iterable* into pieces based on the output of *pred*. + *pred* should be a function that takes successive pairs of items and + returns ``True`` if the iterable should be split in between them. + + For example, to find runs of increasing numbers, split the iterable when + element ``i`` is larger than element ``i + 1``: + + >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], lambda x, y: x > y)) + [[1, 2, 3, 3], [2, 5], [2, 4], [2]] + + At most *maxsplit* splits are done. If *maxsplit* is not specified or -1, + then there is no limit on the number of splits: + + >>> list(split_when([1, 2, 3, 3, 2, 5, 2, 4, 2], + ... lambda x, y: x > y, maxsplit=2)) + [[1, 2, 3, 3], [2, 5], [2, 4, 2]] + + """ + if maxsplit == 0: + yield list(iterable) + return + + it = iter(iterable) + try: + cur_item = next(it) + except StopIteration: + return + + buf = [cur_item] + for next_item in it: + if pred(cur_item, next_item): + yield buf + if maxsplit == 1: + yield [next_item] + list(it) + return + buf = [] + maxsplit -= 1 + + buf.append(next_item) + cur_item = next_item + + yield buf + + +def split_into(iterable, sizes): + """Yield a list of sequential items from *iterable* of length 'n' for each + integer 'n' in *sizes*. + + >>> list(split_into([1,2,3,4,5,6], [1,2,3])) + [[1], [2, 3], [4, 5, 6]] + + If the sum of *sizes* is smaller than the length of *iterable*, then the + remaining items of *iterable* will not be returned. + + >>> list(split_into([1,2,3,4,5,6], [2,3])) + [[1, 2], [3, 4, 5]] + + If the sum of *sizes* is larger than the length of *iterable*, fewer items + will be returned in the iteration that overruns *iterable* and further + lists will be empty: + + >>> list(split_into([1,2,3,4], [1,2,3,4])) + [[1], [2, 3], [4], []] + + When a ``None`` object is encountered in *sizes*, the returned list will + contain items up to the end of *iterable* the same way that itertools.slice + does: + + >>> list(split_into([1,2,3,4,5,6,7,8,9,0], [2,3,None])) + [[1, 2], [3, 4, 5], [6, 7, 8, 9, 0]] + + :func:`split_into` can be useful for grouping a series of items where the + sizes of the groups are not uniform. An example would be where in a row + from a table, multiple columns represent elements of the same feature + (e.g. a point represented by x,y,z) but, the format is not the same for + all columns. + """ + # convert the iterable argument into an iterator so its contents can + # be consumed by islice in case it is a generator + it = iter(iterable) + + for size in sizes: + if size is None: + yield list(it) + return + else: + yield list(islice(it, size)) + + def padded(iterable, fillvalue=None, n=None, next_multiple=False): """Yield the elements from *iterable*, followed by *fillvalue*, such that at least *n* items are emitted. @@ -1060,8 +1507,7 @@ def padded(iterable, fillvalue=None, n=None, next_multiple=False): """ it = iter(iterable) if n is None: - for item in chain(it, repeat(fillvalue)): - yield item + yield from chain(it, repeat(fillvalue)) elif n < 1: raise ValueError('n must be at least 1') else: @@ -1075,6 +1521,34 @@ def padded(iterable, fillvalue=None, n=None, next_multiple=False): yield fillvalue +def repeat_each(iterable, n=2): + """Repeat each element in *iterable* *n* times. + + >>> list(repeat_each('ABC', 3)) + ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C'] + """ + return chain.from_iterable(map(repeat, iterable, repeat(n))) + + +def repeat_last(iterable, default=None): + """After the *iterable* is exhausted, keep yielding its last element. + + >>> list(islice(repeat_last(range(3)), 5)) + [0, 1, 2, 2, 2] + + If the iterable is empty, yield *default* forever:: + + >>> list(islice(repeat_last(range(0), 42), 5)) + [42, 42, 42, 42, 42] + + """ + item = _marker + for item in iterable: + yield item + final = default if item is _marker else item + yield from repeat(final) + + def distribute(n, iterable): """Distribute the items from *iterable* among *n* smaller iterables. @@ -1138,7 +1612,38 @@ def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None): ) -def zip_offset(*iterables, **kwargs): +def zip_equal(*iterables): + """``zip`` the input *iterables* together, but raise + ``UnequalIterablesError`` if they aren't all the same length. + + >>> it_1 = range(3) + >>> it_2 = iter('abc') + >>> list(zip_equal(it_1, it_2)) + [(0, 'a'), (1, 'b'), (2, 'c')] + + >>> it_1 = range(3) + >>> it_2 = iter('abcd') + >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + more_itertools.more.UnequalIterablesError: Iterables have different + lengths + + """ + if hexversion >= 0x30A00A6: + warnings.warn( + ( + 'zip_equal will be removed in a future version of ' + 'more-itertools. Use the builtin zip function with ' + 'strict=True instead.' + ), + DeprecationWarning, + ) + + return _zip_equal(*iterables) + + +def zip_offset(*iterables, offsets, longest=False, fillvalue=None): """``zip`` the input *iterables* together, but offset the `i`-th iterable by the `i`-th item in *offsets*. @@ -1146,7 +1651,7 @@ def zip_offset(*iterables, **kwargs): [('0', 'b'), ('1', 'c'), ('2', 'd'), ('3', 'e')] This can be used as a lightweight alternative to SciPy or pandas to analyze - data sets in which somes series have a lead or lag relationship. + data sets in which some series have a lead or lag relationship. By default, the sequence will end when the shortest iterable is exhausted. To continue until the longest iterable is exhausted, set *longest* to @@ -1159,10 +1664,6 @@ def zip_offset(*iterables, **kwargs): sequence. Specify *fillvalue* to use some other value. """ - offsets = kwargs['offsets'] - longest = kwargs.get('longest', False) - fillvalue = kwargs.get('fillvalue', None) - if len(iterables) != len(offsets): raise ValueError("Number of iterables and offsets didn't match") @@ -1181,7 +1682,7 @@ def zip_offset(*iterables, **kwargs): return zip(*staggered) -def sort_together(iterables, key_list=(0,), reverse=False): +def sort_together(iterables, key_list=(0,), key=None, reverse=False): """Return the input iterables sorted together, with *key_list* as the priority for sorting. All iterables are trimmed to the length of the shortest one. @@ -1197,21 +1698,103 @@ def sort_together(iterables, key_list=(0,), reverse=False): [(1, 2, 3, 4), ('d', 'c', 'b', 'a')] Set a different key list to sort according to another iterable. - Specifying mutliple keys dictates how ties are broken:: + Specifying multiple keys dictates how ties are broken:: >>> iterables = [(3, 1, 2), (0, 1, 0), ('c', 'b', 'a')] >>> sort_together(iterables, key_list=(1, 2)) [(2, 3, 1), (0, 0, 1), ('a', 'c', 'b')] + To sort by a function of the elements of the iterable, pass a *key* + function. Its arguments are the elements of the iterables corresponding to + the key list:: + + >>> names = ('a', 'b', 'c') + >>> lengths = (1, 2, 3) + >>> widths = (5, 2, 1) + >>> def area(length, width): + ... return length * width + >>> sort_together([names, lengths, widths], key_list=(1, 2), key=area) + [('c', 'b', 'a'), (3, 2, 1), (1, 2, 5)] + Set *reverse* to ``True`` to sort in descending order. >>> sort_together([(1, 2, 3), ('c', 'b', 'a')], reverse=True) [(3, 2, 1), ('a', 'b', 'c')] """ - return list(zip(*sorted(zip(*iterables), - key=itemgetter(*key_list), - reverse=reverse))) + if key is None: + # if there is no key function, the key argument to sorted is an + # itemgetter + key_argument = itemgetter(*key_list) + else: + # if there is a key function, call it with the items at the offsets + # specified by the key function as arguments + key_list = list(key_list) + if len(key_list) == 1: + # if key_list contains a single item, pass the item at that offset + # as the only argument to the key function + key_offset = key_list[0] + key_argument = lambda zipped_items: key(zipped_items[key_offset]) + else: + # if key_list contains multiple items, use itemgetter to return a + # tuple of items, which we pass as *args to the key function + get_key_items = itemgetter(*key_list) + key_argument = lambda zipped_items: key( + *get_key_items(zipped_items) + ) + + return list( + zip(*sorted(zip(*iterables), key=key_argument, reverse=reverse)) + ) + + +def unzip(iterable): + """The inverse of :func:`zip`, this function disaggregates the elements + of the zipped *iterable*. + + The ``i``-th iterable contains the ``i``-th element from each element + of the zipped iterable. The first element is used to determine the + length of the remaining elements. + + >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)] + >>> letters, numbers = unzip(iterable) + >>> list(letters) + ['a', 'b', 'c', 'd'] + >>> list(numbers) + [1, 2, 3, 4] + + This is similar to using ``zip(*iterable)``, but it avoids reading + *iterable* into memory. Note, however, that this function uses + :func:`itertools.tee` and thus may require significant storage. + + """ + head, iterable = spy(iter(iterable)) + if not head: + # empty iterable, e.g. zip([], [], []) + return () + # spy returns a one-length iterable as head + head = head[0] + iterables = tee(iterable, len(head)) + + def itemgetter(i): + def getter(obj): + try: + return obj[i] + except IndexError: + # basically if we have an iterable like + # iter([(1, 2, 3), (4, 5), (6,)]) + # the second unzipped iterable would fail at the third tuple + # since it would try to access tup[1] + # same with the third unzipped iterable and the second tuple + # to support these "improperly zipped" iterables, + # we create a custom itemgetter + # which just stops the unzipped iterables + # at first length mismatch + raise StopIteration + + return getter + + return tuple(map(itemgetter(i), it) for i, it in enumerate(iterables)) def divide(n, iterable): @@ -1246,19 +1829,26 @@ def divide(n, iterable): if n < 1: raise ValueError('n must be at least 1') - seq = tuple(iterable) + try: + iterable[:0] + except TypeError: + seq = tuple(iterable) + else: + seq = iterable + q, r = divmod(len(seq), n) ret = [] - for i in range(n): - start = (i * q) + (i if i < r else r) - stop = ((i + 1) * q) + (i + 1 if i + 1 < r else r) + stop = 0 + for i in range(1, n + 1): + start = stop + stop += q + 1 if i <= r else q ret.append(iter(seq[start:stop])) return ret -def always_iterable(obj, base_type=(text_type, binary_type)): +def always_iterable(obj, base_type=(str, bytes)): """If *obj* is iterable, return an iterator over its items:: >>> obj = (1, 2, 3) @@ -1350,21 +1940,23 @@ def adjacent(predicate, iterable, distance=1): return zip(adjacent_to_selected, i2) -def groupby_transform(iterable, keyfunc=None, valuefunc=None): - """An extension of :func:`itertools.groupby` that transforms the values of - *iterable* after grouping them. - *keyfunc* is a function used to compute a grouping key for each item. - *valuefunc* is a function for transforming the items after grouping. +def groupby_transform(iterable, keyfunc=None, valuefunc=None, reducefunc=None): + """An extension of :func:`itertools.groupby` that can apply transformations + to the grouped data. - >>> iterable = 'AaaABbBCcA' - >>> keyfunc = lambda x: x.upper() - >>> valuefunc = lambda x: x.lower() - >>> grouper = groupby_transform(iterable, keyfunc, valuefunc) - >>> [(k, ''.join(g)) for k, g in grouper] - [('A', 'aaaa'), ('B', 'bbb'), ('C', 'cc'), ('A', 'a')] + * *keyfunc* is a function computing a key value for each item in *iterable* + * *valuefunc* is a function that transforms the individual items from + *iterable* after grouping + * *reducefunc* is a function that transforms each group of items - *keyfunc* and *valuefunc* default to identity functions if they are not - specified. + >>> iterable = 'aAAbBBcCC' + >>> keyfunc = lambda k: k.upper() + >>> valuefunc = lambda v: v.lower() + >>> reducefunc = lambda g: ''.join(g) + >>> list(groupby_transform(iterable, keyfunc, valuefunc, reducefunc)) + [('A', 'aaa'), ('B', 'bbb'), ('C', 'ccc')] + + Each optional argument defaults to an identity function if not specified. :func:`groupby_transform` is useful when grouping elements of an iterable using a separate iterable as the key. To do this, :func:`zip` the iterables @@ -1384,11 +1976,16 @@ def groupby_transform(iterable, keyfunc=None, valuefunc=None): duplicate groups, you should sort the iterable by the key function. """ - valuefunc = (lambda x: x) if valuefunc is None else valuefunc - return ((k, map(valuefunc, g)) for k, g in groupby(iterable, keyfunc)) + ret = groupby(iterable, keyfunc) + if valuefunc: + ret = ((k, map(valuefunc, g)) for k, g in ret) + if reducefunc: + ret = ((k, reducefunc(g)) for k, g in ret) + + return ret -def numeric_range(*args): +class numeric_range(abc.Sequence, abc.Hashable): """An extension of the built-in ``range()`` function whose arguments can be any orderable numeric type. @@ -1425,28 +2022,184 @@ def numeric_range(*args): Be aware of the limitations of floating point numbers; the representation of the yielded numbers may be surprising. - """ - argc = len(args) - if argc == 1: - stop, = args - start = type(stop)(0) - step = 1 - elif argc == 2: - start, stop = args - step = 1 - elif argc == 3: - start, stop, step = args - else: - err_msg = 'numeric_range takes at most 3 arguments, got {}' - raise TypeError(err_msg.format(argc)) + ``datetime.datetime`` objects can be used for *start* and *stop*, if *step* + is a ``datetime.timedelta`` object: - values = (start + (step * n) for n in count()) - if step > 0: - return takewhile(partial(gt, stop), values) - elif step < 0: - return takewhile(partial(lt, stop), values) - else: - raise ValueError('numeric_range arg 3 must not be zero') + >>> import datetime + >>> start = datetime.datetime(2019, 1, 1) + >>> stop = datetime.datetime(2019, 1, 3) + >>> step = datetime.timedelta(days=1) + >>> items = iter(numeric_range(start, stop, step)) + >>> next(items) + datetime.datetime(2019, 1, 1, 0, 0) + >>> next(items) + datetime.datetime(2019, 1, 2, 0, 0) + + """ + + _EMPTY_HASH = hash(range(0, 0)) + + def __init__(self, *args): + argc = len(args) + if argc == 1: + (self._stop,) = args + self._start = type(self._stop)(0) + self._step = type(self._stop - self._start)(1) + elif argc == 2: + self._start, self._stop = args + self._step = type(self._stop - self._start)(1) + elif argc == 3: + self._start, self._stop, self._step = args + elif argc == 0: + raise TypeError( + 'numeric_range expected at least ' + '1 argument, got {}'.format(argc) + ) + else: + raise TypeError( + 'numeric_range expected at most ' + '3 arguments, got {}'.format(argc) + ) + + self._zero = type(self._step)(0) + if self._step == self._zero: + raise ValueError('numeric_range() arg 3 must not be zero') + self._growing = self._step > self._zero + self._init_len() + + def __bool__(self): + if self._growing: + return self._start < self._stop + else: + return self._start > self._stop + + def __contains__(self, elem): + if self._growing: + if self._start <= elem < self._stop: + return (elem - self._start) % self._step == self._zero + else: + if self._start >= elem > self._stop: + return (self._start - elem) % (-self._step) == self._zero + + return False + + def __eq__(self, other): + if isinstance(other, numeric_range): + empty_self = not bool(self) + empty_other = not bool(other) + if empty_self or empty_other: + return empty_self and empty_other # True if both empty + else: + return ( + self._start == other._start + and self._step == other._step + and self._get_by_index(-1) == other._get_by_index(-1) + ) + else: + return False + + def __getitem__(self, key): + if isinstance(key, int): + return self._get_by_index(key) + elif isinstance(key, slice): + step = self._step if key.step is None else key.step * self._step + + if key.start is None or key.start <= -self._len: + start = self._start + elif key.start >= self._len: + start = self._stop + else: # -self._len < key.start < self._len + start = self._get_by_index(key.start) + + if key.stop is None or key.stop >= self._len: + stop = self._stop + elif key.stop <= -self._len: + stop = self._start + else: # -self._len < key.stop < self._len + stop = self._get_by_index(key.stop) + + return numeric_range(start, stop, step) + else: + raise TypeError( + 'numeric range indices must be ' + 'integers or slices, not {}'.format(type(key).__name__) + ) + + def __hash__(self): + if self: + return hash((self._start, self._get_by_index(-1), self._step)) + else: + return self._EMPTY_HASH + + def __iter__(self): + values = (self._start + (n * self._step) for n in count()) + if self._growing: + return takewhile(partial(gt, self._stop), values) + else: + return takewhile(partial(lt, self._stop), values) + + def __len__(self): + return self._len + + def _init_len(self): + if self._growing: + start = self._start + stop = self._stop + step = self._step + else: + start = self._stop + stop = self._start + step = -self._step + distance = stop - start + if distance <= self._zero: + self._len = 0 + else: # distance > 0 and step > 0: regular euclidean division + q, r = divmod(distance, step) + self._len = int(q) + int(r != self._zero) + + def __reduce__(self): + return numeric_range, (self._start, self._stop, self._step) + + def __repr__(self): + if self._step == 1: + return "numeric_range({}, {})".format( + repr(self._start), repr(self._stop) + ) + else: + return "numeric_range({}, {}, {})".format( + repr(self._start), repr(self._stop), repr(self._step) + ) + + def __reversed__(self): + return iter( + numeric_range( + self._get_by_index(-1), self._start - self._step, -self._step + ) + ) + + def count(self, value): + return int(value in self) + + def index(self, value): + if self._growing: + if self._start <= value < self._stop: + q, r = divmod(value - self._start, self._step) + if r == self._zero: + return int(q) + else: + if self._start >= value > self._stop: + q, r = divmod(self._start - value, -self._step) + if r == self._zero: + return int(q) + + raise ValueError("{} is not in numeric range".format(value)) + + def _get_by_index(self, i): + if i < 0: + i += self._len + if i < 0 or i >= self._len: + raise IndexError("numeric range object index out of range") + return self._start + i * self._step def count_cycle(iterable, n=None): @@ -1465,6 +2218,43 @@ def count_cycle(iterable, n=None): return ((i, item) for i in counter for item in iterable) +def mark_ends(iterable): + """Yield 3-tuples of the form ``(is_first, is_last, item)``. + + >>> list(mark_ends('ABC')) + [(True, False, 'A'), (False, False, 'B'), (False, True, 'C')] + + Use this when looping over an iterable to take special action on its first + and/or last items: + + >>> iterable = ['Header', 100, 200, 'Footer'] + >>> total = 0 + >>> for is_first, is_last, item in mark_ends(iterable): + ... if is_first: + ... continue # Skip the header + ... if is_last: + ... continue # Skip the footer + ... total += item + >>> print(total) + 300 + """ + it = iter(iterable) + + try: + b = next(it) + except StopIteration: + return + + try: + for i in count(): + a = b + b = next(it) + yield i == 0, False, a + + except StopIteration: + yield i == 0, True, a + + def locate(iterable, pred=bool, window_size=None): """Yield the index of each item in *iterable* for which *pred* returns ``True``. @@ -1513,6 +2303,16 @@ def locate(iterable, pred=bool, window_size=None): return compress(count(), starmap(pred, it)) +def longest_common_prefix(iterables): + """Yield elements of the longest common prefix amongst given *iterables*. + + >>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf'])) + 'ab' + + """ + return (c[0] for c in takewhile(all_equal, zip(*iterables))) + + def lstrip(iterable, pred): """Yield the items from *iterable*, but strip any from the beginning for which *pred* returns ``True``. @@ -1547,13 +2347,13 @@ def rstrip(iterable, pred): """ cache = [] cache_append = cache.append + cache_clear = cache.clear for x in iterable: if pred(x): cache_append(x) else: - for y in cache: - yield y - del cache[:] + yield from cache + cache_clear() yield x @@ -1574,7 +2374,7 @@ def strip(iterable, pred): return rstrip(lstrip(iterable, pred), pred) -def islice_extended(iterable, *args): +class islice_extended: """An extension of :func:`itertools.islice` that supports negative values for *stop*, *start*, and *step*. @@ -1591,20 +2391,46 @@ def islice_extended(iterable, *args): >>> list(islice_extended(count(), 110, 99, -2)) [110, 108, 106, 104, 102, 100] + You can also use slice notation directly: + + >>> iterable = map(str, count()) + >>> it = islice_extended(iterable)[10:20:2] + >>> list(it) + ['10', '12', '14', '16', '18'] + """ - s = slice(*args) + + def __init__(self, iterable, *args): + it = iter(iterable) + if args: + self._iterable = _islice_helper(it, slice(*args)) + else: + self._iterable = it + + def __iter__(self): + return self + + def __next__(self): + return next(self._iterable) + + def __getitem__(self, key): + if isinstance(key, slice): + return islice_extended(_islice_helper(self._iterable, key)) + + raise TypeError('islice_extended.__getitem__ argument must be a slice') + + +def _islice_helper(it, s): start = s.start stop = s.stop if s.step == 0: raise ValueError('step argument must be a non-zero integer or None.') step = s.step or 1 - it = iter(iterable) - if step > 0: start = 0 if (start is None) else start - if (start < 0): + if start < 0: # Consume all but the last -start items cache = deque(enumerate(it, 1), maxlen=-start) len_iter = cache[-1][0] if cache else 0 @@ -1642,8 +2468,7 @@ def islice_extended(iterable, *args): cache.append(item) else: # When both start and stop are positive we have the normal case - for item in islice(it, start, stop, step): - yield item + yield from islice(it, start, stop, step) else: start = -1 if (start is None) else start @@ -1688,8 +2513,7 @@ def islice_extended(iterable, *args): cache = list(islice(it, n)) - for item in cache[i::step]: - yield item + yield from cache[i::step] def always_reversible(iterable): @@ -1740,6 +2564,17 @@ def consecutive_groups(iterable, ordering=lambda x: x): ['i'] ['l', 'm', 'n', 'o', 'p'] + Each group of consecutive items is an iterator that shares it source with + *iterable*. When an an output group is advanced, the previous group is + no longer available unless its elements are copied (e.g., into a ``list``). + + >>> iterable = [1, 2, 11, 12, 21, 22] + >>> saved_groups = [] + >>> for group in consecutive_groups(iterable): + ... saved_groups.append(list(group)) # Copy group elements + >>> saved_groups + [[1, 2], [11, 12], [21, 22]] + """ for k, g in groupby( enumerate(iterable), key=lambda x: x[0] - ordering(x[1]) @@ -1747,48 +2582,52 @@ def consecutive_groups(iterable, ordering=lambda x: x): yield map(itemgetter(1), g) -def difference(iterable, func=sub): - """By default, compute the first difference of *iterable* using - :func:`operator.sub`. +def difference(iterable, func=sub, *, initial=None): + """This function is the inverse of :func:`itertools.accumulate`. By default + it will compute the first difference of *iterable* using + :func:`operator.sub`: - >>> iterable = [0, 1, 3, 6, 10] + >>> from itertools import accumulate + >>> iterable = accumulate([0, 1, 2, 3, 4]) # produces 0, 1, 3, 6, 10 >>> list(difference(iterable)) [0, 1, 2, 3, 4] - This is the opposite of :func:`accumulate`'s default behavior: - - >>> from more_itertools import accumulate - >>> iterable = [0, 1, 2, 3, 4] - >>> list(accumulate(iterable)) - [0, 1, 3, 6, 10] - >>> list(difference(accumulate(iterable))) - [0, 1, 2, 3, 4] - - By default *func* is :func:`operator.sub`, but other functions can be + *func* defaults to :func:`operator.sub`, but other functions can be specified. They will be applied as follows:: A, B, C, D, ... --> A, func(B, A), func(C, B), func(D, C), ... For example, to do progressive division: - >>> iterable = [1, 2, 6, 24, 120] # Factorial sequence + >>> iterable = [1, 2, 6, 24, 120] >>> func = lambda x, y: x // y >>> list(difference(iterable, func)) [1, 2, 3, 4, 5] + If the *initial* keyword is set, the first element will be skipped when + computing successive differences. + + >>> it = [10, 11, 13, 16] # from accumulate([1, 2, 3], initial=10) + >>> list(difference(it, initial=10)) + [1, 2, 3] + """ a, b = tee(iterable) try: - item = next(b) + first = [next(b)] except StopIteration: return iter([]) - return chain([item], map(lambda x: func(x[1], x[0]), zip(a, b))) + + if initial is not None: + first = [] + + return chain(first, map(func, b, a)) class SequenceView(Sequence): """Return a read-only view of the sequence object *target*. - :class:`SequenceView` objects are analagous to Python's built-in + :class:`SequenceView` objects are analogous to Python's built-in "dictionary view" types. They provide a dynamic view of a sequence's items, meaning that when the sequence updates, so does the view. @@ -1814,6 +2653,7 @@ class SequenceView(Sequence): require (much) extra storage. """ + def __init__(self, target): if not isinstance(target, Sequence): raise TypeError @@ -1829,7 +2669,7 @@ class SequenceView(Sequence): return '{}({})'.format(self.__class__.__name__, repr(self._target)) -class seekable(object): +class seekable: """Wrap an iterator to allow for seeking backward and forward. This progressively caches the items in the source iterable so they can be re-visited. @@ -1862,8 +2702,26 @@ class seekable(object): >>> next(it), next(it), next(it) ('0', '1', '2') - The cache grows as the source iterable progresses, so beware of wrapping - very large or infinite iterables. + Call :meth:`peek` to look ahead one item without advancing the iterator: + + >>> it = seekable('1234') + >>> it.peek() + '1' + >>> list(it) + ['1', '2', '3', '4'] + >>> it.peek(default='empty') + 'empty' + + Before the iterator is at its end, calling :func:`bool` on it will return + ``True``. After it will return ``False``: + + >>> it = seekable('5678') + >>> bool(it) + True + >>> list(it) + ['5', '6', '7', '8'] + >>> bool(it) + False You may view the contents of the cache with the :meth:`elements` method. That returns a :class:`SequenceView`, a view that updates automatically: @@ -1879,11 +2737,30 @@ class seekable(object): >>> elements SequenceView(['0', '1', '2', '3']) + By default, the cache grows as the source iterable progresses, so beware of + wrapping very large or infinite iterables. Supply *maxlen* to limit the + size of the cache (this of course limits how far back you can seek). + + >>> from itertools import count + >>> it = seekable((str(n) for n in count()), maxlen=2) + >>> next(it), next(it), next(it), next(it) + ('0', '1', '2', '3') + >>> list(it.elements()) + ['2', '3'] + >>> it.seek(0) + >>> next(it), next(it), next(it), next(it) + ('2', '3', '4', '5') + >>> next(it) + '6' + """ - def __init__(self, iterable): + def __init__(self, iterable, maxlen=None): self._source = iter(iterable) - self._cache = [] + if maxlen is None: + self._cache = [] + else: + self._cache = deque([], maxlen) self._index = None def __iter__(self): @@ -1903,7 +2780,24 @@ class seekable(object): self._cache.append(item) return item - next = __next__ + def __bool__(self): + try: + self.peek() + except StopIteration: + return False + return True + + def peek(self, default=_marker): + try: + peeked = next(self) + except StopIteration: + if default is _marker: + raise + return default + if self._index is None: + self._index = len(self._cache) + self._index -= 1 + return peeked def elements(self): return SequenceView(self._cache) @@ -1915,7 +2809,7 @@ class seekable(object): consume(self, remainder) -class run_length(object): +class run_length: """ :func:`run_length.encode` compresses an iterable with run-length encoding. It yields groups of repeated items with the count of how many times they @@ -1965,8 +2859,8 @@ def exactly_n(iterable, n, predicate=bool): def circular_shifts(iterable): """Return a list of circular shifts of *iterable*. - >>> circular_shifts(range(4)) - [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] + >>> circular_shifts(range(4)) + [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] """ lst = list(iterable) return take(len(lst), windowed(cycle(lst), len(lst))) @@ -2140,9 +3034,7 @@ def rlocate(iterable, pred=bool, window_size=None): if window_size is None: try: len_iter = len(iterable) - return ( - len_iter - i - 1 for i in locate(reversed(iterable), pred) - ) + return (len_iter - i - 1 for i in locate(reversed(iterable), pred)) except TypeError: pass @@ -2200,8 +3092,7 @@ def replace(iterable, pred, substitutes, count=None, window_size=1): if pred(*w): if (count is None) or (n < count): n += 1 - for s in substitutes: - yield s + yield from substitutes consume(windows, window_size - 1) continue @@ -2209,3 +3100,1248 @@ def replace(iterable, pred, substitutes, count=None, window_size=1): # yield the first item from the window. if w and (w[0] is not _marker): yield w[0] + + +def partitions(iterable): + """Yield all possible order-preserving partitions of *iterable*. + + >>> iterable = 'abc' + >>> for part in partitions(iterable): + ... print([''.join(p) for p in part]) + ['abc'] + ['a', 'bc'] + ['ab', 'c'] + ['a', 'b', 'c'] + + This is unrelated to :func:`partition`. + + """ + sequence = list(iterable) + n = len(sequence) + for i in powerset(range(1, n)): + yield [sequence[i:j] for i, j in zip((0,) + i, i + (n,))] + + +def set_partitions(iterable, k=None): + """ + Yield the set partitions of *iterable* into *k* parts. Set partitions are + not order-preserving. + + >>> iterable = 'abc' + >>> for part in set_partitions(iterable, 2): + ... print([''.join(p) for p in part]) + ['a', 'bc'] + ['ab', 'c'] + ['b', 'ac'] + + + If *k* is not given, every set partition is generated. + + >>> iterable = 'abc' + >>> for part in set_partitions(iterable): + ... print([''.join(p) for p in part]) + ['abc'] + ['a', 'bc'] + ['ab', 'c'] + ['b', 'ac'] + ['a', 'b', 'c'] + + """ + L = list(iterable) + n = len(L) + if k is not None: + if k < 1: + raise ValueError( + "Can't partition in a negative or zero number of groups" + ) + elif k > n: + return + + def set_partitions_helper(L, k): + n = len(L) + if k == 1: + yield [L] + elif n == k: + yield [[s] for s in L] + else: + e, *M = L + for p in set_partitions_helper(M, k - 1): + yield [[e], *p] + for p in set_partitions_helper(M, k): + for i in range(len(p)): + yield p[:i] + [[e] + p[i]] + p[i + 1 :] + + if k is None: + for k in range(1, n + 1): + yield from set_partitions_helper(L, k) + else: + yield from set_partitions_helper(L, k) + + +class time_limited: + """ + Yield items from *iterable* until *limit_seconds* have passed. + If the time limit expires before all items have been yielded, the + ``timed_out`` parameter will be set to ``True``. + + >>> from time import sleep + >>> def generator(): + ... yield 1 + ... yield 2 + ... sleep(0.2) + ... yield 3 + >>> iterable = time_limited(0.1, generator()) + >>> list(iterable) + [1, 2] + >>> iterable.timed_out + True + + Note that the time is checked before each item is yielded, and iteration + stops if the time elapsed is greater than *limit_seconds*. If your time + limit is 1 second, but it takes 2 seconds to generate the first item from + the iterable, the function will run for 2 seconds and not yield anything. + + """ + + def __init__(self, limit_seconds, iterable): + if limit_seconds < 0: + raise ValueError('limit_seconds must be positive') + self.limit_seconds = limit_seconds + self._iterable = iter(iterable) + self._start_time = monotonic() + self.timed_out = False + + def __iter__(self): + return self + + def __next__(self): + item = next(self._iterable) + if monotonic() - self._start_time > self.limit_seconds: + self.timed_out = True + raise StopIteration + + return item + + +def only(iterable, default=None, too_long=None): + """If *iterable* has only one item, return it. + If it has zero items, return *default*. + If it has more than one item, raise the exception given by *too_long*, + which is ``ValueError`` by default. + + >>> only([], default='missing') + 'missing' + >>> only([1]) + 1 + >>> only([1, 2]) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: Expected exactly one item in iterable, but got 1, 2, + and perhaps more.' + >>> only([1, 2], too_long=TypeError) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + TypeError + + Note that :func:`only` attempts to advance *iterable* twice to ensure there + is only one item. See :func:`spy` or :func:`peekable` to check + iterable contents less destructively. + """ + it = iter(iterable) + first_value = next(it, default) + + try: + second_value = next(it) + except StopIteration: + pass + else: + msg = ( + 'Expected exactly one item in iterable, but got {!r}, {!r}, ' + 'and perhaps more.'.format(first_value, second_value) + ) + raise too_long or ValueError(msg) + + return first_value + + +class _IChunk: + def __init__(self, iterable, n): + self._it = islice(iterable, n) + self._cache = deque() + + def fill_cache(self): + self._cache.extend(self._it) + + def __iter__(self): + return self + + def __next__(self): + try: + return next(self._it) + except StopIteration: + if self._cache: + return self._cache.popleft() + else: + raise + + +def ichunked(iterable, n): + """Break *iterable* into sub-iterables with *n* elements each. + :func:`ichunked` is like :func:`chunked`, but it yields iterables + instead of lists. + + If the sub-iterables are read in order, the elements of *iterable* + won't be stored in memory. + If they are read out of order, :func:`itertools.tee` is used to cache + elements as necessary. + + >>> from itertools import count + >>> all_chunks = ichunked(count(), 4) + >>> c_1, c_2, c_3 = next(all_chunks), next(all_chunks), next(all_chunks) + >>> list(c_2) # c_1's elements have been cached; c_3's haven't been + [4, 5, 6, 7] + >>> list(c_1) + [0, 1, 2, 3] + >>> list(c_3) + [8, 9, 10, 11] + + """ + source = peekable(iter(iterable)) + ichunk_marker = object() + while True: + # Check to see whether we're at the end of the source iterable + item = source.peek(ichunk_marker) + if item is ichunk_marker: + return + + chunk = _IChunk(source, n) + yield chunk + + # Advance the source iterable and fill previous chunk's cache + chunk.fill_cache() + + +def iequals(*iterables): + """Return ``True`` if all given *iterables* are equal to each other, + which means that they contain the same elements in the same order. + + The function is useful for comparing iterables of different data types + or iterables that do not support equality checks. + + >>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc")) + True + + >>> iequals("abc", "acb") + False + + Not to be confused with :func:`all_equals`, which checks whether all + elements of iterable are equal to each other. + + """ + return all(map(all_equal, zip_longest(*iterables, fillvalue=object()))) + + +def distinct_combinations(iterable, r): + """Yield the distinct combinations of *r* items taken from *iterable*. + + >>> list(distinct_combinations([0, 0, 1], 2)) + [(0, 0), (0, 1)] + + Equivalent to ``set(combinations(iterable))``, except duplicates are not + generated and thrown away. For larger input sequences this is much more + efficient. + + """ + if r < 0: + raise ValueError('r must be non-negative') + elif r == 0: + yield () + return + pool = tuple(iterable) + generators = [unique_everseen(enumerate(pool), key=itemgetter(1))] + current_combo = [None] * r + level = 0 + while generators: + try: + cur_idx, p = next(generators[-1]) + except StopIteration: + generators.pop() + level -= 1 + continue + current_combo[level] = p + if level + 1 == r: + yield tuple(current_combo) + else: + generators.append( + unique_everseen( + enumerate(pool[cur_idx + 1 :], cur_idx + 1), + key=itemgetter(1), + ) + ) + level += 1 + + +def filter_except(validator, iterable, *exceptions): + """Yield the items from *iterable* for which the *validator* function does + not raise one of the specified *exceptions*. + + *validator* is called for each item in *iterable*. + It should be a function that accepts one argument and raises an exception + if that item is not valid. + + >>> iterable = ['1', '2', 'three', '4', None] + >>> list(filter_except(int, iterable, ValueError, TypeError)) + ['1', '2', '4'] + + If an exception other than one given by *exceptions* is raised by + *validator*, it is raised like normal. + """ + for item in iterable: + try: + validator(item) + except exceptions: + pass + else: + yield item + + +def map_except(function, iterable, *exceptions): + """Transform each item from *iterable* with *function* and yield the + result, unless *function* raises one of the specified *exceptions*. + + *function* is called to transform each item in *iterable*. + It should accept one argument. + + >>> iterable = ['1', '2', 'three', '4', None] + >>> list(map_except(int, iterable, ValueError, TypeError)) + [1, 2, 4] + + If an exception other than one given by *exceptions* is raised by + *function*, it is raised like normal. + """ + for item in iterable: + try: + yield function(item) + except exceptions: + pass + + +def map_if(iterable, pred, func, func_else=lambda x: x): + """Evaluate each item from *iterable* using *pred*. If the result is + equivalent to ``True``, transform the item with *func* and yield it. + Otherwise, transform the item with *func_else* and yield it. + + *pred*, *func*, and *func_else* should each be functions that accept + one argument. By default, *func_else* is the identity function. + + >>> from math import sqrt + >>> iterable = list(range(-5, 5)) + >>> iterable + [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4] + >>> list(map_if(iterable, lambda x: x > 3, lambda x: 'toobig')) + [-5, -4, -3, -2, -1, 0, 1, 2, 3, 'toobig'] + >>> list(map_if(iterable, lambda x: x >= 0, + ... lambda x: f'{sqrt(x):.2f}', lambda x: None)) + [None, None, None, None, None, '0.00', '1.00', '1.41', '1.73', '2.00'] + """ + for item in iterable: + yield func(item) if pred(item) else func_else(item) + + +def _sample_unweighted(iterable, k): + # Implementation of "Algorithm L" from the 1994 paper by Kim-Hung Li: + # "Reservoir-Sampling Algorithms of Time Complexity O(n(1+log(N/n)))". + + # Fill up the reservoir (collection of samples) with the first `k` samples + reservoir = take(k, iterable) + + # Generate random number that's the largest in a sample of k U(0,1) numbers + # Largest order statistic: https://en.wikipedia.org/wiki/Order_statistic + W = exp(log(random()) / k) + + # The number of elements to skip before changing the reservoir is a random + # number with a geometric distribution. Sample it using random() and logs. + next_index = k + floor(log(random()) / log(1 - W)) + + for index, element in enumerate(iterable, k): + + if index == next_index: + reservoir[randrange(k)] = element + # The new W is the largest in a sample of k U(0, `old_W`) numbers + W *= exp(log(random()) / k) + next_index += floor(log(random()) / log(1 - W)) + 1 + + return reservoir + + +def _sample_weighted(iterable, k, weights): + # Implementation of "A-ExpJ" from the 2006 paper by Efraimidis et al. : + # "Weighted random sampling with a reservoir". + + # Log-transform for numerical stability for weights that are small/large + weight_keys = (log(random()) / weight for weight in weights) + + # Fill up the reservoir (collection of samples) with the first `k` + # weight-keys and elements, then heapify the list. + reservoir = take(k, zip(weight_keys, iterable)) + heapify(reservoir) + + # The number of jumps before changing the reservoir is a random variable + # with an exponential distribution. Sample it using random() and logs. + smallest_weight_key, _ = reservoir[0] + weights_to_skip = log(random()) / smallest_weight_key + + for weight, element in zip(weights, iterable): + if weight >= weights_to_skip: + # The notation here is consistent with the paper, but we store + # the weight-keys in log-space for better numerical stability. + smallest_weight_key, _ = reservoir[0] + t_w = exp(weight * smallest_weight_key) + r_2 = uniform(t_w, 1) # generate U(t_w, 1) + weight_key = log(r_2) / weight + heapreplace(reservoir, (weight_key, element)) + smallest_weight_key, _ = reservoir[0] + weights_to_skip = log(random()) / smallest_weight_key + else: + weights_to_skip -= weight + + # Equivalent to [element for weight_key, element in sorted(reservoir)] + return [heappop(reservoir)[1] for _ in range(k)] + + +def sample(iterable, k, weights=None): + """Return a *k*-length list of elements chosen (without replacement) + from the *iterable*. Like :func:`random.sample`, but works on iterables + of unknown length. + + >>> iterable = range(100) + >>> sample(iterable, 5) # doctest: +SKIP + [81, 60, 96, 16, 4] + + An iterable with *weights* may also be given: + + >>> iterable = range(100) + >>> weights = (i * i + 1 for i in range(100)) + >>> sampled = sample(iterable, 5, weights=weights) # doctest: +SKIP + [79, 67, 74, 66, 78] + + The algorithm can also be used to generate weighted random permutations. + The relative weight of each item determines the probability that it + appears late in the permutation. + + >>> data = "abcdefgh" + >>> weights = range(1, len(data) + 1) + >>> sample(data, k=len(data), weights=weights) # doctest: +SKIP + ['c', 'a', 'b', 'e', 'g', 'd', 'h', 'f'] + """ + if k == 0: + return [] + + iterable = iter(iterable) + if weights is None: + return _sample_unweighted(iterable, k) + else: + weights = iter(weights) + return _sample_weighted(iterable, k, weights) + + +def is_sorted(iterable, key=None, reverse=False, strict=False): + """Returns ``True`` if the items of iterable are in sorted order, and + ``False`` otherwise. *key* and *reverse* have the same meaning that they do + in the built-in :func:`sorted` function. + + >>> is_sorted(['1', '2', '3', '4', '5'], key=int) + True + >>> is_sorted([5, 4, 3, 1, 2], reverse=True) + False + + If *strict*, tests for strict sorting, that is, returns ``False`` if equal + elements are found: + + >>> is_sorted([1, 2, 2]) + True + >>> is_sorted([1, 2, 2], strict=True) + False + + The function returns ``False`` after encountering the first out-of-order + item. If there are no out-of-order items, the iterable is exhausted. + """ + + compare = (le if reverse else ge) if strict else (lt if reverse else gt) + it = iterable if key is None else map(key, iterable) + return not any(starmap(compare, pairwise(it))) + + +class AbortThread(BaseException): + pass + + +class callback_iter: + """Convert a function that uses callbacks to an iterator. + + Let *func* be a function that takes a `callback` keyword argument. + For example: + + >>> def func(callback=None): + ... for i, c in [(1, 'a'), (2, 'b'), (3, 'c')]: + ... if callback: + ... callback(i, c) + ... return 4 + + + Use ``with callback_iter(func)`` to get an iterator over the parameters + that are delivered to the callback. + + >>> with callback_iter(func) as it: + ... for args, kwargs in it: + ... print(args) + (1, 'a') + (2, 'b') + (3, 'c') + + The function will be called in a background thread. The ``done`` property + indicates whether it has completed execution. + + >>> it.done + True + + If it completes successfully, its return value will be available + in the ``result`` property. + + >>> it.result + 4 + + Notes: + + * If the function uses some keyword argument besides ``callback``, supply + *callback_kwd*. + * If it finished executing, but raised an exception, accessing the + ``result`` property will raise the same exception. + * If it hasn't finished executing, accessing the ``result`` + property from within the ``with`` block will raise ``RuntimeError``. + * If it hasn't finished executing, accessing the ``result`` property from + outside the ``with`` block will raise a + ``more_itertools.AbortThread`` exception. + * Provide *wait_seconds* to adjust how frequently the it is polled for + output. + + """ + + def __init__(self, func, callback_kwd='callback', wait_seconds=0.1): + self._func = func + self._callback_kwd = callback_kwd + self._aborted = False + self._future = None + self._wait_seconds = wait_seconds + # Lazily import concurrent.future + self._executor = __import__( + 'concurrent.futures' + ).futures.ThreadPoolExecutor(max_workers=1) + self._iterator = self._reader() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._aborted = True + self._executor.shutdown() + + def __iter__(self): + return self + + def __next__(self): + return next(self._iterator) + + @property + def done(self): + if self._future is None: + return False + return self._future.done() + + @property + def result(self): + if not self.done: + raise RuntimeError('Function has not yet completed') + + return self._future.result() + + def _reader(self): + q = Queue() + + def callback(*args, **kwargs): + if self._aborted: + raise AbortThread('canceled by user') + + q.put((args, kwargs)) + + self._future = self._executor.submit( + self._func, **{self._callback_kwd: callback} + ) + + while True: + try: + item = q.get(timeout=self._wait_seconds) + except Empty: + pass + else: + q.task_done() + yield item + + if self._future.done(): + break + + remaining = [] + while True: + try: + item = q.get_nowait() + except Empty: + break + else: + q.task_done() + remaining.append(item) + q.join() + yield from remaining + + +def windowed_complete(iterable, n): + """ + Yield ``(beginning, middle, end)`` tuples, where: + + * Each ``middle`` has *n* items from *iterable* + * Each ``beginning`` has the items before the ones in ``middle`` + * Each ``end`` has the items after the ones in ``middle`` + + >>> iterable = range(7) + >>> n = 3 + >>> for beginning, middle, end in windowed_complete(iterable, n): + ... print(beginning, middle, end) + () (0, 1, 2) (3, 4, 5, 6) + (0,) (1, 2, 3) (4, 5, 6) + (0, 1) (2, 3, 4) (5, 6) + (0, 1, 2) (3, 4, 5) (6,) + (0, 1, 2, 3) (4, 5, 6) () + + Note that *n* must be at least 0 and most equal to the length of + *iterable*. + + This function will exhaust the iterable and may require significant + storage. + """ + if n < 0: + raise ValueError('n must be >= 0') + + seq = tuple(iterable) + size = len(seq) + + if n > size: + raise ValueError('n must be <= len(seq)') + + for i in range(size - n + 1): + beginning = seq[:i] + middle = seq[i : i + n] + end = seq[i + n :] + yield beginning, middle, end + + +def all_unique(iterable, key=None): + """ + Returns ``True`` if all the elements of *iterable* are unique (no two + elements are equal). + + >>> all_unique('ABCB') + False + + If a *key* function is specified, it will be used to make comparisons. + + >>> all_unique('ABCb') + True + >>> all_unique('ABCb', str.lower) + False + + The function returns as soon as the first non-unique element is + encountered. Iterables with a mix of hashable and unhashable items can + be used, but the function will be slower for unhashable items. + """ + seenset = set() + seenset_add = seenset.add + seenlist = [] + seenlist_add = seenlist.append + for element in map(key, iterable) if key else iterable: + try: + if element in seenset: + return False + seenset_add(element) + except TypeError: + if element in seenlist: + return False + seenlist_add(element) + return True + + +def nth_product(index, *args): + """Equivalent to ``list(product(*args))[index]``. + + The products of *args* can be ordered lexicographically. + :func:`nth_product` computes the product at sort position *index* without + computing the previous products. + + >>> nth_product(8, range(2), range(2), range(2), range(2)) + (1, 0, 0, 0) + + ``IndexError`` will be raised if the given *index* is invalid. + """ + pools = list(map(tuple, reversed(args))) + ns = list(map(len, pools)) + + c = reduce(mul, ns) + + if index < 0: + index += c + + if not 0 <= index < c: + raise IndexError + + result = [] + for pool, n in zip(pools, ns): + result.append(pool[index % n]) + index //= n + + return tuple(reversed(result)) + + +def nth_permutation(iterable, r, index): + """Equivalent to ``list(permutations(iterable, r))[index]``` + + The subsequences of *iterable* that are of length *r* where order is + important can be ordered lexicographically. :func:`nth_permutation` + computes the subsequence at sort position *index* directly, without + computing the previous subsequences. + + >>> nth_permutation('ghijk', 2, 5) + ('h', 'i') + + ``ValueError`` will be raised If *r* is negative or greater than the length + of *iterable*. + ``IndexError`` will be raised if the given *index* is invalid. + """ + pool = list(iterable) + n = len(pool) + + if r is None or r == n: + r, c = n, factorial(n) + elif not 0 <= r < n: + raise ValueError + else: + c = factorial(n) // factorial(n - r) + + if index < 0: + index += c + + if not 0 <= index < c: + raise IndexError + + if c == 0: + return tuple() + + result = [0] * r + q = index * factorial(n) // c if r < n else index + for d in range(1, n + 1): + q, i = divmod(q, d) + if 0 <= n - d < r: + result[n - d] = i + if q == 0: + break + + return tuple(map(pool.pop, result)) + + +def value_chain(*args): + """Yield all arguments passed to the function in the same order in which + they were passed. If an argument itself is iterable then iterate over its + values. + + >>> list(value_chain(1, 2, 3, [4, 5, 6])) + [1, 2, 3, 4, 5, 6] + + Binary and text strings are not considered iterable and are emitted + as-is: + + >>> list(value_chain('12', '34', ['56', '78'])) + ['12', '34', '56', '78'] + + + Multiple levels of nesting are not flattened. + + """ + for value in args: + if isinstance(value, (str, bytes)): + yield value + continue + try: + yield from value + except TypeError: + yield value + + +def product_index(element, *args): + """Equivalent to ``list(product(*args)).index(element)`` + + The products of *args* can be ordered lexicographically. + :func:`product_index` computes the first index of *element* without + computing the previous products. + + >>> product_index([8, 2], range(10), range(5)) + 42 + + ``ValueError`` will be raised if the given *element* isn't in the product + of *args*. + """ + index = 0 + + for x, pool in zip_longest(element, args, fillvalue=_marker): + if x is _marker or pool is _marker: + raise ValueError('element is not a product of args') + + pool = tuple(pool) + index = index * len(pool) + pool.index(x) + + return index + + +def combination_index(element, iterable): + """Equivalent to ``list(combinations(iterable, r)).index(element)`` + + The subsequences of *iterable* that are of length *r* can be ordered + lexicographically. :func:`combination_index` computes the index of the + first *element*, without computing the previous combinations. + + >>> combination_index('adf', 'abcdefg') + 10 + + ``ValueError`` will be raised if the given *element* isn't one of the + combinations of *iterable*. + """ + element = enumerate(element) + k, y = next(element, (None, None)) + if k is None: + return 0 + + indexes = [] + pool = enumerate(iterable) + for n, x in pool: + if x == y: + indexes.append(n) + tmp, y = next(element, (None, None)) + if tmp is None: + break + else: + k = tmp + else: + raise ValueError('element is not a combination of iterable') + + n, _ = last(pool, default=(n, None)) + + # Python versions below 3.8 don't have math.comb + index = 1 + for i, j in enumerate(reversed(indexes), start=1): + j = n - j + if i <= j: + index += factorial(j) // (factorial(i) * factorial(j - i)) + + return factorial(n + 1) // (factorial(k + 1) * factorial(n - k)) - index + + +def permutation_index(element, iterable): + """Equivalent to ``list(permutations(iterable, r)).index(element)``` + + The subsequences of *iterable* that are of length *r* where order is + important can be ordered lexicographically. :func:`permutation_index` + computes the index of the first *element* directly, without computing + the previous permutations. + + >>> permutation_index([1, 3, 2], range(5)) + 19 + + ``ValueError`` will be raised if the given *element* isn't one of the + permutations of *iterable*. + """ + index = 0 + pool = list(iterable) + for i, x in zip(range(len(pool), -1, -1), element): + r = pool.index(x) + index = index * i + r + del pool[r] + + return index + + +class countable: + """Wrap *iterable* and keep a count of how many items have been consumed. + + The ``items_seen`` attribute starts at ``0`` and increments as the iterable + is consumed: + + >>> iterable = map(str, range(10)) + >>> it = countable(iterable) + >>> it.items_seen + 0 + >>> next(it), next(it) + ('0', '1') + >>> list(it) + ['2', '3', '4', '5', '6', '7', '8', '9'] + >>> it.items_seen + 10 + """ + + def __init__(self, iterable): + self._it = iter(iterable) + self.items_seen = 0 + + def __iter__(self): + return self + + def __next__(self): + item = next(self._it) + self.items_seen += 1 + + return item + + +def chunked_even(iterable, n): + """Break *iterable* into lists of approximately length *n*. + Items are distributed such the lengths of the lists differ by at most + 1 item. + + >>> iterable = [1, 2, 3, 4, 5, 6, 7] + >>> n = 3 + >>> list(chunked_even(iterable, n)) # List lengths: 3, 2, 2 + [[1, 2, 3], [4, 5], [6, 7]] + >>> list(chunked(iterable, n)) # List lengths: 3, 3, 1 + [[1, 2, 3], [4, 5, 6], [7]] + + """ + + len_method = getattr(iterable, '__len__', None) + + if len_method is None: + return _chunked_even_online(iterable, n) + else: + return _chunked_even_finite(iterable, len_method(), n) + + +def _chunked_even_online(iterable, n): + buffer = [] + maxbuf = n + (n - 2) * (n - 1) + for x in iterable: + buffer.append(x) + if len(buffer) == maxbuf: + yield buffer[:n] + buffer = buffer[n:] + yield from _chunked_even_finite(buffer, len(buffer), n) + + +def _chunked_even_finite(iterable, N, n): + if N < 1: + return + + # Lists are either size `full_size <= n` or `partial_size = full_size - 1` + q, r = divmod(N, n) + num_lists = q + (1 if r > 0 else 0) + q, r = divmod(N, num_lists) + full_size = q + (1 if r > 0 else 0) + partial_size = full_size - 1 + num_full = N - partial_size * num_lists + num_partial = num_lists - num_full + + buffer = [] + iterator = iter(iterable) + + # Yield num_full lists of full_size + for x in iterator: + buffer.append(x) + if len(buffer) == full_size: + yield buffer + buffer = [] + num_full -= 1 + if num_full <= 0: + break + + # Yield num_partial lists of partial_size + for x in iterator: + buffer.append(x) + if len(buffer) == partial_size: + yield buffer + buffer = [] + num_partial -= 1 + + +def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False): + """A version of :func:`zip` that "broadcasts" any scalar + (i.e., non-iterable) items into output tuples. + + >>> iterable_1 = [1, 2, 3] + >>> iterable_2 = ['a', 'b', 'c'] + >>> scalar = '_' + >>> list(zip_broadcast(iterable_1, iterable_2, scalar)) + [(1, 'a', '_'), (2, 'b', '_'), (3, 'c', '_')] + + The *scalar_types* keyword argument determines what types are considered + scalar. It is set to ``(str, bytes)`` by default. Set it to ``None`` to + treat strings and byte strings as iterable: + + >>> list(zip_broadcast('abc', 0, 'xyz', scalar_types=None)) + [('a', 0, 'x'), ('b', 0, 'y'), ('c', 0, 'z')] + + If the *strict* keyword argument is ``True``, then + ``UnequalIterablesError`` will be raised if any of the iterables have + different lengths. + """ + + def is_scalar(obj): + if scalar_types and isinstance(obj, scalar_types): + return True + try: + iter(obj) + except TypeError: + return True + else: + return False + + size = len(objects) + if not size: + return + + iterables, iterable_positions = [], [] + scalars, scalar_positions = [], [] + for i, obj in enumerate(objects): + if is_scalar(obj): + scalars.append(obj) + scalar_positions.append(i) + else: + iterables.append(iter(obj)) + iterable_positions.append(i) + + if len(scalars) == size: + yield tuple(objects) + return + + zipper = _zip_equal if strict else zip + for item in zipper(*iterables): + new_item = [None] * size + + for i, elem in zip(iterable_positions, item): + new_item[i] = elem + + for i, elem in zip(scalar_positions, scalars): + new_item[i] = elem + + yield tuple(new_item) + + +def unique_in_window(iterable, n, key=None): + """Yield the items from *iterable* that haven't been seen recently. + *n* is the size of the lookback window. + + >>> iterable = [0, 1, 0, 2, 3, 0] + >>> n = 3 + >>> list(unique_in_window(iterable, n)) + [0, 1, 2, 3, 0] + + The *key* function, if provided, will be used to determine uniqueness: + + >>> list(unique_in_window('abAcda', 3, key=lambda x: x.lower())) + ['a', 'b', 'c', 'd', 'a'] + + The items in *iterable* must be hashable. + + """ + if n <= 0: + raise ValueError('n must be greater than 0') + + window = deque(maxlen=n) + uniques = set() + use_key = key is not None + + for item in iterable: + k = key(item) if use_key else item + if k in uniques: + continue + + if len(uniques) == n: + uniques.discard(window[0]) + + uniques.add(k) + window.append(k) + + yield item + + +def duplicates_everseen(iterable, key=None): + """Yield duplicate elements after their first appearance. + + >>> list(duplicates_everseen('mississippi')) + ['s', 'i', 's', 's', 'i', 'p', 'i'] + >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower)) + ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a'] + + This function is analagous to :func:`unique_everseen` and is subject to + the same performance considerations. + + """ + seen_set = set() + seen_list = [] + use_key = key is not None + + for element in iterable: + k = key(element) if use_key else element + try: + if k not in seen_set: + seen_set.add(k) + else: + yield element + except TypeError: + if k not in seen_list: + seen_list.append(k) + else: + yield element + + +def duplicates_justseen(iterable, key=None): + """Yields serially-duplicate elements after their first appearance. + + >>> list(duplicates_justseen('mississippi')) + ['s', 's', 'p'] + >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower)) + ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a'] + + This function is analagous to :func:`unique_justseen`. + + """ + return flatten( + map( + lambda group_tuple: islice_extended(group_tuple[1])[1:], + groupby(iterable, key), + ) + ) + + +def minmax(iterable_or_value, *others, key=None, default=_marker): + """Returns both the smallest and largest items in an iterable + or the largest of two or more arguments. + + >>> minmax([3, 1, 5]) + (1, 5) + + >>> minmax(4, 2, 6) + (2, 6) + + If a *key* function is provided, it will be used to transform the input + items for comparison. + + >>> minmax([5, 30], key=str) # '30' sorts before '5' + (30, 5) + + If a *default* value is provided, it will be returned if there are no + input items. + + >>> minmax([], default=(0, 0)) + (0, 0) + + Otherwise ``ValueError`` is raised. + + This function is based on the + `recipe `__ by + Raymond Hettinger and takes care to minimize the number of comparisons + performed. + """ + iterable = (iterable_or_value, *others) if others else iterable_or_value + + it = iter(iterable) + + try: + lo = hi = next(it) + except StopIteration as e: + if default is _marker: + raise ValueError( + '`minmax()` argument is an empty iterable. ' + 'Provide a `default` value to suppress this error.' + ) from e + return default + + # Different branches depending on the presence of key. This saves a lot + # of unimportant copies which would slow the "key=None" branch + # significantly down. + if key is None: + for x, y in zip_longest(it, it, fillvalue=lo): + if y < x: + x, y = y, x + if x < lo: + lo = x + if hi < y: + hi = y + + else: + lo_key = hi_key = key(lo) + + for x, y in zip_longest(it, it, fillvalue=lo): + + x_key, y_key = key(x), key(y) + + if y_key < x_key: + x, y, x_key, y_key = y, x, y_key, x_key + if x_key < lo_key: + lo, lo_key = x, x_key + if hi_key < y_key: + hi, hi_key = y, y_key + + return lo, hi + + +def constrained_batches( + iterable, max_size, max_count=None, get_len=len, strict=True +): + """Yield batches of items from *iterable* with a combined size limited by + *max_size*. + + >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1'] + >>> list(constrained_batches(iterable, 10)) + [(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')] + + If a *max_count* is supplied, the number of items per batch is also + limited: + + >>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1'] + >>> list(constrained_batches(iterable, 10, max_count = 2)) + [(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)] + + If a *get_len* function is supplied, use that instead of :func:`len` to + determine item size. + + If *strict* is ``True``, raise ``ValueError`` if any single item is bigger + than *max_size*. Otherwise, allow single items to exceed *max_size*. + """ + if max_size <= 0: + raise ValueError('maximum size must be greater than zero') + + batch = [] + batch_size = 0 + batch_count = 0 + for item in iterable: + item_len = get_len(item) + if strict and item_len > max_size: + raise ValueError('item size exceeds maximum size') + + reached_count = batch_count == max_count + reached_size = item_len + batch_size > max_size + if batch_count and (reached_size or reached_count): + yield tuple(batch) + batch.clear() + batch_size = 0 + batch_count = 0 + + batch.append(item) + batch_size += item_len + batch_count += 1 + + if batch: + yield tuple(batch) diff --git a/libs/win/more_itertools/more.pyi b/libs/win/more_itertools/more.pyi new file mode 100644 index 00000000..1413fae7 --- /dev/null +++ b/libs/win/more_itertools/more.pyi @@ -0,0 +1,674 @@ +"""Stubs for more_itertools.more""" + +from typing import ( + Any, + Callable, + Container, + Dict, + Generic, + Hashable, + Iterable, + Iterator, + List, + Optional, + Reversible, + Sequence, + Sized, + Tuple, + Union, + TypeVar, + type_check_only, +) +from types import TracebackType +from typing_extensions import ContextManager, Protocol, Type, overload + +# Type and type variable definitions +_T = TypeVar('_T') +_T1 = TypeVar('_T1') +_T2 = TypeVar('_T2') +_U = TypeVar('_U') +_V = TypeVar('_V') +_W = TypeVar('_W') +_T_co = TypeVar('_T_co', covariant=True) +_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[object]]) +_Raisable = Union[BaseException, 'Type[BaseException]'] + +@type_check_only +class _SizedIterable(Protocol[_T_co], Sized, Iterable[_T_co]): ... + +@type_check_only +class _SizedReversible(Protocol[_T_co], Sized, Reversible[_T_co]): ... + +def chunked( + iterable: Iterable[_T], n: Optional[int], strict: bool = ... +) -> Iterator[List[_T]]: ... +@overload +def first(iterable: Iterable[_T]) -> _T: ... +@overload +def first(iterable: Iterable[_T], default: _U) -> Union[_T, _U]: ... +@overload +def last(iterable: Iterable[_T]) -> _T: ... +@overload +def last(iterable: Iterable[_T], default: _U) -> Union[_T, _U]: ... +@overload +def nth_or_last(iterable: Iterable[_T], n: int) -> _T: ... +@overload +def nth_or_last( + iterable: Iterable[_T], n: int, default: _U +) -> Union[_T, _U]: ... + +class peekable(Generic[_T], Iterator[_T]): + def __init__(self, iterable: Iterable[_T]) -> None: ... + def __iter__(self) -> peekable[_T]: ... + def __bool__(self) -> bool: ... + @overload + def peek(self) -> _T: ... + @overload + def peek(self, default: _U) -> Union[_T, _U]: ... + def prepend(self, *items: _T) -> None: ... + def __next__(self) -> _T: ... + @overload + def __getitem__(self, index: int) -> _T: ... + @overload + def __getitem__(self, index: slice) -> List[_T]: ... + +def consumer(func: _GenFn) -> _GenFn: ... +def ilen(iterable: Iterable[object]) -> int: ... +def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ... +def with_iter( + context_manager: ContextManager[Iterable[_T]], +) -> Iterator[_T]: ... +def one( + iterable: Iterable[_T], + too_short: Optional[_Raisable] = ..., + too_long: Optional[_Raisable] = ..., +) -> _T: ... +def raise_(exception: _Raisable, *args: Any) -> None: ... +def strictly_n( + iterable: Iterable[_T], + n: int, + too_short: Optional[_GenFn] = ..., + too_long: Optional[_GenFn] = ..., +) -> List[_T]: ... +def distinct_permutations( + iterable: Iterable[_T], r: Optional[int] = ... +) -> Iterator[Tuple[_T, ...]]: ... +def intersperse( + e: _U, iterable: Iterable[_T], n: int = ... +) -> Iterator[Union[_T, _U]]: ... +def unique_to_each(*iterables: Iterable[_T]) -> List[List[_T]]: ... +@overload +def windowed( + seq: Iterable[_T], n: int, *, step: int = ... +) -> Iterator[Tuple[Optional[_T], ...]]: ... +@overload +def windowed( + seq: Iterable[_T], n: int, fillvalue: _U, step: int = ... +) -> Iterator[Tuple[Union[_T, _U], ...]]: ... +def substrings(iterable: Iterable[_T]) -> Iterator[Tuple[_T, ...]]: ... +def substrings_indexes( + seq: Sequence[_T], reverse: bool = ... +) -> Iterator[Tuple[Sequence[_T], int, int]]: ... + +class bucket(Generic[_T, _U], Container[_U]): + def __init__( + self, + iterable: Iterable[_T], + key: Callable[[_T], _U], + validator: Optional[Callable[[object], object]] = ..., + ) -> None: ... + def __contains__(self, value: object) -> bool: ... + def __iter__(self) -> Iterator[_U]: ... + def __getitem__(self, value: object) -> Iterator[_T]: ... + +def spy( + iterable: Iterable[_T], n: int = ... +) -> Tuple[List[_T], Iterator[_T]]: ... +def interleave(*iterables: Iterable[_T]) -> Iterator[_T]: ... +def interleave_longest(*iterables: Iterable[_T]) -> Iterator[_T]: ... +def interleave_evenly( + iterables: List[Iterable[_T]], lengths: Optional[List[int]] = ... +) -> Iterator[_T]: ... +def collapse( + iterable: Iterable[Any], + base_type: Optional[type] = ..., + levels: Optional[int] = ..., +) -> Iterator[Any]: ... +@overload +def side_effect( + func: Callable[[_T], object], + iterable: Iterable[_T], + chunk_size: None = ..., + before: Optional[Callable[[], object]] = ..., + after: Optional[Callable[[], object]] = ..., +) -> Iterator[_T]: ... +@overload +def side_effect( + func: Callable[[List[_T]], object], + iterable: Iterable[_T], + chunk_size: int, + before: Optional[Callable[[], object]] = ..., + after: Optional[Callable[[], object]] = ..., +) -> Iterator[_T]: ... +def sliced( + seq: Sequence[_T], n: int, strict: bool = ... +) -> Iterator[Sequence[_T]]: ... +def split_at( + iterable: Iterable[_T], + pred: Callable[[_T], object], + maxsplit: int = ..., + keep_separator: bool = ..., +) -> Iterator[List[_T]]: ... +def split_before( + iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ... +) -> Iterator[List[_T]]: ... +def split_after( + iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ... +) -> Iterator[List[_T]]: ... +def split_when( + iterable: Iterable[_T], + pred: Callable[[_T, _T], object], + maxsplit: int = ..., +) -> Iterator[List[_T]]: ... +def split_into( + iterable: Iterable[_T], sizes: Iterable[Optional[int]] +) -> Iterator[List[_T]]: ... +@overload +def padded( + iterable: Iterable[_T], + *, + n: Optional[int] = ..., + next_multiple: bool = ..., +) -> Iterator[Optional[_T]]: ... +@overload +def padded( + iterable: Iterable[_T], + fillvalue: _U, + n: Optional[int] = ..., + next_multiple: bool = ..., +) -> Iterator[Union[_T, _U]]: ... +@overload +def repeat_last(iterable: Iterable[_T]) -> Iterator[_T]: ... +@overload +def repeat_last( + iterable: Iterable[_T], default: _U +) -> Iterator[Union[_T, _U]]: ... +def distribute(n: int, iterable: Iterable[_T]) -> List[Iterator[_T]]: ... +@overload +def stagger( + iterable: Iterable[_T], + offsets: _SizedIterable[int] = ..., + longest: bool = ..., +) -> Iterator[Tuple[Optional[_T], ...]]: ... +@overload +def stagger( + iterable: Iterable[_T], + offsets: _SizedIterable[int] = ..., + longest: bool = ..., + fillvalue: _U = ..., +) -> Iterator[Tuple[Union[_T, _U], ...]]: ... + +class UnequalIterablesError(ValueError): + def __init__( + self, details: Optional[Tuple[int, int, int]] = ... + ) -> None: ... + +@overload +def zip_equal(__iter1: Iterable[_T1]) -> Iterator[Tuple[_T1]]: ... +@overload +def zip_equal( + __iter1: Iterable[_T1], __iter2: Iterable[_T2] +) -> Iterator[Tuple[_T1, _T2]]: ... +@overload +def zip_equal( + __iter1: Iterable[_T], + __iter2: Iterable[_T], + __iter3: Iterable[_T], + *iterables: Iterable[_T], +) -> Iterator[Tuple[_T, ...]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T1], + *, + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: None = None, +) -> Iterator[Tuple[Optional[_T1]]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T1], + __iter2: Iterable[_T2], + *, + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: None = None, +) -> Iterator[Tuple[Optional[_T1], Optional[_T2]]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T], + __iter2: Iterable[_T], + __iter3: Iterable[_T], + *iterables: Iterable[_T], + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: None = None, +) -> Iterator[Tuple[Optional[_T], ...]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T1], + *, + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: _U, +) -> Iterator[Tuple[Union[_T1, _U]]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T1], + __iter2: Iterable[_T2], + *, + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: _U, +) -> Iterator[Tuple[Union[_T1, _U], Union[_T2, _U]]]: ... +@overload +def zip_offset( + __iter1: Iterable[_T], + __iter2: Iterable[_T], + __iter3: Iterable[_T], + *iterables: Iterable[_T], + offsets: _SizedIterable[int], + longest: bool = ..., + fillvalue: _U, +) -> Iterator[Tuple[Union[_T, _U], ...]]: ... +def sort_together( + iterables: Iterable[Iterable[_T]], + key_list: Iterable[int] = ..., + key: Optional[Callable[..., Any]] = ..., + reverse: bool = ..., +) -> List[Tuple[_T, ...]]: ... +def unzip(iterable: Iterable[Sequence[_T]]) -> Tuple[Iterator[_T], ...]: ... +def divide(n: int, iterable: Iterable[_T]) -> List[Iterator[_T]]: ... +def always_iterable( + obj: object, + base_type: Union[ + type, Tuple[Union[type, Tuple[Any, ...]], ...], None + ] = ..., +) -> Iterator[Any]: ... +def adjacent( + predicate: Callable[[_T], bool], + iterable: Iterable[_T], + distance: int = ..., +) -> Iterator[Tuple[bool, _T]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: None = None, + valuefunc: None = None, + reducefunc: None = None, +) -> Iterator[Tuple[_T, Iterator[_T]]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: None, + reducefunc: None, +) -> Iterator[Tuple[_U, Iterator[_T]]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: None, + valuefunc: Callable[[_T], _V], + reducefunc: None, +) -> Iterable[Tuple[_T, Iterable[_V]]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: Callable[[_T], _V], + reducefunc: None, +) -> Iterable[Tuple[_U, Iterator[_V]]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: None, + valuefunc: None, + reducefunc: Callable[[Iterator[_T]], _W], +) -> Iterable[Tuple[_T, _W]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: None, + reducefunc: Callable[[Iterator[_T]], _W], +) -> Iterable[Tuple[_U, _W]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: None, + valuefunc: Callable[[_T], _V], + reducefunc: Callable[[Iterable[_V]], _W], +) -> Iterable[Tuple[_T, _W]]: ... +@overload +def groupby_transform( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: Callable[[_T], _V], + reducefunc: Callable[[Iterable[_V]], _W], +) -> Iterable[Tuple[_U, _W]]: ... + +class numeric_range(Generic[_T, _U], Sequence[_T], Hashable, Reversible[_T]): + @overload + def __init__(self, __stop: _T) -> None: ... + @overload + def __init__(self, __start: _T, __stop: _T) -> None: ... + @overload + def __init__(self, __start: _T, __stop: _T, __step: _U) -> None: ... + def __bool__(self) -> bool: ... + def __contains__(self, elem: object) -> bool: ... + def __eq__(self, other: object) -> bool: ... + @overload + def __getitem__(self, key: int) -> _T: ... + @overload + def __getitem__(self, key: slice) -> numeric_range[_T, _U]: ... + def __hash__(self) -> int: ... + def __iter__(self) -> Iterator[_T]: ... + def __len__(self) -> int: ... + def __reduce__( + self, + ) -> Tuple[Type[numeric_range[_T, _U]], Tuple[_T, _T, _U]]: ... + def __repr__(self) -> str: ... + def __reversed__(self) -> Iterator[_T]: ... + def count(self, value: _T) -> int: ... + def index(self, value: _T) -> int: ... # type: ignore + +def count_cycle( + iterable: Iterable[_T], n: Optional[int] = ... +) -> Iterable[Tuple[int, _T]]: ... +def mark_ends( + iterable: Iterable[_T], +) -> Iterable[Tuple[bool, bool, _T]]: ... +def locate( + iterable: Iterable[object], + pred: Callable[..., Any] = ..., + window_size: Optional[int] = ..., +) -> Iterator[int]: ... +def lstrip( + iterable: Iterable[_T], pred: Callable[[_T], object] +) -> Iterator[_T]: ... +def rstrip( + iterable: Iterable[_T], pred: Callable[[_T], object] +) -> Iterator[_T]: ... +def strip( + iterable: Iterable[_T], pred: Callable[[_T], object] +) -> Iterator[_T]: ... + +class islice_extended(Generic[_T], Iterator[_T]): + def __init__( + self, iterable: Iterable[_T], *args: Optional[int] + ) -> None: ... + def __iter__(self) -> islice_extended[_T]: ... + def __next__(self) -> _T: ... + def __getitem__(self, index: slice) -> islice_extended[_T]: ... + +def always_reversible(iterable: Iterable[_T]) -> Iterator[_T]: ... +def consecutive_groups( + iterable: Iterable[_T], ordering: Callable[[_T], int] = ... +) -> Iterator[Iterator[_T]]: ... +@overload +def difference( + iterable: Iterable[_T], + func: Callable[[_T, _T], _U] = ..., + *, + initial: None = ..., +) -> Iterator[Union[_T, _U]]: ... +@overload +def difference( + iterable: Iterable[_T], func: Callable[[_T, _T], _U] = ..., *, initial: _U +) -> Iterator[_U]: ... + +class SequenceView(Generic[_T], Sequence[_T]): + def __init__(self, target: Sequence[_T]) -> None: ... + @overload + def __getitem__(self, index: int) -> _T: ... + @overload + def __getitem__(self, index: slice) -> Sequence[_T]: ... + def __len__(self) -> int: ... + +class seekable(Generic[_T], Iterator[_T]): + def __init__( + self, iterable: Iterable[_T], maxlen: Optional[int] = ... + ) -> None: ... + def __iter__(self) -> seekable[_T]: ... + def __next__(self) -> _T: ... + def __bool__(self) -> bool: ... + @overload + def peek(self) -> _T: ... + @overload + def peek(self, default: _U) -> Union[_T, _U]: ... + def elements(self) -> SequenceView[_T]: ... + def seek(self, index: int) -> None: ... + +class run_length: + @staticmethod + def encode(iterable: Iterable[_T]) -> Iterator[Tuple[_T, int]]: ... + @staticmethod + def decode(iterable: Iterable[Tuple[_T, int]]) -> Iterator[_T]: ... + +def exactly_n( + iterable: Iterable[_T], n: int, predicate: Callable[[_T], object] = ... +) -> bool: ... +def circular_shifts(iterable: Iterable[_T]) -> List[Tuple[_T, ...]]: ... +def make_decorator( + wrapping_func: Callable[..., _U], result_index: int = ... +) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., _U]]]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: None = ..., + reducefunc: None = ..., +) -> Dict[_U, List[_T]]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: Callable[[_T], _V], + reducefunc: None = ..., +) -> Dict[_U, List[_V]]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: None = ..., + reducefunc: Callable[[List[_T]], _W] = ..., +) -> Dict[_U, _W]: ... +@overload +def map_reduce( + iterable: Iterable[_T], + keyfunc: Callable[[_T], _U], + valuefunc: Callable[[_T], _V], + reducefunc: Callable[[List[_V]], _W], +) -> Dict[_U, _W]: ... +def rlocate( + iterable: Iterable[_T], + pred: Callable[..., object] = ..., + window_size: Optional[int] = ..., +) -> Iterator[int]: ... +def replace( + iterable: Iterable[_T], + pred: Callable[..., object], + substitutes: Iterable[_U], + count: Optional[int] = ..., + window_size: int = ..., +) -> Iterator[Union[_T, _U]]: ... +def partitions(iterable: Iterable[_T]) -> Iterator[List[List[_T]]]: ... +def set_partitions( + iterable: Iterable[_T], k: Optional[int] = ... +) -> Iterator[List[List[_T]]]: ... + +class time_limited(Generic[_T], Iterator[_T]): + def __init__( + self, limit_seconds: float, iterable: Iterable[_T] + ) -> None: ... + def __iter__(self) -> islice_extended[_T]: ... + def __next__(self) -> _T: ... + +@overload +def only( + iterable: Iterable[_T], *, too_long: Optional[_Raisable] = ... +) -> Optional[_T]: ... +@overload +def only( + iterable: Iterable[_T], default: _U, too_long: Optional[_Raisable] = ... +) -> Union[_T, _U]: ... +def ichunked(iterable: Iterable[_T], n: int) -> Iterator[Iterator[_T]]: ... +def distinct_combinations( + iterable: Iterable[_T], r: int +) -> Iterator[Tuple[_T, ...]]: ... +def filter_except( + validator: Callable[[Any], object], + iterable: Iterable[_T], + *exceptions: Type[BaseException], +) -> Iterator[_T]: ... +def map_except( + function: Callable[[Any], _U], + iterable: Iterable[_T], + *exceptions: Type[BaseException], +) -> Iterator[_U]: ... +def map_if( + iterable: Iterable[Any], + pred: Callable[[Any], bool], + func: Callable[[Any], Any], + func_else: Optional[Callable[[Any], Any]] = ..., +) -> Iterator[Any]: ... +def sample( + iterable: Iterable[_T], + k: int, + weights: Optional[Iterable[float]] = ..., +) -> List[_T]: ... +def is_sorted( + iterable: Iterable[_T], + key: Optional[Callable[[_T], _U]] = ..., + reverse: bool = False, + strict: bool = False, +) -> bool: ... + +class AbortThread(BaseException): + pass + +class callback_iter(Generic[_T], Iterator[_T]): + def __init__( + self, + func: Callable[..., Any], + callback_kwd: str = ..., + wait_seconds: float = ..., + ) -> None: ... + def __enter__(self) -> callback_iter[_T]: ... + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Optional[bool]: ... + def __iter__(self) -> callback_iter[_T]: ... + def __next__(self) -> _T: ... + def _reader(self) -> Iterator[_T]: ... + @property + def done(self) -> bool: ... + @property + def result(self) -> Any: ... + +def windowed_complete( + iterable: Iterable[_T], n: int +) -> Iterator[Tuple[_T, ...]]: ... +def all_unique( + iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ... +) -> bool: ... +def nth_product(index: int, *args: Iterable[_T]) -> Tuple[_T, ...]: ... +def nth_permutation( + iterable: Iterable[_T], r: int, index: int +) -> Tuple[_T, ...]: ... +def value_chain(*args: Union[_T, Iterable[_T]]) -> Iterable[_T]: ... +def product_index(element: Iterable[_T], *args: Iterable[_T]) -> int: ... +def combination_index( + element: Iterable[_T], iterable: Iterable[_T] +) -> int: ... +def permutation_index( + element: Iterable[_T], iterable: Iterable[_T] +) -> int: ... +def repeat_each(iterable: Iterable[_T], n: int = ...) -> Iterator[_T]: ... + +class countable(Generic[_T], Iterator[_T]): + def __init__(self, iterable: Iterable[_T]) -> None: ... + def __iter__(self) -> countable[_T]: ... + def __next__(self) -> _T: ... + +def chunked_even(iterable: Iterable[_T], n: int) -> Iterator[List[_T]]: ... +def zip_broadcast( + *objects: Union[_T, Iterable[_T]], + scalar_types: Union[ + type, Tuple[Union[type, Tuple[Any, ...]], ...], None + ] = ..., + strict: bool = ..., +) -> Iterable[Tuple[_T, ...]]: ... +def unique_in_window( + iterable: Iterable[_T], n: int, key: Optional[Callable[[_T], _U]] = ... +) -> Iterator[_T]: ... +def duplicates_everseen( + iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ... +) -> Iterator[_T]: ... +def duplicates_justseen( + iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ... +) -> Iterator[_T]: ... + +class _SupportsLessThan(Protocol): + def __lt__(self, __other: Any) -> bool: ... + +_SupportsLessThanT = TypeVar("_SupportsLessThanT", bound=_SupportsLessThan) + +@overload +def minmax( + iterable_or_value: Iterable[_SupportsLessThanT], *, key: None = None +) -> Tuple[_SupportsLessThanT, _SupportsLessThanT]: ... +@overload +def minmax( + iterable_or_value: Iterable[_T], *, key: Callable[[_T], _SupportsLessThan] +) -> Tuple[_T, _T]: ... +@overload +def minmax( + iterable_or_value: Iterable[_SupportsLessThanT], + *, + key: None = None, + default: _U, +) -> Union[_U, Tuple[_SupportsLessThanT, _SupportsLessThanT]]: ... +@overload +def minmax( + iterable_or_value: Iterable[_T], + *, + key: Callable[[_T], _SupportsLessThan], + default: _U, +) -> Union[_U, Tuple[_T, _T]]: ... +@overload +def minmax( + iterable_or_value: _SupportsLessThanT, + __other: _SupportsLessThanT, + *others: _SupportsLessThanT, +) -> Tuple[_SupportsLessThanT, _SupportsLessThanT]: ... +@overload +def minmax( + iterable_or_value: _T, + __other: _T, + *others: _T, + key: Callable[[_T], _SupportsLessThan], +) -> Tuple[_T, _T]: ... +def longest_common_prefix( + iterables: Iterable[Iterable[_T]], +) -> Iterator[_T]: ... +def iequals(*iterables: Iterable[object]) -> bool: ... +def constrained_batches( + iterable: Iterable[object], + max_size: int, + max_count: Optional[int] = ..., + get_len: Callable[[_T], object] = ..., + strict: bool = ..., +) -> Iterator[Tuple[_T]]: ... diff --git a/libs/win/more_itertools/py.typed b/libs/win/more_itertools/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/libs/win/more_itertools/recipes.py b/libs/win/more_itertools/recipes.py index 3a7706cb..85796207 100644 --- a/libs/win/more_itertools/recipes.py +++ b/libs/win/more_itertools/recipes.py @@ -7,20 +7,33 @@ Some backward-compatible usability improvements have been made. .. [1] http://docs.python.org/library/itertools.html#recipes """ -from collections import deque -from itertools import ( - chain, combinations, count, cycle, groupby, islice, repeat, starmap, tee -) +import math import operator + +from collections import deque +from collections.abc import Sized +from functools import reduce +from itertools import ( + chain, + combinations, + compress, + count, + cycle, + groupby, + islice, + repeat, + starmap, + tee, + zip_longest, +) from random import randrange, sample, choice -from six import PY2 -from six.moves import filter, filterfalse, map, range, zip, zip_longest - __all__ = [ - 'accumulate', 'all_equal', + 'batched', + 'before_and_after', 'consume', + 'convolve', 'dotproduct', 'first_true', 'flatten', @@ -30,8 +43,10 @@ __all__ = [ 'nth', 'nth_combination', 'padnone', + 'pad_none', 'pairwise', 'partition', + 'polynomial_from_roots', 'powerset', 'prepend', 'quantify', @@ -41,42 +56,18 @@ __all__ = [ 'random_product', 'repeatfunc', 'roundrobin', + 'sieve', + 'sliding_window', + 'subslices', 'tabulate', 'tail', 'take', + 'triplewise', 'unique_everseen', 'unique_justseen', ] - -def accumulate(iterable, func=operator.add): - """ - Return an iterator whose items are the accumulated results of a function - (specified by the optional *func* argument) that takes two arguments. - By default, returns accumulated sums with :func:`operator.add`. - - >>> list(accumulate([1, 2, 3, 4, 5])) # Running sum - [1, 3, 6, 10, 15] - >>> list(accumulate([1, 2, 3], func=operator.mul)) # Running product - [1, 2, 6] - >>> list(accumulate([0, 1, -1, 2, 3, 2], func=max)) # Running maximum - [0, 1, 1, 2, 3, 3] - - This function is available in the ``itertools`` module for Python 3.2 and - greater. - - """ - it = iter(iterable) - try: - total = next(it) - except StopIteration: - return - else: - yield total - - for element in it: - total = func(total, element) - yield total +_marker = object() def take(n, iterable): @@ -84,11 +75,12 @@ def take(n, iterable): >>> take(3, range(10)) [0, 1, 2] - >>> take(5, range(3)) - [0, 1, 2] - Effectively a short replacement for ``next`` based iterator consumption - when you want more than one item, but less than the whole iterator. + If there are fewer than *n* items in the iterable, all of them are + returned. + + >>> take(10, range(3)) + [0, 1, 2] """ return list(islice(iterable, n)) @@ -115,12 +107,19 @@ def tabulate(function, start=0): def tail(n, iterable): """Return an iterator over the last *n* items of *iterable*. - >>> t = tail(3, 'ABCDEFG') - >>> list(t) - ['E', 'F', 'G'] + >>> t = tail(3, 'ABCDEFG') + >>> list(t) + ['E', 'F', 'G'] """ - return iter(deque(iterable, maxlen=n)) + # If the given iterable has a length, then we can use islice to get its + # final elements. Note that if the iterable is not actually Iterable, + # either islice or deque will throw a TypeError. This is why we don't + # check if it is Iterable. + if isinstance(iterable, Sized): + yield from islice(iterable, max(0, len(iterable) - n), None) + else: + yield from iter(deque(iterable, maxlen=n)) def consume(iterator, n=None): @@ -166,11 +165,11 @@ def consume(iterator, n=None): def nth(iterable, n, default=None): """Returns the nth item or a default value. - >>> l = range(10) - >>> nth(l, 3) - 3 - >>> nth(l, 20, "zebra") - 'zebra' + >>> l = range(10) + >>> nth(l, 3) + 3 + >>> nth(l, 20, "zebra") + 'zebra' """ return next(islice(iterable, n, None), default) @@ -193,17 +192,17 @@ def all_equal(iterable): def quantify(iterable, pred=bool): """Return the how many times the predicate is true. - >>> quantify([True, False, True]) - 2 + >>> quantify([True, False, True]) + 2 """ return sum(map(pred, iterable)) -def padnone(iterable): +def pad_none(iterable): """Returns the sequence of elements and then returns ``None`` indefinitely. - >>> take(5, padnone(range(3))) + >>> take(5, pad_none(range(3))) [0, 1, 2, None, None] Useful for emulating the behavior of the built-in :func:`map` function. @@ -214,11 +213,14 @@ def padnone(iterable): return chain(iterable, repeat(None)) +padnone = pad_none + + def ncycles(iterable, n): """Returns the sequence elements *n* times - >>> list(ncycles(["a", "b"], 3)) - ['a', 'b', 'a', 'b', 'a', 'b'] + >>> list(ncycles(["a", "b"], 3)) + ['a', 'b', 'a', 'b', 'a', 'b'] """ return chain.from_iterable(repeat(tuple(iterable), n)) @@ -227,8 +229,8 @@ def ncycles(iterable, n): def dotproduct(vec1, vec2): """Returns the dot product of the two iterables. - >>> dotproduct([10, 10], [20, 20]) - 400 + >>> dotproduct([10, 10], [20, 20]) + 400 """ return sum(map(operator.mul, vec1, vec2)) @@ -273,27 +275,109 @@ def repeatfunc(func, times=None, *args): return starmap(func, repeat(args, times)) -def pairwise(iterable): +def _pairwise(iterable): """Returns an iterator of paired items, overlapping, from the original - >>> take(4, pairwise(count())) - [(0, 1), (1, 2), (2, 3), (3, 4)] + >>> take(4, pairwise(count())) + [(0, 1), (1, 2), (2, 3), (3, 4)] + + On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`. """ a, b = tee(iterable) next(b, None) - return zip(a, b) + yield from zip(a, b) -def grouper(n, iterable, fillvalue=None): - """Collect data into fixed-length chunks or blocks. +try: + from itertools import pairwise as itertools_pairwise +except ImportError: + pairwise = _pairwise +else: - >>> list(grouper(3, 'ABCDEFG', 'x')) - [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] + def pairwise(iterable): + yield from itertools_pairwise(iterable) + + pairwise.__doc__ = _pairwise.__doc__ + + +class UnequalIterablesError(ValueError): + def __init__(self, details=None): + msg = 'Iterables have different lengths' + if details is not None: + msg += (': index 0 has length {}; index {} has length {}').format( + *details + ) + + super().__init__(msg) + + +def _zip_equal_generator(iterables): + for combo in zip_longest(*iterables, fillvalue=_marker): + for val in combo: + if val is _marker: + raise UnequalIterablesError() + yield combo + + +def _zip_equal(*iterables): + # Check whether the iterables are all the same size. + try: + first_size = len(iterables[0]) + for i, it in enumerate(iterables[1:], 1): + size = len(it) + if size != first_size: + break + else: + # If we didn't break out, we can use the built-in zip. + return zip(*iterables) + + # If we did break out, there was a mismatch. + raise UnequalIterablesError(details=(first_size, i, size)) + # If any one of the iterables didn't have a length, start reading + # them until one runs out. + except TypeError: + return _zip_equal_generator(iterables) + + +def grouper(iterable, n, incomplete='fill', fillvalue=None): + """Group elements from *iterable* into fixed-length groups of length *n*. + + >>> list(grouper('ABCDEF', 3)) + [('A', 'B', 'C'), ('D', 'E', 'F')] + + The keyword arguments *incomplete* and *fillvalue* control what happens for + iterables whose length is not a multiple of *n*. + + When *incomplete* is `'fill'`, the last group will contain instances of + *fillvalue*. + + >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x')) + [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] + + When *incomplete* is `'ignore'`, the last group will not be emitted. + + >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x')) + [('A', 'B', 'C'), ('D', 'E', 'F')] + + When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised. + + >>> it = grouper('ABCDEFG', 3, incomplete='strict') + >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + UnequalIterablesError """ args = [iter(iterable)] * n - return zip_longest(fillvalue=fillvalue, *args) + if incomplete == 'fill': + return zip_longest(*args, fillvalue=fillvalue) + if incomplete == 'strict': + return _zip_equal(*args) + if incomplete == 'ignore': + return zip(*args) + else: + raise ValueError('Expected fill, strict, or ignore') def roundrobin(*iterables): @@ -309,10 +393,7 @@ def roundrobin(*iterables): """ # Recipe credited to George Sakkis pending = len(iterables) - if PY2: - nexts = cycle(iter(it).next for it in iterables) - else: - nexts = cycle(iter(it).__next__ for it in iterables) + nexts = cycle(iter(it).__next__ for it in iterables) while pending: try: for next in nexts: @@ -334,18 +415,43 @@ def partition(pred, iterable): >>> list(even_items), list(odd_items) ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9]) + If *pred* is None, :func:`bool` is used. + + >>> iterable = [0, 1, False, True, '', ' '] + >>> false_items, true_items = partition(None, iterable) + >>> list(false_items), list(true_items) + ([0, False, ''], [1, True, ' ']) + """ - # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9 - t1, t2 = tee(iterable) - return filterfalse(pred, t1), filter(pred, t2) + if pred is None: + pred = bool + + evaluations = ((pred(x), x) for x in iterable) + t1, t2 = tee(evaluations) + return ( + (x for (cond, x) in t1 if not cond), + (x for (cond, x) in t2 if cond), + ) def powerset(iterable): """Yields all possible subsets of the iterable. - >>> list(powerset([1,2,3])) + >>> list(powerset([1, 2, 3])) [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] + :func:`powerset` will operate on iterables that aren't :class:`set` + instances, so repeated elements in the input will produce repeated elements + in the output. Use :func:`unique_everseen` on the input to avoid generating + duplicates: + + >>> seq = [1, 1, 0] + >>> list(powerset(seq)) + [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)] + >>> from more_itertools import unique_everseen + >>> list(powerset(unique_everseen(seq))) + [(), (1,), (0,), (1, 0)] + """ s = list(iterable) return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)) @@ -363,41 +469,46 @@ def unique_everseen(iterable, key=None): Sequences with a mix of hashable and unhashable items can be used. The function will be slower (i.e., `O(n^2)`) for unhashable items. + Remember that ``list`` objects are unhashable - you can use the *key* + parameter to transform the list to a tuple (which is hashable) to + avoid a slowdown. + + >>> iterable = ([1, 2], [2, 3], [1, 2]) + >>> list(unique_everseen(iterable)) # Slow + [[1, 2], [2, 3]] + >>> list(unique_everseen(iterable, key=tuple)) # Faster + [[1, 2], [2, 3]] + + Similary, you may want to convert unhashable ``set`` objects with + ``key=frozenset``. For ``dict`` objects, + ``key=lambda x: frozenset(x.items())`` can be used. + """ seenset = set() seenset_add = seenset.add seenlist = [] seenlist_add = seenlist.append - if key is None: - for element in iterable: - try: - if element not in seenset: - seenset_add(element) - yield element - except TypeError: - if element not in seenlist: - seenlist_add(element) - yield element - else: - for element in iterable: - k = key(element) - try: - if k not in seenset: - seenset_add(k) - yield element - except TypeError: - if k not in seenlist: - seenlist_add(k) - yield element + use_key = key is not None + + for element in iterable: + k = key(element) if use_key else element + try: + if k not in seenset: + seenset_add(k) + yield element + except TypeError: + if k not in seenlist: + seenlist_add(k) + yield element def unique_justseen(iterable, key=None): """Yields elements in order, ignoring serial duplicates - >>> list(unique_justseen('AAAABBBCCDAABBB')) - ['A', 'B', 'C', 'D', 'A', 'B'] - >>> list(unique_justseen('ABBCcAD', str.lower)) - ['A', 'B', 'C', 'A', 'D'] + >>> list(unique_justseen('AAAABBBCCDAABBB')) + ['A', 'B', 'C', 'D', 'A', 'B'] + >>> list(unique_justseen('ABBCcAD', str.lower)) + ['A', 'B', 'C', 'A', 'D'] """ return map(next, map(operator.itemgetter(1), groupby(iterable, key))) @@ -414,6 +525,16 @@ def iter_except(func, exception, first=None): >>> list(iter_except(l.pop, IndexError)) [2, 1, 0] + Multiple exceptions can be specified as a stopping condition: + + >>> l = [1, 2, 3, '...', 4, 5, 6] + >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) + [7, 6, 5] + >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) + [4, 3, 2] + >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError))) + [] + """ try: if first is not None: @@ -424,7 +545,7 @@ def iter_except(func, exception, first=None): pass -def first_true(iterable, default=False, pred=None): +def first_true(iterable, default=None, pred=None): """ Returns the first true value in the iterable. @@ -444,7 +565,7 @@ def first_true(iterable, default=False, pred=None): return next(filter(pred, iterable), default) -def random_product(*args, **kwds): +def random_product(*args, repeat=1): """Draw an item at random from each of the input iterables. >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP @@ -460,7 +581,7 @@ def random_product(*args, **kwds): ``itertools.product(*args, **kwarg)``. """ - pools = [tuple(pool) for pool in args] * kwds.get('repeat', 1) + pools = [tuple(pool) for pool in args] * repeat return tuple(choice(pool) for pool in pools) @@ -523,6 +644,12 @@ def nth_combination(iterable, r, index): sort position *index* directly, without computing the previous subsequences. + >>> nth_combination(range(5), 3, 5) + (0, 3, 4) + + ``ValueError`` will be raised If *r* is negative or greater than the length + of *iterable*. + ``IndexError`` will be raised if the given *index* is invalid. """ pool = tuple(iterable) n = len(pool) @@ -559,7 +686,156 @@ def prepend(value, iterator): >>> list(prepend(value, iterator)) ['0', '1', '2', '3'] - To prepend multiple values, see :func:`itertools.chain`. + To prepend multiple values, see :func:`itertools.chain` + or :func:`value_chain`. """ return chain([value], iterator) + + +def convolve(signal, kernel): + """Convolve the iterable *signal* with the iterable *kernel*. + + >>> signal = (1, 2, 3, 4, 5) + >>> kernel = [3, 2, 1] + >>> list(convolve(signal, kernel)) + [3, 8, 14, 20, 26, 14, 5] + + Note: the input arguments are not interchangeable, as the *kernel* + is immediately consumed and stored. + + """ + kernel = tuple(kernel)[::-1] + n = len(kernel) + window = deque([0], maxlen=n) * n + for x in chain(signal, repeat(0, n - 1)): + window.append(x) + yield sum(map(operator.mul, kernel, window)) + + +def before_and_after(predicate, it): + """A variant of :func:`takewhile` that allows complete access to the + remainder of the iterator. + + >>> it = iter('ABCdEfGhI') + >>> all_upper, remainder = before_and_after(str.isupper, it) + >>> ''.join(all_upper) + 'ABC' + >>> ''.join(remainder) # takewhile() would lose the 'd' + 'dEfGhI' + + Note that the first iterator must be fully consumed before the second + iterator can generate valid results. + """ + it = iter(it) + transition = [] + + def true_iterator(): + for elem in it: + if predicate(elem): + yield elem + else: + transition.append(elem) + return + + # Note: this is different from itertools recipes to allow nesting + # before_and_after remainders into before_and_after again. See tests + # for an example. + remainder_iterator = chain(transition, it) + + return true_iterator(), remainder_iterator + + +def triplewise(iterable): + """Return overlapping triplets from *iterable*. + + >>> list(triplewise('ABCDE')) + [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')] + + """ + for (a, _), (b, c) in pairwise(pairwise(iterable)): + yield a, b, c + + +def sliding_window(iterable, n): + """Return a sliding window of width *n* over *iterable*. + + >>> list(sliding_window(range(6), 4)) + [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)] + + If *iterable* has fewer than *n* items, then nothing is yielded: + + >>> list(sliding_window(range(3), 4)) + [] + + For a variant with more features, see :func:`windowed`. + """ + it = iter(iterable) + window = deque(islice(it, n), maxlen=n) + if len(window) == n: + yield tuple(window) + for x in it: + window.append(x) + yield tuple(window) + + +def subslices(iterable): + """Return all contiguous non-empty subslices of *iterable*. + + >>> list(subslices('ABC')) + [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']] + + This is similar to :func:`substrings`, but emits items in a different + order. + """ + seq = list(iterable) + slices = starmap(slice, combinations(range(len(seq) + 1), 2)) + return map(operator.getitem, repeat(seq), slices) + + +def polynomial_from_roots(roots): + """Compute a polynomial's coefficients from its roots. + + >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3) + >>> polynomial_from_roots(roots) # x^3 - 4 * x^2 - 17 * x + 60 + [1, -4, -17, 60] + """ + # Use math.prod for Python 3.8+, + prod = getattr(math, 'prod', lambda x: reduce(operator.mul, x, 1)) + roots = list(map(operator.neg, roots)) + return [ + sum(map(prod, combinations(roots, k))) for k in range(len(roots) + 1) + ] + + +def sieve(n): + """Yield the primes less than n. + + >>> list(sieve(30)) + [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] + """ + isqrt = getattr(math, 'isqrt', lambda x: int(math.sqrt(x))) + limit = isqrt(n) + 1 + data = bytearray([1]) * n + data[:2] = 0, 0 + for p in compress(range(limit), data): + data[p + p : n : p] = bytearray(len(range(p + p, n, p))) + + return compress(count(), data) + + +def batched(iterable, n): + """Batch data into lists of length *n*. The last batch may be shorter. + + >>> list(batched('ABCDEFG', 3)) + [['A', 'B', 'C'], ['D', 'E', 'F'], ['G']] + + This recipe is from the ``itertools`` docs. This library also provides + :func:`chunked`, which has a different implementation. + """ + it = iter(iterable) + while True: + batch = list(islice(it, n)) + if not batch: + break + yield batch diff --git a/libs/win/more_itertools/recipes.pyi b/libs/win/more_itertools/recipes.pyi new file mode 100644 index 00000000..29415c5a --- /dev/null +++ b/libs/win/more_itertools/recipes.pyi @@ -0,0 +1,110 @@ +"""Stubs for more_itertools.recipes""" +from typing import ( + Any, + Callable, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) +from typing_extensions import overload, Type + +# Type and type variable definitions +_T = TypeVar('_T') +_U = TypeVar('_U') + +def take(n: int, iterable: Iterable[_T]) -> List[_T]: ... +def tabulate( + function: Callable[[int], _T], start: int = ... +) -> Iterator[_T]: ... +def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: ... +def consume(iterator: Iterable[object], n: Optional[int] = ...) -> None: ... +@overload +def nth(iterable: Iterable[_T], n: int) -> Optional[_T]: ... +@overload +def nth(iterable: Iterable[_T], n: int, default: _U) -> Union[_T, _U]: ... +def all_equal(iterable: Iterable[object]) -> bool: ... +def quantify( + iterable: Iterable[_T], pred: Callable[[_T], bool] = ... +) -> int: ... +def pad_none(iterable: Iterable[_T]) -> Iterator[Optional[_T]]: ... +def padnone(iterable: Iterable[_T]) -> Iterator[Optional[_T]]: ... +def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: ... +def dotproduct(vec1: Iterable[object], vec2: Iterable[object]) -> object: ... +def flatten(listOfLists: Iterable[Iterable[_T]]) -> Iterator[_T]: ... +def repeatfunc( + func: Callable[..., _U], times: Optional[int] = ..., *args: Any +) -> Iterator[_U]: ... +def pairwise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T]]: ... +def grouper( + iterable: Iterable[_T], + n: int, + incomplete: str = ..., + fillvalue: _U = ..., +) -> Iterator[Tuple[Union[_T, _U], ...]]: ... +def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: ... +def partition( + pred: Optional[Callable[[_T], object]], iterable: Iterable[_T] +) -> Tuple[Iterator[_T], Iterator[_T]]: ... +def powerset(iterable: Iterable[_T]) -> Iterator[Tuple[_T, ...]]: ... +def unique_everseen( + iterable: Iterable[_T], key: Optional[Callable[[_T], _U]] = ... +) -> Iterator[_T]: ... +def unique_justseen( + iterable: Iterable[_T], key: Optional[Callable[[_T], object]] = ... +) -> Iterator[_T]: ... +@overload +def iter_except( + func: Callable[[], _T], + exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], + first: None = ..., +) -> Iterator[_T]: ... +@overload +def iter_except( + func: Callable[[], _T], + exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]], + first: Callable[[], _U], +) -> Iterator[Union[_T, _U]]: ... +@overload +def first_true( + iterable: Iterable[_T], *, pred: Optional[Callable[[_T], object]] = ... +) -> Optional[_T]: ... +@overload +def first_true( + iterable: Iterable[_T], + default: _U, + pred: Optional[Callable[[_T], object]] = ..., +) -> Union[_T, _U]: ... +def random_product( + *args: Iterable[_T], repeat: int = ... +) -> Tuple[_T, ...]: ... +def random_permutation( + iterable: Iterable[_T], r: Optional[int] = ... +) -> Tuple[_T, ...]: ... +def random_combination(iterable: Iterable[_T], r: int) -> Tuple[_T, ...]: ... +def random_combination_with_replacement( + iterable: Iterable[_T], r: int +) -> Tuple[_T, ...]: ... +def nth_combination( + iterable: Iterable[_T], r: int, index: int +) -> Tuple[_T, ...]: ... +def prepend(value: _T, iterator: Iterable[_U]) -> Iterator[Union[_T, _U]]: ... +def convolve(signal: Iterable[_T], kernel: Iterable[_T]) -> Iterator[_T]: ... +def before_and_after( + predicate: Callable[[_T], bool], it: Iterable[_T] +) -> Tuple[Iterator[_T], Iterator[_T]]: ... +def triplewise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T, _T]]: ... +def sliding_window( + iterable: Iterable[_T], n: int +) -> Iterator[Tuple[_T, ...]]: ... +def subslices(iterable: Iterable[_T]) -> Iterator[List[_T]]: ... +def polynomial_from_roots(roots: Sequence[int]) -> List[int]: ... +def sieve(n: int) -> Iterator[int]: ... +def batched( + iterable: Iterable[_T], + n: int, +) -> Iterator[List[_T]]: ... diff --git a/libs/win/more_itertools/tests/test_more.py b/libs/win/more_itertools/tests/test_more.py deleted file mode 100644 index a1b1e431..00000000 --- a/libs/win/more_itertools/tests/test_more.py +++ /dev/null @@ -1,2074 +0,0 @@ -from __future__ import division, print_function, unicode_literals - -from collections import OrderedDict -from decimal import Decimal -from doctest import DocTestSuite -from fractions import Fraction -from functools import partial, reduce -from heapq import merge -from io import StringIO -from itertools import ( - chain, - count, - groupby, - islice, - permutations, - product, - repeat, -) -from operator import add, mul, itemgetter -from unittest import TestCase - -from six.moves import filter, map, range, zip - -import more_itertools as mi - - -def load_tests(loader, tests, ignore): - # Add the doctests - tests.addTests(DocTestSuite('more_itertools.more')) - return tests - - -class CollateTests(TestCase): - """Unit tests for ``collate()``""" - # Also accidentally tests peekable, though that could use its own tests - - def test_default(self): - """Test with the default `key` function.""" - iterables = [range(4), range(7), range(3, 6)] - self.assertEqual( - sorted(reduce(list.__add__, [list(it) for it in iterables])), - list(mi.collate(*iterables)) - ) - - def test_key(self): - """Test using a custom `key` function.""" - iterables = [range(5, 0, -1), range(4, 0, -1)] - actual = sorted( - reduce(list.__add__, [list(it) for it in iterables]), reverse=True - ) - expected = list(mi.collate(*iterables, key=lambda x: -x)) - self.assertEqual(actual, expected) - - def test_empty(self): - """Be nice if passed an empty list of iterables.""" - self.assertEqual([], list(mi.collate())) - - def test_one(self): - """Work when only 1 iterable is passed.""" - self.assertEqual([0, 1], list(mi.collate(range(2)))) - - def test_reverse(self): - """Test the `reverse` kwarg.""" - iterables = [range(4, 0, -1), range(7, 0, -1), range(3, 6, -1)] - - actual = sorted( - reduce(list.__add__, [list(it) for it in iterables]), reverse=True - ) - expected = list(mi.collate(*iterables, reverse=True)) - self.assertEqual(actual, expected) - - def test_alias(self): - self.assertNotEqual(merge.__doc__, mi.collate.__doc__) - self.assertNotEqual(partial.__doc__, mi.collate.__doc__) - - -class ChunkedTests(TestCase): - """Tests for ``chunked()``""" - - def test_even(self): - """Test when ``n`` divides evenly into the length of the iterable.""" - self.assertEqual( - list(mi.chunked('ABCDEF', 3)), [['A', 'B', 'C'], ['D', 'E', 'F']] - ) - - def test_odd(self): - """Test when ``n`` does not divide evenly into the length of the - iterable. - - """ - self.assertEqual( - list(mi.chunked('ABCDE', 3)), [['A', 'B', 'C'], ['D', 'E']] - ) - - -class FirstTests(TestCase): - """Tests for ``first()``""" - - def test_many(self): - """Test that it works on many-item iterables.""" - # Also try it on a generator expression to make sure it works on - # whatever those return, across Python versions. - self.assertEqual(mi.first(x for x in range(4)), 0) - - def test_one(self): - """Test that it doesn't raise StopIteration prematurely.""" - self.assertEqual(mi.first([3]), 3) - - def test_empty_stop_iteration(self): - """It should raise StopIteration for empty iterables.""" - self.assertRaises(ValueError, lambda: mi.first([])) - - def test_default(self): - """It should return the provided default arg for empty iterables.""" - self.assertEqual(mi.first([], 'boo'), 'boo') - - -class IterOnlyRange: - """User-defined iterable class which only support __iter__. - - It is not specified to inherit ``object``, so indexing on a instance will - raise an ``AttributeError`` rather than ``TypeError`` in Python 2. - - >>> r = IterOnlyRange(5) - >>> r[0] - AttributeError: IterOnlyRange instance has no attribute '__getitem__' - - Note: In Python 3, ``TypeError`` will be raised because ``object`` is - inherited implicitly by default. - - >>> r[0] - TypeError: 'IterOnlyRange' object does not support indexing - """ - def __init__(self, n): - """Set the length of the range.""" - self.n = n - - def __iter__(self): - """Works same as range().""" - return iter(range(self.n)) - - -class LastTests(TestCase): - """Tests for ``last()``""" - - def test_many_nonsliceable(self): - """Test that it works on many-item non-slice-able iterables.""" - # Also try it on a generator expression to make sure it works on - # whatever those return, across Python versions. - self.assertEqual(mi.last(x for x in range(4)), 3) - - def test_one_nonsliceable(self): - """Test that it doesn't raise StopIteration prematurely.""" - self.assertEqual(mi.last(x for x in range(1)), 0) - - def test_empty_stop_iteration_nonsliceable(self): - """It should raise ValueError for empty non-slice-able iterables.""" - self.assertRaises(ValueError, lambda: mi.last(x for x in range(0))) - - def test_default_nonsliceable(self): - """It should return the provided default arg for empty non-slice-able - iterables. - """ - self.assertEqual(mi.last((x for x in range(0)), 'boo'), 'boo') - - def test_many_sliceable(self): - """Test that it works on many-item slice-able iterables.""" - self.assertEqual(mi.last([0, 1, 2, 3]), 3) - - def test_one_sliceable(self): - """Test that it doesn't raise StopIteration prematurely.""" - self.assertEqual(mi.last([3]), 3) - - def test_empty_stop_iteration_sliceable(self): - """It should raise ValueError for empty slice-able iterables.""" - self.assertRaises(ValueError, lambda: mi.last([])) - - def test_default_sliceable(self): - """It should return the provided default arg for empty slice-able - iterables. - """ - self.assertEqual(mi.last([], 'boo'), 'boo') - - def test_dict(self): - """last(dic) and last(dic.keys()) should return same result.""" - dic = {'a': 1, 'b': 2, 'c': 3} - self.assertEqual(mi.last(dic), mi.last(dic.keys())) - - def test_ordereddict(self): - """last(dic) should return the last key.""" - od = OrderedDict() - od['a'] = 1 - od['b'] = 2 - od['c'] = 3 - self.assertEqual(mi.last(od), 'c') - - def test_customrange(self): - """It should work on custom class where [] raises AttributeError.""" - self.assertEqual(mi.last(IterOnlyRange(5)), 4) - - -class PeekableTests(TestCase): - """Tests for ``peekable()`` behavor not incidentally covered by testing - ``collate()`` - - """ - def test_peek_default(self): - """Make sure passing a default into ``peek()`` works.""" - p = mi.peekable([]) - self.assertEqual(p.peek(7), 7) - - def test_truthiness(self): - """Make sure a ``peekable`` tests true iff there are items remaining in - the iterable. - - """ - p = mi.peekable([]) - self.assertFalse(p) - - p = mi.peekable(range(3)) - self.assertTrue(p) - - def test_simple_peeking(self): - """Make sure ``next`` and ``peek`` advance and don't advance the - iterator, respectively. - - """ - p = mi.peekable(range(10)) - self.assertEqual(next(p), 0) - self.assertEqual(p.peek(), 1) - self.assertEqual(next(p), 1) - - def test_indexing(self): - """ - Indexing into the peekable shouldn't advance the iterator. - """ - p = mi.peekable('abcdefghijkl') - - # The 0th index is what ``next()`` will return - self.assertEqual(p[0], 'a') - self.assertEqual(next(p), 'a') - - # Indexing further into the peekable shouldn't advance the itertor - self.assertEqual(p[2], 'd') - self.assertEqual(next(p), 'b') - - # The 0th index moves up with the iterator; the last index follows - self.assertEqual(p[0], 'c') - self.assertEqual(p[9], 'l') - - self.assertEqual(next(p), 'c') - self.assertEqual(p[8], 'l') - - # Negative indexing should work too - self.assertEqual(p[-2], 'k') - self.assertEqual(p[-9], 'd') - self.assertRaises(IndexError, lambda: p[-10]) - - def test_slicing(self): - """Slicing the peekable shouldn't advance the iterator.""" - seq = list('abcdefghijkl') - p = mi.peekable(seq) - - # Slicing the peekable should just be like slicing a re-iterable - self.assertEqual(p[1:4], seq[1:4]) - - # Advancing the iterator moves the slices up also - self.assertEqual(next(p), 'a') - self.assertEqual(p[1:4], seq[1:][1:4]) - - # Implicit starts and stop should work - self.assertEqual(p[:5], seq[1:][:5]) - self.assertEqual(p[:], seq[1:][:]) - - # Indexing past the end should work - self.assertEqual(p[:100], seq[1:][:100]) - - # Steps should work, including negative - self.assertEqual(p[::2], seq[1:][::2]) - self.assertEqual(p[::-1], seq[1:][::-1]) - - def test_slicing_reset(self): - """Test slicing on a fresh iterable each time""" - iterable = ['0', '1', '2', '3', '4', '5'] - indexes = list(range(-4, len(iterable) + 4)) + [None] - steps = [1, 2, 3, 4, -1, -2, -3, 4] - for slice_args in product(indexes, indexes, steps): - it = iter(iterable) - p = mi.peekable(it) - next(p) - index = slice(*slice_args) - actual = p[index] - expected = iterable[1:][index] - self.assertEqual(actual, expected, slice_args) - - def test_slicing_error(self): - iterable = '01234567' - p = mi.peekable(iter(iterable)) - - # Prime the cache - p.peek() - old_cache = list(p._cache) - - # Illegal slice - with self.assertRaises(ValueError): - p[1:-1:0] - - # Neither the cache nor the iteration should be affected - self.assertEqual(old_cache, list(p._cache)) - self.assertEqual(list(p), list(iterable)) - - def test_passthrough(self): - """Iterating a peekable without using ``peek()`` or ``prepend()`` - should just give the underlying iterable's elements (a trivial test but - useful to set a baseline in case something goes wrong)""" - expected = [1, 2, 3, 4, 5] - actual = list(mi.peekable(expected)) - self.assertEqual(actual, expected) - - # prepend() behavior tests - - def test_prepend(self): - """Tests intersperesed ``prepend()`` and ``next()`` calls""" - it = mi.peekable(range(2)) - actual = [] - - # Test prepend() before next() - it.prepend(10) - actual += [next(it), next(it)] - - # Test prepend() between next()s - it.prepend(11) - actual += [next(it), next(it)] - - # Test prepend() after source iterable is consumed - it.prepend(12) - actual += [next(it)] - - expected = [10, 0, 11, 1, 12] - self.assertEqual(actual, expected) - - def test_multi_prepend(self): - """Tests prepending multiple items and getting them in proper order""" - it = mi.peekable(range(5)) - actual = [next(it), next(it)] - it.prepend(10, 11, 12) - it.prepend(20, 21) - actual += list(it) - expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4] - self.assertEqual(actual, expected) - - def test_empty(self): - """Tests prepending in front of an empty iterable""" - it = mi.peekable([]) - it.prepend(10) - actual = list(it) - expected = [10] - self.assertEqual(actual, expected) - - def test_prepend_truthiness(self): - """Tests that ``__bool__()`` or ``__nonzero__()`` works properly - with ``prepend()``""" - it = mi.peekable(range(5)) - self.assertTrue(it) - actual = list(it) - self.assertFalse(it) - it.prepend(10) - self.assertTrue(it) - actual += [next(it)] - self.assertFalse(it) - expected = [0, 1, 2, 3, 4, 10] - self.assertEqual(actual, expected) - - def test_multi_prepend_peek(self): - """Tests prepending multiple elements and getting them in reverse order - while peeking""" - it = mi.peekable(range(5)) - actual = [next(it), next(it)] - self.assertEqual(it.peek(), 2) - it.prepend(10, 11, 12) - self.assertEqual(it.peek(), 10) - it.prepend(20, 21) - self.assertEqual(it.peek(), 20) - actual += list(it) - self.assertFalse(it) - expected = [0, 1, 20, 21, 10, 11, 12, 2, 3, 4] - self.assertEqual(actual, expected) - - def test_prepend_after_stop(self): - """Test resuming iteration after a previous exhaustion""" - it = mi.peekable(range(3)) - self.assertEqual(list(it), [0, 1, 2]) - self.assertRaises(StopIteration, lambda: next(it)) - it.prepend(10) - self.assertEqual(next(it), 10) - self.assertRaises(StopIteration, lambda: next(it)) - - def test_prepend_slicing(self): - """Tests interaction between prepending and slicing""" - seq = list(range(20)) - p = mi.peekable(seq) - - p.prepend(30, 40, 50) - pseq = [30, 40, 50] + seq # pseq for prepended_seq - - # adapt the specific tests from test_slicing - self.assertEqual(p[0], 30) - self.assertEqual(p[1:8], pseq[1:8]) - self.assertEqual(p[1:], pseq[1:]) - self.assertEqual(p[:5], pseq[:5]) - self.assertEqual(p[:], pseq[:]) - self.assertEqual(p[:100], pseq[:100]) - self.assertEqual(p[::2], pseq[::2]) - self.assertEqual(p[::-1], pseq[::-1]) - - def test_prepend_indexing(self): - """Tests interaction between prepending and indexing""" - seq = list(range(20)) - p = mi.peekable(seq) - - p.prepend(30, 40, 50) - - self.assertEqual(p[0], 30) - self.assertEqual(next(p), 30) - self.assertEqual(p[2], 0) - self.assertEqual(next(p), 40) - self.assertEqual(p[0], 50) - self.assertEqual(p[9], 8) - self.assertEqual(next(p), 50) - self.assertEqual(p[8], 8) - self.assertEqual(p[-2], 18) - self.assertEqual(p[-9], 11) - self.assertRaises(IndexError, lambda: p[-21]) - - def test_prepend_iterable(self): - """Tests prepending from an iterable""" - it = mi.peekable(range(5)) - # Don't directly use the range() object to avoid any range-specific - # optimizations - it.prepend(*(x for x in range(5))) - actual = list(it) - expected = list(chain(range(5), range(5))) - self.assertEqual(actual, expected) - - def test_prepend_many(self): - """Tests that prepending a huge number of elements works""" - it = mi.peekable(range(5)) - # Don't directly use the range() object to avoid any range-specific - # optimizations - it.prepend(*(x for x in range(20000))) - actual = list(it) - expected = list(chain(range(20000), range(5))) - self.assertEqual(actual, expected) - - def test_prepend_reversed(self): - """Tests prepending from a reversed iterable""" - it = mi.peekable(range(3)) - it.prepend(*reversed((10, 11, 12))) - actual = list(it) - expected = [12, 11, 10, 0, 1, 2] - self.assertEqual(actual, expected) - - -class ConsumerTests(TestCase): - """Tests for ``consumer()``""" - - def test_consumer(self): - @mi.consumer - def eater(): - while True: - x = yield # noqa - - e = eater() - e.send('hi') # without @consumer, would raise TypeError - - -class DistinctPermutationsTests(TestCase): - def test_distinct_permutations(self): - """Make sure the output for ``distinct_permutations()`` is the same as - set(permutations(it)). - - """ - iterable = ['z', 'a', 'a', 'q', 'q', 'q', 'y'] - test_output = sorted(mi.distinct_permutations(iterable)) - ref_output = sorted(set(permutations(iterable))) - self.assertEqual(test_output, ref_output) - - def test_other_iterables(self): - """Make sure ``distinct_permutations()`` accepts a different type of - iterables. - - """ - # a generator - iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y']) - test_output = sorted(mi.distinct_permutations(iterable)) - # "reload" it - iterable = (c for c in ['z', 'a', 'a', 'q', 'q', 'q', 'y']) - ref_output = sorted(set(permutations(iterable))) - self.assertEqual(test_output, ref_output) - - # an iterator - iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y']) - test_output = sorted(mi.distinct_permutations(iterable)) - # "reload" it - iterable = iter(['z', 'a', 'a', 'q', 'q', 'q', 'y']) - ref_output = sorted(set(permutations(iterable))) - self.assertEqual(test_output, ref_output) - - -class IlenTests(TestCase): - def test_ilen(self): - """Sanity-checks for ``ilen()``.""" - # Non-empty - self.assertEqual( - mi.ilen(filter(lambda x: x % 10 == 0, range(101))), 11 - ) - - # Empty - self.assertEqual(mi.ilen((x for x in range(0))), 0) - - # Iterable with __len__ - self.assertEqual(mi.ilen(list(range(6))), 6) - - -class WithIterTests(TestCase): - def test_with_iter(self): - s = StringIO('One fish\nTwo fish') - initial_words = [line.split()[0] for line in mi.with_iter(s)] - - # Iterable's items should be faithfully represented - self.assertEqual(initial_words, ['One', 'Two']) - # The file object should be closed - self.assertEqual(s.closed, True) - - -class OneTests(TestCase): - def test_basic(self): - it = iter(['item']) - self.assertEqual(mi.one(it), 'item') - - def test_too_short(self): - it = iter([]) - self.assertRaises(ValueError, lambda: mi.one(it)) - self.assertRaises(IndexError, lambda: mi.one(it, too_short=IndexError)) - - def test_too_long(self): - it = count() - self.assertRaises(ValueError, lambda: mi.one(it)) # burn 0 and 1 - self.assertEqual(next(it), 2) - self.assertRaises( - OverflowError, lambda: mi.one(it, too_long=OverflowError) - ) - - -class IntersperseTest(TestCase): - """ Tests for intersperse() """ - - def test_even(self): - iterable = (x for x in '01') - self.assertEqual( - list(mi.intersperse(None, iterable)), ['0', None, '1'] - ) - - def test_odd(self): - iterable = (x for x in '012') - self.assertEqual( - list(mi.intersperse(None, iterable)), ['0', None, '1', None, '2'] - ) - - def test_nested(self): - element = ('a', 'b') - iterable = (x for x in '012') - actual = list(mi.intersperse(element, iterable)) - expected = ['0', ('a', 'b'), '1', ('a', 'b'), '2'] - self.assertEqual(actual, expected) - - def test_not_iterable(self): - self.assertRaises(TypeError, lambda: mi.intersperse('x', 1)) - - def test_n(self): - for n, element, expected in [ - (1, '_', ['0', '_', '1', '_', '2', '_', '3', '_', '4', '_', '5']), - (2, '_', ['0', '1', '_', '2', '3', '_', '4', '5']), - (3, '_', ['0', '1', '2', '_', '3', '4', '5']), - (4, '_', ['0', '1', '2', '3', '_', '4', '5']), - (5, '_', ['0', '1', '2', '3', '4', '_', '5']), - (6, '_', ['0', '1', '2', '3', '4', '5']), - (7, '_', ['0', '1', '2', '3', '4', '5']), - (3, ['a', 'b'], ['0', '1', '2', ['a', 'b'], '3', '4', '5']), - ]: - iterable = (x for x in '012345') - actual = list(mi.intersperse(element, iterable, n=n)) - self.assertEqual(actual, expected) - - def test_n_zero(self): - self.assertRaises( - ValueError, lambda: list(mi.intersperse('x', '012', n=0)) - ) - - -class UniqueToEachTests(TestCase): - """Tests for ``unique_to_each()``""" - - def test_all_unique(self): - """When all the input iterables are unique the output should match - the input.""" - iterables = [[1, 2], [3, 4, 5], [6, 7, 8]] - self.assertEqual(mi.unique_to_each(*iterables), iterables) - - def test_duplicates(self): - """When there are duplicates in any of the input iterables that aren't - in the rest, those duplicates should be emitted.""" - iterables = ["mississippi", "missouri"] - self.assertEqual( - mi.unique_to_each(*iterables), [['p', 'p'], ['o', 'u', 'r']] - ) - - def test_mixed(self): - """When the input iterables contain different types the function should - still behave properly""" - iterables = ['x', (i for i in range(3)), [1, 2, 3], tuple()] - self.assertEqual(mi.unique_to_each(*iterables), [['x'], [0], [3], []]) - - -class WindowedTests(TestCase): - """Tests for ``windowed()``""" - - def test_basic(self): - actual = list(mi.windowed([1, 2, 3, 4, 5], 3)) - expected = [(1, 2, 3), (2, 3, 4), (3, 4, 5)] - self.assertEqual(actual, expected) - - def test_large_size(self): - """ - When the window size is larger than the iterable, and no fill value is - given,``None`` should be filled in. - """ - actual = list(mi.windowed([1, 2, 3, 4, 5], 6)) - expected = [(1, 2, 3, 4, 5, None)] - self.assertEqual(actual, expected) - - def test_fillvalue(self): - """ - When sizes don't match evenly, the given fill value should be used. - """ - iterable = [1, 2, 3, 4, 5] - - for n, kwargs, expected in [ - (6, {}, [(1, 2, 3, 4, 5, '!')]), # n > len(iterable) - (3, {'step': 3}, [(1, 2, 3), (4, 5, '!')]), # using ``step`` - ]: - actual = list(mi.windowed(iterable, n, fillvalue='!', **kwargs)) - self.assertEqual(actual, expected) - - def test_zero(self): - """When the window size is zero, an empty tuple should be emitted.""" - actual = list(mi.windowed([1, 2, 3, 4, 5], 0)) - expected = [tuple()] - self.assertEqual(actual, expected) - - def test_negative(self): - """When the window size is negative, ValueError should be raised.""" - with self.assertRaises(ValueError): - list(mi.windowed([1, 2, 3, 4, 5], -1)) - - def test_step(self): - """The window should advance by the number of steps provided""" - iterable = [1, 2, 3, 4, 5, 6, 7] - for n, step, expected in [ - (3, 2, [(1, 2, 3), (3, 4, 5), (5, 6, 7)]), # n > step - (3, 3, [(1, 2, 3), (4, 5, 6), (7, None, None)]), # n == step - (3, 4, [(1, 2, 3), (5, 6, 7)]), # line up nicely - (3, 5, [(1, 2, 3), (6, 7, None)]), # off by one - (3, 6, [(1, 2, 3), (7, None, None)]), # off by two - (3, 7, [(1, 2, 3)]), # step past the end - (7, 8, [(1, 2, 3, 4, 5, 6, 7)]), # step > len(iterable) - ]: - actual = list(mi.windowed(iterable, n, step=step)) - self.assertEqual(actual, expected) - - # Step must be greater than or equal to 1 - with self.assertRaises(ValueError): - list(mi.windowed(iterable, 3, step=0)) - - -class BucketTests(TestCase): - """Tests for ``bucket()``""" - - def test_basic(self): - iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33] - D = mi.bucket(iterable, key=lambda x: 10 * (x // 10)) - - # In-order access - self.assertEqual(list(D[10]), [10, 11, 12]) - - # Out of order access - self.assertEqual(list(D[30]), [30, 31, 33]) - self.assertEqual(list(D[20]), [20, 21, 22, 23]) - - self.assertEqual(list(D[40]), []) # Nothing in here! - - def test_in(self): - iterable = [10, 20, 30, 11, 21, 31, 12, 22, 23, 33] - D = mi.bucket(iterable, key=lambda x: 10 * (x // 10)) - - self.assertTrue(10 in D) - self.assertFalse(40 in D) - self.assertTrue(20 in D) - self.assertFalse(21 in D) - - # Checking in-ness shouldn't advance the iterator - self.assertEqual(next(D[10]), 10) - - def test_validator(self): - iterable = count(0) - key = lambda x: int(str(x)[0]) # First digit of each number - validator = lambda x: 0 < x < 10 # No leading zeros - D = mi.bucket(iterable, key, validator=validator) - self.assertEqual(mi.take(3, D[1]), [1, 10, 11]) - self.assertNotIn(0, D) # Non-valid entries don't return True - self.assertNotIn(0, D._cache) # Don't store non-valid entries - self.assertEqual(list(D[0]), []) - - -class SpyTests(TestCase): - """Tests for ``spy()``""" - - def test_basic(self): - original_iterable = iter('abcdefg') - head, new_iterable = mi.spy(original_iterable) - self.assertEqual(head, ['a']) - self.assertEqual( - list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g'] - ) - - def test_unpacking(self): - original_iterable = iter('abcdefg') - (first, second, third), new_iterable = mi.spy(original_iterable, 3) - self.assertEqual(first, 'a') - self.assertEqual(second, 'b') - self.assertEqual(third, 'c') - self.assertEqual( - list(new_iterable), ['a', 'b', 'c', 'd', 'e', 'f', 'g'] - ) - - def test_too_many(self): - original_iterable = iter('abc') - head, new_iterable = mi.spy(original_iterable, 4) - self.assertEqual(head, ['a', 'b', 'c']) - self.assertEqual(list(new_iterable), ['a', 'b', 'c']) - - def test_zero(self): - original_iterable = iter('abc') - head, new_iterable = mi.spy(original_iterable, 0) - self.assertEqual(head, []) - self.assertEqual(list(new_iterable), ['a', 'b', 'c']) - - -class InterleaveTests(TestCase): - def test_even(self): - actual = list(mi.interleave([1, 4, 7], [2, 5, 8], [3, 6, 9])) - expected = [1, 2, 3, 4, 5, 6, 7, 8, 9] - self.assertEqual(actual, expected) - - def test_short(self): - actual = list(mi.interleave([1, 4], [2, 5, 7], [3, 6, 8])) - expected = [1, 2, 3, 4, 5, 6] - self.assertEqual(actual, expected) - - def test_mixed_types(self): - it_list = ['a', 'b', 'c', 'd'] - it_str = '12345' - it_inf = count() - actual = list(mi.interleave(it_list, it_str, it_inf)) - expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', 3] - self.assertEqual(actual, expected) - - -class InterleaveLongestTests(TestCase): - def test_even(self): - actual = list(mi.interleave_longest([1, 4, 7], [2, 5, 8], [3, 6, 9])) - expected = [1, 2, 3, 4, 5, 6, 7, 8, 9] - self.assertEqual(actual, expected) - - def test_short(self): - actual = list(mi.interleave_longest([1, 4], [2, 5, 7], [3, 6, 8])) - expected = [1, 2, 3, 4, 5, 6, 7, 8] - self.assertEqual(actual, expected) - - def test_mixed_types(self): - it_list = ['a', 'b', 'c', 'd'] - it_str = '12345' - it_gen = (x for x in range(3)) - actual = list(mi.interleave_longest(it_list, it_str, it_gen)) - expected = ['a', '1', 0, 'b', '2', 1, 'c', '3', 2, 'd', '4', '5'] - self.assertEqual(actual, expected) - - -class TestCollapse(TestCase): - """Tests for ``collapse()``""" - - def test_collapse(self): - l = [[1], 2, [[3], 4], [[[5]]]] - self.assertEqual(list(mi.collapse(l)), [1, 2, 3, 4, 5]) - - def test_collapse_to_string(self): - l = [["s1"], "s2", [["s3"], "s4"], [[["s5"]]]] - self.assertEqual(list(mi.collapse(l)), ["s1", "s2", "s3", "s4", "s5"]) - - def test_collapse_flatten(self): - l = [[1], [2], [[3], 4], [[[5]]]] - self.assertEqual(list(mi.collapse(l, levels=1)), list(mi.flatten(l))) - - def test_collapse_to_level(self): - l = [[1], 2, [[3], 4], [[[5]]]] - self.assertEqual(list(mi.collapse(l, levels=2)), [1, 2, 3, 4, [5]]) - self.assertEqual( - list(mi.collapse(mi.collapse(l, levels=1), levels=1)), - list(mi.collapse(l, levels=2)) - ) - - def test_collapse_to_list(self): - l = (1, [2], (3, [4, (5,)], 'ab')) - actual = list(mi.collapse(l, base_type=list)) - expected = [1, [2], 3, [4, (5,)], 'ab'] - self.assertEqual(actual, expected) - - -class SideEffectTests(TestCase): - """Tests for ``side_effect()``""" - - def test_individual(self): - # The function increments the counter for each call - counter = [0] - - def func(arg): - counter[0] += 1 - - result = list(mi.side_effect(func, range(10))) - self.assertEqual(result, list(range(10))) - self.assertEqual(counter[0], 10) - - def test_chunked(self): - # The function increments the counter for each call - counter = [0] - - def func(arg): - counter[0] += 1 - - result = list(mi.side_effect(func, range(10), 2)) - self.assertEqual(result, list(range(10))) - self.assertEqual(counter[0], 5) - - def test_before_after(self): - f = StringIO() - collector = [] - - def func(item): - print(item, file=f) - collector.append(f.getvalue()) - - def it(): - yield u'a' - yield u'b' - raise RuntimeError('kaboom') - - before = lambda: print('HEADER', file=f) - after = f.close - - try: - mi.consume(mi.side_effect(func, it(), before=before, after=after)) - except RuntimeError: - pass - - # The iterable should have been written to the file - self.assertEqual(collector, [u'HEADER\na\n', u'HEADER\na\nb\n']) - - # The file should be closed even though something bad happened - self.assertTrue(f.closed) - - def test_before_fails(self): - f = StringIO() - func = lambda x: print(x, file=f) - - def before(): - raise RuntimeError('ouch') - - try: - mi.consume( - mi.side_effect(func, u'abc', before=before, after=f.close) - ) - except RuntimeError: - pass - - # The file should be closed even though something bad happened in the - # before function - self.assertTrue(f.closed) - - -class SlicedTests(TestCase): - """Tests for ``sliced()``""" - - def test_even(self): - """Test when the length of the sequence is divisible by *n*""" - seq = 'ABCDEFGHI' - self.assertEqual(list(mi.sliced(seq, 3)), ['ABC', 'DEF', 'GHI']) - - def test_odd(self): - """Test when the length of the sequence is not divisible by *n*""" - seq = 'ABCDEFGHI' - self.assertEqual(list(mi.sliced(seq, 4)), ['ABCD', 'EFGH', 'I']) - - def test_not_sliceable(self): - seq = (x for x in 'ABCDEFGHI') - - with self.assertRaises(TypeError): - list(mi.sliced(seq, 3)) - - -class SplitAtTests(TestCase): - """Tests for ``split()``""" - - def comp_with_str_split(self, str_to_split, delim): - pred = lambda c: c == delim - actual = list(map(''.join, mi.split_at(str_to_split, pred))) - expected = str_to_split.split(delim) - self.assertEqual(actual, expected) - - def test_seperators(self): - test_strs = ['', 'abcba', 'aaabbbcccddd', 'e'] - for s, delim in product(test_strs, 'abcd'): - self.comp_with_str_split(s, delim) - - -class SplitBeforeTest(TestCase): - """Tests for ``split_before()``""" - - def test_starts_with_sep(self): - actual = list(mi.split_before('xooxoo', lambda c: c == 'x')) - expected = [['x', 'o', 'o'], ['x', 'o', 'o']] - self.assertEqual(actual, expected) - - def test_ends_with_sep(self): - actual = list(mi.split_before('ooxoox', lambda c: c == 'x')) - expected = [['o', 'o'], ['x', 'o', 'o'], ['x']] - self.assertEqual(actual, expected) - - def test_no_sep(self): - actual = list(mi.split_before('ooo', lambda c: c == 'x')) - expected = [['o', 'o', 'o']] - self.assertEqual(actual, expected) - - -class SplitAfterTest(TestCase): - """Tests for ``split_after()``""" - - def test_starts_with_sep(self): - actual = list(mi.split_after('xooxoo', lambda c: c == 'x')) - expected = [['x'], ['o', 'o', 'x'], ['o', 'o']] - self.assertEqual(actual, expected) - - def test_ends_with_sep(self): - actual = list(mi.split_after('ooxoox', lambda c: c == 'x')) - expected = [['o', 'o', 'x'], ['o', 'o', 'x']] - self.assertEqual(actual, expected) - - def test_no_sep(self): - actual = list(mi.split_after('ooo', lambda c: c == 'x')) - expected = [['o', 'o', 'o']] - self.assertEqual(actual, expected) - - -class PaddedTest(TestCase): - """Tests for ``padded()``""" - - def test_no_n(self): - seq = [1, 2, 3] - - # No fillvalue - self.assertEqual(mi.take(5, mi.padded(seq)), [1, 2, 3, None, None]) - - # With fillvalue - self.assertEqual( - mi.take(5, mi.padded(seq, fillvalue='')), [1, 2, 3, '', ''] - ) - - def test_invalid_n(self): - self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=-1))) - self.assertRaises(ValueError, lambda: list(mi.padded([1, 2, 3], n=0))) - - def test_valid_n(self): - seq = [1, 2, 3, 4, 5] - - # No need for padding: len(seq) <= n - self.assertEqual(list(mi.padded(seq, n=4)), [1, 2, 3, 4, 5]) - self.assertEqual(list(mi.padded(seq, n=5)), [1, 2, 3, 4, 5]) - - # No fillvalue - self.assertEqual( - list(mi.padded(seq, n=7)), [1, 2, 3, 4, 5, None, None] - ) - - # With fillvalue - self.assertEqual( - list(mi.padded(seq, fillvalue='', n=7)), [1, 2, 3, 4, 5, '', ''] - ) - - def test_next_multiple(self): - seq = [1, 2, 3, 4, 5, 6] - - # No need for padding: len(seq) % n == 0 - self.assertEqual( - list(mi.padded(seq, n=3, next_multiple=True)), [1, 2, 3, 4, 5, 6] - ) - - # Padding needed: len(seq) < n - self.assertEqual( - list(mi.padded(seq, n=8, next_multiple=True)), - [1, 2, 3, 4, 5, 6, None, None] - ) - - # No padding needed: len(seq) == n - self.assertEqual( - list(mi.padded(seq, n=6, next_multiple=True)), [1, 2, 3, 4, 5, 6] - ) - - # Padding needed: len(seq) > n - self.assertEqual( - list(mi.padded(seq, n=4, next_multiple=True)), - [1, 2, 3, 4, 5, 6, None, None] - ) - - # With fillvalue - self.assertEqual( - list(mi.padded(seq, fillvalue='', n=4, next_multiple=True)), - [1, 2, 3, 4, 5, 6, '', ''] - ) - - -class DistributeTest(TestCase): - """Tests for distribute()""" - - def test_invalid_n(self): - self.assertRaises(ValueError, lambda: mi.distribute(-1, [1, 2, 3])) - self.assertRaises(ValueError, lambda: mi.distribute(0, [1, 2, 3])) - - def test_basic(self): - iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - - for n, expected in [ - (1, [iterable]), - (2, [[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]), - (3, [[1, 4, 7, 10], [2, 5, 8], [3, 6, 9]]), - (10, [[n] for n in range(1, 10 + 1)]), - ]: - self.assertEqual( - [list(x) for x in mi.distribute(n, iterable)], expected - ) - - def test_large_n(self): - iterable = [1, 2, 3, 4] - self.assertEqual( - [list(x) for x in mi.distribute(6, iterable)], - [[1], [2], [3], [4], [], []] - ) - - -class StaggerTest(TestCase): - """Tests for ``stagger()``""" - - def test_default(self): - iterable = [0, 1, 2, 3] - actual = list(mi.stagger(iterable)) - expected = [(None, 0, 1), (0, 1, 2), (1, 2, 3)] - self.assertEqual(actual, expected) - - def test_offsets(self): - iterable = [0, 1, 2, 3] - for offsets, expected in [ - ((-2, 0, 2), [('', 0, 2), ('', 1, 3)]), - ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3)]), - ((1, 2), [(1, 2), (2, 3)]), - ]: - all_groups = mi.stagger(iterable, offsets=offsets, fillvalue='') - self.assertEqual(list(all_groups), expected) - - def test_longest(self): - iterable = [0, 1, 2, 3] - for offsets, expected in [ - ( - (-1, 0, 1), - [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, ''), (3, '', '')] - ), - ((-2, -1), [('', ''), ('', 0), (0, 1), (1, 2), (2, 3), (3, '')]), - ((1, 2), [(1, 2), (2, 3), (3, '')]), - ]: - all_groups = mi.stagger( - iterable, offsets=offsets, fillvalue='', longest=True - ) - self.assertEqual(list(all_groups), expected) - - -class ZipOffsetTest(TestCase): - """Tests for ``zip_offset()``""" - - def test_shortest(self): - a_1 = [0, 1, 2, 3] - a_2 = [0, 1, 2, 3, 4, 5] - a_3 = [0, 1, 2, 3, 4, 5, 6, 7] - actual = list( - mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), fillvalue='') - ) - expected = [('', 0, 1), (0, 1, 2), (1, 2, 3), (2, 3, 4), (3, 4, 5)] - self.assertEqual(actual, expected) - - def test_longest(self): - a_1 = [0, 1, 2, 3] - a_2 = [0, 1, 2, 3, 4, 5] - a_3 = [0, 1, 2, 3, 4, 5, 6, 7] - actual = list( - mi.zip_offset(a_1, a_2, a_3, offsets=(-1, 0, 1), longest=True) - ) - expected = [ - (None, 0, 1), - (0, 1, 2), - (1, 2, 3), - (2, 3, 4), - (3, 4, 5), - (None, 5, 6), - (None, None, 7), - ] - self.assertEqual(actual, expected) - - def test_mismatch(self): - iterables = [0, 1, 2], [2, 3, 4] - offsets = (-1, 0, 1) - self.assertRaises( - ValueError, - lambda: list(mi.zip_offset(*iterables, offsets=offsets)) - ) - - -class SortTogetherTest(TestCase): - """Tests for sort_together()""" - - def test_key_list(self): - """tests `key_list` including default, iterables include duplicates""" - iterables = [ - ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], - ['May', 'Aug.', 'May', 'June', 'July', 'July'], - [97, 20, 100, 70, 100, 20] - ] - - self.assertEqual( - mi.sort_together(iterables), - [ - ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), - ('June', 'July', 'July', 'May', 'Aug.', 'May'), - (70, 100, 20, 97, 20, 100) - ] - ) - - self.assertEqual( - mi.sort_together(iterables, key_list=(0, 1)), - [ - ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), - ('July', 'July', 'June', 'Aug.', 'May', 'May'), - (100, 20, 70, 20, 97, 100) - ] - ) - - self.assertEqual( - mi.sort_together(iterables, key_list=(0, 1, 2)), - [ - ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), - ('July', 'July', 'June', 'Aug.', 'May', 'May'), - (20, 100, 70, 20, 97, 100) - ] - ) - - self.assertEqual( - mi.sort_together(iterables, key_list=(2,)), - [ - ('GA', 'CT', 'CT', 'GA', 'GA', 'CT'), - ('Aug.', 'July', 'June', 'May', 'May', 'July'), - (20, 20, 70, 97, 100, 100) - ] - ) - - def test_invalid_key_list(self): - """tests `key_list` for indexes not available in `iterables`""" - iterables = [ - ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], - ['May', 'Aug.', 'May', 'June', 'July', 'July'], - [97, 20, 100, 70, 100, 20] - ] - - self.assertRaises( - IndexError, lambda: mi.sort_together(iterables, key_list=(5,)) - ) - - def test_reverse(self): - """tests `reverse` to ensure a reverse sort for `key_list` iterables""" - iterables = [ - ['GA', 'GA', 'GA', 'CT', 'CT', 'CT'], - ['May', 'Aug.', 'May', 'June', 'July', 'July'], - [97, 20, 100, 70, 100, 20] - ] - - self.assertEqual( - mi.sort_together(iterables, key_list=(0, 1, 2), reverse=True), - [('GA', 'GA', 'GA', 'CT', 'CT', 'CT'), - ('May', 'May', 'Aug.', 'June', 'July', 'July'), - (100, 97, 20, 70, 100, 20)] - ) - - def test_uneven_iterables(self): - """tests trimming of iterables to the shortest length before sorting""" - iterables = [['GA', 'GA', 'GA', 'CT', 'CT', 'CT', 'MA'], - ['May', 'Aug.', 'May', 'June', 'July', 'July'], - [97, 20, 100, 70, 100, 20, 0]] - - self.assertEqual( - mi.sort_together(iterables), - [ - ('CT', 'CT', 'CT', 'GA', 'GA', 'GA'), - ('June', 'July', 'July', 'May', 'Aug.', 'May'), - (70, 100, 20, 97, 20, 100) - ] - ) - - -class DivideTest(TestCase): - """Tests for divide()""" - - def test_invalid_n(self): - self.assertRaises(ValueError, lambda: mi.divide(-1, [1, 2, 3])) - self.assertRaises(ValueError, lambda: mi.divide(0, [1, 2, 3])) - - def test_basic(self): - iterable = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - - for n, expected in [ - (1, [iterable]), - (2, [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]), - (3, [[1, 2, 3, 4], [5, 6, 7], [8, 9, 10]]), - (10, [[n] for n in range(1, 10 + 1)]), - ]: - self.assertEqual( - [list(x) for x in mi.divide(n, iterable)], expected - ) - - def test_large_n(self): - iterable = [1, 2, 3, 4] - self.assertEqual( - [list(x) for x in mi.divide(6, iterable)], - [[1], [2], [3], [4], [], []] - ) - - -class TestAlwaysIterable(TestCase): - """Tests for always_iterable()""" - def test_single(self): - self.assertEqual(list(mi.always_iterable(1)), [1]) - - def test_strings(self): - for obj in ['foo', b'bar', u'baz']: - actual = list(mi.always_iterable(obj)) - expected = [obj] - self.assertEqual(actual, expected) - - def test_base_type(self): - dict_obj = {'a': 1, 'b': 2} - str_obj = '123' - - # Default: dicts are iterable like they normally are - default_actual = list(mi.always_iterable(dict_obj)) - default_expected = list(dict_obj) - self.assertEqual(default_actual, default_expected) - - # Unitary types set: dicts are not iterable - custom_actual = list(mi.always_iterable(dict_obj, base_type=dict)) - custom_expected = [dict_obj] - self.assertEqual(custom_actual, custom_expected) - - # With unitary types set, strings are iterable - str_actual = list(mi.always_iterable(str_obj, base_type=None)) - str_expected = list(str_obj) - self.assertEqual(str_actual, str_expected) - - def test_iterables(self): - self.assertEqual(list(mi.always_iterable([0, 1])), [0, 1]) - self.assertEqual( - list(mi.always_iterable([0, 1], base_type=list)), [[0, 1]] - ) - self.assertEqual( - list(mi.always_iterable(iter('foo'))), ['f', 'o', 'o'] - ) - self.assertEqual(list(mi.always_iterable([])), []) - - def test_none(self): - self.assertEqual(list(mi.always_iterable(None)), []) - - def test_generator(self): - def _gen(): - yield 0 - yield 1 - - self.assertEqual(list(mi.always_iterable(_gen())), [0, 1]) - - -class AdjacentTests(TestCase): - def test_typical(self): - actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10))) - expected = [(True, 0), (True, 1), (False, 2), (False, 3), (True, 4), - (True, 5), (True, 6), (False, 7), (False, 8), (False, 9)] - self.assertEqual(actual, expected) - - def test_empty_iterable(self): - actual = list(mi.adjacent(lambda x: x % 5 == 0, [])) - expected = [] - self.assertEqual(actual, expected) - - def test_length_one(self): - actual = list(mi.adjacent(lambda x: x % 5 == 0, [0])) - expected = [(True, 0)] - self.assertEqual(actual, expected) - - actual = list(mi.adjacent(lambda x: x % 5 == 0, [1])) - expected = [(False, 1)] - self.assertEqual(actual, expected) - - def test_consecutive_true(self): - """Test that when the predicate matches multiple consecutive elements - it doesn't repeat elements in the output""" - actual = list(mi.adjacent(lambda x: x % 5 < 2, range(10))) - expected = [(True, 0), (True, 1), (True, 2), (False, 3), (True, 4), - (True, 5), (True, 6), (True, 7), (False, 8), (False, 9)] - self.assertEqual(actual, expected) - - def test_distance(self): - actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=2)) - expected = [(True, 0), (True, 1), (True, 2), (True, 3), (True, 4), - (True, 5), (True, 6), (True, 7), (False, 8), (False, 9)] - self.assertEqual(actual, expected) - - actual = list(mi.adjacent(lambda x: x % 5 == 0, range(10), distance=3)) - expected = [(True, 0), (True, 1), (True, 2), (True, 3), (True, 4), - (True, 5), (True, 6), (True, 7), (True, 8), (False, 9)] - self.assertEqual(actual, expected) - - def test_large_distance(self): - """Test distance larger than the length of the iterable""" - iterable = range(10) - actual = list(mi.adjacent(lambda x: x % 5 == 4, iterable, distance=20)) - expected = list(zip(repeat(True), iterable)) - self.assertEqual(actual, expected) - - actual = list(mi.adjacent(lambda x: False, iterable, distance=20)) - expected = list(zip(repeat(False), iterable)) - self.assertEqual(actual, expected) - - def test_zero_distance(self): - """Test that adjacent() reduces to zip+map when distance is 0""" - iterable = range(1000) - predicate = lambda x: x % 4 == 2 - actual = mi.adjacent(predicate, iterable, 0) - expected = zip(map(predicate, iterable), iterable) - self.assertTrue(all(a == e for a, e in zip(actual, expected))) - - def test_negative_distance(self): - """Test that adjacent() raises an error with negative distance""" - pred = lambda x: x - self.assertRaises( - ValueError, lambda: mi.adjacent(pred, range(1000), -1) - ) - self.assertRaises( - ValueError, lambda: mi.adjacent(pred, range(10), -10) - ) - - def test_grouping(self): - """Test interaction of adjacent() with groupby_transform()""" - iterable = mi.adjacent(lambda x: x % 5 == 0, range(10)) - grouper = mi.groupby_transform(iterable, itemgetter(0), itemgetter(1)) - actual = [(k, list(g)) for k, g in grouper] - expected = [ - (True, [0, 1]), - (False, [2, 3]), - (True, [4, 5, 6]), - (False, [7, 8, 9]), - ] - self.assertEqual(actual, expected) - - def test_call_once(self): - """Test that the predicate is only called once per item.""" - already_seen = set() - iterable = range(10) - - def predicate(item): - self.assertNotIn(item, already_seen) - already_seen.add(item) - return True - - actual = list(mi.adjacent(predicate, iterable)) - expected = [(True, x) for x in iterable] - self.assertEqual(actual, expected) - - -class GroupByTransformTests(TestCase): - def assertAllGroupsEqual(self, groupby1, groupby2): - """Compare two groupby objects for equality, both keys and groups.""" - for a, b in zip(groupby1, groupby2): - key1, group1 = a - key2, group2 = b - self.assertEqual(key1, key2) - self.assertListEqual(list(group1), list(group2)) - self.assertRaises(StopIteration, lambda: next(groupby1)) - self.assertRaises(StopIteration, lambda: next(groupby2)) - - def test_default_funcs(self): - """Test that groupby_transform() with default args mimics groupby()""" - iterable = [(x // 5, x) for x in range(1000)] - actual = mi.groupby_transform(iterable) - expected = groupby(iterable) - self.assertAllGroupsEqual(actual, expected) - - def test_valuefunc(self): - iterable = [(int(x / 5), int(x / 3), x) for x in range(10)] - - # Test the standard usage of grouping one iterable using another's keys - grouper = mi.groupby_transform( - iterable, keyfunc=itemgetter(0), valuefunc=itemgetter(-1) - ) - actual = [(k, list(g)) for k, g in grouper] - expected = [(0, [0, 1, 2, 3, 4]), (1, [5, 6, 7, 8, 9])] - self.assertEqual(actual, expected) - - grouper = mi.groupby_transform( - iterable, keyfunc=itemgetter(1), valuefunc=itemgetter(-1) - ) - actual = [(k, list(g)) for k, g in grouper] - expected = [(0, [0, 1, 2]), (1, [3, 4, 5]), (2, [6, 7, 8]), (3, [9])] - self.assertEqual(actual, expected) - - # and now for something a little different - d = dict(zip(range(10), 'abcdefghij')) - grouper = mi.groupby_transform( - range(10), keyfunc=lambda x: x // 5, valuefunc=d.get - ) - actual = [(k, ''.join(g)) for k, g in grouper] - expected = [(0, 'abcde'), (1, 'fghij')] - self.assertEqual(actual, expected) - - def test_no_valuefunc(self): - iterable = range(1000) - - def key(x): - return x // 5 - - actual = mi.groupby_transform(iterable, key, valuefunc=None) - expected = groupby(iterable, key) - self.assertAllGroupsEqual(actual, expected) - - actual = mi.groupby_transform(iterable, key) # default valuefunc - expected = groupby(iterable, key) - self.assertAllGroupsEqual(actual, expected) - - -class NumericRangeTests(TestCase): - def test_basic(self): - for args, expected in [ - ((4,), [0, 1, 2, 3]), - ((4.0,), [0.0, 1.0, 2.0, 3.0]), - ((1.0, 4), [1.0, 2.0, 3.0]), - ((1, 4.0), [1, 2, 3]), - ((1.0, 5), [1.0, 2.0, 3.0, 4.0]), - ((0, 20, 5), [0, 5, 10, 15]), - ((0, 20, 5.0), [0.0, 5.0, 10.0, 15.0]), - ((0, 10, 3), [0, 3, 6, 9]), - ((0, 10, 3.0), [0.0, 3.0, 6.0, 9.0]), - ((0, -5, -1), [0, -1, -2, -3, -4]), - ((0.0, -5, -1), [0.0, -1.0, -2.0, -3.0, -4.0]), - ((1, 2, Fraction(1, 2)), [Fraction(1, 1), Fraction(3, 2)]), - ((0,), []), - ((0.0,), []), - ((1, 0), []), - ((1.0, 0.0), []), - ((Fraction(2, 1),), [Fraction(0, 1), Fraction(1, 1)]), - ((Decimal('2.0'),), [Decimal('0.0'), Decimal('1.0')]), - ]: - actual = list(mi.numeric_range(*args)) - self.assertEqual(actual, expected) - self.assertTrue( - all(type(a) == type(e) for a, e in zip(actual, expected)) - ) - - def test_arg_count(self): - self.assertRaises(TypeError, lambda: list(mi.numeric_range())) - self.assertRaises( - TypeError, lambda: list(mi.numeric_range(0, 1, 2, 3)) - ) - - def test_zero_step(self): - self.assertRaises( - ValueError, lambda: list(mi.numeric_range(1, 2, 0)) - ) - - -class CountCycleTests(TestCase): - def test_basic(self): - expected = [ - (0, 'a'), (0, 'b'), (0, 'c'), - (1, 'a'), (1, 'b'), (1, 'c'), - (2, 'a'), (2, 'b'), (2, 'c'), - ] - for actual in [ - mi.take(9, mi.count_cycle('abc')), # n=None - list(mi.count_cycle('abc', 3)), # n=3 - ]: - self.assertEqual(actual, expected) - - def test_empty(self): - self.assertEqual(list(mi.count_cycle('')), []) - self.assertEqual(list(mi.count_cycle('', 2)), []) - - def test_negative(self): - self.assertEqual(list(mi.count_cycle('abc', -3)), []) - - -class LocateTests(TestCase): - def test_default_pred(self): - iterable = [0, 1, 1, 0, 1, 0, 0] - actual = list(mi.locate(iterable)) - expected = [1, 2, 4] - self.assertEqual(actual, expected) - - def test_no_matches(self): - iterable = [0, 0, 0] - actual = list(mi.locate(iterable)) - expected = [] - self.assertEqual(actual, expected) - - def test_custom_pred(self): - iterable = ['0', 1, 1, '0', 1, '0', '0'] - pred = lambda x: x == '0' - actual = list(mi.locate(iterable, pred)) - expected = [0, 3, 5, 6] - self.assertEqual(actual, expected) - - def test_window_size(self): - iterable = ['0', 1, 1, '0', 1, '0', '0'] - pred = lambda *args: args == ('0', 1) - actual = list(mi.locate(iterable, pred, window_size=2)) - expected = [0, 3] - self.assertEqual(actual, expected) - - def test_window_size_large(self): - iterable = [1, 2, 3, 4] - pred = lambda a, b, c, d, e: True - actual = list(mi.locate(iterable, pred, window_size=5)) - expected = [0] - self.assertEqual(actual, expected) - - def test_window_size_zero(self): - iterable = [1, 2, 3, 4] - pred = lambda: True - with self.assertRaises(ValueError): - list(mi.locate(iterable, pred, window_size=0)) - - -class StripFunctionTests(TestCase): - def test_hashable(self): - iterable = list('www.example.com') - pred = lambda x: x in set('cmowz.') - - self.assertEqual(list(mi.lstrip(iterable, pred)), list('example.com')) - self.assertEqual(list(mi.rstrip(iterable, pred)), list('www.example')) - self.assertEqual(list(mi.strip(iterable, pred)), list('example')) - - def test_not_hashable(self): - iterable = [ - list('http://'), list('www'), list('.example'), list('.com') - ] - pred = lambda x: x in [list('http://'), list('www'), list('.com')] - - self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[2:]) - self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:3]) - self.assertEqual(list(mi.strip(iterable, pred)), iterable[2: 3]) - - def test_math(self): - iterable = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2] - pred = lambda x: x <= 2 - - self.assertEqual(list(mi.lstrip(iterable, pred)), iterable[3:]) - self.assertEqual(list(mi.rstrip(iterable, pred)), iterable[:-3]) - self.assertEqual(list(mi.strip(iterable, pred)), iterable[3:-3]) - - -class IsliceExtendedTests(TestCase): - def test_all(self): - iterable = ['0', '1', '2', '3', '4', '5'] - indexes = list(range(-4, len(iterable) + 4)) + [None] - steps = [1, 2, 3, 4, -1, -2, -3, 4] - for slice_args in product(indexes, indexes, steps): - try: - actual = list(mi.islice_extended(iterable, *slice_args)) - except Exception as e: - self.fail((slice_args, e)) - - expected = iterable[slice(*slice_args)] - self.assertEqual(actual, expected, slice_args) - - def test_zero_step(self): - with self.assertRaises(ValueError): - list(mi.islice_extended([1, 2, 3], 0, 1, 0)) - - -class ConsecutiveGroupsTest(TestCase): - def test_numbers(self): - iterable = [-10, -8, -7, -6, 1, 2, 4, 5, -1, 7] - actual = [list(g) for g in mi.consecutive_groups(iterable)] - expected = [[-10], [-8, -7, -6], [1, 2], [4, 5], [-1], [7]] - self.assertEqual(actual, expected) - - def test_custom_ordering(self): - iterable = ['1', '10', '11', '20', '21', '22', '30', '31'] - ordering = lambda x: int(x) - actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)] - expected = [['1'], ['10', '11'], ['20', '21', '22'], ['30', '31']] - self.assertEqual(actual, expected) - - def test_exotic_ordering(self): - iterable = [ - ('a', 'b', 'c', 'd'), - ('a', 'c', 'b', 'd'), - ('a', 'c', 'd', 'b'), - ('a', 'd', 'b', 'c'), - ('d', 'b', 'c', 'a'), - ('d', 'c', 'a', 'b'), - ] - ordering = list(permutations('abcd')).index - actual = [list(g) for g in mi.consecutive_groups(iterable, ordering)] - expected = [ - [('a', 'b', 'c', 'd')], - [('a', 'c', 'b', 'd'), ('a', 'c', 'd', 'b'), ('a', 'd', 'b', 'c')], - [('d', 'b', 'c', 'a'), ('d', 'c', 'a', 'b')], - ] - self.assertEqual(actual, expected) - - -class DifferenceTest(TestCase): - def test_normal(self): - iterable = [10, 20, 30, 40, 50] - actual = list(mi.difference(iterable)) - expected = [10, 10, 10, 10, 10] - self.assertEqual(actual, expected) - - def test_custom(self): - iterable = [10, 20, 30, 40, 50] - actual = list(mi.difference(iterable, add)) - expected = [10, 30, 50, 70, 90] - self.assertEqual(actual, expected) - - def test_roundtrip(self): - original = list(range(100)) - accumulated = mi.accumulate(original) - actual = list(mi.difference(accumulated)) - self.assertEqual(actual, original) - - def test_one(self): - self.assertEqual(list(mi.difference([0])), [0]) - - def test_empty(self): - self.assertEqual(list(mi.difference([])), []) - - -class SeekableTest(TestCase): - def test_exhaustion_reset(self): - iterable = [str(n) for n in range(10)] - - s = mi.seekable(iterable) - self.assertEqual(list(s), iterable) # Normal iteration - self.assertEqual(list(s), []) # Iterable is exhausted - - s.seek(0) - self.assertEqual(list(s), iterable) # Back in action - - def test_partial_reset(self): - iterable = [str(n) for n in range(10)] - - s = mi.seekable(iterable) - self.assertEqual(mi.take(5, s), iterable[:5]) # Normal iteration - - s.seek(1) - self.assertEqual(list(s), iterable[1:]) # Get the rest of the iterable - - def test_forward(self): - iterable = [str(n) for n in range(10)] - - s = mi.seekable(iterable) - self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration - - s.seek(3) # Skip over index 2 - self.assertEqual(list(s), iterable[3:]) # Result is similar to slicing - - s.seek(0) # Back to 0 - self.assertEqual(list(s), iterable) # No difference in result - - def test_past_end(self): - iterable = [str(n) for n in range(10)] - - s = mi.seekable(iterable) - self.assertEqual(mi.take(1, s), iterable[:1]) # Normal iteration - - s.seek(20) - self.assertEqual(list(s), []) # Iterable is exhausted - - s.seek(0) # Back to 0 - self.assertEqual(list(s), iterable) # No difference in result - - def test_elements(self): - iterable = map(str, count()) - - s = mi.seekable(iterable) - mi.take(10, s) - - elements = s.elements() - self.assertEqual( - [elements[i] for i in range(10)], [str(n) for n in range(10)] - ) - self.assertEqual(len(elements), 10) - - mi.take(10, s) - self.assertEqual(list(elements), [str(n) for n in range(20)]) - - -class SequenceViewTests(TestCase): - def test_init(self): - view = mi.SequenceView((1, 2, 3)) - self.assertEqual(repr(view), "SequenceView((1, 2, 3))") - self.assertRaises(TypeError, lambda: mi.SequenceView({})) - - def test_update(self): - seq = [1, 2, 3] - view = mi.SequenceView(seq) - self.assertEqual(len(view), 3) - self.assertEqual(repr(view), "SequenceView([1, 2, 3])") - - seq.pop() - self.assertEqual(len(view), 2) - self.assertEqual(repr(view), "SequenceView([1, 2])") - - def test_indexing(self): - seq = ('a', 'b', 'c', 'd', 'e', 'f') - view = mi.SequenceView(seq) - for i in range(-len(seq), len(seq)): - self.assertEqual(view[i], seq[i]) - - def test_slicing(self): - seq = ('a', 'b', 'c', 'd', 'e', 'f') - view = mi.SequenceView(seq) - n = len(seq) - indexes = list(range(-n - 1, n + 1)) + [None] - steps = list(range(-n, n + 1)) - steps.remove(0) - for slice_args in product(indexes, indexes, steps): - i = slice(*slice_args) - self.assertEqual(view[i], seq[i]) - - def test_abc_methods(self): - # collections.Sequence should provide all of this functionality - seq = ('a', 'b', 'c', 'd', 'e', 'f', 'f') - view = mi.SequenceView(seq) - - # __contains__ - self.assertIn('b', view) - self.assertNotIn('g', view) - - # __iter__ - self.assertEqual(list(iter(view)), list(seq)) - - # __reversed__ - self.assertEqual(list(reversed(view)), list(reversed(seq))) - - # index - self.assertEqual(view.index('b'), 1) - - # count - self.assertEqual(seq.count('f'), 2) - - -class RunLengthTest(TestCase): - def test_encode(self): - iterable = (int(str(n)[0]) for n in count(800)) - actual = mi.take(4, mi.run_length.encode(iterable)) - expected = [(8, 100), (9, 100), (1, 1000), (2, 1000)] - self.assertEqual(actual, expected) - - def test_decode(self): - iterable = [('d', 4), ('c', 3), ('b', 2), ('a', 1)] - actual = ''.join(mi.run_length.decode(iterable)) - expected = 'ddddcccbba' - self.assertEqual(actual, expected) - - -class ExactlyNTests(TestCase): - """Tests for ``exactly_n()``""" - - def test_true(self): - """Iterable has ``n`` ``True`` elements""" - self.assertTrue(mi.exactly_n([True, False, True], 2)) - self.assertTrue(mi.exactly_n([1, 1, 1, 0], 3)) - self.assertTrue(mi.exactly_n([False, False], 0)) - self.assertTrue(mi.exactly_n(range(100), 10, lambda x: x < 10)) - - def test_false(self): - """Iterable does not have ``n`` ``True`` elements""" - self.assertFalse(mi.exactly_n([True, False, False], 2)) - self.assertFalse(mi.exactly_n([True, True, False], 1)) - self.assertFalse(mi.exactly_n([False], 1)) - self.assertFalse(mi.exactly_n([True], -1)) - self.assertFalse(mi.exactly_n(repeat(True), 100)) - - def test_empty(self): - """Return ``True`` if the iterable is empty and ``n`` is 0""" - self.assertTrue(mi.exactly_n([], 0)) - self.assertFalse(mi.exactly_n([], 1)) - - -class AlwaysReversibleTests(TestCase): - """Tests for ``always_reversible()``""" - - def test_regular_reversed(self): - self.assertEqual(list(reversed(range(10))), - list(mi.always_reversible(range(10)))) - self.assertEqual(list(reversed([1, 2, 3])), - list(mi.always_reversible([1, 2, 3]))) - self.assertEqual(reversed([1, 2, 3]).__class__, - mi.always_reversible([1, 2, 3]).__class__) - - def test_nonseq_reversed(self): - # Create a non-reversible generator from a sequence - with self.assertRaises(TypeError): - reversed(x for x in range(10)) - - self.assertEqual(list(reversed(range(10))), - list(mi.always_reversible(x for x in range(10)))) - self.assertEqual(list(reversed([1, 2, 3])), - list(mi.always_reversible(x for x in [1, 2, 3]))) - self.assertNotEqual(reversed((1, 2)).__class__, - mi.always_reversible(x for x in (1, 2)).__class__) - - -class CircularShiftsTests(TestCase): - def test_empty(self): - # empty iterable -> empty list - self.assertEqual(list(mi.circular_shifts([])), []) - - def test_simple_circular_shifts(self): - # test the a simple iterator case - self.assertEqual( - mi.circular_shifts(range(4)), - [(0, 1, 2, 3), (1, 2, 3, 0), (2, 3, 0, 1), (3, 0, 1, 2)] - ) - - def test_duplicates(self): - # test non-distinct entries - self.assertEqual( - mi.circular_shifts([0, 1, 0, 1]), - [(0, 1, 0, 1), (1, 0, 1, 0), (0, 1, 0, 1), (1, 0, 1, 0)] - ) - - -class MakeDecoratorTests(TestCase): - def test_basic(self): - slicer = mi.make_decorator(islice) - - @slicer(1, 10, 2) - def user_function(arg_1, arg_2, kwarg_1=None): - self.assertEqual(arg_1, 'arg_1') - self.assertEqual(arg_2, 'arg_2') - self.assertEqual(kwarg_1, 'kwarg_1') - return map(str, count()) - - it = user_function('arg_1', 'arg_2', kwarg_1='kwarg_1') - actual = list(it) - expected = ['1', '3', '5', '7', '9'] - self.assertEqual(actual, expected) - - def test_result_index(self): - def stringify(*args, **kwargs): - self.assertEqual(args[0], 'arg_0') - iterable = args[1] - self.assertEqual(args[2], 'arg_2') - self.assertEqual(kwargs['kwarg_1'], 'kwarg_1') - return map(str, iterable) - - stringifier = mi.make_decorator(stringify, result_index=1) - - @stringifier('arg_0', 'arg_2', kwarg_1='kwarg_1') - def user_function(n): - return count(n) - - it = user_function(1) - actual = mi.take(5, it) - expected = ['1', '2', '3', '4', '5'] - self.assertEqual(actual, expected) - - def test_wrap_class(self): - seeker = mi.make_decorator(mi.seekable) - - @seeker() - def user_function(n): - return map(str, range(n)) - - it = user_function(5) - self.assertEqual(list(it), ['0', '1', '2', '3', '4']) - - it.seek(0) - self.assertEqual(list(it), ['0', '1', '2', '3', '4']) - - -class MapReduceTests(TestCase): - def test_default(self): - iterable = (str(x) for x in range(5)) - keyfunc = lambda x: int(x) // 2 - actual = sorted(mi.map_reduce(iterable, keyfunc).items()) - expected = [(0, ['0', '1']), (1, ['2', '3']), (2, ['4'])] - self.assertEqual(actual, expected) - - def test_valuefunc(self): - iterable = (str(x) for x in range(5)) - keyfunc = lambda x: int(x) // 2 - valuefunc = int - actual = sorted(mi.map_reduce(iterable, keyfunc, valuefunc).items()) - expected = [(0, [0, 1]), (1, [2, 3]), (2, [4])] - self.assertEqual(actual, expected) - - def test_reducefunc(self): - iterable = (str(x) for x in range(5)) - keyfunc = lambda x: int(x) // 2 - valuefunc = int - reducefunc = lambda value_list: reduce(mul, value_list, 1) - actual = sorted( - mi.map_reduce(iterable, keyfunc, valuefunc, reducefunc).items() - ) - expected = [(0, 0), (1, 6), (2, 4)] - self.assertEqual(actual, expected) - - def test_ret(self): - d = mi.map_reduce([1, 0, 2, 0, 1, 0], bool) - self.assertEqual(d, {False: [0, 0, 0], True: [1, 2, 1]}) - self.assertRaises(KeyError, lambda: d[None].append(1)) - - -class RlocateTests(TestCase): - def test_default_pred(self): - iterable = [0, 1, 1, 0, 1, 0, 0] - for it in (iterable[:], iter(iterable)): - actual = list(mi.rlocate(it)) - expected = [4, 2, 1] - self.assertEqual(actual, expected) - - def test_no_matches(self): - iterable = [0, 0, 0] - for it in (iterable[:], iter(iterable)): - actual = list(mi.rlocate(it)) - expected = [] - self.assertEqual(actual, expected) - - def test_custom_pred(self): - iterable = ['0', 1, 1, '0', 1, '0', '0'] - pred = lambda x: x == '0' - for it in (iterable[:], iter(iterable)): - actual = list(mi.rlocate(it, pred)) - expected = [6, 5, 3, 0] - self.assertEqual(actual, expected) - - def test_efficient_reversal(self): - iterable = range(10 ** 10) # Is efficiently reversible - target = 10 ** 10 - 2 - pred = lambda x: x == target # Find-able from the right - actual = next(mi.rlocate(iterable, pred)) - self.assertEqual(actual, target) - - def test_window_size(self): - iterable = ['0', 1, 1, '0', 1, '0', '0'] - pred = lambda *args: args == ('0', 1) - for it in (iterable, iter(iterable)): - actual = list(mi.rlocate(it, pred, window_size=2)) - expected = [3, 0] - self.assertEqual(actual, expected) - - def test_window_size_large(self): - iterable = [1, 2, 3, 4] - pred = lambda a, b, c, d, e: True - for it in (iterable, iter(iterable)): - actual = list(mi.rlocate(iterable, pred, window_size=5)) - expected = [0] - self.assertEqual(actual, expected) - - def test_window_size_zero(self): - iterable = [1, 2, 3, 4] - pred = lambda: True - for it in (iterable, iter(iterable)): - with self.assertRaises(ValueError): - list(mi.locate(iterable, pred, window_size=0)) - - -class ReplaceTests(TestCase): - def test_basic(self): - iterable = range(10) - pred = lambda x: x % 2 == 0 - substitutes = [] - actual = list(mi.replace(iterable, pred, substitutes)) - expected = [1, 3, 5, 7, 9] - self.assertEqual(actual, expected) - - def test_count(self): - iterable = range(10) - pred = lambda x: x % 2 == 0 - substitutes = [] - actual = list(mi.replace(iterable, pred, substitutes, count=4)) - expected = [1, 3, 5, 7, 8, 9] - self.assertEqual(actual, expected) - - def test_window_size(self): - iterable = range(10) - pred = lambda *args: args == (0, 1, 2) - substitutes = [] - actual = list(mi.replace(iterable, pred, substitutes, window_size=3)) - expected = [3, 4, 5, 6, 7, 8, 9] - self.assertEqual(actual, expected) - - def test_window_size_end(self): - iterable = range(10) - pred = lambda *args: args == (7, 8, 9) - substitutes = [] - actual = list(mi.replace(iterable, pred, substitutes, window_size=3)) - expected = [0, 1, 2, 3, 4, 5, 6] - self.assertEqual(actual, expected) - - def test_window_size_count(self): - iterable = range(10) - pred = lambda *args: (args == (0, 1, 2)) or (args == (7, 8, 9)) - substitutes = [] - actual = list( - mi.replace(iterable, pred, substitutes, count=1, window_size=3) - ) - expected = [3, 4, 5, 6, 7, 8, 9] - self.assertEqual(actual, expected) - - def test_window_size_large(self): - iterable = range(4) - pred = lambda a, b, c, d, e: True - substitutes = [5, 6, 7] - actual = list(mi.replace(iterable, pred, substitutes, window_size=5)) - expected = [5, 6, 7] - self.assertEqual(actual, expected) - - def test_window_size_zero(self): - iterable = range(10) - pred = lambda *args: True - substitutes = [] - with self.assertRaises(ValueError): - list(mi.replace(iterable, pred, substitutes, window_size=0)) - - def test_iterable_substitutes(self): - iterable = range(5) - pred = lambda x: x % 2 == 0 - substitutes = iter('__') - actual = list(mi.replace(iterable, pred, substitutes)) - expected = ['_', '_', 1, '_', '_', 3, '_', '_'] - self.assertEqual(actual, expected) diff --git a/libs/win/more_itertools/tests/test_recipes.py b/libs/win/more_itertools/tests/test_recipes.py deleted file mode 100644 index 98981fe8..00000000 --- a/libs/win/more_itertools/tests/test_recipes.py +++ /dev/null @@ -1,616 +0,0 @@ -from doctest import DocTestSuite -from unittest import TestCase - -from itertools import combinations -from six.moves import range - -import more_itertools as mi - - -def load_tests(loader, tests, ignore): - # Add the doctests - tests.addTests(DocTestSuite('more_itertools.recipes')) - return tests - - -class AccumulateTests(TestCase): - """Tests for ``accumulate()``""" - - def test_empty(self): - """Test that an empty input returns an empty output""" - self.assertEqual(list(mi.accumulate([])), []) - - def test_default(self): - """Test accumulate with the default function (addition)""" - self.assertEqual(list(mi.accumulate([1, 2, 3])), [1, 3, 6]) - - def test_bogus_function(self): - """Test accumulate with an invalid function""" - with self.assertRaises(TypeError): - list(mi.accumulate([1, 2, 3], func=lambda x: x)) - - def test_custom_function(self): - """Test accumulate with a custom function""" - self.assertEqual( - list(mi.accumulate((1, 2, 3, 2, 1), func=max)), [1, 2, 3, 3, 3] - ) - - -class TakeTests(TestCase): - """Tests for ``take()``""" - - def test_simple_take(self): - """Test basic usage""" - t = mi.take(5, range(10)) - self.assertEqual(t, [0, 1, 2, 3, 4]) - - def test_null_take(self): - """Check the null case""" - t = mi.take(0, range(10)) - self.assertEqual(t, []) - - def test_negative_take(self): - """Make sure taking negative items results in a ValueError""" - self.assertRaises(ValueError, lambda: mi.take(-3, range(10))) - - def test_take_too_much(self): - """Taking more than an iterator has remaining should return what the - iterator has remaining. - - """ - t = mi.take(10, range(5)) - self.assertEqual(t, [0, 1, 2, 3, 4]) - - -class TabulateTests(TestCase): - """Tests for ``tabulate()``""" - - def test_simple_tabulate(self): - """Test the happy path""" - t = mi.tabulate(lambda x: x) - f = tuple([next(t) for _ in range(3)]) - self.assertEqual(f, (0, 1, 2)) - - def test_count(self): - """Ensure tabulate accepts specific count""" - t = mi.tabulate(lambda x: 2 * x, -1) - f = (next(t), next(t), next(t)) - self.assertEqual(f, (-2, 0, 2)) - - -class TailTests(TestCase): - """Tests for ``tail()``""" - - def test_greater(self): - """Length of iterable is greather than requested tail""" - self.assertEqual(list(mi.tail(3, 'ABCDEFG')), ['E', 'F', 'G']) - - def test_equal(self): - """Length of iterable is equal to the requested tail""" - self.assertEqual( - list(mi.tail(7, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] - ) - - def test_less(self): - """Length of iterable is less than requested tail""" - self.assertEqual( - list(mi.tail(8, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] - ) - - -class ConsumeTests(TestCase): - """Tests for ``consume()``""" - - def test_sanity(self): - """Test basic functionality""" - r = (x for x in range(10)) - mi.consume(r, 3) - self.assertEqual(3, next(r)) - - def test_null_consume(self): - """Check the null case""" - r = (x for x in range(10)) - mi.consume(r, 0) - self.assertEqual(0, next(r)) - - def test_negative_consume(self): - """Check that negative consumsion throws an error""" - r = (x for x in range(10)) - self.assertRaises(ValueError, lambda: mi.consume(r, -1)) - - def test_total_consume(self): - """Check that iterator is totally consumed by default""" - r = (x for x in range(10)) - mi.consume(r) - self.assertRaises(StopIteration, lambda: next(r)) - - -class NthTests(TestCase): - """Tests for ``nth()``""" - - def test_basic(self): - """Make sure the nth item is returned""" - l = range(10) - for i, v in enumerate(l): - self.assertEqual(mi.nth(l, i), v) - - def test_default(self): - """Ensure a default value is returned when nth item not found""" - l = range(3) - self.assertEqual(mi.nth(l, 100, "zebra"), "zebra") - - def test_negative_item_raises(self): - """Ensure asking for a negative item raises an exception""" - self.assertRaises(ValueError, lambda: mi.nth(range(10), -3)) - - -class AllEqualTests(TestCase): - """Tests for ``all_equal()``""" - - def test_true(self): - """Everything is equal""" - self.assertTrue(mi.all_equal('aaaaaa')) - self.assertTrue(mi.all_equal([0, 0, 0, 0])) - - def test_false(self): - """Not everything is equal""" - self.assertFalse(mi.all_equal('aaaaab')) - self.assertFalse(mi.all_equal([0, 0, 0, 1])) - - def test_tricky(self): - """Not everything is identical, but everything is equal""" - items = [1, complex(1, 0), 1.0] - self.assertTrue(mi.all_equal(items)) - - def test_empty(self): - """Return True if the iterable is empty""" - self.assertTrue(mi.all_equal('')) - self.assertTrue(mi.all_equal([])) - - def test_one(self): - """Return True if the iterable is singular""" - self.assertTrue(mi.all_equal('0')) - self.assertTrue(mi.all_equal([0])) - - -class QuantifyTests(TestCase): - """Tests for ``quantify()``""" - - def test_happy_path(self): - """Make sure True count is returned""" - q = [True, False, True] - self.assertEqual(mi.quantify(q), 2) - - def test_custom_predicate(self): - """Ensure non-default predicates return as expected""" - q = range(10) - self.assertEqual(mi.quantify(q, lambda x: x % 2 == 0), 5) - - -class PadnoneTests(TestCase): - """Tests for ``padnone()``""" - - def test_happy_path(self): - """wrapper iterator should return None indefinitely""" - r = range(2) - p = mi.padnone(r) - self.assertEqual([0, 1, None, None], [next(p) for _ in range(4)]) - - -class NcyclesTests(TestCase): - """Tests for ``nyclces()``""" - - def test_happy_path(self): - """cycle a sequence three times""" - r = ["a", "b", "c"] - n = mi.ncycles(r, 3) - self.assertEqual( - ["a", "b", "c", "a", "b", "c", "a", "b", "c"], - list(n) - ) - - def test_null_case(self): - """asking for 0 cycles should return an empty iterator""" - n = mi.ncycles(range(100), 0) - self.assertRaises(StopIteration, lambda: next(n)) - - def test_pathalogical_case(self): - """asking for negative cycles should return an empty iterator""" - n = mi.ncycles(range(100), -10) - self.assertRaises(StopIteration, lambda: next(n)) - - -class DotproductTests(TestCase): - """Tests for ``dotproduct()``'""" - - def test_happy_path(self): - """simple dotproduct example""" - self.assertEqual(400, mi.dotproduct([10, 10], [20, 20])) - - -class FlattenTests(TestCase): - """Tests for ``flatten()``""" - - def test_basic_usage(self): - """ensure list of lists is flattened one level""" - f = [[0, 1, 2], [3, 4, 5]] - self.assertEqual(list(range(6)), list(mi.flatten(f))) - - def test_single_level(self): - """ensure list of lists is flattened only one level""" - f = [[0, [1, 2]], [[3, 4], 5]] - self.assertEqual([0, [1, 2], [3, 4], 5], list(mi.flatten(f))) - - -class RepeatfuncTests(TestCase): - """Tests for ``repeatfunc()``""" - - def test_simple_repeat(self): - """test simple repeated functions""" - r = mi.repeatfunc(lambda: 5) - self.assertEqual([5, 5, 5, 5, 5], [next(r) for _ in range(5)]) - - def test_finite_repeat(self): - """ensure limited repeat when times is provided""" - r = mi.repeatfunc(lambda: 5, times=5) - self.assertEqual([5, 5, 5, 5, 5], list(r)) - - def test_added_arguments(self): - """ensure arguments are applied to the function""" - r = mi.repeatfunc(lambda x: x, 2, 3) - self.assertEqual([3, 3], list(r)) - - def test_null_times(self): - """repeat 0 should return an empty iterator""" - r = mi.repeatfunc(range, 0, 3) - self.assertRaises(StopIteration, lambda: next(r)) - - -class PairwiseTests(TestCase): - """Tests for ``pairwise()``""" - - def test_base_case(self): - """ensure an iterable will return pairwise""" - p = mi.pairwise([1, 2, 3]) - self.assertEqual([(1, 2), (2, 3)], list(p)) - - def test_short_case(self): - """ensure an empty iterator if there's not enough values to pair""" - p = mi.pairwise("a") - self.assertRaises(StopIteration, lambda: next(p)) - - -class GrouperTests(TestCase): - """Tests for ``grouper()``""" - - def test_even(self): - """Test when group size divides evenly into the length of - the iterable. - - """ - self.assertEqual( - list(mi.grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')] - ) - - def test_odd(self): - """Test when group size does not divide evenly into the length of the - iterable. - - """ - self.assertEqual( - list(mi.grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)] - ) - - def test_fill_value(self): - """Test that the fill value is used to pad the final group""" - self.assertEqual( - list(mi.grouper(3, 'ABCDE', 'x')), - [('A', 'B', 'C'), ('D', 'E', 'x')] - ) - - -class RoundrobinTests(TestCase): - """Tests for ``roundrobin()``""" - - def test_even_groups(self): - """Ensure ordered output from evenly populated iterables""" - self.assertEqual( - list(mi.roundrobin('ABC', [1, 2, 3], range(3))), - ['A', 1, 0, 'B', 2, 1, 'C', 3, 2] - ) - - def test_uneven_groups(self): - """Ensure ordered output from unevenly populated iterables""" - self.assertEqual( - list(mi.roundrobin('ABCD', [1, 2], range(0))), - ['A', 1, 'B', 2, 'C', 'D'] - ) - - -class PartitionTests(TestCase): - """Tests for ``partition()``""" - - def test_bool(self): - """Test when pred() returns a boolean""" - lesser, greater = mi.partition(lambda x: x > 5, range(10)) - self.assertEqual(list(lesser), [0, 1, 2, 3, 4, 5]) - self.assertEqual(list(greater), [6, 7, 8, 9]) - - def test_arbitrary(self): - """Test when pred() returns an integer""" - divisibles, remainders = mi.partition(lambda x: x % 3, range(10)) - self.assertEqual(list(divisibles), [0, 3, 6, 9]) - self.assertEqual(list(remainders), [1, 2, 4, 5, 7, 8]) - - -class PowersetTests(TestCase): - """Tests for ``powerset()``""" - - def test_combinatorics(self): - """Ensure a proper enumeration""" - p = mi.powerset([1, 2, 3]) - self.assertEqual( - list(p), - [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] - ) - - -class UniqueEverseenTests(TestCase): - """Tests for ``unique_everseen()``""" - - def test_everseen(self): - """ensure duplicate elements are ignored""" - u = mi.unique_everseen('AAAABBBBCCDAABBB') - self.assertEqual( - ['A', 'B', 'C', 'D'], - list(u) - ) - - def test_custom_key(self): - """ensure the custom key comparison works""" - u = mi.unique_everseen('aAbACCc', key=str.lower) - self.assertEqual(list('abC'), list(u)) - - def test_unhashable(self): - """ensure things work for unhashable items""" - iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] - u = mi.unique_everseen(iterable) - self.assertEqual(list(u), ['a', [1, 2, 3]]) - - def test_unhashable_key(self): - """ensure things work for unhashable items with a custom key""" - iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] - u = mi.unique_everseen(iterable, key=lambda x: x) - self.assertEqual(list(u), ['a', [1, 2, 3]]) - - -class UniqueJustseenTests(TestCase): - """Tests for ``unique_justseen()``""" - - def test_justseen(self): - """ensure only last item is remembered""" - u = mi.unique_justseen('AAAABBBCCDABB') - self.assertEqual(list('ABCDAB'), list(u)) - - def test_custom_key(self): - """ensure the custom key comparison works""" - u = mi.unique_justseen('AABCcAD', str.lower) - self.assertEqual(list('ABCAD'), list(u)) - - -class IterExceptTests(TestCase): - """Tests for ``iter_except()``""" - - def test_exact_exception(self): - """ensure the exact specified exception is caught""" - l = [1, 2, 3] - i = mi.iter_except(l.pop, IndexError) - self.assertEqual(list(i), [3, 2, 1]) - - def test_generic_exception(self): - """ensure the generic exception can be caught""" - l = [1, 2] - i = mi.iter_except(l.pop, Exception) - self.assertEqual(list(i), [2, 1]) - - def test_uncaught_exception_is_raised(self): - """ensure a non-specified exception is raised""" - l = [1, 2, 3] - i = mi.iter_except(l.pop, KeyError) - self.assertRaises(IndexError, lambda: list(i)) - - def test_first(self): - """ensure first is run before the function""" - l = [1, 2, 3] - f = lambda: 25 - i = mi.iter_except(l.pop, IndexError, f) - self.assertEqual(list(i), [25, 3, 2, 1]) - - -class FirstTrueTests(TestCase): - """Tests for ``first_true()``""" - - def test_something_true(self): - """Test with no keywords""" - self.assertEqual(mi.first_true(range(10)), 1) - - def test_nothing_true(self): - """Test default return value.""" - self.assertEqual(mi.first_true([0, 0, 0]), False) - - def test_default(self): - """Test with a default keyword""" - self.assertEqual(mi.first_true([0, 0, 0], default='!'), '!') - - def test_pred(self): - """Test with a custom predicate""" - self.assertEqual( - mi.first_true([2, 4, 6], pred=lambda x: x % 3 == 0), 6 - ) - - -class RandomProductTests(TestCase): - """Tests for ``random_product()`` - - Since random.choice() has different results with the same seed across - python versions 2.x and 3.x, these tests use highly probably events to - create predictable outcomes across platforms. - """ - - def test_simple_lists(self): - """Ensure that one item is chosen from each list in each pair. - Also ensure that each item from each list eventually appears in - the chosen combinations. - - Odds are roughly 1 in 7.1 * 10e16 that one item from either list will - not be chosen after 100 samplings of one item from each list. Just to - be safe, better use a known random seed, too. - - """ - nums = [1, 2, 3] - lets = ['a', 'b', 'c'] - n, m = zip(*[mi.random_product(nums, lets) for _ in range(100)]) - n, m = set(n), set(m) - self.assertEqual(n, set(nums)) - self.assertEqual(m, set(lets)) - self.assertEqual(len(n), len(nums)) - self.assertEqual(len(m), len(lets)) - - def test_list_with_repeat(self): - """ensure multiple items are chosen, and that they appear to be chosen - from one list then the next, in proper order. - - """ - nums = [1, 2, 3] - lets = ['a', 'b', 'c'] - r = list(mi.random_product(nums, lets, repeat=100)) - self.assertEqual(2 * 100, len(r)) - n, m = set(r[::2]), set(r[1::2]) - self.assertEqual(n, set(nums)) - self.assertEqual(m, set(lets)) - self.assertEqual(len(n), len(nums)) - self.assertEqual(len(m), len(lets)) - - -class RandomPermutationTests(TestCase): - """Tests for ``random_permutation()``""" - - def test_full_permutation(self): - """ensure every item from the iterable is returned in a new ordering - - 15 elements have a 1 in 1.3 * 10e12 of appearing in sorted order, so - we fix a seed value just to be sure. - - """ - i = range(15) - r = mi.random_permutation(i) - self.assertEqual(set(i), set(r)) - if i == r: - raise AssertionError("Values were not permuted") - - def test_partial_permutation(self): - """ensure all returned items are from the iterable, that the returned - permutation is of the desired length, and that all items eventually - get returned. - - Sampling 100 permutations of length 5 from a set of 15 leaves a - (2/3)^100 chance that an item will not be chosen. Multiplied by 15 - items, there is a 1 in 2.6e16 chance that at least 1 item will not - show up in the resulting output. Using a random seed will fix that. - - """ - items = range(15) - item_set = set(items) - all_items = set() - for _ in range(100): - permutation = mi.random_permutation(items, 5) - self.assertEqual(len(permutation), 5) - permutation_set = set(permutation) - self.assertLessEqual(permutation_set, item_set) - all_items |= permutation_set - self.assertEqual(all_items, item_set) - - -class RandomCombinationTests(TestCase): - """Tests for ``random_combination()``""" - - def test_psuedorandomness(self): - """ensure different subsets of the iterable get returned over many - samplings of random combinations""" - items = range(15) - all_items = set() - for _ in range(50): - combination = mi.random_combination(items, 5) - all_items |= set(combination) - self.assertEqual(all_items, set(items)) - - def test_no_replacement(self): - """ensure that elements are sampled without replacement""" - items = range(15) - for _ in range(50): - combination = mi.random_combination(items, len(items)) - self.assertEqual(len(combination), len(set(combination))) - self.assertRaises( - ValueError, lambda: mi.random_combination(items, len(items) + 1) - ) - - -class RandomCombinationWithReplacementTests(TestCase): - """Tests for ``random_combination_with_replacement()``""" - - def test_replacement(self): - """ensure that elements are sampled with replacement""" - items = range(5) - combo = mi.random_combination_with_replacement(items, len(items) * 2) - self.assertEqual(2 * len(items), len(combo)) - if len(set(combo)) == len(combo): - raise AssertionError("Combination contained no duplicates") - - def test_pseudorandomness(self): - """ensure different subsets of the iterable get returned over many - samplings of random combinations""" - items = range(15) - all_items = set() - for _ in range(50): - combination = mi.random_combination_with_replacement(items, 5) - all_items |= set(combination) - self.assertEqual(all_items, set(items)) - - -class NthCombinationTests(TestCase): - def test_basic(self): - iterable = 'abcdefg' - r = 4 - for index, expected in enumerate(combinations(iterable, r)): - actual = mi.nth_combination(iterable, r, index) - self.assertEqual(actual, expected) - - def test_long(self): - actual = mi.nth_combination(range(180), 4, 2000000) - expected = (2, 12, 35, 126) - self.assertEqual(actual, expected) - - def test_invalid_r(self): - for r in (-1, 3): - with self.assertRaises(ValueError): - mi.nth_combination([], r, 0) - - def test_invalid_index(self): - with self.assertRaises(IndexError): - mi.nth_combination('abcdefg', 3, -36) - - -class PrependTests(TestCase): - def test_basic(self): - value = 'a' - iterator = iter('bcdefg') - actual = list(mi.prepend(value, iterator)) - expected = list('abcdefg') - self.assertEqual(actual, expected) - - def test_multiple(self): - value = 'ab' - iterator = iter('cdefg') - actual = tuple(mi.prepend(value, iterator)) - expected = ('ab',) + tuple('cdefg') - self.assertEqual(actual, expected) diff --git a/libs/win/path.py b/libs/win/path/__init__.py similarity index 53% rename from libs/win/path.py rename to libs/win/path/__init__.py index 69ac5c13..f16857f8 100644 --- a/libs/win/path.py +++ b/libs/win/path/__init__.py @@ -1,7 +1,8 @@ """ -path.py - An object representing a path to a file or directory. +Path Pie -https://github.com/jaraco/path.py +Implements ``path.Path`` - An object representing a +path to a file or directory. Example:: @@ -21,8 +22,6 @@ Example:: foo_txt = Path("bar") / "foo.txt" """ -from __future__ import unicode_literals - import sys import warnings import os @@ -33,183 +32,89 @@ import hashlib import errno import tempfile import functools -import operator import re import contextlib import io import importlib import itertools -import platform -import ntpath -try: +with contextlib.suppress(ImportError): import win32security -except ImportError: - pass -try: +with contextlib.suppress(ImportError): import pwd -except ImportError: - pass -try: +with contextlib.suppress(ImportError): import grp -except ImportError: - pass -############################################################################## -# Python 2/3 support -PY3 = sys.version_info >= (3,) -PY2 = not PY3 - -string_types = str, -text_type = str -getcwdu = os.getcwd +from . import matchers +from . import masks +from . import classes +from .py37compat import best_realpath, lru_cache -if PY2: - import __builtin__ - string_types = __builtin__.basestring, - text_type = __builtin__.unicode - getcwdu = os.getcwdu - map = itertools.imap - filter = itertools.ifilter - FileNotFoundError = OSError - itertools.filterfalse = itertools.ifilterfalse - - -@contextlib.contextmanager -def io_error_compat(): - try: - yield - except IOError as io_err: - # On Python 2, io.open raises IOError; transform to OSError for - # future compatibility. - os_err = OSError(*io_err.args) - os_err.filename = getattr(io_err, 'filename', None) - raise os_err - -############################################################################## - - -__all__ = ['Path', 'TempDir', 'CaseInsensitivePattern'] +__all__ = ['Path', 'TempDir'] LINESEPS = ['\r\n', '\r', '\n'] U_LINESEPS = LINESEPS + ['\u0085', '\u2028', '\u2029'] -NEWLINE = re.compile('|'.join(LINESEPS)) +B_NEWLINE = re.compile('|'.join(LINESEPS).encode()) U_NEWLINE = re.compile('|'.join(U_LINESEPS)) -NL_END = re.compile(r'(?:{0})$'.format(NEWLINE.pattern)) -U_NL_END = re.compile(r'(?:{0})$'.format(U_NEWLINE.pattern)) +B_NL_END = re.compile(B_NEWLINE.pattern + b'$') +U_NL_END = re.compile(U_NEWLINE.pattern + '$') - -try: - import importlib_metadata - __version__ = importlib_metadata.version('path.py') -except Exception: - __version__ = 'unknown' +_default_linesep = object() class TreeWalkWarning(Warning): pass -# from jaraco.functools -def compose(*funcs): - compose_two = lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)) # noqa - return functools.reduce(compose_two, funcs) - - -def simple_cache(func): +class Traversal: """ - Save results for the :meth:'path.using_module' classmethod. - When Python 3.2 is available, use functools.lru_cache instead. + Wrap a walk result to customize the traversal. + + `follow` is a function that takes an item and returns + True if that item should be followed and False otherwise. + + For example, to avoid traversing into directories that + begin with `.`: + + >>> traverse = Traversal(lambda dir: not dir.startswith('.')) + >>> items = list(traverse(Path('.').walk())) + + Directories beginning with `.` will appear in the results, but + their children will not. + + >>> dot_dir = next(item for item in items if item.isdir() and item.startswith('.')) + >>> any(item.parent == dot_dir for item in items) + False """ - saved_results = {} - def wrapper(cls, module): - if module in saved_results: - return saved_results[module] - saved_results[module] = func(cls, module) - return saved_results[module] - return wrapper + def __init__(self, follow): + self.follow = follow - -class ClassProperty(property): - def __get__(self, cls, owner): - return self.fget.__get__(None, owner)() - - -class multimethod(object): - """ - Acts like a classmethod when invoked from the class and like an - instancemethod when invoked from the instance. - """ - def __init__(self, func): - self.func = func - - def __get__(self, instance, owner): - return ( - functools.partial(self.func, owner) if instance is None - else functools.partial(self.func, owner, instance) - ) - - -class matchers(object): - # TODO: make this class a module - - @staticmethod - def load(param): - """ - If the supplied parameter is a string, assum it's a simple - pattern. - """ - return ( - matchers.Pattern(param) if isinstance(param, string_types) - else param if param is not None - else matchers.Null() - ) - - class Base(object): - pass - - class Null(Base): - def __call__(self, path): - return True - - class Pattern(Base): - def __init__(self, pattern): - self.pattern = pattern - - def get_pattern(self, normcase): + def __call__(self, walker): + traverse = None + while True: try: - return self._pattern - except AttributeError: - pass - self._pattern = normcase(self.pattern) - return self._pattern + item = walker.send(traverse) + except StopIteration: + return + yield item - def __call__(self, path): - normcase = getattr(self, 'normcase', path.module.normcase) - pattern = self.get_pattern(normcase) - return fnmatch.fnmatchcase(normcase(path.name), pattern) - - class CaseInsensitive(Pattern): - """ - A Pattern with a ``'normcase'`` property, suitable for passing to - :meth:`listdir`, :meth:`dirs`, :meth:`files`, :meth:`walk`, - :meth:`walkdirs`, or :meth:`walkfiles` to match case-insensitive. - - For example, to get all files ending in .py, .Py, .pY, or .PY in the - current directory:: - - from path import Path, matchers - Path('.').files(matchers.CaseInsensitive('*.py')) - """ - normcase = staticmethod(ntpath.normcase) + traverse = functools.partial(self.follow, item) -class Path(text_type): +def _strip_newlines(lines): + r""" + >>> list(_strip_newlines(['Hello World\r\n', 'foo'])) + ['Hello World', 'foo'] + """ + return (U_NL_END.sub('', line) for line in lines) + + +class Path(str): """ Represents a filesystem path. @@ -234,18 +139,18 @@ class Path(text_type): def __init__(self, other=''): if other is None: raise TypeError("Invalid initial value for path: None") + with contextlib.suppress(AttributeError): + self._validate() @classmethod - @simple_cache + @lru_cache def using_module(cls, module): subclass_name = cls.__name__ + '_' + module.__name__ - if PY2: - subclass_name = str(subclass_name) bases = (cls,) ns = {'module': module} return type(subclass_name, bases, ns) - @ClassProperty + @classes.ClassProperty @classmethod def _next_class(cls): """ @@ -260,19 +165,14 @@ class Path(text_type): # Adding a Path and a string yields a Path. def __add__(self, more): - try: - return self._next_class(super(Path, self).__add__(more)) - except TypeError: # Python bug - return NotImplemented + return self._next_class(super(Path, self).__add__(more)) def __radd__(self, other): - if not isinstance(other, string_types): - return NotImplemented return self._next_class(other.__add__(self)) # The / operator joins Paths. def __div__(self, rel): - """ fp.__div__(rel) == fp / rel == fp.joinpath(rel) + """fp.__div__(rel) == fp / rel == fp.joinpath(rel) Join two path components, adding a separator character if needed. @@ -286,7 +186,7 @@ class Path(text_type): # The / operator joins Paths the other way around def __rdiv__(self, rel): - """ fp.__rdiv__(rel) == rel / fp + """fp.__rdiv__(rel) == rel / fp Join two path components, adding a separator character if needed. @@ -306,54 +206,52 @@ class Path(text_type): def __exit__(self, *_): os.chdir(self._old_dir) - def __fspath__(self): - return self - @classmethod def getcwd(cls): - """ Return the current working directory as a path object. + """Return the current working directory as a path object. - .. seealso:: :func:`os.getcwdu` + .. seealso:: :func:`os.getcwd` """ - return cls(getcwdu()) + return cls(os.getcwd()) # # --- Operations on Path strings. def abspath(self): - """ .. seealso:: :func:`os.path.abspath` """ + """.. seealso:: :func:`os.path.abspath`""" return self._next_class(self.module.abspath(self)) def normcase(self): - """ .. seealso:: :func:`os.path.normcase` """ + """.. seealso:: :func:`os.path.normcase`""" return self._next_class(self.module.normcase(self)) def normpath(self): - """ .. seealso:: :func:`os.path.normpath` """ + """.. seealso:: :func:`os.path.normpath`""" return self._next_class(self.module.normpath(self)) def realpath(self): - """ .. seealso:: :func:`os.path.realpath` """ - return self._next_class(self.module.realpath(self)) + """.. seealso:: :func:`os.path.realpath`""" + realpath = best_realpath(self.module) + return self._next_class(realpath(self)) def expanduser(self): - """ .. seealso:: :func:`os.path.expanduser` """ + """.. seealso:: :func:`os.path.expanduser`""" return self._next_class(self.module.expanduser(self)) def expandvars(self): - """ .. seealso:: :func:`os.path.expandvars` """ + """.. seealso:: :func:`os.path.expandvars`""" return self._next_class(self.module.expandvars(self)) def dirname(self): - """ .. seealso:: :attr:`parent`, :func:`os.path.dirname` """ + """.. seealso:: :attr:`parent`, :func:`os.path.dirname`""" return self._next_class(self.module.dirname(self)) def basename(self): - """ .. seealso:: :attr:`name`, :func:`os.path.basename` """ + """.. seealso:: :attr:`name`, :func:`os.path.basename`""" return self._next_class(self.module.basename(self)) def expand(self): - """ Clean up a filename by calling :meth:`expandvars()`, + """Clean up a filename by calling :meth:`expandvars()`, :meth:`expanduser()`, and :meth:`normpath()` on it. This is commonly everything needed to clean up a filename @@ -363,7 +261,7 @@ class Path(text_type): @property def stem(self): - """ The same as :meth:`name`, but with one file extension stripped off. + """The same as :meth:`name`, but with one file extension stripped off. >>> Path('/home/guido/python.tar.gz').stem 'python.tar' @@ -371,19 +269,14 @@ class Path(text_type): base, ext = self.module.splitext(self.name) return base - @property - def namebase(self): - warnings.warn("Use .stem instead of .namebase", DeprecationWarning) - return self.stem - @property def ext(self): - """ The file extension, for example ``'.py'``. """ + """The file extension, for example ``'.py'``.""" f, ext = self.module.splitext(self) return ext def with_suffix(self, suffix): - """ Return a new path with the file suffix changed (or added, if none) + """Return a new path with the file suffix changed (or added, if none) >>> Path('/home/guido/python.tar.gz').with_suffix(".foo") Path('/home/guido/python.tar.foo') @@ -403,7 +296,7 @@ class Path(text_type): @property def drive(self): - """ The drive specifier, for example ``'C:'``. + """The drive specifier, for example ``'C:'``. This is always empty on systems that don't use drive specifiers. """ @@ -411,7 +304,9 @@ class Path(text_type): return self._next_class(drive) parent = property( - dirname, None, None, + dirname, + None, + None, """ This path's parent directory, as a new Path object. For example, @@ -419,20 +314,24 @@ class Path(text_type): Path('/usr/local/lib')`` .. seealso:: :meth:`dirname`, :func:`os.path.dirname` - """) + """, + ) name = property( - basename, None, None, + basename, + None, + None, """ The name of this file or directory without the full path. For example, ``Path('/usr/local/lib/libpython.so').name == 'libpython.so'`` .. seealso:: :meth:`basename`, :func:`os.path.basename` - """) + """, + ) def splitpath(self): - """ p.splitpath() -> Return ``(p.parent, p.name)``. + """Return two-tuple of ``.parent``, ``.name``. .. seealso:: :attr:`parent`, :attr:`name`, :func:`os.path.split` """ @@ -440,7 +339,7 @@ class Path(text_type): return self._next_class(parent), child def splitdrive(self): - """ p.splitdrive() -> Return ``(p.drive, )``. + """Return two-tuple of ``.drive`` and rest without drive. Split the drive specifier from this path. If there is no drive specifier, :samp:`{p.drive}` is empty, so the return value @@ -449,10 +348,10 @@ class Path(text_type): .. seealso:: :func:`os.path.splitdrive` """ drive, rel = self.module.splitdrive(self) - return self._next_class(drive), rel + return self._next_class(drive), self._next_class(rel) def splitext(self): - """ p.splitext() -> Return ``(p.stripext(), p.ext)``. + """Return two-tuple of ``.stripext()`` and ``.ext``. Split the filename extension from this path and return the two parts. Either part may be empty. @@ -467,28 +366,14 @@ class Path(text_type): return self._next_class(filename), ext def stripext(self): - """ p.stripext() -> Remove one file extension from the path. + """Remove one file extension from the path. For example, ``Path('/home/guido/python.tar.gz').stripext()`` returns ``Path('/home/guido/python.tar')``. """ return self.splitext()[0] - def splitunc(self): - """ .. seealso:: :func:`os.path.splitunc` """ - unc, rest = self.module.splitunc(self) - return self._next_class(unc), rest - - @property - def uncshare(self): - """ - The UNC mount point for this path. - This is empty for paths on local drives. - """ - unc, r = self.module.splitunc(self) - return self._next_class(unc) - - @multimethod + @classes.multimethod def joinpath(cls, first, *others): """ Join first to zero or more :class:`Path` components, @@ -498,41 +383,52 @@ class Path(text_type): .. seealso:: :func:`os.path.join` """ - if not isinstance(first, cls): - first = cls(first) - return first._next_class(first.module.join(first, *others)) + return cls._next_class(cls.module.join(first, *others)) def splitall(self): - r""" Return a list of the path components in this path. + r"""Return a list of the path components in this path. The first item in the list will be a Path. Its value will be either :data:`os.curdir`, :data:`os.pardir`, empty, or the root directory of this path (for example, ``'/'`` or ``'C:\\'``). The other items in the list will be strings. - ``path.Path.joinpath(*result)`` will yield the original path. + ``Path.joinpath(*result)`` will yield the original path. + + >>> Path('/foo/bar/baz').splitall() + [Path('/'), 'foo', 'bar', 'baz'] """ - parts = [] + return list(self._parts()) + + def parts(self): + """ + >>> Path('/foo/bar/baz').parts() + (Path('/'), 'foo', 'bar', 'baz') + """ + return tuple(self._parts()) + + def _parts(self): + return reversed(tuple(self._parts_iter())) + + def _parts_iter(self): loc = self while loc != os.curdir and loc != os.pardir: prev = loc loc, child = prev.splitpath() if loc == prev: break - parts.append(child) - parts.append(loc) - parts.reverse() - return parts + yield child + yield loc def relpath(self, start='.'): - """ Return this path as a relative path, + """Return this path as a relative path, based from `start`, which defaults to the current working directory. """ cwd = self._next_class(start) return cwd.relpathto(self) def relpathto(self, dest): - """ Return a relative path from `self` to `dest`. + """Return a relative path from `self` to `dest`. If there is no relative path from `self` to `dest`, for example if they reside on different drives in Windows, then this returns @@ -572,7 +468,7 @@ class Path(text_type): # --- Listing, searching, walking, and matching def listdir(self, match=None): - """ D.listdir() -> List of items in this directory. + """List of items in this directory. Use :meth:`files` or :meth:`dirs` instead if you want a listing of just files or just subdirectories. @@ -585,12 +481,10 @@ class Path(text_type): .. seealso:: :meth:`files`, :meth:`dirs` """ match = matchers.load(match) - return list(filter(match, ( - self / child for child in os.listdir(self) - ))) + return list(filter(match, (self / child for child in os.listdir(self)))) def dirs(self, *args, **kwargs): - """ D.dirs() -> List of this directory's subdirectories. + """List of this directory's subdirectories. The elements of the list are Path objects. This does not walk recursively into subdirectories @@ -601,7 +495,7 @@ class Path(text_type): return [p for p in self.listdir(*args, **kwargs) if p.isdir()] def files(self, *args, **kwargs): - """ D.files() -> List of the files in this directory. + """List of the files in self. The elements of the list are Path objects. This does not walk into subdirectories (see :meth:`walkfiles`). @@ -612,7 +506,7 @@ class Path(text_type): return [p for p in self.listdir(*args, **kwargs) if p.isfile()] def walk(self, match=None, errors='strict'): - """ D.walk() -> iterator over files and subdirs, recursively. + """Iterator over files and subdirs, recursively. The iterator yields Path objects naming each child item of this directory and its descendants. This requires that @@ -627,75 +521,49 @@ class Path(text_type): reports the error via :func:`warnings.warn()`), and ``'ignore'``. `errors` may also be an arbitrary callable taking a msg parameter. """ - class Handlers: - def strict(msg): - raise - - def warn(msg): - warnings.warn(msg, TreeWalkWarning) - - def ignore(msg): - pass - - if not callable(errors) and errors not in vars(Handlers): - raise ValueError("invalid errors parameter") - errors = vars(Handlers).get(errors, errors) + errors = Handlers._resolve(errors) match = matchers.load(match) try: childList = self.listdir() - except Exception: - exc = sys.exc_info()[1] - tmpl = "Unable to list directory '%(self)s': %(exc)s" - msg = tmpl % locals() - errors(msg) + except Exception as exc: + errors(f"Unable to list directory '{self}': {exc}") return for child in childList: + traverse = None if match(child): - yield child + traverse = yield child + traverse = traverse or child.isdir try: - isdir = child.isdir() - except Exception: - exc = sys.exc_info()[1] - tmpl = "Unable to access '%(child)s': %(exc)s" - msg = tmpl % locals() - errors(msg) - isdir = False + do_traverse = traverse() + except Exception as exc: + errors(f"Unable to access '{child}': {exc}") + continue - if isdir: + if do_traverse: for item in child.walk(errors=errors, match=match): yield item def walkdirs(self, *args, **kwargs): - """ D.walkdirs() -> iterator over subdirs, recursively. - """ - return ( - item - for item in self.walk(*args, **kwargs) - if item.isdir() - ) + """Iterator over subdirs, recursively.""" + return (item for item in self.walk(*args, **kwargs) if item.isdir()) def walkfiles(self, *args, **kwargs): - """ D.walkfiles() -> iterator over files in D, recursively. - """ - return ( - item - for item in self.walk(*args, **kwargs) - if item.isfile() - ) + """Iterator over files, recursively.""" + return (item for item in self.walk(*args, **kwargs) if item.isfile()) def fnmatch(self, pattern, normcase=None): - """ Return ``True`` if `self.name` matches the given `pattern`. + """Return ``True`` if `self.name` matches the given `pattern`. `pattern` - A filename pattern with wildcards, for example ``'*.py'``. If the pattern contains a `normcase` attribute, it is applied to the name and path prior to comparison. `normcase` - (optional) A function used to normalize the pattern and - filename before matching. Defaults to :meth:`self.module`, which - defaults to :meth:`os.path.normcase`. + filename before matching. Defaults to normcase from + ``self.module``, :func:`os.path.normcase`. .. seealso:: :func:`fnmatch.fnmatch` """ @@ -706,7 +574,7 @@ class Path(text_type): return fnmatch.fnmatchcase(name, pattern) def glob(self, pattern): - """ Return a list of Path objects that match the pattern. + """Return a list of Path objects that match the pattern. `pattern` - a path relative to this directory, with wildcards. @@ -723,7 +591,7 @@ class Path(text_type): return [cls(s) for s in glob.glob(self / pattern)] def iglob(self, pattern): - """ Return an iterator of Path objects that match the pattern. + """Return an iterator of Path objects that match the pattern. `pattern` - a path relative to this directory, with wildcards. @@ -744,64 +612,72 @@ class Path(text_type): # --- Reading or writing an entire file at once. def open(self, *args, **kwargs): - """ Open this file and return a corresponding :class:`file` object. + """Open this file and return a corresponding file object. Keyword arguments work as in :func:`io.open`. If the file cannot be - opened, an :class:`~exceptions.OSError` is raised. + opened, an :class:`OSError` is raised. """ - with io_error_compat(): - return io.open(self, *args, **kwargs) + return io.open(self, *args, **kwargs) def bytes(self): - """ Open this file, read all bytes, return them as a string. """ + """Open this file, read all bytes, return them as a string.""" with self.open('rb') as f: return f.read() def chunks(self, size, *args, **kwargs): - """ Returns a generator yielding chunks of the file, so it can - be read piece by piece with a simple for loop. + """Returns a generator yielding chunks of the file, so it can + be read piece by piece with a simple for loop. - Any argument you pass after `size` will be passed to :meth:`open`. + Any argument you pass after `size` will be passed to :meth:`open`. - :example: + :example: - >>> hash = hashlib.md5() - >>> for chunk in Path("path.py").chunks(8192, mode='rb'): - ... hash.update(chunk) + >>> hash = hashlib.md5() + >>> for chunk in Path("CHANGES.rst").chunks(8192, mode='rb'): + ... hash.update(chunk) - This will read the file by chunks of 8192 bytes. + This will read the file by chunks of 8192 bytes. """ with self.open(*args, **kwargs) as f: for chunk in iter(lambda: f.read(size) or None, None): yield chunk def write_bytes(self, bytes, append=False): - """ Open this file and write the given bytes to it. + """Open this file and write the given bytes to it. Default behavior is to overwrite any existing file. Call ``p.write_bytes(bytes, append=True)`` to append instead. """ - if append: - mode = 'ab' - else: - mode = 'wb' - with self.open(mode) as f: + with self.open('ab' if append else 'wb') as f: f.write(bytes) - def text(self, encoding=None, errors='strict'): - r""" Open this file, read it in, return the content as a string. + def read_text(self, encoding=None, errors=None): + r"""Open this file, read it in, return the content as a string. - All newline sequences are converted to ``'\n'``. Keyword arguments - will be passed to :meth:`open`. + Optional parameters are passed to :meth:`open`. .. seealso:: :meth:`lines` """ - with self.open(mode='r', encoding=encoding, errors=errors) as f: - return U_NEWLINE.sub('\n', f.read()) + with self.open(encoding=encoding, errors=errors) as f: + return f.read() - def write_text(self, text, encoding=None, errors='strict', - linesep=os.linesep, append=False): - r""" Write the given text to this file. + def read_bytes(self): + r"""Return the contents of this file as bytes.""" + with self.open(mode='rb') as f: + return f.read() + + def text(self, encoding=None, errors='strict'): + r"""Legacy function to read text. + + Converts all newline sequences to ``\n``. + """ + warnings.warn(".text is deprecated; use read_text", DeprecationWarning) + return U_NEWLINE.sub('\n', self.read_text(encoding, errors)) + + def write_text( + self, text, encoding=None, errors='strict', linesep=os.linesep, append=False + ): + r"""Write the given text to this file. The default behavior is to overwrite any existing file; to append instead, use the `append=True` keyword argument. @@ -812,24 +688,22 @@ class Path(text_type): Parameters: - `text` - str/unicode - The text to be written. + `text` - str/bytes - The text to be written. - `encoding` - str - The Unicode encoding that will be used. - This is ignored if `text` isn't a Unicode string. + `encoding` - str - The text encoding used. `errors` - str - How to handle Unicode encoding errors. Default is ``'strict'``. See ``help(unicode.encode)`` for the - options. This is ignored if `text` isn't a Unicode - string. + options. Ignored if `text` isn't a Unicode string. `linesep` - keyword argument - str/unicode - The sequence of characters to be used to mark end-of-line. The default is - :data:`os.linesep`. You can also specify ``None`` to - leave all newlines as they are in `text`. + :data:`os.linesep`. Specify ``None`` to + use newlines unmodified. `append` - keyword argument - bool - Specifies what to do if the file already exists (``True``: append to the end of it; - ``False``: overwrite it.) The default is ``False``. + ``False``: overwrite it). The default is ``False``. --- Newline handling. @@ -839,18 +713,13 @@ class Path(text_type): end-of-line sequence (see :data:`os.linesep`; on Windows, for example, the end-of-line marker is ``'\r\n'``). - If you don't like your platform's default, you can override it - using the `linesep=` keyword argument. If you specifically want - ``write_text()`` to preserve the newlines as-is, use ``linesep=None``. - - This applies to Unicode text the same as to 8-bit text, except - there are three additional standard Unicode end-of-line sequences: - ``u'\x85'``, ``u'\r\x85'``, and ``u'\u2028'``. - - (This is slightly different from when you open a file for - writing with ``fopen(filename, "w")`` in C or ``open(filename, 'w')`` - in Python.) + To override the platform's default, pass the `linesep=` + keyword argument. To preserve the newlines as-is, pass + ``linesep=None``. + This handling applies to Unicode text and bytes, except + with Unicode, additional non-ASCII newlines are recognized: + ``\x85``, ``\r\x85``, and ``\u2028``. --- Unicode @@ -862,39 +731,52 @@ class Path(text_type): specified `encoding` (or the default encoding if `encoding` isn't specified). The `errors` argument applies only to this conversion. - """ - if isinstance(text, text_type): + if isinstance(text, str): if linesep is not None: text = U_NEWLINE.sub(linesep, text) - text = text.encode(encoding or sys.getdefaultencoding(), errors) + bytes = text.encode(encoding or sys.getdefaultencoding(), errors) else: + warnings.warn( + "Writing bytes in write_text is deprecated", + DeprecationWarning, + stacklevel=1, + ) assert encoding is None - text = NEWLINE.sub(linesep, text) - self.write_bytes(text, append=append) + if linesep is not None: + text = B_NEWLINE.sub(linesep.encode(), text) + bytes = text + self.write_bytes(bytes, append=append) - def lines(self, encoding=None, errors='strict', retain=True): - r""" Open this file, read all lines, return them in a list. + def lines(self, encoding=None, errors=None, retain=True): + r"""Open this file, read all lines, return them in a list. Optional arguments: `encoding` - The Unicode encoding (or character set) of - the file. The default is ``None``, meaning the content - of the file is read as 8-bit characters and returned - as a list of (non-Unicode) str objects. - `errors` - How to handle Unicode errors; see help(str.decode) - for the options. Default is ``'strict'``. - `retain` - If ``True``, retain newline characters; but all newline - character combinations (``'\r'``, ``'\n'``, ``'\r\n'``) are - translated to ``'\n'``. If ``False``, newline characters are - stripped off. Default is ``True``. + the file. The default is ``None``, meaning use + ``locale.getpreferredencoding()``. + `errors` - How to handle Unicode errors; see + `open `_ + for the options. Default is ``None`` meaning "strict". + `retain` - If ``True`` (default), retain newline characters, + but translate all newline + characters to ``\n``. If ``False``, newline characters are + omitted. .. seealso:: :meth:`text` """ - return self.text(encoding, errors).splitlines(retain) + text = U_NEWLINE.sub('\n', self.read_text(encoding, errors)) + return text.splitlines(retain) - def write_lines(self, lines, encoding=None, errors='strict', - linesep=os.linesep, append=False): - r""" Write the given lines of text to this file. + def write_lines( + self, + lines, + encoding=None, + errors='strict', + linesep=_default_linesep, + append=False, + ): + r"""Write the given lines of text to this file. By default this overwrites any existing file at this path. @@ -909,7 +791,7 @@ class Path(text_type): `errors` - How to handle errors in Unicode encoding. This also applies only to Unicode strings. - linesep - The desired line-ending. This line-ending is + linesep - (deprecated) The desired line-ending. This line-ending is applied to every line. If a line already has any standard line ending (``'\r'``, ``'\n'``, ``'\r\n'``, ``u'\x85'``, ``u'\r\x85'``, ``u'\u2028'``), that will @@ -917,32 +799,28 @@ class Path(text_type): default is os.linesep, which is platform-dependent (``'\r\n'`` on Windows, ``'\n'`` on Unix, etc.). Specify ``None`` to write the lines as-is, like - :meth:`file.writelines`. + ``.writelines`` on a file object. Use the keyword argument ``append=True`` to append lines to the file. The default is to overwrite the file. - - .. warning :: - - When you use this with Unicode data, if the encoding of the - existing data in the file is different from the encoding - you specify with the `encoding=` parameter, the result is - mixed-encoding data, which can really confuse someone trying - to read the file later. """ - with self.open('ab' if append else 'wb') as f: - for line in lines: - isUnicode = isinstance(line, text_type) - if linesep is not None: - pattern = U_NL_END if isUnicode else NL_END - line = pattern.sub('', line) + linesep - if isUnicode: - line = line.encode( - encoding or sys.getdefaultencoding(), errors) - f.write(line) + mode = 'a' if append else 'w' + with self.open(mode, encoding=encoding, errors=errors, newline='') as f: + f.writelines(self._replace_linesep(lines, linesep)) + + @staticmethod + def _replace_linesep(lines, linesep): + if linesep != _default_linesep: + warnings.warn("linesep is deprecated", DeprecationWarning, stacklevel=3) + else: + linesep = os.linesep + if linesep is None: + return lines + + return (line + linesep for line in _strip_newlines(lines)) def read_md5(self): - """ Calculate the md5 hash for this file. + """Calculate the md5 hash for this file. This reads through the entire file. @@ -951,7 +829,7 @@ class Path(text_type): return self.read_hash('md5') def _hash(self, hash_name): - """ Returns a hash object for the file at the current path. + """Returns a hash object for the file at the current path. `hash_name` should be a hash algo name (such as ``'md5'`` or ``'sha1'``) that's available in the :mod:`hashlib` module. @@ -962,7 +840,7 @@ class Path(text_type): return m def read_hash(self, hash_name): - """ Calculate given hash for this file. + """Calculate given hash for this file. List of supported hashes can be obtained from :mod:`hashlib` package. This reads the entire file. @@ -972,7 +850,7 @@ class Path(text_type): return self._hash(hash_name).digest() def read_hexhash(self, hash_name): - """ Calculate given hash for this file, returning hexdigest. + """Calculate given hash for this file, returning hexdigest. List of supported hashes can be obtained from :mod:`hashlib` package. This reads the entire file. @@ -987,121 +865,154 @@ class Path(text_type): # bound. Playing it safe and wrapping them all in method calls. def isabs(self): - """ .. seealso:: :func:`os.path.isabs` """ + """ + >>> Path('.').isabs() + False + + .. seealso:: :func:`os.path.isabs` + """ return self.module.isabs(self) def exists(self): - """ .. seealso:: :func:`os.path.exists` """ + """.. seealso:: :func:`os.path.exists`""" return self.module.exists(self) def isdir(self): - """ .. seealso:: :func:`os.path.isdir` """ + """.. seealso:: :func:`os.path.isdir`""" return self.module.isdir(self) def isfile(self): - """ .. seealso:: :func:`os.path.isfile` """ + """.. seealso:: :func:`os.path.isfile`""" return self.module.isfile(self) def islink(self): - """ .. seealso:: :func:`os.path.islink` """ + """.. seealso:: :func:`os.path.islink`""" return self.module.islink(self) def ismount(self): - """ .. seealso:: :func:`os.path.ismount` """ + """ + >>> Path('.').ismount() + False + + .. seealso:: :func:`os.path.ismount` + """ return self.module.ismount(self) def samefile(self, other): - """ .. seealso:: :func:`os.path.samefile` """ - if not hasattr(self.module, 'samefile'): - other = Path(other).realpath().normpath().normcase() - return self.realpath().normpath().normcase() == other + """.. seealso:: :func:`os.path.samefile`""" return self.module.samefile(self, other) def getatime(self): - """ .. seealso:: :attr:`atime`, :func:`os.path.getatime` """ + """.. seealso:: :attr:`atime`, :func:`os.path.getatime`""" return self.module.getatime(self) atime = property( - getatime, None, None, - """ Last access time of the file. + getatime, + None, + None, + """ + Last access time of the file. + + >>> Path('.').atime > 0 + True .. seealso:: :meth:`getatime`, :func:`os.path.getatime` - """) + """, + ) def getmtime(self): - """ .. seealso:: :attr:`mtime`, :func:`os.path.getmtime` """ + """.. seealso:: :attr:`mtime`, :func:`os.path.getmtime`""" return self.module.getmtime(self) mtime = property( - getmtime, None, None, - """ Last-modified time of the file. + getmtime, + None, + None, + """ + Last modified time of the file. .. seealso:: :meth:`getmtime`, :func:`os.path.getmtime` - """) + """, + ) def getctime(self): - """ .. seealso:: :attr:`ctime`, :func:`os.path.getctime` """ + """.. seealso:: :attr:`ctime`, :func:`os.path.getctime`""" return self.module.getctime(self) ctime = property( - getctime, None, None, + getctime, + None, + None, """ Creation time of the file. .. seealso:: :meth:`getctime`, :func:`os.path.getctime` - """) + """, + ) def getsize(self): - """ .. seealso:: :attr:`size`, :func:`os.path.getsize` """ + """.. seealso:: :attr:`size`, :func:`os.path.getsize`""" return self.module.getsize(self) size = property( - getsize, None, None, + getsize, + None, + None, """ Size of the file, in bytes. .. seealso:: :meth:`getsize`, :func:`os.path.getsize` - """) + """, + ) - if hasattr(os, 'access'): - def access(self, mode): - """ Return ``True`` if current user has access to this path. + def access(self, *args, **kwargs): + """ + Return does the real user have access to this path. - mode - One of the constants :data:`os.F_OK`, :data:`os.R_OK`, - :data:`os.W_OK`, :data:`os.X_OK` + >>> Path('.').access(os.F_OK) + True - .. seealso:: :func:`os.access` - """ - return os.access(self, mode) + .. seealso:: :func:`os.access` + """ + return os.access(self, *args, **kwargs) def stat(self): - """ Perform a ``stat()`` system call on this path. + """ + Perform a ``stat()`` system call on this path. + + >>> Path('.').stat() + os.stat_result(...) .. seealso:: :meth:`lstat`, :func:`os.stat` """ return os.stat(self) def lstat(self): - """ Like :meth:`stat`, but do not follow symbolic links. + """ + Like :meth:`stat`, but do not follow symbolic links. + + >>> Path('.').lstat() == Path('.').stat() + True .. seealso:: :meth:`stat`, :func:`os.lstat` """ return os.lstat(self) - def __get_owner_windows(self): - """ + def __get_owner_windows(self): # pragma: nocover + r""" Return the name of the owner of this file or directory. Follow symbolic links. - Return a name of the form ``r'DOMAIN\\User Name'``; may be a group. + Return a name of the form ``DOMAIN\User Name``; may be a group. .. seealso:: :attr:`owner` """ desc = win32security.GetFileSecurity( - self, win32security.OWNER_SECURITY_INFORMATION) + self, win32security.OWNER_SECURITY_INFORMATION + ) sid = desc.GetSecurityDescriptorOwner() account, domain, typecode = win32security.LookupAccountSid(None, sid) return domain + '\\' + account - def __get_owner_unix(self): + def __get_owner_unix(self): # pragma: nocover """ Return the name of the owner of this file or directory. Follow symbolic links. @@ -1111,44 +1022,50 @@ class Path(text_type): st = self.stat() return pwd.getpwuid(st.st_uid).pw_name - def __get_owner_not_implemented(self): + def __get_owner_not_implemented(self): # pragma: nocover raise NotImplementedError("Ownership not available on this platform.") - if 'win32security' in globals(): - get_owner = __get_owner_windows - elif 'pwd' in globals(): - get_owner = __get_owner_unix - else: - get_owner = __get_owner_not_implemented + get_owner = ( + __get_owner_windows + if 'win32security' in globals() + else __get_owner_unix + if 'pwd' in globals() + else __get_owner_not_implemented + ) owner = property( - get_owner, None, None, + get_owner, + None, + None, """ Name of the owner of this file or directory. - .. seealso:: :meth:`get_owner`""") + .. seealso:: :meth:`get_owner`""", + ) if hasattr(os, 'statvfs'): + def statvfs(self): - """ Perform a ``statvfs()`` system call on this path. + """Perform a ``statvfs()`` system call on this path. .. seealso:: :func:`os.statvfs` """ return os.statvfs(self) if hasattr(os, 'pathconf'): + def pathconf(self, name): - """ .. seealso:: :func:`os.pathconf` """ + """.. seealso:: :func:`os.pathconf`""" return os.pathconf(self, name) # # --- Modifying operations on files and directories - def utime(self, times): - """ Set the access and modified times of this file. + def utime(self, *args, **kwargs): + """Set the access and modified times of this file. .. seealso:: :func:`os.utime` """ - os.utime(self, times) + os.utime(self, *args, **kwargs) return self def chmod(self, mode): @@ -1158,36 +1075,37 @@ class Path(text_type): .. seealso:: :func:`os.chmod` """ - if isinstance(mode, string_types): - mask = _multi_permission_mask(mode) + if isinstance(mode, str): + mask = masks.compound(mode) mode = mask(self.stat().st_mode) os.chmod(self, mode) return self - def chown(self, uid=-1, gid=-1): - """ - Change the owner and group by names rather than the uid or gid numbers. + if hasattr(os, 'chown'): - .. seealso:: :func:`os.chown` - """ - if hasattr(os, 'chown'): - if 'pwd' in globals() and isinstance(uid, string_types): - uid = pwd.getpwnam(uid).pw_uid - if 'grp' in globals() and isinstance(gid, string_types): - gid = grp.getgrnam(gid).gr_gid - os.chown(self, uid, gid) - else: - msg = "Ownership not available on this platform." - raise NotImplementedError(msg) - return self + def chown(self, uid=-1, gid=-1): + """ + Change the owner and group by names or numbers. + + .. seealso:: :func:`os.chown` + """ + + def resolve_uid(uid): + return uid if isinstance(uid, int) else pwd.getpwnam(uid).pw_uid + + def resolve_gid(gid): + return gid if isinstance(gid, int) else grp.getgrnam(gid).gr_gid + + os.chown(self, resolve_uid(uid), resolve_gid(gid)) + return self def rename(self, new): - """ .. seealso:: :func:`os.rename` """ + """.. seealso:: :func:`os.rename`""" os.rename(self, new) return self._next_class(new) def renames(self, new): - """ .. seealso:: :func:`os.renames` """ + """.. seealso:: :func:`os.renames`""" os.renames(self, new) return self._next_class(new) @@ -1195,74 +1113,60 @@ class Path(text_type): # --- Create/delete operations on directories def mkdir(self, mode=0o777): - """ .. seealso:: :func:`os.mkdir` """ + """.. seealso:: :func:`os.mkdir`""" os.mkdir(self, mode) return self def mkdir_p(self, mode=0o777): - """ Like :meth:`mkdir`, but does not raise an exception if the - directory already exists. """ - try: + """Like :meth:`mkdir`, but does not raise an exception if the + directory already exists.""" + with contextlib.suppress(FileExistsError): self.mkdir(mode) - except OSError: - _, e, _ = sys.exc_info() - if e.errno != errno.EEXIST: - raise return self def makedirs(self, mode=0o777): - """ .. seealso:: :func:`os.makedirs` """ + """.. seealso:: :func:`os.makedirs`""" os.makedirs(self, mode) return self def makedirs_p(self, mode=0o777): - """ Like :meth:`makedirs`, but does not raise an exception if the - directory already exists. """ - try: + """Like :meth:`makedirs`, but does not raise an exception if the + directory already exists.""" + with contextlib.suppress(FileExistsError): self.makedirs(mode) - except OSError: - _, e, _ = sys.exc_info() - if e.errno != errno.EEXIST: - raise return self def rmdir(self): - """ .. seealso:: :func:`os.rmdir` """ + """.. seealso:: :func:`os.rmdir`""" os.rmdir(self) return self def rmdir_p(self): - """ Like :meth:`rmdir`, but does not raise an exception if the - directory is not empty or does not exist. """ - try: - self.rmdir() - except OSError: - _, e, _ = sys.exc_info() - bypass_codes = errno.ENOTEMPTY, errno.EEXIST, errno.ENOENT - if e.errno not in bypass_codes: - raise + """Like :meth:`rmdir`, but does not raise an exception if the + directory is not empty or does not exist.""" + suppressed = FileNotFoundError, FileExistsError, DirectoryNotEmpty + with contextlib.suppress(suppressed): + with DirectoryNotEmpty.translate(): + self.rmdir() return self def removedirs(self): - """ .. seealso:: :func:`os.removedirs` """ + """.. seealso:: :func:`os.removedirs`""" os.removedirs(self) return self def removedirs_p(self): - """ Like :meth:`removedirs`, but does not raise an exception if the - directory is not empty or does not exist. """ - try: - self.removedirs() - except OSError: - _, e, _ = sys.exc_info() - if e.errno != errno.ENOTEMPTY and e.errno != errno.EEXIST: - raise + """Like :meth:`removedirs`, but does not raise an exception if the + directory is not empty or does not exist.""" + with contextlib.suppress(FileExistsError, DirectoryNotEmpty): + with DirectoryNotEmpty.translate(): + self.removedirs() return self # --- Modifying operations on files def touch(self): - """ Set the access/modified times of this file to the current time. + """Set the access/modified times of this file to the current time. Create the file if it does not exist. """ fd = os.open(self, os.O_WRONLY | os.O_CREAT, 0o666) @@ -1271,78 +1175,61 @@ class Path(text_type): return self def remove(self): - """ .. seealso:: :func:`os.remove` """ + """.. seealso:: :func:`os.remove`""" os.remove(self) return self def remove_p(self): - """ Like :meth:`remove`, but does not raise an exception if the - file does not exist. """ - try: + """Like :meth:`remove`, but does not raise an exception if the + file does not exist.""" + with contextlib.suppress(FileNotFoundError): self.unlink() - except FileNotFoundError as exc: - if PY2 and exc.errno != errno.ENOENT: - raise return self - def unlink(self): - """ .. seealso:: :func:`os.unlink` """ - os.unlink(self) - return self - - def unlink_p(self): - """ Like :meth:`unlink`, but does not raise an exception if the - file does not exist. """ - self.remove_p() - return self + unlink = remove + unlink_p = remove_p # --- Links - if hasattr(os, 'link'): - def link(self, newpath): - """ Create a hard link at `newpath`, pointing to this file. + def link(self, newpath): + """Create a hard link at `newpath`, pointing to this file. - .. seealso:: :func:`os.link` - """ - os.link(self, newpath) - return self._next_class(newpath) + .. seealso:: :func:`os.link` + """ + os.link(self, newpath) + return self._next_class(newpath) - if hasattr(os, 'symlink'): - def symlink(self, newlink=None): - """ Create a symbolic link at `newlink`, pointing here. + def symlink(self, newlink=None): + """Create a symbolic link at `newlink`, pointing here. - If newlink is not supplied, the symbolic link will assume - the name self.basename(), creating the link in the cwd. + If newlink is not supplied, the symbolic link will assume + the name self.basename(), creating the link in the cwd. - .. seealso:: :func:`os.symlink` - """ - if newlink is None: - newlink = self.basename() - os.symlink(self, newlink) - return self._next_class(newlink) + .. seealso:: :func:`os.symlink` + """ + if newlink is None: + newlink = self.basename() + os.symlink(self, newlink) + return self._next_class(newlink) - if hasattr(os, 'readlink'): - def readlink(self): - """ Return the path to which this symbolic link points. + def readlink(self): + """Return the path to which this symbolic link points. - The result may be an absolute or a relative path. + The result may be an absolute or a relative path. - .. seealso:: :meth:`readlinkabs`, :func:`os.readlink` - """ - return self._next_class(os.readlink(self)) + .. seealso:: :meth:`readlinkabs`, :func:`os.readlink` + """ + return self._next_class(os.readlink(self)) - def readlinkabs(self): - """ Return the path to which this symbolic link points. + def readlinkabs(self): + """Return the path to which this symbolic link points. - The result is always an absolute path. + The result is always an absolute path. - .. seealso:: :meth:`readlink`, :func:`os.readlink` - """ - p = self.readlink() - if p.isabs(): - return p - else: - return (self.parent / p).abspath() + .. seealso:: :meth:`readlink`, :func:`os.readlink` + """ + p = self.readlink() + return p if p.isabs() else (self.parent / p).abspath() # High-level functions from shutil # These functions will be bound to the instance such that @@ -1359,28 +1246,26 @@ class Path(text_type): rmtree = shutil.rmtree def rmtree_p(self): - """ Like :meth:`rmtree`, but does not raise an exception if the - directory does not exist. """ - try: + """Like :meth:`rmtree`, but does not raise an exception if the + directory does not exist.""" + with contextlib.suppress(FileNotFoundError): self.rmtree() - except OSError: - _, e, _ = sys.exc_info() - if e.errno != errno.ENOENT: - raise return self def chdir(self): - """ .. seealso:: :func:`os.chdir` """ + """.. seealso:: :func:`os.chdir`""" os.chdir(self) cd = chdir def merge_tree( - self, dst, symlinks=False, - # * - update=False, - copy_function=shutil.copy2, - ignore=lambda dir, contents: []): + self, + dst, + symlinks=False, + *, + copy_function=shutil.copy2, + ignore=lambda dir, contents: [], + ): """ Copy entire contents of self to dst, overwriting existing contents in dst with those in self. @@ -1397,15 +1282,6 @@ class Path(text_type): dst = self._next_class(dst) dst.makedirs_p() - if update: - warnings.warn( - "Update is deprecated; " - "use copy_function=only_newer(shutil.copy2)", - DeprecationWarning, - stacklevel=2, - ) - copy_function = only_newer(copy_function) - sources = self.listdir() _ignored = ignore(self, [item.name for item in sources]) @@ -1421,7 +1297,6 @@ class Path(text_type): source.merge_tree( dest, symlinks=symlinks, - update=update, copy_function=copy_function, ignore=ignore, ) @@ -1434,22 +1309,29 @@ class Path(text_type): # --- Special stuff from os if hasattr(os, 'chroot'): - def chroot(self): - """ .. seealso:: :func:`os.chroot` """ + + def chroot(self): # pragma: nocover + """.. seealso:: :func:`os.chroot`""" os.chroot(self) if hasattr(os, 'startfile'): - def startfile(self): - """ .. seealso:: :func:`os.startfile` """ - os.startfile(self) + + def startfile(self, *args, **kwargs): # pragma: nocover + """.. seealso:: :func:`os.startfile`""" + os.startfile(self, *args, **kwargs) return self # in-place re-writing, courtesy of Martijn Pieters # http://www.zopatista.com/python/2013/11/26/inplace-file-rewriting/ @contextlib.contextmanager def in_place( - self, mode='r', buffering=-1, encoding=None, errors=None, - newline=None, backup_extension=None, + self, + mode='r', + buffering=-1, + encoding=None, + errors=None, + newline=None, + backup_extension=None, ): """ A context in which a file may be re-written in-place with @@ -1476,68 +1358,62 @@ class Path(text_type): Thereafter, the file at `filename` will have line numbers in it. """ - import io - if set(mode).intersection('wa+'): raise ValueError('Only read-only file modes can be used') # move existing file to backup, create new file with same permissions # borrowed extensively from the fileinput module backup_fn = self + (backup_extension or os.extsep + 'bak') - try: - os.unlink(backup_fn) - except os.error: - pass - os.rename(self, backup_fn) + backup_fn.remove_p() + self.rename(backup_fn) readable = io.open( - backup_fn, mode, buffering=buffering, - encoding=encoding, errors=errors, newline=newline, + backup_fn, + mode, + buffering=buffering, + encoding=encoding, + errors=errors, + newline=newline, ) try: perm = os.fstat(readable.fileno()).st_mode except OSError: - writable = open( - self, 'w' + mode.replace('r', ''), - buffering=buffering, encoding=encoding, errors=errors, + writable = self.open( + 'w' + mode.replace('r', ''), + buffering=buffering, + encoding=encoding, + errors=errors, newline=newline, ) else: os_mode = os.O_CREAT | os.O_WRONLY | os.O_TRUNC - if hasattr(os, 'O_BINARY'): - os_mode |= os.O_BINARY + os_mode |= getattr(os, 'O_BINARY', 0) fd = os.open(self, os_mode, perm) writable = io.open( - fd, "w" + mode.replace('r', ''), - buffering=buffering, encoding=encoding, errors=errors, + fd, + "w" + mode.replace('r', ''), + buffering=buffering, + encoding=encoding, + errors=errors, newline=newline, ) - try: - if hasattr(os, 'chmod'): - os.chmod(self, perm) - except OSError: - pass + with contextlib.suppress(OSError, AttributeError): + self.chmod(perm) try: yield readable, writable except Exception: # move backup back readable.close() writable.close() - try: - os.unlink(self) - except os.error: - pass - os.rename(backup_fn, self) + self.remove_p() + backup_fn.rename(self) raise else: readable.close() writable.close() finally: - try: - os.unlink(backup_fn) - except os.error: - pass + backup_fn.remove_p() - @ClassProperty + @classes.ClassProperty @classmethod def special(cls): """ @@ -1563,24 +1439,64 @@ class Path(text_type): return functools.partial(SpecialResolver, cls) +class DirectoryNotEmpty(OSError): + @staticmethod + @contextlib.contextmanager + def translate(): + try: + yield + except OSError as exc: + if exc.errno == errno.ENOTEMPTY: + raise DirectoryNotEmpty(*exc.args) from exc + raise + + def only_newer(copy_func): """ Wrap a copy function (like shutil.copy2) to return the dst if it's newer than the source. """ + @functools.wraps(copy_func) def wrapper(src, dst, *args, **kwargs): - is_newer_dst = ( - dst.exists() - and dst.getmtime() >= src.getmtime() - ) + is_newer_dst = dst.exists() and dst.getmtime() >= src.getmtime() if is_newer_dst: return dst return copy_func(src, dst, *args, **kwargs) + return wrapper -class SpecialResolver(object): +class ExtantPath(Path): + """ + >>> ExtantPath('.') + ExtantPath('.') + >>> ExtantPath('does-not-exist') + Traceback (most recent call last): + OSError: does-not-exist does not exist. + """ + + def _validate(self): + if not self.exists(): + raise OSError(f"{self} does not exist.") + + +class ExtantFile(Path): + """ + >>> ExtantFile('.') + Traceback (most recent call last): + FileNotFoundError: . does not exist as a file. + >>> ExtantFile('does-not-exist') + Traceback (most recent call last): + FileNotFoundError: does-not-exist does not exist as a file. + """ + + def _validate(self): + if not self.isfile(): + raise FileNotFoundError(f"{self} does not exist as a file.") + + +class SpecialResolver: class ResolverScope: def __init__(self, paths, scope): self.paths = paths @@ -1592,13 +1508,8 @@ class SpecialResolver(object): def __init__(self, path_class, *args, **kwargs): appdirs = importlib.import_module('appdirs') - # let appname default to None until - # https://github.com/ActiveState/appdirs/issues/55 is solved. - not args and kwargs.setdefault('appname', None) - vars(self).update( - path_class=path_class, - wrapper=appdirs.AppDirs(*args, **kwargs), + path_class=path_class, wrapper=appdirs.AppDirs(*args, **kwargs) ) def __getattr__(self, scope): @@ -1619,11 +1530,10 @@ class Multi: """ A mix-in for a Path which may contain multiple Path separated by pathsep. """ + @classmethod def for_class(cls, path_cls): name = 'Multi' + path_cls.__name__ - if PY2: - name = str(name) return type(name, (cls, path_cls), {}) @classmethod @@ -1635,17 +1545,13 @@ class Multi: def __iter__(self): return iter(map(self._next_class, self.split(os.pathsep))) - @ClassProperty + @classes.ClassProperty @classmethod def _next_class(cls): """ Multi-subclasses should use the parent class """ - return next( - class_ - for class_ in cls.__mro__ - if not issubclass(class_, Multi) - ) + return next(class_ for class_ in cls.__mro__ if not issubclass(class_, Multi)) class TempDir(Path): @@ -1654,17 +1560,21 @@ class TempDir(Path): constructed with the same parameters that you can use as a context manager. - Example:: + For example: - with TempDir() as d: - # do stuff with the Path object "d" + >>> with TempDir() as d: + ... d.isdir() and isinstance(d, Path) + True - # here the directory is deleted automatically + The directory is deleted automatically. + + >>> d.isdir() + False .. seealso:: :func:`tempfile.mkdtemp` """ - @ClassProperty + @classes.ClassProperty @classmethod def _next_class(cls): return Path @@ -1684,138 +1594,21 @@ class TempDir(Path): return self._next_class(self) def __exit__(self, exc_type, exc_value, traceback): - if not exc_value: - self.rmtree() + self.rmtree() -# For backwards compatibility. -tempdir = TempDir +class Handlers: + def strict(msg): + raise + def warn(msg): + warnings.warn(msg, TreeWalkWarning) -def _multi_permission_mask(mode): - """ - Support multiple, comma-separated Unix chmod symbolic modes. + def ignore(msg): + pass - >>> _multi_permission_mask('a=r,u+w')(0) == 0o644 - True - """ - def compose(f, g): - return lambda *args, **kwargs: g(f(*args, **kwargs)) - return functools.reduce(compose, map(_permission_mask, mode.split(','))) - - -def _permission_mask(mode): - """ - Convert a Unix chmod symbolic mode like ``'ugo+rwx'`` to a function - suitable for applying to a mask to affect that change. - - >>> mask = _permission_mask('ugo+rwx') - >>> mask(0o554) == 0o777 - True - - >>> _permission_mask('go-x')(0o777) == 0o766 - True - - >>> _permission_mask('o-x')(0o445) == 0o444 - True - - >>> _permission_mask('a+x')(0) == 0o111 - True - - >>> _permission_mask('a=rw')(0o057) == 0o666 - True - - >>> _permission_mask('u=x')(0o666) == 0o166 - True - - >>> _permission_mask('g=')(0o157) == 0o107 - True - """ - # parse the symbolic mode - parsed = re.match('(?P[ugoa]+)(?P[-+=])(?P[rwx]*)$', mode) - if not parsed: - raise ValueError("Unrecognized symbolic mode", mode) - - # generate a mask representing the specified permission - spec_map = dict(r=4, w=2, x=1) - specs = (spec_map[perm] for perm in parsed.group('what')) - spec = functools.reduce(operator.or_, specs, 0) - - # now apply spec to each subject in who - shift_map = dict(u=6, g=3, o=0) - who = parsed.group('who').replace('a', 'ugo') - masks = (spec << shift_map[subj] for subj in who) - mask = functools.reduce(operator.or_, masks) - - op = parsed.group('op') - - # if op is -, invert the mask - if op == '-': - mask ^= 0o777 - - # if op is =, retain extant values for unreferenced subjects - if op == '=': - masks = (0o7 << shift_map[subj] for subj in who) - retain = functools.reduce(operator.or_, masks) ^ 0o777 - - op_map = { - '+': operator.or_, - '-': operator.and_, - '=': lambda mask, target: target & retain ^ mask, - } - return functools.partial(op_map[op], mask) - - -class CaseInsensitivePattern(matchers.CaseInsensitive): - def __init__(self, value): - warnings.warn( - "Use matchers.CaseInsensitive instead", - DeprecationWarning, - stacklevel=2, - ) - super(CaseInsensitivePattern, self).__init__(value) - - -class FastPath(Path): - def __init__(self, *args, **kwargs): - warnings.warn( - "Use Path, as FastPath no longer holds any advantage", - DeprecationWarning, - stacklevel=2, - ) - super(FastPath, self).__init__(*args, **kwargs) - - -def patch_for_linux_python2(): - """ - As reported in #130, when Linux users create filenames - not in the file system encoding, it creates problems on - Python 2. This function attempts to patch the os module - to make it behave more like that on Python 3. - """ - if not PY2 or platform.system() != 'Linux': - return - - try: - import backports.os - except ImportError: - return - - class OS: - """ - The proxy to the os module - """ - def __init__(self, wrapped): - self._orig = wrapped - - def __getattr__(self, name): - return getattr(self._orig, name) - - def listdir(self, *args, **kwargs): - items = self._orig.listdir(*args, **kwargs) - return list(map(backports.os.fsdecode, items)) - - globals().update(os=OS(os)) - - -patch_for_linux_python2() + @classmethod + def _resolve(cls, param): + if not callable(param) and param not in vars(Handlers): + raise ValueError("invalid errors parameter") + return vars(cls).get(param, param) diff --git a/libs/win/path/__init__.pyi b/libs/win/path/__init__.pyi new file mode 100644 index 00000000..a0b8f561 --- /dev/null +++ b/libs/win/path/__init__.pyi @@ -0,0 +1,483 @@ +from __future__ import annotations + +import builtins +import contextlib +import os +import shutil +import sys +from io import ( + BufferedRandom, + BufferedReader, + BufferedWriter, + FileIO, + TextIOWrapper, +) +from types import ModuleType, TracebackType +from typing import ( + Any, + AnyStr, + BinaryIO, + Callable, + Generator, + Iterable, + Iterator, + IO, + List, + Optional, + Set, + Tuple, + Type, + Union, + overload, +) + +from _typeshed import ( + OpenBinaryMode, + OpenBinaryModeUpdating, + OpenBinaryModeReading, + OpenBinaryModeWriting, + OpenTextMode, + Self, +) +from typing_extensions import Literal + +from . import classes + +# Type for the match argument for several methods +_Match = Optional[Union[str, Callable[[str], bool], Callable[[Path], bool]]] + +class TreeWalkWarning(Warning): + pass + +class Traversal: + follow: Callable[[Path], bool] + def __init__(self, follow: Callable[[Path], bool]): ... + def __call__( + self, + walker: Generator[Path, Optional[Callable[[], bool]], None], + ) -> Iterator[Path]: ... + +class Path(str): + module: Any + def __init__(self, other: Any = ...) -> None: ... + @classmethod + def using_module(cls, module: ModuleType) -> Type[Path]: ... + @classes.ClassProperty + @classmethod + def _next_class(cls: Type[Self]) -> Type[Self]: ... + def __repr__(self) -> str: ... + def __add__(self: Self, more: str) -> Self: ... + def __radd__(self: Self, other: str) -> Self: ... + def __div__(self: Self, rel: str) -> Self: ... + def __truediv__(self: Self, rel: str) -> Self: ... + def __rdiv__(self: Self, rel: str) -> Self: ... + def __rtruediv__(self: Self, rel: str) -> Self: ... + def __enter__(self: Self) -> Self: ... + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: ... + @classmethod + def getcwd(cls: Type[Self]) -> Self: ... + def abspath(self: Self) -> Self: ... + def normcase(self: Self) -> Self: ... + def normpath(self: Self) -> Self: ... + def realpath(self: Self) -> Self: ... + def expanduser(self: Self) -> Self: ... + def expandvars(self: Self) -> Self: ... + def dirname(self: Self) -> Self: ... + def basename(self: Self) -> Self: ... + def expand(self: Self) -> Self: ... + @property + def stem(self) -> str: ... + @property + def ext(self) -> str: ... + def with_suffix(self: Self, suffix: str) -> Self: ... + @property + def drive(self: Self) -> Self: ... + @property + def parent(self: Self) -> Self: ... + @property + def name(self: Self) -> Self: ... + def splitpath(self: Self) -> Tuple[Self, str]: ... + def splitdrive(self: Self) -> Tuple[Self, Self]: ... + def splitext(self: Self) -> Tuple[Self, str]: ... + def stripext(self: Self) -> Self: ... + @classes.multimethod + def joinpath(cls: Self, first: str, *others: str) -> Self: ... + def splitall(self: Self) -> List[Union[Self, str]]: ... + def parts(self: Self) -> Tuple[Union[Self, str], ...]: ... + def _parts(self: Self) -> Iterator[Union[Self, str]]: ... + def _parts_iter(self: Self) -> Iterator[Union[Self, str]]: ... + def relpath(self: Self, start: str = ...) -> Self: ... + def relpathto(self: Self, dest: str) -> Self: ... + # --- Listing, searching, walking, and matching + def listdir(self: Self, match: _Match = ...) -> List[Self]: ... + def dirs(self: Self, match: _Match = ...) -> List[Self]: ... + def files(self: Self, match: _Match = ...) -> List[Self]: ... + def walk( + self: Self, + match: _Match = ..., + errors: str = ..., + ) -> Generator[Self, Optional[Callable[[], bool]], None]: ... + def walkdirs( + self: Self, + match: _Match = ..., + errors: str = ..., + ) -> Iterator[Self]: ... + def walkfiles( + self: Self, + match: _Match = ..., + errors: str = ..., + ) -> Iterator[Self]: ... + def fnmatch( + self, + pattern: Union[Path, str], + normcase: Optional[Callable[[str], str]] = ..., + ) -> bool: ... + def glob(self: Self, pattern: str) -> List[Self]: ... + def iglob(self: Self, pattern: str) -> Iterator[Self]: ... + @overload + def open( + self, + mode: OpenTextMode = ..., + buffering: int = ..., + encoding: Optional[str] = ..., + errors: Optional[str] = ..., + newline: Optional[str] = ..., + closefd: bool = ..., + opener: Optional[Callable[[str, int], int]] = ..., + ) -> TextIOWrapper: ... + @overload + def open( + self, + mode: OpenBinaryMode, + buffering: Literal[0], + encoding: Optional[str] = ..., + errors: Optional[str] = ..., + newline: Optional[str] = ..., + closefd: bool = ..., + opener: Callable[[str, int], int] = ..., + ) -> FileIO: ... + @overload + def open( + self, + mode: OpenBinaryModeUpdating, + buffering: Literal[-1, 1] = ..., + encoding: Optional[str] = ..., + errors: Optional[str] = ..., + newline: Optional[str] = ..., + closefd: bool = ..., + opener: Callable[[str, int], int] = ..., + ) -> BufferedRandom: ... + @overload + def open( + self, + mode: OpenBinaryModeReading, + buffering: Literal[-1, 1] = ..., + encoding: Optional[str] = ..., + errors: Optional[str] = ..., + newline: Optional[str] = ..., + closefd: bool = ..., + opener: Callable[[str, int], int] = ..., + ) -> BufferedReader: ... + @overload + def open( + self, + mode: OpenBinaryModeWriting, + buffering: Literal[-1, 1] = ..., + encoding: Optional[str] = ..., + errors: Optional[str] = ..., + newline: Optional[str] = ..., + closefd: bool = ..., + opener: Callable[[str, int], int] = ..., + ) -> BufferedWriter: ... + @overload + def open( + self, + mode: OpenBinaryMode, + buffering: int, + encoding: Optional[str] = ..., + errors: Optional[str] = ..., + newline: Optional[str] = ..., + closefd: bool = ..., + opener: Callable[[str, int], int] = ..., + ) -> BinaryIO: ... + @overload + def open( + self, + mode: str, + buffering: int = ..., + encoding: Optional[str] = ..., + errors: Optional[str] = ..., + newline: Optional[str] = ..., + closefd: bool = ..., + opener: Callable[[str, int], int] = ..., + ) -> IO[Any]: ... + def bytes(self) -> builtins.bytes: ... + @overload + def chunks( + self, + size: int, + mode: OpenTextMode = ..., + buffering: int = ..., + encoding: Optional[str] = ..., + errors: Optional[str] = ..., + newline: Optional[str] = ..., + closefd: bool = ..., + opener: Optional[Callable[[str, int], int]] = ..., + ) -> Iterator[str]: ... + @overload + def chunks( + self, + size: int, + mode: OpenBinaryMode, + buffering: int = ..., + encoding: Optional[str] = ..., + errors: Optional[str] = ..., + newline: Optional[str] = ..., + closefd: bool = ..., + opener: Optional[Callable[[str, int], int]] = ..., + ) -> Iterator[builtins.bytes]: ... + @overload + def chunks( + self, + size: int, + mode: str, + buffering: int = ..., + encoding: Optional[str] = ..., + errors: Optional[str] = ..., + newline: Optional[str] = ..., + closefd: bool = ..., + opener: Optional[Callable[[str, int], int]] = ..., + ) -> Iterator[Union[str, builtins.bytes]]: ... + def write_bytes(self, bytes: builtins.bytes, append: bool = ...) -> None: ... + def read_text( + self, encoding: Optional[str] = ..., errors: Optional[str] = ... + ) -> str: ... + def read_bytes(self) -> builtins.bytes: ... + def text(self, encoding: Optional[str] = ..., errors: str = ...) -> str: ... + @overload + def write_text( + self, + text: str, + encoding: Optional[str] = ..., + errors: str = ..., + linesep: Optional[str] = ..., + append: bool = ..., + ) -> None: ... + @overload + def write_text( + self, + text: builtins.bytes, + encoding: None = ..., + errors: str = ..., + linesep: Optional[str] = ..., + append: bool = ..., + ) -> None: ... + def lines( + self, + encoding: Optional[str] = ..., + errors: Optional[str] = ..., + retain: bool = ..., + ) -> List[str]: ... + def write_lines( + self, + lines: List[str], + encoding: Optional[str] = ..., + errors: str = ..., + linesep: Optional[str] = ..., + append: bool = ..., + ) -> None: ... + def read_md5(self) -> builtins.bytes: ... + def read_hash(self, hash_name: str) -> builtins.bytes: ... + def read_hexhash(self, hash_name: str) -> str: ... + def isabs(self) -> bool: ... + def exists(self) -> bool: ... + def isdir(self) -> bool: ... + def isfile(self) -> bool: ... + def islink(self) -> bool: ... + def ismount(self) -> bool: ... + def samefile(self, other: str) -> bool: ... + def getatime(self) -> float: ... + @property + def atime(self) -> float: ... + def getmtime(self) -> float: ... + @property + def mtime(self) -> float: ... + def getctime(self) -> float: ... + @property + def ctime(self) -> float: ... + def getsize(self) -> int: ... + @property + def size(self) -> int: ... + def access( + self, + mode: int, + *, + dir_fd: Optional[int] = ..., + effective_ids: bool = ..., + follow_symlinks: bool = ..., + ) -> bool: ... + def stat(self) -> os.stat_result: ... + def lstat(self) -> os.stat_result: ... + def get_owner(self) -> str: ... + @property + def owner(self) -> str: ... + if sys.platform != 'win32': + def statvfs(self) -> os.statvfs_result: ... + def pathconf(self, name: Union[str, int]) -> int: ... + + def utime( + self, + times: Union[Tuple[int, int], Tuple[float, float], None] = ..., + *, + ns: Tuple[int, int] = ..., + dir_fd: Optional[int] = ..., + follow_symlinks: bool = ..., + ) -> Path: ... + def chmod(self: Self, mode: Union[str, int]) -> Self: ... + if sys.platform != 'win32': + def chown( + self: Self, uid: Union[int, str] = ..., gid: Union[int, str] = ... + ) -> Self: ... + + def rename(self: Self, new: str) -> Self: ... + def renames(self: Self, new: str) -> Self: ... + def mkdir(self: Self, mode: int = ...) -> Self: ... + def mkdir_p(self: Self, mode: int = ...) -> Self: ... + def makedirs(self: Self, mode: int = ...) -> Self: ... + def makedirs_p(self: Self, mode: int = ...) -> Self: ... + def rmdir(self: Self) -> Self: ... + def rmdir_p(self: Self) -> Self: ... + def removedirs(self: Self) -> Self: ... + def removedirs_p(self: Self) -> Self: ... + def touch(self: Self) -> Self: ... + def remove(self: Self) -> Self: ... + def remove_p(self: Self) -> Self: ... + def unlink(self: Self) -> Self: ... + def unlink_p(self: Self) -> Self: ... + def link(self: Self, newpath: str) -> Self: ... + def symlink(self: Self, newlink: Optional[str] = ...) -> Self: ... + def readlink(self: Self) -> Self: ... + def readlinkabs(self: Self) -> Self: ... + def copyfile(self, dst: str, *, follow_symlinks: bool = ...) -> str: ... + def copymode(self, dst: str, *, follow_symlinks: bool = ...) -> None: ... + def copystat(self, dst: str, *, follow_symlinks: bool = ...) -> None: ... + def copy(self, dst: str, *, follow_symlinks: bool = ...) -> Any: ... + def copy2(self, dst: str, *, follow_symlinks: bool = ...) -> Any: ... + def copytree( + self, + dst: str, + symlinks: bool = ..., + ignore: Optional[Callable[[str, list[str]], Iterable[str]]] = ..., + copy_function: Callable[[str, str], None] = ..., + ignore_dangling_symlinks: bool = ..., + dirs_exist_ok: bool = ..., + ) -> Any: ... + def move( + self, dst: str, copy_function: Callable[[str, str], None] = ... + ) -> Any: ... + def rmtree( + self, + ignore_errors: bool = ..., + onerror: Optional[Callable[[Any, Any, Any], Any]] = ..., + ) -> None: ... + def rmtree_p(self: Self) -> Self: ... + def chdir(self) -> None: ... + def cd(self) -> None: ... + def merge_tree( + self, + dst: str, + symlinks: bool = ..., + *, + copy_function: Callable[[str, str], None] = ..., + ignore: Callable[[Any, List[str]], Union[List[str], Set[str]]] = ..., + ) -> None: ... + if sys.platform != 'win32': + def chroot(self) -> None: ... + if sys.platform == 'win32': + def startfile(self: Self, operation: Optional[str] = ...) -> Self: ... + + @contextlib.contextmanager + def in_place( + self, + mode: str = ..., + buffering: int = ..., + encoding: Optional[str] = ..., + errors: Optional[str] = ..., + newline: Optional[str] = ..., + backup_extension: Optional[str] = ..., + ) -> Iterator[Tuple[IO[Any], IO[Any]]]: ... + @classes.ClassProperty + @classmethod + def special(cls) -> Callable[[Optional[str]], SpecialResolver]: ... + +class DirectoryNotEmpty(OSError): + @staticmethod + def translate() -> Iterator[None]: ... + +def only_newer(copy_func: Callable[[str, str], None]) -> Callable[[str, str], None]: ... + +class ExtantPath(Path): + def _validate(self) -> None: ... + +class ExtantFile(Path): + def _validate(self) -> None: ... + +class SpecialResolver: + class ResolverScope: + def __init__(self, paths: SpecialResolver, scope: str) -> None: ... + def __getattr__(self, class_: str) -> MultiPathType: ... + + def __init__( + self, + path_class: type, + appname: Optional[str] = ..., + appauthor: Optional[str] = ..., + version: Optional[str] = ..., + roaming: bool = ..., + multipath: bool = ..., + ): ... + def __getattr__(self, scope: str) -> ResolverScope: ... + def get_dir(self, scope: str, class_: str) -> MultiPathType: ... + +class Multi: + @classmethod + def for_class(cls, path_cls: type) -> Type[MultiPathType]: ... + @classmethod + def detect(cls, input: str) -> MultiPathType: ... + def __iter__(self) -> Iterator[Path]: ... + @classes.ClassProperty + @classmethod + def _next_class(cls) -> Type[Path]: ... + +class MultiPathType(Multi, Path): + pass + +class TempDir(Path): + @classes.ClassProperty + @classmethod + def _next_class(cls) -> Type[Path]: ... + def __new__( + cls: Type[Self], + suffix: Optional[AnyStr] = ..., + prefix: Optional[AnyStr] = ..., + dir: Optional[Union[AnyStr, os.PathLike[AnyStr]]] = ..., + ) -> Self: ... + def __init__(self) -> None: ... + def __enter__(self) -> Path: ... # type: ignore + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: ... + +class Handlers: + @classmethod + def _resolve( + cls, param: Union[str, Callable[[str], None]] + ) -> Callable[[str], None]: ... diff --git a/libs/win/path/classes.py b/libs/win/path/classes.py new file mode 100644 index 00000000..b6101d0a --- /dev/null +++ b/libs/win/path/classes.py @@ -0,0 +1,27 @@ +import functools + + +class ClassProperty(property): + def __get__(self, cls, owner): + return self.fget.__get__(None, owner)() + + +class multimethod: + """ + Acts like a classmethod when invoked from the class and like an + instancemethod when invoked from the instance. + """ + + def __init__(self, func): + self.func = func + + def __get__(self, instance, owner): + """ + If called on an instance, pass the instance as the first + argument. + """ + return ( + functools.partial(self.func, owner) + if instance is None + else functools.partial(self.func, owner, instance) + ) diff --git a/libs/win/path/classes.pyi b/libs/win/path/classes.pyi new file mode 100644 index 00000000..2878c48b --- /dev/null +++ b/libs/win/path/classes.pyi @@ -0,0 +1,8 @@ +from typing import Any, Callable, Optional + +class ClassProperty(property): + def __get__(self, cls: Any, owner: Optional[type] = ...) -> Any: ... + +class multimethod: + def __init__(self, func: Callable[..., Any]): ... + def __get__(self, instance: Any, owner: Optional[type]) -> Any: ... diff --git a/libs/win/path/masks.py b/libs/win/path/masks.py new file mode 100644 index 00000000..761e51f8 --- /dev/null +++ b/libs/win/path/masks.py @@ -0,0 +1,85 @@ +import re +import functools +import operator + + +# from jaraco.functools +def compose(*funcs): + compose_two = lambda f1, f2: lambda *args, **kwargs: f1(f2(*args, **kwargs)) # noqa + return functools.reduce(compose_two, funcs) + + +def compound(mode): + """ + Support multiple, comma-separated Unix chmod symbolic modes. + + >>> oct(compound('a=r,u+w')(0)) + '0o644' + """ + return compose(*map(simple, reversed(mode.split(',')))) + + +def simple(mode): + """ + Convert a Unix chmod symbolic mode like ``'ugo+rwx'`` to a function + suitable for applying to a mask to affect that change. + + >>> mask = simple('ugo+rwx') + >>> mask(0o554) == 0o777 + True + + >>> simple('go-x')(0o777) == 0o766 + True + + >>> simple('o-x')(0o445) == 0o444 + True + + >>> simple('a+x')(0) == 0o111 + True + + >>> simple('a=rw')(0o057) == 0o666 + True + + >>> simple('u=x')(0o666) == 0o166 + True + + >>> simple('g=')(0o157) == 0o107 + True + + >>> simple('gobbledeegook') + Traceback (most recent call last): + ValueError: ('Unrecognized symbolic mode', 'gobbledeegook') + """ + # parse the symbolic mode + parsed = re.match('(?P[ugoa]+)(?P[-+=])(?P[rwx]*)$', mode) + if not parsed: + raise ValueError("Unrecognized symbolic mode", mode) + + # generate a mask representing the specified permission + spec_map = dict(r=4, w=2, x=1) + specs = (spec_map[perm] for perm in parsed.group('what')) + spec = functools.reduce(operator.or_, specs, 0) + + # now apply spec to each subject in who + shift_map = dict(u=6, g=3, o=0) + who = parsed.group('who').replace('a', 'ugo') + masks = (spec << shift_map[subj] for subj in who) + mask = functools.reduce(operator.or_, masks) + + op = parsed.group('op') + + # if op is -, invert the mask + if op == '-': + mask ^= 0o777 + + # if op is =, retain extant values for unreferenced subjects + if op == '=': + masks = (0o7 << shift_map[subj] for subj in who) + retain = functools.reduce(operator.or_, masks) ^ 0o777 + + op_map = { + '+': operator.or_, + '-': operator.and_, + '=': lambda mask, target: target & retain ^ mask, + } + return functools.partial(op_map[op], mask) diff --git a/libs/win/path/masks.pyi b/libs/win/path/masks.pyi new file mode 100644 index 00000000..d69bf202 --- /dev/null +++ b/libs/win/path/masks.pyi @@ -0,0 +1,5 @@ +from typing import Any, Callable + +def compose(*funcs: Callable[..., Any]) -> Callable[..., Any]: ... +def compound(mode: str) -> Callable[[int], int]: ... +def simple(mode: str) -> Callable[[int], int]: ... diff --git a/libs/win/path/matchers.py b/libs/win/path/matchers.py new file mode 100644 index 00000000..63ca218a --- /dev/null +++ b/libs/win/path/matchers.py @@ -0,0 +1,59 @@ +import ntpath +import fnmatch + + +def load(param): + """ + If the supplied parameter is a string, assume it's a simple + pattern. + """ + return ( + Pattern(param) + if isinstance(param, str) + else param + if param is not None + else Null() + ) + + +class Base: + pass + + +class Null(Base): + def __call__(self, path): + return True + + +class Pattern(Base): + def __init__(self, pattern): + self.pattern = pattern + + def get_pattern(self, normcase): + try: + return self._pattern + except AttributeError: + pass + self._pattern = normcase(self.pattern) + return self._pattern + + def __call__(self, path): + normcase = getattr(self, 'normcase', path.module.normcase) + pattern = self.get_pattern(normcase) + return fnmatch.fnmatchcase(normcase(path.name), pattern) + + +class CaseInsensitive(Pattern): + """ + A Pattern with a ``'normcase'`` property, suitable for passing to + :meth:`listdir`, :meth:`dirs`, :meth:`files`, :meth:`walk`, + :meth:`walkdirs`, or :meth:`walkfiles` to match case-insensitive. + + For example, to get all files ending in .py, .Py, .pY, or .PY in the + current directory:: + + from path import Path, matchers + Path('.').files(matchers.CaseInsensitive('*.py')) + """ + + normcase = staticmethod(ntpath.normcase) diff --git a/libs/win/path/matchers.pyi b/libs/win/path/matchers.pyi new file mode 100644 index 00000000..80acd0b1 --- /dev/null +++ b/libs/win/path/matchers.pyi @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import Any, Callable, overload + +from typing_extensions import Literal + +from path import Path + +@overload +def load(param: None) -> Null: ... +@overload +def load(param: str) -> Pattern: ... +@overload +def load(param: Any) -> Any: ... + +class Base: + pass + +class Null(Base): + def __call__(self, path: str) -> Literal[True]: ... + +class Pattern(Base): + def __init__(self, pattern: str) -> None: ... + def get_pattern(self, normcase: Callable[[str], str]) -> str: ... + def __call__(self, path: Path) -> bool: ... + +class CaseInsensitive(Pattern): + normcase: Callable[[str], str] diff --git a/libs/win/path/py.typed b/libs/win/path/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/libs/win/path/py37compat.py b/libs/win/path/py37compat.py new file mode 100644 index 00000000..f2a9e8b4 --- /dev/null +++ b/libs/win/path/py37compat.py @@ -0,0 +1,125 @@ +import functools +import os + + +def best_realpath(module): + """ + Given a path module (i.e. ntpath, posixpath), + determine the best 'realpath' function to use + for best future compatibility. + """ + needs_backport = module.realpath is module.abspath + return realpath_backport if needs_backport else module.realpath + + +# backport taken from jaraco.windows 5 +def realpath_backport(path): + if isinstance(path, str): + prefix = '\\\\?\\' + unc_prefix = prefix + 'UNC' + new_unc_prefix = '\\' + cwd = os.getcwd() + else: + prefix = b'\\\\?\\' + unc_prefix = prefix + b'UNC' + new_unc_prefix = b'\\' + cwd = os.getcwdb() + had_prefix = path.startswith(prefix) + path, ok = _resolve_path(cwd, path, {}) + # The path returned by _getfinalpathname will always start with \\?\ - + # strip off that prefix unless it was already provided on the original + # path. + if not had_prefix: + # For UNC paths, the prefix will actually be \\?\UNC - handle that + # case as well. + if path.startswith(unc_prefix): + path = new_unc_prefix + path[len(unc_prefix) :] + elif path.startswith(prefix): + path = path[len(prefix) :] + return path + + +def _resolve_path(path, rest, seen): # noqa: C901 + # Windows normalizes the path before resolving symlinks; be sure to + # follow the same behavior. + rest = os.path.normpath(rest) + + if isinstance(rest, str): + sep = '\\' + else: + sep = b'\\' + + if os.path.isabs(rest): + drive, rest = os.path.splitdrive(rest) + path = drive + sep + rest = rest[1:] + + while rest: + name, _, rest = rest.partition(sep) + new_path = os.path.join(path, name) if path else name + if os.path.exists(new_path): + if not rest: + # The whole path exists. Resolve it using the OS. + path = os.path._getfinalpathname(new_path) + else: + # The OS can resolve `new_path`; keep traversing the path. + path = new_path + elif not os.path.lexists(new_path): + # `new_path` does not exist on the filesystem at all. Use the + # OS to resolve `path`, if it exists, and then append the + # remainder. + if os.path.exists(path): + path = os.path._getfinalpathname(path) + rest = os.path.join(name, rest) if rest else name + return os.path.join(path, rest), True + else: + # We have a symbolic link that the OS cannot resolve. Try to + # resolve it ourselves. + + # On Windows, symbolic link resolution can be partially or + # fully disabled [1]. The end result of a disabled symlink + # appears the same as a broken symlink (lexists() returns True + # but exists() returns False). And in both cases, the link can + # still be read using readlink(). Call stat() and check the + # resulting error code to ensure we don't circumvent the + # Windows symbolic link restrictions. + # [1] https://technet.microsoft.com/en-us/library/cc754077.aspx + try: + os.stat(new_path) + except OSError as e: + # WinError 1463: The symbolic link cannot be followed + # because its type is disabled. + if e.winerror == 1463: + raise + + key = os.path.normcase(new_path) + if key in seen: + # This link has already been seen; try to use the + # previously resolved value. + path = seen[key] + if path is None: + # It has not yet been resolved, which means we must + # have a symbolic link loop. Return what we have + # resolved so far plus the remainder of the path (who + # cares about the Zen of Python?). + path = os.path.join(new_path, rest) if rest else new_path + return path, False + else: + # Mark this link as in the process of being resolved. + seen[key] = None + # Try to resolve it. + path, ok = _resolve_path(path, os.readlink(new_path), seen) + if ok: + # Resolution succeded; store the resolved value. + seen[key] = path + else: + # Resolution failed; punt. + return (os.path.join(path, rest) if rest else path), False + return path, True + + +def lru_cache(user_function): + """ + Support for lru_cache(user_function) + """ + return functools.lru_cache()(user_function) diff --git a/libs/win/path/py37compat.pyi b/libs/win/path/py37compat.pyi new file mode 100644 index 00000000..ea62fa06 --- /dev/null +++ b/libs/win/path/py37compat.pyi @@ -0,0 +1,17 @@ +import os + +from types import ModuleType +from typing import Any, AnyStr, Callable, Dict, Tuple, overload + +def best_realpath(module: ModuleType) -> Callable[[AnyStr], AnyStr]: ... +@overload +def realpath_backport(path: str) -> str: ... +@overload +def realpath_backport(path: bytes) -> bytes: ... +@overload +def _resolve_path(path: str, rest: str, seen: Dict[Any, Any]) -> Tuple[str, bool]: ... +@overload +def _resolve_path( + path: bytes, rest: bytes, seen: Dict[Any, Any] +) -> Tuple[bytes, bool]: ... +def lru_cache(user_function: Callable[..., Any]) -> Callable[..., Any]: ... diff --git a/libs/win/pydantic/__init__.cp37-win_amd64.pyd b/libs/win/pydantic/__init__.cp37-win_amd64.pyd new file mode 100644 index 00000000..e6293456 Binary files /dev/null and b/libs/win/pydantic/__init__.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/__init__.py b/libs/win/pydantic/__init__.py new file mode 100644 index 00000000..3bf1418f --- /dev/null +++ b/libs/win/pydantic/__init__.py @@ -0,0 +1,131 @@ +# flake8: noqa +from . import dataclasses +from .annotated_types import create_model_from_namedtuple, create_model_from_typeddict +from .class_validators import root_validator, validator +from .config import BaseConfig, ConfigDict, Extra +from .decorator import validate_arguments +from .env_settings import BaseSettings +from .error_wrappers import ValidationError +from .errors import * +from .fields import Field, PrivateAttr, Required +from .main import * +from .networks import * +from .parse import Protocol +from .tools import * +from .types import * +from .version import VERSION, compiled + +__version__ = VERSION + +# WARNING __all__ from .errors is not included here, it will be removed as an export here in v2 +# please use "from pydantic.errors import ..." instead +__all__ = [ + # annotated types utils + 'create_model_from_namedtuple', + 'create_model_from_typeddict', + # dataclasses + 'dataclasses', + # class_validators + 'root_validator', + 'validator', + # config + 'BaseConfig', + 'ConfigDict', + 'Extra', + # decorator + 'validate_arguments', + # env_settings + 'BaseSettings', + # error_wrappers + 'ValidationError', + # fields + 'Field', + 'Required', + # main + 'BaseModel', + 'create_model', + 'validate_model', + # network + 'AnyUrl', + 'AnyHttpUrl', + 'FileUrl', + 'HttpUrl', + 'stricturl', + 'EmailStr', + 'NameEmail', + 'IPvAnyAddress', + 'IPvAnyInterface', + 'IPvAnyNetwork', + 'PostgresDsn', + 'CockroachDsn', + 'AmqpDsn', + 'RedisDsn', + 'MongoDsn', + 'KafkaDsn', + 'validate_email', + # parse + 'Protocol', + # tools + 'parse_file_as', + 'parse_obj_as', + 'parse_raw_as', + 'schema_of', + 'schema_json_of', + # types + 'NoneStr', + 'NoneBytes', + 'StrBytes', + 'NoneStrBytes', + 'StrictStr', + 'ConstrainedBytes', + 'conbytes', + 'ConstrainedList', + 'conlist', + 'ConstrainedSet', + 'conset', + 'ConstrainedFrozenSet', + 'confrozenset', + 'ConstrainedStr', + 'constr', + 'PyObject', + 'ConstrainedInt', + 'conint', + 'PositiveInt', + 'NegativeInt', + 'NonNegativeInt', + 'NonPositiveInt', + 'ConstrainedFloat', + 'confloat', + 'PositiveFloat', + 'NegativeFloat', + 'NonNegativeFloat', + 'NonPositiveFloat', + 'FiniteFloat', + 'ConstrainedDecimal', + 'condecimal', + 'ConstrainedDate', + 'condate', + 'UUID1', + 'UUID3', + 'UUID4', + 'UUID5', + 'FilePath', + 'DirectoryPath', + 'Json', + 'JsonWrapper', + 'SecretField', + 'SecretStr', + 'SecretBytes', + 'StrictBool', + 'StrictBytes', + 'StrictInt', + 'StrictFloat', + 'PaymentCardNumber', + 'PrivateAttr', + 'ByteSize', + 'PastDate', + 'FutureDate', + # version + 'compiled', + 'VERSION', +] diff --git a/libs/win/pydantic/_hypothesis_plugin.cp37-win_amd64.pyd b/libs/win/pydantic/_hypothesis_plugin.cp37-win_amd64.pyd new file mode 100644 index 00000000..480ba0ae Binary files /dev/null and b/libs/win/pydantic/_hypothesis_plugin.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/_hypothesis_plugin.py b/libs/win/pydantic/_hypothesis_plugin.py new file mode 100644 index 00000000..a56d2b98 --- /dev/null +++ b/libs/win/pydantic/_hypothesis_plugin.py @@ -0,0 +1,386 @@ +""" +Register Hypothesis strategies for Pydantic custom types. + +This enables fully-automatic generation of test data for most Pydantic classes. + +Note that this module has *no* runtime impact on Pydantic itself; instead it +is registered as a setuptools entry point and Hypothesis will import it if +Pydantic is installed. See also: + +https://hypothesis.readthedocs.io/en/latest/strategies.html#registering-strategies-via-setuptools-entry-points +https://hypothesis.readthedocs.io/en/latest/data.html#hypothesis.strategies.register_type_strategy +https://hypothesis.readthedocs.io/en/latest/strategies.html#interaction-with-pytest-cov +https://pydantic-docs.helpmanual.io/usage/types/#pydantic-types + +Note that because our motivation is to *improve user experience*, the strategies +are always sound (never generate invalid data) but sacrifice completeness for +maintainability (ie may be unable to generate some tricky but valid data). + +Finally, this module makes liberal use of `# type: ignore[]` pragmas. +This is because Hypothesis annotates `register_type_strategy()` with +`(T, SearchStrategy[T])`, but in most cases we register e.g. `ConstrainedInt` +to generate instances of the builtin `int` type which match the constraints. +""" + +import contextlib +import datetime +import ipaddress +import json +import math +from fractions import Fraction +from typing import Callable, Dict, Type, Union, cast, overload + +import hypothesis.strategies as st + +import pydantic +import pydantic.color +import pydantic.types +from pydantic.utils import lenient_issubclass + +# FilePath and DirectoryPath are explicitly unsupported, as we'd have to create +# them on-disk, and that's unsafe in general without being told *where* to do so. +# +# URLs are unsupported because it's easy for users to define their own strategy for +# "normal" URLs, and hard for us to define a general strategy which includes "weird" +# URLs but doesn't also have unpredictable performance problems. +# +# conlist() and conset() are unsupported for now, because the workarounds for +# Cython and Hypothesis to handle parametrized generic types are incompatible. +# Once Cython can support 'normal' generics we'll revisit this. + +# Emails +try: + import email_validator +except ImportError: # pragma: no cover + pass +else: + + def is_valid_email(s: str) -> bool: + # Hypothesis' st.emails() occasionally generates emails like 0@A0--0.ac + # that are invalid according to email-validator, so we filter those out. + try: + email_validator.validate_email(s, check_deliverability=False) + return True + except email_validator.EmailNotValidError: # pragma: no cover + return False + + # Note that these strategies deliberately stay away from any tricky Unicode + # or other encoding issues; we're just trying to generate *something* valid. + st.register_type_strategy(pydantic.EmailStr, st.emails().filter(is_valid_email)) # type: ignore[arg-type] + st.register_type_strategy( + pydantic.NameEmail, + st.builds( + '{} <{}>'.format, # type: ignore[arg-type] + st.from_regex('[A-Za-z0-9_]+( [A-Za-z0-9_]+){0,5}', fullmatch=True), + st.emails().filter(is_valid_email), + ), + ) + +# PyObject - dotted names, in this case taken from the math module. +st.register_type_strategy( + pydantic.PyObject, # type: ignore[arg-type] + st.sampled_from( + [cast(pydantic.PyObject, f'math.{name}') for name in sorted(vars(math)) if not name.startswith('_')] + ), +) + +# CSS3 Colors; as name, hex, rgb(a) tuples or strings, or hsl strings +_color_regexes = ( + '|'.join( + ( + pydantic.color.r_hex_short, + pydantic.color.r_hex_long, + pydantic.color.r_rgb, + pydantic.color.r_rgba, + pydantic.color.r_hsl, + pydantic.color.r_hsla, + ) + ) + # Use more precise regex patterns to avoid value-out-of-range errors + .replace(pydantic.color._r_sl, r'(?:(\d\d?(?:\.\d+)?|100(?:\.0+)?)%)') + .replace(pydantic.color._r_alpha, r'(?:(0(?:\.\d+)?|1(?:\.0+)?|\.\d+|\d{1,2}%))') + .replace(pydantic.color._r_255, r'(?:((?:\d|\d\d|[01]\d\d|2[0-4]\d|25[0-4])(?:\.\d+)?|255(?:\.0+)?))') +) +st.register_type_strategy( + pydantic.color.Color, + st.one_of( + st.sampled_from(sorted(pydantic.color.COLORS_BY_NAME)), + st.tuples( + st.integers(0, 255), + st.integers(0, 255), + st.integers(0, 255), + st.none() | st.floats(0, 1) | st.floats(0, 100).map('{}%'.format), + ), + st.from_regex(_color_regexes, fullmatch=True), + ), +) + + +# Card numbers, valid according to the Luhn algorithm + + +def add_luhn_digit(card_number: str) -> str: + # See https://en.wikipedia.org/wiki/Luhn_algorithm + for digit in '0123456789': + with contextlib.suppress(Exception): + pydantic.PaymentCardNumber.validate_luhn_check_digit(card_number + digit) + return card_number + digit + raise AssertionError('Unreachable') # pragma: no cover + + +card_patterns = ( + # Note that these patterns omit the Luhn check digit; that's added by the function above + '4[0-9]{14}', # Visa + '5[12345][0-9]{13}', # Mastercard + '3[47][0-9]{12}', # American Express + '[0-26-9][0-9]{10,17}', # other (incomplete to avoid overlap) +) +st.register_type_strategy( + pydantic.PaymentCardNumber, + st.from_regex('|'.join(card_patterns), fullmatch=True).map(add_luhn_digit), # type: ignore[arg-type] +) + +# UUIDs +st.register_type_strategy(pydantic.UUID1, st.uuids(version=1)) +st.register_type_strategy(pydantic.UUID3, st.uuids(version=3)) +st.register_type_strategy(pydantic.UUID4, st.uuids(version=4)) +st.register_type_strategy(pydantic.UUID5, st.uuids(version=5)) + +# Secrets +st.register_type_strategy(pydantic.SecretBytes, st.binary().map(pydantic.SecretBytes)) +st.register_type_strategy(pydantic.SecretStr, st.text().map(pydantic.SecretStr)) + +# IP addresses, networks, and interfaces +st.register_type_strategy(pydantic.IPvAnyAddress, st.ip_addresses()) # type: ignore[arg-type] +st.register_type_strategy( + pydantic.IPvAnyInterface, + st.from_type(ipaddress.IPv4Interface) | st.from_type(ipaddress.IPv6Interface), # type: ignore[arg-type] +) +st.register_type_strategy( + pydantic.IPvAnyNetwork, + st.from_type(ipaddress.IPv4Network) | st.from_type(ipaddress.IPv6Network), # type: ignore[arg-type] +) + +# We hook into the con***() functions and the ConstrainedNumberMeta metaclass, +# so here we only have to register subclasses for other constrained types which +# don't go via those mechanisms. Then there are the registration hooks below. +st.register_type_strategy(pydantic.StrictBool, st.booleans()) +st.register_type_strategy(pydantic.StrictStr, st.text()) + + +# Constrained-type resolver functions +# +# For these ones, we actually want to inspect the type in order to work out a +# satisfying strategy. First up, the machinery for tracking resolver functions: + +RESOLVERS: Dict[type, Callable[[type], st.SearchStrategy]] = {} # type: ignore[type-arg] + + +@overload +def _registered(typ: Type[pydantic.types.T]) -> Type[pydantic.types.T]: + pass + + +@overload +def _registered(typ: pydantic.types.ConstrainedNumberMeta) -> pydantic.types.ConstrainedNumberMeta: + pass + + +def _registered( + typ: Union[Type[pydantic.types.T], pydantic.types.ConstrainedNumberMeta] +) -> Union[Type[pydantic.types.T], pydantic.types.ConstrainedNumberMeta]: + # This function replaces the version in `pydantic.types`, in order to + # effect the registration of new constrained types so that Hypothesis + # can generate valid examples. + pydantic.types._DEFINED_TYPES.add(typ) + for supertype, resolver in RESOLVERS.items(): + if issubclass(typ, supertype): + st.register_type_strategy(typ, resolver(typ)) # type: ignore + return typ + raise NotImplementedError(f'Unknown type {typ!r} has no resolver to register') # pragma: no cover + + +def resolves( + typ: Union[type, pydantic.types.ConstrainedNumberMeta] +) -> Callable[[Callable[..., st.SearchStrategy]], Callable[..., st.SearchStrategy]]: # type: ignore[type-arg] + def inner(f): # type: ignore + assert f not in RESOLVERS + RESOLVERS[typ] = f + return f + + return inner + + +# Type-to-strategy resolver functions + + +@resolves(pydantic.JsonWrapper) +def resolve_json(cls): # type: ignore[no-untyped-def] + try: + inner = st.none() if cls.inner_type is None else st.from_type(cls.inner_type) + except Exception: # pragma: no cover + finite = st.floats(allow_infinity=False, allow_nan=False) + inner = st.recursive( + base=st.one_of(st.none(), st.booleans(), st.integers(), finite, st.text()), + extend=lambda x: st.lists(x) | st.dictionaries(st.text(), x), # type: ignore + ) + inner_type = getattr(cls, 'inner_type', None) + return st.builds( + cls.inner_type.json if lenient_issubclass(inner_type, pydantic.BaseModel) else json.dumps, + inner, + ensure_ascii=st.booleans(), + indent=st.none() | st.integers(0, 16), + sort_keys=st.booleans(), + ) + + +@resolves(pydantic.ConstrainedBytes) +def resolve_conbytes(cls): # type: ignore[no-untyped-def] # pragma: no cover + min_size = cls.min_length or 0 + max_size = cls.max_length + if not cls.strip_whitespace: + return st.binary(min_size=min_size, max_size=max_size) + # Fun with regex to ensure we neither start nor end with whitespace + repeats = '{{{},{}}}'.format( + min_size - 2 if min_size > 2 else 0, + max_size - 2 if (max_size or 0) > 2 else '', + ) + if min_size >= 2: + pattern = rf'\W.{repeats}\W' + elif min_size == 1: + pattern = rf'\W(.{repeats}\W)?' + else: + assert min_size == 0 + pattern = rf'(\W(.{repeats}\W)?)?' + return st.from_regex(pattern.encode(), fullmatch=True) + + +@resolves(pydantic.ConstrainedDecimal) +def resolve_condecimal(cls): # type: ignore[no-untyped-def] + min_value = cls.ge + max_value = cls.le + if cls.gt is not None: + assert min_value is None, 'Set `gt` or `ge`, but not both' + min_value = cls.gt + if cls.lt is not None: + assert max_value is None, 'Set `lt` or `le`, but not both' + max_value = cls.lt + s = st.decimals(min_value, max_value, allow_nan=False, places=cls.decimal_places) + if cls.lt is not None: + s = s.filter(lambda d: d < cls.lt) + if cls.gt is not None: + s = s.filter(lambda d: cls.gt < d) + return s + + +@resolves(pydantic.ConstrainedFloat) +def resolve_confloat(cls): # type: ignore[no-untyped-def] + min_value = cls.ge + max_value = cls.le + exclude_min = False + exclude_max = False + + if cls.gt is not None: + assert min_value is None, 'Set `gt` or `ge`, but not both' + min_value = cls.gt + exclude_min = True + if cls.lt is not None: + assert max_value is None, 'Set `lt` or `le`, but not both' + max_value = cls.lt + exclude_max = True + + if cls.multiple_of is None: + return st.floats(min_value, max_value, exclude_min=exclude_min, exclude_max=exclude_max, allow_nan=False) + + if min_value is not None: + min_value = math.ceil(min_value / cls.multiple_of) + if exclude_min: + min_value = min_value + 1 + if max_value is not None: + assert max_value >= cls.multiple_of, 'Cannot build model with max value smaller than multiple of' + max_value = math.floor(max_value / cls.multiple_of) + if exclude_max: + max_value = max_value - 1 + + return st.integers(min_value, max_value).map(lambda x: x * cls.multiple_of) + + +@resolves(pydantic.ConstrainedInt) +def resolve_conint(cls): # type: ignore[no-untyped-def] + min_value = cls.ge + max_value = cls.le + if cls.gt is not None: + assert min_value is None, 'Set `gt` or `ge`, but not both' + min_value = cls.gt + 1 + if cls.lt is not None: + assert max_value is None, 'Set `lt` or `le`, but not both' + max_value = cls.lt - 1 + + if cls.multiple_of is None or cls.multiple_of == 1: + return st.integers(min_value, max_value) + + # These adjustments and the .map handle integer-valued multiples, while the + # .filter handles trickier cases as for confloat. + if min_value is not None: + min_value = math.ceil(Fraction(min_value) / Fraction(cls.multiple_of)) + if max_value is not None: + max_value = math.floor(Fraction(max_value) / Fraction(cls.multiple_of)) + return st.integers(min_value, max_value).map(lambda x: x * cls.multiple_of) + + +@resolves(pydantic.ConstrainedDate) +def resolve_condate(cls): # type: ignore[no-untyped-def] + if cls.ge is not None: + assert cls.gt is None, 'Set `gt` or `ge`, but not both' + min_value = cls.ge + elif cls.gt is not None: + min_value = cls.gt + datetime.timedelta(days=1) + else: + min_value = datetime.date.min + if cls.le is not None: + assert cls.lt is None, 'Set `lt` or `le`, but not both' + max_value = cls.le + elif cls.lt is not None: + max_value = cls.lt - datetime.timedelta(days=1) + else: + max_value = datetime.date.max + return st.dates(min_value, max_value) + + +@resolves(pydantic.ConstrainedStr) +def resolve_constr(cls): # type: ignore[no-untyped-def] # pragma: no cover + min_size = cls.min_length or 0 + max_size = cls.max_length + + if cls.regex is None and not cls.strip_whitespace: + return st.text(min_size=min_size, max_size=max_size) + + if cls.regex is not None: + strategy = st.from_regex(cls.regex) + if cls.strip_whitespace: + strategy = strategy.filter(lambda s: s == s.strip()) + elif cls.strip_whitespace: + repeats = '{{{},{}}}'.format( + min_size - 2 if min_size > 2 else 0, + max_size - 2 if (max_size or 0) > 2 else '', + ) + if min_size >= 2: + strategy = st.from_regex(rf'\W.{repeats}\W') + elif min_size == 1: + strategy = st.from_regex(rf'\W(.{repeats}\W)?') + else: + assert min_size == 0 + strategy = st.from_regex(rf'(\W(.{repeats}\W)?)?') + + if min_size == 0 and max_size is None: + return strategy + elif max_size is None: + return strategy.filter(lambda s: min_size <= len(s)) + return strategy.filter(lambda s: min_size <= len(s) <= max_size) + + +# Finally, register all previously-defined types, and patch in our new function +for typ in list(pydantic.types._DEFINED_TYPES): + _registered(typ) +pydantic.types._registered = _registered +st.register_type_strategy(pydantic.Json, resolve_json) diff --git a/libs/win/pydantic/annotated_types.cp37-win_amd64.pyd b/libs/win/pydantic/annotated_types.cp37-win_amd64.pyd new file mode 100644 index 00000000..6c9f0f37 Binary files /dev/null and b/libs/win/pydantic/annotated_types.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/annotated_types.py b/libs/win/pydantic/annotated_types.py new file mode 100644 index 00000000..d333457f --- /dev/null +++ b/libs/win/pydantic/annotated_types.py @@ -0,0 +1,72 @@ +import sys +from typing import TYPE_CHECKING, Any, Dict, FrozenSet, NamedTuple, Type + +from .fields import Required +from .main import BaseModel, create_model +from .typing import is_typeddict, is_typeddict_special + +if TYPE_CHECKING: + from typing_extensions import TypedDict + +if sys.version_info < (3, 11): + + def is_legacy_typeddict(typeddict_cls: Type['TypedDict']) -> bool: # type: ignore[valid-type] + return is_typeddict(typeddict_cls) and type(typeddict_cls).__module__ == 'typing' + +else: + + def is_legacy_typeddict(_: Any) -> Any: + return False + + +def create_model_from_typeddict( + # Mypy bug: `Type[TypedDict]` is resolved as `Any` https://github.com/python/mypy/issues/11030 + typeddict_cls: Type['TypedDict'], # type: ignore[valid-type] + **kwargs: Any, +) -> Type['BaseModel']: + """ + Create a `BaseModel` based on the fields of a `TypedDict`. + Since `typing.TypedDict` in Python 3.8 does not store runtime information about optional keys, + we raise an error if this happens (see https://bugs.python.org/issue38834). + """ + field_definitions: Dict[str, Any] + + # Best case scenario: with python 3.9+ or when `TypedDict` is imported from `typing_extensions` + if not hasattr(typeddict_cls, '__required_keys__'): + raise TypeError( + 'You should use `typing_extensions.TypedDict` instead of `typing.TypedDict` with Python < 3.9.2. ' + 'Without it, there is no way to differentiate required and optional fields when subclassed.' + ) + + if is_legacy_typeddict(typeddict_cls) and any( + is_typeddict_special(t) for t in typeddict_cls.__annotations__.values() + ): + raise TypeError( + 'You should use `typing_extensions.TypedDict` instead of `typing.TypedDict` with Python < 3.11. ' + 'Without it, there is no way to reflect Required/NotRequired keys.' + ) + + required_keys: FrozenSet[str] = typeddict_cls.__required_keys__ # type: ignore[attr-defined] + field_definitions = { + field_name: (field_type, Required if field_name in required_keys else None) + for field_name, field_type in typeddict_cls.__annotations__.items() + } + + return create_model(typeddict_cls.__name__, **kwargs, **field_definitions) + + +def create_model_from_namedtuple(namedtuple_cls: Type['NamedTuple'], **kwargs: Any) -> Type['BaseModel']: + """ + Create a `BaseModel` based on the fields of a named tuple. + A named tuple can be created with `typing.NamedTuple` and declared annotations + but also with `collections.namedtuple`, in this case we consider all fields + to have type `Any`. + """ + # With python 3.10+, `__annotations__` always exists but can be empty hence the `getattr... or...` logic + namedtuple_annotations: Dict[str, Type[Any]] = getattr(namedtuple_cls, '__annotations__', None) or { + k: Any for k in namedtuple_cls._fields + } + field_definitions: Dict[str, Any] = { + field_name: (field_type, Required) for field_name, field_type in namedtuple_annotations.items() + } + return create_model(namedtuple_cls.__name__, **kwargs, **field_definitions) diff --git a/libs/win/pydantic/class_validators.cp37-win_amd64.pyd b/libs/win/pydantic/class_validators.cp37-win_amd64.pyd new file mode 100644 index 00000000..ff6f576c Binary files /dev/null and b/libs/win/pydantic/class_validators.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/class_validators.py b/libs/win/pydantic/class_validators.py new file mode 100644 index 00000000..87190610 --- /dev/null +++ b/libs/win/pydantic/class_validators.py @@ -0,0 +1,342 @@ +import warnings +from collections import ChainMap +from functools import wraps +from itertools import chain +from types import FunctionType +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, overload + +from .errors import ConfigError +from .typing import AnyCallable +from .utils import ROOT_KEY, in_ipython + +if TYPE_CHECKING: + from .typing import AnyClassMethod + + +class Validator: + __slots__ = 'func', 'pre', 'each_item', 'always', 'check_fields', 'skip_on_failure' + + def __init__( + self, + func: AnyCallable, + pre: bool = False, + each_item: bool = False, + always: bool = False, + check_fields: bool = False, + skip_on_failure: bool = False, + ): + self.func = func + self.pre = pre + self.each_item = each_item + self.always = always + self.check_fields = check_fields + self.skip_on_failure = skip_on_failure + + +if TYPE_CHECKING: + from inspect import Signature + + from .config import BaseConfig + from .fields import ModelField + from .types import ModelOrDc + + ValidatorCallable = Callable[[Optional[ModelOrDc], Any, Dict[str, Any], ModelField, Type[BaseConfig]], Any] + ValidatorsList = List[ValidatorCallable] + ValidatorListDict = Dict[str, List[Validator]] + +_FUNCS: Set[str] = set() +VALIDATOR_CONFIG_KEY = '__validator_config__' +ROOT_VALIDATOR_CONFIG_KEY = '__root_validator_config__' + + +def validator( + *fields: str, + pre: bool = False, + each_item: bool = False, + always: bool = False, + check_fields: bool = True, + whole: bool = None, + allow_reuse: bool = False, +) -> Callable[[AnyCallable], 'AnyClassMethod']: + """ + Decorate methods on the class indicating that they should be used to validate fields + :param fields: which field(s) the method should be called on + :param pre: whether or not this validator should be called before the standard validators (else after) + :param each_item: for complex objects (sets, lists etc.) whether to validate individual elements rather than the + whole object + :param always: whether this method and other validators should be called even if the value is missing + :param check_fields: whether to check that the fields actually exist on the model + :param allow_reuse: whether to track and raise an error if another validator refers to the decorated function + """ + if not fields: + raise ConfigError('validator with no fields specified') + elif isinstance(fields[0], FunctionType): + raise ConfigError( + "validators should be used with fields and keyword arguments, not bare. " # noqa: Q000 + "E.g. usage should be `@validator('', ...)`" + ) + elif not all(isinstance(field, str) for field in fields): + raise ConfigError( + "validator fields should be passed as separate string args. " # noqa: Q000 + "E.g. usage should be `@validator('', '', ...)`" + ) + + if whole is not None: + warnings.warn( + 'The "whole" keyword argument is deprecated, use "each_item" (inverse meaning, default False) instead', + DeprecationWarning, + ) + assert each_item is False, '"each_item" and "whole" conflict, remove "whole"' + each_item = not whole + + def dec(f: AnyCallable) -> 'AnyClassMethod': + f_cls = _prepare_validator(f, allow_reuse) + setattr( + f_cls, + VALIDATOR_CONFIG_KEY, + ( + fields, + Validator(func=f_cls.__func__, pre=pre, each_item=each_item, always=always, check_fields=check_fields), + ), + ) + return f_cls + + return dec + + +@overload +def root_validator(_func: AnyCallable) -> 'AnyClassMethod': + ... + + +@overload +def root_validator( + *, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False +) -> Callable[[AnyCallable], 'AnyClassMethod']: + ... + + +def root_validator( + _func: Optional[AnyCallable] = None, *, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False +) -> Union['AnyClassMethod', Callable[[AnyCallable], 'AnyClassMethod']]: + """ + Decorate methods on a model indicating that they should be used to validate (and perhaps modify) data either + before or after standard model parsing/validation is performed. + """ + if _func: + f_cls = _prepare_validator(_func, allow_reuse) + setattr( + f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure) + ) + return f_cls + + def dec(f: AnyCallable) -> 'AnyClassMethod': + f_cls = _prepare_validator(f, allow_reuse) + setattr( + f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure) + ) + return f_cls + + return dec + + +def _prepare_validator(function: AnyCallable, allow_reuse: bool) -> 'AnyClassMethod': + """ + Avoid validators with duplicated names since without this, validators can be overwritten silently + which generally isn't the intended behaviour, don't run in ipython (see #312) or if allow_reuse is False. + """ + f_cls = function if isinstance(function, classmethod) else classmethod(function) + if not in_ipython() and not allow_reuse: + ref = f_cls.__func__.__module__ + '.' + f_cls.__func__.__qualname__ + if ref in _FUNCS: + raise ConfigError(f'duplicate validator function "{ref}"; if this is intended, set `allow_reuse=True`') + _FUNCS.add(ref) + return f_cls + + +class ValidatorGroup: + def __init__(self, validators: 'ValidatorListDict') -> None: + self.validators = validators + self.used_validators = {'*'} + + def get_validators(self, name: str) -> Optional[Dict[str, Validator]]: + self.used_validators.add(name) + validators = self.validators.get(name, []) + if name != ROOT_KEY: + validators += self.validators.get('*', []) + if validators: + return {v.func.__name__: v for v in validators} + else: + return None + + def check_for_unused(self) -> None: + unused_validators = set( + chain.from_iterable( + (v.func.__name__ for v in self.validators[f] if v.check_fields) + for f in (self.validators.keys() - self.used_validators) + ) + ) + if unused_validators: + fn = ', '.join(unused_validators) + raise ConfigError( + f"Validators defined with incorrect fields: {fn} " # noqa: Q000 + f"(use check_fields=False if you're inheriting from the model and intended this)" + ) + + +def extract_validators(namespace: Dict[str, Any]) -> Dict[str, List[Validator]]: + validators: Dict[str, List[Validator]] = {} + for var_name, value in namespace.items(): + validator_config = getattr(value, VALIDATOR_CONFIG_KEY, None) + if validator_config: + fields, v = validator_config + for field in fields: + if field in validators: + validators[field].append(v) + else: + validators[field] = [v] + return validators + + +def extract_root_validators(namespace: Dict[str, Any]) -> Tuple[List[AnyCallable], List[Tuple[bool, AnyCallable]]]: + from inspect import signature + + pre_validators: List[AnyCallable] = [] + post_validators: List[Tuple[bool, AnyCallable]] = [] + for name, value in namespace.items(): + validator_config: Optional[Validator] = getattr(value, ROOT_VALIDATOR_CONFIG_KEY, None) + if validator_config: + sig = signature(validator_config.func) + args = list(sig.parameters.keys()) + if args[0] == 'self': + raise ConfigError( + f'Invalid signature for root validator {name}: {sig}, "self" not permitted as first argument, ' + f'should be: (cls, values).' + ) + if len(args) != 2: + raise ConfigError(f'Invalid signature for root validator {name}: {sig}, should be: (cls, values).') + # check function signature + if validator_config.pre: + pre_validators.append(validator_config.func) + else: + post_validators.append((validator_config.skip_on_failure, validator_config.func)) + return pre_validators, post_validators + + +def inherit_validators(base_validators: 'ValidatorListDict', validators: 'ValidatorListDict') -> 'ValidatorListDict': + for field, field_validators in base_validators.items(): + if field not in validators: + validators[field] = [] + validators[field] += field_validators + return validators + + +def make_generic_validator(validator: AnyCallable) -> 'ValidatorCallable': + """ + Make a generic function which calls a validator with the right arguments. + + Unfortunately other approaches (eg. return a partial of a function that builds the arguments) is slow, + hence this laborious way of doing things. + + It's done like this so validators don't all need **kwargs in their signature, eg. any combination of + the arguments "values", "fields" and/or "config" are permitted. + """ + from inspect import signature + + sig = signature(validator) + args = list(sig.parameters.keys()) + first_arg = args.pop(0) + if first_arg == 'self': + raise ConfigError( + f'Invalid signature for validator {validator}: {sig}, "self" not permitted as first argument, ' + f'should be: (cls, value, values, config, field), "values", "config" and "field" are all optional.' + ) + elif first_arg == 'cls': + # assume the second argument is value + return wraps(validator)(_generic_validator_cls(validator, sig, set(args[1:]))) + else: + # assume the first argument was value which has already been removed + return wraps(validator)(_generic_validator_basic(validator, sig, set(args))) + + +def prep_validators(v_funcs: Iterable[AnyCallable]) -> 'ValidatorsList': + return [make_generic_validator(f) for f in v_funcs if f] + + +all_kwargs = {'values', 'field', 'config'} + + +def _generic_validator_cls(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable': + # assume the first argument is value + has_kwargs = False + if 'kwargs' in args: + has_kwargs = True + args -= {'kwargs'} + + if not args.issubset(all_kwargs): + raise ConfigError( + f'Invalid signature for validator {validator}: {sig}, should be: ' + f'(cls, value, values, config, field), "values", "config" and "field" are all optional.' + ) + + if has_kwargs: + return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config) + elif args == set(): + return lambda cls, v, values, field, config: validator(cls, v) + elif args == {'values'}: + return lambda cls, v, values, field, config: validator(cls, v, values=values) + elif args == {'field'}: + return lambda cls, v, values, field, config: validator(cls, v, field=field) + elif args == {'config'}: + return lambda cls, v, values, field, config: validator(cls, v, config=config) + elif args == {'values', 'field'}: + return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field) + elif args == {'values', 'config'}: + return lambda cls, v, values, field, config: validator(cls, v, values=values, config=config) + elif args == {'field', 'config'}: + return lambda cls, v, values, field, config: validator(cls, v, field=field, config=config) + else: + # args == {'values', 'field', 'config'} + return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config) + + +def _generic_validator_basic(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable': + has_kwargs = False + if 'kwargs' in args: + has_kwargs = True + args -= {'kwargs'} + + if not args.issubset(all_kwargs): + raise ConfigError( + f'Invalid signature for validator {validator}: {sig}, should be: ' + f'(value, values, config, field), "values", "config" and "field" are all optional.' + ) + + if has_kwargs: + return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config) + elif args == set(): + return lambda cls, v, values, field, config: validator(v) + elif args == {'values'}: + return lambda cls, v, values, field, config: validator(v, values=values) + elif args == {'field'}: + return lambda cls, v, values, field, config: validator(v, field=field) + elif args == {'config'}: + return lambda cls, v, values, field, config: validator(v, config=config) + elif args == {'values', 'field'}: + return lambda cls, v, values, field, config: validator(v, values=values, field=field) + elif args == {'values', 'config'}: + return lambda cls, v, values, field, config: validator(v, values=values, config=config) + elif args == {'field', 'config'}: + return lambda cls, v, values, field, config: validator(v, field=field, config=config) + else: + # args == {'values', 'field', 'config'} + return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config) + + +def gather_all_validators(type_: 'ModelOrDc') -> Dict[str, 'AnyClassMethod']: + all_attributes = ChainMap(*[cls.__dict__ for cls in type_.__mro__]) # type: ignore[arg-type,var-annotated] + return { + k: v + for k, v in all_attributes.items() + if hasattr(v, VALIDATOR_CONFIG_KEY) or hasattr(v, ROOT_VALIDATOR_CONFIG_KEY) + } diff --git a/libs/win/pydantic/color.cp37-win_amd64.pyd b/libs/win/pydantic/color.cp37-win_amd64.pyd new file mode 100644 index 00000000..eaea390c Binary files /dev/null and b/libs/win/pydantic/color.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/color.py b/libs/win/pydantic/color.py new file mode 100644 index 00000000..6fdc9fb1 --- /dev/null +++ b/libs/win/pydantic/color.py @@ -0,0 +1,494 @@ +""" +Color definitions are used as per CSS3 specification: +http://www.w3.org/TR/css3-color/#svg-color + +A few colors have multiple names referring to the sames colors, eg. `grey` and `gray` or `aqua` and `cyan`. + +In these cases the LAST color when sorted alphabetically takes preferences, +eg. Color((0, 255, 255)).as_named() == 'cyan' because "cyan" comes after "aqua". +""" +import math +import re +from colorsys import hls_to_rgb, rgb_to_hls +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast + +from .errors import ColorError +from .utils import Representation, almost_equal_floats + +if TYPE_CHECKING: + from .typing import CallableGenerator, ReprArgs + +ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]] +ColorType = Union[ColorTuple, str] +HslColorTuple = Union[Tuple[float, float, float], Tuple[float, float, float, float]] + + +class RGBA: + """ + Internal use only as a representation of a color. + """ + + __slots__ = 'r', 'g', 'b', 'alpha', '_tuple' + + def __init__(self, r: float, g: float, b: float, alpha: Optional[float]): + self.r = r + self.g = g + self.b = b + self.alpha = alpha + + self._tuple: Tuple[float, float, float, Optional[float]] = (r, g, b, alpha) + + def __getitem__(self, item: Any) -> Any: + return self._tuple[item] + + +# these are not compiled here to avoid import slowdown, they'll be compiled the first time they're used, then cached +r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*' +r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*' +_r_255 = r'(\d{1,3}(?:\.\d+)?)' +_r_comma = r'\s*,\s*' +r_rgb = fr'\s*rgb\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}\)\s*' +_r_alpha = r'(\d(?:\.\d+)?|\.\d+|\d{1,2}%)' +r_rgba = fr'\s*rgba\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_alpha}\s*\)\s*' +_r_h = r'(-?\d+(?:\.\d+)?|-?\.\d+)(deg|rad|turn)?' +_r_sl = r'(\d{1,3}(?:\.\d+)?)%' +r_hsl = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}\s*\)\s*' +r_hsla = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}{_r_comma}{_r_alpha}\s*\)\s*' + +# colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used +repeat_colors = {int(c * 2, 16) for c in '0123456789abcdef'} +rads = 2 * math.pi + + +class Color(Representation): + __slots__ = '_original', '_rgba' + + def __init__(self, value: ColorType) -> None: + self._rgba: RGBA + self._original: ColorType + if isinstance(value, (tuple, list)): + self._rgba = parse_tuple(value) + elif isinstance(value, str): + self._rgba = parse_str(value) + elif isinstance(value, Color): + self._rgba = value._rgba + value = value._original + else: + raise ColorError(reason='value must be a tuple, list or string') + + # if we've got here value must be a valid color + self._original = value + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format='color') + + def original(self) -> ColorType: + """ + Original value passed to Color + """ + return self._original + + def as_named(self, *, fallback: bool = False) -> str: + if self._rgba.alpha is None: + rgb = cast(Tuple[int, int, int], self.as_rgb_tuple()) + try: + return COLORS_BY_VALUE[rgb] + except KeyError as e: + if fallback: + return self.as_hex() + else: + raise ValueError('no named color found, use fallback=True, as_hex() or as_rgb()') from e + else: + return self.as_hex() + + def as_hex(self) -> str: + """ + Hex string representing the color can be 3, 4, 6 or 8 characters depending on whether the string + a "short" representation of the color is possible and whether there's an alpha channel. + """ + values = [float_to_255(c) for c in self._rgba[:3]] + if self._rgba.alpha is not None: + values.append(float_to_255(self._rgba.alpha)) + + as_hex = ''.join(f'{v:02x}' for v in values) + if all(c in repeat_colors for c in values): + as_hex = ''.join(as_hex[c] for c in range(0, len(as_hex), 2)) + return '#' + as_hex + + def as_rgb(self) -> str: + """ + Color as an rgb(, , ) or rgba(, , , ) string. + """ + if self._rgba.alpha is None: + return f'rgb({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)})' + else: + return ( + f'rgba({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)}, ' + f'{round(self._alpha_float(), 2)})' + ) + + def as_rgb_tuple(self, *, alpha: Optional[bool] = None) -> ColorTuple: + """ + Color as an RGB or RGBA tuple; red, green and blue are in the range 0 to 255, alpha if included is + in the range 0 to 1. + + :param alpha: whether to include the alpha channel, options are + None - (default) include alpha only if it's set (e.g. not None) + True - always include alpha, + False - always omit alpha, + """ + r, g, b = (float_to_255(c) for c in self._rgba[:3]) + if alpha is None: + if self._rgba.alpha is None: + return r, g, b + else: + return r, g, b, self._alpha_float() + elif alpha: + return r, g, b, self._alpha_float() + else: + # alpha is False + return r, g, b + + def as_hsl(self) -> str: + """ + Color as an hsl(, , ) or hsl(, , , ) string. + """ + if self._rgba.alpha is None: + h, s, li = self.as_hsl_tuple(alpha=False) # type: ignore + return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%})' + else: + h, s, li, a = self.as_hsl_tuple(alpha=True) # type: ignore + return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%}, {round(a, 2)})' + + def as_hsl_tuple(self, *, alpha: Optional[bool] = None) -> HslColorTuple: + """ + Color as an HSL or HSLA tuple, e.g. hue, saturation, lightness and optionally alpha; all elements are in + the range 0 to 1. + + NOTE: this is HSL as used in HTML and most other places, not HLS as used in python's colorsys. + + :param alpha: whether to include the alpha channel, options are + None - (default) include alpha only if it's set (e.g. not None) + True - always include alpha, + False - always omit alpha, + """ + h, l, s = rgb_to_hls(self._rgba.r, self._rgba.g, self._rgba.b) + if alpha is None: + if self._rgba.alpha is None: + return h, s, l + else: + return h, s, l, self._alpha_float() + if alpha: + return h, s, l, self._alpha_float() + else: + # alpha is False + return h, s, l + + def _alpha_float(self) -> float: + return 1 if self._rgba.alpha is None else self._rgba.alpha + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls + + def __str__(self) -> str: + return self.as_named(fallback=True) + + def __repr_args__(self) -> 'ReprArgs': + return [(None, self.as_named(fallback=True))] + [('rgb', self.as_rgb_tuple())] # type: ignore + + def __eq__(self, other: Any) -> bool: + return isinstance(other, Color) and self.as_rgb_tuple() == other.as_rgb_tuple() + + def __hash__(self) -> int: + return hash(self.as_rgb_tuple()) + + +def parse_tuple(value: Tuple[Any, ...]) -> RGBA: + """ + Parse a tuple or list as a color. + """ + if len(value) == 3: + r, g, b = (parse_color_value(v) for v in value) + return RGBA(r, g, b, None) + elif len(value) == 4: + r, g, b = (parse_color_value(v) for v in value[:3]) + return RGBA(r, g, b, parse_float_alpha(value[3])) + else: + raise ColorError(reason='tuples must have length 3 or 4') + + +def parse_str(value: str) -> RGBA: + """ + Parse a string to an RGBA tuple, trying the following formats (in this order): + * named color, see COLORS_BY_NAME below + * hex short eg. `fff` (prefix can be `#`, `0x` or nothing) + * hex long eg. `ffffff` (prefix can be `#`, `0x` or nothing) + * `rgb(, , ) ` + * `rgba(, , , )` + """ + value_lower = value.lower() + try: + r, g, b = COLORS_BY_NAME[value_lower] + except KeyError: + pass + else: + return ints_to_rgba(r, g, b, None) + + m = re.fullmatch(r_hex_short, value_lower) + if m: + *rgb, a = m.groups() + r, g, b = (int(v * 2, 16) for v in rgb) + if a: + alpha: Optional[float] = int(a * 2, 16) / 255 + else: + alpha = None + return ints_to_rgba(r, g, b, alpha) + + m = re.fullmatch(r_hex_long, value_lower) + if m: + *rgb, a = m.groups() + r, g, b = (int(v, 16) for v in rgb) + if a: + alpha = int(a, 16) / 255 + else: + alpha = None + return ints_to_rgba(r, g, b, alpha) + + m = re.fullmatch(r_rgb, value_lower) + if m: + return ints_to_rgba(*m.groups(), None) # type: ignore + + m = re.fullmatch(r_rgba, value_lower) + if m: + return ints_to_rgba(*m.groups()) # type: ignore + + m = re.fullmatch(r_hsl, value_lower) + if m: + h, h_units, s, l_ = m.groups() + return parse_hsl(h, h_units, s, l_) + + m = re.fullmatch(r_hsla, value_lower) + if m: + h, h_units, s, l_, a = m.groups() + return parse_hsl(h, h_units, s, l_, parse_float_alpha(a)) + + raise ColorError(reason='string not recognised as a valid color') + + +def ints_to_rgba(r: Union[int, str], g: Union[int, str], b: Union[int, str], alpha: Optional[float]) -> RGBA: + return RGBA(parse_color_value(r), parse_color_value(g), parse_color_value(b), parse_float_alpha(alpha)) + + +def parse_color_value(value: Union[int, str], max_val: int = 255) -> float: + """ + Parse a value checking it's a valid int in the range 0 to max_val and divide by max_val to give a number + in the range 0 to 1 + """ + try: + color = float(value) + except ValueError: + raise ColorError(reason='color values must be a valid number') + if 0 <= color <= max_val: + return color / max_val + else: + raise ColorError(reason=f'color values must be in the range 0 to {max_val}') + + +def parse_float_alpha(value: Union[None, str, float, int]) -> Optional[float]: + """ + Parse a value checking it's a valid float in the range 0 to 1 + """ + if value is None: + return None + try: + if isinstance(value, str) and value.endswith('%'): + alpha = float(value[:-1]) / 100 + else: + alpha = float(value) + except ValueError: + raise ColorError(reason='alpha values must be a valid float') + + if almost_equal_floats(alpha, 1): + return None + elif 0 <= alpha <= 1: + return alpha + else: + raise ColorError(reason='alpha values must be in the range 0 to 1') + + +def parse_hsl(h: str, h_units: str, sat: str, light: str, alpha: Optional[float] = None) -> RGBA: + """ + Parse raw hue, saturation, lightness and alpha values and convert to RGBA. + """ + s_value, l_value = parse_color_value(sat, 100), parse_color_value(light, 100) + + h_value = float(h) + if h_units in {None, 'deg'}: + h_value = h_value % 360 / 360 + elif h_units == 'rad': + h_value = h_value % rads / rads + else: + # turns + h_value = h_value % 1 + + r, g, b = hls_to_rgb(h_value, l_value, s_value) + return RGBA(r, g, b, alpha) + + +def float_to_255(c: float) -> int: + return int(round(c * 255)) + + +COLORS_BY_NAME = { + 'aliceblue': (240, 248, 255), + 'antiquewhite': (250, 235, 215), + 'aqua': (0, 255, 255), + 'aquamarine': (127, 255, 212), + 'azure': (240, 255, 255), + 'beige': (245, 245, 220), + 'bisque': (255, 228, 196), + 'black': (0, 0, 0), + 'blanchedalmond': (255, 235, 205), + 'blue': (0, 0, 255), + 'blueviolet': (138, 43, 226), + 'brown': (165, 42, 42), + 'burlywood': (222, 184, 135), + 'cadetblue': (95, 158, 160), + 'chartreuse': (127, 255, 0), + 'chocolate': (210, 105, 30), + 'coral': (255, 127, 80), + 'cornflowerblue': (100, 149, 237), + 'cornsilk': (255, 248, 220), + 'crimson': (220, 20, 60), + 'cyan': (0, 255, 255), + 'darkblue': (0, 0, 139), + 'darkcyan': (0, 139, 139), + 'darkgoldenrod': (184, 134, 11), + 'darkgray': (169, 169, 169), + 'darkgreen': (0, 100, 0), + 'darkgrey': (169, 169, 169), + 'darkkhaki': (189, 183, 107), + 'darkmagenta': (139, 0, 139), + 'darkolivegreen': (85, 107, 47), + 'darkorange': (255, 140, 0), + 'darkorchid': (153, 50, 204), + 'darkred': (139, 0, 0), + 'darksalmon': (233, 150, 122), + 'darkseagreen': (143, 188, 143), + 'darkslateblue': (72, 61, 139), + 'darkslategray': (47, 79, 79), + 'darkslategrey': (47, 79, 79), + 'darkturquoise': (0, 206, 209), + 'darkviolet': (148, 0, 211), + 'deeppink': (255, 20, 147), + 'deepskyblue': (0, 191, 255), + 'dimgray': (105, 105, 105), + 'dimgrey': (105, 105, 105), + 'dodgerblue': (30, 144, 255), + 'firebrick': (178, 34, 34), + 'floralwhite': (255, 250, 240), + 'forestgreen': (34, 139, 34), + 'fuchsia': (255, 0, 255), + 'gainsboro': (220, 220, 220), + 'ghostwhite': (248, 248, 255), + 'gold': (255, 215, 0), + 'goldenrod': (218, 165, 32), + 'gray': (128, 128, 128), + 'green': (0, 128, 0), + 'greenyellow': (173, 255, 47), + 'grey': (128, 128, 128), + 'honeydew': (240, 255, 240), + 'hotpink': (255, 105, 180), + 'indianred': (205, 92, 92), + 'indigo': (75, 0, 130), + 'ivory': (255, 255, 240), + 'khaki': (240, 230, 140), + 'lavender': (230, 230, 250), + 'lavenderblush': (255, 240, 245), + 'lawngreen': (124, 252, 0), + 'lemonchiffon': (255, 250, 205), + 'lightblue': (173, 216, 230), + 'lightcoral': (240, 128, 128), + 'lightcyan': (224, 255, 255), + 'lightgoldenrodyellow': (250, 250, 210), + 'lightgray': (211, 211, 211), + 'lightgreen': (144, 238, 144), + 'lightgrey': (211, 211, 211), + 'lightpink': (255, 182, 193), + 'lightsalmon': (255, 160, 122), + 'lightseagreen': (32, 178, 170), + 'lightskyblue': (135, 206, 250), + 'lightslategray': (119, 136, 153), + 'lightslategrey': (119, 136, 153), + 'lightsteelblue': (176, 196, 222), + 'lightyellow': (255, 255, 224), + 'lime': (0, 255, 0), + 'limegreen': (50, 205, 50), + 'linen': (250, 240, 230), + 'magenta': (255, 0, 255), + 'maroon': (128, 0, 0), + 'mediumaquamarine': (102, 205, 170), + 'mediumblue': (0, 0, 205), + 'mediumorchid': (186, 85, 211), + 'mediumpurple': (147, 112, 219), + 'mediumseagreen': (60, 179, 113), + 'mediumslateblue': (123, 104, 238), + 'mediumspringgreen': (0, 250, 154), + 'mediumturquoise': (72, 209, 204), + 'mediumvioletred': (199, 21, 133), + 'midnightblue': (25, 25, 112), + 'mintcream': (245, 255, 250), + 'mistyrose': (255, 228, 225), + 'moccasin': (255, 228, 181), + 'navajowhite': (255, 222, 173), + 'navy': (0, 0, 128), + 'oldlace': (253, 245, 230), + 'olive': (128, 128, 0), + 'olivedrab': (107, 142, 35), + 'orange': (255, 165, 0), + 'orangered': (255, 69, 0), + 'orchid': (218, 112, 214), + 'palegoldenrod': (238, 232, 170), + 'palegreen': (152, 251, 152), + 'paleturquoise': (175, 238, 238), + 'palevioletred': (219, 112, 147), + 'papayawhip': (255, 239, 213), + 'peachpuff': (255, 218, 185), + 'peru': (205, 133, 63), + 'pink': (255, 192, 203), + 'plum': (221, 160, 221), + 'powderblue': (176, 224, 230), + 'purple': (128, 0, 128), + 'red': (255, 0, 0), + 'rosybrown': (188, 143, 143), + 'royalblue': (65, 105, 225), + 'saddlebrown': (139, 69, 19), + 'salmon': (250, 128, 114), + 'sandybrown': (244, 164, 96), + 'seagreen': (46, 139, 87), + 'seashell': (255, 245, 238), + 'sienna': (160, 82, 45), + 'silver': (192, 192, 192), + 'skyblue': (135, 206, 235), + 'slateblue': (106, 90, 205), + 'slategray': (112, 128, 144), + 'slategrey': (112, 128, 144), + 'snow': (255, 250, 250), + 'springgreen': (0, 255, 127), + 'steelblue': (70, 130, 180), + 'tan': (210, 180, 140), + 'teal': (0, 128, 128), + 'thistle': (216, 191, 216), + 'tomato': (255, 99, 71), + 'turquoise': (64, 224, 208), + 'violet': (238, 130, 238), + 'wheat': (245, 222, 179), + 'white': (255, 255, 255), + 'whitesmoke': (245, 245, 245), + 'yellow': (255, 255, 0), + 'yellowgreen': (154, 205, 50), +} + +COLORS_BY_VALUE = {v: k for k, v in COLORS_BY_NAME.items()} diff --git a/libs/win/pydantic/config.cp37-win_amd64.pyd b/libs/win/pydantic/config.cp37-win_amd64.pyd new file mode 100644 index 00000000..689773e4 Binary files /dev/null and b/libs/win/pydantic/config.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/config.py b/libs/win/pydantic/config.py new file mode 100644 index 00000000..74687ca0 --- /dev/null +++ b/libs/win/pydantic/config.py @@ -0,0 +1,192 @@ +import json +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable, Dict, ForwardRef, Optional, Tuple, Type, Union + +from typing_extensions import Literal, Protocol + +from .typing import AnyArgTCallable, AnyCallable +from .utils import GetterDict +from .version import compiled + +if TYPE_CHECKING: + from typing import overload + + from .fields import ModelField + from .main import BaseModel + + ConfigType = Type['BaseConfig'] + + class SchemaExtraCallable(Protocol): + @overload + def __call__(self, schema: Dict[str, Any]) -> None: + pass + + @overload + def __call__(self, schema: Dict[str, Any], model_class: Type[BaseModel]) -> None: + pass + +else: + SchemaExtraCallable = Callable[..., None] + +__all__ = 'BaseConfig', 'ConfigDict', 'get_config', 'Extra', 'inherit_config', 'prepare_config' + + +class Extra(str, Enum): + allow = 'allow' + ignore = 'ignore' + forbid = 'forbid' + + +# https://github.com/cython/cython/issues/4003 +# Will be fixed with Cython 3 but still in alpha right now +if not compiled: + from typing_extensions import TypedDict + + class ConfigDict(TypedDict, total=False): + title: Optional[str] + anystr_lower: bool + anystr_strip_whitespace: bool + min_anystr_length: int + max_anystr_length: Optional[int] + validate_all: bool + extra: Extra + allow_mutation: bool + frozen: bool + allow_population_by_field_name: bool + use_enum_values: bool + fields: Dict[str, Union[str, Dict[str, str]]] + validate_assignment: bool + error_msg_templates: Dict[str, str] + arbitrary_types_allowed: bool + orm_mode: bool + getter_dict: Type[GetterDict] + alias_generator: Optional[Callable[[str], str]] + keep_untouched: Tuple[type, ...] + schema_extra: Union[Dict[str, object], 'SchemaExtraCallable'] + json_loads: Callable[[str], object] + json_dumps: AnyArgTCallable[str] + json_encoders: Dict[Type[object], AnyCallable] + underscore_attrs_are_private: bool + allow_inf_nan: bool + + # whether or not inherited models as fields should be reconstructed as base model + copy_on_model_validation: bool + # whether dataclass `__post_init__` should be run after validation + post_init_call: Literal['before_validation', 'after_validation'] + +else: + ConfigDict = dict # type: ignore + + +class BaseConfig: + title: Optional[str] = None + anystr_lower: bool = False + anystr_upper: bool = False + anystr_strip_whitespace: bool = False + min_anystr_length: int = 0 + max_anystr_length: Optional[int] = None + validate_all: bool = False + extra: Extra = Extra.ignore + allow_mutation: bool = True + frozen: bool = False + allow_population_by_field_name: bool = False + use_enum_values: bool = False + fields: Dict[str, Union[str, Dict[str, str]]] = {} + validate_assignment: bool = False + error_msg_templates: Dict[str, str] = {} + arbitrary_types_allowed: bool = False + orm_mode: bool = False + getter_dict: Type[GetterDict] = GetterDict + alias_generator: Optional[Callable[[str], str]] = None + keep_untouched: Tuple[type, ...] = () + schema_extra: Union[Dict[str, Any], 'SchemaExtraCallable'] = {} + json_loads: Callable[[str], Any] = json.loads + json_dumps: Callable[..., str] = json.dumps + json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable] = {} + underscore_attrs_are_private: bool = False + allow_inf_nan: bool = True + + # whether inherited models as fields should be reconstructed as base model, + # and whether such a copy should be shallow or deep + copy_on_model_validation: Literal['none', 'deep', 'shallow'] = 'shallow' + + # whether `Union` should check all allowed types before even trying to coerce + smart_union: bool = False + # whether dataclass `__post_init__` should be run before or after validation + post_init_call: Literal['before_validation', 'after_validation'] = 'before_validation' + + @classmethod + def get_field_info(cls, name: str) -> Dict[str, Any]: + """ + Get properties of FieldInfo from the `fields` property of the config class. + """ + + fields_value = cls.fields.get(name) + + if isinstance(fields_value, str): + field_info: Dict[str, Any] = {'alias': fields_value} + elif isinstance(fields_value, dict): + field_info = fields_value + else: + field_info = {} + + if 'alias' in field_info: + field_info.setdefault('alias_priority', 2) + + if field_info.get('alias_priority', 0) <= 1 and cls.alias_generator: + alias = cls.alias_generator(name) + if not isinstance(alias, str): + raise TypeError(f'Config.alias_generator must return str, not {alias.__class__}') + field_info.update(alias=alias, alias_priority=1) + return field_info + + @classmethod + def prepare_field(cls, field: 'ModelField') -> None: + """ + Optional hook to check or modify fields during model creation. + """ + pass + + +def get_config(config: Union[ConfigDict, Type[object], None]) -> Type[BaseConfig]: + if config is None: + return BaseConfig + + else: + config_dict = ( + config + if isinstance(config, dict) + else {k: getattr(config, k) for k in dir(config) if not k.startswith('__')} + ) + + class Config(BaseConfig): + ... + + for k, v in config_dict.items(): + setattr(Config, k, v) + return Config + + +def inherit_config(self_config: 'ConfigType', parent_config: 'ConfigType', **namespace: Any) -> 'ConfigType': + if not self_config: + base_classes: Tuple['ConfigType', ...] = (parent_config,) + elif self_config == parent_config: + base_classes = (self_config,) + else: + base_classes = self_config, parent_config + + namespace['json_encoders'] = { + **getattr(parent_config, 'json_encoders', {}), + **getattr(self_config, 'json_encoders', {}), + **namespace.get('json_encoders', {}), + } + + return type('Config', base_classes, namespace) + + +def prepare_config(config: Type[BaseConfig], cls_name: str) -> None: + if not isinstance(config.extra, Extra): + try: + config.extra = Extra(config.extra) + except ValueError: + raise ValueError(f'"{cls_name}": {config.extra} is not a valid value for "extra"') diff --git a/libs/win/pydantic/dataclasses.cp37-win_amd64.pyd b/libs/win/pydantic/dataclasses.cp37-win_amd64.pyd new file mode 100644 index 00000000..c672f8d6 Binary files /dev/null and b/libs/win/pydantic/dataclasses.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/dataclasses.py b/libs/win/pydantic/dataclasses.py new file mode 100644 index 00000000..68331127 --- /dev/null +++ b/libs/win/pydantic/dataclasses.py @@ -0,0 +1,479 @@ +""" +The main purpose is to enhance stdlib dataclasses by adding validation +A pydantic dataclass can be generated from scratch or from a stdlib one. + +Behind the scene, a pydantic dataclass is just like a regular one on which we attach +a `BaseModel` and magic methods to trigger the validation of the data. +`__init__` and `__post_init__` are hence overridden and have extra logic to be +able to validate input data. + +When a pydantic dataclass is generated from scratch, it's just a plain dataclass +with validation triggered at initialization + +The tricky part if for stdlib dataclasses that are converted after into pydantic ones e.g. + +```py +@dataclasses.dataclass +class M: + x: int + +ValidatedM = pydantic.dataclasses.dataclass(M) +``` + +We indeed still want to support equality, hashing, repr, ... as if it was the stdlib one! + +```py +assert isinstance(ValidatedM(x=1), M) +assert ValidatedM(x=1) == M(x=1) +``` + +This means we **don't want to create a new dataclass that inherits from it** +The trick is to create a wrapper around `M` that will act as a proxy to trigger +validation without altering default `M` behaviour. +""" +import sys +from contextlib import contextmanager +from functools import wraps +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Generator, + Optional, + Set, + Type, + TypeVar, + Union, + overload, +) + +from typing_extensions import dataclass_transform + +from .class_validators import gather_all_validators +from .config import BaseConfig, ConfigDict, Extra, get_config +from .error_wrappers import ValidationError +from .errors import DataclassTypeError +from .fields import Field, FieldInfo, Required, Undefined +from .main import create_model, validate_model +from .utils import ClassAttribute + +if TYPE_CHECKING: + from .main import BaseModel + from .typing import CallableGenerator, NoArgAnyCallable + + DataclassT = TypeVar('DataclassT', bound='Dataclass') + + DataclassClassOrWrapper = Union[Type['Dataclass'], 'DataclassProxy'] + + class Dataclass: + # stdlib attributes + __dataclass_fields__: ClassVar[Dict[str, Any]] + __dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams` + __post_init__: ClassVar[Callable[..., None]] + + # Added by pydantic + __pydantic_run_validation__: ClassVar[bool] + __post_init_post_parse__: ClassVar[Callable[..., None]] + __pydantic_initialised__: ClassVar[bool] + __pydantic_model__: ClassVar[Type[BaseModel]] + __pydantic_validate_values__: ClassVar[Callable[['Dataclass'], None]] + __pydantic_has_field_info_default__: ClassVar[bool] # whether a `pydantic.Field` is used as default value + + def __init__(self, *args: object, **kwargs: object) -> None: + pass + + @classmethod + def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator': + pass + + @classmethod + def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT': + pass + + +__all__ = [ + 'dataclass', + 'set_validation', + 'create_pydantic_model_from_dataclass', + 'is_builtin_dataclass', + 'make_dataclass_validator', +] + +_T = TypeVar('_T') + +if sys.version_info >= (3, 10): + + @dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) + @overload + def dataclass( + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Union[ConfigDict, Type[object], None] = None, + validate_on_init: Optional[bool] = None, + kw_only: bool = ..., + ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']: + ... + + @dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) + @overload + def dataclass( + _cls: Type[_T], + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Union[ConfigDict, Type[object], None] = None, + validate_on_init: Optional[bool] = None, + kw_only: bool = ..., + ) -> 'DataclassClassOrWrapper': + ... + +else: + + @dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) + @overload + def dataclass( + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Union[ConfigDict, Type[object], None] = None, + validate_on_init: Optional[bool] = None, + ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']: + ... + + @dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) + @overload + def dataclass( + _cls: Type[_T], + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Union[ConfigDict, Type[object], None] = None, + validate_on_init: Optional[bool] = None, + ) -> 'DataclassClassOrWrapper': + ... + + +@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) +def dataclass( + _cls: Optional[Type[_T]] = None, + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Union[ConfigDict, Type[object], None] = None, + validate_on_init: Optional[bool] = None, + kw_only: bool = False, +) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']: + """ + Like the python standard lib dataclasses but with type validation. + The result is either a pydantic dataclass that will validate input data + or a wrapper that will trigger validation around a stdlib dataclass + to avoid modifying it directly + """ + the_config = get_config(config) + + def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper': + import dataclasses + + if is_builtin_dataclass(cls) and _extra_dc_args(_cls) == _extra_dc_args(_cls.__bases__[0]): # type: ignore + dc_cls_doc = '' + dc_cls = DataclassProxy(cls) + default_validate_on_init = False + else: + dc_cls_doc = cls.__doc__ or '' # needs to be done before generating dataclass + if sys.version_info >= (3, 10): + dc_cls = dataclasses.dataclass( + cls, + init=init, + repr=repr, + eq=eq, + order=order, + unsafe_hash=unsafe_hash, + frozen=frozen, + kw_only=kw_only, + ) + else: + dc_cls = dataclasses.dataclass( # type: ignore + cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen + ) + default_validate_on_init = True + + should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init + _add_pydantic_validation_attributes(cls, the_config, should_validate_on_init, dc_cls_doc) + dc_cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls}) + return dc_cls + + if _cls is None: + return wrap + + return wrap(_cls) + + +@contextmanager +def set_validation(cls: Type['DataclassT'], value: bool) -> Generator[Type['DataclassT'], None, None]: + original_run_validation = cls.__pydantic_run_validation__ + try: + cls.__pydantic_run_validation__ = value + yield cls + finally: + cls.__pydantic_run_validation__ = original_run_validation + + +class DataclassProxy: + __slots__ = '__dataclass__' + + def __init__(self, dc_cls: Type['Dataclass']) -> None: + object.__setattr__(self, '__dataclass__', dc_cls) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + with set_validation(self.__dataclass__, True): + return self.__dataclass__(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: + return getattr(self.__dataclass__, name) + + def __instancecheck__(self, instance: Any) -> bool: + return isinstance(instance, self.__dataclass__) + + +def _add_pydantic_validation_attributes( # noqa: C901 (ignore complexity) + dc_cls: Type['Dataclass'], + config: Type[BaseConfig], + validate_on_init: bool, + dc_cls_doc: str, +) -> None: + """ + We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass + it won't even exist (code is generated on the fly by `dataclasses`) + By default, we run validation after `__init__` or `__post_init__` if defined + """ + init = dc_cls.__init__ + + @wraps(init) + def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: + if config.extra == Extra.ignore: + init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__}) + + elif config.extra == Extra.allow: + for k, v in kwargs.items(): + self.__dict__.setdefault(k, v) + init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__}) + + else: + init(self, *args, **kwargs) + + if hasattr(dc_cls, '__post_init__'): + post_init = dc_cls.__post_init__ + + @wraps(post_init) + def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: + if config.post_init_call == 'before_validation': + post_init(self, *args, **kwargs) + + if self.__class__.__pydantic_run_validation__: + self.__pydantic_validate_values__() + if hasattr(self, '__post_init_post_parse__'): + self.__post_init_post_parse__(*args, **kwargs) + + if config.post_init_call == 'after_validation': + post_init(self, *args, **kwargs) + + setattr(dc_cls, '__init__', handle_extra_init) + setattr(dc_cls, '__post_init__', new_post_init) + + else: + + @wraps(init) + def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: + handle_extra_init(self, *args, **kwargs) + + if self.__class__.__pydantic_run_validation__: + self.__pydantic_validate_values__() + + if hasattr(self, '__post_init_post_parse__'): + # We need to find again the initvars. To do that we use `__dataclass_fields__` instead of + # public method `dataclasses.fields` + import dataclasses + + # get all initvars and their default values + initvars_and_values: Dict[str, Any] = {} + for i, f in enumerate(self.__class__.__dataclass_fields__.values()): + if f._field_type is dataclasses._FIELD_INITVAR: # type: ignore[attr-defined] + try: + # set arg value by default + initvars_and_values[f.name] = args[i] + except IndexError: + initvars_and_values[f.name] = kwargs.get(f.name, f.default) + + self.__post_init_post_parse__(**initvars_and_values) + + setattr(dc_cls, '__init__', new_init) + + setattr(dc_cls, '__pydantic_run_validation__', ClassAttribute('__pydantic_run_validation__', validate_on_init)) + setattr(dc_cls, '__pydantic_initialised__', False) + setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config, dc_cls_doc)) + setattr(dc_cls, '__pydantic_validate_values__', _dataclass_validate_values) + setattr(dc_cls, '__validate__', classmethod(_validate_dataclass)) + setattr(dc_cls, '__get_validators__', classmethod(_get_validators)) + + if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen: + setattr(dc_cls, '__setattr__', _dataclass_validate_assignment_setattr) + + +def _get_validators(cls: 'DataclassClassOrWrapper') -> 'CallableGenerator': + yield cls.__validate__ + + +def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT': + with set_validation(cls, True): + if isinstance(v, cls): + v.__pydantic_validate_values__() + return v + elif isinstance(v, (list, tuple)): + return cls(*v) + elif isinstance(v, dict): + return cls(**v) + else: + raise DataclassTypeError(class_name=cls.__name__) + + +def create_pydantic_model_from_dataclass( + dc_cls: Type['Dataclass'], + config: Type[Any] = BaseConfig, + dc_cls_doc: Optional[str] = None, +) -> Type['BaseModel']: + import dataclasses + + field_definitions: Dict[str, Any] = {} + for field in dataclasses.fields(dc_cls): + default: Any = Undefined + default_factory: Optional['NoArgAnyCallable'] = None + field_info: FieldInfo + + if field.default is not dataclasses.MISSING: + default = field.default + elif field.default_factory is not dataclasses.MISSING: + default_factory = field.default_factory + else: + default = Required + + if isinstance(default, FieldInfo): + field_info = default + dc_cls.__pydantic_has_field_info_default__ = True + else: + field_info = Field(default=default, default_factory=default_factory, **field.metadata) + + field_definitions[field.name] = (field.type, field_info) + + validators = gather_all_validators(dc_cls) + model: Type['BaseModel'] = create_model( + dc_cls.__name__, + __config__=config, + __module__=dc_cls.__module__, + __validators__=validators, + __cls_kwargs__={'__resolve_forward_refs__': False}, + **field_definitions, + ) + model.__doc__ = dc_cls_doc if dc_cls_doc is not None else dc_cls.__doc__ or '' + return model + + +def _dataclass_validate_values(self: 'Dataclass') -> None: + # validation errors can occur if this function is called twice on an already initialised dataclass. + # for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property + if getattr(self, '__pydantic_initialised__'): + return + if getattr(self, '__pydantic_has_field_info_default__', False): + # We need to remove `FieldInfo` values since they are not valid as input + # It's ok to do that because they are obviously the default values! + input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)} + else: + input_data = self.__dict__ + d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__) + if validation_error: + raise validation_error + self.__dict__.update(d) + object.__setattr__(self, '__pydantic_initialised__', True) + + +def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None: + if self.__pydantic_initialised__: + d = dict(self.__dict__) + d.pop(name, None) + known_field = self.__pydantic_model__.__fields__.get(name, None) + if known_field: + value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__) + if error_: + raise ValidationError([error_], self.__class__) + + object.__setattr__(self, name, value) + + +def _extra_dc_args(cls: Type[Any]) -> Set[str]: + return { + x + for x in dir(cls) + if x not in getattr(cls, '__dataclass_fields__', {}) and not (x.startswith('__') and x.endswith('__')) + } + + +def is_builtin_dataclass(_cls: Type[Any]) -> bool: + """ + Whether a class is a stdlib dataclass + (useful to discriminated a pydantic dataclass that is actually a wrapper around a stdlib dataclass) + + we check that + - `_cls` is a dataclass + - `_cls` is not a processed pydantic dataclass (with a basemodel attached) + - `_cls` is not a pydantic dataclass inheriting directly from a stdlib dataclass + e.g. + ``` + @dataclasses.dataclass + class A: + x: int + + @pydantic.dataclasses.dataclass + class B(A): + y: int + ``` + In this case, when we first check `B`, we make an extra check and look at the annotations ('y'), + which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x') + """ + import dataclasses + + return ( + dataclasses.is_dataclass(_cls) + and not hasattr(_cls, '__pydantic_model__') + and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {}))) + ) + + +def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]) -> 'CallableGenerator': + """ + Create a pydantic.dataclass from a builtin dataclass to add type validation + and yield the validators + It retrieves the parameters of the dataclass and forwards them to the newly created dataclass + """ + yield from _get_validators(dataclass(dc_cls, config=config, validate_on_init=False)) diff --git a/libs/win/pydantic/datetime_parse.cp37-win_amd64.pyd b/libs/win/pydantic/datetime_parse.cp37-win_amd64.pyd new file mode 100644 index 00000000..6f98d735 Binary files /dev/null and b/libs/win/pydantic/datetime_parse.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/datetime_parse.py b/libs/win/pydantic/datetime_parse.py new file mode 100644 index 00000000..cfd54593 --- /dev/null +++ b/libs/win/pydantic/datetime_parse.py @@ -0,0 +1,248 @@ +""" +Functions to parse datetime objects. + +We're using regular expressions rather than time.strptime because: +- They provide both validation and parsing. +- They're more flexible for datetimes. +- The date/datetime/time constructors produce friendlier error messages. + +Stolen from https://raw.githubusercontent.com/django/django/main/django/utils/dateparse.py at +9718fa2e8abe430c3526a9278dd976443d4ae3c6 + +Changed to: +* use standard python datetime types not django.utils.timezone +* raise ValueError when regex doesn't match rather than returning None +* support parsing unix timestamps for dates and datetimes +""" +import re +from datetime import date, datetime, time, timedelta, timezone +from typing import Dict, Optional, Type, Union + +from . import errors + +date_expr = r'(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})' +time_expr = ( + r'(?P\d{1,2}):(?P\d{1,2})' + r'(?::(?P\d{1,2})(?:\.(?P\d{1,6})\d{0,6})?)?' + r'(?PZ|[+-]\d{2}(?::?\d{2})?)?$' +) + +date_re = re.compile(f'{date_expr}$') +time_re = re.compile(time_expr) +datetime_re = re.compile(f'{date_expr}[T ]{time_expr}') + +standard_duration_re = re.compile( + r'^' + r'(?:(?P-?\d+) (days?, )?)?' + r'((?:(?P-?\d+):)(?=\d+:\d+))?' + r'(?:(?P-?\d+):)?' + r'(?P-?\d+)' + r'(?:\.(?P\d{1,6})\d{0,6})?' + r'$' +) + +# Support the sections of ISO 8601 date representation that are accepted by timedelta +iso8601_duration_re = re.compile( + r'^(?P[-+]?)' + r'P' + r'(?:(?P\d+(.\d+)?)D)?' + r'(?:T' + r'(?:(?P\d+(.\d+)?)H)?' + r'(?:(?P\d+(.\d+)?)M)?' + r'(?:(?P\d+(.\d+)?)S)?' + r')?' + r'$' +) + +EPOCH = datetime(1970, 1, 1) +# if greater than this, the number is in ms, if less than or equal it's in seconds +# (in seconds this is 11th October 2603, in ms it's 20th August 1970) +MS_WATERSHED = int(2e10) +# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9 +MAX_NUMBER = int(3e20) +StrBytesIntFloat = Union[str, bytes, int, float] + + +def get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]: + if isinstance(value, (int, float)): + return value + try: + return float(value) + except ValueError: + return None + except TypeError: + raise TypeError(f'invalid type; expected {native_expected_type}, string, bytes, int or float') + + +def from_unix_seconds(seconds: Union[int, float]) -> datetime: + if seconds > MAX_NUMBER: + return datetime.max + elif seconds < -MAX_NUMBER: + return datetime.min + + while abs(seconds) > MS_WATERSHED: + seconds /= 1000 + dt = EPOCH + timedelta(seconds=seconds) + return dt.replace(tzinfo=timezone.utc) + + +def _parse_timezone(value: Optional[str], error: Type[Exception]) -> Union[None, int, timezone]: + if value == 'Z': + return timezone.utc + elif value is not None: + offset_mins = int(value[-2:]) if len(value) > 3 else 0 + offset = 60 * int(value[1:3]) + offset_mins + if value[0] == '-': + offset = -offset + try: + return timezone(timedelta(minutes=offset)) + except ValueError: + raise error() + else: + return None + + +def parse_date(value: Union[date, StrBytesIntFloat]) -> date: + """ + Parse a date/int/float/string and return a datetime.date. + + Raise ValueError if the input is well formatted but not a valid date. + Raise ValueError if the input isn't well formatted. + """ + if isinstance(value, date): + if isinstance(value, datetime): + return value.date() + else: + return value + + number = get_numeric(value, 'date') + if number is not None: + return from_unix_seconds(number).date() + + if isinstance(value, bytes): + value = value.decode() + + match = date_re.match(value) # type: ignore + if match is None: + raise errors.DateError() + + kw = {k: int(v) for k, v in match.groupdict().items()} + + try: + return date(**kw) + except ValueError: + raise errors.DateError() + + +def parse_time(value: Union[time, StrBytesIntFloat]) -> time: + """ + Parse a time/string and return a datetime.time. + + Raise ValueError if the input is well formatted but not a valid time. + Raise ValueError if the input isn't well formatted, in particular if it contains an offset. + """ + if isinstance(value, time): + return value + + number = get_numeric(value, 'time') + if number is not None: + if number >= 86400: + # doesn't make sense since the time time loop back around to 0 + raise errors.TimeError() + return (datetime.min + timedelta(seconds=number)).time() + + if isinstance(value, bytes): + value = value.decode() + + match = time_re.match(value) # type: ignore + if match is None: + raise errors.TimeError() + + kw = match.groupdict() + if kw['microsecond']: + kw['microsecond'] = kw['microsecond'].ljust(6, '0') + + tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.TimeError) + kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None} + kw_['tzinfo'] = tzinfo + + try: + return time(**kw_) # type: ignore + except ValueError: + raise errors.TimeError() + + +def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: + """ + Parse a datetime/int/float/string and return a datetime.datetime. + + This function supports time zone offsets. When the input contains one, + the output uses a timezone with a fixed offset from UTC. + + Raise ValueError if the input is well formatted but not a valid datetime. + Raise ValueError if the input isn't well formatted. + """ + if isinstance(value, datetime): + return value + + number = get_numeric(value, 'datetime') + if number is not None: + return from_unix_seconds(number) + + if isinstance(value, bytes): + value = value.decode() + + match = datetime_re.match(value) # type: ignore + if match is None: + raise errors.DateTimeError() + + kw = match.groupdict() + if kw['microsecond']: + kw['microsecond'] = kw['microsecond'].ljust(6, '0') + + tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.DateTimeError) + kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None} + kw_['tzinfo'] = tzinfo + + try: + return datetime(**kw_) # type: ignore + except ValueError: + raise errors.DateTimeError() + + +def parse_duration(value: StrBytesIntFloat) -> timedelta: + """ + Parse a duration int/float/string and return a datetime.timedelta. + + The preferred format for durations in Django is '%d %H:%M:%S.%f'. + + Also supports ISO 8601 representation. + """ + if isinstance(value, timedelta): + return value + + if isinstance(value, (int, float)): + # below code requires a string + value = f'{value:f}' + elif isinstance(value, bytes): + value = value.decode() + + try: + match = standard_duration_re.match(value) or iso8601_duration_re.match(value) + except TypeError: + raise TypeError('invalid type; expected timedelta, string, bytes, int or float') + + if not match: + raise errors.DurationError() + + kw = match.groupdict() + sign = -1 if kw.pop('sign', '+') == '-' else 1 + if kw.get('microseconds'): + kw['microseconds'] = kw['microseconds'].ljust(6, '0') + + if kw.get('seconds') and kw.get('microseconds') and kw['seconds'].startswith('-'): + kw['microseconds'] = '-' + kw['microseconds'] + + kw_ = {k: float(v) for k, v in kw.items() if v is not None} + + return sign * timedelta(**kw_) diff --git a/libs/win/pydantic/decorator.cp37-win_amd64.pyd b/libs/win/pydantic/decorator.cp37-win_amd64.pyd new file mode 100644 index 00000000..ed58a134 Binary files /dev/null and b/libs/win/pydantic/decorator.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/decorator.py b/libs/win/pydantic/decorator.py new file mode 100644 index 00000000..089aab65 --- /dev/null +++ b/libs/win/pydantic/decorator.py @@ -0,0 +1,264 @@ +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload + +from . import validator +from .config import Extra +from .errors import ConfigError +from .main import BaseModel, create_model +from .typing import get_all_type_hints +from .utils import to_camel + +__all__ = ('validate_arguments',) + +if TYPE_CHECKING: + from .typing import AnyCallable + + AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable) + ConfigType = Union[None, Type[Any], Dict[str, Any]] + + +@overload +def validate_arguments(func: None = None, *, config: 'ConfigType' = None) -> Callable[['AnyCallableT'], 'AnyCallableT']: + ... + + +@overload +def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT': + ... + + +def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any: + """ + Decorator to validate the arguments passed to a function. + """ + + def validate(_func: 'AnyCallable') -> 'AnyCallable': + vd = ValidatedFunction(_func, config) + + @wraps(_func) + def wrapper_function(*args: Any, **kwargs: Any) -> Any: + return vd.call(*args, **kwargs) + + wrapper_function.vd = vd # type: ignore + wrapper_function.validate = vd.init_model_instance # type: ignore + wrapper_function.raw_function = vd.raw_function # type: ignore + wrapper_function.model = vd.model # type: ignore + return wrapper_function + + if func: + return validate(func) + else: + return validate + + +ALT_V_ARGS = 'v__args' +ALT_V_KWARGS = 'v__kwargs' +V_POSITIONAL_ONLY_NAME = 'v__positional_only' +V_DUPLICATE_KWARGS = 'v__duplicate_kwargs' + + +class ValidatedFunction: + def __init__(self, function: 'AnyCallableT', config: 'ConfigType'): # noqa C901 + from inspect import Parameter, signature + + parameters: Mapping[str, Parameter] = signature(function).parameters + + if parameters.keys() & {ALT_V_ARGS, ALT_V_KWARGS, V_POSITIONAL_ONLY_NAME, V_DUPLICATE_KWARGS}: + raise ConfigError( + f'"{ALT_V_ARGS}", "{ALT_V_KWARGS}", "{V_POSITIONAL_ONLY_NAME}" and "{V_DUPLICATE_KWARGS}" ' + f'are not permitted as argument names when using the "{validate_arguments.__name__}" decorator' + ) + + self.raw_function = function + self.arg_mapping: Dict[int, str] = {} + self.positional_only_args = set() + self.v_args_name = 'args' + self.v_kwargs_name = 'kwargs' + + type_hints = get_all_type_hints(function) + takes_args = False + takes_kwargs = False + fields: Dict[str, Tuple[Any, Any]] = {} + for i, (name, p) in enumerate(parameters.items()): + if p.annotation is p.empty: + annotation = Any + else: + annotation = type_hints[name] + + default = ... if p.default is p.empty else p.default + if p.kind == Parameter.POSITIONAL_ONLY: + self.arg_mapping[i] = name + fields[name] = annotation, default + fields[V_POSITIONAL_ONLY_NAME] = List[str], None + self.positional_only_args.add(name) + elif p.kind == Parameter.POSITIONAL_OR_KEYWORD: + self.arg_mapping[i] = name + fields[name] = annotation, default + fields[V_DUPLICATE_KWARGS] = List[str], None + elif p.kind == Parameter.KEYWORD_ONLY: + fields[name] = annotation, default + elif p.kind == Parameter.VAR_POSITIONAL: + self.v_args_name = name + fields[name] = Tuple[annotation, ...], None + takes_args = True + else: + assert p.kind == Parameter.VAR_KEYWORD, p.kind + self.v_kwargs_name = name + fields[name] = Dict[str, annotation], None # type: ignore + takes_kwargs = True + + # these checks avoid a clash between "args" and a field with that name + if not takes_args and self.v_args_name in fields: + self.v_args_name = ALT_V_ARGS + + # same with "kwargs" + if not takes_kwargs and self.v_kwargs_name in fields: + self.v_kwargs_name = ALT_V_KWARGS + + if not takes_args: + # we add the field so validation below can raise the correct exception + fields[self.v_args_name] = List[Any], None + + if not takes_kwargs: + # same with kwargs + fields[self.v_kwargs_name] = Dict[Any, Any], None + + self.create_model(fields, takes_args, takes_kwargs, config) + + def init_model_instance(self, *args: Any, **kwargs: Any) -> BaseModel: + values = self.build_values(args, kwargs) + return self.model(**values) + + def call(self, *args: Any, **kwargs: Any) -> Any: + m = self.init_model_instance(*args, **kwargs) + return self.execute(m) + + def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]: + values: Dict[str, Any] = {} + if args: + arg_iter = enumerate(args) + while True: + try: + i, a = next(arg_iter) + except StopIteration: + break + arg_name = self.arg_mapping.get(i) + if arg_name is not None: + values[arg_name] = a + else: + values[self.v_args_name] = [a] + [a for _, a in arg_iter] + break + + var_kwargs: Dict[str, Any] = {} + wrong_positional_args = [] + duplicate_kwargs = [] + fields_alias = [ + field.alias + for name, field in self.model.__fields__.items() + if name not in (self.v_args_name, self.v_kwargs_name) + ] + non_var_fields = set(self.model.__fields__) - {self.v_args_name, self.v_kwargs_name} + for k, v in kwargs.items(): + if k in non_var_fields or k in fields_alias: + if k in self.positional_only_args: + wrong_positional_args.append(k) + if k in values: + duplicate_kwargs.append(k) + values[k] = v + else: + var_kwargs[k] = v + + if var_kwargs: + values[self.v_kwargs_name] = var_kwargs + if wrong_positional_args: + values[V_POSITIONAL_ONLY_NAME] = wrong_positional_args + if duplicate_kwargs: + values[V_DUPLICATE_KWARGS] = duplicate_kwargs + return values + + def execute(self, m: BaseModel) -> Any: + d = {k: v for k, v in m._iter() if k in m.__fields_set__ or m.__fields__[k].default_factory} + var_kwargs = d.pop(self.v_kwargs_name, {}) + + if self.v_args_name in d: + args_: List[Any] = [] + in_kwargs = False + kwargs = {} + for name, value in d.items(): + if in_kwargs: + kwargs[name] = value + elif name == self.v_args_name: + args_ += value + in_kwargs = True + else: + args_.append(value) + return self.raw_function(*args_, **kwargs, **var_kwargs) + elif self.positional_only_args: + args_ = [] + kwargs = {} + for name, value in d.items(): + if name in self.positional_only_args: + args_.append(value) + else: + kwargs[name] = value + return self.raw_function(*args_, **kwargs, **var_kwargs) + else: + return self.raw_function(**d, **var_kwargs) + + def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None: + pos_args = len(self.arg_mapping) + + class CustomConfig: + pass + + if not TYPE_CHECKING: # pragma: no branch + if isinstance(config, dict): + CustomConfig = type('Config', (), config) # noqa: F811 + elif config is not None: + CustomConfig = config # noqa: F811 + + if hasattr(CustomConfig, 'fields') or hasattr(CustomConfig, 'alias_generator'): + raise ConfigError( + 'Setting the "fields" and "alias_generator" property on custom Config for ' + '@validate_arguments is not yet supported, please remove.' + ) + + class DecoratorBaseModel(BaseModel): + @validator(self.v_args_name, check_fields=False, allow_reuse=True) + def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]: + if takes_args or v is None: + return v + + raise TypeError(f'{pos_args} positional arguments expected but {pos_args + len(v)} given') + + @validator(self.v_kwargs_name, check_fields=False, allow_reuse=True) + def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if takes_kwargs or v is None: + return v + + plural = '' if len(v) == 1 else 's' + keys = ', '.join(map(repr, v.keys())) + raise TypeError(f'unexpected keyword argument{plural}: {keys}') + + @validator(V_POSITIONAL_ONLY_NAME, check_fields=False, allow_reuse=True) + def check_positional_only(cls, v: Optional[List[str]]) -> None: + if v is None: + return + + plural = '' if len(v) == 1 else 's' + keys = ', '.join(map(repr, v)) + raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}') + + @validator(V_DUPLICATE_KWARGS, check_fields=False, allow_reuse=True) + def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None: + if v is None: + return + + plural = '' if len(v) == 1 else 's' + keys = ', '.join(map(repr, v)) + raise TypeError(f'multiple values for argument{plural}: {keys}') + + class Config(CustomConfig): + extra = getattr(CustomConfig, 'extra', Extra.forbid) + + self.model = create_model(to_camel(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields) diff --git a/libs/win/pydantic/env_settings.cp37-win_amd64.pyd b/libs/win/pydantic/env_settings.cp37-win_amd64.pyd new file mode 100644 index 00000000..1003ccc3 Binary files /dev/null and b/libs/win/pydantic/env_settings.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/env_settings.py b/libs/win/pydantic/env_settings.py new file mode 100644 index 00000000..e9988c01 --- /dev/null +++ b/libs/win/pydantic/env_settings.py @@ -0,0 +1,346 @@ +import os +import warnings +from pathlib import Path +from typing import AbstractSet, Any, Callable, ClassVar, Dict, List, Mapping, Optional, Tuple, Type, Union + +from .config import BaseConfig, Extra +from .fields import ModelField +from .main import BaseModel +from .typing import StrPath, display_as_type, get_origin, is_union +from .utils import deep_update, path_type, sequence_like + +env_file_sentinel = str(object()) + +SettingsSourceCallable = Callable[['BaseSettings'], Dict[str, Any]] +DotenvType = Union[StrPath, List[StrPath], Tuple[StrPath, ...]] + + +class SettingsError(ValueError): + pass + + +class BaseSettings(BaseModel): + """ + Base class for settings, allowing values to be overridden by environment variables. + + This is useful in production for secrets you do not wish to save in code, it plays nicely with docker(-compose), + Heroku and any 12 factor app design. + """ + + def __init__( + __pydantic_self__, + _env_file: Optional[DotenvType] = env_file_sentinel, + _env_file_encoding: Optional[str] = None, + _env_nested_delimiter: Optional[str] = None, + _secrets_dir: Optional[StrPath] = None, + **values: Any, + ) -> None: + # Uses something other than `self` the first arg to allow "self" as a settable attribute + super().__init__( + **__pydantic_self__._build_values( + values, + _env_file=_env_file, + _env_file_encoding=_env_file_encoding, + _env_nested_delimiter=_env_nested_delimiter, + _secrets_dir=_secrets_dir, + ) + ) + + def _build_values( + self, + init_kwargs: Dict[str, Any], + _env_file: Optional[DotenvType] = None, + _env_file_encoding: Optional[str] = None, + _env_nested_delimiter: Optional[str] = None, + _secrets_dir: Optional[StrPath] = None, + ) -> Dict[str, Any]: + # Configure built-in sources + init_settings = InitSettingsSource(init_kwargs=init_kwargs) + env_settings = EnvSettingsSource( + env_file=(_env_file if _env_file != env_file_sentinel else self.__config__.env_file), + env_file_encoding=( + _env_file_encoding if _env_file_encoding is not None else self.__config__.env_file_encoding + ), + env_nested_delimiter=( + _env_nested_delimiter if _env_nested_delimiter is not None else self.__config__.env_nested_delimiter + ), + env_prefix_len=len(self.__config__.env_prefix), + ) + file_secret_settings = SecretsSettingsSource(secrets_dir=_secrets_dir or self.__config__.secrets_dir) + # Provide a hook to set built-in sources priority and add / remove sources + sources = self.__config__.customise_sources( + init_settings=init_settings, env_settings=env_settings, file_secret_settings=file_secret_settings + ) + if sources: + return deep_update(*reversed([source(self) for source in sources])) + else: + # no one should mean to do this, but I think returning an empty dict is marginally preferable + # to an informative error and much better than a confusing error + return {} + + class Config(BaseConfig): + env_prefix: str = '' + env_file: Optional[DotenvType] = None + env_file_encoding: Optional[str] = None + env_nested_delimiter: Optional[str] = None + secrets_dir: Optional[StrPath] = None + validate_all: bool = True + extra: Extra = Extra.forbid + arbitrary_types_allowed: bool = True + case_sensitive: bool = False + + @classmethod + def prepare_field(cls, field: ModelField) -> None: + env_names: Union[List[str], AbstractSet[str]] + field_info_from_config = cls.get_field_info(field.name) + + env = field_info_from_config.get('env') or field.field_info.extra.get('env') + if env is None: + if field.has_alias: + warnings.warn( + 'aliases are no longer used by BaseSettings to define which environment variables to read. ' + 'Instead use the "env" field setting. ' + 'See https://pydantic-docs.helpmanual.io/usage/settings/#environment-variable-names', + FutureWarning, + ) + env_names = {cls.env_prefix + field.name} + elif isinstance(env, str): + env_names = {env} + elif isinstance(env, (set, frozenset)): + env_names = env + elif sequence_like(env): + env_names = list(env) + else: + raise TypeError(f'invalid field env: {env!r} ({display_as_type(env)}); should be string, list or set') + + if not cls.case_sensitive: + env_names = env_names.__class__(n.lower() for n in env_names) + field.field_info.extra['env_names'] = env_names + + @classmethod + def customise_sources( + cls, + init_settings: SettingsSourceCallable, + env_settings: SettingsSourceCallable, + file_secret_settings: SettingsSourceCallable, + ) -> Tuple[SettingsSourceCallable, ...]: + return init_settings, env_settings, file_secret_settings + + @classmethod + def parse_env_var(cls, field_name: str, raw_val: str) -> Any: + return cls.json_loads(raw_val) + + # populated by the metaclass using the Config class defined above, annotated here to help IDEs only + __config__: ClassVar[Type[Config]] + + +class InitSettingsSource: + __slots__ = ('init_kwargs',) + + def __init__(self, init_kwargs: Dict[str, Any]): + self.init_kwargs = init_kwargs + + def __call__(self, settings: BaseSettings) -> Dict[str, Any]: + return self.init_kwargs + + def __repr__(self) -> str: + return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})' + + +class EnvSettingsSource: + __slots__ = ('env_file', 'env_file_encoding', 'env_nested_delimiter', 'env_prefix_len') + + def __init__( + self, + env_file: Optional[DotenvType], + env_file_encoding: Optional[str], + env_nested_delimiter: Optional[str] = None, + env_prefix_len: int = 0, + ): + self.env_file: Optional[DotenvType] = env_file + self.env_file_encoding: Optional[str] = env_file_encoding + self.env_nested_delimiter: Optional[str] = env_nested_delimiter + self.env_prefix_len: int = env_prefix_len + + def __call__(self, settings: BaseSettings) -> Dict[str, Any]: # noqa C901 + """ + Build environment variables suitable for passing to the Model. + """ + d: Dict[str, Any] = {} + + if settings.__config__.case_sensitive: + env_vars: Mapping[str, Optional[str]] = os.environ + else: + env_vars = {k.lower(): v for k, v in os.environ.items()} + + dotenv_vars = self._read_env_files(settings.__config__.case_sensitive) + if dotenv_vars: + env_vars = {**dotenv_vars, **env_vars} + + for field in settings.__fields__.values(): + env_val: Optional[str] = None + for env_name in field.field_info.extra['env_names']: + env_val = env_vars.get(env_name) + if env_val is not None: + break + + is_complex, allow_parse_failure = self.field_is_complex(field) + if is_complex: + if env_val is None: + # field is complex but no value found so far, try explode_env_vars + env_val_built = self.explode_env_vars(field, env_vars) + if env_val_built: + d[field.alias] = env_val_built + else: + # field is complex and there's a value, decode that as JSON, then add explode_env_vars + try: + env_val = settings.__config__.parse_env_var(field.name, env_val) + except ValueError as e: + if not allow_parse_failure: + raise SettingsError(f'error parsing env var "{env_name}"') from e + + if isinstance(env_val, dict): + d[field.alias] = deep_update(env_val, self.explode_env_vars(field, env_vars)) + else: + d[field.alias] = env_val + elif env_val is not None: + # simplest case, field is not complex, we only need to add the value if it was found + d[field.alias] = env_val + + return d + + def _read_env_files(self, case_sensitive: bool) -> Dict[str, Optional[str]]: + env_files = self.env_file + if env_files is None: + return {} + + if isinstance(env_files, (str, os.PathLike)): + env_files = [env_files] + + dotenv_vars = {} + for env_file in env_files: + env_path = Path(env_file).expanduser() + if env_path.is_file(): + dotenv_vars.update( + read_env_file(env_path, encoding=self.env_file_encoding, case_sensitive=case_sensitive) + ) + + return dotenv_vars + + def field_is_complex(self, field: ModelField) -> Tuple[bool, bool]: + """ + Find out if a field is complex, and if so whether JSON errors should be ignored + """ + if field.is_complex(): + allow_parse_failure = False + elif is_union(get_origin(field.type_)) and field.sub_fields and any(f.is_complex() for f in field.sub_fields): + allow_parse_failure = True + else: + return False, False + + return True, allow_parse_failure + + def explode_env_vars(self, field: ModelField, env_vars: Mapping[str, Optional[str]]) -> Dict[str, Any]: + """ + Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries. + + This is applied to a single field, hence filtering by env_var prefix. + """ + prefixes = [f'{env_name}{self.env_nested_delimiter}' for env_name in field.field_info.extra['env_names']] + result: Dict[str, Any] = {} + for env_name, env_val in env_vars.items(): + if not any(env_name.startswith(prefix) for prefix in prefixes): + continue + # we remove the prefix before splitting in case the prefix has characters in common with the delimiter + env_name_without_prefix = env_name[self.env_prefix_len :] + _, *keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter) + env_var = result + for key in keys: + env_var = env_var.setdefault(key, {}) + env_var[last_key] = env_val + + return result + + def __repr__(self) -> str: + return ( + f'EnvSettingsSource(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, ' + f'env_nested_delimiter={self.env_nested_delimiter!r})' + ) + + +class SecretsSettingsSource: + __slots__ = ('secrets_dir',) + + def __init__(self, secrets_dir: Optional[StrPath]): + self.secrets_dir: Optional[StrPath] = secrets_dir + + def __call__(self, settings: BaseSettings) -> Dict[str, Any]: + """ + Build fields from "secrets" files. + """ + secrets: Dict[str, Optional[str]] = {} + + if self.secrets_dir is None: + return secrets + + secrets_path = Path(self.secrets_dir).expanduser() + + if not secrets_path.exists(): + warnings.warn(f'directory "{secrets_path}" does not exist') + return secrets + + if not secrets_path.is_dir(): + raise SettingsError(f'secrets_dir must reference a directory, not a {path_type(secrets_path)}') + + for field in settings.__fields__.values(): + for env_name in field.field_info.extra['env_names']: + path = find_case_path(secrets_path, env_name, settings.__config__.case_sensitive) + if not path: + # path does not exist, we curently don't return a warning for this + continue + + if path.is_file(): + secret_value = path.read_text().strip() + if field.is_complex(): + try: + secret_value = settings.__config__.parse_env_var(field.name, secret_value) + except ValueError as e: + raise SettingsError(f'error parsing env var "{env_name}"') from e + + secrets[field.alias] = secret_value + else: + warnings.warn( + f'attempted to load secret file "{path}" but found a {path_type(path)} instead.', + stacklevel=4, + ) + return secrets + + def __repr__(self) -> str: + return f'SecretsSettingsSource(secrets_dir={self.secrets_dir!r})' + + +def read_env_file( + file_path: StrPath, *, encoding: str = None, case_sensitive: bool = False +) -> Dict[str, Optional[str]]: + try: + from dotenv import dotenv_values + except ImportError as e: + raise ImportError('python-dotenv is not installed, run `pip install pydantic[dotenv]`') from e + + file_vars: Dict[str, Optional[str]] = dotenv_values(file_path, encoding=encoding or 'utf8') + if not case_sensitive: + return {k.lower(): v for k, v in file_vars.items()} + else: + return file_vars + + +def find_case_path(dir_path: Path, file_name: str, case_sensitive: bool) -> Optional[Path]: + """ + Find a file within path's directory matching filename, optionally ignoring case. + """ + for f in dir_path.iterdir(): + if f.name == file_name: + return f + elif not case_sensitive and f.name.lower() == file_name.lower(): + return f + return None diff --git a/libs/win/pydantic/error_wrappers.cp37-win_amd64.pyd b/libs/win/pydantic/error_wrappers.cp37-win_amd64.pyd new file mode 100644 index 00000000..a9af44c7 Binary files /dev/null and b/libs/win/pydantic/error_wrappers.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/error_wrappers.py b/libs/win/pydantic/error_wrappers.py new file mode 100644 index 00000000..5d3204f4 --- /dev/null +++ b/libs/win/pydantic/error_wrappers.py @@ -0,0 +1,162 @@ +import json +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple, Type, Union + +from .json import pydantic_encoder +from .utils import Representation + +if TYPE_CHECKING: + from typing_extensions import TypedDict + + from .config import BaseConfig + from .types import ModelOrDc + from .typing import ReprArgs + + Loc = Tuple[Union[int, str], ...] + + class _ErrorDictRequired(TypedDict): + loc: Loc + msg: str + type: str + + class ErrorDict(_ErrorDictRequired, total=False): + ctx: Dict[str, Any] + + +__all__ = 'ErrorWrapper', 'ValidationError' + + +class ErrorWrapper(Representation): + __slots__ = 'exc', '_loc' + + def __init__(self, exc: Exception, loc: Union[str, 'Loc']) -> None: + self.exc = exc + self._loc = loc + + def loc_tuple(self) -> 'Loc': + if isinstance(self._loc, tuple): + return self._loc + else: + return (self._loc,) + + def __repr_args__(self) -> 'ReprArgs': + return [('exc', self.exc), ('loc', self.loc_tuple())] + + +# ErrorList is something like Union[List[Union[List[ErrorWrapper], ErrorWrapper]], ErrorWrapper] +# but recursive, therefore just use: +ErrorList = Union[Sequence[Any], ErrorWrapper] + + +class ValidationError(Representation, ValueError): + __slots__ = 'raw_errors', 'model', '_error_cache' + + def __init__(self, errors: Sequence[ErrorList], model: 'ModelOrDc') -> None: + self.raw_errors = errors + self.model = model + self._error_cache: Optional[List['ErrorDict']] = None + + def errors(self) -> List['ErrorDict']: + if self._error_cache is None: + try: + config = self.model.__config__ # type: ignore + except AttributeError: + config = self.model.__pydantic_model__.__config__ # type: ignore + self._error_cache = list(flatten_errors(self.raw_errors, config)) + return self._error_cache + + def json(self, *, indent: Union[None, int, str] = 2) -> str: + return json.dumps(self.errors(), indent=indent, default=pydantic_encoder) + + def __str__(self) -> str: + errors = self.errors() + no_errors = len(errors) + return ( + f'{no_errors} validation error{"" if no_errors == 1 else "s"} for {self.model.__name__}\n' + f'{display_errors(errors)}' + ) + + def __repr_args__(self) -> 'ReprArgs': + return [('model', self.model.__name__), ('errors', self.errors())] + + +def display_errors(errors: List['ErrorDict']) -> str: + return '\n'.join(f'{_display_error_loc(e)}\n {e["msg"]} ({_display_error_type_and_ctx(e)})' for e in errors) + + +def _display_error_loc(error: 'ErrorDict') -> str: + return ' -> '.join(str(e) for e in error['loc']) + + +def _display_error_type_and_ctx(error: 'ErrorDict') -> str: + t = 'type=' + error['type'] + ctx = error.get('ctx') + if ctx: + return t + ''.join(f'; {k}={v}' for k, v in ctx.items()) + else: + return t + + +def flatten_errors( + errors: Sequence[Any], config: Type['BaseConfig'], loc: Optional['Loc'] = None +) -> Generator['ErrorDict', None, None]: + for error in errors: + if isinstance(error, ErrorWrapper): + + if loc: + error_loc = loc + error.loc_tuple() + else: + error_loc = error.loc_tuple() + + if isinstance(error.exc, ValidationError): + yield from flatten_errors(error.exc.raw_errors, config, error_loc) + else: + yield error_dict(error.exc, config, error_loc) + elif isinstance(error, list): + yield from flatten_errors(error, config, loc=loc) + else: + raise RuntimeError(f'Unknown error object: {error}') + + +def error_dict(exc: Exception, config: Type['BaseConfig'], loc: 'Loc') -> 'ErrorDict': + type_ = get_exc_type(exc.__class__) + msg_template = config.error_msg_templates.get(type_) or getattr(exc, 'msg_template', None) + ctx = exc.__dict__ + if msg_template: + msg = msg_template.format(**ctx) + else: + msg = str(exc) + + d: 'ErrorDict' = {'loc': loc, 'msg': msg, 'type': type_} + + if ctx: + d['ctx'] = ctx + + return d + + +_EXC_TYPE_CACHE: Dict[Type[Exception], str] = {} + + +def get_exc_type(cls: Type[Exception]) -> str: + # slightly more efficient than using lru_cache since we don't need to worry about the cache filling up + try: + return _EXC_TYPE_CACHE[cls] + except KeyError: + r = _get_exc_type(cls) + _EXC_TYPE_CACHE[cls] = r + return r + + +def _get_exc_type(cls: Type[Exception]) -> str: + if issubclass(cls, AssertionError): + return 'assertion_error' + + base_name = 'type_error' if issubclass(cls, TypeError) else 'value_error' + if cls in (TypeError, ValueError): + # just TypeError or ValueError, no extra code + return base_name + + # if it's not a TypeError or ValueError, we just take the lowercase of the exception name + # no chaining or snake case logic, use "code" for more complex error types. + code = getattr(cls, 'code', None) or cls.__name__.replace('Error', '').lower() + return base_name + '.' + code diff --git a/libs/win/pydantic/errors.cp37-win_amd64.pyd b/libs/win/pydantic/errors.cp37-win_amd64.pyd new file mode 100644 index 00000000..50d8e049 Binary files /dev/null and b/libs/win/pydantic/errors.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/errors.py b/libs/win/pydantic/errors.py new file mode 100644 index 00000000..7bdafdd1 --- /dev/null +++ b/libs/win/pydantic/errors.py @@ -0,0 +1,646 @@ +from decimal import Decimal +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Sequence, Set, Tuple, Type, Union + +from .typing import display_as_type + +if TYPE_CHECKING: + from .typing import DictStrAny + +# explicitly state exports to avoid "from .errors import *" also importing Decimal, Path etc. +__all__ = ( + 'PydanticTypeError', + 'PydanticValueError', + 'ConfigError', + 'MissingError', + 'ExtraError', + 'NoneIsNotAllowedError', + 'NoneIsAllowedError', + 'WrongConstantError', + 'NotNoneError', + 'BoolError', + 'BytesError', + 'DictError', + 'EmailError', + 'UrlError', + 'UrlSchemeError', + 'UrlSchemePermittedError', + 'UrlUserInfoError', + 'UrlHostError', + 'UrlHostTldError', + 'UrlPortError', + 'UrlExtraError', + 'EnumError', + 'IntEnumError', + 'EnumMemberError', + 'IntegerError', + 'FloatError', + 'PathError', + 'PathNotExistsError', + 'PathNotAFileError', + 'PathNotADirectoryError', + 'PyObjectError', + 'SequenceError', + 'ListError', + 'SetError', + 'FrozenSetError', + 'TupleError', + 'TupleLengthError', + 'ListMinLengthError', + 'ListMaxLengthError', + 'ListUniqueItemsError', + 'SetMinLengthError', + 'SetMaxLengthError', + 'FrozenSetMinLengthError', + 'FrozenSetMaxLengthError', + 'AnyStrMinLengthError', + 'AnyStrMaxLengthError', + 'StrError', + 'StrRegexError', + 'NumberNotGtError', + 'NumberNotGeError', + 'NumberNotLtError', + 'NumberNotLeError', + 'NumberNotMultipleError', + 'DecimalError', + 'DecimalIsNotFiniteError', + 'DecimalMaxDigitsError', + 'DecimalMaxPlacesError', + 'DecimalWholeDigitsError', + 'DateTimeError', + 'DateError', + 'DateNotInThePastError', + 'DateNotInTheFutureError', + 'TimeError', + 'DurationError', + 'HashableError', + 'UUIDError', + 'UUIDVersionError', + 'ArbitraryTypeError', + 'ClassError', + 'SubclassError', + 'JsonError', + 'JsonTypeError', + 'PatternError', + 'DataclassTypeError', + 'CallableError', + 'IPvAnyAddressError', + 'IPvAnyInterfaceError', + 'IPvAnyNetworkError', + 'IPv4AddressError', + 'IPv6AddressError', + 'IPv4NetworkError', + 'IPv6NetworkError', + 'IPv4InterfaceError', + 'IPv6InterfaceError', + 'ColorError', + 'StrictBoolError', + 'NotDigitError', + 'LuhnValidationError', + 'InvalidLengthForBrand', + 'InvalidByteSize', + 'InvalidByteSizeUnit', + 'MissingDiscriminator', + 'InvalidDiscriminator', +) + + +def cls_kwargs(cls: Type['PydanticErrorMixin'], ctx: 'DictStrAny') -> 'PydanticErrorMixin': + """ + For built-in exceptions like ValueError or TypeError, we need to implement + __reduce__ to override the default behaviour (instead of __getstate__/__setstate__) + By default pickle protocol 2 calls `cls.__new__(cls, *args)`. + Since we only use kwargs, we need a little constructor to change that. + Note: the callable can't be a lambda as pickle looks in the namespace to find it + """ + return cls(**ctx) + + +class PydanticErrorMixin: + code: str + msg_template: str + + def __init__(self, **ctx: Any) -> None: + self.__dict__ = ctx + + def __str__(self) -> str: + return self.msg_template.format(**self.__dict__) + + def __reduce__(self) -> Tuple[Callable[..., 'PydanticErrorMixin'], Tuple[Type['PydanticErrorMixin'], 'DictStrAny']]: + return cls_kwargs, (self.__class__, self.__dict__) + + +class PydanticTypeError(PydanticErrorMixin, TypeError): + pass + + +class PydanticValueError(PydanticErrorMixin, ValueError): + pass + + +class ConfigError(RuntimeError): + pass + + +class MissingError(PydanticValueError): + msg_template = 'field required' + + +class ExtraError(PydanticValueError): + msg_template = 'extra fields not permitted' + + +class NoneIsNotAllowedError(PydanticTypeError): + code = 'none.not_allowed' + msg_template = 'none is not an allowed value' + + +class NoneIsAllowedError(PydanticTypeError): + code = 'none.allowed' + msg_template = 'value is not none' + + +class WrongConstantError(PydanticValueError): + code = 'const' + + def __str__(self) -> str: + permitted = ', '.join(repr(v) for v in self.permitted) # type: ignore + return f'unexpected value; permitted: {permitted}' + + +class NotNoneError(PydanticTypeError): + code = 'not_none' + msg_template = 'value is not None' + + +class BoolError(PydanticTypeError): + msg_template = 'value could not be parsed to a boolean' + + +class BytesError(PydanticTypeError): + msg_template = 'byte type expected' + + +class DictError(PydanticTypeError): + msg_template = 'value is not a valid dict' + + +class EmailError(PydanticValueError): + msg_template = 'value is not a valid email address' + + +class UrlError(PydanticValueError): + code = 'url' + + +class UrlSchemeError(UrlError): + code = 'url.scheme' + msg_template = 'invalid or missing URL scheme' + + +class UrlSchemePermittedError(UrlError): + code = 'url.scheme' + msg_template = 'URL scheme not permitted' + + def __init__(self, allowed_schemes: Set[str]): + super().__init__(allowed_schemes=allowed_schemes) + + +class UrlUserInfoError(UrlError): + code = 'url.userinfo' + msg_template = 'userinfo required in URL but missing' + + +class UrlHostError(UrlError): + code = 'url.host' + msg_template = 'URL host invalid' + + +class UrlHostTldError(UrlError): + code = 'url.host' + msg_template = 'URL host invalid, top level domain required' + + +class UrlPortError(UrlError): + code = 'url.port' + msg_template = 'URL port invalid, port cannot exceed 65535' + + +class UrlExtraError(UrlError): + code = 'url.extra' + msg_template = 'URL invalid, extra characters found after valid URL: {extra!r}' + + +class EnumMemberError(PydanticTypeError): + code = 'enum' + + def __str__(self) -> str: + permitted = ', '.join(repr(v.value) for v in self.enum_values) # type: ignore + return f'value is not a valid enumeration member; permitted: {permitted}' + + +class IntegerError(PydanticTypeError): + msg_template = 'value is not a valid integer' + + +class FloatError(PydanticTypeError): + msg_template = 'value is not a valid float' + + +class PathError(PydanticTypeError): + msg_template = 'value is not a valid path' + + +class _PathValueError(PydanticValueError): + def __init__(self, *, path: Path) -> None: + super().__init__(path=str(path)) + + +class PathNotExistsError(_PathValueError): + code = 'path.not_exists' + msg_template = 'file or directory at path "{path}" does not exist' + + +class PathNotAFileError(_PathValueError): + code = 'path.not_a_file' + msg_template = 'path "{path}" does not point to a file' + + +class PathNotADirectoryError(_PathValueError): + code = 'path.not_a_directory' + msg_template = 'path "{path}" does not point to a directory' + + +class PyObjectError(PydanticTypeError): + msg_template = 'ensure this value contains valid import path or valid callable: {error_message}' + + +class SequenceError(PydanticTypeError): + msg_template = 'value is not a valid sequence' + + +class IterableError(PydanticTypeError): + msg_template = 'value is not a valid iterable' + + +class ListError(PydanticTypeError): + msg_template = 'value is not a valid list' + + +class SetError(PydanticTypeError): + msg_template = 'value is not a valid set' + + +class FrozenSetError(PydanticTypeError): + msg_template = 'value is not a valid frozenset' + + +class DequeError(PydanticTypeError): + msg_template = 'value is not a valid deque' + + +class TupleError(PydanticTypeError): + msg_template = 'value is not a valid tuple' + + +class TupleLengthError(PydanticValueError): + code = 'tuple.length' + msg_template = 'wrong tuple length {actual_length}, expected {expected_length}' + + def __init__(self, *, actual_length: int, expected_length: int) -> None: + super().__init__(actual_length=actual_length, expected_length=expected_length) + + +class ListMinLengthError(PydanticValueError): + code = 'list.min_items' + msg_template = 'ensure this value has at least {limit_value} items' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class ListMaxLengthError(PydanticValueError): + code = 'list.max_items' + msg_template = 'ensure this value has at most {limit_value} items' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class ListUniqueItemsError(PydanticValueError): + code = 'list.unique_items' + msg_template = 'the list has duplicated items' + + +class SetMinLengthError(PydanticValueError): + code = 'set.min_items' + msg_template = 'ensure this value has at least {limit_value} items' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class SetMaxLengthError(PydanticValueError): + code = 'set.max_items' + msg_template = 'ensure this value has at most {limit_value} items' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class FrozenSetMinLengthError(PydanticValueError): + code = 'frozenset.min_items' + msg_template = 'ensure this value has at least {limit_value} items' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class FrozenSetMaxLengthError(PydanticValueError): + code = 'frozenset.max_items' + msg_template = 'ensure this value has at most {limit_value} items' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class AnyStrMinLengthError(PydanticValueError): + code = 'any_str.min_length' + msg_template = 'ensure this value has at least {limit_value} characters' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class AnyStrMaxLengthError(PydanticValueError): + code = 'any_str.max_length' + msg_template = 'ensure this value has at most {limit_value} characters' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class StrError(PydanticTypeError): + msg_template = 'str type expected' + + +class StrRegexError(PydanticValueError): + code = 'str.regex' + msg_template = 'string does not match regex "{pattern}"' + + def __init__(self, *, pattern: str) -> None: + super().__init__(pattern=pattern) + + +class _NumberBoundError(PydanticValueError): + def __init__(self, *, limit_value: Union[int, float, Decimal]) -> None: + super().__init__(limit_value=limit_value) + + +class NumberNotGtError(_NumberBoundError): + code = 'number.not_gt' + msg_template = 'ensure this value is greater than {limit_value}' + + +class NumberNotGeError(_NumberBoundError): + code = 'number.not_ge' + msg_template = 'ensure this value is greater than or equal to {limit_value}' + + +class NumberNotLtError(_NumberBoundError): + code = 'number.not_lt' + msg_template = 'ensure this value is less than {limit_value}' + + +class NumberNotLeError(_NumberBoundError): + code = 'number.not_le' + msg_template = 'ensure this value is less than or equal to {limit_value}' + + +class NumberNotFiniteError(PydanticValueError): + code = 'number.not_finite_number' + msg_template = 'ensure this value is a finite number' + + +class NumberNotMultipleError(PydanticValueError): + code = 'number.not_multiple' + msg_template = 'ensure this value is a multiple of {multiple_of}' + + def __init__(self, *, multiple_of: Union[int, float, Decimal]) -> None: + super().__init__(multiple_of=multiple_of) + + +class DecimalError(PydanticTypeError): + msg_template = 'value is not a valid decimal' + + +class DecimalIsNotFiniteError(PydanticValueError): + code = 'decimal.not_finite' + msg_template = 'value is not a valid decimal' + + +class DecimalMaxDigitsError(PydanticValueError): + code = 'decimal.max_digits' + msg_template = 'ensure that there are no more than {max_digits} digits in total' + + def __init__(self, *, max_digits: int) -> None: + super().__init__(max_digits=max_digits) + + +class DecimalMaxPlacesError(PydanticValueError): + code = 'decimal.max_places' + msg_template = 'ensure that there are no more than {decimal_places} decimal places' + + def __init__(self, *, decimal_places: int) -> None: + super().__init__(decimal_places=decimal_places) + + +class DecimalWholeDigitsError(PydanticValueError): + code = 'decimal.whole_digits' + msg_template = 'ensure that there are no more than {whole_digits} digits before the decimal point' + + def __init__(self, *, whole_digits: int) -> None: + super().__init__(whole_digits=whole_digits) + + +class DateTimeError(PydanticValueError): + msg_template = 'invalid datetime format' + + +class DateError(PydanticValueError): + msg_template = 'invalid date format' + + +class DateNotInThePastError(PydanticValueError): + code = 'date.not_in_the_past' + msg_template = 'date is not in the past' + + +class DateNotInTheFutureError(PydanticValueError): + code = 'date.not_in_the_future' + msg_template = 'date is not in the future' + + +class TimeError(PydanticValueError): + msg_template = 'invalid time format' + + +class DurationError(PydanticValueError): + msg_template = 'invalid duration format' + + +class HashableError(PydanticTypeError): + msg_template = 'value is not a valid hashable' + + +class UUIDError(PydanticTypeError): + msg_template = 'value is not a valid uuid' + + +class UUIDVersionError(PydanticValueError): + code = 'uuid.version' + msg_template = 'uuid version {required_version} expected' + + def __init__(self, *, required_version: int) -> None: + super().__init__(required_version=required_version) + + +class ArbitraryTypeError(PydanticTypeError): + code = 'arbitrary_type' + msg_template = 'instance of {expected_arbitrary_type} expected' + + def __init__(self, *, expected_arbitrary_type: Type[Any]) -> None: + super().__init__(expected_arbitrary_type=display_as_type(expected_arbitrary_type)) + + +class ClassError(PydanticTypeError): + code = 'class' + msg_template = 'a class is expected' + + +class SubclassError(PydanticTypeError): + code = 'subclass' + msg_template = 'subclass of {expected_class} expected' + + def __init__(self, *, expected_class: Type[Any]) -> None: + super().__init__(expected_class=display_as_type(expected_class)) + + +class JsonError(PydanticValueError): + msg_template = 'Invalid JSON' + + +class JsonTypeError(PydanticTypeError): + code = 'json' + msg_template = 'JSON object must be str, bytes or bytearray' + + +class PatternError(PydanticValueError): + code = 'regex_pattern' + msg_template = 'Invalid regular expression' + + +class DataclassTypeError(PydanticTypeError): + code = 'dataclass' + msg_template = 'instance of {class_name}, tuple or dict expected' + + +class CallableError(PydanticTypeError): + msg_template = '{value} is not callable' + + +class EnumError(PydanticTypeError): + code = 'enum_instance' + msg_template = '{value} is not a valid Enum instance' + + +class IntEnumError(PydanticTypeError): + code = 'int_enum_instance' + msg_template = '{value} is not a valid IntEnum instance' + + +class IPvAnyAddressError(PydanticValueError): + msg_template = 'value is not a valid IPv4 or IPv6 address' + + +class IPvAnyInterfaceError(PydanticValueError): + msg_template = 'value is not a valid IPv4 or IPv6 interface' + + +class IPvAnyNetworkError(PydanticValueError): + msg_template = 'value is not a valid IPv4 or IPv6 network' + + +class IPv4AddressError(PydanticValueError): + msg_template = 'value is not a valid IPv4 address' + + +class IPv6AddressError(PydanticValueError): + msg_template = 'value is not a valid IPv6 address' + + +class IPv4NetworkError(PydanticValueError): + msg_template = 'value is not a valid IPv4 network' + + +class IPv6NetworkError(PydanticValueError): + msg_template = 'value is not a valid IPv6 network' + + +class IPv4InterfaceError(PydanticValueError): + msg_template = 'value is not a valid IPv4 interface' + + +class IPv6InterfaceError(PydanticValueError): + msg_template = 'value is not a valid IPv6 interface' + + +class ColorError(PydanticValueError): + msg_template = 'value is not a valid color: {reason}' + + +class StrictBoolError(PydanticValueError): + msg_template = 'value is not a valid boolean' + + +class NotDigitError(PydanticValueError): + code = 'payment_card_number.digits' + msg_template = 'card number is not all digits' + + +class LuhnValidationError(PydanticValueError): + code = 'payment_card_number.luhn_check' + msg_template = 'card number is not luhn valid' + + +class InvalidLengthForBrand(PydanticValueError): + code = 'payment_card_number.invalid_length_for_brand' + msg_template = 'Length for a {brand} card must be {required_length}' + + +class InvalidByteSize(PydanticValueError): + msg_template = 'could not parse value and unit from byte string' + + +class InvalidByteSizeUnit(PydanticValueError): + msg_template = 'could not interpret byte unit: {unit}' + + +class MissingDiscriminator(PydanticValueError): + code = 'discriminated_union.missing_discriminator' + msg_template = 'Discriminator {discriminator_key!r} is missing in value' + + +class InvalidDiscriminator(PydanticValueError): + code = 'discriminated_union.invalid_discriminator' + msg_template = ( + 'No match for discriminator {discriminator_key!r} and value {discriminator_value!r} ' + '(allowed values: {allowed_values})' + ) + + def __init__(self, *, discriminator_key: str, discriminator_value: Any, allowed_values: Sequence[Any]) -> None: + super().__init__( + discriminator_key=discriminator_key, + discriminator_value=discriminator_value, + allowed_values=', '.join(map(repr, allowed_values)), + ) diff --git a/libs/win/pydantic/fields.cp37-win_amd64.pyd b/libs/win/pydantic/fields.cp37-win_amd64.pyd new file mode 100644 index 00000000..2d31a5e9 Binary files /dev/null and b/libs/win/pydantic/fields.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/fields.py b/libs/win/pydantic/fields.py new file mode 100644 index 00000000..cecd3d20 --- /dev/null +++ b/libs/win/pydantic/fields.py @@ -0,0 +1,1247 @@ +import copy +import re +from collections import Counter as CollectionCounter, defaultdict, deque +from collections.abc import Callable, Hashable as CollectionsHashable, Iterable as CollectionsIterable +from typing import ( + TYPE_CHECKING, + Any, + Counter, + DefaultDict, + Deque, + Dict, + ForwardRef, + FrozenSet, + Generator, + Iterable, + Iterator, + List, + Mapping, + Optional, + Pattern, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +from typing_extensions import Annotated, Final + +from . import errors as errors_ +from .class_validators import Validator, make_generic_validator, prep_validators +from .error_wrappers import ErrorWrapper +from .errors import ConfigError, InvalidDiscriminator, MissingDiscriminator, NoneIsNotAllowedError +from .types import Json, JsonWrapper +from .typing import ( + NoArgAnyCallable, + convert_generics, + display_as_type, + get_args, + get_origin, + is_finalvar, + is_literal_type, + is_new_type, + is_none_type, + is_typeddict, + is_typeddict_special, + is_union, + new_type_supertype, +) +from .utils import ( + PyObjectStr, + Representation, + ValueItems, + get_discriminator_alias_and_values, + get_unique_discriminator_alias, + lenient_isinstance, + lenient_issubclass, + sequence_like, + smart_deepcopy, +) +from .validators import constant_validator, dict_validator, find_validators, validate_json + +Required: Any = Ellipsis + +T = TypeVar('T') + + +class UndefinedType: + def __repr__(self) -> str: + return 'PydanticUndefined' + + def __copy__(self: T) -> T: + return self + + def __reduce__(self) -> str: + return 'Undefined' + + def __deepcopy__(self: T, _: Any) -> T: + return self + + +Undefined = UndefinedType() + +if TYPE_CHECKING: + from .class_validators import ValidatorsList + from .config import BaseConfig + from .error_wrappers import ErrorList + from .types import ModelOrDc + from .typing import AbstractSetIntStr, MappingIntStrAny, ReprArgs + + ValidateReturn = Tuple[Optional[Any], Optional[ErrorList]] + LocStr = Union[Tuple[Union[int, str], ...], str] + BoolUndefined = Union[bool, UndefinedType] + + +class FieldInfo(Representation): + """ + Captures extra information about a field. + """ + + __slots__ = ( + 'default', + 'default_factory', + 'alias', + 'alias_priority', + 'title', + 'description', + 'exclude', + 'include', + 'const', + 'gt', + 'ge', + 'lt', + 'le', + 'multiple_of', + 'allow_inf_nan', + 'max_digits', + 'decimal_places', + 'min_items', + 'max_items', + 'unique_items', + 'min_length', + 'max_length', + 'allow_mutation', + 'repr', + 'regex', + 'discriminator', + 'extra', + ) + + # field constraints with the default value, it's also used in update_from_config below + __field_constraints__ = { + 'min_length': None, + 'max_length': None, + 'regex': None, + 'gt': None, + 'lt': None, + 'ge': None, + 'le': None, + 'multiple_of': None, + 'allow_inf_nan': None, + 'max_digits': None, + 'decimal_places': None, + 'min_items': None, + 'max_items': None, + 'unique_items': None, + 'allow_mutation': True, + } + + def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: + self.default = default + self.default_factory = kwargs.pop('default_factory', None) + self.alias = kwargs.pop('alias', None) + self.alias_priority = kwargs.pop('alias_priority', 2 if self.alias is not None else None) + self.title = kwargs.pop('title', None) + self.description = kwargs.pop('description', None) + self.exclude = kwargs.pop('exclude', None) + self.include = kwargs.pop('include', None) + self.const = kwargs.pop('const', None) + self.gt = kwargs.pop('gt', None) + self.ge = kwargs.pop('ge', None) + self.lt = kwargs.pop('lt', None) + self.le = kwargs.pop('le', None) + self.multiple_of = kwargs.pop('multiple_of', None) + self.allow_inf_nan = kwargs.pop('allow_inf_nan', None) + self.max_digits = kwargs.pop('max_digits', None) + self.decimal_places = kwargs.pop('decimal_places', None) + self.min_items = kwargs.pop('min_items', None) + self.max_items = kwargs.pop('max_items', None) + self.unique_items = kwargs.pop('unique_items', None) + self.min_length = kwargs.pop('min_length', None) + self.max_length = kwargs.pop('max_length', None) + self.allow_mutation = kwargs.pop('allow_mutation', True) + self.regex = kwargs.pop('regex', None) + self.discriminator = kwargs.pop('discriminator', None) + self.repr = kwargs.pop('repr', True) + self.extra = kwargs + + def __repr_args__(self) -> 'ReprArgs': + + field_defaults_to_hide: Dict[str, Any] = { + 'repr': True, + **self.__field_constraints__, + } + + attrs = ((s, getattr(self, s)) for s in self.__slots__) + return [(a, v) for a, v in attrs if v != field_defaults_to_hide.get(a, None)] + + def get_constraints(self) -> Set[str]: + """ + Gets the constraints set on the field by comparing the constraint value with its default value + + :return: the constraints set on field_info + """ + return {attr for attr, default in self.__field_constraints__.items() if getattr(self, attr) != default} + + def update_from_config(self, from_config: Dict[str, Any]) -> None: + """ + Update this FieldInfo based on a dict from get_field_info, only fields which have not been set are dated. + """ + for attr_name, value in from_config.items(): + try: + current_value = getattr(self, attr_name) + except AttributeError: + # attr_name is not an attribute of FieldInfo, it should therefore be added to extra + # (except if extra already has this value!) + self.extra.setdefault(attr_name, value) + else: + if current_value is self.__field_constraints__.get(attr_name, None): + setattr(self, attr_name, value) + elif attr_name == 'exclude': + self.exclude = ValueItems.merge(value, current_value) + elif attr_name == 'include': + self.include = ValueItems.merge(value, current_value, intersect=True) + + def _validate(self) -> None: + if self.default is not Undefined and self.default_factory is not None: + raise ValueError('cannot specify both default and default_factory') + + +def Field( + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, + alias: str = None, + title: str = None, + description: str = None, + exclude: Union['AbstractSetIntStr', 'MappingIntStrAny', Any] = None, + include: Union['AbstractSetIntStr', 'MappingIntStrAny', Any] = None, + const: bool = None, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + multiple_of: float = None, + allow_inf_nan: bool = None, + max_digits: int = None, + decimal_places: int = None, + min_items: int = None, + max_items: int = None, + unique_items: bool = None, + min_length: int = None, + max_length: int = None, + allow_mutation: bool = True, + regex: str = None, + discriminator: str = None, + repr: bool = True, + **extra: Any, +) -> Any: + """ + Used to provide extra information about a field, either for the model schema or complex validation. Some arguments + apply only to number fields (``int``, ``float``, ``Decimal``) and some apply only to ``str``. + + :param default: since this is replacing the field’s default, its first argument is used + to set the default, use ellipsis (``...``) to indicate the field is required + :param default_factory: callable that will be called when a default value is needed for this field + If both `default` and `default_factory` are set, an error is raised. + :param alias: the public name of the field + :param title: can be any string, used in the schema + :param description: can be any string, used in the schema + :param exclude: exclude this field while dumping. + Takes same values as the ``include`` and ``exclude`` arguments on the ``.dict`` method. + :param include: include this field while dumping. + Takes same values as the ``include`` and ``exclude`` arguments on the ``.dict`` method. + :param const: this field is required and *must* take it's default value + :param gt: only applies to numbers, requires the field to be "greater than". The schema + will have an ``exclusiveMinimum`` validation keyword + :param ge: only applies to numbers, requires the field to be "greater than or equal to". The + schema will have a ``minimum`` validation keyword + :param lt: only applies to numbers, requires the field to be "less than". The schema + will have an ``exclusiveMaximum`` validation keyword + :param le: only applies to numbers, requires the field to be "less than or equal to". The + schema will have a ``maximum`` validation keyword + :param multiple_of: only applies to numbers, requires the field to be "a multiple of". The + schema will have a ``multipleOf`` validation keyword + :param allow_inf_nan: only applies to numbers, allows the field to be NaN or infinity (+inf or -inf), + which is a valid Python float. Default True, set to False for compatibility with JSON. + :param max_digits: only applies to Decimals, requires the field to have a maximum number + of digits within the decimal. It does not include a zero before the decimal point or trailing decimal zeroes. + :param decimal_places: only applies to Decimals, requires the field to have at most a number of decimal places + allowed. It does not include trailing decimal zeroes. + :param min_items: only applies to lists, requires the field to have a minimum number of + elements. The schema will have a ``minItems`` validation keyword + :param max_items: only applies to lists, requires the field to have a maximum number of + elements. The schema will have a ``maxItems`` validation keyword + :param unique_items: only applies to lists, requires the field not to have duplicated + elements. The schema will have a ``uniqueItems`` validation keyword + :param min_length: only applies to strings, requires the field to have a minimum length. The + schema will have a ``maximum`` validation keyword + :param max_length: only applies to strings, requires the field to have a maximum length. The + schema will have a ``maxLength`` validation keyword + :param allow_mutation: a boolean which defaults to True. When False, the field raises a TypeError if the field is + assigned on an instance. The BaseModel Config must set validate_assignment to True + :param regex: only applies to strings, requires the field match against a regular expression + pattern string. The schema will have a ``pattern`` validation keyword + :param discriminator: only useful with a (discriminated a.k.a. tagged) `Union` of sub models with a common field. + The `discriminator` is the name of this common field to shorten validation and improve generated schema + :param repr: show this field in the representation + :param **extra: any additional keyword arguments will be added as is to the schema + """ + field_info = FieldInfo( + default, + default_factory=default_factory, + alias=alias, + title=title, + description=description, + exclude=exclude, + include=include, + const=const, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + min_items=min_items, + max_items=max_items, + unique_items=unique_items, + min_length=min_length, + max_length=max_length, + allow_mutation=allow_mutation, + regex=regex, + discriminator=discriminator, + repr=repr, + **extra, + ) + field_info._validate() + return field_info + + +# used to be an enum but changed to int's for small performance improvement as less access overhead +SHAPE_SINGLETON = 1 +SHAPE_LIST = 2 +SHAPE_SET = 3 +SHAPE_MAPPING = 4 +SHAPE_TUPLE = 5 +SHAPE_TUPLE_ELLIPSIS = 6 +SHAPE_SEQUENCE = 7 +SHAPE_FROZENSET = 8 +SHAPE_ITERABLE = 9 +SHAPE_GENERIC = 10 +SHAPE_DEQUE = 11 +SHAPE_DICT = 12 +SHAPE_DEFAULTDICT = 13 +SHAPE_COUNTER = 14 +SHAPE_NAME_LOOKUP = { + SHAPE_LIST: 'List[{}]', + SHAPE_SET: 'Set[{}]', + SHAPE_TUPLE_ELLIPSIS: 'Tuple[{}, ...]', + SHAPE_SEQUENCE: 'Sequence[{}]', + SHAPE_FROZENSET: 'FrozenSet[{}]', + SHAPE_ITERABLE: 'Iterable[{}]', + SHAPE_DEQUE: 'Deque[{}]', + SHAPE_DICT: 'Dict[{}]', + SHAPE_DEFAULTDICT: 'DefaultDict[{}]', + SHAPE_COUNTER: 'Counter[{}]', +} + +MAPPING_LIKE_SHAPES: Set[int] = {SHAPE_DEFAULTDICT, SHAPE_DICT, SHAPE_MAPPING, SHAPE_COUNTER} + + +class ModelField(Representation): + __slots__ = ( + 'type_', + 'outer_type_', + 'annotation', + 'sub_fields', + 'sub_fields_mapping', + 'key_field', + 'validators', + 'pre_validators', + 'post_validators', + 'default', + 'default_factory', + 'required', + 'final', + 'model_config', + 'name', + 'alias', + 'has_alias', + 'field_info', + 'discriminator_key', + 'discriminator_alias', + 'validate_always', + 'allow_none', + 'shape', + 'class_validators', + 'parse_json', + ) + + def __init__( + self, + *, + name: str, + type_: Type[Any], + class_validators: Optional[Dict[str, Validator]], + model_config: Type['BaseConfig'], + default: Any = None, + default_factory: Optional[NoArgAnyCallable] = None, + required: 'BoolUndefined' = Undefined, + final: bool = False, + alias: str = None, + field_info: Optional[FieldInfo] = None, + ) -> None: + + self.name: str = name + self.has_alias: bool = alias is not None + self.alias: str = alias if alias is not None else name + self.annotation = type_ + self.type_: Any = convert_generics(type_) + self.outer_type_: Any = type_ + self.class_validators = class_validators or {} + self.default: Any = default + self.default_factory: Optional[NoArgAnyCallable] = default_factory + self.required: 'BoolUndefined' = required + self.final: bool = final + self.model_config = model_config + self.field_info: FieldInfo = field_info or FieldInfo(default) + self.discriminator_key: Optional[str] = self.field_info.discriminator + self.discriminator_alias: Optional[str] = self.discriminator_key + + self.allow_none: bool = False + self.validate_always: bool = False + self.sub_fields: Optional[List[ModelField]] = None + self.sub_fields_mapping: Optional[Dict[str, 'ModelField']] = None # used for discriminated union + self.key_field: Optional[ModelField] = None + self.validators: 'ValidatorsList' = [] + self.pre_validators: Optional['ValidatorsList'] = None + self.post_validators: Optional['ValidatorsList'] = None + self.parse_json: bool = False + self.shape: int = SHAPE_SINGLETON + self.model_config.prepare_field(self) + self.prepare() + + def get_default(self) -> Any: + return smart_deepcopy(self.default) if self.default_factory is None else self.default_factory() + + @staticmethod + def _get_field_info( + field_name: str, annotation: Any, value: Any, config: Type['BaseConfig'] + ) -> Tuple[FieldInfo, Any]: + """ + Get a FieldInfo from a root typing.Annotated annotation, value, or config default. + + The FieldInfo may be set in typing.Annotated or the value, but not both. If neither contain + a FieldInfo, a new one will be created using the config. + + :param field_name: name of the field for use in error messages + :param annotation: a type hint such as `str` or `Annotated[str, Field(..., min_length=5)]` + :param value: the field's assigned value + :param config: the model's config object + :return: the FieldInfo contained in the `annotation`, the value, or a new one from the config. + """ + field_info_from_config = config.get_field_info(field_name) + + field_info = None + if get_origin(annotation) is Annotated: + field_infos = [arg for arg in get_args(annotation)[1:] if isinstance(arg, FieldInfo)] + if len(field_infos) > 1: + raise ValueError(f'cannot specify multiple `Annotated` `Field`s for {field_name!r}') + field_info = next(iter(field_infos), None) + if field_info is not None: + field_info = copy.copy(field_info) + field_info.update_from_config(field_info_from_config) + if field_info.default not in (Undefined, Required): + raise ValueError(f'`Field` default cannot be set in `Annotated` for {field_name!r}') + if value is not Undefined and value is not Required: + # check also `Required` because of `validate_arguments` that sets `...` as default value + field_info.default = value + + if isinstance(value, FieldInfo): + if field_info is not None: + raise ValueError(f'cannot specify `Annotated` and value `Field`s together for {field_name!r}') + field_info = value + field_info.update_from_config(field_info_from_config) + elif field_info is None: + field_info = FieldInfo(value, **field_info_from_config) + value = None if field_info.default_factory is not None else field_info.default + field_info._validate() + return field_info, value + + @classmethod + def infer( + cls, + *, + name: str, + value: Any, + annotation: Any, + class_validators: Optional[Dict[str, Validator]], + config: Type['BaseConfig'], + ) -> 'ModelField': + from .schema import get_annotation_from_field_info + + field_info, value = cls._get_field_info(name, annotation, value, config) + required: 'BoolUndefined' = Undefined + if value is Required: + required = True + value = None + elif value is not Undefined: + required = False + annotation = get_annotation_from_field_info(annotation, field_info, name, config.validate_assignment) + + return cls( + name=name, + type_=annotation, + alias=field_info.alias, + class_validators=class_validators, + default=value, + default_factory=field_info.default_factory, + required=required, + model_config=config, + field_info=field_info, + ) + + def set_config(self, config: Type['BaseConfig']) -> None: + self.model_config = config + info_from_config = config.get_field_info(self.name) + config.prepare_field(self) + new_alias = info_from_config.get('alias') + new_alias_priority = info_from_config.get('alias_priority') or 0 + if new_alias and new_alias_priority >= (self.field_info.alias_priority or 0): + self.field_info.alias = new_alias + self.field_info.alias_priority = new_alias_priority + self.alias = new_alias + new_exclude = info_from_config.get('exclude') + if new_exclude is not None: + self.field_info.exclude = ValueItems.merge(self.field_info.exclude, new_exclude) + new_include = info_from_config.get('include') + if new_include is not None: + self.field_info.include = ValueItems.merge(self.field_info.include, new_include, intersect=True) + + @property + def alt_alias(self) -> bool: + return self.name != self.alias + + def prepare(self) -> None: + """ + Prepare the field but inspecting self.default, self.type_ etc. + + Note: this method is **not** idempotent (because _type_analysis is not idempotent), + e.g. calling it it multiple times may modify the field and configure it incorrectly. + """ + self._set_default_and_type() + if self.type_.__class__ is ForwardRef or self.type_.__class__ is DeferredType: + # self.type_ is currently a ForwardRef and there's nothing we can do now, + # user will need to call model.update_forward_refs() + return + + self._type_analysis() + if self.required is Undefined: + self.required = True + if self.default is Undefined and self.default_factory is None: + self.default = None + self.populate_validators() + + def _set_default_and_type(self) -> None: + """ + Set the default value, infer the type if needed and check if `None` value is valid. + """ + if self.default_factory is not None: + if self.type_ is Undefined: + raise errors_.ConfigError( + f'you need to set the type of field {self.name!r} when using `default_factory`' + ) + return + + default_value = self.get_default() + + if default_value is not None and self.type_ is Undefined: + self.type_ = default_value.__class__ + self.outer_type_ = self.type_ + self.annotation = self.type_ + + if self.type_ is Undefined: + raise errors_.ConfigError(f'unable to infer type for attribute "{self.name}"') + + if self.required is False and default_value is None: + self.allow_none = True + + def _type_analysis(self) -> None: # noqa: C901 (ignore complexity) + # typing interface is horrible, we have to do some ugly checks + if lenient_issubclass(self.type_, JsonWrapper): + self.type_ = self.type_.inner_type + self.parse_json = True + elif lenient_issubclass(self.type_, Json): + self.type_ = Any + self.parse_json = True + elif isinstance(self.type_, TypeVar): + if self.type_.__bound__: + self.type_ = self.type_.__bound__ + elif self.type_.__constraints__: + self.type_ = Union[self.type_.__constraints__] + else: + self.type_ = Any + elif is_new_type(self.type_): + self.type_ = new_type_supertype(self.type_) + + if self.type_ is Any or self.type_ is object: + if self.required is Undefined: + self.required = False + self.allow_none = True + return + elif self.type_ is Pattern or self.type_ is re.Pattern: + # python 3.7 only, Pattern is a typing object but without sub fields + return + elif is_literal_type(self.type_): + return + elif is_typeddict(self.type_): + return + + if is_finalvar(self.type_): + self.final = True + + if self.type_ is Final: + self.type_ = Any + else: + self.type_ = get_args(self.type_)[0] + + self._type_analysis() + return + + origin = get_origin(self.type_) + + if origin is Annotated or is_typeddict_special(origin): + self.type_ = get_args(self.type_)[0] + self._type_analysis() + return + + if self.discriminator_key is not None and not is_union(origin): + raise TypeError('`discriminator` can only be used with `Union` type with more than one variant') + + # add extra check for `collections.abc.Hashable` for python 3.10+ where origin is not `None` + if origin is None or origin is CollectionsHashable: + # field is not "typing" object eg. Union, Dict, List etc. + # allow None for virtual superclasses of NoneType, e.g. Hashable + if isinstance(self.type_, type) and isinstance(None, self.type_): + self.allow_none = True + return + elif origin is Callable: + return + elif is_union(origin): + types_ = [] + for type_ in get_args(self.type_): + if is_none_type(type_) or type_ is Any or type_ is object: + if self.required is Undefined: + self.required = False + self.allow_none = True + if is_none_type(type_): + continue + types_.append(type_) + + if len(types_) == 1: + # Optional[] + self.type_ = types_[0] + # this is the one case where the "outer type" isn't just the original type + self.outer_type_ = self.type_ + # re-run to correctly interpret the new self.type_ + self._type_analysis() + else: + self.sub_fields = [self._create_sub_type(t, f'{self.name}_{display_as_type(t)}') for t in types_] + + if self.discriminator_key is not None: + self.prepare_discriminated_union_sub_fields() + return + elif issubclass(origin, Tuple): # type: ignore + # origin == Tuple without item type + args = get_args(self.type_) + if not args: # plain tuple + self.type_ = Any + self.shape = SHAPE_TUPLE_ELLIPSIS + elif len(args) == 2 and args[1] is Ellipsis: # e.g. Tuple[int, ...] + self.type_ = args[0] + self.shape = SHAPE_TUPLE_ELLIPSIS + self.sub_fields = [self._create_sub_type(args[0], f'{self.name}_0')] + elif args == ((),): # Tuple[()] means empty tuple + self.shape = SHAPE_TUPLE + self.type_ = Any + self.sub_fields = [] + else: + self.shape = SHAPE_TUPLE + self.sub_fields = [self._create_sub_type(t, f'{self.name}_{i}') for i, t in enumerate(args)] + return + elif issubclass(origin, List): + # Create self validators + get_validators = getattr(self.type_, '__get_validators__', None) + if get_validators: + self.class_validators.update( + {f'list_{i}': Validator(validator, pre=True) for i, validator in enumerate(get_validators())} + ) + + self.type_ = get_args(self.type_)[0] + self.shape = SHAPE_LIST + elif issubclass(origin, Set): + # Create self validators + get_validators = getattr(self.type_, '__get_validators__', None) + if get_validators: + self.class_validators.update( + {f'set_{i}': Validator(validator, pre=True) for i, validator in enumerate(get_validators())} + ) + + self.type_ = get_args(self.type_)[0] + self.shape = SHAPE_SET + elif issubclass(origin, FrozenSet): + # Create self validators + get_validators = getattr(self.type_, '__get_validators__', None) + if get_validators: + self.class_validators.update( + {f'frozenset_{i}': Validator(validator, pre=True) for i, validator in enumerate(get_validators())} + ) + + self.type_ = get_args(self.type_)[0] + self.shape = SHAPE_FROZENSET + elif issubclass(origin, Deque): + self.type_ = get_args(self.type_)[0] + self.shape = SHAPE_DEQUE + elif issubclass(origin, Sequence): + self.type_ = get_args(self.type_)[0] + self.shape = SHAPE_SEQUENCE + # priority to most common mapping: dict + elif origin is dict or origin is Dict: + self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True) + self.type_ = get_args(self.type_)[1] + self.shape = SHAPE_DICT + elif issubclass(origin, DefaultDict): + self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True) + self.type_ = get_args(self.type_)[1] + self.shape = SHAPE_DEFAULTDICT + elif issubclass(origin, Counter): + self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True) + self.type_ = int + self.shape = SHAPE_COUNTER + elif issubclass(origin, Mapping): + self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True) + self.type_ = get_args(self.type_)[1] + self.shape = SHAPE_MAPPING + # Equality check as almost everything inherits form Iterable, including str + # check for Iterable and CollectionsIterable, as it could receive one even when declared with the other + elif origin in {Iterable, CollectionsIterable}: + self.type_ = get_args(self.type_)[0] + self.shape = SHAPE_ITERABLE + self.sub_fields = [self._create_sub_type(self.type_, f'{self.name}_type')] + elif issubclass(origin, Type): # type: ignore + return + elif hasattr(origin, '__get_validators__') or self.model_config.arbitrary_types_allowed: + # Is a Pydantic-compatible generic that handles itself + # or we have arbitrary_types_allowed = True + self.shape = SHAPE_GENERIC + self.sub_fields = [self._create_sub_type(t, f'{self.name}_{i}') for i, t in enumerate(get_args(self.type_))] + self.type_ = origin + return + else: + raise TypeError(f'Fields of type "{origin}" are not supported.') + + # type_ has been refined eg. as the type of a List and sub_fields needs to be populated + self.sub_fields = [self._create_sub_type(self.type_, '_' + self.name)] + + def prepare_discriminated_union_sub_fields(self) -> None: + """ + Prepare the mapping -> and update `sub_fields` + Note that this process can be aborted if a `ForwardRef` is encountered + """ + assert self.discriminator_key is not None + + if self.type_.__class__ is DeferredType: + return + + assert self.sub_fields is not None + sub_fields_mapping: Dict[str, 'ModelField'] = {} + all_aliases: Set[str] = set() + + for sub_field in self.sub_fields: + t = sub_field.type_ + if t.__class__ is ForwardRef: + # Stopping everything...will need to call `update_forward_refs` + return + + alias, discriminator_values = get_discriminator_alias_and_values(t, self.discriminator_key) + all_aliases.add(alias) + for discriminator_value in discriminator_values: + sub_fields_mapping[discriminator_value] = sub_field + + self.sub_fields_mapping = sub_fields_mapping + self.discriminator_alias = get_unique_discriminator_alias(all_aliases, self.discriminator_key) + + def _create_sub_type(self, type_: Type[Any], name: str, *, for_keys: bool = False) -> 'ModelField': + if for_keys: + class_validators = None + else: + # validators for sub items should not have `each_item` as we want to check only the first sublevel + class_validators = { + k: Validator( + func=v.func, + pre=v.pre, + each_item=False, + always=v.always, + check_fields=v.check_fields, + skip_on_failure=v.skip_on_failure, + ) + for k, v in self.class_validators.items() + if v.each_item + } + + field_info, _ = self._get_field_info(name, type_, None, self.model_config) + + return self.__class__( + type_=type_, + name=name, + class_validators=class_validators, + model_config=self.model_config, + field_info=field_info, + ) + + def populate_validators(self) -> None: + """ + Prepare self.pre_validators, self.validators, and self.post_validators based on self.type_'s __get_validators__ + and class validators. This method should be idempotent, e.g. it should be safe to call multiple times + without mis-configuring the field. + """ + self.validate_always = getattr(self.type_, 'validate_always', False) or any( + v.always for v in self.class_validators.values() + ) + + class_validators_ = self.class_validators.values() + if not self.sub_fields or self.shape == SHAPE_GENERIC: + get_validators = getattr(self.type_, '__get_validators__', None) + v_funcs = ( + *[v.func for v in class_validators_ if v.each_item and v.pre], + *(get_validators() if get_validators else list(find_validators(self.type_, self.model_config))), + *[v.func for v in class_validators_ if v.each_item and not v.pre], + ) + self.validators = prep_validators(v_funcs) + + self.pre_validators = [] + self.post_validators = [] + + if self.field_info and self.field_info.const: + self.post_validators.append(make_generic_validator(constant_validator)) + + if class_validators_: + self.pre_validators += prep_validators(v.func for v in class_validators_ if not v.each_item and v.pre) + self.post_validators += prep_validators(v.func for v in class_validators_ if not v.each_item and not v.pre) + + if self.parse_json: + self.pre_validators.append(make_generic_validator(validate_json)) + + self.pre_validators = self.pre_validators or None + self.post_validators = self.post_validators or None + + def validate( + self, v: Any, values: Dict[str, Any], *, loc: 'LocStr', cls: Optional['ModelOrDc'] = None + ) -> 'ValidateReturn': + + assert self.type_.__class__ is not DeferredType + + if self.type_.__class__ is ForwardRef: + assert cls is not None + raise ConfigError( + f'field "{self.name}" not yet prepared so type is still a ForwardRef, ' + f'you might need to call {cls.__name__}.update_forward_refs().' + ) + + errors: Optional['ErrorList'] + if self.pre_validators: + v, errors = self._apply_validators(v, values, loc, cls, self.pre_validators) + if errors: + return v, errors + + if v is None: + if is_none_type(self.type_): + # keep validating + pass + elif self.allow_none: + if self.post_validators: + return self._apply_validators(v, values, loc, cls, self.post_validators) + else: + return None, None + else: + return v, ErrorWrapper(NoneIsNotAllowedError(), loc) + + if self.shape == SHAPE_SINGLETON: + v, errors = self._validate_singleton(v, values, loc, cls) + elif self.shape in MAPPING_LIKE_SHAPES: + v, errors = self._validate_mapping_like(v, values, loc, cls) + elif self.shape == SHAPE_TUPLE: + v, errors = self._validate_tuple(v, values, loc, cls) + elif self.shape == SHAPE_ITERABLE: + v, errors = self._validate_iterable(v, values, loc, cls) + elif self.shape == SHAPE_GENERIC: + v, errors = self._apply_validators(v, values, loc, cls, self.validators) + else: + # sequence, list, set, generator, tuple with ellipsis, frozen set + v, errors = self._validate_sequence_like(v, values, loc, cls) + + if not errors and self.post_validators: + v, errors = self._apply_validators(v, values, loc, cls, self.post_validators) + return v, errors + + def _validate_sequence_like( # noqa: C901 (ignore complexity) + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] + ) -> 'ValidateReturn': + """ + Validate sequence-like containers: lists, tuples, sets and generators + Note that large if-else blocks are necessary to enable Cython + optimization, which is why we disable the complexity check above. + """ + if not sequence_like(v): + e: errors_.PydanticTypeError + if self.shape == SHAPE_LIST: + e = errors_.ListError() + elif self.shape in (SHAPE_TUPLE, SHAPE_TUPLE_ELLIPSIS): + e = errors_.TupleError() + elif self.shape == SHAPE_SET: + e = errors_.SetError() + elif self.shape == SHAPE_FROZENSET: + e = errors_.FrozenSetError() + else: + e = errors_.SequenceError() + return v, ErrorWrapper(e, loc) + + loc = loc if isinstance(loc, tuple) else (loc,) + result = [] + errors: List[ErrorList] = [] + for i, v_ in enumerate(v): + v_loc = *loc, i + r, ee = self._validate_singleton(v_, values, v_loc, cls) + if ee: + errors.append(ee) + else: + result.append(r) + + if errors: + return v, errors + + converted: Union[List[Any], Set[Any], FrozenSet[Any], Tuple[Any, ...], Iterator[Any], Deque[Any]] = result + + if self.shape == SHAPE_SET: + converted = set(result) + elif self.shape == SHAPE_FROZENSET: + converted = frozenset(result) + elif self.shape == SHAPE_TUPLE_ELLIPSIS: + converted = tuple(result) + elif self.shape == SHAPE_DEQUE: + converted = deque(result) + elif self.shape == SHAPE_SEQUENCE: + if isinstance(v, tuple): + converted = tuple(result) + elif isinstance(v, set): + converted = set(result) + elif isinstance(v, Generator): + converted = iter(result) + elif isinstance(v, deque): + converted = deque(result) + return converted, None + + def _validate_iterable( + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] + ) -> 'ValidateReturn': + """ + Validate Iterables. + + This intentionally doesn't validate values to allow infinite generators. + """ + + try: + iterable = iter(v) + except TypeError: + return v, ErrorWrapper(errors_.IterableError(), loc) + return iterable, None + + def _validate_tuple( + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] + ) -> 'ValidateReturn': + e: Optional[Exception] = None + if not sequence_like(v): + e = errors_.TupleError() + else: + actual_length, expected_length = len(v), len(self.sub_fields) # type: ignore + if actual_length != expected_length: + e = errors_.TupleLengthError(actual_length=actual_length, expected_length=expected_length) + + if e: + return v, ErrorWrapper(e, loc) + + loc = loc if isinstance(loc, tuple) else (loc,) + result = [] + errors: List[ErrorList] = [] + for i, (v_, field) in enumerate(zip(v, self.sub_fields)): # type: ignore + v_loc = *loc, i + r, ee = field.validate(v_, values, loc=v_loc, cls=cls) + if ee: + errors.append(ee) + else: + result.append(r) + + if errors: + return v, errors + else: + return tuple(result), None + + def _validate_mapping_like( + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] + ) -> 'ValidateReturn': + try: + v_iter = dict_validator(v) + except TypeError as exc: + return v, ErrorWrapper(exc, loc) + + loc = loc if isinstance(loc, tuple) else (loc,) + result, errors = {}, [] + for k, v_ in v_iter.items(): + v_loc = *loc, '__key__' + key_result, key_errors = self.key_field.validate(k, values, loc=v_loc, cls=cls) # type: ignore + if key_errors: + errors.append(key_errors) + continue + + v_loc = *loc, k + value_result, value_errors = self._validate_singleton(v_, values, v_loc, cls) + if value_errors: + errors.append(value_errors) + continue + + result[key_result] = value_result + if errors: + return v, errors + elif self.shape == SHAPE_DICT: + return result, None + elif self.shape == SHAPE_DEFAULTDICT: + return defaultdict(self.type_, result), None + elif self.shape == SHAPE_COUNTER: + return CollectionCounter(result), None + else: + return self._get_mapping_value(v, result), None + + def _get_mapping_value(self, original: T, converted: Dict[Any, Any]) -> Union[T, Dict[Any, Any]]: + """ + When type is `Mapping[KT, KV]` (or another unsupported mapping), we try to avoid + coercing to `dict` unwillingly. + """ + original_cls = original.__class__ + + if original_cls == dict or original_cls == Dict: + return converted + elif original_cls in {defaultdict, DefaultDict}: + return defaultdict(self.type_, converted) + else: + try: + # Counter, OrderedDict, UserDict, ... + return original_cls(converted) # type: ignore + except TypeError: + raise RuntimeError(f'Could not convert dictionary to {original_cls.__name__!r}') from None + + def _validate_singleton( + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] + ) -> 'ValidateReturn': + if self.sub_fields: + if self.discriminator_key is not None: + return self._validate_discriminated_union(v, values, loc, cls) + + errors = [] + + if self.model_config.smart_union and is_union(get_origin(self.type_)): + # 1st pass: check if the value is an exact instance of one of the Union types + # (e.g. to avoid coercing a bool into an int) + for field in self.sub_fields: + if v.__class__ is field.outer_type_: + return v, None + + # 2nd pass: check if the value is an instance of any subclass of the Union types + for field in self.sub_fields: + # This whole logic will be improved later on to support more complex `isinstance` checks + # It will probably be done once a strict mode is added and be something like: + # ``` + # value, error = field.validate(v, values, strict=True) + # if error is None: + # return value, None + # ``` + try: + if isinstance(v, field.outer_type_): + return v, None + except TypeError: + # compound type + if lenient_isinstance(v, get_origin(field.outer_type_)): + value, error = field.validate(v, values, loc=loc, cls=cls) + if not error: + return value, None + + # 1st pass by default or 3rd pass with `smart_union` enabled: + # check if the value can be coerced into one of the Union types + for field in self.sub_fields: + value, error = field.validate(v, values, loc=loc, cls=cls) + if error: + errors.append(error) + else: + return value, None + return v, errors + else: + return self._apply_validators(v, values, loc, cls, self.validators) + + def _validate_discriminated_union( + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] + ) -> 'ValidateReturn': + assert self.discriminator_key is not None + assert self.discriminator_alias is not None + + try: + discriminator_value = v[self.discriminator_alias] + except KeyError: + return v, ErrorWrapper(MissingDiscriminator(discriminator_key=self.discriminator_key), loc) + except TypeError: + try: + # BaseModel or dataclass + discriminator_value = getattr(v, self.discriminator_key) + except (AttributeError, TypeError): + return v, ErrorWrapper(MissingDiscriminator(discriminator_key=self.discriminator_key), loc) + + try: + sub_field = self.sub_fields_mapping[discriminator_value] # type: ignore[index] + except TypeError: + assert cls is not None + raise ConfigError( + f'field "{self.name}" not yet prepared so type is still a ForwardRef, ' + f'you might need to call {cls.__name__}.update_forward_refs().' + ) + except KeyError: + assert self.sub_fields_mapping is not None + return v, ErrorWrapper( + InvalidDiscriminator( + discriminator_key=self.discriminator_key, + discriminator_value=discriminator_value, + allowed_values=list(self.sub_fields_mapping), + ), + loc, + ) + else: + if not isinstance(loc, tuple): + loc = (loc,) + return sub_field.validate(v, values, loc=(*loc, display_as_type(sub_field.type_)), cls=cls) + + def _apply_validators( + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'], validators: 'ValidatorsList' + ) -> 'ValidateReturn': + for validator in validators: + try: + v = validator(cls, v, values, self, self.model_config) + except (ValueError, TypeError, AssertionError) as exc: + return v, ErrorWrapper(exc, loc) + return v, None + + def is_complex(self) -> bool: + """ + Whether the field is "complex" eg. env variables should be parsed as JSON. + """ + from .main import BaseModel + + return ( + self.shape != SHAPE_SINGLETON + or hasattr(self.type_, '__pydantic_model__') + or lenient_issubclass(self.type_, (BaseModel, list, set, frozenset, dict)) + ) + + def _type_display(self) -> PyObjectStr: + t = display_as_type(self.type_) + + if self.shape in MAPPING_LIKE_SHAPES: + t = f'Mapping[{display_as_type(self.key_field.type_)}, {t}]' # type: ignore + elif self.shape == SHAPE_TUPLE: + t = 'Tuple[{}]'.format(', '.join(display_as_type(f.type_) for f in self.sub_fields)) # type: ignore + elif self.shape == SHAPE_GENERIC: + assert self.sub_fields + t = '{}[{}]'.format( + display_as_type(self.type_), ', '.join(display_as_type(f.type_) for f in self.sub_fields) + ) + elif self.shape != SHAPE_SINGLETON: + t = SHAPE_NAME_LOOKUP[self.shape].format(t) + + if self.allow_none and (self.shape != SHAPE_SINGLETON or not self.sub_fields): + t = f'Optional[{t}]' + return PyObjectStr(t) + + def __repr_args__(self) -> 'ReprArgs': + args = [('name', self.name), ('type', self._type_display()), ('required', self.required)] + + if not self.required: + if self.default_factory is not None: + args.append(('default_factory', f'')) + else: + args.append(('default', self.default)) + + if self.alt_alias: + args.append(('alias', self.alias)) + return args + + +class ModelPrivateAttr(Representation): + __slots__ = ('default', 'default_factory') + + def __init__(self, default: Any = Undefined, *, default_factory: Optional[NoArgAnyCallable] = None) -> None: + self.default = default + self.default_factory = default_factory + + def get_default(self) -> Any: + return smart_deepcopy(self.default) if self.default_factory is None else self.default_factory() + + def __eq__(self, other: Any) -> bool: + return isinstance(other, self.__class__) and (self.default, self.default_factory) == ( + other.default, + other.default_factory, + ) + + +def PrivateAttr( + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, +) -> Any: + """ + Indicates that attribute is only used internally and never mixed with regular fields. + + Types or values of private attrs are not checked by pydantic and it's up to you to keep them relevant. + + Private attrs are stored in model __slots__. + + :param default: the attribute’s default value + :param default_factory: callable that will be called when a default value is needed for this attribute + If both `default` and `default_factory` are set, an error is raised. + """ + if default is not Undefined and default_factory is not None: + raise ValueError('cannot specify both default and default_factory') + + return ModelPrivateAttr( + default, + default_factory=default_factory, + ) + + +class DeferredType: + """ + Used to postpone field preparation, while creating recursive generic models. + """ + + +def is_finalvar_with_default_val(type_: Type[Any], val: Any) -> bool: + return is_finalvar(type_) and val is not Undefined and not isinstance(val, FieldInfo) diff --git a/libs/win/pydantic/generics.py b/libs/win/pydantic/generics.py new file mode 100644 index 00000000..a3f52bfe --- /dev/null +++ b/libs/win/pydantic/generics.py @@ -0,0 +1,364 @@ +import sys +import typing +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + Generic, + Iterator, + List, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from typing_extensions import Annotated + +from .class_validators import gather_all_validators +from .fields import DeferredType +from .main import BaseModel, create_model +from .types import JsonWrapper +from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base +from .utils import LimitedDict, all_identical, lenient_issubclass + +GenericModelT = TypeVar('GenericModelT', bound='GenericModel') +TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type + +Parametrization = Mapping[TypeVarType, Type[Any]] + +_generic_types_cache: LimitedDict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = LimitedDict() +# _assigned_parameters is a Mapping from parametrized version of generic models to assigned types of parametrizations +# as captured during construction of the class (not instances). +# E.g., for generic model `Model[A, B]`, when parametrized model `Model[int, str]` is created, +# `Model[int, str]`: {A: int, B: str}` will be stored in `_assigned_parameters`. +# (This information is only otherwise available after creation from the class name string). +_assigned_parameters: LimitedDict[Type[Any], Parametrization] = LimitedDict() + + +class GenericModel(BaseModel): + __slots__ = () + __concrete__: ClassVar[bool] = False + + if TYPE_CHECKING: + # Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with + # `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of + # `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below. + __parameters__: ClassVar[Tuple[TypeVarType, ...]] + + # Setting the return type as Type[Any] instead of Type[BaseModel] prevents PyCharm warnings + def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[Type[Any], ...]]) -> Type[Any]: + """Instantiates a new class from a generic class `cls` and type variables `params`. + + :param params: Tuple of types the class . Given a generic class + `Model` with 2 type variables and a concrete model `Model[str, int]`, + the value `(str, int)` would be passed to `params`. + :return: New model class inheriting from `cls` with instantiated + types described by `params`. If no parameters are given, `cls` is + returned as is. + + """ + + def _cache_key(_params: Any) -> Tuple[Type[GenericModelT], Any, Tuple[Any, ...]]: + return cls, _params, get_args(_params) + + cached = _generic_types_cache.get(_cache_key(params)) + if cached is not None: + return cached + if cls.__concrete__ and Generic not in cls.__bases__: + raise TypeError('Cannot parameterize a concrete instantiation of a generic model') + if not isinstance(params, tuple): + params = (params,) + if cls is GenericModel and any(isinstance(param, TypeVar) for param in params): + raise TypeError('Type parameters should be placed on typing.Generic, not GenericModel') + if not hasattr(cls, '__parameters__'): + raise TypeError(f'Type {cls.__name__} must inherit from typing.Generic before being parameterized') + + check_parameters_count(cls, params) + # Build map from generic typevars to passed params + typevars_map: Dict[TypeVarType, Type[Any]] = dict(zip(cls.__parameters__, params)) + if all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map: + return cls # if arguments are equal to parameters it's the same object + + # Create new model with original model as parent inserting fields with DeferredType. + model_name = cls.__concrete_name__(params) + validators = gather_all_validators(cls) + + type_hints = get_all_type_hints(cls).items() + instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar} + + fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__} + + model_module, called_globally = get_caller_frame_info() + created_model = cast( + Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes + create_model( + model_name, + __module__=model_module or cls.__module__, + __base__=(cls,) + tuple(cls.__parameterized_bases__(typevars_map)), + __config__=None, + __validators__=validators, + __cls_kwargs__=None, + **fields, + ), + ) + + _assigned_parameters[created_model] = typevars_map + + if called_globally: # create global reference and therefore allow pickling + object_by_reference = None + reference_name = model_name + reference_module_globals = sys.modules[created_model.__module__].__dict__ + while object_by_reference is not created_model: + object_by_reference = reference_module_globals.setdefault(reference_name, created_model) + reference_name += '_' + + created_model.Config = cls.Config + + # Find any typevars that are still present in the model. + # If none are left, the model is fully "concrete", otherwise the new + # class is a generic class as well taking the found typevars as + # parameters. + new_params = tuple( + {param: None for param in iter_contained_typevars(typevars_map.values())} + ) # use dict as ordered set + created_model.__concrete__ = not new_params + if new_params: + created_model.__parameters__ = new_params + + # Save created model in cache so we don't end up creating duplicate + # models that should be identical. + _generic_types_cache[_cache_key(params)] = created_model + if len(params) == 1: + _generic_types_cache[_cache_key(params[0])] = created_model + + # Recursively walk class type hints and replace generic typevars + # with concrete types that were passed. + _prepare_model_fields(created_model, fields, instance_type_hints, typevars_map) + + return created_model + + @classmethod + def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str: + """Compute class name for child classes. + + :param params: Tuple of types the class . Given a generic class + `Model` with 2 type variables and a concrete model `Model[str, int]`, + the value `(str, int)` would be passed to `params`. + :return: String representing a the new class where `params` are + passed to `cls` as type variables. + + This method can be overridden to achieve a custom naming scheme for GenericModels. + """ + param_names = [display_as_type(param) for param in params] + params_component = ', '.join(param_names) + return f'{cls.__name__}[{params_component}]' + + @classmethod + def __parameterized_bases__(cls, typevars_map: Parametrization) -> Iterator[Type[Any]]: + """ + Returns unbound bases of cls parameterised to given type variables + + :param typevars_map: Dictionary of type applications for binding subclasses. + Given a generic class `Model` with 2 type variables [S, T] + and a concrete model `Model[str, int]`, + the value `{S: str, T: int}` would be passed to `typevars_map`. + :return: an iterator of generic sub classes, parameterised by `typevars_map` + and other assigned parameters of `cls` + + e.g.: + ``` + class A(GenericModel, Generic[T]): + ... + + class B(A[V], Generic[V]): + ... + + assert A[int] in B.__parameterized_bases__({V: int}) + ``` + """ + + def build_base_model( + base_model: Type[GenericModel], mapped_types: Parametrization + ) -> Iterator[Type[GenericModel]]: + base_parameters = tuple(mapped_types[param] for param in base_model.__parameters__) + parameterized_base = base_model.__class_getitem__(base_parameters) + if parameterized_base is base_model or parameterized_base is cls: + # Avoid duplication in MRO + return + yield parameterized_base + + for base_model in cls.__bases__: + if not issubclass(base_model, GenericModel): + # not a class that can be meaningfully parameterized + continue + elif not getattr(base_model, '__parameters__', None): + # base_model is "GenericModel" (and has no __parameters__) + # or + # base_model is already concrete, and will be included transitively via cls. + continue + elif cls in _assigned_parameters: + if base_model in _assigned_parameters: + # cls is partially parameterised but not from base_model + # e.g. cls = B[S], base_model = A[S] + # B[S][int] should subclass A[int], (and will be transitively via B[int]) + # but it's not viable to consistently subclass types with arbitrary construction + # So don't attempt to include A[S][int] + continue + else: # base_model not in _assigned_parameters: + # cls is partially parameterized, base_model is original generic + # e.g. cls = B[str, T], base_model = B[S, T] + # Need to determine the mapping for the base_model parameters + mapped_types: Parametrization = { + key: typevars_map.get(value, value) for key, value in _assigned_parameters[cls].items() + } + yield from build_base_model(base_model, mapped_types) + else: + # cls is base generic, so base_class has a distinct base + # can construct the Parameterised base model using typevars_map directly + yield from build_base_model(base_model, typevars_map) + + +def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any: + """Return type with all occurrences of `type_map` keys recursively replaced with their values. + + :param type_: Any type, class or generic alias + :param type_map: Mapping from `TypeVar` instance to concrete types. + :return: New type representing the basic structure of `type_` with all + `typevar_map` keys recursively replaced. + + >>> replace_types(Tuple[str, Union[List[str], float]], {str: int}) + Tuple[int, Union[List[int], float]] + + """ + if not type_map: + return type_ + + type_args = get_args(type_) + origin_type = get_origin(type_) + + if origin_type is Annotated: + annotated_type, *annotations = type_args + return Annotated[replace_types(annotated_type, type_map), tuple(annotations)] + + # Having type args is a good indicator that this is a typing module + # class instantiation or a generic alias of some sort. + if type_args: + resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args) + if all_identical(type_args, resolved_type_args): + # If all arguments are the same, there is no need to modify the + # type or create a new object at all + return type_ + if ( + origin_type is not None + and isinstance(type_, typing_base) + and not isinstance(origin_type, typing_base) + and getattr(type_, '_name', None) is not None + ): + # In python < 3.9 generic aliases don't exist so any of these like `list`, + # `type` or `collections.abc.Callable` need to be translated. + # See: https://www.python.org/dev/peps/pep-0585 + origin_type = getattr(typing, type_._name) + assert origin_type is not None + return origin_type[resolved_type_args] + + # We handle pydantic generic models separately as they don't have the same + # semantics as "typing" classes or generic aliases + if not origin_type and lenient_issubclass(type_, GenericModel) and not type_.__concrete__: + type_args = type_.__parameters__ + resolved_type_args = tuple(replace_types(t, type_map) for t in type_args) + if all_identical(type_args, resolved_type_args): + return type_ + return type_[resolved_type_args] + + # Handle special case for typehints that can have lists as arguments. + # `typing.Callable[[int, str], int]` is an example for this. + if isinstance(type_, (List, list)): + resolved_list = list(replace_types(element, type_map) for element in type_) + if all_identical(type_, resolved_list): + return type_ + return resolved_list + + # For JsonWrapperValue, need to handle its inner type to allow correct parsing + # of generic Json arguments like Json[T] + if not origin_type and lenient_issubclass(type_, JsonWrapper): + type_.inner_type = replace_types(type_.inner_type, type_map) + return type_ + + # If all else fails, we try to resolve the type directly and otherwise just + # return the input with no modifications. + return type_map.get(type_, type_) + + +def check_parameters_count(cls: Type[GenericModel], parameters: Tuple[Any, ...]) -> None: + actual = len(parameters) + expected = len(cls.__parameters__) + if actual != expected: + description = 'many' if actual > expected else 'few' + raise TypeError(f'Too {description} parameters for {cls.__name__}; actual {actual}, expected {expected}') + + +DictValues: Type[Any] = {}.values().__class__ + + +def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]: + """Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found.""" + if isinstance(v, TypeVar): + yield v + elif hasattr(v, '__parameters__') and not get_origin(v) and lenient_issubclass(v, GenericModel): + yield from v.__parameters__ + elif isinstance(v, (DictValues, list)): + for var in v: + yield from iter_contained_typevars(var) + else: + args = get_args(v) + for arg in args: + yield from iter_contained_typevars(arg) + + +def get_caller_frame_info() -> Tuple[Optional[str], bool]: + """ + Used inside a function to check whether it was called globally + + Will only work against non-compiled code, therefore used only in pydantic.generics + + :returns Tuple[module_name, called_globally] + """ + try: + previous_caller_frame = sys._getframe(2) + except ValueError as e: + raise RuntimeError('This function must be used inside another function') from e + except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it + return None, False + frame_globals = previous_caller_frame.f_globals + return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals + + +def _prepare_model_fields( + created_model: Type[GenericModel], + fields: Mapping[str, Any], + instance_type_hints: Mapping[str, type], + typevars_map: Mapping[Any, type], +) -> None: + """ + Replace DeferredType fields with concrete type hints and prepare them. + """ + + for key, field in created_model.__fields__.items(): + if key not in fields: + assert field.type_.__class__ is not DeferredType + # https://github.com/nedbat/coveragepy/issues/198 + continue # pragma: no cover + + assert field.type_.__class__ is DeferredType, field.type_.__class__ + + field_type_hint = instance_type_hints[key] + concrete_type = replace_types(field_type_hint, typevars_map) + field.type_ = concrete_type + field.outer_type_ = concrete_type + field.prepare() + created_model.__annotations__[key] = concrete_type diff --git a/libs/win/pydantic/json.cp37-win_amd64.pyd b/libs/win/pydantic/json.cp37-win_amd64.pyd new file mode 100644 index 00000000..b01a561a Binary files /dev/null and b/libs/win/pydantic/json.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/json.py b/libs/win/pydantic/json.py new file mode 100644 index 00000000..b358b850 --- /dev/null +++ b/libs/win/pydantic/json.py @@ -0,0 +1,112 @@ +import datetime +from collections import deque +from decimal import Decimal +from enum import Enum +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from pathlib import Path +from re import Pattern +from types import GeneratorType +from typing import Any, Callable, Dict, Type, Union +from uuid import UUID + +from .color import Color +from .networks import NameEmail +from .types import SecretBytes, SecretStr + +__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat' + + +def isoformat(o: Union[datetime.date, datetime.time]) -> str: + return o.isoformat() + + +def decimal_encoder(dec_value: Decimal) -> Union[int, float]: + """ + Encodes a Decimal as int of there's no exponent, otherwise float + + This is useful when we use ConstrainedDecimal to represent Numeric(x,0) + where a integer (but not int typed) is used. Encoding this as a float + results in failed round-tripping between encode and parse. + Our Id type is a prime example of this. + + >>> decimal_encoder(Decimal("1.0")) + 1.0 + + >>> decimal_encoder(Decimal("1")) + 1 + """ + if dec_value.as_tuple().exponent >= 0: + return int(dec_value) + else: + return float(dec_value) + + +ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { + bytes: lambda o: o.decode(), + Color: str, + datetime.date: isoformat, + datetime.datetime: isoformat, + datetime.time: isoformat, + datetime.timedelta: lambda td: td.total_seconds(), + Decimal: decimal_encoder, + Enum: lambda o: o.value, + frozenset: list, + deque: list, + GeneratorType: list, + IPv4Address: str, + IPv4Interface: str, + IPv4Network: str, + IPv6Address: str, + IPv6Interface: str, + IPv6Network: str, + NameEmail: str, + Path: str, + Pattern: lambda o: o.pattern, + SecretBytes: str, + SecretStr: str, + set: list, + UUID: str, +} + + +def pydantic_encoder(obj: Any) -> Any: + from dataclasses import asdict, is_dataclass + + from .main import BaseModel + + if isinstance(obj, BaseModel): + return obj.dict() + elif is_dataclass(obj): + return asdict(obj) + + # Check the class type and its superclasses for a matching encoder + for base in obj.__class__.__mro__[:-1]: + try: + encoder = ENCODERS_BY_TYPE[base] + except KeyError: + continue + return encoder(obj) + else: # We have exited the for loop without finding a suitable encoder + raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable") + + +def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any: + # Check the class type and its superclasses for a matching encoder + for base in obj.__class__.__mro__[:-1]: + try: + encoder = type_encoders[base] + except KeyError: + continue + + return encoder(obj) + else: # We have exited the for loop without finding a suitable encoder + return pydantic_encoder(obj) + + +def timedelta_isoformat(td: datetime.timedelta) -> str: + """ + ISO 8601 encoding for Python timedelta object. + """ + minutes, seconds = divmod(td.seconds, 60) + hours, minutes = divmod(minutes, 60) + return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S' diff --git a/libs/win/pydantic/main.cp37-win_amd64.pyd b/libs/win/pydantic/main.cp37-win_amd64.pyd new file mode 100644 index 00000000..a100804c Binary files /dev/null and b/libs/win/pydantic/main.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/main.py b/libs/win/pydantic/main.py new file mode 100644 index 00000000..69f3b751 --- /dev/null +++ b/libs/win/pydantic/main.py @@ -0,0 +1,1109 @@ +import warnings +from abc import ABCMeta +from copy import deepcopy +from enum import Enum +from functools import partial +from pathlib import Path +from types import FunctionType, prepare_class, resolve_bases +from typing import ( + TYPE_CHECKING, + AbstractSet, + Any, + Callable, + ClassVar, + Dict, + List, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, + no_type_check, + overload, +) + +from typing_extensions import dataclass_transform + +from .class_validators import ValidatorGroup, extract_root_validators, extract_validators, inherit_validators +from .config import BaseConfig, Extra, inherit_config, prepare_config +from .error_wrappers import ErrorWrapper, ValidationError +from .errors import ConfigError, DictError, ExtraError, MissingError +from .fields import ( + MAPPING_LIKE_SHAPES, + Field, + FieldInfo, + ModelField, + ModelPrivateAttr, + PrivateAttr, + Undefined, + is_finalvar_with_default_val, +) +from .json import custom_pydantic_encoder, pydantic_encoder +from .parse import Protocol, load_file, load_str_bytes +from .schema import default_ref_template, model_schema +from .types import PyObject, StrBytes +from .typing import ( + AnyCallable, + get_args, + get_origin, + is_classvar, + is_namedtuple, + is_union, + resolve_annotations, + update_model_forward_refs, +) +from .utils import ( + DUNDER_ATTRIBUTES, + ROOT_KEY, + ClassAttribute, + GetterDict, + Representation, + ValueItems, + generate_model_signature, + is_valid_field, + is_valid_private_name, + lenient_issubclass, + sequence_like, + smart_deepcopy, + unique_list, + validate_field_name, +) + +if TYPE_CHECKING: + from inspect import Signature + + from .class_validators import ValidatorListDict + from .types import ModelOrDc + from .typing import ( + AbstractSetIntStr, + AnyClassMethod, + CallableGenerator, + DictAny, + DictStrAny, + MappingIntStrAny, + ReprArgs, + SetStr, + TupleGenerator, + ) + + Model = TypeVar('Model', bound='BaseModel') + +__all__ = 'BaseModel', 'create_model', 'validate_model' + +_T = TypeVar('_T') + + +def validate_custom_root_type(fields: Dict[str, ModelField]) -> None: + if len(fields) > 1: + raise ValueError(f'{ROOT_KEY} cannot be mixed with other fields') + + +def generate_hash_function(frozen: bool) -> Optional[Callable[[Any], int]]: + def hash_function(self_: Any) -> int: + return hash(self_.__class__) + hash(tuple(self_.__dict__.values())) + + return hash_function if frozen else None + + +# If a field is of type `Callable`, its default value should be a function and cannot to ignored. +ANNOTATED_FIELD_UNTOUCHED_TYPES: Tuple[Any, ...] = (property, type, classmethod, staticmethod) +# When creating a `BaseModel` instance, we bypass all the methods, properties... added to the model +UNTOUCHED_TYPES: Tuple[Any, ...] = (FunctionType,) + ANNOTATED_FIELD_UNTOUCHED_TYPES +# Note `ModelMetaclass` refers to `BaseModel`, but is also used to *create* `BaseModel`, so we need to add this extra +# (somewhat hacky) boolean to keep track of whether we've created the `BaseModel` class yet, and therefore whether it's +# safe to refer to it. If it *hasn't* been created, we assume that the `__new__` call we're in the middle of is for +# the `BaseModel` class, since that's defined immediately after the metaclass. +_is_base_model_class_defined = False + + +@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) +class ModelMetaclass(ABCMeta): + @no_type_check # noqa C901 + def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901 + fields: Dict[str, ModelField] = {} + config = BaseConfig + validators: 'ValidatorListDict' = {} + + pre_root_validators, post_root_validators = [], [] + private_attributes: Dict[str, ModelPrivateAttr] = {} + base_private_attributes: Dict[str, ModelPrivateAttr] = {} + slots: SetStr = namespace.get('__slots__', ()) + slots = {slots} if isinstance(slots, str) else set(slots) + class_vars: SetStr = set() + hash_func: Optional[Callable[[Any], int]] = None + + for base in reversed(bases): + if _is_base_model_class_defined and issubclass(base, BaseModel) and base != BaseModel: + fields.update(smart_deepcopy(base.__fields__)) + config = inherit_config(base.__config__, config) + validators = inherit_validators(base.__validators__, validators) + pre_root_validators += base.__pre_root_validators__ + post_root_validators += base.__post_root_validators__ + base_private_attributes.update(base.__private_attributes__) + class_vars.update(base.__class_vars__) + hash_func = base.__hash__ + + resolve_forward_refs = kwargs.pop('__resolve_forward_refs__', True) + allowed_config_kwargs: SetStr = { + key + for key in dir(config) + if not (key.startswith('__') and key.endswith('__')) # skip dunder methods and attributes + } + config_kwargs = {key: kwargs.pop(key) for key in kwargs.keys() & allowed_config_kwargs} + config_from_namespace = namespace.get('Config') + if config_kwargs and config_from_namespace: + raise TypeError('Specifying config in two places is ambiguous, use either Config attribute or class kwargs') + config = inherit_config(config_from_namespace, config, **config_kwargs) + + validators = inherit_validators(extract_validators(namespace), validators) + vg = ValidatorGroup(validators) + + for f in fields.values(): + f.set_config(config) + extra_validators = vg.get_validators(f.name) + if extra_validators: + f.class_validators.update(extra_validators) + # re-run prepare to add extra validators + f.populate_validators() + + prepare_config(config, name) + + untouched_types = ANNOTATED_FIELD_UNTOUCHED_TYPES + + def is_untouched(v: Any) -> bool: + return isinstance(v, untouched_types) or v.__class__.__name__ == 'cython_function_or_method' + + if (namespace.get('__module__'), namespace.get('__qualname__')) != ('pydantic.main', 'BaseModel'): + annotations = resolve_annotations(namespace.get('__annotations__', {}), namespace.get('__module__', None)) + # annotation only fields need to come first in fields + for ann_name, ann_type in annotations.items(): + if is_classvar(ann_type): + class_vars.add(ann_name) + elif is_finalvar_with_default_val(ann_type, namespace.get(ann_name, Undefined)): + class_vars.add(ann_name) + elif is_valid_field(ann_name): + validate_field_name(bases, ann_name) + value = namespace.get(ann_name, Undefined) + allowed_types = get_args(ann_type) if is_union(get_origin(ann_type)) else (ann_type,) + if ( + is_untouched(value) + and ann_type != PyObject + and not any( + lenient_issubclass(get_origin(allowed_type), Type) for allowed_type in allowed_types + ) + ): + continue + fields[ann_name] = ModelField.infer( + name=ann_name, + value=value, + annotation=ann_type, + class_validators=vg.get_validators(ann_name), + config=config, + ) + elif ann_name not in namespace and config.underscore_attrs_are_private: + private_attributes[ann_name] = PrivateAttr() + + untouched_types = UNTOUCHED_TYPES + config.keep_untouched + for var_name, value in namespace.items(): + can_be_changed = var_name not in class_vars and not is_untouched(value) + if isinstance(value, ModelPrivateAttr): + if not is_valid_private_name(var_name): + raise NameError( + f'Private attributes "{var_name}" must not be a valid field name; ' + f'Use sunder or dunder names, e. g. "_{var_name}" or "__{var_name}__"' + ) + private_attributes[var_name] = value + elif config.underscore_attrs_are_private and is_valid_private_name(var_name) and can_be_changed: + private_attributes[var_name] = PrivateAttr(default=value) + elif is_valid_field(var_name) and var_name not in annotations and can_be_changed: + validate_field_name(bases, var_name) + inferred = ModelField.infer( + name=var_name, + value=value, + annotation=annotations.get(var_name, Undefined), + class_validators=vg.get_validators(var_name), + config=config, + ) + if var_name in fields: + if lenient_issubclass(inferred.type_, fields[var_name].type_): + inferred.type_ = fields[var_name].type_ + else: + raise TypeError( + f'The type of {name}.{var_name} differs from the new default value; ' + f'if you wish to change the type of this field, please use a type annotation' + ) + fields[var_name] = inferred + + _custom_root_type = ROOT_KEY in fields + if _custom_root_type: + validate_custom_root_type(fields) + vg.check_for_unused() + if config.json_encoders: + json_encoder = partial(custom_pydantic_encoder, config.json_encoders) + else: + json_encoder = pydantic_encoder + pre_rv_new, post_rv_new = extract_root_validators(namespace) + + if hash_func is None: + hash_func = generate_hash_function(config.frozen) + + exclude_from_namespace = fields | private_attributes.keys() | {'__slots__'} + new_namespace = { + '__config__': config, + '__fields__': fields, + '__exclude_fields__': { + name: field.field_info.exclude for name, field in fields.items() if field.field_info.exclude is not None + } + or None, + '__include_fields__': { + name: field.field_info.include for name, field in fields.items() if field.field_info.include is not None + } + or None, + '__validators__': vg.validators, + '__pre_root_validators__': unique_list( + pre_root_validators + pre_rv_new, + name_factory=lambda v: v.__name__, + ), + '__post_root_validators__': unique_list( + post_root_validators + post_rv_new, + name_factory=lambda skip_on_failure_and_v: skip_on_failure_and_v[1].__name__, + ), + '__schema_cache__': {}, + '__json_encoder__': staticmethod(json_encoder), + '__custom_root_type__': _custom_root_type, + '__private_attributes__': {**base_private_attributes, **private_attributes}, + '__slots__': slots | private_attributes.keys(), + '__hash__': hash_func, + '__class_vars__': class_vars, + **{n: v for n, v in namespace.items() if n not in exclude_from_namespace}, + } + + cls = super().__new__(mcs, name, bases, new_namespace, **kwargs) + # set __signature__ attr only for model class, but not for its instances + cls.__signature__ = ClassAttribute('__signature__', generate_model_signature(cls.__init__, fields, config)) + if resolve_forward_refs: + cls.__try_update_forward_refs__() + + # preserve `__set_name__` protocol defined in https://peps.python.org/pep-0487 + # for attributes not in `new_namespace` (e.g. private attributes) + for name, obj in namespace.items(): + if name not in new_namespace: + set_name = getattr(obj, '__set_name__', None) + if callable(set_name): + set_name(cls, name) + + return cls + + def __instancecheck__(self, instance: Any) -> bool: + """ + Avoid calling ABC _abc_subclasscheck unless we're pretty sure. + + See #3829 and python/cpython#92810 + """ + return hasattr(instance, '__fields__') and super().__instancecheck__(instance) + + +object_setattr = object.__setattr__ + + +class BaseModel(Representation, metaclass=ModelMetaclass): + if TYPE_CHECKING: + # populated by the metaclass, defined here to help IDEs only + __fields__: ClassVar[Dict[str, ModelField]] = {} + __include_fields__: ClassVar[Optional[Mapping[str, Any]]] = None + __exclude_fields__: ClassVar[Optional[Mapping[str, Any]]] = None + __validators__: ClassVar[Dict[str, AnyCallable]] = {} + __pre_root_validators__: ClassVar[List[AnyCallable]] + __post_root_validators__: ClassVar[List[Tuple[bool, AnyCallable]]] + __config__: ClassVar[Type[BaseConfig]] = BaseConfig + __json_encoder__: ClassVar[Callable[[Any], Any]] = lambda x: x + __schema_cache__: ClassVar['DictAny'] = {} + __custom_root_type__: ClassVar[bool] = False + __signature__: ClassVar['Signature'] + __private_attributes__: ClassVar[Dict[str, ModelPrivateAttr]] + __class_vars__: ClassVar[SetStr] + __fields_set__: ClassVar[SetStr] = set() + + Config = BaseConfig + __slots__ = ('__dict__', '__fields_set__') + __doc__ = '' # Null out the Representation docstring + + def __init__(__pydantic_self__, **data: Any) -> None: + """ + Create a new model by parsing and validating input data from keyword arguments. + + Raises ValidationError if the input data cannot be parsed to form a valid model. + """ + # Uses something other than `self` the first arg to allow "self" as a settable attribute + values, fields_set, validation_error = validate_model(__pydantic_self__.__class__, data) + if validation_error: + raise validation_error + try: + object_setattr(__pydantic_self__, '__dict__', values) + except TypeError as e: + raise TypeError( + 'Model values must be a dict; you may not have returned a dictionary from a root validator' + ) from e + object_setattr(__pydantic_self__, '__fields_set__', fields_set) + __pydantic_self__._init_private_attributes() + + @no_type_check + def __setattr__(self, name, value): # noqa: C901 (ignore complexity) + if name in self.__private_attributes__ or name in DUNDER_ATTRIBUTES: + return object_setattr(self, name, value) + + if self.__config__.extra is not Extra.allow and name not in self.__fields__: + raise ValueError(f'"{self.__class__.__name__}" object has no field "{name}"') + elif not self.__config__.allow_mutation or self.__config__.frozen: + raise TypeError(f'"{self.__class__.__name__}" is immutable and does not support item assignment') + elif name in self.__fields__ and self.__fields__[name].final: + raise TypeError( + f'"{self.__class__.__name__}" object "{name}" field is final and does not support reassignment' + ) + elif self.__config__.validate_assignment: + new_values = {**self.__dict__, name: value} + + for validator in self.__pre_root_validators__: + try: + new_values = validator(self.__class__, new_values) + except (ValueError, TypeError, AssertionError) as exc: + raise ValidationError([ErrorWrapper(exc, loc=ROOT_KEY)], self.__class__) + + known_field = self.__fields__.get(name, None) + if known_field: + # We want to + # - make sure validators are called without the current value for this field inside `values` + # - keep other values (e.g. submodels) untouched (using `BaseModel.dict()` will change them into dicts) + # - keep the order of the fields + if not known_field.field_info.allow_mutation: + raise TypeError(f'"{known_field.name}" has allow_mutation set to False and cannot be assigned') + dict_without_original_value = {k: v for k, v in self.__dict__.items() if k != name} + value, error_ = known_field.validate(value, dict_without_original_value, loc=name, cls=self.__class__) + if error_: + raise ValidationError([error_], self.__class__) + else: + new_values[name] = value + + errors = [] + for skip_on_failure, validator in self.__post_root_validators__: + if skip_on_failure and errors: + continue + try: + new_values = validator(self.__class__, new_values) + except (ValueError, TypeError, AssertionError) as exc: + errors.append(ErrorWrapper(exc, loc=ROOT_KEY)) + if errors: + raise ValidationError(errors, self.__class__) + + # update the whole __dict__ as other values than just `value` + # may be changed (e.g. with `root_validator`) + object_setattr(self, '__dict__', new_values) + else: + self.__dict__[name] = value + + self.__fields_set__.add(name) + + def __getstate__(self) -> 'DictAny': + private_attrs = ((k, getattr(self, k, Undefined)) for k in self.__private_attributes__) + return { + '__dict__': self.__dict__, + '__fields_set__': self.__fields_set__, + '__private_attribute_values__': {k: v for k, v in private_attrs if v is not Undefined}, + } + + def __setstate__(self, state: 'DictAny') -> None: + object_setattr(self, '__dict__', state['__dict__']) + object_setattr(self, '__fields_set__', state['__fields_set__']) + for name, value in state.get('__private_attribute_values__', {}).items(): + object_setattr(self, name, value) + + def _init_private_attributes(self) -> None: + for name, private_attr in self.__private_attributes__.items(): + default = private_attr.get_default() + if default is not Undefined: + object_setattr(self, name, default) + + def dict( + self, + *, + include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> 'DictStrAny': + """ + Generate a dictionary representation of the model, optionally specifying which fields to include or exclude. + + """ + if skip_defaults is not None: + warnings.warn( + f'{self.__class__.__name__}.dict(): "skip_defaults" is deprecated and replaced by "exclude_unset"', + DeprecationWarning, + ) + exclude_unset = skip_defaults + + return dict( + self._iter( + to_dict=True, + by_alias=by_alias, + include=include, + exclude=exclude, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + ) + + def json( + self, + *, + include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + encoder: Optional[Callable[[Any], Any]] = None, + models_as_dict: bool = True, + **dumps_kwargs: Any, + ) -> str: + """ + Generate a JSON representation of the model, `include` and `exclude` arguments as per `dict()`. + + `encoder` is an optional function to supply as `default` to json.dumps(), other arguments as per `json.dumps()`. + """ + if skip_defaults is not None: + warnings.warn( + f'{self.__class__.__name__}.json(): "skip_defaults" is deprecated and replaced by "exclude_unset"', + DeprecationWarning, + ) + exclude_unset = skip_defaults + encoder = cast(Callable[[Any], Any], encoder or self.__json_encoder__) + + # We don't directly call `self.dict()`, which does exactly this with `to_dict=True` + # because we want to be able to keep raw `BaseModel` instances and not as `dict`. + # This allows users to write custom JSON encoders for given `BaseModel` classes. + data = dict( + self._iter( + to_dict=models_as_dict, + by_alias=by_alias, + include=include, + exclude=exclude, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + ) + if self.__custom_root_type__: + data = data[ROOT_KEY] + return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs) + + @classmethod + def _enforce_dict_if_root(cls, obj: Any) -> Any: + if cls.__custom_root_type__ and ( + not (isinstance(obj, dict) and obj.keys() == {ROOT_KEY}) + or cls.__fields__[ROOT_KEY].shape in MAPPING_LIKE_SHAPES + ): + return {ROOT_KEY: obj} + else: + return obj + + @classmethod + def parse_obj(cls: Type['Model'], obj: Any) -> 'Model': + obj = cls._enforce_dict_if_root(obj) + if not isinstance(obj, dict): + try: + obj = dict(obj) + except (TypeError, ValueError) as e: + exc = TypeError(f'{cls.__name__} expected dict not {obj.__class__.__name__}') + raise ValidationError([ErrorWrapper(exc, loc=ROOT_KEY)], cls) from e + return cls(**obj) + + @classmethod + def parse_raw( + cls: Type['Model'], + b: StrBytes, + *, + content_type: str = None, + encoding: str = 'utf8', + proto: Protocol = None, + allow_pickle: bool = False, + ) -> 'Model': + try: + obj = load_str_bytes( + b, + proto=proto, + content_type=content_type, + encoding=encoding, + allow_pickle=allow_pickle, + json_loads=cls.__config__.json_loads, + ) + except (ValueError, TypeError, UnicodeDecodeError) as e: + raise ValidationError([ErrorWrapper(e, loc=ROOT_KEY)], cls) + return cls.parse_obj(obj) + + @classmethod + def parse_file( + cls: Type['Model'], + path: Union[str, Path], + *, + content_type: str = None, + encoding: str = 'utf8', + proto: Protocol = None, + allow_pickle: bool = False, + ) -> 'Model': + obj = load_file( + path, + proto=proto, + content_type=content_type, + encoding=encoding, + allow_pickle=allow_pickle, + json_loads=cls.__config__.json_loads, + ) + return cls.parse_obj(obj) + + @classmethod + def from_orm(cls: Type['Model'], obj: Any) -> 'Model': + if not cls.__config__.orm_mode: + raise ConfigError('You must have the config attribute orm_mode=True to use from_orm') + obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj) + m = cls.__new__(cls) + values, fields_set, validation_error = validate_model(cls, obj) + if validation_error: + raise validation_error + object_setattr(m, '__dict__', values) + object_setattr(m, '__fields_set__', fields_set) + m._init_private_attributes() + return m + + @classmethod + def construct(cls: Type['Model'], _fields_set: Optional['SetStr'] = None, **values: Any) -> 'Model': + """ + Creates a new model setting __dict__ and __fields_set__ from trusted or pre-validated data. + Default values are respected, but no other validation is performed. + Behaves as if `Config.extra = 'allow'` was set since it adds all passed values + """ + m = cls.__new__(cls) + fields_values: Dict[str, Any] = {} + for name, field in cls.__fields__.items(): + if field.alt_alias and field.alias in values: + fields_values[name] = values[field.alias] + elif name in values: + fields_values[name] = values[name] + elif not field.required: + fields_values[name] = field.get_default() + fields_values.update(values) + object_setattr(m, '__dict__', fields_values) + if _fields_set is None: + _fields_set = set(values.keys()) + object_setattr(m, '__fields_set__', _fields_set) + m._init_private_attributes() + return m + + def _copy_and_set_values(self: 'Model', values: 'DictStrAny', fields_set: 'SetStr', *, deep: bool) -> 'Model': + if deep: + # chances of having empty dict here are quite low for using smart_deepcopy + values = deepcopy(values) + + cls = self.__class__ + m = cls.__new__(cls) + object_setattr(m, '__dict__', values) + object_setattr(m, '__fields_set__', fields_set) + for name in self.__private_attributes__: + value = getattr(self, name, Undefined) + if value is not Undefined: + if deep: + value = deepcopy(value) + object_setattr(m, name, value) + + return m + + def copy( + self: 'Model', + *, + include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + update: Optional['DictStrAny'] = None, + deep: bool = False, + ) -> 'Model': + """ + Duplicate a model, optionally choose which fields to include, exclude and change. + + :param include: fields to include in new model + :param exclude: fields to exclude from new model, as with values this takes precedence over include + :param update: values to change/add in the new model. Note: the data is not validated before creating + the new model: you should trust this data + :param deep: set to `True` to make a deep copy of the model + :return: new model instance + """ + + values = dict( + self._iter(to_dict=False, by_alias=False, include=include, exclude=exclude, exclude_unset=False), + **(update or {}), + ) + + # new `__fields_set__` can have unset optional fields with a set value in `update` kwarg + if update: + fields_set = self.__fields_set__ | update.keys() + else: + fields_set = set(self.__fields_set__) + + return self._copy_and_set_values(values, fields_set, deep=deep) + + @classmethod + def schema(cls, by_alias: bool = True, ref_template: str = default_ref_template) -> 'DictStrAny': + cached = cls.__schema_cache__.get((by_alias, ref_template)) + if cached is not None: + return cached + s = model_schema(cls, by_alias=by_alias, ref_template=ref_template) + cls.__schema_cache__[(by_alias, ref_template)] = s + return s + + @classmethod + def schema_json( + cls, *, by_alias: bool = True, ref_template: str = default_ref_template, **dumps_kwargs: Any + ) -> str: + from .json import pydantic_encoder + + return cls.__config__.json_dumps( + cls.schema(by_alias=by_alias, ref_template=ref_template), default=pydantic_encoder, **dumps_kwargs + ) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls: Type['Model'], value: Any) -> 'Model': + if isinstance(value, cls): + copy_on_model_validation = cls.__config__.copy_on_model_validation + # whether to deep or shallow copy the model on validation, None means do not copy + deep_copy: Optional[bool] = None + if copy_on_model_validation not in {'deep', 'shallow', 'none'}: + # Warn about deprecated behavior + warnings.warn( + "`copy_on_model_validation` should be a string: 'deep', 'shallow' or 'none'", DeprecationWarning + ) + if copy_on_model_validation: + deep_copy = False + + if copy_on_model_validation == 'shallow': + # shallow copy + deep_copy = False + elif copy_on_model_validation == 'deep': + # deep copy + deep_copy = True + + if deep_copy is None: + return value + else: + return value._copy_and_set_values(value.__dict__, value.__fields_set__, deep=deep_copy) + + value = cls._enforce_dict_if_root(value) + + if isinstance(value, dict): + return cls(**value) + elif cls.__config__.orm_mode: + return cls.from_orm(value) + else: + try: + value_as_dict = dict(value) + except (TypeError, ValueError) as e: + raise DictError() from e + return cls(**value_as_dict) + + @classmethod + def _decompose_class(cls: Type['Model'], obj: Any) -> GetterDict: + if isinstance(obj, GetterDict): + return obj + return cls.__config__.getter_dict(obj) + + @classmethod + @no_type_check + def _get_value( + cls, + v: Any, + to_dict: bool, + by_alias: bool, + include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']], + exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']], + exclude_unset: bool, + exclude_defaults: bool, + exclude_none: bool, + ) -> Any: + + if isinstance(v, BaseModel): + if to_dict: + v_dict = v.dict( + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + include=include, + exclude=exclude, + exclude_none=exclude_none, + ) + if ROOT_KEY in v_dict: + return v_dict[ROOT_KEY] + return v_dict + else: + return v.copy(include=include, exclude=exclude) + + value_exclude = ValueItems(v, exclude) if exclude else None + value_include = ValueItems(v, include) if include else None + + if isinstance(v, dict): + return { + k_: cls._get_value( + v_, + to_dict=to_dict, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + include=value_include and value_include.for_element(k_), + exclude=value_exclude and value_exclude.for_element(k_), + exclude_none=exclude_none, + ) + for k_, v_ in v.items() + if (not value_exclude or not value_exclude.is_excluded(k_)) + and (not value_include or value_include.is_included(k_)) + } + + elif sequence_like(v): + seq_args = ( + cls._get_value( + v_, + to_dict=to_dict, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + include=value_include and value_include.for_element(i), + exclude=value_exclude and value_exclude.for_element(i), + exclude_none=exclude_none, + ) + for i, v_ in enumerate(v) + if (not value_exclude or not value_exclude.is_excluded(i)) + and (not value_include or value_include.is_included(i)) + ) + + return v.__class__(*seq_args) if is_namedtuple(v.__class__) else v.__class__(seq_args) + + elif isinstance(v, Enum) and getattr(cls.Config, 'use_enum_values', False): + return v.value + + else: + return v + + @classmethod + def __try_update_forward_refs__(cls, **localns: Any) -> None: + """ + Same as update_forward_refs but will not raise exception + when forward references are not defined. + """ + update_model_forward_refs(cls, cls.__fields__.values(), cls.__config__.json_encoders, localns, (NameError,)) + + @classmethod + def update_forward_refs(cls, **localns: Any) -> None: + """ + Try to update ForwardRefs on fields based on this Model, globalns and localns. + """ + update_model_forward_refs(cls, cls.__fields__.values(), cls.__config__.json_encoders, localns) + + def __iter__(self) -> 'TupleGenerator': + """ + so `dict(model)` works + """ + yield from self.__dict__.items() + + def _iter( + self, + to_dict: bool = False, + by_alias: bool = False, + include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> 'TupleGenerator': + + # Merge field set excludes with explicit exclude parameter with explicit overriding field set options. + # The extra "is not None" guards are not logically necessary but optimizes performance for the simple case. + if exclude is not None or self.__exclude_fields__ is not None: + exclude = ValueItems.merge(self.__exclude_fields__, exclude) + + if include is not None or self.__include_fields__ is not None: + include = ValueItems.merge(self.__include_fields__, include, intersect=True) + + allowed_keys = self._calculate_keys( + include=include, exclude=exclude, exclude_unset=exclude_unset # type: ignore + ) + if allowed_keys is None and not (to_dict or by_alias or exclude_unset or exclude_defaults or exclude_none): + # huge boost for plain _iter() + yield from self.__dict__.items() + return + + value_exclude = ValueItems(self, exclude) if exclude is not None else None + value_include = ValueItems(self, include) if include is not None else None + + for field_key, v in self.__dict__.items(): + if (allowed_keys is not None and field_key not in allowed_keys) or (exclude_none and v is None): + continue + + if exclude_defaults: + model_field = self.__fields__.get(field_key) + if not getattr(model_field, 'required', True) and getattr(model_field, 'default', _missing) == v: + continue + + if by_alias and field_key in self.__fields__: + dict_key = self.__fields__[field_key].alias + else: + dict_key = field_key + + if to_dict or value_include or value_exclude: + v = self._get_value( + v, + to_dict=to_dict, + by_alias=by_alias, + include=value_include and value_include.for_element(field_key), + exclude=value_exclude and value_exclude.for_element(field_key), + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + yield dict_key, v + + def _calculate_keys( + self, + include: Optional['MappingIntStrAny'], + exclude: Optional['MappingIntStrAny'], + exclude_unset: bool, + update: Optional['DictStrAny'] = None, + ) -> Optional[AbstractSet[str]]: + if include is None and exclude is None and exclude_unset is False: + return None + + keys: AbstractSet[str] + if exclude_unset: + keys = self.__fields_set__.copy() + else: + keys = self.__dict__.keys() + + if include is not None: + keys &= include.keys() + + if update: + keys -= update.keys() + + if exclude: + keys -= {k for k, v in exclude.items() if ValueItems.is_true(v)} + + return keys + + def __eq__(self, other: Any) -> bool: + if isinstance(other, BaseModel): + return self.dict() == other.dict() + else: + return self.dict() == other + + def __repr_args__(self) -> 'ReprArgs': + return [ + (k, v) + for k, v in self.__dict__.items() + if k not in DUNDER_ATTRIBUTES and (k not in self.__fields__ or self.__fields__[k].field_info.repr) + ] + + +_is_base_model_class_defined = True + + +@overload +def create_model( + __model_name: str, + *, + __config__: Optional[Type[BaseConfig]] = None, + __base__: None = None, + __module__: str = __name__, + __validators__: Dict[str, 'AnyClassMethod'] = None, + __cls_kwargs__: Dict[str, Any] = None, + **field_definitions: Any, +) -> Type['BaseModel']: + ... + + +@overload +def create_model( + __model_name: str, + *, + __config__: Optional[Type[BaseConfig]] = None, + __base__: Union[Type['Model'], Tuple[Type['Model'], ...]], + __module__: str = __name__, + __validators__: Dict[str, 'AnyClassMethod'] = None, + __cls_kwargs__: Dict[str, Any] = None, + **field_definitions: Any, +) -> Type['Model']: + ... + + +def create_model( + __model_name: str, + *, + __config__: Optional[Type[BaseConfig]] = None, + __base__: Union[None, Type['Model'], Tuple[Type['Model'], ...]] = None, + __module__: str = __name__, + __validators__: Dict[str, 'AnyClassMethod'] = None, + __cls_kwargs__: Dict[str, Any] = None, + __slots__: Optional[Tuple[str, ...]] = None, + **field_definitions: Any, +) -> Type['Model']: + """ + Dynamically create a model. + :param __model_name: name of the created model + :param __config__: config class to use for the new model + :param __base__: base class for the new model to inherit from + :param __module__: module of the created model + :param __validators__: a dict of method names and @validator class methods + :param __cls_kwargs__: a dict for class creation + :param __slots__: Deprecated, `__slots__` should not be passed to `create_model` + :param field_definitions: fields of the model (or extra fields if a base is supplied) + in the format `=(, )` or `=, e.g. + `foobar=(str, ...)` or `foobar=123`, or, for complex use-cases, in the format + `=` or `=(, )`, e.g. + `foo=Field(datetime, default_factory=datetime.utcnow, alias='bar')` or + `foo=(str, FieldInfo(title='Foo'))` + """ + if __slots__ is not None: + # __slots__ will be ignored from here on + warnings.warn('__slots__ should not be passed to create_model', RuntimeWarning) + + if __base__ is not None: + if __config__ is not None: + raise ConfigError('to avoid confusion __config__ and __base__ cannot be used together') + if not isinstance(__base__, tuple): + __base__ = (__base__,) + else: + __base__ = (cast(Type['Model'], BaseModel),) + + __cls_kwargs__ = __cls_kwargs__ or {} + + fields = {} + annotations = {} + + for f_name, f_def in field_definitions.items(): + if not is_valid_field(f_name): + warnings.warn(f'fields may not start with an underscore, ignoring "{f_name}"', RuntimeWarning) + if isinstance(f_def, tuple): + try: + f_annotation, f_value = f_def + except ValueError as e: + raise ConfigError( + 'field definitions should either be a tuple of (, ) or just a ' + 'default value, unfortunately this means tuples as ' + 'default values are not allowed' + ) from e + else: + f_annotation, f_value = None, f_def + + if f_annotation: + annotations[f_name] = f_annotation + fields[f_name] = f_value + + namespace: 'DictStrAny' = {'__annotations__': annotations, '__module__': __module__} + if __validators__: + namespace.update(__validators__) + namespace.update(fields) + if __config__: + namespace['Config'] = inherit_config(__config__, BaseConfig) + resolved_bases = resolve_bases(__base__) + meta, ns, kwds = prepare_class(__model_name, resolved_bases, kwds=__cls_kwargs__) + if resolved_bases is not __base__: + ns['__orig_bases__'] = __base__ + namespace.update(ns) + return meta(__model_name, resolved_bases, namespace, **kwds) + + +_missing = object() + + +def validate_model( # noqa: C901 (ignore complexity) + model: Type[BaseModel], input_data: 'DictStrAny', cls: 'ModelOrDc' = None +) -> Tuple['DictStrAny', 'SetStr', Optional[ValidationError]]: + """ + validate data against a model. + """ + values = {} + errors = [] + # input_data names, possibly alias + names_used = set() + # field names, never aliases + fields_set = set() + config = model.__config__ + check_extra = config.extra is not Extra.ignore + cls_ = cls or model + + for validator in model.__pre_root_validators__: + try: + input_data = validator(cls_, input_data) + except (ValueError, TypeError, AssertionError) as exc: + return {}, set(), ValidationError([ErrorWrapper(exc, loc=ROOT_KEY)], cls_) + + for name, field in model.__fields__.items(): + value = input_data.get(field.alias, _missing) + using_name = False + if value is _missing and config.allow_population_by_field_name and field.alt_alias: + value = input_data.get(field.name, _missing) + using_name = True + + if value is _missing: + if field.required: + errors.append(ErrorWrapper(MissingError(), loc=field.alias)) + continue + + value = field.get_default() + + if not config.validate_all and not field.validate_always: + values[name] = value + continue + else: + fields_set.add(name) + if check_extra: + names_used.add(field.name if using_name else field.alias) + + v_, errors_ = field.validate(value, values, loc=field.alias, cls=cls_) + if isinstance(errors_, ErrorWrapper): + errors.append(errors_) + elif isinstance(errors_, list): + errors.extend(errors_) + else: + values[name] = v_ + + if check_extra: + if isinstance(input_data, GetterDict): + extra = input_data.extra_keys() - names_used + else: + extra = input_data.keys() - names_used + if extra: + fields_set |= extra + if config.extra is Extra.allow: + for f in extra: + values[f] = input_data[f] + else: + for f in sorted(extra): + errors.append(ErrorWrapper(ExtraError(), loc=f)) + + for skip_on_failure, validator in model.__post_root_validators__: + if skip_on_failure and errors: + continue + try: + values = validator(cls_, values) + except (ValueError, TypeError, AssertionError) as exc: + errors.append(ErrorWrapper(exc, loc=ROOT_KEY)) + + if errors: + return values, fields_set, ValidationError(errors, cls_) + else: + return values, fields_set, None diff --git a/libs/win/pydantic/mypy.cp37-win_amd64.pyd b/libs/win/pydantic/mypy.cp37-win_amd64.pyd new file mode 100644 index 00000000..3611c95b Binary files /dev/null and b/libs/win/pydantic/mypy.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/mypy.py b/libs/win/pydantic/mypy.py new file mode 100644 index 00000000..6bd9db18 --- /dev/null +++ b/libs/win/pydantic/mypy.py @@ -0,0 +1,850 @@ +import sys +from configparser import ConfigParser +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type as TypingType, Union + +from mypy.errorcodes import ErrorCode +from mypy.nodes import ( + ARG_NAMED, + ARG_NAMED_OPT, + ARG_OPT, + ARG_POS, + ARG_STAR2, + MDEF, + Argument, + AssignmentStmt, + Block, + CallExpr, + ClassDef, + Context, + Decorator, + EllipsisExpr, + FuncBase, + FuncDef, + JsonDict, + MemberExpr, + NameExpr, + PassStmt, + PlaceholderNode, + RefExpr, + StrExpr, + SymbolNode, + SymbolTableNode, + TempNode, + TypeInfo, + TypeVarExpr, + Var, +) +from mypy.options import Options +from mypy.plugin import ( + CheckerPluginInterface, + ClassDefContext, + FunctionContext, + MethodContext, + Plugin, + SemanticAnalyzerPluginInterface, +) +from mypy.plugins import dataclasses +from mypy.semanal import set_callable_name # type: ignore +from mypy.server.trigger import make_wildcard_trigger +from mypy.types import ( + AnyType, + CallableType, + Instance, + NoneType, + Overloaded, + Type, + TypeOfAny, + TypeType, + TypeVarType, + UnionType, + get_proper_type, +) +from mypy.typevars import fill_typevars +from mypy.util import get_unique_redefinition_name +from mypy.version import __version__ as mypy_version + +from pydantic.utils import is_valid_field + +try: + from mypy.types import TypeVarDef # type: ignore[attr-defined] +except ImportError: # pragma: no cover + # Backward-compatible with TypeVarDef from Mypy 0.910. + from mypy.types import TypeVarType as TypeVarDef + +CONFIGFILE_KEY = 'pydantic-mypy' +METADATA_KEY = 'pydantic-mypy-metadata' +BASEMODEL_FULLNAME = 'pydantic.main.BaseModel' +BASESETTINGS_FULLNAME = 'pydantic.env_settings.BaseSettings' +FIELD_FULLNAME = 'pydantic.fields.Field' +DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass' + + +def parse_mypy_version(version: str) -> Tuple[int, ...]: + return tuple(int(part) for part in version.split('+', 1)[0].split('.')) + + +MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version) +BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__' + + +def plugin(version: str) -> 'TypingType[Plugin]': + """ + `version` is the mypy version string + + We might want to use this to print a warning if the mypy version being used is + newer, or especially older, than we expect (or need). + """ + return PydanticPlugin + + +class PydanticPlugin(Plugin): + def __init__(self, options: Options) -> None: + self.plugin_config = PydanticPluginConfig(options) + super().__init__(options) + + def get_base_class_hook(self, fullname: str) -> 'Optional[Callable[[ClassDefContext], None]]': + sym = self.lookup_fully_qualified(fullname) + if sym and isinstance(sym.node, TypeInfo): # pragma: no branch + # No branching may occur if the mypy cache has not been cleared + if any(get_fullname(base) == BASEMODEL_FULLNAME for base in sym.node.mro): + return self._pydantic_model_class_maker_callback + return None + + def get_function_hook(self, fullname: str) -> 'Optional[Callable[[FunctionContext], Type]]': + sym = self.lookup_fully_qualified(fullname) + if sym and sym.fullname == FIELD_FULLNAME: + return self._pydantic_field_callback + return None + + def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], Type]]: + if fullname.endswith('.from_orm'): + return from_orm_callback + return None + + def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: + if fullname == DATACLASS_FULLNAME: + return dataclasses.dataclass_class_maker_callback # type: ignore[return-value] + return None + + def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None: + transformer = PydanticModelTransformer(ctx, self.plugin_config) + transformer.transform() + + def _pydantic_field_callback(self, ctx: FunctionContext) -> 'Type': + """ + Extract the type of the `default` argument from the Field function, and use it as the return type. + + In particular: + * Check whether the default and default_factory argument is specified. + * Output an error if both are specified. + * Retrieve the type of the argument which is specified, and use it as return type for the function. + """ + default_any_type = ctx.default_return_type + + assert ctx.callee_arg_names[0] == 'default', '"default" is no longer first argument in Field()' + assert ctx.callee_arg_names[1] == 'default_factory', '"default_factory" is no longer second argument in Field()' + default_args = ctx.args[0] + default_factory_args = ctx.args[1] + + if default_args and default_factory_args: + error_default_and_default_factory_specified(ctx.api, ctx.context) + return default_any_type + + if default_args: + default_type = ctx.arg_types[0][0] + default_arg = default_args[0] + + # Fallback to default Any type if the field is required + if not isinstance(default_arg, EllipsisExpr): + return default_type + + elif default_factory_args: + default_factory_type = ctx.arg_types[1][0] + + # Functions which use `ParamSpec` can be overloaded, exposing the callable's types as a parameter + # Pydantic calls the default factory without any argument, so we retrieve the first item + if isinstance(default_factory_type, Overloaded): + if MYPY_VERSION_TUPLE > (0, 910): + default_factory_type = default_factory_type.items[0] + else: + # Mypy0.910 exposes the items of overloaded types in a function + default_factory_type = default_factory_type.items()[0] # type: ignore[operator] + + if isinstance(default_factory_type, CallableType): + ret_type = default_factory_type.ret_type + # mypy doesn't think `ret_type` has `args`, you'd think mypy should know, + # add this check in case it varies by version + args = getattr(ret_type, 'args', None) + if args: + if all(isinstance(arg, TypeVarType) for arg in args): + # Looks like the default factory is a type like `list` or `dict`, replace all args with `Any` + ret_type.args = tuple(default_any_type for _ in args) # type: ignore[attr-defined] + return ret_type + + return default_any_type + + +class PydanticPluginConfig: + __slots__ = ('init_forbid_extra', 'init_typed', 'warn_required_dynamic_aliases', 'warn_untyped_fields') + init_forbid_extra: bool + init_typed: bool + warn_required_dynamic_aliases: bool + warn_untyped_fields: bool + + def __init__(self, options: Options) -> None: + if options.config_file is None: # pragma: no cover + return + + toml_config = parse_toml(options.config_file) + if toml_config is not None: + config = toml_config.get('tool', {}).get('pydantic-mypy', {}) + for key in self.__slots__: + setting = config.get(key, False) + if not isinstance(setting, bool): + raise ValueError(f'Configuration value must be a boolean for key: {key}') + setattr(self, key, setting) + else: + plugin_config = ConfigParser() + plugin_config.read(options.config_file) + for key in self.__slots__: + setting = plugin_config.getboolean(CONFIGFILE_KEY, key, fallback=False) + setattr(self, key, setting) + + +def from_orm_callback(ctx: MethodContext) -> Type: + """ + Raise an error if orm_mode is not enabled + """ + model_type: Instance + if isinstance(ctx.type, CallableType) and isinstance(ctx.type.ret_type, Instance): + model_type = ctx.type.ret_type # called on the class + elif isinstance(ctx.type, Instance): + model_type = ctx.type # called on an instance (unusual, but still valid) + else: # pragma: no cover + detail = f'ctx.type: {ctx.type} (of type {ctx.type.__class__.__name__})' + error_unexpected_behavior(detail, ctx.api, ctx.context) + return ctx.default_return_type + pydantic_metadata = model_type.type.metadata.get(METADATA_KEY) + if pydantic_metadata is None: + return ctx.default_return_type + orm_mode = pydantic_metadata.get('config', {}).get('orm_mode') + if orm_mode is not True: + error_from_orm(get_name(model_type.type), ctx.api, ctx.context) + return ctx.default_return_type + + +class PydanticModelTransformer: + tracked_config_fields: Set[str] = { + 'extra', + 'allow_mutation', + 'frozen', + 'orm_mode', + 'allow_population_by_field_name', + 'alias_generator', + } + + def __init__(self, ctx: ClassDefContext, plugin_config: PydanticPluginConfig) -> None: + self._ctx = ctx + self.plugin_config = plugin_config + + def transform(self) -> None: + """ + Configures the BaseModel subclass according to the plugin settings. + + In particular: + * determines the model config and fields, + * adds a fields-aware signature for the initializer and construct methods + * freezes the class if allow_mutation = False or frozen = True + * stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses + """ + ctx = self._ctx + info = self._ctx.cls.info + + self.adjust_validator_signatures() + config = self.collect_config() + fields = self.collect_fields(config) + for field in fields: + if info[field.name].type is None: + if not ctx.api.final_iteration: + ctx.api.defer() + is_settings = any(get_fullname(base) == BASESETTINGS_FULLNAME for base in info.mro[:-1]) + self.add_initializer(fields, config, is_settings) + self.add_construct_method(fields) + self.set_frozen(fields, frozen=config.allow_mutation is False or config.frozen is True) + info.metadata[METADATA_KEY] = { + 'fields': {field.name: field.serialize() for field in fields}, + 'config': config.set_values_dict(), + } + + def adjust_validator_signatures(self) -> None: + """When we decorate a function `f` with `pydantic.validator(...), mypy sees + `f` as a regular method taking a `self` instance, even though pydantic + internally wraps `f` with `classmethod` if necessary. + + Teach mypy this by marking any function whose outermost decorator is a + `validator()` call as a classmethod. + """ + for name, sym in self._ctx.cls.info.names.items(): + if isinstance(sym.node, Decorator): + first_dec = sym.node.original_decorators[0] + if ( + isinstance(first_dec, CallExpr) + and isinstance(first_dec.callee, NameExpr) + and first_dec.callee.fullname == 'pydantic.class_validators.validator' + ): + sym.node.func.is_class = True + + def collect_config(self) -> 'ModelConfigData': + """ + Collects the values of the config attributes that are used by the plugin, accounting for parent classes. + """ + ctx = self._ctx + cls = ctx.cls + config = ModelConfigData() + for stmt in cls.defs.body: + if not isinstance(stmt, ClassDef): + continue + if stmt.name == 'Config': + for substmt in stmt.defs.body: + if not isinstance(substmt, AssignmentStmt): + continue + config.update(self.get_config_update(substmt)) + if ( + config.has_alias_generator + and not config.allow_population_by_field_name + and self.plugin_config.warn_required_dynamic_aliases + ): + error_required_dynamic_aliases(ctx.api, stmt) + for info in cls.info.mro[1:]: # 0 is the current class + if METADATA_KEY not in info.metadata: + continue + + # Each class depends on the set of fields in its ancestors + ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info))) + for name, value in info.metadata[METADATA_KEY]['config'].items(): + config.setdefault(name, value) + return config + + def collect_fields(self, model_config: 'ModelConfigData') -> List['PydanticModelField']: + """ + Collects the fields for the model, accounting for parent classes + """ + # First, collect fields belonging to the current class. + ctx = self._ctx + cls = self._ctx.cls + fields = [] # type: List[PydanticModelField] + known_fields = set() # type: Set[str] + for stmt in cls.defs.body: + if not isinstance(stmt, AssignmentStmt): # `and stmt.new_syntax` to require annotation + continue + + lhs = stmt.lvalues[0] + if not isinstance(lhs, NameExpr) or not is_valid_field(lhs.name): + continue + + if not stmt.new_syntax and self.plugin_config.warn_untyped_fields: + error_untyped_fields(ctx.api, stmt) + + # if lhs.name == '__config__': # BaseConfig not well handled; I'm not sure why yet + # continue + + sym = cls.info.names.get(lhs.name) + if sym is None: # pragma: no cover + # This is likely due to a star import (see the dataclasses plugin for a more detailed explanation) + # This is the same logic used in the dataclasses plugin + continue + + node = sym.node + if isinstance(node, PlaceholderNode): # pragma: no cover + # See the PlaceholderNode docstring for more detail about how this can occur + # Basically, it is an edge case when dealing with complex import logic + # This is the same logic used in the dataclasses plugin + continue + if not isinstance(node, Var): # pragma: no cover + # Don't know if this edge case still happens with the `is_valid_field` check above + # but better safe than sorry + continue + + # x: ClassVar[int] is ignored by dataclasses. + if node.is_classvar: + continue + + is_required = self.get_is_required(cls, stmt, lhs) + alias, has_dynamic_alias = self.get_alias_info(stmt) + if ( + has_dynamic_alias + and not model_config.allow_population_by_field_name + and self.plugin_config.warn_required_dynamic_aliases + ): + error_required_dynamic_aliases(ctx.api, stmt) + fields.append( + PydanticModelField( + name=lhs.name, + is_required=is_required, + alias=alias, + has_dynamic_alias=has_dynamic_alias, + line=stmt.line, + column=stmt.column, + ) + ) + known_fields.add(lhs.name) + all_fields = fields.copy() + for info in cls.info.mro[1:]: # 0 is the current class, -2 is BaseModel, -1 is object + if METADATA_KEY not in info.metadata: + continue + + superclass_fields = [] + # Each class depends on the set of fields in its ancestors + ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info))) + + for name, data in info.metadata[METADATA_KEY]['fields'].items(): + if name not in known_fields: + field = PydanticModelField.deserialize(info, data) + known_fields.add(name) + superclass_fields.append(field) + else: + (field,) = (a for a in all_fields if a.name == name) + all_fields.remove(field) + superclass_fields.append(field) + all_fields = superclass_fields + all_fields + return all_fields + + def add_initializer(self, fields: List['PydanticModelField'], config: 'ModelConfigData', is_settings: bool) -> None: + """ + Adds a fields-aware `__init__` method to the class. + + The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings. + """ + ctx = self._ctx + typed = self.plugin_config.init_typed + use_alias = config.allow_population_by_field_name is not True + force_all_optional = is_settings or bool( + config.has_alias_generator and not config.allow_population_by_field_name + ) + init_arguments = self.get_field_arguments( + fields, typed=typed, force_all_optional=force_all_optional, use_alias=use_alias + ) + if not self.should_init_forbid_extra(fields, config): + var = Var('kwargs') + init_arguments.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2)) + + if '__init__' not in ctx.cls.info.names: + add_method(ctx, '__init__', init_arguments, NoneType()) + + def add_construct_method(self, fields: List['PydanticModelField']) -> None: + """ + Adds a fully typed `construct` classmethod to the class. + + Similar to the fields-aware __init__ method, but always uses the field names (not aliases), + and does not treat settings fields as optional. + """ + ctx = self._ctx + set_str = ctx.api.named_type(f'{BUILTINS_NAME}.set', [ctx.api.named_type(f'{BUILTINS_NAME}.str')]) + optional_set_str = UnionType([set_str, NoneType()]) + fields_set_argument = Argument(Var('_fields_set', optional_set_str), optional_set_str, None, ARG_OPT) + construct_arguments = self.get_field_arguments(fields, typed=True, force_all_optional=False, use_alias=False) + construct_arguments = [fields_set_argument] + construct_arguments + + obj_type = ctx.api.named_type(f'{BUILTINS_NAME}.object') + self_tvar_name = '_PydanticBaseModel' # Make sure it does not conflict with other names in the class + tvar_fullname = ctx.cls.fullname + '.' + self_tvar_name + tvd = TypeVarDef(self_tvar_name, tvar_fullname, -1, [], obj_type) + self_tvar_expr = TypeVarExpr(self_tvar_name, tvar_fullname, [], obj_type) + ctx.cls.info.names[self_tvar_name] = SymbolTableNode(MDEF, self_tvar_expr) + + # Backward-compatible with TypeVarDef from Mypy 0.910. + if isinstance(tvd, TypeVarType): + self_type = tvd + else: + self_type = TypeVarType(tvd) # type: ignore[call-arg] + + add_method( + ctx, + 'construct', + construct_arguments, + return_type=self_type, + self_type=self_type, + tvar_def=tvd, + is_classmethod=True, + ) + + def set_frozen(self, fields: List['PydanticModelField'], frozen: bool) -> None: + """ + Marks all fields as properties so that attempts to set them trigger mypy errors. + + This is the same approach used by the attrs and dataclasses plugins. + """ + info = self._ctx.cls.info + for field in fields: + sym_node = info.names.get(field.name) + if sym_node is not None: + var = sym_node.node + assert isinstance(var, Var) + var.is_property = frozen + else: + var = field.to_var(info, use_alias=False) + var.info = info + var.is_property = frozen + var._fullname = get_fullname(info) + '.' + get_name(var) + info.names[get_name(var)] = SymbolTableNode(MDEF, var) + + def get_config_update(self, substmt: AssignmentStmt) -> Optional['ModelConfigData']: + """ + Determines the config update due to a single statement in the Config class definition. + + Warns if a tracked config attribute is set to a value the plugin doesn't know how to interpret (e.g., an int) + """ + lhs = substmt.lvalues[0] + if not (isinstance(lhs, NameExpr) and lhs.name in self.tracked_config_fields): + return None + if lhs.name == 'extra': + if isinstance(substmt.rvalue, StrExpr): + forbid_extra = substmt.rvalue.value == 'forbid' + elif isinstance(substmt.rvalue, MemberExpr): + forbid_extra = substmt.rvalue.name == 'forbid' + else: + error_invalid_config_value(lhs.name, self._ctx.api, substmt) + return None + return ModelConfigData(forbid_extra=forbid_extra) + if lhs.name == 'alias_generator': + has_alias_generator = True + if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname == 'builtins.None': + has_alias_generator = False + return ModelConfigData(has_alias_generator=has_alias_generator) + if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname in ('builtins.True', 'builtins.False'): + return ModelConfigData(**{lhs.name: substmt.rvalue.fullname == 'builtins.True'}) + error_invalid_config_value(lhs.name, self._ctx.api, substmt) + return None + + @staticmethod + def get_is_required(cls: ClassDef, stmt: AssignmentStmt, lhs: NameExpr) -> bool: + """ + Returns a boolean indicating whether the field defined in `stmt` is a required field. + """ + expr = stmt.rvalue + if isinstance(expr, TempNode): + # TempNode means annotation-only, so only non-required if Optional + value_type = get_proper_type(cls.info[lhs.name].type) + if isinstance(value_type, UnionType) and any(isinstance(item, NoneType) for item in value_type.items): + # Annotated as Optional, or otherwise having NoneType in the union + return False + return True + if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME: + # The "default value" is a call to `Field`; at this point, the field is + # only required if default is Ellipsis (i.e., `field_name: Annotation = Field(...)`) or if default_factory + # is specified. + for arg, name in zip(expr.args, expr.arg_names): + # If name is None, then this arg is the default because it is the only positonal argument. + if name is None or name == 'default': + return arg.__class__ is EllipsisExpr + if name == 'default_factory': + return False + return True + # Only required if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`) + return isinstance(expr, EllipsisExpr) + + @staticmethod + def get_alias_info(stmt: AssignmentStmt) -> Tuple[Optional[str], bool]: + """ + Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`. + + `has_dynamic_alias` is True if and only if an alias is provided, but not as a string literal. + If `has_dynamic_alias` is True, `alias` will be None. + """ + expr = stmt.rvalue + if isinstance(expr, TempNode): + # TempNode means annotation-only + return None, False + + if not ( + isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME + ): + # Assigned value is not a call to pydantic.fields.Field + return None, False + + for i, arg_name in enumerate(expr.arg_names): + if arg_name != 'alias': + continue + arg = expr.args[i] + if isinstance(arg, StrExpr): + return arg.value, False + else: + return None, True + return None, False + + def get_field_arguments( + self, fields: List['PydanticModelField'], typed: bool, force_all_optional: bool, use_alias: bool + ) -> List[Argument]: + """ + Helper function used during the construction of the `__init__` and `construct` method signatures. + + Returns a list of mypy Argument instances for use in the generated signatures. + """ + info = self._ctx.cls.info + arguments = [ + field.to_argument(info, typed=typed, force_optional=force_all_optional, use_alias=use_alias) + for field in fields + if not (use_alias and field.has_dynamic_alias) + ] + return arguments + + def should_init_forbid_extra(self, fields: List['PydanticModelField'], config: 'ModelConfigData') -> bool: + """ + Indicates whether the generated `__init__` should get a `**kwargs` at the end of its signature + + We disallow arbitrary kwargs if the extra config setting is "forbid", or if the plugin config says to, + *unless* a required dynamic alias is present (since then we can't determine a valid signature). + """ + if not config.allow_population_by_field_name: + if self.is_dynamic_alias_present(fields, bool(config.has_alias_generator)): + return False + if config.forbid_extra: + return True + return self.plugin_config.init_forbid_extra + + @staticmethod + def is_dynamic_alias_present(fields: List['PydanticModelField'], has_alias_generator: bool) -> bool: + """ + Returns whether any fields on the model have a "dynamic alias", i.e., an alias that cannot be + determined during static analysis. + """ + for field in fields: + if field.has_dynamic_alias: + return True + if has_alias_generator: + for field in fields: + if field.alias is None: + return True + return False + + +class PydanticModelField: + def __init__( + self, name: str, is_required: bool, alias: Optional[str], has_dynamic_alias: bool, line: int, column: int + ): + self.name = name + self.is_required = is_required + self.alias = alias + self.has_dynamic_alias = has_dynamic_alias + self.line = line + self.column = column + + def to_var(self, info: TypeInfo, use_alias: bool) -> Var: + name = self.name + if use_alias and self.alias is not None: + name = self.alias + return Var(name, info[self.name].type) + + def to_argument(self, info: TypeInfo, typed: bool, force_optional: bool, use_alias: bool) -> Argument: + if typed and info[self.name].type is not None: + type_annotation = info[self.name].type + else: + type_annotation = AnyType(TypeOfAny.explicit) + return Argument( + variable=self.to_var(info, use_alias), + type_annotation=type_annotation, + initializer=None, + kind=ARG_NAMED_OPT if force_optional or not self.is_required else ARG_NAMED, + ) + + def serialize(self) -> JsonDict: + return self.__dict__ + + @classmethod + def deserialize(cls, info: TypeInfo, data: JsonDict) -> 'PydanticModelField': + return cls(**data) + + +class ModelConfigData: + def __init__( + self, + forbid_extra: Optional[bool] = None, + allow_mutation: Optional[bool] = None, + frozen: Optional[bool] = None, + orm_mode: Optional[bool] = None, + allow_population_by_field_name: Optional[bool] = None, + has_alias_generator: Optional[bool] = None, + ): + self.forbid_extra = forbid_extra + self.allow_mutation = allow_mutation + self.frozen = frozen + self.orm_mode = orm_mode + self.allow_population_by_field_name = allow_population_by_field_name + self.has_alias_generator = has_alias_generator + + def set_values_dict(self) -> Dict[str, Any]: + return {k: v for k, v in self.__dict__.items() if v is not None} + + def update(self, config: Optional['ModelConfigData']) -> None: + if config is None: + return + for k, v in config.set_values_dict().items(): + setattr(self, k, v) + + def setdefault(self, key: str, value: Any) -> None: + if getattr(self, key) is None: + setattr(self, key, value) + + +ERROR_ORM = ErrorCode('pydantic-orm', 'Invalid from_orm call', 'Pydantic') +ERROR_CONFIG = ErrorCode('pydantic-config', 'Invalid config value', 'Pydantic') +ERROR_ALIAS = ErrorCode('pydantic-alias', 'Dynamic alias disallowed', 'Pydantic') +ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic') +ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic') +ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic') + + +def error_from_orm(model_name: str, api: CheckerPluginInterface, context: Context) -> None: + api.fail(f'"{model_name}" does not have orm_mode=True', context, code=ERROR_ORM) + + +def error_invalid_config_value(name: str, api: SemanticAnalyzerPluginInterface, context: Context) -> None: + api.fail(f'Invalid value for "Config.{name}"', context, code=ERROR_CONFIG) + + +def error_required_dynamic_aliases(api: SemanticAnalyzerPluginInterface, context: Context) -> None: + api.fail('Required dynamic aliases disallowed', context, code=ERROR_ALIAS) + + +def error_unexpected_behavior(detail: str, api: CheckerPluginInterface, context: Context) -> None: # pragma: no cover + # Can't think of a good way to test this, but I confirmed it renders as desired by adding to a non-error path + link = 'https://github.com/pydantic/pydantic/issues/new/choose' + full_message = f'The pydantic mypy plugin ran into unexpected behavior: {detail}\n' + full_message += f'Please consider reporting this bug at {link} so we can try to fix it!' + api.fail(full_message, context, code=ERROR_UNEXPECTED) + + +def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context) -> None: + api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED) + + +def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None: + api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS) + + +def add_method( + ctx: ClassDefContext, + name: str, + args: List[Argument], + return_type: Type, + self_type: Optional[Type] = None, + tvar_def: Optional[TypeVarDef] = None, + is_classmethod: bool = False, + is_new: bool = False, + # is_staticmethod: bool = False, +) -> None: + """ + Adds a new method to a class. + + This can be dropped if/when https://github.com/python/mypy/issues/7301 is merged + """ + info = ctx.cls.info + + # First remove any previously generated methods with the same name + # to avoid clashes and problems in the semantic analyzer. + if name in info.names: + sym = info.names[name] + if sym.plugin_generated and isinstance(sym.node, FuncDef): + ctx.cls.defs.body.remove(sym.node) # pragma: no cover + + self_type = self_type or fill_typevars(info) + if is_classmethod or is_new: + first = [Argument(Var('_cls'), TypeType.make_normalized(self_type), None, ARG_POS)] + # elif is_staticmethod: + # first = [] + else: + self_type = self_type or fill_typevars(info) + first = [Argument(Var('__pydantic_self__'), self_type, None, ARG_POS)] + args = first + args + arg_types, arg_names, arg_kinds = [], [], [] + for arg in args: + assert arg.type_annotation, 'All arguments must be fully typed.' + arg_types.append(arg.type_annotation) + arg_names.append(get_name(arg.variable)) + arg_kinds.append(arg.kind) + + function_type = ctx.api.named_type(f'{BUILTINS_NAME}.function') + signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type) + if tvar_def: + signature.variables = [tvar_def] + + func = FuncDef(name, args, Block([PassStmt()])) + func.info = info + func.type = set_callable_name(signature, func) + func.is_class = is_classmethod + # func.is_static = is_staticmethod + func._fullname = get_fullname(info) + '.' + name + func.line = info.line + + # NOTE: we would like the plugin generated node to dominate, but we still + # need to keep any existing definitions so they get semantically analyzed. + if name in info.names: + # Get a nice unique name instead. + r_name = get_unique_redefinition_name(name, info.names) + info.names[r_name] = info.names[name] + + if is_classmethod: # or is_staticmethod: + func.is_decorated = True + v = Var(name, func.type) + v.info = info + v._fullname = func._fullname + # if is_classmethod: + v.is_classmethod = True + dec = Decorator(func, [NameExpr('classmethod')], v) + # else: + # v.is_staticmethod = True + # dec = Decorator(func, [NameExpr('staticmethod')], v) + + dec.line = info.line + sym = SymbolTableNode(MDEF, dec) + else: + sym = SymbolTableNode(MDEF, func) + sym.plugin_generated = True + + info.names[name] = sym + info.defn.defs.body.append(func) + + +def get_fullname(x: Union[FuncBase, SymbolNode]) -> str: + """ + Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped. + """ + fn = x.fullname + if callable(fn): # pragma: no cover + return fn() + return fn + + +def get_name(x: Union[FuncBase, SymbolNode]) -> str: + """ + Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped. + """ + fn = x.name + if callable(fn): # pragma: no cover + return fn() + return fn + + +def parse_toml(config_file: str) -> Optional[Dict[str, Any]]: + if not config_file.endswith('.toml'): + return None + + read_mode = 'rb' + if sys.version_info >= (3, 11): + import tomllib as toml_ + else: + try: + import tomli as toml_ + except ImportError: + # older versions of mypy have toml as a dependency, not tomli + read_mode = 'r' + try: + import toml as toml_ # type: ignore[no-redef] + except ImportError: # pragma: no cover + import warnings + + warnings.warn('No TOML parser installed, cannot read configuration from `pyproject.toml`.') + return None + + with open(config_file, read_mode) as rf: + return toml_.load(rf) # type: ignore[arg-type] diff --git a/libs/win/pydantic/networks.cp37-win_amd64.pyd b/libs/win/pydantic/networks.cp37-win_amd64.pyd new file mode 100644 index 00000000..08ef06cd Binary files /dev/null and b/libs/win/pydantic/networks.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/networks.py b/libs/win/pydantic/networks.py new file mode 100644 index 00000000..c7d97186 --- /dev/null +++ b/libs/win/pydantic/networks.py @@ -0,0 +1,736 @@ +import re +from ipaddress import ( + IPv4Address, + IPv4Interface, + IPv4Network, + IPv6Address, + IPv6Interface, + IPv6Network, + _BaseAddress, + _BaseNetwork, +) +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Generator, + List, + Match, + Optional, + Pattern, + Set, + Tuple, + Type, + Union, + cast, + no_type_check, +) + +from . import errors +from .utils import Representation, update_not_none +from .validators import constr_length_validator, str_validator + +if TYPE_CHECKING: + import email_validator + from typing_extensions import TypedDict + + from .config import BaseConfig + from .fields import ModelField + from .typing import AnyCallable + + CallableGenerator = Generator[AnyCallable, None, None] + + class Parts(TypedDict, total=False): + scheme: str + user: Optional[str] + password: Optional[str] + ipv4: Optional[str] + ipv6: Optional[str] + domain: Optional[str] + port: Optional[str] + path: Optional[str] + query: Optional[str] + fragment: Optional[str] + + class HostParts(TypedDict, total=False): + host: str + tld: Optional[str] + host_type: Optional[str] + port: Optional[str] + rebuild: bool + +else: + email_validator = None + + class Parts(dict): + pass + + +NetworkType = Union[str, bytes, int, Tuple[Union[str, bytes, int], Union[str, int]]] + +__all__ = [ + 'AnyUrl', + 'AnyHttpUrl', + 'FileUrl', + 'HttpUrl', + 'stricturl', + 'EmailStr', + 'NameEmail', + 'IPvAnyAddress', + 'IPvAnyInterface', + 'IPvAnyNetwork', + 'PostgresDsn', + 'CockroachDsn', + 'AmqpDsn', + 'RedisDsn', + 'MongoDsn', + 'KafkaDsn', + 'validate_email', +] + +_url_regex_cache = None +_multi_host_url_regex_cache = None +_ascii_domain_regex_cache = None +_int_domain_regex_cache = None +_host_regex_cache = None + +_host_regex = ( + r'(?:' + r'(?P(?:\d{1,3}\.){3}\d{1,3})(?=$|[/:#?])|' # ipv4 + r'(?P\[[A-F0-9]*:[A-F0-9:]+\])(?=$|[/:#?])|' # ipv6 + r'(?P[^\s/:?#]+)' # domain, validation occurs later + r')?' + r'(?::(?P\d+))?' # port +) +_scheme_regex = r'(?:(?P[a-z][a-z0-9+\-.]+)://)?' # scheme https://tools.ietf.org/html/rfc3986#appendix-A +_user_info_regex = r'(?:(?P[^\s:/]*)(?::(?P[^\s/]*))?@)?' +_path_regex = r'(?P/[^\s?#]*)?' +_query_regex = r'(?:\?(?P[^\s#]*))?' +_fragment_regex = r'(?:#(?P[^\s#]*))?' + + +def url_regex() -> Pattern[str]: + global _url_regex_cache + if _url_regex_cache is None: + _url_regex_cache = re.compile( + rf'{_scheme_regex}{_user_info_regex}{_host_regex}{_path_regex}{_query_regex}{_fragment_regex}', + re.IGNORECASE, + ) + return _url_regex_cache + + +def multi_host_url_regex() -> Pattern[str]: + """ + Compiled multi host url regex. + + Additionally to `url_regex` it allows to match multiple hosts. + E.g. host1.db.net,host2.db.net + """ + global _multi_host_url_regex_cache + if _multi_host_url_regex_cache is None: + _multi_host_url_regex_cache = re.compile( + rf'{_scheme_regex}{_user_info_regex}' + r'(?P([^/]*))' # validation occurs later + rf'{_path_regex}{_query_regex}{_fragment_regex}', + re.IGNORECASE, + ) + return _multi_host_url_regex_cache + + +def ascii_domain_regex() -> Pattern[str]: + global _ascii_domain_regex_cache + if _ascii_domain_regex_cache is None: + ascii_chunk = r'[_0-9a-z](?:[-_0-9a-z]{0,61}[_0-9a-z])?' + ascii_domain_ending = r'(?P\.[a-z]{2,63})?\.?' + _ascii_domain_regex_cache = re.compile( + fr'(?:{ascii_chunk}\.)*?{ascii_chunk}{ascii_domain_ending}', re.IGNORECASE + ) + return _ascii_domain_regex_cache + + +def int_domain_regex() -> Pattern[str]: + global _int_domain_regex_cache + if _int_domain_regex_cache is None: + int_chunk = r'[_0-9a-\U00040000](?:[-_0-9a-\U00040000]{0,61}[_0-9a-\U00040000])?' + int_domain_ending = r'(?P(\.[^\W\d_]{2,63})|(\.(?:xn--)[_0-9a-z-]{2,63}))?\.?' + _int_domain_regex_cache = re.compile(fr'(?:{int_chunk}\.)*?{int_chunk}{int_domain_ending}', re.IGNORECASE) + return _int_domain_regex_cache + + +def host_regex() -> Pattern[str]: + global _host_regex_cache + if _host_regex_cache is None: + _host_regex_cache = re.compile( + _host_regex, + re.IGNORECASE, + ) + return _host_regex_cache + + +class AnyUrl(str): + strip_whitespace = True + min_length = 1 + max_length = 2**16 + allowed_schemes: Optional[Collection[str]] = None + tld_required: bool = False + user_required: bool = False + host_required: bool = True + hidden_parts: Set[str] = set() + + __slots__ = ('scheme', 'user', 'password', 'host', 'tld', 'host_type', 'port', 'path', 'query', 'fragment') + + @no_type_check + def __new__(cls, url: Optional[str], **kwargs) -> object: + return str.__new__(cls, cls.build(**kwargs) if url is None else url) + + def __init__( + self, + url: str, + *, + scheme: str, + user: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + tld: Optional[str] = None, + host_type: str = 'domain', + port: Optional[str] = None, + path: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, + ) -> None: + str.__init__(url) + self.scheme = scheme + self.user = user + self.password = password + self.host = host + self.tld = tld + self.host_type = host_type + self.port = port + self.path = path + self.query = query + self.fragment = fragment + + @classmethod + def build( + cls, + *, + scheme: str, + user: Optional[str] = None, + password: Optional[str] = None, + host: str, + port: Optional[str] = None, + path: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, + **_kwargs: str, + ) -> str: + parts = Parts( + scheme=scheme, + user=user, + password=password, + host=host, + port=port, + path=path, + query=query, + fragment=fragment, + **_kwargs, # type: ignore[misc] + ) + + url = scheme + '://' + if user: + url += user + if password: + url += ':' + password + if user or password: + url += '@' + url += host + if port and ('port' not in cls.hidden_parts or cls.get_default_parts(parts).get('port') != port): + url += ':' + port + if path: + url += path + if query: + url += '?' + query + if fragment: + url += '#' + fragment + return url + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none(field_schema, minLength=cls.min_length, maxLength=cls.max_length, format='uri') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls, value: Any, field: 'ModelField', config: 'BaseConfig') -> 'AnyUrl': + if value.__class__ == cls: + return value + value = str_validator(value) + if cls.strip_whitespace: + value = value.strip() + url: str = cast(str, constr_length_validator(value, field, config)) + + m = cls._match_url(url) + # the regex should always match, if it doesn't please report with details of the URL tried + assert m, 'URL regex failed unexpectedly' + + original_parts = cast('Parts', m.groupdict()) + parts = cls.apply_default_parts(original_parts) + parts = cls.validate_parts(parts) + + if m.end() != len(url): + raise errors.UrlExtraError(extra=url[m.end() :]) + + return cls._build_url(m, url, parts) + + @classmethod + def _build_url(cls, m: Match[str], url: str, parts: 'Parts') -> 'AnyUrl': + """ + Validate hosts and build the AnyUrl object. Split from `validate` so this method + can be altered in `MultiHostDsn`. + """ + host, tld, host_type, rebuild = cls.validate_host(parts) + + return cls( + None if rebuild else url, + scheme=parts['scheme'], + user=parts['user'], + password=parts['password'], + host=host, + tld=tld, + host_type=host_type, + port=parts['port'], + path=parts['path'], + query=parts['query'], + fragment=parts['fragment'], + ) + + @staticmethod + def _match_url(url: str) -> Optional[Match[str]]: + return url_regex().match(url) + + @staticmethod + def _validate_port(port: Optional[str]) -> None: + if port is not None and int(port) > 65_535: + raise errors.UrlPortError() + + @classmethod + def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts': + """ + A method used to validate parts of a URL. + Could be overridden to set default values for parts if missing + """ + scheme = parts['scheme'] + if scheme is None: + raise errors.UrlSchemeError() + + if cls.allowed_schemes and scheme.lower() not in cls.allowed_schemes: + raise errors.UrlSchemePermittedError(set(cls.allowed_schemes)) + + if validate_port: + cls._validate_port(parts['port']) + + user = parts['user'] + if cls.user_required and user is None: + raise errors.UrlUserInfoError() + + return parts + + @classmethod + def validate_host(cls, parts: 'Parts') -> Tuple[str, Optional[str], str, bool]: + tld, host_type, rebuild = None, None, False + for f in ('domain', 'ipv4', 'ipv6'): + host = parts[f] # type: ignore[literal-required] + if host: + host_type = f + break + + if host is None: + if cls.host_required: + raise errors.UrlHostError() + elif host_type == 'domain': + is_international = False + d = ascii_domain_regex().fullmatch(host) + if d is None: + d = int_domain_regex().fullmatch(host) + if d is None: + raise errors.UrlHostError() + is_international = True + + tld = d.group('tld') + if tld is None and not is_international: + d = int_domain_regex().fullmatch(host) + assert d is not None + tld = d.group('tld') + is_international = True + + if tld is not None: + tld = tld[1:] + elif cls.tld_required: + raise errors.UrlHostTldError() + + if is_international: + host_type = 'int_domain' + rebuild = True + host = host.encode('idna').decode('ascii') + if tld is not None: + tld = tld.encode('idna').decode('ascii') + + return host, tld, host_type, rebuild # type: ignore + + @staticmethod + def get_default_parts(parts: 'Parts') -> 'Parts': + return {} + + @classmethod + def apply_default_parts(cls, parts: 'Parts') -> 'Parts': + for key, value in cls.get_default_parts(parts).items(): + if not parts[key]: # type: ignore[literal-required] + parts[key] = value # type: ignore[literal-required] + return parts + + def __repr__(self) -> str: + extra = ', '.join(f'{n}={getattr(self, n)!r}' for n in self.__slots__ if getattr(self, n) is not None) + return f'{self.__class__.__name__}({super().__repr__()}, {extra})' + + +class AnyHttpUrl(AnyUrl): + allowed_schemes = {'http', 'https'} + + __slots__ = () + + +class HttpUrl(AnyHttpUrl): + tld_required = True + # https://stackoverflow.com/questions/417142/what-is-the-maximum-length-of-a-url-in-different-browsers + max_length = 2083 + hidden_parts = {'port'} + + @staticmethod + def get_default_parts(parts: 'Parts') -> 'Parts': + return {'port': '80' if parts['scheme'] == 'http' else '443'} + + +class FileUrl(AnyUrl): + allowed_schemes = {'file'} + host_required = False + + __slots__ = () + + +class MultiHostDsn(AnyUrl): + __slots__ = AnyUrl.__slots__ + ('hosts',) + + def __init__(self, *args: Any, hosts: Optional[List['HostParts']] = None, **kwargs: Any): + super().__init__(*args, **kwargs) + self.hosts = hosts + + @staticmethod + def _match_url(url: str) -> Optional[Match[str]]: + return multi_host_url_regex().match(url) + + @classmethod + def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts': + return super().validate_parts(parts, validate_port=False) + + @classmethod + def _build_url(cls, m: Match[str], url: str, parts: 'Parts') -> 'MultiHostDsn': + hosts_parts: List['HostParts'] = [] + host_re = host_regex() + for host in m.groupdict()['hosts'].split(','): + d: Parts = host_re.match(host).groupdict() # type: ignore + host, tld, host_type, rebuild = cls.validate_host(d) + port = d.get('port') + cls._validate_port(port) + hosts_parts.append( + { + 'host': host, + 'host_type': host_type, + 'tld': tld, + 'rebuild': rebuild, + 'port': port, + } + ) + + if len(hosts_parts) > 1: + return cls( + None if any([hp['rebuild'] for hp in hosts_parts]) else url, + scheme=parts['scheme'], + user=parts['user'], + password=parts['password'], + path=parts['path'], + query=parts['query'], + fragment=parts['fragment'], + host_type=None, + hosts=hosts_parts, + ) + else: + # backwards compatibility with single host + host_part = hosts_parts[0] + return cls( + None if host_part['rebuild'] else url, + scheme=parts['scheme'], + user=parts['user'], + password=parts['password'], + host=host_part['host'], + tld=host_part['tld'], + host_type=host_part['host_type'], + port=host_part.get('port'), + path=parts['path'], + query=parts['query'], + fragment=parts['fragment'], + ) + + +class PostgresDsn(MultiHostDsn): + allowed_schemes = { + 'postgres', + 'postgresql', + 'postgresql+asyncpg', + 'postgresql+pg8000', + 'postgresql+psycopg2', + 'postgresql+psycopg2cffi', + 'postgresql+py-postgresql', + 'postgresql+pygresql', + } + user_required = True + + __slots__ = () + + +class CockroachDsn(AnyUrl): + allowed_schemes = { + 'cockroachdb', + 'cockroachdb+psycopg2', + 'cockroachdb+asyncpg', + } + user_required = True + + +class AmqpDsn(AnyUrl): + allowed_schemes = {'amqp', 'amqps'} + host_required = False + + +class RedisDsn(AnyUrl): + __slots__ = () + allowed_schemes = {'redis', 'rediss'} + host_required = False + + @staticmethod + def get_default_parts(parts: 'Parts') -> 'Parts': + return { + 'domain': 'localhost' if not (parts['ipv4'] or parts['ipv6']) else '', + 'port': '6379', + 'path': '/0', + } + + +class MongoDsn(AnyUrl): + allowed_schemes = {'mongodb'} + + # TODO: Needed to generic "Parts" for "Replica Set", "Sharded Cluster", and other mongodb deployment modes + @staticmethod + def get_default_parts(parts: 'Parts') -> 'Parts': + return { + 'port': '27017', + } + + +class KafkaDsn(AnyUrl): + allowed_schemes = {'kafka'} + + @staticmethod + def get_default_parts(parts: 'Parts') -> 'Parts': + return { + 'domain': 'localhost', + 'port': '9092', + } + + +def stricturl( + *, + strip_whitespace: bool = True, + min_length: int = 1, + max_length: int = 2**16, + tld_required: bool = True, + host_required: bool = True, + allowed_schemes: Optional[Collection[str]] = None, +) -> Type[AnyUrl]: + # use kwargs then define conf in a dict to aid with IDE type hinting + namespace = dict( + strip_whitespace=strip_whitespace, + min_length=min_length, + max_length=max_length, + tld_required=tld_required, + host_required=host_required, + allowed_schemes=allowed_schemes, + ) + return type('UrlValue', (AnyUrl,), namespace) + + +def import_email_validator() -> None: + global email_validator + try: + import email_validator + except ImportError as e: + raise ImportError('email-validator is not installed, run `pip install pydantic[email]`') from e + + +class EmailStr(str): + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format='email') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + # included here and below so the error happens straight away + import_email_validator() + + yield str_validator + yield cls.validate + + @classmethod + def validate(cls, value: Union[str]) -> str: + return validate_email(value)[1] + + +class NameEmail(Representation): + __slots__ = 'name', 'email' + + def __init__(self, name: str, email: str): + self.name = name + self.email = email + + def __eq__(self, other: Any) -> bool: + return isinstance(other, NameEmail) and (self.name, self.email) == (other.name, other.email) + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format='name-email') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + import_email_validator() + + yield cls.validate + + @classmethod + def validate(cls, value: Any) -> 'NameEmail': + if value.__class__ == cls: + return value + value = str_validator(value) + return cls(*validate_email(value)) + + def __str__(self) -> str: + return f'{self.name} <{self.email}>' + + +class IPvAnyAddress(_BaseAddress): + __slots__ = () + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format='ipvanyaddress') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls, value: Union[str, bytes, int]) -> Union[IPv4Address, IPv6Address]: + try: + return IPv4Address(value) + except ValueError: + pass + + try: + return IPv6Address(value) + except ValueError: + raise errors.IPvAnyAddressError() + + +class IPvAnyInterface(_BaseAddress): + __slots__ = () + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format='ipvanyinterface') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls, value: NetworkType) -> Union[IPv4Interface, IPv6Interface]: + try: + return IPv4Interface(value) + except ValueError: + pass + + try: + return IPv6Interface(value) + except ValueError: + raise errors.IPvAnyInterfaceError() + + +class IPvAnyNetwork(_BaseNetwork): # type: ignore + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format='ipvanynetwork') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls, value: NetworkType) -> Union[IPv4Network, IPv6Network]: + # Assume IP Network is defined with a default value for ``strict`` argument. + # Define your own class if you want to specify network address check strictness. + try: + return IPv4Network(value) + except ValueError: + pass + + try: + return IPv6Network(value) + except ValueError: + raise errors.IPvAnyNetworkError() + + +pretty_email_regex = re.compile(r'([\w ]*?) *<(.*)> *') + + +def validate_email(value: Union[str]) -> Tuple[str, str]: + """ + Brutally simple email address validation. Note unlike most email address validation + * raw ip address (literal) domain parts are not allowed. + * "John Doe " style "pretty" email addresses are processed + * the local part check is extremely basic. This raises the possibility of unicode spoofing, but no better + solution is really possible. + * spaces are striped from the beginning and end of addresses but no error is raised + + See RFC 5322 but treat it with suspicion, there seems to exist no universally acknowledged test for a valid email! + """ + if email_validator is None: + import_email_validator() + + m = pretty_email_regex.fullmatch(value) + name: Optional[str] = None + if m: + name, value = m.groups() + + email = value.strip() + + try: + email_validator.validate_email(email, check_deliverability=False) + except email_validator.EmailNotValidError as e: + raise errors.EmailError() from e + + at_index = email.index('@') + local_part = email[:at_index] # RFC 5321, local part must be case-sensitive. + global_part = email[at_index:].lower() + + return name or local_part, local_part + global_part diff --git a/libs/win/pydantic/parse.cp37-win_amd64.pyd b/libs/win/pydantic/parse.cp37-win_amd64.pyd new file mode 100644 index 00000000..80090237 Binary files /dev/null and b/libs/win/pydantic/parse.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/parse.py b/libs/win/pydantic/parse.py new file mode 100644 index 00000000..7ac330ca --- /dev/null +++ b/libs/win/pydantic/parse.py @@ -0,0 +1,66 @@ +import json +import pickle +from enum import Enum +from pathlib import Path +from typing import Any, Callable, Union + +from .types import StrBytes + + +class Protocol(str, Enum): + json = 'json' + pickle = 'pickle' + + +def load_str_bytes( + b: StrBytes, + *, + content_type: str = None, + encoding: str = 'utf8', + proto: Protocol = None, + allow_pickle: bool = False, + json_loads: Callable[[str], Any] = json.loads, +) -> Any: + if proto is None and content_type: + if content_type.endswith(('json', 'javascript')): + pass + elif allow_pickle and content_type.endswith('pickle'): + proto = Protocol.pickle + else: + raise TypeError(f'Unknown content-type: {content_type}') + + proto = proto or Protocol.json + + if proto == Protocol.json: + if isinstance(b, bytes): + b = b.decode(encoding) + return json_loads(b) + elif proto == Protocol.pickle: + if not allow_pickle: + raise RuntimeError('Trying to decode with pickle with allow_pickle=False') + bb = b if isinstance(b, bytes) else b.encode() + return pickle.loads(bb) + else: + raise TypeError(f'Unknown protocol: {proto}') + + +def load_file( + path: Union[str, Path], + *, + content_type: str = None, + encoding: str = 'utf8', + proto: Protocol = None, + allow_pickle: bool = False, + json_loads: Callable[[str], Any] = json.loads, +) -> Any: + path = Path(path) + b = path.read_bytes() + if content_type is None: + if path.suffix in ('.js', '.json'): + proto = Protocol.json + elif path.suffix == '.pkl': + proto = Protocol.pickle + + return load_str_bytes( + b, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle, json_loads=json_loads + ) diff --git a/libs/win/pydantic/py.typed b/libs/win/pydantic/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/libs/win/pydantic/schema.cp37-win_amd64.pyd b/libs/win/pydantic/schema.cp37-win_amd64.pyd new file mode 100644 index 00000000..c142e02d Binary files /dev/null and b/libs/win/pydantic/schema.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/schema.py b/libs/win/pydantic/schema.py new file mode 100644 index 00000000..e7af56f1 --- /dev/null +++ b/libs/win/pydantic/schema.py @@ -0,0 +1,1153 @@ +import re +import warnings +from collections import defaultdict +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from enum import Enum +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + ForwardRef, + FrozenSet, + Generic, + Iterable, + List, + Optional, + Pattern, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) +from uuid import UUID + +from typing_extensions import Annotated, Literal + +from .fields import ( + MAPPING_LIKE_SHAPES, + SHAPE_DEQUE, + SHAPE_FROZENSET, + SHAPE_GENERIC, + SHAPE_ITERABLE, + SHAPE_LIST, + SHAPE_SEQUENCE, + SHAPE_SET, + SHAPE_SINGLETON, + SHAPE_TUPLE, + SHAPE_TUPLE_ELLIPSIS, + FieldInfo, + ModelField, +) +from .json import pydantic_encoder +from .networks import AnyUrl, EmailStr +from .types import ( + ConstrainedDecimal, + ConstrainedFloat, + ConstrainedFrozenSet, + ConstrainedInt, + ConstrainedList, + ConstrainedSet, + SecretBytes, + SecretStr, + StrictBytes, + StrictStr, + conbytes, + condecimal, + confloat, + confrozenset, + conint, + conlist, + conset, + constr, +) +from .typing import ( + all_literal_values, + get_args, + get_origin, + get_sub_types, + is_callable_type, + is_literal_type, + is_namedtuple, + is_none_type, + is_union, +) +from .utils import ROOT_KEY, get_model, lenient_issubclass + +if TYPE_CHECKING: + from .dataclasses import Dataclass + from .main import BaseModel + +default_prefix = '#/definitions/' +default_ref_template = '#/definitions/{model}' + +TypeModelOrEnum = Union[Type['BaseModel'], Type[Enum]] +TypeModelSet = Set[TypeModelOrEnum] + + +def _apply_modify_schema( + modify_schema: Callable[..., None], field: Optional[ModelField], field_schema: Dict[str, Any] +) -> None: + from inspect import signature + + sig = signature(modify_schema) + args = set(sig.parameters.keys()) + if 'field' in args or 'kwargs' in args: + modify_schema(field_schema, field=field) + else: + modify_schema(field_schema) + + +def schema( + models: Sequence[Union[Type['BaseModel'], Type['Dataclass']]], + *, + by_alias: bool = True, + title: Optional[str] = None, + description: Optional[str] = None, + ref_prefix: Optional[str] = None, + ref_template: str = default_ref_template, +) -> Dict[str, Any]: + """ + Process a list of models and generate a single JSON Schema with all of them defined in the ``definitions`` + top-level JSON key, including their sub-models. + + :param models: a list of models to include in the generated JSON Schema + :param by_alias: generate the schemas using the aliases defined, if any + :param title: title for the generated schema that includes the definitions + :param description: description for the generated schema + :param ref_prefix: the JSON Pointer prefix for schema references with ``$ref``, if None, will be set to the + default of ``#/definitions/``. Update it if you want the schemas to reference the definitions somewhere + else, e.g. for OpenAPI use ``#/components/schemas/``. The resulting generated schemas will still be at the + top-level key ``definitions``, so you can extract them from there. But all the references will have the set + prefix. + :param ref_template: Use a ``string.format()`` template for ``$ref`` instead of a prefix. This can be useful + for references that cannot be represented by ``ref_prefix`` such as a definition stored in another file. For + a sibling json file in a ``/schemas`` directory use ``"/schemas/${model}.json#"``. + :return: dict with the JSON Schema with a ``definitions`` top-level key including the schema definitions for + the models and sub-models passed in ``models``. + """ + clean_models = [get_model(model) for model in models] + flat_models = get_flat_models_from_models(clean_models) + model_name_map = get_model_name_map(flat_models) + definitions = {} + output_schema: Dict[str, Any] = {} + if title: + output_schema['title'] = title + if description: + output_schema['description'] = description + for model in clean_models: + m_schema, m_definitions, m_nested_models = model_process_schema( + model, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + ) + definitions.update(m_definitions) + model_name = model_name_map[model] + definitions[model_name] = m_schema + if definitions: + output_schema['definitions'] = definitions + return output_schema + + +def model_schema( + model: Union[Type['BaseModel'], Type['Dataclass']], + by_alias: bool = True, + ref_prefix: Optional[str] = None, + ref_template: str = default_ref_template, +) -> Dict[str, Any]: + """ + Generate a JSON Schema for one model. With all the sub-models defined in the ``definitions`` top-level + JSON key. + + :param model: a Pydantic model (a class that inherits from BaseModel) + :param by_alias: generate the schemas using the aliases defined, if any + :param ref_prefix: the JSON Pointer prefix for schema references with ``$ref``, if None, will be set to the + default of ``#/definitions/``. Update it if you want the schemas to reference the definitions somewhere + else, e.g. for OpenAPI use ``#/components/schemas/``. The resulting generated schemas will still be at the + top-level key ``definitions``, so you can extract them from there. But all the references will have the set + prefix. + :param ref_template: Use a ``string.format()`` template for ``$ref`` instead of a prefix. This can be useful for + references that cannot be represented by ``ref_prefix`` such as a definition stored in another file. For a + sibling json file in a ``/schemas`` directory use ``"/schemas/${model}.json#"``. + :return: dict with the JSON Schema for the passed ``model`` + """ + model = get_model(model) + flat_models = get_flat_models_from_model(model) + model_name_map = get_model_name_map(flat_models) + model_name = model_name_map[model] + m_schema, m_definitions, nested_models = model_process_schema( + model, by_alias=by_alias, model_name_map=model_name_map, ref_prefix=ref_prefix, ref_template=ref_template + ) + if model_name in nested_models: + # model_name is in Nested models, it has circular references + m_definitions[model_name] = m_schema + m_schema = get_schema_ref(model_name, ref_prefix, ref_template, False) + if m_definitions: + m_schema.update({'definitions': m_definitions}) + return m_schema + + +def get_field_info_schema(field: ModelField, schema_overrides: bool = False) -> Tuple[Dict[str, Any], bool]: + + # If no title is explicitly set, we don't set title in the schema for enums. + # The behaviour is the same as `BaseModel` reference, where the default title + # is in the definitions part of the schema. + schema_: Dict[str, Any] = {} + if field.field_info.title or not lenient_issubclass(field.type_, Enum): + schema_['title'] = field.field_info.title or field.alias.title().replace('_', ' ') + + if field.field_info.title: + schema_overrides = True + + if field.field_info.description: + schema_['description'] = field.field_info.description + schema_overrides = True + + if not field.required and field.default is not None and not is_callable_type(field.outer_type_): + schema_['default'] = encode_default(field.default) + schema_overrides = True + + return schema_, schema_overrides + + +def field_schema( + field: ModelField, + *, + by_alias: bool = True, + model_name_map: Dict[TypeModelOrEnum, str], + ref_prefix: Optional[str] = None, + ref_template: str = default_ref_template, + known_models: TypeModelSet = None, +) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: + """ + Process a Pydantic field and return a tuple with a JSON Schema for it as the first item. + Also return a dictionary of definitions with models as keys and their schemas as values. If the passed field + is a model and has sub-models, and those sub-models don't have overrides (as ``title``, ``default``, etc), they + will be included in the definitions and referenced in the schema instead of included recursively. + + :param field: a Pydantic ``ModelField`` + :param by_alias: use the defined alias (if any) in the returned schema + :param model_name_map: used to generate the JSON Schema references to other models included in the definitions + :param ref_prefix: the JSON Pointer prefix to use for references to other schemas, if None, the default of + #/definitions/ will be used + :param ref_template: Use a ``string.format()`` template for ``$ref`` instead of a prefix. This can be useful for + references that cannot be represented by ``ref_prefix`` such as a definition stored in another file. For a + sibling json file in a ``/schemas`` directory use ``"/schemas/${model}.json#"``. + :param known_models: used to solve circular references + :return: tuple of the schema for this field and additional definitions + """ + s, schema_overrides = get_field_info_schema(field) + + validation_schema = get_field_schema_validations(field) + if validation_schema: + s.update(validation_schema) + schema_overrides = True + + f_schema, f_definitions, f_nested_models = field_type_schema( + field, + by_alias=by_alias, + model_name_map=model_name_map, + schema_overrides=schema_overrides, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models or set(), + ) + + # $ref will only be returned when there are no schema_overrides + if '$ref' in f_schema: + return f_schema, f_definitions, f_nested_models + else: + s.update(f_schema) + return s, f_definitions, f_nested_models + + +numeric_types = (int, float, Decimal) +_str_types_attrs: Tuple[Tuple[str, Union[type, Tuple[type, ...]], str], ...] = ( + ('max_length', numeric_types, 'maxLength'), + ('min_length', numeric_types, 'minLength'), + ('regex', str, 'pattern'), +) + +_numeric_types_attrs: Tuple[Tuple[str, Union[type, Tuple[type, ...]], str], ...] = ( + ('gt', numeric_types, 'exclusiveMinimum'), + ('lt', numeric_types, 'exclusiveMaximum'), + ('ge', numeric_types, 'minimum'), + ('le', numeric_types, 'maximum'), + ('multiple_of', numeric_types, 'multipleOf'), +) + + +def get_field_schema_validations(field: ModelField) -> Dict[str, Any]: + """ + Get the JSON Schema validation keywords for a ``field`` with an annotation of + a Pydantic ``FieldInfo`` with validation arguments. + """ + f_schema: Dict[str, Any] = {} + + if lenient_issubclass(field.type_, Enum): + # schema is already updated by `enum_process_schema`; just update with field extra + if field.field_info.extra: + f_schema.update(field.field_info.extra) + return f_schema + + if lenient_issubclass(field.type_, (str, bytes)): + for attr_name, t, keyword in _str_types_attrs: + attr = getattr(field.field_info, attr_name, None) + if isinstance(attr, t): + f_schema[keyword] = attr + if lenient_issubclass(field.type_, numeric_types) and not issubclass(field.type_, bool): + for attr_name, t, keyword in _numeric_types_attrs: + attr = getattr(field.field_info, attr_name, None) + if isinstance(attr, t): + f_schema[keyword] = attr + if field.field_info is not None and field.field_info.const: + f_schema['const'] = field.default + if field.field_info.extra: + f_schema.update(field.field_info.extra) + modify_schema = getattr(field.outer_type_, '__modify_schema__', None) + if modify_schema: + _apply_modify_schema(modify_schema, field, f_schema) + return f_schema + + +def get_model_name_map(unique_models: TypeModelSet) -> Dict[TypeModelOrEnum, str]: + """ + Process a set of models and generate unique names for them to be used as keys in the JSON Schema + definitions. By default the names are the same as the class name. But if two models in different Python + modules have the same name (e.g. "users.Model" and "items.Model"), the generated names will be + based on the Python module path for those conflicting models to prevent name collisions. + + :param unique_models: a Python set of models + :return: dict mapping models to names + """ + name_model_map = {} + conflicting_names: Set[str] = set() + for model in unique_models: + model_name = normalize_name(model.__name__) + if model_name in conflicting_names: + model_name = get_long_model_name(model) + name_model_map[model_name] = model + elif model_name in name_model_map: + conflicting_names.add(model_name) + conflicting_model = name_model_map.pop(model_name) + name_model_map[get_long_model_name(conflicting_model)] = conflicting_model + name_model_map[get_long_model_name(model)] = model + else: + name_model_map[model_name] = model + return {v: k for k, v in name_model_map.items()} + + +def get_flat_models_from_model(model: Type['BaseModel'], known_models: TypeModelSet = None) -> TypeModelSet: + """ + Take a single ``model`` and generate a set with itself and all the sub-models in the tree. I.e. if you pass + model ``Foo`` (subclass of Pydantic ``BaseModel``) as ``model``, and it has a field of type ``Bar`` (also + subclass of ``BaseModel``) and that model ``Bar`` has a field of type ``Baz`` (also subclass of ``BaseModel``), + the return value will be ``set([Foo, Bar, Baz])``. + + :param model: a Pydantic ``BaseModel`` subclass + :param known_models: used to solve circular references + :return: a set with the initial model and all its sub-models + """ + known_models = known_models or set() + flat_models: TypeModelSet = set() + flat_models.add(model) + known_models |= flat_models + fields = cast(Sequence[ModelField], model.__fields__.values()) + flat_models |= get_flat_models_from_fields(fields, known_models=known_models) + return flat_models + + +def get_flat_models_from_field(field: ModelField, known_models: TypeModelSet) -> TypeModelSet: + """ + Take a single Pydantic ``ModelField`` (from a model) that could have been declared as a sublcass of BaseModel + (so, it could be a submodel), and generate a set with its model and all the sub-models in the tree. + I.e. if you pass a field that was declared to be of type ``Foo`` (subclass of BaseModel) as ``field``, and that + model ``Foo`` has a field of type ``Bar`` (also subclass of ``BaseModel``) and that model ``Bar`` has a field of + type ``Baz`` (also subclass of ``BaseModel``), the return value will be ``set([Foo, Bar, Baz])``. + + :param field: a Pydantic ``ModelField`` + :param known_models: used to solve circular references + :return: a set with the model used in the declaration for this field, if any, and all its sub-models + """ + from .main import BaseModel + + flat_models: TypeModelSet = set() + + field_type = field.type_ + if lenient_issubclass(getattr(field_type, '__pydantic_model__', None), BaseModel): + field_type = field_type.__pydantic_model__ + + if field.sub_fields and not lenient_issubclass(field_type, BaseModel): + flat_models |= get_flat_models_from_fields(field.sub_fields, known_models=known_models) + elif lenient_issubclass(field_type, BaseModel) and field_type not in known_models: + flat_models |= get_flat_models_from_model(field_type, known_models=known_models) + elif lenient_issubclass(field_type, Enum): + flat_models.add(field_type) + return flat_models + + +def get_flat_models_from_fields(fields: Sequence[ModelField], known_models: TypeModelSet) -> TypeModelSet: + """ + Take a list of Pydantic ``ModelField``s (from a model) that could have been declared as subclasses of ``BaseModel`` + (so, any of them could be a submodel), and generate a set with their models and all the sub-models in the tree. + I.e. if you pass a the fields of a model ``Foo`` (subclass of ``BaseModel``) as ``fields``, and on of them has a + field of type ``Bar`` (also subclass of ``BaseModel``) and that model ``Bar`` has a field of type ``Baz`` (also + subclass of ``BaseModel``), the return value will be ``set([Foo, Bar, Baz])``. + + :param fields: a list of Pydantic ``ModelField``s + :param known_models: used to solve circular references + :return: a set with any model declared in the fields, and all their sub-models + """ + flat_models: TypeModelSet = set() + for field in fields: + flat_models |= get_flat_models_from_field(field, known_models=known_models) + return flat_models + + +def get_flat_models_from_models(models: Sequence[Type['BaseModel']]) -> TypeModelSet: + """ + Take a list of ``models`` and generate a set with them and all their sub-models in their trees. I.e. if you pass + a list of two models, ``Foo`` and ``Bar``, both subclasses of Pydantic ``BaseModel`` as models, and ``Bar`` has + a field of type ``Baz`` (also subclass of ``BaseModel``), the return value will be ``set([Foo, Bar, Baz])``. + """ + flat_models: TypeModelSet = set() + for model in models: + flat_models |= get_flat_models_from_model(model) + return flat_models + + +def get_long_model_name(model: TypeModelOrEnum) -> str: + return f'{model.__module__}__{model.__qualname__}'.replace('.', '__') + + +def field_type_schema( + field: ModelField, + *, + by_alias: bool, + model_name_map: Dict[TypeModelOrEnum, str], + ref_template: str, + schema_overrides: bool = False, + ref_prefix: Optional[str] = None, + known_models: TypeModelSet, +) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: + """ + Used by ``field_schema()``, you probably should be using that function. + + Take a single ``field`` and generate the schema for its type only, not including additional + information as title, etc. Also return additional schema definitions, from sub-models. + """ + from .main import BaseModel # noqa: F811 + + definitions = {} + nested_models: Set[str] = set() + f_schema: Dict[str, Any] + if field.shape in { + SHAPE_LIST, + SHAPE_TUPLE_ELLIPSIS, + SHAPE_SEQUENCE, + SHAPE_SET, + SHAPE_FROZENSET, + SHAPE_ITERABLE, + SHAPE_DEQUE, + }: + items_schema, f_definitions, f_nested_models = field_singleton_schema( + field, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + definitions.update(f_definitions) + nested_models.update(f_nested_models) + f_schema = {'type': 'array', 'items': items_schema} + if field.shape in {SHAPE_SET, SHAPE_FROZENSET}: + f_schema['uniqueItems'] = True + + elif field.shape in MAPPING_LIKE_SHAPES: + f_schema = {'type': 'object'} + key_field = cast(ModelField, field.key_field) + regex = getattr(key_field.type_, 'regex', None) + items_schema, f_definitions, f_nested_models = field_singleton_schema( + field, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + definitions.update(f_definitions) + nested_models.update(f_nested_models) + if regex: + # Dict keys have a regex pattern + # items_schema might be a schema or empty dict, add it either way + f_schema['patternProperties'] = {regex.pattern: items_schema} + elif items_schema: + # The dict values are not simply Any, so they need a schema + f_schema['additionalProperties'] = items_schema + elif field.shape == SHAPE_TUPLE or (field.shape == SHAPE_GENERIC and not issubclass(field.type_, BaseModel)): + sub_schema = [] + sub_fields = cast(List[ModelField], field.sub_fields) + for sf in sub_fields: + sf_schema, sf_definitions, sf_nested_models = field_type_schema( + sf, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + definitions.update(sf_definitions) + nested_models.update(sf_nested_models) + sub_schema.append(sf_schema) + + sub_fields_len = len(sub_fields) + if field.shape == SHAPE_GENERIC: + all_of_schemas = sub_schema[0] if sub_fields_len == 1 else {'type': 'array', 'items': sub_schema} + f_schema = {'allOf': [all_of_schemas]} + else: + f_schema = { + 'type': 'array', + 'minItems': sub_fields_len, + 'maxItems': sub_fields_len, + } + if sub_fields_len >= 1: + f_schema['items'] = sub_schema + else: + assert field.shape in {SHAPE_SINGLETON, SHAPE_GENERIC}, field.shape + f_schema, f_definitions, f_nested_models = field_singleton_schema( + field, + by_alias=by_alias, + model_name_map=model_name_map, + schema_overrides=schema_overrides, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + definitions.update(f_definitions) + nested_models.update(f_nested_models) + + # check field type to avoid repeated calls to the same __modify_schema__ method + if field.type_ != field.outer_type_: + if field.shape == SHAPE_GENERIC: + field_type = field.type_ + else: + field_type = field.outer_type_ + modify_schema = getattr(field_type, '__modify_schema__', None) + if modify_schema: + _apply_modify_schema(modify_schema, field, f_schema) + return f_schema, definitions, nested_models + + +def model_process_schema( + model: TypeModelOrEnum, + *, + by_alias: bool = True, + model_name_map: Dict[TypeModelOrEnum, str], + ref_prefix: Optional[str] = None, + ref_template: str = default_ref_template, + known_models: TypeModelSet = None, + field: Optional[ModelField] = None, +) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: + """ + Used by ``model_schema()``, you probably should be using that function. + + Take a single ``model`` and generate its schema. Also return additional schema definitions, from sub-models. The + sub-models of the returned schema will be referenced, but their definitions will not be included in the schema. All + the definitions are returned as the second value. + """ + from inspect import getdoc, signature + + known_models = known_models or set() + if lenient_issubclass(model, Enum): + model = cast(Type[Enum], model) + s = enum_process_schema(model, field=field) + return s, {}, set() + model = cast(Type['BaseModel'], model) + s = {'title': model.__config__.title or model.__name__} + doc = getdoc(model) + if doc: + s['description'] = doc + known_models.add(model) + m_schema, m_definitions, nested_models = model_type_schema( + model, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + s.update(m_schema) + schema_extra = model.__config__.schema_extra + if callable(schema_extra): + if len(signature(schema_extra).parameters) == 1: + schema_extra(s) + else: + schema_extra(s, model) + else: + s.update(schema_extra) + return s, m_definitions, nested_models + + +def model_type_schema( + model: Type['BaseModel'], + *, + by_alias: bool, + model_name_map: Dict[TypeModelOrEnum, str], + ref_template: str, + ref_prefix: Optional[str] = None, + known_models: TypeModelSet, +) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: + """ + You probably should be using ``model_schema()``, this function is indirectly used by that function. + + Take a single ``model`` and generate the schema for its type only, not including additional + information as title, etc. Also return additional schema definitions, from sub-models. + """ + properties = {} + required = [] + definitions: Dict[str, Any] = {} + nested_models: Set[str] = set() + for k, f in model.__fields__.items(): + try: + f_schema, f_definitions, f_nested_models = field_schema( + f, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + except SkipField as skip: + warnings.warn(skip.message, UserWarning) + continue + definitions.update(f_definitions) + nested_models.update(f_nested_models) + if by_alias: + properties[f.alias] = f_schema + if f.required: + required.append(f.alias) + else: + properties[k] = f_schema + if f.required: + required.append(k) + if ROOT_KEY in properties: + out_schema = properties[ROOT_KEY] + out_schema['title'] = model.__config__.title or model.__name__ + else: + out_schema = {'type': 'object', 'properties': properties} + if required: + out_schema['required'] = required + if model.__config__.extra == 'forbid': + out_schema['additionalProperties'] = False + return out_schema, definitions, nested_models + + +def enum_process_schema(enum: Type[Enum], *, field: Optional[ModelField] = None) -> Dict[str, Any]: + """ + Take a single `enum` and generate its schema. + + This is similar to the `model_process_schema` function, but applies to ``Enum`` objects. + """ + schema_: Dict[str, Any] = { + 'title': enum.__name__, + # Python assigns all enums a default docstring value of 'An enumeration', so + # all enums will have a description field even if not explicitly provided. + 'description': enum.__doc__ or 'An enumeration.', + # Add enum values and the enum field type to the schema. + 'enum': [item.value for item in cast(Iterable[Enum], enum)], + } + + add_field_type_to_schema(enum, schema_) + + modify_schema = getattr(enum, '__modify_schema__', None) + if modify_schema: + _apply_modify_schema(modify_schema, field, schema_) + + return schema_ + + +def field_singleton_sub_fields_schema( + field: ModelField, + *, + by_alias: bool, + model_name_map: Dict[TypeModelOrEnum, str], + ref_template: str, + schema_overrides: bool = False, + ref_prefix: Optional[str] = None, + known_models: TypeModelSet, +) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: + """ + This function is indirectly used by ``field_schema()``, you probably should be using that function. + + Take a list of Pydantic ``ModelField`` from the declaration of a type with parameters, and generate their + schema. I.e., fields used as "type parameters", like ``str`` and ``int`` in ``Tuple[str, int]``. + """ + sub_fields = cast(List[ModelField], field.sub_fields) + definitions = {} + nested_models: Set[str] = set() + if len(sub_fields) == 1: + return field_type_schema( + sub_fields[0], + by_alias=by_alias, + model_name_map=model_name_map, + schema_overrides=schema_overrides, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + else: + s: Dict[str, Any] = {} + # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#discriminator-object + field_has_discriminator: bool = field.discriminator_key is not None + if field_has_discriminator: + assert field.sub_fields_mapping is not None + + discriminator_models_refs: Dict[str, Union[str, Dict[str, Any]]] = {} + + for discriminator_value, sub_field in field.sub_fields_mapping.items(): + # sub_field is either a `BaseModel` or directly an `Annotated` `Union` of many + if is_union(get_origin(sub_field.type_)): + sub_models = get_sub_types(sub_field.type_) + discriminator_models_refs[discriminator_value] = { + model_name_map[sub_model]: get_schema_ref( + model_name_map[sub_model], ref_prefix, ref_template, False + ) + for sub_model in sub_models + } + else: + sub_field_type = sub_field.type_ + if hasattr(sub_field_type, '__pydantic_model__'): + sub_field_type = sub_field_type.__pydantic_model__ + + discriminator_model_name = model_name_map[sub_field_type] + discriminator_model_ref = get_schema_ref(discriminator_model_name, ref_prefix, ref_template, False) + discriminator_models_refs[discriminator_value] = discriminator_model_ref['$ref'] + + s['discriminator'] = { + 'propertyName': field.discriminator_alias, + 'mapping': discriminator_models_refs, + } + + sub_field_schemas = [] + for sf in sub_fields: + sub_schema, sub_definitions, sub_nested_models = field_type_schema( + sf, + by_alias=by_alias, + model_name_map=model_name_map, + schema_overrides=schema_overrides, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + definitions.update(sub_definitions) + if schema_overrides and 'allOf' in sub_schema: + # if the sub_field is a referenced schema we only need the referenced + # object. Otherwise we will end up with several allOf inside anyOf/oneOf. + # See https://github.com/pydantic/pydantic/issues/1209 + sub_schema = sub_schema['allOf'][0] + + if sub_schema.keys() == {'discriminator', 'oneOf'}: + # we don't want discriminator information inside oneOf choices, this is dealt with elsewhere + sub_schema.pop('discriminator') + sub_field_schemas.append(sub_schema) + nested_models.update(sub_nested_models) + s['oneOf' if field_has_discriminator else 'anyOf'] = sub_field_schemas + return s, definitions, nested_models + + +# Order is important, e.g. subclasses of str must go before str +# this is used only for standard library types, custom types should use __modify_schema__ instead +field_class_to_schema: Tuple[Tuple[Any, Dict[str, Any]], ...] = ( + (Path, {'type': 'string', 'format': 'path'}), + (datetime, {'type': 'string', 'format': 'date-time'}), + (date, {'type': 'string', 'format': 'date'}), + (time, {'type': 'string', 'format': 'time'}), + (timedelta, {'type': 'number', 'format': 'time-delta'}), + (IPv4Network, {'type': 'string', 'format': 'ipv4network'}), + (IPv6Network, {'type': 'string', 'format': 'ipv6network'}), + (IPv4Interface, {'type': 'string', 'format': 'ipv4interface'}), + (IPv6Interface, {'type': 'string', 'format': 'ipv6interface'}), + (IPv4Address, {'type': 'string', 'format': 'ipv4'}), + (IPv6Address, {'type': 'string', 'format': 'ipv6'}), + (Pattern, {'type': 'string', 'format': 'regex'}), + (str, {'type': 'string'}), + (bytes, {'type': 'string', 'format': 'binary'}), + (bool, {'type': 'boolean'}), + (int, {'type': 'integer'}), + (float, {'type': 'number'}), + (Decimal, {'type': 'number'}), + (UUID, {'type': 'string', 'format': 'uuid'}), + (dict, {'type': 'object'}), + (list, {'type': 'array', 'items': {}}), + (tuple, {'type': 'array', 'items': {}}), + (set, {'type': 'array', 'items': {}, 'uniqueItems': True}), + (frozenset, {'type': 'array', 'items': {}, 'uniqueItems': True}), +) + +json_scheme = {'type': 'string', 'format': 'json-string'} + + +def add_field_type_to_schema(field_type: Any, schema_: Dict[str, Any]) -> None: + """ + Update the given `schema` with the type-specific metadata for the given `field_type`. + + This function looks through `field_class_to_schema` for a class that matches the given `field_type`, + and then modifies the given `schema` with the information from that type. + """ + for type_, t_schema in field_class_to_schema: + # Fallback for `typing.Pattern` and `re.Pattern` as they are not a valid class + if lenient_issubclass(field_type, type_) or field_type is type_ is Pattern: + schema_.update(t_schema) + break + + +def get_schema_ref(name: str, ref_prefix: Optional[str], ref_template: str, schema_overrides: bool) -> Dict[str, Any]: + if ref_prefix: + schema_ref = {'$ref': ref_prefix + name} + else: + schema_ref = {'$ref': ref_template.format(model=name)} + return {'allOf': [schema_ref]} if schema_overrides else schema_ref + + +def field_singleton_schema( # noqa: C901 (ignore complexity) + field: ModelField, + *, + by_alias: bool, + model_name_map: Dict[TypeModelOrEnum, str], + ref_template: str, + schema_overrides: bool = False, + ref_prefix: Optional[str] = None, + known_models: TypeModelSet, +) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: + """ + This function is indirectly used by ``field_schema()``, you should probably be using that function. + + Take a single Pydantic ``ModelField``, and return its schema and any additional definitions from sub-models. + """ + from .main import BaseModel + + definitions: Dict[str, Any] = {} + nested_models: Set[str] = set() + field_type = field.type_ + + # Recurse into this field if it contains sub_fields and is NOT a + # BaseModel OR that BaseModel is a const + if field.sub_fields and ( + (field.field_info and field.field_info.const) or not lenient_issubclass(field_type, BaseModel) + ): + return field_singleton_sub_fields_schema( + field, + by_alias=by_alias, + model_name_map=model_name_map, + schema_overrides=schema_overrides, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + if field_type is Any or field_type is object or field_type.__class__ == TypeVar or get_origin(field_type) is type: + return {}, definitions, nested_models # no restrictions + if is_none_type(field_type): + return {'type': 'null'}, definitions, nested_models + if is_callable_type(field_type): + raise SkipField(f'Callable {field.name} was excluded from schema since JSON schema has no equivalent type.') + f_schema: Dict[str, Any] = {} + if field.field_info is not None and field.field_info.const: + f_schema['const'] = field.default + + if is_literal_type(field_type): + values = all_literal_values(field_type) + + if len({v.__class__ for v in values}) > 1: + return field_schema( + multitypes_literal_field_for_schema(values, field), + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + + # All values have the same type + field_type = values[0].__class__ + f_schema['enum'] = list(values) + add_field_type_to_schema(field_type, f_schema) + elif lenient_issubclass(field_type, Enum): + enum_name = model_name_map[field_type] + f_schema, schema_overrides = get_field_info_schema(field, schema_overrides) + f_schema.update(get_schema_ref(enum_name, ref_prefix, ref_template, schema_overrides)) + definitions[enum_name] = enum_process_schema(field_type, field=field) + elif is_namedtuple(field_type): + sub_schema, *_ = model_process_schema( + field_type.__pydantic_model__, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + field=field, + ) + items_schemas = list(sub_schema['properties'].values()) + f_schema.update( + { + 'type': 'array', + 'items': items_schemas, + 'minItems': len(items_schemas), + 'maxItems': len(items_schemas), + } + ) + elif not hasattr(field_type, '__pydantic_model__'): + add_field_type_to_schema(field_type, f_schema) + + modify_schema = getattr(field_type, '__modify_schema__', None) + if modify_schema: + _apply_modify_schema(modify_schema, field, f_schema) + + if f_schema: + return f_schema, definitions, nested_models + + # Handle dataclass-based models + if lenient_issubclass(getattr(field_type, '__pydantic_model__', None), BaseModel): + field_type = field_type.__pydantic_model__ + + if issubclass(field_type, BaseModel): + model_name = model_name_map[field_type] + if field_type not in known_models: + sub_schema, sub_definitions, sub_nested_models = model_process_schema( + field_type, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + field=field, + ) + definitions.update(sub_definitions) + definitions[model_name] = sub_schema + nested_models.update(sub_nested_models) + else: + nested_models.add(model_name) + schema_ref = get_schema_ref(model_name, ref_prefix, ref_template, schema_overrides) + return schema_ref, definitions, nested_models + + # For generics with no args + args = get_args(field_type) + if args is not None and not args and Generic in field_type.__bases__: + return f_schema, definitions, nested_models + + raise ValueError(f'Value not declarable with JSON Schema, field: {field}') + + +def multitypes_literal_field_for_schema(values: Tuple[Any, ...], field: ModelField) -> ModelField: + """ + To support `Literal` with values of different types, we split it into multiple `Literal` with same type + e.g. `Literal['qwe', 'asd', 1, 2]` becomes `Union[Literal['qwe', 'asd'], Literal[1, 2]]` + """ + literal_distinct_types = defaultdict(list) + for v in values: + literal_distinct_types[v.__class__].append(v) + distinct_literals = (Literal[tuple(same_type_values)] for same_type_values in literal_distinct_types.values()) + + return ModelField( + name=field.name, + type_=Union[tuple(distinct_literals)], # type: ignore + class_validators=field.class_validators, + model_config=field.model_config, + default=field.default, + required=field.required, + alias=field.alias, + field_info=field.field_info, + ) + + +def encode_default(dft: Any) -> Any: + if isinstance(dft, Enum): + return dft.value + elif isinstance(dft, (int, float, str)): + return dft + elif isinstance(dft, (list, tuple)): + t = dft.__class__ + seq_args = (encode_default(v) for v in dft) + return t(*seq_args) if is_namedtuple(t) else t(seq_args) + elif isinstance(dft, dict): + return {encode_default(k): encode_default(v) for k, v in dft.items()} + elif dft is None: + return None + else: + return pydantic_encoder(dft) + + +_map_types_constraint: Dict[Any, Callable[..., type]] = {int: conint, float: confloat, Decimal: condecimal} + + +def get_annotation_from_field_info( + annotation: Any, field_info: FieldInfo, field_name: str, validate_assignment: bool = False +) -> Type[Any]: + """ + Get an annotation with validation implemented for numbers and strings based on the field_info. + :param annotation: an annotation from a field specification, as ``str``, ``ConstrainedStr`` + :param field_info: an instance of FieldInfo, possibly with declarations for validations and JSON Schema + :param field_name: name of the field for use in error messages + :param validate_assignment: default False, flag for BaseModel Config value of validate_assignment + :return: the same ``annotation`` if unmodified or a new annotation with validation in place + """ + constraints = field_info.get_constraints() + used_constraints: Set[str] = set() + if constraints: + annotation, used_constraints = get_annotation_with_constraints(annotation, field_info) + if validate_assignment: + used_constraints.add('allow_mutation') + + unused_constraints = constraints - used_constraints + if unused_constraints: + raise ValueError( + f'On field "{field_name}" the following field constraints are set but not enforced: ' + f'{", ".join(unused_constraints)}. ' + f'\nFor more details see https://pydantic-docs.helpmanual.io/usage/schema/#unenforced-field-constraints' + ) + + return annotation + + +def get_annotation_with_constraints(annotation: Any, field_info: FieldInfo) -> Tuple[Type[Any], Set[str]]: # noqa: C901 + """ + Get an annotation with used constraints implemented for numbers and strings based on the field_info. + + :param annotation: an annotation from a field specification, as ``str``, ``ConstrainedStr`` + :param field_info: an instance of FieldInfo, possibly with declarations for validations and JSON Schema + :return: the same ``annotation`` if unmodified or a new annotation along with the used constraints. + """ + used_constraints: Set[str] = set() + + def go(type_: Any) -> Type[Any]: + if ( + is_literal_type(type_) + or isinstance(type_, ForwardRef) + or lenient_issubclass(type_, (ConstrainedList, ConstrainedSet, ConstrainedFrozenSet)) + ): + return type_ + origin = get_origin(type_) + if origin is not None: + args: Tuple[Any, ...] = get_args(type_) + if any(isinstance(a, ForwardRef) for a in args): + # forward refs cause infinite recursion below + return type_ + + if origin is Annotated: + return go(args[0]) + if is_union(origin): + return Union[tuple(go(a) for a in args)] # type: ignore + + if issubclass(origin, List) and ( + field_info.min_items is not None + or field_info.max_items is not None + or field_info.unique_items is not None + ): + used_constraints.update({'min_items', 'max_items', 'unique_items'}) + return conlist( + go(args[0]), + min_items=field_info.min_items, + max_items=field_info.max_items, + unique_items=field_info.unique_items, + ) + + if issubclass(origin, Set) and (field_info.min_items is not None or field_info.max_items is not None): + used_constraints.update({'min_items', 'max_items'}) + return conset(go(args[0]), min_items=field_info.min_items, max_items=field_info.max_items) + + if issubclass(origin, FrozenSet) and (field_info.min_items is not None or field_info.max_items is not None): + used_constraints.update({'min_items', 'max_items'}) + return confrozenset(go(args[0]), min_items=field_info.min_items, max_items=field_info.max_items) + + for t in (Tuple, List, Set, FrozenSet, Sequence): + if issubclass(origin, t): # type: ignore + return t[tuple(go(a) for a in args)] # type: ignore + + if issubclass(origin, Dict): + return Dict[args[0], go(args[1])] # type: ignore + + attrs: Optional[Tuple[str, ...]] = None + constraint_func: Optional[Callable[..., type]] = None + if isinstance(type_, type): + if issubclass(type_, (SecretStr, SecretBytes)): + attrs = ('max_length', 'min_length') + + def constraint_func(**kw: Any) -> Type[Any]: + return type(type_.__name__, (type_,), kw) + + elif issubclass(type_, str) and not issubclass(type_, (EmailStr, AnyUrl)): + attrs = ('max_length', 'min_length', 'regex') + if issubclass(type_, StrictStr): + + def constraint_func(**kw: Any) -> Type[Any]: + return type(type_.__name__, (type_,), kw) + + else: + constraint_func = constr + elif issubclass(type_, bytes): + attrs = ('max_length', 'min_length', 'regex') + if issubclass(type_, StrictBytes): + + def constraint_func(**kw: Any) -> Type[Any]: + return type(type_.__name__, (type_,), kw) + + else: + constraint_func = conbytes + elif issubclass(type_, numeric_types) and not issubclass( + type_, + ( + ConstrainedInt, + ConstrainedFloat, + ConstrainedDecimal, + ConstrainedList, + ConstrainedSet, + ConstrainedFrozenSet, + bool, + ), + ): + # Is numeric type + attrs = ('gt', 'lt', 'ge', 'le', 'multiple_of') + if issubclass(type_, float): + attrs += ('allow_inf_nan',) + if issubclass(type_, Decimal): + attrs += ('max_digits', 'decimal_places') + numeric_type = next(t for t in numeric_types if issubclass(type_, t)) # pragma: no branch + constraint_func = _map_types_constraint[numeric_type] + + if attrs: + used_constraints.update(set(attrs)) + kwargs = { + attr_name: attr + for attr_name, attr in ((attr_name, getattr(field_info, attr_name)) for attr_name in attrs) + if attr is not None + } + if kwargs: + constraint_func = cast(Callable[..., type], constraint_func) + return constraint_func(**kwargs) + return type_ + + return go(annotation), used_constraints + + +def normalize_name(name: str) -> str: + """ + Normalizes the given name. This can be applied to either a model *or* enum. + """ + return re.sub(r'[^a-zA-Z0-9.\-_]', '_', name) + + +class SkipField(Exception): + """ + Utility exception used to exclude fields from schema. + """ + + def __init__(self, message: str) -> None: + self.message = message diff --git a/libs/win/pydantic/tools.cp37-win_amd64.pyd b/libs/win/pydantic/tools.cp37-win_amd64.pyd new file mode 100644 index 00000000..cd8cf09e Binary files /dev/null and b/libs/win/pydantic/tools.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/tools.py b/libs/win/pydantic/tools.py new file mode 100644 index 00000000..9cdb4538 --- /dev/null +++ b/libs/win/pydantic/tools.py @@ -0,0 +1,92 @@ +import json +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Union + +from .parse import Protocol, load_file, load_str_bytes +from .types import StrBytes +from .typing import display_as_type + +__all__ = ('parse_file_as', 'parse_obj_as', 'parse_raw_as', 'schema_of', 'schema_json_of') + +NameFactory = Union[str, Callable[[Type[Any]], str]] + +if TYPE_CHECKING: + from .typing import DictStrAny + + +def _generate_parsing_type_name(type_: Any) -> str: + return f'ParsingModel[{display_as_type(type_)}]' + + +@lru_cache(maxsize=2048) +def _get_parsing_type(type_: Any, *, type_name: Optional[NameFactory] = None) -> Any: + from pydantic.main import create_model + + if type_name is None: + type_name = _generate_parsing_type_name + if not isinstance(type_name, str): + type_name = type_name(type_) + return create_model(type_name, __root__=(type_, ...)) + + +T = TypeVar('T') + + +def parse_obj_as(type_: Type[T], obj: Any, *, type_name: Optional[NameFactory] = None) -> T: + model_type = _get_parsing_type(type_, type_name=type_name) # type: ignore[arg-type] + return model_type(__root__=obj).__root__ + + +def parse_file_as( + type_: Type[T], + path: Union[str, Path], + *, + content_type: str = None, + encoding: str = 'utf8', + proto: Protocol = None, + allow_pickle: bool = False, + json_loads: Callable[[str], Any] = json.loads, + type_name: Optional[NameFactory] = None, +) -> T: + obj = load_file( + path, + proto=proto, + content_type=content_type, + encoding=encoding, + allow_pickle=allow_pickle, + json_loads=json_loads, + ) + return parse_obj_as(type_, obj, type_name=type_name) + + +def parse_raw_as( + type_: Type[T], + b: StrBytes, + *, + content_type: str = None, + encoding: str = 'utf8', + proto: Protocol = None, + allow_pickle: bool = False, + json_loads: Callable[[str], Any] = json.loads, + type_name: Optional[NameFactory] = None, +) -> T: + obj = load_str_bytes( + b, + proto=proto, + content_type=content_type, + encoding=encoding, + allow_pickle=allow_pickle, + json_loads=json_loads, + ) + return parse_obj_as(type_, obj, type_name=type_name) + + +def schema_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_kwargs: Any) -> 'DictStrAny': + """Generate a JSON schema (as dict) for the passed model or dynamically generated one""" + return _get_parsing_type(type_, type_name=title).schema(**schema_kwargs) + + +def schema_json_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_json_kwargs: Any) -> str: + """Generate a JSON schema (as JSON) for the passed model or dynamically generated one""" + return _get_parsing_type(type_, type_name=title).schema_json(**schema_json_kwargs) diff --git a/libs/win/pydantic/types.cp37-win_amd64.pyd b/libs/win/pydantic/types.cp37-win_amd64.pyd new file mode 100644 index 00000000..da1767b2 Binary files /dev/null and b/libs/win/pydantic/types.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/types.py b/libs/win/pydantic/types.py new file mode 100644 index 00000000..f98dba3d --- /dev/null +++ b/libs/win/pydantic/types.py @@ -0,0 +1,1187 @@ +import abc +import math +import re +import warnings +from datetime import date +from decimal import Decimal +from enum import Enum +from pathlib import Path +from types import new_class +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + FrozenSet, + List, + Optional, + Pattern, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) +from uuid import UUID +from weakref import WeakSet + +from . import errors +from .datetime_parse import parse_date +from .utils import import_string, update_not_none +from .validators import ( + bytes_validator, + constr_length_validator, + constr_lower, + constr_strip_whitespace, + constr_upper, + decimal_validator, + float_finite_validator, + float_validator, + frozenset_validator, + int_validator, + list_validator, + number_multiple_validator, + number_size_validator, + path_exists_validator, + path_validator, + set_validator, + str_validator, + strict_bytes_validator, + strict_float_validator, + strict_int_validator, + strict_str_validator, +) + +__all__ = [ + 'NoneStr', + 'NoneBytes', + 'StrBytes', + 'NoneStrBytes', + 'StrictStr', + 'ConstrainedBytes', + 'conbytes', + 'ConstrainedList', + 'conlist', + 'ConstrainedSet', + 'conset', + 'ConstrainedFrozenSet', + 'confrozenset', + 'ConstrainedStr', + 'constr', + 'PyObject', + 'ConstrainedInt', + 'conint', + 'PositiveInt', + 'NegativeInt', + 'NonNegativeInt', + 'NonPositiveInt', + 'ConstrainedFloat', + 'confloat', + 'PositiveFloat', + 'NegativeFloat', + 'NonNegativeFloat', + 'NonPositiveFloat', + 'FiniteFloat', + 'ConstrainedDecimal', + 'condecimal', + 'UUID1', + 'UUID3', + 'UUID4', + 'UUID5', + 'FilePath', + 'DirectoryPath', + 'Json', + 'JsonWrapper', + 'SecretField', + 'SecretStr', + 'SecretBytes', + 'StrictBool', + 'StrictBytes', + 'StrictInt', + 'StrictFloat', + 'PaymentCardNumber', + 'ByteSize', + 'PastDate', + 'FutureDate', + 'ConstrainedDate', + 'condate', +] + +NoneStr = Optional[str] +NoneBytes = Optional[bytes] +StrBytes = Union[str, bytes] +NoneStrBytes = Optional[StrBytes] +OptionalInt = Optional[int] +OptionalIntFloat = Union[OptionalInt, float] +OptionalIntFloatDecimal = Union[OptionalIntFloat, Decimal] +OptionalDate = Optional[date] +StrIntFloat = Union[str, int, float] + +if TYPE_CHECKING: + from typing_extensions import Annotated + + from .dataclasses import Dataclass + from .main import BaseModel + from .typing import CallableGenerator + + ModelOrDc = Type[Union[BaseModel, Dataclass]] + +T = TypeVar('T') +_DEFINED_TYPES: 'WeakSet[type]' = WeakSet() + + +@overload +def _registered(typ: Type[T]) -> Type[T]: + pass + + +@overload +def _registered(typ: 'ConstrainedNumberMeta') -> 'ConstrainedNumberMeta': + pass + + +def _registered(typ: Union[Type[T], 'ConstrainedNumberMeta']) -> Union[Type[T], 'ConstrainedNumberMeta']: + # In order to generate valid examples of constrained types, Hypothesis needs + # to inspect the type object - so we keep a weakref to each contype object + # until it can be registered. When (or if) our Hypothesis plugin is loaded, + # it monkeypatches this function. + # If Hypothesis is never used, the total effect is to keep a weak reference + # which has minimal memory usage and doesn't even affect garbage collection. + _DEFINED_TYPES.add(typ) + return typ + + +class ConstrainedNumberMeta(type): + def __new__(cls, name: str, bases: Any, dct: Dict[str, Any]) -> 'ConstrainedInt': # type: ignore + new_cls = cast('ConstrainedInt', type.__new__(cls, name, bases, dct)) + + if new_cls.gt is not None and new_cls.ge is not None: + raise errors.ConfigError('bounds gt and ge cannot be specified at the same time') + if new_cls.lt is not None and new_cls.le is not None: + raise errors.ConfigError('bounds lt and le cannot be specified at the same time') + + return _registered(new_cls) # type: ignore + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BOOLEAN TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +if TYPE_CHECKING: + StrictBool = bool +else: + + class StrictBool(int): + """ + StrictBool to allow for bools which are not type-coerced. + """ + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='boolean') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls, value: Any) -> bool: + """ + Ensure that we only allow bools. + """ + if isinstance(value, bool): + return value + + raise errors.StrictBoolError() + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTEGER TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class ConstrainedInt(int, metaclass=ConstrainedNumberMeta): + strict: bool = False + gt: OptionalInt = None + ge: OptionalInt = None + lt: OptionalInt = None + le: OptionalInt = None + multiple_of: OptionalInt = None + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none( + field_schema, + exclusiveMinimum=cls.gt, + exclusiveMaximum=cls.lt, + minimum=cls.ge, + maximum=cls.le, + multipleOf=cls.multiple_of, + ) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield strict_int_validator if cls.strict else int_validator + yield number_size_validator + yield number_multiple_validator + + +def conint( + *, strict: bool = False, gt: int = None, ge: int = None, lt: int = None, le: int = None, multiple_of: int = None +) -> Type[int]: + # use kwargs then define conf in a dict to aid with IDE type hinting + namespace = dict(strict=strict, gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of) + return type('ConstrainedIntValue', (ConstrainedInt,), namespace) + + +if TYPE_CHECKING: + PositiveInt = int + NegativeInt = int + NonPositiveInt = int + NonNegativeInt = int + StrictInt = int +else: + + class PositiveInt(ConstrainedInt): + gt = 0 + + class NegativeInt(ConstrainedInt): + lt = 0 + + class NonPositiveInt(ConstrainedInt): + le = 0 + + class NonNegativeInt(ConstrainedInt): + ge = 0 + + class StrictInt(ConstrainedInt): + strict = True + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLOAT TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class ConstrainedFloat(float, metaclass=ConstrainedNumberMeta): + strict: bool = False + gt: OptionalIntFloat = None + ge: OptionalIntFloat = None + lt: OptionalIntFloat = None + le: OptionalIntFloat = None + multiple_of: OptionalIntFloat = None + allow_inf_nan: Optional[bool] = None + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none( + field_schema, + exclusiveMinimum=cls.gt, + exclusiveMaximum=cls.lt, + minimum=cls.ge, + maximum=cls.le, + multipleOf=cls.multiple_of, + ) + # Modify constraints to account for differences between IEEE floats and JSON + if field_schema.get('exclusiveMinimum') == -math.inf: + del field_schema['exclusiveMinimum'] + if field_schema.get('minimum') == -math.inf: + del field_schema['minimum'] + if field_schema.get('exclusiveMaximum') == math.inf: + del field_schema['exclusiveMaximum'] + if field_schema.get('maximum') == math.inf: + del field_schema['maximum'] + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield strict_float_validator if cls.strict else float_validator + yield number_size_validator + yield number_multiple_validator + yield float_finite_validator + + +def confloat( + *, + strict: bool = False, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + multiple_of: float = None, + allow_inf_nan: Optional[bool] = None, +) -> Type[float]: + # use kwargs then define conf in a dict to aid with IDE type hinting + namespace = dict(strict=strict, gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of, allow_inf_nan=allow_inf_nan) + return type('ConstrainedFloatValue', (ConstrainedFloat,), namespace) + + +if TYPE_CHECKING: + PositiveFloat = float + NegativeFloat = float + NonPositiveFloat = float + NonNegativeFloat = float + StrictFloat = float + FiniteFloat = float +else: + + class PositiveFloat(ConstrainedFloat): + gt = 0 + + class NegativeFloat(ConstrainedFloat): + lt = 0 + + class NonPositiveFloat(ConstrainedFloat): + le = 0 + + class NonNegativeFloat(ConstrainedFloat): + ge = 0 + + class StrictFloat(ConstrainedFloat): + strict = True + + class FiniteFloat(ConstrainedFloat): + allow_inf_nan = False + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTES TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class ConstrainedBytes(bytes): + strip_whitespace = False + to_upper = False + to_lower = False + min_length: OptionalInt = None + max_length: OptionalInt = None + strict: bool = False + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none(field_schema, minLength=cls.min_length, maxLength=cls.max_length) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield strict_bytes_validator if cls.strict else bytes_validator + yield constr_strip_whitespace + yield constr_upper + yield constr_lower + yield constr_length_validator + + +def conbytes( + *, + strip_whitespace: bool = False, + to_upper: bool = False, + to_lower: bool = False, + min_length: int = None, + max_length: int = None, + strict: bool = False, +) -> Type[bytes]: + # use kwargs then define conf in a dict to aid with IDE type hinting + namespace = dict( + strip_whitespace=strip_whitespace, + to_upper=to_upper, + to_lower=to_lower, + min_length=min_length, + max_length=max_length, + strict=strict, + ) + return _registered(type('ConstrainedBytesValue', (ConstrainedBytes,), namespace)) + + +if TYPE_CHECKING: + StrictBytes = bytes +else: + + class StrictBytes(ConstrainedBytes): + strict = True + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ STRING TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class ConstrainedStr(str): + strip_whitespace = False + to_upper = False + to_lower = False + min_length: OptionalInt = None + max_length: OptionalInt = None + curtail_length: OptionalInt = None + regex: Optional[Pattern[str]] = None + strict = False + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none( + field_schema, + minLength=cls.min_length, + maxLength=cls.max_length, + pattern=cls.regex and cls.regex.pattern, + ) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield strict_str_validator if cls.strict else str_validator + yield constr_strip_whitespace + yield constr_upper + yield constr_lower + yield constr_length_validator + yield cls.validate + + @classmethod + def validate(cls, value: Union[str]) -> Union[str]: + if cls.curtail_length and len(value) > cls.curtail_length: + value = value[: cls.curtail_length] + + if cls.regex: + if not cls.regex.match(value): + raise errors.StrRegexError(pattern=cls.regex.pattern) + + return value + + +def constr( + *, + strip_whitespace: bool = False, + to_upper: bool = False, + to_lower: bool = False, + strict: bool = False, + min_length: int = None, + max_length: int = None, + curtail_length: int = None, + regex: str = None, +) -> Type[str]: + # use kwargs then define conf in a dict to aid with IDE type hinting + namespace = dict( + strip_whitespace=strip_whitespace, + to_upper=to_upper, + to_lower=to_lower, + strict=strict, + min_length=min_length, + max_length=max_length, + curtail_length=curtail_length, + regex=regex and re.compile(regex), + ) + return _registered(type('ConstrainedStrValue', (ConstrainedStr,), namespace)) + + +if TYPE_CHECKING: + StrictStr = str +else: + + class StrictStr(ConstrainedStr): + strict = True + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +# This types superclass should be Set[T], but cython chokes on that... +class ConstrainedSet(set): # type: ignore + # Needed for pydantic to detect that this is a set + __origin__ = set + __args__: Set[Type[T]] # type: ignore + + min_items: Optional[int] = None + max_items: Optional[int] = None + item_type: Type[T] # type: ignore + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.set_length_validator + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none(field_schema, minItems=cls.min_items, maxItems=cls.max_items) + + @classmethod + def set_length_validator(cls, v: 'Optional[Set[T]]') -> 'Optional[Set[T]]': + if v is None: + return None + + v = set_validator(v) + v_len = len(v) + + if cls.min_items is not None and v_len < cls.min_items: + raise errors.SetMinLengthError(limit_value=cls.min_items) + + if cls.max_items is not None and v_len > cls.max_items: + raise errors.SetMaxLengthError(limit_value=cls.max_items) + + return v + + +def conset(item_type: Type[T], *, min_items: int = None, max_items: int = None) -> Type[Set[T]]: + # __args__ is needed to conform to typing generics api + namespace = {'min_items': min_items, 'max_items': max_items, 'item_type': item_type, '__args__': [item_type]} + # We use new_class to be able to deal with Generic types + return new_class('ConstrainedSetValue', (ConstrainedSet,), {}, lambda ns: ns.update(namespace)) + + +# This types superclass should be FrozenSet[T], but cython chokes on that... +class ConstrainedFrozenSet(frozenset): # type: ignore + # Needed for pydantic to detect that this is a set + __origin__ = frozenset + __args__: FrozenSet[Type[T]] # type: ignore + + min_items: Optional[int] = None + max_items: Optional[int] = None + item_type: Type[T] # type: ignore + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.frozenset_length_validator + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none(field_schema, minItems=cls.min_items, maxItems=cls.max_items) + + @classmethod + def frozenset_length_validator(cls, v: 'Optional[FrozenSet[T]]') -> 'Optional[FrozenSet[T]]': + if v is None: + return None + + v = frozenset_validator(v) + v_len = len(v) + + if cls.min_items is not None and v_len < cls.min_items: + raise errors.FrozenSetMinLengthError(limit_value=cls.min_items) + + if cls.max_items is not None and v_len > cls.max_items: + raise errors.FrozenSetMaxLengthError(limit_value=cls.max_items) + + return v + + +def confrozenset(item_type: Type[T], *, min_items: int = None, max_items: int = None) -> Type[FrozenSet[T]]: + # __args__ is needed to conform to typing generics api + namespace = {'min_items': min_items, 'max_items': max_items, 'item_type': item_type, '__args__': [item_type]} + # We use new_class to be able to deal with Generic types + return new_class('ConstrainedFrozenSetValue', (ConstrainedFrozenSet,), {}, lambda ns: ns.update(namespace)) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LIST TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +# This types superclass should be List[T], but cython chokes on that... +class ConstrainedList(list): # type: ignore + # Needed for pydantic to detect that this is a list + __origin__ = list + __args__: Tuple[Type[T], ...] # type: ignore + + min_items: Optional[int] = None + max_items: Optional[int] = None + unique_items: Optional[bool] = None + item_type: Type[T] # type: ignore + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.list_length_validator + if cls.unique_items: + yield cls.unique_items_validator + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none(field_schema, minItems=cls.min_items, maxItems=cls.max_items, uniqueItems=cls.unique_items) + + @classmethod + def list_length_validator(cls, v: 'Optional[List[T]]') -> 'Optional[List[T]]': + if v is None: + return None + + v = list_validator(v) + v_len = len(v) + + if cls.min_items is not None and v_len < cls.min_items: + raise errors.ListMinLengthError(limit_value=cls.min_items) + + if cls.max_items is not None and v_len > cls.max_items: + raise errors.ListMaxLengthError(limit_value=cls.max_items) + + return v + + @classmethod + def unique_items_validator(cls, v: 'List[T]') -> 'List[T]': + for i, value in enumerate(v, start=1): + if value in v[i:]: + raise errors.ListUniqueItemsError() + + return v + + +def conlist( + item_type: Type[T], *, min_items: int = None, max_items: int = None, unique_items: bool = None +) -> Type[List[T]]: + # __args__ is needed to conform to typing generics api + namespace = dict( + min_items=min_items, max_items=max_items, unique_items=unique_items, item_type=item_type, __args__=(item_type,) + ) + # We use new_class to be able to deal with Generic types + return new_class('ConstrainedListValue', (ConstrainedList,), {}, lambda ns: ns.update(namespace)) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PYOBJECT TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +if TYPE_CHECKING: + PyObject = Callable[..., Any] +else: + + class PyObject: + validate_always = True + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls, value: Any) -> Any: + if isinstance(value, Callable): + return value + + try: + value = str_validator(value) + except errors.StrError: + raise errors.PyObjectError(error_message='value is neither a valid import path not a valid callable') + + try: + return import_string(value) + except ImportError as e: + raise errors.PyObjectError(error_message=str(e)) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DECIMAL TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class ConstrainedDecimal(Decimal, metaclass=ConstrainedNumberMeta): + gt: OptionalIntFloatDecimal = None + ge: OptionalIntFloatDecimal = None + lt: OptionalIntFloatDecimal = None + le: OptionalIntFloatDecimal = None + max_digits: OptionalInt = None + decimal_places: OptionalInt = None + multiple_of: OptionalIntFloatDecimal = None + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none( + field_schema, + exclusiveMinimum=cls.gt, + exclusiveMaximum=cls.lt, + minimum=cls.ge, + maximum=cls.le, + multipleOf=cls.multiple_of, + ) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield decimal_validator + yield number_size_validator + yield number_multiple_validator + yield cls.validate + + @classmethod + def validate(cls, value: Decimal) -> Decimal: + digit_tuple, exponent = value.as_tuple()[1:] + if exponent in {'F', 'n', 'N'}: + raise errors.DecimalIsNotFiniteError() + + if exponent >= 0: + # A positive exponent adds that many trailing zeros. + digits = len(digit_tuple) + exponent + decimals = 0 + else: + # If the absolute value of the negative exponent is larger than the + # number of digits, then it's the same as the number of digits, + # because it'll consume all of the digits in digit_tuple and then + # add abs(exponent) - len(digit_tuple) leading zeros after the + # decimal point. + if abs(exponent) > len(digit_tuple): + digits = decimals = abs(exponent) + else: + digits = len(digit_tuple) + decimals = abs(exponent) + whole_digits = digits - decimals + + if cls.max_digits is not None and digits > cls.max_digits: + raise errors.DecimalMaxDigitsError(max_digits=cls.max_digits) + + if cls.decimal_places is not None and decimals > cls.decimal_places: + raise errors.DecimalMaxPlacesError(decimal_places=cls.decimal_places) + + if cls.max_digits is not None and cls.decimal_places is not None: + expected = cls.max_digits - cls.decimal_places + if whole_digits > expected: + raise errors.DecimalWholeDigitsError(whole_digits=expected) + + return value + + +def condecimal( + *, + gt: Decimal = None, + ge: Decimal = None, + lt: Decimal = None, + le: Decimal = None, + max_digits: int = None, + decimal_places: int = None, + multiple_of: Decimal = None, +) -> Type[Decimal]: + # use kwargs then define conf in a dict to aid with IDE type hinting + namespace = dict( + gt=gt, ge=ge, lt=lt, le=le, max_digits=max_digits, decimal_places=decimal_places, multiple_of=multiple_of + ) + return type('ConstrainedDecimalValue', (ConstrainedDecimal,), namespace) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ UUID TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +if TYPE_CHECKING: + UUID1 = UUID + UUID3 = UUID + UUID4 = UUID + UUID5 = UUID +else: + + class UUID1(UUID): + _required_version = 1 + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format=f'uuid{cls._required_version}') + + class UUID3(UUID1): + _required_version = 3 + + class UUID4(UUID1): + _required_version = 4 + + class UUID5(UUID1): + _required_version = 5 + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PATH TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +if TYPE_CHECKING: + FilePath = Path + DirectoryPath = Path +else: + + class FilePath(Path): + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(format='file-path') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield path_validator + yield path_exists_validator + yield cls.validate + + @classmethod + def validate(cls, value: Path) -> Path: + if not value.is_file(): + raise errors.PathNotAFileError(path=value) + + return value + + class DirectoryPath(Path): + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(format='directory-path') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield path_validator + yield path_exists_validator + yield cls.validate + + @classmethod + def validate(cls, value: Path) -> Path: + if not value.is_dir(): + raise errors.PathNotADirectoryError(path=value) + + return value + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ JSON TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class JsonWrapper: + pass + + +class JsonMeta(type): + def __getitem__(self, t: Type[Any]) -> Type[JsonWrapper]: + if t is Any: + return Json # allow Json[Any] to replecate plain Json + return _registered(type('JsonWrapperValue', (JsonWrapper,), {'inner_type': t})) + + +if TYPE_CHECKING: + Json = Annotated[T, ...] # Json[list[str]] will be recognized by type checkers as list[str] + +else: + + class Json(metaclass=JsonMeta): + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format='json-string') + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SECRET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class SecretField(abc.ABC): + """ + Note: this should be implemented as a generic like `SecretField(ABC, Generic[T])`, + the `__init__()` should be part of the abstract class and the + `get_secret_value()` method should use the generic `T` type. + + However Cython doesn't support very well generics at the moment and + the generated code fails to be imported (see + https://github.com/cython/cython/issues/2753). + """ + + def __eq__(self, other: Any) -> bool: + return isinstance(other, self.__class__) and self.get_secret_value() == other.get_secret_value() + + def __str__(self) -> str: + return '**********' if self.get_secret_value() else '' + + def __hash__(self) -> int: + return hash(self.get_secret_value()) + + @abc.abstractmethod + def get_secret_value(self) -> Any: # pragma: no cover + ... + + +class SecretStr(SecretField): + min_length: OptionalInt = None + max_length: OptionalInt = None + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none( + field_schema, + type='string', + writeOnly=True, + format='password', + minLength=cls.min_length, + maxLength=cls.max_length, + ) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + yield constr_length_validator + + @classmethod + def validate(cls, value: Any) -> 'SecretStr': + if isinstance(value, cls): + return value + value = str_validator(value) + return cls(value) + + def __init__(self, value: str): + self._secret_value = value + + def __repr__(self) -> str: + return f"SecretStr('{self}')" + + def __len__(self) -> int: + return len(self._secret_value) + + def display(self) -> str: + warnings.warn('`secret_str.display()` is deprecated, use `str(secret_str)` instead', DeprecationWarning) + return str(self) + + def get_secret_value(self) -> str: + return self._secret_value + + +class SecretBytes(SecretField): + min_length: OptionalInt = None + max_length: OptionalInt = None + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none( + field_schema, + type='string', + writeOnly=True, + format='password', + minLength=cls.min_length, + maxLength=cls.max_length, + ) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + yield constr_length_validator + + @classmethod + def validate(cls, value: Any) -> 'SecretBytes': + if isinstance(value, cls): + return value + value = bytes_validator(value) + return cls(value) + + def __init__(self, value: bytes): + self._secret_value = value + + def __repr__(self) -> str: + return f"SecretBytes(b'{self}')" + + def __len__(self) -> int: + return len(self._secret_value) + + def display(self) -> str: + warnings.warn('`secret_bytes.display()` is deprecated, use `str(secret_bytes)` instead', DeprecationWarning) + return str(self) + + def get_secret_value(self) -> bytes: + return self._secret_value + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PAYMENT CARD TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class PaymentCardBrand(str, Enum): + # If you add another card type, please also add it to the + # Hypothesis strategy in `pydantic._hypothesis_plugin`. + amex = 'American Express' + mastercard = 'Mastercard' + visa = 'Visa' + other = 'other' + + def __str__(self) -> str: + return self.value + + +class PaymentCardNumber(str): + """ + Based on: https://en.wikipedia.org/wiki/Payment_card_number + """ + + strip_whitespace: ClassVar[bool] = True + min_length: ClassVar[int] = 12 + max_length: ClassVar[int] = 19 + bin: str + last4: str + brand: PaymentCardBrand + + def __init__(self, card_number: str): + self.bin = card_number[:6] + self.last4 = card_number[-4:] + self.brand = self._get_brand(card_number) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield str_validator + yield constr_strip_whitespace + yield constr_length_validator + yield cls.validate_digits + yield cls.validate_luhn_check_digit + yield cls + yield cls.validate_length_for_brand + + @property + def masked(self) -> str: + num_masked = len(self) - 10 # len(bin) + len(last4) == 10 + return f'{self.bin}{"*" * num_masked}{self.last4}' + + @classmethod + def validate_digits(cls, card_number: str) -> str: + if not card_number.isdigit(): + raise errors.NotDigitError + return card_number + + @classmethod + def validate_luhn_check_digit(cls, card_number: str) -> str: + """ + Based on: https://en.wikipedia.org/wiki/Luhn_algorithm + """ + sum_ = int(card_number[-1]) + length = len(card_number) + parity = length % 2 + for i in range(length - 1): + digit = int(card_number[i]) + if i % 2 == parity: + digit *= 2 + if digit > 9: + digit -= 9 + sum_ += digit + valid = sum_ % 10 == 0 + if not valid: + raise errors.LuhnValidationError + return card_number + + @classmethod + def validate_length_for_brand(cls, card_number: 'PaymentCardNumber') -> 'PaymentCardNumber': + """ + Validate length based on BIN for major brands: + https://en.wikipedia.org/wiki/Payment_card_number#Issuer_identification_number_(IIN) + """ + required_length: Union[None, int, str] = None + if card_number.brand in PaymentCardBrand.mastercard: + required_length = 16 + valid = len(card_number) == required_length + elif card_number.brand == PaymentCardBrand.visa: + required_length = '13, 16 or 19' + valid = len(card_number) in {13, 16, 19} + elif card_number.brand == PaymentCardBrand.amex: + required_length = 15 + valid = len(card_number) == required_length + else: + valid = True + if not valid: + raise errors.InvalidLengthForBrand(brand=card_number.brand, required_length=required_length) + return card_number + + @staticmethod + def _get_brand(card_number: str) -> PaymentCardBrand: + if card_number[0] == '4': + brand = PaymentCardBrand.visa + elif 51 <= int(card_number[:2]) <= 55: + brand = PaymentCardBrand.mastercard + elif card_number[:2] in {'34', '37'}: + brand = PaymentCardBrand.amex + else: + brand = PaymentCardBrand.other + return brand + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTE SIZE TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +BYTE_SIZES = { + 'b': 1, + 'kb': 10**3, + 'mb': 10**6, + 'gb': 10**9, + 'tb': 10**12, + 'pb': 10**15, + 'eb': 10**18, + 'kib': 2**10, + 'mib': 2**20, + 'gib': 2**30, + 'tib': 2**40, + 'pib': 2**50, + 'eib': 2**60, +} +BYTE_SIZES.update({k.lower()[0]: v for k, v in BYTE_SIZES.items() if 'i' not in k}) +byte_string_re = re.compile(r'^\s*(\d*\.?\d+)\s*(\w+)?', re.IGNORECASE) + + +class ByteSize(int): + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls, v: StrIntFloat) -> 'ByteSize': + + try: + return cls(int(v)) + except ValueError: + pass + + str_match = byte_string_re.match(str(v)) + if str_match is None: + raise errors.InvalidByteSize() + + scalar, unit = str_match.groups() + if unit is None: + unit = 'b' + + try: + unit_mult = BYTE_SIZES[unit.lower()] + except KeyError: + raise errors.InvalidByteSizeUnit(unit=unit) + + return cls(int(float(scalar) * unit_mult)) + + def human_readable(self, decimal: bool = False) -> str: + + if decimal: + divisor = 1000 + units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB'] + final_unit = 'EB' + else: + divisor = 1024 + units = ['B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB'] + final_unit = 'EiB' + + num = float(self) + for unit in units: + if abs(num) < divisor: + return f'{num:0.1f}{unit}' + num /= divisor + + return f'{num:0.1f}{final_unit}' + + def to(self, unit: str) -> float: + + try: + unit_div = BYTE_SIZES[unit.lower()] + except KeyError: + raise errors.InvalidByteSizeUnit(unit=unit) + + return self / unit_div + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DATE TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +if TYPE_CHECKING: + PastDate = date + FutureDate = date +else: + + class PastDate(date): + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield parse_date + yield cls.validate + + @classmethod + def validate(cls, value: date) -> date: + if value >= date.today(): + raise errors.DateNotInThePastError() + + return value + + class FutureDate(date): + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield parse_date + yield cls.validate + + @classmethod + def validate(cls, value: date) -> date: + if value <= date.today(): + raise errors.DateNotInTheFutureError() + + return value + + +class ConstrainedDate(date, metaclass=ConstrainedNumberMeta): + gt: OptionalDate = None + ge: OptionalDate = None + lt: OptionalDate = None + le: OptionalDate = None + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none(field_schema, exclusiveMinimum=cls.gt, exclusiveMaximum=cls.lt, minimum=cls.ge, maximum=cls.le) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield parse_date + yield number_size_validator + + +def condate( + *, + gt: date = None, + ge: date = None, + lt: date = None, + le: date = None, +) -> Type[date]: + # use kwargs then define conf in a dict to aid with IDE type hinting + namespace = dict(gt=gt, ge=ge, lt=lt, le=le) + return type('ConstrainedDateValue', (ConstrainedDate,), namespace) diff --git a/libs/win/pydantic/typing.cp37-win_amd64.pyd b/libs/win/pydantic/typing.cp37-win_amd64.pyd new file mode 100644 index 00000000..d825e1a5 Binary files /dev/null and b/libs/win/pydantic/typing.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/typing.py b/libs/win/pydantic/typing.py new file mode 100644 index 00000000..5ccf266c --- /dev/null +++ b/libs/win/pydantic/typing.py @@ -0,0 +1,602 @@ +import sys +from collections.abc import Callable +from os import PathLike +from typing import ( # type: ignore + TYPE_CHECKING, + AbstractSet, + Any, + Callable as TypingCallable, + ClassVar, + Dict, + ForwardRef, + Generator, + Iterable, + List, + Mapping, + NewType, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + _eval_type, + cast, + get_type_hints, +) + +from typing_extensions import ( + Annotated, + Final, + Literal, + NotRequired as TypedDictNotRequired, + Required as TypedDictRequired, +) + +try: + from typing import _TypingBase as typing_base # type: ignore +except ImportError: + from typing import _Final as typing_base # type: ignore + +try: + from typing import GenericAlias as TypingGenericAlias # type: ignore +except ImportError: + # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on) + TypingGenericAlias = () + +try: + from types import UnionType as TypesUnionType # type: ignore +except ImportError: + # python < 3.10 does not have UnionType (str | int, byte | bool and so on) + TypesUnionType = () + + +if sys.version_info < (3, 9): + + def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: + return type_._evaluate(globalns, localns) + +else: + + def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: + # Even though it is the right signature for python 3.9, mypy complains with + # `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast... + return cast(Any, type_)._evaluate(globalns, localns, set()) + + +if sys.version_info < (3, 9): + # Ensure we always get all the whole `Annotated` hint, not just the annotated type. + # For 3.7 to 3.8, `get_type_hints` doesn't recognize `typing_extensions.Annotated`, + # so it already returns the full annotation + get_all_type_hints = get_type_hints + +else: + + def get_all_type_hints(obj: Any, globalns: Any = None, localns: Any = None) -> Any: + return get_type_hints(obj, globalns, localns, include_extras=True) + + +_T = TypeVar('_T') + +AnyCallable = TypingCallable[..., Any] +NoArgAnyCallable = TypingCallable[[], Any] + +# workaround for https://github.com/python/mypy/issues/9496 +AnyArgTCallable = TypingCallable[..., _T] + + +# Annotated[...] is implemented by returning an instance of one of these classes, depending on +# python/typing_extensions version. +AnnotatedTypeNames = {'AnnotatedMeta', '_AnnotatedAlias'} + + +if sys.version_info < (3, 8): + + def get_origin(t: Type[Any]) -> Optional[Type[Any]]: + if type(t).__name__ in AnnotatedTypeNames: + # weirdly this is a runtime requirement, as well as for mypy + return cast(Type[Any], Annotated) + return getattr(t, '__origin__', None) + +else: + from typing import get_origin as _typing_get_origin + + def get_origin(tp: Type[Any]) -> Optional[Type[Any]]: + """ + We can't directly use `typing.get_origin` since we need a fallback to support + custom generic classes like `ConstrainedList` + It should be useless once https://github.com/cython/cython/issues/3537 is + solved and https://github.com/pydantic/pydantic/pull/1753 is merged. + """ + if type(tp).__name__ in AnnotatedTypeNames: + return cast(Type[Any], Annotated) # mypy complains about _SpecialForm + return _typing_get_origin(tp) or getattr(tp, '__origin__', None) + + +if sys.version_info < (3, 8): + from typing import _GenericAlias + + def get_args(t: Type[Any]) -> Tuple[Any, ...]: + """Compatibility version of get_args for python 3.7. + + Mostly compatible with the python 3.8 `typing` module version + and able to handle almost all use cases. + """ + if type(t).__name__ in AnnotatedTypeNames: + return t.__args__ + t.__metadata__ + if isinstance(t, _GenericAlias): + res = t.__args__ + if t.__origin__ is Callable and res and res[0] is not Ellipsis: + res = (list(res[:-1]), res[-1]) + return res + return getattr(t, '__args__', ()) + +else: + from typing import get_args as _typing_get_args + + def _generic_get_args(tp: Type[Any]) -> Tuple[Any, ...]: + """ + In python 3.9, `typing.Dict`, `typing.List`, ... + do have an empty `__args__` by default (instead of the generic ~T for example). + In order to still support `Dict` for example and consider it as `Dict[Any, Any]`, + we retrieve the `_nparams` value that tells us how many parameters it needs. + """ + if hasattr(tp, '_nparams'): + return (Any,) * tp._nparams + # Special case for `tuple[()]`, which used to return ((),) with `typing.Tuple` + # in python 3.10- but now returns () for `tuple` and `Tuple`. + # This will probably be clarified in pydantic v2 + try: + if tp == Tuple[()] or sys.version_info >= (3, 9) and tp == tuple[()]: # type: ignore[misc] + return ((),) + # there is a TypeError when compiled with cython + except TypeError: # pragma: no cover + pass + return () + + def get_args(tp: Type[Any]) -> Tuple[Any, ...]: + """Get type arguments with all substitutions performed. + + For unions, basic simplifications used by Union constructor are performed. + Examples:: + get_args(Dict[str, int]) == (str, int) + get_args(int) == () + get_args(Union[int, Union[T, int], str][int]) == (int, str) + get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) + get_args(Callable[[], T][int]) == ([], int) + """ + if type(tp).__name__ in AnnotatedTypeNames: + return tp.__args__ + tp.__metadata__ + # the fallback is needed for the same reasons as `get_origin` (see above) + return _typing_get_args(tp) or getattr(tp, '__args__', ()) or _generic_get_args(tp) + + +if sys.version_info < (3, 9): + + def convert_generics(tp: Type[Any]) -> Type[Any]: + """Python 3.9 and older only supports generics from `typing` module. + They convert strings to ForwardRef automatically. + + Examples:: + typing.List['Hero'] == typing.List[ForwardRef('Hero')] + """ + return tp + +else: + from typing import _UnionGenericAlias # type: ignore + + from typing_extensions import _AnnotatedAlias + + def convert_generics(tp: Type[Any]) -> Type[Any]: + """ + Recursively searches for `str` type hints and replaces them with ForwardRef. + + Examples:: + convert_generics(list['Hero']) == list[ForwardRef('Hero')] + convert_generics(dict['Hero', 'Team']) == dict[ForwardRef('Hero'), ForwardRef('Team')] + convert_generics(typing.Dict['Hero', 'Team']) == typing.Dict[ForwardRef('Hero'), ForwardRef('Team')] + convert_generics(list[str | 'Hero'] | int) == list[str | ForwardRef('Hero')] | int + """ + origin = get_origin(tp) + if not origin or not hasattr(tp, '__args__'): + return tp + + args = get_args(tp) + + # typing.Annotated needs special treatment + if origin is Annotated: + return _AnnotatedAlias(convert_generics(args[0]), args[1:]) + + # recursively replace `str` instances inside of `GenericAlias` with `ForwardRef(arg)` + converted = tuple( + ForwardRef(arg) if isinstance(arg, str) and isinstance(tp, TypingGenericAlias) else convert_generics(arg) + for arg in args + ) + + if converted == args: + return tp + elif isinstance(tp, TypingGenericAlias): + return TypingGenericAlias(origin, converted) + elif isinstance(tp, TypesUnionType): + # recreate types.UnionType (PEP604, Python >= 3.10) + return _UnionGenericAlias(origin, converted) + else: + try: + setattr(tp, '__args__', converted) + except AttributeError: + pass + return tp + + +if sys.version_info < (3, 10): + + def is_union(tp: Optional[Type[Any]]) -> bool: + return tp is Union + + WithArgsTypes = (TypingGenericAlias,) + +else: + import types + import typing + + def is_union(tp: Optional[Type[Any]]) -> bool: + return tp is Union or tp is types.UnionType # noqa: E721 + + WithArgsTypes = (typing._GenericAlias, types.GenericAlias, types.UnionType) + + +if sys.version_info < (3, 9): + StrPath = Union[str, PathLike] +else: + StrPath = Union[str, PathLike] + # TODO: Once we switch to Cython 3 to handle generics properly + # (https://github.com/cython/cython/issues/2753), use following lines instead + # of the one above + # # os.PathLike only becomes subscriptable from Python 3.9 onwards + # StrPath = Union[str, PathLike[str]] + + +if TYPE_CHECKING: + from .fields import ModelField + + TupleGenerator = Generator[Tuple[str, Any], None, None] + DictStrAny = Dict[str, Any] + DictAny = Dict[Any, Any] + SetStr = Set[str] + ListStr = List[str] + IntStr = Union[int, str] + AbstractSetIntStr = AbstractSet[IntStr] + DictIntStrAny = Dict[IntStr, Any] + MappingIntStrAny = Mapping[IntStr, Any] + CallableGenerator = Generator[AnyCallable, None, None] + ReprArgs = Sequence[Tuple[Optional[str], Any]] + AnyClassMethod = classmethod[Any] + +__all__ = ( + 'AnyCallable', + 'NoArgAnyCallable', + 'NoneType', + 'is_none_type', + 'display_as_type', + 'resolve_annotations', + 'is_callable_type', + 'is_literal_type', + 'all_literal_values', + 'is_namedtuple', + 'is_typeddict', + 'is_typeddict_special', + 'is_new_type', + 'new_type_supertype', + 'is_classvar', + 'is_finalvar', + 'update_field_forward_refs', + 'update_model_forward_refs', + 'TupleGenerator', + 'DictStrAny', + 'DictAny', + 'SetStr', + 'ListStr', + 'IntStr', + 'AbstractSetIntStr', + 'DictIntStrAny', + 'CallableGenerator', + 'ReprArgs', + 'AnyClassMethod', + 'CallableGenerator', + 'WithArgsTypes', + 'get_args', + 'get_origin', + 'get_sub_types', + 'typing_base', + 'get_all_type_hints', + 'is_union', + 'StrPath', + 'MappingIntStrAny', +) + + +NoneType = None.__class__ + + +NONE_TYPES: Tuple[Any, Any, Any] = (None, NoneType, Literal[None]) + + +if sys.version_info < (3, 8): + # Even though this implementation is slower, we need it for python 3.7: + # In python 3.7 "Literal" is not a builtin type and uses a different + # mechanism. + # for this reason `Literal[None] is Literal[None]` evaluates to `False`, + # breaking the faster implementation used for the other python versions. + + def is_none_type(type_: Any) -> bool: + return type_ in NONE_TYPES + +elif sys.version_info[:2] == (3, 8): + + def is_none_type(type_: Any) -> bool: + for none_type in NONE_TYPES: + if type_ is none_type: + return True + # With python 3.8, specifically 3.8.10, Literal "is" check sare very flakey + # can change on very subtle changes like use of types in other modules, + # hopefully this check avoids that issue. + if is_literal_type(type_): # pragma: no cover + return all_literal_values(type_) == (None,) + return False + +else: + + def is_none_type(type_: Any) -> bool: + for none_type in NONE_TYPES: + if type_ is none_type: + return True + return False + + +def display_as_type(v: Type[Any]) -> str: + if not isinstance(v, typing_base) and not isinstance(v, WithArgsTypes) and not isinstance(v, type): + v = v.__class__ + + if is_union(get_origin(v)): + return f'Union[{", ".join(map(display_as_type, get_args(v)))}]' + + if isinstance(v, WithArgsTypes): + # Generic alias are constructs like `list[int]` + return str(v).replace('typing.', '') + + try: + return v.__name__ + except AttributeError: + # happens with typing objects + return str(v).replace('typing.', '') + + +def resolve_annotations(raw_annotations: Dict[str, Type[Any]], module_name: Optional[str]) -> Dict[str, Type[Any]]: + """ + Partially taken from typing.get_type_hints. + + Resolve string or ForwardRef annotations into type objects if possible. + """ + base_globals: Optional[Dict[str, Any]] = None + if module_name: + try: + module = sys.modules[module_name] + except KeyError: + # happens occasionally, see https://github.com/pydantic/pydantic/issues/2363 + pass + else: + base_globals = module.__dict__ + + annotations = {} + for name, value in raw_annotations.items(): + if isinstance(value, str): + if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1): + value = ForwardRef(value, is_argument=False, is_class=True) + else: + value = ForwardRef(value, is_argument=False) + try: + value = _eval_type(value, base_globals, None) + except NameError: + # this is ok, it can be fixed with update_forward_refs + pass + annotations[name] = value + return annotations + + +def is_callable_type(type_: Type[Any]) -> bool: + return type_ is Callable or get_origin(type_) is Callable + + +def is_literal_type(type_: Type[Any]) -> bool: + return Literal is not None and get_origin(type_) is Literal + + +def literal_values(type_: Type[Any]) -> Tuple[Any, ...]: + return get_args(type_) + + +def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]: + """ + This method is used to retrieve all Literal values as + Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586) + e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]` + """ + if not is_literal_type(type_): + return (type_,) + + values = literal_values(type_) + return tuple(x for value in values for x in all_literal_values(value)) + + +def is_namedtuple(type_: Type[Any]) -> bool: + """ + Check if a given class is a named tuple. + It can be either a `typing.NamedTuple` or `collections.namedtuple` + """ + from .utils import lenient_issubclass + + return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields') + + +def is_typeddict(type_: Type[Any]) -> bool: + """ + Check if a given class is a typed dict (from `typing` or `typing_extensions`) + In 3.10, there will be a public method (https://docs.python.org/3.10/library/typing.html#typing.is_typeddict) + """ + from .utils import lenient_issubclass + + return lenient_issubclass(type_, dict) and hasattr(type_, '__total__') + + +def _check_typeddict_special(type_: Any) -> bool: + return type_ is TypedDictRequired or type_ is TypedDictNotRequired + + +def is_typeddict_special(type_: Any) -> bool: + """ + Check if type is a TypedDict special form (Required or NotRequired). + """ + return _check_typeddict_special(type_) or _check_typeddict_special(get_origin(type_)) + + +test_type = NewType('test_type', str) + + +def is_new_type(type_: Type[Any]) -> bool: + """ + Check whether type_ was created using typing.NewType + """ + return isinstance(type_, test_type.__class__) and hasattr(type_, '__supertype__') # type: ignore + + +def new_type_supertype(type_: Type[Any]) -> Type[Any]: + while hasattr(type_, '__supertype__'): + type_ = type_.__supertype__ + return type_ + + +def _check_classvar(v: Optional[Type[Any]]) -> bool: + if v is None: + return False + + return v.__class__ == ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar' + + +def _check_finalvar(v: Optional[Type[Any]]) -> bool: + """ + Check if a given type is a `typing.Final` type. + """ + if v is None: + return False + + return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final') + + +def is_classvar(ann_type: Type[Any]) -> bool: + if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)): + return True + + # this is an ugly workaround for class vars that contain forward references and are therefore themselves + # forward references, see #3679 + if ann_type.__class__ == ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['): + return True + + return False + + +def is_finalvar(ann_type: Type[Any]) -> bool: + return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type)) + + +def update_field_forward_refs(field: 'ModelField', globalns: Any, localns: Any) -> None: + """ + Try to update ForwardRefs on fields based on this ModelField, globalns and localns. + """ + prepare = False + if field.type_.__class__ == ForwardRef: + prepare = True + field.type_ = evaluate_forwardref(field.type_, globalns, localns or None) + if field.outer_type_.__class__ == ForwardRef: + prepare = True + field.outer_type_ = evaluate_forwardref(field.outer_type_, globalns, localns or None) + if prepare: + field.prepare() + + if field.sub_fields: + for sub_f in field.sub_fields: + update_field_forward_refs(sub_f, globalns=globalns, localns=localns) + + if field.discriminator_key is not None: + field.prepare_discriminated_union_sub_fields() + + +def update_model_forward_refs( + model: Type[Any], + fields: Iterable['ModelField'], + json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable], + localns: 'DictStrAny', + exc_to_suppress: Tuple[Type[BaseException], ...] = (), +) -> None: + """ + Try to update model fields ForwardRefs based on model and localns. + """ + if model.__module__ in sys.modules: + globalns = sys.modules[model.__module__].__dict__.copy() + else: + globalns = {} + + globalns.setdefault(model.__name__, model) + + for f in fields: + try: + update_field_forward_refs(f, globalns=globalns, localns=localns) + except exc_to_suppress: + pass + + for key in set(json_encoders.keys()): + if isinstance(key, str): + fr: ForwardRef = ForwardRef(key) + elif isinstance(key, ForwardRef): + fr = key + else: + continue + + try: + new_key = evaluate_forwardref(fr, globalns, localns or None) + except exc_to_suppress: # pragma: no cover + continue + + json_encoders[new_key] = json_encoders.pop(key) + + +def get_class(type_: Type[Any]) -> Union[None, bool, Type[Any]]: + """ + Tries to get the class of a Type[T] annotation. Returns True if Type is used + without brackets. Otherwise returns None. + """ + if type_ is type: + return True + + if get_origin(type_) is None: + return None + + args = get_args(type_) + if not args or not isinstance(args[0], type): + return True + else: + return args[0] + + +def get_sub_types(tp: Any) -> List[Any]: + """ + Return all the types that are allowed by type `tp` + `tp` can be a `Union` of allowed types or an `Annotated` type + """ + origin = get_origin(tp) + if origin is Annotated: + return get_sub_types(get_args(tp)[0]) + elif is_union(origin): + return [x for t in get_args(tp) for x in get_sub_types(t)] + else: + return [tp] diff --git a/libs/win/pydantic/utils.cp37-win_amd64.pyd b/libs/win/pydantic/utils.cp37-win_amd64.pyd new file mode 100644 index 00000000..e3bfb22c Binary files /dev/null and b/libs/win/pydantic/utils.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/utils.py b/libs/win/pydantic/utils.py new file mode 100644 index 00000000..1d016c0e --- /dev/null +++ b/libs/win/pydantic/utils.py @@ -0,0 +1,841 @@ +import keyword +import warnings +import weakref +from collections import OrderedDict, defaultdict, deque +from copy import deepcopy +from itertools import islice, zip_longest +from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType +from typing import ( + TYPE_CHECKING, + AbstractSet, + Any, + Callable, + Collection, + Dict, + Generator, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + NoReturn, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +from typing_extensions import Annotated + +from .errors import ConfigError +from .typing import ( + NoneType, + WithArgsTypes, + all_literal_values, + display_as_type, + get_args, + get_origin, + is_literal_type, + is_union, +) +from .version import version_info + +if TYPE_CHECKING: + from inspect import Signature + from pathlib import Path + + from .config import BaseConfig + from .dataclasses import Dataclass + from .fields import ModelField + from .main import BaseModel + from .typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs + + RichReprResult = Iterable[Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]] + +__all__ = ( + 'import_string', + 'sequence_like', + 'validate_field_name', + 'lenient_isinstance', + 'lenient_issubclass', + 'in_ipython', + 'is_valid_identifier', + 'deep_update', + 'update_not_none', + 'almost_equal_floats', + 'get_model', + 'to_camel', + 'is_valid_field', + 'smart_deepcopy', + 'PyObjectStr', + 'Representation', + 'GetterDict', + 'ValueItems', + 'version_info', # required here to match behaviour in v1.3 + 'ClassAttribute', + 'path_type', + 'ROOT_KEY', + 'get_unique_discriminator_alias', + 'get_discriminator_alias_and_values', + 'DUNDER_ATTRIBUTES', + 'LimitedDict', +) + +ROOT_KEY = '__root__' +# these are types that are returned unchanged by deepcopy +IMMUTABLE_NON_COLLECTIONS_TYPES: Set[Type[Any]] = { + int, + float, + complex, + str, + bool, + bytes, + type, + NoneType, + FunctionType, + BuiltinFunctionType, + LambdaType, + weakref.ref, + CodeType, + # note: including ModuleType will differ from behaviour of deepcopy by not producing error. + # It might be not a good idea in general, but considering that this function used only internally + # against default values of fields, this will allow to actually have a field with module as default value + ModuleType, + NotImplemented.__class__, + Ellipsis.__class__, +} + +# these are types that if empty, might be copied with simple copy() instead of deepcopy() +BUILTIN_COLLECTIONS: Set[Type[Any]] = { + list, + set, + tuple, + frozenset, + dict, + OrderedDict, + defaultdict, + deque, +} + + +def import_string(dotted_path: str) -> Any: + """ + Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the + last name in the path. Raise ImportError if the import fails. + """ + from importlib import import_module + + try: + module_path, class_name = dotted_path.strip(' ').rsplit('.', 1) + except ValueError as e: + raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e + + module = import_module(module_path) + try: + return getattr(module, class_name) + except AttributeError as e: + raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e + + +def truncate(v: Union[str], *, max_len: int = 80) -> str: + """ + Truncate a value and add a unicode ellipsis (three dots) to the end if it was too long + """ + warnings.warn('`truncate` is no-longer used by pydantic and is deprecated', DeprecationWarning) + if isinstance(v, str) and len(v) > (max_len - 2): + # -3 so quote + string + … + quote has correct length + return (v[: (max_len - 3)] + '…').__repr__() + try: + v = v.__repr__() + except TypeError: + v = v.__class__.__repr__(v) # in case v is a type + if len(v) > max_len: + v = v[: max_len - 1] + '…' + return v + + +def sequence_like(v: Any) -> bool: + return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque)) + + +def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None: + """ + Ensure that the field's name does not shadow an existing attribute of the model. + """ + for base in bases: + if getattr(base, field_name, None): + raise NameError( + f'Field name "{field_name}" shadows a BaseModel attribute; ' + f'use a different field name with "alias=\'{field_name}\'".' + ) + + +def lenient_isinstance(o: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool: + try: + return isinstance(o, class_or_tuple) # type: ignore[arg-type] + except TypeError: + return False + + +def lenient_issubclass(cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool: + try: + return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type] + except TypeError: + if isinstance(cls, WithArgsTypes): + return False + raise # pragma: no cover + + +def in_ipython() -> bool: + """ + Check whether we're in an ipython environment, including jupyter notebooks. + """ + try: + eval('__IPYTHON__') + except NameError: + return False + else: # pragma: no cover + return True + + +def is_valid_identifier(identifier: str) -> bool: + """ + Checks that a string is a valid identifier and not a Python keyword. + :param identifier: The identifier to test. + :return: True if the identifier is valid. + """ + return identifier.isidentifier() and not keyword.iskeyword(identifier) + + +KeyType = TypeVar('KeyType') + + +def deep_update(mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]) -> Dict[KeyType, Any]: + updated_mapping = mapping.copy() + for updating_mapping in updating_mappings: + for k, v in updating_mapping.items(): + if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict): + updated_mapping[k] = deep_update(updated_mapping[k], v) + else: + updated_mapping[k] = v + return updated_mapping + + +def update_not_none(mapping: Dict[Any, Any], **update: Any) -> None: + mapping.update({k: v for k, v in update.items() if v is not None}) + + +def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool: + """ + Return True if two floats are almost equal + """ + return abs(value_1 - value_2) <= delta + + +def generate_model_signature( + init: Callable[..., None], fields: Dict[str, 'ModelField'], config: Type['BaseConfig'] +) -> 'Signature': + """ + Generate signature for model based on its fields + """ + from inspect import Parameter, Signature, signature + + from .config import Extra + + present_params = signature(init).parameters.values() + merged_params: Dict[str, Parameter] = {} + var_kw = None + use_var_kw = False + + for param in islice(present_params, 1, None): # skip self arg + if param.kind is param.VAR_KEYWORD: + var_kw = param + continue + merged_params[param.name] = param + + if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through + allow_names = config.allow_population_by_field_name + for field_name, field in fields.items(): + param_name = field.alias + if field_name in merged_params or param_name in merged_params: + continue + elif not is_valid_identifier(param_name): + if allow_names and is_valid_identifier(field_name): + param_name = field_name + else: + use_var_kw = True + continue + + # TODO: replace annotation with actual expected types once #1055 solved + kwargs = {'default': field.default} if not field.required else {} + merged_params[param_name] = Parameter( + param_name, Parameter.KEYWORD_ONLY, annotation=field.annotation, **kwargs + ) + + if config.extra is Extra.allow: + use_var_kw = True + + if var_kw and use_var_kw: + # Make sure the parameter for extra kwargs + # does not have the same name as a field + default_model_signature = [ + ('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD), + ('data', Parameter.VAR_KEYWORD), + ] + if [(p.name, p.kind) for p in present_params] == default_model_signature: + # if this is the standard model signature, use extra_data as the extra args name + var_kw_name = 'extra_data' + else: + # else start from var_kw + var_kw_name = var_kw.name + + # generate a name that's definitely unique + while var_kw_name in fields: + var_kw_name += '_' + merged_params[var_kw_name] = var_kw.replace(name=var_kw_name) + + return Signature(parameters=list(merged_params.values()), return_annotation=None) + + +def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']: + from .main import BaseModel + + try: + model_cls = obj.__pydantic_model__ # type: ignore + except AttributeError: + model_cls = obj + + if not issubclass(model_cls, BaseModel): + raise TypeError('Unsupported type, must be either BaseModel or dataclass') + return model_cls + + +def to_camel(string: str) -> str: + return ''.join(word.capitalize() for word in string.split('_')) + + +def to_lower_camel(string: str) -> str: + if len(string) >= 1: + pascal_string = to_camel(string) + return pascal_string[0].lower() + pascal_string[1:] + return string.lower() + + +T = TypeVar('T') + + +def unique_list( + input_list: Union[List[T], Tuple[T, ...]], + *, + name_factory: Callable[[T], str] = str, +) -> List[T]: + """ + Make a list unique while maintaining order. + We update the list if another one with the same name is set + (e.g. root validator overridden in subclass) + """ + result: List[T] = [] + result_names: List[str] = [] + for v in input_list: + v_name = name_factory(v) + if v_name not in result_names: + result_names.append(v_name) + result.append(v) + else: + result[result_names.index(v_name)] = v + + return result + + +class PyObjectStr(str): + """ + String class where repr doesn't include quotes. Useful with Representation when you want to return a string + representation of something that valid (or pseudo-valid) python. + """ + + def __repr__(self) -> str: + return str(self) + + +class Representation: + """ + Mixin to provide __str__, __repr__, and __pretty__ methods. See #884 for more details. + + __pretty__ is used by [devtools](https://python-devtools.helpmanual.io/) to provide human readable representations + of objects. + """ + + __slots__: Tuple[str, ...] = tuple() + + def __repr_args__(self) -> 'ReprArgs': + """ + Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden. + + Can either return: + * name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]` + * or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]` + """ + attrs = ((s, getattr(self, s)) for s in self.__slots__) + return [(a, v) for a, v in attrs if v is not None] + + def __repr_name__(self) -> str: + """ + Name of the instance's class, used in __repr__. + """ + return self.__class__.__name__ + + def __repr_str__(self, join_str: str) -> str: + return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__()) + + def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]: + """ + Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects + """ + yield self.__repr_name__() + '(' + yield 1 + for name, value in self.__repr_args__(): + if name is not None: + yield name + '=' + yield fmt(value) + yield ',' + yield 0 + yield -1 + yield ')' + + def __str__(self) -> str: + return self.__repr_str__(' ') + + def __repr__(self) -> str: + return f'{self.__repr_name__()}({self.__repr_str__(", ")})' + + def __rich_repr__(self) -> 'RichReprResult': + """Get fields for Rich library""" + for name, field_repr in self.__repr_args__(): + if name is None: + yield field_repr + else: + yield name, field_repr + + +class GetterDict(Representation): + """ + Hack to make object's smell just enough like dicts for validate_model. + + We can't inherit from Mapping[str, Any] because it upsets cython so we have to implement all methods ourselves. + """ + + __slots__ = ('_obj',) + + def __init__(self, obj: Any): + self._obj = obj + + def __getitem__(self, key: str) -> Any: + try: + return getattr(self._obj, key) + except AttributeError as e: + raise KeyError(key) from e + + def get(self, key: Any, default: Any = None) -> Any: + return getattr(self._obj, key, default) + + def extra_keys(self) -> Set[Any]: + """ + We don't want to get any other attributes of obj if the model didn't explicitly ask for them + """ + return set() + + def keys(self) -> List[Any]: + """ + Keys of the pseudo dictionary, uses a list not set so order information can be maintained like python + dictionaries. + """ + return list(self) + + def values(self) -> List[Any]: + return [self[k] for k in self] + + def items(self) -> Iterator[Tuple[str, Any]]: + for k in self: + yield k, self.get(k) + + def __iter__(self) -> Iterator[str]: + for name in dir(self._obj): + if not name.startswith('_'): + yield name + + def __len__(self) -> int: + return sum(1 for _ in self) + + def __contains__(self, item: Any) -> bool: + return item in self.keys() + + def __eq__(self, other: Any) -> bool: + return dict(self) == dict(other.items()) + + def __repr_args__(self) -> 'ReprArgs': + return [(None, dict(self))] + + def __repr_name__(self) -> str: + return f'GetterDict[{display_as_type(self._obj)}]' + + +class ValueItems(Representation): + """ + Class for more convenient calculation of excluded or included fields on values. + """ + + __slots__ = ('_items', '_type') + + def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> None: + items = self._coerce_items(items) + + if isinstance(value, (list, tuple)): + items = self._normalize_indexes(items, len(value)) + + self._items: 'MappingIntStrAny' = items + + def is_excluded(self, item: Any) -> bool: + """ + Check if item is fully excluded. + + :param item: key or index of a value + """ + return self.is_true(self._items.get(item)) + + def is_included(self, item: Any) -> bool: + """ + Check if value is contained in self._items + + :param item: key or index of value + """ + return item in self._items + + def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]: + """ + :param e: key or index of element on value + :return: raw values for element if self._items is dict and contain needed element + """ + + item = self._items.get(e) + return item if not self.is_true(item) else None + + def _normalize_indexes(self, items: 'MappingIntStrAny', v_length: int) -> 'DictIntStrAny': + """ + :param items: dict or set of indexes which will be normalized + :param v_length: length of sequence indexes of which will be + + >>> self._normalize_indexes({0: True, -2: True, -1: True}, 4) + {0: True, 2: True, 3: True} + >>> self._normalize_indexes({'__all__': True}, 4) + {0: True, 1: True, 2: True, 3: True} + """ + + normalized_items: 'DictIntStrAny' = {} + all_items = None + for i, v in items.items(): + if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or self.is_true(v)): + raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}') + if i == '__all__': + all_items = self._coerce_value(v) + continue + if not isinstance(i, int): + raise TypeError( + 'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: ' + 'expected integer keys or keyword "__all__"' + ) + normalized_i = v_length + i if i < 0 else i + normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i)) + + if not all_items: + return normalized_items + if self.is_true(all_items): + for i in range(v_length): + normalized_items.setdefault(i, ...) + return normalized_items + for i in range(v_length): + normalized_item = normalized_items.setdefault(i, {}) + if not self.is_true(normalized_item): + normalized_items[i] = self.merge(all_items, normalized_item) + return normalized_items + + @classmethod + def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any: + """ + Merge a ``base`` item with an ``override`` item. + + Both ``base`` and ``override`` are converted to dictionaries if possible. + Sets are converted to dictionaries with the sets entries as keys and + Ellipsis as values. + + Each key-value pair existing in ``base`` is merged with ``override``, + while the rest of the key-value pairs are updated recursively with this function. + + Merging takes place based on the "union" of keys if ``intersect`` is + set to ``False`` (default) and on the intersection of keys if + ``intersect`` is set to ``True``. + """ + override = cls._coerce_value(override) + base = cls._coerce_value(base) + if override is None: + return base + if cls.is_true(base) or base is None: + return override + if cls.is_true(override): + return base if intersect else override + + # intersection or union of keys while preserving ordering: + if intersect: + merge_keys = [k for k in base if k in override] + [k for k in override if k in base] + else: + merge_keys = list(base) + [k for k in override if k not in base] + + merged: 'DictIntStrAny' = {} + for k in merge_keys: + merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect) + if merged_item is not None: + merged[k] = merged_item + + return merged + + @staticmethod + def _coerce_items(items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> 'MappingIntStrAny': + if isinstance(items, Mapping): + pass + elif isinstance(items, AbstractSet): + items = dict.fromkeys(items, ...) + else: + class_name = getattr(items, '__class__', '???') + assert_never( + items, + f'Unexpected type of exclude value {class_name}', + ) + return items + + @classmethod + def _coerce_value(cls, value: Any) -> Any: + if value is None or cls.is_true(value): + return value + return cls._coerce_items(value) + + @staticmethod + def is_true(v: Any) -> bool: + return v is True or v is ... + + def __repr_args__(self) -> 'ReprArgs': + return [(None, self._items)] + + +class ClassAttribute: + """ + Hide class attribute from its instances + """ + + __slots__ = ( + 'name', + 'value', + ) + + def __init__(self, name: str, value: Any) -> None: + self.name = name + self.value = value + + def __get__(self, instance: Any, owner: Type[Any]) -> None: + if instance is None: + return self.value + raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only') + + +path_types = { + 'is_dir': 'directory', + 'is_file': 'file', + 'is_mount': 'mount point', + 'is_symlink': 'symlink', + 'is_block_device': 'block device', + 'is_char_device': 'char device', + 'is_fifo': 'FIFO', + 'is_socket': 'socket', +} + + +def path_type(p: 'Path') -> str: + """ + Find out what sort of thing a path is. + """ + assert p.exists(), 'path does not exist' + for method, name in path_types.items(): + if getattr(p, method)(): + return name + + return 'unknown' + + +Obj = TypeVar('Obj') + + +def smart_deepcopy(obj: Obj) -> Obj: + """ + Return type as is for immutable built-in types + Use obj.copy() for built-in empty collections + Use copy.deepcopy() for non-empty collections and unknown objects + """ + + obj_type = obj.__class__ + if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES: + return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway + try: + if not obj and obj_type in BUILTIN_COLLECTIONS: + # faster way for empty collections, no need to copy its members + return obj if obj_type is tuple else obj.copy() # type: ignore # tuple doesn't have copy method + except (TypeError, ValueError, RuntimeError): + # do we really dare to catch ALL errors? Seems a bit risky + pass + + return deepcopy(obj) # slowest way when we actually might need a deepcopy + + +def is_valid_field(name: str) -> bool: + if not name.startswith('_'): + return True + return ROOT_KEY == name + + +DUNDER_ATTRIBUTES = { + '__annotations__', + '__classcell__', + '__doc__', + '__module__', + '__orig_bases__', + '__orig_class__', + '__qualname__', +} + + +def is_valid_private_name(name: str) -> bool: + return not is_valid_field(name) and name not in DUNDER_ATTRIBUTES + + +_EMPTY = object() + + +def all_identical(left: Iterable[Any], right: Iterable[Any]) -> bool: + """ + Check that the items of `left` are the same objects as those in `right`. + + >>> a, b = object(), object() + >>> all_identical([a, b, a], [a, b, a]) + True + >>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical" + False + """ + for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY): + if left_item is not right_item: + return False + return True + + +def assert_never(obj: NoReturn, msg: str) -> NoReturn: + """ + Helper to make sure that we have covered all possible types. + + This is mostly useful for ``mypy``, docs: + https://mypy.readthedocs.io/en/latest/literal_types.html#exhaustive-checks + """ + raise TypeError(msg) + + +def get_unique_discriminator_alias(all_aliases: Collection[str], discriminator_key: str) -> str: + """Validate that all aliases are the same and if that's the case return the alias""" + unique_aliases = set(all_aliases) + if len(unique_aliases) > 1: + raise ConfigError( + f'Aliases for discriminator {discriminator_key!r} must be the same (got {", ".join(sorted(all_aliases))})' + ) + return unique_aliases.pop() + + +def get_discriminator_alias_and_values(tp: Any, discriminator_key: str) -> Tuple[str, Tuple[str, ...]]: + """ + Get alias and all valid values in the `Literal` type of the discriminator field + `tp` can be a `BaseModel` class or directly an `Annotated` `Union` of many. + """ + is_root_model = getattr(tp, '__custom_root_type__', False) + + if get_origin(tp) is Annotated: + tp = get_args(tp)[0] + + if hasattr(tp, '__pydantic_model__'): + tp = tp.__pydantic_model__ + + if is_union(get_origin(tp)): + alias, all_values = _get_union_alias_and_all_values(tp, discriminator_key) + return alias, tuple(v for values in all_values for v in values) + elif is_root_model: + union_type = tp.__fields__[ROOT_KEY].type_ + alias, all_values = _get_union_alias_and_all_values(union_type, discriminator_key) + + if len(set(all_values)) > 1: + raise ConfigError( + f'Field {discriminator_key!r} is not the same for all submodels of {display_as_type(tp)!r}' + ) + + return alias, all_values[0] + + else: + try: + t_discriminator_type = tp.__fields__[discriminator_key].type_ + except AttributeError as e: + raise TypeError(f'Type {tp.__name__!r} is not a valid `BaseModel` or `dataclass`') from e + except KeyError as e: + raise ConfigError(f'Model {tp.__name__!r} needs a discriminator field for key {discriminator_key!r}') from e + + if not is_literal_type(t_discriminator_type): + raise ConfigError(f'Field {discriminator_key!r} of model {tp.__name__!r} needs to be a `Literal`') + + return tp.__fields__[discriminator_key].alias, all_literal_values(t_discriminator_type) + + +def _get_union_alias_and_all_values( + union_type: Type[Any], discriminator_key: str +) -> Tuple[str, Tuple[Tuple[str, ...], ...]]: + zipped_aliases_values = [get_discriminator_alias_and_values(t, discriminator_key) for t in get_args(union_type)] + # unzip: [('alias_a',('v1', 'v2)), ('alias_b', ('v3',))] => [('alias_a', 'alias_b'), (('v1', 'v2'), ('v3',))] + all_aliases, all_values = zip(*zipped_aliases_values) + return get_unique_discriminator_alias(all_aliases, discriminator_key), all_values + + +KT = TypeVar('KT') +VT = TypeVar('VT') +if TYPE_CHECKING: + # Annoying inheriting from `MutableMapping` and `dict` breaks cython, hence this work around + class LimitedDict(dict, MutableMapping[KT, VT]): # type: ignore[type-arg] + def __init__(self, size_limit: int = 1000): + ... + +else: + + class LimitedDict(dict): + """ + Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage. + + Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache. + + Annoying inheriting from `MutableMapping` breaks cython. + """ + + def __init__(self, size_limit: int = 1000): + self.size_limit = size_limit + super().__init__() + + def __setitem__(self, __key: Any, __value: Any) -> None: + super().__setitem__(__key, __value) + if len(self) > self.size_limit: + excess = len(self) - self.size_limit + self.size_limit // 10 + to_remove = list(self.keys())[:excess] + for key in to_remove: + del self[key] + + def __class_getitem__(cls, *args: Any) -> Any: + # to avoid errors with 3.7 + pass diff --git a/libs/win/pydantic/validators.cp37-win_amd64.pyd b/libs/win/pydantic/validators.cp37-win_amd64.pyd new file mode 100644 index 00000000..a59f4be2 Binary files /dev/null and b/libs/win/pydantic/validators.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/validators.py b/libs/win/pydantic/validators.py new file mode 100644 index 00000000..fb6d0418 --- /dev/null +++ b/libs/win/pydantic/validators.py @@ -0,0 +1,765 @@ +import math +import re +from collections import OrderedDict, deque +from collections.abc import Hashable as CollectionsHashable +from datetime import date, datetime, time, timedelta +from decimal import Decimal, DecimalException +from enum import Enum, IntEnum +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Deque, + Dict, + ForwardRef, + FrozenSet, + Generator, + Hashable, + List, + NamedTuple, + Pattern, + Set, + Tuple, + Type, + TypeVar, + Union, +) +from uuid import UUID + +from . import errors +from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time +from .typing import ( + AnyCallable, + all_literal_values, + display_as_type, + get_class, + is_callable_type, + is_literal_type, + is_namedtuple, + is_none_type, + is_typeddict, +) +from .utils import almost_equal_floats, lenient_issubclass, sequence_like + +if TYPE_CHECKING: + from typing_extensions import Literal, TypedDict + + from .config import BaseConfig + from .fields import ModelField + from .types import ConstrainedDecimal, ConstrainedFloat, ConstrainedInt + + ConstrainedNumber = Union[ConstrainedDecimal, ConstrainedFloat, ConstrainedInt] + AnyOrderedDict = OrderedDict[Any, Any] + Number = Union[int, float, Decimal] + StrBytes = Union[str, bytes] + + +def str_validator(v: Any) -> Union[str]: + if isinstance(v, str): + if isinstance(v, Enum): + return v.value + else: + return v + elif isinstance(v, (float, int, Decimal)): + # is there anything else we want to add here? If you think so, create an issue. + return str(v) + elif isinstance(v, (bytes, bytearray)): + return v.decode() + else: + raise errors.StrError() + + +def strict_str_validator(v: Any) -> Union[str]: + if isinstance(v, str) and not isinstance(v, Enum): + return v + raise errors.StrError() + + +def bytes_validator(v: Any) -> Union[bytes]: + if isinstance(v, bytes): + return v + elif isinstance(v, bytearray): + return bytes(v) + elif isinstance(v, str): + return v.encode() + elif isinstance(v, (float, int, Decimal)): + return str(v).encode() + else: + raise errors.BytesError() + + +def strict_bytes_validator(v: Any) -> Union[bytes]: + if isinstance(v, bytes): + return v + elif isinstance(v, bytearray): + return bytes(v) + else: + raise errors.BytesError() + + +BOOL_FALSE = {0, '0', 'off', 'f', 'false', 'n', 'no'} +BOOL_TRUE = {1, '1', 'on', 't', 'true', 'y', 'yes'} + + +def bool_validator(v: Any) -> bool: + if v is True or v is False: + return v + if isinstance(v, bytes): + v = v.decode() + if isinstance(v, str): + v = v.lower() + try: + if v in BOOL_TRUE: + return True + if v in BOOL_FALSE: + return False + except TypeError: + raise errors.BoolError() + raise errors.BoolError() + + +# matches the default limit cpython, see https://github.com/python/cpython/pull/96500 +max_str_int = 4_300 + + +def int_validator(v: Any) -> int: + if isinstance(v, int) and not (v is True or v is False): + return v + + # see https://github.com/pydantic/pydantic/issues/1477 and in turn, https://github.com/python/cpython/issues/95778 + # this check should be unnecessary once patch releases are out for 3.7, 3.8, 3.9 and 3.10 + # but better to check here until then. + # NOTICE: this does not fully protect user from the DOS risk since the standard library JSON implementation + # (and other std lib modules like xml) use `int()` and are likely called before this, the best workaround is to + # 1. update to the latest patch release of python once released, 2. use a different JSON library like ujson + if isinstance(v, (str, bytes, bytearray)) and len(v) > max_str_int: + raise errors.IntegerError() + + try: + return int(v) + except (TypeError, ValueError, OverflowError): + raise errors.IntegerError() + + +def strict_int_validator(v: Any) -> int: + if isinstance(v, int) and not (v is True or v is False): + return v + raise errors.IntegerError() + + +def float_validator(v: Any) -> float: + if isinstance(v, float): + return v + + try: + return float(v) + except (TypeError, ValueError): + raise errors.FloatError() + + +def strict_float_validator(v: Any) -> float: + if isinstance(v, float): + return v + raise errors.FloatError() + + +def float_finite_validator(v: 'Number', field: 'ModelField', config: 'BaseConfig') -> 'Number': + allow_inf_nan = getattr(field.type_, 'allow_inf_nan', None) + if allow_inf_nan is None: + allow_inf_nan = config.allow_inf_nan + + if allow_inf_nan is False and (math.isnan(v) or math.isinf(v)): + raise errors.NumberNotFiniteError() + return v + + +def number_multiple_validator(v: 'Number', field: 'ModelField') -> 'Number': + field_type: ConstrainedNumber = field.type_ + if field_type.multiple_of is not None: + mod = float(v) / float(field_type.multiple_of) % 1 + if not almost_equal_floats(mod, 0.0) and not almost_equal_floats(mod, 1.0): + raise errors.NumberNotMultipleError(multiple_of=field_type.multiple_of) + return v + + +def number_size_validator(v: 'Number', field: 'ModelField') -> 'Number': + field_type: ConstrainedNumber = field.type_ + if field_type.gt is not None and not v > field_type.gt: + raise errors.NumberNotGtError(limit_value=field_type.gt) + elif field_type.ge is not None and not v >= field_type.ge: + raise errors.NumberNotGeError(limit_value=field_type.ge) + + if field_type.lt is not None and not v < field_type.lt: + raise errors.NumberNotLtError(limit_value=field_type.lt) + if field_type.le is not None and not v <= field_type.le: + raise errors.NumberNotLeError(limit_value=field_type.le) + + return v + + +def constant_validator(v: 'Any', field: 'ModelField') -> 'Any': + """Validate ``const`` fields. + + The value provided for a ``const`` field must be equal to the default value + of the field. This is to support the keyword of the same name in JSON + Schema. + """ + if v != field.default: + raise errors.WrongConstantError(given=v, permitted=[field.default]) + + return v + + +def anystr_length_validator(v: 'StrBytes', config: 'BaseConfig') -> 'StrBytes': + v_len = len(v) + + min_length = config.min_anystr_length + if v_len < min_length: + raise errors.AnyStrMinLengthError(limit_value=min_length) + + max_length = config.max_anystr_length + if max_length is not None and v_len > max_length: + raise errors.AnyStrMaxLengthError(limit_value=max_length) + + return v + + +def anystr_strip_whitespace(v: 'StrBytes') -> 'StrBytes': + return v.strip() + + +def anystr_upper(v: 'StrBytes') -> 'StrBytes': + return v.upper() + + +def anystr_lower(v: 'StrBytes') -> 'StrBytes': + return v.lower() + + +def ordered_dict_validator(v: Any) -> 'AnyOrderedDict': + if isinstance(v, OrderedDict): + return v + + try: + return OrderedDict(v) + except (TypeError, ValueError): + raise errors.DictError() + + +def dict_validator(v: Any) -> Dict[Any, Any]: + if isinstance(v, dict): + return v + + try: + return dict(v) + except (TypeError, ValueError): + raise errors.DictError() + + +def list_validator(v: Any) -> List[Any]: + if isinstance(v, list): + return v + elif sequence_like(v): + return list(v) + else: + raise errors.ListError() + + +def tuple_validator(v: Any) -> Tuple[Any, ...]: + if isinstance(v, tuple): + return v + elif sequence_like(v): + return tuple(v) + else: + raise errors.TupleError() + + +def set_validator(v: Any) -> Set[Any]: + if isinstance(v, set): + return v + elif sequence_like(v): + return set(v) + else: + raise errors.SetError() + + +def frozenset_validator(v: Any) -> FrozenSet[Any]: + if isinstance(v, frozenset): + return v + elif sequence_like(v): + return frozenset(v) + else: + raise errors.FrozenSetError() + + +def deque_validator(v: Any) -> Deque[Any]: + if isinstance(v, deque): + return v + elif sequence_like(v): + return deque(v) + else: + raise errors.DequeError() + + +def enum_member_validator(v: Any, field: 'ModelField', config: 'BaseConfig') -> Enum: + try: + enum_v = field.type_(v) + except ValueError: + # field.type_ should be an enum, so will be iterable + raise errors.EnumMemberError(enum_values=list(field.type_)) + return enum_v.value if config.use_enum_values else enum_v + + +def uuid_validator(v: Any, field: 'ModelField') -> UUID: + try: + if isinstance(v, str): + v = UUID(v) + elif isinstance(v, (bytes, bytearray)): + try: + v = UUID(v.decode()) + except ValueError: + # 16 bytes in big-endian order as the bytes argument fail + # the above check + v = UUID(bytes=v) + except ValueError: + raise errors.UUIDError() + + if not isinstance(v, UUID): + raise errors.UUIDError() + + required_version = getattr(field.type_, '_required_version', None) + if required_version and v.version != required_version: + raise errors.UUIDVersionError(required_version=required_version) + + return v + + +def decimal_validator(v: Any) -> Decimal: + if isinstance(v, Decimal): + return v + elif isinstance(v, (bytes, bytearray)): + v = v.decode() + + v = str(v).strip() + + try: + v = Decimal(v) + except DecimalException: + raise errors.DecimalError() + + if not v.is_finite(): + raise errors.DecimalIsNotFiniteError() + + return v + + +def hashable_validator(v: Any) -> Hashable: + if isinstance(v, Hashable): + return v + + raise errors.HashableError() + + +def ip_v4_address_validator(v: Any) -> IPv4Address: + if isinstance(v, IPv4Address): + return v + + try: + return IPv4Address(v) + except ValueError: + raise errors.IPv4AddressError() + + +def ip_v6_address_validator(v: Any) -> IPv6Address: + if isinstance(v, IPv6Address): + return v + + try: + return IPv6Address(v) + except ValueError: + raise errors.IPv6AddressError() + + +def ip_v4_network_validator(v: Any) -> IPv4Network: + """ + Assume IPv4Network initialised with a default ``strict`` argument + + See more: + https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network + """ + if isinstance(v, IPv4Network): + return v + + try: + return IPv4Network(v) + except ValueError: + raise errors.IPv4NetworkError() + + +def ip_v6_network_validator(v: Any) -> IPv6Network: + """ + Assume IPv6Network initialised with a default ``strict`` argument + + See more: + https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network + """ + if isinstance(v, IPv6Network): + return v + + try: + return IPv6Network(v) + except ValueError: + raise errors.IPv6NetworkError() + + +def ip_v4_interface_validator(v: Any) -> IPv4Interface: + if isinstance(v, IPv4Interface): + return v + + try: + return IPv4Interface(v) + except ValueError: + raise errors.IPv4InterfaceError() + + +def ip_v6_interface_validator(v: Any) -> IPv6Interface: + if isinstance(v, IPv6Interface): + return v + + try: + return IPv6Interface(v) + except ValueError: + raise errors.IPv6InterfaceError() + + +def path_validator(v: Any) -> Path: + if isinstance(v, Path): + return v + + try: + return Path(v) + except TypeError: + raise errors.PathError() + + +def path_exists_validator(v: Any) -> Path: + if not v.exists(): + raise errors.PathNotExistsError(path=v) + + return v + + +def callable_validator(v: Any) -> AnyCallable: + """ + Perform a simple check if the value is callable. + + Note: complete matching of argument type hints and return types is not performed + """ + if callable(v): + return v + + raise errors.CallableError(value=v) + + +def enum_validator(v: Any) -> Enum: + if isinstance(v, Enum): + return v + + raise errors.EnumError(value=v) + + +def int_enum_validator(v: Any) -> IntEnum: + if isinstance(v, IntEnum): + return v + + raise errors.IntEnumError(value=v) + + +def make_literal_validator(type_: Any) -> Callable[[Any], Any]: + permitted_choices = all_literal_values(type_) + + # To have a O(1) complexity and still return one of the values set inside the `Literal`, + # we create a dict with the set values (a set causes some problems with the way intersection works). + # In some cases the set value and checked value can indeed be different (see `test_literal_validator_str_enum`) + allowed_choices = {v: v for v in permitted_choices} + + def literal_validator(v: Any) -> Any: + try: + return allowed_choices[v] + except KeyError: + raise errors.WrongConstantError(given=v, permitted=permitted_choices) + + return literal_validator + + +def constr_length_validator(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes': + v_len = len(v) + + min_length = field.type_.min_length if field.type_.min_length is not None else config.min_anystr_length + if v_len < min_length: + raise errors.AnyStrMinLengthError(limit_value=min_length) + + max_length = field.type_.max_length if field.type_.max_length is not None else config.max_anystr_length + if max_length is not None and v_len > max_length: + raise errors.AnyStrMaxLengthError(limit_value=max_length) + + return v + + +def constr_strip_whitespace(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes': + strip_whitespace = field.type_.strip_whitespace or config.anystr_strip_whitespace + if strip_whitespace: + v = v.strip() + + return v + + +def constr_upper(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes': + upper = field.type_.to_upper or config.anystr_upper + if upper: + v = v.upper() + + return v + + +def constr_lower(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes': + lower = field.type_.to_lower or config.anystr_lower + if lower: + v = v.lower() + return v + + +def validate_json(v: Any, config: 'BaseConfig') -> Any: + if v is None: + # pass None through to other validators + return v + try: + return config.json_loads(v) # type: ignore + except ValueError: + raise errors.JsonError() + except TypeError: + raise errors.JsonTypeError() + + +T = TypeVar('T') + + +def make_arbitrary_type_validator(type_: Type[T]) -> Callable[[T], T]: + def arbitrary_type_validator(v: Any) -> T: + if isinstance(v, type_): + return v + raise errors.ArbitraryTypeError(expected_arbitrary_type=type_) + + return arbitrary_type_validator + + +def make_class_validator(type_: Type[T]) -> Callable[[Any], Type[T]]: + def class_validator(v: Any) -> Type[T]: + if lenient_issubclass(v, type_): + return v + raise errors.SubclassError(expected_class=type_) + + return class_validator + + +def any_class_validator(v: Any) -> Type[T]: + if isinstance(v, type): + return v + raise errors.ClassError() + + +def none_validator(v: Any) -> 'Literal[None]': + if v is None: + return v + raise errors.NotNoneError() + + +def pattern_validator(v: Any) -> Pattern[str]: + if isinstance(v, Pattern): + return v + + str_value = str_validator(v) + + try: + return re.compile(str_value) + except re.error: + raise errors.PatternError() + + +NamedTupleT = TypeVar('NamedTupleT', bound=NamedTuple) + + +def make_namedtuple_validator( + namedtuple_cls: Type[NamedTupleT], config: Type['BaseConfig'] +) -> Callable[[Tuple[Any, ...]], NamedTupleT]: + from .annotated_types import create_model_from_namedtuple + + NamedTupleModel = create_model_from_namedtuple( + namedtuple_cls, + __config__=config, + __module__=namedtuple_cls.__module__, + ) + namedtuple_cls.__pydantic_model__ = NamedTupleModel # type: ignore[attr-defined] + + def namedtuple_validator(values: Tuple[Any, ...]) -> NamedTupleT: + annotations = NamedTupleModel.__annotations__ + + if len(values) > len(annotations): + raise errors.ListMaxLengthError(limit_value=len(annotations)) + + dict_values: Dict[str, Any] = dict(zip(annotations, values)) + validated_dict_values: Dict[str, Any] = dict(NamedTupleModel(**dict_values)) + return namedtuple_cls(**validated_dict_values) + + return namedtuple_validator + + +def make_typeddict_validator( + typeddict_cls: Type['TypedDict'], config: Type['BaseConfig'] # type: ignore[valid-type] +) -> Callable[[Any], Dict[str, Any]]: + from .annotated_types import create_model_from_typeddict + + TypedDictModel = create_model_from_typeddict( + typeddict_cls, + __config__=config, + __module__=typeddict_cls.__module__, + ) + typeddict_cls.__pydantic_model__ = TypedDictModel # type: ignore[attr-defined] + + def typeddict_validator(values: 'TypedDict') -> Dict[str, Any]: # type: ignore[valid-type] + return TypedDictModel.parse_obj(values).dict(exclude_unset=True) + + return typeddict_validator + + +class IfConfig: + def __init__(self, validator: AnyCallable, *config_attr_names: str, ignored_value: Any = False) -> None: + self.validator = validator + self.config_attr_names = config_attr_names + self.ignored_value = ignored_value + + def check(self, config: Type['BaseConfig']) -> bool: + return any(getattr(config, name) not in {None, self.ignored_value} for name in self.config_attr_names) + + +# order is important here, for example: bool is a subclass of int so has to come first, datetime before date same, +# IPv4Interface before IPv4Address, etc +_VALIDATORS: List[Tuple[Type[Any], List[Any]]] = [ + (IntEnum, [int_validator, enum_member_validator]), + (Enum, [enum_member_validator]), + ( + str, + [ + str_validator, + IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'), + IfConfig(anystr_upper, 'anystr_upper'), + IfConfig(anystr_lower, 'anystr_lower'), + IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'), + ], + ), + ( + bytes, + [ + bytes_validator, + IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'), + IfConfig(anystr_upper, 'anystr_upper'), + IfConfig(anystr_lower, 'anystr_lower'), + IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'), + ], + ), + (bool, [bool_validator]), + (int, [int_validator]), + (float, [float_validator, IfConfig(float_finite_validator, 'allow_inf_nan', ignored_value=True)]), + (Path, [path_validator]), + (datetime, [parse_datetime]), + (date, [parse_date]), + (time, [parse_time]), + (timedelta, [parse_duration]), + (OrderedDict, [ordered_dict_validator]), + (dict, [dict_validator]), + (list, [list_validator]), + (tuple, [tuple_validator]), + (set, [set_validator]), + (frozenset, [frozenset_validator]), + (deque, [deque_validator]), + (UUID, [uuid_validator]), + (Decimal, [decimal_validator]), + (IPv4Interface, [ip_v4_interface_validator]), + (IPv6Interface, [ip_v6_interface_validator]), + (IPv4Address, [ip_v4_address_validator]), + (IPv6Address, [ip_v6_address_validator]), + (IPv4Network, [ip_v4_network_validator]), + (IPv6Network, [ip_v6_network_validator]), +] + + +def find_validators( # noqa: C901 (ignore complexity) + type_: Type[Any], config: Type['BaseConfig'] +) -> Generator[AnyCallable, None, None]: + from .dataclasses import is_builtin_dataclass, make_dataclass_validator + + if type_ is Any or type_ is object: + return + type_type = type_.__class__ + if type_type == ForwardRef or type_type == TypeVar: + return + + if is_none_type(type_): + yield none_validator + return + if type_ is Pattern or type_ is re.Pattern: + yield pattern_validator + return + if type_ is Hashable or type_ is CollectionsHashable: + yield hashable_validator + return + if is_callable_type(type_): + yield callable_validator + return + if is_literal_type(type_): + yield make_literal_validator(type_) + return + if is_builtin_dataclass(type_): + yield from make_dataclass_validator(type_, config) + return + if type_ is Enum: + yield enum_validator + return + if type_ is IntEnum: + yield int_enum_validator + return + if is_namedtuple(type_): + yield tuple_validator + yield make_namedtuple_validator(type_, config) + return + if is_typeddict(type_): + yield make_typeddict_validator(type_, config) + return + + class_ = get_class(type_) + if class_ is not None: + if class_ is not Any and isinstance(class_, type): + yield make_class_validator(class_) + else: + yield any_class_validator + return + + for val_type, validators in _VALIDATORS: + try: + if issubclass(type_, val_type): + for v in validators: + if isinstance(v, IfConfig): + if v.check(config): + yield v.validator + else: + yield v + return + except TypeError: + raise RuntimeError(f'error checking inheritance of {type_!r} (type: {display_as_type(type_)})') + + if config.arbitrary_types_allowed: + yield make_arbitrary_type_validator(type_) + else: + raise RuntimeError(f'no validator found for {type_}, see `arbitrary_types_allowed` in Config') diff --git a/libs/win/pydantic/version.cp37-win_amd64.pyd b/libs/win/pydantic/version.cp37-win_amd64.pyd new file mode 100644 index 00000000..91395862 Binary files /dev/null and b/libs/win/pydantic/version.cp37-win_amd64.pyd differ diff --git a/libs/win/pydantic/version.py b/libs/win/pydantic/version.py new file mode 100644 index 00000000..32c61633 --- /dev/null +++ b/libs/win/pydantic/version.py @@ -0,0 +1,38 @@ +__all__ = 'compiled', 'VERSION', 'version_info' + +VERSION = '1.10.2' + +try: + import cython # type: ignore +except ImportError: + compiled: bool = False +else: # pragma: no cover + try: + compiled = cython.compiled + except AttributeError: + compiled = False + + +def version_info() -> str: + import platform + import sys + from importlib import import_module + from pathlib import Path + + optional_deps = [] + for p in ('devtools', 'dotenv', 'email-validator', 'typing-extensions'): + try: + import_module(p.replace('-', '_')) + except ImportError: + continue + optional_deps.append(p) + + info = { + 'pydantic version': VERSION, + 'pydantic compiled': compiled, + 'install path': Path(__file__).resolve().parent, + 'python version': sys.version, + 'platform': platform.platform(), + 'optional deps. installed': optional_deps, + } + return '\n'.join('{:>30} {}'.format(k + ':', str(v).replace('\n', ' ')) for k, v in info.items()) diff --git a/libs/win/scripts/watch-changes.py b/libs/win/scripts/watch-changes.py new file mode 100644 index 00000000..72b48686 --- /dev/null +++ b/libs/win/scripts/watch-changes.py @@ -0,0 +1,35 @@ +import traceback +import sys +import logging + +from jaraco.windows.filesystem import change + +logging.basicConfig(level=logging.INFO) + + +def long_handler(file): + try: + with open(file, 'rb') as f: + data = f.read() + print("read", len(data), "bytes from", file) + except Exception: + traceback.print_exc() + + +def main(): + try: + watch() + except KeyboardInterrupt: + pass + + +def watch(): + notifier = change.BlockingNotifier(sys.argv[1]) + notifier.watch_subtree = True + + for ch in notifier.get_changed_files(): + long_handler(ch) + + +if __name__ == '__main__': + main() diff --git a/libs/win/six.py b/libs/win/six.py deleted file mode 100644 index 89b2188f..00000000 --- a/libs/win/six.py +++ /dev/null @@ -1,952 +0,0 @@ -# Copyright (c) 2010-2018 Benjamin Peterson -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -"""Utilities for writing code that runs on Python 2 and 3""" - -from __future__ import absolute_import - -import functools -import itertools -import operator -import sys -import types - -__author__ = "Benjamin Peterson " -__version__ = "1.12.0" - - -# Useful for very coarse version differentiation. -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 -PY34 = sys.version_info[0:2] >= (3, 4) - -if PY3: - string_types = str, - integer_types = int, - class_types = type, - text_type = str - binary_type = bytes - - MAXSIZE = sys.maxsize -else: - string_types = basestring, - integer_types = (int, long) - class_types = (type, types.ClassType) - text_type = unicode - binary_type = str - - if sys.platform.startswith("java"): - # Jython always uses 32 bits. - MAXSIZE = int((1 << 31) - 1) - else: - # It's possible to have sizeof(long) != sizeof(Py_ssize_t). - class X(object): - - def __len__(self): - return 1 << 31 - try: - len(X()) - except OverflowError: - # 32-bit - MAXSIZE = int((1 << 31) - 1) - else: - # 64-bit - MAXSIZE = int((1 << 63) - 1) - del X - - -def _add_doc(func, doc): - """Add documentation to a function.""" - func.__doc__ = doc - - -def _import_module(name): - """Import module, returning the module after the last dot.""" - __import__(name) - return sys.modules[name] - - -class _LazyDescr(object): - - def __init__(self, name): - self.name = name - - def __get__(self, obj, tp): - result = self._resolve() - setattr(obj, self.name, result) # Invokes __set__. - try: - # This is a bit ugly, but it avoids running this again by - # removing this descriptor. - delattr(obj.__class__, self.name) - except AttributeError: - pass - return result - - -class MovedModule(_LazyDescr): - - def __init__(self, name, old, new=None): - super(MovedModule, self).__init__(name) - if PY3: - if new is None: - new = name - self.mod = new - else: - self.mod = old - - def _resolve(self): - return _import_module(self.mod) - - def __getattr__(self, attr): - _module = self._resolve() - value = getattr(_module, attr) - setattr(self, attr, value) - return value - - -class _LazyModule(types.ModuleType): - - def __init__(self, name): - super(_LazyModule, self).__init__(name) - self.__doc__ = self.__class__.__doc__ - - def __dir__(self): - attrs = ["__doc__", "__name__"] - attrs += [attr.name for attr in self._moved_attributes] - return attrs - - # Subclasses should override this - _moved_attributes = [] - - -class MovedAttribute(_LazyDescr): - - def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): - super(MovedAttribute, self).__init__(name) - if PY3: - if new_mod is None: - new_mod = name - self.mod = new_mod - if new_attr is None: - if old_attr is None: - new_attr = name - else: - new_attr = old_attr - self.attr = new_attr - else: - self.mod = old_mod - if old_attr is None: - old_attr = name - self.attr = old_attr - - def _resolve(self): - module = _import_module(self.mod) - return getattr(module, self.attr) - - -class _SixMetaPathImporter(object): - - """ - A meta path importer to import six.moves and its submodules. - - This class implements a PEP302 finder and loader. It should be compatible - with Python 2.5 and all existing versions of Python3 - """ - - def __init__(self, six_module_name): - self.name = six_module_name - self.known_modules = {} - - def _add_module(self, mod, *fullnames): - for fullname in fullnames: - self.known_modules[self.name + "." + fullname] = mod - - def _get_module(self, fullname): - return self.known_modules[self.name + "." + fullname] - - def find_module(self, fullname, path=None): - if fullname in self.known_modules: - return self - return None - - def __get_module(self, fullname): - try: - return self.known_modules[fullname] - except KeyError: - raise ImportError("This loader does not know module " + fullname) - - def load_module(self, fullname): - try: - # in case of a reload - return sys.modules[fullname] - except KeyError: - pass - mod = self.__get_module(fullname) - if isinstance(mod, MovedModule): - mod = mod._resolve() - else: - mod.__loader__ = self - sys.modules[fullname] = mod - return mod - - def is_package(self, fullname): - """ - Return true, if the named module is a package. - - We need this method to get correct spec objects with - Python 3.4 (see PEP451) - """ - return hasattr(self.__get_module(fullname), "__path__") - - def get_code(self, fullname): - """Return None - - Required, if is_package is implemented""" - self.__get_module(fullname) # eventually raises ImportError - return None - get_source = get_code # same as get_code - -_importer = _SixMetaPathImporter(__name__) - - -class _MovedItems(_LazyModule): - - """Lazy loading of moved objects""" - __path__ = [] # mark as package - - -_moved_attributes = [ - MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), - MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), - MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"), - MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), - MovedAttribute("intern", "__builtin__", "sys"), - MovedAttribute("map", "itertools", "builtins", "imap", "map"), - MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"), - MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"), - MovedAttribute("getoutput", "commands", "subprocess"), - MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"), - MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"), - MovedAttribute("reduce", "__builtin__", "functools"), - MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), - MovedAttribute("StringIO", "StringIO", "io"), - MovedAttribute("UserDict", "UserDict", "collections"), - MovedAttribute("UserList", "UserList", "collections"), - MovedAttribute("UserString", "UserString", "collections"), - MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), - MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), - MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), - MovedModule("builtins", "__builtin__"), - MovedModule("configparser", "ConfigParser"), - MovedModule("copyreg", "copy_reg"), - MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), - MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread"), - MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), - MovedModule("http_cookies", "Cookie", "http.cookies"), - MovedModule("html_entities", "htmlentitydefs", "html.entities"), - MovedModule("html_parser", "HTMLParser", "html.parser"), - MovedModule("http_client", "httplib", "http.client"), - MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"), - MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"), - MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"), - MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"), - MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"), - MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), - MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), - MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), - MovedModule("cPickle", "cPickle", "pickle"), - MovedModule("queue", "Queue"), - MovedModule("reprlib", "repr"), - MovedModule("socketserver", "SocketServer"), - MovedModule("_thread", "thread", "_thread"), - MovedModule("tkinter", "Tkinter"), - MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"), - MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"), - MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"), - MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"), - MovedModule("tkinter_tix", "Tix", "tkinter.tix"), - MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"), - MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), - MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), - MovedModule("tkinter_colorchooser", "tkColorChooser", - "tkinter.colorchooser"), - MovedModule("tkinter_commondialog", "tkCommonDialog", - "tkinter.commondialog"), - MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), - MovedModule("tkinter_font", "tkFont", "tkinter.font"), - MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), - MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", - "tkinter.simpledialog"), - MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"), - MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"), - MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"), - MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"), - MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"), - MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"), -] -# Add windows specific modules. -if sys.platform == "win32": - _moved_attributes += [ - MovedModule("winreg", "_winreg"), - ] - -for attr in _moved_attributes: - setattr(_MovedItems, attr.name, attr) - if isinstance(attr, MovedModule): - _importer._add_module(attr, "moves." + attr.name) -del attr - -_MovedItems._moved_attributes = _moved_attributes - -moves = _MovedItems(__name__ + ".moves") -_importer._add_module(moves, "moves") - - -class Module_six_moves_urllib_parse(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_parse""" - - -_urllib_parse_moved_attributes = [ - MovedAttribute("ParseResult", "urlparse", "urllib.parse"), - MovedAttribute("SplitResult", "urlparse", "urllib.parse"), - MovedAttribute("parse_qs", "urlparse", "urllib.parse"), - MovedAttribute("parse_qsl", "urlparse", "urllib.parse"), - MovedAttribute("urldefrag", "urlparse", "urllib.parse"), - MovedAttribute("urljoin", "urlparse", "urllib.parse"), - MovedAttribute("urlparse", "urlparse", "urllib.parse"), - MovedAttribute("urlsplit", "urlparse", "urllib.parse"), - MovedAttribute("urlunparse", "urlparse", "urllib.parse"), - MovedAttribute("urlunsplit", "urlparse", "urllib.parse"), - MovedAttribute("quote", "urllib", "urllib.parse"), - MovedAttribute("quote_plus", "urllib", "urllib.parse"), - MovedAttribute("unquote", "urllib", "urllib.parse"), - MovedAttribute("unquote_plus", "urllib", "urllib.parse"), - MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"), - MovedAttribute("urlencode", "urllib", "urllib.parse"), - MovedAttribute("splitquery", "urllib", "urllib.parse"), - MovedAttribute("splittag", "urllib", "urllib.parse"), - MovedAttribute("splituser", "urllib", "urllib.parse"), - MovedAttribute("splitvalue", "urllib", "urllib.parse"), - MovedAttribute("uses_fragment", "urlparse", "urllib.parse"), - MovedAttribute("uses_netloc", "urlparse", "urllib.parse"), - MovedAttribute("uses_params", "urlparse", "urllib.parse"), - MovedAttribute("uses_query", "urlparse", "urllib.parse"), - MovedAttribute("uses_relative", "urlparse", "urllib.parse"), -] -for attr in _urllib_parse_moved_attributes: - setattr(Module_six_moves_urllib_parse, attr.name, attr) -del attr - -Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes - -_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), - "moves.urllib_parse", "moves.urllib.parse") - - -class Module_six_moves_urllib_error(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_error""" - - -_urllib_error_moved_attributes = [ - MovedAttribute("URLError", "urllib2", "urllib.error"), - MovedAttribute("HTTPError", "urllib2", "urllib.error"), - MovedAttribute("ContentTooShortError", "urllib", "urllib.error"), -] -for attr in _urllib_error_moved_attributes: - setattr(Module_six_moves_urllib_error, attr.name, attr) -del attr - -Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes - -_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), - "moves.urllib_error", "moves.urllib.error") - - -class Module_six_moves_urllib_request(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_request""" - - -_urllib_request_moved_attributes = [ - MovedAttribute("urlopen", "urllib2", "urllib.request"), - MovedAttribute("install_opener", "urllib2", "urllib.request"), - MovedAttribute("build_opener", "urllib2", "urllib.request"), - MovedAttribute("pathname2url", "urllib", "urllib.request"), - MovedAttribute("url2pathname", "urllib", "urllib.request"), - MovedAttribute("getproxies", "urllib", "urllib.request"), - MovedAttribute("Request", "urllib2", "urllib.request"), - MovedAttribute("OpenerDirector", "urllib2", "urllib.request"), - MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"), - MovedAttribute("ProxyHandler", "urllib2", "urllib.request"), - MovedAttribute("BaseHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"), - MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"), - MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"), - MovedAttribute("FileHandler", "urllib2", "urllib.request"), - MovedAttribute("FTPHandler", "urllib2", "urllib.request"), - MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"), - MovedAttribute("UnknownHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"), - MovedAttribute("urlretrieve", "urllib", "urllib.request"), - MovedAttribute("urlcleanup", "urllib", "urllib.request"), - MovedAttribute("URLopener", "urllib", "urllib.request"), - MovedAttribute("FancyURLopener", "urllib", "urllib.request"), - MovedAttribute("proxy_bypass", "urllib", "urllib.request"), - MovedAttribute("parse_http_list", "urllib2", "urllib.request"), - MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"), -] -for attr in _urllib_request_moved_attributes: - setattr(Module_six_moves_urllib_request, attr.name, attr) -del attr - -Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes - -_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), - "moves.urllib_request", "moves.urllib.request") - - -class Module_six_moves_urllib_response(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_response""" - - -_urllib_response_moved_attributes = [ - MovedAttribute("addbase", "urllib", "urllib.response"), - MovedAttribute("addclosehook", "urllib", "urllib.response"), - MovedAttribute("addinfo", "urllib", "urllib.response"), - MovedAttribute("addinfourl", "urllib", "urllib.response"), -] -for attr in _urllib_response_moved_attributes: - setattr(Module_six_moves_urllib_response, attr.name, attr) -del attr - -Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes - -_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), - "moves.urllib_response", "moves.urllib.response") - - -class Module_six_moves_urllib_robotparser(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_robotparser""" - - -_urllib_robotparser_moved_attributes = [ - MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"), -] -for attr in _urllib_robotparser_moved_attributes: - setattr(Module_six_moves_urllib_robotparser, attr.name, attr) -del attr - -Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes - -_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), - "moves.urllib_robotparser", "moves.urllib.robotparser") - - -class Module_six_moves_urllib(types.ModuleType): - - """Create a six.moves.urllib namespace that resembles the Python 3 namespace""" - __path__ = [] # mark as package - parse = _importer._get_module("moves.urllib_parse") - error = _importer._get_module("moves.urllib_error") - request = _importer._get_module("moves.urllib_request") - response = _importer._get_module("moves.urllib_response") - robotparser = _importer._get_module("moves.urllib_robotparser") - - def __dir__(self): - return ['parse', 'error', 'request', 'response', 'robotparser'] - -_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"), - "moves.urllib") - - -def add_move(move): - """Add an item to six.moves.""" - setattr(_MovedItems, move.name, move) - - -def remove_move(name): - """Remove item from six.moves.""" - try: - delattr(_MovedItems, name) - except AttributeError: - try: - del moves.__dict__[name] - except KeyError: - raise AttributeError("no such move, %r" % (name,)) - - -if PY3: - _meth_func = "__func__" - _meth_self = "__self__" - - _func_closure = "__closure__" - _func_code = "__code__" - _func_defaults = "__defaults__" - _func_globals = "__globals__" -else: - _meth_func = "im_func" - _meth_self = "im_self" - - _func_closure = "func_closure" - _func_code = "func_code" - _func_defaults = "func_defaults" - _func_globals = "func_globals" - - -try: - advance_iterator = next -except NameError: - def advance_iterator(it): - return it.next() -next = advance_iterator - - -try: - callable = callable -except NameError: - def callable(obj): - return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) - - -if PY3: - def get_unbound_function(unbound): - return unbound - - create_bound_method = types.MethodType - - def create_unbound_method(func, cls): - return func - - Iterator = object -else: - def get_unbound_function(unbound): - return unbound.im_func - - def create_bound_method(func, obj): - return types.MethodType(func, obj, obj.__class__) - - def create_unbound_method(func, cls): - return types.MethodType(func, None, cls) - - class Iterator(object): - - def next(self): - return type(self).__next__(self) - - callable = callable -_add_doc(get_unbound_function, - """Get the function out of a possibly unbound function""") - - -get_method_function = operator.attrgetter(_meth_func) -get_method_self = operator.attrgetter(_meth_self) -get_function_closure = operator.attrgetter(_func_closure) -get_function_code = operator.attrgetter(_func_code) -get_function_defaults = operator.attrgetter(_func_defaults) -get_function_globals = operator.attrgetter(_func_globals) - - -if PY3: - def iterkeys(d, **kw): - return iter(d.keys(**kw)) - - def itervalues(d, **kw): - return iter(d.values(**kw)) - - def iteritems(d, **kw): - return iter(d.items(**kw)) - - def iterlists(d, **kw): - return iter(d.lists(**kw)) - - viewkeys = operator.methodcaller("keys") - - viewvalues = operator.methodcaller("values") - - viewitems = operator.methodcaller("items") -else: - def iterkeys(d, **kw): - return d.iterkeys(**kw) - - def itervalues(d, **kw): - return d.itervalues(**kw) - - def iteritems(d, **kw): - return d.iteritems(**kw) - - def iterlists(d, **kw): - return d.iterlists(**kw) - - viewkeys = operator.methodcaller("viewkeys") - - viewvalues = operator.methodcaller("viewvalues") - - viewitems = operator.methodcaller("viewitems") - -_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") -_add_doc(itervalues, "Return an iterator over the values of a dictionary.") -_add_doc(iteritems, - "Return an iterator over the (key, value) pairs of a dictionary.") -_add_doc(iterlists, - "Return an iterator over the (key, [values]) pairs of a dictionary.") - - -if PY3: - def b(s): - return s.encode("latin-1") - - def u(s): - return s - unichr = chr - import struct - int2byte = struct.Struct(">B").pack - del struct - byte2int = operator.itemgetter(0) - indexbytes = operator.getitem - iterbytes = iter - import io - StringIO = io.StringIO - BytesIO = io.BytesIO - _assertCountEqual = "assertCountEqual" - if sys.version_info[1] <= 1: - _assertRaisesRegex = "assertRaisesRegexp" - _assertRegex = "assertRegexpMatches" - else: - _assertRaisesRegex = "assertRaisesRegex" - _assertRegex = "assertRegex" -else: - def b(s): - return s - # Workaround for standalone backslash - - def u(s): - return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape") - unichr = unichr - int2byte = chr - - def byte2int(bs): - return ord(bs[0]) - - def indexbytes(buf, i): - return ord(buf[i]) - iterbytes = functools.partial(itertools.imap, ord) - import StringIO - StringIO = BytesIO = StringIO.StringIO - _assertCountEqual = "assertItemsEqual" - _assertRaisesRegex = "assertRaisesRegexp" - _assertRegex = "assertRegexpMatches" -_add_doc(b, """Byte literal""") -_add_doc(u, """Text literal""") - - -def assertCountEqual(self, *args, **kwargs): - return getattr(self, _assertCountEqual)(*args, **kwargs) - - -def assertRaisesRegex(self, *args, **kwargs): - return getattr(self, _assertRaisesRegex)(*args, **kwargs) - - -def assertRegex(self, *args, **kwargs): - return getattr(self, _assertRegex)(*args, **kwargs) - - -if PY3: - exec_ = getattr(moves.builtins, "exec") - - def reraise(tp, value, tb=None): - try: - if value is None: - value = tp() - if value.__traceback__ is not tb: - raise value.with_traceback(tb) - raise value - finally: - value = None - tb = None - -else: - def exec_(_code_, _globs_=None, _locs_=None): - """Execute code in a namespace.""" - if _globs_ is None: - frame = sys._getframe(1) - _globs_ = frame.f_globals - if _locs_ is None: - _locs_ = frame.f_locals - del frame - elif _locs_ is None: - _locs_ = _globs_ - exec("""exec _code_ in _globs_, _locs_""") - - exec_("""def reraise(tp, value, tb=None): - try: - raise tp, value, tb - finally: - tb = None -""") - - -if sys.version_info[:2] == (3, 2): - exec_("""def raise_from(value, from_value): - try: - if from_value is None: - raise value - raise value from from_value - finally: - value = None -""") -elif sys.version_info[:2] > (3, 2): - exec_("""def raise_from(value, from_value): - try: - raise value from from_value - finally: - value = None -""") -else: - def raise_from(value, from_value): - raise value - - -print_ = getattr(moves.builtins, "print", None) -if print_ is None: - def print_(*args, **kwargs): - """The new-style print function for Python 2.4 and 2.5.""" - fp = kwargs.pop("file", sys.stdout) - if fp is None: - return - - def write(data): - if not isinstance(data, basestring): - data = str(data) - # If the file has an encoding, encode unicode with it. - if (isinstance(fp, file) and - isinstance(data, unicode) and - fp.encoding is not None): - errors = getattr(fp, "errors", None) - if errors is None: - errors = "strict" - data = data.encode(fp.encoding, errors) - fp.write(data) - want_unicode = False - sep = kwargs.pop("sep", None) - if sep is not None: - if isinstance(sep, unicode): - want_unicode = True - elif not isinstance(sep, str): - raise TypeError("sep must be None or a string") - end = kwargs.pop("end", None) - if end is not None: - if isinstance(end, unicode): - want_unicode = True - elif not isinstance(end, str): - raise TypeError("end must be None or a string") - if kwargs: - raise TypeError("invalid keyword arguments to print()") - if not want_unicode: - for arg in args: - if isinstance(arg, unicode): - want_unicode = True - break - if want_unicode: - newline = unicode("\n") - space = unicode(" ") - else: - newline = "\n" - space = " " - if sep is None: - sep = space - if end is None: - end = newline - for i, arg in enumerate(args): - if i: - write(sep) - write(arg) - write(end) -if sys.version_info[:2] < (3, 3): - _print = print_ - - def print_(*args, **kwargs): - fp = kwargs.get("file", sys.stdout) - flush = kwargs.pop("flush", False) - _print(*args, **kwargs) - if flush and fp is not None: - fp.flush() - -_add_doc(reraise, """Reraise an exception.""") - -if sys.version_info[0:2] < (3, 4): - def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, - updated=functools.WRAPPER_UPDATES): - def wrapper(f): - f = functools.wraps(wrapped, assigned, updated)(f) - f.__wrapped__ = wrapped - return f - return wrapper -else: - wraps = functools.wraps - - -def with_metaclass(meta, *bases): - """Create a base class with a metaclass.""" - # This requires a bit of explanation: the basic idea is to make a dummy - # metaclass for one level of class instantiation that replaces itself with - # the actual metaclass. - class metaclass(type): - - def __new__(cls, name, this_bases, d): - return meta(name, bases, d) - - @classmethod - def __prepare__(cls, name, this_bases): - return meta.__prepare__(name, bases) - return type.__new__(metaclass, 'temporary_class', (), {}) - - -def add_metaclass(metaclass): - """Class decorator for creating a class with a metaclass.""" - def wrapper(cls): - orig_vars = cls.__dict__.copy() - slots = orig_vars.get('__slots__') - if slots is not None: - if isinstance(slots, str): - slots = [slots] - for slots_var in slots: - orig_vars.pop(slots_var) - orig_vars.pop('__dict__', None) - orig_vars.pop('__weakref__', None) - if hasattr(cls, '__qualname__'): - orig_vars['__qualname__'] = cls.__qualname__ - return metaclass(cls.__name__, cls.__bases__, orig_vars) - return wrapper - - -def ensure_binary(s, encoding='utf-8', errors='strict'): - """Coerce **s** to six.binary_type. - - For Python 2: - - `unicode` -> encoded to `str` - - `str` -> `str` - - For Python 3: - - `str` -> encoded to `bytes` - - `bytes` -> `bytes` - """ - if isinstance(s, text_type): - return s.encode(encoding, errors) - elif isinstance(s, binary_type): - return s - else: - raise TypeError("not expecting type '%s'" % type(s)) - - -def ensure_str(s, encoding='utf-8', errors='strict'): - """Coerce *s* to `str`. - - For Python 2: - - `unicode` -> encoded to `str` - - `str` -> `str` - - For Python 3: - - `str` -> `str` - - `bytes` -> decoded to `str` - """ - if not isinstance(s, (text_type, binary_type)): - raise TypeError("not expecting type '%s'" % type(s)) - if PY2 and isinstance(s, text_type): - s = s.encode(encoding, errors) - elif PY3 and isinstance(s, binary_type): - s = s.decode(encoding, errors) - return s - - -def ensure_text(s, encoding='utf-8', errors='strict'): - """Coerce *s* to six.text_type. - - For Python 2: - - `unicode` -> `unicode` - - `str` -> `unicode` - - For Python 3: - - `str` -> `str` - - `bytes` -> decoded to `str` - """ - if isinstance(s, binary_type): - return s.decode(encoding, errors) - elif isinstance(s, text_type): - return s - else: - raise TypeError("not expecting type '%s'" % type(s)) - - - -def python_2_unicode_compatible(klass): - """ - A decorator that defines __unicode__ and __str__ methods under Python 2. - Under Python 3 it does nothing. - - To support Python 2 and 3 with a single code base, define a __str__ method - returning text and apply this decorator to the class. - """ - if PY2: - if '__str__' not in klass.__dict__: - raise ValueError("@python_2_unicode_compatible cannot be applied " - "to %s because it doesn't define __str__()." % - klass.__name__) - klass.__unicode__ = klass.__str__ - klass.__str__ = lambda self: self.__unicode__().encode('utf-8') - return klass - - -# Complete the moves implementation. -# This code is at the end of this module to speed up module loading. -# Turn this module into a package. -__path__ = [] # required for PEP 302 and PEP 451 -__package__ = __name__ # see PEP 366 @ReservedAssignment -if globals().get("__spec__") is not None: - __spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable -# Remove other six meta path importers, since they cause problems. This can -# happen if six is removed from sys.modules and then reloaded. (Setuptools does -# this for some reason.) -if sys.meta_path: - for i, importer in enumerate(sys.meta_path): - # Here's some real nastiness: Another "instance" of the six module might - # be floating around. Therefore, we can't use isinstance() to check for - # the six meta path importer, since the other six instance will have - # inserted an importer with different class. - if (type(importer).__name__ == "_SixMetaPathImporter" and - importer.name == __name__): - del sys.meta_path[i] - break - del i, importer -# Finally, add the importer to the meta path import hook. -sys.meta_path.append(_importer) diff --git a/libs/win/test_path.py b/libs/win/test_path.py deleted file mode 100644 index 2a7ddb8f..00000000 --- a/libs/win/test_path.py +++ /dev/null @@ -1,1258 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -Tests for the path module. - -This suite runs on Linux, OS X, and Windows right now. To extend the -platform support, just add appropriate pathnames for your -platform (os.name) in each place where the p() function is called. -Then report the result. If you can't get the test to run at all on -your platform, there's probably a bug in path.py -- please report the issue -in the issue tracker at https://github.com/jaraco/path.py. - -TestScratchDir.test_touch() takes a while to run. It sleeps a few -seconds to allow some time to pass between calls to check the modify -time on files. -""" - -from __future__ import unicode_literals, absolute_import, print_function - -import codecs -import os -import sys -import shutil -import time -import types -import ntpath -import posixpath -import textwrap -import platform -import importlib -import operator -import datetime -import subprocess -import re - -import pytest -import packaging.version - -import path -from path import TempDir -from path import matchers -from path import SpecialResolver -from path import Multi - -Path = None - - -def p(**choices): - """ Choose a value from several possible values, based on os.name """ - return choices[os.name] - - -@pytest.fixture(autouse=True, params=[path.Path]) -def path_class(request, monkeypatch): - """ - Invoke tests on any number of Path classes. - """ - monkeypatch.setitem(globals(), 'Path', request.param) - - -def mac_version(target, comparator=operator.ge): - """ - Return True if on a Mac whose version passes the comparator. - """ - current_ver = packaging.version.parse(platform.mac_ver()[0]) - target_ver = packaging.version.parse(target) - return ( - platform.system() == 'Darwin' - and comparator(current_ver, target_ver) - ) - - -class TestBasics: - def test_relpath(self): - root = Path(p(nt='C:\\', posix='/')) - foo = root / 'foo' - quux = foo / 'quux' - bar = foo / 'bar' - boz = bar / 'Baz' / 'Boz' - up = Path(os.pardir) - - # basics - assert root.relpathto(boz) == Path('foo') / 'bar' / 'Baz' / 'Boz' - assert bar.relpathto(boz) == Path('Baz') / 'Boz' - assert quux.relpathto(boz) == up / 'bar' / 'Baz' / 'Boz' - assert boz.relpathto(quux) == up / up / up / 'quux' - assert boz.relpathto(bar) == up / up - - # Path is not the first element in concatenation - assert root.relpathto(boz) == 'foo' / Path('bar') / 'Baz' / 'Boz' - - # x.relpathto(x) == curdir - assert root.relpathto(root) == os.curdir - assert boz.relpathto(boz) == os.curdir - # Make sure case is properly noted (or ignored) - assert boz.relpathto(boz.normcase()) == os.curdir - - # relpath() - cwd = Path(os.getcwd()) - assert boz.relpath() == cwd.relpathto(boz) - - if os.name == 'nt': - # Check relpath across drives. - d = Path('D:\\') - assert d.relpathto(boz) == boz - - def test_construction_from_none(self): - """ - - """ - try: - Path(None) - except TypeError: - pass - else: - raise Exception("DID NOT RAISE") - - def test_construction_from_int(self): - """ - Path class will construct a path as a string of the number - """ - assert Path(1) == '1' - - def test_string_compatibility(self): - """ Test compatibility with ordinary strings. """ - x = Path('xyzzy') - assert x == 'xyzzy' - assert x == str('xyzzy') - - # sorting - items = [Path('fhj'), - Path('fgh'), - 'E', - Path('d'), - 'A', - Path('B'), - 'c'] - items.sort() - assert items == ['A', 'B', 'E', 'c', 'd', 'fgh', 'fhj'] - - # Test p1/p1. - p1 = Path("foo") - p2 = Path("bar") - assert p1 / p2 == p(nt='foo\\bar', posix='foo/bar') - - def test_properties(self): - # Create sample path object. - f = p(nt='C:\\Program Files\\Python\\Lib\\xyzzy.py', - posix='/usr/local/python/lib/xyzzy.py') - f = Path(f) - - # .parent - nt_lib = 'C:\\Program Files\\Python\\Lib' - posix_lib = '/usr/local/python/lib' - expected = p(nt=nt_lib, posix=posix_lib) - assert f.parent == expected - - # .name - assert f.name == 'xyzzy.py' - assert f.parent.name == p(nt='Lib', posix='lib') - - # .ext - assert f.ext == '.py' - assert f.parent.ext == '' - - # .drive - assert f.drive == p(nt='C:', posix='') - - def test_methods(self): - # .abspath() - assert Path(os.curdir).abspath() == os.getcwd() - - # .getcwd() - cwd = Path.getcwd() - assert isinstance(cwd, Path) - assert cwd == os.getcwd() - - def test_UNC(self): - if hasattr(os.path, 'splitunc'): - p = Path(r'\\python1\share1\dir1\file1.txt') - assert p.uncshare == r'\\python1\share1' - assert p.splitunc() == os.path.splitunc(str(p)) - - def test_explicit_module(self): - """ - The user may specify an explicit path module to use. - """ - nt_ok = Path.using_module(ntpath)(r'foo\bar\baz') - posix_ok = Path.using_module(posixpath)(r'foo/bar/baz') - posix_wrong = Path.using_module(posixpath)(r'foo\bar\baz') - - assert nt_ok.dirname() == r'foo\bar' - assert posix_ok.dirname() == r'foo/bar' - assert posix_wrong.dirname() == '' - - assert nt_ok / 'quux' == r'foo\bar\baz\quux' - assert posix_ok / 'quux' == r'foo/bar/baz/quux' - - def test_explicit_module_classes(self): - """ - Multiple calls to path.using_module should produce the same class. - """ - nt_path = Path.using_module(ntpath) - assert nt_path is Path.using_module(ntpath) - assert nt_path.__name__ == 'Path_ntpath' - - def test_joinpath_on_instance(self): - res = Path('foo') - foo_bar = res.joinpath('bar') - assert foo_bar == p(nt='foo\\bar', posix='foo/bar') - - def test_joinpath_to_nothing(self): - res = Path('foo') - assert res.joinpath() == res - - def test_joinpath_on_class(self): - "Construct a path from a series of strings" - foo_bar = Path.joinpath('foo', 'bar') - assert foo_bar == p(nt='foo\\bar', posix='foo/bar') - - def test_joinpath_fails_on_empty(self): - "It doesn't make sense to join nothing at all" - try: - Path.joinpath() - except TypeError: - pass - else: - raise Exception("did not raise") - - def test_joinpath_returns_same_type(self): - path_posix = Path.using_module(posixpath) - res = path_posix.joinpath('foo') - assert isinstance(res, path_posix) - res2 = res.joinpath('bar') - assert isinstance(res2, path_posix) - assert res2 == 'foo/bar' - - -class TestPerformance: - @pytest.mark.skipif( - path.PY2, - reason="Tests fail frequently on Python 2; see #153") - def test_import_time(self, monkeypatch): - """ - Import of path.py should take less than 100ms. - - Run tests in a subprocess to isolate from test suite overhead. - """ - cmd = [ - sys.executable, - '-m', 'timeit', - '-n', '1', - '-r', '1', - 'import path', - ] - res = subprocess.check_output(cmd, universal_newlines=True) - dur = re.search(r'(\d+) msec per loop', res).group(1) - limit = datetime.timedelta(milliseconds=100) - duration = datetime.timedelta(milliseconds=int(dur)) - assert duration < limit - - -class TestSelfReturn: - """ - Some methods don't necessarily return any value (e.g. makedirs, - makedirs_p, rename, mkdir, touch, chroot). These methods should return - self anyhow to allow methods to be chained. - """ - def test_makedirs_p(self, tmpdir): - """ - Path('foo').makedirs_p() == Path('foo') - """ - p = Path(tmpdir) / "newpath" - ret = p.makedirs_p() - assert p == ret - - def test_makedirs_p_extant(self, tmpdir): - p = Path(tmpdir) - ret = p.makedirs_p() - assert p == ret - - def test_rename(self, tmpdir): - p = Path(tmpdir) / "somefile" - p.touch() - target = Path(tmpdir) / "otherfile" - ret = p.rename(target) - assert target == ret - - def test_mkdir(self, tmpdir): - p = Path(tmpdir) / "newdir" - ret = p.mkdir() - assert p == ret - - def test_touch(self, tmpdir): - p = Path(tmpdir) / "empty file" - ret = p.touch() - assert p == ret - - -class TestScratchDir: - """ - Tests that run in a temporary directory (does not test TempDir class) - """ - def test_context_manager(self, tmpdir): - """Can be used as context manager for chdir.""" - d = Path(tmpdir) - subdir = d / 'subdir' - subdir.makedirs() - old_dir = os.getcwd() - with subdir: - assert os.getcwd() == os.path.realpath(subdir) - assert os.getcwd() == old_dir - - def test_touch(self, tmpdir): - # NOTE: This test takes a long time to run (~10 seconds). - # It sleeps several seconds because on Windows, the resolution - # of a file's mtime and ctime is about 2 seconds. - # - # atime isn't tested because on Windows the resolution of atime - # is something like 24 hours. - - threshold = 1 - - d = Path(tmpdir) - f = d / 'test.txt' - t0 = time.time() - threshold - f.touch() - t1 = time.time() + threshold - - assert f.exists() - assert f.isfile() - assert f.size == 0 - assert t0 <= f.mtime <= t1 - if hasattr(os.path, 'getctime'): - ct = f.ctime - assert t0 <= ct <= t1 - - time.sleep(threshold * 2) - fobj = open(f, 'ab') - fobj.write('some bytes'.encode('utf-8')) - fobj.close() - - time.sleep(threshold * 2) - t2 = time.time() - threshold - f.touch() - t3 = time.time() + threshold - - assert t0 <= t1 < t2 <= t3 # sanity check - - assert f.exists() - assert f.isfile() - assert f.size == 10 - assert t2 <= f.mtime <= t3 - if hasattr(os.path, 'getctime'): - ct2 = f.ctime - if os.name == 'nt': - # On Windows, "ctime" is CREATION time - assert ct == ct2 - assert ct2 < t2 - else: - assert ( - # ctime is unchanged - ct == ct2 or - # ctime is approximately the mtime - ct2 == pytest.approx(f.mtime, 0.001) - ) - - def test_listing(self, tmpdir): - d = Path(tmpdir) - assert d.listdir() == [] - - f = 'testfile.txt' - af = d / f - assert af == os.path.join(d, f) - af.touch() - try: - assert af.exists() - - assert d.listdir() == [af] - - # .glob() - assert d.glob('testfile.txt') == [af] - assert d.glob('test*.txt') == [af] - assert d.glob('*.txt') == [af] - assert d.glob('*txt') == [af] - assert d.glob('*') == [af] - assert d.glob('*.html') == [] - assert d.glob('testfile') == [] - - # .iglob matches .glob but as an iterator. - assert list(d.iglob('*')) == d.glob('*') - assert isinstance(d.iglob('*'), types.GeneratorType) - - finally: - af.remove() - - # Try a test with 20 files - files = [d / ('%d.txt' % i) for i in range(20)] - for f in files: - fobj = open(f, 'w') - fobj.write('some text\n') - fobj.close() - try: - files2 = d.listdir() - files.sort() - files2.sort() - assert files == files2 - finally: - for f in files: - try: - f.remove() - except Exception: - pass - - @pytest.mark.xfail( - mac_version('10.13'), - reason="macOS disallows invalid encodings", - ) - @pytest.mark.xfail( - platform.system() == 'Windows' and path.PY3, - reason="Can't write latin characters. See #133", - ) - def test_listdir_other_encoding(self, tmpdir): - """ - Some filesystems allow non-character sequences in path names. - ``.listdir`` should still function in this case. - See issue #61 for details. - """ - assert Path(tmpdir).listdir() == [] - tmpdir_bytes = str(tmpdir).encode('ascii') - - filename = 'r\xe9\xf1emi'.encode('latin-1') - pathname = os.path.join(tmpdir_bytes, filename) - with open(pathname, 'wb'): - pass - # first demonstrate that os.listdir works - assert os.listdir(tmpdir_bytes) - - # now try with path.py - results = Path(tmpdir).listdir() - assert len(results) == 1 - res, = results - assert isinstance(res, Path) - # OS X seems to encode the bytes in the filename as %XX characters. - if platform.system() == 'Darwin': - assert res.basename() == 'r%E9%F1emi' - return - assert len(res.basename()) == len(filename) - - def test_makedirs(self, tmpdir): - d = Path(tmpdir) - - # Placeholder file so that when removedirs() is called, - # it doesn't remove the temporary directory itself. - tempf = d / 'temp.txt' - tempf.touch() - try: - foo = d / 'foo' - boz = foo / 'bar' / 'baz' / 'boz' - boz.makedirs() - try: - assert boz.isdir() - finally: - boz.removedirs() - assert not foo.exists() - assert d.exists() - - foo.mkdir(0o750) - boz.makedirs(0o700) - try: - assert boz.isdir() - finally: - boz.removedirs() - assert not foo.exists() - assert d.exists() - finally: - os.remove(tempf) - - def assertSetsEqual(self, a, b): - ad = {} - - for i in a: - ad[i] = None - - bd = {} - - for i in b: - bd[i] = None - - assert ad == bd - - def test_shutil(self, tmpdir): - # Note: This only tests the methods exist and do roughly what - # they should, neglecting the details as they are shutil's - # responsibility. - - d = Path(tmpdir) - testDir = d / 'testdir' - testFile = testDir / 'testfile.txt' - testA = testDir / 'A' - testCopy = testA / 'testcopy.txt' - testLink = testA / 'testlink.txt' - testB = testDir / 'B' - testC = testB / 'C' - testCopyOfLink = testC / testA.relpathto(testLink) - - # Create test dirs and a file - testDir.mkdir() - testA.mkdir() - testB.mkdir() - - f = open(testFile, 'w') - f.write('x' * 10000) - f.close() - - # Test simple file copying. - testFile.copyfile(testCopy) - assert testCopy.isfile() - assert testFile.bytes() == testCopy.bytes() - - # Test copying into a directory. - testCopy2 = testA / testFile.name - testFile.copy(testA) - assert testCopy2.isfile() - assert testFile.bytes() == testCopy2.bytes() - - # Make a link for the next test to use. - if hasattr(os, 'symlink'): - testFile.symlink(testLink) - else: - testFile.copy(testLink) # fallback - - # Test copying directory tree. - testA.copytree(testC) - assert testC.isdir() - self.assertSetsEqual( - testC.listdir(), - [testC / testCopy.name, - testC / testFile.name, - testCopyOfLink]) - assert not testCopyOfLink.islink() - - # Clean up for another try. - testC.rmtree() - assert not testC.exists() - - # Copy again, preserving symlinks. - testA.copytree(testC, True) - assert testC.isdir() - self.assertSetsEqual( - testC.listdir(), - [testC / testCopy.name, - testC / testFile.name, - testCopyOfLink]) - if hasattr(os, 'symlink'): - assert testCopyOfLink.islink() - assert testCopyOfLink.readlink() == testFile - - # Clean up. - testDir.rmtree() - assert not testDir.exists() - self.assertList(d.listdir(), []) - - def assertList(self, listing, expected): - assert sorted(listing) == sorted(expected) - - def test_patterns(self, tmpdir): - d = Path(tmpdir) - names = ['x.tmp', 'x.xtmp', 'x2g', 'x22', 'x.txt'] - dirs = [d, d / 'xdir', d / 'xdir.tmp', d / 'xdir.tmp' / 'xsubdir'] - - for e in dirs: - if not e.isdir(): - e.makedirs() - - for name in names: - (e / name).touch() - self.assertList(d.listdir('*.tmp'), [d / 'x.tmp', d / 'xdir.tmp']) - self.assertList(d.files('*.tmp'), [d / 'x.tmp']) - self.assertList(d.dirs('*.tmp'), [d / 'xdir.tmp']) - self.assertList(d.walk(), [e for e in dirs - if e != d] + [e / n for e in dirs - for n in names]) - self.assertList(d.walk('*.tmp'), - [e / 'x.tmp' for e in dirs] + [d / 'xdir.tmp']) - self.assertList(d.walkfiles('*.tmp'), [e / 'x.tmp' for e in dirs]) - self.assertList(d.walkdirs('*.tmp'), [d / 'xdir.tmp']) - - def test_unicode(self, tmpdir): - d = Path(tmpdir) - p = d / 'unicode.txt' - - def test(enc): - """ Test that path works with the specified encoding, - which must be capable of representing the entire range of - Unicode codepoints. - """ - - given = ( - 'Hello world\n' - '\u0d0a\u0a0d\u0d15\u0a15\r\n' - '\u0d0a\u0a0d\u0d15\u0a15\x85' - '\u0d0a\u0a0d\u0d15\u0a15\u2028' - '\r' - 'hanging' - ) - clean = ( - 'Hello world\n' - '\u0d0a\u0a0d\u0d15\u0a15\n' - '\u0d0a\u0a0d\u0d15\u0a15\n' - '\u0d0a\u0a0d\u0d15\u0a15\n' - '\n' - 'hanging' - ) - givenLines = [ - ('Hello world\n'), - ('\u0d0a\u0a0d\u0d15\u0a15\r\n'), - ('\u0d0a\u0a0d\u0d15\u0a15\x85'), - ('\u0d0a\u0a0d\u0d15\u0a15\u2028'), - ('\r'), - ('hanging')] - expectedLines = [ - ('Hello world\n'), - ('\u0d0a\u0a0d\u0d15\u0a15\n'), - ('\u0d0a\u0a0d\u0d15\u0a15\n'), - ('\u0d0a\u0a0d\u0d15\u0a15\n'), - ('\n'), - ('hanging')] - expectedLines2 = [ - ('Hello world'), - ('\u0d0a\u0a0d\u0d15\u0a15'), - ('\u0d0a\u0a0d\u0d15\u0a15'), - ('\u0d0a\u0a0d\u0d15\u0a15'), - (''), - ('hanging')] - - # write bytes manually to file - f = codecs.open(p, 'w', enc) - f.write(given) - f.close() - - # test all 3 path read-fully functions, including - # path.lines() in unicode mode. - assert p.bytes() == given.encode(enc) - assert p.text(enc) == clean - assert p.lines(enc) == expectedLines - assert p.lines(enc, retain=False) == expectedLines2 - - # If this is UTF-16, that's enough. - # The rest of these will unfortunately fail because append=True - # mode causes an extra BOM to be written in the middle of the file. - # UTF-16 is the only encoding that has this problem. - if enc == 'UTF-16': - return - - # Write Unicode to file using path.write_text(). - # This test doesn't work with a hanging line. - cleanNoHanging = clean + '\n' - - p.write_text(cleanNoHanging, enc) - p.write_text(cleanNoHanging, enc, append=True) - # Check the result. - expectedBytes = 2 * cleanNoHanging.replace('\n', - os.linesep).encode(enc) - expectedLinesNoHanging = expectedLines[:] - expectedLinesNoHanging[-1] += '\n' - assert p.bytes() == expectedBytes - assert p.text(enc) == 2 * cleanNoHanging - assert p.lines(enc) == 2 * expectedLinesNoHanging - assert p.lines(enc, retain=False) == 2 * expectedLines2 - - # Write Unicode to file using path.write_lines(). - # The output in the file should be exactly the same as last time. - p.write_lines(expectedLines, enc) - p.write_lines(expectedLines2, enc, append=True) - # Check the result. - assert p.bytes() == expectedBytes - - # Now: same test, but using various newline sequences. - # If linesep is being properly applied, these will be converted - # to the platform standard newline sequence. - p.write_lines(givenLines, enc) - p.write_lines(givenLines, enc, append=True) - # Check the result. - assert p.bytes() == expectedBytes - - # Same test, using newline sequences that are different - # from the platform default. - def testLinesep(eol): - p.write_lines(givenLines, enc, linesep=eol) - p.write_lines(givenLines, enc, linesep=eol, append=True) - expected = 2 * cleanNoHanging.replace('\n', eol).encode(enc) - assert p.bytes() == expected - - testLinesep('\n') - testLinesep('\r') - testLinesep('\r\n') - testLinesep('\x0d\x85') - - # Again, but with linesep=None. - p.write_lines(givenLines, enc, linesep=None) - p.write_lines(givenLines, enc, linesep=None, append=True) - # Check the result. - expectedBytes = 2 * given.encode(enc) - assert p.bytes() == expectedBytes - assert p.text(enc) == 2 * clean - expectedResultLines = expectedLines[:] - expectedResultLines[-1] += expectedLines[0] - expectedResultLines += expectedLines[1:] - assert p.lines(enc) == expectedResultLines - - test('UTF-8') - test('UTF-16BE') - test('UTF-16LE') - test('UTF-16') - - def test_chunks(self, tmpdir): - p = (TempDir() / 'test.txt').touch() - txt = "0123456789" - size = 5 - p.write_text(txt) - for i, chunk in enumerate(p.chunks(size)): - assert chunk == txt[i * size:i * size + size] - - assert i == len(txt) / size - 1 - - @pytest.mark.skipif( - not hasattr(os.path, 'samefile'), - reason="samefile not present", - ) - def test_samefile(self, tmpdir): - f1 = (TempDir() / '1.txt').touch() - f1.write_text('foo') - f2 = (TempDir() / '2.txt').touch() - f1.write_text('foo') - f3 = (TempDir() / '3.txt').touch() - f1.write_text('bar') - f4 = (TempDir() / '4.txt') - f1.copyfile(f4) - - assert os.path.samefile(f1, f2) == f1.samefile(f2) - assert os.path.samefile(f1, f3) == f1.samefile(f3) - assert os.path.samefile(f1, f4) == f1.samefile(f4) - assert os.path.samefile(f1, f1) == f1.samefile(f1) - - def test_rmtree_p(self, tmpdir): - d = Path(tmpdir) - sub = d / 'subfolder' - sub.mkdir() - (sub / 'afile').write_text('something') - sub.rmtree_p() - assert not sub.exists() - try: - sub.rmtree_p() - except OSError: - self.fail("Calling `rmtree_p` on non-existent directory " - "should not raise an exception.") - - def test_rmdir_p_exists(self, tmpdir): - """ - Invocation of rmdir_p on an existant directory should - remove the directory. - """ - d = Path(tmpdir) - sub = d / 'subfolder' - sub.mkdir() - sub.rmdir_p() - assert not sub.exists() - - def test_rmdir_p_nonexistent(self, tmpdir): - """ - A non-existent file should not raise an exception. - """ - d = Path(tmpdir) - sub = d / 'subfolder' - assert not sub.exists() - sub.rmdir_p() - - -class TestMergeTree: - @pytest.fixture(autouse=True) - def testing_structure(self, tmpdir): - self.test_dir = Path(tmpdir) - self.subdir_a = self.test_dir / 'A' - self.test_file = self.subdir_a / 'testfile.txt' - self.test_link = self.subdir_a / 'testlink.txt' - self.subdir_b = self.test_dir / 'B' - - self.subdir_a.mkdir() - self.subdir_b.mkdir() - - with open(self.test_file, 'w') as f: - f.write('x' * 10000) - - if hasattr(os, 'symlink'): - self.test_file.symlink(self.test_link) - else: - self.test_file.copy(self.test_link) - - def check_link(self): - target = Path(self.subdir_b / self.test_link.name) - check = target.islink if hasattr(os, 'symlink') else target.isfile - assert check() - - def test_with_nonexisting_dst_kwargs(self): - self.subdir_a.merge_tree(self.subdir_b, symlinks=True) - assert self.subdir_b.isdir() - expected = set(( - self.subdir_b / self.test_file.name, - self.subdir_b / self.test_link.name, - )) - assert set(self.subdir_b.listdir()) == expected - self.check_link() - - def test_with_nonexisting_dst_args(self): - self.subdir_a.merge_tree(self.subdir_b, True) - assert self.subdir_b.isdir() - expected = set(( - self.subdir_b / self.test_file.name, - self.subdir_b / self.test_link.name, - )) - assert set(self.subdir_b.listdir()) == expected - self.check_link() - - def test_with_existing_dst(self): - self.subdir_b.rmtree() - self.subdir_a.copytree(self.subdir_b, True) - - self.test_link.remove() - test_new = self.subdir_a / 'newfile.txt' - test_new.touch() - with open(self.test_file, 'w') as f: - f.write('x' * 5000) - - self.subdir_a.merge_tree(self.subdir_b, True) - - assert self.subdir_b.isdir() - expected = set(( - self.subdir_b / self.test_file.name, - self.subdir_b / self.test_link.name, - self.subdir_b / test_new.name, - )) - assert set(self.subdir_b.listdir()) == expected - self.check_link() - assert len(Path(self.subdir_b / self.test_file.name).bytes()) == 5000 - - def test_copytree_parameters(self): - """ - merge_tree should accept parameters to copytree, such as 'ignore' - """ - ignore = shutil.ignore_patterns('testlink*') - self.subdir_a.merge_tree(self.subdir_b, ignore=ignore) - - assert self.subdir_b.isdir() - assert self.subdir_b.listdir() == [self.subdir_b / self.test_file.name] - - def test_only_newer(self): - """ - merge_tree should accept a copy_function in which only - newer files are copied and older files do not overwrite - newer copies in the dest. - """ - target = self.subdir_b / 'testfile.txt' - target.write_text('this is newer') - self.subdir_a.merge_tree( - self.subdir_b, - copy_function=path.only_newer(shutil.copy2), - ) - assert target.text() == 'this is newer' - - -class TestChdir: - def test_chdir_or_cd(self, tmpdir): - """ tests the chdir or cd method """ - d = Path(str(tmpdir)) - cwd = d.getcwd() - - # ensure the cwd isn't our tempdir - assert str(d) != str(cwd) - # now, we're going to chdir to tempdir - d.chdir() - - # we now ensure that our cwd is the tempdir - assert str(d.getcwd()) == str(tmpdir) - # we're resetting our path - d = Path(cwd) - - # we ensure that our cwd is still set to tempdir - assert str(d.getcwd()) == str(tmpdir) - - # we're calling the alias cd method - d.cd() - # now, we ensure cwd isn'r tempdir - assert str(d.getcwd()) == str(cwd) - assert str(d.getcwd()) != str(tmpdir) - - -class TestSubclass: - - def test_subclass_produces_same_class(self): - """ - When operations are invoked on a subclass, they should produce another - instance of that subclass. - """ - class PathSubclass(Path): - pass - p = PathSubclass('/foo') - subdir = p / 'bar' - assert isinstance(subdir, PathSubclass) - - -class TestTempDir: - - def test_constructor(self): - """ - One should be able to readily construct a temporary directory - """ - d = TempDir() - assert isinstance(d, path.Path) - assert d.exists() - assert d.isdir() - d.rmdir() - assert not d.exists() - - def test_next_class(self): - """ - It should be possible to invoke operations on a TempDir and get - Path classes. - """ - d = TempDir() - sub = d / 'subdir' - assert isinstance(sub, path.Path) - d.rmdir() - - def test_context_manager(self): - """ - One should be able to use a TempDir object as a context, which will - clean up the contents after. - """ - d = TempDir() - res = d.__enter__() - assert res == path.Path(d) - (d / 'somefile.txt').touch() - assert not isinstance(d / 'somefile.txt', TempDir) - d.__exit__(None, None, None) - assert not d.exists() - - def test_context_manager_exception(self): - """ - The context manager will not clean up if an exception occurs. - """ - d = TempDir() - d.__enter__() - (d / 'somefile.txt').touch() - assert not isinstance(d / 'somefile.txt', TempDir) - d.__exit__(TypeError, TypeError('foo'), None) - assert d.exists() - - def test_context_manager_using_with(self): - """ - The context manager will allow using the with keyword and - provide a temporry directory that will be deleted after that. - """ - - with TempDir() as d: - assert d.isdir() - assert not d.isdir() - - -class TestUnicode: - @pytest.fixture(autouse=True) - def unicode_name_in_tmpdir(self, tmpdir): - # build a snowman (dir) in the temporary directory - Path(tmpdir).joinpath('☃').mkdir() - - def test_walkdirs_with_unicode_name(self, tmpdir): - for res in Path(tmpdir).walkdirs(): - pass - - -class TestPatternMatching: - def test_fnmatch_simple(self): - p = Path('FooBar') - assert p.fnmatch('Foo*') - assert p.fnmatch('Foo[ABC]ar') - - def test_fnmatch_custom_mod(self): - p = Path('FooBar') - p.module = ntpath - assert p.fnmatch('foobar') - assert p.fnmatch('FOO[ABC]AR') - - def test_fnmatch_custom_normcase(self): - def normcase(path): - return path.upper() - p = Path('FooBar') - assert p.fnmatch('foobar', normcase=normcase) - assert p.fnmatch('FOO[ABC]AR', normcase=normcase) - - def test_listdir_simple(self): - p = Path('.') - assert len(p.listdir()) == len(os.listdir('.')) - - def test_listdir_empty_pattern(self): - p = Path('.') - assert p.listdir('') == [] - - def test_listdir_patterns(self, tmpdir): - p = Path(tmpdir) - (p / 'sub').mkdir() - (p / 'File').touch() - assert p.listdir('s*') == [p / 'sub'] - assert len(p.listdir('*')) == 2 - - def test_listdir_custom_module(self, tmpdir): - """ - Listdir patterns should honor the case sensitivity of the path module - used by that Path class. - """ - always_unix = Path.using_module(posixpath) - p = always_unix(tmpdir) - (p / 'sub').mkdir() - (p / 'File').touch() - assert p.listdir('S*') == [] - - always_win = Path.using_module(ntpath) - p = always_win(tmpdir) - assert p.listdir('S*') == [p / 'sub'] - assert p.listdir('f*') == [p / 'File'] - - def test_listdir_case_insensitive(self, tmpdir): - """ - Listdir patterns should honor the case sensitivity of the path module - used by that Path class. - """ - p = Path(tmpdir) - (p / 'sub').mkdir() - (p / 'File').touch() - assert p.listdir(matchers.CaseInsensitive('S*')) == [p / 'sub'] - assert p.listdir(matchers.CaseInsensitive('f*')) == [p / 'File'] - assert p.files(matchers.CaseInsensitive('S*')) == [] - assert p.dirs(matchers.CaseInsensitive('f*')) == [] - - def test_walk_case_insensitive(self, tmpdir): - p = Path(tmpdir) - (p / 'sub1' / 'foo').makedirs_p() - (p / 'sub2' / 'foo').makedirs_p() - (p / 'sub1' / 'foo' / 'bar.Txt').touch() - (p / 'sub2' / 'foo' / 'bar.TXT').touch() - (p / 'sub2' / 'foo' / 'bar.txt.bz2').touch() - files = list(p.walkfiles(matchers.CaseInsensitive('*.txt'))) - assert len(files) == 2 - assert p / 'sub2' / 'foo' / 'bar.TXT' in files - assert p / 'sub1' / 'foo' / 'bar.Txt' in files - - -@pytest.mark.skipif( - sys.version_info < (2, 6), - reason="in_place requires io module in Python 2.6", -) -class TestInPlace: - reference_content = textwrap.dedent(""" - The quick brown fox jumped over the lazy dog. - """.lstrip()) - reversed_content = textwrap.dedent(""" - .god yzal eht revo depmuj xof nworb kciuq ehT - """.lstrip()) - alternate_content = textwrap.dedent(""" - Lorem ipsum dolor sit amet, consectetur adipisicing elit, - sed do eiusmod tempor incididunt ut labore et dolore magna - aliqua. Ut enim ad minim veniam, quis nostrud exercitation - ullamco laboris nisi ut aliquip ex ea commodo consequat. - Duis aute irure dolor in reprehenderit in voluptate velit - esse cillum dolore eu fugiat nulla pariatur. Excepteur - sint occaecat cupidatat non proident, sunt in culpa qui - officia deserunt mollit anim id est laborum. - """.lstrip()) - - @classmethod - def create_reference(cls, tmpdir): - p = Path(tmpdir) / 'document' - with p.open('w') as stream: - stream.write(cls.reference_content) - return p - - def test_line_by_line_rewrite(self, tmpdir): - doc = self.create_reference(tmpdir) - # reverse all the text in the document, line by line - with doc.in_place() as (reader, writer): - for line in reader: - r_line = ''.join(reversed(line.strip())) + '\n' - writer.write(r_line) - with doc.open() as stream: - data = stream.read() - assert data == self.reversed_content - - def test_exception_in_context(self, tmpdir): - doc = self.create_reference(tmpdir) - with pytest.raises(RuntimeError) as exc: - with doc.in_place() as (reader, writer): - writer.write(self.alternate_content) - raise RuntimeError("some error") - assert "some error" in str(exc) - with doc.open() as stream: - data = stream.read() - assert 'Lorem' not in data - assert 'lazy dog' in data - - -class TestSpecialPaths: - @pytest.fixture(autouse=True, scope='class') - def appdirs_installed(cls): - pytest.importorskip('appdirs') - - @pytest.fixture - def feign_linux(self, monkeypatch): - monkeypatch.setattr("platform.system", lambda: "Linux") - monkeypatch.setattr("sys.platform", "linux") - monkeypatch.setattr("os.pathsep", ":") - # remove any existing import of appdirs, as it sets up some - # state during import. - sys.modules.pop('appdirs') - - def test_basic_paths(self): - appdirs = importlib.import_module('appdirs') - - expected = appdirs.user_config_dir() - assert SpecialResolver(Path).user.config == expected - - expected = appdirs.site_config_dir() - assert SpecialResolver(Path).site.config == expected - - expected = appdirs.user_config_dir('My App', 'Me') - assert SpecialResolver(Path, 'My App', 'Me').user.config == expected - - def test_unix_paths(self, tmpdir, monkeypatch, feign_linux): - fake_config = tmpdir / '_config' - monkeypatch.setitem(os.environ, 'XDG_CONFIG_HOME', str(fake_config)) - expected = str(tmpdir / '_config') - assert SpecialResolver(Path).user.config == expected - - def test_unix_paths_fallback(self, tmpdir, monkeypatch, feign_linux): - "Without XDG_CONFIG_HOME set, ~/.config should be used." - fake_home = tmpdir / '_home' - monkeypatch.delitem(os.environ, 'XDG_CONFIG_HOME', raising=False) - monkeypatch.setitem(os.environ, 'HOME', str(fake_home)) - expected = Path('~/.config').expanduser() - assert SpecialResolver(Path).user.config == expected - - def test_property(self): - assert isinstance(Path.special().user.config, Path) - assert isinstance(Path.special().user.data, Path) - assert isinstance(Path.special().user.cache, Path) - - def test_other_parameters(self): - """ - Other parameters should be passed through to appdirs function. - """ - res = Path.special(version="1.0", multipath=True).site.config - assert isinstance(res, Path) - - def test_multipath(self, feign_linux, monkeypatch, tmpdir): - """ - If multipath is provided, on Linux return the XDG_CONFIG_DIRS - """ - fake_config_1 = str(tmpdir / '_config1') - fake_config_2 = str(tmpdir / '_config2') - config_dirs = os.pathsep.join([fake_config_1, fake_config_2]) - monkeypatch.setitem(os.environ, 'XDG_CONFIG_DIRS', config_dirs) - res = Path.special(multipath=True).site.config - assert isinstance(res, Multi) - assert fake_config_1 in res - assert fake_config_2 in res - assert '_config1' in str(res) - - def test_reused_SpecialResolver(self): - """ - Passing additional args and kwargs to SpecialResolver should be - passed through to each invocation of the function in appdirs. - """ - appdirs = importlib.import_module('appdirs') - - adp = SpecialResolver(Path, version="1.0") - res = adp.user.config - - expected = appdirs.user_config_dir(version="1.0") - assert res == expected - - -class TestMultiPath: - def test_for_class(self): - """ - Multi.for_class should return a subclass of the Path class provided. - """ - cls = Multi.for_class(Path) - assert issubclass(cls, Path) - assert issubclass(cls, Multi) - expected_name = 'Multi' + Path.__name__ - assert cls.__name__ == expected_name - - def test_detect_no_pathsep(self): - """ - If no pathsep is provided, multipath detect should return an instance - of the parent class with no Multi mix-in. - """ - path = Multi.for_class(Path).detect('/foo/bar') - assert isinstance(path, Path) - assert not isinstance(path, Multi) - - def test_detect_with_pathsep(self): - """ - If a pathsep appears in the input, detect should return an instance - of a Path with the Multi mix-in. - """ - inputs = '/foo/bar', '/baz/bing' - input = os.pathsep.join(inputs) - path = Multi.for_class(Path).detect(input) - - assert isinstance(path, Multi) - - def test_iteration(self): - """ - Iterating over a MultiPath should yield instances of the - parent class. - """ - inputs = '/foo/bar', '/baz/bing' - input = os.pathsep.join(inputs) - path = Multi.for_class(Path).detect(input) - - items = iter(path) - first = next(items) - assert first == '/foo/bar' - assert isinstance(first, Path) - assert not isinstance(first, Multi) - assert next(items) == '/baz/bing' - assert path == input - - -@pytest.mark.xfail('path.PY2', reason="Python 2 has no __future__") -def test_no_dependencies(): - """ - Path.py guarantees that the path module can be - transplanted into an environment without any dependencies. - """ - cmd = [ - sys.executable, - '-S', - '-c', 'import path', - ] - subprocess.check_call(cmd) - - -def test_version(): - """ - Under normal circumstances, path should present a - __version__. - """ - assert re.match(r'\d+\.\d+.*', path.__version__) diff --git a/libs/win/typing_extensions.py b/libs/win/typing_extensions.py new file mode 100644 index 00000000..ef42417c --- /dev/null +++ b/libs/win/typing_extensions.py @@ -0,0 +1,2209 @@ +import abc +import collections +import collections.abc +import functools +import operator +import sys +import types as _types +import typing + + +__all__ = [ + # Super-special typing primitives. + 'Any', + 'ClassVar', + 'Concatenate', + 'Final', + 'LiteralString', + 'ParamSpec', + 'ParamSpecArgs', + 'ParamSpecKwargs', + 'Self', + 'Type', + 'TypeVar', + 'TypeVarTuple', + 'Unpack', + + # ABCs (from collections.abc). + 'Awaitable', + 'AsyncIterator', + 'AsyncIterable', + 'Coroutine', + 'AsyncGenerator', + 'AsyncContextManager', + 'ChainMap', + + # Concrete collection types. + 'ContextManager', + 'Counter', + 'Deque', + 'DefaultDict', + 'NamedTuple', + 'OrderedDict', + 'TypedDict', + + # Structural checks, a.k.a. protocols. + 'SupportsIndex', + + # One-off things. + 'Annotated', + 'assert_never', + 'assert_type', + 'clear_overloads', + 'dataclass_transform', + 'get_overloads', + 'final', + 'get_args', + 'get_origin', + 'get_type_hints', + 'IntVar', + 'is_typeddict', + 'Literal', + 'NewType', + 'overload', + 'override', + 'Protocol', + 'reveal_type', + 'runtime', + 'runtime_checkable', + 'Text', + 'TypeAlias', + 'TypeGuard', + 'TYPE_CHECKING', + 'Never', + 'NoReturn', + 'Required', + 'NotRequired', +] + +# for backward compatibility +PEP_560 = True +GenericMeta = type + +# The functions below are modified copies of typing internal helpers. +# They are needed by _ProtocolMeta and they provide support for PEP 646. + +_marker = object() + + +def _check_generic(cls, parameters, elen=_marker): + """Check correct count for parameters of a generic cls (internal helper). + This gives a nice error message in case of count mismatch. + """ + if not elen: + raise TypeError(f"{cls} is not a generic class") + if elen is _marker: + if not hasattr(cls, "__parameters__") or not cls.__parameters__: + raise TypeError(f"{cls} is not a generic class") + elen = len(cls.__parameters__) + alen = len(parameters) + if alen != elen: + if hasattr(cls, "__parameters__"): + parameters = [p for p in cls.__parameters__ if not _is_unpack(p)] + num_tv_tuples = sum(isinstance(p, TypeVarTuple) for p in parameters) + if (num_tv_tuples > 0) and (alen >= elen - num_tv_tuples): + return + raise TypeError(f"Too {'many' if alen > elen else 'few'} parameters for {cls};" + f" actual {alen}, expected {elen}") + + +if sys.version_info >= (3, 10): + def _should_collect_from_parameters(t): + return isinstance( + t, (typing._GenericAlias, _types.GenericAlias, _types.UnionType) + ) +elif sys.version_info >= (3, 9): + def _should_collect_from_parameters(t): + return isinstance(t, (typing._GenericAlias, _types.GenericAlias)) +else: + def _should_collect_from_parameters(t): + return isinstance(t, typing._GenericAlias) and not t._special + + +def _collect_type_vars(types, typevar_types=None): + """Collect all type variable contained in types in order of + first appearance (lexicographic order). For example:: + + _collect_type_vars((T, List[S, T])) == (T, S) + """ + if typevar_types is None: + typevar_types = typing.TypeVar + tvars = [] + for t in types: + if ( + isinstance(t, typevar_types) and + t not in tvars and + not _is_unpack(t) + ): + tvars.append(t) + if _should_collect_from_parameters(t): + tvars.extend([t for t in t.__parameters__ if t not in tvars]) + return tuple(tvars) + + +NoReturn = typing.NoReturn + +# Some unconstrained type variables. These are used by the container types. +# (These are not for export.) +T = typing.TypeVar('T') # Any type. +KT = typing.TypeVar('KT') # Key type. +VT = typing.TypeVar('VT') # Value type. +T_co = typing.TypeVar('T_co', covariant=True) # Any type covariant containers. +T_contra = typing.TypeVar('T_contra', contravariant=True) # Ditto contravariant. + + +if sys.version_info >= (3, 11): + from typing import Any +else: + + class _AnyMeta(type): + def __instancecheck__(self, obj): + if self is Any: + raise TypeError("typing_extensions.Any cannot be used with isinstance()") + return super().__instancecheck__(obj) + + def __repr__(self): + if self is Any: + return "typing_extensions.Any" + return super().__repr__() + + class Any(metaclass=_AnyMeta): + """Special type indicating an unconstrained type. + - Any is compatible with every type. + - Any assumed to have all methods. + - All values assumed to be instances of Any. + Note that all the above statements are true from the point of view of + static type checkers. At runtime, Any should not be used with instance + checks. + """ + def __new__(cls, *args, **kwargs): + if cls is Any: + raise TypeError("Any cannot be instantiated") + return super().__new__(cls, *args, **kwargs) + + +ClassVar = typing.ClassVar + +# On older versions of typing there is an internal class named "Final". +# 3.8+ +if hasattr(typing, 'Final') and sys.version_info[:2] >= (3, 7): + Final = typing.Final +# 3.7 +else: + class _FinalForm(typing._SpecialForm, _root=True): + + def __repr__(self): + return 'typing_extensions.' + self._name + + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + + Final = _FinalForm('Final', + doc="""A special typing construct to indicate that a name + cannot be re-assigned or overridden in a subclass. + For example: + + MAX_SIZE: Final = 9000 + MAX_SIZE += 1 # Error reported by type checker + + class Connection: + TIMEOUT: Final[int] = 10 + class FastConnector(Connection): + TIMEOUT = 1 # Error reported by type checker + + There is no runtime checking of these properties.""") + +if sys.version_info >= (3, 11): + final = typing.final +else: + # @final exists in 3.8+, but we backport it for all versions + # before 3.11 to keep support for the __final__ attribute. + # See https://bugs.python.org/issue46342 + def final(f): + """This decorator can be used to indicate to type checkers that + the decorated method cannot be overridden, and decorated class + cannot be subclassed. For example: + + class Base: + @final + def done(self) -> None: + ... + class Sub(Base): + def done(self) -> None: # Error reported by type checker + ... + @final + class Leaf: + ... + class Other(Leaf): # Error reported by type checker + ... + + There is no runtime checking of these properties. The decorator + sets the ``__final__`` attribute to ``True`` on the decorated object + to allow runtime introspection. + """ + try: + f.__final__ = True + except (AttributeError, TypeError): + # Skip the attribute silently if it is not writable. + # AttributeError happens if the object has __slots__ or a + # read-only property, TypeError if it's a builtin class. + pass + return f + + +def IntVar(name): + return typing.TypeVar(name) + + +# 3.8+: +if hasattr(typing, 'Literal'): + Literal = typing.Literal +# 3.7: +else: + class _LiteralForm(typing._SpecialForm, _root=True): + + def __repr__(self): + return 'typing_extensions.' + self._name + + def __getitem__(self, parameters): + return typing._GenericAlias(self, parameters) + + Literal = _LiteralForm('Literal', + doc="""A type that can be used to indicate to type checkers + that the corresponding value has a value literally equivalent + to the provided parameter. For example: + + var: Literal[4] = 4 + + The type checker understands that 'var' is literally equal to + the value 4 and no other value. + + Literal[...] cannot be subclassed. There is no runtime + checking verifying that the parameter is actually a value + instead of a type.""") + + +_overload_dummy = typing._overload_dummy # noqa + + +if hasattr(typing, "get_overloads"): # 3.11+ + overload = typing.overload + get_overloads = typing.get_overloads + clear_overloads = typing.clear_overloads +else: + # {module: {qualname: {firstlineno: func}}} + _overload_registry = collections.defaultdict( + functools.partial(collections.defaultdict, dict) + ) + + def overload(func): + """Decorator for overloaded functions/methods. + + In a stub file, place two or more stub definitions for the same + function in a row, each decorated with @overload. For example: + + @overload + def utf8(value: None) -> None: ... + @overload + def utf8(value: bytes) -> bytes: ... + @overload + def utf8(value: str) -> bytes: ... + + In a non-stub file (i.e. a regular .py file), do the same but + follow it with an implementation. The implementation should *not* + be decorated with @overload. For example: + + @overload + def utf8(value: None) -> None: ... + @overload + def utf8(value: bytes) -> bytes: ... + @overload + def utf8(value: str) -> bytes: ... + def utf8(value): + # implementation goes here + + The overloads for a function can be retrieved at runtime using the + get_overloads() function. + """ + # classmethod and staticmethod + f = getattr(func, "__func__", func) + try: + _overload_registry[f.__module__][f.__qualname__][ + f.__code__.co_firstlineno + ] = func + except AttributeError: + # Not a normal function; ignore. + pass + return _overload_dummy + + def get_overloads(func): + """Return all defined overloads for *func* as a sequence.""" + # classmethod and staticmethod + f = getattr(func, "__func__", func) + if f.__module__ not in _overload_registry: + return [] + mod_dict = _overload_registry[f.__module__] + if f.__qualname__ not in mod_dict: + return [] + return list(mod_dict[f.__qualname__].values()) + + def clear_overloads(): + """Clear all overloads in the registry.""" + _overload_registry.clear() + + +# This is not a real generic class. Don't use outside annotations. +Type = typing.Type + +# Various ABCs mimicking those in collections.abc. +# A few are simply re-exported for completeness. + + +Awaitable = typing.Awaitable +Coroutine = typing.Coroutine +AsyncIterable = typing.AsyncIterable +AsyncIterator = typing.AsyncIterator +Deque = typing.Deque +ContextManager = typing.ContextManager +AsyncContextManager = typing.AsyncContextManager +DefaultDict = typing.DefaultDict + +# 3.7.2+ +if hasattr(typing, 'OrderedDict'): + OrderedDict = typing.OrderedDict +# 3.7.0-3.7.2 +else: + OrderedDict = typing._alias(collections.OrderedDict, (KT, VT)) + +Counter = typing.Counter +ChainMap = typing.ChainMap +AsyncGenerator = typing.AsyncGenerator +NewType = typing.NewType +Text = typing.Text +TYPE_CHECKING = typing.TYPE_CHECKING + + +_PROTO_WHITELIST = ['Callable', 'Awaitable', + 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', + 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', + 'ContextManager', 'AsyncContextManager'] + + +def _get_protocol_attrs(cls): + attrs = set() + for base in cls.__mro__[:-1]: # without object + if base.__name__ in ('Protocol', 'Generic'): + continue + annotations = getattr(base, '__annotations__', {}) + for attr in list(base.__dict__.keys()) + list(annotations.keys()): + if (not attr.startswith('_abc_') and attr not in ( + '__abstractmethods__', '__annotations__', '__weakref__', + '_is_protocol', '_is_runtime_protocol', '__dict__', + '__args__', '__slots__', + '__next_in_mro__', '__parameters__', '__origin__', + '__orig_bases__', '__extra__', '__tree_hash__', + '__doc__', '__subclasshook__', '__init__', '__new__', + '__module__', '_MutableMapping__marker', '_gorg')): + attrs.add(attr) + return attrs + + +def _is_callable_members_only(cls): + return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls)) + + +def _maybe_adjust_parameters(cls): + """Helper function used in Protocol.__init_subclass__ and _TypedDictMeta.__new__. + + The contents of this function are very similar + to logic found in typing.Generic.__init_subclass__ + on the CPython main branch. + """ + tvars = [] + if '__orig_bases__' in cls.__dict__: + tvars = typing._collect_type_vars(cls.__orig_bases__) + # Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn]. + # If found, tvars must be a subset of it. + # If not found, tvars is it. + # Also check for and reject plain Generic, + # and reject multiple Generic[...] and/or Protocol[...]. + gvars = None + for base in cls.__orig_bases__: + if (isinstance(base, typing._GenericAlias) and + base.__origin__ in (typing.Generic, Protocol)): + # for error messages + the_base = base.__origin__.__name__ + if gvars is not None: + raise TypeError( + "Cannot inherit from Generic[...]" + " and/or Protocol[...] multiple types.") + gvars = base.__parameters__ + if gvars is None: + gvars = tvars + else: + tvarset = set(tvars) + gvarset = set(gvars) + if not tvarset <= gvarset: + s_vars = ', '.join(str(t) for t in tvars if t not in gvarset) + s_args = ', '.join(str(g) for g in gvars) + raise TypeError(f"Some type variables ({s_vars}) are" + f" not listed in {the_base}[{s_args}]") + tvars = gvars + cls.__parameters__ = tuple(tvars) + + +# 3.8+ +if hasattr(typing, 'Protocol'): + Protocol = typing.Protocol +# 3.7 +else: + + def _no_init(self, *args, **kwargs): + if type(self)._is_protocol: + raise TypeError('Protocols cannot be instantiated') + + class _ProtocolMeta(abc.ABCMeta): # noqa: B024 + # This metaclass is a bit unfortunate and exists only because of the lack + # of __instancehook__. + def __instancecheck__(cls, instance): + # We need this method for situations where attributes are + # assigned in __init__. + if ((not getattr(cls, '_is_protocol', False) or + _is_callable_members_only(cls)) and + issubclass(instance.__class__, cls)): + return True + if cls._is_protocol: + if all(hasattr(instance, attr) and + (not callable(getattr(cls, attr, None)) or + getattr(instance, attr) is not None) + for attr in _get_protocol_attrs(cls)): + return True + return super().__instancecheck__(instance) + + class Protocol(metaclass=_ProtocolMeta): + # There is quite a lot of overlapping code with typing.Generic. + # Unfortunately it is hard to avoid this while these live in two different + # modules. The duplicated code will be removed when Protocol is moved to typing. + """Base class for protocol classes. Protocol classes are defined as:: + + class Proto(Protocol): + def meth(self) -> int: + ... + + Such classes are primarily used with static type checkers that recognize + structural subtyping (static duck-typing), for example:: + + class C: + def meth(self) -> int: + return 0 + + def func(x: Proto) -> int: + return x.meth() + + func(C()) # Passes static type check + + See PEP 544 for details. Protocol classes decorated with + @typing_extensions.runtime act as simple-minded runtime protocol that checks + only the presence of given attributes, ignoring their type signatures. + + Protocol classes can be generic, they are defined as:: + + class GenProto(Protocol[T]): + def meth(self) -> T: + ... + """ + __slots__ = () + _is_protocol = True + + def __new__(cls, *args, **kwds): + if cls is Protocol: + raise TypeError("Type Protocol cannot be instantiated; " + "it can only be used as a base class") + return super().__new__(cls) + + @typing._tp_cache + def __class_getitem__(cls, params): + if not isinstance(params, tuple): + params = (params,) + if not params and cls is not typing.Tuple: + raise TypeError( + f"Parameter list to {cls.__qualname__}[...] cannot be empty") + msg = "Parameters to generic types must be types." + params = tuple(typing._type_check(p, msg) for p in params) # noqa + if cls is Protocol: + # Generic can only be subscripted with unique type variables. + if not all(isinstance(p, typing.TypeVar) for p in params): + i = 0 + while isinstance(params[i], typing.TypeVar): + i += 1 + raise TypeError( + "Parameters to Protocol[...] must all be type variables." + f" Parameter {i + 1} is {params[i]}") + if len(set(params)) != len(params): + raise TypeError( + "Parameters to Protocol[...] must all be unique") + else: + # Subscripting a regular Generic subclass. + _check_generic(cls, params, len(cls.__parameters__)) + return typing._GenericAlias(cls, params) + + def __init_subclass__(cls, *args, **kwargs): + if '__orig_bases__' in cls.__dict__: + error = typing.Generic in cls.__orig_bases__ + else: + error = typing.Generic in cls.__bases__ + if error: + raise TypeError("Cannot inherit from plain Generic") + _maybe_adjust_parameters(cls) + + # Determine if this is a protocol or a concrete subclass. + if not cls.__dict__.get('_is_protocol', None): + cls._is_protocol = any(b is Protocol for b in cls.__bases__) + + # Set (or override) the protocol subclass hook. + def _proto_hook(other): + if not cls.__dict__.get('_is_protocol', None): + return NotImplemented + if not getattr(cls, '_is_runtime_protocol', False): + if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']: + return NotImplemented + raise TypeError("Instance and class checks can only be used with" + " @runtime protocols") + if not _is_callable_members_only(cls): + if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']: + return NotImplemented + raise TypeError("Protocols with non-method members" + " don't support issubclass()") + if not isinstance(other, type): + # Same error as for issubclass(1, int) + raise TypeError('issubclass() arg 1 must be a class') + for attr in _get_protocol_attrs(cls): + for base in other.__mro__: + if attr in base.__dict__: + if base.__dict__[attr] is None: + return NotImplemented + break + annotations = getattr(base, '__annotations__', {}) + if (isinstance(annotations, typing.Mapping) and + attr in annotations and + isinstance(other, _ProtocolMeta) and + other._is_protocol): + break + else: + return NotImplemented + return True + if '__subclasshook__' not in cls.__dict__: + cls.__subclasshook__ = _proto_hook + + # We have nothing more to do for non-protocols. + if not cls._is_protocol: + return + + # Check consistency of bases. + for base in cls.__bases__: + if not (base in (object, typing.Generic) or + base.__module__ == 'collections.abc' and + base.__name__ in _PROTO_WHITELIST or + isinstance(base, _ProtocolMeta) and base._is_protocol): + raise TypeError('Protocols can only inherit from other' + f' protocols, got {repr(base)}') + cls.__init__ = _no_init + + +# 3.8+ +if hasattr(typing, 'runtime_checkable'): + runtime_checkable = typing.runtime_checkable +# 3.7 +else: + def runtime_checkable(cls): + """Mark a protocol class as a runtime protocol, so that it + can be used with isinstance() and issubclass(). Raise TypeError + if applied to a non-protocol class. + + This allows a simple-minded structural check very similar to the + one-offs in collections.abc such as Hashable. + """ + if not isinstance(cls, _ProtocolMeta) or not cls._is_protocol: + raise TypeError('@runtime_checkable can be only applied to protocol classes,' + f' got {cls!r}') + cls._is_runtime_protocol = True + return cls + + +# Exists for backwards compatibility. +runtime = runtime_checkable + + +# 3.8+ +if hasattr(typing, 'SupportsIndex'): + SupportsIndex = typing.SupportsIndex +# 3.7 +else: + @runtime_checkable + class SupportsIndex(Protocol): + __slots__ = () + + @abc.abstractmethod + def __index__(self) -> int: + pass + + +if hasattr(typing, "Required"): + # The standard library TypedDict in Python 3.8 does not store runtime information + # about which (if any) keys are optional. See https://bugs.python.org/issue38834 + # The standard library TypedDict in Python 3.9.0/1 does not honour the "total" + # keyword with old-style TypedDict(). See https://bugs.python.org/issue42059 + # The standard library TypedDict below Python 3.11 does not store runtime + # information about optional and required keys when using Required or NotRequired. + # Generic TypedDicts are also impossible using typing.TypedDict on Python <3.11. + TypedDict = typing.TypedDict + _TypedDictMeta = typing._TypedDictMeta + is_typeddict = typing.is_typeddict +else: + def _check_fails(cls, other): + try: + if sys._getframe(1).f_globals['__name__'] not in ['abc', + 'functools', + 'typing']: + # Typed dicts are only for static structural subtyping. + raise TypeError('TypedDict does not support instance and class checks') + except (AttributeError, ValueError): + pass + return False + + def _dict_new(*args, **kwargs): + if not args: + raise TypeError('TypedDict.__new__(): not enough arguments') + _, args = args[0], args[1:] # allow the "cls" keyword be passed + return dict(*args, **kwargs) + + _dict_new.__text_signature__ = '($cls, _typename, _fields=None, /, **kwargs)' + + def _typeddict_new(*args, total=True, **kwargs): + if not args: + raise TypeError('TypedDict.__new__(): not enough arguments') + _, args = args[0], args[1:] # allow the "cls" keyword be passed + if args: + typename, args = args[0], args[1:] # allow the "_typename" keyword be passed + elif '_typename' in kwargs: + typename = kwargs.pop('_typename') + import warnings + warnings.warn("Passing '_typename' as keyword argument is deprecated", + DeprecationWarning, stacklevel=2) + else: + raise TypeError("TypedDict.__new__() missing 1 required positional " + "argument: '_typename'") + if args: + try: + fields, = args # allow the "_fields" keyword be passed + except ValueError: + raise TypeError('TypedDict.__new__() takes from 2 to 3 ' + f'positional arguments but {len(args) + 2} ' + 'were given') + elif '_fields' in kwargs and len(kwargs) == 1: + fields = kwargs.pop('_fields') + import warnings + warnings.warn("Passing '_fields' as keyword argument is deprecated", + DeprecationWarning, stacklevel=2) + else: + fields = None + + if fields is None: + fields = kwargs + elif kwargs: + raise TypeError("TypedDict takes either a dict or keyword arguments," + " but not both") + + ns = {'__annotations__': dict(fields)} + try: + # Setting correct module is necessary to make typed dict classes pickleable. + ns['__module__'] = sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + pass + + return _TypedDictMeta(typename, (), ns, total=total) + + _typeddict_new.__text_signature__ = ('($cls, _typename, _fields=None,' + ' /, *, total=True, **kwargs)') + + class _TypedDictMeta(type): + def __init__(cls, name, bases, ns, total=True): + super().__init__(name, bases, ns) + + def __new__(cls, name, bases, ns, total=True): + # Create new typed dict class object. + # This method is called directly when TypedDict is subclassed, + # or via _typeddict_new when TypedDict is instantiated. This way + # TypedDict supports all three syntaxes described in its docstring. + # Subclasses and instances of TypedDict return actual dictionaries + # via _dict_new. + ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new + # Don't insert typing.Generic into __bases__ here, + # or Generic.__init_subclass__ will raise TypeError + # in the super().__new__() call. + # Instead, monkey-patch __bases__ onto the class after it's been created. + tp_dict = super().__new__(cls, name, (dict,), ns) + + if any(issubclass(base, typing.Generic) for base in bases): + tp_dict.__bases__ = (typing.Generic, dict) + _maybe_adjust_parameters(tp_dict) + + annotations = {} + own_annotations = ns.get('__annotations__', {}) + msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type" + own_annotations = { + n: typing._type_check(tp, msg) for n, tp in own_annotations.items() + } + required_keys = set() + optional_keys = set() + + for base in bases: + annotations.update(base.__dict__.get('__annotations__', {})) + required_keys.update(base.__dict__.get('__required_keys__', ())) + optional_keys.update(base.__dict__.get('__optional_keys__', ())) + + annotations.update(own_annotations) + for annotation_key, annotation_type in own_annotations.items(): + annotation_origin = get_origin(annotation_type) + if annotation_origin is Annotated: + annotation_args = get_args(annotation_type) + if annotation_args: + annotation_type = annotation_args[0] + annotation_origin = get_origin(annotation_type) + + if annotation_origin is Required: + required_keys.add(annotation_key) + elif annotation_origin is NotRequired: + optional_keys.add(annotation_key) + elif total: + required_keys.add(annotation_key) + else: + optional_keys.add(annotation_key) + + tp_dict.__annotations__ = annotations + tp_dict.__required_keys__ = frozenset(required_keys) + tp_dict.__optional_keys__ = frozenset(optional_keys) + if not hasattr(tp_dict, '__total__'): + tp_dict.__total__ = total + return tp_dict + + __instancecheck__ = __subclasscheck__ = _check_fails + + TypedDict = _TypedDictMeta('TypedDict', (dict,), {}) + TypedDict.__module__ = __name__ + TypedDict.__doc__ = \ + """A simple typed name space. At runtime it is equivalent to a plain dict. + + TypedDict creates a dictionary type that expects all of its + instances to have a certain set of keys, with each key + associated with a value of a consistent type. This expectation + is not checked at runtime but is only enforced by type checkers. + Usage:: + + class Point2D(TypedDict): + x: int + y: int + label: str + + a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK + b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check + + assert Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first') + + The type info can be accessed via the Point2D.__annotations__ dict, and + the Point2D.__required_keys__ and Point2D.__optional_keys__ frozensets. + TypedDict supports two additional equivalent forms:: + + Point2D = TypedDict('Point2D', x=int, y=int, label=str) + Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str}) + + The class syntax is only supported in Python 3.6+, while two other + syntax forms work for Python 2.7 and 3.2+ + """ + + if hasattr(typing, "_TypedDictMeta"): + _TYPEDDICT_TYPES = (typing._TypedDictMeta, _TypedDictMeta) + else: + _TYPEDDICT_TYPES = (_TypedDictMeta,) + + def is_typeddict(tp): + """Check if an annotation is a TypedDict class + + For example:: + class Film(TypedDict): + title: str + year: int + + is_typeddict(Film) # => True + is_typeddict(Union[list, str]) # => False + """ + return isinstance(tp, tuple(_TYPEDDICT_TYPES)) + + +if hasattr(typing, "assert_type"): + assert_type = typing.assert_type + +else: + def assert_type(__val, __typ): + """Assert (to the type checker) that the value is of the given type. + + When the type checker encounters a call to assert_type(), it + emits an error if the value is not of the specified type:: + + def greet(name: str) -> None: + assert_type(name, str) # ok + assert_type(name, int) # type checker error + + At runtime this returns the first argument unchanged and otherwise + does nothing. + """ + return __val + + +if hasattr(typing, "Required"): + get_type_hints = typing.get_type_hints +else: + import functools + import types + + # replaces _strip_annotations() + def _strip_extras(t): + """Strips Annotated, Required and NotRequired from a given type.""" + if isinstance(t, _AnnotatedAlias): + return _strip_extras(t.__origin__) + if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired): + return _strip_extras(t.__args__[0]) + if isinstance(t, typing._GenericAlias): + stripped_args = tuple(_strip_extras(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + return t.copy_with(stripped_args) + if hasattr(types, "GenericAlias") and isinstance(t, types.GenericAlias): + stripped_args = tuple(_strip_extras(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + return types.GenericAlias(t.__origin__, stripped_args) + if hasattr(types, "UnionType") and isinstance(t, types.UnionType): + stripped_args = tuple(_strip_extras(a) for a in t.__args__) + if stripped_args == t.__args__: + return t + return functools.reduce(operator.or_, stripped_args) + + return t + + def get_type_hints(obj, globalns=None, localns=None, include_extras=False): + """Return type hints for an object. + + This is often the same as obj.__annotations__, but it handles + forward references encoded as string literals, adds Optional[t] if a + default value equal to None is set and recursively replaces all + 'Annotated[T, ...]', 'Required[T]' or 'NotRequired[T]' with 'T' + (unless 'include_extras=True'). + + The argument may be a module, class, method, or function. The annotations + are returned as a dictionary. For classes, annotations include also + inherited members. + + TypeError is raised if the argument is not of a type that can contain + annotations, and an empty dictionary is returned if no annotations are + present. + + BEWARE -- the behavior of globalns and localns is counterintuitive + (unless you are familiar with how eval() and exec() work). The + search order is locals first, then globals. + + - If no dict arguments are passed, an attempt is made to use the + globals from obj (or the respective module's globals for classes), + and these are also used as the locals. If the object does not appear + to have globals, an empty dictionary is used. + + - If one dict argument is passed, it is used for both globals and + locals. + + - If two dict arguments are passed, they specify globals and + locals, respectively. + """ + if hasattr(typing, "Annotated"): + hint = typing.get_type_hints( + obj, globalns=globalns, localns=localns, include_extras=True + ) + else: + hint = typing.get_type_hints(obj, globalns=globalns, localns=localns) + if include_extras: + return hint + return {k: _strip_extras(t) for k, t in hint.items()} + + +# Python 3.9+ has PEP 593 (Annotated) +if hasattr(typing, 'Annotated'): + Annotated = typing.Annotated + # Not exported and not a public API, but needed for get_origin() and get_args() + # to work. + _AnnotatedAlias = typing._AnnotatedAlias +# 3.7-3.8 +else: + class _AnnotatedAlias(typing._GenericAlias, _root=True): + """Runtime representation of an annotated type. + + At its core 'Annotated[t, dec1, dec2, ...]' is an alias for the type 't' + with extra annotations. The alias behaves like a normal typing alias, + instantiating is the same as instantiating the underlying type, binding + it to types is also the same. + """ + def __init__(self, origin, metadata): + if isinstance(origin, _AnnotatedAlias): + metadata = origin.__metadata__ + metadata + origin = origin.__origin__ + super().__init__(origin, origin) + self.__metadata__ = metadata + + def copy_with(self, params): + assert len(params) == 1 + new_type = params[0] + return _AnnotatedAlias(new_type, self.__metadata__) + + def __repr__(self): + return (f"typing_extensions.Annotated[{typing._type_repr(self.__origin__)}, " + f"{', '.join(repr(a) for a in self.__metadata__)}]") + + def __reduce__(self): + return operator.getitem, ( + Annotated, (self.__origin__,) + self.__metadata__ + ) + + def __eq__(self, other): + if not isinstance(other, _AnnotatedAlias): + return NotImplemented + if self.__origin__ != other.__origin__: + return False + return self.__metadata__ == other.__metadata__ + + def __hash__(self): + return hash((self.__origin__, self.__metadata__)) + + class Annotated: + """Add context specific metadata to a type. + + Example: Annotated[int, runtime_check.Unsigned] indicates to the + hypothetical runtime_check module that this type is an unsigned int. + Every other consumer of this type can ignore this metadata and treat + this type as int. + + The first argument to Annotated must be a valid type (and will be in + the __origin__ field), the remaining arguments are kept as a tuple in + the __extra__ field. + + Details: + + - It's an error to call `Annotated` with less than two arguments. + - Nested Annotated are flattened:: + + Annotated[Annotated[T, Ann1, Ann2], Ann3] == Annotated[T, Ann1, Ann2, Ann3] + + - Instantiating an annotated type is equivalent to instantiating the + underlying type:: + + Annotated[C, Ann1](5) == C(5) + + - Annotated can be used as a generic type alias:: + + Optimized = Annotated[T, runtime.Optimize()] + Optimized[int] == Annotated[int, runtime.Optimize()] + + OptimizedList = Annotated[List[T], runtime.Optimize()] + OptimizedList[int] == Annotated[List[int], runtime.Optimize()] + """ + + __slots__ = () + + def __new__(cls, *args, **kwargs): + raise TypeError("Type Annotated cannot be instantiated.") + + @typing._tp_cache + def __class_getitem__(cls, params): + if not isinstance(params, tuple) or len(params) < 2: + raise TypeError("Annotated[...] should be used " + "with at least two arguments (a type and an " + "annotation).") + allowed_special_forms = (ClassVar, Final) + if get_origin(params[0]) in allowed_special_forms: + origin = params[0] + else: + msg = "Annotated[t, ...]: t must be a type." + origin = typing._type_check(params[0], msg) + metadata = tuple(params[1:]) + return _AnnotatedAlias(origin, metadata) + + def __init_subclass__(cls, *args, **kwargs): + raise TypeError( + f"Cannot subclass {cls.__module__}.Annotated" + ) + +# Python 3.8 has get_origin() and get_args() but those implementations aren't +# Annotated-aware, so we can't use those. Python 3.9's versions don't support +# ParamSpecArgs and ParamSpecKwargs, so only Python 3.10's versions will do. +if sys.version_info[:2] >= (3, 10): + get_origin = typing.get_origin + get_args = typing.get_args +# 3.7-3.9 +else: + try: + # 3.9+ + from typing import _BaseGenericAlias + except ImportError: + _BaseGenericAlias = typing._GenericAlias + try: + # 3.9+ + from typing import GenericAlias as _typing_GenericAlias + except ImportError: + _typing_GenericAlias = typing._GenericAlias + + def get_origin(tp): + """Get the unsubscripted version of a type. + + This supports generic types, Callable, Tuple, Union, Literal, Final, ClassVar + and Annotated. Return None for unsupported types. Examples:: + + get_origin(Literal[42]) is Literal + get_origin(int) is None + get_origin(ClassVar[int]) is ClassVar + get_origin(Generic) is Generic + get_origin(Generic[T]) is Generic + get_origin(Union[T, int]) is Union + get_origin(List[Tuple[T, T]][int]) == list + get_origin(P.args) is P + """ + if isinstance(tp, _AnnotatedAlias): + return Annotated + if isinstance(tp, (typing._GenericAlias, _typing_GenericAlias, _BaseGenericAlias, + ParamSpecArgs, ParamSpecKwargs)): + return tp.__origin__ + if tp is typing.Generic: + return typing.Generic + return None + + def get_args(tp): + """Get type arguments with all substitutions performed. + + For unions, basic simplifications used by Union constructor are performed. + Examples:: + get_args(Dict[str, int]) == (str, int) + get_args(int) == () + get_args(Union[int, Union[T, int], str][int]) == (int, str) + get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) + get_args(Callable[[], T][int]) == ([], int) + """ + if isinstance(tp, _AnnotatedAlias): + return (tp.__origin__,) + tp.__metadata__ + if isinstance(tp, (typing._GenericAlias, _typing_GenericAlias)): + if getattr(tp, "_special", False): + return () + res = tp.__args__ + if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis: + res = (list(res[:-1]), res[-1]) + return res + return () + + +# 3.10+ +if hasattr(typing, 'TypeAlias'): + TypeAlias = typing.TypeAlias +# 3.9 +elif sys.version_info[:2] >= (3, 9): + class _TypeAliasForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + @_TypeAliasForm + def TypeAlias(self, parameters): + """Special marker indicating that an assignment should + be recognized as a proper type alias definition by type + checkers. + + For example:: + + Predicate: TypeAlias = Callable[..., bool] + + It's invalid when used anywhere except as in the example above. + """ + raise TypeError(f"{self} is not subscriptable") +# 3.7-3.8 +else: + class _TypeAliasForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + TypeAlias = _TypeAliasForm('TypeAlias', + doc="""Special marker indicating that an assignment should + be recognized as a proper type alias definition by type + checkers. + + For example:: + + Predicate: TypeAlias = Callable[..., bool] + + It's invalid when used anywhere except as in the example + above.""") + + +class _DefaultMixin: + """Mixin for TypeVarLike defaults.""" + + __slots__ = () + + def __init__(self, default): + if isinstance(default, (tuple, list)): + self.__default__ = tuple((typing._type_check(d, "Default must be a type") + for d in default)) + elif default: + self.__default__ = typing._type_check(default, "Default must be a type") + else: + self.__default__ = None + + +# Add default and infer_variance parameters from PEP 696 and 695 +class TypeVar(typing.TypeVar, _DefaultMixin, _root=True): + """Type variable.""" + + __module__ = 'typing' + + def __init__(self, name, *constraints, bound=None, + covariant=False, contravariant=False, + default=None, infer_variance=False): + super().__init__(name, *constraints, bound=bound, covariant=covariant, + contravariant=contravariant) + _DefaultMixin.__init__(self, default) + self.__infer_variance__ = infer_variance + + # for pickling: + try: + def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + def_mod = None + if def_mod != 'typing_extensions': + self.__module__ = def_mod + + +# Python 3.10+ has PEP 612 +if hasattr(typing, 'ParamSpecArgs'): + ParamSpecArgs = typing.ParamSpecArgs + ParamSpecKwargs = typing.ParamSpecKwargs +# 3.7-3.9 +else: + class _Immutable: + """Mixin to indicate that object should not be copied.""" + __slots__ = () + + def __copy__(self): + return self + + def __deepcopy__(self, memo): + return self + + class ParamSpecArgs(_Immutable): + """The args for a ParamSpec object. + + Given a ParamSpec object P, P.args is an instance of ParamSpecArgs. + + ParamSpecArgs objects have a reference back to their ParamSpec: + + P.args.__origin__ is P + + This type is meant for runtime introspection and has no special meaning to + static type checkers. + """ + def __init__(self, origin): + self.__origin__ = origin + + def __repr__(self): + return f"{self.__origin__.__name__}.args" + + def __eq__(self, other): + if not isinstance(other, ParamSpecArgs): + return NotImplemented + return self.__origin__ == other.__origin__ + + class ParamSpecKwargs(_Immutable): + """The kwargs for a ParamSpec object. + + Given a ParamSpec object P, P.kwargs is an instance of ParamSpecKwargs. + + ParamSpecKwargs objects have a reference back to their ParamSpec: + + P.kwargs.__origin__ is P + + This type is meant for runtime introspection and has no special meaning to + static type checkers. + """ + def __init__(self, origin): + self.__origin__ = origin + + def __repr__(self): + return f"{self.__origin__.__name__}.kwargs" + + def __eq__(self, other): + if not isinstance(other, ParamSpecKwargs): + return NotImplemented + return self.__origin__ == other.__origin__ + +# 3.10+ +if hasattr(typing, 'ParamSpec'): + + # Add default Parameter - PEP 696 + class ParamSpec(typing.ParamSpec, _DefaultMixin, _root=True): + """Parameter specification variable.""" + + __module__ = 'typing' + + def __init__(self, name, *, bound=None, covariant=False, contravariant=False, + default=None): + super().__init__(name, bound=bound, covariant=covariant, + contravariant=contravariant) + _DefaultMixin.__init__(self, default) + + # for pickling: + try: + def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + def_mod = None + if def_mod != 'typing_extensions': + self.__module__ = def_mod + +# 3.7-3.9 +else: + + # Inherits from list as a workaround for Callable checks in Python < 3.9.2. + class ParamSpec(list, _DefaultMixin): + """Parameter specification variable. + + Usage:: + + P = ParamSpec('P') + + Parameter specification variables exist primarily for the benefit of static + type checkers. They are used to forward the parameter types of one + callable to another callable, a pattern commonly found in higher order + functions and decorators. They are only valid when used in ``Concatenate``, + or s the first argument to ``Callable``. In Python 3.10 and higher, + they are also supported in user-defined Generics at runtime. + See class Generic for more information on generic types. An + example for annotating a decorator:: + + T = TypeVar('T') + P = ParamSpec('P') + + def add_logging(f: Callable[P, T]) -> Callable[P, T]: + '''A type-safe decorator to add logging to a function.''' + def inner(*args: P.args, **kwargs: P.kwargs) -> T: + logging.info(f'{f.__name__} was called') + return f(*args, **kwargs) + return inner + + @add_logging + def add_two(x: float, y: float) -> float: + '''Add two numbers together.''' + return x + y + + Parameter specification variables defined with covariant=True or + contravariant=True can be used to declare covariant or contravariant + generic types. These keyword arguments are valid, but their actual semantics + are yet to be decided. See PEP 612 for details. + + Parameter specification variables can be introspected. e.g.: + + P.__name__ == 'T' + P.__bound__ == None + P.__covariant__ == False + P.__contravariant__ == False + + Note that only parameter specification variables defined in global scope can + be pickled. + """ + + # Trick Generic __parameters__. + __class__ = typing.TypeVar + + @property + def args(self): + return ParamSpecArgs(self) + + @property + def kwargs(self): + return ParamSpecKwargs(self) + + def __init__(self, name, *, bound=None, covariant=False, contravariant=False, + default=None): + super().__init__([self]) + self.__name__ = name + self.__covariant__ = bool(covariant) + self.__contravariant__ = bool(contravariant) + if bound: + self.__bound__ = typing._type_check(bound, 'Bound must be a type.') + else: + self.__bound__ = None + _DefaultMixin.__init__(self, default) + + # for pickling: + try: + def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + def_mod = None + if def_mod != 'typing_extensions': + self.__module__ = def_mod + + def __repr__(self): + if self.__covariant__: + prefix = '+' + elif self.__contravariant__: + prefix = '-' + else: + prefix = '~' + return prefix + self.__name__ + + def __hash__(self): + return object.__hash__(self) + + def __eq__(self, other): + return self is other + + def __reduce__(self): + return self.__name__ + + # Hack to get typing._type_check to pass. + def __call__(self, *args, **kwargs): + pass + + +# 3.7-3.9 +if not hasattr(typing, 'Concatenate'): + # Inherits from list as a workaround for Callable checks in Python < 3.9.2. + class _ConcatenateGenericAlias(list): + + # Trick Generic into looking into this for __parameters__. + __class__ = typing._GenericAlias + + # Flag in 3.8. + _special = False + + def __init__(self, origin, args): + super().__init__(args) + self.__origin__ = origin + self.__args__ = args + + def __repr__(self): + _type_repr = typing._type_repr + return (f'{_type_repr(self.__origin__)}' + f'[{", ".join(_type_repr(arg) for arg in self.__args__)}]') + + def __hash__(self): + return hash((self.__origin__, self.__args__)) + + # Hack to get typing._type_check to pass in Generic. + def __call__(self, *args, **kwargs): + pass + + @property + def __parameters__(self): + return tuple( + tp for tp in self.__args__ if isinstance(tp, (typing.TypeVar, ParamSpec)) + ) + + +# 3.7-3.9 +@typing._tp_cache +def _concatenate_getitem(self, parameters): + if parameters == (): + raise TypeError("Cannot take a Concatenate of no types.") + if not isinstance(parameters, tuple): + parameters = (parameters,) + if not isinstance(parameters[-1], ParamSpec): + raise TypeError("The last parameter to Concatenate should be a " + "ParamSpec variable.") + msg = "Concatenate[arg, ...]: each arg must be a type." + parameters = tuple(typing._type_check(p, msg) for p in parameters) + return _ConcatenateGenericAlias(self, parameters) + + +# 3.10+ +if hasattr(typing, 'Concatenate'): + Concatenate = typing.Concatenate + _ConcatenateGenericAlias = typing._ConcatenateGenericAlias # noqa +# 3.9 +elif sys.version_info[:2] >= (3, 9): + @_TypeAliasForm + def Concatenate(self, parameters): + """Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a + higher order function which adds, removes or transforms parameters of a + callable. + + For example:: + + Callable[Concatenate[int, P], int] + + See PEP 612 for detailed information. + """ + return _concatenate_getitem(self, parameters) +# 3.7-8 +else: + class _ConcatenateForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + def __getitem__(self, parameters): + return _concatenate_getitem(self, parameters) + + Concatenate = _ConcatenateForm( + 'Concatenate', + doc="""Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a + higher order function which adds, removes or transforms parameters of a + callable. + + For example:: + + Callable[Concatenate[int, P], int] + + See PEP 612 for detailed information. + """) + +# 3.10+ +if hasattr(typing, 'TypeGuard'): + TypeGuard = typing.TypeGuard +# 3.9 +elif sys.version_info[:2] >= (3, 9): + class _TypeGuardForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + @_TypeGuardForm + def TypeGuard(self, parameters): + """Special typing form used to annotate the return type of a user-defined + type guard function. ``TypeGuard`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean. + + ``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type guard". + + Sometimes it would be convenient to use a user-defined boolean function + as a type guard. Such a function should use ``TypeGuard[...]`` as its + return type to alert static type checkers to this intention. + + Using ``-> TypeGuard`` tells the static type checker that for a given + function: + + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the type inside ``TypeGuard``. + + For example:: + + def is_str(val: Union[str, float]): + # "isinstance" type guard + if isinstance(val, str): + # Type of ``val`` is narrowed to ``str`` + ... + else: + # Else, type of ``val`` is narrowed to ``float``. + ... + + Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower + form of ``TypeA`` (it can even be a wider form) and this may lead to + type-unsafe results. The main reason is to allow for things like + narrowing ``List[object]`` to ``List[str]`` even though the latter is not + a subtype of the former, since ``List`` is invariant. The responsibility of + writing type-safe type guards is left to the user. + + ``TypeGuard`` also works with type variables. For more information, see + PEP 647 (User-Defined Type Guards). + """ + item = typing._type_check(parameters, f'{self} accepts only a single type.') + return typing._GenericAlias(self, (item,)) +# 3.7-3.8 +else: + class _TypeGuardForm(typing._SpecialForm, _root=True): + + def __repr__(self): + return 'typing_extensions.' + self._name + + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type') + return typing._GenericAlias(self, (item,)) + + TypeGuard = _TypeGuardForm( + 'TypeGuard', + doc="""Special typing form used to annotate the return type of a user-defined + type guard function. ``TypeGuard`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean. + + ``TypeGuard`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type guard". + + Sometimes it would be convenient to use a user-defined boolean function + as a type guard. Such a function should use ``TypeGuard[...]`` as its + return type to alert static type checkers to this intention. + + Using ``-> TypeGuard`` tells the static type checker that for a given + function: + + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the type inside ``TypeGuard``. + + For example:: + + def is_str(val: Union[str, float]): + # "isinstance" type guard + if isinstance(val, str): + # Type of ``val`` is narrowed to ``str`` + ... + else: + # Else, type of ``val`` is narrowed to ``float``. + ... + + Strict type narrowing is not enforced -- ``TypeB`` need not be a narrower + form of ``TypeA`` (it can even be a wider form) and this may lead to + type-unsafe results. The main reason is to allow for things like + narrowing ``List[object]`` to ``List[str]`` even though the latter is not + a subtype of the former, since ``List`` is invariant. The responsibility of + writing type-safe type guards is left to the user. + + ``TypeGuard`` also works with type variables. For more information, see + PEP 647 (User-Defined Type Guards). + """) + + +# Vendored from cpython typing._SpecialFrom +class _SpecialForm(typing._Final, _root=True): + __slots__ = ('_name', '__doc__', '_getitem') + + def __init__(self, getitem): + self._getitem = getitem + self._name = getitem.__name__ + self.__doc__ = getitem.__doc__ + + def __getattr__(self, item): + if item in {'__name__', '__qualname__'}: + return self._name + + raise AttributeError(item) + + def __mro_entries__(self, bases): + raise TypeError(f"Cannot subclass {self!r}") + + def __repr__(self): + return f'typing_extensions.{self._name}' + + def __reduce__(self): + return self._name + + def __call__(self, *args, **kwds): + raise TypeError(f"Cannot instantiate {self!r}") + + def __or__(self, other): + return typing.Union[self, other] + + def __ror__(self, other): + return typing.Union[other, self] + + def __instancecheck__(self, obj): + raise TypeError(f"{self} cannot be used with isinstance()") + + def __subclasscheck__(self, cls): + raise TypeError(f"{self} cannot be used with issubclass()") + + @typing._tp_cache + def __getitem__(self, parameters): + return self._getitem(self, parameters) + + +if hasattr(typing, "LiteralString"): + LiteralString = typing.LiteralString +else: + @_SpecialForm + def LiteralString(self, params): + """Represents an arbitrary literal string. + + Example:: + + from typing_extensions import LiteralString + + def query(sql: LiteralString) -> ...: + ... + + query("SELECT * FROM table") # ok + query(f"SELECT * FROM {input()}") # not ok + + See PEP 675 for details. + + """ + raise TypeError(f"{self} is not subscriptable") + + +if hasattr(typing, "Self"): + Self = typing.Self +else: + @_SpecialForm + def Self(self, params): + """Used to spell the type of "self" in classes. + + Example:: + + from typing import Self + + class ReturnsSelf: + def parse(self, data: bytes) -> Self: + ... + return self + + """ + + raise TypeError(f"{self} is not subscriptable") + + +if hasattr(typing, "Never"): + Never = typing.Never +else: + @_SpecialForm + def Never(self, params): + """The bottom type, a type that has no members. + + This can be used to define a function that should never be + called, or a function that never returns:: + + from typing_extensions import Never + + def never_call_me(arg: Never) -> None: + pass + + def int_or_str(arg: int | str) -> None: + never_call_me(arg) # type checker error + match arg: + case int(): + print("It's an int") + case str(): + print("It's a str") + case _: + never_call_me(arg) # ok, arg is of type Never + + """ + + raise TypeError(f"{self} is not subscriptable") + + +if hasattr(typing, 'Required'): + Required = typing.Required + NotRequired = typing.NotRequired +elif sys.version_info[:2] >= (3, 9): + class _ExtensionsSpecialForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + @_ExtensionsSpecialForm + def Required(self, parameters): + """A special typing construct to mark a key of a total=False TypedDict + as required. For example: + + class Movie(TypedDict, total=False): + title: Required[str] + year: int + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + + There is no runtime checking that a required key is actually provided + when instantiating a related TypedDict. + """ + item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + + @_ExtensionsSpecialForm + def NotRequired(self, parameters): + """A special typing construct to mark a key of a TypedDict as + potentially missing. For example: + + class Movie(TypedDict): + title: str + year: NotRequired[int] + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + """ + item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + +else: + class _RequiredForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + + Required = _RequiredForm( + 'Required', + doc="""A special typing construct to mark a key of a total=False TypedDict + as required. For example: + + class Movie(TypedDict, total=False): + title: Required[str] + year: int + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + + There is no runtime checking that a required key is actually provided + when instantiating a related TypedDict. + """) + NotRequired = _RequiredForm( + 'NotRequired', + doc="""A special typing construct to mark a key of a TypedDict as + potentially missing. For example: + + class Movie(TypedDict): + title: str + year: NotRequired[int] + + m = Movie( + title='The Matrix', # typechecker error if key is omitted + year=1999, + ) + """) + + +if hasattr(typing, "Unpack"): # 3.11+ + Unpack = typing.Unpack +elif sys.version_info[:2] >= (3, 9): + class _UnpackSpecialForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + class _UnpackAlias(typing._GenericAlias, _root=True): + __class__ = typing.TypeVar + + @_UnpackSpecialForm + def Unpack(self, parameters): + """A special typing construct to unpack a variadic type. For example: + + Shape = TypeVarTuple('Shape') + Batch = NewType('Batch', int) + + def add_batch_axis( + x: Array[Unpack[Shape]] + ) -> Array[Batch, Unpack[Shape]]: ... + + """ + item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + return _UnpackAlias(self, (item,)) + + def _is_unpack(obj): + return isinstance(obj, _UnpackAlias) + +else: + class _UnpackAlias(typing._GenericAlias, _root=True): + __class__ = typing.TypeVar + + class _UnpackForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name + + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type.') + return _UnpackAlias(self, (item,)) + + Unpack = _UnpackForm( + 'Unpack', + doc="""A special typing construct to unpack a variadic type. For example: + + Shape = TypeVarTuple('Shape') + Batch = NewType('Batch', int) + + def add_batch_axis( + x: Array[Unpack[Shape]] + ) -> Array[Batch, Unpack[Shape]]: ... + + """) + + def _is_unpack(obj): + return isinstance(obj, _UnpackAlias) + + +if hasattr(typing, "TypeVarTuple"): # 3.11+ + + # Add default Parameter - PEP 696 + class TypeVarTuple(typing.TypeVarTuple, _DefaultMixin, _root=True): + """Type variable tuple.""" + + def __init__(self, name, *, default=None): + super().__init__(name) + _DefaultMixin.__init__(self, default) + + # for pickling: + try: + def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + def_mod = None + if def_mod != 'typing_extensions': + self.__module__ = def_mod + +else: + class TypeVarTuple(_DefaultMixin): + """Type variable tuple. + + Usage:: + + Ts = TypeVarTuple('Ts') + + In the same way that a normal type variable is a stand-in for a single + type such as ``int``, a type variable *tuple* is a stand-in for a *tuple* + type such as ``Tuple[int, str]``. + + Type variable tuples can be used in ``Generic`` declarations. + Consider the following example:: + + class Array(Generic[*Ts]): ... + + The ``Ts`` type variable tuple here behaves like ``tuple[T1, T2]``, + where ``T1`` and ``T2`` are type variables. To use these type variables + as type parameters of ``Array``, we must *unpack* the type variable tuple using + the star operator: ``*Ts``. The signature of ``Array`` then behaves + as if we had simply written ``class Array(Generic[T1, T2]): ...``. + In contrast to ``Generic[T1, T2]``, however, ``Generic[*Shape]`` allows + us to parameterise the class with an *arbitrary* number of type parameters. + + Type variable tuples can be used anywhere a normal ``TypeVar`` can. + This includes class definitions, as shown above, as well as function + signatures and variable annotations:: + + class Array(Generic[*Ts]): + + def __init__(self, shape: Tuple[*Ts]): + self._shape: Tuple[*Ts] = shape + + def get_shape(self) -> Tuple[*Ts]: + return self._shape + + shape = (Height(480), Width(640)) + x: Array[Height, Width] = Array(shape) + y = abs(x) # Inferred type is Array[Height, Width] + z = x + x # ... is Array[Height, Width] + x.get_shape() # ... is tuple[Height, Width] + + """ + + # Trick Generic __parameters__. + __class__ = typing.TypeVar + + def __iter__(self): + yield self.__unpacked__ + + def __init__(self, name, *, default=None): + self.__name__ = name + _DefaultMixin.__init__(self, default) + + # for pickling: + try: + def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + def_mod = None + if def_mod != 'typing_extensions': + self.__module__ = def_mod + + self.__unpacked__ = Unpack[self] + + def __repr__(self): + return self.__name__ + + def __hash__(self): + return object.__hash__(self) + + def __eq__(self, other): + return self is other + + def __reduce__(self): + return self.__name__ + + def __init_subclass__(self, *args, **kwds): + if '_root' not in kwds: + raise TypeError("Cannot subclass special typing classes") + + +if hasattr(typing, "reveal_type"): + reveal_type = typing.reveal_type +else: + def reveal_type(__obj: T) -> T: + """Reveal the inferred type of a variable. + + When a static type checker encounters a call to ``reveal_type()``, + it will emit the inferred type of the argument:: + + x: int = 1 + reveal_type(x) + + Running a static type checker (e.g., ``mypy``) on this example + will produce output similar to 'Revealed type is "builtins.int"'. + + At runtime, the function prints the runtime type of the + argument and returns it unchanged. + + """ + print(f"Runtime type is {type(__obj).__name__!r}", file=sys.stderr) + return __obj + + +if hasattr(typing, "assert_never"): + assert_never = typing.assert_never +else: + def assert_never(__arg: Never) -> Never: + """Assert to the type checker that a line of code is unreachable. + + Example:: + + def int_or_str(arg: int | str) -> None: + match arg: + case int(): + print("It's an int") + case str(): + print("It's a str") + case _: + assert_never(arg) + + If a type checker finds that a call to assert_never() is + reachable, it will emit an error. + + At runtime, this throws an exception when called. + + """ + raise AssertionError("Expected code to be unreachable") + + +if hasattr(typing, 'dataclass_transform'): + dataclass_transform = typing.dataclass_transform +else: + def dataclass_transform( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + field_specifiers: typing.Tuple[ + typing.Union[typing.Type[typing.Any], typing.Callable[..., typing.Any]], + ... + ] = (), + **kwargs: typing.Any, + ) -> typing.Callable[[T], T]: + """Decorator that marks a function, class, or metaclass as providing + dataclass-like behavior. + + Example: + + from typing_extensions import dataclass_transform + + _T = TypeVar("_T") + + # Used on a decorator function + @dataclass_transform() + def create_model(cls: type[_T]) -> type[_T]: + ... + return cls + + @create_model + class CustomerModel: + id: int + name: str + + # Used on a base class + @dataclass_transform() + class ModelBase: ... + + class CustomerModel(ModelBase): + id: int + name: str + + # Used on a metaclass + @dataclass_transform() + class ModelMeta(type): ... + + class ModelBase(metaclass=ModelMeta): ... + + class CustomerModel(ModelBase): + id: int + name: str + + Each of the ``CustomerModel`` classes defined in this example will now + behave similarly to a dataclass created with the ``@dataclasses.dataclass`` + decorator. For example, the type checker will synthesize an ``__init__`` + method. + + The arguments to this decorator can be used to customize this behavior: + - ``eq_default`` indicates whether the ``eq`` parameter is assumed to be + True or False if it is omitted by the caller. + - ``order_default`` indicates whether the ``order`` parameter is + assumed to be True or False if it is omitted by the caller. + - ``kw_only_default`` indicates whether the ``kw_only`` parameter is + assumed to be True or False if it is omitted by the caller. + - ``field_specifiers`` specifies a static list of supported classes + or functions that describe fields, similar to ``dataclasses.field()``. + + At runtime, this decorator records its arguments in the + ``__dataclass_transform__`` attribute on the decorated object. + + See PEP 681 for details. + + """ + def decorator(cls_or_fn): + cls_or_fn.__dataclass_transform__ = { + "eq_default": eq_default, + "order_default": order_default, + "kw_only_default": kw_only_default, + "field_specifiers": field_specifiers, + "kwargs": kwargs, + } + return cls_or_fn + return decorator + + +if hasattr(typing, "override"): + override = typing.override +else: + _F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any]) + + def override(__arg: _F) -> _F: + """Indicate that a method is intended to override a method in a base class. + + Usage: + + class Base: + def method(self) -> None: ... + pass + + class Child(Base): + @override + def method(self) -> None: + super().method() + + When this decorator is applied to a method, the type checker will + validate that it overrides a method with the same name on a base class. + This helps prevent bugs that may occur when a base class is changed + without an equivalent change to a child class. + + See PEP 698 for details. + + """ + return __arg + + +# We have to do some monkey patching to deal with the dual nature of +# Unpack/TypeVarTuple: +# - We want Unpack to be a kind of TypeVar so it gets accepted in +# Generic[Unpack[Ts]] +# - We want it to *not* be treated as a TypeVar for the purposes of +# counting generic parameters, so that when we subscript a generic, +# the runtime doesn't try to substitute the Unpack with the subscripted type. +if not hasattr(typing, "TypeVarTuple"): + typing._collect_type_vars = _collect_type_vars + typing._check_generic = _check_generic + + +# Backport typing.NamedTuple as it exists in Python 3.11. +# In 3.11, the ability to define generic `NamedTuple`s was supported. +# This was explicitly disallowed in 3.9-3.10, and only half-worked in <=3.8. +if sys.version_info >= (3, 11): + NamedTuple = typing.NamedTuple +else: + def _caller(): + try: + return sys._getframe(2).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): # For platforms without _getframe() + return None + + def _make_nmtuple(name, types, module, defaults=()): + fields = [n for n, t in types] + annotations = {n: typing._type_check(t, f"field {n} annotation must be a type") + for n, t in types} + nm_tpl = collections.namedtuple(name, fields, + defaults=defaults, module=module) + nm_tpl.__annotations__ = nm_tpl.__new__.__annotations__ = annotations + # The `_field_types` attribute was removed in 3.9; + # in earlier versions, it is the same as the `__annotations__` attribute + if sys.version_info < (3, 9): + nm_tpl._field_types = annotations + return nm_tpl + + _prohibited_namedtuple_fields = typing._prohibited + _special_namedtuple_fields = frozenset({'__module__', '__name__', '__annotations__'}) + + class _NamedTupleMeta(type): + def __new__(cls, typename, bases, ns): + assert _NamedTuple in bases + for base in bases: + if base is not _NamedTuple and base is not typing.Generic: + raise TypeError( + 'can only inherit from a NamedTuple type and Generic') + bases = tuple(tuple if base is _NamedTuple else base for base in bases) + types = ns.get('__annotations__', {}) + default_names = [] + for field_name in types: + if field_name in ns: + default_names.append(field_name) + elif default_names: + raise TypeError(f"Non-default namedtuple field {field_name} " + f"cannot follow default field" + f"{'s' if len(default_names) > 1 else ''} " + f"{', '.join(default_names)}") + nm_tpl = _make_nmtuple( + typename, types.items(), + defaults=[ns[n] for n in default_names], + module=ns['__module__'] + ) + nm_tpl.__bases__ = bases + if typing.Generic in bases: + class_getitem = typing.Generic.__class_getitem__.__func__ + nm_tpl.__class_getitem__ = classmethod(class_getitem) + # update from user namespace without overriding special namedtuple attributes + for key in ns: + if key in _prohibited_namedtuple_fields: + raise AttributeError("Cannot overwrite NamedTuple attribute " + key) + elif key not in _special_namedtuple_fields and key not in nm_tpl._fields: + setattr(nm_tpl, key, ns[key]) + if typing.Generic in bases: + nm_tpl.__init_subclass__() + return nm_tpl + + def NamedTuple(__typename, __fields=None, **kwargs): + if __fields is None: + __fields = kwargs.items() + elif kwargs: + raise TypeError("Either list of fields or keywords" + " can be provided to NamedTuple, not both") + return _make_nmtuple(__typename, __fields, module=_caller()) + + NamedTuple.__doc__ = typing.NamedTuple.__doc__ + _NamedTuple = type.__new__(_NamedTupleMeta, 'NamedTuple', (), {}) + + # On 3.8+, alter the signature so that it matches typing.NamedTuple. + # The signature of typing.NamedTuple on >=3.8 is invalid syntax in Python 3.7, + # so just leave the signature as it is on 3.7. + if sys.version_info >= (3, 8): + NamedTuple.__text_signature__ = '(typename, fields=None, /, **kwargs)' + + def _namedtuple_mro_entries(bases): + assert NamedTuple in bases + return (_NamedTuple,) + + NamedTuple.__mro_entries__ = _namedtuple_mro_entries diff --git a/libs/win/zipp/__init__.py b/libs/win/zipp/__init__.py new file mode 100644 index 00000000..ad01e27e --- /dev/null +++ b/libs/win/zipp/__init__.py @@ -0,0 +1,381 @@ +import io +import posixpath +import zipfile +import itertools +import contextlib +import pathlib +import re +import fnmatch + +from .py310compat import text_encoding + + +__all__ = ['Path'] + + +def _parents(path): + """ + Given a path with elements separated by + posixpath.sep, generate all parents of that path. + + >>> list(_parents('b/d')) + ['b'] + >>> list(_parents('/b/d/')) + ['/b'] + >>> list(_parents('b/d/f/')) + ['b/d', 'b'] + >>> list(_parents('b')) + [] + >>> list(_parents('')) + [] + """ + return itertools.islice(_ancestry(path), 1, None) + + +def _ancestry(path): + """ + Given a path with elements separated by + posixpath.sep, generate all elements of that path + + >>> list(_ancestry('b/d')) + ['b/d', 'b'] + >>> list(_ancestry('/b/d/')) + ['/b/d', '/b'] + >>> list(_ancestry('b/d/f/')) + ['b/d/f', 'b/d', 'b'] + >>> list(_ancestry('b')) + ['b'] + >>> list(_ancestry('')) + [] + """ + path = path.rstrip(posixpath.sep) + while path and path != posixpath.sep: + yield path + path, tail = posixpath.split(path) + + +_dedupe = dict.fromkeys +"""Deduplicate an iterable in original order""" + + +def _difference(minuend, subtrahend): + """ + Return items in minuend not in subtrahend, retaining order + with O(1) lookup. + """ + return itertools.filterfalse(set(subtrahend).__contains__, minuend) + + +class InitializedState: + """ + Mix-in to save the initialization state for pickling. + """ + + def __init__(self, *args, **kwargs): + self.__args = args + self.__kwargs = kwargs + super().__init__(*args, **kwargs) + + def __getstate__(self): + return self.__args, self.__kwargs + + def __setstate__(self, state): + args, kwargs = state + super().__init__(*args, **kwargs) + + +class CompleteDirs(InitializedState, zipfile.ZipFile): + """ + A ZipFile subclass that ensures that implied directories + are always included in the namelist. + """ + + @staticmethod + def _implied_dirs(names): + parents = itertools.chain.from_iterable(map(_parents, names)) + as_dirs = (p + posixpath.sep for p in parents) + return _dedupe(_difference(as_dirs, names)) + + def namelist(self): + names = super(CompleteDirs, self).namelist() + return names + list(self._implied_dirs(names)) + + def _name_set(self): + return set(self.namelist()) + + def resolve_dir(self, name): + """ + If the name represents a directory, return that name + as a directory (with the trailing slash). + """ + names = self._name_set() + dirname = name + '/' + dir_match = name not in names and dirname in names + return dirname if dir_match else name + + @classmethod + def make(cls, source): + """ + Given a source (filename or zipfile), return an + appropriate CompleteDirs subclass. + """ + if isinstance(source, CompleteDirs): + return source + + if not isinstance(source, zipfile.ZipFile): + return cls(source) + + # Only allow for FastLookup when supplied zipfile is read-only + if 'r' not in source.mode: + cls = CompleteDirs + + source.__class__ = cls + return source + + +class FastLookup(CompleteDirs): + """ + ZipFile subclass to ensure implicit + dirs exist and are resolved rapidly. + """ + + def namelist(self): + with contextlib.suppress(AttributeError): + return self.__names + self.__names = super(FastLookup, self).namelist() + return self.__names + + def _name_set(self): + with contextlib.suppress(AttributeError): + return self.__lookup + self.__lookup = super(FastLookup, self)._name_set() + return self.__lookup + + +class Path: + """ + A pathlib-compatible interface for zip files. + + Consider a zip file with this structure:: + + . + ├── a.txt + └── b + ├── c.txt + └── d + └── e.txt + + >>> data = io.BytesIO() + >>> zf = zipfile.ZipFile(data, 'w') + >>> zf.writestr('a.txt', 'content of a') + >>> zf.writestr('b/c.txt', 'content of c') + >>> zf.writestr('b/d/e.txt', 'content of e') + >>> zf.filename = 'mem/abcde.zip' + + Path accepts the zipfile object itself or a filename + + >>> root = Path(zf) + + From there, several path operations are available. + + Directory iteration (including the zip file itself): + + >>> a, b = root.iterdir() + >>> a + Path('mem/abcde.zip', 'a.txt') + >>> b + Path('mem/abcde.zip', 'b/') + + name property: + + >>> b.name + 'b' + + join with divide operator: + + >>> c = b / 'c.txt' + >>> c + Path('mem/abcde.zip', 'b/c.txt') + >>> c.name + 'c.txt' + + Read text: + + >>> c.read_text() + 'content of c' + + existence: + + >>> c.exists() + True + >>> (b / 'missing.txt').exists() + False + + Coercion to string: + + >>> import os + >>> str(c).replace(os.sep, posixpath.sep) + 'mem/abcde.zip/b/c.txt' + + At the root, ``name``, ``filename``, and ``parent`` + resolve to the zipfile. Note these attributes are not + valid and will raise a ``ValueError`` if the zipfile + has no filename. + + >>> root.name + 'abcde.zip' + >>> str(root.filename).replace(os.sep, posixpath.sep) + 'mem/abcde.zip' + >>> str(root.parent) + 'mem' + """ + + __repr = "{self.__class__.__name__}({self.root.filename!r}, {self.at!r})" + + def __init__(self, root, at=""): + """ + Construct a Path from a ZipFile or filename. + + Note: When the source is an existing ZipFile object, + its type (__class__) will be mutated to a + specialized type. If the caller wishes to retain the + original type, the caller should either create a + separate ZipFile object or pass a filename. + """ + self.root = FastLookup.make(root) + self.at = at + + def __eq__(self, other): + """ + >>> Path(zipfile.ZipFile(io.BytesIO(), 'w')) == 'foo' + False + """ + if self.__class__ is not other.__class__: + return NotImplemented + return (self.root, self.at) == (other.root, other.at) + + def __hash__(self): + return hash((self.root, self.at)) + + def open(self, mode='r', *args, pwd=None, **kwargs): + """ + Open this entry as text or binary following the semantics + of ``pathlib.Path.open()`` by passing arguments through + to io.TextIOWrapper(). + """ + if self.is_dir(): + raise IsADirectoryError(self) + zip_mode = mode[0] + if not self.exists() and zip_mode == 'r': + raise FileNotFoundError(self) + stream = self.root.open(self.at, zip_mode, pwd=pwd) + if 'b' in mode: + if args or kwargs: + raise ValueError("encoding args invalid for binary operation") + return stream + else: + kwargs["encoding"] = text_encoding(kwargs.get("encoding")) + return io.TextIOWrapper(stream, *args, **kwargs) + + @property + def name(self): + return pathlib.Path(self.at).name or self.filename.name + + @property + def suffix(self): + return pathlib.Path(self.at).suffix or self.filename.suffix + + @property + def suffixes(self): + return pathlib.Path(self.at).suffixes or self.filename.suffixes + + @property + def stem(self): + return pathlib.Path(self.at).stem or self.filename.stem + + @property + def filename(self): + return pathlib.Path(self.root.filename).joinpath(self.at) + + def read_text(self, *args, **kwargs): + kwargs["encoding"] = text_encoding(kwargs.get("encoding")) + with self.open('r', *args, **kwargs) as strm: + return strm.read() + + def read_bytes(self): + with self.open('rb') as strm: + return strm.read() + + def _is_child(self, path): + return posixpath.dirname(path.at.rstrip("/")) == self.at.rstrip("/") + + def _next(self, at): + return self.__class__(self.root, at) + + def is_dir(self): + return not self.at or self.at.endswith("/") + + def is_file(self): + return self.exists() and not self.is_dir() + + def exists(self): + return self.at in self.root._name_set() + + def iterdir(self): + if not self.is_dir(): + raise ValueError("Can't listdir a file") + subs = map(self._next, self.root.namelist()) + return filter(self._is_child, subs) + + def match(self, path_pattern): + return pathlib.Path(self.at).match(path_pattern) + + def is_symlink(self): + """ + Return whether this path is a symlink. Always false (python/cpython#82102). + """ + return False + + def _descendants(self): + for child in self.iterdir(): + yield child + if child.is_dir(): + yield from child._descendants() + + def glob(self, pattern): + if not pattern: + raise ValueError("Unacceptable pattern: {!r}".format(pattern)) + + matches = re.compile(fnmatch.translate(pattern)).fullmatch + return ( + child + for child in self._descendants() + if matches(str(child.relative_to(self))) + ) + + def rglob(self, pattern): + return self.glob(f'**/{pattern}') + + def relative_to(self, other, *extra): + return posixpath.relpath(str(self), str(other.joinpath(*extra))) + + def __str__(self): + return posixpath.join(self.root.filename, self.at) + + def __repr__(self): + return self.__repr.format(self=self) + + def joinpath(self, *other): + next = posixpath.join(self.at, *other) + return self._next(self.root.resolve_dir(next)) + + __truediv__ = joinpath + + @property + def parent(self): + if not self.at: + return self.filename.parent + parent_at = posixpath.dirname(self.at.rstrip('/')) + if parent_at: + parent_at += '/' + return self._next(parent_at) diff --git a/libs/win/zipp/py310compat.py b/libs/win/zipp/py310compat.py new file mode 100644 index 00000000..8244124c --- /dev/null +++ b/libs/win/zipp/py310compat.py @@ -0,0 +1,12 @@ +import sys +import io + + +te_impl = 'lambda encoding, stacklevel=2, /: encoding' +te_impl_37 = te_impl.replace(', /', '') +_text_encoding = eval(te_impl) if sys.version_info > (3, 8) else eval(te_impl_37) + + +text_encoding = ( + io.text_encoding if sys.version_info > (3, 10) else _text_encoding # type: ignore +)