mirror of
https://github.com/Tautulli/Tautulli.git
synced 2025-07-07 13:41:15 -07:00
Bump cherrypy from 18.6.1 to 18.8.0 (#1796)
* Bump cherrypy from 18.6.1 to 18.8.0 Bumps [cherrypy](https://github.com/cherrypy/cherrypy) from 18.6.1 to 18.8.0. - [Release notes](https://github.com/cherrypy/cherrypy/releases) - [Changelog](https://github.com/cherrypy/cherrypy/blob/main/CHANGES.rst) - [Commits](https://github.com/cherrypy/cherrypy/compare/v18.6.1...v18.8.0) --- updated-dependencies: - dependency-name: cherrypy dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * Update cherrypy==18.8.0 Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com>
This commit is contained in:
parent
e79da07973
commit
76cc56a215
75 changed files with 19150 additions and 1339 deletions
27
lib/autocommand/__init__.py
Normal file
27
lib/autocommand/__init__.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2014-2016 Nathan West
|
||||
#
|
||||
# This file is part of autocommand.
|
||||
#
|
||||
# autocommand is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# autocommand is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with autocommand. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# flake8 flags all these imports as unused, hence the NOQAs everywhere.
|
||||
|
||||
from .automain import automain # NOQA
|
||||
from .autoparse import autoparse, smart_open # NOQA
|
||||
from .autocommand import autocommand # NOQA
|
||||
|
||||
try:
|
||||
from .autoasync import autoasync # NOQA
|
||||
except ImportError: # pragma: no cover
|
||||
pass
|
140
lib/autocommand/autoasync.py
Normal file
140
lib/autocommand/autoasync.py
Normal file
|
@ -0,0 +1,140 @@
|
|||
# Copyright 2014-2015 Nathan West
|
||||
#
|
||||
# This file is part of autocommand.
|
||||
#
|
||||
# autocommand is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# autocommand is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with autocommand. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from asyncio import get_event_loop, iscoroutine
|
||||
from functools import wraps
|
||||
from inspect import signature
|
||||
|
||||
|
||||
def _launch_forever_coro(coro, args, kwargs, loop):
|
||||
'''
|
||||
This helper function launches an async main function that was tagged with
|
||||
forever=True. There are two possibilities:
|
||||
|
||||
- The function is a normal function, which handles initializing the event
|
||||
loop, which is then run forever
|
||||
- The function is a coroutine, which needs to be scheduled in the event
|
||||
loop, which is then run forever
|
||||
- There is also the possibility that the function is a normal function
|
||||
wrapping a coroutine function
|
||||
|
||||
The function is therefore called unconditionally and scheduled in the event
|
||||
loop if the return value is a coroutine object.
|
||||
|
||||
The reason this is a separate function is to make absolutely sure that all
|
||||
the objects created are garbage collected after all is said and done; we
|
||||
do this to ensure that any exceptions raised in the tasks are collected
|
||||
ASAP.
|
||||
'''
|
||||
|
||||
# Personal note: I consider this an antipattern, as it relies on the use of
|
||||
# unowned resources. The setup function dumps some stuff into the event
|
||||
# loop where it just whirls in the ether without a well defined owner or
|
||||
# lifetime. For this reason, there's a good chance I'll remove the
|
||||
# forever=True feature from autoasync at some point in the future.
|
||||
thing = coro(*args, **kwargs)
|
||||
if iscoroutine(thing):
|
||||
loop.create_task(thing)
|
||||
|
||||
|
||||
def autoasync(coro=None, *, loop=None, forever=False, pass_loop=False):
|
||||
'''
|
||||
Convert an asyncio coroutine into a function which, when called, is
|
||||
evaluted in an event loop, and the return value returned. This is intented
|
||||
to make it easy to write entry points into asyncio coroutines, which
|
||||
otherwise need to be explictly evaluted with an event loop's
|
||||
run_until_complete.
|
||||
|
||||
If `loop` is given, it is used as the event loop to run the coro in. If it
|
||||
is None (the default), the loop is retreived using asyncio.get_event_loop.
|
||||
This call is defered until the decorated function is called, so that
|
||||
callers can install custom event loops or event loop policies after
|
||||
@autoasync is applied.
|
||||
|
||||
If `forever` is True, the loop is run forever after the decorated coroutine
|
||||
is finished. Use this for servers created with asyncio.start_server and the
|
||||
like.
|
||||
|
||||
If `pass_loop` is True, the event loop object is passed into the coroutine
|
||||
as the `loop` kwarg when the wrapper function is called. In this case, the
|
||||
wrapper function's __signature__ is updated to remove this parameter, so
|
||||
that autoparse can still be used on it without generating a parameter for
|
||||
`loop`.
|
||||
|
||||
This coroutine can be called with ( @autoasync(...) ) or without
|
||||
( @autoasync ) arguments.
|
||||
|
||||
Examples:
|
||||
|
||||
@autoasync
|
||||
def get_file(host, port):
|
||||
reader, writer = yield from asyncio.open_connection(host, port)
|
||||
data = reader.read()
|
||||
sys.stdout.write(data.decode())
|
||||
|
||||
get_file(host, port)
|
||||
|
||||
@autoasync(forever=True, pass_loop=True)
|
||||
def server(host, port, loop):
|
||||
yield_from loop.create_server(Proto, host, port)
|
||||
|
||||
server('localhost', 8899)
|
||||
|
||||
'''
|
||||
if coro is None:
|
||||
return lambda c: autoasync(
|
||||
c, loop=loop,
|
||||
forever=forever,
|
||||
pass_loop=pass_loop)
|
||||
|
||||
# The old and new signatures are required to correctly bind the loop
|
||||
# parameter in 100% of cases, even if it's a positional parameter.
|
||||
# NOTE: A future release will probably require the loop parameter to be
|
||||
# a kwonly parameter.
|
||||
if pass_loop:
|
||||
old_sig = signature(coro)
|
||||
new_sig = old_sig.replace(parameters=(
|
||||
param for name, param in old_sig.parameters.items()
|
||||
if name != "loop"))
|
||||
|
||||
@wraps(coro)
|
||||
def autoasync_wrapper(*args, **kwargs):
|
||||
# Defer the call to get_event_loop so that, if a custom policy is
|
||||
# installed after the autoasync decorator, it is respected at call time
|
||||
local_loop = get_event_loop() if loop is None else loop
|
||||
|
||||
# Inject the 'loop' argument. We have to use this signature binding to
|
||||
# ensure it's injected in the correct place (positional, keyword, etc)
|
||||
if pass_loop:
|
||||
bound_args = old_sig.bind_partial()
|
||||
bound_args.arguments.update(
|
||||
loop=local_loop,
|
||||
**new_sig.bind(*args, **kwargs).arguments)
|
||||
args, kwargs = bound_args.args, bound_args.kwargs
|
||||
|
||||
if forever:
|
||||
_launch_forever_coro(coro, args, kwargs, local_loop)
|
||||
local_loop.run_forever()
|
||||
else:
|
||||
return local_loop.run_until_complete(coro(*args, **kwargs))
|
||||
|
||||
# Attach the updated signature. This allows 'pass_loop' to be used with
|
||||
# autoparse
|
||||
if pass_loop:
|
||||
autoasync_wrapper.__signature__ = new_sig
|
||||
|
||||
return autoasync_wrapper
|
70
lib/autocommand/autocommand.py
Normal file
70
lib/autocommand/autocommand.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2014-2015 Nathan West
|
||||
#
|
||||
# This file is part of autocommand.
|
||||
#
|
||||
# autocommand is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# autocommand is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with autocommand. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from .autoparse import autoparse
|
||||
from .automain import automain
|
||||
try:
|
||||
from .autoasync import autoasync
|
||||
except ImportError: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
def autocommand(
|
||||
module, *,
|
||||
description=None,
|
||||
epilog=None,
|
||||
add_nos=False,
|
||||
parser=None,
|
||||
loop=None,
|
||||
forever=False,
|
||||
pass_loop=False):
|
||||
|
||||
if callable(module):
|
||||
raise TypeError('autocommand requires a module name argument')
|
||||
|
||||
def autocommand_decorator(func):
|
||||
# Step 1: if requested, run it all in an asyncio event loop. autoasync
|
||||
# patches the __signature__ of the decorated function, so that in the
|
||||
# event that pass_loop is True, the `loop` parameter of the original
|
||||
# function will *not* be interpreted as a command-line argument by
|
||||
# autoparse
|
||||
if loop is not None or forever or pass_loop:
|
||||
func = autoasync(
|
||||
func,
|
||||
loop=None if loop is True else loop,
|
||||
pass_loop=pass_loop,
|
||||
forever=forever)
|
||||
|
||||
# Step 2: create parser. We do this second so that the arguments are
|
||||
# parsed and passed *before* entering the asyncio event loop, if it
|
||||
# exists. This simplifies the stack trace and ensures errors are
|
||||
# reported earlier. It also ensures that errors raised during parsing &
|
||||
# passing are still raised if `forever` is True.
|
||||
func = autoparse(
|
||||
func,
|
||||
description=description,
|
||||
epilog=epilog,
|
||||
add_nos=add_nos,
|
||||
parser=parser)
|
||||
|
||||
# Step 3: call the function automatically if __name__ == '__main__' (or
|
||||
# if True was provided)
|
||||
func = automain(module)(func)
|
||||
|
||||
return func
|
||||
|
||||
return autocommand_decorator
|
59
lib/autocommand/automain.py
Normal file
59
lib/autocommand/automain.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2014-2015 Nathan West
|
||||
#
|
||||
# This file is part of autocommand.
|
||||
#
|
||||
# autocommand is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# autocommand is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with autocommand. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import sys
|
||||
from .errors import AutocommandError
|
||||
|
||||
|
||||
class AutomainRequiresModuleError(AutocommandError, TypeError):
|
||||
pass
|
||||
|
||||
|
||||
def automain(module, *, args=(), kwargs=None):
|
||||
'''
|
||||
This decorator automatically invokes a function if the module is being run
|
||||
as the "__main__" module. Optionally, provide args or kwargs with which to
|
||||
call the function. If `module` is "__main__", the function is called, and
|
||||
the program is `sys.exit`ed with the return value. You can also pass `True`
|
||||
to cause the function to be called unconditionally. If the function is not
|
||||
called, it is returned unchanged by the decorator.
|
||||
|
||||
Usage:
|
||||
|
||||
@automain(__name__) # Pass __name__ to check __name__=="__main__"
|
||||
def main():
|
||||
...
|
||||
|
||||
If __name__ is "__main__" here, the main function is called, and then
|
||||
sys.exit called with the return value.
|
||||
'''
|
||||
|
||||
# Check that @automain(...) was called, rather than @automain
|
||||
if callable(module):
|
||||
raise AutomainRequiresModuleError(module)
|
||||
|
||||
if module == '__main__' or module is True:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# Use a function definition instead of a lambda for a neater traceback
|
||||
def automain_decorator(main):
|
||||
sys.exit(main(*args, **kwargs))
|
||||
|
||||
return automain_decorator
|
||||
else:
|
||||
return lambda main: main
|
333
lib/autocommand/autoparse.py
Normal file
333
lib/autocommand/autoparse.py
Normal file
|
@ -0,0 +1,333 @@
|
|||
# Copyright 2014-2015 Nathan West
|
||||
#
|
||||
# This file is part of autocommand.
|
||||
#
|
||||
# autocommand is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# autocommand is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with autocommand. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import sys
|
||||
from re import compile as compile_regex
|
||||
from inspect import signature, getdoc, Parameter
|
||||
from argparse import ArgumentParser
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from io import IOBase
|
||||
from autocommand.errors import AutocommandError
|
||||
|
||||
|
||||
_empty = Parameter.empty
|
||||
|
||||
|
||||
class AnnotationError(AutocommandError):
|
||||
'''Annotation error: annotation must be a string, type, or tuple of both'''
|
||||
|
||||
|
||||
class PositionalArgError(AutocommandError):
|
||||
'''
|
||||
Postional Arg Error: autocommand can't handle postional-only parameters
|
||||
'''
|
||||
|
||||
|
||||
class KWArgError(AutocommandError):
|
||||
'''kwarg Error: autocommand can't handle a **kwargs parameter'''
|
||||
|
||||
|
||||
class DocstringError(AutocommandError):
|
||||
'''Docstring error'''
|
||||
|
||||
|
||||
class TooManySplitsError(DocstringError):
|
||||
'''
|
||||
The docstring had too many ---- section splits. Currently we only support
|
||||
using up to a single split, to split the docstring into description and
|
||||
epilog parts.
|
||||
'''
|
||||
|
||||
|
||||
def _get_type_description(annotation):
|
||||
'''
|
||||
Given an annotation, return the (type, description) for the parameter.
|
||||
If you provide an annotation that is somehow both a string and a callable,
|
||||
the behavior is undefined.
|
||||
'''
|
||||
if annotation is _empty:
|
||||
return None, None
|
||||
elif callable(annotation):
|
||||
return annotation, None
|
||||
elif isinstance(annotation, str):
|
||||
return None, annotation
|
||||
elif isinstance(annotation, tuple):
|
||||
try:
|
||||
arg1, arg2 = annotation
|
||||
except ValueError as e:
|
||||
raise AnnotationError(annotation) from e
|
||||
else:
|
||||
if callable(arg1) and isinstance(arg2, str):
|
||||
return arg1, arg2
|
||||
elif isinstance(arg1, str) and callable(arg2):
|
||||
return arg2, arg1
|
||||
|
||||
raise AnnotationError(annotation)
|
||||
|
||||
|
||||
def _add_arguments(param, parser, used_char_args, add_nos):
|
||||
'''
|
||||
Add the argument(s) to an ArgumentParser (using add_argument) for a given
|
||||
parameter. used_char_args is the set of -short options currently already in
|
||||
use, and is updated (if necessary) by this function. If add_nos is True,
|
||||
this will also add an inverse switch for all boolean options. For
|
||||
instance, for the boolean parameter "verbose", this will create --verbose
|
||||
and --no-verbose.
|
||||
'''
|
||||
|
||||
# Impl note: This function is kept separate from make_parser because it's
|
||||
# already very long and I wanted to separate out as much as possible into
|
||||
# its own call scope, to prevent even the possibility of suble mutation
|
||||
# bugs.
|
||||
if param.kind is param.POSITIONAL_ONLY:
|
||||
raise PositionalArgError(param)
|
||||
elif param.kind is param.VAR_KEYWORD:
|
||||
raise KWArgError(param)
|
||||
|
||||
# These are the kwargs for the add_argument function.
|
||||
arg_spec = {}
|
||||
is_option = False
|
||||
|
||||
# Get the type and default from the annotation.
|
||||
arg_type, description = _get_type_description(param.annotation)
|
||||
|
||||
# Get the default value
|
||||
default = param.default
|
||||
|
||||
# If there is no explicit type, and the default is present and not None,
|
||||
# infer the type from the default.
|
||||
if arg_type is None and default not in {_empty, None}:
|
||||
arg_type = type(default)
|
||||
|
||||
# Add default. The presence of a default means this is an option, not an
|
||||
# argument.
|
||||
if default is not _empty:
|
||||
arg_spec['default'] = default
|
||||
is_option = True
|
||||
|
||||
# Add the type
|
||||
if arg_type is not None:
|
||||
# Special case for bool: make it just a --switch
|
||||
if arg_type is bool:
|
||||
if not default or default is _empty:
|
||||
arg_spec['action'] = 'store_true'
|
||||
else:
|
||||
arg_spec['action'] = 'store_false'
|
||||
|
||||
# Switches are always options
|
||||
is_option = True
|
||||
|
||||
# Special case for file types: make it a string type, for filename
|
||||
elif isinstance(default, IOBase):
|
||||
arg_spec['type'] = str
|
||||
|
||||
# TODO: special case for list type.
|
||||
# - How to specificy type of list members?
|
||||
# - param: [int]
|
||||
# - param: int =[]
|
||||
# - action='append' vs nargs='*'
|
||||
|
||||
else:
|
||||
arg_spec['type'] = arg_type
|
||||
|
||||
# nargs: if the signature includes *args, collect them as trailing CLI
|
||||
# arguments in a list. *args can't have a default value, so it can never be
|
||||
# an option.
|
||||
if param.kind is param.VAR_POSITIONAL:
|
||||
# TODO: consider depluralizing metavar/name here.
|
||||
arg_spec['nargs'] = '*'
|
||||
|
||||
# Add description.
|
||||
if description is not None:
|
||||
arg_spec['help'] = description
|
||||
|
||||
# Get the --flags
|
||||
flags = []
|
||||
name = param.name
|
||||
|
||||
if is_option:
|
||||
# Add the first letter as a -short option.
|
||||
for letter in name[0], name[0].swapcase():
|
||||
if letter not in used_char_args:
|
||||
used_char_args.add(letter)
|
||||
flags.append('-{}'.format(letter))
|
||||
break
|
||||
|
||||
# If the parameter is a --long option, or is a -short option that
|
||||
# somehow failed to get a flag, add it.
|
||||
if len(name) > 1 or not flags:
|
||||
flags.append('--{}'.format(name))
|
||||
|
||||
arg_spec['dest'] = name
|
||||
else:
|
||||
flags.append(name)
|
||||
|
||||
parser.add_argument(*flags, **arg_spec)
|
||||
|
||||
# Create the --no- version for boolean switches
|
||||
if add_nos and arg_type is bool:
|
||||
parser.add_argument(
|
||||
'--no-{}'.format(name),
|
||||
action='store_const',
|
||||
dest=name,
|
||||
const=default if default is not _empty else False)
|
||||
|
||||
|
||||
def make_parser(func_sig, description, epilog, add_nos):
|
||||
'''
|
||||
Given the signature of a function, create an ArgumentParser
|
||||
'''
|
||||
parser = ArgumentParser(description=description, epilog=epilog)
|
||||
|
||||
used_char_args = {'h'}
|
||||
|
||||
# Arange the params so that single-character arguments are first. This
|
||||
# esnures they don't have to get --long versions. sorted is stable, so the
|
||||
# parameters will otherwise still be in relative order.
|
||||
params = sorted(
|
||||
func_sig.parameters.values(),
|
||||
key=lambda param: len(param.name) > 1)
|
||||
|
||||
for param in params:
|
||||
_add_arguments(param, parser, used_char_args, add_nos)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
_DOCSTRING_SPLIT = compile_regex(r'\n\s*-{4,}\s*\n')
|
||||
|
||||
|
||||
def parse_docstring(docstring):
|
||||
'''
|
||||
Given a docstring, parse it into a description and epilog part
|
||||
'''
|
||||
if docstring is None:
|
||||
return '', ''
|
||||
|
||||
parts = _DOCSTRING_SPLIT.split(docstring)
|
||||
|
||||
if len(parts) == 1:
|
||||
return docstring, ''
|
||||
elif len(parts) == 2:
|
||||
return parts[0], parts[1]
|
||||
else:
|
||||
raise TooManySplitsError()
|
||||
|
||||
|
||||
def autoparse(
|
||||
func=None, *,
|
||||
description=None,
|
||||
epilog=None,
|
||||
add_nos=False,
|
||||
parser=None):
|
||||
'''
|
||||
This decorator converts a function that takes normal arguments into a
|
||||
function which takes a single optional argument, argv, parses it using an
|
||||
argparse.ArgumentParser, and calls the underlying function with the parsed
|
||||
arguments. If it is not given, sys.argv[1:] is used. This is so that the
|
||||
function can be used as a setuptools entry point, as well as a normal main
|
||||
function. sys.argv[1:] is not evaluated until the function is called, to
|
||||
allow injecting different arguments for testing.
|
||||
|
||||
It uses the argument signature of the function to create an
|
||||
ArgumentParser. Parameters without defaults become positional parameters,
|
||||
while parameters *with* defaults become --options. Use annotations to set
|
||||
the type of the parameter.
|
||||
|
||||
The `desctiption` and `epilog` parameters corrospond to the same respective
|
||||
argparse parameters. If no description is given, it defaults to the
|
||||
decorated functions's docstring, if present.
|
||||
|
||||
If add_nos is True, every boolean option (that is, every parameter with a
|
||||
default of True/False or a type of bool) will have a --no- version created
|
||||
as well, which inverts the option. For instance, the --verbose option will
|
||||
have a --no-verbose counterpart. These are not mutually exclusive-
|
||||
whichever one appears last in the argument list will have precedence.
|
||||
|
||||
If a parser is given, it is used instead of one generated from the function
|
||||
signature. In this case, no parser is created; instead, the given parser is
|
||||
used to parse the argv argument. The parser's results' argument names must
|
||||
match up with the parameter names of the decorated function.
|
||||
|
||||
The decorated function is attached to the result as the `func` attribute,
|
||||
and the parser is attached as the `parser` attribute.
|
||||
'''
|
||||
|
||||
# If @autoparse(...) is used instead of @autoparse
|
||||
if func is None:
|
||||
return lambda f: autoparse(
|
||||
f, description=description,
|
||||
epilog=epilog,
|
||||
add_nos=add_nos,
|
||||
parser=parser)
|
||||
|
||||
func_sig = signature(func)
|
||||
|
||||
docstr_description, docstr_epilog = parse_docstring(getdoc(func))
|
||||
|
||||
if parser is None:
|
||||
parser = make_parser(
|
||||
func_sig,
|
||||
description or docstr_description,
|
||||
epilog or docstr_epilog,
|
||||
add_nos)
|
||||
|
||||
@wraps(func)
|
||||
def autoparse_wrapper(argv=None):
|
||||
if argv is None:
|
||||
argv = sys.argv[1:]
|
||||
|
||||
# Get empty argument binding, to fill with parsed arguments. This
|
||||
# object does all the heavy lifting of turning named arguments into
|
||||
# into correctly bound *args and **kwargs.
|
||||
parsed_args = func_sig.bind_partial()
|
||||
parsed_args.arguments.update(vars(parser.parse_args(argv)))
|
||||
|
||||
return func(*parsed_args.args, **parsed_args.kwargs)
|
||||
|
||||
# TODO: attach an updated __signature__ to autoparse_wrapper, just in case.
|
||||
|
||||
# Attach the wrapped function and parser, and return the wrapper.
|
||||
autoparse_wrapper.func = func
|
||||
autoparse_wrapper.parser = parser
|
||||
return autoparse_wrapper
|
||||
|
||||
|
||||
@contextmanager
|
||||
def smart_open(filename_or_file, *args, **kwargs):
|
||||
'''
|
||||
This context manager allows you to open a filename, if you want to default
|
||||
some already-existing file object, like sys.stdout, which shouldn't be
|
||||
closed at the end of the context. If the filename argument is a str, bytes,
|
||||
or int, the file object is created via a call to open with the given *args
|
||||
and **kwargs, sent to the context, and closed at the end of the context,
|
||||
just like "with open(filename) as f:". If it isn't one of the openable
|
||||
types, the object simply sent to the context unchanged, and left unclosed
|
||||
at the end of the context. Example:
|
||||
|
||||
def work_with_file(name=sys.stdout):
|
||||
with smart_open(name) as f:
|
||||
# Works correctly if name is a str filename or sys.stdout
|
||||
print("Some stuff", file=f)
|
||||
# If it was a filename, f is closed at the end here.
|
||||
'''
|
||||
if isinstance(filename_or_file, (str, bytes, int)):
|
||||
with open(filename_or_file, *args, **kwargs) as file:
|
||||
yield file
|
||||
else:
|
||||
yield filename_or_file
|
23
lib/autocommand/errors.py
Normal file
23
lib/autocommand/errors.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
# Copyright 2014-2016 Nathan West
|
||||
#
|
||||
# This file is part of autocommand.
|
||||
#
|
||||
# autocommand is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# autocommand is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with autocommand. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
|
||||
class AutocommandError(Exception):
|
||||
'''Base class for autocommand exceptions'''
|
||||
pass
|
||||
|
||||
# Individual modules will define errors specific to that module.
|
|
@ -206,10 +206,6 @@ except ImportError:
|
|||
def test_callable_spec(callable, args, kwargs): # noqa: F811
|
||||
return None
|
||||
else:
|
||||
getargspec = inspect.getargspec
|
||||
# Python 3 requires using getfullargspec if
|
||||
# keyword-only arguments are present
|
||||
if hasattr(inspect, 'getfullargspec'):
|
||||
def getargspec(callable):
|
||||
return inspect.getfullargspec(callable)[:4]
|
||||
|
||||
|
|
|
@ -466,7 +466,7 @@ _HTTPErrorTemplate = '''<!DOCTYPE html PUBLIC
|
|||
<pre id="traceback">%(traceback)s</pre>
|
||||
<div id="powered_by">
|
||||
<span>
|
||||
Powered by <a href="http://www.cherrypy.org">CherryPy %(version)s</a>
|
||||
Powered by <a href="http://www.cherrypy.dev">CherryPy %(version)s</a>
|
||||
</span>
|
||||
</div>
|
||||
</body>
|
||||
|
@ -532,7 +532,8 @@ def get_error_page(status, **kwargs):
|
|||
return result
|
||||
else:
|
||||
# Load the template from this path.
|
||||
template = io.open(error_page, newline='').read()
|
||||
with io.open(error_page, newline='') as f:
|
||||
template = f.read()
|
||||
except Exception:
|
||||
e = _format_exception(*_exc_info())[-1]
|
||||
m = kwargs['message']
|
||||
|
|
|
@ -339,11 +339,8 @@ LoadModule python_module modules/mod_python.so
|
|||
}
|
||||
|
||||
mpconf = os.path.join(os.path.dirname(__file__), 'cpmodpy.conf')
|
||||
f = open(mpconf, 'wb')
|
||||
try:
|
||||
with open(mpconf, 'wb') as f:
|
||||
f.write(conf_data)
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
response = read_process(self.apache_path, '-k start -f %s' % mpconf)
|
||||
self.ready = True
|
||||
|
|
|
@ -169,7 +169,7 @@ def request_namespace(k, v):
|
|||
def response_namespace(k, v):
|
||||
"""Attach response attributes declared in config."""
|
||||
# Provides config entries to set default response headers
|
||||
# http://cherrypy.org/ticket/889
|
||||
# http://cherrypy.dev/ticket/889
|
||||
if k[:8] == 'headers.':
|
||||
cherrypy.serving.response.headers[k.split('.', 1)[1]] = v
|
||||
else:
|
||||
|
@ -252,7 +252,7 @@ class Request(object):
|
|||
The query component of the Request-URI, a string of information to be
|
||||
interpreted by the resource. The query portion of a URI follows the
|
||||
path component, and is separated by a '?'. For example, the URI
|
||||
'http://www.cherrypy.org/wiki?a=3&b=4' has the query component,
|
||||
'http://www.cherrypy.dev/wiki?a=3&b=4' has the query component,
|
||||
'a=3&b=4'."""
|
||||
|
||||
query_string_encoding = 'utf8'
|
||||
|
@ -742,6 +742,9 @@ class Request(object):
|
|||
if self.protocol >= (1, 1):
|
||||
msg = "HTTP/1.1 requires a 'Host' request header."
|
||||
raise cherrypy.HTTPError(400, msg)
|
||||
else:
|
||||
headers['Host'] = httputil.SanitizedHost(dict.get(headers, 'Host'))
|
||||
|
||||
host = dict.get(headers, 'Host')
|
||||
if not host:
|
||||
host = self.local.name or self.local.ip
|
||||
|
|
|
@ -101,13 +101,12 @@ def get_ha1_file_htdigest(filename):
|
|||
"""
|
||||
def get_ha1(realm, username):
|
||||
result = None
|
||||
f = open(filename, 'r')
|
||||
with open(filename, 'r') as f:
|
||||
for line in f:
|
||||
u, r, ha1 = line.rstrip().split(':')
|
||||
if u == username and r == realm:
|
||||
result = ha1
|
||||
break
|
||||
f.close()
|
||||
return result
|
||||
|
||||
return get_ha1
|
||||
|
|
|
@ -334,9 +334,10 @@ class CoverStats(object):
|
|||
yield '</body></html>'
|
||||
|
||||
def annotated_file(self, filename, statements, excluded, missing):
|
||||
source = open(filename, 'r')
|
||||
with open(filename, 'r') as source:
|
||||
lines = source.readlines()
|
||||
buffer = []
|
||||
for lineno, line in enumerate(source.readlines()):
|
||||
for lineno, line in enumerate(lines):
|
||||
lineno += 1
|
||||
line = line.strip('\n\r')
|
||||
empty_the_buffer = True
|
||||
|
|
|
@ -516,3 +516,33 @@ class Host(object):
|
|||
|
||||
def __repr__(self):
|
||||
return 'httputil.Host(%r, %r, %r)' % (self.ip, self.port, self.name)
|
||||
|
||||
|
||||
class SanitizedHost(str):
|
||||
r"""
|
||||
Wraps a raw host header received from the network in
|
||||
a sanitized version that elides dangerous characters.
|
||||
|
||||
>>> SanitizedHost('foo\nbar')
|
||||
'foobar'
|
||||
>>> SanitizedHost('foo\nbar').raw
|
||||
'foo\nbar'
|
||||
|
||||
A SanitizedInstance is only returned if sanitization was performed.
|
||||
|
||||
>>> isinstance(SanitizedHost('foobar'), SanitizedHost)
|
||||
False
|
||||
"""
|
||||
dangerous = re.compile(r'[\n\r]')
|
||||
|
||||
def __new__(cls, raw):
|
||||
sanitized = cls._sanitize(raw)
|
||||
if sanitized == raw:
|
||||
return raw
|
||||
instance = super().__new__(cls, sanitized)
|
||||
instance.raw = raw
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def _sanitize(cls, raw):
|
||||
return cls.dangerous.sub('', raw)
|
||||
|
|
|
@ -163,11 +163,8 @@ class Parser(configparser.ConfigParser):
|
|||
# fp = open(filename)
|
||||
# except IOError:
|
||||
# continue
|
||||
fp = open(filename)
|
||||
try:
|
||||
with open(filename) as fp:
|
||||
self._read(fp, filename)
|
||||
finally:
|
||||
fp.close()
|
||||
|
||||
def as_dict(self, raw=False, vars=None):
|
||||
"""Convert an INI file to a dictionary"""
|
||||
|
|
|
@ -516,11 +516,8 @@ class FileSession(Session):
|
|||
if path is None:
|
||||
path = self._get_file_path()
|
||||
try:
|
||||
f = open(path, 'rb')
|
||||
try:
|
||||
with open(path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
finally:
|
||||
f.close()
|
||||
except (IOError, EOFError):
|
||||
e = sys.exc_info()[1]
|
||||
if self.debug:
|
||||
|
@ -531,11 +528,8 @@ class FileSession(Session):
|
|||
def _save(self, expiration_time):
|
||||
assert self.locked, ('The session was saved without being locked. '
|
||||
"Check your tools' priority levels.")
|
||||
f = open(self._get_file_path(), 'wb')
|
||||
try:
|
||||
with open(self._get_file_path(), 'wb') as f:
|
||||
pickle.dump((self._data, expiration_time), f, self.pickle_protocol)
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
def _delete(self):
|
||||
assert self.locked, ('The session deletion without being locked. '
|
||||
|
|
|
@ -436,7 +436,8 @@ class PIDFile(SimplePlugin):
|
|||
if self.finalized:
|
||||
self.bus.log('PID %r already written to %r.' % (pid, self.pidfile))
|
||||
else:
|
||||
open(self.pidfile, 'wb').write(ntob('%s\n' % pid, 'utf8'))
|
||||
with open(self.pidfile, 'wb') as f:
|
||||
f.write(ntob('%s\n' % pid, 'utf8'))
|
||||
self.bus.log('PID %r written to %r.' % (pid, self.pidfile))
|
||||
self.finalized = True
|
||||
start.priority = 70
|
||||
|
|
|
@ -505,7 +505,8 @@ server.ssl_private_key: r'%s'
|
|||
|
||||
def get_pid(self):
|
||||
if self.daemonize:
|
||||
return int(open(self.pid_file, 'rb').read())
|
||||
with open(self.pid_file, 'rb') as f:
|
||||
return int(f.read())
|
||||
return self._proc.pid
|
||||
|
||||
def join(self):
|
||||
|
|
|
@ -97,7 +97,8 @@ class LogCase(object):
|
|||
|
||||
def emptyLog(self):
|
||||
"""Overwrite self.logfile with 0 bytes."""
|
||||
open(self.logfile, 'wb').write('')
|
||||
with open(self.logfile, 'wb') as f:
|
||||
f.write('')
|
||||
|
||||
def markLog(self, key=None):
|
||||
"""Insert a marker line into the log and set self.lastmarker."""
|
||||
|
@ -105,7 +106,8 @@ class LogCase(object):
|
|||
key = str(time.time())
|
||||
self.lastmarker = key
|
||||
|
||||
open(self.logfile, 'ab+').write(
|
||||
with open(self.logfile, 'ab+') as f:
|
||||
f.write(
|
||||
b'%s%s\n'
|
||||
% (self.markerPrefix, key.encode('utf-8'))
|
||||
)
|
||||
|
@ -122,15 +124,18 @@ class LogCase(object):
|
|||
logfile = self.logfile
|
||||
marker = marker or self.lastmarker
|
||||
if marker is None:
|
||||
return open(logfile, 'rb').readlines()
|
||||
with open(logfile, 'rb') as f:
|
||||
return f.readlines()
|
||||
|
||||
if isinstance(marker, str):
|
||||
marker = marker.encode('utf-8')
|
||||
data = []
|
||||
in_region = False
|
||||
for line in open(logfile, 'rb'):
|
||||
with open(logfile, 'rb') as f:
|
||||
for line in f:
|
||||
if in_region:
|
||||
if line.startswith(self.markerPrefix) and marker not in line:
|
||||
if (line.startswith(self.markerPrefix)
|
||||
and marker not in line):
|
||||
break
|
||||
else:
|
||||
data.append(line)
|
||||
|
|
|
@ -14,7 +14,7 @@ KNOWN BUGS
|
|||
|
||||
1. Apache processes Range headers automatically; CherryPy's truncated
|
||||
output is then truncated again by Apache. See test_core.testRanges.
|
||||
This was worked around in http://www.cherrypy.org/changeset/1319.
|
||||
This was worked around in http://www.cherrypy.dev/changeset/1319.
|
||||
2. Apache does not allow custom HTTP methods like CONNECT as per the spec.
|
||||
See test_core.testHTTPMethods.
|
||||
3. Max request header and body settings do not work with Apache.
|
||||
|
@ -112,15 +112,12 @@ class ModFCGISupervisor(helper.LocalWSGISupervisor):
|
|||
fcgiconf = os.path.join(curdir, fcgiconf)
|
||||
|
||||
# Write the Apache conf file.
|
||||
f = open(fcgiconf, 'wb')
|
||||
try:
|
||||
with open(fcgiconf, 'wb') as f:
|
||||
server = repr(os.path.join(curdir, 'fastcgi.pyc'))[1:-1]
|
||||
output = self.template % {'port': self.port, 'root': curdir,
|
||||
'server': server}
|
||||
output = output.replace('\r\n', '\n')
|
||||
f.write(output)
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
result = read_process(APACHE_PATH, '-k start -f %s' % fcgiconf)
|
||||
if result:
|
||||
|
|
|
@ -14,7 +14,7 @@ KNOWN BUGS
|
|||
|
||||
1. Apache processes Range headers automatically; CherryPy's truncated
|
||||
output is then truncated again by Apache. See test_core.testRanges.
|
||||
This was worked around in http://www.cherrypy.org/changeset/1319.
|
||||
This was worked around in http://www.cherrypy.dev/changeset/1319.
|
||||
2. Apache does not allow custom HTTP methods like CONNECT as per the spec.
|
||||
See test_core.testHTTPMethods.
|
||||
3. Max request header and body settings do not work with Apache.
|
||||
|
@ -101,15 +101,12 @@ class ModFCGISupervisor(helper.LocalSupervisor):
|
|||
fcgiconf = os.path.join(curdir, fcgiconf)
|
||||
|
||||
# Write the Apache conf file.
|
||||
f = open(fcgiconf, 'wb')
|
||||
try:
|
||||
with open(fcgiconf, 'wb') as f:
|
||||
server = repr(os.path.join(curdir, 'fastcgi.pyc'))[1:-1]
|
||||
output = self.template % {'port': self.port, 'root': curdir,
|
||||
'server': server}
|
||||
output = ntob(output.replace('\r\n', '\n'))
|
||||
f.write(output)
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
result = read_process(APACHE_PATH, '-k start -f %s' % fcgiconf)
|
||||
if result:
|
||||
|
|
|
@ -15,7 +15,7 @@ KNOWN BUGS
|
|||
|
||||
1. Apache processes Range headers automatically; CherryPy's truncated
|
||||
output is then truncated again by Apache. See test_core.testRanges.
|
||||
This was worked around in http://www.cherrypy.org/changeset/1319.
|
||||
This was worked around in http://www.cherrypy.dev/changeset/1319.
|
||||
2. Apache does not allow custom HTTP methods like CONNECT as per the spec.
|
||||
See test_core.testHTTPMethods.
|
||||
3. Max request header and body settings do not work with Apache.
|
||||
|
@ -107,13 +107,10 @@ class ModPythonSupervisor(helper.Supervisor):
|
|||
if not os.path.isabs(mpconf):
|
||||
mpconf = os.path.join(curdir, mpconf)
|
||||
|
||||
f = open(mpconf, 'wb')
|
||||
try:
|
||||
with open(mpconf, 'wb') as f:
|
||||
f.write(self.template %
|
||||
{'port': self.port, 'modulename': modulename,
|
||||
'host': self.host})
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
result = read_process(APACHE_PATH, '-k start -f %s' % mpconf)
|
||||
if result:
|
||||
|
|
|
@ -11,7 +11,7 @@ KNOWN BUGS
|
|||
|
||||
1. Apache processes Range headers automatically; CherryPy's truncated
|
||||
output is then truncated again by Apache. See test_core.testRanges.
|
||||
This was worked around in http://www.cherrypy.org/changeset/1319.
|
||||
This was worked around in http://www.cherrypy.dev/changeset/1319.
|
||||
2. Apache does not allow custom HTTP methods like CONNECT as per the spec.
|
||||
See test_core.testHTTPMethods.
|
||||
3. Max request header and body settings do not work with Apache.
|
||||
|
@ -109,14 +109,11 @@ class ModWSGISupervisor(helper.Supervisor):
|
|||
if not os.path.isabs(mpconf):
|
||||
mpconf = os.path.join(curdir, mpconf)
|
||||
|
||||
f = open(mpconf, 'wb')
|
||||
try:
|
||||
with open(mpconf, 'wb') as f:
|
||||
output = (self.template %
|
||||
{'port': self.port, 'testmod': modulename,
|
||||
'curdir': curdir})
|
||||
f.write(output)
|
||||
finally:
|
||||
f.close()
|
||||
|
||||
result = read_process(APACHE_PATH, '-k start -f %s' % mpconf)
|
||||
if result:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# This file is part of CherryPy <http://www.cherrypy.org/>
|
||||
# This file is part of CherryPy <http://www.cherrypy.dev/>
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim:ts=4:sw=4:expandtab:fileencoding=utf-8
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# This file is part of CherryPy <http://www.cherrypy.org/>
|
||||
# This file is part of CherryPy <http://www.cherrypy.dev/>
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim:ts=4:sw=4:expandtab:fileencoding=utf-8
|
||||
|
||||
|
|
|
@ -586,9 +586,8 @@ class CoreRequestHandlingTest(helper.CPWebCase):
|
|||
def testFavicon(self):
|
||||
# favicon.ico is served by staticfile.
|
||||
icofilename = os.path.join(localDir, '../favicon.ico')
|
||||
icofile = open(icofilename, 'rb')
|
||||
with open(icofilename, 'rb') as icofile:
|
||||
data = icofile.read()
|
||||
icofile.close()
|
||||
|
||||
self.getPage('/favicon.ico')
|
||||
self.assertBody(data)
|
||||
|
|
|
@ -46,7 +46,7 @@ class EncodingTests(helper.CPWebCase):
|
|||
# any part which is unicode (even ascii), the response
|
||||
# should not fail.
|
||||
cherrypy.response.cookie['candy'] = 'bar'
|
||||
cherrypy.response.cookie['candy']['domain'] = 'cherrypy.org'
|
||||
cherrypy.response.cookie['candy']['domain'] = 'cherrypy.dev'
|
||||
cherrypy.response.headers[
|
||||
'Some-Header'] = 'My d\xc3\xb6g has fleas'
|
||||
cherrypy.response.headers[
|
||||
|
|
|
@ -113,7 +113,7 @@ def test_normal_return(log_tracker, server):
|
|||
resp = requests.get(
|
||||
'http://%s:%s/as_string' % (host, port),
|
||||
headers={
|
||||
'Referer': 'http://www.cherrypy.org/',
|
||||
'Referer': 'http://www.cherrypy.dev/',
|
||||
'User-Agent': 'Mozilla/5.0',
|
||||
},
|
||||
)
|
||||
|
@ -135,7 +135,7 @@ def test_normal_return(log_tracker, server):
|
|||
log_tracker.assertLog(
|
||||
-1,
|
||||
'] "GET /as_string HTTP/1.1" 200 %s '
|
||||
'"http://www.cherrypy.org/" "Mozilla/5.0"'
|
||||
'"http://www.cherrypy.dev/" "Mozilla/5.0"'
|
||||
% content_length,
|
||||
)
|
||||
|
||||
|
|
|
@ -342,7 +342,7 @@ class RequestObjectTests(helper.CPWebCase):
|
|||
self.assertBody('/pathinfo/foo/bar')
|
||||
|
||||
def testAbsoluteURIPathInfo(self):
|
||||
# http://cherrypy.org/ticket/1061
|
||||
# http://cherrypy.dev/ticket/1061
|
||||
self.getPage('http://localhost/pathinfo/foo/bar')
|
||||
self.assertBody('/pathinfo/foo/bar')
|
||||
|
||||
|
@ -375,10 +375,10 @@ class RequestObjectTests(helper.CPWebCase):
|
|||
|
||||
# Make sure that encoded = and & get parsed correctly
|
||||
self.getPage(
|
||||
'/params/code?url=http%3A//cherrypy.org/index%3Fa%3D1%26b%3D2')
|
||||
'/params/code?url=http%3A//cherrypy.dev/index%3Fa%3D1%26b%3D2')
|
||||
self.assertBody('args: %s kwargs: %s' %
|
||||
(('code',),
|
||||
[('url', ntou('http://cherrypy.org/index?a=1&b=2'))]))
|
||||
[('url', ntou('http://cherrypy.dev/index?a=1&b=2'))]))
|
||||
|
||||
# Test coordinates sent by <img ismap>
|
||||
self.getPage('/params/ismap?223,114')
|
||||
|
@ -756,6 +756,16 @@ class RequestObjectTests(helper.CPWebCase):
|
|||
headers=[('Content-type', 'application/json')])
|
||||
self.assertBody('application/json')
|
||||
|
||||
def test_dangerous_host(self):
|
||||
"""
|
||||
Dangerous characters like newlines should be elided.
|
||||
Ref #1974.
|
||||
"""
|
||||
# foo\nbar
|
||||
encoded = '=?iso-8859-1?q?foo=0Abar?='
|
||||
self.getPage('/headers/Host', headers=[('Host', encoded)])
|
||||
self.assertBody('foobar')
|
||||
|
||||
def test_basic_HTTPMethods(self):
|
||||
helper.webtest.methods_with_bodies = ('POST', 'PUT', 'PROPFIND',
|
||||
'PATCH')
|
||||
|
|
|
@ -424,7 +424,8 @@ test_case_name: "test_signal_handler_unsubscribe"
|
|||
p.join()
|
||||
|
||||
# Assert the old handler ran.
|
||||
log_lines = list(open(p.error_log, 'rb'))
|
||||
with open(p.error_log, 'rb') as f:
|
||||
log_lines = list(f)
|
||||
assert any(
|
||||
line.endswith(b'I am an old SIGTERM handler.\n')
|
||||
for line in log_lines
|
||||
|
|
|
@ -78,7 +78,7 @@ class TutorialTest(helper.CPWebCase):
|
|||
|
||||
<ul>
|
||||
<li><a href="http://del.icio.us">del.icio.us</a></li>
|
||||
<li><a href="http://www.cherrypy.org">CherryPy</a></li>
|
||||
<li><a href="http://www.cherrypy.dev">CherryPy</a></li>
|
||||
</ul>
|
||||
|
||||
<p>[<a href="../">Return to links page</a>]</p>'''
|
||||
|
@ -166,7 +166,7 @@ class TutorialTest(helper.CPWebCase):
|
|||
self.assertHeader('Content-Disposition',
|
||||
# Make sure the filename is quoted.
|
||||
'attachment; filename="pdf_file.pdf"')
|
||||
self.assertEqual(len(self.body), 85698)
|
||||
self.assertEqual(len(self.body), 11961)
|
||||
|
||||
def test10HTTPErrors(self):
|
||||
self.setup_tutorial('tut10_http_errors', 'HTTPErrorDemo')
|
||||
|
|
Binary file not shown.
|
@ -53,7 +53,7 @@ class LinksPage:
|
|||
|
||||
<ul>
|
||||
<li>
|
||||
<a href="http://www.cherrypy.org">The CherryPy Homepage</a>
|
||||
<a href="http://www.cherrypy.dev">The CherryPy Homepage</a>
|
||||
</li>
|
||||
<li>
|
||||
<a href="http://www.python.org">The Python Homepage</a>
|
||||
|
@ -77,7 +77,7 @@ class ExtraLinksPage:
|
|||
|
||||
<ul>
|
||||
<li><a href="http://del.icio.us">del.icio.us</a></li>
|
||||
<li><a href="http://www.cherrypy.org">CherryPy</a></li>
|
||||
<li><a href="http://www.cherrypy.dev">CherryPy</a></li>
|
||||
</ul>
|
||||
|
||||
<p>[<a href="../">Return to links page</a>]</p>'''
|
||||
|
|
3991
lib/inflect/__init__.py
Normal file
3991
lib/inflect/__init__.py
Normal file
File diff suppressed because it is too large
Load diff
0
lib/inflect/py.typed
Normal file
0
lib/inflect/py.typed
Normal file
|
@ -143,7 +143,7 @@ class classproperty:
|
|||
return super().__setattr__(key, value)
|
||||
|
||||
def __init__(self, fget, fset=None):
|
||||
self.fget = self._fix_function(fget)
|
||||
self.fget = self._ensure_method(fget)
|
||||
self.fset = fset
|
||||
fset and self.setter(fset)
|
||||
|
||||
|
@ -158,14 +158,13 @@ class classproperty:
|
|||
return self.fset.__get__(None, owner)(value)
|
||||
|
||||
def setter(self, fset):
|
||||
self.fset = self._fix_function(fset)
|
||||
self.fset = self._ensure_method(fset)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def _fix_function(cls, fn):
|
||||
def _ensure_method(cls, fn):
|
||||
"""
|
||||
Ensure fn is a classmethod or staticmethod.
|
||||
"""
|
||||
if not isinstance(fn, (classmethod, staticmethod)):
|
||||
return classmethod(fn)
|
||||
return fn
|
||||
needs_method = not isinstance(fn, (classmethod, staticmethod))
|
||||
return classmethod(fn) if needs_method else fn
|
||||
|
|
|
@ -63,7 +63,7 @@ class Projection(collections.abc.Mapping):
|
|||
return len(tuple(iter(self)))
|
||||
|
||||
|
||||
class DictFilter(object):
|
||||
class DictFilter(collections.abc.Mapping):
|
||||
"""
|
||||
Takes a dict, and simulates a sub-dict based on the keys.
|
||||
|
||||
|
@ -92,15 +92,21 @@ class DictFilter(object):
|
|||
...
|
||||
KeyError: 'e'
|
||||
|
||||
>>> 'e' in filtered
|
||||
False
|
||||
|
||||
Pattern is useful for excluding keys with a prefix.
|
||||
|
||||
>>> filtered = DictFilter(sample, include_pattern=r'(?![ace])')
|
||||
>>> dict(filtered)
|
||||
{'b': 2, 'd': 4}
|
||||
|
||||
Also note that DictFilter keeps a reference to the original dict, so
|
||||
if you modify the original dict, that could modify the filtered dict.
|
||||
|
||||
>>> del sample['d']
|
||||
>>> del sample['a']
|
||||
>>> filtered == {'b': 2, 'c': 3}
|
||||
True
|
||||
>>> filtered != {'b': 2, 'c': 3}
|
||||
False
|
||||
>>> dict(filtered)
|
||||
{'b': 2}
|
||||
"""
|
||||
|
||||
def __init__(self, dict, include_keys=[], include_pattern=None):
|
||||
|
@ -120,29 +126,18 @@ class DictFilter(object):
|
|||
|
||||
@property
|
||||
def include_keys(self):
|
||||
return self.specified_keys.union(self.pattern_keys)
|
||||
|
||||
def keys(self):
|
||||
return self.include_keys.intersection(self.dict.keys())
|
||||
|
||||
def values(self):
|
||||
return map(self.dict.get, self.keys())
|
||||
return self.specified_keys | self.pattern_keys
|
||||
|
||||
def __getitem__(self, i):
|
||||
if i not in self.include_keys:
|
||||
raise KeyError(i)
|
||||
return self.dict[i]
|
||||
|
||||
def items(self):
|
||||
keys = self.keys()
|
||||
values = map(self.dict.get, keys)
|
||||
return zip(keys, values)
|
||||
def __iter__(self):
|
||||
return filter(self.include_keys.__contains__, self.dict.keys())
|
||||
|
||||
def __eq__(self, other):
|
||||
return dict(self) == other
|
||||
|
||||
def __ne__(self, other):
|
||||
return dict(self) != other
|
||||
def __len__(self):
|
||||
return len(list(self))
|
||||
|
||||
|
||||
def dict_map(function, dictionary):
|
||||
|
@ -167,7 +162,7 @@ class RangeMap(dict):
|
|||
the sorted list of keys.
|
||||
|
||||
One may supply keyword parameters to be passed to the sort function used
|
||||
to sort keys (i.e. cmp [python 2 only], keys, reverse) as sort_params.
|
||||
to sort keys (i.e. key, reverse) as sort_params.
|
||||
|
||||
Let's create a map that maps 1-3 -> 'a', 4-6 -> 'b'
|
||||
|
||||
|
@ -220,6 +215,23 @@ class RangeMap(dict):
|
|||
|
||||
>>> r.get(7, 'not found')
|
||||
'not found'
|
||||
|
||||
One often wishes to define the ranges by their left-most values,
|
||||
which requires use of sort params and a key_match_comparator.
|
||||
|
||||
>>> r = RangeMap({1: 'a', 4: 'b'},
|
||||
... sort_params=dict(reverse=True),
|
||||
... key_match_comparator=operator.ge)
|
||||
>>> r[1], r[2], r[3], r[4], r[5], r[6]
|
||||
('a', 'a', 'a', 'b', 'b', 'b')
|
||||
|
||||
That wasn't nearly as easy as before, so an alternate constructor
|
||||
is provided:
|
||||
|
||||
>>> r = RangeMap.left({1: 'a', 4: 'b', 7: RangeMap.undefined_value})
|
||||
>>> r[1], r[2], r[3], r[4], r[5], r[6]
|
||||
('a', 'a', 'a', 'b', 'b', 'b')
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, source, sort_params={}, key_match_comparator=operator.le):
|
||||
|
@ -227,6 +239,12 @@ class RangeMap(dict):
|
|||
self.sort_params = sort_params
|
||||
self.match = key_match_comparator
|
||||
|
||||
@classmethod
|
||||
def left(cls, source):
|
||||
return cls(
|
||||
source, sort_params=dict(reverse=True), key_match_comparator=operator.ge
|
||||
)
|
||||
|
||||
def __getitem__(self, item):
|
||||
sorted_keys = sorted(self.keys(), **self.sort_params)
|
||||
if isinstance(item, RangeMap.Item):
|
||||
|
@ -261,7 +279,7 @@ class RangeMap(dict):
|
|||
return (sorted_keys[RangeMap.first_item], sorted_keys[RangeMap.last_item])
|
||||
|
||||
# some special values for the RangeMap
|
||||
undefined_value = type(str('RangeValueUndefined'), (object,), {})()
|
||||
undefined_value = type(str('RangeValueUndefined'), (), {})()
|
||||
|
||||
class Item(int):
|
||||
"RangeMap Item"
|
||||
|
@ -370,7 +388,7 @@ class FoldedCaseKeyedDict(KeyTransformingDict):
|
|||
True
|
||||
>>> 'HELLO' in d
|
||||
True
|
||||
>>> print(repr(FoldedCaseKeyedDict({'heLlo': 'world'})).replace("u'", "'"))
|
||||
>>> print(repr(FoldedCaseKeyedDict({'heLlo': 'world'})))
|
||||
{'heLlo': 'world'}
|
||||
>>> d = FoldedCaseKeyedDict({'heLlo': 'world'})
|
||||
>>> print(d['hello'])
|
||||
|
@ -433,7 +451,7 @@ class FoldedCaseKeyedDict(KeyTransformingDict):
|
|||
return jaraco.text.FoldedCase(key)
|
||||
|
||||
|
||||
class DictAdapter(object):
|
||||
class DictAdapter:
|
||||
"""
|
||||
Provide a getitem interface for attributes of an object.
|
||||
|
||||
|
@ -452,7 +470,7 @@ class DictAdapter(object):
|
|||
return getattr(self.object, name)
|
||||
|
||||
|
||||
class ItemsAsAttributes(object):
|
||||
class ItemsAsAttributes:
|
||||
"""
|
||||
Mix-in class to enable a mapping object to provide items as
|
||||
attributes.
|
||||
|
@ -561,7 +579,7 @@ class IdentityOverrideMap(dict):
|
|||
return key
|
||||
|
||||
|
||||
class DictStack(list, collections.abc.Mapping):
|
||||
class DictStack(list, collections.abc.MutableMapping):
|
||||
"""
|
||||
A stack of dictionaries that behaves as a view on those dictionaries,
|
||||
giving preference to the last.
|
||||
|
@ -578,11 +596,12 @@ class DictStack(list, collections.abc.Mapping):
|
|||
>>> stack.push(dict(a=3))
|
||||
>>> stack['a']
|
||||
3
|
||||
>>> stack['a'] = 4
|
||||
>>> set(stack.keys()) == set(['a', 'b', 'c'])
|
||||
True
|
||||
>>> set(stack.items()) == set([('a', 3), ('b', 2), ('c', 2)])
|
||||
>>> set(stack.items()) == set([('a', 4), ('b', 2), ('c', 2)])
|
||||
True
|
||||
>>> dict(**stack) == dict(stack) == dict(a=3, c=2, b=2)
|
||||
>>> dict(**stack) == dict(stack) == dict(a=4, c=2, b=2)
|
||||
True
|
||||
>>> d = stack.pop()
|
||||
>>> stack['a']
|
||||
|
@ -593,6 +612,9 @@ class DictStack(list, collections.abc.Mapping):
|
|||
>>> stack.get('b', None)
|
||||
>>> 'c' in stack
|
||||
True
|
||||
>>> del stack['c']
|
||||
>>> dict(stack)
|
||||
{'a': 1}
|
||||
"""
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -613,6 +635,18 @@ class DictStack(list, collections.abc.Mapping):
|
|||
def __len__(self):
|
||||
return len(list(iter(self)))
|
||||
|
||||
def __setitem__(self, key, item):
|
||||
last = list.__getitem__(self, -1)
|
||||
return last.__setitem__(key, item)
|
||||
|
||||
def __delitem__(self, key):
|
||||
last = list.__getitem__(self, -1)
|
||||
return last.__delitem__(key)
|
||||
|
||||
# workaround for mypy confusion
|
||||
def pop(self, *args, **kwargs):
|
||||
return list.pop(self, *args, **kwargs)
|
||||
|
||||
|
||||
class BijectiveMap(dict):
|
||||
"""
|
||||
|
@ -855,7 +889,7 @@ class Enumeration(ItemsAsAttributes, BijectiveMap):
|
|||
return (self[name] for name in self.names)
|
||||
|
||||
|
||||
class Everything(object):
|
||||
class Everything:
|
||||
"""
|
||||
A collection "containing" every possible thing.
|
||||
|
||||
|
@ -896,7 +930,7 @@ class InstrumentedDict(collections.UserDict): # type: ignore # buggy mypy
|
|||
self.data = data
|
||||
|
||||
|
||||
class Least(object):
|
||||
class Least:
|
||||
"""
|
||||
A value that is always lesser than any other
|
||||
|
||||
|
@ -928,7 +962,7 @@ class Least(object):
|
|||
__gt__ = __ge__
|
||||
|
||||
|
||||
class Greatest(object):
|
||||
class Greatest:
|
||||
"""
|
||||
A value that is always greater than any other
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ class FoldedCase(str):
|
|||
>>> s in ["Hello World"]
|
||||
True
|
||||
|
||||
You may test for set inclusion, but candidate and elements
|
||||
Allows testing for set inclusion, but candidate and elements
|
||||
must both be folded.
|
||||
|
||||
>>> FoldedCase("Hello World") in {s}
|
||||
|
@ -92,37 +92,40 @@ class FoldedCase(str):
|
|||
|
||||
>>> FoldedCase('hello') > FoldedCase('Hello')
|
||||
False
|
||||
|
||||
>>> FoldedCase('ß') == FoldedCase('ss')
|
||||
True
|
||||
"""
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.lower() < other.lower()
|
||||
return self.casefold() < other.casefold()
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.lower() > other.lower()
|
||||
return self.casefold() > other.casefold()
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.lower() == other.lower()
|
||||
return self.casefold() == other.casefold()
|
||||
|
||||
def __ne__(self, other):
|
||||
return self.lower() != other.lower()
|
||||
return self.casefold() != other.casefold()
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.lower())
|
||||
return hash(self.casefold())
|
||||
|
||||
def __contains__(self, other):
|
||||
return super().lower().__contains__(other.lower())
|
||||
return super().casefold().__contains__(other.casefold())
|
||||
|
||||
def in_(self, other):
|
||||
"Does self appear in other?"
|
||||
return self in FoldedCase(other)
|
||||
|
||||
# cache lower since it's likely to be called frequently.
|
||||
# cache casefold since it's likely to be called frequently.
|
||||
@method_cache
|
||||
def lower(self):
|
||||
return super().lower()
|
||||
def casefold(self):
|
||||
return super().casefold()
|
||||
|
||||
def index(self, sub):
|
||||
return self.lower().index(sub.lower())
|
||||
return self.casefold().index(sub.casefold())
|
||||
|
||||
def split(self, splitter=' ', maxsplit=0):
|
||||
pattern = re.compile(re.escape(splitter), re.I)
|
||||
|
@ -277,7 +280,7 @@ class WordSet(tuple):
|
|||
>>> WordSet.parse("myABCClass")
|
||||
('my', 'ABC', 'Class')
|
||||
|
||||
The result is a WordSet, so you can get the form you need.
|
||||
The result is a WordSet, providing access to other forms.
|
||||
|
||||
>>> WordSet.parse("myABCClass").underscore_separated()
|
||||
'my_ABC_Class'
|
||||
|
@ -598,3 +601,22 @@ def join_continuation(lines):
|
|||
except StopIteration:
|
||||
return
|
||||
yield item
|
||||
|
||||
|
||||
def read_newlines(filename, limit=1024):
|
||||
r"""
|
||||
>>> tmp_path = getfixture('tmp_path')
|
||||
>>> filename = tmp_path / 'out.txt'
|
||||
>>> _ = filename.write_text('foo\n', newline='')
|
||||
>>> read_newlines(filename)
|
||||
'\n'
|
||||
>>> _ = filename.write_text('foo\r\n', newline='')
|
||||
>>> read_newlines(filename)
|
||||
'\r\n'
|
||||
>>> _ = filename.write_text('foo\r\nbar\nbing\r', newline='')
|
||||
>>> read_newlines(filename)
|
||||
('\r', '\n', '\r\n')
|
||||
"""
|
||||
with open(filename) as fp:
|
||||
fp.read(limit)
|
||||
return fp.newlines
|
||||
|
|
25
lib/jaraco/text/layouts.py
Normal file
25
lib/jaraco/text/layouts.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
qwerty = "-=qwertyuiop[]asdfghjkl;'zxcvbnm,./_+QWERTYUIOP{}ASDFGHJKL:\"ZXCVBNM<>?"
|
||||
dvorak = "[]',.pyfgcrl/=aoeuidhtns-;qjkxbmwvz{}\"<>PYFGCRL?+AOEUIDHTNS_:QJKXBMWVZ"
|
||||
|
||||
|
||||
to_dvorak = str.maketrans(qwerty, dvorak)
|
||||
to_qwerty = str.maketrans(dvorak, qwerty)
|
||||
|
||||
|
||||
def translate(input, translation):
|
||||
"""
|
||||
>>> translate('dvorak', to_dvorak)
|
||||
'ekrpat'
|
||||
>>> translate('qwerty', to_qwerty)
|
||||
'x,dokt'
|
||||
"""
|
||||
return input.translate(translation)
|
||||
|
||||
|
||||
def _translate_stream(stream, translation):
|
||||
"""
|
||||
>>> import io
|
||||
>>> _translate_stream(io.StringIO('foo'), to_dvorak)
|
||||
urr
|
||||
"""
|
||||
print(translate(stream.read(), translation))
|
33
lib/jaraco/text/show-newlines.py
Normal file
33
lib/jaraco/text/show-newlines.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
import autocommand
|
||||
import inflect
|
||||
|
||||
from more_itertools import always_iterable
|
||||
|
||||
import jaraco.text
|
||||
|
||||
|
||||
def report_newlines(filename):
|
||||
r"""
|
||||
Report the newlines in the indicated file.
|
||||
|
||||
>>> tmp_path = getfixture('tmp_path')
|
||||
>>> filename = tmp_path / 'out.txt'
|
||||
>>> _ = filename.write_text('foo\nbar\n', newline='')
|
||||
>>> report_newlines(filename)
|
||||
newline is '\n'
|
||||
>>> filename = tmp_path / 'out.txt'
|
||||
>>> _ = filename.write_text('foo\nbar\r\n', newline='')
|
||||
>>> report_newlines(filename)
|
||||
newlines are ('\n', '\r\n')
|
||||
"""
|
||||
newlines = jaraco.text.read_newlines(filename)
|
||||
count = len(tuple(always_iterable(newlines)))
|
||||
engine = inflect.engine()
|
||||
print(
|
||||
engine.plural_noun("newline", count),
|
||||
engine.plural_verb("is", count),
|
||||
repr(newlines),
|
||||
)
|
||||
|
||||
|
||||
autocommand.autocommand(__name__)(report_newlines)
|
6
lib/jaraco/text/to-dvorak.py
Normal file
6
lib/jaraco/text/to-dvorak.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
import sys
|
||||
|
||||
from . import layouts
|
||||
|
||||
|
||||
__name__ == '__main__' and layouts._translate_stream(sys.stdin, layouts.to_dvorak)
|
6
lib/jaraco/text/to-qwerty.py
Normal file
6
lib/jaraco/text/to-qwerty.py
Normal file
|
@ -0,0 +1,6 @@
|
|||
import sys
|
||||
|
||||
from . import layouts
|
||||
|
||||
|
||||
__name__ == '__main__' and layouts._translate_stream(sys.stdin, layouts.to_qwerty)
|
|
@ -1,4 +1,6 @@
|
|||
"""More routines for operating on iterables, beyond itertools"""
|
||||
|
||||
from .more import * # noqa
|
||||
from .recipes import * # noqa
|
||||
|
||||
__version__ = '8.12.0'
|
||||
__version__ = '9.0.0'
|
||||
|
|
|
@ -2,9 +2,8 @@ import warnings
|
|||
|
||||
from collections import Counter, defaultdict, deque, abc
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial, reduce, wraps
|
||||
from heapq import merge, heapify, heapreplace, heappop
|
||||
from heapq import heapify, heapreplace, heappop
|
||||
from itertools import (
|
||||
chain,
|
||||
compress,
|
||||
|
@ -27,12 +26,16 @@ from sys import hexversion, maxsize
|
|||
from time import monotonic
|
||||
|
||||
from .recipes import (
|
||||
_marker,
|
||||
_zip_equal,
|
||||
UnequalIterablesError,
|
||||
consume,
|
||||
flatten,
|
||||
pairwise,
|
||||
powerset,
|
||||
take,
|
||||
unique_everseen,
|
||||
all_equal,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
@ -49,9 +52,9 @@ __all__ = [
|
|||
'chunked_even',
|
||||
'circular_shifts',
|
||||
'collapse',
|
||||
'collate',
|
||||
'combination_index',
|
||||
'consecutive_groups',
|
||||
'constrained_batches',
|
||||
'consumer',
|
||||
'count_cycle',
|
||||
'countable',
|
||||
|
@ -67,6 +70,7 @@ __all__ = [
|
|||
'first',
|
||||
'groupby_transform',
|
||||
'ichunked',
|
||||
'iequals',
|
||||
'ilen',
|
||||
'interleave',
|
||||
'interleave_evenly',
|
||||
|
@ -77,6 +81,7 @@ __all__ = [
|
|||
'iterate',
|
||||
'last',
|
||||
'locate',
|
||||
'longest_common_prefix',
|
||||
'lstrip',
|
||||
'make_decorator',
|
||||
'map_except',
|
||||
|
@ -133,9 +138,6 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
_marker = object()
|
||||
|
||||
|
||||
def chunked(iterable, n, strict=False):
|
||||
"""Break *iterable* into lists of length *n*:
|
||||
|
||||
|
@ -410,44 +412,6 @@ class peekable:
|
|||
return self._cache[index]
|
||||
|
||||
|
||||
def collate(*iterables, **kwargs):
|
||||
"""Return a sorted merge of the items from each of several already-sorted
|
||||
*iterables*.
|
||||
|
||||
>>> list(collate('ACDZ', 'AZ', 'JKL'))
|
||||
['A', 'A', 'C', 'D', 'J', 'K', 'L', 'Z', 'Z']
|
||||
|
||||
Works lazily, keeping only the next value from each iterable in memory. Use
|
||||
:func:`collate` to, for example, perform a n-way mergesort of items that
|
||||
don't fit in memory.
|
||||
|
||||
If a *key* function is specified, the iterables will be sorted according
|
||||
to its result:
|
||||
|
||||
>>> key = lambda s: int(s) # Sort by numeric value, not by string
|
||||
>>> list(collate(['1', '10'], ['2', '11'], key=key))
|
||||
['1', '2', '10', '11']
|
||||
|
||||
|
||||
If the *iterables* are sorted in descending order, set *reverse* to
|
||||
``True``:
|
||||
|
||||
>>> list(collate([5, 3, 1], [4, 2, 0], reverse=True))
|
||||
[5, 4, 3, 2, 1, 0]
|
||||
|
||||
If the elements of the passed-in iterables are out of order, you might get
|
||||
unexpected results.
|
||||
|
||||
On Python 3.5+, this function is an alias for :func:`heapq.merge`.
|
||||
|
||||
"""
|
||||
warnings.warn(
|
||||
"collate is no longer part of more_itertools, use heapq.merge",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return merge(*iterables, **kwargs)
|
||||
|
||||
|
||||
def consumer(func):
|
||||
"""Decorator that automatically advances a PEP-342-style "reverse iterator"
|
||||
to its first yield point so you don't have to call ``next()`` on it
|
||||
|
@ -873,7 +837,9 @@ def windowed(seq, n, fillvalue=None, step=1):
|
|||
yield tuple(window)
|
||||
|
||||
size = len(window)
|
||||
if size < n:
|
||||
if size == 0:
|
||||
return
|
||||
elif size < n:
|
||||
yield tuple(chain(window, repeat(fillvalue, n - size)))
|
||||
elif 0 < i < min(step, n):
|
||||
window += (fillvalue,) * i
|
||||
|
@ -1646,45 +1612,6 @@ def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
|
|||
)
|
||||
|
||||
|
||||
class UnequalIterablesError(ValueError):
|
||||
def __init__(self, details=None):
|
||||
msg = 'Iterables have different lengths'
|
||||
if details is not None:
|
||||
msg += (': index 0 has length {}; index {} has length {}').format(
|
||||
*details
|
||||
)
|
||||
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
def _zip_equal_generator(iterables):
|
||||
for combo in zip_longest(*iterables, fillvalue=_marker):
|
||||
for val in combo:
|
||||
if val is _marker:
|
||||
raise UnequalIterablesError()
|
||||
yield combo
|
||||
|
||||
|
||||
def _zip_equal(*iterables):
|
||||
# Check whether the iterables are all the same size.
|
||||
try:
|
||||
first_size = len(iterables[0])
|
||||
for i, it in enumerate(iterables[1:], 1):
|
||||
size = len(it)
|
||||
if size != first_size:
|
||||
break
|
||||
else:
|
||||
# If we didn't break out, we can use the built-in zip.
|
||||
return zip(*iterables)
|
||||
|
||||
# If we did break out, there was a mismatch.
|
||||
raise UnequalIterablesError(details=(first_size, i, size))
|
||||
# If any one of the iterables didn't have a length, start reading
|
||||
# them until one runs out.
|
||||
except TypeError:
|
||||
return _zip_equal_generator(iterables)
|
||||
|
||||
|
||||
def zip_equal(*iterables):
|
||||
"""``zip`` the input *iterables* together, but raise
|
||||
``UnequalIterablesError`` if they aren't all the same length.
|
||||
|
@ -1826,7 +1753,7 @@ def unzip(iterable):
|
|||
of the zipped *iterable*.
|
||||
|
||||
The ``i``-th iterable contains the ``i``-th element from each element
|
||||
of the zipped iterable. The first element is used to to determine the
|
||||
of the zipped iterable. The first element is used to determine the
|
||||
length of the remaining elements.
|
||||
|
||||
>>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
|
||||
|
@ -2376,6 +2303,16 @@ def locate(iterable, pred=bool, window_size=None):
|
|||
return compress(count(), starmap(pred, it))
|
||||
|
||||
|
||||
def longest_common_prefix(iterables):
|
||||
"""Yield elements of the longest common prefix amongst given *iterables*.
|
||||
|
||||
>>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf']))
|
||||
'ab'
|
||||
|
||||
"""
|
||||
return (c[0] for c in takewhile(all_equal, zip(*iterables)))
|
||||
|
||||
|
||||
def lstrip(iterable, pred):
|
||||
"""Yield the items from *iterable*, but strip any from the beginning
|
||||
for which *pred* returns ``True``.
|
||||
|
@ -2684,7 +2621,7 @@ def difference(iterable, func=sub, *, initial=None):
|
|||
if initial is not None:
|
||||
first = []
|
||||
|
||||
return chain(first, starmap(func, zip(b, a)))
|
||||
return chain(first, map(func, b, a))
|
||||
|
||||
|
||||
class SequenceView(Sequence):
|
||||
|
@ -3327,6 +3264,27 @@ def only(iterable, default=None, too_long=None):
|
|||
return first_value
|
||||
|
||||
|
||||
class _IChunk:
|
||||
def __init__(self, iterable, n):
|
||||
self._it = islice(iterable, n)
|
||||
self._cache = deque()
|
||||
|
||||
def fill_cache(self):
|
||||
self._cache.extend(self._it)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return next(self._it)
|
||||
except StopIteration:
|
||||
if self._cache:
|
||||
return self._cache.popleft()
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def ichunked(iterable, n):
|
||||
"""Break *iterable* into sub-iterables with *n* elements each.
|
||||
:func:`ichunked` is like :func:`chunked`, but it yields iterables
|
||||
|
@ -3348,20 +3306,39 @@ def ichunked(iterable, n):
|
|||
[8, 9, 10, 11]
|
||||
|
||||
"""
|
||||
source = iter(iterable)
|
||||
|
||||
source = peekable(iter(iterable))
|
||||
ichunk_marker = object()
|
||||
while True:
|
||||
# Check to see whether we're at the end of the source iterable
|
||||
item = next(source, _marker)
|
||||
if item is _marker:
|
||||
item = source.peek(ichunk_marker)
|
||||
if item is ichunk_marker:
|
||||
return
|
||||
|
||||
# Clone the source and yield an n-length slice
|
||||
source, it = tee(chain([item], source))
|
||||
yield islice(it, n)
|
||||
chunk = _IChunk(source, n)
|
||||
yield chunk
|
||||
|
||||
# Advance the source iterable
|
||||
consume(source, n)
|
||||
# Advance the source iterable and fill previous chunk's cache
|
||||
chunk.fill_cache()
|
||||
|
||||
|
||||
def iequals(*iterables):
|
||||
"""Return ``True`` if all given *iterables* are equal to each other,
|
||||
which means that they contain the same elements in the same order.
|
||||
|
||||
The function is useful for comparing iterables of different data types
|
||||
or iterables that do not support equality checks.
|
||||
|
||||
>>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc"))
|
||||
True
|
||||
|
||||
>>> iequals("abc", "acb")
|
||||
False
|
||||
|
||||
Not to be confused with :func:`all_equals`, which checks whether all
|
||||
elements of iterable are equal to each other.
|
||||
|
||||
"""
|
||||
return all(map(all_equal, zip_longest(*iterables, fillvalue=object())))
|
||||
|
||||
|
||||
def distinct_combinations(iterable, r):
|
||||
|
@ -3656,7 +3633,10 @@ class callback_iter:
|
|||
self._aborted = False
|
||||
self._future = None
|
||||
self._wait_seconds = wait_seconds
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
# Lazily import concurrent.future
|
||||
self._executor = __import__(
|
||||
'concurrent.futures'
|
||||
).futures.ThreadPoolExecutor(max_workers=1)
|
||||
self._iterator = self._reader()
|
||||
|
||||
def __enter__(self):
|
||||
|
@ -3961,7 +3941,7 @@ def combination_index(element, iterable):
|
|||
|
||||
n, _ = last(pool, default=(n, None))
|
||||
|
||||
# Python versiosn below 3.8 don't have math.comb
|
||||
# Python versions below 3.8 don't have math.comb
|
||||
index = 1
|
||||
for i, j in enumerate(reversed(indexes), start=1):
|
||||
j = n - j
|
||||
|
@ -4114,7 +4094,7 @@ def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
|
|||
|
||||
If the *strict* keyword argument is ``True``, then
|
||||
``UnequalIterablesError`` will be raised if any of the iterables have
|
||||
different lengthss.
|
||||
different lengths.
|
||||
"""
|
||||
|
||||
def is_scalar(obj):
|
||||
|
@ -4315,3 +4295,53 @@ def minmax(iterable_or_value, *others, key=None, default=_marker):
|
|||
hi, hi_key = y, y_key
|
||||
|
||||
return lo, hi
|
||||
|
||||
|
||||
def constrained_batches(
|
||||
iterable, max_size, max_count=None, get_len=len, strict=True
|
||||
):
|
||||
"""Yield batches of items from *iterable* with a combined size limited by
|
||||
*max_size*.
|
||||
|
||||
>>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
|
||||
>>> list(constrained_batches(iterable, 10))
|
||||
[(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
|
||||
|
||||
If a *max_count* is supplied, the number of items per batch is also
|
||||
limited:
|
||||
|
||||
>>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
|
||||
>>> list(constrained_batches(iterable, 10, max_count = 2))
|
||||
[(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
|
||||
|
||||
If a *get_len* function is supplied, use that instead of :func:`len` to
|
||||
determine item size.
|
||||
|
||||
If *strict* is ``True``, raise ``ValueError`` if any single item is bigger
|
||||
than *max_size*. Otherwise, allow single items to exceed *max_size*.
|
||||
"""
|
||||
if max_size <= 0:
|
||||
raise ValueError('maximum size must be greater than zero')
|
||||
|
||||
batch = []
|
||||
batch_size = 0
|
||||
batch_count = 0
|
||||
for item in iterable:
|
||||
item_len = get_len(item)
|
||||
if strict and item_len > max_size:
|
||||
raise ValueError('item size exceeds maximum size')
|
||||
|
||||
reached_count = batch_count == max_count
|
||||
reached_size = item_len + batch_size > max_size
|
||||
if batch_count and (reached_size or reached_count):
|
||||
yield tuple(batch)
|
||||
batch.clear()
|
||||
batch_size = 0
|
||||
batch_count = 0
|
||||
|
||||
batch.append(item)
|
||||
batch_size += item_len
|
||||
batch_count += 1
|
||||
|
||||
if batch:
|
||||
yield tuple(batch)
|
||||
|
|
|
@ -72,7 +72,6 @@ class peekable(Generic[_T], Iterator[_T]):
|
|||
@overload
|
||||
def __getitem__(self, index: slice) -> List[_T]: ...
|
||||
|
||||
def collate(*iterables: Iterable[_T], **kwargs: Any) -> Iterable[_T]: ...
|
||||
def consumer(func: _GenFn) -> _GenFn: ...
|
||||
def ilen(iterable: Iterable[object]) -> int: ...
|
||||
def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ...
|
||||
|
@ -179,7 +178,7 @@ def padded(
|
|||
iterable: Iterable[_T],
|
||||
*,
|
||||
n: Optional[int] = ...,
|
||||
next_multiple: bool = ...
|
||||
next_multiple: bool = ...,
|
||||
) -> Iterator[Optional[_T]]: ...
|
||||
@overload
|
||||
def padded(
|
||||
|
@ -225,7 +224,7 @@ def zip_equal(
|
|||
__iter1: Iterable[_T],
|
||||
__iter2: Iterable[_T],
|
||||
__iter3: Iterable[_T],
|
||||
*iterables: Iterable[_T]
|
||||
*iterables: Iterable[_T],
|
||||
) -> Iterator[Tuple[_T, ...]]: ...
|
||||
@overload
|
||||
def zip_offset(
|
||||
|
@ -233,7 +232,7 @@ def zip_offset(
|
|||
*,
|
||||
offsets: _SizedIterable[int],
|
||||
longest: bool = ...,
|
||||
fillvalue: None = None
|
||||
fillvalue: None = None,
|
||||
) -> Iterator[Tuple[Optional[_T1]]]: ...
|
||||
@overload
|
||||
def zip_offset(
|
||||
|
@ -242,7 +241,7 @@ def zip_offset(
|
|||
*,
|
||||
offsets: _SizedIterable[int],
|
||||
longest: bool = ...,
|
||||
fillvalue: None = None
|
||||
fillvalue: None = None,
|
||||
) -> Iterator[Tuple[Optional[_T1], Optional[_T2]]]: ...
|
||||
@overload
|
||||
def zip_offset(
|
||||
|
@ -252,7 +251,7 @@ def zip_offset(
|
|||
*iterables: Iterable[_T],
|
||||
offsets: _SizedIterable[int],
|
||||
longest: bool = ...,
|
||||
fillvalue: None = None
|
||||
fillvalue: None = None,
|
||||
) -> Iterator[Tuple[Optional[_T], ...]]: ...
|
||||
@overload
|
||||
def zip_offset(
|
||||
|
@ -420,7 +419,7 @@ def difference(
|
|||
iterable: Iterable[_T],
|
||||
func: Callable[[_T, _T], _U] = ...,
|
||||
*,
|
||||
initial: None = ...
|
||||
initial: None = ...,
|
||||
) -> Iterator[Union[_T, _U]]: ...
|
||||
@overload
|
||||
def difference(
|
||||
|
@ -529,12 +528,12 @@ def distinct_combinations(
|
|||
def filter_except(
|
||||
validator: Callable[[Any], object],
|
||||
iterable: Iterable[_T],
|
||||
*exceptions: Type[BaseException]
|
||||
*exceptions: Type[BaseException],
|
||||
) -> Iterator[_T]: ...
|
||||
def map_except(
|
||||
function: Callable[[Any], _U],
|
||||
iterable: Iterable[_T],
|
||||
*exceptions: Type[BaseException]
|
||||
*exceptions: Type[BaseException],
|
||||
) -> Iterator[_U]: ...
|
||||
def map_if(
|
||||
iterable: Iterable[Any],
|
||||
|
@ -610,7 +609,7 @@ def zip_broadcast(
|
|||
scalar_types: Union[
|
||||
type, Tuple[Union[type, Tuple[Any, ...]], ...], None
|
||||
] = ...,
|
||||
strict: bool = ...
|
||||
strict: bool = ...,
|
||||
) -> Iterable[Tuple[_T, ...]]: ...
|
||||
def unique_in_window(
|
||||
iterable: Iterable[_T], n: int, key: Optional[Callable[[_T], _U]] = ...
|
||||
|
@ -640,7 +639,7 @@ def minmax(
|
|||
iterable_or_value: Iterable[_SupportsLessThanT],
|
||||
*,
|
||||
key: None = None,
|
||||
default: _U
|
||||
default: _U,
|
||||
) -> Union[_U, Tuple[_SupportsLessThanT, _SupportsLessThanT]]: ...
|
||||
@overload
|
||||
def minmax(
|
||||
|
@ -653,12 +652,23 @@ def minmax(
|
|||
def minmax(
|
||||
iterable_or_value: _SupportsLessThanT,
|
||||
__other: _SupportsLessThanT,
|
||||
*others: _SupportsLessThanT
|
||||
*others: _SupportsLessThanT,
|
||||
) -> Tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
|
||||
@overload
|
||||
def minmax(
|
||||
iterable_or_value: _T,
|
||||
__other: _T,
|
||||
*others: _T,
|
||||
key: Callable[[_T], _SupportsLessThan]
|
||||
key: Callable[[_T], _SupportsLessThan],
|
||||
) -> Tuple[_T, _T]: ...
|
||||
def longest_common_prefix(
|
||||
iterables: Iterable[Iterable[_T]],
|
||||
) -> Iterator[_T]: ...
|
||||
def iequals(*iterables: Iterable[object]) -> bool: ...
|
||||
def constrained_batches(
|
||||
iterable: Iterable[object],
|
||||
max_size: int,
|
||||
max_count: Optional[int] = ...,
|
||||
get_len: Callable[[_T], object] = ...,
|
||||
strict: bool = ...,
|
||||
) -> Iterator[Tuple[_T]]: ...
|
||||
|
|
|
@ -7,11 +7,16 @@ Some backward-compatible usability improvements have been made.
|
|||
.. [1] http://docs.python.org/library/itertools.html#recipes
|
||||
|
||||
"""
|
||||
import warnings
|
||||
import math
|
||||
import operator
|
||||
|
||||
from collections import deque
|
||||
from collections.abc import Sized
|
||||
from functools import reduce
|
||||
from itertools import (
|
||||
chain,
|
||||
combinations,
|
||||
compress,
|
||||
count,
|
||||
cycle,
|
||||
groupby,
|
||||
|
@ -21,11 +26,11 @@ from itertools import (
|
|||
tee,
|
||||
zip_longest,
|
||||
)
|
||||
import operator
|
||||
from random import randrange, sample, choice
|
||||
|
||||
__all__ = [
|
||||
'all_equal',
|
||||
'batched',
|
||||
'before_and_after',
|
||||
'consume',
|
||||
'convolve',
|
||||
|
@ -41,6 +46,7 @@ __all__ = [
|
|||
'pad_none',
|
||||
'pairwise',
|
||||
'partition',
|
||||
'polynomial_from_roots',
|
||||
'powerset',
|
||||
'prepend',
|
||||
'quantify',
|
||||
|
@ -50,7 +56,9 @@ __all__ = [
|
|||
'random_product',
|
||||
'repeatfunc',
|
||||
'roundrobin',
|
||||
'sieve',
|
||||
'sliding_window',
|
||||
'subslices',
|
||||
'tabulate',
|
||||
'tail',
|
||||
'take',
|
||||
|
@ -59,6 +67,8 @@ __all__ = [
|
|||
'unique_justseen',
|
||||
]
|
||||
|
||||
_marker = object()
|
||||
|
||||
|
||||
def take(n, iterable):
|
||||
"""Return first *n* items of the iterable as a list.
|
||||
|
@ -102,7 +112,14 @@ def tail(n, iterable):
|
|||
['E', 'F', 'G']
|
||||
|
||||
"""
|
||||
return iter(deque(iterable, maxlen=n))
|
||||
# If the given iterable has a length, then we can use islice to get its
|
||||
# final elements. Note that if the iterable is not actually Iterable,
|
||||
# either islice or deque will throw a TypeError. This is why we don't
|
||||
# check if it is Iterable.
|
||||
if isinstance(iterable, Sized):
|
||||
yield from islice(iterable, max(0, len(iterable) - n), None)
|
||||
else:
|
||||
yield from iter(deque(iterable, maxlen=n))
|
||||
|
||||
|
||||
def consume(iterator, n=None):
|
||||
|
@ -284,20 +301,83 @@ else:
|
|||
pairwise.__doc__ = _pairwise.__doc__
|
||||
|
||||
|
||||
def grouper(iterable, n, fillvalue=None):
|
||||
"""Collect data into fixed-length chunks or blocks.
|
||||
class UnequalIterablesError(ValueError):
|
||||
def __init__(self, details=None):
|
||||
msg = 'Iterables have different lengths'
|
||||
if details is not None:
|
||||
msg += (': index 0 has length {}; index {} has length {}').format(
|
||||
*details
|
||||
)
|
||||
|
||||
>>> list(grouper('ABCDEFG', 3, 'x'))
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
def _zip_equal_generator(iterables):
|
||||
for combo in zip_longest(*iterables, fillvalue=_marker):
|
||||
for val in combo:
|
||||
if val is _marker:
|
||||
raise UnequalIterablesError()
|
||||
yield combo
|
||||
|
||||
|
||||
def _zip_equal(*iterables):
|
||||
# Check whether the iterables are all the same size.
|
||||
try:
|
||||
first_size = len(iterables[0])
|
||||
for i, it in enumerate(iterables[1:], 1):
|
||||
size = len(it)
|
||||
if size != first_size:
|
||||
break
|
||||
else:
|
||||
# If we didn't break out, we can use the built-in zip.
|
||||
return zip(*iterables)
|
||||
|
||||
# If we did break out, there was a mismatch.
|
||||
raise UnequalIterablesError(details=(first_size, i, size))
|
||||
# If any one of the iterables didn't have a length, start reading
|
||||
# them until one runs out.
|
||||
except TypeError:
|
||||
return _zip_equal_generator(iterables)
|
||||
|
||||
|
||||
def grouper(iterable, n, incomplete='fill', fillvalue=None):
|
||||
"""Group elements from *iterable* into fixed-length groups of length *n*.
|
||||
|
||||
>>> list(grouper('ABCDEF', 3))
|
||||
[('A', 'B', 'C'), ('D', 'E', 'F')]
|
||||
|
||||
The keyword arguments *incomplete* and *fillvalue* control what happens for
|
||||
iterables whose length is not a multiple of *n*.
|
||||
|
||||
When *incomplete* is `'fill'`, the last group will contain instances of
|
||||
*fillvalue*.
|
||||
|
||||
>>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
|
||||
[('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
|
||||
|
||||
When *incomplete* is `'ignore'`, the last group will not be emitted.
|
||||
|
||||
>>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
|
||||
[('A', 'B', 'C'), ('D', 'E', 'F')]
|
||||
|
||||
When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised.
|
||||
|
||||
>>> it = grouper('ABCDEFG', 3, incomplete='strict')
|
||||
>>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
UnequalIterablesError
|
||||
|
||||
"""
|
||||
if isinstance(iterable, int):
|
||||
warnings.warn(
|
||||
"grouper expects iterable as first parameter", DeprecationWarning
|
||||
)
|
||||
n, iterable = iterable, n
|
||||
args = [iter(iterable)] * n
|
||||
return zip_longest(fillvalue=fillvalue, *args)
|
||||
if incomplete == 'fill':
|
||||
return zip_longest(*args, fillvalue=fillvalue)
|
||||
if incomplete == 'strict':
|
||||
return _zip_equal(*args)
|
||||
if incomplete == 'ignore':
|
||||
return zip(*args)
|
||||
else:
|
||||
raise ValueError('Expected fill, strict, or ignore')
|
||||
|
||||
|
||||
def roundrobin(*iterables):
|
||||
|
@ -658,11 +738,12 @@ def before_and_after(predicate, it):
|
|||
transition.append(elem)
|
||||
return
|
||||
|
||||
def remainder_iterator():
|
||||
yield from transition
|
||||
yield from it
|
||||
# Note: this is different from itertools recipes to allow nesting
|
||||
# before_and_after remainders into before_and_after again. See tests
|
||||
# for an example.
|
||||
remainder_iterator = chain(transition, it)
|
||||
|
||||
return true_iterator(), remainder_iterator()
|
||||
return true_iterator(), remainder_iterator
|
||||
|
||||
|
||||
def triplewise(iterable):
|
||||
|
@ -696,3 +777,65 @@ def sliding_window(iterable, n):
|
|||
for x in it:
|
||||
window.append(x)
|
||||
yield tuple(window)
|
||||
|
||||
|
||||
def subslices(iterable):
|
||||
"""Return all contiguous non-empty subslices of *iterable*.
|
||||
|
||||
>>> list(subslices('ABC'))
|
||||
[['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
|
||||
|
||||
This is similar to :func:`substrings`, but emits items in a different
|
||||
order.
|
||||
"""
|
||||
seq = list(iterable)
|
||||
slices = starmap(slice, combinations(range(len(seq) + 1), 2))
|
||||
return map(operator.getitem, repeat(seq), slices)
|
||||
|
||||
|
||||
def polynomial_from_roots(roots):
|
||||
"""Compute a polynomial's coefficients from its roots.
|
||||
|
||||
>>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
|
||||
>>> polynomial_from_roots(roots) # x^3 - 4 * x^2 - 17 * x + 60
|
||||
[1, -4, -17, 60]
|
||||
"""
|
||||
# Use math.prod for Python 3.8+,
|
||||
prod = getattr(math, 'prod', lambda x: reduce(operator.mul, x, 1))
|
||||
roots = list(map(operator.neg, roots))
|
||||
return [
|
||||
sum(map(prod, combinations(roots, k))) for k in range(len(roots) + 1)
|
||||
]
|
||||
|
||||
|
||||
def sieve(n):
|
||||
"""Yield the primes less than n.
|
||||
|
||||
>>> list(sieve(30))
|
||||
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
|
||||
"""
|
||||
isqrt = getattr(math, 'isqrt', lambda x: int(math.sqrt(x)))
|
||||
limit = isqrt(n) + 1
|
||||
data = bytearray([1]) * n
|
||||
data[:2] = 0, 0
|
||||
for p in compress(range(limit), data):
|
||||
data[p + p : n : p] = bytearray(len(range(p + p, n, p)))
|
||||
|
||||
return compress(count(), data)
|
||||
|
||||
|
||||
def batched(iterable, n):
|
||||
"""Batch data into lists of length *n*. The last batch may be shorter.
|
||||
|
||||
>>> list(batched('ABCDEFG', 3))
|
||||
[['A', 'B', 'C'], ['D', 'E', 'F'], ['G']]
|
||||
|
||||
This recipe is from the ``itertools`` docs. This library also provides
|
||||
:func:`chunked`, which has a different implementation.
|
||||
"""
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
batch = list(islice(it, n))
|
||||
if not batch:
|
||||
break
|
||||
yield batch
|
||||
|
|
|
@ -6,6 +6,7 @@ from typing import (
|
|||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
|
@ -39,21 +40,11 @@ def repeatfunc(
|
|||
func: Callable[..., _U], times: Optional[int] = ..., *args: Any
|
||||
) -> Iterator[_U]: ...
|
||||
def pairwise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T]]: ...
|
||||
@overload
|
||||
def grouper(
|
||||
iterable: Iterable[_T], n: int
|
||||
) -> Iterator[Tuple[Optional[_T], ...]]: ...
|
||||
@overload
|
||||
def grouper(
|
||||
iterable: Iterable[_T], n: int, fillvalue: _U
|
||||
) -> Iterator[Tuple[Union[_T, _U], ...]]: ...
|
||||
@overload
|
||||
def grouper( # Deprecated interface
|
||||
iterable: int, n: Iterable[_T]
|
||||
) -> Iterator[Tuple[Optional[_T], ...]]: ...
|
||||
@overload
|
||||
def grouper( # Deprecated interface
|
||||
iterable: int, n: Iterable[_T], fillvalue: _U
|
||||
iterable: Iterable[_T],
|
||||
n: int,
|
||||
incomplete: str = ...,
|
||||
fillvalue: _U = ...,
|
||||
) -> Iterator[Tuple[Union[_T, _U], ...]]: ...
|
||||
def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: ...
|
||||
def partition(
|
||||
|
@ -110,3 +101,10 @@ def triplewise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T, _T]]: ...
|
|||
def sliding_window(
|
||||
iterable: Iterable[_T], n: int
|
||||
) -> Iterator[Tuple[_T, ...]]: ...
|
||||
def subslices(iterable: Iterable[_T]) -> Iterator[List[_T]]: ...
|
||||
def polynomial_from_roots(roots: Sequence[int]) -> List[int]: ...
|
||||
def sieve(n: int) -> Iterator[int]: ...
|
||||
def batched(
|
||||
iterable: Iterable[_T],
|
||||
n: int,
|
||||
) -> Iterator[List[_T]]: ...
|
||||
|
|
131
lib/pydantic/__init__.py
Normal file
131
lib/pydantic/__init__.py
Normal file
|
@ -0,0 +1,131 @@
|
|||
# flake8: noqa
|
||||
from . import dataclasses
|
||||
from .annotated_types import create_model_from_namedtuple, create_model_from_typeddict
|
||||
from .class_validators import root_validator, validator
|
||||
from .config import BaseConfig, ConfigDict, Extra
|
||||
from .decorator import validate_arguments
|
||||
from .env_settings import BaseSettings
|
||||
from .error_wrappers import ValidationError
|
||||
from .errors import *
|
||||
from .fields import Field, PrivateAttr, Required
|
||||
from .main import *
|
||||
from .networks import *
|
||||
from .parse import Protocol
|
||||
from .tools import *
|
||||
from .types import *
|
||||
from .version import VERSION, compiled
|
||||
|
||||
__version__ = VERSION
|
||||
|
||||
# WARNING __all__ from .errors is not included here, it will be removed as an export here in v2
|
||||
# please use "from pydantic.errors import ..." instead
|
||||
__all__ = [
|
||||
# annotated types utils
|
||||
'create_model_from_namedtuple',
|
||||
'create_model_from_typeddict',
|
||||
# dataclasses
|
||||
'dataclasses',
|
||||
# class_validators
|
||||
'root_validator',
|
||||
'validator',
|
||||
# config
|
||||
'BaseConfig',
|
||||
'ConfigDict',
|
||||
'Extra',
|
||||
# decorator
|
||||
'validate_arguments',
|
||||
# env_settings
|
||||
'BaseSettings',
|
||||
# error_wrappers
|
||||
'ValidationError',
|
||||
# fields
|
||||
'Field',
|
||||
'Required',
|
||||
# main
|
||||
'BaseModel',
|
||||
'create_model',
|
||||
'validate_model',
|
||||
# network
|
||||
'AnyUrl',
|
||||
'AnyHttpUrl',
|
||||
'FileUrl',
|
||||
'HttpUrl',
|
||||
'stricturl',
|
||||
'EmailStr',
|
||||
'NameEmail',
|
||||
'IPvAnyAddress',
|
||||
'IPvAnyInterface',
|
||||
'IPvAnyNetwork',
|
||||
'PostgresDsn',
|
||||
'CockroachDsn',
|
||||
'AmqpDsn',
|
||||
'RedisDsn',
|
||||
'MongoDsn',
|
||||
'KafkaDsn',
|
||||
'validate_email',
|
||||
# parse
|
||||
'Protocol',
|
||||
# tools
|
||||
'parse_file_as',
|
||||
'parse_obj_as',
|
||||
'parse_raw_as',
|
||||
'schema_of',
|
||||
'schema_json_of',
|
||||
# types
|
||||
'NoneStr',
|
||||
'NoneBytes',
|
||||
'StrBytes',
|
||||
'NoneStrBytes',
|
||||
'StrictStr',
|
||||
'ConstrainedBytes',
|
||||
'conbytes',
|
||||
'ConstrainedList',
|
||||
'conlist',
|
||||
'ConstrainedSet',
|
||||
'conset',
|
||||
'ConstrainedFrozenSet',
|
||||
'confrozenset',
|
||||
'ConstrainedStr',
|
||||
'constr',
|
||||
'PyObject',
|
||||
'ConstrainedInt',
|
||||
'conint',
|
||||
'PositiveInt',
|
||||
'NegativeInt',
|
||||
'NonNegativeInt',
|
||||
'NonPositiveInt',
|
||||
'ConstrainedFloat',
|
||||
'confloat',
|
||||
'PositiveFloat',
|
||||
'NegativeFloat',
|
||||
'NonNegativeFloat',
|
||||
'NonPositiveFloat',
|
||||
'FiniteFloat',
|
||||
'ConstrainedDecimal',
|
||||
'condecimal',
|
||||
'ConstrainedDate',
|
||||
'condate',
|
||||
'UUID1',
|
||||
'UUID3',
|
||||
'UUID4',
|
||||
'UUID5',
|
||||
'FilePath',
|
||||
'DirectoryPath',
|
||||
'Json',
|
||||
'JsonWrapper',
|
||||
'SecretField',
|
||||
'SecretStr',
|
||||
'SecretBytes',
|
||||
'StrictBool',
|
||||
'StrictBytes',
|
||||
'StrictInt',
|
||||
'StrictFloat',
|
||||
'PaymentCardNumber',
|
||||
'PrivateAttr',
|
||||
'ByteSize',
|
||||
'PastDate',
|
||||
'FutureDate',
|
||||
# version
|
||||
'compiled',
|
||||
'VERSION',
|
||||
]
|
386
lib/pydantic/_hypothesis_plugin.py
Normal file
386
lib/pydantic/_hypothesis_plugin.py
Normal file
|
@ -0,0 +1,386 @@
|
|||
"""
|
||||
Register Hypothesis strategies for Pydantic custom types.
|
||||
|
||||
This enables fully-automatic generation of test data for most Pydantic classes.
|
||||
|
||||
Note that this module has *no* runtime impact on Pydantic itself; instead it
|
||||
is registered as a setuptools entry point and Hypothesis will import it if
|
||||
Pydantic is installed. See also:
|
||||
|
||||
https://hypothesis.readthedocs.io/en/latest/strategies.html#registering-strategies-via-setuptools-entry-points
|
||||
https://hypothesis.readthedocs.io/en/latest/data.html#hypothesis.strategies.register_type_strategy
|
||||
https://hypothesis.readthedocs.io/en/latest/strategies.html#interaction-with-pytest-cov
|
||||
https://pydantic-docs.helpmanual.io/usage/types/#pydantic-types
|
||||
|
||||
Note that because our motivation is to *improve user experience*, the strategies
|
||||
are always sound (never generate invalid data) but sacrifice completeness for
|
||||
maintainability (ie may be unable to generate some tricky but valid data).
|
||||
|
||||
Finally, this module makes liberal use of `# type: ignore[<code>]` pragmas.
|
||||
This is because Hypothesis annotates `register_type_strategy()` with
|
||||
`(T, SearchStrategy[T])`, but in most cases we register e.g. `ConstrainedInt`
|
||||
to generate instances of the builtin `int` type which match the constraints.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import ipaddress
|
||||
import json
|
||||
import math
|
||||
from fractions import Fraction
|
||||
from typing import Callable, Dict, Type, Union, cast, overload
|
||||
|
||||
import hypothesis.strategies as st
|
||||
|
||||
import pydantic
|
||||
import pydantic.color
|
||||
import pydantic.types
|
||||
from pydantic.utils import lenient_issubclass
|
||||
|
||||
# FilePath and DirectoryPath are explicitly unsupported, as we'd have to create
|
||||
# them on-disk, and that's unsafe in general without being told *where* to do so.
|
||||
#
|
||||
# URLs are unsupported because it's easy for users to define their own strategy for
|
||||
# "normal" URLs, and hard for us to define a general strategy which includes "weird"
|
||||
# URLs but doesn't also have unpredictable performance problems.
|
||||
#
|
||||
# conlist() and conset() are unsupported for now, because the workarounds for
|
||||
# Cython and Hypothesis to handle parametrized generic types are incompatible.
|
||||
# Once Cython can support 'normal' generics we'll revisit this.
|
||||
|
||||
# Emails
|
||||
try:
|
||||
import email_validator
|
||||
except ImportError: # pragma: no cover
|
||||
pass
|
||||
else:
|
||||
|
||||
def is_valid_email(s: str) -> bool:
|
||||
# Hypothesis' st.emails() occasionally generates emails like 0@A0--0.ac
|
||||
# that are invalid according to email-validator, so we filter those out.
|
||||
try:
|
||||
email_validator.validate_email(s, check_deliverability=False)
|
||||
return True
|
||||
except email_validator.EmailNotValidError: # pragma: no cover
|
||||
return False
|
||||
|
||||
# Note that these strategies deliberately stay away from any tricky Unicode
|
||||
# or other encoding issues; we're just trying to generate *something* valid.
|
||||
st.register_type_strategy(pydantic.EmailStr, st.emails().filter(is_valid_email)) # type: ignore[arg-type]
|
||||
st.register_type_strategy(
|
||||
pydantic.NameEmail,
|
||||
st.builds(
|
||||
'{} <{}>'.format, # type: ignore[arg-type]
|
||||
st.from_regex('[A-Za-z0-9_]+( [A-Za-z0-9_]+){0,5}', fullmatch=True),
|
||||
st.emails().filter(is_valid_email),
|
||||
),
|
||||
)
|
||||
|
||||
# PyObject - dotted names, in this case taken from the math module.
|
||||
st.register_type_strategy(
|
||||
pydantic.PyObject, # type: ignore[arg-type]
|
||||
st.sampled_from(
|
||||
[cast(pydantic.PyObject, f'math.{name}') for name in sorted(vars(math)) if not name.startswith('_')]
|
||||
),
|
||||
)
|
||||
|
||||
# CSS3 Colors; as name, hex, rgb(a) tuples or strings, or hsl strings
|
||||
_color_regexes = (
|
||||
'|'.join(
|
||||
(
|
||||
pydantic.color.r_hex_short,
|
||||
pydantic.color.r_hex_long,
|
||||
pydantic.color.r_rgb,
|
||||
pydantic.color.r_rgba,
|
||||
pydantic.color.r_hsl,
|
||||
pydantic.color.r_hsla,
|
||||
)
|
||||
)
|
||||
# Use more precise regex patterns to avoid value-out-of-range errors
|
||||
.replace(pydantic.color._r_sl, r'(?:(\d\d?(?:\.\d+)?|100(?:\.0+)?)%)')
|
||||
.replace(pydantic.color._r_alpha, r'(?:(0(?:\.\d+)?|1(?:\.0+)?|\.\d+|\d{1,2}%))')
|
||||
.replace(pydantic.color._r_255, r'(?:((?:\d|\d\d|[01]\d\d|2[0-4]\d|25[0-4])(?:\.\d+)?|255(?:\.0+)?))')
|
||||
)
|
||||
st.register_type_strategy(
|
||||
pydantic.color.Color,
|
||||
st.one_of(
|
||||
st.sampled_from(sorted(pydantic.color.COLORS_BY_NAME)),
|
||||
st.tuples(
|
||||
st.integers(0, 255),
|
||||
st.integers(0, 255),
|
||||
st.integers(0, 255),
|
||||
st.none() | st.floats(0, 1) | st.floats(0, 100).map('{}%'.format),
|
||||
),
|
||||
st.from_regex(_color_regexes, fullmatch=True),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Card numbers, valid according to the Luhn algorithm
|
||||
|
||||
|
||||
def add_luhn_digit(card_number: str) -> str:
|
||||
# See https://en.wikipedia.org/wiki/Luhn_algorithm
|
||||
for digit in '0123456789':
|
||||
with contextlib.suppress(Exception):
|
||||
pydantic.PaymentCardNumber.validate_luhn_check_digit(card_number + digit)
|
||||
return card_number + digit
|
||||
raise AssertionError('Unreachable') # pragma: no cover
|
||||
|
||||
|
||||
card_patterns = (
|
||||
# Note that these patterns omit the Luhn check digit; that's added by the function above
|
||||
'4[0-9]{14}', # Visa
|
||||
'5[12345][0-9]{13}', # Mastercard
|
||||
'3[47][0-9]{12}', # American Express
|
||||
'[0-26-9][0-9]{10,17}', # other (incomplete to avoid overlap)
|
||||
)
|
||||
st.register_type_strategy(
|
||||
pydantic.PaymentCardNumber,
|
||||
st.from_regex('|'.join(card_patterns), fullmatch=True).map(add_luhn_digit), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# UUIDs
|
||||
st.register_type_strategy(pydantic.UUID1, st.uuids(version=1))
|
||||
st.register_type_strategy(pydantic.UUID3, st.uuids(version=3))
|
||||
st.register_type_strategy(pydantic.UUID4, st.uuids(version=4))
|
||||
st.register_type_strategy(pydantic.UUID5, st.uuids(version=5))
|
||||
|
||||
# Secrets
|
||||
st.register_type_strategy(pydantic.SecretBytes, st.binary().map(pydantic.SecretBytes))
|
||||
st.register_type_strategy(pydantic.SecretStr, st.text().map(pydantic.SecretStr))
|
||||
|
||||
# IP addresses, networks, and interfaces
|
||||
st.register_type_strategy(pydantic.IPvAnyAddress, st.ip_addresses()) # type: ignore[arg-type]
|
||||
st.register_type_strategy(
|
||||
pydantic.IPvAnyInterface,
|
||||
st.from_type(ipaddress.IPv4Interface) | st.from_type(ipaddress.IPv6Interface), # type: ignore[arg-type]
|
||||
)
|
||||
st.register_type_strategy(
|
||||
pydantic.IPvAnyNetwork,
|
||||
st.from_type(ipaddress.IPv4Network) | st.from_type(ipaddress.IPv6Network), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# We hook into the con***() functions and the ConstrainedNumberMeta metaclass,
|
||||
# so here we only have to register subclasses for other constrained types which
|
||||
# don't go via those mechanisms. Then there are the registration hooks below.
|
||||
st.register_type_strategy(pydantic.StrictBool, st.booleans())
|
||||
st.register_type_strategy(pydantic.StrictStr, st.text())
|
||||
|
||||
|
||||
# Constrained-type resolver functions
|
||||
#
|
||||
# For these ones, we actually want to inspect the type in order to work out a
|
||||
# satisfying strategy. First up, the machinery for tracking resolver functions:
|
||||
|
||||
RESOLVERS: Dict[type, Callable[[type], st.SearchStrategy]] = {} # type: ignore[type-arg]
|
||||
|
||||
|
||||
@overload
|
||||
def _registered(typ: Type[pydantic.types.T]) -> Type[pydantic.types.T]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def _registered(typ: pydantic.types.ConstrainedNumberMeta) -> pydantic.types.ConstrainedNumberMeta:
|
||||
pass
|
||||
|
||||
|
||||
def _registered(
|
||||
typ: Union[Type[pydantic.types.T], pydantic.types.ConstrainedNumberMeta]
|
||||
) -> Union[Type[pydantic.types.T], pydantic.types.ConstrainedNumberMeta]:
|
||||
# This function replaces the version in `pydantic.types`, in order to
|
||||
# effect the registration of new constrained types so that Hypothesis
|
||||
# can generate valid examples.
|
||||
pydantic.types._DEFINED_TYPES.add(typ)
|
||||
for supertype, resolver in RESOLVERS.items():
|
||||
if issubclass(typ, supertype):
|
||||
st.register_type_strategy(typ, resolver(typ)) # type: ignore
|
||||
return typ
|
||||
raise NotImplementedError(f'Unknown type {typ!r} has no resolver to register') # pragma: no cover
|
||||
|
||||
|
||||
def resolves(
|
||||
typ: Union[type, pydantic.types.ConstrainedNumberMeta]
|
||||
) -> Callable[[Callable[..., st.SearchStrategy]], Callable[..., st.SearchStrategy]]: # type: ignore[type-arg]
|
||||
def inner(f): # type: ignore
|
||||
assert f not in RESOLVERS
|
||||
RESOLVERS[typ] = f
|
||||
return f
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
# Type-to-strategy resolver functions
|
||||
|
||||
|
||||
@resolves(pydantic.JsonWrapper)
|
||||
def resolve_json(cls): # type: ignore[no-untyped-def]
|
||||
try:
|
||||
inner = st.none() if cls.inner_type is None else st.from_type(cls.inner_type)
|
||||
except Exception: # pragma: no cover
|
||||
finite = st.floats(allow_infinity=False, allow_nan=False)
|
||||
inner = st.recursive(
|
||||
base=st.one_of(st.none(), st.booleans(), st.integers(), finite, st.text()),
|
||||
extend=lambda x: st.lists(x) | st.dictionaries(st.text(), x), # type: ignore
|
||||
)
|
||||
inner_type = getattr(cls, 'inner_type', None)
|
||||
return st.builds(
|
||||
cls.inner_type.json if lenient_issubclass(inner_type, pydantic.BaseModel) else json.dumps,
|
||||
inner,
|
||||
ensure_ascii=st.booleans(),
|
||||
indent=st.none() | st.integers(0, 16),
|
||||
sort_keys=st.booleans(),
|
||||
)
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedBytes)
|
||||
def resolve_conbytes(cls): # type: ignore[no-untyped-def] # pragma: no cover
|
||||
min_size = cls.min_length or 0
|
||||
max_size = cls.max_length
|
||||
if not cls.strip_whitespace:
|
||||
return st.binary(min_size=min_size, max_size=max_size)
|
||||
# Fun with regex to ensure we neither start nor end with whitespace
|
||||
repeats = '{{{},{}}}'.format(
|
||||
min_size - 2 if min_size > 2 else 0,
|
||||
max_size - 2 if (max_size or 0) > 2 else '',
|
||||
)
|
||||
if min_size >= 2:
|
||||
pattern = rf'\W.{repeats}\W'
|
||||
elif min_size == 1:
|
||||
pattern = rf'\W(.{repeats}\W)?'
|
||||
else:
|
||||
assert min_size == 0
|
||||
pattern = rf'(\W(.{repeats}\W)?)?'
|
||||
return st.from_regex(pattern.encode(), fullmatch=True)
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedDecimal)
|
||||
def resolve_condecimal(cls): # type: ignore[no-untyped-def]
|
||||
min_value = cls.ge
|
||||
max_value = cls.le
|
||||
if cls.gt is not None:
|
||||
assert min_value is None, 'Set `gt` or `ge`, but not both'
|
||||
min_value = cls.gt
|
||||
if cls.lt is not None:
|
||||
assert max_value is None, 'Set `lt` or `le`, but not both'
|
||||
max_value = cls.lt
|
||||
s = st.decimals(min_value, max_value, allow_nan=False, places=cls.decimal_places)
|
||||
if cls.lt is not None:
|
||||
s = s.filter(lambda d: d < cls.lt)
|
||||
if cls.gt is not None:
|
||||
s = s.filter(lambda d: cls.gt < d)
|
||||
return s
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedFloat)
|
||||
def resolve_confloat(cls): # type: ignore[no-untyped-def]
|
||||
min_value = cls.ge
|
||||
max_value = cls.le
|
||||
exclude_min = False
|
||||
exclude_max = False
|
||||
|
||||
if cls.gt is not None:
|
||||
assert min_value is None, 'Set `gt` or `ge`, but not both'
|
||||
min_value = cls.gt
|
||||
exclude_min = True
|
||||
if cls.lt is not None:
|
||||
assert max_value is None, 'Set `lt` or `le`, but not both'
|
||||
max_value = cls.lt
|
||||
exclude_max = True
|
||||
|
||||
if cls.multiple_of is None:
|
||||
return st.floats(min_value, max_value, exclude_min=exclude_min, exclude_max=exclude_max, allow_nan=False)
|
||||
|
||||
if min_value is not None:
|
||||
min_value = math.ceil(min_value / cls.multiple_of)
|
||||
if exclude_min:
|
||||
min_value = min_value + 1
|
||||
if max_value is not None:
|
||||
assert max_value >= cls.multiple_of, 'Cannot build model with max value smaller than multiple of'
|
||||
max_value = math.floor(max_value / cls.multiple_of)
|
||||
if exclude_max:
|
||||
max_value = max_value - 1
|
||||
|
||||
return st.integers(min_value, max_value).map(lambda x: x * cls.multiple_of)
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedInt)
|
||||
def resolve_conint(cls): # type: ignore[no-untyped-def]
|
||||
min_value = cls.ge
|
||||
max_value = cls.le
|
||||
if cls.gt is not None:
|
||||
assert min_value is None, 'Set `gt` or `ge`, but not both'
|
||||
min_value = cls.gt + 1
|
||||
if cls.lt is not None:
|
||||
assert max_value is None, 'Set `lt` or `le`, but not both'
|
||||
max_value = cls.lt - 1
|
||||
|
||||
if cls.multiple_of is None or cls.multiple_of == 1:
|
||||
return st.integers(min_value, max_value)
|
||||
|
||||
# These adjustments and the .map handle integer-valued multiples, while the
|
||||
# .filter handles trickier cases as for confloat.
|
||||
if min_value is not None:
|
||||
min_value = math.ceil(Fraction(min_value) / Fraction(cls.multiple_of))
|
||||
if max_value is not None:
|
||||
max_value = math.floor(Fraction(max_value) / Fraction(cls.multiple_of))
|
||||
return st.integers(min_value, max_value).map(lambda x: x * cls.multiple_of)
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedDate)
|
||||
def resolve_condate(cls): # type: ignore[no-untyped-def]
|
||||
if cls.ge is not None:
|
||||
assert cls.gt is None, 'Set `gt` or `ge`, but not both'
|
||||
min_value = cls.ge
|
||||
elif cls.gt is not None:
|
||||
min_value = cls.gt + datetime.timedelta(days=1)
|
||||
else:
|
||||
min_value = datetime.date.min
|
||||
if cls.le is not None:
|
||||
assert cls.lt is None, 'Set `lt` or `le`, but not both'
|
||||
max_value = cls.le
|
||||
elif cls.lt is not None:
|
||||
max_value = cls.lt - datetime.timedelta(days=1)
|
||||
else:
|
||||
max_value = datetime.date.max
|
||||
return st.dates(min_value, max_value)
|
||||
|
||||
|
||||
@resolves(pydantic.ConstrainedStr)
|
||||
def resolve_constr(cls): # type: ignore[no-untyped-def] # pragma: no cover
|
||||
min_size = cls.min_length or 0
|
||||
max_size = cls.max_length
|
||||
|
||||
if cls.regex is None and not cls.strip_whitespace:
|
||||
return st.text(min_size=min_size, max_size=max_size)
|
||||
|
||||
if cls.regex is not None:
|
||||
strategy = st.from_regex(cls.regex)
|
||||
if cls.strip_whitespace:
|
||||
strategy = strategy.filter(lambda s: s == s.strip())
|
||||
elif cls.strip_whitespace:
|
||||
repeats = '{{{},{}}}'.format(
|
||||
min_size - 2 if min_size > 2 else 0,
|
||||
max_size - 2 if (max_size or 0) > 2 else '',
|
||||
)
|
||||
if min_size >= 2:
|
||||
strategy = st.from_regex(rf'\W.{repeats}\W')
|
||||
elif min_size == 1:
|
||||
strategy = st.from_regex(rf'\W(.{repeats}\W)?')
|
||||
else:
|
||||
assert min_size == 0
|
||||
strategy = st.from_regex(rf'(\W(.{repeats}\W)?)?')
|
||||
|
||||
if min_size == 0 and max_size is None:
|
||||
return strategy
|
||||
elif max_size is None:
|
||||
return strategy.filter(lambda s: min_size <= len(s))
|
||||
return strategy.filter(lambda s: min_size <= len(s) <= max_size)
|
||||
|
||||
|
||||
# Finally, register all previously-defined types, and patch in our new function
|
||||
for typ in list(pydantic.types._DEFINED_TYPES):
|
||||
_registered(typ)
|
||||
pydantic.types._registered = _registered
|
||||
st.register_type_strategy(pydantic.Json, resolve_json)
|
72
lib/pydantic/annotated_types.py
Normal file
72
lib/pydantic/annotated_types.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, NamedTuple, Type
|
||||
|
||||
from .fields import Required
|
||||
from .main import BaseModel, create_model
|
||||
from .typing import is_typeddict, is_typeddict_special
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
|
||||
def is_legacy_typeddict(typeddict_cls: Type['TypedDict']) -> bool: # type: ignore[valid-type]
|
||||
return is_typeddict(typeddict_cls) and type(typeddict_cls).__module__ == 'typing'
|
||||
|
||||
else:
|
||||
|
||||
def is_legacy_typeddict(_: Any) -> Any:
|
||||
return False
|
||||
|
||||
|
||||
def create_model_from_typeddict(
|
||||
# Mypy bug: `Type[TypedDict]` is resolved as `Any` https://github.com/python/mypy/issues/11030
|
||||
typeddict_cls: Type['TypedDict'], # type: ignore[valid-type]
|
||||
**kwargs: Any,
|
||||
) -> Type['BaseModel']:
|
||||
"""
|
||||
Create a `BaseModel` based on the fields of a `TypedDict`.
|
||||
Since `typing.TypedDict` in Python 3.8 does not store runtime information about optional keys,
|
||||
we raise an error if this happens (see https://bugs.python.org/issue38834).
|
||||
"""
|
||||
field_definitions: Dict[str, Any]
|
||||
|
||||
# Best case scenario: with python 3.9+ or when `TypedDict` is imported from `typing_extensions`
|
||||
if not hasattr(typeddict_cls, '__required_keys__'):
|
||||
raise TypeError(
|
||||
'You should use `typing_extensions.TypedDict` instead of `typing.TypedDict` with Python < 3.9.2. '
|
||||
'Without it, there is no way to differentiate required and optional fields when subclassed.'
|
||||
)
|
||||
|
||||
if is_legacy_typeddict(typeddict_cls) and any(
|
||||
is_typeddict_special(t) for t in typeddict_cls.__annotations__.values()
|
||||
):
|
||||
raise TypeError(
|
||||
'You should use `typing_extensions.TypedDict` instead of `typing.TypedDict` with Python < 3.11. '
|
||||
'Without it, there is no way to reflect Required/NotRequired keys.'
|
||||
)
|
||||
|
||||
required_keys: FrozenSet[str] = typeddict_cls.__required_keys__ # type: ignore[attr-defined]
|
||||
field_definitions = {
|
||||
field_name: (field_type, Required if field_name in required_keys else None)
|
||||
for field_name, field_type in typeddict_cls.__annotations__.items()
|
||||
}
|
||||
|
||||
return create_model(typeddict_cls.__name__, **kwargs, **field_definitions)
|
||||
|
||||
|
||||
def create_model_from_namedtuple(namedtuple_cls: Type['NamedTuple'], **kwargs: Any) -> Type['BaseModel']:
|
||||
"""
|
||||
Create a `BaseModel` based on the fields of a named tuple.
|
||||
A named tuple can be created with `typing.NamedTuple` and declared annotations
|
||||
but also with `collections.namedtuple`, in this case we consider all fields
|
||||
to have type `Any`.
|
||||
"""
|
||||
# With python 3.10+, `__annotations__` always exists but can be empty hence the `getattr... or...` logic
|
||||
namedtuple_annotations: Dict[str, Type[Any]] = getattr(namedtuple_cls, '__annotations__', None) or {
|
||||
k: Any for k in namedtuple_cls._fields
|
||||
}
|
||||
field_definitions: Dict[str, Any] = {
|
||||
field_name: (field_type, Required) for field_name, field_type in namedtuple_annotations.items()
|
||||
}
|
||||
return create_model(namedtuple_cls.__name__, **kwargs, **field_definitions)
|
342
lib/pydantic/class_validators.py
Normal file
342
lib/pydantic/class_validators.py
Normal file
|
@ -0,0 +1,342 @@
|
|||
import warnings
|
||||
from collections import ChainMap
|
||||
from functools import wraps
|
||||
from itertools import chain
|
||||
from types import FunctionType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, overload
|
||||
|
||||
from .errors import ConfigError
|
||||
from .typing import AnyCallable
|
||||
from .utils import ROOT_KEY, in_ipython
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .typing import AnyClassMethod
|
||||
|
||||
|
||||
class Validator:
|
||||
__slots__ = 'func', 'pre', 'each_item', 'always', 'check_fields', 'skip_on_failure'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: AnyCallable,
|
||||
pre: bool = False,
|
||||
each_item: bool = False,
|
||||
always: bool = False,
|
||||
check_fields: bool = False,
|
||||
skip_on_failure: bool = False,
|
||||
):
|
||||
self.func = func
|
||||
self.pre = pre
|
||||
self.each_item = each_item
|
||||
self.always = always
|
||||
self.check_fields = check_fields
|
||||
self.skip_on_failure = skip_on_failure
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inspect import Signature
|
||||
|
||||
from .config import BaseConfig
|
||||
from .fields import ModelField
|
||||
from .types import ModelOrDc
|
||||
|
||||
ValidatorCallable = Callable[[Optional[ModelOrDc], Any, Dict[str, Any], ModelField, Type[BaseConfig]], Any]
|
||||
ValidatorsList = List[ValidatorCallable]
|
||||
ValidatorListDict = Dict[str, List[Validator]]
|
||||
|
||||
_FUNCS: Set[str] = set()
|
||||
VALIDATOR_CONFIG_KEY = '__validator_config__'
|
||||
ROOT_VALIDATOR_CONFIG_KEY = '__root_validator_config__'
|
||||
|
||||
|
||||
def validator(
|
||||
*fields: str,
|
||||
pre: bool = False,
|
||||
each_item: bool = False,
|
||||
always: bool = False,
|
||||
check_fields: bool = True,
|
||||
whole: bool = None,
|
||||
allow_reuse: bool = False,
|
||||
) -> Callable[[AnyCallable], 'AnyClassMethod']:
|
||||
"""
|
||||
Decorate methods on the class indicating that they should be used to validate fields
|
||||
:param fields: which field(s) the method should be called on
|
||||
:param pre: whether or not this validator should be called before the standard validators (else after)
|
||||
:param each_item: for complex objects (sets, lists etc.) whether to validate individual elements rather than the
|
||||
whole object
|
||||
:param always: whether this method and other validators should be called even if the value is missing
|
||||
:param check_fields: whether to check that the fields actually exist on the model
|
||||
:param allow_reuse: whether to track and raise an error if another validator refers to the decorated function
|
||||
"""
|
||||
if not fields:
|
||||
raise ConfigError('validator with no fields specified')
|
||||
elif isinstance(fields[0], FunctionType):
|
||||
raise ConfigError(
|
||||
"validators should be used with fields and keyword arguments, not bare. " # noqa: Q000
|
||||
"E.g. usage should be `@validator('<field_name>', ...)`"
|
||||
)
|
||||
elif not all(isinstance(field, str) for field in fields):
|
||||
raise ConfigError(
|
||||
"validator fields should be passed as separate string args. " # noqa: Q000
|
||||
"E.g. usage should be `@validator('<field_name_1>', '<field_name_2>', ...)`"
|
||||
)
|
||||
|
||||
if whole is not None:
|
||||
warnings.warn(
|
||||
'The "whole" keyword argument is deprecated, use "each_item" (inverse meaning, default False) instead',
|
||||
DeprecationWarning,
|
||||
)
|
||||
assert each_item is False, '"each_item" and "whole" conflict, remove "whole"'
|
||||
each_item = not whole
|
||||
|
||||
def dec(f: AnyCallable) -> 'AnyClassMethod':
|
||||
f_cls = _prepare_validator(f, allow_reuse)
|
||||
setattr(
|
||||
f_cls,
|
||||
VALIDATOR_CONFIG_KEY,
|
||||
(
|
||||
fields,
|
||||
Validator(func=f_cls.__func__, pre=pre, each_item=each_item, always=always, check_fields=check_fields),
|
||||
),
|
||||
)
|
||||
return f_cls
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
@overload
|
||||
def root_validator(_func: AnyCallable) -> 'AnyClassMethod':
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def root_validator(
|
||||
*, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False
|
||||
) -> Callable[[AnyCallable], 'AnyClassMethod']:
|
||||
...
|
||||
|
||||
|
||||
def root_validator(
|
||||
_func: Optional[AnyCallable] = None, *, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False
|
||||
) -> Union['AnyClassMethod', Callable[[AnyCallable], 'AnyClassMethod']]:
|
||||
"""
|
||||
Decorate methods on a model indicating that they should be used to validate (and perhaps modify) data either
|
||||
before or after standard model parsing/validation is performed.
|
||||
"""
|
||||
if _func:
|
||||
f_cls = _prepare_validator(_func, allow_reuse)
|
||||
setattr(
|
||||
f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure)
|
||||
)
|
||||
return f_cls
|
||||
|
||||
def dec(f: AnyCallable) -> 'AnyClassMethod':
|
||||
f_cls = _prepare_validator(f, allow_reuse)
|
||||
setattr(
|
||||
f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure)
|
||||
)
|
||||
return f_cls
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def _prepare_validator(function: AnyCallable, allow_reuse: bool) -> 'AnyClassMethod':
|
||||
"""
|
||||
Avoid validators with duplicated names since without this, validators can be overwritten silently
|
||||
which generally isn't the intended behaviour, don't run in ipython (see #312) or if allow_reuse is False.
|
||||
"""
|
||||
f_cls = function if isinstance(function, classmethod) else classmethod(function)
|
||||
if not in_ipython() and not allow_reuse:
|
||||
ref = f_cls.__func__.__module__ + '.' + f_cls.__func__.__qualname__
|
||||
if ref in _FUNCS:
|
||||
raise ConfigError(f'duplicate validator function "{ref}"; if this is intended, set `allow_reuse=True`')
|
||||
_FUNCS.add(ref)
|
||||
return f_cls
|
||||
|
||||
|
||||
class ValidatorGroup:
|
||||
def __init__(self, validators: 'ValidatorListDict') -> None:
|
||||
self.validators = validators
|
||||
self.used_validators = {'*'}
|
||||
|
||||
def get_validators(self, name: str) -> Optional[Dict[str, Validator]]:
|
||||
self.used_validators.add(name)
|
||||
validators = self.validators.get(name, [])
|
||||
if name != ROOT_KEY:
|
||||
validators += self.validators.get('*', [])
|
||||
if validators:
|
||||
return {v.func.__name__: v for v in validators}
|
||||
else:
|
||||
return None
|
||||
|
||||
def check_for_unused(self) -> None:
|
||||
unused_validators = set(
|
||||
chain.from_iterable(
|
||||
(v.func.__name__ for v in self.validators[f] if v.check_fields)
|
||||
for f in (self.validators.keys() - self.used_validators)
|
||||
)
|
||||
)
|
||||
if unused_validators:
|
||||
fn = ', '.join(unused_validators)
|
||||
raise ConfigError(
|
||||
f"Validators defined with incorrect fields: {fn} " # noqa: Q000
|
||||
f"(use check_fields=False if you're inheriting from the model and intended this)"
|
||||
)
|
||||
|
||||
|
||||
def extract_validators(namespace: Dict[str, Any]) -> Dict[str, List[Validator]]:
|
||||
validators: Dict[str, List[Validator]] = {}
|
||||
for var_name, value in namespace.items():
|
||||
validator_config = getattr(value, VALIDATOR_CONFIG_KEY, None)
|
||||
if validator_config:
|
||||
fields, v = validator_config
|
||||
for field in fields:
|
||||
if field in validators:
|
||||
validators[field].append(v)
|
||||
else:
|
||||
validators[field] = [v]
|
||||
return validators
|
||||
|
||||
|
||||
def extract_root_validators(namespace: Dict[str, Any]) -> Tuple[List[AnyCallable], List[Tuple[bool, AnyCallable]]]:
|
||||
from inspect import signature
|
||||
|
||||
pre_validators: List[AnyCallable] = []
|
||||
post_validators: List[Tuple[bool, AnyCallable]] = []
|
||||
for name, value in namespace.items():
|
||||
validator_config: Optional[Validator] = getattr(value, ROOT_VALIDATOR_CONFIG_KEY, None)
|
||||
if validator_config:
|
||||
sig = signature(validator_config.func)
|
||||
args = list(sig.parameters.keys())
|
||||
if args[0] == 'self':
|
||||
raise ConfigError(
|
||||
f'Invalid signature for root validator {name}: {sig}, "self" not permitted as first argument, '
|
||||
f'should be: (cls, values).'
|
||||
)
|
||||
if len(args) != 2:
|
||||
raise ConfigError(f'Invalid signature for root validator {name}: {sig}, should be: (cls, values).')
|
||||
# check function signature
|
||||
if validator_config.pre:
|
||||
pre_validators.append(validator_config.func)
|
||||
else:
|
||||
post_validators.append((validator_config.skip_on_failure, validator_config.func))
|
||||
return pre_validators, post_validators
|
||||
|
||||
|
||||
def inherit_validators(base_validators: 'ValidatorListDict', validators: 'ValidatorListDict') -> 'ValidatorListDict':
|
||||
for field, field_validators in base_validators.items():
|
||||
if field not in validators:
|
||||
validators[field] = []
|
||||
validators[field] += field_validators
|
||||
return validators
|
||||
|
||||
|
||||
def make_generic_validator(validator: AnyCallable) -> 'ValidatorCallable':
|
||||
"""
|
||||
Make a generic function which calls a validator with the right arguments.
|
||||
|
||||
Unfortunately other approaches (eg. return a partial of a function that builds the arguments) is slow,
|
||||
hence this laborious way of doing things.
|
||||
|
||||
It's done like this so validators don't all need **kwargs in their signature, eg. any combination of
|
||||
the arguments "values", "fields" and/or "config" are permitted.
|
||||
"""
|
||||
from inspect import signature
|
||||
|
||||
sig = signature(validator)
|
||||
args = list(sig.parameters.keys())
|
||||
first_arg = args.pop(0)
|
||||
if first_arg == 'self':
|
||||
raise ConfigError(
|
||||
f'Invalid signature for validator {validator}: {sig}, "self" not permitted as first argument, '
|
||||
f'should be: (cls, value, values, config, field), "values", "config" and "field" are all optional.'
|
||||
)
|
||||
elif first_arg == 'cls':
|
||||
# assume the second argument is value
|
||||
return wraps(validator)(_generic_validator_cls(validator, sig, set(args[1:])))
|
||||
else:
|
||||
# assume the first argument was value which has already been removed
|
||||
return wraps(validator)(_generic_validator_basic(validator, sig, set(args)))
|
||||
|
||||
|
||||
def prep_validators(v_funcs: Iterable[AnyCallable]) -> 'ValidatorsList':
|
||||
return [make_generic_validator(f) for f in v_funcs if f]
|
||||
|
||||
|
||||
all_kwargs = {'values', 'field', 'config'}
|
||||
|
||||
|
||||
def _generic_validator_cls(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable':
|
||||
# assume the first argument is value
|
||||
has_kwargs = False
|
||||
if 'kwargs' in args:
|
||||
has_kwargs = True
|
||||
args -= {'kwargs'}
|
||||
|
||||
if not args.issubset(all_kwargs):
|
||||
raise ConfigError(
|
||||
f'Invalid signature for validator {validator}: {sig}, should be: '
|
||||
f'(cls, value, values, config, field), "values", "config" and "field" are all optional.'
|
||||
)
|
||||
|
||||
if has_kwargs:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config)
|
||||
elif args == set():
|
||||
return lambda cls, v, values, field, config: validator(cls, v)
|
||||
elif args == {'values'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, values=values)
|
||||
elif args == {'field'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, field=field)
|
||||
elif args == {'config'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, config=config)
|
||||
elif args == {'values', 'field'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field)
|
||||
elif args == {'values', 'config'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, values=values, config=config)
|
||||
elif args == {'field', 'config'}:
|
||||
return lambda cls, v, values, field, config: validator(cls, v, field=field, config=config)
|
||||
else:
|
||||
# args == {'values', 'field', 'config'}
|
||||
return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config)
|
||||
|
||||
|
||||
def _generic_validator_basic(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable':
|
||||
has_kwargs = False
|
||||
if 'kwargs' in args:
|
||||
has_kwargs = True
|
||||
args -= {'kwargs'}
|
||||
|
||||
if not args.issubset(all_kwargs):
|
||||
raise ConfigError(
|
||||
f'Invalid signature for validator {validator}: {sig}, should be: '
|
||||
f'(value, values, config, field), "values", "config" and "field" are all optional.'
|
||||
)
|
||||
|
||||
if has_kwargs:
|
||||
return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config)
|
||||
elif args == set():
|
||||
return lambda cls, v, values, field, config: validator(v)
|
||||
elif args == {'values'}:
|
||||
return lambda cls, v, values, field, config: validator(v, values=values)
|
||||
elif args == {'field'}:
|
||||
return lambda cls, v, values, field, config: validator(v, field=field)
|
||||
elif args == {'config'}:
|
||||
return lambda cls, v, values, field, config: validator(v, config=config)
|
||||
elif args == {'values', 'field'}:
|
||||
return lambda cls, v, values, field, config: validator(v, values=values, field=field)
|
||||
elif args == {'values', 'config'}:
|
||||
return lambda cls, v, values, field, config: validator(v, values=values, config=config)
|
||||
elif args == {'field', 'config'}:
|
||||
return lambda cls, v, values, field, config: validator(v, field=field, config=config)
|
||||
else:
|
||||
# args == {'values', 'field', 'config'}
|
||||
return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config)
|
||||
|
||||
|
||||
def gather_all_validators(type_: 'ModelOrDc') -> Dict[str, 'AnyClassMethod']:
|
||||
all_attributes = ChainMap(*[cls.__dict__ for cls in type_.__mro__]) # type: ignore[arg-type,var-annotated]
|
||||
return {
|
||||
k: v
|
||||
for k, v in all_attributes.items()
|
||||
if hasattr(v, VALIDATOR_CONFIG_KEY) or hasattr(v, ROOT_VALIDATOR_CONFIG_KEY)
|
||||
}
|
494
lib/pydantic/color.py
Normal file
494
lib/pydantic/color.py
Normal file
|
@ -0,0 +1,494 @@
|
|||
"""
|
||||
Color definitions are used as per CSS3 specification:
|
||||
http://www.w3.org/TR/css3-color/#svg-color
|
||||
|
||||
A few colors have multiple names referring to the sames colors, eg. `grey` and `gray` or `aqua` and `cyan`.
|
||||
|
||||
In these cases the LAST color when sorted alphabetically takes preferences,
|
||||
eg. Color((0, 255, 255)).as_named() == 'cyan' because "cyan" comes after "aqua".
|
||||
"""
|
||||
import math
|
||||
import re
|
||||
from colorsys import hls_to_rgb, rgb_to_hls
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast
|
||||
|
||||
from .errors import ColorError
|
||||
from .utils import Representation, almost_equal_floats
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .typing import CallableGenerator, ReprArgs
|
||||
|
||||
ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]]
|
||||
ColorType = Union[ColorTuple, str]
|
||||
HslColorTuple = Union[Tuple[float, float, float], Tuple[float, float, float, float]]
|
||||
|
||||
|
||||
class RGBA:
|
||||
"""
|
||||
Internal use only as a representation of a color.
|
||||
"""
|
||||
|
||||
__slots__ = 'r', 'g', 'b', 'alpha', '_tuple'
|
||||
|
||||
def __init__(self, r: float, g: float, b: float, alpha: Optional[float]):
|
||||
self.r = r
|
||||
self.g = g
|
||||
self.b = b
|
||||
self.alpha = alpha
|
||||
|
||||
self._tuple: Tuple[float, float, float, Optional[float]] = (r, g, b, alpha)
|
||||
|
||||
def __getitem__(self, item: Any) -> Any:
|
||||
return self._tuple[item]
|
||||
|
||||
|
||||
# these are not compiled here to avoid import slowdown, they'll be compiled the first time they're used, then cached
|
||||
r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*'
|
||||
r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*'
|
||||
_r_255 = r'(\d{1,3}(?:\.\d+)?)'
|
||||
_r_comma = r'\s*,\s*'
|
||||
r_rgb = fr'\s*rgb\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}\)\s*'
|
||||
_r_alpha = r'(\d(?:\.\d+)?|\.\d+|\d{1,2}%)'
|
||||
r_rgba = fr'\s*rgba\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_alpha}\s*\)\s*'
|
||||
_r_h = r'(-?\d+(?:\.\d+)?|-?\.\d+)(deg|rad|turn)?'
|
||||
_r_sl = r'(\d{1,3}(?:\.\d+)?)%'
|
||||
r_hsl = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}\s*\)\s*'
|
||||
r_hsla = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}{_r_comma}{_r_alpha}\s*\)\s*'
|
||||
|
||||
# colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used
|
||||
repeat_colors = {int(c * 2, 16) for c in '0123456789abcdef'}
|
||||
rads = 2 * math.pi
|
||||
|
||||
|
||||
class Color(Representation):
|
||||
__slots__ = '_original', '_rgba'
|
||||
|
||||
def __init__(self, value: ColorType) -> None:
|
||||
self._rgba: RGBA
|
||||
self._original: ColorType
|
||||
if isinstance(value, (tuple, list)):
|
||||
self._rgba = parse_tuple(value)
|
||||
elif isinstance(value, str):
|
||||
self._rgba = parse_str(value)
|
||||
elif isinstance(value, Color):
|
||||
self._rgba = value._rgba
|
||||
value = value._original
|
||||
else:
|
||||
raise ColorError(reason='value must be a tuple, list or string')
|
||||
|
||||
# if we've got here value must be a valid color
|
||||
self._original = value
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='color')
|
||||
|
||||
def original(self) -> ColorType:
|
||||
"""
|
||||
Original value passed to Color
|
||||
"""
|
||||
return self._original
|
||||
|
||||
def as_named(self, *, fallback: bool = False) -> str:
|
||||
if self._rgba.alpha is None:
|
||||
rgb = cast(Tuple[int, int, int], self.as_rgb_tuple())
|
||||
try:
|
||||
return COLORS_BY_VALUE[rgb]
|
||||
except KeyError as e:
|
||||
if fallback:
|
||||
return self.as_hex()
|
||||
else:
|
||||
raise ValueError('no named color found, use fallback=True, as_hex() or as_rgb()') from e
|
||||
else:
|
||||
return self.as_hex()
|
||||
|
||||
def as_hex(self) -> str:
|
||||
"""
|
||||
Hex string representing the color can be 3, 4, 6 or 8 characters depending on whether the string
|
||||
a "short" representation of the color is possible and whether there's an alpha channel.
|
||||
"""
|
||||
values = [float_to_255(c) for c in self._rgba[:3]]
|
||||
if self._rgba.alpha is not None:
|
||||
values.append(float_to_255(self._rgba.alpha))
|
||||
|
||||
as_hex = ''.join(f'{v:02x}' for v in values)
|
||||
if all(c in repeat_colors for c in values):
|
||||
as_hex = ''.join(as_hex[c] for c in range(0, len(as_hex), 2))
|
||||
return '#' + as_hex
|
||||
|
||||
def as_rgb(self) -> str:
|
||||
"""
|
||||
Color as an rgb(<r>, <g>, <b>) or rgba(<r>, <g>, <b>, <a>) string.
|
||||
"""
|
||||
if self._rgba.alpha is None:
|
||||
return f'rgb({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)})'
|
||||
else:
|
||||
return (
|
||||
f'rgba({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)}, '
|
||||
f'{round(self._alpha_float(), 2)})'
|
||||
)
|
||||
|
||||
def as_rgb_tuple(self, *, alpha: Optional[bool] = None) -> ColorTuple:
|
||||
"""
|
||||
Color as an RGB or RGBA tuple; red, green and blue are in the range 0 to 255, alpha if included is
|
||||
in the range 0 to 1.
|
||||
|
||||
:param alpha: whether to include the alpha channel, options are
|
||||
None - (default) include alpha only if it's set (e.g. not None)
|
||||
True - always include alpha,
|
||||
False - always omit alpha,
|
||||
"""
|
||||
r, g, b = (float_to_255(c) for c in self._rgba[:3])
|
||||
if alpha is None:
|
||||
if self._rgba.alpha is None:
|
||||
return r, g, b
|
||||
else:
|
||||
return r, g, b, self._alpha_float()
|
||||
elif alpha:
|
||||
return r, g, b, self._alpha_float()
|
||||
else:
|
||||
# alpha is False
|
||||
return r, g, b
|
||||
|
||||
def as_hsl(self) -> str:
|
||||
"""
|
||||
Color as an hsl(<h>, <s>, <l>) or hsl(<h>, <s>, <l>, <a>) string.
|
||||
"""
|
||||
if self._rgba.alpha is None:
|
||||
h, s, li = self.as_hsl_tuple(alpha=False) # type: ignore
|
||||
return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%})'
|
||||
else:
|
||||
h, s, li, a = self.as_hsl_tuple(alpha=True) # type: ignore
|
||||
return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%}, {round(a, 2)})'
|
||||
|
||||
def as_hsl_tuple(self, *, alpha: Optional[bool] = None) -> HslColorTuple:
|
||||
"""
|
||||
Color as an HSL or HSLA tuple, e.g. hue, saturation, lightness and optionally alpha; all elements are in
|
||||
the range 0 to 1.
|
||||
|
||||
NOTE: this is HSL as used in HTML and most other places, not HLS as used in python's colorsys.
|
||||
|
||||
:param alpha: whether to include the alpha channel, options are
|
||||
None - (default) include alpha only if it's set (e.g. not None)
|
||||
True - always include alpha,
|
||||
False - always omit alpha,
|
||||
"""
|
||||
h, l, s = rgb_to_hls(self._rgba.r, self._rgba.g, self._rgba.b)
|
||||
if alpha is None:
|
||||
if self._rgba.alpha is None:
|
||||
return h, s, l
|
||||
else:
|
||||
return h, s, l, self._alpha_float()
|
||||
if alpha:
|
||||
return h, s, l, self._alpha_float()
|
||||
else:
|
||||
# alpha is False
|
||||
return h, s, l
|
||||
|
||||
def _alpha_float(self) -> float:
|
||||
return 1 if self._rgba.alpha is None else self._rgba.alpha
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
yield cls
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.as_named(fallback=True)
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
return [(None, self.as_named(fallback=True))] + [('rgb', self.as_rgb_tuple())] # type: ignore
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, Color) and self.as_rgb_tuple() == other.as_rgb_tuple()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.as_rgb_tuple())
|
||||
|
||||
|
||||
def parse_tuple(value: Tuple[Any, ...]) -> RGBA:
|
||||
"""
|
||||
Parse a tuple or list as a color.
|
||||
"""
|
||||
if len(value) == 3:
|
||||
r, g, b = (parse_color_value(v) for v in value)
|
||||
return RGBA(r, g, b, None)
|
||||
elif len(value) == 4:
|
||||
r, g, b = (parse_color_value(v) for v in value[:3])
|
||||
return RGBA(r, g, b, parse_float_alpha(value[3]))
|
||||
else:
|
||||
raise ColorError(reason='tuples must have length 3 or 4')
|
||||
|
||||
|
||||
def parse_str(value: str) -> RGBA:
|
||||
"""
|
||||
Parse a string to an RGBA tuple, trying the following formats (in this order):
|
||||
* named color, see COLORS_BY_NAME below
|
||||
* hex short eg. `<prefix>fff` (prefix can be `#`, `0x` or nothing)
|
||||
* hex long eg. `<prefix>ffffff` (prefix can be `#`, `0x` or nothing)
|
||||
* `rgb(<r>, <g>, <b>) `
|
||||
* `rgba(<r>, <g>, <b>, <a>)`
|
||||
"""
|
||||
value_lower = value.lower()
|
||||
try:
|
||||
r, g, b = COLORS_BY_NAME[value_lower]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
return ints_to_rgba(r, g, b, None)
|
||||
|
||||
m = re.fullmatch(r_hex_short, value_lower)
|
||||
if m:
|
||||
*rgb, a = m.groups()
|
||||
r, g, b = (int(v * 2, 16) for v in rgb)
|
||||
if a:
|
||||
alpha: Optional[float] = int(a * 2, 16) / 255
|
||||
else:
|
||||
alpha = None
|
||||
return ints_to_rgba(r, g, b, alpha)
|
||||
|
||||
m = re.fullmatch(r_hex_long, value_lower)
|
||||
if m:
|
||||
*rgb, a = m.groups()
|
||||
r, g, b = (int(v, 16) for v in rgb)
|
||||
if a:
|
||||
alpha = int(a, 16) / 255
|
||||
else:
|
||||
alpha = None
|
||||
return ints_to_rgba(r, g, b, alpha)
|
||||
|
||||
m = re.fullmatch(r_rgb, value_lower)
|
||||
if m:
|
||||
return ints_to_rgba(*m.groups(), None) # type: ignore
|
||||
|
||||
m = re.fullmatch(r_rgba, value_lower)
|
||||
if m:
|
||||
return ints_to_rgba(*m.groups()) # type: ignore
|
||||
|
||||
m = re.fullmatch(r_hsl, value_lower)
|
||||
if m:
|
||||
h, h_units, s, l_ = m.groups()
|
||||
return parse_hsl(h, h_units, s, l_)
|
||||
|
||||
m = re.fullmatch(r_hsla, value_lower)
|
||||
if m:
|
||||
h, h_units, s, l_, a = m.groups()
|
||||
return parse_hsl(h, h_units, s, l_, parse_float_alpha(a))
|
||||
|
||||
raise ColorError(reason='string not recognised as a valid color')
|
||||
|
||||
|
||||
def ints_to_rgba(r: Union[int, str], g: Union[int, str], b: Union[int, str], alpha: Optional[float]) -> RGBA:
|
||||
return RGBA(parse_color_value(r), parse_color_value(g), parse_color_value(b), parse_float_alpha(alpha))
|
||||
|
||||
|
||||
def parse_color_value(value: Union[int, str], max_val: int = 255) -> float:
|
||||
"""
|
||||
Parse a value checking it's a valid int in the range 0 to max_val and divide by max_val to give a number
|
||||
in the range 0 to 1
|
||||
"""
|
||||
try:
|
||||
color = float(value)
|
||||
except ValueError:
|
||||
raise ColorError(reason='color values must be a valid number')
|
||||
if 0 <= color <= max_val:
|
||||
return color / max_val
|
||||
else:
|
||||
raise ColorError(reason=f'color values must be in the range 0 to {max_val}')
|
||||
|
||||
|
||||
def parse_float_alpha(value: Union[None, str, float, int]) -> Optional[float]:
|
||||
"""
|
||||
Parse a value checking it's a valid float in the range 0 to 1
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
if isinstance(value, str) and value.endswith('%'):
|
||||
alpha = float(value[:-1]) / 100
|
||||
else:
|
||||
alpha = float(value)
|
||||
except ValueError:
|
||||
raise ColorError(reason='alpha values must be a valid float')
|
||||
|
||||
if almost_equal_floats(alpha, 1):
|
||||
return None
|
||||
elif 0 <= alpha <= 1:
|
||||
return alpha
|
||||
else:
|
||||
raise ColorError(reason='alpha values must be in the range 0 to 1')
|
||||
|
||||
|
||||
def parse_hsl(h: str, h_units: str, sat: str, light: str, alpha: Optional[float] = None) -> RGBA:
|
||||
"""
|
||||
Parse raw hue, saturation, lightness and alpha values and convert to RGBA.
|
||||
"""
|
||||
s_value, l_value = parse_color_value(sat, 100), parse_color_value(light, 100)
|
||||
|
||||
h_value = float(h)
|
||||
if h_units in {None, 'deg'}:
|
||||
h_value = h_value % 360 / 360
|
||||
elif h_units == 'rad':
|
||||
h_value = h_value % rads / rads
|
||||
else:
|
||||
# turns
|
||||
h_value = h_value % 1
|
||||
|
||||
r, g, b = hls_to_rgb(h_value, l_value, s_value)
|
||||
return RGBA(r, g, b, alpha)
|
||||
|
||||
|
||||
def float_to_255(c: float) -> int:
|
||||
return int(round(c * 255))
|
||||
|
||||
|
||||
COLORS_BY_NAME = {
|
||||
'aliceblue': (240, 248, 255),
|
||||
'antiquewhite': (250, 235, 215),
|
||||
'aqua': (0, 255, 255),
|
||||
'aquamarine': (127, 255, 212),
|
||||
'azure': (240, 255, 255),
|
||||
'beige': (245, 245, 220),
|
||||
'bisque': (255, 228, 196),
|
||||
'black': (0, 0, 0),
|
||||
'blanchedalmond': (255, 235, 205),
|
||||
'blue': (0, 0, 255),
|
||||
'blueviolet': (138, 43, 226),
|
||||
'brown': (165, 42, 42),
|
||||
'burlywood': (222, 184, 135),
|
||||
'cadetblue': (95, 158, 160),
|
||||
'chartreuse': (127, 255, 0),
|
||||
'chocolate': (210, 105, 30),
|
||||
'coral': (255, 127, 80),
|
||||
'cornflowerblue': (100, 149, 237),
|
||||
'cornsilk': (255, 248, 220),
|
||||
'crimson': (220, 20, 60),
|
||||
'cyan': (0, 255, 255),
|
||||
'darkblue': (0, 0, 139),
|
||||
'darkcyan': (0, 139, 139),
|
||||
'darkgoldenrod': (184, 134, 11),
|
||||
'darkgray': (169, 169, 169),
|
||||
'darkgreen': (0, 100, 0),
|
||||
'darkgrey': (169, 169, 169),
|
||||
'darkkhaki': (189, 183, 107),
|
||||
'darkmagenta': (139, 0, 139),
|
||||
'darkolivegreen': (85, 107, 47),
|
||||
'darkorange': (255, 140, 0),
|
||||
'darkorchid': (153, 50, 204),
|
||||
'darkred': (139, 0, 0),
|
||||
'darksalmon': (233, 150, 122),
|
||||
'darkseagreen': (143, 188, 143),
|
||||
'darkslateblue': (72, 61, 139),
|
||||
'darkslategray': (47, 79, 79),
|
||||
'darkslategrey': (47, 79, 79),
|
||||
'darkturquoise': (0, 206, 209),
|
||||
'darkviolet': (148, 0, 211),
|
||||
'deeppink': (255, 20, 147),
|
||||
'deepskyblue': (0, 191, 255),
|
||||
'dimgray': (105, 105, 105),
|
||||
'dimgrey': (105, 105, 105),
|
||||
'dodgerblue': (30, 144, 255),
|
||||
'firebrick': (178, 34, 34),
|
||||
'floralwhite': (255, 250, 240),
|
||||
'forestgreen': (34, 139, 34),
|
||||
'fuchsia': (255, 0, 255),
|
||||
'gainsboro': (220, 220, 220),
|
||||
'ghostwhite': (248, 248, 255),
|
||||
'gold': (255, 215, 0),
|
||||
'goldenrod': (218, 165, 32),
|
||||
'gray': (128, 128, 128),
|
||||
'green': (0, 128, 0),
|
||||
'greenyellow': (173, 255, 47),
|
||||
'grey': (128, 128, 128),
|
||||
'honeydew': (240, 255, 240),
|
||||
'hotpink': (255, 105, 180),
|
||||
'indianred': (205, 92, 92),
|
||||
'indigo': (75, 0, 130),
|
||||
'ivory': (255, 255, 240),
|
||||
'khaki': (240, 230, 140),
|
||||
'lavender': (230, 230, 250),
|
||||
'lavenderblush': (255, 240, 245),
|
||||
'lawngreen': (124, 252, 0),
|
||||
'lemonchiffon': (255, 250, 205),
|
||||
'lightblue': (173, 216, 230),
|
||||
'lightcoral': (240, 128, 128),
|
||||
'lightcyan': (224, 255, 255),
|
||||
'lightgoldenrodyellow': (250, 250, 210),
|
||||
'lightgray': (211, 211, 211),
|
||||
'lightgreen': (144, 238, 144),
|
||||
'lightgrey': (211, 211, 211),
|
||||
'lightpink': (255, 182, 193),
|
||||
'lightsalmon': (255, 160, 122),
|
||||
'lightseagreen': (32, 178, 170),
|
||||
'lightskyblue': (135, 206, 250),
|
||||
'lightslategray': (119, 136, 153),
|
||||
'lightslategrey': (119, 136, 153),
|
||||
'lightsteelblue': (176, 196, 222),
|
||||
'lightyellow': (255, 255, 224),
|
||||
'lime': (0, 255, 0),
|
||||
'limegreen': (50, 205, 50),
|
||||
'linen': (250, 240, 230),
|
||||
'magenta': (255, 0, 255),
|
||||
'maroon': (128, 0, 0),
|
||||
'mediumaquamarine': (102, 205, 170),
|
||||
'mediumblue': (0, 0, 205),
|
||||
'mediumorchid': (186, 85, 211),
|
||||
'mediumpurple': (147, 112, 219),
|
||||
'mediumseagreen': (60, 179, 113),
|
||||
'mediumslateblue': (123, 104, 238),
|
||||
'mediumspringgreen': (0, 250, 154),
|
||||
'mediumturquoise': (72, 209, 204),
|
||||
'mediumvioletred': (199, 21, 133),
|
||||
'midnightblue': (25, 25, 112),
|
||||
'mintcream': (245, 255, 250),
|
||||
'mistyrose': (255, 228, 225),
|
||||
'moccasin': (255, 228, 181),
|
||||
'navajowhite': (255, 222, 173),
|
||||
'navy': (0, 0, 128),
|
||||
'oldlace': (253, 245, 230),
|
||||
'olive': (128, 128, 0),
|
||||
'olivedrab': (107, 142, 35),
|
||||
'orange': (255, 165, 0),
|
||||
'orangered': (255, 69, 0),
|
||||
'orchid': (218, 112, 214),
|
||||
'palegoldenrod': (238, 232, 170),
|
||||
'palegreen': (152, 251, 152),
|
||||
'paleturquoise': (175, 238, 238),
|
||||
'palevioletred': (219, 112, 147),
|
||||
'papayawhip': (255, 239, 213),
|
||||
'peachpuff': (255, 218, 185),
|
||||
'peru': (205, 133, 63),
|
||||
'pink': (255, 192, 203),
|
||||
'plum': (221, 160, 221),
|
||||
'powderblue': (176, 224, 230),
|
||||
'purple': (128, 0, 128),
|
||||
'red': (255, 0, 0),
|
||||
'rosybrown': (188, 143, 143),
|
||||
'royalblue': (65, 105, 225),
|
||||
'saddlebrown': (139, 69, 19),
|
||||
'salmon': (250, 128, 114),
|
||||
'sandybrown': (244, 164, 96),
|
||||
'seagreen': (46, 139, 87),
|
||||
'seashell': (255, 245, 238),
|
||||
'sienna': (160, 82, 45),
|
||||
'silver': (192, 192, 192),
|
||||
'skyblue': (135, 206, 235),
|
||||
'slateblue': (106, 90, 205),
|
||||
'slategray': (112, 128, 144),
|
||||
'slategrey': (112, 128, 144),
|
||||
'snow': (255, 250, 250),
|
||||
'springgreen': (0, 255, 127),
|
||||
'steelblue': (70, 130, 180),
|
||||
'tan': (210, 180, 140),
|
||||
'teal': (0, 128, 128),
|
||||
'thistle': (216, 191, 216),
|
||||
'tomato': (255, 99, 71),
|
||||
'turquoise': (64, 224, 208),
|
||||
'violet': (238, 130, 238),
|
||||
'wheat': (245, 222, 179),
|
||||
'white': (255, 255, 255),
|
||||
'whitesmoke': (245, 245, 245),
|
||||
'yellow': (255, 255, 0),
|
||||
'yellowgreen': (154, 205, 50),
|
||||
}
|
||||
|
||||
COLORS_BY_VALUE = {v: k for k, v in COLORS_BY_NAME.items()}
|
192
lib/pydantic/config.py
Normal file
192
lib/pydantic/config.py
Normal file
|
@ -0,0 +1,192 @@
|
|||
import json
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, ForwardRef, Optional, Tuple, Type, Union
|
||||
|
||||
from typing_extensions import Literal, Protocol
|
||||
|
||||
from .typing import AnyArgTCallable, AnyCallable
|
||||
from .utils import GetterDict
|
||||
from .version import compiled
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import overload
|
||||
|
||||
from .fields import ModelField
|
||||
from .main import BaseModel
|
||||
|
||||
ConfigType = Type['BaseConfig']
|
||||
|
||||
class SchemaExtraCallable(Protocol):
|
||||
@overload
|
||||
def __call__(self, schema: Dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
@overload
|
||||
def __call__(self, schema: Dict[str, Any], model_class: Type[BaseModel]) -> None:
|
||||
pass
|
||||
|
||||
else:
|
||||
SchemaExtraCallable = Callable[..., None]
|
||||
|
||||
__all__ = 'BaseConfig', 'ConfigDict', 'get_config', 'Extra', 'inherit_config', 'prepare_config'
|
||||
|
||||
|
||||
class Extra(str, Enum):
|
||||
allow = 'allow'
|
||||
ignore = 'ignore'
|
||||
forbid = 'forbid'
|
||||
|
||||
|
||||
# https://github.com/cython/cython/issues/4003
|
||||
# Will be fixed with Cython 3 but still in alpha right now
|
||||
if not compiled:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
class ConfigDict(TypedDict, total=False):
|
||||
title: Optional[str]
|
||||
anystr_lower: bool
|
||||
anystr_strip_whitespace: bool
|
||||
min_anystr_length: int
|
||||
max_anystr_length: Optional[int]
|
||||
validate_all: bool
|
||||
extra: Extra
|
||||
allow_mutation: bool
|
||||
frozen: bool
|
||||
allow_population_by_field_name: bool
|
||||
use_enum_values: bool
|
||||
fields: Dict[str, Union[str, Dict[str, str]]]
|
||||
validate_assignment: bool
|
||||
error_msg_templates: Dict[str, str]
|
||||
arbitrary_types_allowed: bool
|
||||
orm_mode: bool
|
||||
getter_dict: Type[GetterDict]
|
||||
alias_generator: Optional[Callable[[str], str]]
|
||||
keep_untouched: Tuple[type, ...]
|
||||
schema_extra: Union[Dict[str, object], 'SchemaExtraCallable']
|
||||
json_loads: Callable[[str], object]
|
||||
json_dumps: AnyArgTCallable[str]
|
||||
json_encoders: Dict[Type[object], AnyCallable]
|
||||
underscore_attrs_are_private: bool
|
||||
allow_inf_nan: bool
|
||||
|
||||
# whether or not inherited models as fields should be reconstructed as base model
|
||||
copy_on_model_validation: bool
|
||||
# whether dataclass `__post_init__` should be run after validation
|
||||
post_init_call: Literal['before_validation', 'after_validation']
|
||||
|
||||
else:
|
||||
ConfigDict = dict # type: ignore
|
||||
|
||||
|
||||
class BaseConfig:
|
||||
title: Optional[str] = None
|
||||
anystr_lower: bool = False
|
||||
anystr_upper: bool = False
|
||||
anystr_strip_whitespace: bool = False
|
||||
min_anystr_length: int = 0
|
||||
max_anystr_length: Optional[int] = None
|
||||
validate_all: bool = False
|
||||
extra: Extra = Extra.ignore
|
||||
allow_mutation: bool = True
|
||||
frozen: bool = False
|
||||
allow_population_by_field_name: bool = False
|
||||
use_enum_values: bool = False
|
||||
fields: Dict[str, Union[str, Dict[str, str]]] = {}
|
||||
validate_assignment: bool = False
|
||||
error_msg_templates: Dict[str, str] = {}
|
||||
arbitrary_types_allowed: bool = False
|
||||
orm_mode: bool = False
|
||||
getter_dict: Type[GetterDict] = GetterDict
|
||||
alias_generator: Optional[Callable[[str], str]] = None
|
||||
keep_untouched: Tuple[type, ...] = ()
|
||||
schema_extra: Union[Dict[str, Any], 'SchemaExtraCallable'] = {}
|
||||
json_loads: Callable[[str], Any] = json.loads
|
||||
json_dumps: Callable[..., str] = json.dumps
|
||||
json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable] = {}
|
||||
underscore_attrs_are_private: bool = False
|
||||
allow_inf_nan: bool = True
|
||||
|
||||
# whether inherited models as fields should be reconstructed as base model,
|
||||
# and whether such a copy should be shallow or deep
|
||||
copy_on_model_validation: Literal['none', 'deep', 'shallow'] = 'shallow'
|
||||
|
||||
# whether `Union` should check all allowed types before even trying to coerce
|
||||
smart_union: bool = False
|
||||
# whether dataclass `__post_init__` should be run before or after validation
|
||||
post_init_call: Literal['before_validation', 'after_validation'] = 'before_validation'
|
||||
|
||||
@classmethod
|
||||
def get_field_info(cls, name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get properties of FieldInfo from the `fields` property of the config class.
|
||||
"""
|
||||
|
||||
fields_value = cls.fields.get(name)
|
||||
|
||||
if isinstance(fields_value, str):
|
||||
field_info: Dict[str, Any] = {'alias': fields_value}
|
||||
elif isinstance(fields_value, dict):
|
||||
field_info = fields_value
|
||||
else:
|
||||
field_info = {}
|
||||
|
||||
if 'alias' in field_info:
|
||||
field_info.setdefault('alias_priority', 2)
|
||||
|
||||
if field_info.get('alias_priority', 0) <= 1 and cls.alias_generator:
|
||||
alias = cls.alias_generator(name)
|
||||
if not isinstance(alias, str):
|
||||
raise TypeError(f'Config.alias_generator must return str, not {alias.__class__}')
|
||||
field_info.update(alias=alias, alias_priority=1)
|
||||
return field_info
|
||||
|
||||
@classmethod
|
||||
def prepare_field(cls, field: 'ModelField') -> None:
|
||||
"""
|
||||
Optional hook to check or modify fields during model creation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_config(config: Union[ConfigDict, Type[object], None]) -> Type[BaseConfig]:
|
||||
if config is None:
|
||||
return BaseConfig
|
||||
|
||||
else:
|
||||
config_dict = (
|
||||
config
|
||||
if isinstance(config, dict)
|
||||
else {k: getattr(config, k) for k in dir(config) if not k.startswith('__')}
|
||||
)
|
||||
|
||||
class Config(BaseConfig):
|
||||
...
|
||||
|
||||
for k, v in config_dict.items():
|
||||
setattr(Config, k, v)
|
||||
return Config
|
||||
|
||||
|
||||
def inherit_config(self_config: 'ConfigType', parent_config: 'ConfigType', **namespace: Any) -> 'ConfigType':
|
||||
if not self_config:
|
||||
base_classes: Tuple['ConfigType', ...] = (parent_config,)
|
||||
elif self_config == parent_config:
|
||||
base_classes = (self_config,)
|
||||
else:
|
||||
base_classes = self_config, parent_config
|
||||
|
||||
namespace['json_encoders'] = {
|
||||
**getattr(parent_config, 'json_encoders', {}),
|
||||
**getattr(self_config, 'json_encoders', {}),
|
||||
**namespace.get('json_encoders', {}),
|
||||
}
|
||||
|
||||
return type('Config', base_classes, namespace)
|
||||
|
||||
|
||||
def prepare_config(config: Type[BaseConfig], cls_name: str) -> None:
|
||||
if not isinstance(config.extra, Extra):
|
||||
try:
|
||||
config.extra = Extra(config.extra)
|
||||
except ValueError:
|
||||
raise ValueError(f'"{cls_name}": {config.extra} is not a valid value for "extra"')
|
479
lib/pydantic/dataclasses.py
Normal file
479
lib/pydantic/dataclasses.py
Normal file
|
@ -0,0 +1,479 @@
|
|||
"""
|
||||
The main purpose is to enhance stdlib dataclasses by adding validation
|
||||
A pydantic dataclass can be generated from scratch or from a stdlib one.
|
||||
|
||||
Behind the scene, a pydantic dataclass is just like a regular one on which we attach
|
||||
a `BaseModel` and magic methods to trigger the validation of the data.
|
||||
`__init__` and `__post_init__` are hence overridden and have extra logic to be
|
||||
able to validate input data.
|
||||
|
||||
When a pydantic dataclass is generated from scratch, it's just a plain dataclass
|
||||
with validation triggered at initialization
|
||||
|
||||
The tricky part if for stdlib dataclasses that are converted after into pydantic ones e.g.
|
||||
|
||||
```py
|
||||
@dataclasses.dataclass
|
||||
class M:
|
||||
x: int
|
||||
|
||||
ValidatedM = pydantic.dataclasses.dataclass(M)
|
||||
```
|
||||
|
||||
We indeed still want to support equality, hashing, repr, ... as if it was the stdlib one!
|
||||
|
||||
```py
|
||||
assert isinstance(ValidatedM(x=1), M)
|
||||
assert ValidatedM(x=1) == M(x=1)
|
||||
```
|
||||
|
||||
This means we **don't want to create a new dataclass that inherits from it**
|
||||
The trick is to create a wrapper around `M` that will act as a proxy to trigger
|
||||
validation without altering default `M` behaviour.
|
||||
"""
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Generator,
|
||||
Optional,
|
||||
Set,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
from .class_validators import gather_all_validators
|
||||
from .config import BaseConfig, ConfigDict, Extra, get_config
|
||||
from .error_wrappers import ValidationError
|
||||
from .errors import DataclassTypeError
|
||||
from .fields import Field, FieldInfo, Required, Undefined
|
||||
from .main import create_model, validate_model
|
||||
from .utils import ClassAttribute
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .main import BaseModel
|
||||
from .typing import CallableGenerator, NoArgAnyCallable
|
||||
|
||||
DataclassT = TypeVar('DataclassT', bound='Dataclass')
|
||||
|
||||
DataclassClassOrWrapper = Union[Type['Dataclass'], 'DataclassProxy']
|
||||
|
||||
class Dataclass:
|
||||
# stdlib attributes
|
||||
__dataclass_fields__: ClassVar[Dict[str, Any]]
|
||||
__dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
|
||||
__post_init__: ClassVar[Callable[..., None]]
|
||||
|
||||
# Added by pydantic
|
||||
__pydantic_run_validation__: ClassVar[bool]
|
||||
__post_init_post_parse__: ClassVar[Callable[..., None]]
|
||||
__pydantic_initialised__: ClassVar[bool]
|
||||
__pydantic_model__: ClassVar[Type[BaseModel]]
|
||||
__pydantic_validate_values__: ClassVar[Callable[['Dataclass'], None]]
|
||||
__pydantic_has_field_info_default__: ClassVar[bool] # whether a `pydantic.Field` is used as default value
|
||||
|
||||
def __init__(self, *args: object, **kwargs: object) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator':
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
|
||||
pass
|
||||
|
||||
|
||||
__all__ = [
|
||||
'dataclass',
|
||||
'set_validation',
|
||||
'create_pydantic_model_from_dataclass',
|
||||
'is_builtin_dataclass',
|
||||
'make_dataclass_validator',
|
||||
]
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
|
||||
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo))
|
||||
@overload
|
||||
def dataclass(
|
||||
*,
|
||||
init: bool = True,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: Union[ConfigDict, Type[object], None] = None,
|
||||
validate_on_init: Optional[bool] = None,
|
||||
kw_only: bool = ...,
|
||||
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
|
||||
...
|
||||
|
||||
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo))
|
||||
@overload
|
||||
def dataclass(
|
||||
_cls: Type[_T],
|
||||
*,
|
||||
init: bool = True,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: Union[ConfigDict, Type[object], None] = None,
|
||||
validate_on_init: Optional[bool] = None,
|
||||
kw_only: bool = ...,
|
||||
) -> 'DataclassClassOrWrapper':
|
||||
...
|
||||
|
||||
else:
|
||||
|
||||
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo))
|
||||
@overload
|
||||
def dataclass(
|
||||
*,
|
||||
init: bool = True,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: Union[ConfigDict, Type[object], None] = None,
|
||||
validate_on_init: Optional[bool] = None,
|
||||
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
|
||||
...
|
||||
|
||||
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo))
|
||||
@overload
|
||||
def dataclass(
|
||||
_cls: Type[_T],
|
||||
*,
|
||||
init: bool = True,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: Union[ConfigDict, Type[object], None] = None,
|
||||
validate_on_init: Optional[bool] = None,
|
||||
) -> 'DataclassClassOrWrapper':
|
||||
...
|
||||
|
||||
|
||||
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo))
|
||||
def dataclass(
|
||||
_cls: Optional[Type[_T]] = None,
|
||||
*,
|
||||
init: bool = True,
|
||||
repr: bool = True,
|
||||
eq: bool = True,
|
||||
order: bool = False,
|
||||
unsafe_hash: bool = False,
|
||||
frozen: bool = False,
|
||||
config: Union[ConfigDict, Type[object], None] = None,
|
||||
validate_on_init: Optional[bool] = None,
|
||||
kw_only: bool = False,
|
||||
) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']:
|
||||
"""
|
||||
Like the python standard lib dataclasses but with type validation.
|
||||
The result is either a pydantic dataclass that will validate input data
|
||||
or a wrapper that will trigger validation around a stdlib dataclass
|
||||
to avoid modifying it directly
|
||||
"""
|
||||
the_config = get_config(config)
|
||||
|
||||
def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper':
|
||||
import dataclasses
|
||||
|
||||
if is_builtin_dataclass(cls) and _extra_dc_args(_cls) == _extra_dc_args(_cls.__bases__[0]): # type: ignore
|
||||
dc_cls_doc = ''
|
||||
dc_cls = DataclassProxy(cls)
|
||||
default_validate_on_init = False
|
||||
else:
|
||||
dc_cls_doc = cls.__doc__ or '' # needs to be done before generating dataclass
|
||||
if sys.version_info >= (3, 10):
|
||||
dc_cls = dataclasses.dataclass(
|
||||
cls,
|
||||
init=init,
|
||||
repr=repr,
|
||||
eq=eq,
|
||||
order=order,
|
||||
unsafe_hash=unsafe_hash,
|
||||
frozen=frozen,
|
||||
kw_only=kw_only,
|
||||
)
|
||||
else:
|
||||
dc_cls = dataclasses.dataclass( # type: ignore
|
||||
cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
|
||||
)
|
||||
default_validate_on_init = True
|
||||
|
||||
should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init
|
||||
_add_pydantic_validation_attributes(cls, the_config, should_validate_on_init, dc_cls_doc)
|
||||
dc_cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls})
|
||||
return dc_cls
|
||||
|
||||
if _cls is None:
|
||||
return wrap
|
||||
|
||||
return wrap(_cls)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_validation(cls: Type['DataclassT'], value: bool) -> Generator[Type['DataclassT'], None, None]:
|
||||
original_run_validation = cls.__pydantic_run_validation__
|
||||
try:
|
||||
cls.__pydantic_run_validation__ = value
|
||||
yield cls
|
||||
finally:
|
||||
cls.__pydantic_run_validation__ = original_run_validation
|
||||
|
||||
|
||||
class DataclassProxy:
|
||||
__slots__ = '__dataclass__'
|
||||
|
||||
def __init__(self, dc_cls: Type['Dataclass']) -> None:
|
||||
object.__setattr__(self, '__dataclass__', dc_cls)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
with set_validation(self.__dataclass__, True):
|
||||
return self.__dataclass__(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self.__dataclass__, name)
|
||||
|
||||
def __instancecheck__(self, instance: Any) -> bool:
|
||||
return isinstance(instance, self.__dataclass__)
|
||||
|
||||
|
||||
def _add_pydantic_validation_attributes( # noqa: C901 (ignore complexity)
|
||||
dc_cls: Type['Dataclass'],
|
||||
config: Type[BaseConfig],
|
||||
validate_on_init: bool,
|
||||
dc_cls_doc: str,
|
||||
) -> None:
|
||||
"""
|
||||
We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass
|
||||
it won't even exist (code is generated on the fly by `dataclasses`)
|
||||
By default, we run validation after `__init__` or `__post_init__` if defined
|
||||
"""
|
||||
init = dc_cls.__init__
|
||||
|
||||
@wraps(init)
|
||||
def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
|
||||
if config.extra == Extra.ignore:
|
||||
init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
|
||||
|
||||
elif config.extra == Extra.allow:
|
||||
for k, v in kwargs.items():
|
||||
self.__dict__.setdefault(k, v)
|
||||
init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
|
||||
|
||||
else:
|
||||
init(self, *args, **kwargs)
|
||||
|
||||
if hasattr(dc_cls, '__post_init__'):
|
||||
post_init = dc_cls.__post_init__
|
||||
|
||||
@wraps(post_init)
|
||||
def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
|
||||
if config.post_init_call == 'before_validation':
|
||||
post_init(self, *args, **kwargs)
|
||||
|
||||
if self.__class__.__pydantic_run_validation__:
|
||||
self.__pydantic_validate_values__()
|
||||
if hasattr(self, '__post_init_post_parse__'):
|
||||
self.__post_init_post_parse__(*args, **kwargs)
|
||||
|
||||
if config.post_init_call == 'after_validation':
|
||||
post_init(self, *args, **kwargs)
|
||||
|
||||
setattr(dc_cls, '__init__', handle_extra_init)
|
||||
setattr(dc_cls, '__post_init__', new_post_init)
|
||||
|
||||
else:
|
||||
|
||||
@wraps(init)
|
||||
def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
|
||||
handle_extra_init(self, *args, **kwargs)
|
||||
|
||||
if self.__class__.__pydantic_run_validation__:
|
||||
self.__pydantic_validate_values__()
|
||||
|
||||
if hasattr(self, '__post_init_post_parse__'):
|
||||
# We need to find again the initvars. To do that we use `__dataclass_fields__` instead of
|
||||
# public method `dataclasses.fields`
|
||||
import dataclasses
|
||||
|
||||
# get all initvars and their default values
|
||||
initvars_and_values: Dict[str, Any] = {}
|
||||
for i, f in enumerate(self.__class__.__dataclass_fields__.values()):
|
||||
if f._field_type is dataclasses._FIELD_INITVAR: # type: ignore[attr-defined]
|
||||
try:
|
||||
# set arg value by default
|
||||
initvars_and_values[f.name] = args[i]
|
||||
except IndexError:
|
||||
initvars_and_values[f.name] = kwargs.get(f.name, f.default)
|
||||
|
||||
self.__post_init_post_parse__(**initvars_and_values)
|
||||
|
||||
setattr(dc_cls, '__init__', new_init)
|
||||
|
||||
setattr(dc_cls, '__pydantic_run_validation__', ClassAttribute('__pydantic_run_validation__', validate_on_init))
|
||||
setattr(dc_cls, '__pydantic_initialised__', False)
|
||||
setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config, dc_cls_doc))
|
||||
setattr(dc_cls, '__pydantic_validate_values__', _dataclass_validate_values)
|
||||
setattr(dc_cls, '__validate__', classmethod(_validate_dataclass))
|
||||
setattr(dc_cls, '__get_validators__', classmethod(_get_validators))
|
||||
|
||||
if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen:
|
||||
setattr(dc_cls, '__setattr__', _dataclass_validate_assignment_setattr)
|
||||
|
||||
|
||||
def _get_validators(cls: 'DataclassClassOrWrapper') -> 'CallableGenerator':
|
||||
yield cls.__validate__
|
||||
|
||||
|
||||
def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
|
||||
with set_validation(cls, True):
|
||||
if isinstance(v, cls):
|
||||
v.__pydantic_validate_values__()
|
||||
return v
|
||||
elif isinstance(v, (list, tuple)):
|
||||
return cls(*v)
|
||||
elif isinstance(v, dict):
|
||||
return cls(**v)
|
||||
else:
|
||||
raise DataclassTypeError(class_name=cls.__name__)
|
||||
|
||||
|
||||
def create_pydantic_model_from_dataclass(
|
||||
dc_cls: Type['Dataclass'],
|
||||
config: Type[Any] = BaseConfig,
|
||||
dc_cls_doc: Optional[str] = None,
|
||||
) -> Type['BaseModel']:
|
||||
import dataclasses
|
||||
|
||||
field_definitions: Dict[str, Any] = {}
|
||||
for field in dataclasses.fields(dc_cls):
|
||||
default: Any = Undefined
|
||||
default_factory: Optional['NoArgAnyCallable'] = None
|
||||
field_info: FieldInfo
|
||||
|
||||
if field.default is not dataclasses.MISSING:
|
||||
default = field.default
|
||||
elif field.default_factory is not dataclasses.MISSING:
|
||||
default_factory = field.default_factory
|
||||
else:
|
||||
default = Required
|
||||
|
||||
if isinstance(default, FieldInfo):
|
||||
field_info = default
|
||||
dc_cls.__pydantic_has_field_info_default__ = True
|
||||
else:
|
||||
field_info = Field(default=default, default_factory=default_factory, **field.metadata)
|
||||
|
||||
field_definitions[field.name] = (field.type, field_info)
|
||||
|
||||
validators = gather_all_validators(dc_cls)
|
||||
model: Type['BaseModel'] = create_model(
|
||||
dc_cls.__name__,
|
||||
__config__=config,
|
||||
__module__=dc_cls.__module__,
|
||||
__validators__=validators,
|
||||
__cls_kwargs__={'__resolve_forward_refs__': False},
|
||||
**field_definitions,
|
||||
)
|
||||
model.__doc__ = dc_cls_doc if dc_cls_doc is not None else dc_cls.__doc__ or ''
|
||||
return model
|
||||
|
||||
|
||||
def _dataclass_validate_values(self: 'Dataclass') -> None:
|
||||
# validation errors can occur if this function is called twice on an already initialised dataclass.
|
||||
# for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property
|
||||
if getattr(self, '__pydantic_initialised__'):
|
||||
return
|
||||
if getattr(self, '__pydantic_has_field_info_default__', False):
|
||||
# We need to remove `FieldInfo` values since they are not valid as input
|
||||
# It's ok to do that because they are obviously the default values!
|
||||
input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)}
|
||||
else:
|
||||
input_data = self.__dict__
|
||||
d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
|
||||
if validation_error:
|
||||
raise validation_error
|
||||
self.__dict__.update(d)
|
||||
object.__setattr__(self, '__pydantic_initialised__', True)
|
||||
|
||||
|
||||
def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None:
|
||||
if self.__pydantic_initialised__:
|
||||
d = dict(self.__dict__)
|
||||
d.pop(name, None)
|
||||
known_field = self.__pydantic_model__.__fields__.get(name, None)
|
||||
if known_field:
|
||||
value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__)
|
||||
if error_:
|
||||
raise ValidationError([error_], self.__class__)
|
||||
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
|
||||
def _extra_dc_args(cls: Type[Any]) -> Set[str]:
|
||||
return {
|
||||
x
|
||||
for x in dir(cls)
|
||||
if x not in getattr(cls, '__dataclass_fields__', {}) and not (x.startswith('__') and x.endswith('__'))
|
||||
}
|
||||
|
||||
|
||||
def is_builtin_dataclass(_cls: Type[Any]) -> bool:
|
||||
"""
|
||||
Whether a class is a stdlib dataclass
|
||||
(useful to discriminated a pydantic dataclass that is actually a wrapper around a stdlib dataclass)
|
||||
|
||||
we check that
|
||||
- `_cls` is a dataclass
|
||||
- `_cls` is not a processed pydantic dataclass (with a basemodel attached)
|
||||
- `_cls` is not a pydantic dataclass inheriting directly from a stdlib dataclass
|
||||
e.g.
|
||||
```
|
||||
@dataclasses.dataclass
|
||||
class A:
|
||||
x: int
|
||||
|
||||
@pydantic.dataclasses.dataclass
|
||||
class B(A):
|
||||
y: int
|
||||
```
|
||||
In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
|
||||
which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
|
||||
"""
|
||||
import dataclasses
|
||||
|
||||
return (
|
||||
dataclasses.is_dataclass(_cls)
|
||||
and not hasattr(_cls, '__pydantic_model__')
|
||||
and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {})))
|
||||
)
|
||||
|
||||
|
||||
def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]) -> 'CallableGenerator':
|
||||
"""
|
||||
Create a pydantic.dataclass from a builtin dataclass to add type validation
|
||||
and yield the validators
|
||||
It retrieves the parameters of the dataclass and forwards them to the newly created dataclass
|
||||
"""
|
||||
yield from _get_validators(dataclass(dc_cls, config=config, validate_on_init=False))
|
248
lib/pydantic/datetime_parse.py
Normal file
248
lib/pydantic/datetime_parse.py
Normal file
|
@ -0,0 +1,248 @@
|
|||
"""
|
||||
Functions to parse datetime objects.
|
||||
|
||||
We're using regular expressions rather than time.strptime because:
|
||||
- They provide both validation and parsing.
|
||||
- They're more flexible for datetimes.
|
||||
- The date/datetime/time constructors produce friendlier error messages.
|
||||
|
||||
Stolen from https://raw.githubusercontent.com/django/django/main/django/utils/dateparse.py at
|
||||
9718fa2e8abe430c3526a9278dd976443d4ae3c6
|
||||
|
||||
Changed to:
|
||||
* use standard python datetime types not django.utils.timezone
|
||||
* raise ValueError when regex doesn't match rather than returning None
|
||||
* support parsing unix timestamps for dates and datetimes
|
||||
"""
|
||||
import re
|
||||
from datetime import date, datetime, time, timedelta, timezone
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
from . import errors
|
||||
|
||||
date_expr = r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})'
|
||||
time_expr = (
|
||||
r'(?P<hour>\d{1,2}):(?P<minute>\d{1,2})'
|
||||
r'(?::(?P<second>\d{1,2})(?:\.(?P<microsecond>\d{1,6})\d{0,6})?)?'
|
||||
r'(?P<tzinfo>Z|[+-]\d{2}(?::?\d{2})?)?$'
|
||||
)
|
||||
|
||||
date_re = re.compile(f'{date_expr}$')
|
||||
time_re = re.compile(time_expr)
|
||||
datetime_re = re.compile(f'{date_expr}[T ]{time_expr}')
|
||||
|
||||
standard_duration_re = re.compile(
|
||||
r'^'
|
||||
r'(?:(?P<days>-?\d+) (days?, )?)?'
|
||||
r'((?:(?P<hours>-?\d+):)(?=\d+:\d+))?'
|
||||
r'(?:(?P<minutes>-?\d+):)?'
|
||||
r'(?P<seconds>-?\d+)'
|
||||
r'(?:\.(?P<microseconds>\d{1,6})\d{0,6})?'
|
||||
r'$'
|
||||
)
|
||||
|
||||
# Support the sections of ISO 8601 date representation that are accepted by timedelta
|
||||
iso8601_duration_re = re.compile(
|
||||
r'^(?P<sign>[-+]?)'
|
||||
r'P'
|
||||
r'(?:(?P<days>\d+(.\d+)?)D)?'
|
||||
r'(?:T'
|
||||
r'(?:(?P<hours>\d+(.\d+)?)H)?'
|
||||
r'(?:(?P<minutes>\d+(.\d+)?)M)?'
|
||||
r'(?:(?P<seconds>\d+(.\d+)?)S)?'
|
||||
r')?'
|
||||
r'$'
|
||||
)
|
||||
|
||||
EPOCH = datetime(1970, 1, 1)
|
||||
# if greater than this, the number is in ms, if less than or equal it's in seconds
|
||||
# (in seconds this is 11th October 2603, in ms it's 20th August 1970)
|
||||
MS_WATERSHED = int(2e10)
|
||||
# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9
|
||||
MAX_NUMBER = int(3e20)
|
||||
StrBytesIntFloat = Union[str, bytes, int, float]
|
||||
|
||||
|
||||
def get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]:
|
||||
if isinstance(value, (int, float)):
|
||||
return value
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return None
|
||||
except TypeError:
|
||||
raise TypeError(f'invalid type; expected {native_expected_type}, string, bytes, int or float')
|
||||
|
||||
|
||||
def from_unix_seconds(seconds: Union[int, float]) -> datetime:
|
||||
if seconds > MAX_NUMBER:
|
||||
return datetime.max
|
||||
elif seconds < -MAX_NUMBER:
|
||||
return datetime.min
|
||||
|
||||
while abs(seconds) > MS_WATERSHED:
|
||||
seconds /= 1000
|
||||
dt = EPOCH + timedelta(seconds=seconds)
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _parse_timezone(value: Optional[str], error: Type[Exception]) -> Union[None, int, timezone]:
|
||||
if value == 'Z':
|
||||
return timezone.utc
|
||||
elif value is not None:
|
||||
offset_mins = int(value[-2:]) if len(value) > 3 else 0
|
||||
offset = 60 * int(value[1:3]) + offset_mins
|
||||
if value[0] == '-':
|
||||
offset = -offset
|
||||
try:
|
||||
return timezone(timedelta(minutes=offset))
|
||||
except ValueError:
|
||||
raise error()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def parse_date(value: Union[date, StrBytesIntFloat]) -> date:
|
||||
"""
|
||||
Parse a date/int/float/string and return a datetime.date.
|
||||
|
||||
Raise ValueError if the input is well formatted but not a valid date.
|
||||
Raise ValueError if the input isn't well formatted.
|
||||
"""
|
||||
if isinstance(value, date):
|
||||
if isinstance(value, datetime):
|
||||
return value.date()
|
||||
else:
|
||||
return value
|
||||
|
||||
number = get_numeric(value, 'date')
|
||||
if number is not None:
|
||||
return from_unix_seconds(number).date()
|
||||
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
|
||||
match = date_re.match(value) # type: ignore
|
||||
if match is None:
|
||||
raise errors.DateError()
|
||||
|
||||
kw = {k: int(v) for k, v in match.groupdict().items()}
|
||||
|
||||
try:
|
||||
return date(**kw)
|
||||
except ValueError:
|
||||
raise errors.DateError()
|
||||
|
||||
|
||||
def parse_time(value: Union[time, StrBytesIntFloat]) -> time:
|
||||
"""
|
||||
Parse a time/string and return a datetime.time.
|
||||
|
||||
Raise ValueError if the input is well formatted but not a valid time.
|
||||
Raise ValueError if the input isn't well formatted, in particular if it contains an offset.
|
||||
"""
|
||||
if isinstance(value, time):
|
||||
return value
|
||||
|
||||
number = get_numeric(value, 'time')
|
||||
if number is not None:
|
||||
if number >= 86400:
|
||||
# doesn't make sense since the time time loop back around to 0
|
||||
raise errors.TimeError()
|
||||
return (datetime.min + timedelta(seconds=number)).time()
|
||||
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
|
||||
match = time_re.match(value) # type: ignore
|
||||
if match is None:
|
||||
raise errors.TimeError()
|
||||
|
||||
kw = match.groupdict()
|
||||
if kw['microsecond']:
|
||||
kw['microsecond'] = kw['microsecond'].ljust(6, '0')
|
||||
|
||||
tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.TimeError)
|
||||
kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None}
|
||||
kw_['tzinfo'] = tzinfo
|
||||
|
||||
try:
|
||||
return time(**kw_) # type: ignore
|
||||
except ValueError:
|
||||
raise errors.TimeError()
|
||||
|
||||
|
||||
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:
|
||||
"""
|
||||
Parse a datetime/int/float/string and return a datetime.datetime.
|
||||
|
||||
This function supports time zone offsets. When the input contains one,
|
||||
the output uses a timezone with a fixed offset from UTC.
|
||||
|
||||
Raise ValueError if the input is well formatted but not a valid datetime.
|
||||
Raise ValueError if the input isn't well formatted.
|
||||
"""
|
||||
if isinstance(value, datetime):
|
||||
return value
|
||||
|
||||
number = get_numeric(value, 'datetime')
|
||||
if number is not None:
|
||||
return from_unix_seconds(number)
|
||||
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
|
||||
match = datetime_re.match(value) # type: ignore
|
||||
if match is None:
|
||||
raise errors.DateTimeError()
|
||||
|
||||
kw = match.groupdict()
|
||||
if kw['microsecond']:
|
||||
kw['microsecond'] = kw['microsecond'].ljust(6, '0')
|
||||
|
||||
tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.DateTimeError)
|
||||
kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None}
|
||||
kw_['tzinfo'] = tzinfo
|
||||
|
||||
try:
|
||||
return datetime(**kw_) # type: ignore
|
||||
except ValueError:
|
||||
raise errors.DateTimeError()
|
||||
|
||||
|
||||
def parse_duration(value: StrBytesIntFloat) -> timedelta:
|
||||
"""
|
||||
Parse a duration int/float/string and return a datetime.timedelta.
|
||||
|
||||
The preferred format for durations in Django is '%d %H:%M:%S.%f'.
|
||||
|
||||
Also supports ISO 8601 representation.
|
||||
"""
|
||||
if isinstance(value, timedelta):
|
||||
return value
|
||||
|
||||
if isinstance(value, (int, float)):
|
||||
# below code requires a string
|
||||
value = f'{value:f}'
|
||||
elif isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
|
||||
try:
|
||||
match = standard_duration_re.match(value) or iso8601_duration_re.match(value)
|
||||
except TypeError:
|
||||
raise TypeError('invalid type; expected timedelta, string, bytes, int or float')
|
||||
|
||||
if not match:
|
||||
raise errors.DurationError()
|
||||
|
||||
kw = match.groupdict()
|
||||
sign = -1 if kw.pop('sign', '+') == '-' else 1
|
||||
if kw.get('microseconds'):
|
||||
kw['microseconds'] = kw['microseconds'].ljust(6, '0')
|
||||
|
||||
if kw.get('seconds') and kw.get('microseconds') and kw['seconds'].startswith('-'):
|
||||
kw['microseconds'] = '-' + kw['microseconds']
|
||||
|
||||
kw_ = {k: float(v) for k, v in kw.items() if v is not None}
|
||||
|
||||
return sign * timedelta(**kw_)
|
264
lib/pydantic/decorator.py
Normal file
264
lib/pydantic/decorator.py
Normal file
|
@ -0,0 +1,264 @@
|
|||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload
|
||||
|
||||
from . import validator
|
||||
from .config import Extra
|
||||
from .errors import ConfigError
|
||||
from .main import BaseModel, create_model
|
||||
from .typing import get_all_type_hints
|
||||
from .utils import to_camel
|
||||
|
||||
__all__ = ('validate_arguments',)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .typing import AnyCallable
|
||||
|
||||
AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable)
|
||||
ConfigType = Union[None, Type[Any], Dict[str, Any]]
|
||||
|
||||
|
||||
@overload
|
||||
def validate_arguments(func: None = None, *, config: 'ConfigType' = None) -> Callable[['AnyCallableT'], 'AnyCallableT']:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT':
|
||||
...
|
||||
|
||||
|
||||
def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any:
|
||||
"""
|
||||
Decorator to validate the arguments passed to a function.
|
||||
"""
|
||||
|
||||
def validate(_func: 'AnyCallable') -> 'AnyCallable':
|
||||
vd = ValidatedFunction(_func, config)
|
||||
|
||||
@wraps(_func)
|
||||
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
|
||||
return vd.call(*args, **kwargs)
|
||||
|
||||
wrapper_function.vd = vd # type: ignore
|
||||
wrapper_function.validate = vd.init_model_instance # type: ignore
|
||||
wrapper_function.raw_function = vd.raw_function # type: ignore
|
||||
wrapper_function.model = vd.model # type: ignore
|
||||
return wrapper_function
|
||||
|
||||
if func:
|
||||
return validate(func)
|
||||
else:
|
||||
return validate
|
||||
|
||||
|
||||
ALT_V_ARGS = 'v__args'
|
||||
ALT_V_KWARGS = 'v__kwargs'
|
||||
V_POSITIONAL_ONLY_NAME = 'v__positional_only'
|
||||
V_DUPLICATE_KWARGS = 'v__duplicate_kwargs'
|
||||
|
||||
|
||||
class ValidatedFunction:
|
||||
def __init__(self, function: 'AnyCallableT', config: 'ConfigType'): # noqa C901
|
||||
from inspect import Parameter, signature
|
||||
|
||||
parameters: Mapping[str, Parameter] = signature(function).parameters
|
||||
|
||||
if parameters.keys() & {ALT_V_ARGS, ALT_V_KWARGS, V_POSITIONAL_ONLY_NAME, V_DUPLICATE_KWARGS}:
|
||||
raise ConfigError(
|
||||
f'"{ALT_V_ARGS}", "{ALT_V_KWARGS}", "{V_POSITIONAL_ONLY_NAME}" and "{V_DUPLICATE_KWARGS}" '
|
||||
f'are not permitted as argument names when using the "{validate_arguments.__name__}" decorator'
|
||||
)
|
||||
|
||||
self.raw_function = function
|
||||
self.arg_mapping: Dict[int, str] = {}
|
||||
self.positional_only_args = set()
|
||||
self.v_args_name = 'args'
|
||||
self.v_kwargs_name = 'kwargs'
|
||||
|
||||
type_hints = get_all_type_hints(function)
|
||||
takes_args = False
|
||||
takes_kwargs = False
|
||||
fields: Dict[str, Tuple[Any, Any]] = {}
|
||||
for i, (name, p) in enumerate(parameters.items()):
|
||||
if p.annotation is p.empty:
|
||||
annotation = Any
|
||||
else:
|
||||
annotation = type_hints[name]
|
||||
|
||||
default = ... if p.default is p.empty else p.default
|
||||
if p.kind == Parameter.POSITIONAL_ONLY:
|
||||
self.arg_mapping[i] = name
|
||||
fields[name] = annotation, default
|
||||
fields[V_POSITIONAL_ONLY_NAME] = List[str], None
|
||||
self.positional_only_args.add(name)
|
||||
elif p.kind == Parameter.POSITIONAL_OR_KEYWORD:
|
||||
self.arg_mapping[i] = name
|
||||
fields[name] = annotation, default
|
||||
fields[V_DUPLICATE_KWARGS] = List[str], None
|
||||
elif p.kind == Parameter.KEYWORD_ONLY:
|
||||
fields[name] = annotation, default
|
||||
elif p.kind == Parameter.VAR_POSITIONAL:
|
||||
self.v_args_name = name
|
||||
fields[name] = Tuple[annotation, ...], None
|
||||
takes_args = True
|
||||
else:
|
||||
assert p.kind == Parameter.VAR_KEYWORD, p.kind
|
||||
self.v_kwargs_name = name
|
||||
fields[name] = Dict[str, annotation], None # type: ignore
|
||||
takes_kwargs = True
|
||||
|
||||
# these checks avoid a clash between "args" and a field with that name
|
||||
if not takes_args and self.v_args_name in fields:
|
||||
self.v_args_name = ALT_V_ARGS
|
||||
|
||||
# same with "kwargs"
|
||||
if not takes_kwargs and self.v_kwargs_name in fields:
|
||||
self.v_kwargs_name = ALT_V_KWARGS
|
||||
|
||||
if not takes_args:
|
||||
# we add the field so validation below can raise the correct exception
|
||||
fields[self.v_args_name] = List[Any], None
|
||||
|
||||
if not takes_kwargs:
|
||||
# same with kwargs
|
||||
fields[self.v_kwargs_name] = Dict[Any, Any], None
|
||||
|
||||
self.create_model(fields, takes_args, takes_kwargs, config)
|
||||
|
||||
def init_model_instance(self, *args: Any, **kwargs: Any) -> BaseModel:
|
||||
values = self.build_values(args, kwargs)
|
||||
return self.model(**values)
|
||||
|
||||
def call(self, *args: Any, **kwargs: Any) -> Any:
|
||||
m = self.init_model_instance(*args, **kwargs)
|
||||
return self.execute(m)
|
||||
|
||||
def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
values: Dict[str, Any] = {}
|
||||
if args:
|
||||
arg_iter = enumerate(args)
|
||||
while True:
|
||||
try:
|
||||
i, a = next(arg_iter)
|
||||
except StopIteration:
|
||||
break
|
||||
arg_name = self.arg_mapping.get(i)
|
||||
if arg_name is not None:
|
||||
values[arg_name] = a
|
||||
else:
|
||||
values[self.v_args_name] = [a] + [a for _, a in arg_iter]
|
||||
break
|
||||
|
||||
var_kwargs: Dict[str, Any] = {}
|
||||
wrong_positional_args = []
|
||||
duplicate_kwargs = []
|
||||
fields_alias = [
|
||||
field.alias
|
||||
for name, field in self.model.__fields__.items()
|
||||
if name not in (self.v_args_name, self.v_kwargs_name)
|
||||
]
|
||||
non_var_fields = set(self.model.__fields__) - {self.v_args_name, self.v_kwargs_name}
|
||||
for k, v in kwargs.items():
|
||||
if k in non_var_fields or k in fields_alias:
|
||||
if k in self.positional_only_args:
|
||||
wrong_positional_args.append(k)
|
||||
if k in values:
|
||||
duplicate_kwargs.append(k)
|
||||
values[k] = v
|
||||
else:
|
||||
var_kwargs[k] = v
|
||||
|
||||
if var_kwargs:
|
||||
values[self.v_kwargs_name] = var_kwargs
|
||||
if wrong_positional_args:
|
||||
values[V_POSITIONAL_ONLY_NAME] = wrong_positional_args
|
||||
if duplicate_kwargs:
|
||||
values[V_DUPLICATE_KWARGS] = duplicate_kwargs
|
||||
return values
|
||||
|
||||
def execute(self, m: BaseModel) -> Any:
|
||||
d = {k: v for k, v in m._iter() if k in m.__fields_set__ or m.__fields__[k].default_factory}
|
||||
var_kwargs = d.pop(self.v_kwargs_name, {})
|
||||
|
||||
if self.v_args_name in d:
|
||||
args_: List[Any] = []
|
||||
in_kwargs = False
|
||||
kwargs = {}
|
||||
for name, value in d.items():
|
||||
if in_kwargs:
|
||||
kwargs[name] = value
|
||||
elif name == self.v_args_name:
|
||||
args_ += value
|
||||
in_kwargs = True
|
||||
else:
|
||||
args_.append(value)
|
||||
return self.raw_function(*args_, **kwargs, **var_kwargs)
|
||||
elif self.positional_only_args:
|
||||
args_ = []
|
||||
kwargs = {}
|
||||
for name, value in d.items():
|
||||
if name in self.positional_only_args:
|
||||
args_.append(value)
|
||||
else:
|
||||
kwargs[name] = value
|
||||
return self.raw_function(*args_, **kwargs, **var_kwargs)
|
||||
else:
|
||||
return self.raw_function(**d, **var_kwargs)
|
||||
|
||||
def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None:
|
||||
pos_args = len(self.arg_mapping)
|
||||
|
||||
class CustomConfig:
|
||||
pass
|
||||
|
||||
if not TYPE_CHECKING: # pragma: no branch
|
||||
if isinstance(config, dict):
|
||||
CustomConfig = type('Config', (), config) # noqa: F811
|
||||
elif config is not None:
|
||||
CustomConfig = config # noqa: F811
|
||||
|
||||
if hasattr(CustomConfig, 'fields') or hasattr(CustomConfig, 'alias_generator'):
|
||||
raise ConfigError(
|
||||
'Setting the "fields" and "alias_generator" property on custom Config for '
|
||||
'@validate_arguments is not yet supported, please remove.'
|
||||
)
|
||||
|
||||
class DecoratorBaseModel(BaseModel):
|
||||
@validator(self.v_args_name, check_fields=False, allow_reuse=True)
|
||||
def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]:
|
||||
if takes_args or v is None:
|
||||
return v
|
||||
|
||||
raise TypeError(f'{pos_args} positional arguments expected but {pos_args + len(v)} given')
|
||||
|
||||
@validator(self.v_kwargs_name, check_fields=False, allow_reuse=True)
|
||||
def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
if takes_kwargs or v is None:
|
||||
return v
|
||||
|
||||
plural = '' if len(v) == 1 else 's'
|
||||
keys = ', '.join(map(repr, v.keys()))
|
||||
raise TypeError(f'unexpected keyword argument{plural}: {keys}')
|
||||
|
||||
@validator(V_POSITIONAL_ONLY_NAME, check_fields=False, allow_reuse=True)
|
||||
def check_positional_only(cls, v: Optional[List[str]]) -> None:
|
||||
if v is None:
|
||||
return
|
||||
|
||||
plural = '' if len(v) == 1 else 's'
|
||||
keys = ', '.join(map(repr, v))
|
||||
raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}')
|
||||
|
||||
@validator(V_DUPLICATE_KWARGS, check_fields=False, allow_reuse=True)
|
||||
def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None:
|
||||
if v is None:
|
||||
return
|
||||
|
||||
plural = '' if len(v) == 1 else 's'
|
||||
keys = ', '.join(map(repr, v))
|
||||
raise TypeError(f'multiple values for argument{plural}: {keys}')
|
||||
|
||||
class Config(CustomConfig):
|
||||
extra = getattr(CustomConfig, 'extra', Extra.forbid)
|
||||
|
||||
self.model = create_model(to_camel(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields)
|
346
lib/pydantic/env_settings.py
Normal file
346
lib/pydantic/env_settings.py
Normal file
|
@ -0,0 +1,346 @@
|
|||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import AbstractSet, Any, Callable, ClassVar, Dict, List, Mapping, Optional, Tuple, Type, Union
|
||||
|
||||
from .config import BaseConfig, Extra
|
||||
from .fields import ModelField
|
||||
from .main import BaseModel
|
||||
from .typing import StrPath, display_as_type, get_origin, is_union
|
||||
from .utils import deep_update, path_type, sequence_like
|
||||
|
||||
env_file_sentinel = str(object())
|
||||
|
||||
SettingsSourceCallable = Callable[['BaseSettings'], Dict[str, Any]]
|
||||
DotenvType = Union[StrPath, List[StrPath], Tuple[StrPath, ...]]
|
||||
|
||||
|
||||
class SettingsError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class BaseSettings(BaseModel):
|
||||
"""
|
||||
Base class for settings, allowing values to be overridden by environment variables.
|
||||
|
||||
This is useful in production for secrets you do not wish to save in code, it plays nicely with docker(-compose),
|
||||
Heroku and any 12 factor app design.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
__pydantic_self__,
|
||||
_env_file: Optional[DotenvType] = env_file_sentinel,
|
||||
_env_file_encoding: Optional[str] = None,
|
||||
_env_nested_delimiter: Optional[str] = None,
|
||||
_secrets_dir: Optional[StrPath] = None,
|
||||
**values: Any,
|
||||
) -> None:
|
||||
# Uses something other than `self` the first arg to allow "self" as a settable attribute
|
||||
super().__init__(
|
||||
**__pydantic_self__._build_values(
|
||||
values,
|
||||
_env_file=_env_file,
|
||||
_env_file_encoding=_env_file_encoding,
|
||||
_env_nested_delimiter=_env_nested_delimiter,
|
||||
_secrets_dir=_secrets_dir,
|
||||
)
|
||||
)
|
||||
|
||||
def _build_values(
|
||||
self,
|
||||
init_kwargs: Dict[str, Any],
|
||||
_env_file: Optional[DotenvType] = None,
|
||||
_env_file_encoding: Optional[str] = None,
|
||||
_env_nested_delimiter: Optional[str] = None,
|
||||
_secrets_dir: Optional[StrPath] = None,
|
||||
) -> Dict[str, Any]:
|
||||
# Configure built-in sources
|
||||
init_settings = InitSettingsSource(init_kwargs=init_kwargs)
|
||||
env_settings = EnvSettingsSource(
|
||||
env_file=(_env_file if _env_file != env_file_sentinel else self.__config__.env_file),
|
||||
env_file_encoding=(
|
||||
_env_file_encoding if _env_file_encoding is not None else self.__config__.env_file_encoding
|
||||
),
|
||||
env_nested_delimiter=(
|
||||
_env_nested_delimiter if _env_nested_delimiter is not None else self.__config__.env_nested_delimiter
|
||||
),
|
||||
env_prefix_len=len(self.__config__.env_prefix),
|
||||
)
|
||||
file_secret_settings = SecretsSettingsSource(secrets_dir=_secrets_dir or self.__config__.secrets_dir)
|
||||
# Provide a hook to set built-in sources priority and add / remove sources
|
||||
sources = self.__config__.customise_sources(
|
||||
init_settings=init_settings, env_settings=env_settings, file_secret_settings=file_secret_settings
|
||||
)
|
||||
if sources:
|
||||
return deep_update(*reversed([source(self) for source in sources]))
|
||||
else:
|
||||
# no one should mean to do this, but I think returning an empty dict is marginally preferable
|
||||
# to an informative error and much better than a confusing error
|
||||
return {}
|
||||
|
||||
class Config(BaseConfig):
|
||||
env_prefix: str = ''
|
||||
env_file: Optional[DotenvType] = None
|
||||
env_file_encoding: Optional[str] = None
|
||||
env_nested_delimiter: Optional[str] = None
|
||||
secrets_dir: Optional[StrPath] = None
|
||||
validate_all: bool = True
|
||||
extra: Extra = Extra.forbid
|
||||
arbitrary_types_allowed: bool = True
|
||||
case_sensitive: bool = False
|
||||
|
||||
@classmethod
|
||||
def prepare_field(cls, field: ModelField) -> None:
|
||||
env_names: Union[List[str], AbstractSet[str]]
|
||||
field_info_from_config = cls.get_field_info(field.name)
|
||||
|
||||
env = field_info_from_config.get('env') or field.field_info.extra.get('env')
|
||||
if env is None:
|
||||
if field.has_alias:
|
||||
warnings.warn(
|
||||
'aliases are no longer used by BaseSettings to define which environment variables to read. '
|
||||
'Instead use the "env" field setting. '
|
||||
'See https://pydantic-docs.helpmanual.io/usage/settings/#environment-variable-names',
|
||||
FutureWarning,
|
||||
)
|
||||
env_names = {cls.env_prefix + field.name}
|
||||
elif isinstance(env, str):
|
||||
env_names = {env}
|
||||
elif isinstance(env, (set, frozenset)):
|
||||
env_names = env
|
||||
elif sequence_like(env):
|
||||
env_names = list(env)
|
||||
else:
|
||||
raise TypeError(f'invalid field env: {env!r} ({display_as_type(env)}); should be string, list or set')
|
||||
|
||||
if not cls.case_sensitive:
|
||||
env_names = env_names.__class__(n.lower() for n in env_names)
|
||||
field.field_info.extra['env_names'] = env_names
|
||||
|
||||
@classmethod
|
||||
def customise_sources(
|
||||
cls,
|
||||
init_settings: SettingsSourceCallable,
|
||||
env_settings: SettingsSourceCallable,
|
||||
file_secret_settings: SettingsSourceCallable,
|
||||
) -> Tuple[SettingsSourceCallable, ...]:
|
||||
return init_settings, env_settings, file_secret_settings
|
||||
|
||||
@classmethod
|
||||
def parse_env_var(cls, field_name: str, raw_val: str) -> Any:
|
||||
return cls.json_loads(raw_val)
|
||||
|
||||
# populated by the metaclass using the Config class defined above, annotated here to help IDEs only
|
||||
__config__: ClassVar[Type[Config]]
|
||||
|
||||
|
||||
class InitSettingsSource:
|
||||
__slots__ = ('init_kwargs',)
|
||||
|
||||
def __init__(self, init_kwargs: Dict[str, Any]):
|
||||
self.init_kwargs = init_kwargs
|
||||
|
||||
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
|
||||
return self.init_kwargs
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})'
|
||||
|
||||
|
||||
class EnvSettingsSource:
|
||||
__slots__ = ('env_file', 'env_file_encoding', 'env_nested_delimiter', 'env_prefix_len')
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env_file: Optional[DotenvType],
|
||||
env_file_encoding: Optional[str],
|
||||
env_nested_delimiter: Optional[str] = None,
|
||||
env_prefix_len: int = 0,
|
||||
):
|
||||
self.env_file: Optional[DotenvType] = env_file
|
||||
self.env_file_encoding: Optional[str] = env_file_encoding
|
||||
self.env_nested_delimiter: Optional[str] = env_nested_delimiter
|
||||
self.env_prefix_len: int = env_prefix_len
|
||||
|
||||
def __call__(self, settings: BaseSettings) -> Dict[str, Any]: # noqa C901
|
||||
"""
|
||||
Build environment variables suitable for passing to the Model.
|
||||
"""
|
||||
d: Dict[str, Any] = {}
|
||||
|
||||
if settings.__config__.case_sensitive:
|
||||
env_vars: Mapping[str, Optional[str]] = os.environ
|
||||
else:
|
||||
env_vars = {k.lower(): v for k, v in os.environ.items()}
|
||||
|
||||
dotenv_vars = self._read_env_files(settings.__config__.case_sensitive)
|
||||
if dotenv_vars:
|
||||
env_vars = {**dotenv_vars, **env_vars}
|
||||
|
||||
for field in settings.__fields__.values():
|
||||
env_val: Optional[str] = None
|
||||
for env_name in field.field_info.extra['env_names']:
|
||||
env_val = env_vars.get(env_name)
|
||||
if env_val is not None:
|
||||
break
|
||||
|
||||
is_complex, allow_parse_failure = self.field_is_complex(field)
|
||||
if is_complex:
|
||||
if env_val is None:
|
||||
# field is complex but no value found so far, try explode_env_vars
|
||||
env_val_built = self.explode_env_vars(field, env_vars)
|
||||
if env_val_built:
|
||||
d[field.alias] = env_val_built
|
||||
else:
|
||||
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
|
||||
try:
|
||||
env_val = settings.__config__.parse_env_var(field.name, env_val)
|
||||
except ValueError as e:
|
||||
if not allow_parse_failure:
|
||||
raise SettingsError(f'error parsing env var "{env_name}"') from e
|
||||
|
||||
if isinstance(env_val, dict):
|
||||
d[field.alias] = deep_update(env_val, self.explode_env_vars(field, env_vars))
|
||||
else:
|
||||
d[field.alias] = env_val
|
||||
elif env_val is not None:
|
||||
# simplest case, field is not complex, we only need to add the value if it was found
|
||||
d[field.alias] = env_val
|
||||
|
||||
return d
|
||||
|
||||
def _read_env_files(self, case_sensitive: bool) -> Dict[str, Optional[str]]:
|
||||
env_files = self.env_file
|
||||
if env_files is None:
|
||||
return {}
|
||||
|
||||
if isinstance(env_files, (str, os.PathLike)):
|
||||
env_files = [env_files]
|
||||
|
||||
dotenv_vars = {}
|
||||
for env_file in env_files:
|
||||
env_path = Path(env_file).expanduser()
|
||||
if env_path.is_file():
|
||||
dotenv_vars.update(
|
||||
read_env_file(env_path, encoding=self.env_file_encoding, case_sensitive=case_sensitive)
|
||||
)
|
||||
|
||||
return dotenv_vars
|
||||
|
||||
def field_is_complex(self, field: ModelField) -> Tuple[bool, bool]:
|
||||
"""
|
||||
Find out if a field is complex, and if so whether JSON errors should be ignored
|
||||
"""
|
||||
if field.is_complex():
|
||||
allow_parse_failure = False
|
||||
elif is_union(get_origin(field.type_)) and field.sub_fields and any(f.is_complex() for f in field.sub_fields):
|
||||
allow_parse_failure = True
|
||||
else:
|
||||
return False, False
|
||||
|
||||
return True, allow_parse_failure
|
||||
|
||||
def explode_env_vars(self, field: ModelField, env_vars: Mapping[str, Optional[str]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
|
||||
|
||||
This is applied to a single field, hence filtering by env_var prefix.
|
||||
"""
|
||||
prefixes = [f'{env_name}{self.env_nested_delimiter}' for env_name in field.field_info.extra['env_names']]
|
||||
result: Dict[str, Any] = {}
|
||||
for env_name, env_val in env_vars.items():
|
||||
if not any(env_name.startswith(prefix) for prefix in prefixes):
|
||||
continue
|
||||
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
|
||||
env_name_without_prefix = env_name[self.env_prefix_len :]
|
||||
_, *keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter)
|
||||
env_var = result
|
||||
for key in keys:
|
||||
env_var = env_var.setdefault(key, {})
|
||||
env_var[last_key] = env_val
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'EnvSettingsSource(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, '
|
||||
f'env_nested_delimiter={self.env_nested_delimiter!r})'
|
||||
)
|
||||
|
||||
|
||||
class SecretsSettingsSource:
|
||||
__slots__ = ('secrets_dir',)
|
||||
|
||||
def __init__(self, secrets_dir: Optional[StrPath]):
|
||||
self.secrets_dir: Optional[StrPath] = secrets_dir
|
||||
|
||||
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
|
||||
"""
|
||||
Build fields from "secrets" files.
|
||||
"""
|
||||
secrets: Dict[str, Optional[str]] = {}
|
||||
|
||||
if self.secrets_dir is None:
|
||||
return secrets
|
||||
|
||||
secrets_path = Path(self.secrets_dir).expanduser()
|
||||
|
||||
if not secrets_path.exists():
|
||||
warnings.warn(f'directory "{secrets_path}" does not exist')
|
||||
return secrets
|
||||
|
||||
if not secrets_path.is_dir():
|
||||
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type(secrets_path)}')
|
||||
|
||||
for field in settings.__fields__.values():
|
||||
for env_name in field.field_info.extra['env_names']:
|
||||
path = find_case_path(secrets_path, env_name, settings.__config__.case_sensitive)
|
||||
if not path:
|
||||
# path does not exist, we curently don't return a warning for this
|
||||
continue
|
||||
|
||||
if path.is_file():
|
||||
secret_value = path.read_text().strip()
|
||||
if field.is_complex():
|
||||
try:
|
||||
secret_value = settings.__config__.parse_env_var(field.name, secret_value)
|
||||
except ValueError as e:
|
||||
raise SettingsError(f'error parsing env var "{env_name}"') from e
|
||||
|
||||
secrets[field.alias] = secret_value
|
||||
else:
|
||||
warnings.warn(
|
||||
f'attempted to load secret file "{path}" but found a {path_type(path)} instead.',
|
||||
stacklevel=4,
|
||||
)
|
||||
return secrets
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'SecretsSettingsSource(secrets_dir={self.secrets_dir!r})'
|
||||
|
||||
|
||||
def read_env_file(
|
||||
file_path: StrPath, *, encoding: str = None, case_sensitive: bool = False
|
||||
) -> Dict[str, Optional[str]]:
|
||||
try:
|
||||
from dotenv import dotenv_values
|
||||
except ImportError as e:
|
||||
raise ImportError('python-dotenv is not installed, run `pip install pydantic[dotenv]`') from e
|
||||
|
||||
file_vars: Dict[str, Optional[str]] = dotenv_values(file_path, encoding=encoding or 'utf8')
|
||||
if not case_sensitive:
|
||||
return {k.lower(): v for k, v in file_vars.items()}
|
||||
else:
|
||||
return file_vars
|
||||
|
||||
|
||||
def find_case_path(dir_path: Path, file_name: str, case_sensitive: bool) -> Optional[Path]:
|
||||
"""
|
||||
Find a file within path's directory matching filename, optionally ignoring case.
|
||||
"""
|
||||
for f in dir_path.iterdir():
|
||||
if f.name == file_name:
|
||||
return f
|
||||
elif not case_sensitive and f.name.lower() == file_name.lower():
|
||||
return f
|
||||
return None
|
162
lib/pydantic/error_wrappers.py
Normal file
162
lib/pydantic/error_wrappers.py
Normal file
|
@ -0,0 +1,162 @@
|
|||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from .json import pydantic_encoder
|
||||
from .utils import Representation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from .config import BaseConfig
|
||||
from .types import ModelOrDc
|
||||
from .typing import ReprArgs
|
||||
|
||||
Loc = Tuple[Union[int, str], ...]
|
||||
|
||||
class _ErrorDictRequired(TypedDict):
|
||||
loc: Loc
|
||||
msg: str
|
||||
type: str
|
||||
|
||||
class ErrorDict(_ErrorDictRequired, total=False):
|
||||
ctx: Dict[str, Any]
|
||||
|
||||
|
||||
__all__ = 'ErrorWrapper', 'ValidationError'
|
||||
|
||||
|
||||
class ErrorWrapper(Representation):
|
||||
__slots__ = 'exc', '_loc'
|
||||
|
||||
def __init__(self, exc: Exception, loc: Union[str, 'Loc']) -> None:
|
||||
self.exc = exc
|
||||
self._loc = loc
|
||||
|
||||
def loc_tuple(self) -> 'Loc':
|
||||
if isinstance(self._loc, tuple):
|
||||
return self._loc
|
||||
else:
|
||||
return (self._loc,)
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
return [('exc', self.exc), ('loc', self.loc_tuple())]
|
||||
|
||||
|
||||
# ErrorList is something like Union[List[Union[List[ErrorWrapper], ErrorWrapper]], ErrorWrapper]
|
||||
# but recursive, therefore just use:
|
||||
ErrorList = Union[Sequence[Any], ErrorWrapper]
|
||||
|
||||
|
||||
class ValidationError(Representation, ValueError):
|
||||
__slots__ = 'raw_errors', 'model', '_error_cache'
|
||||
|
||||
def __init__(self, errors: Sequence[ErrorList], model: 'ModelOrDc') -> None:
|
||||
self.raw_errors = errors
|
||||
self.model = model
|
||||
self._error_cache: Optional[List['ErrorDict']] = None
|
||||
|
||||
def errors(self) -> List['ErrorDict']:
|
||||
if self._error_cache is None:
|
||||
try:
|
||||
config = self.model.__config__ # type: ignore
|
||||
except AttributeError:
|
||||
config = self.model.__pydantic_model__.__config__ # type: ignore
|
||||
self._error_cache = list(flatten_errors(self.raw_errors, config))
|
||||
return self._error_cache
|
||||
|
||||
def json(self, *, indent: Union[None, int, str] = 2) -> str:
|
||||
return json.dumps(self.errors(), indent=indent, default=pydantic_encoder)
|
||||
|
||||
def __str__(self) -> str:
|
||||
errors = self.errors()
|
||||
no_errors = len(errors)
|
||||
return (
|
||||
f'{no_errors} validation error{"" if no_errors == 1 else "s"} for {self.model.__name__}\n'
|
||||
f'{display_errors(errors)}'
|
||||
)
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
return [('model', self.model.__name__), ('errors', self.errors())]
|
||||
|
||||
|
||||
def display_errors(errors: List['ErrorDict']) -> str:
|
||||
return '\n'.join(f'{_display_error_loc(e)}\n {e["msg"]} ({_display_error_type_and_ctx(e)})' for e in errors)
|
||||
|
||||
|
||||
def _display_error_loc(error: 'ErrorDict') -> str:
|
||||
return ' -> '.join(str(e) for e in error['loc'])
|
||||
|
||||
|
||||
def _display_error_type_and_ctx(error: 'ErrorDict') -> str:
|
||||
t = 'type=' + error['type']
|
||||
ctx = error.get('ctx')
|
||||
if ctx:
|
||||
return t + ''.join(f'; {k}={v}' for k, v in ctx.items())
|
||||
else:
|
||||
return t
|
||||
|
||||
|
||||
def flatten_errors(
|
||||
errors: Sequence[Any], config: Type['BaseConfig'], loc: Optional['Loc'] = None
|
||||
) -> Generator['ErrorDict', None, None]:
|
||||
for error in errors:
|
||||
if isinstance(error, ErrorWrapper):
|
||||
|
||||
if loc:
|
||||
error_loc = loc + error.loc_tuple()
|
||||
else:
|
||||
error_loc = error.loc_tuple()
|
||||
|
||||
if isinstance(error.exc, ValidationError):
|
||||
yield from flatten_errors(error.exc.raw_errors, config, error_loc)
|
||||
else:
|
||||
yield error_dict(error.exc, config, error_loc)
|
||||
elif isinstance(error, list):
|
||||
yield from flatten_errors(error, config, loc=loc)
|
||||
else:
|
||||
raise RuntimeError(f'Unknown error object: {error}')
|
||||
|
||||
|
||||
def error_dict(exc: Exception, config: Type['BaseConfig'], loc: 'Loc') -> 'ErrorDict':
|
||||
type_ = get_exc_type(exc.__class__)
|
||||
msg_template = config.error_msg_templates.get(type_) or getattr(exc, 'msg_template', None)
|
||||
ctx = exc.__dict__
|
||||
if msg_template:
|
||||
msg = msg_template.format(**ctx)
|
||||
else:
|
||||
msg = str(exc)
|
||||
|
||||
d: 'ErrorDict' = {'loc': loc, 'msg': msg, 'type': type_}
|
||||
|
||||
if ctx:
|
||||
d['ctx'] = ctx
|
||||
|
||||
return d
|
||||
|
||||
|
||||
_EXC_TYPE_CACHE: Dict[Type[Exception], str] = {}
|
||||
|
||||
|
||||
def get_exc_type(cls: Type[Exception]) -> str:
|
||||
# slightly more efficient than using lru_cache since we don't need to worry about the cache filling up
|
||||
try:
|
||||
return _EXC_TYPE_CACHE[cls]
|
||||
except KeyError:
|
||||
r = _get_exc_type(cls)
|
||||
_EXC_TYPE_CACHE[cls] = r
|
||||
return r
|
||||
|
||||
|
||||
def _get_exc_type(cls: Type[Exception]) -> str:
|
||||
if issubclass(cls, AssertionError):
|
||||
return 'assertion_error'
|
||||
|
||||
base_name = 'type_error' if issubclass(cls, TypeError) else 'value_error'
|
||||
if cls in (TypeError, ValueError):
|
||||
# just TypeError or ValueError, no extra code
|
||||
return base_name
|
||||
|
||||
# if it's not a TypeError or ValueError, we just take the lowercase of the exception name
|
||||
# no chaining or snake case logic, use "code" for more complex error types.
|
||||
code = getattr(cls, 'code', None) or cls.__name__.replace('Error', '').lower()
|
||||
return base_name + '.' + code
|
646
lib/pydantic/errors.py
Normal file
646
lib/pydantic/errors.py
Normal file
|
@ -0,0 +1,646 @@
|
|||
from decimal import Decimal
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Sequence, Set, Tuple, Type, Union
|
||||
|
||||
from .typing import display_as_type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .typing import DictStrAny
|
||||
|
||||
# explicitly state exports to avoid "from .errors import *" also importing Decimal, Path etc.
|
||||
__all__ = (
|
||||
'PydanticTypeError',
|
||||
'PydanticValueError',
|
||||
'ConfigError',
|
||||
'MissingError',
|
||||
'ExtraError',
|
||||
'NoneIsNotAllowedError',
|
||||
'NoneIsAllowedError',
|
||||
'WrongConstantError',
|
||||
'NotNoneError',
|
||||
'BoolError',
|
||||
'BytesError',
|
||||
'DictError',
|
||||
'EmailError',
|
||||
'UrlError',
|
||||
'UrlSchemeError',
|
||||
'UrlSchemePermittedError',
|
||||
'UrlUserInfoError',
|
||||
'UrlHostError',
|
||||
'UrlHostTldError',
|
||||
'UrlPortError',
|
||||
'UrlExtraError',
|
||||
'EnumError',
|
||||
'IntEnumError',
|
||||
'EnumMemberError',
|
||||
'IntegerError',
|
||||
'FloatError',
|
||||
'PathError',
|
||||
'PathNotExistsError',
|
||||
'PathNotAFileError',
|
||||
'PathNotADirectoryError',
|
||||
'PyObjectError',
|
||||
'SequenceError',
|
||||
'ListError',
|
||||
'SetError',
|
||||
'FrozenSetError',
|
||||
'TupleError',
|
||||
'TupleLengthError',
|
||||
'ListMinLengthError',
|
||||
'ListMaxLengthError',
|
||||
'ListUniqueItemsError',
|
||||
'SetMinLengthError',
|
||||
'SetMaxLengthError',
|
||||
'FrozenSetMinLengthError',
|
||||
'FrozenSetMaxLengthError',
|
||||
'AnyStrMinLengthError',
|
||||
'AnyStrMaxLengthError',
|
||||
'StrError',
|
||||
'StrRegexError',
|
||||
'NumberNotGtError',
|
||||
'NumberNotGeError',
|
||||
'NumberNotLtError',
|
||||
'NumberNotLeError',
|
||||
'NumberNotMultipleError',
|
||||
'DecimalError',
|
||||
'DecimalIsNotFiniteError',
|
||||
'DecimalMaxDigitsError',
|
||||
'DecimalMaxPlacesError',
|
||||
'DecimalWholeDigitsError',
|
||||
'DateTimeError',
|
||||
'DateError',
|
||||
'DateNotInThePastError',
|
||||
'DateNotInTheFutureError',
|
||||
'TimeError',
|
||||
'DurationError',
|
||||
'HashableError',
|
||||
'UUIDError',
|
||||
'UUIDVersionError',
|
||||
'ArbitraryTypeError',
|
||||
'ClassError',
|
||||
'SubclassError',
|
||||
'JsonError',
|
||||
'JsonTypeError',
|
||||
'PatternError',
|
||||
'DataclassTypeError',
|
||||
'CallableError',
|
||||
'IPvAnyAddressError',
|
||||
'IPvAnyInterfaceError',
|
||||
'IPvAnyNetworkError',
|
||||
'IPv4AddressError',
|
||||
'IPv6AddressError',
|
||||
'IPv4NetworkError',
|
||||
'IPv6NetworkError',
|
||||
'IPv4InterfaceError',
|
||||
'IPv6InterfaceError',
|
||||
'ColorError',
|
||||
'StrictBoolError',
|
||||
'NotDigitError',
|
||||
'LuhnValidationError',
|
||||
'InvalidLengthForBrand',
|
||||
'InvalidByteSize',
|
||||
'InvalidByteSizeUnit',
|
||||
'MissingDiscriminator',
|
||||
'InvalidDiscriminator',
|
||||
)
|
||||
|
||||
|
||||
def cls_kwargs(cls: Type['PydanticErrorMixin'], ctx: 'DictStrAny') -> 'PydanticErrorMixin':
|
||||
"""
|
||||
For built-in exceptions like ValueError or TypeError, we need to implement
|
||||
__reduce__ to override the default behaviour (instead of __getstate__/__setstate__)
|
||||
By default pickle protocol 2 calls `cls.__new__(cls, *args)`.
|
||||
Since we only use kwargs, we need a little constructor to change that.
|
||||
Note: the callable can't be a lambda as pickle looks in the namespace to find it
|
||||
"""
|
||||
return cls(**ctx)
|
||||
|
||||
|
||||
class PydanticErrorMixin:
|
||||
code: str
|
||||
msg_template: str
|
||||
|
||||
def __init__(self, **ctx: Any) -> None:
|
||||
self.__dict__ = ctx
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.msg_template.format(**self.__dict__)
|
||||
|
||||
def __reduce__(self) -> Tuple[Callable[..., 'PydanticErrorMixin'], Tuple[Type['PydanticErrorMixin'], 'DictStrAny']]:
|
||||
return cls_kwargs, (self.__class__, self.__dict__)
|
||||
|
||||
|
||||
class PydanticTypeError(PydanticErrorMixin, TypeError):
|
||||
pass
|
||||
|
||||
|
||||
class PydanticValueError(PydanticErrorMixin, ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ConfigError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
class MissingError(PydanticValueError):
|
||||
msg_template = 'field required'
|
||||
|
||||
|
||||
class ExtraError(PydanticValueError):
|
||||
msg_template = 'extra fields not permitted'
|
||||
|
||||
|
||||
class NoneIsNotAllowedError(PydanticTypeError):
|
||||
code = 'none.not_allowed'
|
||||
msg_template = 'none is not an allowed value'
|
||||
|
||||
|
||||
class NoneIsAllowedError(PydanticTypeError):
|
||||
code = 'none.allowed'
|
||||
msg_template = 'value is not none'
|
||||
|
||||
|
||||
class WrongConstantError(PydanticValueError):
|
||||
code = 'const'
|
||||
|
||||
def __str__(self) -> str:
|
||||
permitted = ', '.join(repr(v) for v in self.permitted) # type: ignore
|
||||
return f'unexpected value; permitted: {permitted}'
|
||||
|
||||
|
||||
class NotNoneError(PydanticTypeError):
|
||||
code = 'not_none'
|
||||
msg_template = 'value is not None'
|
||||
|
||||
|
||||
class BoolError(PydanticTypeError):
|
||||
msg_template = 'value could not be parsed to a boolean'
|
||||
|
||||
|
||||
class BytesError(PydanticTypeError):
|
||||
msg_template = 'byte type expected'
|
||||
|
||||
|
||||
class DictError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid dict'
|
||||
|
||||
|
||||
class EmailError(PydanticValueError):
|
||||
msg_template = 'value is not a valid email address'
|
||||
|
||||
|
||||
class UrlError(PydanticValueError):
|
||||
code = 'url'
|
||||
|
||||
|
||||
class UrlSchemeError(UrlError):
|
||||
code = 'url.scheme'
|
||||
msg_template = 'invalid or missing URL scheme'
|
||||
|
||||
|
||||
class UrlSchemePermittedError(UrlError):
|
||||
code = 'url.scheme'
|
||||
msg_template = 'URL scheme not permitted'
|
||||
|
||||
def __init__(self, allowed_schemes: Set[str]):
|
||||
super().__init__(allowed_schemes=allowed_schemes)
|
||||
|
||||
|
||||
class UrlUserInfoError(UrlError):
|
||||
code = 'url.userinfo'
|
||||
msg_template = 'userinfo required in URL but missing'
|
||||
|
||||
|
||||
class UrlHostError(UrlError):
|
||||
code = 'url.host'
|
||||
msg_template = 'URL host invalid'
|
||||
|
||||
|
||||
class UrlHostTldError(UrlError):
|
||||
code = 'url.host'
|
||||
msg_template = 'URL host invalid, top level domain required'
|
||||
|
||||
|
||||
class UrlPortError(UrlError):
|
||||
code = 'url.port'
|
||||
msg_template = 'URL port invalid, port cannot exceed 65535'
|
||||
|
||||
|
||||
class UrlExtraError(UrlError):
|
||||
code = 'url.extra'
|
||||
msg_template = 'URL invalid, extra characters found after valid URL: {extra!r}'
|
||||
|
||||
|
||||
class EnumMemberError(PydanticTypeError):
|
||||
code = 'enum'
|
||||
|
||||
def __str__(self) -> str:
|
||||
permitted = ', '.join(repr(v.value) for v in self.enum_values) # type: ignore
|
||||
return f'value is not a valid enumeration member; permitted: {permitted}'
|
||||
|
||||
|
||||
class IntegerError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid integer'
|
||||
|
||||
|
||||
class FloatError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid float'
|
||||
|
||||
|
||||
class PathError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid path'
|
||||
|
||||
|
||||
class _PathValueError(PydanticValueError):
|
||||
def __init__(self, *, path: Path) -> None:
|
||||
super().__init__(path=str(path))
|
||||
|
||||
|
||||
class PathNotExistsError(_PathValueError):
|
||||
code = 'path.not_exists'
|
||||
msg_template = 'file or directory at path "{path}" does not exist'
|
||||
|
||||
|
||||
class PathNotAFileError(_PathValueError):
|
||||
code = 'path.not_a_file'
|
||||
msg_template = 'path "{path}" does not point to a file'
|
||||
|
||||
|
||||
class PathNotADirectoryError(_PathValueError):
|
||||
code = 'path.not_a_directory'
|
||||
msg_template = 'path "{path}" does not point to a directory'
|
||||
|
||||
|
||||
class PyObjectError(PydanticTypeError):
|
||||
msg_template = 'ensure this value contains valid import path or valid callable: {error_message}'
|
||||
|
||||
|
||||
class SequenceError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid sequence'
|
||||
|
||||
|
||||
class IterableError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid iterable'
|
||||
|
||||
|
||||
class ListError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid list'
|
||||
|
||||
|
||||
class SetError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid set'
|
||||
|
||||
|
||||
class FrozenSetError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid frozenset'
|
||||
|
||||
|
||||
class DequeError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid deque'
|
||||
|
||||
|
||||
class TupleError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid tuple'
|
||||
|
||||
|
||||
class TupleLengthError(PydanticValueError):
|
||||
code = 'tuple.length'
|
||||
msg_template = 'wrong tuple length {actual_length}, expected {expected_length}'
|
||||
|
||||
def __init__(self, *, actual_length: int, expected_length: int) -> None:
|
||||
super().__init__(actual_length=actual_length, expected_length=expected_length)
|
||||
|
||||
|
||||
class ListMinLengthError(PydanticValueError):
|
||||
code = 'list.min_items'
|
||||
msg_template = 'ensure this value has at least {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class ListMaxLengthError(PydanticValueError):
|
||||
code = 'list.max_items'
|
||||
msg_template = 'ensure this value has at most {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class ListUniqueItemsError(PydanticValueError):
|
||||
code = 'list.unique_items'
|
||||
msg_template = 'the list has duplicated items'
|
||||
|
||||
|
||||
class SetMinLengthError(PydanticValueError):
|
||||
code = 'set.min_items'
|
||||
msg_template = 'ensure this value has at least {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class SetMaxLengthError(PydanticValueError):
|
||||
code = 'set.max_items'
|
||||
msg_template = 'ensure this value has at most {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class FrozenSetMinLengthError(PydanticValueError):
|
||||
code = 'frozenset.min_items'
|
||||
msg_template = 'ensure this value has at least {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class FrozenSetMaxLengthError(PydanticValueError):
|
||||
code = 'frozenset.max_items'
|
||||
msg_template = 'ensure this value has at most {limit_value} items'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class AnyStrMinLengthError(PydanticValueError):
|
||||
code = 'any_str.min_length'
|
||||
msg_template = 'ensure this value has at least {limit_value} characters'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class AnyStrMaxLengthError(PydanticValueError):
|
||||
code = 'any_str.max_length'
|
||||
msg_template = 'ensure this value has at most {limit_value} characters'
|
||||
|
||||
def __init__(self, *, limit_value: int) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class StrError(PydanticTypeError):
|
||||
msg_template = 'str type expected'
|
||||
|
||||
|
||||
class StrRegexError(PydanticValueError):
|
||||
code = 'str.regex'
|
||||
msg_template = 'string does not match regex "{pattern}"'
|
||||
|
||||
def __init__(self, *, pattern: str) -> None:
|
||||
super().__init__(pattern=pattern)
|
||||
|
||||
|
||||
class _NumberBoundError(PydanticValueError):
|
||||
def __init__(self, *, limit_value: Union[int, float, Decimal]) -> None:
|
||||
super().__init__(limit_value=limit_value)
|
||||
|
||||
|
||||
class NumberNotGtError(_NumberBoundError):
|
||||
code = 'number.not_gt'
|
||||
msg_template = 'ensure this value is greater than {limit_value}'
|
||||
|
||||
|
||||
class NumberNotGeError(_NumberBoundError):
|
||||
code = 'number.not_ge'
|
||||
msg_template = 'ensure this value is greater than or equal to {limit_value}'
|
||||
|
||||
|
||||
class NumberNotLtError(_NumberBoundError):
|
||||
code = 'number.not_lt'
|
||||
msg_template = 'ensure this value is less than {limit_value}'
|
||||
|
||||
|
||||
class NumberNotLeError(_NumberBoundError):
|
||||
code = 'number.not_le'
|
||||
msg_template = 'ensure this value is less than or equal to {limit_value}'
|
||||
|
||||
|
||||
class NumberNotFiniteError(PydanticValueError):
|
||||
code = 'number.not_finite_number'
|
||||
msg_template = 'ensure this value is a finite number'
|
||||
|
||||
|
||||
class NumberNotMultipleError(PydanticValueError):
|
||||
code = 'number.not_multiple'
|
||||
msg_template = 'ensure this value is a multiple of {multiple_of}'
|
||||
|
||||
def __init__(self, *, multiple_of: Union[int, float, Decimal]) -> None:
|
||||
super().__init__(multiple_of=multiple_of)
|
||||
|
||||
|
||||
class DecimalError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid decimal'
|
||||
|
||||
|
||||
class DecimalIsNotFiniteError(PydanticValueError):
|
||||
code = 'decimal.not_finite'
|
||||
msg_template = 'value is not a valid decimal'
|
||||
|
||||
|
||||
class DecimalMaxDigitsError(PydanticValueError):
|
||||
code = 'decimal.max_digits'
|
||||
msg_template = 'ensure that there are no more than {max_digits} digits in total'
|
||||
|
||||
def __init__(self, *, max_digits: int) -> None:
|
||||
super().__init__(max_digits=max_digits)
|
||||
|
||||
|
||||
class DecimalMaxPlacesError(PydanticValueError):
|
||||
code = 'decimal.max_places'
|
||||
msg_template = 'ensure that there are no more than {decimal_places} decimal places'
|
||||
|
||||
def __init__(self, *, decimal_places: int) -> None:
|
||||
super().__init__(decimal_places=decimal_places)
|
||||
|
||||
|
||||
class DecimalWholeDigitsError(PydanticValueError):
|
||||
code = 'decimal.whole_digits'
|
||||
msg_template = 'ensure that there are no more than {whole_digits} digits before the decimal point'
|
||||
|
||||
def __init__(self, *, whole_digits: int) -> None:
|
||||
super().__init__(whole_digits=whole_digits)
|
||||
|
||||
|
||||
class DateTimeError(PydanticValueError):
|
||||
msg_template = 'invalid datetime format'
|
||||
|
||||
|
||||
class DateError(PydanticValueError):
|
||||
msg_template = 'invalid date format'
|
||||
|
||||
|
||||
class DateNotInThePastError(PydanticValueError):
|
||||
code = 'date.not_in_the_past'
|
||||
msg_template = 'date is not in the past'
|
||||
|
||||
|
||||
class DateNotInTheFutureError(PydanticValueError):
|
||||
code = 'date.not_in_the_future'
|
||||
msg_template = 'date is not in the future'
|
||||
|
||||
|
||||
class TimeError(PydanticValueError):
|
||||
msg_template = 'invalid time format'
|
||||
|
||||
|
||||
class DurationError(PydanticValueError):
|
||||
msg_template = 'invalid duration format'
|
||||
|
||||
|
||||
class HashableError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid hashable'
|
||||
|
||||
|
||||
class UUIDError(PydanticTypeError):
|
||||
msg_template = 'value is not a valid uuid'
|
||||
|
||||
|
||||
class UUIDVersionError(PydanticValueError):
|
||||
code = 'uuid.version'
|
||||
msg_template = 'uuid version {required_version} expected'
|
||||
|
||||
def __init__(self, *, required_version: int) -> None:
|
||||
super().__init__(required_version=required_version)
|
||||
|
||||
|
||||
class ArbitraryTypeError(PydanticTypeError):
|
||||
code = 'arbitrary_type'
|
||||
msg_template = 'instance of {expected_arbitrary_type} expected'
|
||||
|
||||
def __init__(self, *, expected_arbitrary_type: Type[Any]) -> None:
|
||||
super().__init__(expected_arbitrary_type=display_as_type(expected_arbitrary_type))
|
||||
|
||||
|
||||
class ClassError(PydanticTypeError):
|
||||
code = 'class'
|
||||
msg_template = 'a class is expected'
|
||||
|
||||
|
||||
class SubclassError(PydanticTypeError):
|
||||
code = 'subclass'
|
||||
msg_template = 'subclass of {expected_class} expected'
|
||||
|
||||
def __init__(self, *, expected_class: Type[Any]) -> None:
|
||||
super().__init__(expected_class=display_as_type(expected_class))
|
||||
|
||||
|
||||
class JsonError(PydanticValueError):
|
||||
msg_template = 'Invalid JSON'
|
||||
|
||||
|
||||
class JsonTypeError(PydanticTypeError):
|
||||
code = 'json'
|
||||
msg_template = 'JSON object must be str, bytes or bytearray'
|
||||
|
||||
|
||||
class PatternError(PydanticValueError):
|
||||
code = 'regex_pattern'
|
||||
msg_template = 'Invalid regular expression'
|
||||
|
||||
|
||||
class DataclassTypeError(PydanticTypeError):
|
||||
code = 'dataclass'
|
||||
msg_template = 'instance of {class_name}, tuple or dict expected'
|
||||
|
||||
|
||||
class CallableError(PydanticTypeError):
|
||||
msg_template = '{value} is not callable'
|
||||
|
||||
|
||||
class EnumError(PydanticTypeError):
|
||||
code = 'enum_instance'
|
||||
msg_template = '{value} is not a valid Enum instance'
|
||||
|
||||
|
||||
class IntEnumError(PydanticTypeError):
|
||||
code = 'int_enum_instance'
|
||||
msg_template = '{value} is not a valid IntEnum instance'
|
||||
|
||||
|
||||
class IPvAnyAddressError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 or IPv6 address'
|
||||
|
||||
|
||||
class IPvAnyInterfaceError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 or IPv6 interface'
|
||||
|
||||
|
||||
class IPvAnyNetworkError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 or IPv6 network'
|
||||
|
||||
|
||||
class IPv4AddressError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 address'
|
||||
|
||||
|
||||
class IPv6AddressError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv6 address'
|
||||
|
||||
|
||||
class IPv4NetworkError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 network'
|
||||
|
||||
|
||||
class IPv6NetworkError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv6 network'
|
||||
|
||||
|
||||
class IPv4InterfaceError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv4 interface'
|
||||
|
||||
|
||||
class IPv6InterfaceError(PydanticValueError):
|
||||
msg_template = 'value is not a valid IPv6 interface'
|
||||
|
||||
|
||||
class ColorError(PydanticValueError):
|
||||
msg_template = 'value is not a valid color: {reason}'
|
||||
|
||||
|
||||
class StrictBoolError(PydanticValueError):
|
||||
msg_template = 'value is not a valid boolean'
|
||||
|
||||
|
||||
class NotDigitError(PydanticValueError):
|
||||
code = 'payment_card_number.digits'
|
||||
msg_template = 'card number is not all digits'
|
||||
|
||||
|
||||
class LuhnValidationError(PydanticValueError):
|
||||
code = 'payment_card_number.luhn_check'
|
||||
msg_template = 'card number is not luhn valid'
|
||||
|
||||
|
||||
class InvalidLengthForBrand(PydanticValueError):
|
||||
code = 'payment_card_number.invalid_length_for_brand'
|
||||
msg_template = 'Length for a {brand} card must be {required_length}'
|
||||
|
||||
|
||||
class InvalidByteSize(PydanticValueError):
|
||||
msg_template = 'could not parse value and unit from byte string'
|
||||
|
||||
|
||||
class InvalidByteSizeUnit(PydanticValueError):
|
||||
msg_template = 'could not interpret byte unit: {unit}'
|
||||
|
||||
|
||||
class MissingDiscriminator(PydanticValueError):
|
||||
code = 'discriminated_union.missing_discriminator'
|
||||
msg_template = 'Discriminator {discriminator_key!r} is missing in value'
|
||||
|
||||
|
||||
class InvalidDiscriminator(PydanticValueError):
|
||||
code = 'discriminated_union.invalid_discriminator'
|
||||
msg_template = (
|
||||
'No match for discriminator {discriminator_key!r} and value {discriminator_value!r} '
|
||||
'(allowed values: {allowed_values})'
|
||||
)
|
||||
|
||||
def __init__(self, *, discriminator_key: str, discriminator_value: Any, allowed_values: Sequence[Any]) -> None:
|
||||
super().__init__(
|
||||
discriminator_key=discriminator_key,
|
||||
discriminator_value=discriminator_value,
|
||||
allowed_values=', '.join(map(repr, allowed_values)),
|
||||
)
|
1247
lib/pydantic/fields.py
Normal file
1247
lib/pydantic/fields.py
Normal file
File diff suppressed because it is too large
Load diff
364
lib/pydantic/generics.py
Normal file
364
lib/pydantic/generics.py
Normal file
|
@ -0,0 +1,364 @@
|
|||
import sys
|
||||
import typing
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .class_validators import gather_all_validators
|
||||
from .fields import DeferredType
|
||||
from .main import BaseModel, create_model
|
||||
from .types import JsonWrapper
|
||||
from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base
|
||||
from .utils import LimitedDict, all_identical, lenient_issubclass
|
||||
|
||||
GenericModelT = TypeVar('GenericModelT', bound='GenericModel')
|
||||
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type
|
||||
|
||||
Parametrization = Mapping[TypeVarType, Type[Any]]
|
||||
|
||||
_generic_types_cache: LimitedDict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = LimitedDict()
|
||||
# _assigned_parameters is a Mapping from parametrized version of generic models to assigned types of parametrizations
|
||||
# as captured during construction of the class (not instances).
|
||||
# E.g., for generic model `Model[A, B]`, when parametrized model `Model[int, str]` is created,
|
||||
# `Model[int, str]`: {A: int, B: str}` will be stored in `_assigned_parameters`.
|
||||
# (This information is only otherwise available after creation from the class name string).
|
||||
_assigned_parameters: LimitedDict[Type[Any], Parametrization] = LimitedDict()
|
||||
|
||||
|
||||
class GenericModel(BaseModel):
|
||||
__slots__ = ()
|
||||
__concrete__: ClassVar[bool] = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with
|
||||
# `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of
|
||||
# `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below.
|
||||
__parameters__: ClassVar[Tuple[TypeVarType, ...]]
|
||||
|
||||
# Setting the return type as Type[Any] instead of Type[BaseModel] prevents PyCharm warnings
|
||||
def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[Type[Any], ...]]) -> Type[Any]:
|
||||
"""Instantiates a new class from a generic class `cls` and type variables `params`.
|
||||
|
||||
:param params: Tuple of types the class . Given a generic class
|
||||
`Model` with 2 type variables and a concrete model `Model[str, int]`,
|
||||
the value `(str, int)` would be passed to `params`.
|
||||
:return: New model class inheriting from `cls` with instantiated
|
||||
types described by `params`. If no parameters are given, `cls` is
|
||||
returned as is.
|
||||
|
||||
"""
|
||||
|
||||
def _cache_key(_params: Any) -> Tuple[Type[GenericModelT], Any, Tuple[Any, ...]]:
|
||||
return cls, _params, get_args(_params)
|
||||
|
||||
cached = _generic_types_cache.get(_cache_key(params))
|
||||
if cached is not None:
|
||||
return cached
|
||||
if cls.__concrete__ and Generic not in cls.__bases__:
|
||||
raise TypeError('Cannot parameterize a concrete instantiation of a generic model')
|
||||
if not isinstance(params, tuple):
|
||||
params = (params,)
|
||||
if cls is GenericModel and any(isinstance(param, TypeVar) for param in params):
|
||||
raise TypeError('Type parameters should be placed on typing.Generic, not GenericModel')
|
||||
if not hasattr(cls, '__parameters__'):
|
||||
raise TypeError(f'Type {cls.__name__} must inherit from typing.Generic before being parameterized')
|
||||
|
||||
check_parameters_count(cls, params)
|
||||
# Build map from generic typevars to passed params
|
||||
typevars_map: Dict[TypeVarType, Type[Any]] = dict(zip(cls.__parameters__, params))
|
||||
if all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map:
|
||||
return cls # if arguments are equal to parameters it's the same object
|
||||
|
||||
# Create new model with original model as parent inserting fields with DeferredType.
|
||||
model_name = cls.__concrete_name__(params)
|
||||
validators = gather_all_validators(cls)
|
||||
|
||||
type_hints = get_all_type_hints(cls).items()
|
||||
instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar}
|
||||
|
||||
fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__}
|
||||
|
||||
model_module, called_globally = get_caller_frame_info()
|
||||
created_model = cast(
|
||||
Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes
|
||||
create_model(
|
||||
model_name,
|
||||
__module__=model_module or cls.__module__,
|
||||
__base__=(cls,) + tuple(cls.__parameterized_bases__(typevars_map)),
|
||||
__config__=None,
|
||||
__validators__=validators,
|
||||
__cls_kwargs__=None,
|
||||
**fields,
|
||||
),
|
||||
)
|
||||
|
||||
_assigned_parameters[created_model] = typevars_map
|
||||
|
||||
if called_globally: # create global reference and therefore allow pickling
|
||||
object_by_reference = None
|
||||
reference_name = model_name
|
||||
reference_module_globals = sys.modules[created_model.__module__].__dict__
|
||||
while object_by_reference is not created_model:
|
||||
object_by_reference = reference_module_globals.setdefault(reference_name, created_model)
|
||||
reference_name += '_'
|
||||
|
||||
created_model.Config = cls.Config
|
||||
|
||||
# Find any typevars that are still present in the model.
|
||||
# If none are left, the model is fully "concrete", otherwise the new
|
||||
# class is a generic class as well taking the found typevars as
|
||||
# parameters.
|
||||
new_params = tuple(
|
||||
{param: None for param in iter_contained_typevars(typevars_map.values())}
|
||||
) # use dict as ordered set
|
||||
created_model.__concrete__ = not new_params
|
||||
if new_params:
|
||||
created_model.__parameters__ = new_params
|
||||
|
||||
# Save created model in cache so we don't end up creating duplicate
|
||||
# models that should be identical.
|
||||
_generic_types_cache[_cache_key(params)] = created_model
|
||||
if len(params) == 1:
|
||||
_generic_types_cache[_cache_key(params[0])] = created_model
|
||||
|
||||
# Recursively walk class type hints and replace generic typevars
|
||||
# with concrete types that were passed.
|
||||
_prepare_model_fields(created_model, fields, instance_type_hints, typevars_map)
|
||||
|
||||
return created_model
|
||||
|
||||
@classmethod
|
||||
def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str:
|
||||
"""Compute class name for child classes.
|
||||
|
||||
:param params: Tuple of types the class . Given a generic class
|
||||
`Model` with 2 type variables and a concrete model `Model[str, int]`,
|
||||
the value `(str, int)` would be passed to `params`.
|
||||
:return: String representing a the new class where `params` are
|
||||
passed to `cls` as type variables.
|
||||
|
||||
This method can be overridden to achieve a custom naming scheme for GenericModels.
|
||||
"""
|
||||
param_names = [display_as_type(param) for param in params]
|
||||
params_component = ', '.join(param_names)
|
||||
return f'{cls.__name__}[{params_component}]'
|
||||
|
||||
@classmethod
|
||||
def __parameterized_bases__(cls, typevars_map: Parametrization) -> Iterator[Type[Any]]:
|
||||
"""
|
||||
Returns unbound bases of cls parameterised to given type variables
|
||||
|
||||
:param typevars_map: Dictionary of type applications for binding subclasses.
|
||||
Given a generic class `Model` with 2 type variables [S, T]
|
||||
and a concrete model `Model[str, int]`,
|
||||
the value `{S: str, T: int}` would be passed to `typevars_map`.
|
||||
:return: an iterator of generic sub classes, parameterised by `typevars_map`
|
||||
and other assigned parameters of `cls`
|
||||
|
||||
e.g.:
|
||||
```
|
||||
class A(GenericModel, Generic[T]):
|
||||
...
|
||||
|
||||
class B(A[V], Generic[V]):
|
||||
...
|
||||
|
||||
assert A[int] in B.__parameterized_bases__({V: int})
|
||||
```
|
||||
"""
|
||||
|
||||
def build_base_model(
|
||||
base_model: Type[GenericModel], mapped_types: Parametrization
|
||||
) -> Iterator[Type[GenericModel]]:
|
||||
base_parameters = tuple(mapped_types[param] for param in base_model.__parameters__)
|
||||
parameterized_base = base_model.__class_getitem__(base_parameters)
|
||||
if parameterized_base is base_model or parameterized_base is cls:
|
||||
# Avoid duplication in MRO
|
||||
return
|
||||
yield parameterized_base
|
||||
|
||||
for base_model in cls.__bases__:
|
||||
if not issubclass(base_model, GenericModel):
|
||||
# not a class that can be meaningfully parameterized
|
||||
continue
|
||||
elif not getattr(base_model, '__parameters__', None):
|
||||
# base_model is "GenericModel" (and has no __parameters__)
|
||||
# or
|
||||
# base_model is already concrete, and will be included transitively via cls.
|
||||
continue
|
||||
elif cls in _assigned_parameters:
|
||||
if base_model in _assigned_parameters:
|
||||
# cls is partially parameterised but not from base_model
|
||||
# e.g. cls = B[S], base_model = A[S]
|
||||
# B[S][int] should subclass A[int], (and will be transitively via B[int])
|
||||
# but it's not viable to consistently subclass types with arbitrary construction
|
||||
# So don't attempt to include A[S][int]
|
||||
continue
|
||||
else: # base_model not in _assigned_parameters:
|
||||
# cls is partially parameterized, base_model is original generic
|
||||
# e.g. cls = B[str, T], base_model = B[S, T]
|
||||
# Need to determine the mapping for the base_model parameters
|
||||
mapped_types: Parametrization = {
|
||||
key: typevars_map.get(value, value) for key, value in _assigned_parameters[cls].items()
|
||||
}
|
||||
yield from build_base_model(base_model, mapped_types)
|
||||
else:
|
||||
# cls is base generic, so base_class has a distinct base
|
||||
# can construct the Parameterised base model using typevars_map directly
|
||||
yield from build_base_model(base_model, typevars_map)
|
||||
|
||||
|
||||
def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any:
|
||||
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
|
||||
|
||||
:param type_: Any type, class or generic alias
|
||||
:param type_map: Mapping from `TypeVar` instance to concrete types.
|
||||
:return: New type representing the basic structure of `type_` with all
|
||||
`typevar_map` keys recursively replaced.
|
||||
|
||||
>>> replace_types(Tuple[str, Union[List[str], float]], {str: int})
|
||||
Tuple[int, Union[List[int], float]]
|
||||
|
||||
"""
|
||||
if not type_map:
|
||||
return type_
|
||||
|
||||
type_args = get_args(type_)
|
||||
origin_type = get_origin(type_)
|
||||
|
||||
if origin_type is Annotated:
|
||||
annotated_type, *annotations = type_args
|
||||
return Annotated[replace_types(annotated_type, type_map), tuple(annotations)]
|
||||
|
||||
# Having type args is a good indicator that this is a typing module
|
||||
# class instantiation or a generic alias of some sort.
|
||||
if type_args:
|
||||
resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
|
||||
if all_identical(type_args, resolved_type_args):
|
||||
# If all arguments are the same, there is no need to modify the
|
||||
# type or create a new object at all
|
||||
return type_
|
||||
if (
|
||||
origin_type is not None
|
||||
and isinstance(type_, typing_base)
|
||||
and not isinstance(origin_type, typing_base)
|
||||
and getattr(type_, '_name', None) is not None
|
||||
):
|
||||
# In python < 3.9 generic aliases don't exist so any of these like `list`,
|
||||
# `type` or `collections.abc.Callable` need to be translated.
|
||||
# See: https://www.python.org/dev/peps/pep-0585
|
||||
origin_type = getattr(typing, type_._name)
|
||||
assert origin_type is not None
|
||||
return origin_type[resolved_type_args]
|
||||
|
||||
# We handle pydantic generic models separately as they don't have the same
|
||||
# semantics as "typing" classes or generic aliases
|
||||
if not origin_type and lenient_issubclass(type_, GenericModel) and not type_.__concrete__:
|
||||
type_args = type_.__parameters__
|
||||
resolved_type_args = tuple(replace_types(t, type_map) for t in type_args)
|
||||
if all_identical(type_args, resolved_type_args):
|
||||
return type_
|
||||
return type_[resolved_type_args]
|
||||
|
||||
# Handle special case for typehints that can have lists as arguments.
|
||||
# `typing.Callable[[int, str], int]` is an example for this.
|
||||
if isinstance(type_, (List, list)):
|
||||
resolved_list = list(replace_types(element, type_map) for element in type_)
|
||||
if all_identical(type_, resolved_list):
|
||||
return type_
|
||||
return resolved_list
|
||||
|
||||
# For JsonWrapperValue, need to handle its inner type to allow correct parsing
|
||||
# of generic Json arguments like Json[T]
|
||||
if not origin_type and lenient_issubclass(type_, JsonWrapper):
|
||||
type_.inner_type = replace_types(type_.inner_type, type_map)
|
||||
return type_
|
||||
|
||||
# If all else fails, we try to resolve the type directly and otherwise just
|
||||
# return the input with no modifications.
|
||||
return type_map.get(type_, type_)
|
||||
|
||||
|
||||
def check_parameters_count(cls: Type[GenericModel], parameters: Tuple[Any, ...]) -> None:
|
||||
actual = len(parameters)
|
||||
expected = len(cls.__parameters__)
|
||||
if actual != expected:
|
||||
description = 'many' if actual > expected else 'few'
|
||||
raise TypeError(f'Too {description} parameters for {cls.__name__}; actual {actual}, expected {expected}')
|
||||
|
||||
|
||||
DictValues: Type[Any] = {}.values().__class__
|
||||
|
||||
|
||||
def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
|
||||
"""Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found."""
|
||||
if isinstance(v, TypeVar):
|
||||
yield v
|
||||
elif hasattr(v, '__parameters__') and not get_origin(v) and lenient_issubclass(v, GenericModel):
|
||||
yield from v.__parameters__
|
||||
elif isinstance(v, (DictValues, list)):
|
||||
for var in v:
|
||||
yield from iter_contained_typevars(var)
|
||||
else:
|
||||
args = get_args(v)
|
||||
for arg in args:
|
||||
yield from iter_contained_typevars(arg)
|
||||
|
||||
|
||||
def get_caller_frame_info() -> Tuple[Optional[str], bool]:
|
||||
"""
|
||||
Used inside a function to check whether it was called globally
|
||||
|
||||
Will only work against non-compiled code, therefore used only in pydantic.generics
|
||||
|
||||
:returns Tuple[module_name, called_globally]
|
||||
"""
|
||||
try:
|
||||
previous_caller_frame = sys._getframe(2)
|
||||
except ValueError as e:
|
||||
raise RuntimeError('This function must be used inside another function') from e
|
||||
except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it
|
||||
return None, False
|
||||
frame_globals = previous_caller_frame.f_globals
|
||||
return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals
|
||||
|
||||
|
||||
def _prepare_model_fields(
|
||||
created_model: Type[GenericModel],
|
||||
fields: Mapping[str, Any],
|
||||
instance_type_hints: Mapping[str, type],
|
||||
typevars_map: Mapping[Any, type],
|
||||
) -> None:
|
||||
"""
|
||||
Replace DeferredType fields with concrete type hints and prepare them.
|
||||
"""
|
||||
|
||||
for key, field in created_model.__fields__.items():
|
||||
if key not in fields:
|
||||
assert field.type_.__class__ is not DeferredType
|
||||
# https://github.com/nedbat/coveragepy/issues/198
|
||||
continue # pragma: no cover
|
||||
|
||||
assert field.type_.__class__ is DeferredType, field.type_.__class__
|
||||
|
||||
field_type_hint = instance_type_hints[key]
|
||||
concrete_type = replace_types(field_type_hint, typevars_map)
|
||||
field.type_ = concrete_type
|
||||
field.outer_type_ = concrete_type
|
||||
field.prepare()
|
||||
created_model.__annotations__[key] = concrete_type
|
112
lib/pydantic/json.py
Normal file
112
lib/pydantic/json.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
import datetime
|
||||
from collections import deque
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from types import GeneratorType
|
||||
from typing import Any, Callable, Dict, Type, Union
|
||||
from uuid import UUID
|
||||
|
||||
from .color import Color
|
||||
from .networks import NameEmail
|
||||
from .types import SecretBytes, SecretStr
|
||||
|
||||
__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat'
|
||||
|
||||
|
||||
def isoformat(o: Union[datetime.date, datetime.time]) -> str:
|
||||
return o.isoformat()
|
||||
|
||||
|
||||
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
|
||||
"""
|
||||
Encodes a Decimal as int of there's no exponent, otherwise float
|
||||
|
||||
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
|
||||
where a integer (but not int typed) is used. Encoding this as a float
|
||||
results in failed round-tripping between encode and parse.
|
||||
Our Id type is a prime example of this.
|
||||
|
||||
>>> decimal_encoder(Decimal("1.0"))
|
||||
1.0
|
||||
|
||||
>>> decimal_encoder(Decimal("1"))
|
||||
1
|
||||
"""
|
||||
if dec_value.as_tuple().exponent >= 0:
|
||||
return int(dec_value)
|
||||
else:
|
||||
return float(dec_value)
|
||||
|
||||
|
||||
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
|
||||
bytes: lambda o: o.decode(),
|
||||
Color: str,
|
||||
datetime.date: isoformat,
|
||||
datetime.datetime: isoformat,
|
||||
datetime.time: isoformat,
|
||||
datetime.timedelta: lambda td: td.total_seconds(),
|
||||
Decimal: decimal_encoder,
|
||||
Enum: lambda o: o.value,
|
||||
frozenset: list,
|
||||
deque: list,
|
||||
GeneratorType: list,
|
||||
IPv4Address: str,
|
||||
IPv4Interface: str,
|
||||
IPv4Network: str,
|
||||
IPv6Address: str,
|
||||
IPv6Interface: str,
|
||||
IPv6Network: str,
|
||||
NameEmail: str,
|
||||
Path: str,
|
||||
Pattern: lambda o: o.pattern,
|
||||
SecretBytes: str,
|
||||
SecretStr: str,
|
||||
set: list,
|
||||
UUID: str,
|
||||
}
|
||||
|
||||
|
||||
def pydantic_encoder(obj: Any) -> Any:
|
||||
from dataclasses import asdict, is_dataclass
|
||||
|
||||
from .main import BaseModel
|
||||
|
||||
if isinstance(obj, BaseModel):
|
||||
return obj.dict()
|
||||
elif is_dataclass(obj):
|
||||
return asdict(obj)
|
||||
|
||||
# Check the class type and its superclasses for a matching encoder
|
||||
for base in obj.__class__.__mro__[:-1]:
|
||||
try:
|
||||
encoder = ENCODERS_BY_TYPE[base]
|
||||
except KeyError:
|
||||
continue
|
||||
return encoder(obj)
|
||||
else: # We have exited the for loop without finding a suitable encoder
|
||||
raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")
|
||||
|
||||
|
||||
def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any:
|
||||
# Check the class type and its superclasses for a matching encoder
|
||||
for base in obj.__class__.__mro__[:-1]:
|
||||
try:
|
||||
encoder = type_encoders[base]
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
return encoder(obj)
|
||||
else: # We have exited the for loop without finding a suitable encoder
|
||||
return pydantic_encoder(obj)
|
||||
|
||||
|
||||
def timedelta_isoformat(td: datetime.timedelta) -> str:
|
||||
"""
|
||||
ISO 8601 encoding for Python timedelta object.
|
||||
"""
|
||||
minutes, seconds = divmod(td.seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S'
|
1109
lib/pydantic/main.py
Normal file
1109
lib/pydantic/main.py
Normal file
File diff suppressed because it is too large
Load diff
850
lib/pydantic/mypy.py
Normal file
850
lib/pydantic/mypy.py
Normal file
|
@ -0,0 +1,850 @@
|
|||
import sys
|
||||
from configparser import ConfigParser
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type as TypingType, Union
|
||||
|
||||
from mypy.errorcodes import ErrorCode
|
||||
from mypy.nodes import (
|
||||
ARG_NAMED,
|
||||
ARG_NAMED_OPT,
|
||||
ARG_OPT,
|
||||
ARG_POS,
|
||||
ARG_STAR2,
|
||||
MDEF,
|
||||
Argument,
|
||||
AssignmentStmt,
|
||||
Block,
|
||||
CallExpr,
|
||||
ClassDef,
|
||||
Context,
|
||||
Decorator,
|
||||
EllipsisExpr,
|
||||
FuncBase,
|
||||
FuncDef,
|
||||
JsonDict,
|
||||
MemberExpr,
|
||||
NameExpr,
|
||||
PassStmt,
|
||||
PlaceholderNode,
|
||||
RefExpr,
|
||||
StrExpr,
|
||||
SymbolNode,
|
||||
SymbolTableNode,
|
||||
TempNode,
|
||||
TypeInfo,
|
||||
TypeVarExpr,
|
||||
Var,
|
||||
)
|
||||
from mypy.options import Options
|
||||
from mypy.plugin import (
|
||||
CheckerPluginInterface,
|
||||
ClassDefContext,
|
||||
FunctionContext,
|
||||
MethodContext,
|
||||
Plugin,
|
||||
SemanticAnalyzerPluginInterface,
|
||||
)
|
||||
from mypy.plugins import dataclasses
|
||||
from mypy.semanal import set_callable_name # type: ignore
|
||||
from mypy.server.trigger import make_wildcard_trigger
|
||||
from mypy.types import (
|
||||
AnyType,
|
||||
CallableType,
|
||||
Instance,
|
||||
NoneType,
|
||||
Overloaded,
|
||||
Type,
|
||||
TypeOfAny,
|
||||
TypeType,
|
||||
TypeVarType,
|
||||
UnionType,
|
||||
get_proper_type,
|
||||
)
|
||||
from mypy.typevars import fill_typevars
|
||||
from mypy.util import get_unique_redefinition_name
|
||||
from mypy.version import __version__ as mypy_version
|
||||
|
||||
from pydantic.utils import is_valid_field
|
||||
|
||||
try:
|
||||
from mypy.types import TypeVarDef # type: ignore[attr-defined]
|
||||
except ImportError: # pragma: no cover
|
||||
# Backward-compatible with TypeVarDef from Mypy 0.910.
|
||||
from mypy.types import TypeVarType as TypeVarDef
|
||||
|
||||
CONFIGFILE_KEY = 'pydantic-mypy'
|
||||
METADATA_KEY = 'pydantic-mypy-metadata'
|
||||
BASEMODEL_FULLNAME = 'pydantic.main.BaseModel'
|
||||
BASESETTINGS_FULLNAME = 'pydantic.env_settings.BaseSettings'
|
||||
FIELD_FULLNAME = 'pydantic.fields.Field'
|
||||
DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass'
|
||||
|
||||
|
||||
def parse_mypy_version(version: str) -> Tuple[int, ...]:
|
||||
return tuple(int(part) for part in version.split('+', 1)[0].split('.'))
|
||||
|
||||
|
||||
MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version)
|
||||
BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__'
|
||||
|
||||
|
||||
def plugin(version: str) -> 'TypingType[Plugin]':
|
||||
"""
|
||||
`version` is the mypy version string
|
||||
|
||||
We might want to use this to print a warning if the mypy version being used is
|
||||
newer, or especially older, than we expect (or need).
|
||||
"""
|
||||
return PydanticPlugin
|
||||
|
||||
|
||||
class PydanticPlugin(Plugin):
|
||||
def __init__(self, options: Options) -> None:
|
||||
self.plugin_config = PydanticPluginConfig(options)
|
||||
super().__init__(options)
|
||||
|
||||
def get_base_class_hook(self, fullname: str) -> 'Optional[Callable[[ClassDefContext], None]]':
|
||||
sym = self.lookup_fully_qualified(fullname)
|
||||
if sym and isinstance(sym.node, TypeInfo): # pragma: no branch
|
||||
# No branching may occur if the mypy cache has not been cleared
|
||||
if any(get_fullname(base) == BASEMODEL_FULLNAME for base in sym.node.mro):
|
||||
return self._pydantic_model_class_maker_callback
|
||||
return None
|
||||
|
||||
def get_function_hook(self, fullname: str) -> 'Optional[Callable[[FunctionContext], Type]]':
|
||||
sym = self.lookup_fully_qualified(fullname)
|
||||
if sym and sym.fullname == FIELD_FULLNAME:
|
||||
return self._pydantic_field_callback
|
||||
return None
|
||||
|
||||
def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], Type]]:
|
||||
if fullname.endswith('.from_orm'):
|
||||
return from_orm_callback
|
||||
return None
|
||||
|
||||
def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
|
||||
if fullname == DATACLASS_FULLNAME:
|
||||
return dataclasses.dataclass_class_maker_callback # type: ignore[return-value]
|
||||
return None
|
||||
|
||||
def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None:
|
||||
transformer = PydanticModelTransformer(ctx, self.plugin_config)
|
||||
transformer.transform()
|
||||
|
||||
def _pydantic_field_callback(self, ctx: FunctionContext) -> 'Type':
|
||||
"""
|
||||
Extract the type of the `default` argument from the Field function, and use it as the return type.
|
||||
|
||||
In particular:
|
||||
* Check whether the default and default_factory argument is specified.
|
||||
* Output an error if both are specified.
|
||||
* Retrieve the type of the argument which is specified, and use it as return type for the function.
|
||||
"""
|
||||
default_any_type = ctx.default_return_type
|
||||
|
||||
assert ctx.callee_arg_names[0] == 'default', '"default" is no longer first argument in Field()'
|
||||
assert ctx.callee_arg_names[1] == 'default_factory', '"default_factory" is no longer second argument in Field()'
|
||||
default_args = ctx.args[0]
|
||||
default_factory_args = ctx.args[1]
|
||||
|
||||
if default_args and default_factory_args:
|
||||
error_default_and_default_factory_specified(ctx.api, ctx.context)
|
||||
return default_any_type
|
||||
|
||||
if default_args:
|
||||
default_type = ctx.arg_types[0][0]
|
||||
default_arg = default_args[0]
|
||||
|
||||
# Fallback to default Any type if the field is required
|
||||
if not isinstance(default_arg, EllipsisExpr):
|
||||
return default_type
|
||||
|
||||
elif default_factory_args:
|
||||
default_factory_type = ctx.arg_types[1][0]
|
||||
|
||||
# Functions which use `ParamSpec` can be overloaded, exposing the callable's types as a parameter
|
||||
# Pydantic calls the default factory without any argument, so we retrieve the first item
|
||||
if isinstance(default_factory_type, Overloaded):
|
||||
if MYPY_VERSION_TUPLE > (0, 910):
|
||||
default_factory_type = default_factory_type.items[0]
|
||||
else:
|
||||
# Mypy0.910 exposes the items of overloaded types in a function
|
||||
default_factory_type = default_factory_type.items()[0] # type: ignore[operator]
|
||||
|
||||
if isinstance(default_factory_type, CallableType):
|
||||
ret_type = default_factory_type.ret_type
|
||||
# mypy doesn't think `ret_type` has `args`, you'd think mypy should know,
|
||||
# add this check in case it varies by version
|
||||
args = getattr(ret_type, 'args', None)
|
||||
if args:
|
||||
if all(isinstance(arg, TypeVarType) for arg in args):
|
||||
# Looks like the default factory is a type like `list` or `dict`, replace all args with `Any`
|
||||
ret_type.args = tuple(default_any_type for _ in args) # type: ignore[attr-defined]
|
||||
return ret_type
|
||||
|
||||
return default_any_type
|
||||
|
||||
|
||||
class PydanticPluginConfig:
|
||||
__slots__ = ('init_forbid_extra', 'init_typed', 'warn_required_dynamic_aliases', 'warn_untyped_fields')
|
||||
init_forbid_extra: bool
|
||||
init_typed: bool
|
||||
warn_required_dynamic_aliases: bool
|
||||
warn_untyped_fields: bool
|
||||
|
||||
def __init__(self, options: Options) -> None:
|
||||
if options.config_file is None: # pragma: no cover
|
||||
return
|
||||
|
||||
toml_config = parse_toml(options.config_file)
|
||||
if toml_config is not None:
|
||||
config = toml_config.get('tool', {}).get('pydantic-mypy', {})
|
||||
for key in self.__slots__:
|
||||
setting = config.get(key, False)
|
||||
if not isinstance(setting, bool):
|
||||
raise ValueError(f'Configuration value must be a boolean for key: {key}')
|
||||
setattr(self, key, setting)
|
||||
else:
|
||||
plugin_config = ConfigParser()
|
||||
plugin_config.read(options.config_file)
|
||||
for key in self.__slots__:
|
||||
setting = plugin_config.getboolean(CONFIGFILE_KEY, key, fallback=False)
|
||||
setattr(self, key, setting)
|
||||
|
||||
|
||||
def from_orm_callback(ctx: MethodContext) -> Type:
|
||||
"""
|
||||
Raise an error if orm_mode is not enabled
|
||||
"""
|
||||
model_type: Instance
|
||||
if isinstance(ctx.type, CallableType) and isinstance(ctx.type.ret_type, Instance):
|
||||
model_type = ctx.type.ret_type # called on the class
|
||||
elif isinstance(ctx.type, Instance):
|
||||
model_type = ctx.type # called on an instance (unusual, but still valid)
|
||||
else: # pragma: no cover
|
||||
detail = f'ctx.type: {ctx.type} (of type {ctx.type.__class__.__name__})'
|
||||
error_unexpected_behavior(detail, ctx.api, ctx.context)
|
||||
return ctx.default_return_type
|
||||
pydantic_metadata = model_type.type.metadata.get(METADATA_KEY)
|
||||
if pydantic_metadata is None:
|
||||
return ctx.default_return_type
|
||||
orm_mode = pydantic_metadata.get('config', {}).get('orm_mode')
|
||||
if orm_mode is not True:
|
||||
error_from_orm(get_name(model_type.type), ctx.api, ctx.context)
|
||||
return ctx.default_return_type
|
||||
|
||||
|
||||
class PydanticModelTransformer:
|
||||
tracked_config_fields: Set[str] = {
|
||||
'extra',
|
||||
'allow_mutation',
|
||||
'frozen',
|
||||
'orm_mode',
|
||||
'allow_population_by_field_name',
|
||||
'alias_generator',
|
||||
}
|
||||
|
||||
def __init__(self, ctx: ClassDefContext, plugin_config: PydanticPluginConfig) -> None:
|
||||
self._ctx = ctx
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
def transform(self) -> None:
|
||||
"""
|
||||
Configures the BaseModel subclass according to the plugin settings.
|
||||
|
||||
In particular:
|
||||
* determines the model config and fields,
|
||||
* adds a fields-aware signature for the initializer and construct methods
|
||||
* freezes the class if allow_mutation = False or frozen = True
|
||||
* stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses
|
||||
"""
|
||||
ctx = self._ctx
|
||||
info = self._ctx.cls.info
|
||||
|
||||
self.adjust_validator_signatures()
|
||||
config = self.collect_config()
|
||||
fields = self.collect_fields(config)
|
||||
for field in fields:
|
||||
if info[field.name].type is None:
|
||||
if not ctx.api.final_iteration:
|
||||
ctx.api.defer()
|
||||
is_settings = any(get_fullname(base) == BASESETTINGS_FULLNAME for base in info.mro[:-1])
|
||||
self.add_initializer(fields, config, is_settings)
|
||||
self.add_construct_method(fields)
|
||||
self.set_frozen(fields, frozen=config.allow_mutation is False or config.frozen is True)
|
||||
info.metadata[METADATA_KEY] = {
|
||||
'fields': {field.name: field.serialize() for field in fields},
|
||||
'config': config.set_values_dict(),
|
||||
}
|
||||
|
||||
def adjust_validator_signatures(self) -> None:
|
||||
"""When we decorate a function `f` with `pydantic.validator(...), mypy sees
|
||||
`f` as a regular method taking a `self` instance, even though pydantic
|
||||
internally wraps `f` with `classmethod` if necessary.
|
||||
|
||||
Teach mypy this by marking any function whose outermost decorator is a
|
||||
`validator()` call as a classmethod.
|
||||
"""
|
||||
for name, sym in self._ctx.cls.info.names.items():
|
||||
if isinstance(sym.node, Decorator):
|
||||
first_dec = sym.node.original_decorators[0]
|
||||
if (
|
||||
isinstance(first_dec, CallExpr)
|
||||
and isinstance(first_dec.callee, NameExpr)
|
||||
and first_dec.callee.fullname == 'pydantic.class_validators.validator'
|
||||
):
|
||||
sym.node.func.is_class = True
|
||||
|
||||
def collect_config(self) -> 'ModelConfigData':
|
||||
"""
|
||||
Collects the values of the config attributes that are used by the plugin, accounting for parent classes.
|
||||
"""
|
||||
ctx = self._ctx
|
||||
cls = ctx.cls
|
||||
config = ModelConfigData()
|
||||
for stmt in cls.defs.body:
|
||||
if not isinstance(stmt, ClassDef):
|
||||
continue
|
||||
if stmt.name == 'Config':
|
||||
for substmt in stmt.defs.body:
|
||||
if not isinstance(substmt, AssignmentStmt):
|
||||
continue
|
||||
config.update(self.get_config_update(substmt))
|
||||
if (
|
||||
config.has_alias_generator
|
||||
and not config.allow_population_by_field_name
|
||||
and self.plugin_config.warn_required_dynamic_aliases
|
||||
):
|
||||
error_required_dynamic_aliases(ctx.api, stmt)
|
||||
for info in cls.info.mro[1:]: # 0 is the current class
|
||||
if METADATA_KEY not in info.metadata:
|
||||
continue
|
||||
|
||||
# Each class depends on the set of fields in its ancestors
|
||||
ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info)))
|
||||
for name, value in info.metadata[METADATA_KEY]['config'].items():
|
||||
config.setdefault(name, value)
|
||||
return config
|
||||
|
||||
def collect_fields(self, model_config: 'ModelConfigData') -> List['PydanticModelField']:
|
||||
"""
|
||||
Collects the fields for the model, accounting for parent classes
|
||||
"""
|
||||
# First, collect fields belonging to the current class.
|
||||
ctx = self._ctx
|
||||
cls = self._ctx.cls
|
||||
fields = [] # type: List[PydanticModelField]
|
||||
known_fields = set() # type: Set[str]
|
||||
for stmt in cls.defs.body:
|
||||
if not isinstance(stmt, AssignmentStmt): # `and stmt.new_syntax` to require annotation
|
||||
continue
|
||||
|
||||
lhs = stmt.lvalues[0]
|
||||
if not isinstance(lhs, NameExpr) or not is_valid_field(lhs.name):
|
||||
continue
|
||||
|
||||
if not stmt.new_syntax and self.plugin_config.warn_untyped_fields:
|
||||
error_untyped_fields(ctx.api, stmt)
|
||||
|
||||
# if lhs.name == '__config__': # BaseConfig not well handled; I'm not sure why yet
|
||||
# continue
|
||||
|
||||
sym = cls.info.names.get(lhs.name)
|
||||
if sym is None: # pragma: no cover
|
||||
# This is likely due to a star import (see the dataclasses plugin for a more detailed explanation)
|
||||
# This is the same logic used in the dataclasses plugin
|
||||
continue
|
||||
|
||||
node = sym.node
|
||||
if isinstance(node, PlaceholderNode): # pragma: no cover
|
||||
# See the PlaceholderNode docstring for more detail about how this can occur
|
||||
# Basically, it is an edge case when dealing with complex import logic
|
||||
# This is the same logic used in the dataclasses plugin
|
||||
continue
|
||||
if not isinstance(node, Var): # pragma: no cover
|
||||
# Don't know if this edge case still happens with the `is_valid_field` check above
|
||||
# but better safe than sorry
|
||||
continue
|
||||
|
||||
# x: ClassVar[int] is ignored by dataclasses.
|
||||
if node.is_classvar:
|
||||
continue
|
||||
|
||||
is_required = self.get_is_required(cls, stmt, lhs)
|
||||
alias, has_dynamic_alias = self.get_alias_info(stmt)
|
||||
if (
|
||||
has_dynamic_alias
|
||||
and not model_config.allow_population_by_field_name
|
||||
and self.plugin_config.warn_required_dynamic_aliases
|
||||
):
|
||||
error_required_dynamic_aliases(ctx.api, stmt)
|
||||
fields.append(
|
||||
PydanticModelField(
|
||||
name=lhs.name,
|
||||
is_required=is_required,
|
||||
alias=alias,
|
||||
has_dynamic_alias=has_dynamic_alias,
|
||||
line=stmt.line,
|
||||
column=stmt.column,
|
||||
)
|
||||
)
|
||||
known_fields.add(lhs.name)
|
||||
all_fields = fields.copy()
|
||||
for info in cls.info.mro[1:]: # 0 is the current class, -2 is BaseModel, -1 is object
|
||||
if METADATA_KEY not in info.metadata:
|
||||
continue
|
||||
|
||||
superclass_fields = []
|
||||
# Each class depends on the set of fields in its ancestors
|
||||
ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info)))
|
||||
|
||||
for name, data in info.metadata[METADATA_KEY]['fields'].items():
|
||||
if name not in known_fields:
|
||||
field = PydanticModelField.deserialize(info, data)
|
||||
known_fields.add(name)
|
||||
superclass_fields.append(field)
|
||||
else:
|
||||
(field,) = (a for a in all_fields if a.name == name)
|
||||
all_fields.remove(field)
|
||||
superclass_fields.append(field)
|
||||
all_fields = superclass_fields + all_fields
|
||||
return all_fields
|
||||
|
||||
def add_initializer(self, fields: List['PydanticModelField'], config: 'ModelConfigData', is_settings: bool) -> None:
|
||||
"""
|
||||
Adds a fields-aware `__init__` method to the class.
|
||||
|
||||
The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings.
|
||||
"""
|
||||
ctx = self._ctx
|
||||
typed = self.plugin_config.init_typed
|
||||
use_alias = config.allow_population_by_field_name is not True
|
||||
force_all_optional = is_settings or bool(
|
||||
config.has_alias_generator and not config.allow_population_by_field_name
|
||||
)
|
||||
init_arguments = self.get_field_arguments(
|
||||
fields, typed=typed, force_all_optional=force_all_optional, use_alias=use_alias
|
||||
)
|
||||
if not self.should_init_forbid_extra(fields, config):
|
||||
var = Var('kwargs')
|
||||
init_arguments.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2))
|
||||
|
||||
if '__init__' not in ctx.cls.info.names:
|
||||
add_method(ctx, '__init__', init_arguments, NoneType())
|
||||
|
||||
def add_construct_method(self, fields: List['PydanticModelField']) -> None:
|
||||
"""
|
||||
Adds a fully typed `construct` classmethod to the class.
|
||||
|
||||
Similar to the fields-aware __init__ method, but always uses the field names (not aliases),
|
||||
and does not treat settings fields as optional.
|
||||
"""
|
||||
ctx = self._ctx
|
||||
set_str = ctx.api.named_type(f'{BUILTINS_NAME}.set', [ctx.api.named_type(f'{BUILTINS_NAME}.str')])
|
||||
optional_set_str = UnionType([set_str, NoneType()])
|
||||
fields_set_argument = Argument(Var('_fields_set', optional_set_str), optional_set_str, None, ARG_OPT)
|
||||
construct_arguments = self.get_field_arguments(fields, typed=True, force_all_optional=False, use_alias=False)
|
||||
construct_arguments = [fields_set_argument] + construct_arguments
|
||||
|
||||
obj_type = ctx.api.named_type(f'{BUILTINS_NAME}.object')
|
||||
self_tvar_name = '_PydanticBaseModel' # Make sure it does not conflict with other names in the class
|
||||
tvar_fullname = ctx.cls.fullname + '.' + self_tvar_name
|
||||
tvd = TypeVarDef(self_tvar_name, tvar_fullname, -1, [], obj_type)
|
||||
self_tvar_expr = TypeVarExpr(self_tvar_name, tvar_fullname, [], obj_type)
|
||||
ctx.cls.info.names[self_tvar_name] = SymbolTableNode(MDEF, self_tvar_expr)
|
||||
|
||||
# Backward-compatible with TypeVarDef from Mypy 0.910.
|
||||
if isinstance(tvd, TypeVarType):
|
||||
self_type = tvd
|
||||
else:
|
||||
self_type = TypeVarType(tvd) # type: ignore[call-arg]
|
||||
|
||||
add_method(
|
||||
ctx,
|
||||
'construct',
|
||||
construct_arguments,
|
||||
return_type=self_type,
|
||||
self_type=self_type,
|
||||
tvar_def=tvd,
|
||||
is_classmethod=True,
|
||||
)
|
||||
|
||||
def set_frozen(self, fields: List['PydanticModelField'], frozen: bool) -> None:
|
||||
"""
|
||||
Marks all fields as properties so that attempts to set them trigger mypy errors.
|
||||
|
||||
This is the same approach used by the attrs and dataclasses plugins.
|
||||
"""
|
||||
info = self._ctx.cls.info
|
||||
for field in fields:
|
||||
sym_node = info.names.get(field.name)
|
||||
if sym_node is not None:
|
||||
var = sym_node.node
|
||||
assert isinstance(var, Var)
|
||||
var.is_property = frozen
|
||||
else:
|
||||
var = field.to_var(info, use_alias=False)
|
||||
var.info = info
|
||||
var.is_property = frozen
|
||||
var._fullname = get_fullname(info) + '.' + get_name(var)
|
||||
info.names[get_name(var)] = SymbolTableNode(MDEF, var)
|
||||
|
||||
def get_config_update(self, substmt: AssignmentStmt) -> Optional['ModelConfigData']:
|
||||
"""
|
||||
Determines the config update due to a single statement in the Config class definition.
|
||||
|
||||
Warns if a tracked config attribute is set to a value the plugin doesn't know how to interpret (e.g., an int)
|
||||
"""
|
||||
lhs = substmt.lvalues[0]
|
||||
if not (isinstance(lhs, NameExpr) and lhs.name in self.tracked_config_fields):
|
||||
return None
|
||||
if lhs.name == 'extra':
|
||||
if isinstance(substmt.rvalue, StrExpr):
|
||||
forbid_extra = substmt.rvalue.value == 'forbid'
|
||||
elif isinstance(substmt.rvalue, MemberExpr):
|
||||
forbid_extra = substmt.rvalue.name == 'forbid'
|
||||
else:
|
||||
error_invalid_config_value(lhs.name, self._ctx.api, substmt)
|
||||
return None
|
||||
return ModelConfigData(forbid_extra=forbid_extra)
|
||||
if lhs.name == 'alias_generator':
|
||||
has_alias_generator = True
|
||||
if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname == 'builtins.None':
|
||||
has_alias_generator = False
|
||||
return ModelConfigData(has_alias_generator=has_alias_generator)
|
||||
if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname in ('builtins.True', 'builtins.False'):
|
||||
return ModelConfigData(**{lhs.name: substmt.rvalue.fullname == 'builtins.True'})
|
||||
error_invalid_config_value(lhs.name, self._ctx.api, substmt)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_is_required(cls: ClassDef, stmt: AssignmentStmt, lhs: NameExpr) -> bool:
|
||||
"""
|
||||
Returns a boolean indicating whether the field defined in `stmt` is a required field.
|
||||
"""
|
||||
expr = stmt.rvalue
|
||||
if isinstance(expr, TempNode):
|
||||
# TempNode means annotation-only, so only non-required if Optional
|
||||
value_type = get_proper_type(cls.info[lhs.name].type)
|
||||
if isinstance(value_type, UnionType) and any(isinstance(item, NoneType) for item in value_type.items):
|
||||
# Annotated as Optional, or otherwise having NoneType in the union
|
||||
return False
|
||||
return True
|
||||
if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME:
|
||||
# The "default value" is a call to `Field`; at this point, the field is
|
||||
# only required if default is Ellipsis (i.e., `field_name: Annotation = Field(...)`) or if default_factory
|
||||
# is specified.
|
||||
for arg, name in zip(expr.args, expr.arg_names):
|
||||
# If name is None, then this arg is the default because it is the only positonal argument.
|
||||
if name is None or name == 'default':
|
||||
return arg.__class__ is EllipsisExpr
|
||||
if name == 'default_factory':
|
||||
return False
|
||||
return True
|
||||
# Only required if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`)
|
||||
return isinstance(expr, EllipsisExpr)
|
||||
|
||||
@staticmethod
|
||||
def get_alias_info(stmt: AssignmentStmt) -> Tuple[Optional[str], bool]:
|
||||
"""
|
||||
Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`.
|
||||
|
||||
`has_dynamic_alias` is True if and only if an alias is provided, but not as a string literal.
|
||||
If `has_dynamic_alias` is True, `alias` will be None.
|
||||
"""
|
||||
expr = stmt.rvalue
|
||||
if isinstance(expr, TempNode):
|
||||
# TempNode means annotation-only
|
||||
return None, False
|
||||
|
||||
if not (
|
||||
isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME
|
||||
):
|
||||
# Assigned value is not a call to pydantic.fields.Field
|
||||
return None, False
|
||||
|
||||
for i, arg_name in enumerate(expr.arg_names):
|
||||
if arg_name != 'alias':
|
||||
continue
|
||||
arg = expr.args[i]
|
||||
if isinstance(arg, StrExpr):
|
||||
return arg.value, False
|
||||
else:
|
||||
return None, True
|
||||
return None, False
|
||||
|
||||
def get_field_arguments(
|
||||
self, fields: List['PydanticModelField'], typed: bool, force_all_optional: bool, use_alias: bool
|
||||
) -> List[Argument]:
|
||||
"""
|
||||
Helper function used during the construction of the `__init__` and `construct` method signatures.
|
||||
|
||||
Returns a list of mypy Argument instances for use in the generated signatures.
|
||||
"""
|
||||
info = self._ctx.cls.info
|
||||
arguments = [
|
||||
field.to_argument(info, typed=typed, force_optional=force_all_optional, use_alias=use_alias)
|
||||
for field in fields
|
||||
if not (use_alias and field.has_dynamic_alias)
|
||||
]
|
||||
return arguments
|
||||
|
||||
def should_init_forbid_extra(self, fields: List['PydanticModelField'], config: 'ModelConfigData') -> bool:
|
||||
"""
|
||||
Indicates whether the generated `__init__` should get a `**kwargs` at the end of its signature
|
||||
|
||||
We disallow arbitrary kwargs if the extra config setting is "forbid", or if the plugin config says to,
|
||||
*unless* a required dynamic alias is present (since then we can't determine a valid signature).
|
||||
"""
|
||||
if not config.allow_population_by_field_name:
|
||||
if self.is_dynamic_alias_present(fields, bool(config.has_alias_generator)):
|
||||
return False
|
||||
if config.forbid_extra:
|
||||
return True
|
||||
return self.plugin_config.init_forbid_extra
|
||||
|
||||
@staticmethod
|
||||
def is_dynamic_alias_present(fields: List['PydanticModelField'], has_alias_generator: bool) -> bool:
|
||||
"""
|
||||
Returns whether any fields on the model have a "dynamic alias", i.e., an alias that cannot be
|
||||
determined during static analysis.
|
||||
"""
|
||||
for field in fields:
|
||||
if field.has_dynamic_alias:
|
||||
return True
|
||||
if has_alias_generator:
|
||||
for field in fields:
|
||||
if field.alias is None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class PydanticModelField:
|
||||
def __init__(
|
||||
self, name: str, is_required: bool, alias: Optional[str], has_dynamic_alias: bool, line: int, column: int
|
||||
):
|
||||
self.name = name
|
||||
self.is_required = is_required
|
||||
self.alias = alias
|
||||
self.has_dynamic_alias = has_dynamic_alias
|
||||
self.line = line
|
||||
self.column = column
|
||||
|
||||
def to_var(self, info: TypeInfo, use_alias: bool) -> Var:
|
||||
name = self.name
|
||||
if use_alias and self.alias is not None:
|
||||
name = self.alias
|
||||
return Var(name, info[self.name].type)
|
||||
|
||||
def to_argument(self, info: TypeInfo, typed: bool, force_optional: bool, use_alias: bool) -> Argument:
|
||||
if typed and info[self.name].type is not None:
|
||||
type_annotation = info[self.name].type
|
||||
else:
|
||||
type_annotation = AnyType(TypeOfAny.explicit)
|
||||
return Argument(
|
||||
variable=self.to_var(info, use_alias),
|
||||
type_annotation=type_annotation,
|
||||
initializer=None,
|
||||
kind=ARG_NAMED_OPT if force_optional or not self.is_required else ARG_NAMED,
|
||||
)
|
||||
|
||||
def serialize(self) -> JsonDict:
|
||||
return self.__dict__
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, info: TypeInfo, data: JsonDict) -> 'PydanticModelField':
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class ModelConfigData:
|
||||
def __init__(
|
||||
self,
|
||||
forbid_extra: Optional[bool] = None,
|
||||
allow_mutation: Optional[bool] = None,
|
||||
frozen: Optional[bool] = None,
|
||||
orm_mode: Optional[bool] = None,
|
||||
allow_population_by_field_name: Optional[bool] = None,
|
||||
has_alias_generator: Optional[bool] = None,
|
||||
):
|
||||
self.forbid_extra = forbid_extra
|
||||
self.allow_mutation = allow_mutation
|
||||
self.frozen = frozen
|
||||
self.orm_mode = orm_mode
|
||||
self.allow_population_by_field_name = allow_population_by_field_name
|
||||
self.has_alias_generator = has_alias_generator
|
||||
|
||||
def set_values_dict(self) -> Dict[str, Any]:
|
||||
return {k: v for k, v in self.__dict__.items() if v is not None}
|
||||
|
||||
def update(self, config: Optional['ModelConfigData']) -> None:
|
||||
if config is None:
|
||||
return
|
||||
for k, v in config.set_values_dict().items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def setdefault(self, key: str, value: Any) -> None:
|
||||
if getattr(self, key) is None:
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
ERROR_ORM = ErrorCode('pydantic-orm', 'Invalid from_orm call', 'Pydantic')
|
||||
ERROR_CONFIG = ErrorCode('pydantic-config', 'Invalid config value', 'Pydantic')
|
||||
ERROR_ALIAS = ErrorCode('pydantic-alias', 'Dynamic alias disallowed', 'Pydantic')
|
||||
ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic')
|
||||
ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic')
|
||||
ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic')
|
||||
|
||||
|
||||
def error_from_orm(model_name: str, api: CheckerPluginInterface, context: Context) -> None:
|
||||
api.fail(f'"{model_name}" does not have orm_mode=True', context, code=ERROR_ORM)
|
||||
|
||||
|
||||
def error_invalid_config_value(name: str, api: SemanticAnalyzerPluginInterface, context: Context) -> None:
|
||||
api.fail(f'Invalid value for "Config.{name}"', context, code=ERROR_CONFIG)
|
||||
|
||||
|
||||
def error_required_dynamic_aliases(api: SemanticAnalyzerPluginInterface, context: Context) -> None:
|
||||
api.fail('Required dynamic aliases disallowed', context, code=ERROR_ALIAS)
|
||||
|
||||
|
||||
def error_unexpected_behavior(detail: str, api: CheckerPluginInterface, context: Context) -> None: # pragma: no cover
|
||||
# Can't think of a good way to test this, but I confirmed it renders as desired by adding to a non-error path
|
||||
link = 'https://github.com/pydantic/pydantic/issues/new/choose'
|
||||
full_message = f'The pydantic mypy plugin ran into unexpected behavior: {detail}\n'
|
||||
full_message += f'Please consider reporting this bug at {link} so we can try to fix it!'
|
||||
api.fail(full_message, context, code=ERROR_UNEXPECTED)
|
||||
|
||||
|
||||
def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context) -> None:
|
||||
api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED)
|
||||
|
||||
|
||||
def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None:
|
||||
api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS)
|
||||
|
||||
|
||||
def add_method(
|
||||
ctx: ClassDefContext,
|
||||
name: str,
|
||||
args: List[Argument],
|
||||
return_type: Type,
|
||||
self_type: Optional[Type] = None,
|
||||
tvar_def: Optional[TypeVarDef] = None,
|
||||
is_classmethod: bool = False,
|
||||
is_new: bool = False,
|
||||
# is_staticmethod: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Adds a new method to a class.
|
||||
|
||||
This can be dropped if/when https://github.com/python/mypy/issues/7301 is merged
|
||||
"""
|
||||
info = ctx.cls.info
|
||||
|
||||
# First remove any previously generated methods with the same name
|
||||
# to avoid clashes and problems in the semantic analyzer.
|
||||
if name in info.names:
|
||||
sym = info.names[name]
|
||||
if sym.plugin_generated and isinstance(sym.node, FuncDef):
|
||||
ctx.cls.defs.body.remove(sym.node) # pragma: no cover
|
||||
|
||||
self_type = self_type or fill_typevars(info)
|
||||
if is_classmethod or is_new:
|
||||
first = [Argument(Var('_cls'), TypeType.make_normalized(self_type), None, ARG_POS)]
|
||||
# elif is_staticmethod:
|
||||
# first = []
|
||||
else:
|
||||
self_type = self_type or fill_typevars(info)
|
||||
first = [Argument(Var('__pydantic_self__'), self_type, None, ARG_POS)]
|
||||
args = first + args
|
||||
arg_types, arg_names, arg_kinds = [], [], []
|
||||
for arg in args:
|
||||
assert arg.type_annotation, 'All arguments must be fully typed.'
|
||||
arg_types.append(arg.type_annotation)
|
||||
arg_names.append(get_name(arg.variable))
|
||||
arg_kinds.append(arg.kind)
|
||||
|
||||
function_type = ctx.api.named_type(f'{BUILTINS_NAME}.function')
|
||||
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
|
||||
if tvar_def:
|
||||
signature.variables = [tvar_def]
|
||||
|
||||
func = FuncDef(name, args, Block([PassStmt()]))
|
||||
func.info = info
|
||||
func.type = set_callable_name(signature, func)
|
||||
func.is_class = is_classmethod
|
||||
# func.is_static = is_staticmethod
|
||||
func._fullname = get_fullname(info) + '.' + name
|
||||
func.line = info.line
|
||||
|
||||
# NOTE: we would like the plugin generated node to dominate, but we still
|
||||
# need to keep any existing definitions so they get semantically analyzed.
|
||||
if name in info.names:
|
||||
# Get a nice unique name instead.
|
||||
r_name = get_unique_redefinition_name(name, info.names)
|
||||
info.names[r_name] = info.names[name]
|
||||
|
||||
if is_classmethod: # or is_staticmethod:
|
||||
func.is_decorated = True
|
||||
v = Var(name, func.type)
|
||||
v.info = info
|
||||
v._fullname = func._fullname
|
||||
# if is_classmethod:
|
||||
v.is_classmethod = True
|
||||
dec = Decorator(func, [NameExpr('classmethod')], v)
|
||||
# else:
|
||||
# v.is_staticmethod = True
|
||||
# dec = Decorator(func, [NameExpr('staticmethod')], v)
|
||||
|
||||
dec.line = info.line
|
||||
sym = SymbolTableNode(MDEF, dec)
|
||||
else:
|
||||
sym = SymbolTableNode(MDEF, func)
|
||||
sym.plugin_generated = True
|
||||
|
||||
info.names[name] = sym
|
||||
info.defn.defs.body.append(func)
|
||||
|
||||
|
||||
def get_fullname(x: Union[FuncBase, SymbolNode]) -> str:
|
||||
"""
|
||||
Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped.
|
||||
"""
|
||||
fn = x.fullname
|
||||
if callable(fn): # pragma: no cover
|
||||
return fn()
|
||||
return fn
|
||||
|
||||
|
||||
def get_name(x: Union[FuncBase, SymbolNode]) -> str:
|
||||
"""
|
||||
Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped.
|
||||
"""
|
||||
fn = x.name
|
||||
if callable(fn): # pragma: no cover
|
||||
return fn()
|
||||
return fn
|
||||
|
||||
|
||||
def parse_toml(config_file: str) -> Optional[Dict[str, Any]]:
|
||||
if not config_file.endswith('.toml'):
|
||||
return None
|
||||
|
||||
read_mode = 'rb'
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib as toml_
|
||||
else:
|
||||
try:
|
||||
import tomli as toml_
|
||||
except ImportError:
|
||||
# older versions of mypy have toml as a dependency, not tomli
|
||||
read_mode = 'r'
|
||||
try:
|
||||
import toml as toml_ # type: ignore[no-redef]
|
||||
except ImportError: # pragma: no cover
|
||||
import warnings
|
||||
|
||||
warnings.warn('No TOML parser installed, cannot read configuration from `pyproject.toml`.')
|
||||
return None
|
||||
|
||||
with open(config_file, read_mode) as rf:
|
||||
return toml_.load(rf) # type: ignore[arg-type]
|
736
lib/pydantic/networks.py
Normal file
736
lib/pydantic/networks.py
Normal file
|
@ -0,0 +1,736 @@
|
|||
import re
|
||||
from ipaddress import (
|
||||
IPv4Address,
|
||||
IPv4Interface,
|
||||
IPv4Network,
|
||||
IPv6Address,
|
||||
IPv6Interface,
|
||||
IPv6Network,
|
||||
_BaseAddress,
|
||||
_BaseNetwork,
|
||||
)
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Match,
|
||||
Optional,
|
||||
Pattern,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
no_type_check,
|
||||
)
|
||||
|
||||
from . import errors
|
||||
from .utils import Representation, update_not_none
|
||||
from .validators import constr_length_validator, str_validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import email_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from .config import BaseConfig
|
||||
from .fields import ModelField
|
||||
from .typing import AnyCallable
|
||||
|
||||
CallableGenerator = Generator[AnyCallable, None, None]
|
||||
|
||||
class Parts(TypedDict, total=False):
|
||||
scheme: str
|
||||
user: Optional[str]
|
||||
password: Optional[str]
|
||||
ipv4: Optional[str]
|
||||
ipv6: Optional[str]
|
||||
domain: Optional[str]
|
||||
port: Optional[str]
|
||||
path: Optional[str]
|
||||
query: Optional[str]
|
||||
fragment: Optional[str]
|
||||
|
||||
class HostParts(TypedDict, total=False):
|
||||
host: str
|
||||
tld: Optional[str]
|
||||
host_type: Optional[str]
|
||||
port: Optional[str]
|
||||
rebuild: bool
|
||||
|
||||
else:
|
||||
email_validator = None
|
||||
|
||||
class Parts(dict):
|
||||
pass
|
||||
|
||||
|
||||
NetworkType = Union[str, bytes, int, Tuple[Union[str, bytes, int], Union[str, int]]]
|
||||
|
||||
__all__ = [
|
||||
'AnyUrl',
|
||||
'AnyHttpUrl',
|
||||
'FileUrl',
|
||||
'HttpUrl',
|
||||
'stricturl',
|
||||
'EmailStr',
|
||||
'NameEmail',
|
||||
'IPvAnyAddress',
|
||||
'IPvAnyInterface',
|
||||
'IPvAnyNetwork',
|
||||
'PostgresDsn',
|
||||
'CockroachDsn',
|
||||
'AmqpDsn',
|
||||
'RedisDsn',
|
||||
'MongoDsn',
|
||||
'KafkaDsn',
|
||||
'validate_email',
|
||||
]
|
||||
|
||||
_url_regex_cache = None
|
||||
_multi_host_url_regex_cache = None
|
||||
_ascii_domain_regex_cache = None
|
||||
_int_domain_regex_cache = None
|
||||
_host_regex_cache = None
|
||||
|
||||
_host_regex = (
|
||||
r'(?:'
|
||||
r'(?P<ipv4>(?:\d{1,3}\.){3}\d{1,3})(?=$|[/:#?])|' # ipv4
|
||||
r'(?P<ipv6>\[[A-F0-9]*:[A-F0-9:]+\])(?=$|[/:#?])|' # ipv6
|
||||
r'(?P<domain>[^\s/:?#]+)' # domain, validation occurs later
|
||||
r')?'
|
||||
r'(?::(?P<port>\d+))?' # port
|
||||
)
|
||||
_scheme_regex = r'(?:(?P<scheme>[a-z][a-z0-9+\-.]+)://)?' # scheme https://tools.ietf.org/html/rfc3986#appendix-A
|
||||
_user_info_regex = r'(?:(?P<user>[^\s:/]*)(?::(?P<password>[^\s/]*))?@)?'
|
||||
_path_regex = r'(?P<path>/[^\s?#]*)?'
|
||||
_query_regex = r'(?:\?(?P<query>[^\s#]*))?'
|
||||
_fragment_regex = r'(?:#(?P<fragment>[^\s#]*))?'
|
||||
|
||||
|
||||
def url_regex() -> Pattern[str]:
|
||||
global _url_regex_cache
|
||||
if _url_regex_cache is None:
|
||||
_url_regex_cache = re.compile(
|
||||
rf'{_scheme_regex}{_user_info_regex}{_host_regex}{_path_regex}{_query_regex}{_fragment_regex}',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return _url_regex_cache
|
||||
|
||||
|
||||
def multi_host_url_regex() -> Pattern[str]:
|
||||
"""
|
||||
Compiled multi host url regex.
|
||||
|
||||
Additionally to `url_regex` it allows to match multiple hosts.
|
||||
E.g. host1.db.net,host2.db.net
|
||||
"""
|
||||
global _multi_host_url_regex_cache
|
||||
if _multi_host_url_regex_cache is None:
|
||||
_multi_host_url_regex_cache = re.compile(
|
||||
rf'{_scheme_regex}{_user_info_regex}'
|
||||
r'(?P<hosts>([^/]*))' # validation occurs later
|
||||
rf'{_path_regex}{_query_regex}{_fragment_regex}',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return _multi_host_url_regex_cache
|
||||
|
||||
|
||||
def ascii_domain_regex() -> Pattern[str]:
|
||||
global _ascii_domain_regex_cache
|
||||
if _ascii_domain_regex_cache is None:
|
||||
ascii_chunk = r'[_0-9a-z](?:[-_0-9a-z]{0,61}[_0-9a-z])?'
|
||||
ascii_domain_ending = r'(?P<tld>\.[a-z]{2,63})?\.?'
|
||||
_ascii_domain_regex_cache = re.compile(
|
||||
fr'(?:{ascii_chunk}\.)*?{ascii_chunk}{ascii_domain_ending}', re.IGNORECASE
|
||||
)
|
||||
return _ascii_domain_regex_cache
|
||||
|
||||
|
||||
def int_domain_regex() -> Pattern[str]:
|
||||
global _int_domain_regex_cache
|
||||
if _int_domain_regex_cache is None:
|
||||
int_chunk = r'[_0-9a-\U00040000](?:[-_0-9a-\U00040000]{0,61}[_0-9a-\U00040000])?'
|
||||
int_domain_ending = r'(?P<tld>(\.[^\W\d_]{2,63})|(\.(?:xn--)[_0-9a-z-]{2,63}))?\.?'
|
||||
_int_domain_regex_cache = re.compile(fr'(?:{int_chunk}\.)*?{int_chunk}{int_domain_ending}', re.IGNORECASE)
|
||||
return _int_domain_regex_cache
|
||||
|
||||
|
||||
def host_regex() -> Pattern[str]:
|
||||
global _host_regex_cache
|
||||
if _host_regex_cache is None:
|
||||
_host_regex_cache = re.compile(
|
||||
_host_regex,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
return _host_regex_cache
|
||||
|
||||
|
||||
class AnyUrl(str):
|
||||
strip_whitespace = True
|
||||
min_length = 1
|
||||
max_length = 2**16
|
||||
allowed_schemes: Optional[Collection[str]] = None
|
||||
tld_required: bool = False
|
||||
user_required: bool = False
|
||||
host_required: bool = True
|
||||
hidden_parts: Set[str] = set()
|
||||
|
||||
__slots__ = ('scheme', 'user', 'password', 'host', 'tld', 'host_type', 'port', 'path', 'query', 'fragment')
|
||||
|
||||
@no_type_check
|
||||
def __new__(cls, url: Optional[str], **kwargs) -> object:
|
||||
return str.__new__(cls, cls.build(**kwargs) if url is None else url)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
scheme: str,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
tld: Optional[str] = None,
|
||||
host_type: str = 'domain',
|
||||
port: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
fragment: Optional[str] = None,
|
||||
) -> None:
|
||||
str.__init__(url)
|
||||
self.scheme = scheme
|
||||
self.user = user
|
||||
self.password = password
|
||||
self.host = host
|
||||
self.tld = tld
|
||||
self.host_type = host_type
|
||||
self.port = port
|
||||
self.path = path
|
||||
self.query = query
|
||||
self.fragment = fragment
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
*,
|
||||
scheme: str,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
host: str,
|
||||
port: Optional[str] = None,
|
||||
path: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
fragment: Optional[str] = None,
|
||||
**_kwargs: str,
|
||||
) -> str:
|
||||
parts = Parts(
|
||||
scheme=scheme,
|
||||
user=user,
|
||||
password=password,
|
||||
host=host,
|
||||
port=port,
|
||||
path=path,
|
||||
query=query,
|
||||
fragment=fragment,
|
||||
**_kwargs, # type: ignore[misc]
|
||||
)
|
||||
|
||||
url = scheme + '://'
|
||||
if user:
|
||||
url += user
|
||||
if password:
|
||||
url += ':' + password
|
||||
if user or password:
|
||||
url += '@'
|
||||
url += host
|
||||
if port and ('port' not in cls.hidden_parts or cls.get_default_parts(parts).get('port') != port):
|
||||
url += ':' + port
|
||||
if path:
|
||||
url += path
|
||||
if query:
|
||||
url += '?' + query
|
||||
if fragment:
|
||||
url += '#' + fragment
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
update_not_none(field_schema, minLength=cls.min_length, maxLength=cls.max_length, format='uri')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Any, field: 'ModelField', config: 'BaseConfig') -> 'AnyUrl':
|
||||
if value.__class__ == cls:
|
||||
return value
|
||||
value = str_validator(value)
|
||||
if cls.strip_whitespace:
|
||||
value = value.strip()
|
||||
url: str = cast(str, constr_length_validator(value, field, config))
|
||||
|
||||
m = cls._match_url(url)
|
||||
# the regex should always match, if it doesn't please report with details of the URL tried
|
||||
assert m, 'URL regex failed unexpectedly'
|
||||
|
||||
original_parts = cast('Parts', m.groupdict())
|
||||
parts = cls.apply_default_parts(original_parts)
|
||||
parts = cls.validate_parts(parts)
|
||||
|
||||
if m.end() != len(url):
|
||||
raise errors.UrlExtraError(extra=url[m.end() :])
|
||||
|
||||
return cls._build_url(m, url, parts)
|
||||
|
||||
@classmethod
|
||||
def _build_url(cls, m: Match[str], url: str, parts: 'Parts') -> 'AnyUrl':
|
||||
"""
|
||||
Validate hosts and build the AnyUrl object. Split from `validate` so this method
|
||||
can be altered in `MultiHostDsn`.
|
||||
"""
|
||||
host, tld, host_type, rebuild = cls.validate_host(parts)
|
||||
|
||||
return cls(
|
||||
None if rebuild else url,
|
||||
scheme=parts['scheme'],
|
||||
user=parts['user'],
|
||||
password=parts['password'],
|
||||
host=host,
|
||||
tld=tld,
|
||||
host_type=host_type,
|
||||
port=parts['port'],
|
||||
path=parts['path'],
|
||||
query=parts['query'],
|
||||
fragment=parts['fragment'],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _match_url(url: str) -> Optional[Match[str]]:
|
||||
return url_regex().match(url)
|
||||
|
||||
@staticmethod
|
||||
def _validate_port(port: Optional[str]) -> None:
|
||||
if port is not None and int(port) > 65_535:
|
||||
raise errors.UrlPortError()
|
||||
|
||||
@classmethod
|
||||
def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts':
|
||||
"""
|
||||
A method used to validate parts of a URL.
|
||||
Could be overridden to set default values for parts if missing
|
||||
"""
|
||||
scheme = parts['scheme']
|
||||
if scheme is None:
|
||||
raise errors.UrlSchemeError()
|
||||
|
||||
if cls.allowed_schemes and scheme.lower() not in cls.allowed_schemes:
|
||||
raise errors.UrlSchemePermittedError(set(cls.allowed_schemes))
|
||||
|
||||
if validate_port:
|
||||
cls._validate_port(parts['port'])
|
||||
|
||||
user = parts['user']
|
||||
if cls.user_required and user is None:
|
||||
raise errors.UrlUserInfoError()
|
||||
|
||||
return parts
|
||||
|
||||
@classmethod
|
||||
def validate_host(cls, parts: 'Parts') -> Tuple[str, Optional[str], str, bool]:
|
||||
tld, host_type, rebuild = None, None, False
|
||||
for f in ('domain', 'ipv4', 'ipv6'):
|
||||
host = parts[f] # type: ignore[literal-required]
|
||||
if host:
|
||||
host_type = f
|
||||
break
|
||||
|
||||
if host is None:
|
||||
if cls.host_required:
|
||||
raise errors.UrlHostError()
|
||||
elif host_type == 'domain':
|
||||
is_international = False
|
||||
d = ascii_domain_regex().fullmatch(host)
|
||||
if d is None:
|
||||
d = int_domain_regex().fullmatch(host)
|
||||
if d is None:
|
||||
raise errors.UrlHostError()
|
||||
is_international = True
|
||||
|
||||
tld = d.group('tld')
|
||||
if tld is None and not is_international:
|
||||
d = int_domain_regex().fullmatch(host)
|
||||
assert d is not None
|
||||
tld = d.group('tld')
|
||||
is_international = True
|
||||
|
||||
if tld is not None:
|
||||
tld = tld[1:]
|
||||
elif cls.tld_required:
|
||||
raise errors.UrlHostTldError()
|
||||
|
||||
if is_international:
|
||||
host_type = 'int_domain'
|
||||
rebuild = True
|
||||
host = host.encode('idna').decode('ascii')
|
||||
if tld is not None:
|
||||
tld = tld.encode('idna').decode('ascii')
|
||||
|
||||
return host, tld, host_type, rebuild # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def get_default_parts(parts: 'Parts') -> 'Parts':
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def apply_default_parts(cls, parts: 'Parts') -> 'Parts':
|
||||
for key, value in cls.get_default_parts(parts).items():
|
||||
if not parts[key]: # type: ignore[literal-required]
|
||||
parts[key] = value # type: ignore[literal-required]
|
||||
return parts
|
||||
|
||||
def __repr__(self) -> str:
|
||||
extra = ', '.join(f'{n}={getattr(self, n)!r}' for n in self.__slots__ if getattr(self, n) is not None)
|
||||
return f'{self.__class__.__name__}({super().__repr__()}, {extra})'
|
||||
|
||||
|
||||
class AnyHttpUrl(AnyUrl):
|
||||
allowed_schemes = {'http', 'https'}
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class HttpUrl(AnyHttpUrl):
|
||||
tld_required = True
|
||||
# https://stackoverflow.com/questions/417142/what-is-the-maximum-length-of-a-url-in-different-browsers
|
||||
max_length = 2083
|
||||
hidden_parts = {'port'}
|
||||
|
||||
@staticmethod
|
||||
def get_default_parts(parts: 'Parts') -> 'Parts':
|
||||
return {'port': '80' if parts['scheme'] == 'http' else '443'}
|
||||
|
||||
|
||||
class FileUrl(AnyUrl):
|
||||
allowed_schemes = {'file'}
|
||||
host_required = False
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class MultiHostDsn(AnyUrl):
|
||||
__slots__ = AnyUrl.__slots__ + ('hosts',)
|
||||
|
||||
def __init__(self, *args: Any, hosts: Optional[List['HostParts']] = None, **kwargs: Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.hosts = hosts
|
||||
|
||||
@staticmethod
|
||||
def _match_url(url: str) -> Optional[Match[str]]:
|
||||
return multi_host_url_regex().match(url)
|
||||
|
||||
@classmethod
|
||||
def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts':
|
||||
return super().validate_parts(parts, validate_port=False)
|
||||
|
||||
@classmethod
|
||||
def _build_url(cls, m: Match[str], url: str, parts: 'Parts') -> 'MultiHostDsn':
|
||||
hosts_parts: List['HostParts'] = []
|
||||
host_re = host_regex()
|
||||
for host in m.groupdict()['hosts'].split(','):
|
||||
d: Parts = host_re.match(host).groupdict() # type: ignore
|
||||
host, tld, host_type, rebuild = cls.validate_host(d)
|
||||
port = d.get('port')
|
||||
cls._validate_port(port)
|
||||
hosts_parts.append(
|
||||
{
|
||||
'host': host,
|
||||
'host_type': host_type,
|
||||
'tld': tld,
|
||||
'rebuild': rebuild,
|
||||
'port': port,
|
||||
}
|
||||
)
|
||||
|
||||
if len(hosts_parts) > 1:
|
||||
return cls(
|
||||
None if any([hp['rebuild'] for hp in hosts_parts]) else url,
|
||||
scheme=parts['scheme'],
|
||||
user=parts['user'],
|
||||
password=parts['password'],
|
||||
path=parts['path'],
|
||||
query=parts['query'],
|
||||
fragment=parts['fragment'],
|
||||
host_type=None,
|
||||
hosts=hosts_parts,
|
||||
)
|
||||
else:
|
||||
# backwards compatibility with single host
|
||||
host_part = hosts_parts[0]
|
||||
return cls(
|
||||
None if host_part['rebuild'] else url,
|
||||
scheme=parts['scheme'],
|
||||
user=parts['user'],
|
||||
password=parts['password'],
|
||||
host=host_part['host'],
|
||||
tld=host_part['tld'],
|
||||
host_type=host_part['host_type'],
|
||||
port=host_part.get('port'),
|
||||
path=parts['path'],
|
||||
query=parts['query'],
|
||||
fragment=parts['fragment'],
|
||||
)
|
||||
|
||||
|
||||
class PostgresDsn(MultiHostDsn):
|
||||
allowed_schemes = {
|
||||
'postgres',
|
||||
'postgresql',
|
||||
'postgresql+asyncpg',
|
||||
'postgresql+pg8000',
|
||||
'postgresql+psycopg2',
|
||||
'postgresql+psycopg2cffi',
|
||||
'postgresql+py-postgresql',
|
||||
'postgresql+pygresql',
|
||||
}
|
||||
user_required = True
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class CockroachDsn(AnyUrl):
|
||||
allowed_schemes = {
|
||||
'cockroachdb',
|
||||
'cockroachdb+psycopg2',
|
||||
'cockroachdb+asyncpg',
|
||||
}
|
||||
user_required = True
|
||||
|
||||
|
||||
class AmqpDsn(AnyUrl):
|
||||
allowed_schemes = {'amqp', 'amqps'}
|
||||
host_required = False
|
||||
|
||||
|
||||
class RedisDsn(AnyUrl):
|
||||
__slots__ = ()
|
||||
allowed_schemes = {'redis', 'rediss'}
|
||||
host_required = False
|
||||
|
||||
@staticmethod
|
||||
def get_default_parts(parts: 'Parts') -> 'Parts':
|
||||
return {
|
||||
'domain': 'localhost' if not (parts['ipv4'] or parts['ipv6']) else '',
|
||||
'port': '6379',
|
||||
'path': '/0',
|
||||
}
|
||||
|
||||
|
||||
class MongoDsn(AnyUrl):
|
||||
allowed_schemes = {'mongodb'}
|
||||
|
||||
# TODO: Needed to generic "Parts" for "Replica Set", "Sharded Cluster", and other mongodb deployment modes
|
||||
@staticmethod
|
||||
def get_default_parts(parts: 'Parts') -> 'Parts':
|
||||
return {
|
||||
'port': '27017',
|
||||
}
|
||||
|
||||
|
||||
class KafkaDsn(AnyUrl):
|
||||
allowed_schemes = {'kafka'}
|
||||
|
||||
@staticmethod
|
||||
def get_default_parts(parts: 'Parts') -> 'Parts':
|
||||
return {
|
||||
'domain': 'localhost',
|
||||
'port': '9092',
|
||||
}
|
||||
|
||||
|
||||
def stricturl(
|
||||
*,
|
||||
strip_whitespace: bool = True,
|
||||
min_length: int = 1,
|
||||
max_length: int = 2**16,
|
||||
tld_required: bool = True,
|
||||
host_required: bool = True,
|
||||
allowed_schemes: Optional[Collection[str]] = None,
|
||||
) -> Type[AnyUrl]:
|
||||
# use kwargs then define conf in a dict to aid with IDE type hinting
|
||||
namespace = dict(
|
||||
strip_whitespace=strip_whitespace,
|
||||
min_length=min_length,
|
||||
max_length=max_length,
|
||||
tld_required=tld_required,
|
||||
host_required=host_required,
|
||||
allowed_schemes=allowed_schemes,
|
||||
)
|
||||
return type('UrlValue', (AnyUrl,), namespace)
|
||||
|
||||
|
||||
def import_email_validator() -> None:
|
||||
global email_validator
|
||||
try:
|
||||
import email_validator
|
||||
except ImportError as e:
|
||||
raise ImportError('email-validator is not installed, run `pip install pydantic[email]`') from e
|
||||
|
||||
|
||||
class EmailStr(str):
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='email')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
# included here and below so the error happens straight away
|
||||
import_email_validator()
|
||||
|
||||
yield str_validator
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Union[str]) -> str:
|
||||
return validate_email(value)[1]
|
||||
|
||||
|
||||
class NameEmail(Representation):
|
||||
__slots__ = 'name', 'email'
|
||||
|
||||
def __init__(self, name: str, email: str):
|
||||
self.name = name
|
||||
self.email = email
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return isinstance(other, NameEmail) and (self.name, self.email) == (other.name, other.email)
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='name-email')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
import_email_validator()
|
||||
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Any) -> 'NameEmail':
|
||||
if value.__class__ == cls:
|
||||
return value
|
||||
value = str_validator(value)
|
||||
return cls(*validate_email(value))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'{self.name} <{self.email}>'
|
||||
|
||||
|
||||
class IPvAnyAddress(_BaseAddress):
|
||||
__slots__ = ()
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='ipvanyaddress')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: Union[str, bytes, int]) -> Union[IPv4Address, IPv6Address]:
|
||||
try:
|
||||
return IPv4Address(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return IPv6Address(value)
|
||||
except ValueError:
|
||||
raise errors.IPvAnyAddressError()
|
||||
|
||||
|
||||
class IPvAnyInterface(_BaseAddress):
|
||||
__slots__ = ()
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='ipvanyinterface')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: NetworkType) -> Union[IPv4Interface, IPv6Interface]:
|
||||
try:
|
||||
return IPv4Interface(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return IPv6Interface(value)
|
||||
except ValueError:
|
||||
raise errors.IPvAnyInterfaceError()
|
||||
|
||||
|
||||
class IPvAnyNetwork(_BaseNetwork): # type: ignore
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
|
||||
field_schema.update(type='string', format='ipvanynetwork')
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> 'CallableGenerator':
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: NetworkType) -> Union[IPv4Network, IPv6Network]:
|
||||
# Assume IP Network is defined with a default value for ``strict`` argument.
|
||||
# Define your own class if you want to specify network address check strictness.
|
||||
try:
|
||||
return IPv4Network(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return IPv6Network(value)
|
||||
except ValueError:
|
||||
raise errors.IPvAnyNetworkError()
|
||||
|
||||
|
||||
pretty_email_regex = re.compile(r'([\w ]*?) *<(.*)> *')
|
||||
|
||||
|
||||
def validate_email(value: Union[str]) -> Tuple[str, str]:
|
||||
"""
|
||||
Brutally simple email address validation. Note unlike most email address validation
|
||||
* raw ip address (literal) domain parts are not allowed.
|
||||
* "John Doe <local_part@domain.com>" style "pretty" email addresses are processed
|
||||
* the local part check is extremely basic. This raises the possibility of unicode spoofing, but no better
|
||||
solution is really possible.
|
||||
* spaces are striped from the beginning and end of addresses but no error is raised
|
||||
|
||||
See RFC 5322 but treat it with suspicion, there seems to exist no universally acknowledged test for a valid email!
|
||||
"""
|
||||
if email_validator is None:
|
||||
import_email_validator()
|
||||
|
||||
m = pretty_email_regex.fullmatch(value)
|
||||
name: Optional[str] = None
|
||||
if m:
|
||||
name, value = m.groups()
|
||||
|
||||
email = value.strip()
|
||||
|
||||
try:
|
||||
email_validator.validate_email(email, check_deliverability=False)
|
||||
except email_validator.EmailNotValidError as e:
|
||||
raise errors.EmailError() from e
|
||||
|
||||
at_index = email.index('@')
|
||||
local_part = email[:at_index] # RFC 5321, local part must be case-sensitive.
|
||||
global_part = email[at_index:].lower()
|
||||
|
||||
return name or local_part, local_part + global_part
|
66
lib/pydantic/parse.py
Normal file
66
lib/pydantic/parse.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
import json
|
||||
import pickle
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from .types import StrBytes
|
||||
|
||||
|
||||
class Protocol(str, Enum):
|
||||
json = 'json'
|
||||
pickle = 'pickle'
|
||||
|
||||
|
||||
def load_str_bytes(
|
||||
b: StrBytes,
|
||||
*,
|
||||
content_type: str = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
) -> Any:
|
||||
if proto is None and content_type:
|
||||
if content_type.endswith(('json', 'javascript')):
|
||||
pass
|
||||
elif allow_pickle and content_type.endswith('pickle'):
|
||||
proto = Protocol.pickle
|
||||
else:
|
||||
raise TypeError(f'Unknown content-type: {content_type}')
|
||||
|
||||
proto = proto or Protocol.json
|
||||
|
||||
if proto == Protocol.json:
|
||||
if isinstance(b, bytes):
|
||||
b = b.decode(encoding)
|
||||
return json_loads(b)
|
||||
elif proto == Protocol.pickle:
|
||||
if not allow_pickle:
|
||||
raise RuntimeError('Trying to decode with pickle with allow_pickle=False')
|
||||
bb = b if isinstance(b, bytes) else b.encode()
|
||||
return pickle.loads(bb)
|
||||
else:
|
||||
raise TypeError(f'Unknown protocol: {proto}')
|
||||
|
||||
|
||||
def load_file(
|
||||
path: Union[str, Path],
|
||||
*,
|
||||
content_type: str = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
) -> Any:
|
||||
path = Path(path)
|
||||
b = path.read_bytes()
|
||||
if content_type is None:
|
||||
if path.suffix in ('.js', '.json'):
|
||||
proto = Protocol.json
|
||||
elif path.suffix == '.pkl':
|
||||
proto = Protocol.pickle
|
||||
|
||||
return load_str_bytes(
|
||||
b, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle, json_loads=json_loads
|
||||
)
|
0
lib/pydantic/py.typed
Normal file
0
lib/pydantic/py.typed
Normal file
1153
lib/pydantic/schema.py
Normal file
1153
lib/pydantic/schema.py
Normal file
File diff suppressed because it is too large
Load diff
92
lib/pydantic/tools.py
Normal file
92
lib/pydantic/tools.py
Normal file
|
@ -0,0 +1,92 @@
|
|||
import json
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Union
|
||||
|
||||
from .parse import Protocol, load_file, load_str_bytes
|
||||
from .types import StrBytes
|
||||
from .typing import display_as_type
|
||||
|
||||
__all__ = ('parse_file_as', 'parse_obj_as', 'parse_raw_as', 'schema_of', 'schema_json_of')
|
||||
|
||||
NameFactory = Union[str, Callable[[Type[Any]], str]]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .typing import DictStrAny
|
||||
|
||||
|
||||
def _generate_parsing_type_name(type_: Any) -> str:
|
||||
return f'ParsingModel[{display_as_type(type_)}]'
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _get_parsing_type(type_: Any, *, type_name: Optional[NameFactory] = None) -> Any:
|
||||
from pydantic.main import create_model
|
||||
|
||||
if type_name is None:
|
||||
type_name = _generate_parsing_type_name
|
||||
if not isinstance(type_name, str):
|
||||
type_name = type_name(type_)
|
||||
return create_model(type_name, __root__=(type_, ...))
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def parse_obj_as(type_: Type[T], obj: Any, *, type_name: Optional[NameFactory] = None) -> T:
|
||||
model_type = _get_parsing_type(type_, type_name=type_name) # type: ignore[arg-type]
|
||||
return model_type(__root__=obj).__root__
|
||||
|
||||
|
||||
def parse_file_as(
|
||||
type_: Type[T],
|
||||
path: Union[str, Path],
|
||||
*,
|
||||
content_type: str = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
type_name: Optional[NameFactory] = None,
|
||||
) -> T:
|
||||
obj = load_file(
|
||||
path,
|
||||
proto=proto,
|
||||
content_type=content_type,
|
||||
encoding=encoding,
|
||||
allow_pickle=allow_pickle,
|
||||
json_loads=json_loads,
|
||||
)
|
||||
return parse_obj_as(type_, obj, type_name=type_name)
|
||||
|
||||
|
||||
def parse_raw_as(
|
||||
type_: Type[T],
|
||||
b: StrBytes,
|
||||
*,
|
||||
content_type: str = None,
|
||||
encoding: str = 'utf8',
|
||||
proto: Protocol = None,
|
||||
allow_pickle: bool = False,
|
||||
json_loads: Callable[[str], Any] = json.loads,
|
||||
type_name: Optional[NameFactory] = None,
|
||||
) -> T:
|
||||
obj = load_str_bytes(
|
||||
b,
|
||||
proto=proto,
|
||||
content_type=content_type,
|
||||
encoding=encoding,
|
||||
allow_pickle=allow_pickle,
|
||||
json_loads=json_loads,
|
||||
)
|
||||
return parse_obj_as(type_, obj, type_name=type_name)
|
||||
|
||||
|
||||
def schema_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_kwargs: Any) -> 'DictStrAny':
|
||||
"""Generate a JSON schema (as dict) for the passed model or dynamically generated one"""
|
||||
return _get_parsing_type(type_, type_name=title).schema(**schema_kwargs)
|
||||
|
||||
|
||||
def schema_json_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_json_kwargs: Any) -> str:
|
||||
"""Generate a JSON schema (as JSON) for the passed model or dynamically generated one"""
|
||||
return _get_parsing_type(type_, type_name=title).schema_json(**schema_json_kwargs)
|
1187
lib/pydantic/types.py
Normal file
1187
lib/pydantic/types.py
Normal file
File diff suppressed because it is too large
Load diff
602
lib/pydantic/typing.py
Normal file
602
lib/pydantic/typing.py
Normal file
|
@ -0,0 +1,602 @@
|
|||
import sys
|
||||
from collections.abc import Callable
|
||||
from os import PathLike
|
||||
from typing import ( # type: ignore
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable as TypingCallable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
NewType,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
_eval_type,
|
||||
cast,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from typing_extensions import (
|
||||
Annotated,
|
||||
Final,
|
||||
Literal,
|
||||
NotRequired as TypedDictNotRequired,
|
||||
Required as TypedDictRequired,
|
||||
)
|
||||
|
||||
try:
|
||||
from typing import _TypingBase as typing_base # type: ignore
|
||||
except ImportError:
|
||||
from typing import _Final as typing_base # type: ignore
|
||||
|
||||
try:
|
||||
from typing import GenericAlias as TypingGenericAlias # type: ignore
|
||||
except ImportError:
|
||||
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
|
||||
TypingGenericAlias = ()
|
||||
|
||||
try:
|
||||
from types import UnionType as TypesUnionType # type: ignore
|
||||
except ImportError:
|
||||
# python < 3.10 does not have UnionType (str | int, byte | bool and so on)
|
||||
TypesUnionType = ()
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
return type_._evaluate(globalns, localns)
|
||||
|
||||
else:
|
||||
|
||||
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
|
||||
# Even though it is the right signature for python 3.9, mypy complains with
|
||||
# `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast...
|
||||
return cast(Any, type_)._evaluate(globalns, localns, set())
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
# Ensure we always get all the whole `Annotated` hint, not just the annotated type.
|
||||
# For 3.7 to 3.8, `get_type_hints` doesn't recognize `typing_extensions.Annotated`,
|
||||
# so it already returns the full annotation
|
||||
get_all_type_hints = get_type_hints
|
||||
|
||||
else:
|
||||
|
||||
def get_all_type_hints(obj: Any, globalns: Any = None, localns: Any = None) -> Any:
|
||||
return get_type_hints(obj, globalns, localns, include_extras=True)
|
||||
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
AnyCallable = TypingCallable[..., Any]
|
||||
NoArgAnyCallable = TypingCallable[[], Any]
|
||||
|
||||
# workaround for https://github.com/python/mypy/issues/9496
|
||||
AnyArgTCallable = TypingCallable[..., _T]
|
||||
|
||||
|
||||
# Annotated[...] is implemented by returning an instance of one of these classes, depending on
|
||||
# python/typing_extensions version.
|
||||
AnnotatedTypeNames = {'AnnotatedMeta', '_AnnotatedAlias'}
|
||||
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
|
||||
def get_origin(t: Type[Any]) -> Optional[Type[Any]]:
|
||||
if type(t).__name__ in AnnotatedTypeNames:
|
||||
# weirdly this is a runtime requirement, as well as for mypy
|
||||
return cast(Type[Any], Annotated)
|
||||
return getattr(t, '__origin__', None)
|
||||
|
||||
else:
|
||||
from typing import get_origin as _typing_get_origin
|
||||
|
||||
def get_origin(tp: Type[Any]) -> Optional[Type[Any]]:
|
||||
"""
|
||||
We can't directly use `typing.get_origin` since we need a fallback to support
|
||||
custom generic classes like `ConstrainedList`
|
||||
It should be useless once https://github.com/cython/cython/issues/3537 is
|
||||
solved and https://github.com/pydantic/pydantic/pull/1753 is merged.
|
||||
"""
|
||||
if type(tp).__name__ in AnnotatedTypeNames:
|
||||
return cast(Type[Any], Annotated) # mypy complains about _SpecialForm
|
||||
return _typing_get_origin(tp) or getattr(tp, '__origin__', None)
|
||||
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
from typing import _GenericAlias
|
||||
|
||||
def get_args(t: Type[Any]) -> Tuple[Any, ...]:
|
||||
"""Compatibility version of get_args for python 3.7.
|
||||
|
||||
Mostly compatible with the python 3.8 `typing` module version
|
||||
and able to handle almost all use cases.
|
||||
"""
|
||||
if type(t).__name__ in AnnotatedTypeNames:
|
||||
return t.__args__ + t.__metadata__
|
||||
if isinstance(t, _GenericAlias):
|
||||
res = t.__args__
|
||||
if t.__origin__ is Callable and res and res[0] is not Ellipsis:
|
||||
res = (list(res[:-1]), res[-1])
|
||||
return res
|
||||
return getattr(t, '__args__', ())
|
||||
|
||||
else:
|
||||
from typing import get_args as _typing_get_args
|
||||
|
||||
def _generic_get_args(tp: Type[Any]) -> Tuple[Any, ...]:
|
||||
"""
|
||||
In python 3.9, `typing.Dict`, `typing.List`, ...
|
||||
do have an empty `__args__` by default (instead of the generic ~T for example).
|
||||
In order to still support `Dict` for example and consider it as `Dict[Any, Any]`,
|
||||
we retrieve the `_nparams` value that tells us how many parameters it needs.
|
||||
"""
|
||||
if hasattr(tp, '_nparams'):
|
||||
return (Any,) * tp._nparams
|
||||
# Special case for `tuple[()]`, which used to return ((),) with `typing.Tuple`
|
||||
# in python 3.10- but now returns () for `tuple` and `Tuple`.
|
||||
# This will probably be clarified in pydantic v2
|
||||
try:
|
||||
if tp == Tuple[()] or sys.version_info >= (3, 9) and tp == tuple[()]: # type: ignore[misc]
|
||||
return ((),)
|
||||
# there is a TypeError when compiled with cython
|
||||
except TypeError: # pragma: no cover
|
||||
pass
|
||||
return ()
|
||||
|
||||
def get_args(tp: Type[Any]) -> Tuple[Any, ...]:
|
||||
"""Get type arguments with all substitutions performed.
|
||||
|
||||
For unions, basic simplifications used by Union constructor are performed.
|
||||
Examples::
|
||||
get_args(Dict[str, int]) == (str, int)
|
||||
get_args(int) == ()
|
||||
get_args(Union[int, Union[T, int], str][int]) == (int, str)
|
||||
get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
|
||||
get_args(Callable[[], T][int]) == ([], int)
|
||||
"""
|
||||
if type(tp).__name__ in AnnotatedTypeNames:
|
||||
return tp.__args__ + tp.__metadata__
|
||||
# the fallback is needed for the same reasons as `get_origin` (see above)
|
||||
return _typing_get_args(tp) or getattr(tp, '__args__', ()) or _generic_get_args(tp)
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
|
||||
def convert_generics(tp: Type[Any]) -> Type[Any]:
|
||||
"""Python 3.9 and older only supports generics from `typing` module.
|
||||
They convert strings to ForwardRef automatically.
|
||||
|
||||
Examples::
|
||||
typing.List['Hero'] == typing.List[ForwardRef('Hero')]
|
||||
"""
|
||||
return tp
|
||||
|
||||
else:
|
||||
from typing import _UnionGenericAlias # type: ignore
|
||||
|
||||
from typing_extensions import _AnnotatedAlias
|
||||
|
||||
def convert_generics(tp: Type[Any]) -> Type[Any]:
|
||||
"""
|
||||
Recursively searches for `str` type hints and replaces them with ForwardRef.
|
||||
|
||||
Examples::
|
||||
convert_generics(list['Hero']) == list[ForwardRef('Hero')]
|
||||
convert_generics(dict['Hero', 'Team']) == dict[ForwardRef('Hero'), ForwardRef('Team')]
|
||||
convert_generics(typing.Dict['Hero', 'Team']) == typing.Dict[ForwardRef('Hero'), ForwardRef('Team')]
|
||||
convert_generics(list[str | 'Hero'] | int) == list[str | ForwardRef('Hero')] | int
|
||||
"""
|
||||
origin = get_origin(tp)
|
||||
if not origin or not hasattr(tp, '__args__'):
|
||||
return tp
|
||||
|
||||
args = get_args(tp)
|
||||
|
||||
# typing.Annotated needs special treatment
|
||||
if origin is Annotated:
|
||||
return _AnnotatedAlias(convert_generics(args[0]), args[1:])
|
||||
|
||||
# recursively replace `str` instances inside of `GenericAlias` with `ForwardRef(arg)`
|
||||
converted = tuple(
|
||||
ForwardRef(arg) if isinstance(arg, str) and isinstance(tp, TypingGenericAlias) else convert_generics(arg)
|
||||
for arg in args
|
||||
)
|
||||
|
||||
if converted == args:
|
||||
return tp
|
||||
elif isinstance(tp, TypingGenericAlias):
|
||||
return TypingGenericAlias(origin, converted)
|
||||
elif isinstance(tp, TypesUnionType):
|
||||
# recreate types.UnionType (PEP604, Python >= 3.10)
|
||||
return _UnionGenericAlias(origin, converted)
|
||||
else:
|
||||
try:
|
||||
setattr(tp, '__args__', converted)
|
||||
except AttributeError:
|
||||
pass
|
||||
return tp
|
||||
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
|
||||
def is_union(tp: Optional[Type[Any]]) -> bool:
|
||||
return tp is Union
|
||||
|
||||
WithArgsTypes = (TypingGenericAlias,)
|
||||
|
||||
else:
|
||||
import types
|
||||
import typing
|
||||
|
||||
def is_union(tp: Optional[Type[Any]]) -> bool:
|
||||
return tp is Union or tp is types.UnionType # noqa: E721
|
||||
|
||||
WithArgsTypes = (typing._GenericAlias, types.GenericAlias, types.UnionType)
|
||||
|
||||
|
||||
if sys.version_info < (3, 9):
|
||||
StrPath = Union[str, PathLike]
|
||||
else:
|
||||
StrPath = Union[str, PathLike]
|
||||
# TODO: Once we switch to Cython 3 to handle generics properly
|
||||
# (https://github.com/cython/cython/issues/2753), use following lines instead
|
||||
# of the one above
|
||||
# # os.PathLike only becomes subscriptable from Python 3.9 onwards
|
||||
# StrPath = Union[str, PathLike[str]]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .fields import ModelField
|
||||
|
||||
TupleGenerator = Generator[Tuple[str, Any], None, None]
|
||||
DictStrAny = Dict[str, Any]
|
||||
DictAny = Dict[Any, Any]
|
||||
SetStr = Set[str]
|
||||
ListStr = List[str]
|
||||
IntStr = Union[int, str]
|
||||
AbstractSetIntStr = AbstractSet[IntStr]
|
||||
DictIntStrAny = Dict[IntStr, Any]
|
||||
MappingIntStrAny = Mapping[IntStr, Any]
|
||||
CallableGenerator = Generator[AnyCallable, None, None]
|
||||
ReprArgs = Sequence[Tuple[Optional[str], Any]]
|
||||
AnyClassMethod = classmethod[Any]
|
||||
|
||||
__all__ = (
|
||||
'AnyCallable',
|
||||
'NoArgAnyCallable',
|
||||
'NoneType',
|
||||
'is_none_type',
|
||||
'display_as_type',
|
||||
'resolve_annotations',
|
||||
'is_callable_type',
|
||||
'is_literal_type',
|
||||
'all_literal_values',
|
||||
'is_namedtuple',
|
||||
'is_typeddict',
|
||||
'is_typeddict_special',
|
||||
'is_new_type',
|
||||
'new_type_supertype',
|
||||
'is_classvar',
|
||||
'is_finalvar',
|
||||
'update_field_forward_refs',
|
||||
'update_model_forward_refs',
|
||||
'TupleGenerator',
|
||||
'DictStrAny',
|
||||
'DictAny',
|
||||
'SetStr',
|
||||
'ListStr',
|
||||
'IntStr',
|
||||
'AbstractSetIntStr',
|
||||
'DictIntStrAny',
|
||||
'CallableGenerator',
|
||||
'ReprArgs',
|
||||
'AnyClassMethod',
|
||||
'CallableGenerator',
|
||||
'WithArgsTypes',
|
||||
'get_args',
|
||||
'get_origin',
|
||||
'get_sub_types',
|
||||
'typing_base',
|
||||
'get_all_type_hints',
|
||||
'is_union',
|
||||
'StrPath',
|
||||
'MappingIntStrAny',
|
||||
)
|
||||
|
||||
|
||||
NoneType = None.__class__
|
||||
|
||||
|
||||
NONE_TYPES: Tuple[Any, Any, Any] = (None, NoneType, Literal[None])
|
||||
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
# Even though this implementation is slower, we need it for python 3.7:
|
||||
# In python 3.7 "Literal" is not a builtin type and uses a different
|
||||
# mechanism.
|
||||
# for this reason `Literal[None] is Literal[None]` evaluates to `False`,
|
||||
# breaking the faster implementation used for the other python versions.
|
||||
|
||||
def is_none_type(type_: Any) -> bool:
|
||||
return type_ in NONE_TYPES
|
||||
|
||||
elif sys.version_info[:2] == (3, 8):
|
||||
|
||||
def is_none_type(type_: Any) -> bool:
|
||||
for none_type in NONE_TYPES:
|
||||
if type_ is none_type:
|
||||
return True
|
||||
# With python 3.8, specifically 3.8.10, Literal "is" check sare very flakey
|
||||
# can change on very subtle changes like use of types in other modules,
|
||||
# hopefully this check avoids that issue.
|
||||
if is_literal_type(type_): # pragma: no cover
|
||||
return all_literal_values(type_) == (None,)
|
||||
return False
|
||||
|
||||
else:
|
||||
|
||||
def is_none_type(type_: Any) -> bool:
|
||||
for none_type in NONE_TYPES:
|
||||
if type_ is none_type:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def display_as_type(v: Type[Any]) -> str:
|
||||
if not isinstance(v, typing_base) and not isinstance(v, WithArgsTypes) and not isinstance(v, type):
|
||||
v = v.__class__
|
||||
|
||||
if is_union(get_origin(v)):
|
||||
return f'Union[{", ".join(map(display_as_type, get_args(v)))}]'
|
||||
|
||||
if isinstance(v, WithArgsTypes):
|
||||
# Generic alias are constructs like `list[int]`
|
||||
return str(v).replace('typing.', '')
|
||||
|
||||
try:
|
||||
return v.__name__
|
||||
except AttributeError:
|
||||
# happens with typing objects
|
||||
return str(v).replace('typing.', '')
|
||||
|
||||
|
||||
def resolve_annotations(raw_annotations: Dict[str, Type[Any]], module_name: Optional[str]) -> Dict[str, Type[Any]]:
|
||||
"""
|
||||
Partially taken from typing.get_type_hints.
|
||||
|
||||
Resolve string or ForwardRef annotations into type objects if possible.
|
||||
"""
|
||||
base_globals: Optional[Dict[str, Any]] = None
|
||||
if module_name:
|
||||
try:
|
||||
module = sys.modules[module_name]
|
||||
except KeyError:
|
||||
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
|
||||
pass
|
||||
else:
|
||||
base_globals = module.__dict__
|
||||
|
||||
annotations = {}
|
||||
for name, value in raw_annotations.items():
|
||||
if isinstance(value, str):
|
||||
if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1):
|
||||
value = ForwardRef(value, is_argument=False, is_class=True)
|
||||
else:
|
||||
value = ForwardRef(value, is_argument=False)
|
||||
try:
|
||||
value = _eval_type(value, base_globals, None)
|
||||
except NameError:
|
||||
# this is ok, it can be fixed with update_forward_refs
|
||||
pass
|
||||
annotations[name] = value
|
||||
return annotations
|
||||
|
||||
|
||||
def is_callable_type(type_: Type[Any]) -> bool:
|
||||
return type_ is Callable or get_origin(type_) is Callable
|
||||
|
||||
|
||||
def is_literal_type(type_: Type[Any]) -> bool:
|
||||
return Literal is not None and get_origin(type_) is Literal
|
||||
|
||||
|
||||
def literal_values(type_: Type[Any]) -> Tuple[Any, ...]:
|
||||
return get_args(type_)
|
||||
|
||||
|
||||
def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]:
|
||||
"""
|
||||
This method is used to retrieve all Literal values as
|
||||
Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586)
|
||||
e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`
|
||||
"""
|
||||
if not is_literal_type(type_):
|
||||
return (type_,)
|
||||
|
||||
values = literal_values(type_)
|
||||
return tuple(x for value in values for x in all_literal_values(value))
|
||||
|
||||
|
||||
def is_namedtuple(type_: Type[Any]) -> bool:
|
||||
"""
|
||||
Check if a given class is a named tuple.
|
||||
It can be either a `typing.NamedTuple` or `collections.namedtuple`
|
||||
"""
|
||||
from .utils import lenient_issubclass
|
||||
|
||||
return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields')
|
||||
|
||||
|
||||
def is_typeddict(type_: Type[Any]) -> bool:
|
||||
"""
|
||||
Check if a given class is a typed dict (from `typing` or `typing_extensions`)
|
||||
In 3.10, there will be a public method (https://docs.python.org/3.10/library/typing.html#typing.is_typeddict)
|
||||
"""
|
||||
from .utils import lenient_issubclass
|
||||
|
||||
return lenient_issubclass(type_, dict) and hasattr(type_, '__total__')
|
||||
|
||||
|
||||
def _check_typeddict_special(type_: Any) -> bool:
|
||||
return type_ is TypedDictRequired or type_ is TypedDictNotRequired
|
||||
|
||||
|
||||
def is_typeddict_special(type_: Any) -> bool:
|
||||
"""
|
||||
Check if type is a TypedDict special form (Required or NotRequired).
|
||||
"""
|
||||
return _check_typeddict_special(type_) or _check_typeddict_special(get_origin(type_))
|
||||
|
||||
|
||||
test_type = NewType('test_type', str)
|
||||
|
||||
|
||||
def is_new_type(type_: Type[Any]) -> bool:
|
||||
"""
|
||||
Check whether type_ was created using typing.NewType
|
||||
"""
|
||||
return isinstance(type_, test_type.__class__) and hasattr(type_, '__supertype__') # type: ignore
|
||||
|
||||
|
||||
def new_type_supertype(type_: Type[Any]) -> Type[Any]:
|
||||
while hasattr(type_, '__supertype__'):
|
||||
type_ = type_.__supertype__
|
||||
return type_
|
||||
|
||||
|
||||
def _check_classvar(v: Optional[Type[Any]]) -> bool:
|
||||
if v is None:
|
||||
return False
|
||||
|
||||
return v.__class__ == ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar'
|
||||
|
||||
|
||||
def _check_finalvar(v: Optional[Type[Any]]) -> bool:
|
||||
"""
|
||||
Check if a given type is a `typing.Final` type.
|
||||
"""
|
||||
if v is None:
|
||||
return False
|
||||
|
||||
return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final')
|
||||
|
||||
|
||||
def is_classvar(ann_type: Type[Any]) -> bool:
|
||||
if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)):
|
||||
return True
|
||||
|
||||
# this is an ugly workaround for class vars that contain forward references and are therefore themselves
|
||||
# forward references, see #3679
|
||||
if ann_type.__class__ == ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_finalvar(ann_type: Type[Any]) -> bool:
|
||||
return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type))
|
||||
|
||||
|
||||
def update_field_forward_refs(field: 'ModelField', globalns: Any, localns: Any) -> None:
|
||||
"""
|
||||
Try to update ForwardRefs on fields based on this ModelField, globalns and localns.
|
||||
"""
|
||||
prepare = False
|
||||
if field.type_.__class__ == ForwardRef:
|
||||
prepare = True
|
||||
field.type_ = evaluate_forwardref(field.type_, globalns, localns or None)
|
||||
if field.outer_type_.__class__ == ForwardRef:
|
||||
prepare = True
|
||||
field.outer_type_ = evaluate_forwardref(field.outer_type_, globalns, localns or None)
|
||||
if prepare:
|
||||
field.prepare()
|
||||
|
||||
if field.sub_fields:
|
||||
for sub_f in field.sub_fields:
|
||||
update_field_forward_refs(sub_f, globalns=globalns, localns=localns)
|
||||
|
||||
if field.discriminator_key is not None:
|
||||
field.prepare_discriminated_union_sub_fields()
|
||||
|
||||
|
||||
def update_model_forward_refs(
|
||||
model: Type[Any],
|
||||
fields: Iterable['ModelField'],
|
||||
json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable],
|
||||
localns: 'DictStrAny',
|
||||
exc_to_suppress: Tuple[Type[BaseException], ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Try to update model fields ForwardRefs based on model and localns.
|
||||
"""
|
||||
if model.__module__ in sys.modules:
|
||||
globalns = sys.modules[model.__module__].__dict__.copy()
|
||||
else:
|
||||
globalns = {}
|
||||
|
||||
globalns.setdefault(model.__name__, model)
|
||||
|
||||
for f in fields:
|
||||
try:
|
||||
update_field_forward_refs(f, globalns=globalns, localns=localns)
|
||||
except exc_to_suppress:
|
||||
pass
|
||||
|
||||
for key in set(json_encoders.keys()):
|
||||
if isinstance(key, str):
|
||||
fr: ForwardRef = ForwardRef(key)
|
||||
elif isinstance(key, ForwardRef):
|
||||
fr = key
|
||||
else:
|
||||
continue
|
||||
|
||||
try:
|
||||
new_key = evaluate_forwardref(fr, globalns, localns or None)
|
||||
except exc_to_suppress: # pragma: no cover
|
||||
continue
|
||||
|
||||
json_encoders[new_key] = json_encoders.pop(key)
|
||||
|
||||
|
||||
def get_class(type_: Type[Any]) -> Union[None, bool, Type[Any]]:
|
||||
"""
|
||||
Tries to get the class of a Type[T] annotation. Returns True if Type is used
|
||||
without brackets. Otherwise returns None.
|
||||
"""
|
||||
if type_ is type:
|
||||
return True
|
||||
|
||||
if get_origin(type_) is None:
|
||||
return None
|
||||
|
||||
args = get_args(type_)
|
||||
if not args or not isinstance(args[0], type):
|
||||
return True
|
||||
else:
|
||||
return args[0]
|
||||
|
||||
|
||||
def get_sub_types(tp: Any) -> List[Any]:
|
||||
"""
|
||||
Return all the types that are allowed by type `tp`
|
||||
`tp` can be a `Union` of allowed types or an `Annotated` type
|
||||
"""
|
||||
origin = get_origin(tp)
|
||||
if origin is Annotated:
|
||||
return get_sub_types(get_args(tp)[0])
|
||||
elif is_union(origin):
|
||||
return [x for t in get_args(tp) for x in get_sub_types(t)]
|
||||
else:
|
||||
return [tp]
|
841
lib/pydantic/utils.py
Normal file
841
lib/pydantic/utils.py
Normal file
|
@ -0,0 +1,841 @@
|
|||
import keyword
|
||||
import warnings
|
||||
import weakref
|
||||
from collections import OrderedDict, defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from itertools import islice, zip_longest
|
||||
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from .errors import ConfigError
|
||||
from .typing import (
|
||||
NoneType,
|
||||
WithArgsTypes,
|
||||
all_literal_values,
|
||||
display_as_type,
|
||||
get_args,
|
||||
get_origin,
|
||||
is_literal_type,
|
||||
is_union,
|
||||
)
|
||||
from .version import version_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from inspect import Signature
|
||||
from pathlib import Path
|
||||
|
||||
from .config import BaseConfig
|
||||
from .dataclasses import Dataclass
|
||||
from .fields import ModelField
|
||||
from .main import BaseModel
|
||||
from .typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs
|
||||
|
||||
RichReprResult = Iterable[Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]]
|
||||
|
||||
__all__ = (
|
||||
'import_string',
|
||||
'sequence_like',
|
||||
'validate_field_name',
|
||||
'lenient_isinstance',
|
||||
'lenient_issubclass',
|
||||
'in_ipython',
|
||||
'is_valid_identifier',
|
||||
'deep_update',
|
||||
'update_not_none',
|
||||
'almost_equal_floats',
|
||||
'get_model',
|
||||
'to_camel',
|
||||
'is_valid_field',
|
||||
'smart_deepcopy',
|
||||
'PyObjectStr',
|
||||
'Representation',
|
||||
'GetterDict',
|
||||
'ValueItems',
|
||||
'version_info', # required here to match behaviour in v1.3
|
||||
'ClassAttribute',
|
||||
'path_type',
|
||||
'ROOT_KEY',
|
||||
'get_unique_discriminator_alias',
|
||||
'get_discriminator_alias_and_values',
|
||||
'DUNDER_ATTRIBUTES',
|
||||
'LimitedDict',
|
||||
)
|
||||
|
||||
ROOT_KEY = '__root__'
|
||||
# these are types that are returned unchanged by deepcopy
|
||||
IMMUTABLE_NON_COLLECTIONS_TYPES: Set[Type[Any]] = {
|
||||
int,
|
||||
float,
|
||||
complex,
|
||||
str,
|
||||
bool,
|
||||
bytes,
|
||||
type,
|
||||
NoneType,
|
||||
FunctionType,
|
||||
BuiltinFunctionType,
|
||||
LambdaType,
|
||||
weakref.ref,
|
||||
CodeType,
|
||||
# note: including ModuleType will differ from behaviour of deepcopy by not producing error.
|
||||
# It might be not a good idea in general, but considering that this function used only internally
|
||||
# against default values of fields, this will allow to actually have a field with module as default value
|
||||
ModuleType,
|
||||
NotImplemented.__class__,
|
||||
Ellipsis.__class__,
|
||||
}
|
||||
|
||||
# these are types that if empty, might be copied with simple copy() instead of deepcopy()
|
||||
BUILTIN_COLLECTIONS: Set[Type[Any]] = {
|
||||
list,
|
||||
set,
|
||||
tuple,
|
||||
frozenset,
|
||||
dict,
|
||||
OrderedDict,
|
||||
defaultdict,
|
||||
deque,
|
||||
}
|
||||
|
||||
|
||||
def import_string(dotted_path: str) -> Any:
|
||||
"""
|
||||
Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the
|
||||
last name in the path. Raise ImportError if the import fails.
|
||||
"""
|
||||
from importlib import import_module
|
||||
|
||||
try:
|
||||
module_path, class_name = dotted_path.strip(' ').rsplit('.', 1)
|
||||
except ValueError as e:
|
||||
raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e
|
||||
|
||||
module = import_module(module_path)
|
||||
try:
|
||||
return getattr(module, class_name)
|
||||
except AttributeError as e:
|
||||
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e
|
||||
|
||||
|
||||
def truncate(v: Union[str], *, max_len: int = 80) -> str:
|
||||
"""
|
||||
Truncate a value and add a unicode ellipsis (three dots) to the end if it was too long
|
||||
"""
|
||||
warnings.warn('`truncate` is no-longer used by pydantic and is deprecated', DeprecationWarning)
|
||||
if isinstance(v, str) and len(v) > (max_len - 2):
|
||||
# -3 so quote + string + … + quote has correct length
|
||||
return (v[: (max_len - 3)] + '…').__repr__()
|
||||
try:
|
||||
v = v.__repr__()
|
||||
except TypeError:
|
||||
v = v.__class__.__repr__(v) # in case v is a type
|
||||
if len(v) > max_len:
|
||||
v = v[: max_len - 1] + '…'
|
||||
return v
|
||||
|
||||
|
||||
def sequence_like(v: Any) -> bool:
|
||||
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
|
||||
|
||||
|
||||
def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None:
|
||||
"""
|
||||
Ensure that the field's name does not shadow an existing attribute of the model.
|
||||
"""
|
||||
for base in bases:
|
||||
if getattr(base, field_name, None):
|
||||
raise NameError(
|
||||
f'Field name "{field_name}" shadows a BaseModel attribute; '
|
||||
f'use a different field name with "alias=\'{field_name}\'".'
|
||||
)
|
||||
|
||||
|
||||
def lenient_isinstance(o: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
|
||||
try:
|
||||
return isinstance(o, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def lenient_issubclass(cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError:
|
||||
if isinstance(cls, WithArgsTypes):
|
||||
return False
|
||||
raise # pragma: no cover
|
||||
|
||||
|
||||
def in_ipython() -> bool:
|
||||
"""
|
||||
Check whether we're in an ipython environment, including jupyter notebooks.
|
||||
"""
|
||||
try:
|
||||
eval('__IPYTHON__')
|
||||
except NameError:
|
||||
return False
|
||||
else: # pragma: no cover
|
||||
return True
|
||||
|
||||
|
||||
def is_valid_identifier(identifier: str) -> bool:
|
||||
"""
|
||||
Checks that a string is a valid identifier and not a Python keyword.
|
||||
:param identifier: The identifier to test.
|
||||
:return: True if the identifier is valid.
|
||||
"""
|
||||
return identifier.isidentifier() and not keyword.iskeyword(identifier)
|
||||
|
||||
|
||||
KeyType = TypeVar('KeyType')
|
||||
|
||||
|
||||
def deep_update(mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]) -> Dict[KeyType, Any]:
|
||||
updated_mapping = mapping.copy()
|
||||
for updating_mapping in updating_mappings:
|
||||
for k, v in updating_mapping.items():
|
||||
if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
|
||||
updated_mapping[k] = deep_update(updated_mapping[k], v)
|
||||
else:
|
||||
updated_mapping[k] = v
|
||||
return updated_mapping
|
||||
|
||||
|
||||
def update_not_none(mapping: Dict[Any, Any], **update: Any) -> None:
|
||||
mapping.update({k: v for k, v in update.items() if v is not None})
|
||||
|
||||
|
||||
def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool:
|
||||
"""
|
||||
Return True if two floats are almost equal
|
||||
"""
|
||||
return abs(value_1 - value_2) <= delta
|
||||
|
||||
|
||||
def generate_model_signature(
|
||||
init: Callable[..., None], fields: Dict[str, 'ModelField'], config: Type['BaseConfig']
|
||||
) -> 'Signature':
|
||||
"""
|
||||
Generate signature for model based on its fields
|
||||
"""
|
||||
from inspect import Parameter, Signature, signature
|
||||
|
||||
from .config import Extra
|
||||
|
||||
present_params = signature(init).parameters.values()
|
||||
merged_params: Dict[str, Parameter] = {}
|
||||
var_kw = None
|
||||
use_var_kw = False
|
||||
|
||||
for param in islice(present_params, 1, None): # skip self arg
|
||||
if param.kind is param.VAR_KEYWORD:
|
||||
var_kw = param
|
||||
continue
|
||||
merged_params[param.name] = param
|
||||
|
||||
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
|
||||
allow_names = config.allow_population_by_field_name
|
||||
for field_name, field in fields.items():
|
||||
param_name = field.alias
|
||||
if field_name in merged_params or param_name in merged_params:
|
||||
continue
|
||||
elif not is_valid_identifier(param_name):
|
||||
if allow_names and is_valid_identifier(field_name):
|
||||
param_name = field_name
|
||||
else:
|
||||
use_var_kw = True
|
||||
continue
|
||||
|
||||
# TODO: replace annotation with actual expected types once #1055 solved
|
||||
kwargs = {'default': field.default} if not field.required else {}
|
||||
merged_params[param_name] = Parameter(
|
||||
param_name, Parameter.KEYWORD_ONLY, annotation=field.annotation, **kwargs
|
||||
)
|
||||
|
||||
if config.extra is Extra.allow:
|
||||
use_var_kw = True
|
||||
|
||||
if var_kw and use_var_kw:
|
||||
# Make sure the parameter for extra kwargs
|
||||
# does not have the same name as a field
|
||||
default_model_signature = [
|
||||
('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
|
||||
('data', Parameter.VAR_KEYWORD),
|
||||
]
|
||||
if [(p.name, p.kind) for p in present_params] == default_model_signature:
|
||||
# if this is the standard model signature, use extra_data as the extra args name
|
||||
var_kw_name = 'extra_data'
|
||||
else:
|
||||
# else start from var_kw
|
||||
var_kw_name = var_kw.name
|
||||
|
||||
# generate a name that's definitely unique
|
||||
while var_kw_name in fields:
|
||||
var_kw_name += '_'
|
||||
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
|
||||
|
||||
return Signature(parameters=list(merged_params.values()), return_annotation=None)
|
||||
|
||||
|
||||
def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']:
|
||||
from .main import BaseModel
|
||||
|
||||
try:
|
||||
model_cls = obj.__pydantic_model__ # type: ignore
|
||||
except AttributeError:
|
||||
model_cls = obj
|
||||
|
||||
if not issubclass(model_cls, BaseModel):
|
||||
raise TypeError('Unsupported type, must be either BaseModel or dataclass')
|
||||
return model_cls
|
||||
|
||||
|
||||
def to_camel(string: str) -> str:
|
||||
return ''.join(word.capitalize() for word in string.split('_'))
|
||||
|
||||
|
||||
def to_lower_camel(string: str) -> str:
|
||||
if len(string) >= 1:
|
||||
pascal_string = to_camel(string)
|
||||
return pascal_string[0].lower() + pascal_string[1:]
|
||||
return string.lower()
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def unique_list(
|
||||
input_list: Union[List[T], Tuple[T, ...]],
|
||||
*,
|
||||
name_factory: Callable[[T], str] = str,
|
||||
) -> List[T]:
|
||||
"""
|
||||
Make a list unique while maintaining order.
|
||||
We update the list if another one with the same name is set
|
||||
(e.g. root validator overridden in subclass)
|
||||
"""
|
||||
result: List[T] = []
|
||||
result_names: List[str] = []
|
||||
for v in input_list:
|
||||
v_name = name_factory(v)
|
||||
if v_name not in result_names:
|
||||
result_names.append(v_name)
|
||||
result.append(v)
|
||||
else:
|
||||
result[result_names.index(v_name)] = v
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class PyObjectStr(str):
|
||||
"""
|
||||
String class where repr doesn't include quotes. Useful with Representation when you want to return a string
|
||||
representation of something that valid (or pseudo-valid) python.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return str(self)
|
||||
|
||||
|
||||
class Representation:
|
||||
"""
|
||||
Mixin to provide __str__, __repr__, and __pretty__ methods. See #884 for more details.
|
||||
|
||||
__pretty__ is used by [devtools](https://python-devtools.helpmanual.io/) to provide human readable representations
|
||||
of objects.
|
||||
"""
|
||||
|
||||
__slots__: Tuple[str, ...] = tuple()
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
"""
|
||||
Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
|
||||
|
||||
Can either return:
|
||||
* name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
|
||||
* or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
|
||||
"""
|
||||
attrs = ((s, getattr(self, s)) for s in self.__slots__)
|
||||
return [(a, v) for a, v in attrs if v is not None]
|
||||
|
||||
def __repr_name__(self) -> str:
|
||||
"""
|
||||
Name of the instance's class, used in __repr__.
|
||||
"""
|
||||
return self.__class__.__name__
|
||||
|
||||
def __repr_str__(self, join_str: str) -> str:
|
||||
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
|
||||
|
||||
def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]:
|
||||
"""
|
||||
Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects
|
||||
"""
|
||||
yield self.__repr_name__() + '('
|
||||
yield 1
|
||||
for name, value in self.__repr_args__():
|
||||
if name is not None:
|
||||
yield name + '='
|
||||
yield fmt(value)
|
||||
yield ','
|
||||
yield 0
|
||||
yield -1
|
||||
yield ')'
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr_str__(' ')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
|
||||
|
||||
def __rich_repr__(self) -> 'RichReprResult':
|
||||
"""Get fields for Rich library"""
|
||||
for name, field_repr in self.__repr_args__():
|
||||
if name is None:
|
||||
yield field_repr
|
||||
else:
|
||||
yield name, field_repr
|
||||
|
||||
|
||||
class GetterDict(Representation):
|
||||
"""
|
||||
Hack to make object's smell just enough like dicts for validate_model.
|
||||
|
||||
We can't inherit from Mapping[str, Any] because it upsets cython so we have to implement all methods ourselves.
|
||||
"""
|
||||
|
||||
__slots__ = ('_obj',)
|
||||
|
||||
def __init__(self, obj: Any):
|
||||
self._obj = obj
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
try:
|
||||
return getattr(self._obj, key)
|
||||
except AttributeError as e:
|
||||
raise KeyError(key) from e
|
||||
|
||||
def get(self, key: Any, default: Any = None) -> Any:
|
||||
return getattr(self._obj, key, default)
|
||||
|
||||
def extra_keys(self) -> Set[Any]:
|
||||
"""
|
||||
We don't want to get any other attributes of obj if the model didn't explicitly ask for them
|
||||
"""
|
||||
return set()
|
||||
|
||||
def keys(self) -> List[Any]:
|
||||
"""
|
||||
Keys of the pseudo dictionary, uses a list not set so order information can be maintained like python
|
||||
dictionaries.
|
||||
"""
|
||||
return list(self)
|
||||
|
||||
def values(self) -> List[Any]:
|
||||
return [self[k] for k in self]
|
||||
|
||||
def items(self) -> Iterator[Tuple[str, Any]]:
|
||||
for k in self:
|
||||
yield k, self.get(k)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
for name in dir(self._obj):
|
||||
if not name.startswith('_'):
|
||||
yield name
|
||||
|
||||
def __len__(self) -> int:
|
||||
return sum(1 for _ in self)
|
||||
|
||||
def __contains__(self, item: Any) -> bool:
|
||||
return item in self.keys()
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return dict(self) == dict(other.items())
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
return [(None, dict(self))]
|
||||
|
||||
def __repr_name__(self) -> str:
|
||||
return f'GetterDict[{display_as_type(self._obj)}]'
|
||||
|
||||
|
||||
class ValueItems(Representation):
|
||||
"""
|
||||
Class for more convenient calculation of excluded or included fields on values.
|
||||
"""
|
||||
|
||||
__slots__ = ('_items', '_type')
|
||||
|
||||
def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> None:
|
||||
items = self._coerce_items(items)
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
items = self._normalize_indexes(items, len(value))
|
||||
|
||||
self._items: 'MappingIntStrAny' = items
|
||||
|
||||
def is_excluded(self, item: Any) -> bool:
|
||||
"""
|
||||
Check if item is fully excluded.
|
||||
|
||||
:param item: key or index of a value
|
||||
"""
|
||||
return self.is_true(self._items.get(item))
|
||||
|
||||
def is_included(self, item: Any) -> bool:
|
||||
"""
|
||||
Check if value is contained in self._items
|
||||
|
||||
:param item: key or index of value
|
||||
"""
|
||||
return item in self._items
|
||||
|
||||
def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]:
|
||||
"""
|
||||
:param e: key or index of element on value
|
||||
:return: raw values for element if self._items is dict and contain needed element
|
||||
"""
|
||||
|
||||
item = self._items.get(e)
|
||||
return item if not self.is_true(item) else None
|
||||
|
||||
def _normalize_indexes(self, items: 'MappingIntStrAny', v_length: int) -> 'DictIntStrAny':
|
||||
"""
|
||||
:param items: dict or set of indexes which will be normalized
|
||||
:param v_length: length of sequence indexes of which will be
|
||||
|
||||
>>> self._normalize_indexes({0: True, -2: True, -1: True}, 4)
|
||||
{0: True, 2: True, 3: True}
|
||||
>>> self._normalize_indexes({'__all__': True}, 4)
|
||||
{0: True, 1: True, 2: True, 3: True}
|
||||
"""
|
||||
|
||||
normalized_items: 'DictIntStrAny' = {}
|
||||
all_items = None
|
||||
for i, v in items.items():
|
||||
if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or self.is_true(v)):
|
||||
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
|
||||
if i == '__all__':
|
||||
all_items = self._coerce_value(v)
|
||||
continue
|
||||
if not isinstance(i, int):
|
||||
raise TypeError(
|
||||
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
|
||||
'expected integer keys or keyword "__all__"'
|
||||
)
|
||||
normalized_i = v_length + i if i < 0 else i
|
||||
normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i))
|
||||
|
||||
if not all_items:
|
||||
return normalized_items
|
||||
if self.is_true(all_items):
|
||||
for i in range(v_length):
|
||||
normalized_items.setdefault(i, ...)
|
||||
return normalized_items
|
||||
for i in range(v_length):
|
||||
normalized_item = normalized_items.setdefault(i, {})
|
||||
if not self.is_true(normalized_item):
|
||||
normalized_items[i] = self.merge(all_items, normalized_item)
|
||||
return normalized_items
|
||||
|
||||
@classmethod
|
||||
def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any:
|
||||
"""
|
||||
Merge a ``base`` item with an ``override`` item.
|
||||
|
||||
Both ``base`` and ``override`` are converted to dictionaries if possible.
|
||||
Sets are converted to dictionaries with the sets entries as keys and
|
||||
Ellipsis as values.
|
||||
|
||||
Each key-value pair existing in ``base`` is merged with ``override``,
|
||||
while the rest of the key-value pairs are updated recursively with this function.
|
||||
|
||||
Merging takes place based on the "union" of keys if ``intersect`` is
|
||||
set to ``False`` (default) and on the intersection of keys if
|
||||
``intersect`` is set to ``True``.
|
||||
"""
|
||||
override = cls._coerce_value(override)
|
||||
base = cls._coerce_value(base)
|
||||
if override is None:
|
||||
return base
|
||||
if cls.is_true(base) or base is None:
|
||||
return override
|
||||
if cls.is_true(override):
|
||||
return base if intersect else override
|
||||
|
||||
# intersection or union of keys while preserving ordering:
|
||||
if intersect:
|
||||
merge_keys = [k for k in base if k in override] + [k for k in override if k in base]
|
||||
else:
|
||||
merge_keys = list(base) + [k for k in override if k not in base]
|
||||
|
||||
merged: 'DictIntStrAny' = {}
|
||||
for k in merge_keys:
|
||||
merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect)
|
||||
if merged_item is not None:
|
||||
merged[k] = merged_item
|
||||
|
||||
return merged
|
||||
|
||||
@staticmethod
|
||||
def _coerce_items(items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> 'MappingIntStrAny':
|
||||
if isinstance(items, Mapping):
|
||||
pass
|
||||
elif isinstance(items, AbstractSet):
|
||||
items = dict.fromkeys(items, ...)
|
||||
else:
|
||||
class_name = getattr(items, '__class__', '???')
|
||||
assert_never(
|
||||
items,
|
||||
f'Unexpected type of exclude value {class_name}',
|
||||
)
|
||||
return items
|
||||
|
||||
@classmethod
|
||||
def _coerce_value(cls, value: Any) -> Any:
|
||||
if value is None or cls.is_true(value):
|
||||
return value
|
||||
return cls._coerce_items(value)
|
||||
|
||||
@staticmethod
|
||||
def is_true(v: Any) -> bool:
|
||||
return v is True or v is ...
|
||||
|
||||
def __repr_args__(self) -> 'ReprArgs':
|
||||
return [(None, self._items)]
|
||||
|
||||
|
||||
class ClassAttribute:
|
||||
"""
|
||||
Hide class attribute from its instances
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'name',
|
||||
'value',
|
||||
)
|
||||
|
||||
def __init__(self, name: str, value: Any) -> None:
|
||||
self.name = name
|
||||
self.value = value
|
||||
|
||||
def __get__(self, instance: Any, owner: Type[Any]) -> None:
|
||||
if instance is None:
|
||||
return self.value
|
||||
raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only')
|
||||
|
||||
|
||||
path_types = {
|
||||
'is_dir': 'directory',
|
||||
'is_file': 'file',
|
||||
'is_mount': 'mount point',
|
||||
'is_symlink': 'symlink',
|
||||
'is_block_device': 'block device',
|
||||
'is_char_device': 'char device',
|
||||
'is_fifo': 'FIFO',
|
||||
'is_socket': 'socket',
|
||||
}
|
||||
|
||||
|
||||
def path_type(p: 'Path') -> str:
|
||||
"""
|
||||
Find out what sort of thing a path is.
|
||||
"""
|
||||
assert p.exists(), 'path does not exist'
|
||||
for method, name in path_types.items():
|
||||
if getattr(p, method)():
|
||||
return name
|
||||
|
||||
return 'unknown'
|
||||
|
||||
|
||||
Obj = TypeVar('Obj')
|
||||
|
||||
|
||||
def smart_deepcopy(obj: Obj) -> Obj:
|
||||
"""
|
||||
Return type as is for immutable built-in types
|
||||
Use obj.copy() for built-in empty collections
|
||||
Use copy.deepcopy() for non-empty collections and unknown objects
|
||||
"""
|
||||
|
||||
obj_type = obj.__class__
|
||||
if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES:
|
||||
return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway
|
||||
try:
|
||||
if not obj and obj_type in BUILTIN_COLLECTIONS:
|
||||
# faster way for empty collections, no need to copy its members
|
||||
return obj if obj_type is tuple else obj.copy() # type: ignore # tuple doesn't have copy method
|
||||
except (TypeError, ValueError, RuntimeError):
|
||||
# do we really dare to catch ALL errors? Seems a bit risky
|
||||
pass
|
||||
|
||||
return deepcopy(obj) # slowest way when we actually might need a deepcopy
|
||||
|
||||
|
||||
def is_valid_field(name: str) -> bool:
|
||||
if not name.startswith('_'):
|
||||
return True
|
||||
return ROOT_KEY == name
|
||||
|
||||
|
||||
DUNDER_ATTRIBUTES = {
|
||||
'__annotations__',
|
||||
'__classcell__',
|
||||
'__doc__',
|
||||
'__module__',
|
||||
'__orig_bases__',
|
||||
'__orig_class__',
|
||||
'__qualname__',
|
||||
}
|
||||
|
||||
|
||||
def is_valid_private_name(name: str) -> bool:
|
||||
return not is_valid_field(name) and name not in DUNDER_ATTRIBUTES
|
||||
|
||||
|
||||
_EMPTY = object()
|
||||
|
||||
|
||||
def all_identical(left: Iterable[Any], right: Iterable[Any]) -> bool:
|
||||
"""
|
||||
Check that the items of `left` are the same objects as those in `right`.
|
||||
|
||||
>>> a, b = object(), object()
|
||||
>>> all_identical([a, b, a], [a, b, a])
|
||||
True
|
||||
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
|
||||
False
|
||||
"""
|
||||
for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY):
|
||||
if left_item is not right_item:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def assert_never(obj: NoReturn, msg: str) -> NoReturn:
|
||||
"""
|
||||
Helper to make sure that we have covered all possible types.
|
||||
|
||||
This is mostly useful for ``mypy``, docs:
|
||||
https://mypy.readthedocs.io/en/latest/literal_types.html#exhaustive-checks
|
||||
"""
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
def get_unique_discriminator_alias(all_aliases: Collection[str], discriminator_key: str) -> str:
|
||||
"""Validate that all aliases are the same and if that's the case return the alias"""
|
||||
unique_aliases = set(all_aliases)
|
||||
if len(unique_aliases) > 1:
|
||||
raise ConfigError(
|
||||
f'Aliases for discriminator {discriminator_key!r} must be the same (got {", ".join(sorted(all_aliases))})'
|
||||
)
|
||||
return unique_aliases.pop()
|
||||
|
||||
|
||||
def get_discriminator_alias_and_values(tp: Any, discriminator_key: str) -> Tuple[str, Tuple[str, ...]]:
|
||||
"""
|
||||
Get alias and all valid values in the `Literal` type of the discriminator field
|
||||
`tp` can be a `BaseModel` class or directly an `Annotated` `Union` of many.
|
||||
"""
|
||||
is_root_model = getattr(tp, '__custom_root_type__', False)
|
||||
|
||||
if get_origin(tp) is Annotated:
|
||||
tp = get_args(tp)[0]
|
||||
|
||||
if hasattr(tp, '__pydantic_model__'):
|
||||
tp = tp.__pydantic_model__
|
||||
|
||||
if is_union(get_origin(tp)):
|
||||
alias, all_values = _get_union_alias_and_all_values(tp, discriminator_key)
|
||||
return alias, tuple(v for values in all_values for v in values)
|
||||
elif is_root_model:
|
||||
union_type = tp.__fields__[ROOT_KEY].type_
|
||||
alias, all_values = _get_union_alias_and_all_values(union_type, discriminator_key)
|
||||
|
||||
if len(set(all_values)) > 1:
|
||||
raise ConfigError(
|
||||
f'Field {discriminator_key!r} is not the same for all submodels of {display_as_type(tp)!r}'
|
||||
)
|
||||
|
||||
return alias, all_values[0]
|
||||
|
||||
else:
|
||||
try:
|
||||
t_discriminator_type = tp.__fields__[discriminator_key].type_
|
||||
except AttributeError as e:
|
||||
raise TypeError(f'Type {tp.__name__!r} is not a valid `BaseModel` or `dataclass`') from e
|
||||
except KeyError as e:
|
||||
raise ConfigError(f'Model {tp.__name__!r} needs a discriminator field for key {discriminator_key!r}') from e
|
||||
|
||||
if not is_literal_type(t_discriminator_type):
|
||||
raise ConfigError(f'Field {discriminator_key!r} of model {tp.__name__!r} needs to be a `Literal`')
|
||||
|
||||
return tp.__fields__[discriminator_key].alias, all_literal_values(t_discriminator_type)
|
||||
|
||||
|
||||
def _get_union_alias_and_all_values(
|
||||
union_type: Type[Any], discriminator_key: str
|
||||
) -> Tuple[str, Tuple[Tuple[str, ...], ...]]:
|
||||
zipped_aliases_values = [get_discriminator_alias_and_values(t, discriminator_key) for t in get_args(union_type)]
|
||||
# unzip: [('alias_a',('v1', 'v2)), ('alias_b', ('v3',))] => [('alias_a', 'alias_b'), (('v1', 'v2'), ('v3',))]
|
||||
all_aliases, all_values = zip(*zipped_aliases_values)
|
||||
return get_unique_discriminator_alias(all_aliases, discriminator_key), all_values
|
||||
|
||||
|
||||
KT = TypeVar('KT')
|
||||
VT = TypeVar('VT')
|
||||
if TYPE_CHECKING:
|
||||
# Annoying inheriting from `MutableMapping` and `dict` breaks cython, hence this work around
|
||||
class LimitedDict(dict, MutableMapping[KT, VT]): # type: ignore[type-arg]
|
||||
def __init__(self, size_limit: int = 1000):
|
||||
...
|
||||
|
||||
else:
|
||||
|
||||
class LimitedDict(dict):
|
||||
"""
|
||||
Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage.
|
||||
|
||||
Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache.
|
||||
|
||||
Annoying inheriting from `MutableMapping` breaks cython.
|
||||
"""
|
||||
|
||||
def __init__(self, size_limit: int = 1000):
|
||||
self.size_limit = size_limit
|
||||
super().__init__()
|
||||
|
||||
def __setitem__(self, __key: Any, __value: Any) -> None:
|
||||
super().__setitem__(__key, __value)
|
||||
if len(self) > self.size_limit:
|
||||
excess = len(self) - self.size_limit + self.size_limit // 10
|
||||
to_remove = list(self.keys())[:excess]
|
||||
for key in to_remove:
|
||||
del self[key]
|
||||
|
||||
def __class_getitem__(cls, *args: Any) -> Any:
|
||||
# to avoid errors with 3.7
|
||||
pass
|
765
lib/pydantic/validators.py
Normal file
765
lib/pydantic/validators.py
Normal file
|
@ -0,0 +1,765 @@
|
|||
import math
|
||||
import re
|
||||
from collections import OrderedDict, deque
|
||||
from collections.abc import Hashable as CollectionsHashable
|
||||
from datetime import date, datetime, time, timedelta
|
||||
from decimal import Decimal, DecimalException
|
||||
from enum import Enum, IntEnum
|
||||
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Deque,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
FrozenSet,
|
||||
Generator,
|
||||
Hashable,
|
||||
List,
|
||||
NamedTuple,
|
||||
Pattern,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from . import errors
|
||||
from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time
|
||||
from .typing import (
|
||||
AnyCallable,
|
||||
all_literal_values,
|
||||
display_as_type,
|
||||
get_class,
|
||||
is_callable_type,
|
||||
is_literal_type,
|
||||
is_namedtuple,
|
||||
is_none_type,
|
||||
is_typeddict,
|
||||
)
|
||||
from .utils import almost_equal_floats, lenient_issubclass, sequence_like
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Literal, TypedDict
|
||||
|
||||
from .config import BaseConfig
|
||||
from .fields import ModelField
|
||||
from .types import ConstrainedDecimal, ConstrainedFloat, ConstrainedInt
|
||||
|
||||
ConstrainedNumber = Union[ConstrainedDecimal, ConstrainedFloat, ConstrainedInt]
|
||||
AnyOrderedDict = OrderedDict[Any, Any]
|
||||
Number = Union[int, float, Decimal]
|
||||
StrBytes = Union[str, bytes]
|
||||
|
||||
|
||||
def str_validator(v: Any) -> Union[str]:
|
||||
if isinstance(v, str):
|
||||
if isinstance(v, Enum):
|
||||
return v.value
|
||||
else:
|
||||
return v
|
||||
elif isinstance(v, (float, int, Decimal)):
|
||||
# is there anything else we want to add here? If you think so, create an issue.
|
||||
return str(v)
|
||||
elif isinstance(v, (bytes, bytearray)):
|
||||
return v.decode()
|
||||
else:
|
||||
raise errors.StrError()
|
||||
|
||||
|
||||
def strict_str_validator(v: Any) -> Union[str]:
|
||||
if isinstance(v, str) and not isinstance(v, Enum):
|
||||
return v
|
||||
raise errors.StrError()
|
||||
|
||||
|
||||
def bytes_validator(v: Any) -> Union[bytes]:
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
elif isinstance(v, bytearray):
|
||||
return bytes(v)
|
||||
elif isinstance(v, str):
|
||||
return v.encode()
|
||||
elif isinstance(v, (float, int, Decimal)):
|
||||
return str(v).encode()
|
||||
else:
|
||||
raise errors.BytesError()
|
||||
|
||||
|
||||
def strict_bytes_validator(v: Any) -> Union[bytes]:
|
||||
if isinstance(v, bytes):
|
||||
return v
|
||||
elif isinstance(v, bytearray):
|
||||
return bytes(v)
|
||||
else:
|
||||
raise errors.BytesError()
|
||||
|
||||
|
||||
BOOL_FALSE = {0, '0', 'off', 'f', 'false', 'n', 'no'}
|
||||
BOOL_TRUE = {1, '1', 'on', 't', 'true', 'y', 'yes'}
|
||||
|
||||
|
||||
def bool_validator(v: Any) -> bool:
|
||||
if v is True or v is False:
|
||||
return v
|
||||
if isinstance(v, bytes):
|
||||
v = v.decode()
|
||||
if isinstance(v, str):
|
||||
v = v.lower()
|
||||
try:
|
||||
if v in BOOL_TRUE:
|
||||
return True
|
||||
if v in BOOL_FALSE:
|
||||
return False
|
||||
except TypeError:
|
||||
raise errors.BoolError()
|
||||
raise errors.BoolError()
|
||||
|
||||
|
||||
# matches the default limit cpython, see https://github.com/python/cpython/pull/96500
|
||||
max_str_int = 4_300
|
||||
|
||||
|
||||
def int_validator(v: Any) -> int:
|
||||
if isinstance(v, int) and not (v is True or v is False):
|
||||
return v
|
||||
|
||||
# see https://github.com/pydantic/pydantic/issues/1477 and in turn, https://github.com/python/cpython/issues/95778
|
||||
# this check should be unnecessary once patch releases are out for 3.7, 3.8, 3.9 and 3.10
|
||||
# but better to check here until then.
|
||||
# NOTICE: this does not fully protect user from the DOS risk since the standard library JSON implementation
|
||||
# (and other std lib modules like xml) use `int()` and are likely called before this, the best workaround is to
|
||||
# 1. update to the latest patch release of python once released, 2. use a different JSON library like ujson
|
||||
if isinstance(v, (str, bytes, bytearray)) and len(v) > max_str_int:
|
||||
raise errors.IntegerError()
|
||||
|
||||
try:
|
||||
return int(v)
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
raise errors.IntegerError()
|
||||
|
||||
|
||||
def strict_int_validator(v: Any) -> int:
|
||||
if isinstance(v, int) and not (v is True or v is False):
|
||||
return v
|
||||
raise errors.IntegerError()
|
||||
|
||||
|
||||
def float_validator(v: Any) -> float:
|
||||
if isinstance(v, float):
|
||||
return v
|
||||
|
||||
try:
|
||||
return float(v)
|
||||
except (TypeError, ValueError):
|
||||
raise errors.FloatError()
|
||||
|
||||
|
||||
def strict_float_validator(v: Any) -> float:
|
||||
if isinstance(v, float):
|
||||
return v
|
||||
raise errors.FloatError()
|
||||
|
||||
|
||||
def float_finite_validator(v: 'Number', field: 'ModelField', config: 'BaseConfig') -> 'Number':
|
||||
allow_inf_nan = getattr(field.type_, 'allow_inf_nan', None)
|
||||
if allow_inf_nan is None:
|
||||
allow_inf_nan = config.allow_inf_nan
|
||||
|
||||
if allow_inf_nan is False and (math.isnan(v) or math.isinf(v)):
|
||||
raise errors.NumberNotFiniteError()
|
||||
return v
|
||||
|
||||
|
||||
def number_multiple_validator(v: 'Number', field: 'ModelField') -> 'Number':
|
||||
field_type: ConstrainedNumber = field.type_
|
||||
if field_type.multiple_of is not None:
|
||||
mod = float(v) / float(field_type.multiple_of) % 1
|
||||
if not almost_equal_floats(mod, 0.0) and not almost_equal_floats(mod, 1.0):
|
||||
raise errors.NumberNotMultipleError(multiple_of=field_type.multiple_of)
|
||||
return v
|
||||
|
||||
|
||||
def number_size_validator(v: 'Number', field: 'ModelField') -> 'Number':
|
||||
field_type: ConstrainedNumber = field.type_
|
||||
if field_type.gt is not None and not v > field_type.gt:
|
||||
raise errors.NumberNotGtError(limit_value=field_type.gt)
|
||||
elif field_type.ge is not None and not v >= field_type.ge:
|
||||
raise errors.NumberNotGeError(limit_value=field_type.ge)
|
||||
|
||||
if field_type.lt is not None and not v < field_type.lt:
|
||||
raise errors.NumberNotLtError(limit_value=field_type.lt)
|
||||
if field_type.le is not None and not v <= field_type.le:
|
||||
raise errors.NumberNotLeError(limit_value=field_type.le)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def constant_validator(v: 'Any', field: 'ModelField') -> 'Any':
|
||||
"""Validate ``const`` fields.
|
||||
|
||||
The value provided for a ``const`` field must be equal to the default value
|
||||
of the field. This is to support the keyword of the same name in JSON
|
||||
Schema.
|
||||
"""
|
||||
if v != field.default:
|
||||
raise errors.WrongConstantError(given=v, permitted=[field.default])
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def anystr_length_validator(v: 'StrBytes', config: 'BaseConfig') -> 'StrBytes':
|
||||
v_len = len(v)
|
||||
|
||||
min_length = config.min_anystr_length
|
||||
if v_len < min_length:
|
||||
raise errors.AnyStrMinLengthError(limit_value=min_length)
|
||||
|
||||
max_length = config.max_anystr_length
|
||||
if max_length is not None and v_len > max_length:
|
||||
raise errors.AnyStrMaxLengthError(limit_value=max_length)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def anystr_strip_whitespace(v: 'StrBytes') -> 'StrBytes':
|
||||
return v.strip()
|
||||
|
||||
|
||||
def anystr_upper(v: 'StrBytes') -> 'StrBytes':
|
||||
return v.upper()
|
||||
|
||||
|
||||
def anystr_lower(v: 'StrBytes') -> 'StrBytes':
|
||||
return v.lower()
|
||||
|
||||
|
||||
def ordered_dict_validator(v: Any) -> 'AnyOrderedDict':
|
||||
if isinstance(v, OrderedDict):
|
||||
return v
|
||||
|
||||
try:
|
||||
return OrderedDict(v)
|
||||
except (TypeError, ValueError):
|
||||
raise errors.DictError()
|
||||
|
||||
|
||||
def dict_validator(v: Any) -> Dict[Any, Any]:
|
||||
if isinstance(v, dict):
|
||||
return v
|
||||
|
||||
try:
|
||||
return dict(v)
|
||||
except (TypeError, ValueError):
|
||||
raise errors.DictError()
|
||||
|
||||
|
||||
def list_validator(v: Any) -> List[Any]:
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
elif sequence_like(v):
|
||||
return list(v)
|
||||
else:
|
||||
raise errors.ListError()
|
||||
|
||||
|
||||
def tuple_validator(v: Any) -> Tuple[Any, ...]:
|
||||
if isinstance(v, tuple):
|
||||
return v
|
||||
elif sequence_like(v):
|
||||
return tuple(v)
|
||||
else:
|
||||
raise errors.TupleError()
|
||||
|
||||
|
||||
def set_validator(v: Any) -> Set[Any]:
|
||||
if isinstance(v, set):
|
||||
return v
|
||||
elif sequence_like(v):
|
||||
return set(v)
|
||||
else:
|
||||
raise errors.SetError()
|
||||
|
||||
|
||||
def frozenset_validator(v: Any) -> FrozenSet[Any]:
|
||||
if isinstance(v, frozenset):
|
||||
return v
|
||||
elif sequence_like(v):
|
||||
return frozenset(v)
|
||||
else:
|
||||
raise errors.FrozenSetError()
|
||||
|
||||
|
||||
def deque_validator(v: Any) -> Deque[Any]:
|
||||
if isinstance(v, deque):
|
||||
return v
|
||||
elif sequence_like(v):
|
||||
return deque(v)
|
||||
else:
|
||||
raise errors.DequeError()
|
||||
|
||||
|
||||
def enum_member_validator(v: Any, field: 'ModelField', config: 'BaseConfig') -> Enum:
|
||||
try:
|
||||
enum_v = field.type_(v)
|
||||
except ValueError:
|
||||
# field.type_ should be an enum, so will be iterable
|
||||
raise errors.EnumMemberError(enum_values=list(field.type_))
|
||||
return enum_v.value if config.use_enum_values else enum_v
|
||||
|
||||
|
||||
def uuid_validator(v: Any, field: 'ModelField') -> UUID:
|
||||
try:
|
||||
if isinstance(v, str):
|
||||
v = UUID(v)
|
||||
elif isinstance(v, (bytes, bytearray)):
|
||||
try:
|
||||
v = UUID(v.decode())
|
||||
except ValueError:
|
||||
# 16 bytes in big-endian order as the bytes argument fail
|
||||
# the above check
|
||||
v = UUID(bytes=v)
|
||||
except ValueError:
|
||||
raise errors.UUIDError()
|
||||
|
||||
if not isinstance(v, UUID):
|
||||
raise errors.UUIDError()
|
||||
|
||||
required_version = getattr(field.type_, '_required_version', None)
|
||||
if required_version and v.version != required_version:
|
||||
raise errors.UUIDVersionError(required_version=required_version)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def decimal_validator(v: Any) -> Decimal:
|
||||
if isinstance(v, Decimal):
|
||||
return v
|
||||
elif isinstance(v, (bytes, bytearray)):
|
||||
v = v.decode()
|
||||
|
||||
v = str(v).strip()
|
||||
|
||||
try:
|
||||
v = Decimal(v)
|
||||
except DecimalException:
|
||||
raise errors.DecimalError()
|
||||
|
||||
if not v.is_finite():
|
||||
raise errors.DecimalIsNotFiniteError()
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def hashable_validator(v: Any) -> Hashable:
|
||||
if isinstance(v, Hashable):
|
||||
return v
|
||||
|
||||
raise errors.HashableError()
|
||||
|
||||
|
||||
def ip_v4_address_validator(v: Any) -> IPv4Address:
|
||||
if isinstance(v, IPv4Address):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv4Address(v)
|
||||
except ValueError:
|
||||
raise errors.IPv4AddressError()
|
||||
|
||||
|
||||
def ip_v6_address_validator(v: Any) -> IPv6Address:
|
||||
if isinstance(v, IPv6Address):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv6Address(v)
|
||||
except ValueError:
|
||||
raise errors.IPv6AddressError()
|
||||
|
||||
|
||||
def ip_v4_network_validator(v: Any) -> IPv4Network:
|
||||
"""
|
||||
Assume IPv4Network initialised with a default ``strict`` argument
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
|
||||
"""
|
||||
if isinstance(v, IPv4Network):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv4Network(v)
|
||||
except ValueError:
|
||||
raise errors.IPv4NetworkError()
|
||||
|
||||
|
||||
def ip_v6_network_validator(v: Any) -> IPv6Network:
|
||||
"""
|
||||
Assume IPv6Network initialised with a default ``strict`` argument
|
||||
|
||||
See more:
|
||||
https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
|
||||
"""
|
||||
if isinstance(v, IPv6Network):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv6Network(v)
|
||||
except ValueError:
|
||||
raise errors.IPv6NetworkError()
|
||||
|
||||
|
||||
def ip_v4_interface_validator(v: Any) -> IPv4Interface:
|
||||
if isinstance(v, IPv4Interface):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv4Interface(v)
|
||||
except ValueError:
|
||||
raise errors.IPv4InterfaceError()
|
||||
|
||||
|
||||
def ip_v6_interface_validator(v: Any) -> IPv6Interface:
|
||||
if isinstance(v, IPv6Interface):
|
||||
return v
|
||||
|
||||
try:
|
||||
return IPv6Interface(v)
|
||||
except ValueError:
|
||||
raise errors.IPv6InterfaceError()
|
||||
|
||||
|
||||
def path_validator(v: Any) -> Path:
|
||||
if isinstance(v, Path):
|
||||
return v
|
||||
|
||||
try:
|
||||
return Path(v)
|
||||
except TypeError:
|
||||
raise errors.PathError()
|
||||
|
||||
|
||||
def path_exists_validator(v: Any) -> Path:
|
||||
if not v.exists():
|
||||
raise errors.PathNotExistsError(path=v)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def callable_validator(v: Any) -> AnyCallable:
|
||||
"""
|
||||
Perform a simple check if the value is callable.
|
||||
|
||||
Note: complete matching of argument type hints and return types is not performed
|
||||
"""
|
||||
if callable(v):
|
||||
return v
|
||||
|
||||
raise errors.CallableError(value=v)
|
||||
|
||||
|
||||
def enum_validator(v: Any) -> Enum:
|
||||
if isinstance(v, Enum):
|
||||
return v
|
||||
|
||||
raise errors.EnumError(value=v)
|
||||
|
||||
|
||||
def int_enum_validator(v: Any) -> IntEnum:
|
||||
if isinstance(v, IntEnum):
|
||||
return v
|
||||
|
||||
raise errors.IntEnumError(value=v)
|
||||
|
||||
|
||||
def make_literal_validator(type_: Any) -> Callable[[Any], Any]:
|
||||
permitted_choices = all_literal_values(type_)
|
||||
|
||||
# To have a O(1) complexity and still return one of the values set inside the `Literal`,
|
||||
# we create a dict with the set values (a set causes some problems with the way intersection works).
|
||||
# In some cases the set value and checked value can indeed be different (see `test_literal_validator_str_enum`)
|
||||
allowed_choices = {v: v for v in permitted_choices}
|
||||
|
||||
def literal_validator(v: Any) -> Any:
|
||||
try:
|
||||
return allowed_choices[v]
|
||||
except KeyError:
|
||||
raise errors.WrongConstantError(given=v, permitted=permitted_choices)
|
||||
|
||||
return literal_validator
|
||||
|
||||
|
||||
def constr_length_validator(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
|
||||
v_len = len(v)
|
||||
|
||||
min_length = field.type_.min_length if field.type_.min_length is not None else config.min_anystr_length
|
||||
if v_len < min_length:
|
||||
raise errors.AnyStrMinLengthError(limit_value=min_length)
|
||||
|
||||
max_length = field.type_.max_length if field.type_.max_length is not None else config.max_anystr_length
|
||||
if max_length is not None and v_len > max_length:
|
||||
raise errors.AnyStrMaxLengthError(limit_value=max_length)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def constr_strip_whitespace(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
|
||||
strip_whitespace = field.type_.strip_whitespace or config.anystr_strip_whitespace
|
||||
if strip_whitespace:
|
||||
v = v.strip()
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def constr_upper(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
|
||||
upper = field.type_.to_upper or config.anystr_upper
|
||||
if upper:
|
||||
v = v.upper()
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def constr_lower(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
|
||||
lower = field.type_.to_lower or config.anystr_lower
|
||||
if lower:
|
||||
v = v.lower()
|
||||
return v
|
||||
|
||||
|
||||
def validate_json(v: Any, config: 'BaseConfig') -> Any:
|
||||
if v is None:
|
||||
# pass None through to other validators
|
||||
return v
|
||||
try:
|
||||
return config.json_loads(v) # type: ignore
|
||||
except ValueError:
|
||||
raise errors.JsonError()
|
||||
except TypeError:
|
||||
raise errors.JsonTypeError()
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def make_arbitrary_type_validator(type_: Type[T]) -> Callable[[T], T]:
|
||||
def arbitrary_type_validator(v: Any) -> T:
|
||||
if isinstance(v, type_):
|
||||
return v
|
||||
raise errors.ArbitraryTypeError(expected_arbitrary_type=type_)
|
||||
|
||||
return arbitrary_type_validator
|
||||
|
||||
|
||||
def make_class_validator(type_: Type[T]) -> Callable[[Any], Type[T]]:
|
||||
def class_validator(v: Any) -> Type[T]:
|
||||
if lenient_issubclass(v, type_):
|
||||
return v
|
||||
raise errors.SubclassError(expected_class=type_)
|
||||
|
||||
return class_validator
|
||||
|
||||
|
||||
def any_class_validator(v: Any) -> Type[T]:
|
||||
if isinstance(v, type):
|
||||
return v
|
||||
raise errors.ClassError()
|
||||
|
||||
|
||||
def none_validator(v: Any) -> 'Literal[None]':
|
||||
if v is None:
|
||||
return v
|
||||
raise errors.NotNoneError()
|
||||
|
||||
|
||||
def pattern_validator(v: Any) -> Pattern[str]:
|
||||
if isinstance(v, Pattern):
|
||||
return v
|
||||
|
||||
str_value = str_validator(v)
|
||||
|
||||
try:
|
||||
return re.compile(str_value)
|
||||
except re.error:
|
||||
raise errors.PatternError()
|
||||
|
||||
|
||||
NamedTupleT = TypeVar('NamedTupleT', bound=NamedTuple)
|
||||
|
||||
|
||||
def make_namedtuple_validator(
|
||||
namedtuple_cls: Type[NamedTupleT], config: Type['BaseConfig']
|
||||
) -> Callable[[Tuple[Any, ...]], NamedTupleT]:
|
||||
from .annotated_types import create_model_from_namedtuple
|
||||
|
||||
NamedTupleModel = create_model_from_namedtuple(
|
||||
namedtuple_cls,
|
||||
__config__=config,
|
||||
__module__=namedtuple_cls.__module__,
|
||||
)
|
||||
namedtuple_cls.__pydantic_model__ = NamedTupleModel # type: ignore[attr-defined]
|
||||
|
||||
def namedtuple_validator(values: Tuple[Any, ...]) -> NamedTupleT:
|
||||
annotations = NamedTupleModel.__annotations__
|
||||
|
||||
if len(values) > len(annotations):
|
||||
raise errors.ListMaxLengthError(limit_value=len(annotations))
|
||||
|
||||
dict_values: Dict[str, Any] = dict(zip(annotations, values))
|
||||
validated_dict_values: Dict[str, Any] = dict(NamedTupleModel(**dict_values))
|
||||
return namedtuple_cls(**validated_dict_values)
|
||||
|
||||
return namedtuple_validator
|
||||
|
||||
|
||||
def make_typeddict_validator(
|
||||
typeddict_cls: Type['TypedDict'], config: Type['BaseConfig'] # type: ignore[valid-type]
|
||||
) -> Callable[[Any], Dict[str, Any]]:
|
||||
from .annotated_types import create_model_from_typeddict
|
||||
|
||||
TypedDictModel = create_model_from_typeddict(
|
||||
typeddict_cls,
|
||||
__config__=config,
|
||||
__module__=typeddict_cls.__module__,
|
||||
)
|
||||
typeddict_cls.__pydantic_model__ = TypedDictModel # type: ignore[attr-defined]
|
||||
|
||||
def typeddict_validator(values: 'TypedDict') -> Dict[str, Any]: # type: ignore[valid-type]
|
||||
return TypedDictModel.parse_obj(values).dict(exclude_unset=True)
|
||||
|
||||
return typeddict_validator
|
||||
|
||||
|
||||
class IfConfig:
|
||||
def __init__(self, validator: AnyCallable, *config_attr_names: str, ignored_value: Any = False) -> None:
|
||||
self.validator = validator
|
||||
self.config_attr_names = config_attr_names
|
||||
self.ignored_value = ignored_value
|
||||
|
||||
def check(self, config: Type['BaseConfig']) -> bool:
|
||||
return any(getattr(config, name) not in {None, self.ignored_value} for name in self.config_attr_names)
|
||||
|
||||
|
||||
# order is important here, for example: bool is a subclass of int so has to come first, datetime before date same,
|
||||
# IPv4Interface before IPv4Address, etc
|
||||
_VALIDATORS: List[Tuple[Type[Any], List[Any]]] = [
|
||||
(IntEnum, [int_validator, enum_member_validator]),
|
||||
(Enum, [enum_member_validator]),
|
||||
(
|
||||
str,
|
||||
[
|
||||
str_validator,
|
||||
IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'),
|
||||
IfConfig(anystr_upper, 'anystr_upper'),
|
||||
IfConfig(anystr_lower, 'anystr_lower'),
|
||||
IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'),
|
||||
],
|
||||
),
|
||||
(
|
||||
bytes,
|
||||
[
|
||||
bytes_validator,
|
||||
IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'),
|
||||
IfConfig(anystr_upper, 'anystr_upper'),
|
||||
IfConfig(anystr_lower, 'anystr_lower'),
|
||||
IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'),
|
||||
],
|
||||
),
|
||||
(bool, [bool_validator]),
|
||||
(int, [int_validator]),
|
||||
(float, [float_validator, IfConfig(float_finite_validator, 'allow_inf_nan', ignored_value=True)]),
|
||||
(Path, [path_validator]),
|
||||
(datetime, [parse_datetime]),
|
||||
(date, [parse_date]),
|
||||
(time, [parse_time]),
|
||||
(timedelta, [parse_duration]),
|
||||
(OrderedDict, [ordered_dict_validator]),
|
||||
(dict, [dict_validator]),
|
||||
(list, [list_validator]),
|
||||
(tuple, [tuple_validator]),
|
||||
(set, [set_validator]),
|
||||
(frozenset, [frozenset_validator]),
|
||||
(deque, [deque_validator]),
|
||||
(UUID, [uuid_validator]),
|
||||
(Decimal, [decimal_validator]),
|
||||
(IPv4Interface, [ip_v4_interface_validator]),
|
||||
(IPv6Interface, [ip_v6_interface_validator]),
|
||||
(IPv4Address, [ip_v4_address_validator]),
|
||||
(IPv6Address, [ip_v6_address_validator]),
|
||||
(IPv4Network, [ip_v4_network_validator]),
|
||||
(IPv6Network, [ip_v6_network_validator]),
|
||||
]
|
||||
|
||||
|
||||
def find_validators( # noqa: C901 (ignore complexity)
|
||||
type_: Type[Any], config: Type['BaseConfig']
|
||||
) -> Generator[AnyCallable, None, None]:
|
||||
from .dataclasses import is_builtin_dataclass, make_dataclass_validator
|
||||
|
||||
if type_ is Any or type_ is object:
|
||||
return
|
||||
type_type = type_.__class__
|
||||
if type_type == ForwardRef or type_type == TypeVar:
|
||||
return
|
||||
|
||||
if is_none_type(type_):
|
||||
yield none_validator
|
||||
return
|
||||
if type_ is Pattern or type_ is re.Pattern:
|
||||
yield pattern_validator
|
||||
return
|
||||
if type_ is Hashable or type_ is CollectionsHashable:
|
||||
yield hashable_validator
|
||||
return
|
||||
if is_callable_type(type_):
|
||||
yield callable_validator
|
||||
return
|
||||
if is_literal_type(type_):
|
||||
yield make_literal_validator(type_)
|
||||
return
|
||||
if is_builtin_dataclass(type_):
|
||||
yield from make_dataclass_validator(type_, config)
|
||||
return
|
||||
if type_ is Enum:
|
||||
yield enum_validator
|
||||
return
|
||||
if type_ is IntEnum:
|
||||
yield int_enum_validator
|
||||
return
|
||||
if is_namedtuple(type_):
|
||||
yield tuple_validator
|
||||
yield make_namedtuple_validator(type_, config)
|
||||
return
|
||||
if is_typeddict(type_):
|
||||
yield make_typeddict_validator(type_, config)
|
||||
return
|
||||
|
||||
class_ = get_class(type_)
|
||||
if class_ is not None:
|
||||
if class_ is not Any and isinstance(class_, type):
|
||||
yield make_class_validator(class_)
|
||||
else:
|
||||
yield any_class_validator
|
||||
return
|
||||
|
||||
for val_type, validators in _VALIDATORS:
|
||||
try:
|
||||
if issubclass(type_, val_type):
|
||||
for v in validators:
|
||||
if isinstance(v, IfConfig):
|
||||
if v.check(config):
|
||||
yield v.validator
|
||||
else:
|
||||
yield v
|
||||
return
|
||||
except TypeError:
|
||||
raise RuntimeError(f'error checking inheritance of {type_!r} (type: {display_as_type(type_)})')
|
||||
|
||||
if config.arbitrary_types_allowed:
|
||||
yield make_arbitrary_type_validator(type_)
|
||||
else:
|
||||
raise RuntimeError(f'no validator found for {type_}, see `arbitrary_types_allowed` in Config')
|
38
lib/pydantic/version.py
Normal file
38
lib/pydantic/version.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
__all__ = 'compiled', 'VERSION', 'version_info'
|
||||
|
||||
VERSION = '1.10.2'
|
||||
|
||||
try:
|
||||
import cython # type: ignore
|
||||
except ImportError:
|
||||
compiled: bool = False
|
||||
else: # pragma: no cover
|
||||
try:
|
||||
compiled = cython.compiled
|
||||
except AttributeError:
|
||||
compiled = False
|
||||
|
||||
|
||||
def version_info() -> str:
|
||||
import platform
|
||||
import sys
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
|
||||
optional_deps = []
|
||||
for p in ('devtools', 'dotenv', 'email-validator', 'typing-extensions'):
|
||||
try:
|
||||
import_module(p.replace('-', '_'))
|
||||
except ImportError:
|
||||
continue
|
||||
optional_deps.append(p)
|
||||
|
||||
info = {
|
||||
'pydantic version': VERSION,
|
||||
'pydantic compiled': compiled,
|
||||
'install path': Path(__file__).resolve().parent,
|
||||
'python version': sys.version,
|
||||
'platform': platform.platform(),
|
||||
'optional deps. installed': optional_deps,
|
||||
}
|
||||
return '\n'.join('{:>30} {}'.format(k + ':', str(v).replace('\n', ' ')) for k, v in info.items())
|
File diff suppressed because it is too large
Load diff
|
@ -8,7 +8,7 @@ beautifulsoup4==4.11.1
|
|||
bleach==5.0.1
|
||||
certifi==2022.9.24
|
||||
cheroot==8.6.0
|
||||
cherrypy==18.6.1
|
||||
cherrypy==18.8.0
|
||||
cloudinary==1.29.0
|
||||
distro==1.7.0
|
||||
dnspython==2.2.1
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue