Bump cherrypy from 18.6.1 to 18.8.0 (#1796)

* Bump cherrypy from 18.6.1 to 18.8.0

Bumps [cherrypy](https://github.com/cherrypy/cherrypy) from 18.6.1 to 18.8.0.
- [Release notes](https://github.com/cherrypy/cherrypy/releases)
- [Changelog](https://github.com/cherrypy/cherrypy/blob/main/CHANGES.rst)
- [Commits](https://github.com/cherrypy/cherrypy/compare/v18.6.1...v18.8.0)

---
updated-dependencies:
- dependency-name: cherrypy
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update cherrypy==18.8.0

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com>
This commit is contained in:
dependabot[bot] 2022-11-12 17:53:03 -08:00 committed by GitHub
parent e79da07973
commit 76cc56a215
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
75 changed files with 19150 additions and 1339 deletions

View file

@ -0,0 +1,27 @@
# Copyright 2014-2016 Nathan West
#
# This file is part of autocommand.
#
# autocommand is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# autocommand is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with autocommand. If not, see <http://www.gnu.org/licenses/>.
# flake8 flags all these imports as unused, hence the NOQAs everywhere.
from .automain import automain # NOQA
from .autoparse import autoparse, smart_open # NOQA
from .autocommand import autocommand # NOQA
try:
from .autoasync import autoasync # NOQA
except ImportError: # pragma: no cover
pass

View file

@ -0,0 +1,140 @@
# Copyright 2014-2015 Nathan West
#
# This file is part of autocommand.
#
# autocommand is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# autocommand is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with autocommand. If not, see <http://www.gnu.org/licenses/>.
from asyncio import get_event_loop, iscoroutine
from functools import wraps
from inspect import signature
def _launch_forever_coro(coro, args, kwargs, loop):
'''
This helper function launches an async main function that was tagged with
forever=True. There are two possibilities:
- The function is a normal function, which handles initializing the event
loop, which is then run forever
- The function is a coroutine, which needs to be scheduled in the event
loop, which is then run forever
- There is also the possibility that the function is a normal function
wrapping a coroutine function
The function is therefore called unconditionally and scheduled in the event
loop if the return value is a coroutine object.
The reason this is a separate function is to make absolutely sure that all
the objects created are garbage collected after all is said and done; we
do this to ensure that any exceptions raised in the tasks are collected
ASAP.
'''
# Personal note: I consider this an antipattern, as it relies on the use of
# unowned resources. The setup function dumps some stuff into the event
# loop where it just whirls in the ether without a well defined owner or
# lifetime. For this reason, there's a good chance I'll remove the
# forever=True feature from autoasync at some point in the future.
thing = coro(*args, **kwargs)
if iscoroutine(thing):
loop.create_task(thing)
def autoasync(coro=None, *, loop=None, forever=False, pass_loop=False):
'''
Convert an asyncio coroutine into a function which, when called, is
evaluted in an event loop, and the return value returned. This is intented
to make it easy to write entry points into asyncio coroutines, which
otherwise need to be explictly evaluted with an event loop's
run_until_complete.
If `loop` is given, it is used as the event loop to run the coro in. If it
is None (the default), the loop is retreived using asyncio.get_event_loop.
This call is defered until the decorated function is called, so that
callers can install custom event loops or event loop policies after
@autoasync is applied.
If `forever` is True, the loop is run forever after the decorated coroutine
is finished. Use this for servers created with asyncio.start_server and the
like.
If `pass_loop` is True, the event loop object is passed into the coroutine
as the `loop` kwarg when the wrapper function is called. In this case, the
wrapper function's __signature__ is updated to remove this parameter, so
that autoparse can still be used on it without generating a parameter for
`loop`.
This coroutine can be called with ( @autoasync(...) ) or without
( @autoasync ) arguments.
Examples:
@autoasync
def get_file(host, port):
reader, writer = yield from asyncio.open_connection(host, port)
data = reader.read()
sys.stdout.write(data.decode())
get_file(host, port)
@autoasync(forever=True, pass_loop=True)
def server(host, port, loop):
yield_from loop.create_server(Proto, host, port)
server('localhost', 8899)
'''
if coro is None:
return lambda c: autoasync(
c, loop=loop,
forever=forever,
pass_loop=pass_loop)
# The old and new signatures are required to correctly bind the loop
# parameter in 100% of cases, even if it's a positional parameter.
# NOTE: A future release will probably require the loop parameter to be
# a kwonly parameter.
if pass_loop:
old_sig = signature(coro)
new_sig = old_sig.replace(parameters=(
param for name, param in old_sig.parameters.items()
if name != "loop"))
@wraps(coro)
def autoasync_wrapper(*args, **kwargs):
# Defer the call to get_event_loop so that, if a custom policy is
# installed after the autoasync decorator, it is respected at call time
local_loop = get_event_loop() if loop is None else loop
# Inject the 'loop' argument. We have to use this signature binding to
# ensure it's injected in the correct place (positional, keyword, etc)
if pass_loop:
bound_args = old_sig.bind_partial()
bound_args.arguments.update(
loop=local_loop,
**new_sig.bind(*args, **kwargs).arguments)
args, kwargs = bound_args.args, bound_args.kwargs
if forever:
_launch_forever_coro(coro, args, kwargs, local_loop)
local_loop.run_forever()
else:
return local_loop.run_until_complete(coro(*args, **kwargs))
# Attach the updated signature. This allows 'pass_loop' to be used with
# autoparse
if pass_loop:
autoasync_wrapper.__signature__ = new_sig
return autoasync_wrapper

View file

@ -0,0 +1,70 @@
# Copyright 2014-2015 Nathan West
#
# This file is part of autocommand.
#
# autocommand is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# autocommand is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with autocommand. If not, see <http://www.gnu.org/licenses/>.
from .autoparse import autoparse
from .automain import automain
try:
from .autoasync import autoasync
except ImportError: # pragma: no cover
pass
def autocommand(
module, *,
description=None,
epilog=None,
add_nos=False,
parser=None,
loop=None,
forever=False,
pass_loop=False):
if callable(module):
raise TypeError('autocommand requires a module name argument')
def autocommand_decorator(func):
# Step 1: if requested, run it all in an asyncio event loop. autoasync
# patches the __signature__ of the decorated function, so that in the
# event that pass_loop is True, the `loop` parameter of the original
# function will *not* be interpreted as a command-line argument by
# autoparse
if loop is not None or forever or pass_loop:
func = autoasync(
func,
loop=None if loop is True else loop,
pass_loop=pass_loop,
forever=forever)
# Step 2: create parser. We do this second so that the arguments are
# parsed and passed *before* entering the asyncio event loop, if it
# exists. This simplifies the stack trace and ensures errors are
# reported earlier. It also ensures that errors raised during parsing &
# passing are still raised if `forever` is True.
func = autoparse(
func,
description=description,
epilog=epilog,
add_nos=add_nos,
parser=parser)
# Step 3: call the function automatically if __name__ == '__main__' (or
# if True was provided)
func = automain(module)(func)
return func
return autocommand_decorator

View file

@ -0,0 +1,59 @@
# Copyright 2014-2015 Nathan West
#
# This file is part of autocommand.
#
# autocommand is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# autocommand is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with autocommand. If not, see <http://www.gnu.org/licenses/>.
import sys
from .errors import AutocommandError
class AutomainRequiresModuleError(AutocommandError, TypeError):
pass
def automain(module, *, args=(), kwargs=None):
'''
This decorator automatically invokes a function if the module is being run
as the "__main__" module. Optionally, provide args or kwargs with which to
call the function. If `module` is "__main__", the function is called, and
the program is `sys.exit`ed with the return value. You can also pass `True`
to cause the function to be called unconditionally. If the function is not
called, it is returned unchanged by the decorator.
Usage:
@automain(__name__) # Pass __name__ to check __name__=="__main__"
def main():
...
If __name__ is "__main__" here, the main function is called, and then
sys.exit called with the return value.
'''
# Check that @automain(...) was called, rather than @automain
if callable(module):
raise AutomainRequiresModuleError(module)
if module == '__main__' or module is True:
if kwargs is None:
kwargs = {}
# Use a function definition instead of a lambda for a neater traceback
def automain_decorator(main):
sys.exit(main(*args, **kwargs))
return automain_decorator
else:
return lambda main: main

View file

@ -0,0 +1,333 @@
# Copyright 2014-2015 Nathan West
#
# This file is part of autocommand.
#
# autocommand is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# autocommand is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with autocommand. If not, see <http://www.gnu.org/licenses/>.
import sys
from re import compile as compile_regex
from inspect import signature, getdoc, Parameter
from argparse import ArgumentParser
from contextlib import contextmanager
from functools import wraps
from io import IOBase
from autocommand.errors import AutocommandError
_empty = Parameter.empty
class AnnotationError(AutocommandError):
'''Annotation error: annotation must be a string, type, or tuple of both'''
class PositionalArgError(AutocommandError):
'''
Postional Arg Error: autocommand can't handle postional-only parameters
'''
class KWArgError(AutocommandError):
'''kwarg Error: autocommand can't handle a **kwargs parameter'''
class DocstringError(AutocommandError):
'''Docstring error'''
class TooManySplitsError(DocstringError):
'''
The docstring had too many ---- section splits. Currently we only support
using up to a single split, to split the docstring into description and
epilog parts.
'''
def _get_type_description(annotation):
'''
Given an annotation, return the (type, description) for the parameter.
If you provide an annotation that is somehow both a string and a callable,
the behavior is undefined.
'''
if annotation is _empty:
return None, None
elif callable(annotation):
return annotation, None
elif isinstance(annotation, str):
return None, annotation
elif isinstance(annotation, tuple):
try:
arg1, arg2 = annotation
except ValueError as e:
raise AnnotationError(annotation) from e
else:
if callable(arg1) and isinstance(arg2, str):
return arg1, arg2
elif isinstance(arg1, str) and callable(arg2):
return arg2, arg1
raise AnnotationError(annotation)
def _add_arguments(param, parser, used_char_args, add_nos):
'''
Add the argument(s) to an ArgumentParser (using add_argument) for a given
parameter. used_char_args is the set of -short options currently already in
use, and is updated (if necessary) by this function. If add_nos is True,
this will also add an inverse switch for all boolean options. For
instance, for the boolean parameter "verbose", this will create --verbose
and --no-verbose.
'''
# Impl note: This function is kept separate from make_parser because it's
# already very long and I wanted to separate out as much as possible into
# its own call scope, to prevent even the possibility of suble mutation
# bugs.
if param.kind is param.POSITIONAL_ONLY:
raise PositionalArgError(param)
elif param.kind is param.VAR_KEYWORD:
raise KWArgError(param)
# These are the kwargs for the add_argument function.
arg_spec = {}
is_option = False
# Get the type and default from the annotation.
arg_type, description = _get_type_description(param.annotation)
# Get the default value
default = param.default
# If there is no explicit type, and the default is present and not None,
# infer the type from the default.
if arg_type is None and default not in {_empty, None}:
arg_type = type(default)
# Add default. The presence of a default means this is an option, not an
# argument.
if default is not _empty:
arg_spec['default'] = default
is_option = True
# Add the type
if arg_type is not None:
# Special case for bool: make it just a --switch
if arg_type is bool:
if not default or default is _empty:
arg_spec['action'] = 'store_true'
else:
arg_spec['action'] = 'store_false'
# Switches are always options
is_option = True
# Special case for file types: make it a string type, for filename
elif isinstance(default, IOBase):
arg_spec['type'] = str
# TODO: special case for list type.
# - How to specificy type of list members?
# - param: [int]
# - param: int =[]
# - action='append' vs nargs='*'
else:
arg_spec['type'] = arg_type
# nargs: if the signature includes *args, collect them as trailing CLI
# arguments in a list. *args can't have a default value, so it can never be
# an option.
if param.kind is param.VAR_POSITIONAL:
# TODO: consider depluralizing metavar/name here.
arg_spec['nargs'] = '*'
# Add description.
if description is not None:
arg_spec['help'] = description
# Get the --flags
flags = []
name = param.name
if is_option:
# Add the first letter as a -short option.
for letter in name[0], name[0].swapcase():
if letter not in used_char_args:
used_char_args.add(letter)
flags.append('-{}'.format(letter))
break
# If the parameter is a --long option, or is a -short option that
# somehow failed to get a flag, add it.
if len(name) > 1 or not flags:
flags.append('--{}'.format(name))
arg_spec['dest'] = name
else:
flags.append(name)
parser.add_argument(*flags, **arg_spec)
# Create the --no- version for boolean switches
if add_nos and arg_type is bool:
parser.add_argument(
'--no-{}'.format(name),
action='store_const',
dest=name,
const=default if default is not _empty else False)
def make_parser(func_sig, description, epilog, add_nos):
'''
Given the signature of a function, create an ArgumentParser
'''
parser = ArgumentParser(description=description, epilog=epilog)
used_char_args = {'h'}
# Arange the params so that single-character arguments are first. This
# esnures they don't have to get --long versions. sorted is stable, so the
# parameters will otherwise still be in relative order.
params = sorted(
func_sig.parameters.values(),
key=lambda param: len(param.name) > 1)
for param in params:
_add_arguments(param, parser, used_char_args, add_nos)
return parser
_DOCSTRING_SPLIT = compile_regex(r'\n\s*-{4,}\s*\n')
def parse_docstring(docstring):
'''
Given a docstring, parse it into a description and epilog part
'''
if docstring is None:
return '', ''
parts = _DOCSTRING_SPLIT.split(docstring)
if len(parts) == 1:
return docstring, ''
elif len(parts) == 2:
return parts[0], parts[1]
else:
raise TooManySplitsError()
def autoparse(
func=None, *,
description=None,
epilog=None,
add_nos=False,
parser=None):
'''
This decorator converts a function that takes normal arguments into a
function which takes a single optional argument, argv, parses it using an
argparse.ArgumentParser, and calls the underlying function with the parsed
arguments. If it is not given, sys.argv[1:] is used. This is so that the
function can be used as a setuptools entry point, as well as a normal main
function. sys.argv[1:] is not evaluated until the function is called, to
allow injecting different arguments for testing.
It uses the argument signature of the function to create an
ArgumentParser. Parameters without defaults become positional parameters,
while parameters *with* defaults become --options. Use annotations to set
the type of the parameter.
The `desctiption` and `epilog` parameters corrospond to the same respective
argparse parameters. If no description is given, it defaults to the
decorated functions's docstring, if present.
If add_nos is True, every boolean option (that is, every parameter with a
default of True/False or a type of bool) will have a --no- version created
as well, which inverts the option. For instance, the --verbose option will
have a --no-verbose counterpart. These are not mutually exclusive-
whichever one appears last in the argument list will have precedence.
If a parser is given, it is used instead of one generated from the function
signature. In this case, no parser is created; instead, the given parser is
used to parse the argv argument. The parser's results' argument names must
match up with the parameter names of the decorated function.
The decorated function is attached to the result as the `func` attribute,
and the parser is attached as the `parser` attribute.
'''
# If @autoparse(...) is used instead of @autoparse
if func is None:
return lambda f: autoparse(
f, description=description,
epilog=epilog,
add_nos=add_nos,
parser=parser)
func_sig = signature(func)
docstr_description, docstr_epilog = parse_docstring(getdoc(func))
if parser is None:
parser = make_parser(
func_sig,
description or docstr_description,
epilog or docstr_epilog,
add_nos)
@wraps(func)
def autoparse_wrapper(argv=None):
if argv is None:
argv = sys.argv[1:]
# Get empty argument binding, to fill with parsed arguments. This
# object does all the heavy lifting of turning named arguments into
# into correctly bound *args and **kwargs.
parsed_args = func_sig.bind_partial()
parsed_args.arguments.update(vars(parser.parse_args(argv)))
return func(*parsed_args.args, **parsed_args.kwargs)
# TODO: attach an updated __signature__ to autoparse_wrapper, just in case.
# Attach the wrapped function and parser, and return the wrapper.
autoparse_wrapper.func = func
autoparse_wrapper.parser = parser
return autoparse_wrapper
@contextmanager
def smart_open(filename_or_file, *args, **kwargs):
'''
This context manager allows you to open a filename, if you want to default
some already-existing file object, like sys.stdout, which shouldn't be
closed at the end of the context. If the filename argument is a str, bytes,
or int, the file object is created via a call to open with the given *args
and **kwargs, sent to the context, and closed at the end of the context,
just like "with open(filename) as f:". If it isn't one of the openable
types, the object simply sent to the context unchanged, and left unclosed
at the end of the context. Example:
def work_with_file(name=sys.stdout):
with smart_open(name) as f:
# Works correctly if name is a str filename or sys.stdout
print("Some stuff", file=f)
# If it was a filename, f is closed at the end here.
'''
if isinstance(filename_or_file, (str, bytes, int)):
with open(filename_or_file, *args, **kwargs) as file:
yield file
else:
yield filename_or_file

23
lib/autocommand/errors.py Normal file
View file

@ -0,0 +1,23 @@
# Copyright 2014-2016 Nathan West
#
# This file is part of autocommand.
#
# autocommand is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# autocommand is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with autocommand. If not, see <http://www.gnu.org/licenses/>.
class AutocommandError(Exception):
'''Base class for autocommand exceptions'''
pass
# Individual modules will define errors specific to that module.

View file

@ -206,12 +206,8 @@ except ImportError:
def test_callable_spec(callable, args, kwargs): # noqa: F811 def test_callable_spec(callable, args, kwargs): # noqa: F811
return None return None
else: else:
getargspec = inspect.getargspec def getargspec(callable):
# Python 3 requires using getfullargspec if return inspect.getfullargspec(callable)[:4]
# keyword-only arguments are present
if hasattr(inspect, 'getfullargspec'):
def getargspec(callable):
return inspect.getfullargspec(callable)[:4]
class LateParamPageHandler(PageHandler): class LateParamPageHandler(PageHandler):

View file

@ -466,7 +466,7 @@ _HTTPErrorTemplate = '''<!DOCTYPE html PUBLIC
<pre id="traceback">%(traceback)s</pre> <pre id="traceback">%(traceback)s</pre>
<div id="powered_by"> <div id="powered_by">
<span> <span>
Powered by <a href="http://www.cherrypy.org">CherryPy %(version)s</a> Powered by <a href="http://www.cherrypy.dev">CherryPy %(version)s</a>
</span> </span>
</div> </div>
</body> </body>
@ -532,7 +532,8 @@ def get_error_page(status, **kwargs):
return result return result
else: else:
# Load the template from this path. # Load the template from this path.
template = io.open(error_page, newline='').read() with io.open(error_page, newline='') as f:
template = f.read()
except Exception: except Exception:
e = _format_exception(*_exc_info())[-1] e = _format_exception(*_exc_info())[-1]
m = kwargs['message'] m = kwargs['message']

View file

@ -339,11 +339,8 @@ LoadModule python_module modules/mod_python.so
} }
mpconf = os.path.join(os.path.dirname(__file__), 'cpmodpy.conf') mpconf = os.path.join(os.path.dirname(__file__), 'cpmodpy.conf')
f = open(mpconf, 'wb') with open(mpconf, 'wb') as f:
try:
f.write(conf_data) f.write(conf_data)
finally:
f.close()
response = read_process(self.apache_path, '-k start -f %s' % mpconf) response = read_process(self.apache_path, '-k start -f %s' % mpconf)
self.ready = True self.ready = True

View file

@ -169,7 +169,7 @@ def request_namespace(k, v):
def response_namespace(k, v): def response_namespace(k, v):
"""Attach response attributes declared in config.""" """Attach response attributes declared in config."""
# Provides config entries to set default response headers # Provides config entries to set default response headers
# http://cherrypy.org/ticket/889 # http://cherrypy.dev/ticket/889
if k[:8] == 'headers.': if k[:8] == 'headers.':
cherrypy.serving.response.headers[k.split('.', 1)[1]] = v cherrypy.serving.response.headers[k.split('.', 1)[1]] = v
else: else:
@ -252,7 +252,7 @@ class Request(object):
The query component of the Request-URI, a string of information to be The query component of the Request-URI, a string of information to be
interpreted by the resource. The query portion of a URI follows the interpreted by the resource. The query portion of a URI follows the
path component, and is separated by a '?'. For example, the URI path component, and is separated by a '?'. For example, the URI
'http://www.cherrypy.org/wiki?a=3&b=4' has the query component, 'http://www.cherrypy.dev/wiki?a=3&b=4' has the query component,
'a=3&b=4'.""" 'a=3&b=4'."""
query_string_encoding = 'utf8' query_string_encoding = 'utf8'
@ -742,6 +742,9 @@ class Request(object):
if self.protocol >= (1, 1): if self.protocol >= (1, 1):
msg = "HTTP/1.1 requires a 'Host' request header." msg = "HTTP/1.1 requires a 'Host' request header."
raise cherrypy.HTTPError(400, msg) raise cherrypy.HTTPError(400, msg)
else:
headers['Host'] = httputil.SanitizedHost(dict.get(headers, 'Host'))
host = dict.get(headers, 'Host') host = dict.get(headers, 'Host')
if not host: if not host:
host = self.local.name or self.local.ip host = self.local.name or self.local.ip

View file

@ -101,13 +101,12 @@ def get_ha1_file_htdigest(filename):
""" """
def get_ha1(realm, username): def get_ha1(realm, username):
result = None result = None
f = open(filename, 'r') with open(filename, 'r') as f:
for line in f: for line in f:
u, r, ha1 = line.rstrip().split(':') u, r, ha1 = line.rstrip().split(':')
if u == username and r == realm: if u == username and r == realm:
result = ha1 result = ha1
break break
f.close()
return result return result
return get_ha1 return get_ha1

View file

@ -334,9 +334,10 @@ class CoverStats(object):
yield '</body></html>' yield '</body></html>'
def annotated_file(self, filename, statements, excluded, missing): def annotated_file(self, filename, statements, excluded, missing):
source = open(filename, 'r') with open(filename, 'r') as source:
lines = source.readlines()
buffer = [] buffer = []
for lineno, line in enumerate(source.readlines()): for lineno, line in enumerate(lines):
lineno += 1 lineno += 1
line = line.strip('\n\r') line = line.strip('\n\r')
empty_the_buffer = True empty_the_buffer = True

View file

@ -516,3 +516,33 @@ class Host(object):
def __repr__(self): def __repr__(self):
return 'httputil.Host(%r, %r, %r)' % (self.ip, self.port, self.name) return 'httputil.Host(%r, %r, %r)' % (self.ip, self.port, self.name)
class SanitizedHost(str):
r"""
Wraps a raw host header received from the network in
a sanitized version that elides dangerous characters.
>>> SanitizedHost('foo\nbar')
'foobar'
>>> SanitizedHost('foo\nbar').raw
'foo\nbar'
A SanitizedInstance is only returned if sanitization was performed.
>>> isinstance(SanitizedHost('foobar'), SanitizedHost)
False
"""
dangerous = re.compile(r'[\n\r]')
def __new__(cls, raw):
sanitized = cls._sanitize(raw)
if sanitized == raw:
return raw
instance = super().__new__(cls, sanitized)
instance.raw = raw
return instance
@classmethod
def _sanitize(cls, raw):
return cls.dangerous.sub('', raw)

View file

@ -163,11 +163,8 @@ class Parser(configparser.ConfigParser):
# fp = open(filename) # fp = open(filename)
# except IOError: # except IOError:
# continue # continue
fp = open(filename) with open(filename) as fp:
try:
self._read(fp, filename) self._read(fp, filename)
finally:
fp.close()
def as_dict(self, raw=False, vars=None): def as_dict(self, raw=False, vars=None):
"""Convert an INI file to a dictionary""" """Convert an INI file to a dictionary"""

View file

@ -516,11 +516,8 @@ class FileSession(Session):
if path is None: if path is None:
path = self._get_file_path() path = self._get_file_path()
try: try:
f = open(path, 'rb') with open(path, 'rb') as f:
try:
return pickle.load(f) return pickle.load(f)
finally:
f.close()
except (IOError, EOFError): except (IOError, EOFError):
e = sys.exc_info()[1] e = sys.exc_info()[1]
if self.debug: if self.debug:
@ -531,11 +528,8 @@ class FileSession(Session):
def _save(self, expiration_time): def _save(self, expiration_time):
assert self.locked, ('The session was saved without being locked. ' assert self.locked, ('The session was saved without being locked. '
"Check your tools' priority levels.") "Check your tools' priority levels.")
f = open(self._get_file_path(), 'wb') with open(self._get_file_path(), 'wb') as f:
try:
pickle.dump((self._data, expiration_time), f, self.pickle_protocol) pickle.dump((self._data, expiration_time), f, self.pickle_protocol)
finally:
f.close()
def _delete(self): def _delete(self):
assert self.locked, ('The session deletion without being locked. ' assert self.locked, ('The session deletion without being locked. '

View file

@ -436,7 +436,8 @@ class PIDFile(SimplePlugin):
if self.finalized: if self.finalized:
self.bus.log('PID %r already written to %r.' % (pid, self.pidfile)) self.bus.log('PID %r already written to %r.' % (pid, self.pidfile))
else: else:
open(self.pidfile, 'wb').write(ntob('%s\n' % pid, 'utf8')) with open(self.pidfile, 'wb') as f:
f.write(ntob('%s\n' % pid, 'utf8'))
self.bus.log('PID %r written to %r.' % (pid, self.pidfile)) self.bus.log('PID %r written to %r.' % (pid, self.pidfile))
self.finalized = True self.finalized = True
start.priority = 70 start.priority = 70

View file

@ -505,7 +505,8 @@ server.ssl_private_key: r'%s'
def get_pid(self): def get_pid(self):
if self.daemonize: if self.daemonize:
return int(open(self.pid_file, 'rb').read()) with open(self.pid_file, 'rb') as f:
return int(f.read())
return self._proc.pid return self._proc.pid
def join(self): def join(self):

View file

@ -97,7 +97,8 @@ class LogCase(object):
def emptyLog(self): def emptyLog(self):
"""Overwrite self.logfile with 0 bytes.""" """Overwrite self.logfile with 0 bytes."""
open(self.logfile, 'wb').write('') with open(self.logfile, 'wb') as f:
f.write('')
def markLog(self, key=None): def markLog(self, key=None):
"""Insert a marker line into the log and set self.lastmarker.""" """Insert a marker line into the log and set self.lastmarker."""
@ -105,10 +106,11 @@ class LogCase(object):
key = str(time.time()) key = str(time.time())
self.lastmarker = key self.lastmarker = key
open(self.logfile, 'ab+').write( with open(self.logfile, 'ab+') as f:
b'%s%s\n' f.write(
% (self.markerPrefix, key.encode('utf-8')) b'%s%s\n'
) % (self.markerPrefix, key.encode('utf-8'))
)
def _read_marked_region(self, marker=None): def _read_marked_region(self, marker=None):
"""Return lines from self.logfile in the marked region. """Return lines from self.logfile in the marked region.
@ -122,20 +124,23 @@ class LogCase(object):
logfile = self.logfile logfile = self.logfile
marker = marker or self.lastmarker marker = marker or self.lastmarker
if marker is None: if marker is None:
return open(logfile, 'rb').readlines() with open(logfile, 'rb') as f:
return f.readlines()
if isinstance(marker, str): if isinstance(marker, str):
marker = marker.encode('utf-8') marker = marker.encode('utf-8')
data = [] data = []
in_region = False in_region = False
for line in open(logfile, 'rb'): with open(logfile, 'rb') as f:
if in_region: for line in f:
if line.startswith(self.markerPrefix) and marker not in line: if in_region:
break if (line.startswith(self.markerPrefix)
else: and marker not in line):
data.append(line) break
elif marker in line: else:
in_region = True data.append(line)
elif marker in line:
in_region = True
return data return data
def assertInLog(self, line, marker=None): def assertInLog(self, line, marker=None):

View file

@ -14,7 +14,7 @@ KNOWN BUGS
1. Apache processes Range headers automatically; CherryPy's truncated 1. Apache processes Range headers automatically; CherryPy's truncated
output is then truncated again by Apache. See test_core.testRanges. output is then truncated again by Apache. See test_core.testRanges.
This was worked around in http://www.cherrypy.org/changeset/1319. This was worked around in http://www.cherrypy.dev/changeset/1319.
2. Apache does not allow custom HTTP methods like CONNECT as per the spec. 2. Apache does not allow custom HTTP methods like CONNECT as per the spec.
See test_core.testHTTPMethods. See test_core.testHTTPMethods.
3. Max request header and body settings do not work with Apache. 3. Max request header and body settings do not work with Apache.
@ -112,15 +112,12 @@ class ModFCGISupervisor(helper.LocalWSGISupervisor):
fcgiconf = os.path.join(curdir, fcgiconf) fcgiconf = os.path.join(curdir, fcgiconf)
# Write the Apache conf file. # Write the Apache conf file.
f = open(fcgiconf, 'wb') with open(fcgiconf, 'wb') as f:
try:
server = repr(os.path.join(curdir, 'fastcgi.pyc'))[1:-1] server = repr(os.path.join(curdir, 'fastcgi.pyc'))[1:-1]
output = self.template % {'port': self.port, 'root': curdir, output = self.template % {'port': self.port, 'root': curdir,
'server': server} 'server': server}
output = output.replace('\r\n', '\n') output = output.replace('\r\n', '\n')
f.write(output) f.write(output)
finally:
f.close()
result = read_process(APACHE_PATH, '-k start -f %s' % fcgiconf) result = read_process(APACHE_PATH, '-k start -f %s' % fcgiconf)
if result: if result:

View file

@ -14,7 +14,7 @@ KNOWN BUGS
1. Apache processes Range headers automatically; CherryPy's truncated 1. Apache processes Range headers automatically; CherryPy's truncated
output is then truncated again by Apache. See test_core.testRanges. output is then truncated again by Apache. See test_core.testRanges.
This was worked around in http://www.cherrypy.org/changeset/1319. This was worked around in http://www.cherrypy.dev/changeset/1319.
2. Apache does not allow custom HTTP methods like CONNECT as per the spec. 2. Apache does not allow custom HTTP methods like CONNECT as per the spec.
See test_core.testHTTPMethods. See test_core.testHTTPMethods.
3. Max request header and body settings do not work with Apache. 3. Max request header and body settings do not work with Apache.
@ -101,15 +101,12 @@ class ModFCGISupervisor(helper.LocalSupervisor):
fcgiconf = os.path.join(curdir, fcgiconf) fcgiconf = os.path.join(curdir, fcgiconf)
# Write the Apache conf file. # Write the Apache conf file.
f = open(fcgiconf, 'wb') with open(fcgiconf, 'wb') as f:
try:
server = repr(os.path.join(curdir, 'fastcgi.pyc'))[1:-1] server = repr(os.path.join(curdir, 'fastcgi.pyc'))[1:-1]
output = self.template % {'port': self.port, 'root': curdir, output = self.template % {'port': self.port, 'root': curdir,
'server': server} 'server': server}
output = ntob(output.replace('\r\n', '\n')) output = ntob(output.replace('\r\n', '\n'))
f.write(output) f.write(output)
finally:
f.close()
result = read_process(APACHE_PATH, '-k start -f %s' % fcgiconf) result = read_process(APACHE_PATH, '-k start -f %s' % fcgiconf)
if result: if result:

View file

@ -15,7 +15,7 @@ KNOWN BUGS
1. Apache processes Range headers automatically; CherryPy's truncated 1. Apache processes Range headers automatically; CherryPy's truncated
output is then truncated again by Apache. See test_core.testRanges. output is then truncated again by Apache. See test_core.testRanges.
This was worked around in http://www.cherrypy.org/changeset/1319. This was worked around in http://www.cherrypy.dev/changeset/1319.
2. Apache does not allow custom HTTP methods like CONNECT as per the spec. 2. Apache does not allow custom HTTP methods like CONNECT as per the spec.
See test_core.testHTTPMethods. See test_core.testHTTPMethods.
3. Max request header and body settings do not work with Apache. 3. Max request header and body settings do not work with Apache.
@ -107,13 +107,10 @@ class ModPythonSupervisor(helper.Supervisor):
if not os.path.isabs(mpconf): if not os.path.isabs(mpconf):
mpconf = os.path.join(curdir, mpconf) mpconf = os.path.join(curdir, mpconf)
f = open(mpconf, 'wb') with open(mpconf, 'wb') as f:
try:
f.write(self.template % f.write(self.template %
{'port': self.port, 'modulename': modulename, {'port': self.port, 'modulename': modulename,
'host': self.host}) 'host': self.host})
finally:
f.close()
result = read_process(APACHE_PATH, '-k start -f %s' % mpconf) result = read_process(APACHE_PATH, '-k start -f %s' % mpconf)
if result: if result:

View file

@ -11,7 +11,7 @@ KNOWN BUGS
1. Apache processes Range headers automatically; CherryPy's truncated 1. Apache processes Range headers automatically; CherryPy's truncated
output is then truncated again by Apache. See test_core.testRanges. output is then truncated again by Apache. See test_core.testRanges.
This was worked around in http://www.cherrypy.org/changeset/1319. This was worked around in http://www.cherrypy.dev/changeset/1319.
2. Apache does not allow custom HTTP methods like CONNECT as per the spec. 2. Apache does not allow custom HTTP methods like CONNECT as per the spec.
See test_core.testHTTPMethods. See test_core.testHTTPMethods.
3. Max request header and body settings do not work with Apache. 3. Max request header and body settings do not work with Apache.
@ -109,14 +109,11 @@ class ModWSGISupervisor(helper.Supervisor):
if not os.path.isabs(mpconf): if not os.path.isabs(mpconf):
mpconf = os.path.join(curdir, mpconf) mpconf = os.path.join(curdir, mpconf)
f = open(mpconf, 'wb') with open(mpconf, 'wb') as f:
try:
output = (self.template % output = (self.template %
{'port': self.port, 'testmod': modulename, {'port': self.port, 'testmod': modulename,
'curdir': curdir}) 'curdir': curdir})
f.write(output) f.write(output)
finally:
f.close()
result = read_process(APACHE_PATH, '-k start -f %s' % mpconf) result = read_process(APACHE_PATH, '-k start -f %s' % mpconf)
if result: if result:

View file

@ -1,4 +1,4 @@
# This file is part of CherryPy <http://www.cherrypy.org/> # This file is part of CherryPy <http://www.cherrypy.dev/>
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# vim:ts=4:sw=4:expandtab:fileencoding=utf-8 # vim:ts=4:sw=4:expandtab:fileencoding=utf-8

View file

@ -1,4 +1,4 @@
# This file is part of CherryPy <http://www.cherrypy.org/> # This file is part of CherryPy <http://www.cherrypy.dev/>
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# vim:ts=4:sw=4:expandtab:fileencoding=utf-8 # vim:ts=4:sw=4:expandtab:fileencoding=utf-8

View file

@ -586,9 +586,8 @@ class CoreRequestHandlingTest(helper.CPWebCase):
def testFavicon(self): def testFavicon(self):
# favicon.ico is served by staticfile. # favicon.ico is served by staticfile.
icofilename = os.path.join(localDir, '../favicon.ico') icofilename = os.path.join(localDir, '../favicon.ico')
icofile = open(icofilename, 'rb') with open(icofilename, 'rb') as icofile:
data = icofile.read() data = icofile.read()
icofile.close()
self.getPage('/favicon.ico') self.getPage('/favicon.ico')
self.assertBody(data) self.assertBody(data)

View file

@ -46,7 +46,7 @@ class EncodingTests(helper.CPWebCase):
# any part which is unicode (even ascii), the response # any part which is unicode (even ascii), the response
# should not fail. # should not fail.
cherrypy.response.cookie['candy'] = 'bar' cherrypy.response.cookie['candy'] = 'bar'
cherrypy.response.cookie['candy']['domain'] = 'cherrypy.org' cherrypy.response.cookie['candy']['domain'] = 'cherrypy.dev'
cherrypy.response.headers[ cherrypy.response.headers[
'Some-Header'] = 'My d\xc3\xb6g has fleas' 'Some-Header'] = 'My d\xc3\xb6g has fleas'
cherrypy.response.headers[ cherrypy.response.headers[

View file

@ -113,7 +113,7 @@ def test_normal_return(log_tracker, server):
resp = requests.get( resp = requests.get(
'http://%s:%s/as_string' % (host, port), 'http://%s:%s/as_string' % (host, port),
headers={ headers={
'Referer': 'http://www.cherrypy.org/', 'Referer': 'http://www.cherrypy.dev/',
'User-Agent': 'Mozilla/5.0', 'User-Agent': 'Mozilla/5.0',
}, },
) )
@ -135,7 +135,7 @@ def test_normal_return(log_tracker, server):
log_tracker.assertLog( log_tracker.assertLog(
-1, -1,
'] "GET /as_string HTTP/1.1" 200 %s ' '] "GET /as_string HTTP/1.1" 200 %s '
'"http://www.cherrypy.org/" "Mozilla/5.0"' '"http://www.cherrypy.dev/" "Mozilla/5.0"'
% content_length, % content_length,
) )

View file

@ -342,7 +342,7 @@ class RequestObjectTests(helper.CPWebCase):
self.assertBody('/pathinfo/foo/bar') self.assertBody('/pathinfo/foo/bar')
def testAbsoluteURIPathInfo(self): def testAbsoluteURIPathInfo(self):
# http://cherrypy.org/ticket/1061 # http://cherrypy.dev/ticket/1061
self.getPage('http://localhost/pathinfo/foo/bar') self.getPage('http://localhost/pathinfo/foo/bar')
self.assertBody('/pathinfo/foo/bar') self.assertBody('/pathinfo/foo/bar')
@ -375,10 +375,10 @@ class RequestObjectTests(helper.CPWebCase):
# Make sure that encoded = and & get parsed correctly # Make sure that encoded = and & get parsed correctly
self.getPage( self.getPage(
'/params/code?url=http%3A//cherrypy.org/index%3Fa%3D1%26b%3D2') '/params/code?url=http%3A//cherrypy.dev/index%3Fa%3D1%26b%3D2')
self.assertBody('args: %s kwargs: %s' % self.assertBody('args: %s kwargs: %s' %
(('code',), (('code',),
[('url', ntou('http://cherrypy.org/index?a=1&b=2'))])) [('url', ntou('http://cherrypy.dev/index?a=1&b=2'))]))
# Test coordinates sent by <img ismap> # Test coordinates sent by <img ismap>
self.getPage('/params/ismap?223,114') self.getPage('/params/ismap?223,114')
@ -756,6 +756,16 @@ class RequestObjectTests(helper.CPWebCase):
headers=[('Content-type', 'application/json')]) headers=[('Content-type', 'application/json')])
self.assertBody('application/json') self.assertBody('application/json')
def test_dangerous_host(self):
"""
Dangerous characters like newlines should be elided.
Ref #1974.
"""
# foo\nbar
encoded = '=?iso-8859-1?q?foo=0Abar?='
self.getPage('/headers/Host', headers=[('Host', encoded)])
self.assertBody('foobar')
def test_basic_HTTPMethods(self): def test_basic_HTTPMethods(self):
helper.webtest.methods_with_bodies = ('POST', 'PUT', 'PROPFIND', helper.webtest.methods_with_bodies = ('POST', 'PUT', 'PROPFIND',
'PATCH') 'PATCH')

View file

@ -424,11 +424,12 @@ test_case_name: "test_signal_handler_unsubscribe"
p.join() p.join()
# Assert the old handler ran. # Assert the old handler ran.
log_lines = list(open(p.error_log, 'rb')) with open(p.error_log, 'rb') as f:
assert any( log_lines = list(f)
line.endswith(b'I am an old SIGTERM handler.\n') assert any(
for line in log_lines line.endswith(b'I am an old SIGTERM handler.\n')
) for line in log_lines
)
def test_safe_wait_INADDR_ANY(): # pylint: disable=invalid-name def test_safe_wait_INADDR_ANY(): # pylint: disable=invalid-name

View file

@ -78,7 +78,7 @@ class TutorialTest(helper.CPWebCase):
<ul> <ul>
<li><a href="http://del.icio.us">del.icio.us</a></li> <li><a href="http://del.icio.us">del.icio.us</a></li>
<li><a href="http://www.cherrypy.org">CherryPy</a></li> <li><a href="http://www.cherrypy.dev">CherryPy</a></li>
</ul> </ul>
<p>[<a href="../">Return to links page</a>]</p>''' <p>[<a href="../">Return to links page</a>]</p>'''
@ -166,7 +166,7 @@ class TutorialTest(helper.CPWebCase):
self.assertHeader('Content-Disposition', self.assertHeader('Content-Disposition',
# Make sure the filename is quoted. # Make sure the filename is quoted.
'attachment; filename="pdf_file.pdf"') 'attachment; filename="pdf_file.pdf"')
self.assertEqual(len(self.body), 85698) self.assertEqual(len(self.body), 11961)
def test10HTTPErrors(self): def test10HTTPErrors(self):
self.setup_tutorial('tut10_http_errors', 'HTTPErrorDemo') self.setup_tutorial('tut10_http_errors', 'HTTPErrorDemo')

Binary file not shown.

View file

@ -53,7 +53,7 @@ class LinksPage:
<ul> <ul>
<li> <li>
<a href="http://www.cherrypy.org">The CherryPy Homepage</a> <a href="http://www.cherrypy.dev">The CherryPy Homepage</a>
</li> </li>
<li> <li>
<a href="http://www.python.org">The Python Homepage</a> <a href="http://www.python.org">The Python Homepage</a>
@ -77,7 +77,7 @@ class ExtraLinksPage:
<ul> <ul>
<li><a href="http://del.icio.us">del.icio.us</a></li> <li><a href="http://del.icio.us">del.icio.us</a></li>
<li><a href="http://www.cherrypy.org">CherryPy</a></li> <li><a href="http://www.cherrypy.dev">CherryPy</a></li>
</ul> </ul>
<p>[<a href="../">Return to links page</a>]</p>''' <p>[<a href="../">Return to links page</a>]</p>'''

3991
lib/inflect/__init__.py Normal file

File diff suppressed because it is too large Load diff

0
lib/inflect/py.typed Normal file
View file

View file

@ -143,7 +143,7 @@ class classproperty:
return super().__setattr__(key, value) return super().__setattr__(key, value)
def __init__(self, fget, fset=None): def __init__(self, fget, fset=None):
self.fget = self._fix_function(fget) self.fget = self._ensure_method(fget)
self.fset = fset self.fset = fset
fset and self.setter(fset) fset and self.setter(fset)
@ -158,14 +158,13 @@ class classproperty:
return self.fset.__get__(None, owner)(value) return self.fset.__get__(None, owner)(value)
def setter(self, fset): def setter(self, fset):
self.fset = self._fix_function(fset) self.fset = self._ensure_method(fset)
return self return self
@classmethod @classmethod
def _fix_function(cls, fn): def _ensure_method(cls, fn):
""" """
Ensure fn is a classmethod or staticmethod. Ensure fn is a classmethod or staticmethod.
""" """
if not isinstance(fn, (classmethod, staticmethod)): needs_method = not isinstance(fn, (classmethod, staticmethod))
return classmethod(fn) return classmethod(fn) if needs_method else fn
return fn

View file

@ -63,7 +63,7 @@ class Projection(collections.abc.Mapping):
return len(tuple(iter(self))) return len(tuple(iter(self)))
class DictFilter(object): class DictFilter(collections.abc.Mapping):
""" """
Takes a dict, and simulates a sub-dict based on the keys. Takes a dict, and simulates a sub-dict based on the keys.
@ -92,15 +92,21 @@ class DictFilter(object):
... ...
KeyError: 'e' KeyError: 'e'
>>> 'e' in filtered
False
Pattern is useful for excluding keys with a prefix.
>>> filtered = DictFilter(sample, include_pattern=r'(?![ace])')
>>> dict(filtered)
{'b': 2, 'd': 4}
Also note that DictFilter keeps a reference to the original dict, so Also note that DictFilter keeps a reference to the original dict, so
if you modify the original dict, that could modify the filtered dict. if you modify the original dict, that could modify the filtered dict.
>>> del sample['d'] >>> del sample['d']
>>> del sample['a'] >>> dict(filtered)
>>> filtered == {'b': 2, 'c': 3} {'b': 2}
True
>>> filtered != {'b': 2, 'c': 3}
False
""" """
def __init__(self, dict, include_keys=[], include_pattern=None): def __init__(self, dict, include_keys=[], include_pattern=None):
@ -120,29 +126,18 @@ class DictFilter(object):
@property @property
def include_keys(self): def include_keys(self):
return self.specified_keys.union(self.pattern_keys) return self.specified_keys | self.pattern_keys
def keys(self):
return self.include_keys.intersection(self.dict.keys())
def values(self):
return map(self.dict.get, self.keys())
def __getitem__(self, i): def __getitem__(self, i):
if i not in self.include_keys: if i not in self.include_keys:
raise KeyError(i) raise KeyError(i)
return self.dict[i] return self.dict[i]
def items(self): def __iter__(self):
keys = self.keys() return filter(self.include_keys.__contains__, self.dict.keys())
values = map(self.dict.get, keys)
return zip(keys, values)
def __eq__(self, other): def __len__(self):
return dict(self) == other return len(list(self))
def __ne__(self, other):
return dict(self) != other
def dict_map(function, dictionary): def dict_map(function, dictionary):
@ -167,7 +162,7 @@ class RangeMap(dict):
the sorted list of keys. the sorted list of keys.
One may supply keyword parameters to be passed to the sort function used One may supply keyword parameters to be passed to the sort function used
to sort keys (i.e. cmp [python 2 only], keys, reverse) as sort_params. to sort keys (i.e. key, reverse) as sort_params.
Let's create a map that maps 1-3 -> 'a', 4-6 -> 'b' Let's create a map that maps 1-3 -> 'a', 4-6 -> 'b'
@ -220,6 +215,23 @@ class RangeMap(dict):
>>> r.get(7, 'not found') >>> r.get(7, 'not found')
'not found' 'not found'
One often wishes to define the ranges by their left-most values,
which requires use of sort params and a key_match_comparator.
>>> r = RangeMap({1: 'a', 4: 'b'},
... sort_params=dict(reverse=True),
... key_match_comparator=operator.ge)
>>> r[1], r[2], r[3], r[4], r[5], r[6]
('a', 'a', 'a', 'b', 'b', 'b')
That wasn't nearly as easy as before, so an alternate constructor
is provided:
>>> r = RangeMap.left({1: 'a', 4: 'b', 7: RangeMap.undefined_value})
>>> r[1], r[2], r[3], r[4], r[5], r[6]
('a', 'a', 'a', 'b', 'b', 'b')
""" """
def __init__(self, source, sort_params={}, key_match_comparator=operator.le): def __init__(self, source, sort_params={}, key_match_comparator=operator.le):
@ -227,6 +239,12 @@ class RangeMap(dict):
self.sort_params = sort_params self.sort_params = sort_params
self.match = key_match_comparator self.match = key_match_comparator
@classmethod
def left(cls, source):
return cls(
source, sort_params=dict(reverse=True), key_match_comparator=operator.ge
)
def __getitem__(self, item): def __getitem__(self, item):
sorted_keys = sorted(self.keys(), **self.sort_params) sorted_keys = sorted(self.keys(), **self.sort_params)
if isinstance(item, RangeMap.Item): if isinstance(item, RangeMap.Item):
@ -261,7 +279,7 @@ class RangeMap(dict):
return (sorted_keys[RangeMap.first_item], sorted_keys[RangeMap.last_item]) return (sorted_keys[RangeMap.first_item], sorted_keys[RangeMap.last_item])
# some special values for the RangeMap # some special values for the RangeMap
undefined_value = type(str('RangeValueUndefined'), (object,), {})() undefined_value = type(str('RangeValueUndefined'), (), {})()
class Item(int): class Item(int):
"RangeMap Item" "RangeMap Item"
@ -370,7 +388,7 @@ class FoldedCaseKeyedDict(KeyTransformingDict):
True True
>>> 'HELLO' in d >>> 'HELLO' in d
True True
>>> print(repr(FoldedCaseKeyedDict({'heLlo': 'world'})).replace("u'", "'")) >>> print(repr(FoldedCaseKeyedDict({'heLlo': 'world'})))
{'heLlo': 'world'} {'heLlo': 'world'}
>>> d = FoldedCaseKeyedDict({'heLlo': 'world'}) >>> d = FoldedCaseKeyedDict({'heLlo': 'world'})
>>> print(d['hello']) >>> print(d['hello'])
@ -433,7 +451,7 @@ class FoldedCaseKeyedDict(KeyTransformingDict):
return jaraco.text.FoldedCase(key) return jaraco.text.FoldedCase(key)
class DictAdapter(object): class DictAdapter:
""" """
Provide a getitem interface for attributes of an object. Provide a getitem interface for attributes of an object.
@ -452,7 +470,7 @@ class DictAdapter(object):
return getattr(self.object, name) return getattr(self.object, name)
class ItemsAsAttributes(object): class ItemsAsAttributes:
""" """
Mix-in class to enable a mapping object to provide items as Mix-in class to enable a mapping object to provide items as
attributes. attributes.
@ -561,7 +579,7 @@ class IdentityOverrideMap(dict):
return key return key
class DictStack(list, collections.abc.Mapping): class DictStack(list, collections.abc.MutableMapping):
""" """
A stack of dictionaries that behaves as a view on those dictionaries, A stack of dictionaries that behaves as a view on those dictionaries,
giving preference to the last. giving preference to the last.
@ -578,11 +596,12 @@ class DictStack(list, collections.abc.Mapping):
>>> stack.push(dict(a=3)) >>> stack.push(dict(a=3))
>>> stack['a'] >>> stack['a']
3 3
>>> stack['a'] = 4
>>> set(stack.keys()) == set(['a', 'b', 'c']) >>> set(stack.keys()) == set(['a', 'b', 'c'])
True True
>>> set(stack.items()) == set([('a', 3), ('b', 2), ('c', 2)]) >>> set(stack.items()) == set([('a', 4), ('b', 2), ('c', 2)])
True True
>>> dict(**stack) == dict(stack) == dict(a=3, c=2, b=2) >>> dict(**stack) == dict(stack) == dict(a=4, c=2, b=2)
True True
>>> d = stack.pop() >>> d = stack.pop()
>>> stack['a'] >>> stack['a']
@ -593,6 +612,9 @@ class DictStack(list, collections.abc.Mapping):
>>> stack.get('b', None) >>> stack.get('b', None)
>>> 'c' in stack >>> 'c' in stack
True True
>>> del stack['c']
>>> dict(stack)
{'a': 1}
""" """
def __iter__(self): def __iter__(self):
@ -613,6 +635,18 @@ class DictStack(list, collections.abc.Mapping):
def __len__(self): def __len__(self):
return len(list(iter(self))) return len(list(iter(self)))
def __setitem__(self, key, item):
last = list.__getitem__(self, -1)
return last.__setitem__(key, item)
def __delitem__(self, key):
last = list.__getitem__(self, -1)
return last.__delitem__(key)
# workaround for mypy confusion
def pop(self, *args, **kwargs):
return list.pop(self, *args, **kwargs)
class BijectiveMap(dict): class BijectiveMap(dict):
""" """
@ -855,7 +889,7 @@ class Enumeration(ItemsAsAttributes, BijectiveMap):
return (self[name] for name in self.names) return (self[name] for name in self.names)
class Everything(object): class Everything:
""" """
A collection "containing" every possible thing. A collection "containing" every possible thing.
@ -896,7 +930,7 @@ class InstrumentedDict(collections.UserDict): # type: ignore # buggy mypy
self.data = data self.data = data
class Least(object): class Least:
""" """
A value that is always lesser than any other A value that is always lesser than any other
@ -928,7 +962,7 @@ class Least(object):
__gt__ = __ge__ __gt__ = __ge__
class Greatest(object): class Greatest:
""" """
A value that is always greater than any other A value that is always greater than any other

View file

@ -66,7 +66,7 @@ class FoldedCase(str):
>>> s in ["Hello World"] >>> s in ["Hello World"]
True True
You may test for set inclusion, but candidate and elements Allows testing for set inclusion, but candidate and elements
must both be folded. must both be folded.
>>> FoldedCase("Hello World") in {s} >>> FoldedCase("Hello World") in {s}
@ -92,37 +92,40 @@ class FoldedCase(str):
>>> FoldedCase('hello') > FoldedCase('Hello') >>> FoldedCase('hello') > FoldedCase('Hello')
False False
>>> FoldedCase('ß') == FoldedCase('ss')
True
""" """
def __lt__(self, other): def __lt__(self, other):
return self.lower() < other.lower() return self.casefold() < other.casefold()
def __gt__(self, other): def __gt__(self, other):
return self.lower() > other.lower() return self.casefold() > other.casefold()
def __eq__(self, other): def __eq__(self, other):
return self.lower() == other.lower() return self.casefold() == other.casefold()
def __ne__(self, other): def __ne__(self, other):
return self.lower() != other.lower() return self.casefold() != other.casefold()
def __hash__(self): def __hash__(self):
return hash(self.lower()) return hash(self.casefold())
def __contains__(self, other): def __contains__(self, other):
return super().lower().__contains__(other.lower()) return super().casefold().__contains__(other.casefold())
def in_(self, other): def in_(self, other):
"Does self appear in other?" "Does self appear in other?"
return self in FoldedCase(other) return self in FoldedCase(other)
# cache lower since it's likely to be called frequently. # cache casefold since it's likely to be called frequently.
@method_cache @method_cache
def lower(self): def casefold(self):
return super().lower() return super().casefold()
def index(self, sub): def index(self, sub):
return self.lower().index(sub.lower()) return self.casefold().index(sub.casefold())
def split(self, splitter=' ', maxsplit=0): def split(self, splitter=' ', maxsplit=0):
pattern = re.compile(re.escape(splitter), re.I) pattern = re.compile(re.escape(splitter), re.I)
@ -277,7 +280,7 @@ class WordSet(tuple):
>>> WordSet.parse("myABCClass") >>> WordSet.parse("myABCClass")
('my', 'ABC', 'Class') ('my', 'ABC', 'Class')
The result is a WordSet, so you can get the form you need. The result is a WordSet, providing access to other forms.
>>> WordSet.parse("myABCClass").underscore_separated() >>> WordSet.parse("myABCClass").underscore_separated()
'my_ABC_Class' 'my_ABC_Class'
@ -598,3 +601,22 @@ def join_continuation(lines):
except StopIteration: except StopIteration:
return return
yield item yield item
def read_newlines(filename, limit=1024):
r"""
>>> tmp_path = getfixture('tmp_path')
>>> filename = tmp_path / 'out.txt'
>>> _ = filename.write_text('foo\n', newline='')
>>> read_newlines(filename)
'\n'
>>> _ = filename.write_text('foo\r\n', newline='')
>>> read_newlines(filename)
'\r\n'
>>> _ = filename.write_text('foo\r\nbar\nbing\r', newline='')
>>> read_newlines(filename)
('\r', '\n', '\r\n')
"""
with open(filename) as fp:
fp.read(limit)
return fp.newlines

View file

@ -0,0 +1,25 @@
qwerty = "-=qwertyuiop[]asdfghjkl;'zxcvbnm,./_+QWERTYUIOP{}ASDFGHJKL:\"ZXCVBNM<>?"
dvorak = "[]',.pyfgcrl/=aoeuidhtns-;qjkxbmwvz{}\"<>PYFGCRL?+AOEUIDHTNS_:QJKXBMWVZ"
to_dvorak = str.maketrans(qwerty, dvorak)
to_qwerty = str.maketrans(dvorak, qwerty)
def translate(input, translation):
"""
>>> translate('dvorak', to_dvorak)
'ekrpat'
>>> translate('qwerty', to_qwerty)
'x,dokt'
"""
return input.translate(translation)
def _translate_stream(stream, translation):
"""
>>> import io
>>> _translate_stream(io.StringIO('foo'), to_dvorak)
urr
"""
print(translate(stream.read(), translation))

View file

@ -0,0 +1,33 @@
import autocommand
import inflect
from more_itertools import always_iterable
import jaraco.text
def report_newlines(filename):
r"""
Report the newlines in the indicated file.
>>> tmp_path = getfixture('tmp_path')
>>> filename = tmp_path / 'out.txt'
>>> _ = filename.write_text('foo\nbar\n', newline='')
>>> report_newlines(filename)
newline is '\n'
>>> filename = tmp_path / 'out.txt'
>>> _ = filename.write_text('foo\nbar\r\n', newline='')
>>> report_newlines(filename)
newlines are ('\n', '\r\n')
"""
newlines = jaraco.text.read_newlines(filename)
count = len(tuple(always_iterable(newlines)))
engine = inflect.engine()
print(
engine.plural_noun("newline", count),
engine.plural_verb("is", count),
repr(newlines),
)
autocommand.autocommand(__name__)(report_newlines)

View file

@ -0,0 +1,6 @@
import sys
from . import layouts
__name__ == '__main__' and layouts._translate_stream(sys.stdin, layouts.to_dvorak)

View file

@ -0,0 +1,6 @@
import sys
from . import layouts
__name__ == '__main__' and layouts._translate_stream(sys.stdin, layouts.to_qwerty)

View file

@ -1,4 +1,6 @@
"""More routines for operating on iterables, beyond itertools"""
from .more import * # noqa from .more import * # noqa
from .recipes import * # noqa from .recipes import * # noqa
__version__ = '8.12.0' __version__ = '9.0.0'

View file

@ -2,9 +2,8 @@ import warnings
from collections import Counter, defaultdict, deque, abc from collections import Counter, defaultdict, deque, abc
from collections.abc import Sequence from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from functools import partial, reduce, wraps from functools import partial, reduce, wraps
from heapq import merge, heapify, heapreplace, heappop from heapq import heapify, heapreplace, heappop
from itertools import ( from itertools import (
chain, chain,
compress, compress,
@ -27,12 +26,16 @@ from sys import hexversion, maxsize
from time import monotonic from time import monotonic
from .recipes import ( from .recipes import (
_marker,
_zip_equal,
UnequalIterablesError,
consume, consume,
flatten, flatten,
pairwise, pairwise,
powerset, powerset,
take, take,
unique_everseen, unique_everseen,
all_equal,
) )
__all__ = [ __all__ = [
@ -49,9 +52,9 @@ __all__ = [
'chunked_even', 'chunked_even',
'circular_shifts', 'circular_shifts',
'collapse', 'collapse',
'collate',
'combination_index', 'combination_index',
'consecutive_groups', 'consecutive_groups',
'constrained_batches',
'consumer', 'consumer',
'count_cycle', 'count_cycle',
'countable', 'countable',
@ -67,6 +70,7 @@ __all__ = [
'first', 'first',
'groupby_transform', 'groupby_transform',
'ichunked', 'ichunked',
'iequals',
'ilen', 'ilen',
'interleave', 'interleave',
'interleave_evenly', 'interleave_evenly',
@ -77,6 +81,7 @@ __all__ = [
'iterate', 'iterate',
'last', 'last',
'locate', 'locate',
'longest_common_prefix',
'lstrip', 'lstrip',
'make_decorator', 'make_decorator',
'map_except', 'map_except',
@ -133,9 +138,6 @@ __all__ = [
] ]
_marker = object()
def chunked(iterable, n, strict=False): def chunked(iterable, n, strict=False):
"""Break *iterable* into lists of length *n*: """Break *iterable* into lists of length *n*:
@ -410,44 +412,6 @@ class peekable:
return self._cache[index] return self._cache[index]
def collate(*iterables, **kwargs):
"""Return a sorted merge of the items from each of several already-sorted
*iterables*.
>>> list(collate('ACDZ', 'AZ', 'JKL'))
['A', 'A', 'C', 'D', 'J', 'K', 'L', 'Z', 'Z']
Works lazily, keeping only the next value from each iterable in memory. Use
:func:`collate` to, for example, perform a n-way mergesort of items that
don't fit in memory.
If a *key* function is specified, the iterables will be sorted according
to its result:
>>> key = lambda s: int(s) # Sort by numeric value, not by string
>>> list(collate(['1', '10'], ['2', '11'], key=key))
['1', '2', '10', '11']
If the *iterables* are sorted in descending order, set *reverse* to
``True``:
>>> list(collate([5, 3, 1], [4, 2, 0], reverse=True))
[5, 4, 3, 2, 1, 0]
If the elements of the passed-in iterables are out of order, you might get
unexpected results.
On Python 3.5+, this function is an alias for :func:`heapq.merge`.
"""
warnings.warn(
"collate is no longer part of more_itertools, use heapq.merge",
DeprecationWarning,
)
return merge(*iterables, **kwargs)
def consumer(func): def consumer(func):
"""Decorator that automatically advances a PEP-342-style "reverse iterator" """Decorator that automatically advances a PEP-342-style "reverse iterator"
to its first yield point so you don't have to call ``next()`` on it to its first yield point so you don't have to call ``next()`` on it
@ -873,7 +837,9 @@ def windowed(seq, n, fillvalue=None, step=1):
yield tuple(window) yield tuple(window)
size = len(window) size = len(window)
if size < n: if size == 0:
return
elif size < n:
yield tuple(chain(window, repeat(fillvalue, n - size))) yield tuple(chain(window, repeat(fillvalue, n - size)))
elif 0 < i < min(step, n): elif 0 < i < min(step, n):
window += (fillvalue,) * i window += (fillvalue,) * i
@ -1646,45 +1612,6 @@ def stagger(iterable, offsets=(-1, 0, 1), longest=False, fillvalue=None):
) )
class UnequalIterablesError(ValueError):
def __init__(self, details=None):
msg = 'Iterables have different lengths'
if details is not None:
msg += (': index 0 has length {}; index {} has length {}').format(
*details
)
super().__init__(msg)
def _zip_equal_generator(iterables):
for combo in zip_longest(*iterables, fillvalue=_marker):
for val in combo:
if val is _marker:
raise UnequalIterablesError()
yield combo
def _zip_equal(*iterables):
# Check whether the iterables are all the same size.
try:
first_size = len(iterables[0])
for i, it in enumerate(iterables[1:], 1):
size = len(it)
if size != first_size:
break
else:
# If we didn't break out, we can use the built-in zip.
return zip(*iterables)
# If we did break out, there was a mismatch.
raise UnequalIterablesError(details=(first_size, i, size))
# If any one of the iterables didn't have a length, start reading
# them until one runs out.
except TypeError:
return _zip_equal_generator(iterables)
def zip_equal(*iterables): def zip_equal(*iterables):
"""``zip`` the input *iterables* together, but raise """``zip`` the input *iterables* together, but raise
``UnequalIterablesError`` if they aren't all the same length. ``UnequalIterablesError`` if they aren't all the same length.
@ -1826,7 +1753,7 @@ def unzip(iterable):
of the zipped *iterable*. of the zipped *iterable*.
The ``i``-th iterable contains the ``i``-th element from each element The ``i``-th iterable contains the ``i``-th element from each element
of the zipped iterable. The first element is used to to determine the of the zipped iterable. The first element is used to determine the
length of the remaining elements. length of the remaining elements.
>>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)] >>> iterable = [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
@ -2376,6 +2303,16 @@ def locate(iterable, pred=bool, window_size=None):
return compress(count(), starmap(pred, it)) return compress(count(), starmap(pred, it))
def longest_common_prefix(iterables):
"""Yield elements of the longest common prefix amongst given *iterables*.
>>> ''.join(longest_common_prefix(['abcd', 'abc', 'abf']))
'ab'
"""
return (c[0] for c in takewhile(all_equal, zip(*iterables)))
def lstrip(iterable, pred): def lstrip(iterable, pred):
"""Yield the items from *iterable*, but strip any from the beginning """Yield the items from *iterable*, but strip any from the beginning
for which *pred* returns ``True``. for which *pred* returns ``True``.
@ -2684,7 +2621,7 @@ def difference(iterable, func=sub, *, initial=None):
if initial is not None: if initial is not None:
first = [] first = []
return chain(first, starmap(func, zip(b, a))) return chain(first, map(func, b, a))
class SequenceView(Sequence): class SequenceView(Sequence):
@ -3327,6 +3264,27 @@ def only(iterable, default=None, too_long=None):
return first_value return first_value
class _IChunk:
def __init__(self, iterable, n):
self._it = islice(iterable, n)
self._cache = deque()
def fill_cache(self):
self._cache.extend(self._it)
def __iter__(self):
return self
def __next__(self):
try:
return next(self._it)
except StopIteration:
if self._cache:
return self._cache.popleft()
else:
raise
def ichunked(iterable, n): def ichunked(iterable, n):
"""Break *iterable* into sub-iterables with *n* elements each. """Break *iterable* into sub-iterables with *n* elements each.
:func:`ichunked` is like :func:`chunked`, but it yields iterables :func:`ichunked` is like :func:`chunked`, but it yields iterables
@ -3348,20 +3306,39 @@ def ichunked(iterable, n):
[8, 9, 10, 11] [8, 9, 10, 11]
""" """
source = iter(iterable) source = peekable(iter(iterable))
ichunk_marker = object()
while True: while True:
# Check to see whether we're at the end of the source iterable # Check to see whether we're at the end of the source iterable
item = next(source, _marker) item = source.peek(ichunk_marker)
if item is _marker: if item is ichunk_marker:
return return
# Clone the source and yield an n-length slice chunk = _IChunk(source, n)
source, it = tee(chain([item], source)) yield chunk
yield islice(it, n)
# Advance the source iterable # Advance the source iterable and fill previous chunk's cache
consume(source, n) chunk.fill_cache()
def iequals(*iterables):
"""Return ``True`` if all given *iterables* are equal to each other,
which means that they contain the same elements in the same order.
The function is useful for comparing iterables of different data types
or iterables that do not support equality checks.
>>> iequals("abc", ['a', 'b', 'c'], ('a', 'b', 'c'), iter("abc"))
True
>>> iequals("abc", "acb")
False
Not to be confused with :func:`all_equals`, which checks whether all
elements of iterable are equal to each other.
"""
return all(map(all_equal, zip_longest(*iterables, fillvalue=object())))
def distinct_combinations(iterable, r): def distinct_combinations(iterable, r):
@ -3656,7 +3633,10 @@ class callback_iter:
self._aborted = False self._aborted = False
self._future = None self._future = None
self._wait_seconds = wait_seconds self._wait_seconds = wait_seconds
self._executor = ThreadPoolExecutor(max_workers=1) # Lazily import concurrent.future
self._executor = __import__(
'concurrent.futures'
).futures.ThreadPoolExecutor(max_workers=1)
self._iterator = self._reader() self._iterator = self._reader()
def __enter__(self): def __enter__(self):
@ -3961,7 +3941,7 @@ def combination_index(element, iterable):
n, _ = last(pool, default=(n, None)) n, _ = last(pool, default=(n, None))
# Python versiosn below 3.8 don't have math.comb # Python versions below 3.8 don't have math.comb
index = 1 index = 1
for i, j in enumerate(reversed(indexes), start=1): for i, j in enumerate(reversed(indexes), start=1):
j = n - j j = n - j
@ -4114,7 +4094,7 @@ def zip_broadcast(*objects, scalar_types=(str, bytes), strict=False):
If the *strict* keyword argument is ``True``, then If the *strict* keyword argument is ``True``, then
``UnequalIterablesError`` will be raised if any of the iterables have ``UnequalIterablesError`` will be raised if any of the iterables have
different lengthss. different lengths.
""" """
def is_scalar(obj): def is_scalar(obj):
@ -4315,3 +4295,53 @@ def minmax(iterable_or_value, *others, key=None, default=_marker):
hi, hi_key = y, y_key hi, hi_key = y, y_key
return lo, hi return lo, hi
def constrained_batches(
iterable, max_size, max_count=None, get_len=len, strict=True
):
"""Yield batches of items from *iterable* with a combined size limited by
*max_size*.
>>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
>>> list(constrained_batches(iterable, 10))
[(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
If a *max_count* is supplied, the number of items per batch is also
limited:
>>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
>>> list(constrained_batches(iterable, 10, max_count = 2))
[(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
If a *get_len* function is supplied, use that instead of :func:`len` to
determine item size.
If *strict* is ``True``, raise ``ValueError`` if any single item is bigger
than *max_size*. Otherwise, allow single items to exceed *max_size*.
"""
if max_size <= 0:
raise ValueError('maximum size must be greater than zero')
batch = []
batch_size = 0
batch_count = 0
for item in iterable:
item_len = get_len(item)
if strict and item_len > max_size:
raise ValueError('item size exceeds maximum size')
reached_count = batch_count == max_count
reached_size = item_len + batch_size > max_size
if batch_count and (reached_size or reached_count):
yield tuple(batch)
batch.clear()
batch_size = 0
batch_count = 0
batch.append(item)
batch_size += item_len
batch_count += 1
if batch:
yield tuple(batch)

View file

@ -72,7 +72,6 @@ class peekable(Generic[_T], Iterator[_T]):
@overload @overload
def __getitem__(self, index: slice) -> List[_T]: ... def __getitem__(self, index: slice) -> List[_T]: ...
def collate(*iterables: Iterable[_T], **kwargs: Any) -> Iterable[_T]: ...
def consumer(func: _GenFn) -> _GenFn: ... def consumer(func: _GenFn) -> _GenFn: ...
def ilen(iterable: Iterable[object]) -> int: ... def ilen(iterable: Iterable[object]) -> int: ...
def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ... def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ...
@ -179,7 +178,7 @@ def padded(
iterable: Iterable[_T], iterable: Iterable[_T],
*, *,
n: Optional[int] = ..., n: Optional[int] = ...,
next_multiple: bool = ... next_multiple: bool = ...,
) -> Iterator[Optional[_T]]: ... ) -> Iterator[Optional[_T]]: ...
@overload @overload
def padded( def padded(
@ -225,7 +224,7 @@ def zip_equal(
__iter1: Iterable[_T], __iter1: Iterable[_T],
__iter2: Iterable[_T], __iter2: Iterable[_T],
__iter3: Iterable[_T], __iter3: Iterable[_T],
*iterables: Iterable[_T] *iterables: Iterable[_T],
) -> Iterator[Tuple[_T, ...]]: ... ) -> Iterator[Tuple[_T, ...]]: ...
@overload @overload
def zip_offset( def zip_offset(
@ -233,7 +232,7 @@ def zip_offset(
*, *,
offsets: _SizedIterable[int], offsets: _SizedIterable[int],
longest: bool = ..., longest: bool = ...,
fillvalue: None = None fillvalue: None = None,
) -> Iterator[Tuple[Optional[_T1]]]: ... ) -> Iterator[Tuple[Optional[_T1]]]: ...
@overload @overload
def zip_offset( def zip_offset(
@ -242,7 +241,7 @@ def zip_offset(
*, *,
offsets: _SizedIterable[int], offsets: _SizedIterable[int],
longest: bool = ..., longest: bool = ...,
fillvalue: None = None fillvalue: None = None,
) -> Iterator[Tuple[Optional[_T1], Optional[_T2]]]: ... ) -> Iterator[Tuple[Optional[_T1], Optional[_T2]]]: ...
@overload @overload
def zip_offset( def zip_offset(
@ -252,7 +251,7 @@ def zip_offset(
*iterables: Iterable[_T], *iterables: Iterable[_T],
offsets: _SizedIterable[int], offsets: _SizedIterable[int],
longest: bool = ..., longest: bool = ...,
fillvalue: None = None fillvalue: None = None,
) -> Iterator[Tuple[Optional[_T], ...]]: ... ) -> Iterator[Tuple[Optional[_T], ...]]: ...
@overload @overload
def zip_offset( def zip_offset(
@ -420,7 +419,7 @@ def difference(
iterable: Iterable[_T], iterable: Iterable[_T],
func: Callable[[_T, _T], _U] = ..., func: Callable[[_T, _T], _U] = ...,
*, *,
initial: None = ... initial: None = ...,
) -> Iterator[Union[_T, _U]]: ... ) -> Iterator[Union[_T, _U]]: ...
@overload @overload
def difference( def difference(
@ -529,12 +528,12 @@ def distinct_combinations(
def filter_except( def filter_except(
validator: Callable[[Any], object], validator: Callable[[Any], object],
iterable: Iterable[_T], iterable: Iterable[_T],
*exceptions: Type[BaseException] *exceptions: Type[BaseException],
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
def map_except( def map_except(
function: Callable[[Any], _U], function: Callable[[Any], _U],
iterable: Iterable[_T], iterable: Iterable[_T],
*exceptions: Type[BaseException] *exceptions: Type[BaseException],
) -> Iterator[_U]: ... ) -> Iterator[_U]: ...
def map_if( def map_if(
iterable: Iterable[Any], iterable: Iterable[Any],
@ -610,7 +609,7 @@ def zip_broadcast(
scalar_types: Union[ scalar_types: Union[
type, Tuple[Union[type, Tuple[Any, ...]], ...], None type, Tuple[Union[type, Tuple[Any, ...]], ...], None
] = ..., ] = ...,
strict: bool = ... strict: bool = ...,
) -> Iterable[Tuple[_T, ...]]: ... ) -> Iterable[Tuple[_T, ...]]: ...
def unique_in_window( def unique_in_window(
iterable: Iterable[_T], n: int, key: Optional[Callable[[_T], _U]] = ... iterable: Iterable[_T], n: int, key: Optional[Callable[[_T], _U]] = ...
@ -640,7 +639,7 @@ def minmax(
iterable_or_value: Iterable[_SupportsLessThanT], iterable_or_value: Iterable[_SupportsLessThanT],
*, *,
key: None = None, key: None = None,
default: _U default: _U,
) -> Union[_U, Tuple[_SupportsLessThanT, _SupportsLessThanT]]: ... ) -> Union[_U, Tuple[_SupportsLessThanT, _SupportsLessThanT]]: ...
@overload @overload
def minmax( def minmax(
@ -653,12 +652,23 @@ def minmax(
def minmax( def minmax(
iterable_or_value: _SupportsLessThanT, iterable_or_value: _SupportsLessThanT,
__other: _SupportsLessThanT, __other: _SupportsLessThanT,
*others: _SupportsLessThanT *others: _SupportsLessThanT,
) -> Tuple[_SupportsLessThanT, _SupportsLessThanT]: ... ) -> Tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
@overload @overload
def minmax( def minmax(
iterable_or_value: _T, iterable_or_value: _T,
__other: _T, __other: _T,
*others: _T, *others: _T,
key: Callable[[_T], _SupportsLessThan] key: Callable[[_T], _SupportsLessThan],
) -> Tuple[_T, _T]: ... ) -> Tuple[_T, _T]: ...
def longest_common_prefix(
iterables: Iterable[Iterable[_T]],
) -> Iterator[_T]: ...
def iequals(*iterables: Iterable[object]) -> bool: ...
def constrained_batches(
iterable: Iterable[object],
max_size: int,
max_count: Optional[int] = ...,
get_len: Callable[[_T], object] = ...,
strict: bool = ...,
) -> Iterator[Tuple[_T]]: ...

View file

@ -7,11 +7,16 @@ Some backward-compatible usability improvements have been made.
.. [1] http://docs.python.org/library/itertools.html#recipes .. [1] http://docs.python.org/library/itertools.html#recipes
""" """
import warnings import math
import operator
from collections import deque from collections import deque
from collections.abc import Sized
from functools import reduce
from itertools import ( from itertools import (
chain, chain,
combinations, combinations,
compress,
count, count,
cycle, cycle,
groupby, groupby,
@ -21,11 +26,11 @@ from itertools import (
tee, tee,
zip_longest, zip_longest,
) )
import operator
from random import randrange, sample, choice from random import randrange, sample, choice
__all__ = [ __all__ = [
'all_equal', 'all_equal',
'batched',
'before_and_after', 'before_and_after',
'consume', 'consume',
'convolve', 'convolve',
@ -41,6 +46,7 @@ __all__ = [
'pad_none', 'pad_none',
'pairwise', 'pairwise',
'partition', 'partition',
'polynomial_from_roots',
'powerset', 'powerset',
'prepend', 'prepend',
'quantify', 'quantify',
@ -50,7 +56,9 @@ __all__ = [
'random_product', 'random_product',
'repeatfunc', 'repeatfunc',
'roundrobin', 'roundrobin',
'sieve',
'sliding_window', 'sliding_window',
'subslices',
'tabulate', 'tabulate',
'tail', 'tail',
'take', 'take',
@ -59,6 +67,8 @@ __all__ = [
'unique_justseen', 'unique_justseen',
] ]
_marker = object()
def take(n, iterable): def take(n, iterable):
"""Return first *n* items of the iterable as a list. """Return first *n* items of the iterable as a list.
@ -102,7 +112,14 @@ def tail(n, iterable):
['E', 'F', 'G'] ['E', 'F', 'G']
""" """
return iter(deque(iterable, maxlen=n)) # If the given iterable has a length, then we can use islice to get its
# final elements. Note that if the iterable is not actually Iterable,
# either islice or deque will throw a TypeError. This is why we don't
# check if it is Iterable.
if isinstance(iterable, Sized):
yield from islice(iterable, max(0, len(iterable) - n), None)
else:
yield from iter(deque(iterable, maxlen=n))
def consume(iterator, n=None): def consume(iterator, n=None):
@ -284,20 +301,83 @@ else:
pairwise.__doc__ = _pairwise.__doc__ pairwise.__doc__ = _pairwise.__doc__
def grouper(iterable, n, fillvalue=None): class UnequalIterablesError(ValueError):
"""Collect data into fixed-length chunks or blocks. def __init__(self, details=None):
msg = 'Iterables have different lengths'
if details is not None:
msg += (': index 0 has length {}; index {} has length {}').format(
*details
)
>>> list(grouper('ABCDEFG', 3, 'x')) super().__init__(msg)
def _zip_equal_generator(iterables):
for combo in zip_longest(*iterables, fillvalue=_marker):
for val in combo:
if val is _marker:
raise UnequalIterablesError()
yield combo
def _zip_equal(*iterables):
# Check whether the iterables are all the same size.
try:
first_size = len(iterables[0])
for i, it in enumerate(iterables[1:], 1):
size = len(it)
if size != first_size:
break
else:
# If we didn't break out, we can use the built-in zip.
return zip(*iterables)
# If we did break out, there was a mismatch.
raise UnequalIterablesError(details=(first_size, i, size))
# If any one of the iterables didn't have a length, start reading
# them until one runs out.
except TypeError:
return _zip_equal_generator(iterables)
def grouper(iterable, n, incomplete='fill', fillvalue=None):
"""Group elements from *iterable* into fixed-length groups of length *n*.
>>> list(grouper('ABCDEF', 3))
[('A', 'B', 'C'), ('D', 'E', 'F')]
The keyword arguments *incomplete* and *fillvalue* control what happens for
iterables whose length is not a multiple of *n*.
When *incomplete* is `'fill'`, the last group will contain instances of
*fillvalue*.
>>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
[('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')] [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
When *incomplete* is `'ignore'`, the last group will not be emitted.
>>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
[('A', 'B', 'C'), ('D', 'E', 'F')]
When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised.
>>> it = grouper('ABCDEFG', 3, incomplete='strict')
>>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
UnequalIterablesError
""" """
if isinstance(iterable, int):
warnings.warn(
"grouper expects iterable as first parameter", DeprecationWarning
)
n, iterable = iterable, n
args = [iter(iterable)] * n args = [iter(iterable)] * n
return zip_longest(fillvalue=fillvalue, *args) if incomplete == 'fill':
return zip_longest(*args, fillvalue=fillvalue)
if incomplete == 'strict':
return _zip_equal(*args)
if incomplete == 'ignore':
return zip(*args)
else:
raise ValueError('Expected fill, strict, or ignore')
def roundrobin(*iterables): def roundrobin(*iterables):
@ -658,11 +738,12 @@ def before_and_after(predicate, it):
transition.append(elem) transition.append(elem)
return return
def remainder_iterator(): # Note: this is different from itertools recipes to allow nesting
yield from transition # before_and_after remainders into before_and_after again. See tests
yield from it # for an example.
remainder_iterator = chain(transition, it)
return true_iterator(), remainder_iterator() return true_iterator(), remainder_iterator
def triplewise(iterable): def triplewise(iterable):
@ -696,3 +777,65 @@ def sliding_window(iterable, n):
for x in it: for x in it:
window.append(x) window.append(x)
yield tuple(window) yield tuple(window)
def subslices(iterable):
"""Return all contiguous non-empty subslices of *iterable*.
>>> list(subslices('ABC'))
[['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
This is similar to :func:`substrings`, but emits items in a different
order.
"""
seq = list(iterable)
slices = starmap(slice, combinations(range(len(seq) + 1), 2))
return map(operator.getitem, repeat(seq), slices)
def polynomial_from_roots(roots):
"""Compute a polynomial's coefficients from its roots.
>>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
>>> polynomial_from_roots(roots) # x^3 - 4 * x^2 - 17 * x + 60
[1, -4, -17, 60]
"""
# Use math.prod for Python 3.8+,
prod = getattr(math, 'prod', lambda x: reduce(operator.mul, x, 1))
roots = list(map(operator.neg, roots))
return [
sum(map(prod, combinations(roots, k))) for k in range(len(roots) + 1)
]
def sieve(n):
"""Yield the primes less than n.
>>> list(sieve(30))
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
"""
isqrt = getattr(math, 'isqrt', lambda x: int(math.sqrt(x)))
limit = isqrt(n) + 1
data = bytearray([1]) * n
data[:2] = 0, 0
for p in compress(range(limit), data):
data[p + p : n : p] = bytearray(len(range(p + p, n, p)))
return compress(count(), data)
def batched(iterable, n):
"""Batch data into lists of length *n*. The last batch may be shorter.
>>> list(batched('ABCDEFG', 3))
[['A', 'B', 'C'], ['D', 'E', 'F'], ['G']]
This recipe is from the ``itertools`` docs. This library also provides
:func:`chunked`, which has a different implementation.
"""
it = iter(iterable)
while True:
batch = list(islice(it, n))
if not batch:
break
yield batch

View file

@ -6,6 +6,7 @@ from typing import (
Iterator, Iterator,
List, List,
Optional, Optional,
Sequence,
Tuple, Tuple,
TypeVar, TypeVar,
Union, Union,
@ -39,21 +40,11 @@ def repeatfunc(
func: Callable[..., _U], times: Optional[int] = ..., *args: Any func: Callable[..., _U], times: Optional[int] = ..., *args: Any
) -> Iterator[_U]: ... ) -> Iterator[_U]: ...
def pairwise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T]]: ... def pairwise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T]]: ...
@overload
def grouper( def grouper(
iterable: Iterable[_T], n: int iterable: Iterable[_T],
) -> Iterator[Tuple[Optional[_T], ...]]: ... n: int,
@overload incomplete: str = ...,
def grouper( fillvalue: _U = ...,
iterable: Iterable[_T], n: int, fillvalue: _U
) -> Iterator[Tuple[Union[_T, _U], ...]]: ...
@overload
def grouper( # Deprecated interface
iterable: int, n: Iterable[_T]
) -> Iterator[Tuple[Optional[_T], ...]]: ...
@overload
def grouper( # Deprecated interface
iterable: int, n: Iterable[_T], fillvalue: _U
) -> Iterator[Tuple[Union[_T, _U], ...]]: ... ) -> Iterator[Tuple[Union[_T, _U], ...]]: ...
def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: ... def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: ...
def partition( def partition(
@ -110,3 +101,10 @@ def triplewise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T, _T]]: ...
def sliding_window( def sliding_window(
iterable: Iterable[_T], n: int iterable: Iterable[_T], n: int
) -> Iterator[Tuple[_T, ...]]: ... ) -> Iterator[Tuple[_T, ...]]: ...
def subslices(iterable: Iterable[_T]) -> Iterator[List[_T]]: ...
def polynomial_from_roots(roots: Sequence[int]) -> List[int]: ...
def sieve(n: int) -> Iterator[int]: ...
def batched(
iterable: Iterable[_T],
n: int,
) -> Iterator[List[_T]]: ...

131
lib/pydantic/__init__.py Normal file
View file

@ -0,0 +1,131 @@
# flake8: noqa
from . import dataclasses
from .annotated_types import create_model_from_namedtuple, create_model_from_typeddict
from .class_validators import root_validator, validator
from .config import BaseConfig, ConfigDict, Extra
from .decorator import validate_arguments
from .env_settings import BaseSettings
from .error_wrappers import ValidationError
from .errors import *
from .fields import Field, PrivateAttr, Required
from .main import *
from .networks import *
from .parse import Protocol
from .tools import *
from .types import *
from .version import VERSION, compiled
__version__ = VERSION
# WARNING __all__ from .errors is not included here, it will be removed as an export here in v2
# please use "from pydantic.errors import ..." instead
__all__ = [
# annotated types utils
'create_model_from_namedtuple',
'create_model_from_typeddict',
# dataclasses
'dataclasses',
# class_validators
'root_validator',
'validator',
# config
'BaseConfig',
'ConfigDict',
'Extra',
# decorator
'validate_arguments',
# env_settings
'BaseSettings',
# error_wrappers
'ValidationError',
# fields
'Field',
'Required',
# main
'BaseModel',
'create_model',
'validate_model',
# network
'AnyUrl',
'AnyHttpUrl',
'FileUrl',
'HttpUrl',
'stricturl',
'EmailStr',
'NameEmail',
'IPvAnyAddress',
'IPvAnyInterface',
'IPvAnyNetwork',
'PostgresDsn',
'CockroachDsn',
'AmqpDsn',
'RedisDsn',
'MongoDsn',
'KafkaDsn',
'validate_email',
# parse
'Protocol',
# tools
'parse_file_as',
'parse_obj_as',
'parse_raw_as',
'schema_of',
'schema_json_of',
# types
'NoneStr',
'NoneBytes',
'StrBytes',
'NoneStrBytes',
'StrictStr',
'ConstrainedBytes',
'conbytes',
'ConstrainedList',
'conlist',
'ConstrainedSet',
'conset',
'ConstrainedFrozenSet',
'confrozenset',
'ConstrainedStr',
'constr',
'PyObject',
'ConstrainedInt',
'conint',
'PositiveInt',
'NegativeInt',
'NonNegativeInt',
'NonPositiveInt',
'ConstrainedFloat',
'confloat',
'PositiveFloat',
'NegativeFloat',
'NonNegativeFloat',
'NonPositiveFloat',
'FiniteFloat',
'ConstrainedDecimal',
'condecimal',
'ConstrainedDate',
'condate',
'UUID1',
'UUID3',
'UUID4',
'UUID5',
'FilePath',
'DirectoryPath',
'Json',
'JsonWrapper',
'SecretField',
'SecretStr',
'SecretBytes',
'StrictBool',
'StrictBytes',
'StrictInt',
'StrictFloat',
'PaymentCardNumber',
'PrivateAttr',
'ByteSize',
'PastDate',
'FutureDate',
# version
'compiled',
'VERSION',
]

View file

@ -0,0 +1,386 @@
"""
Register Hypothesis strategies for Pydantic custom types.
This enables fully-automatic generation of test data for most Pydantic classes.
Note that this module has *no* runtime impact on Pydantic itself; instead it
is registered as a setuptools entry point and Hypothesis will import it if
Pydantic is installed. See also:
https://hypothesis.readthedocs.io/en/latest/strategies.html#registering-strategies-via-setuptools-entry-points
https://hypothesis.readthedocs.io/en/latest/data.html#hypothesis.strategies.register_type_strategy
https://hypothesis.readthedocs.io/en/latest/strategies.html#interaction-with-pytest-cov
https://pydantic-docs.helpmanual.io/usage/types/#pydantic-types
Note that because our motivation is to *improve user experience*, the strategies
are always sound (never generate invalid data) but sacrifice completeness for
maintainability (ie may be unable to generate some tricky but valid data).
Finally, this module makes liberal use of `# type: ignore[<code>]` pragmas.
This is because Hypothesis annotates `register_type_strategy()` with
`(T, SearchStrategy[T])`, but in most cases we register e.g. `ConstrainedInt`
to generate instances of the builtin `int` type which match the constraints.
"""
import contextlib
import datetime
import ipaddress
import json
import math
from fractions import Fraction
from typing import Callable, Dict, Type, Union, cast, overload
import hypothesis.strategies as st
import pydantic
import pydantic.color
import pydantic.types
from pydantic.utils import lenient_issubclass
# FilePath and DirectoryPath are explicitly unsupported, as we'd have to create
# them on-disk, and that's unsafe in general without being told *where* to do so.
#
# URLs are unsupported because it's easy for users to define their own strategy for
# "normal" URLs, and hard for us to define a general strategy which includes "weird"
# URLs but doesn't also have unpredictable performance problems.
#
# conlist() and conset() are unsupported for now, because the workarounds for
# Cython and Hypothesis to handle parametrized generic types are incompatible.
# Once Cython can support 'normal' generics we'll revisit this.
# Emails
try:
import email_validator
except ImportError: # pragma: no cover
pass
else:
def is_valid_email(s: str) -> bool:
# Hypothesis' st.emails() occasionally generates emails like 0@A0--0.ac
# that are invalid according to email-validator, so we filter those out.
try:
email_validator.validate_email(s, check_deliverability=False)
return True
except email_validator.EmailNotValidError: # pragma: no cover
return False
# Note that these strategies deliberately stay away from any tricky Unicode
# or other encoding issues; we're just trying to generate *something* valid.
st.register_type_strategy(pydantic.EmailStr, st.emails().filter(is_valid_email)) # type: ignore[arg-type]
st.register_type_strategy(
pydantic.NameEmail,
st.builds(
'{} <{}>'.format, # type: ignore[arg-type]
st.from_regex('[A-Za-z0-9_]+( [A-Za-z0-9_]+){0,5}', fullmatch=True),
st.emails().filter(is_valid_email),
),
)
# PyObject - dotted names, in this case taken from the math module.
st.register_type_strategy(
pydantic.PyObject, # type: ignore[arg-type]
st.sampled_from(
[cast(pydantic.PyObject, f'math.{name}') for name in sorted(vars(math)) if not name.startswith('_')]
),
)
# CSS3 Colors; as name, hex, rgb(a) tuples or strings, or hsl strings
_color_regexes = (
'|'.join(
(
pydantic.color.r_hex_short,
pydantic.color.r_hex_long,
pydantic.color.r_rgb,
pydantic.color.r_rgba,
pydantic.color.r_hsl,
pydantic.color.r_hsla,
)
)
# Use more precise regex patterns to avoid value-out-of-range errors
.replace(pydantic.color._r_sl, r'(?:(\d\d?(?:\.\d+)?|100(?:\.0+)?)%)')
.replace(pydantic.color._r_alpha, r'(?:(0(?:\.\d+)?|1(?:\.0+)?|\.\d+|\d{1,2}%))')
.replace(pydantic.color._r_255, r'(?:((?:\d|\d\d|[01]\d\d|2[0-4]\d|25[0-4])(?:\.\d+)?|255(?:\.0+)?))')
)
st.register_type_strategy(
pydantic.color.Color,
st.one_of(
st.sampled_from(sorted(pydantic.color.COLORS_BY_NAME)),
st.tuples(
st.integers(0, 255),
st.integers(0, 255),
st.integers(0, 255),
st.none() | st.floats(0, 1) | st.floats(0, 100).map('{}%'.format),
),
st.from_regex(_color_regexes, fullmatch=True),
),
)
# Card numbers, valid according to the Luhn algorithm
def add_luhn_digit(card_number: str) -> str:
# See https://en.wikipedia.org/wiki/Luhn_algorithm
for digit in '0123456789':
with contextlib.suppress(Exception):
pydantic.PaymentCardNumber.validate_luhn_check_digit(card_number + digit)
return card_number + digit
raise AssertionError('Unreachable') # pragma: no cover
card_patterns = (
# Note that these patterns omit the Luhn check digit; that's added by the function above
'4[0-9]{14}', # Visa
'5[12345][0-9]{13}', # Mastercard
'3[47][0-9]{12}', # American Express
'[0-26-9][0-9]{10,17}', # other (incomplete to avoid overlap)
)
st.register_type_strategy(
pydantic.PaymentCardNumber,
st.from_regex('|'.join(card_patterns), fullmatch=True).map(add_luhn_digit), # type: ignore[arg-type]
)
# UUIDs
st.register_type_strategy(pydantic.UUID1, st.uuids(version=1))
st.register_type_strategy(pydantic.UUID3, st.uuids(version=3))
st.register_type_strategy(pydantic.UUID4, st.uuids(version=4))
st.register_type_strategy(pydantic.UUID5, st.uuids(version=5))
# Secrets
st.register_type_strategy(pydantic.SecretBytes, st.binary().map(pydantic.SecretBytes))
st.register_type_strategy(pydantic.SecretStr, st.text().map(pydantic.SecretStr))
# IP addresses, networks, and interfaces
st.register_type_strategy(pydantic.IPvAnyAddress, st.ip_addresses()) # type: ignore[arg-type]
st.register_type_strategy(
pydantic.IPvAnyInterface,
st.from_type(ipaddress.IPv4Interface) | st.from_type(ipaddress.IPv6Interface), # type: ignore[arg-type]
)
st.register_type_strategy(
pydantic.IPvAnyNetwork,
st.from_type(ipaddress.IPv4Network) | st.from_type(ipaddress.IPv6Network), # type: ignore[arg-type]
)
# We hook into the con***() functions and the ConstrainedNumberMeta metaclass,
# so here we only have to register subclasses for other constrained types which
# don't go via those mechanisms. Then there are the registration hooks below.
st.register_type_strategy(pydantic.StrictBool, st.booleans())
st.register_type_strategy(pydantic.StrictStr, st.text())
# Constrained-type resolver functions
#
# For these ones, we actually want to inspect the type in order to work out a
# satisfying strategy. First up, the machinery for tracking resolver functions:
RESOLVERS: Dict[type, Callable[[type], st.SearchStrategy]] = {} # type: ignore[type-arg]
@overload
def _registered(typ: Type[pydantic.types.T]) -> Type[pydantic.types.T]:
pass
@overload
def _registered(typ: pydantic.types.ConstrainedNumberMeta) -> pydantic.types.ConstrainedNumberMeta:
pass
def _registered(
typ: Union[Type[pydantic.types.T], pydantic.types.ConstrainedNumberMeta]
) -> Union[Type[pydantic.types.T], pydantic.types.ConstrainedNumberMeta]:
# This function replaces the version in `pydantic.types`, in order to
# effect the registration of new constrained types so that Hypothesis
# can generate valid examples.
pydantic.types._DEFINED_TYPES.add(typ)
for supertype, resolver in RESOLVERS.items():
if issubclass(typ, supertype):
st.register_type_strategy(typ, resolver(typ)) # type: ignore
return typ
raise NotImplementedError(f'Unknown type {typ!r} has no resolver to register') # pragma: no cover
def resolves(
typ: Union[type, pydantic.types.ConstrainedNumberMeta]
) -> Callable[[Callable[..., st.SearchStrategy]], Callable[..., st.SearchStrategy]]: # type: ignore[type-arg]
def inner(f): # type: ignore
assert f not in RESOLVERS
RESOLVERS[typ] = f
return f
return inner
# Type-to-strategy resolver functions
@resolves(pydantic.JsonWrapper)
def resolve_json(cls): # type: ignore[no-untyped-def]
try:
inner = st.none() if cls.inner_type is None else st.from_type(cls.inner_type)
except Exception: # pragma: no cover
finite = st.floats(allow_infinity=False, allow_nan=False)
inner = st.recursive(
base=st.one_of(st.none(), st.booleans(), st.integers(), finite, st.text()),
extend=lambda x: st.lists(x) | st.dictionaries(st.text(), x), # type: ignore
)
inner_type = getattr(cls, 'inner_type', None)
return st.builds(
cls.inner_type.json if lenient_issubclass(inner_type, pydantic.BaseModel) else json.dumps,
inner,
ensure_ascii=st.booleans(),
indent=st.none() | st.integers(0, 16),
sort_keys=st.booleans(),
)
@resolves(pydantic.ConstrainedBytes)
def resolve_conbytes(cls): # type: ignore[no-untyped-def] # pragma: no cover
min_size = cls.min_length or 0
max_size = cls.max_length
if not cls.strip_whitespace:
return st.binary(min_size=min_size, max_size=max_size)
# Fun with regex to ensure we neither start nor end with whitespace
repeats = '{{{},{}}}'.format(
min_size - 2 if min_size > 2 else 0,
max_size - 2 if (max_size or 0) > 2 else '',
)
if min_size >= 2:
pattern = rf'\W.{repeats}\W'
elif min_size == 1:
pattern = rf'\W(.{repeats}\W)?'
else:
assert min_size == 0
pattern = rf'(\W(.{repeats}\W)?)?'
return st.from_regex(pattern.encode(), fullmatch=True)
@resolves(pydantic.ConstrainedDecimal)
def resolve_condecimal(cls): # type: ignore[no-untyped-def]
min_value = cls.ge
max_value = cls.le
if cls.gt is not None:
assert min_value is None, 'Set `gt` or `ge`, but not both'
min_value = cls.gt
if cls.lt is not None:
assert max_value is None, 'Set `lt` or `le`, but not both'
max_value = cls.lt
s = st.decimals(min_value, max_value, allow_nan=False, places=cls.decimal_places)
if cls.lt is not None:
s = s.filter(lambda d: d < cls.lt)
if cls.gt is not None:
s = s.filter(lambda d: cls.gt < d)
return s
@resolves(pydantic.ConstrainedFloat)
def resolve_confloat(cls): # type: ignore[no-untyped-def]
min_value = cls.ge
max_value = cls.le
exclude_min = False
exclude_max = False
if cls.gt is not None:
assert min_value is None, 'Set `gt` or `ge`, but not both'
min_value = cls.gt
exclude_min = True
if cls.lt is not None:
assert max_value is None, 'Set `lt` or `le`, but not both'
max_value = cls.lt
exclude_max = True
if cls.multiple_of is None:
return st.floats(min_value, max_value, exclude_min=exclude_min, exclude_max=exclude_max, allow_nan=False)
if min_value is not None:
min_value = math.ceil(min_value / cls.multiple_of)
if exclude_min:
min_value = min_value + 1
if max_value is not None:
assert max_value >= cls.multiple_of, 'Cannot build model with max value smaller than multiple of'
max_value = math.floor(max_value / cls.multiple_of)
if exclude_max:
max_value = max_value - 1
return st.integers(min_value, max_value).map(lambda x: x * cls.multiple_of)
@resolves(pydantic.ConstrainedInt)
def resolve_conint(cls): # type: ignore[no-untyped-def]
min_value = cls.ge
max_value = cls.le
if cls.gt is not None:
assert min_value is None, 'Set `gt` or `ge`, but not both'
min_value = cls.gt + 1
if cls.lt is not None:
assert max_value is None, 'Set `lt` or `le`, but not both'
max_value = cls.lt - 1
if cls.multiple_of is None or cls.multiple_of == 1:
return st.integers(min_value, max_value)
# These adjustments and the .map handle integer-valued multiples, while the
# .filter handles trickier cases as for confloat.
if min_value is not None:
min_value = math.ceil(Fraction(min_value) / Fraction(cls.multiple_of))
if max_value is not None:
max_value = math.floor(Fraction(max_value) / Fraction(cls.multiple_of))
return st.integers(min_value, max_value).map(lambda x: x * cls.multiple_of)
@resolves(pydantic.ConstrainedDate)
def resolve_condate(cls): # type: ignore[no-untyped-def]
if cls.ge is not None:
assert cls.gt is None, 'Set `gt` or `ge`, but not both'
min_value = cls.ge
elif cls.gt is not None:
min_value = cls.gt + datetime.timedelta(days=1)
else:
min_value = datetime.date.min
if cls.le is not None:
assert cls.lt is None, 'Set `lt` or `le`, but not both'
max_value = cls.le
elif cls.lt is not None:
max_value = cls.lt - datetime.timedelta(days=1)
else:
max_value = datetime.date.max
return st.dates(min_value, max_value)
@resolves(pydantic.ConstrainedStr)
def resolve_constr(cls): # type: ignore[no-untyped-def] # pragma: no cover
min_size = cls.min_length or 0
max_size = cls.max_length
if cls.regex is None and not cls.strip_whitespace:
return st.text(min_size=min_size, max_size=max_size)
if cls.regex is not None:
strategy = st.from_regex(cls.regex)
if cls.strip_whitespace:
strategy = strategy.filter(lambda s: s == s.strip())
elif cls.strip_whitespace:
repeats = '{{{},{}}}'.format(
min_size - 2 if min_size > 2 else 0,
max_size - 2 if (max_size or 0) > 2 else '',
)
if min_size >= 2:
strategy = st.from_regex(rf'\W.{repeats}\W')
elif min_size == 1:
strategy = st.from_regex(rf'\W(.{repeats}\W)?')
else:
assert min_size == 0
strategy = st.from_regex(rf'(\W(.{repeats}\W)?)?')
if min_size == 0 and max_size is None:
return strategy
elif max_size is None:
return strategy.filter(lambda s: min_size <= len(s))
return strategy.filter(lambda s: min_size <= len(s) <= max_size)
# Finally, register all previously-defined types, and patch in our new function
for typ in list(pydantic.types._DEFINED_TYPES):
_registered(typ)
pydantic.types._registered = _registered
st.register_type_strategy(pydantic.Json, resolve_json)

View file

@ -0,0 +1,72 @@
import sys
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, NamedTuple, Type
from .fields import Required
from .main import BaseModel, create_model
from .typing import is_typeddict, is_typeddict_special
if TYPE_CHECKING:
from typing_extensions import TypedDict
if sys.version_info < (3, 11):
def is_legacy_typeddict(typeddict_cls: Type['TypedDict']) -> bool: # type: ignore[valid-type]
return is_typeddict(typeddict_cls) and type(typeddict_cls).__module__ == 'typing'
else:
def is_legacy_typeddict(_: Any) -> Any:
return False
def create_model_from_typeddict(
# Mypy bug: `Type[TypedDict]` is resolved as `Any` https://github.com/python/mypy/issues/11030
typeddict_cls: Type['TypedDict'], # type: ignore[valid-type]
**kwargs: Any,
) -> Type['BaseModel']:
"""
Create a `BaseModel` based on the fields of a `TypedDict`.
Since `typing.TypedDict` in Python 3.8 does not store runtime information about optional keys,
we raise an error if this happens (see https://bugs.python.org/issue38834).
"""
field_definitions: Dict[str, Any]
# Best case scenario: with python 3.9+ or when `TypedDict` is imported from `typing_extensions`
if not hasattr(typeddict_cls, '__required_keys__'):
raise TypeError(
'You should use `typing_extensions.TypedDict` instead of `typing.TypedDict` with Python < 3.9.2. '
'Without it, there is no way to differentiate required and optional fields when subclassed.'
)
if is_legacy_typeddict(typeddict_cls) and any(
is_typeddict_special(t) for t in typeddict_cls.__annotations__.values()
):
raise TypeError(
'You should use `typing_extensions.TypedDict` instead of `typing.TypedDict` with Python < 3.11. '
'Without it, there is no way to reflect Required/NotRequired keys.'
)
required_keys: FrozenSet[str] = typeddict_cls.__required_keys__ # type: ignore[attr-defined]
field_definitions = {
field_name: (field_type, Required if field_name in required_keys else None)
for field_name, field_type in typeddict_cls.__annotations__.items()
}
return create_model(typeddict_cls.__name__, **kwargs, **field_definitions)
def create_model_from_namedtuple(namedtuple_cls: Type['NamedTuple'], **kwargs: Any) -> Type['BaseModel']:
"""
Create a `BaseModel` based on the fields of a named tuple.
A named tuple can be created with `typing.NamedTuple` and declared annotations
but also with `collections.namedtuple`, in this case we consider all fields
to have type `Any`.
"""
# With python 3.10+, `__annotations__` always exists but can be empty hence the `getattr... or...` logic
namedtuple_annotations: Dict[str, Type[Any]] = getattr(namedtuple_cls, '__annotations__', None) or {
k: Any for k in namedtuple_cls._fields
}
field_definitions: Dict[str, Any] = {
field_name: (field_type, Required) for field_name, field_type in namedtuple_annotations.items()
}
return create_model(namedtuple_cls.__name__, **kwargs, **field_definitions)

View file

@ -0,0 +1,342 @@
import warnings
from collections import ChainMap
from functools import wraps
from itertools import chain
from types import FunctionType
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, overload
from .errors import ConfigError
from .typing import AnyCallable
from .utils import ROOT_KEY, in_ipython
if TYPE_CHECKING:
from .typing import AnyClassMethod
class Validator:
__slots__ = 'func', 'pre', 'each_item', 'always', 'check_fields', 'skip_on_failure'
def __init__(
self,
func: AnyCallable,
pre: bool = False,
each_item: bool = False,
always: bool = False,
check_fields: bool = False,
skip_on_failure: bool = False,
):
self.func = func
self.pre = pre
self.each_item = each_item
self.always = always
self.check_fields = check_fields
self.skip_on_failure = skip_on_failure
if TYPE_CHECKING:
from inspect import Signature
from .config import BaseConfig
from .fields import ModelField
from .types import ModelOrDc
ValidatorCallable = Callable[[Optional[ModelOrDc], Any, Dict[str, Any], ModelField, Type[BaseConfig]], Any]
ValidatorsList = List[ValidatorCallable]
ValidatorListDict = Dict[str, List[Validator]]
_FUNCS: Set[str] = set()
VALIDATOR_CONFIG_KEY = '__validator_config__'
ROOT_VALIDATOR_CONFIG_KEY = '__root_validator_config__'
def validator(
*fields: str,
pre: bool = False,
each_item: bool = False,
always: bool = False,
check_fields: bool = True,
whole: bool = None,
allow_reuse: bool = False,
) -> Callable[[AnyCallable], 'AnyClassMethod']:
"""
Decorate methods on the class indicating that they should be used to validate fields
:param fields: which field(s) the method should be called on
:param pre: whether or not this validator should be called before the standard validators (else after)
:param each_item: for complex objects (sets, lists etc.) whether to validate individual elements rather than the
whole object
:param always: whether this method and other validators should be called even if the value is missing
:param check_fields: whether to check that the fields actually exist on the model
:param allow_reuse: whether to track and raise an error if another validator refers to the decorated function
"""
if not fields:
raise ConfigError('validator with no fields specified')
elif isinstance(fields[0], FunctionType):
raise ConfigError(
"validators should be used with fields and keyword arguments, not bare. " # noqa: Q000
"E.g. usage should be `@validator('<field_name>', ...)`"
)
elif not all(isinstance(field, str) for field in fields):
raise ConfigError(
"validator fields should be passed as separate string args. " # noqa: Q000
"E.g. usage should be `@validator('<field_name_1>', '<field_name_2>', ...)`"
)
if whole is not None:
warnings.warn(
'The "whole" keyword argument is deprecated, use "each_item" (inverse meaning, default False) instead',
DeprecationWarning,
)
assert each_item is False, '"each_item" and "whole" conflict, remove "whole"'
each_item = not whole
def dec(f: AnyCallable) -> 'AnyClassMethod':
f_cls = _prepare_validator(f, allow_reuse)
setattr(
f_cls,
VALIDATOR_CONFIG_KEY,
(
fields,
Validator(func=f_cls.__func__, pre=pre, each_item=each_item, always=always, check_fields=check_fields),
),
)
return f_cls
return dec
@overload
def root_validator(_func: AnyCallable) -> 'AnyClassMethod':
...
@overload
def root_validator(
*, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False
) -> Callable[[AnyCallable], 'AnyClassMethod']:
...
def root_validator(
_func: Optional[AnyCallable] = None, *, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False
) -> Union['AnyClassMethod', Callable[[AnyCallable], 'AnyClassMethod']]:
"""
Decorate methods on a model indicating that they should be used to validate (and perhaps modify) data either
before or after standard model parsing/validation is performed.
"""
if _func:
f_cls = _prepare_validator(_func, allow_reuse)
setattr(
f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure)
)
return f_cls
def dec(f: AnyCallable) -> 'AnyClassMethod':
f_cls = _prepare_validator(f, allow_reuse)
setattr(
f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure)
)
return f_cls
return dec
def _prepare_validator(function: AnyCallable, allow_reuse: bool) -> 'AnyClassMethod':
"""
Avoid validators with duplicated names since without this, validators can be overwritten silently
which generally isn't the intended behaviour, don't run in ipython (see #312) or if allow_reuse is False.
"""
f_cls = function if isinstance(function, classmethod) else classmethod(function)
if not in_ipython() and not allow_reuse:
ref = f_cls.__func__.__module__ + '.' + f_cls.__func__.__qualname__
if ref in _FUNCS:
raise ConfigError(f'duplicate validator function "{ref}"; if this is intended, set `allow_reuse=True`')
_FUNCS.add(ref)
return f_cls
class ValidatorGroup:
def __init__(self, validators: 'ValidatorListDict') -> None:
self.validators = validators
self.used_validators = {'*'}
def get_validators(self, name: str) -> Optional[Dict[str, Validator]]:
self.used_validators.add(name)
validators = self.validators.get(name, [])
if name != ROOT_KEY:
validators += self.validators.get('*', [])
if validators:
return {v.func.__name__: v for v in validators}
else:
return None
def check_for_unused(self) -> None:
unused_validators = set(
chain.from_iterable(
(v.func.__name__ for v in self.validators[f] if v.check_fields)
for f in (self.validators.keys() - self.used_validators)
)
)
if unused_validators:
fn = ', '.join(unused_validators)
raise ConfigError(
f"Validators defined with incorrect fields: {fn} " # noqa: Q000
f"(use check_fields=False if you're inheriting from the model and intended this)"
)
def extract_validators(namespace: Dict[str, Any]) -> Dict[str, List[Validator]]:
validators: Dict[str, List[Validator]] = {}
for var_name, value in namespace.items():
validator_config = getattr(value, VALIDATOR_CONFIG_KEY, None)
if validator_config:
fields, v = validator_config
for field in fields:
if field in validators:
validators[field].append(v)
else:
validators[field] = [v]
return validators
def extract_root_validators(namespace: Dict[str, Any]) -> Tuple[List[AnyCallable], List[Tuple[bool, AnyCallable]]]:
from inspect import signature
pre_validators: List[AnyCallable] = []
post_validators: List[Tuple[bool, AnyCallable]] = []
for name, value in namespace.items():
validator_config: Optional[Validator] = getattr(value, ROOT_VALIDATOR_CONFIG_KEY, None)
if validator_config:
sig = signature(validator_config.func)
args = list(sig.parameters.keys())
if args[0] == 'self':
raise ConfigError(
f'Invalid signature for root validator {name}: {sig}, "self" not permitted as first argument, '
f'should be: (cls, values).'
)
if len(args) != 2:
raise ConfigError(f'Invalid signature for root validator {name}: {sig}, should be: (cls, values).')
# check function signature
if validator_config.pre:
pre_validators.append(validator_config.func)
else:
post_validators.append((validator_config.skip_on_failure, validator_config.func))
return pre_validators, post_validators
def inherit_validators(base_validators: 'ValidatorListDict', validators: 'ValidatorListDict') -> 'ValidatorListDict':
for field, field_validators in base_validators.items():
if field not in validators:
validators[field] = []
validators[field] += field_validators
return validators
def make_generic_validator(validator: AnyCallable) -> 'ValidatorCallable':
"""
Make a generic function which calls a validator with the right arguments.
Unfortunately other approaches (eg. return a partial of a function that builds the arguments) is slow,
hence this laborious way of doing things.
It's done like this so validators don't all need **kwargs in their signature, eg. any combination of
the arguments "values", "fields" and/or "config" are permitted.
"""
from inspect import signature
sig = signature(validator)
args = list(sig.parameters.keys())
first_arg = args.pop(0)
if first_arg == 'self':
raise ConfigError(
f'Invalid signature for validator {validator}: {sig}, "self" not permitted as first argument, '
f'should be: (cls, value, values, config, field), "values", "config" and "field" are all optional.'
)
elif first_arg == 'cls':
# assume the second argument is value
return wraps(validator)(_generic_validator_cls(validator, sig, set(args[1:])))
else:
# assume the first argument was value which has already been removed
return wraps(validator)(_generic_validator_basic(validator, sig, set(args)))
def prep_validators(v_funcs: Iterable[AnyCallable]) -> 'ValidatorsList':
return [make_generic_validator(f) for f in v_funcs if f]
all_kwargs = {'values', 'field', 'config'}
def _generic_validator_cls(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable':
# assume the first argument is value
has_kwargs = False
if 'kwargs' in args:
has_kwargs = True
args -= {'kwargs'}
if not args.issubset(all_kwargs):
raise ConfigError(
f'Invalid signature for validator {validator}: {sig}, should be: '
f'(cls, value, values, config, field), "values", "config" and "field" are all optional.'
)
if has_kwargs:
return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config)
elif args == set():
return lambda cls, v, values, field, config: validator(cls, v)
elif args == {'values'}:
return lambda cls, v, values, field, config: validator(cls, v, values=values)
elif args == {'field'}:
return lambda cls, v, values, field, config: validator(cls, v, field=field)
elif args == {'config'}:
return lambda cls, v, values, field, config: validator(cls, v, config=config)
elif args == {'values', 'field'}:
return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field)
elif args == {'values', 'config'}:
return lambda cls, v, values, field, config: validator(cls, v, values=values, config=config)
elif args == {'field', 'config'}:
return lambda cls, v, values, field, config: validator(cls, v, field=field, config=config)
else:
# args == {'values', 'field', 'config'}
return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config)
def _generic_validator_basic(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable':
has_kwargs = False
if 'kwargs' in args:
has_kwargs = True
args -= {'kwargs'}
if not args.issubset(all_kwargs):
raise ConfigError(
f'Invalid signature for validator {validator}: {sig}, should be: '
f'(value, values, config, field), "values", "config" and "field" are all optional.'
)
if has_kwargs:
return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config)
elif args == set():
return lambda cls, v, values, field, config: validator(v)
elif args == {'values'}:
return lambda cls, v, values, field, config: validator(v, values=values)
elif args == {'field'}:
return lambda cls, v, values, field, config: validator(v, field=field)
elif args == {'config'}:
return lambda cls, v, values, field, config: validator(v, config=config)
elif args == {'values', 'field'}:
return lambda cls, v, values, field, config: validator(v, values=values, field=field)
elif args == {'values', 'config'}:
return lambda cls, v, values, field, config: validator(v, values=values, config=config)
elif args == {'field', 'config'}:
return lambda cls, v, values, field, config: validator(v, field=field, config=config)
else:
# args == {'values', 'field', 'config'}
return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config)
def gather_all_validators(type_: 'ModelOrDc') -> Dict[str, 'AnyClassMethod']:
all_attributes = ChainMap(*[cls.__dict__ for cls in type_.__mro__]) # type: ignore[arg-type,var-annotated]
return {
k: v
for k, v in all_attributes.items()
if hasattr(v, VALIDATOR_CONFIG_KEY) or hasattr(v, ROOT_VALIDATOR_CONFIG_KEY)
}

494
lib/pydantic/color.py Normal file
View file

@ -0,0 +1,494 @@
"""
Color definitions are used as per CSS3 specification:
http://www.w3.org/TR/css3-color/#svg-color
A few colors have multiple names referring to the sames colors, eg. `grey` and `gray` or `aqua` and `cyan`.
In these cases the LAST color when sorted alphabetically takes preferences,
eg. Color((0, 255, 255)).as_named() == 'cyan' because "cyan" comes after "aqua".
"""
import math
import re
from colorsys import hls_to_rgb, rgb_to_hls
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast
from .errors import ColorError
from .utils import Representation, almost_equal_floats
if TYPE_CHECKING:
from .typing import CallableGenerator, ReprArgs
ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]]
ColorType = Union[ColorTuple, str]
HslColorTuple = Union[Tuple[float, float, float], Tuple[float, float, float, float]]
class RGBA:
"""
Internal use only as a representation of a color.
"""
__slots__ = 'r', 'g', 'b', 'alpha', '_tuple'
def __init__(self, r: float, g: float, b: float, alpha: Optional[float]):
self.r = r
self.g = g
self.b = b
self.alpha = alpha
self._tuple: Tuple[float, float, float, Optional[float]] = (r, g, b, alpha)
def __getitem__(self, item: Any) -> Any:
return self._tuple[item]
# these are not compiled here to avoid import slowdown, they'll be compiled the first time they're used, then cached
r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*'
r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*'
_r_255 = r'(\d{1,3}(?:\.\d+)?)'
_r_comma = r'\s*,\s*'
r_rgb = fr'\s*rgb\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}\)\s*'
_r_alpha = r'(\d(?:\.\d+)?|\.\d+|\d{1,2}%)'
r_rgba = fr'\s*rgba\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_alpha}\s*\)\s*'
_r_h = r'(-?\d+(?:\.\d+)?|-?\.\d+)(deg|rad|turn)?'
_r_sl = r'(\d{1,3}(?:\.\d+)?)%'
r_hsl = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}\s*\)\s*'
r_hsla = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}{_r_comma}{_r_alpha}\s*\)\s*'
# colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used
repeat_colors = {int(c * 2, 16) for c in '0123456789abcdef'}
rads = 2 * math.pi
class Color(Representation):
__slots__ = '_original', '_rgba'
def __init__(self, value: ColorType) -> None:
self._rgba: RGBA
self._original: ColorType
if isinstance(value, (tuple, list)):
self._rgba = parse_tuple(value)
elif isinstance(value, str):
self._rgba = parse_str(value)
elif isinstance(value, Color):
self._rgba = value._rgba
value = value._original
else:
raise ColorError(reason='value must be a tuple, list or string')
# if we've got here value must be a valid color
self._original = value
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
field_schema.update(type='string', format='color')
def original(self) -> ColorType:
"""
Original value passed to Color
"""
return self._original
def as_named(self, *, fallback: bool = False) -> str:
if self._rgba.alpha is None:
rgb = cast(Tuple[int, int, int], self.as_rgb_tuple())
try:
return COLORS_BY_VALUE[rgb]
except KeyError as e:
if fallback:
return self.as_hex()
else:
raise ValueError('no named color found, use fallback=True, as_hex() or as_rgb()') from e
else:
return self.as_hex()
def as_hex(self) -> str:
"""
Hex string representing the color can be 3, 4, 6 or 8 characters depending on whether the string
a "short" representation of the color is possible and whether there's an alpha channel.
"""
values = [float_to_255(c) for c in self._rgba[:3]]
if self._rgba.alpha is not None:
values.append(float_to_255(self._rgba.alpha))
as_hex = ''.join(f'{v:02x}' for v in values)
if all(c in repeat_colors for c in values):
as_hex = ''.join(as_hex[c] for c in range(0, len(as_hex), 2))
return '#' + as_hex
def as_rgb(self) -> str:
"""
Color as an rgb(<r>, <g>, <b>) or rgba(<r>, <g>, <b>, <a>) string.
"""
if self._rgba.alpha is None:
return f'rgb({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)})'
else:
return (
f'rgba({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)}, '
f'{round(self._alpha_float(), 2)})'
)
def as_rgb_tuple(self, *, alpha: Optional[bool] = None) -> ColorTuple:
"""
Color as an RGB or RGBA tuple; red, green and blue are in the range 0 to 255, alpha if included is
in the range 0 to 1.
:param alpha: whether to include the alpha channel, options are
None - (default) include alpha only if it's set (e.g. not None)
True - always include alpha,
False - always omit alpha,
"""
r, g, b = (float_to_255(c) for c in self._rgba[:3])
if alpha is None:
if self._rgba.alpha is None:
return r, g, b
else:
return r, g, b, self._alpha_float()
elif alpha:
return r, g, b, self._alpha_float()
else:
# alpha is False
return r, g, b
def as_hsl(self) -> str:
"""
Color as an hsl(<h>, <s>, <l>) or hsl(<h>, <s>, <l>, <a>) string.
"""
if self._rgba.alpha is None:
h, s, li = self.as_hsl_tuple(alpha=False) # type: ignore
return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%})'
else:
h, s, li, a = self.as_hsl_tuple(alpha=True) # type: ignore
return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%}, {round(a, 2)})'
def as_hsl_tuple(self, *, alpha: Optional[bool] = None) -> HslColorTuple:
"""
Color as an HSL or HSLA tuple, e.g. hue, saturation, lightness and optionally alpha; all elements are in
the range 0 to 1.
NOTE: this is HSL as used in HTML and most other places, not HLS as used in python's colorsys.
:param alpha: whether to include the alpha channel, options are
None - (default) include alpha only if it's set (e.g. not None)
True - always include alpha,
False - always omit alpha,
"""
h, l, s = rgb_to_hls(self._rgba.r, self._rgba.g, self._rgba.b)
if alpha is None:
if self._rgba.alpha is None:
return h, s, l
else:
return h, s, l, self._alpha_float()
if alpha:
return h, s, l, self._alpha_float()
else:
# alpha is False
return h, s, l
def _alpha_float(self) -> float:
return 1 if self._rgba.alpha is None else self._rgba.alpha
@classmethod
def __get_validators__(cls) -> 'CallableGenerator':
yield cls
def __str__(self) -> str:
return self.as_named(fallback=True)
def __repr_args__(self) -> 'ReprArgs':
return [(None, self.as_named(fallback=True))] + [('rgb', self.as_rgb_tuple())] # type: ignore
def __eq__(self, other: Any) -> bool:
return isinstance(other, Color) and self.as_rgb_tuple() == other.as_rgb_tuple()
def __hash__(self) -> int:
return hash(self.as_rgb_tuple())
def parse_tuple(value: Tuple[Any, ...]) -> RGBA:
"""
Parse a tuple or list as a color.
"""
if len(value) == 3:
r, g, b = (parse_color_value(v) for v in value)
return RGBA(r, g, b, None)
elif len(value) == 4:
r, g, b = (parse_color_value(v) for v in value[:3])
return RGBA(r, g, b, parse_float_alpha(value[3]))
else:
raise ColorError(reason='tuples must have length 3 or 4')
def parse_str(value: str) -> RGBA:
"""
Parse a string to an RGBA tuple, trying the following formats (in this order):
* named color, see COLORS_BY_NAME below
* hex short eg. `<prefix>fff` (prefix can be `#`, `0x` or nothing)
* hex long eg. `<prefix>ffffff` (prefix can be `#`, `0x` or nothing)
* `rgb(<r>, <g>, <b>) `
* `rgba(<r>, <g>, <b>, <a>)`
"""
value_lower = value.lower()
try:
r, g, b = COLORS_BY_NAME[value_lower]
except KeyError:
pass
else:
return ints_to_rgba(r, g, b, None)
m = re.fullmatch(r_hex_short, value_lower)
if m:
*rgb, a = m.groups()
r, g, b = (int(v * 2, 16) for v in rgb)
if a:
alpha: Optional[float] = int(a * 2, 16) / 255
else:
alpha = None
return ints_to_rgba(r, g, b, alpha)
m = re.fullmatch(r_hex_long, value_lower)
if m:
*rgb, a = m.groups()
r, g, b = (int(v, 16) for v in rgb)
if a:
alpha = int(a, 16) / 255
else:
alpha = None
return ints_to_rgba(r, g, b, alpha)
m = re.fullmatch(r_rgb, value_lower)
if m:
return ints_to_rgba(*m.groups(), None) # type: ignore
m = re.fullmatch(r_rgba, value_lower)
if m:
return ints_to_rgba(*m.groups()) # type: ignore
m = re.fullmatch(r_hsl, value_lower)
if m:
h, h_units, s, l_ = m.groups()
return parse_hsl(h, h_units, s, l_)
m = re.fullmatch(r_hsla, value_lower)
if m:
h, h_units, s, l_, a = m.groups()
return parse_hsl(h, h_units, s, l_, parse_float_alpha(a))
raise ColorError(reason='string not recognised as a valid color')
def ints_to_rgba(r: Union[int, str], g: Union[int, str], b: Union[int, str], alpha: Optional[float]) -> RGBA:
return RGBA(parse_color_value(r), parse_color_value(g), parse_color_value(b), parse_float_alpha(alpha))
def parse_color_value(value: Union[int, str], max_val: int = 255) -> float:
"""
Parse a value checking it's a valid int in the range 0 to max_val and divide by max_val to give a number
in the range 0 to 1
"""
try:
color = float(value)
except ValueError:
raise ColorError(reason='color values must be a valid number')
if 0 <= color <= max_val:
return color / max_val
else:
raise ColorError(reason=f'color values must be in the range 0 to {max_val}')
def parse_float_alpha(value: Union[None, str, float, int]) -> Optional[float]:
"""
Parse a value checking it's a valid float in the range 0 to 1
"""
if value is None:
return None
try:
if isinstance(value, str) and value.endswith('%'):
alpha = float(value[:-1]) / 100
else:
alpha = float(value)
except ValueError:
raise ColorError(reason='alpha values must be a valid float')
if almost_equal_floats(alpha, 1):
return None
elif 0 <= alpha <= 1:
return alpha
else:
raise ColorError(reason='alpha values must be in the range 0 to 1')
def parse_hsl(h: str, h_units: str, sat: str, light: str, alpha: Optional[float] = None) -> RGBA:
"""
Parse raw hue, saturation, lightness and alpha values and convert to RGBA.
"""
s_value, l_value = parse_color_value(sat, 100), parse_color_value(light, 100)
h_value = float(h)
if h_units in {None, 'deg'}:
h_value = h_value % 360 / 360
elif h_units == 'rad':
h_value = h_value % rads / rads
else:
# turns
h_value = h_value % 1
r, g, b = hls_to_rgb(h_value, l_value, s_value)
return RGBA(r, g, b, alpha)
def float_to_255(c: float) -> int:
return int(round(c * 255))
COLORS_BY_NAME = {
'aliceblue': (240, 248, 255),
'antiquewhite': (250, 235, 215),
'aqua': (0, 255, 255),
'aquamarine': (127, 255, 212),
'azure': (240, 255, 255),
'beige': (245, 245, 220),
'bisque': (255, 228, 196),
'black': (0, 0, 0),
'blanchedalmond': (255, 235, 205),
'blue': (0, 0, 255),
'blueviolet': (138, 43, 226),
'brown': (165, 42, 42),
'burlywood': (222, 184, 135),
'cadetblue': (95, 158, 160),
'chartreuse': (127, 255, 0),
'chocolate': (210, 105, 30),
'coral': (255, 127, 80),
'cornflowerblue': (100, 149, 237),
'cornsilk': (255, 248, 220),
'crimson': (220, 20, 60),
'cyan': (0, 255, 255),
'darkblue': (0, 0, 139),
'darkcyan': (0, 139, 139),
'darkgoldenrod': (184, 134, 11),
'darkgray': (169, 169, 169),
'darkgreen': (0, 100, 0),
'darkgrey': (169, 169, 169),
'darkkhaki': (189, 183, 107),
'darkmagenta': (139, 0, 139),
'darkolivegreen': (85, 107, 47),
'darkorange': (255, 140, 0),
'darkorchid': (153, 50, 204),
'darkred': (139, 0, 0),
'darksalmon': (233, 150, 122),
'darkseagreen': (143, 188, 143),
'darkslateblue': (72, 61, 139),
'darkslategray': (47, 79, 79),
'darkslategrey': (47, 79, 79),
'darkturquoise': (0, 206, 209),
'darkviolet': (148, 0, 211),
'deeppink': (255, 20, 147),
'deepskyblue': (0, 191, 255),
'dimgray': (105, 105, 105),
'dimgrey': (105, 105, 105),
'dodgerblue': (30, 144, 255),
'firebrick': (178, 34, 34),
'floralwhite': (255, 250, 240),
'forestgreen': (34, 139, 34),
'fuchsia': (255, 0, 255),
'gainsboro': (220, 220, 220),
'ghostwhite': (248, 248, 255),
'gold': (255, 215, 0),
'goldenrod': (218, 165, 32),
'gray': (128, 128, 128),
'green': (0, 128, 0),
'greenyellow': (173, 255, 47),
'grey': (128, 128, 128),
'honeydew': (240, 255, 240),
'hotpink': (255, 105, 180),
'indianred': (205, 92, 92),
'indigo': (75, 0, 130),
'ivory': (255, 255, 240),
'khaki': (240, 230, 140),
'lavender': (230, 230, 250),
'lavenderblush': (255, 240, 245),
'lawngreen': (124, 252, 0),
'lemonchiffon': (255, 250, 205),
'lightblue': (173, 216, 230),
'lightcoral': (240, 128, 128),
'lightcyan': (224, 255, 255),
'lightgoldenrodyellow': (250, 250, 210),
'lightgray': (211, 211, 211),
'lightgreen': (144, 238, 144),
'lightgrey': (211, 211, 211),
'lightpink': (255, 182, 193),
'lightsalmon': (255, 160, 122),
'lightseagreen': (32, 178, 170),
'lightskyblue': (135, 206, 250),
'lightslategray': (119, 136, 153),
'lightslategrey': (119, 136, 153),
'lightsteelblue': (176, 196, 222),
'lightyellow': (255, 255, 224),
'lime': (0, 255, 0),
'limegreen': (50, 205, 50),
'linen': (250, 240, 230),
'magenta': (255, 0, 255),
'maroon': (128, 0, 0),
'mediumaquamarine': (102, 205, 170),
'mediumblue': (0, 0, 205),
'mediumorchid': (186, 85, 211),
'mediumpurple': (147, 112, 219),
'mediumseagreen': (60, 179, 113),
'mediumslateblue': (123, 104, 238),
'mediumspringgreen': (0, 250, 154),
'mediumturquoise': (72, 209, 204),
'mediumvioletred': (199, 21, 133),
'midnightblue': (25, 25, 112),
'mintcream': (245, 255, 250),
'mistyrose': (255, 228, 225),
'moccasin': (255, 228, 181),
'navajowhite': (255, 222, 173),
'navy': (0, 0, 128),
'oldlace': (253, 245, 230),
'olive': (128, 128, 0),
'olivedrab': (107, 142, 35),
'orange': (255, 165, 0),
'orangered': (255, 69, 0),
'orchid': (218, 112, 214),
'palegoldenrod': (238, 232, 170),
'palegreen': (152, 251, 152),
'paleturquoise': (175, 238, 238),
'palevioletred': (219, 112, 147),
'papayawhip': (255, 239, 213),
'peachpuff': (255, 218, 185),
'peru': (205, 133, 63),
'pink': (255, 192, 203),
'plum': (221, 160, 221),
'powderblue': (176, 224, 230),
'purple': (128, 0, 128),
'red': (255, 0, 0),
'rosybrown': (188, 143, 143),
'royalblue': (65, 105, 225),
'saddlebrown': (139, 69, 19),
'salmon': (250, 128, 114),
'sandybrown': (244, 164, 96),
'seagreen': (46, 139, 87),
'seashell': (255, 245, 238),
'sienna': (160, 82, 45),
'silver': (192, 192, 192),
'skyblue': (135, 206, 235),
'slateblue': (106, 90, 205),
'slategray': (112, 128, 144),
'slategrey': (112, 128, 144),
'snow': (255, 250, 250),
'springgreen': (0, 255, 127),
'steelblue': (70, 130, 180),
'tan': (210, 180, 140),
'teal': (0, 128, 128),
'thistle': (216, 191, 216),
'tomato': (255, 99, 71),
'turquoise': (64, 224, 208),
'violet': (238, 130, 238),
'wheat': (245, 222, 179),
'white': (255, 255, 255),
'whitesmoke': (245, 245, 245),
'yellow': (255, 255, 0),
'yellowgreen': (154, 205, 50),
}
COLORS_BY_VALUE = {v: k for k, v in COLORS_BY_NAME.items()}

192
lib/pydantic/config.py Normal file
View file

@ -0,0 +1,192 @@
import json
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, ForwardRef, Optional, Tuple, Type, Union
from typing_extensions import Literal, Protocol
from .typing import AnyArgTCallable, AnyCallable
from .utils import GetterDict
from .version import compiled
if TYPE_CHECKING:
from typing import overload
from .fields import ModelField
from .main import BaseModel
ConfigType = Type['BaseConfig']
class SchemaExtraCallable(Protocol):
@overload
def __call__(self, schema: Dict[str, Any]) -> None:
pass
@overload
def __call__(self, schema: Dict[str, Any], model_class: Type[BaseModel]) -> None:
pass
else:
SchemaExtraCallable = Callable[..., None]
__all__ = 'BaseConfig', 'ConfigDict', 'get_config', 'Extra', 'inherit_config', 'prepare_config'
class Extra(str, Enum):
allow = 'allow'
ignore = 'ignore'
forbid = 'forbid'
# https://github.com/cython/cython/issues/4003
# Will be fixed with Cython 3 but still in alpha right now
if not compiled:
from typing_extensions import TypedDict
class ConfigDict(TypedDict, total=False):
title: Optional[str]
anystr_lower: bool
anystr_strip_whitespace: bool
min_anystr_length: int
max_anystr_length: Optional[int]
validate_all: bool
extra: Extra
allow_mutation: bool
frozen: bool
allow_population_by_field_name: bool
use_enum_values: bool
fields: Dict[str, Union[str, Dict[str, str]]]
validate_assignment: bool
error_msg_templates: Dict[str, str]
arbitrary_types_allowed: bool
orm_mode: bool
getter_dict: Type[GetterDict]
alias_generator: Optional[Callable[[str], str]]
keep_untouched: Tuple[type, ...]
schema_extra: Union[Dict[str, object], 'SchemaExtraCallable']
json_loads: Callable[[str], object]
json_dumps: AnyArgTCallable[str]
json_encoders: Dict[Type[object], AnyCallable]
underscore_attrs_are_private: bool
allow_inf_nan: bool
# whether or not inherited models as fields should be reconstructed as base model
copy_on_model_validation: bool
# whether dataclass `__post_init__` should be run after validation
post_init_call: Literal['before_validation', 'after_validation']
else:
ConfigDict = dict # type: ignore
class BaseConfig:
title: Optional[str] = None
anystr_lower: bool = False
anystr_upper: bool = False
anystr_strip_whitespace: bool = False
min_anystr_length: int = 0
max_anystr_length: Optional[int] = None
validate_all: bool = False
extra: Extra = Extra.ignore
allow_mutation: bool = True
frozen: bool = False
allow_population_by_field_name: bool = False
use_enum_values: bool = False
fields: Dict[str, Union[str, Dict[str, str]]] = {}
validate_assignment: bool = False
error_msg_templates: Dict[str, str] = {}
arbitrary_types_allowed: bool = False
orm_mode: bool = False
getter_dict: Type[GetterDict] = GetterDict
alias_generator: Optional[Callable[[str], str]] = None
keep_untouched: Tuple[type, ...] = ()
schema_extra: Union[Dict[str, Any], 'SchemaExtraCallable'] = {}
json_loads: Callable[[str], Any] = json.loads
json_dumps: Callable[..., str] = json.dumps
json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable] = {}
underscore_attrs_are_private: bool = False
allow_inf_nan: bool = True
# whether inherited models as fields should be reconstructed as base model,
# and whether such a copy should be shallow or deep
copy_on_model_validation: Literal['none', 'deep', 'shallow'] = 'shallow'
# whether `Union` should check all allowed types before even trying to coerce
smart_union: bool = False
# whether dataclass `__post_init__` should be run before or after validation
post_init_call: Literal['before_validation', 'after_validation'] = 'before_validation'
@classmethod
def get_field_info(cls, name: str) -> Dict[str, Any]:
"""
Get properties of FieldInfo from the `fields` property of the config class.
"""
fields_value = cls.fields.get(name)
if isinstance(fields_value, str):
field_info: Dict[str, Any] = {'alias': fields_value}
elif isinstance(fields_value, dict):
field_info = fields_value
else:
field_info = {}
if 'alias' in field_info:
field_info.setdefault('alias_priority', 2)
if field_info.get('alias_priority', 0) <= 1 and cls.alias_generator:
alias = cls.alias_generator(name)
if not isinstance(alias, str):
raise TypeError(f'Config.alias_generator must return str, not {alias.__class__}')
field_info.update(alias=alias, alias_priority=1)
return field_info
@classmethod
def prepare_field(cls, field: 'ModelField') -> None:
"""
Optional hook to check or modify fields during model creation.
"""
pass
def get_config(config: Union[ConfigDict, Type[object], None]) -> Type[BaseConfig]:
if config is None:
return BaseConfig
else:
config_dict = (
config
if isinstance(config, dict)
else {k: getattr(config, k) for k in dir(config) if not k.startswith('__')}
)
class Config(BaseConfig):
...
for k, v in config_dict.items():
setattr(Config, k, v)
return Config
def inherit_config(self_config: 'ConfigType', parent_config: 'ConfigType', **namespace: Any) -> 'ConfigType':
if not self_config:
base_classes: Tuple['ConfigType', ...] = (parent_config,)
elif self_config == parent_config:
base_classes = (self_config,)
else:
base_classes = self_config, parent_config
namespace['json_encoders'] = {
**getattr(parent_config, 'json_encoders', {}),
**getattr(self_config, 'json_encoders', {}),
**namespace.get('json_encoders', {}),
}
return type('Config', base_classes, namespace)
def prepare_config(config: Type[BaseConfig], cls_name: str) -> None:
if not isinstance(config.extra, Extra):
try:
config.extra = Extra(config.extra)
except ValueError:
raise ValueError(f'"{cls_name}": {config.extra} is not a valid value for "extra"')

479
lib/pydantic/dataclasses.py Normal file
View file

@ -0,0 +1,479 @@
"""
The main purpose is to enhance stdlib dataclasses by adding validation
A pydantic dataclass can be generated from scratch or from a stdlib one.
Behind the scene, a pydantic dataclass is just like a regular one on which we attach
a `BaseModel` and magic methods to trigger the validation of the data.
`__init__` and `__post_init__` are hence overridden and have extra logic to be
able to validate input data.
When a pydantic dataclass is generated from scratch, it's just a plain dataclass
with validation triggered at initialization
The tricky part if for stdlib dataclasses that are converted after into pydantic ones e.g.
```py
@dataclasses.dataclass
class M:
x: int
ValidatedM = pydantic.dataclasses.dataclass(M)
```
We indeed still want to support equality, hashing, repr, ... as if it was the stdlib one!
```py
assert isinstance(ValidatedM(x=1), M)
assert ValidatedM(x=1) == M(x=1)
```
This means we **don't want to create a new dataclass that inherits from it**
The trick is to create a wrapper around `M` that will act as a proxy to trigger
validation without altering default `M` behaviour.
"""
import sys
from contextlib import contextmanager
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Generator,
Optional,
Set,
Type,
TypeVar,
Union,
overload,
)
from typing_extensions import dataclass_transform
from .class_validators import gather_all_validators
from .config import BaseConfig, ConfigDict, Extra, get_config
from .error_wrappers import ValidationError
from .errors import DataclassTypeError
from .fields import Field, FieldInfo, Required, Undefined
from .main import create_model, validate_model
from .utils import ClassAttribute
if TYPE_CHECKING:
from .main import BaseModel
from .typing import CallableGenerator, NoArgAnyCallable
DataclassT = TypeVar('DataclassT', bound='Dataclass')
DataclassClassOrWrapper = Union[Type['Dataclass'], 'DataclassProxy']
class Dataclass:
# stdlib attributes
__dataclass_fields__: ClassVar[Dict[str, Any]]
__dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
__post_init__: ClassVar[Callable[..., None]]
# Added by pydantic
__pydantic_run_validation__: ClassVar[bool]
__post_init_post_parse__: ClassVar[Callable[..., None]]
__pydantic_initialised__: ClassVar[bool]
__pydantic_model__: ClassVar[Type[BaseModel]]
__pydantic_validate_values__: ClassVar[Callable[['Dataclass'], None]]
__pydantic_has_field_info_default__: ClassVar[bool] # whether a `pydantic.Field` is used as default value
def __init__(self, *args: object, **kwargs: object) -> None:
pass
@classmethod
def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator':
pass
@classmethod
def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
pass
__all__ = [
'dataclass',
'set_validation',
'create_pydantic_model_from_dataclass',
'is_builtin_dataclass',
'make_dataclass_validator',
]
_T = TypeVar('_T')
if sys.version_info >= (3, 10):
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[object], None] = None,
validate_on_init: Optional[bool] = None,
kw_only: bool = ...,
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
...
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
_cls: Type[_T],
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[object], None] = None,
validate_on_init: Optional[bool] = None,
kw_only: bool = ...,
) -> 'DataclassClassOrWrapper':
...
else:
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[object], None] = None,
validate_on_init: Optional[bool] = None,
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']:
...
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo))
@overload
def dataclass(
_cls: Type[_T],
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[object], None] = None,
validate_on_init: Optional[bool] = None,
) -> 'DataclassClassOrWrapper':
...
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo))
def dataclass(
_cls: Optional[Type[_T]] = None,
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
config: Union[ConfigDict, Type[object], None] = None,
validate_on_init: Optional[bool] = None,
kw_only: bool = False,
) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']:
"""
Like the python standard lib dataclasses but with type validation.
The result is either a pydantic dataclass that will validate input data
or a wrapper that will trigger validation around a stdlib dataclass
to avoid modifying it directly
"""
the_config = get_config(config)
def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper':
import dataclasses
if is_builtin_dataclass(cls) and _extra_dc_args(_cls) == _extra_dc_args(_cls.__bases__[0]): # type: ignore
dc_cls_doc = ''
dc_cls = DataclassProxy(cls)
default_validate_on_init = False
else:
dc_cls_doc = cls.__doc__ or '' # needs to be done before generating dataclass
if sys.version_info >= (3, 10):
dc_cls = dataclasses.dataclass(
cls,
init=init,
repr=repr,
eq=eq,
order=order,
unsafe_hash=unsafe_hash,
frozen=frozen,
kw_only=kw_only,
)
else:
dc_cls = dataclasses.dataclass( # type: ignore
cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
)
default_validate_on_init = True
should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init
_add_pydantic_validation_attributes(cls, the_config, should_validate_on_init, dc_cls_doc)
dc_cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls})
return dc_cls
if _cls is None:
return wrap
return wrap(_cls)
@contextmanager
def set_validation(cls: Type['DataclassT'], value: bool) -> Generator[Type['DataclassT'], None, None]:
original_run_validation = cls.__pydantic_run_validation__
try:
cls.__pydantic_run_validation__ = value
yield cls
finally:
cls.__pydantic_run_validation__ = original_run_validation
class DataclassProxy:
__slots__ = '__dataclass__'
def __init__(self, dc_cls: Type['Dataclass']) -> None:
object.__setattr__(self, '__dataclass__', dc_cls)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
with set_validation(self.__dataclass__, True):
return self.__dataclass__(*args, **kwargs)
def __getattr__(self, name: str) -> Any:
return getattr(self.__dataclass__, name)
def __instancecheck__(self, instance: Any) -> bool:
return isinstance(instance, self.__dataclass__)
def _add_pydantic_validation_attributes( # noqa: C901 (ignore complexity)
dc_cls: Type['Dataclass'],
config: Type[BaseConfig],
validate_on_init: bool,
dc_cls_doc: str,
) -> None:
"""
We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass
it won't even exist (code is generated on the fly by `dataclasses`)
By default, we run validation after `__init__` or `__post_init__` if defined
"""
init = dc_cls.__init__
@wraps(init)
def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
if config.extra == Extra.ignore:
init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
elif config.extra == Extra.allow:
for k, v in kwargs.items():
self.__dict__.setdefault(k, v)
init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
else:
init(self, *args, **kwargs)
if hasattr(dc_cls, '__post_init__'):
post_init = dc_cls.__post_init__
@wraps(post_init)
def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
if config.post_init_call == 'before_validation':
post_init(self, *args, **kwargs)
if self.__class__.__pydantic_run_validation__:
self.__pydantic_validate_values__()
if hasattr(self, '__post_init_post_parse__'):
self.__post_init_post_parse__(*args, **kwargs)
if config.post_init_call == 'after_validation':
post_init(self, *args, **kwargs)
setattr(dc_cls, '__init__', handle_extra_init)
setattr(dc_cls, '__post_init__', new_post_init)
else:
@wraps(init)
def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
handle_extra_init(self, *args, **kwargs)
if self.__class__.__pydantic_run_validation__:
self.__pydantic_validate_values__()
if hasattr(self, '__post_init_post_parse__'):
# We need to find again the initvars. To do that we use `__dataclass_fields__` instead of
# public method `dataclasses.fields`
import dataclasses
# get all initvars and their default values
initvars_and_values: Dict[str, Any] = {}
for i, f in enumerate(self.__class__.__dataclass_fields__.values()):
if f._field_type is dataclasses._FIELD_INITVAR: # type: ignore[attr-defined]
try:
# set arg value by default
initvars_and_values[f.name] = args[i]
except IndexError:
initvars_and_values[f.name] = kwargs.get(f.name, f.default)
self.__post_init_post_parse__(**initvars_and_values)
setattr(dc_cls, '__init__', new_init)
setattr(dc_cls, '__pydantic_run_validation__', ClassAttribute('__pydantic_run_validation__', validate_on_init))
setattr(dc_cls, '__pydantic_initialised__', False)
setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config, dc_cls_doc))
setattr(dc_cls, '__pydantic_validate_values__', _dataclass_validate_values)
setattr(dc_cls, '__validate__', classmethod(_validate_dataclass))
setattr(dc_cls, '__get_validators__', classmethod(_get_validators))
if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen:
setattr(dc_cls, '__setattr__', _dataclass_validate_assignment_setattr)
def _get_validators(cls: 'DataclassClassOrWrapper') -> 'CallableGenerator':
yield cls.__validate__
def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
with set_validation(cls, True):
if isinstance(v, cls):
v.__pydantic_validate_values__()
return v
elif isinstance(v, (list, tuple)):
return cls(*v)
elif isinstance(v, dict):
return cls(**v)
else:
raise DataclassTypeError(class_name=cls.__name__)
def create_pydantic_model_from_dataclass(
dc_cls: Type['Dataclass'],
config: Type[Any] = BaseConfig,
dc_cls_doc: Optional[str] = None,
) -> Type['BaseModel']:
import dataclasses
field_definitions: Dict[str, Any] = {}
for field in dataclasses.fields(dc_cls):
default: Any = Undefined
default_factory: Optional['NoArgAnyCallable'] = None
field_info: FieldInfo
if field.default is not dataclasses.MISSING:
default = field.default
elif field.default_factory is not dataclasses.MISSING:
default_factory = field.default_factory
else:
default = Required
if isinstance(default, FieldInfo):
field_info = default
dc_cls.__pydantic_has_field_info_default__ = True
else:
field_info = Field(default=default, default_factory=default_factory, **field.metadata)
field_definitions[field.name] = (field.type, field_info)
validators = gather_all_validators(dc_cls)
model: Type['BaseModel'] = create_model(
dc_cls.__name__,
__config__=config,
__module__=dc_cls.__module__,
__validators__=validators,
__cls_kwargs__={'__resolve_forward_refs__': False},
**field_definitions,
)
model.__doc__ = dc_cls_doc if dc_cls_doc is not None else dc_cls.__doc__ or ''
return model
def _dataclass_validate_values(self: 'Dataclass') -> None:
# validation errors can occur if this function is called twice on an already initialised dataclass.
# for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property
if getattr(self, '__pydantic_initialised__'):
return
if getattr(self, '__pydantic_has_field_info_default__', False):
# We need to remove `FieldInfo` values since they are not valid as input
# It's ok to do that because they are obviously the default values!
input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)}
else:
input_data = self.__dict__
d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
if validation_error:
raise validation_error
self.__dict__.update(d)
object.__setattr__(self, '__pydantic_initialised__', True)
def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None:
if self.__pydantic_initialised__:
d = dict(self.__dict__)
d.pop(name, None)
known_field = self.__pydantic_model__.__fields__.get(name, None)
if known_field:
value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__)
if error_:
raise ValidationError([error_], self.__class__)
object.__setattr__(self, name, value)
def _extra_dc_args(cls: Type[Any]) -> Set[str]:
return {
x
for x in dir(cls)
if x not in getattr(cls, '__dataclass_fields__', {}) and not (x.startswith('__') and x.endswith('__'))
}
def is_builtin_dataclass(_cls: Type[Any]) -> bool:
"""
Whether a class is a stdlib dataclass
(useful to discriminated a pydantic dataclass that is actually a wrapper around a stdlib dataclass)
we check that
- `_cls` is a dataclass
- `_cls` is not a processed pydantic dataclass (with a basemodel attached)
- `_cls` is not a pydantic dataclass inheriting directly from a stdlib dataclass
e.g.
```
@dataclasses.dataclass
class A:
x: int
@pydantic.dataclasses.dataclass
class B(A):
y: int
```
In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
"""
import dataclasses
return (
dataclasses.is_dataclass(_cls)
and not hasattr(_cls, '__pydantic_model__')
and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {})))
)
def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]) -> 'CallableGenerator':
"""
Create a pydantic.dataclass from a builtin dataclass to add type validation
and yield the validators
It retrieves the parameters of the dataclass and forwards them to the newly created dataclass
"""
yield from _get_validators(dataclass(dc_cls, config=config, validate_on_init=False))

View file

@ -0,0 +1,248 @@
"""
Functions to parse datetime objects.
We're using regular expressions rather than time.strptime because:
- They provide both validation and parsing.
- They're more flexible for datetimes.
- The date/datetime/time constructors produce friendlier error messages.
Stolen from https://raw.githubusercontent.com/django/django/main/django/utils/dateparse.py at
9718fa2e8abe430c3526a9278dd976443d4ae3c6
Changed to:
* use standard python datetime types not django.utils.timezone
* raise ValueError when regex doesn't match rather than returning None
* support parsing unix timestamps for dates and datetimes
"""
import re
from datetime import date, datetime, time, timedelta, timezone
from typing import Dict, Optional, Type, Union
from . import errors
date_expr = r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})'
time_expr = (
r'(?P<hour>\d{1,2}):(?P<minute>\d{1,2})'
r'(?::(?P<second>\d{1,2})(?:\.(?P<microsecond>\d{1,6})\d{0,6})?)?'
r'(?P<tzinfo>Z|[+-]\d{2}(?::?\d{2})?)?$'
)
date_re = re.compile(f'{date_expr}$')
time_re = re.compile(time_expr)
datetime_re = re.compile(f'{date_expr}[T ]{time_expr}')
standard_duration_re = re.compile(
r'^'
r'(?:(?P<days>-?\d+) (days?, )?)?'
r'((?:(?P<hours>-?\d+):)(?=\d+:\d+))?'
r'(?:(?P<minutes>-?\d+):)?'
r'(?P<seconds>-?\d+)'
r'(?:\.(?P<microseconds>\d{1,6})\d{0,6})?'
r'$'
)
# Support the sections of ISO 8601 date representation that are accepted by timedelta
iso8601_duration_re = re.compile(
r'^(?P<sign>[-+]?)'
r'P'
r'(?:(?P<days>\d+(.\d+)?)D)?'
r'(?:T'
r'(?:(?P<hours>\d+(.\d+)?)H)?'
r'(?:(?P<minutes>\d+(.\d+)?)M)?'
r'(?:(?P<seconds>\d+(.\d+)?)S)?'
r')?'
r'$'
)
EPOCH = datetime(1970, 1, 1)
# if greater than this, the number is in ms, if less than or equal it's in seconds
# (in seconds this is 11th October 2603, in ms it's 20th August 1970)
MS_WATERSHED = int(2e10)
# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9
MAX_NUMBER = int(3e20)
StrBytesIntFloat = Union[str, bytes, int, float]
def get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]:
if isinstance(value, (int, float)):
return value
try:
return float(value)
except ValueError:
return None
except TypeError:
raise TypeError(f'invalid type; expected {native_expected_type}, string, bytes, int or float')
def from_unix_seconds(seconds: Union[int, float]) -> datetime:
if seconds > MAX_NUMBER:
return datetime.max
elif seconds < -MAX_NUMBER:
return datetime.min
while abs(seconds) > MS_WATERSHED:
seconds /= 1000
dt = EPOCH + timedelta(seconds=seconds)
return dt.replace(tzinfo=timezone.utc)
def _parse_timezone(value: Optional[str], error: Type[Exception]) -> Union[None, int, timezone]:
if value == 'Z':
return timezone.utc
elif value is not None:
offset_mins = int(value[-2:]) if len(value) > 3 else 0
offset = 60 * int(value[1:3]) + offset_mins
if value[0] == '-':
offset = -offset
try:
return timezone(timedelta(minutes=offset))
except ValueError:
raise error()
else:
return None
def parse_date(value: Union[date, StrBytesIntFloat]) -> date:
"""
Parse a date/int/float/string and return a datetime.date.
Raise ValueError if the input is well formatted but not a valid date.
Raise ValueError if the input isn't well formatted.
"""
if isinstance(value, date):
if isinstance(value, datetime):
return value.date()
else:
return value
number = get_numeric(value, 'date')
if number is not None:
return from_unix_seconds(number).date()
if isinstance(value, bytes):
value = value.decode()
match = date_re.match(value) # type: ignore
if match is None:
raise errors.DateError()
kw = {k: int(v) for k, v in match.groupdict().items()}
try:
return date(**kw)
except ValueError:
raise errors.DateError()
def parse_time(value: Union[time, StrBytesIntFloat]) -> time:
"""
Parse a time/string and return a datetime.time.
Raise ValueError if the input is well formatted but not a valid time.
Raise ValueError if the input isn't well formatted, in particular if it contains an offset.
"""
if isinstance(value, time):
return value
number = get_numeric(value, 'time')
if number is not None:
if number >= 86400:
# doesn't make sense since the time time loop back around to 0
raise errors.TimeError()
return (datetime.min + timedelta(seconds=number)).time()
if isinstance(value, bytes):
value = value.decode()
match = time_re.match(value) # type: ignore
if match is None:
raise errors.TimeError()
kw = match.groupdict()
if kw['microsecond']:
kw['microsecond'] = kw['microsecond'].ljust(6, '0')
tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.TimeError)
kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None}
kw_['tzinfo'] = tzinfo
try:
return time(**kw_) # type: ignore
except ValueError:
raise errors.TimeError()
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:
"""
Parse a datetime/int/float/string and return a datetime.datetime.
This function supports time zone offsets. When the input contains one,
the output uses a timezone with a fixed offset from UTC.
Raise ValueError if the input is well formatted but not a valid datetime.
Raise ValueError if the input isn't well formatted.
"""
if isinstance(value, datetime):
return value
number = get_numeric(value, 'datetime')
if number is not None:
return from_unix_seconds(number)
if isinstance(value, bytes):
value = value.decode()
match = datetime_re.match(value) # type: ignore
if match is None:
raise errors.DateTimeError()
kw = match.groupdict()
if kw['microsecond']:
kw['microsecond'] = kw['microsecond'].ljust(6, '0')
tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.DateTimeError)
kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None}
kw_['tzinfo'] = tzinfo
try:
return datetime(**kw_) # type: ignore
except ValueError:
raise errors.DateTimeError()
def parse_duration(value: StrBytesIntFloat) -> timedelta:
"""
Parse a duration int/float/string and return a datetime.timedelta.
The preferred format for durations in Django is '%d %H:%M:%S.%f'.
Also supports ISO 8601 representation.
"""
if isinstance(value, timedelta):
return value
if isinstance(value, (int, float)):
# below code requires a string
value = f'{value:f}'
elif isinstance(value, bytes):
value = value.decode()
try:
match = standard_duration_re.match(value) or iso8601_duration_re.match(value)
except TypeError:
raise TypeError('invalid type; expected timedelta, string, bytes, int or float')
if not match:
raise errors.DurationError()
kw = match.groupdict()
sign = -1 if kw.pop('sign', '+') == '-' else 1
if kw.get('microseconds'):
kw['microseconds'] = kw['microseconds'].ljust(6, '0')
if kw.get('seconds') and kw.get('microseconds') and kw['seconds'].startswith('-'):
kw['microseconds'] = '-' + kw['microseconds']
kw_ = {k: float(v) for k, v in kw.items() if v is not None}
return sign * timedelta(**kw_)

264
lib/pydantic/decorator.py Normal file
View file

@ -0,0 +1,264 @@
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload
from . import validator
from .config import Extra
from .errors import ConfigError
from .main import BaseModel, create_model
from .typing import get_all_type_hints
from .utils import to_camel
__all__ = ('validate_arguments',)
if TYPE_CHECKING:
from .typing import AnyCallable
AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable)
ConfigType = Union[None, Type[Any], Dict[str, Any]]
@overload
def validate_arguments(func: None = None, *, config: 'ConfigType' = None) -> Callable[['AnyCallableT'], 'AnyCallableT']:
...
@overload
def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT':
...
def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any:
"""
Decorator to validate the arguments passed to a function.
"""
def validate(_func: 'AnyCallable') -> 'AnyCallable':
vd = ValidatedFunction(_func, config)
@wraps(_func)
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
return vd.call(*args, **kwargs)
wrapper_function.vd = vd # type: ignore
wrapper_function.validate = vd.init_model_instance # type: ignore
wrapper_function.raw_function = vd.raw_function # type: ignore
wrapper_function.model = vd.model # type: ignore
return wrapper_function
if func:
return validate(func)
else:
return validate
ALT_V_ARGS = 'v__args'
ALT_V_KWARGS = 'v__kwargs'
V_POSITIONAL_ONLY_NAME = 'v__positional_only'
V_DUPLICATE_KWARGS = 'v__duplicate_kwargs'
class ValidatedFunction:
def __init__(self, function: 'AnyCallableT', config: 'ConfigType'): # noqa C901
from inspect import Parameter, signature
parameters: Mapping[str, Parameter] = signature(function).parameters
if parameters.keys() & {ALT_V_ARGS, ALT_V_KWARGS, V_POSITIONAL_ONLY_NAME, V_DUPLICATE_KWARGS}:
raise ConfigError(
f'"{ALT_V_ARGS}", "{ALT_V_KWARGS}", "{V_POSITIONAL_ONLY_NAME}" and "{V_DUPLICATE_KWARGS}" '
f'are not permitted as argument names when using the "{validate_arguments.__name__}" decorator'
)
self.raw_function = function
self.arg_mapping: Dict[int, str] = {}
self.positional_only_args = set()
self.v_args_name = 'args'
self.v_kwargs_name = 'kwargs'
type_hints = get_all_type_hints(function)
takes_args = False
takes_kwargs = False
fields: Dict[str, Tuple[Any, Any]] = {}
for i, (name, p) in enumerate(parameters.items()):
if p.annotation is p.empty:
annotation = Any
else:
annotation = type_hints[name]
default = ... if p.default is p.empty else p.default
if p.kind == Parameter.POSITIONAL_ONLY:
self.arg_mapping[i] = name
fields[name] = annotation, default
fields[V_POSITIONAL_ONLY_NAME] = List[str], None
self.positional_only_args.add(name)
elif p.kind == Parameter.POSITIONAL_OR_KEYWORD:
self.arg_mapping[i] = name
fields[name] = annotation, default
fields[V_DUPLICATE_KWARGS] = List[str], None
elif p.kind == Parameter.KEYWORD_ONLY:
fields[name] = annotation, default
elif p.kind == Parameter.VAR_POSITIONAL:
self.v_args_name = name
fields[name] = Tuple[annotation, ...], None
takes_args = True
else:
assert p.kind == Parameter.VAR_KEYWORD, p.kind
self.v_kwargs_name = name
fields[name] = Dict[str, annotation], None # type: ignore
takes_kwargs = True
# these checks avoid a clash between "args" and a field with that name
if not takes_args and self.v_args_name in fields:
self.v_args_name = ALT_V_ARGS
# same with "kwargs"
if not takes_kwargs and self.v_kwargs_name in fields:
self.v_kwargs_name = ALT_V_KWARGS
if not takes_args:
# we add the field so validation below can raise the correct exception
fields[self.v_args_name] = List[Any], None
if not takes_kwargs:
# same with kwargs
fields[self.v_kwargs_name] = Dict[Any, Any], None
self.create_model(fields, takes_args, takes_kwargs, config)
def init_model_instance(self, *args: Any, **kwargs: Any) -> BaseModel:
values = self.build_values(args, kwargs)
return self.model(**values)
def call(self, *args: Any, **kwargs: Any) -> Any:
m = self.init_model_instance(*args, **kwargs)
return self.execute(m)
def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]:
values: Dict[str, Any] = {}
if args:
arg_iter = enumerate(args)
while True:
try:
i, a = next(arg_iter)
except StopIteration:
break
arg_name = self.arg_mapping.get(i)
if arg_name is not None:
values[arg_name] = a
else:
values[self.v_args_name] = [a] + [a for _, a in arg_iter]
break
var_kwargs: Dict[str, Any] = {}
wrong_positional_args = []
duplicate_kwargs = []
fields_alias = [
field.alias
for name, field in self.model.__fields__.items()
if name not in (self.v_args_name, self.v_kwargs_name)
]
non_var_fields = set(self.model.__fields__) - {self.v_args_name, self.v_kwargs_name}
for k, v in kwargs.items():
if k in non_var_fields or k in fields_alias:
if k in self.positional_only_args:
wrong_positional_args.append(k)
if k in values:
duplicate_kwargs.append(k)
values[k] = v
else:
var_kwargs[k] = v
if var_kwargs:
values[self.v_kwargs_name] = var_kwargs
if wrong_positional_args:
values[V_POSITIONAL_ONLY_NAME] = wrong_positional_args
if duplicate_kwargs:
values[V_DUPLICATE_KWARGS] = duplicate_kwargs
return values
def execute(self, m: BaseModel) -> Any:
d = {k: v for k, v in m._iter() if k in m.__fields_set__ or m.__fields__[k].default_factory}
var_kwargs = d.pop(self.v_kwargs_name, {})
if self.v_args_name in d:
args_: List[Any] = []
in_kwargs = False
kwargs = {}
for name, value in d.items():
if in_kwargs:
kwargs[name] = value
elif name == self.v_args_name:
args_ += value
in_kwargs = True
else:
args_.append(value)
return self.raw_function(*args_, **kwargs, **var_kwargs)
elif self.positional_only_args:
args_ = []
kwargs = {}
for name, value in d.items():
if name in self.positional_only_args:
args_.append(value)
else:
kwargs[name] = value
return self.raw_function(*args_, **kwargs, **var_kwargs)
else:
return self.raw_function(**d, **var_kwargs)
def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None:
pos_args = len(self.arg_mapping)
class CustomConfig:
pass
if not TYPE_CHECKING: # pragma: no branch
if isinstance(config, dict):
CustomConfig = type('Config', (), config) # noqa: F811
elif config is not None:
CustomConfig = config # noqa: F811
if hasattr(CustomConfig, 'fields') or hasattr(CustomConfig, 'alias_generator'):
raise ConfigError(
'Setting the "fields" and "alias_generator" property on custom Config for '
'@validate_arguments is not yet supported, please remove.'
)
class DecoratorBaseModel(BaseModel):
@validator(self.v_args_name, check_fields=False, allow_reuse=True)
def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]:
if takes_args or v is None:
return v
raise TypeError(f'{pos_args} positional arguments expected but {pos_args + len(v)} given')
@validator(self.v_kwargs_name, check_fields=False, allow_reuse=True)
def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if takes_kwargs or v is None:
return v
plural = '' if len(v) == 1 else 's'
keys = ', '.join(map(repr, v.keys()))
raise TypeError(f'unexpected keyword argument{plural}: {keys}')
@validator(V_POSITIONAL_ONLY_NAME, check_fields=False, allow_reuse=True)
def check_positional_only(cls, v: Optional[List[str]]) -> None:
if v is None:
return
plural = '' if len(v) == 1 else 's'
keys = ', '.join(map(repr, v))
raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}')
@validator(V_DUPLICATE_KWARGS, check_fields=False, allow_reuse=True)
def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None:
if v is None:
return
plural = '' if len(v) == 1 else 's'
keys = ', '.join(map(repr, v))
raise TypeError(f'multiple values for argument{plural}: {keys}')
class Config(CustomConfig):
extra = getattr(CustomConfig, 'extra', Extra.forbid)
self.model = create_model(to_camel(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields)

View file

@ -0,0 +1,346 @@
import os
import warnings
from pathlib import Path
from typing import AbstractSet, Any, Callable, ClassVar, Dict, List, Mapping, Optional, Tuple, Type, Union
from .config import BaseConfig, Extra
from .fields import ModelField
from .main import BaseModel
from .typing import StrPath, display_as_type, get_origin, is_union
from .utils import deep_update, path_type, sequence_like
env_file_sentinel = str(object())
SettingsSourceCallable = Callable[['BaseSettings'], Dict[str, Any]]
DotenvType = Union[StrPath, List[StrPath], Tuple[StrPath, ...]]
class SettingsError(ValueError):
pass
class BaseSettings(BaseModel):
"""
Base class for settings, allowing values to be overridden by environment variables.
This is useful in production for secrets you do not wish to save in code, it plays nicely with docker(-compose),
Heroku and any 12 factor app design.
"""
def __init__(
__pydantic_self__,
_env_file: Optional[DotenvType] = env_file_sentinel,
_env_file_encoding: Optional[str] = None,
_env_nested_delimiter: Optional[str] = None,
_secrets_dir: Optional[StrPath] = None,
**values: Any,
) -> None:
# Uses something other than `self` the first arg to allow "self" as a settable attribute
super().__init__(
**__pydantic_self__._build_values(
values,
_env_file=_env_file,
_env_file_encoding=_env_file_encoding,
_env_nested_delimiter=_env_nested_delimiter,
_secrets_dir=_secrets_dir,
)
)
def _build_values(
self,
init_kwargs: Dict[str, Any],
_env_file: Optional[DotenvType] = None,
_env_file_encoding: Optional[str] = None,
_env_nested_delimiter: Optional[str] = None,
_secrets_dir: Optional[StrPath] = None,
) -> Dict[str, Any]:
# Configure built-in sources
init_settings = InitSettingsSource(init_kwargs=init_kwargs)
env_settings = EnvSettingsSource(
env_file=(_env_file if _env_file != env_file_sentinel else self.__config__.env_file),
env_file_encoding=(
_env_file_encoding if _env_file_encoding is not None else self.__config__.env_file_encoding
),
env_nested_delimiter=(
_env_nested_delimiter if _env_nested_delimiter is not None else self.__config__.env_nested_delimiter
),
env_prefix_len=len(self.__config__.env_prefix),
)
file_secret_settings = SecretsSettingsSource(secrets_dir=_secrets_dir or self.__config__.secrets_dir)
# Provide a hook to set built-in sources priority and add / remove sources
sources = self.__config__.customise_sources(
init_settings=init_settings, env_settings=env_settings, file_secret_settings=file_secret_settings
)
if sources:
return deep_update(*reversed([source(self) for source in sources]))
else:
# no one should mean to do this, but I think returning an empty dict is marginally preferable
# to an informative error and much better than a confusing error
return {}
class Config(BaseConfig):
env_prefix: str = ''
env_file: Optional[DotenvType] = None
env_file_encoding: Optional[str] = None
env_nested_delimiter: Optional[str] = None
secrets_dir: Optional[StrPath] = None
validate_all: bool = True
extra: Extra = Extra.forbid
arbitrary_types_allowed: bool = True
case_sensitive: bool = False
@classmethod
def prepare_field(cls, field: ModelField) -> None:
env_names: Union[List[str], AbstractSet[str]]
field_info_from_config = cls.get_field_info(field.name)
env = field_info_from_config.get('env') or field.field_info.extra.get('env')
if env is None:
if field.has_alias:
warnings.warn(
'aliases are no longer used by BaseSettings to define which environment variables to read. '
'Instead use the "env" field setting. '
'See https://pydantic-docs.helpmanual.io/usage/settings/#environment-variable-names',
FutureWarning,
)
env_names = {cls.env_prefix + field.name}
elif isinstance(env, str):
env_names = {env}
elif isinstance(env, (set, frozenset)):
env_names = env
elif sequence_like(env):
env_names = list(env)
else:
raise TypeError(f'invalid field env: {env!r} ({display_as_type(env)}); should be string, list or set')
if not cls.case_sensitive:
env_names = env_names.__class__(n.lower() for n in env_names)
field.field_info.extra['env_names'] = env_names
@classmethod
def customise_sources(
cls,
init_settings: SettingsSourceCallable,
env_settings: SettingsSourceCallable,
file_secret_settings: SettingsSourceCallable,
) -> Tuple[SettingsSourceCallable, ...]:
return init_settings, env_settings, file_secret_settings
@classmethod
def parse_env_var(cls, field_name: str, raw_val: str) -> Any:
return cls.json_loads(raw_val)
# populated by the metaclass using the Config class defined above, annotated here to help IDEs only
__config__: ClassVar[Type[Config]]
class InitSettingsSource:
__slots__ = ('init_kwargs',)
def __init__(self, init_kwargs: Dict[str, Any]):
self.init_kwargs = init_kwargs
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
return self.init_kwargs
def __repr__(self) -> str:
return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})'
class EnvSettingsSource:
__slots__ = ('env_file', 'env_file_encoding', 'env_nested_delimiter', 'env_prefix_len')
def __init__(
self,
env_file: Optional[DotenvType],
env_file_encoding: Optional[str],
env_nested_delimiter: Optional[str] = None,
env_prefix_len: int = 0,
):
self.env_file: Optional[DotenvType] = env_file
self.env_file_encoding: Optional[str] = env_file_encoding
self.env_nested_delimiter: Optional[str] = env_nested_delimiter
self.env_prefix_len: int = env_prefix_len
def __call__(self, settings: BaseSettings) -> Dict[str, Any]: # noqa C901
"""
Build environment variables suitable for passing to the Model.
"""
d: Dict[str, Any] = {}
if settings.__config__.case_sensitive:
env_vars: Mapping[str, Optional[str]] = os.environ
else:
env_vars = {k.lower(): v for k, v in os.environ.items()}
dotenv_vars = self._read_env_files(settings.__config__.case_sensitive)
if dotenv_vars:
env_vars = {**dotenv_vars, **env_vars}
for field in settings.__fields__.values():
env_val: Optional[str] = None
for env_name in field.field_info.extra['env_names']:
env_val = env_vars.get(env_name)
if env_val is not None:
break
is_complex, allow_parse_failure = self.field_is_complex(field)
if is_complex:
if env_val is None:
# field is complex but no value found so far, try explode_env_vars
env_val_built = self.explode_env_vars(field, env_vars)
if env_val_built:
d[field.alias] = env_val_built
else:
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
try:
env_val = settings.__config__.parse_env_var(field.name, env_val)
except ValueError as e:
if not allow_parse_failure:
raise SettingsError(f'error parsing env var "{env_name}"') from e
if isinstance(env_val, dict):
d[field.alias] = deep_update(env_val, self.explode_env_vars(field, env_vars))
else:
d[field.alias] = env_val
elif env_val is not None:
# simplest case, field is not complex, we only need to add the value if it was found
d[field.alias] = env_val
return d
def _read_env_files(self, case_sensitive: bool) -> Dict[str, Optional[str]]:
env_files = self.env_file
if env_files is None:
return {}
if isinstance(env_files, (str, os.PathLike)):
env_files = [env_files]
dotenv_vars = {}
for env_file in env_files:
env_path = Path(env_file).expanduser()
if env_path.is_file():
dotenv_vars.update(
read_env_file(env_path, encoding=self.env_file_encoding, case_sensitive=case_sensitive)
)
return dotenv_vars
def field_is_complex(self, field: ModelField) -> Tuple[bool, bool]:
"""
Find out if a field is complex, and if so whether JSON errors should be ignored
"""
if field.is_complex():
allow_parse_failure = False
elif is_union(get_origin(field.type_)) and field.sub_fields and any(f.is_complex() for f in field.sub_fields):
allow_parse_failure = True
else:
return False, False
return True, allow_parse_failure
def explode_env_vars(self, field: ModelField, env_vars: Mapping[str, Optional[str]]) -> Dict[str, Any]:
"""
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
This is applied to a single field, hence filtering by env_var prefix.
"""
prefixes = [f'{env_name}{self.env_nested_delimiter}' for env_name in field.field_info.extra['env_names']]
result: Dict[str, Any] = {}
for env_name, env_val in env_vars.items():
if not any(env_name.startswith(prefix) for prefix in prefixes):
continue
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
env_name_without_prefix = env_name[self.env_prefix_len :]
_, *keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter)
env_var = result
for key in keys:
env_var = env_var.setdefault(key, {})
env_var[last_key] = env_val
return result
def __repr__(self) -> str:
return (
f'EnvSettingsSource(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, '
f'env_nested_delimiter={self.env_nested_delimiter!r})'
)
class SecretsSettingsSource:
__slots__ = ('secrets_dir',)
def __init__(self, secrets_dir: Optional[StrPath]):
self.secrets_dir: Optional[StrPath] = secrets_dir
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
"""
Build fields from "secrets" files.
"""
secrets: Dict[str, Optional[str]] = {}
if self.secrets_dir is None:
return secrets
secrets_path = Path(self.secrets_dir).expanduser()
if not secrets_path.exists():
warnings.warn(f'directory "{secrets_path}" does not exist')
return secrets
if not secrets_path.is_dir():
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type(secrets_path)}')
for field in settings.__fields__.values():
for env_name in field.field_info.extra['env_names']:
path = find_case_path(secrets_path, env_name, settings.__config__.case_sensitive)
if not path:
# path does not exist, we curently don't return a warning for this
continue
if path.is_file():
secret_value = path.read_text().strip()
if field.is_complex():
try:
secret_value = settings.__config__.parse_env_var(field.name, secret_value)
except ValueError as e:
raise SettingsError(f'error parsing env var "{env_name}"') from e
secrets[field.alias] = secret_value
else:
warnings.warn(
f'attempted to load secret file "{path}" but found a {path_type(path)} instead.',
stacklevel=4,
)
return secrets
def __repr__(self) -> str:
return f'SecretsSettingsSource(secrets_dir={self.secrets_dir!r})'
def read_env_file(
file_path: StrPath, *, encoding: str = None, case_sensitive: bool = False
) -> Dict[str, Optional[str]]:
try:
from dotenv import dotenv_values
except ImportError as e:
raise ImportError('python-dotenv is not installed, run `pip install pydantic[dotenv]`') from e
file_vars: Dict[str, Optional[str]] = dotenv_values(file_path, encoding=encoding or 'utf8')
if not case_sensitive:
return {k.lower(): v for k, v in file_vars.items()}
else:
return file_vars
def find_case_path(dir_path: Path, file_name: str, case_sensitive: bool) -> Optional[Path]:
"""
Find a file within path's directory matching filename, optionally ignoring case.
"""
for f in dir_path.iterdir():
if f.name == file_name:
return f
elif not case_sensitive and f.name.lower() == file_name.lower():
return f
return None

View file

@ -0,0 +1,162 @@
import json
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple, Type, Union
from .json import pydantic_encoder
from .utils import Representation
if TYPE_CHECKING:
from typing_extensions import TypedDict
from .config import BaseConfig
from .types import ModelOrDc
from .typing import ReprArgs
Loc = Tuple[Union[int, str], ...]
class _ErrorDictRequired(TypedDict):
loc: Loc
msg: str
type: str
class ErrorDict(_ErrorDictRequired, total=False):
ctx: Dict[str, Any]
__all__ = 'ErrorWrapper', 'ValidationError'
class ErrorWrapper(Representation):
__slots__ = 'exc', '_loc'
def __init__(self, exc: Exception, loc: Union[str, 'Loc']) -> None:
self.exc = exc
self._loc = loc
def loc_tuple(self) -> 'Loc':
if isinstance(self._loc, tuple):
return self._loc
else:
return (self._loc,)
def __repr_args__(self) -> 'ReprArgs':
return [('exc', self.exc), ('loc', self.loc_tuple())]
# ErrorList is something like Union[List[Union[List[ErrorWrapper], ErrorWrapper]], ErrorWrapper]
# but recursive, therefore just use:
ErrorList = Union[Sequence[Any], ErrorWrapper]
class ValidationError(Representation, ValueError):
__slots__ = 'raw_errors', 'model', '_error_cache'
def __init__(self, errors: Sequence[ErrorList], model: 'ModelOrDc') -> None:
self.raw_errors = errors
self.model = model
self._error_cache: Optional[List['ErrorDict']] = None
def errors(self) -> List['ErrorDict']:
if self._error_cache is None:
try:
config = self.model.__config__ # type: ignore
except AttributeError:
config = self.model.__pydantic_model__.__config__ # type: ignore
self._error_cache = list(flatten_errors(self.raw_errors, config))
return self._error_cache
def json(self, *, indent: Union[None, int, str] = 2) -> str:
return json.dumps(self.errors(), indent=indent, default=pydantic_encoder)
def __str__(self) -> str:
errors = self.errors()
no_errors = len(errors)
return (
f'{no_errors} validation error{"" if no_errors == 1 else "s"} for {self.model.__name__}\n'
f'{display_errors(errors)}'
)
def __repr_args__(self) -> 'ReprArgs':
return [('model', self.model.__name__), ('errors', self.errors())]
def display_errors(errors: List['ErrorDict']) -> str:
return '\n'.join(f'{_display_error_loc(e)}\n {e["msg"]} ({_display_error_type_and_ctx(e)})' for e in errors)
def _display_error_loc(error: 'ErrorDict') -> str:
return ' -> '.join(str(e) for e in error['loc'])
def _display_error_type_and_ctx(error: 'ErrorDict') -> str:
t = 'type=' + error['type']
ctx = error.get('ctx')
if ctx:
return t + ''.join(f'; {k}={v}' for k, v in ctx.items())
else:
return t
def flatten_errors(
errors: Sequence[Any], config: Type['BaseConfig'], loc: Optional['Loc'] = None
) -> Generator['ErrorDict', None, None]:
for error in errors:
if isinstance(error, ErrorWrapper):
if loc:
error_loc = loc + error.loc_tuple()
else:
error_loc = error.loc_tuple()
if isinstance(error.exc, ValidationError):
yield from flatten_errors(error.exc.raw_errors, config, error_loc)
else:
yield error_dict(error.exc, config, error_loc)
elif isinstance(error, list):
yield from flatten_errors(error, config, loc=loc)
else:
raise RuntimeError(f'Unknown error object: {error}')
def error_dict(exc: Exception, config: Type['BaseConfig'], loc: 'Loc') -> 'ErrorDict':
type_ = get_exc_type(exc.__class__)
msg_template = config.error_msg_templates.get(type_) or getattr(exc, 'msg_template', None)
ctx = exc.__dict__
if msg_template:
msg = msg_template.format(**ctx)
else:
msg = str(exc)
d: 'ErrorDict' = {'loc': loc, 'msg': msg, 'type': type_}
if ctx:
d['ctx'] = ctx
return d
_EXC_TYPE_CACHE: Dict[Type[Exception], str] = {}
def get_exc_type(cls: Type[Exception]) -> str:
# slightly more efficient than using lru_cache since we don't need to worry about the cache filling up
try:
return _EXC_TYPE_CACHE[cls]
except KeyError:
r = _get_exc_type(cls)
_EXC_TYPE_CACHE[cls] = r
return r
def _get_exc_type(cls: Type[Exception]) -> str:
if issubclass(cls, AssertionError):
return 'assertion_error'
base_name = 'type_error' if issubclass(cls, TypeError) else 'value_error'
if cls in (TypeError, ValueError):
# just TypeError or ValueError, no extra code
return base_name
# if it's not a TypeError or ValueError, we just take the lowercase of the exception name
# no chaining or snake case logic, use "code" for more complex error types.
code = getattr(cls, 'code', None) or cls.__name__.replace('Error', '').lower()
return base_name + '.' + code

646
lib/pydantic/errors.py Normal file
View file

@ -0,0 +1,646 @@
from decimal import Decimal
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Sequence, Set, Tuple, Type, Union
from .typing import display_as_type
if TYPE_CHECKING:
from .typing import DictStrAny
# explicitly state exports to avoid "from .errors import *" also importing Decimal, Path etc.
__all__ = (
'PydanticTypeError',
'PydanticValueError',
'ConfigError',
'MissingError',
'ExtraError',
'NoneIsNotAllowedError',
'NoneIsAllowedError',
'WrongConstantError',
'NotNoneError',
'BoolError',
'BytesError',
'DictError',
'EmailError',
'UrlError',
'UrlSchemeError',
'UrlSchemePermittedError',
'UrlUserInfoError',
'UrlHostError',
'UrlHostTldError',
'UrlPortError',
'UrlExtraError',
'EnumError',
'IntEnumError',
'EnumMemberError',
'IntegerError',
'FloatError',
'PathError',
'PathNotExistsError',
'PathNotAFileError',
'PathNotADirectoryError',
'PyObjectError',
'SequenceError',
'ListError',
'SetError',
'FrozenSetError',
'TupleError',
'TupleLengthError',
'ListMinLengthError',
'ListMaxLengthError',
'ListUniqueItemsError',
'SetMinLengthError',
'SetMaxLengthError',
'FrozenSetMinLengthError',
'FrozenSetMaxLengthError',
'AnyStrMinLengthError',
'AnyStrMaxLengthError',
'StrError',
'StrRegexError',
'NumberNotGtError',
'NumberNotGeError',
'NumberNotLtError',
'NumberNotLeError',
'NumberNotMultipleError',
'DecimalError',
'DecimalIsNotFiniteError',
'DecimalMaxDigitsError',
'DecimalMaxPlacesError',
'DecimalWholeDigitsError',
'DateTimeError',
'DateError',
'DateNotInThePastError',
'DateNotInTheFutureError',
'TimeError',
'DurationError',
'HashableError',
'UUIDError',
'UUIDVersionError',
'ArbitraryTypeError',
'ClassError',
'SubclassError',
'JsonError',
'JsonTypeError',
'PatternError',
'DataclassTypeError',
'CallableError',
'IPvAnyAddressError',
'IPvAnyInterfaceError',
'IPvAnyNetworkError',
'IPv4AddressError',
'IPv6AddressError',
'IPv4NetworkError',
'IPv6NetworkError',
'IPv4InterfaceError',
'IPv6InterfaceError',
'ColorError',
'StrictBoolError',
'NotDigitError',
'LuhnValidationError',
'InvalidLengthForBrand',
'InvalidByteSize',
'InvalidByteSizeUnit',
'MissingDiscriminator',
'InvalidDiscriminator',
)
def cls_kwargs(cls: Type['PydanticErrorMixin'], ctx: 'DictStrAny') -> 'PydanticErrorMixin':
"""
For built-in exceptions like ValueError or TypeError, we need to implement
__reduce__ to override the default behaviour (instead of __getstate__/__setstate__)
By default pickle protocol 2 calls `cls.__new__(cls, *args)`.
Since we only use kwargs, we need a little constructor to change that.
Note: the callable can't be a lambda as pickle looks in the namespace to find it
"""
return cls(**ctx)
class PydanticErrorMixin:
code: str
msg_template: str
def __init__(self, **ctx: Any) -> None:
self.__dict__ = ctx
def __str__(self) -> str:
return self.msg_template.format(**self.__dict__)
def __reduce__(self) -> Tuple[Callable[..., 'PydanticErrorMixin'], Tuple[Type['PydanticErrorMixin'], 'DictStrAny']]:
return cls_kwargs, (self.__class__, self.__dict__)
class PydanticTypeError(PydanticErrorMixin, TypeError):
pass
class PydanticValueError(PydanticErrorMixin, ValueError):
pass
class ConfigError(RuntimeError):
pass
class MissingError(PydanticValueError):
msg_template = 'field required'
class ExtraError(PydanticValueError):
msg_template = 'extra fields not permitted'
class NoneIsNotAllowedError(PydanticTypeError):
code = 'none.not_allowed'
msg_template = 'none is not an allowed value'
class NoneIsAllowedError(PydanticTypeError):
code = 'none.allowed'
msg_template = 'value is not none'
class WrongConstantError(PydanticValueError):
code = 'const'
def __str__(self) -> str:
permitted = ', '.join(repr(v) for v in self.permitted) # type: ignore
return f'unexpected value; permitted: {permitted}'
class NotNoneError(PydanticTypeError):
code = 'not_none'
msg_template = 'value is not None'
class BoolError(PydanticTypeError):
msg_template = 'value could not be parsed to a boolean'
class BytesError(PydanticTypeError):
msg_template = 'byte type expected'
class DictError(PydanticTypeError):
msg_template = 'value is not a valid dict'
class EmailError(PydanticValueError):
msg_template = 'value is not a valid email address'
class UrlError(PydanticValueError):
code = 'url'
class UrlSchemeError(UrlError):
code = 'url.scheme'
msg_template = 'invalid or missing URL scheme'
class UrlSchemePermittedError(UrlError):
code = 'url.scheme'
msg_template = 'URL scheme not permitted'
def __init__(self, allowed_schemes: Set[str]):
super().__init__(allowed_schemes=allowed_schemes)
class UrlUserInfoError(UrlError):
code = 'url.userinfo'
msg_template = 'userinfo required in URL but missing'
class UrlHostError(UrlError):
code = 'url.host'
msg_template = 'URL host invalid'
class UrlHostTldError(UrlError):
code = 'url.host'
msg_template = 'URL host invalid, top level domain required'
class UrlPortError(UrlError):
code = 'url.port'
msg_template = 'URL port invalid, port cannot exceed 65535'
class UrlExtraError(UrlError):
code = 'url.extra'
msg_template = 'URL invalid, extra characters found after valid URL: {extra!r}'
class EnumMemberError(PydanticTypeError):
code = 'enum'
def __str__(self) -> str:
permitted = ', '.join(repr(v.value) for v in self.enum_values) # type: ignore
return f'value is not a valid enumeration member; permitted: {permitted}'
class IntegerError(PydanticTypeError):
msg_template = 'value is not a valid integer'
class FloatError(PydanticTypeError):
msg_template = 'value is not a valid float'
class PathError(PydanticTypeError):
msg_template = 'value is not a valid path'
class _PathValueError(PydanticValueError):
def __init__(self, *, path: Path) -> None:
super().__init__(path=str(path))
class PathNotExistsError(_PathValueError):
code = 'path.not_exists'
msg_template = 'file or directory at path "{path}" does not exist'
class PathNotAFileError(_PathValueError):
code = 'path.not_a_file'
msg_template = 'path "{path}" does not point to a file'
class PathNotADirectoryError(_PathValueError):
code = 'path.not_a_directory'
msg_template = 'path "{path}" does not point to a directory'
class PyObjectError(PydanticTypeError):
msg_template = 'ensure this value contains valid import path or valid callable: {error_message}'
class SequenceError(PydanticTypeError):
msg_template = 'value is not a valid sequence'
class IterableError(PydanticTypeError):
msg_template = 'value is not a valid iterable'
class ListError(PydanticTypeError):
msg_template = 'value is not a valid list'
class SetError(PydanticTypeError):
msg_template = 'value is not a valid set'
class FrozenSetError(PydanticTypeError):
msg_template = 'value is not a valid frozenset'
class DequeError(PydanticTypeError):
msg_template = 'value is not a valid deque'
class TupleError(PydanticTypeError):
msg_template = 'value is not a valid tuple'
class TupleLengthError(PydanticValueError):
code = 'tuple.length'
msg_template = 'wrong tuple length {actual_length}, expected {expected_length}'
def __init__(self, *, actual_length: int, expected_length: int) -> None:
super().__init__(actual_length=actual_length, expected_length=expected_length)
class ListMinLengthError(PydanticValueError):
code = 'list.min_items'
msg_template = 'ensure this value has at least {limit_value} items'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class ListMaxLengthError(PydanticValueError):
code = 'list.max_items'
msg_template = 'ensure this value has at most {limit_value} items'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class ListUniqueItemsError(PydanticValueError):
code = 'list.unique_items'
msg_template = 'the list has duplicated items'
class SetMinLengthError(PydanticValueError):
code = 'set.min_items'
msg_template = 'ensure this value has at least {limit_value} items'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class SetMaxLengthError(PydanticValueError):
code = 'set.max_items'
msg_template = 'ensure this value has at most {limit_value} items'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class FrozenSetMinLengthError(PydanticValueError):
code = 'frozenset.min_items'
msg_template = 'ensure this value has at least {limit_value} items'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class FrozenSetMaxLengthError(PydanticValueError):
code = 'frozenset.max_items'
msg_template = 'ensure this value has at most {limit_value} items'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class AnyStrMinLengthError(PydanticValueError):
code = 'any_str.min_length'
msg_template = 'ensure this value has at least {limit_value} characters'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class AnyStrMaxLengthError(PydanticValueError):
code = 'any_str.max_length'
msg_template = 'ensure this value has at most {limit_value} characters'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class StrError(PydanticTypeError):
msg_template = 'str type expected'
class StrRegexError(PydanticValueError):
code = 'str.regex'
msg_template = 'string does not match regex "{pattern}"'
def __init__(self, *, pattern: str) -> None:
super().__init__(pattern=pattern)
class _NumberBoundError(PydanticValueError):
def __init__(self, *, limit_value: Union[int, float, Decimal]) -> None:
super().__init__(limit_value=limit_value)
class NumberNotGtError(_NumberBoundError):
code = 'number.not_gt'
msg_template = 'ensure this value is greater than {limit_value}'
class NumberNotGeError(_NumberBoundError):
code = 'number.not_ge'
msg_template = 'ensure this value is greater than or equal to {limit_value}'
class NumberNotLtError(_NumberBoundError):
code = 'number.not_lt'
msg_template = 'ensure this value is less than {limit_value}'
class NumberNotLeError(_NumberBoundError):
code = 'number.not_le'
msg_template = 'ensure this value is less than or equal to {limit_value}'
class NumberNotFiniteError(PydanticValueError):
code = 'number.not_finite_number'
msg_template = 'ensure this value is a finite number'
class NumberNotMultipleError(PydanticValueError):
code = 'number.not_multiple'
msg_template = 'ensure this value is a multiple of {multiple_of}'
def __init__(self, *, multiple_of: Union[int, float, Decimal]) -> None:
super().__init__(multiple_of=multiple_of)
class DecimalError(PydanticTypeError):
msg_template = 'value is not a valid decimal'
class DecimalIsNotFiniteError(PydanticValueError):
code = 'decimal.not_finite'
msg_template = 'value is not a valid decimal'
class DecimalMaxDigitsError(PydanticValueError):
code = 'decimal.max_digits'
msg_template = 'ensure that there are no more than {max_digits} digits in total'
def __init__(self, *, max_digits: int) -> None:
super().__init__(max_digits=max_digits)
class DecimalMaxPlacesError(PydanticValueError):
code = 'decimal.max_places'
msg_template = 'ensure that there are no more than {decimal_places} decimal places'
def __init__(self, *, decimal_places: int) -> None:
super().__init__(decimal_places=decimal_places)
class DecimalWholeDigitsError(PydanticValueError):
code = 'decimal.whole_digits'
msg_template = 'ensure that there are no more than {whole_digits} digits before the decimal point'
def __init__(self, *, whole_digits: int) -> None:
super().__init__(whole_digits=whole_digits)
class DateTimeError(PydanticValueError):
msg_template = 'invalid datetime format'
class DateError(PydanticValueError):
msg_template = 'invalid date format'
class DateNotInThePastError(PydanticValueError):
code = 'date.not_in_the_past'
msg_template = 'date is not in the past'
class DateNotInTheFutureError(PydanticValueError):
code = 'date.not_in_the_future'
msg_template = 'date is not in the future'
class TimeError(PydanticValueError):
msg_template = 'invalid time format'
class DurationError(PydanticValueError):
msg_template = 'invalid duration format'
class HashableError(PydanticTypeError):
msg_template = 'value is not a valid hashable'
class UUIDError(PydanticTypeError):
msg_template = 'value is not a valid uuid'
class UUIDVersionError(PydanticValueError):
code = 'uuid.version'
msg_template = 'uuid version {required_version} expected'
def __init__(self, *, required_version: int) -> None:
super().__init__(required_version=required_version)
class ArbitraryTypeError(PydanticTypeError):
code = 'arbitrary_type'
msg_template = 'instance of {expected_arbitrary_type} expected'
def __init__(self, *, expected_arbitrary_type: Type[Any]) -> None:
super().__init__(expected_arbitrary_type=display_as_type(expected_arbitrary_type))
class ClassError(PydanticTypeError):
code = 'class'
msg_template = 'a class is expected'
class SubclassError(PydanticTypeError):
code = 'subclass'
msg_template = 'subclass of {expected_class} expected'
def __init__(self, *, expected_class: Type[Any]) -> None:
super().__init__(expected_class=display_as_type(expected_class))
class JsonError(PydanticValueError):
msg_template = 'Invalid JSON'
class JsonTypeError(PydanticTypeError):
code = 'json'
msg_template = 'JSON object must be str, bytes or bytearray'
class PatternError(PydanticValueError):
code = 'regex_pattern'
msg_template = 'Invalid regular expression'
class DataclassTypeError(PydanticTypeError):
code = 'dataclass'
msg_template = 'instance of {class_name}, tuple or dict expected'
class CallableError(PydanticTypeError):
msg_template = '{value} is not callable'
class EnumError(PydanticTypeError):
code = 'enum_instance'
msg_template = '{value} is not a valid Enum instance'
class IntEnumError(PydanticTypeError):
code = 'int_enum_instance'
msg_template = '{value} is not a valid IntEnum instance'
class IPvAnyAddressError(PydanticValueError):
msg_template = 'value is not a valid IPv4 or IPv6 address'
class IPvAnyInterfaceError(PydanticValueError):
msg_template = 'value is not a valid IPv4 or IPv6 interface'
class IPvAnyNetworkError(PydanticValueError):
msg_template = 'value is not a valid IPv4 or IPv6 network'
class IPv4AddressError(PydanticValueError):
msg_template = 'value is not a valid IPv4 address'
class IPv6AddressError(PydanticValueError):
msg_template = 'value is not a valid IPv6 address'
class IPv4NetworkError(PydanticValueError):
msg_template = 'value is not a valid IPv4 network'
class IPv6NetworkError(PydanticValueError):
msg_template = 'value is not a valid IPv6 network'
class IPv4InterfaceError(PydanticValueError):
msg_template = 'value is not a valid IPv4 interface'
class IPv6InterfaceError(PydanticValueError):
msg_template = 'value is not a valid IPv6 interface'
class ColorError(PydanticValueError):
msg_template = 'value is not a valid color: {reason}'
class StrictBoolError(PydanticValueError):
msg_template = 'value is not a valid boolean'
class NotDigitError(PydanticValueError):
code = 'payment_card_number.digits'
msg_template = 'card number is not all digits'
class LuhnValidationError(PydanticValueError):
code = 'payment_card_number.luhn_check'
msg_template = 'card number is not luhn valid'
class InvalidLengthForBrand(PydanticValueError):
code = 'payment_card_number.invalid_length_for_brand'
msg_template = 'Length for a {brand} card must be {required_length}'
class InvalidByteSize(PydanticValueError):
msg_template = 'could not parse value and unit from byte string'
class InvalidByteSizeUnit(PydanticValueError):
msg_template = 'could not interpret byte unit: {unit}'
class MissingDiscriminator(PydanticValueError):
code = 'discriminated_union.missing_discriminator'
msg_template = 'Discriminator {discriminator_key!r} is missing in value'
class InvalidDiscriminator(PydanticValueError):
code = 'discriminated_union.invalid_discriminator'
msg_template = (
'No match for discriminator {discriminator_key!r} and value {discriminator_value!r} '
'(allowed values: {allowed_values})'
)
def __init__(self, *, discriminator_key: str, discriminator_value: Any, allowed_values: Sequence[Any]) -> None:
super().__init__(
discriminator_key=discriminator_key,
discriminator_value=discriminator_value,
allowed_values=', '.join(map(repr, allowed_values)),
)

1247
lib/pydantic/fields.py Normal file

File diff suppressed because it is too large Load diff

364
lib/pydantic/generics.py Normal file
View file

@ -0,0 +1,364 @@
import sys
import typing
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Generic,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from typing_extensions import Annotated
from .class_validators import gather_all_validators
from .fields import DeferredType
from .main import BaseModel, create_model
from .types import JsonWrapper
from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base
from .utils import LimitedDict, all_identical, lenient_issubclass
GenericModelT = TypeVar('GenericModelT', bound='GenericModel')
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type
Parametrization = Mapping[TypeVarType, Type[Any]]
_generic_types_cache: LimitedDict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = LimitedDict()
# _assigned_parameters is a Mapping from parametrized version of generic models to assigned types of parametrizations
# as captured during construction of the class (not instances).
# E.g., for generic model `Model[A, B]`, when parametrized model `Model[int, str]` is created,
# `Model[int, str]`: {A: int, B: str}` will be stored in `_assigned_parameters`.
# (This information is only otherwise available after creation from the class name string).
_assigned_parameters: LimitedDict[Type[Any], Parametrization] = LimitedDict()
class GenericModel(BaseModel):
__slots__ = ()
__concrete__: ClassVar[bool] = False
if TYPE_CHECKING:
# Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with
# `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of
# `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below.
__parameters__: ClassVar[Tuple[TypeVarType, ...]]
# Setting the return type as Type[Any] instead of Type[BaseModel] prevents PyCharm warnings
def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[Type[Any], ...]]) -> Type[Any]:
"""Instantiates a new class from a generic class `cls` and type variables `params`.
:param params: Tuple of types the class . Given a generic class
`Model` with 2 type variables and a concrete model `Model[str, int]`,
the value `(str, int)` would be passed to `params`.
:return: New model class inheriting from `cls` with instantiated
types described by `params`. If no parameters are given, `cls` is
returned as is.
"""
def _cache_key(_params: Any) -> Tuple[Type[GenericModelT], Any, Tuple[Any, ...]]:
return cls, _params, get_args(_params)
cached = _generic_types_cache.get(_cache_key(params))
if cached is not None:
return cached
if cls.__concrete__ and Generic not in cls.__bases__:
raise TypeError('Cannot parameterize a concrete instantiation of a generic model')
if not isinstance(params, tuple):
params = (params,)
if cls is GenericModel and any(isinstance(param, TypeVar) for param in params):
raise TypeError('Type parameters should be placed on typing.Generic, not GenericModel')
if not hasattr(cls, '__parameters__'):
raise TypeError(f'Type {cls.__name__} must inherit from typing.Generic before being parameterized')
check_parameters_count(cls, params)
# Build map from generic typevars to passed params
typevars_map: Dict[TypeVarType, Type[Any]] = dict(zip(cls.__parameters__, params))
if all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map:
return cls # if arguments are equal to parameters it's the same object
# Create new model with original model as parent inserting fields with DeferredType.
model_name = cls.__concrete_name__(params)
validators = gather_all_validators(cls)
type_hints = get_all_type_hints(cls).items()
instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar}
fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__}
model_module, called_globally = get_caller_frame_info()
created_model = cast(
Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes
create_model(
model_name,
__module__=model_module or cls.__module__,
__base__=(cls,) + tuple(cls.__parameterized_bases__(typevars_map)),
__config__=None,
__validators__=validators,
__cls_kwargs__=None,
**fields,
),
)
_assigned_parameters[created_model] = typevars_map
if called_globally: # create global reference and therefore allow pickling
object_by_reference = None
reference_name = model_name
reference_module_globals = sys.modules[created_model.__module__].__dict__
while object_by_reference is not created_model:
object_by_reference = reference_module_globals.setdefault(reference_name, created_model)
reference_name += '_'
created_model.Config = cls.Config
# Find any typevars that are still present in the model.
# If none are left, the model is fully "concrete", otherwise the new
# class is a generic class as well taking the found typevars as
# parameters.
new_params = tuple(
{param: None for param in iter_contained_typevars(typevars_map.values())}
) # use dict as ordered set
created_model.__concrete__ = not new_params
if new_params:
created_model.__parameters__ = new_params
# Save created model in cache so we don't end up creating duplicate
# models that should be identical.
_generic_types_cache[_cache_key(params)] = created_model
if len(params) == 1:
_generic_types_cache[_cache_key(params[0])] = created_model
# Recursively walk class type hints and replace generic typevars
# with concrete types that were passed.
_prepare_model_fields(created_model, fields, instance_type_hints, typevars_map)
return created_model
@classmethod
def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str:
"""Compute class name for child classes.
:param params: Tuple of types the class . Given a generic class
`Model` with 2 type variables and a concrete model `Model[str, int]`,
the value `(str, int)` would be passed to `params`.
:return: String representing a the new class where `params` are
passed to `cls` as type variables.
This method can be overridden to achieve a custom naming scheme for GenericModels.
"""
param_names = [display_as_type(param) for param in params]
params_component = ', '.join(param_names)
return f'{cls.__name__}[{params_component}]'
@classmethod
def __parameterized_bases__(cls, typevars_map: Parametrization) -> Iterator[Type[Any]]:
"""
Returns unbound bases of cls parameterised to given type variables
:param typevars_map: Dictionary of type applications for binding subclasses.
Given a generic class `Model` with 2 type variables [S, T]
and a concrete model `Model[str, int]`,
the value `{S: str, T: int}` would be passed to `typevars_map`.
:return: an iterator of generic sub classes, parameterised by `typevars_map`
and other assigned parameters of `cls`
e.g.:
```
class A(GenericModel, Generic[T]):
...
class B(A[V], Generic[V]):
...
assert A[int] in B.__parameterized_bases__({V: int})
```
"""
def build_base_model(
base_model: Type[GenericModel], mapped_types: Parametrization
) -> Iterator[Type[GenericModel]]:
base_parameters = tuple(mapped_types[param] for param in base_model.__parameters__)
parameterized_base = base_model.__class_getitem__(base_parameters)
if parameterized_base is base_model or parameterized_base is cls:
# Avoid duplication in MRO
return
yield parameterized_base
for base_model in cls.__bases__:
if not issubclass(base_model, GenericModel):
# not a class that can be meaningfully parameterized
continue
elif not getattr(base_model, '__parameters__', None):
# base_model is "GenericModel" (and has no __parameters__)
# or
# base_model is already concrete, and will be included transitively via cls.
continue
elif cls in _assigned_parameters:
if base_model in _assigned_parameters:
# cls is partially parameterised but not from base_model
# e.g. cls = B[S], base_model = A[S]
# B[S][int] should subclass A[int], (and will be transitively via B[int])
# but it's not viable to consistently subclass types with arbitrary construction
# So don't attempt to include A[S][int]
continue
else: # base_model not in _assigned_parameters:
# cls is partially parameterized, base_model is original generic
# e.g. cls = B[str, T], base_model = B[S, T]
# Need to determine the mapping for the base_model parameters
mapped_types: Parametrization = {
key: typevars_map.get(value, value) for key, value in _assigned_parameters[cls].items()
}
yield from build_base_model(base_model, mapped_types)
else:
# cls is base generic, so base_class has a distinct base
# can construct the Parameterised base model using typevars_map directly
yield from build_base_model(base_model, typevars_map)
def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any:
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
:param type_: Any type, class or generic alias
:param type_map: Mapping from `TypeVar` instance to concrete types.
:return: New type representing the basic structure of `type_` with all
`typevar_map` keys recursively replaced.
>>> replace_types(Tuple[str, Union[List[str], float]], {str: int})
Tuple[int, Union[List[int], float]]
"""
if not type_map:
return type_
type_args = get_args(type_)
origin_type = get_origin(type_)
if origin_type is Annotated:
annotated_type, *annotations = type_args
return Annotated[replace_types(annotated_type, type_map), tuple(annotations)]
# Having type args is a good indicator that this is a typing module
# class instantiation or a generic alias of some sort.
if type_args:
resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
if all_identical(type_args, resolved_type_args):
# If all arguments are the same, there is no need to modify the
# type or create a new object at all
return type_
if (
origin_type is not None
and isinstance(type_, typing_base)
and not isinstance(origin_type, typing_base)
and getattr(type_, '_name', None) is not None
):
# In python < 3.9 generic aliases don't exist so any of these like `list`,
# `type` or `collections.abc.Callable` need to be translated.
# See: https://www.python.org/dev/peps/pep-0585
origin_type = getattr(typing, type_._name)
assert origin_type is not None
return origin_type[resolved_type_args]
# We handle pydantic generic models separately as they don't have the same
# semantics as "typing" classes or generic aliases
if not origin_type and lenient_issubclass(type_, GenericModel) and not type_.__concrete__:
type_args = type_.__parameters__
resolved_type_args = tuple(replace_types(t, type_map) for t in type_args)
if all_identical(type_args, resolved_type_args):
return type_
return type_[resolved_type_args]
# Handle special case for typehints that can have lists as arguments.
# `typing.Callable[[int, str], int]` is an example for this.
if isinstance(type_, (List, list)):
resolved_list = list(replace_types(element, type_map) for element in type_)
if all_identical(type_, resolved_list):
return type_
return resolved_list
# For JsonWrapperValue, need to handle its inner type to allow correct parsing
# of generic Json arguments like Json[T]
if not origin_type and lenient_issubclass(type_, JsonWrapper):
type_.inner_type = replace_types(type_.inner_type, type_map)
return type_
# If all else fails, we try to resolve the type directly and otherwise just
# return the input with no modifications.
return type_map.get(type_, type_)
def check_parameters_count(cls: Type[GenericModel], parameters: Tuple[Any, ...]) -> None:
actual = len(parameters)
expected = len(cls.__parameters__)
if actual != expected:
description = 'many' if actual > expected else 'few'
raise TypeError(f'Too {description} parameters for {cls.__name__}; actual {actual}, expected {expected}')
DictValues: Type[Any] = {}.values().__class__
def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
"""Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found."""
if isinstance(v, TypeVar):
yield v
elif hasattr(v, '__parameters__') and not get_origin(v) and lenient_issubclass(v, GenericModel):
yield from v.__parameters__
elif isinstance(v, (DictValues, list)):
for var in v:
yield from iter_contained_typevars(var)
else:
args = get_args(v)
for arg in args:
yield from iter_contained_typevars(arg)
def get_caller_frame_info() -> Tuple[Optional[str], bool]:
"""
Used inside a function to check whether it was called globally
Will only work against non-compiled code, therefore used only in pydantic.generics
:returns Tuple[module_name, called_globally]
"""
try:
previous_caller_frame = sys._getframe(2)
except ValueError as e:
raise RuntimeError('This function must be used inside another function') from e
except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it
return None, False
frame_globals = previous_caller_frame.f_globals
return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals
def _prepare_model_fields(
created_model: Type[GenericModel],
fields: Mapping[str, Any],
instance_type_hints: Mapping[str, type],
typevars_map: Mapping[Any, type],
) -> None:
"""
Replace DeferredType fields with concrete type hints and prepare them.
"""
for key, field in created_model.__fields__.items():
if key not in fields:
assert field.type_.__class__ is not DeferredType
# https://github.com/nedbat/coveragepy/issues/198
continue # pragma: no cover
assert field.type_.__class__ is DeferredType, field.type_.__class__
field_type_hint = instance_type_hints[key]
concrete_type = replace_types(field_type_hint, typevars_map)
field.type_ = concrete_type
field.outer_type_ = concrete_type
field.prepare()
created_model.__annotations__[key] = concrete_type

112
lib/pydantic/json.py Normal file
View file

@ -0,0 +1,112 @@
import datetime
from collections import deque
from decimal import Decimal
from enum import Enum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from pathlib import Path
from re import Pattern
from types import GeneratorType
from typing import Any, Callable, Dict, Type, Union
from uuid import UUID
from .color import Color
from .networks import NameEmail
from .types import SecretBytes, SecretStr
__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat'
def isoformat(o: Union[datetime.date, datetime.time]) -> str:
return o.isoformat()
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
"""
Encodes a Decimal as int of there's no exponent, otherwise float
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
where a integer (but not int typed) is used. Encoding this as a float
results in failed round-tripping between encode and parse.
Our Id type is a prime example of this.
>>> decimal_encoder(Decimal("1.0"))
1.0
>>> decimal_encoder(Decimal("1"))
1
"""
if dec_value.as_tuple().exponent >= 0:
return int(dec_value)
else:
return float(dec_value)
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
bytes: lambda o: o.decode(),
Color: str,
datetime.date: isoformat,
datetime.datetime: isoformat,
datetime.time: isoformat,
datetime.timedelta: lambda td: td.total_seconds(),
Decimal: decimal_encoder,
Enum: lambda o: o.value,
frozenset: list,
deque: list,
GeneratorType: list,
IPv4Address: str,
IPv4Interface: str,
IPv4Network: str,
IPv6Address: str,
IPv6Interface: str,
IPv6Network: str,
NameEmail: str,
Path: str,
Pattern: lambda o: o.pattern,
SecretBytes: str,
SecretStr: str,
set: list,
UUID: str,
}
def pydantic_encoder(obj: Any) -> Any:
from dataclasses import asdict, is_dataclass
from .main import BaseModel
if isinstance(obj, BaseModel):
return obj.dict()
elif is_dataclass(obj):
return asdict(obj)
# Check the class type and its superclasses for a matching encoder
for base in obj.__class__.__mro__[:-1]:
try:
encoder = ENCODERS_BY_TYPE[base]
except KeyError:
continue
return encoder(obj)
else: # We have exited the for loop without finding a suitable encoder
raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")
def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any:
# Check the class type and its superclasses for a matching encoder
for base in obj.__class__.__mro__[:-1]:
try:
encoder = type_encoders[base]
except KeyError:
continue
return encoder(obj)
else: # We have exited the for loop without finding a suitable encoder
return pydantic_encoder(obj)
def timedelta_isoformat(td: datetime.timedelta) -> str:
"""
ISO 8601 encoding for Python timedelta object.
"""
minutes, seconds = divmod(td.seconds, 60)
hours, minutes = divmod(minutes, 60)
return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S'

1109
lib/pydantic/main.py Normal file

File diff suppressed because it is too large Load diff

850
lib/pydantic/mypy.py Normal file
View file

@ -0,0 +1,850 @@
import sys
from configparser import ConfigParser
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type as TypingType, Union
from mypy.errorcodes import ErrorCode
from mypy.nodes import (
ARG_NAMED,
ARG_NAMED_OPT,
ARG_OPT,
ARG_POS,
ARG_STAR2,
MDEF,
Argument,
AssignmentStmt,
Block,
CallExpr,
ClassDef,
Context,
Decorator,
EllipsisExpr,
FuncBase,
FuncDef,
JsonDict,
MemberExpr,
NameExpr,
PassStmt,
PlaceholderNode,
RefExpr,
StrExpr,
SymbolNode,
SymbolTableNode,
TempNode,
TypeInfo,
TypeVarExpr,
Var,
)
from mypy.options import Options
from mypy.plugin import (
CheckerPluginInterface,
ClassDefContext,
FunctionContext,
MethodContext,
Plugin,
SemanticAnalyzerPluginInterface,
)
from mypy.plugins import dataclasses
from mypy.semanal import set_callable_name # type: ignore
from mypy.server.trigger import make_wildcard_trigger
from mypy.types import (
AnyType,
CallableType,
Instance,
NoneType,
Overloaded,
Type,
TypeOfAny,
TypeType,
TypeVarType,
UnionType,
get_proper_type,
)
from mypy.typevars import fill_typevars
from mypy.util import get_unique_redefinition_name
from mypy.version import __version__ as mypy_version
from pydantic.utils import is_valid_field
try:
from mypy.types import TypeVarDef # type: ignore[attr-defined]
except ImportError: # pragma: no cover
# Backward-compatible with TypeVarDef from Mypy 0.910.
from mypy.types import TypeVarType as TypeVarDef
CONFIGFILE_KEY = 'pydantic-mypy'
METADATA_KEY = 'pydantic-mypy-metadata'
BASEMODEL_FULLNAME = 'pydantic.main.BaseModel'
BASESETTINGS_FULLNAME = 'pydantic.env_settings.BaseSettings'
FIELD_FULLNAME = 'pydantic.fields.Field'
DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass'
def parse_mypy_version(version: str) -> Tuple[int, ...]:
return tuple(int(part) for part in version.split('+', 1)[0].split('.'))
MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version)
BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__'
def plugin(version: str) -> 'TypingType[Plugin]':
"""
`version` is the mypy version string
We might want to use this to print a warning if the mypy version being used is
newer, or especially older, than we expect (or need).
"""
return PydanticPlugin
class PydanticPlugin(Plugin):
def __init__(self, options: Options) -> None:
self.plugin_config = PydanticPluginConfig(options)
super().__init__(options)
def get_base_class_hook(self, fullname: str) -> 'Optional[Callable[[ClassDefContext], None]]':
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo): # pragma: no branch
# No branching may occur if the mypy cache has not been cleared
if any(get_fullname(base) == BASEMODEL_FULLNAME for base in sym.node.mro):
return self._pydantic_model_class_maker_callback
return None
def get_function_hook(self, fullname: str) -> 'Optional[Callable[[FunctionContext], Type]]':
sym = self.lookup_fully_qualified(fullname)
if sym and sym.fullname == FIELD_FULLNAME:
return self._pydantic_field_callback
return None
def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], Type]]:
if fullname.endswith('.from_orm'):
return from_orm_callback
return None
def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
if fullname == DATACLASS_FULLNAME:
return dataclasses.dataclass_class_maker_callback # type: ignore[return-value]
return None
def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None:
transformer = PydanticModelTransformer(ctx, self.plugin_config)
transformer.transform()
def _pydantic_field_callback(self, ctx: FunctionContext) -> 'Type':
"""
Extract the type of the `default` argument from the Field function, and use it as the return type.
In particular:
* Check whether the default and default_factory argument is specified.
* Output an error if both are specified.
* Retrieve the type of the argument which is specified, and use it as return type for the function.
"""
default_any_type = ctx.default_return_type
assert ctx.callee_arg_names[0] == 'default', '"default" is no longer first argument in Field()'
assert ctx.callee_arg_names[1] == 'default_factory', '"default_factory" is no longer second argument in Field()'
default_args = ctx.args[0]
default_factory_args = ctx.args[1]
if default_args and default_factory_args:
error_default_and_default_factory_specified(ctx.api, ctx.context)
return default_any_type
if default_args:
default_type = ctx.arg_types[0][0]
default_arg = default_args[0]
# Fallback to default Any type if the field is required
if not isinstance(default_arg, EllipsisExpr):
return default_type
elif default_factory_args:
default_factory_type = ctx.arg_types[1][0]
# Functions which use `ParamSpec` can be overloaded, exposing the callable's types as a parameter
# Pydantic calls the default factory without any argument, so we retrieve the first item
if isinstance(default_factory_type, Overloaded):
if MYPY_VERSION_TUPLE > (0, 910):
default_factory_type = default_factory_type.items[0]
else:
# Mypy0.910 exposes the items of overloaded types in a function
default_factory_type = default_factory_type.items()[0] # type: ignore[operator]
if isinstance(default_factory_type, CallableType):
ret_type = default_factory_type.ret_type
# mypy doesn't think `ret_type` has `args`, you'd think mypy should know,
# add this check in case it varies by version
args = getattr(ret_type, 'args', None)
if args:
if all(isinstance(arg, TypeVarType) for arg in args):
# Looks like the default factory is a type like `list` or `dict`, replace all args with `Any`
ret_type.args = tuple(default_any_type for _ in args) # type: ignore[attr-defined]
return ret_type
return default_any_type
class PydanticPluginConfig:
__slots__ = ('init_forbid_extra', 'init_typed', 'warn_required_dynamic_aliases', 'warn_untyped_fields')
init_forbid_extra: bool
init_typed: bool
warn_required_dynamic_aliases: bool
warn_untyped_fields: bool
def __init__(self, options: Options) -> None:
if options.config_file is None: # pragma: no cover
return
toml_config = parse_toml(options.config_file)
if toml_config is not None:
config = toml_config.get('tool', {}).get('pydantic-mypy', {})
for key in self.__slots__:
setting = config.get(key, False)
if not isinstance(setting, bool):
raise ValueError(f'Configuration value must be a boolean for key: {key}')
setattr(self, key, setting)
else:
plugin_config = ConfigParser()
plugin_config.read(options.config_file)
for key in self.__slots__:
setting = plugin_config.getboolean(CONFIGFILE_KEY, key, fallback=False)
setattr(self, key, setting)
def from_orm_callback(ctx: MethodContext) -> Type:
"""
Raise an error if orm_mode is not enabled
"""
model_type: Instance
if isinstance(ctx.type, CallableType) and isinstance(ctx.type.ret_type, Instance):
model_type = ctx.type.ret_type # called on the class
elif isinstance(ctx.type, Instance):
model_type = ctx.type # called on an instance (unusual, but still valid)
else: # pragma: no cover
detail = f'ctx.type: {ctx.type} (of type {ctx.type.__class__.__name__})'
error_unexpected_behavior(detail, ctx.api, ctx.context)
return ctx.default_return_type
pydantic_metadata = model_type.type.metadata.get(METADATA_KEY)
if pydantic_metadata is None:
return ctx.default_return_type
orm_mode = pydantic_metadata.get('config', {}).get('orm_mode')
if orm_mode is not True:
error_from_orm(get_name(model_type.type), ctx.api, ctx.context)
return ctx.default_return_type
class PydanticModelTransformer:
tracked_config_fields: Set[str] = {
'extra',
'allow_mutation',
'frozen',
'orm_mode',
'allow_population_by_field_name',
'alias_generator',
}
def __init__(self, ctx: ClassDefContext, plugin_config: PydanticPluginConfig) -> None:
self._ctx = ctx
self.plugin_config = plugin_config
def transform(self) -> None:
"""
Configures the BaseModel subclass according to the plugin settings.
In particular:
* determines the model config and fields,
* adds a fields-aware signature for the initializer and construct methods
* freezes the class if allow_mutation = False or frozen = True
* stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses
"""
ctx = self._ctx
info = self._ctx.cls.info
self.adjust_validator_signatures()
config = self.collect_config()
fields = self.collect_fields(config)
for field in fields:
if info[field.name].type is None:
if not ctx.api.final_iteration:
ctx.api.defer()
is_settings = any(get_fullname(base) == BASESETTINGS_FULLNAME for base in info.mro[:-1])
self.add_initializer(fields, config, is_settings)
self.add_construct_method(fields)
self.set_frozen(fields, frozen=config.allow_mutation is False or config.frozen is True)
info.metadata[METADATA_KEY] = {
'fields': {field.name: field.serialize() for field in fields},
'config': config.set_values_dict(),
}
def adjust_validator_signatures(self) -> None:
"""When we decorate a function `f` with `pydantic.validator(...), mypy sees
`f` as a regular method taking a `self` instance, even though pydantic
internally wraps `f` with `classmethod` if necessary.
Teach mypy this by marking any function whose outermost decorator is a
`validator()` call as a classmethod.
"""
for name, sym in self._ctx.cls.info.names.items():
if isinstance(sym.node, Decorator):
first_dec = sym.node.original_decorators[0]
if (
isinstance(first_dec, CallExpr)
and isinstance(first_dec.callee, NameExpr)
and first_dec.callee.fullname == 'pydantic.class_validators.validator'
):
sym.node.func.is_class = True
def collect_config(self) -> 'ModelConfigData':
"""
Collects the values of the config attributes that are used by the plugin, accounting for parent classes.
"""
ctx = self._ctx
cls = ctx.cls
config = ModelConfigData()
for stmt in cls.defs.body:
if not isinstance(stmt, ClassDef):
continue
if stmt.name == 'Config':
for substmt in stmt.defs.body:
if not isinstance(substmt, AssignmentStmt):
continue
config.update(self.get_config_update(substmt))
if (
config.has_alias_generator
and not config.allow_population_by_field_name
and self.plugin_config.warn_required_dynamic_aliases
):
error_required_dynamic_aliases(ctx.api, stmt)
for info in cls.info.mro[1:]: # 0 is the current class
if METADATA_KEY not in info.metadata:
continue
# Each class depends on the set of fields in its ancestors
ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info)))
for name, value in info.metadata[METADATA_KEY]['config'].items():
config.setdefault(name, value)
return config
def collect_fields(self, model_config: 'ModelConfigData') -> List['PydanticModelField']:
"""
Collects the fields for the model, accounting for parent classes
"""
# First, collect fields belonging to the current class.
ctx = self._ctx
cls = self._ctx.cls
fields = [] # type: List[PydanticModelField]
known_fields = set() # type: Set[str]
for stmt in cls.defs.body:
if not isinstance(stmt, AssignmentStmt): # `and stmt.new_syntax` to require annotation
continue
lhs = stmt.lvalues[0]
if not isinstance(lhs, NameExpr) or not is_valid_field(lhs.name):
continue
if not stmt.new_syntax and self.plugin_config.warn_untyped_fields:
error_untyped_fields(ctx.api, stmt)
# if lhs.name == '__config__': # BaseConfig not well handled; I'm not sure why yet
# continue
sym = cls.info.names.get(lhs.name)
if sym is None: # pragma: no cover
# This is likely due to a star import (see the dataclasses plugin for a more detailed explanation)
# This is the same logic used in the dataclasses plugin
continue
node = sym.node
if isinstance(node, PlaceholderNode): # pragma: no cover
# See the PlaceholderNode docstring for more detail about how this can occur
# Basically, it is an edge case when dealing with complex import logic
# This is the same logic used in the dataclasses plugin
continue
if not isinstance(node, Var): # pragma: no cover
# Don't know if this edge case still happens with the `is_valid_field` check above
# but better safe than sorry
continue
# x: ClassVar[int] is ignored by dataclasses.
if node.is_classvar:
continue
is_required = self.get_is_required(cls, stmt, lhs)
alias, has_dynamic_alias = self.get_alias_info(stmt)
if (
has_dynamic_alias
and not model_config.allow_population_by_field_name
and self.plugin_config.warn_required_dynamic_aliases
):
error_required_dynamic_aliases(ctx.api, stmt)
fields.append(
PydanticModelField(
name=lhs.name,
is_required=is_required,
alias=alias,
has_dynamic_alias=has_dynamic_alias,
line=stmt.line,
column=stmt.column,
)
)
known_fields.add(lhs.name)
all_fields = fields.copy()
for info in cls.info.mro[1:]: # 0 is the current class, -2 is BaseModel, -1 is object
if METADATA_KEY not in info.metadata:
continue
superclass_fields = []
# Each class depends on the set of fields in its ancestors
ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info)))
for name, data in info.metadata[METADATA_KEY]['fields'].items():
if name not in known_fields:
field = PydanticModelField.deserialize(info, data)
known_fields.add(name)
superclass_fields.append(field)
else:
(field,) = (a for a in all_fields if a.name == name)
all_fields.remove(field)
superclass_fields.append(field)
all_fields = superclass_fields + all_fields
return all_fields
def add_initializer(self, fields: List['PydanticModelField'], config: 'ModelConfigData', is_settings: bool) -> None:
"""
Adds a fields-aware `__init__` method to the class.
The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings.
"""
ctx = self._ctx
typed = self.plugin_config.init_typed
use_alias = config.allow_population_by_field_name is not True
force_all_optional = is_settings or bool(
config.has_alias_generator and not config.allow_population_by_field_name
)
init_arguments = self.get_field_arguments(
fields, typed=typed, force_all_optional=force_all_optional, use_alias=use_alias
)
if not self.should_init_forbid_extra(fields, config):
var = Var('kwargs')
init_arguments.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2))
if '__init__' not in ctx.cls.info.names:
add_method(ctx, '__init__', init_arguments, NoneType())
def add_construct_method(self, fields: List['PydanticModelField']) -> None:
"""
Adds a fully typed `construct` classmethod to the class.
Similar to the fields-aware __init__ method, but always uses the field names (not aliases),
and does not treat settings fields as optional.
"""
ctx = self._ctx
set_str = ctx.api.named_type(f'{BUILTINS_NAME}.set', [ctx.api.named_type(f'{BUILTINS_NAME}.str')])
optional_set_str = UnionType([set_str, NoneType()])
fields_set_argument = Argument(Var('_fields_set', optional_set_str), optional_set_str, None, ARG_OPT)
construct_arguments = self.get_field_arguments(fields, typed=True, force_all_optional=False, use_alias=False)
construct_arguments = [fields_set_argument] + construct_arguments
obj_type = ctx.api.named_type(f'{BUILTINS_NAME}.object')
self_tvar_name = '_PydanticBaseModel' # Make sure it does not conflict with other names in the class
tvar_fullname = ctx.cls.fullname + '.' + self_tvar_name
tvd = TypeVarDef(self_tvar_name, tvar_fullname, -1, [], obj_type)
self_tvar_expr = TypeVarExpr(self_tvar_name, tvar_fullname, [], obj_type)
ctx.cls.info.names[self_tvar_name] = SymbolTableNode(MDEF, self_tvar_expr)
# Backward-compatible with TypeVarDef from Mypy 0.910.
if isinstance(tvd, TypeVarType):
self_type = tvd
else:
self_type = TypeVarType(tvd) # type: ignore[call-arg]
add_method(
ctx,
'construct',
construct_arguments,
return_type=self_type,
self_type=self_type,
tvar_def=tvd,
is_classmethod=True,
)
def set_frozen(self, fields: List['PydanticModelField'], frozen: bool) -> None:
"""
Marks all fields as properties so that attempts to set them trigger mypy errors.
This is the same approach used by the attrs and dataclasses plugins.
"""
info = self._ctx.cls.info
for field in fields:
sym_node = info.names.get(field.name)
if sym_node is not None:
var = sym_node.node
assert isinstance(var, Var)
var.is_property = frozen
else:
var = field.to_var(info, use_alias=False)
var.info = info
var.is_property = frozen
var._fullname = get_fullname(info) + '.' + get_name(var)
info.names[get_name(var)] = SymbolTableNode(MDEF, var)
def get_config_update(self, substmt: AssignmentStmt) -> Optional['ModelConfigData']:
"""
Determines the config update due to a single statement in the Config class definition.
Warns if a tracked config attribute is set to a value the plugin doesn't know how to interpret (e.g., an int)
"""
lhs = substmt.lvalues[0]
if not (isinstance(lhs, NameExpr) and lhs.name in self.tracked_config_fields):
return None
if lhs.name == 'extra':
if isinstance(substmt.rvalue, StrExpr):
forbid_extra = substmt.rvalue.value == 'forbid'
elif isinstance(substmt.rvalue, MemberExpr):
forbid_extra = substmt.rvalue.name == 'forbid'
else:
error_invalid_config_value(lhs.name, self._ctx.api, substmt)
return None
return ModelConfigData(forbid_extra=forbid_extra)
if lhs.name == 'alias_generator':
has_alias_generator = True
if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname == 'builtins.None':
has_alias_generator = False
return ModelConfigData(has_alias_generator=has_alias_generator)
if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname in ('builtins.True', 'builtins.False'):
return ModelConfigData(**{lhs.name: substmt.rvalue.fullname == 'builtins.True'})
error_invalid_config_value(lhs.name, self._ctx.api, substmt)
return None
@staticmethod
def get_is_required(cls: ClassDef, stmt: AssignmentStmt, lhs: NameExpr) -> bool:
"""
Returns a boolean indicating whether the field defined in `stmt` is a required field.
"""
expr = stmt.rvalue
if isinstance(expr, TempNode):
# TempNode means annotation-only, so only non-required if Optional
value_type = get_proper_type(cls.info[lhs.name].type)
if isinstance(value_type, UnionType) and any(isinstance(item, NoneType) for item in value_type.items):
# Annotated as Optional, or otherwise having NoneType in the union
return False
return True
if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME:
# The "default value" is a call to `Field`; at this point, the field is
# only required if default is Ellipsis (i.e., `field_name: Annotation = Field(...)`) or if default_factory
# is specified.
for arg, name in zip(expr.args, expr.arg_names):
# If name is None, then this arg is the default because it is the only positonal argument.
if name is None or name == 'default':
return arg.__class__ is EllipsisExpr
if name == 'default_factory':
return False
return True
# Only required if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`)
return isinstance(expr, EllipsisExpr)
@staticmethod
def get_alias_info(stmt: AssignmentStmt) -> Tuple[Optional[str], bool]:
"""
Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`.
`has_dynamic_alias` is True if and only if an alias is provided, but not as a string literal.
If `has_dynamic_alias` is True, `alias` will be None.
"""
expr = stmt.rvalue
if isinstance(expr, TempNode):
# TempNode means annotation-only
return None, False
if not (
isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME
):
# Assigned value is not a call to pydantic.fields.Field
return None, False
for i, arg_name in enumerate(expr.arg_names):
if arg_name != 'alias':
continue
arg = expr.args[i]
if isinstance(arg, StrExpr):
return arg.value, False
else:
return None, True
return None, False
def get_field_arguments(
self, fields: List['PydanticModelField'], typed: bool, force_all_optional: bool, use_alias: bool
) -> List[Argument]:
"""
Helper function used during the construction of the `__init__` and `construct` method signatures.
Returns a list of mypy Argument instances for use in the generated signatures.
"""
info = self._ctx.cls.info
arguments = [
field.to_argument(info, typed=typed, force_optional=force_all_optional, use_alias=use_alias)
for field in fields
if not (use_alias and field.has_dynamic_alias)
]
return arguments
def should_init_forbid_extra(self, fields: List['PydanticModelField'], config: 'ModelConfigData') -> bool:
"""
Indicates whether the generated `__init__` should get a `**kwargs` at the end of its signature
We disallow arbitrary kwargs if the extra config setting is "forbid", or if the plugin config says to,
*unless* a required dynamic alias is present (since then we can't determine a valid signature).
"""
if not config.allow_population_by_field_name:
if self.is_dynamic_alias_present(fields, bool(config.has_alias_generator)):
return False
if config.forbid_extra:
return True
return self.plugin_config.init_forbid_extra
@staticmethod
def is_dynamic_alias_present(fields: List['PydanticModelField'], has_alias_generator: bool) -> bool:
"""
Returns whether any fields on the model have a "dynamic alias", i.e., an alias that cannot be
determined during static analysis.
"""
for field in fields:
if field.has_dynamic_alias:
return True
if has_alias_generator:
for field in fields:
if field.alias is None:
return True
return False
class PydanticModelField:
def __init__(
self, name: str, is_required: bool, alias: Optional[str], has_dynamic_alias: bool, line: int, column: int
):
self.name = name
self.is_required = is_required
self.alias = alias
self.has_dynamic_alias = has_dynamic_alias
self.line = line
self.column = column
def to_var(self, info: TypeInfo, use_alias: bool) -> Var:
name = self.name
if use_alias and self.alias is not None:
name = self.alias
return Var(name, info[self.name].type)
def to_argument(self, info: TypeInfo, typed: bool, force_optional: bool, use_alias: bool) -> Argument:
if typed and info[self.name].type is not None:
type_annotation = info[self.name].type
else:
type_annotation = AnyType(TypeOfAny.explicit)
return Argument(
variable=self.to_var(info, use_alias),
type_annotation=type_annotation,
initializer=None,
kind=ARG_NAMED_OPT if force_optional or not self.is_required else ARG_NAMED,
)
def serialize(self) -> JsonDict:
return self.__dict__
@classmethod
def deserialize(cls, info: TypeInfo, data: JsonDict) -> 'PydanticModelField':
return cls(**data)
class ModelConfigData:
def __init__(
self,
forbid_extra: Optional[bool] = None,
allow_mutation: Optional[bool] = None,
frozen: Optional[bool] = None,
orm_mode: Optional[bool] = None,
allow_population_by_field_name: Optional[bool] = None,
has_alias_generator: Optional[bool] = None,
):
self.forbid_extra = forbid_extra
self.allow_mutation = allow_mutation
self.frozen = frozen
self.orm_mode = orm_mode
self.allow_population_by_field_name = allow_population_by_field_name
self.has_alias_generator = has_alias_generator
def set_values_dict(self) -> Dict[str, Any]:
return {k: v for k, v in self.__dict__.items() if v is not None}
def update(self, config: Optional['ModelConfigData']) -> None:
if config is None:
return
for k, v in config.set_values_dict().items():
setattr(self, k, v)
def setdefault(self, key: str, value: Any) -> None:
if getattr(self, key) is None:
setattr(self, key, value)
ERROR_ORM = ErrorCode('pydantic-orm', 'Invalid from_orm call', 'Pydantic')
ERROR_CONFIG = ErrorCode('pydantic-config', 'Invalid config value', 'Pydantic')
ERROR_ALIAS = ErrorCode('pydantic-alias', 'Dynamic alias disallowed', 'Pydantic')
ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic')
ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic')
ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic')
def error_from_orm(model_name: str, api: CheckerPluginInterface, context: Context) -> None:
api.fail(f'"{model_name}" does not have orm_mode=True', context, code=ERROR_ORM)
def error_invalid_config_value(name: str, api: SemanticAnalyzerPluginInterface, context: Context) -> None:
api.fail(f'Invalid value for "Config.{name}"', context, code=ERROR_CONFIG)
def error_required_dynamic_aliases(api: SemanticAnalyzerPluginInterface, context: Context) -> None:
api.fail('Required dynamic aliases disallowed', context, code=ERROR_ALIAS)
def error_unexpected_behavior(detail: str, api: CheckerPluginInterface, context: Context) -> None: # pragma: no cover
# Can't think of a good way to test this, but I confirmed it renders as desired by adding to a non-error path
link = 'https://github.com/pydantic/pydantic/issues/new/choose'
full_message = f'The pydantic mypy plugin ran into unexpected behavior: {detail}\n'
full_message += f'Please consider reporting this bug at {link} so we can try to fix it!'
api.fail(full_message, context, code=ERROR_UNEXPECTED)
def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context) -> None:
api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED)
def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None:
api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS)
def add_method(
ctx: ClassDefContext,
name: str,
args: List[Argument],
return_type: Type,
self_type: Optional[Type] = None,
tvar_def: Optional[TypeVarDef] = None,
is_classmethod: bool = False,
is_new: bool = False,
# is_staticmethod: bool = False,
) -> None:
"""
Adds a new method to a class.
This can be dropped if/when https://github.com/python/mypy/issues/7301 is merged
"""
info = ctx.cls.info
# First remove any previously generated methods with the same name
# to avoid clashes and problems in the semantic analyzer.
if name in info.names:
sym = info.names[name]
if sym.plugin_generated and isinstance(sym.node, FuncDef):
ctx.cls.defs.body.remove(sym.node) # pragma: no cover
self_type = self_type or fill_typevars(info)
if is_classmethod or is_new:
first = [Argument(Var('_cls'), TypeType.make_normalized(self_type), None, ARG_POS)]
# elif is_staticmethod:
# first = []
else:
self_type = self_type or fill_typevars(info)
first = [Argument(Var('__pydantic_self__'), self_type, None, ARG_POS)]
args = first + args
arg_types, arg_names, arg_kinds = [], [], []
for arg in args:
assert arg.type_annotation, 'All arguments must be fully typed.'
arg_types.append(arg.type_annotation)
arg_names.append(get_name(arg.variable))
arg_kinds.append(arg.kind)
function_type = ctx.api.named_type(f'{BUILTINS_NAME}.function')
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
if tvar_def:
signature.variables = [tvar_def]
func = FuncDef(name, args, Block([PassStmt()]))
func.info = info
func.type = set_callable_name(signature, func)
func.is_class = is_classmethod
# func.is_static = is_staticmethod
func._fullname = get_fullname(info) + '.' + name
func.line = info.line
# NOTE: we would like the plugin generated node to dominate, but we still
# need to keep any existing definitions so they get semantically analyzed.
if name in info.names:
# Get a nice unique name instead.
r_name = get_unique_redefinition_name(name, info.names)
info.names[r_name] = info.names[name]
if is_classmethod: # or is_staticmethod:
func.is_decorated = True
v = Var(name, func.type)
v.info = info
v._fullname = func._fullname
# if is_classmethod:
v.is_classmethod = True
dec = Decorator(func, [NameExpr('classmethod')], v)
# else:
# v.is_staticmethod = True
# dec = Decorator(func, [NameExpr('staticmethod')], v)
dec.line = info.line
sym = SymbolTableNode(MDEF, dec)
else:
sym = SymbolTableNode(MDEF, func)
sym.plugin_generated = True
info.names[name] = sym
info.defn.defs.body.append(func)
def get_fullname(x: Union[FuncBase, SymbolNode]) -> str:
"""
Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped.
"""
fn = x.fullname
if callable(fn): # pragma: no cover
return fn()
return fn
def get_name(x: Union[FuncBase, SymbolNode]) -> str:
"""
Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped.
"""
fn = x.name
if callable(fn): # pragma: no cover
return fn()
return fn
def parse_toml(config_file: str) -> Optional[Dict[str, Any]]:
if not config_file.endswith('.toml'):
return None
read_mode = 'rb'
if sys.version_info >= (3, 11):
import tomllib as toml_
else:
try:
import tomli as toml_
except ImportError:
# older versions of mypy have toml as a dependency, not tomli
read_mode = 'r'
try:
import toml as toml_ # type: ignore[no-redef]
except ImportError: # pragma: no cover
import warnings
warnings.warn('No TOML parser installed, cannot read configuration from `pyproject.toml`.')
return None
with open(config_file, read_mode) as rf:
return toml_.load(rf) # type: ignore[arg-type]

736
lib/pydantic/networks.py Normal file
View file

@ -0,0 +1,736 @@
import re
from ipaddress import (
IPv4Address,
IPv4Interface,
IPv4Network,
IPv6Address,
IPv6Interface,
IPv6Network,
_BaseAddress,
_BaseNetwork,
)
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Generator,
List,
Match,
Optional,
Pattern,
Set,
Tuple,
Type,
Union,
cast,
no_type_check,
)
from . import errors
from .utils import Representation, update_not_none
from .validators import constr_length_validator, str_validator
if TYPE_CHECKING:
import email_validator
from typing_extensions import TypedDict
from .config import BaseConfig
from .fields import ModelField
from .typing import AnyCallable
CallableGenerator = Generator[AnyCallable, None, None]
class Parts(TypedDict, total=False):
scheme: str
user: Optional[str]
password: Optional[str]
ipv4: Optional[str]
ipv6: Optional[str]
domain: Optional[str]
port: Optional[str]
path: Optional[str]
query: Optional[str]
fragment: Optional[str]
class HostParts(TypedDict, total=False):
host: str
tld: Optional[str]
host_type: Optional[str]
port: Optional[str]
rebuild: bool
else:
email_validator = None
class Parts(dict):
pass
NetworkType = Union[str, bytes, int, Tuple[Union[str, bytes, int], Union[str, int]]]
__all__ = [
'AnyUrl',
'AnyHttpUrl',
'FileUrl',
'HttpUrl',
'stricturl',
'EmailStr',
'NameEmail',
'IPvAnyAddress',
'IPvAnyInterface',
'IPvAnyNetwork',
'PostgresDsn',
'CockroachDsn',
'AmqpDsn',
'RedisDsn',
'MongoDsn',
'KafkaDsn',
'validate_email',
]
_url_regex_cache = None
_multi_host_url_regex_cache = None
_ascii_domain_regex_cache = None
_int_domain_regex_cache = None
_host_regex_cache = None
_host_regex = (
r'(?:'
r'(?P<ipv4>(?:\d{1,3}\.){3}\d{1,3})(?=$|[/:#?])|' # ipv4
r'(?P<ipv6>\[[A-F0-9]*:[A-F0-9:]+\])(?=$|[/:#?])|' # ipv6
r'(?P<domain>[^\s/:?#]+)' # domain, validation occurs later
r')?'
r'(?::(?P<port>\d+))?' # port
)
_scheme_regex = r'(?:(?P<scheme>[a-z][a-z0-9+\-.]+)://)?' # scheme https://tools.ietf.org/html/rfc3986#appendix-A
_user_info_regex = r'(?:(?P<user>[^\s:/]*)(?::(?P<password>[^\s/]*))?@)?'
_path_regex = r'(?P<path>/[^\s?#]*)?'
_query_regex = r'(?:\?(?P<query>[^\s#]*))?'
_fragment_regex = r'(?:#(?P<fragment>[^\s#]*))?'
def url_regex() -> Pattern[str]:
global _url_regex_cache
if _url_regex_cache is None:
_url_regex_cache = re.compile(
rf'{_scheme_regex}{_user_info_regex}{_host_regex}{_path_regex}{_query_regex}{_fragment_regex}',
re.IGNORECASE,
)
return _url_regex_cache
def multi_host_url_regex() -> Pattern[str]:
"""
Compiled multi host url regex.
Additionally to `url_regex` it allows to match multiple hosts.
E.g. host1.db.net,host2.db.net
"""
global _multi_host_url_regex_cache
if _multi_host_url_regex_cache is None:
_multi_host_url_regex_cache = re.compile(
rf'{_scheme_regex}{_user_info_regex}'
r'(?P<hosts>([^/]*))' # validation occurs later
rf'{_path_regex}{_query_regex}{_fragment_regex}',
re.IGNORECASE,
)
return _multi_host_url_regex_cache
def ascii_domain_regex() -> Pattern[str]:
global _ascii_domain_regex_cache
if _ascii_domain_regex_cache is None:
ascii_chunk = r'[_0-9a-z](?:[-_0-9a-z]{0,61}[_0-9a-z])?'
ascii_domain_ending = r'(?P<tld>\.[a-z]{2,63})?\.?'
_ascii_domain_regex_cache = re.compile(
fr'(?:{ascii_chunk}\.)*?{ascii_chunk}{ascii_domain_ending}', re.IGNORECASE
)
return _ascii_domain_regex_cache
def int_domain_regex() -> Pattern[str]:
global _int_domain_regex_cache
if _int_domain_regex_cache is None:
int_chunk = r'[_0-9a-\U00040000](?:[-_0-9a-\U00040000]{0,61}[_0-9a-\U00040000])?'
int_domain_ending = r'(?P<tld>(\.[^\W\d_]{2,63})|(\.(?:xn--)[_0-9a-z-]{2,63}))?\.?'
_int_domain_regex_cache = re.compile(fr'(?:{int_chunk}\.)*?{int_chunk}{int_domain_ending}', re.IGNORECASE)
return _int_domain_regex_cache
def host_regex() -> Pattern[str]:
global _host_regex_cache
if _host_regex_cache is None:
_host_regex_cache = re.compile(
_host_regex,
re.IGNORECASE,
)
return _host_regex_cache
class AnyUrl(str):
strip_whitespace = True
min_length = 1
max_length = 2**16
allowed_schemes: Optional[Collection[str]] = None
tld_required: bool = False
user_required: bool = False
host_required: bool = True
hidden_parts: Set[str] = set()
__slots__ = ('scheme', 'user', 'password', 'host', 'tld', 'host_type', 'port', 'path', 'query', 'fragment')
@no_type_check
def __new__(cls, url: Optional[str], **kwargs) -> object:
return str.__new__(cls, cls.build(**kwargs) if url is None else url)
def __init__(
self,
url: str,
*,
scheme: str,
user: Optional[str] = None,
password: Optional[str] = None,
host: Optional[str] = None,
tld: Optional[str] = None,
host_type: str = 'domain',
port: Optional[str] = None,
path: Optional[str] = None,
query: Optional[str] = None,
fragment: Optional[str] = None,
) -> None:
str.__init__(url)
self.scheme = scheme
self.user = user
self.password = password
self.host = host
self.tld = tld
self.host_type = host_type
self.port = port
self.path = path
self.query = query
self.fragment = fragment
@classmethod
def build(
cls,
*,
scheme: str,
user: Optional[str] = None,
password: Optional[str] = None,
host: str,
port: Optional[str] = None,
path: Optional[str] = None,
query: Optional[str] = None,
fragment: Optional[str] = None,
**_kwargs: str,
) -> str:
parts = Parts(
scheme=scheme,
user=user,
password=password,
host=host,
port=port,
path=path,
query=query,
fragment=fragment,
**_kwargs, # type: ignore[misc]
)
url = scheme + '://'
if user:
url += user
if password:
url += ':' + password
if user or password:
url += '@'
url += host
if port and ('port' not in cls.hidden_parts or cls.get_default_parts(parts).get('port') != port):
url += ':' + port
if path:
url += path
if query:
url += '?' + query
if fragment:
url += '#' + fragment
return url
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
update_not_none(field_schema, minLength=cls.min_length, maxLength=cls.max_length, format='uri')
@classmethod
def __get_validators__(cls) -> 'CallableGenerator':
yield cls.validate
@classmethod
def validate(cls, value: Any, field: 'ModelField', config: 'BaseConfig') -> 'AnyUrl':
if value.__class__ == cls:
return value
value = str_validator(value)
if cls.strip_whitespace:
value = value.strip()
url: str = cast(str, constr_length_validator(value, field, config))
m = cls._match_url(url)
# the regex should always match, if it doesn't please report with details of the URL tried
assert m, 'URL regex failed unexpectedly'
original_parts = cast('Parts', m.groupdict())
parts = cls.apply_default_parts(original_parts)
parts = cls.validate_parts(parts)
if m.end() != len(url):
raise errors.UrlExtraError(extra=url[m.end() :])
return cls._build_url(m, url, parts)
@classmethod
def _build_url(cls, m: Match[str], url: str, parts: 'Parts') -> 'AnyUrl':
"""
Validate hosts and build the AnyUrl object. Split from `validate` so this method
can be altered in `MultiHostDsn`.
"""
host, tld, host_type, rebuild = cls.validate_host(parts)
return cls(
None if rebuild else url,
scheme=parts['scheme'],
user=parts['user'],
password=parts['password'],
host=host,
tld=tld,
host_type=host_type,
port=parts['port'],
path=parts['path'],
query=parts['query'],
fragment=parts['fragment'],
)
@staticmethod
def _match_url(url: str) -> Optional[Match[str]]:
return url_regex().match(url)
@staticmethod
def _validate_port(port: Optional[str]) -> None:
if port is not None and int(port) > 65_535:
raise errors.UrlPortError()
@classmethod
def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts':
"""
A method used to validate parts of a URL.
Could be overridden to set default values for parts if missing
"""
scheme = parts['scheme']
if scheme is None:
raise errors.UrlSchemeError()
if cls.allowed_schemes and scheme.lower() not in cls.allowed_schemes:
raise errors.UrlSchemePermittedError(set(cls.allowed_schemes))
if validate_port:
cls._validate_port(parts['port'])
user = parts['user']
if cls.user_required and user is None:
raise errors.UrlUserInfoError()
return parts
@classmethod
def validate_host(cls, parts: 'Parts') -> Tuple[str, Optional[str], str, bool]:
tld, host_type, rebuild = None, None, False
for f in ('domain', 'ipv4', 'ipv6'):
host = parts[f] # type: ignore[literal-required]
if host:
host_type = f
break
if host is None:
if cls.host_required:
raise errors.UrlHostError()
elif host_type == 'domain':
is_international = False
d = ascii_domain_regex().fullmatch(host)
if d is None:
d = int_domain_regex().fullmatch(host)
if d is None:
raise errors.UrlHostError()
is_international = True
tld = d.group('tld')
if tld is None and not is_international:
d = int_domain_regex().fullmatch(host)
assert d is not None
tld = d.group('tld')
is_international = True
if tld is not None:
tld = tld[1:]
elif cls.tld_required:
raise errors.UrlHostTldError()
if is_international:
host_type = 'int_domain'
rebuild = True
host = host.encode('idna').decode('ascii')
if tld is not None:
tld = tld.encode('idna').decode('ascii')
return host, tld, host_type, rebuild # type: ignore
@staticmethod
def get_default_parts(parts: 'Parts') -> 'Parts':
return {}
@classmethod
def apply_default_parts(cls, parts: 'Parts') -> 'Parts':
for key, value in cls.get_default_parts(parts).items():
if not parts[key]: # type: ignore[literal-required]
parts[key] = value # type: ignore[literal-required]
return parts
def __repr__(self) -> str:
extra = ', '.join(f'{n}={getattr(self, n)!r}' for n in self.__slots__ if getattr(self, n) is not None)
return f'{self.__class__.__name__}({super().__repr__()}, {extra})'
class AnyHttpUrl(AnyUrl):
allowed_schemes = {'http', 'https'}
__slots__ = ()
class HttpUrl(AnyHttpUrl):
tld_required = True
# https://stackoverflow.com/questions/417142/what-is-the-maximum-length-of-a-url-in-different-browsers
max_length = 2083
hidden_parts = {'port'}
@staticmethod
def get_default_parts(parts: 'Parts') -> 'Parts':
return {'port': '80' if parts['scheme'] == 'http' else '443'}
class FileUrl(AnyUrl):
allowed_schemes = {'file'}
host_required = False
__slots__ = ()
class MultiHostDsn(AnyUrl):
__slots__ = AnyUrl.__slots__ + ('hosts',)
def __init__(self, *args: Any, hosts: Optional[List['HostParts']] = None, **kwargs: Any):
super().__init__(*args, **kwargs)
self.hosts = hosts
@staticmethod
def _match_url(url: str) -> Optional[Match[str]]:
return multi_host_url_regex().match(url)
@classmethod
def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts':
return super().validate_parts(parts, validate_port=False)
@classmethod
def _build_url(cls, m: Match[str], url: str, parts: 'Parts') -> 'MultiHostDsn':
hosts_parts: List['HostParts'] = []
host_re = host_regex()
for host in m.groupdict()['hosts'].split(','):
d: Parts = host_re.match(host).groupdict() # type: ignore
host, tld, host_type, rebuild = cls.validate_host(d)
port = d.get('port')
cls._validate_port(port)
hosts_parts.append(
{
'host': host,
'host_type': host_type,
'tld': tld,
'rebuild': rebuild,
'port': port,
}
)
if len(hosts_parts) > 1:
return cls(
None if any([hp['rebuild'] for hp in hosts_parts]) else url,
scheme=parts['scheme'],
user=parts['user'],
password=parts['password'],
path=parts['path'],
query=parts['query'],
fragment=parts['fragment'],
host_type=None,
hosts=hosts_parts,
)
else:
# backwards compatibility with single host
host_part = hosts_parts[0]
return cls(
None if host_part['rebuild'] else url,
scheme=parts['scheme'],
user=parts['user'],
password=parts['password'],
host=host_part['host'],
tld=host_part['tld'],
host_type=host_part['host_type'],
port=host_part.get('port'),
path=parts['path'],
query=parts['query'],
fragment=parts['fragment'],
)
class PostgresDsn(MultiHostDsn):
allowed_schemes = {
'postgres',
'postgresql',
'postgresql+asyncpg',
'postgresql+pg8000',
'postgresql+psycopg2',
'postgresql+psycopg2cffi',
'postgresql+py-postgresql',
'postgresql+pygresql',
}
user_required = True
__slots__ = ()
class CockroachDsn(AnyUrl):
allowed_schemes = {
'cockroachdb',
'cockroachdb+psycopg2',
'cockroachdb+asyncpg',
}
user_required = True
class AmqpDsn(AnyUrl):
allowed_schemes = {'amqp', 'amqps'}
host_required = False
class RedisDsn(AnyUrl):
__slots__ = ()
allowed_schemes = {'redis', 'rediss'}
host_required = False
@staticmethod
def get_default_parts(parts: 'Parts') -> 'Parts':
return {
'domain': 'localhost' if not (parts['ipv4'] or parts['ipv6']) else '',
'port': '6379',
'path': '/0',
}
class MongoDsn(AnyUrl):
allowed_schemes = {'mongodb'}
# TODO: Needed to generic "Parts" for "Replica Set", "Sharded Cluster", and other mongodb deployment modes
@staticmethod
def get_default_parts(parts: 'Parts') -> 'Parts':
return {
'port': '27017',
}
class KafkaDsn(AnyUrl):
allowed_schemes = {'kafka'}
@staticmethod
def get_default_parts(parts: 'Parts') -> 'Parts':
return {
'domain': 'localhost',
'port': '9092',
}
def stricturl(
*,
strip_whitespace: bool = True,
min_length: int = 1,
max_length: int = 2**16,
tld_required: bool = True,
host_required: bool = True,
allowed_schemes: Optional[Collection[str]] = None,
) -> Type[AnyUrl]:
# use kwargs then define conf in a dict to aid with IDE type hinting
namespace = dict(
strip_whitespace=strip_whitespace,
min_length=min_length,
max_length=max_length,
tld_required=tld_required,
host_required=host_required,
allowed_schemes=allowed_schemes,
)
return type('UrlValue', (AnyUrl,), namespace)
def import_email_validator() -> None:
global email_validator
try:
import email_validator
except ImportError as e:
raise ImportError('email-validator is not installed, run `pip install pydantic[email]`') from e
class EmailStr(str):
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
field_schema.update(type='string', format='email')
@classmethod
def __get_validators__(cls) -> 'CallableGenerator':
# included here and below so the error happens straight away
import_email_validator()
yield str_validator
yield cls.validate
@classmethod
def validate(cls, value: Union[str]) -> str:
return validate_email(value)[1]
class NameEmail(Representation):
__slots__ = 'name', 'email'
def __init__(self, name: str, email: str):
self.name = name
self.email = email
def __eq__(self, other: Any) -> bool:
return isinstance(other, NameEmail) and (self.name, self.email) == (other.name, other.email)
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
field_schema.update(type='string', format='name-email')
@classmethod
def __get_validators__(cls) -> 'CallableGenerator':
import_email_validator()
yield cls.validate
@classmethod
def validate(cls, value: Any) -> 'NameEmail':
if value.__class__ == cls:
return value
value = str_validator(value)
return cls(*validate_email(value))
def __str__(self) -> str:
return f'{self.name} <{self.email}>'
class IPvAnyAddress(_BaseAddress):
__slots__ = ()
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
field_schema.update(type='string', format='ipvanyaddress')
@classmethod
def __get_validators__(cls) -> 'CallableGenerator':
yield cls.validate
@classmethod
def validate(cls, value: Union[str, bytes, int]) -> Union[IPv4Address, IPv6Address]:
try:
return IPv4Address(value)
except ValueError:
pass
try:
return IPv6Address(value)
except ValueError:
raise errors.IPvAnyAddressError()
class IPvAnyInterface(_BaseAddress):
__slots__ = ()
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
field_schema.update(type='string', format='ipvanyinterface')
@classmethod
def __get_validators__(cls) -> 'CallableGenerator':
yield cls.validate
@classmethod
def validate(cls, value: NetworkType) -> Union[IPv4Interface, IPv6Interface]:
try:
return IPv4Interface(value)
except ValueError:
pass
try:
return IPv6Interface(value)
except ValueError:
raise errors.IPvAnyInterfaceError()
class IPvAnyNetwork(_BaseNetwork): # type: ignore
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
field_schema.update(type='string', format='ipvanynetwork')
@classmethod
def __get_validators__(cls) -> 'CallableGenerator':
yield cls.validate
@classmethod
def validate(cls, value: NetworkType) -> Union[IPv4Network, IPv6Network]:
# Assume IP Network is defined with a default value for ``strict`` argument.
# Define your own class if you want to specify network address check strictness.
try:
return IPv4Network(value)
except ValueError:
pass
try:
return IPv6Network(value)
except ValueError:
raise errors.IPvAnyNetworkError()
pretty_email_regex = re.compile(r'([\w ]*?) *<(.*)> *')
def validate_email(value: Union[str]) -> Tuple[str, str]:
"""
Brutally simple email address validation. Note unlike most email address validation
* raw ip address (literal) domain parts are not allowed.
* "John Doe <local_part@domain.com>" style "pretty" email addresses are processed
* the local part check is extremely basic. This raises the possibility of unicode spoofing, but no better
solution is really possible.
* spaces are striped from the beginning and end of addresses but no error is raised
See RFC 5322 but treat it with suspicion, there seems to exist no universally acknowledged test for a valid email!
"""
if email_validator is None:
import_email_validator()
m = pretty_email_regex.fullmatch(value)
name: Optional[str] = None
if m:
name, value = m.groups()
email = value.strip()
try:
email_validator.validate_email(email, check_deliverability=False)
except email_validator.EmailNotValidError as e:
raise errors.EmailError() from e
at_index = email.index('@')
local_part = email[:at_index] # RFC 5321, local part must be case-sensitive.
global_part = email[at_index:].lower()
return name or local_part, local_part + global_part

66
lib/pydantic/parse.py Normal file
View file

@ -0,0 +1,66 @@
import json
import pickle
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Union
from .types import StrBytes
class Protocol(str, Enum):
json = 'json'
pickle = 'pickle'
def load_str_bytes(
b: StrBytes,
*,
content_type: str = None,
encoding: str = 'utf8',
proto: Protocol = None,
allow_pickle: bool = False,
json_loads: Callable[[str], Any] = json.loads,
) -> Any:
if proto is None and content_type:
if content_type.endswith(('json', 'javascript')):
pass
elif allow_pickle and content_type.endswith('pickle'):
proto = Protocol.pickle
else:
raise TypeError(f'Unknown content-type: {content_type}')
proto = proto or Protocol.json
if proto == Protocol.json:
if isinstance(b, bytes):
b = b.decode(encoding)
return json_loads(b)
elif proto == Protocol.pickle:
if not allow_pickle:
raise RuntimeError('Trying to decode with pickle with allow_pickle=False')
bb = b if isinstance(b, bytes) else b.encode()
return pickle.loads(bb)
else:
raise TypeError(f'Unknown protocol: {proto}')
def load_file(
path: Union[str, Path],
*,
content_type: str = None,
encoding: str = 'utf8',
proto: Protocol = None,
allow_pickle: bool = False,
json_loads: Callable[[str], Any] = json.loads,
) -> Any:
path = Path(path)
b = path.read_bytes()
if content_type is None:
if path.suffix in ('.js', '.json'):
proto = Protocol.json
elif path.suffix == '.pkl':
proto = Protocol.pickle
return load_str_bytes(
b, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle, json_loads=json_loads
)

0
lib/pydantic/py.typed Normal file
View file

1153
lib/pydantic/schema.py Normal file

File diff suppressed because it is too large Load diff

92
lib/pydantic/tools.py Normal file
View file

@ -0,0 +1,92 @@
import json
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Union
from .parse import Protocol, load_file, load_str_bytes
from .types import StrBytes
from .typing import display_as_type
__all__ = ('parse_file_as', 'parse_obj_as', 'parse_raw_as', 'schema_of', 'schema_json_of')
NameFactory = Union[str, Callable[[Type[Any]], str]]
if TYPE_CHECKING:
from .typing import DictStrAny
def _generate_parsing_type_name(type_: Any) -> str:
return f'ParsingModel[{display_as_type(type_)}]'
@lru_cache(maxsize=2048)
def _get_parsing_type(type_: Any, *, type_name: Optional[NameFactory] = None) -> Any:
from pydantic.main import create_model
if type_name is None:
type_name = _generate_parsing_type_name
if not isinstance(type_name, str):
type_name = type_name(type_)
return create_model(type_name, __root__=(type_, ...))
T = TypeVar('T')
def parse_obj_as(type_: Type[T], obj: Any, *, type_name: Optional[NameFactory] = None) -> T:
model_type = _get_parsing_type(type_, type_name=type_name) # type: ignore[arg-type]
return model_type(__root__=obj).__root__
def parse_file_as(
type_: Type[T],
path: Union[str, Path],
*,
content_type: str = None,
encoding: str = 'utf8',
proto: Protocol = None,
allow_pickle: bool = False,
json_loads: Callable[[str], Any] = json.loads,
type_name: Optional[NameFactory] = None,
) -> T:
obj = load_file(
path,
proto=proto,
content_type=content_type,
encoding=encoding,
allow_pickle=allow_pickle,
json_loads=json_loads,
)
return parse_obj_as(type_, obj, type_name=type_name)
def parse_raw_as(
type_: Type[T],
b: StrBytes,
*,
content_type: str = None,
encoding: str = 'utf8',
proto: Protocol = None,
allow_pickle: bool = False,
json_loads: Callable[[str], Any] = json.loads,
type_name: Optional[NameFactory] = None,
) -> T:
obj = load_str_bytes(
b,
proto=proto,
content_type=content_type,
encoding=encoding,
allow_pickle=allow_pickle,
json_loads=json_loads,
)
return parse_obj_as(type_, obj, type_name=type_name)
def schema_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_kwargs: Any) -> 'DictStrAny':
"""Generate a JSON schema (as dict) for the passed model or dynamically generated one"""
return _get_parsing_type(type_, type_name=title).schema(**schema_kwargs)
def schema_json_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_json_kwargs: Any) -> str:
"""Generate a JSON schema (as JSON) for the passed model or dynamically generated one"""
return _get_parsing_type(type_, type_name=title).schema_json(**schema_json_kwargs)

1187
lib/pydantic/types.py Normal file

File diff suppressed because it is too large Load diff

602
lib/pydantic/typing.py Normal file
View file

@ -0,0 +1,602 @@
import sys
from collections.abc import Callable
from os import PathLike
from typing import ( # type: ignore
TYPE_CHECKING,
AbstractSet,
Any,
Callable as TypingCallable,
ClassVar,
Dict,
ForwardRef,
Generator,
Iterable,
List,
Mapping,
NewType,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
_eval_type,
cast,
get_type_hints,
)
from typing_extensions import (
Annotated,
Final,
Literal,
NotRequired as TypedDictNotRequired,
Required as TypedDictRequired,
)
try:
from typing import _TypingBase as typing_base # type: ignore
except ImportError:
from typing import _Final as typing_base # type: ignore
try:
from typing import GenericAlias as TypingGenericAlias # type: ignore
except ImportError:
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
TypingGenericAlias = ()
try:
from types import UnionType as TypesUnionType # type: ignore
except ImportError:
# python < 3.10 does not have UnionType (str | int, byte | bool and so on)
TypesUnionType = ()
if sys.version_info < (3, 9):
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
return type_._evaluate(globalns, localns)
else:
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
# Even though it is the right signature for python 3.9, mypy complains with
# `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast...
return cast(Any, type_)._evaluate(globalns, localns, set())
if sys.version_info < (3, 9):
# Ensure we always get all the whole `Annotated` hint, not just the annotated type.
# For 3.7 to 3.8, `get_type_hints` doesn't recognize `typing_extensions.Annotated`,
# so it already returns the full annotation
get_all_type_hints = get_type_hints
else:
def get_all_type_hints(obj: Any, globalns: Any = None, localns: Any = None) -> Any:
return get_type_hints(obj, globalns, localns, include_extras=True)
_T = TypeVar('_T')
AnyCallable = TypingCallable[..., Any]
NoArgAnyCallable = TypingCallable[[], Any]
# workaround for https://github.com/python/mypy/issues/9496
AnyArgTCallable = TypingCallable[..., _T]
# Annotated[...] is implemented by returning an instance of one of these classes, depending on
# python/typing_extensions version.
AnnotatedTypeNames = {'AnnotatedMeta', '_AnnotatedAlias'}
if sys.version_info < (3, 8):
def get_origin(t: Type[Any]) -> Optional[Type[Any]]:
if type(t).__name__ in AnnotatedTypeNames:
# weirdly this is a runtime requirement, as well as for mypy
return cast(Type[Any], Annotated)
return getattr(t, '__origin__', None)
else:
from typing import get_origin as _typing_get_origin
def get_origin(tp: Type[Any]) -> Optional[Type[Any]]:
"""
We can't directly use `typing.get_origin` since we need a fallback to support
custom generic classes like `ConstrainedList`
It should be useless once https://github.com/cython/cython/issues/3537 is
solved and https://github.com/pydantic/pydantic/pull/1753 is merged.
"""
if type(tp).__name__ in AnnotatedTypeNames:
return cast(Type[Any], Annotated) # mypy complains about _SpecialForm
return _typing_get_origin(tp) or getattr(tp, '__origin__', None)
if sys.version_info < (3, 8):
from typing import _GenericAlias
def get_args(t: Type[Any]) -> Tuple[Any, ...]:
"""Compatibility version of get_args for python 3.7.
Mostly compatible with the python 3.8 `typing` module version
and able to handle almost all use cases.
"""
if type(t).__name__ in AnnotatedTypeNames:
return t.__args__ + t.__metadata__
if isinstance(t, _GenericAlias):
res = t.__args__
if t.__origin__ is Callable and res and res[0] is not Ellipsis:
res = (list(res[:-1]), res[-1])
return res
return getattr(t, '__args__', ())
else:
from typing import get_args as _typing_get_args
def _generic_get_args(tp: Type[Any]) -> Tuple[Any, ...]:
"""
In python 3.9, `typing.Dict`, `typing.List`, ...
do have an empty `__args__` by default (instead of the generic ~T for example).
In order to still support `Dict` for example and consider it as `Dict[Any, Any]`,
we retrieve the `_nparams` value that tells us how many parameters it needs.
"""
if hasattr(tp, '_nparams'):
return (Any,) * tp._nparams
# Special case for `tuple[()]`, which used to return ((),) with `typing.Tuple`
# in python 3.10- but now returns () for `tuple` and `Tuple`.
# This will probably be clarified in pydantic v2
try:
if tp == Tuple[()] or sys.version_info >= (3, 9) and tp == tuple[()]: # type: ignore[misc]
return ((),)
# there is a TypeError when compiled with cython
except TypeError: # pragma: no cover
pass
return ()
def get_args(tp: Type[Any]) -> Tuple[Any, ...]:
"""Get type arguments with all substitutions performed.
For unions, basic simplifications used by Union constructor are performed.
Examples::
get_args(Dict[str, int]) == (str, int)
get_args(int) == ()
get_args(Union[int, Union[T, int], str][int]) == (int, str)
get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
get_args(Callable[[], T][int]) == ([], int)
"""
if type(tp).__name__ in AnnotatedTypeNames:
return tp.__args__ + tp.__metadata__
# the fallback is needed for the same reasons as `get_origin` (see above)
return _typing_get_args(tp) or getattr(tp, '__args__', ()) or _generic_get_args(tp)
if sys.version_info < (3, 9):
def convert_generics(tp: Type[Any]) -> Type[Any]:
"""Python 3.9 and older only supports generics from `typing` module.
They convert strings to ForwardRef automatically.
Examples::
typing.List['Hero'] == typing.List[ForwardRef('Hero')]
"""
return tp
else:
from typing import _UnionGenericAlias # type: ignore
from typing_extensions import _AnnotatedAlias
def convert_generics(tp: Type[Any]) -> Type[Any]:
"""
Recursively searches for `str` type hints and replaces them with ForwardRef.
Examples::
convert_generics(list['Hero']) == list[ForwardRef('Hero')]
convert_generics(dict['Hero', 'Team']) == dict[ForwardRef('Hero'), ForwardRef('Team')]
convert_generics(typing.Dict['Hero', 'Team']) == typing.Dict[ForwardRef('Hero'), ForwardRef('Team')]
convert_generics(list[str | 'Hero'] | int) == list[str | ForwardRef('Hero')] | int
"""
origin = get_origin(tp)
if not origin or not hasattr(tp, '__args__'):
return tp
args = get_args(tp)
# typing.Annotated needs special treatment
if origin is Annotated:
return _AnnotatedAlias(convert_generics(args[0]), args[1:])
# recursively replace `str` instances inside of `GenericAlias` with `ForwardRef(arg)`
converted = tuple(
ForwardRef(arg) if isinstance(arg, str) and isinstance(tp, TypingGenericAlias) else convert_generics(arg)
for arg in args
)
if converted == args:
return tp
elif isinstance(tp, TypingGenericAlias):
return TypingGenericAlias(origin, converted)
elif isinstance(tp, TypesUnionType):
# recreate types.UnionType (PEP604, Python >= 3.10)
return _UnionGenericAlias(origin, converted)
else:
try:
setattr(tp, '__args__', converted)
except AttributeError:
pass
return tp
if sys.version_info < (3, 10):
def is_union(tp: Optional[Type[Any]]) -> bool:
return tp is Union
WithArgsTypes = (TypingGenericAlias,)
else:
import types
import typing
def is_union(tp: Optional[Type[Any]]) -> bool:
return tp is Union or tp is types.UnionType # noqa: E721
WithArgsTypes = (typing._GenericAlias, types.GenericAlias, types.UnionType)
if sys.version_info < (3, 9):
StrPath = Union[str, PathLike]
else:
StrPath = Union[str, PathLike]
# TODO: Once we switch to Cython 3 to handle generics properly
# (https://github.com/cython/cython/issues/2753), use following lines instead
# of the one above
# # os.PathLike only becomes subscriptable from Python 3.9 onwards
# StrPath = Union[str, PathLike[str]]
if TYPE_CHECKING:
from .fields import ModelField
TupleGenerator = Generator[Tuple[str, Any], None, None]
DictStrAny = Dict[str, Any]
DictAny = Dict[Any, Any]
SetStr = Set[str]
ListStr = List[str]
IntStr = Union[int, str]
AbstractSetIntStr = AbstractSet[IntStr]
DictIntStrAny = Dict[IntStr, Any]
MappingIntStrAny = Mapping[IntStr, Any]
CallableGenerator = Generator[AnyCallable, None, None]
ReprArgs = Sequence[Tuple[Optional[str], Any]]
AnyClassMethod = classmethod[Any]
__all__ = (
'AnyCallable',
'NoArgAnyCallable',
'NoneType',
'is_none_type',
'display_as_type',
'resolve_annotations',
'is_callable_type',
'is_literal_type',
'all_literal_values',
'is_namedtuple',
'is_typeddict',
'is_typeddict_special',
'is_new_type',
'new_type_supertype',
'is_classvar',
'is_finalvar',
'update_field_forward_refs',
'update_model_forward_refs',
'TupleGenerator',
'DictStrAny',
'DictAny',
'SetStr',
'ListStr',
'IntStr',
'AbstractSetIntStr',
'DictIntStrAny',
'CallableGenerator',
'ReprArgs',
'AnyClassMethod',
'CallableGenerator',
'WithArgsTypes',
'get_args',
'get_origin',
'get_sub_types',
'typing_base',
'get_all_type_hints',
'is_union',
'StrPath',
'MappingIntStrAny',
)
NoneType = None.__class__
NONE_TYPES: Tuple[Any, Any, Any] = (None, NoneType, Literal[None])
if sys.version_info < (3, 8):
# Even though this implementation is slower, we need it for python 3.7:
# In python 3.7 "Literal" is not a builtin type and uses a different
# mechanism.
# for this reason `Literal[None] is Literal[None]` evaluates to `False`,
# breaking the faster implementation used for the other python versions.
def is_none_type(type_: Any) -> bool:
return type_ in NONE_TYPES
elif sys.version_info[:2] == (3, 8):
def is_none_type(type_: Any) -> bool:
for none_type in NONE_TYPES:
if type_ is none_type:
return True
# With python 3.8, specifically 3.8.10, Literal "is" check sare very flakey
# can change on very subtle changes like use of types in other modules,
# hopefully this check avoids that issue.
if is_literal_type(type_): # pragma: no cover
return all_literal_values(type_) == (None,)
return False
else:
def is_none_type(type_: Any) -> bool:
for none_type in NONE_TYPES:
if type_ is none_type:
return True
return False
def display_as_type(v: Type[Any]) -> str:
if not isinstance(v, typing_base) and not isinstance(v, WithArgsTypes) and not isinstance(v, type):
v = v.__class__
if is_union(get_origin(v)):
return f'Union[{", ".join(map(display_as_type, get_args(v)))}]'
if isinstance(v, WithArgsTypes):
# Generic alias are constructs like `list[int]`
return str(v).replace('typing.', '')
try:
return v.__name__
except AttributeError:
# happens with typing objects
return str(v).replace('typing.', '')
def resolve_annotations(raw_annotations: Dict[str, Type[Any]], module_name: Optional[str]) -> Dict[str, Type[Any]]:
"""
Partially taken from typing.get_type_hints.
Resolve string or ForwardRef annotations into type objects if possible.
"""
base_globals: Optional[Dict[str, Any]] = None
if module_name:
try:
module = sys.modules[module_name]
except KeyError:
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
pass
else:
base_globals = module.__dict__
annotations = {}
for name, value in raw_annotations.items():
if isinstance(value, str):
if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1):
value = ForwardRef(value, is_argument=False, is_class=True)
else:
value = ForwardRef(value, is_argument=False)
try:
value = _eval_type(value, base_globals, None)
except NameError:
# this is ok, it can be fixed with update_forward_refs
pass
annotations[name] = value
return annotations
def is_callable_type(type_: Type[Any]) -> bool:
return type_ is Callable or get_origin(type_) is Callable
def is_literal_type(type_: Type[Any]) -> bool:
return Literal is not None and get_origin(type_) is Literal
def literal_values(type_: Type[Any]) -> Tuple[Any, ...]:
return get_args(type_)
def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]:
"""
This method is used to retrieve all Literal values as
Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586)
e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`
"""
if not is_literal_type(type_):
return (type_,)
values = literal_values(type_)
return tuple(x for value in values for x in all_literal_values(value))
def is_namedtuple(type_: Type[Any]) -> bool:
"""
Check if a given class is a named tuple.
It can be either a `typing.NamedTuple` or `collections.namedtuple`
"""
from .utils import lenient_issubclass
return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields')
def is_typeddict(type_: Type[Any]) -> bool:
"""
Check if a given class is a typed dict (from `typing` or `typing_extensions`)
In 3.10, there will be a public method (https://docs.python.org/3.10/library/typing.html#typing.is_typeddict)
"""
from .utils import lenient_issubclass
return lenient_issubclass(type_, dict) and hasattr(type_, '__total__')
def _check_typeddict_special(type_: Any) -> bool:
return type_ is TypedDictRequired or type_ is TypedDictNotRequired
def is_typeddict_special(type_: Any) -> bool:
"""
Check if type is a TypedDict special form (Required or NotRequired).
"""
return _check_typeddict_special(type_) or _check_typeddict_special(get_origin(type_))
test_type = NewType('test_type', str)
def is_new_type(type_: Type[Any]) -> bool:
"""
Check whether type_ was created using typing.NewType
"""
return isinstance(type_, test_type.__class__) and hasattr(type_, '__supertype__') # type: ignore
def new_type_supertype(type_: Type[Any]) -> Type[Any]:
while hasattr(type_, '__supertype__'):
type_ = type_.__supertype__
return type_
def _check_classvar(v: Optional[Type[Any]]) -> bool:
if v is None:
return False
return v.__class__ == ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar'
def _check_finalvar(v: Optional[Type[Any]]) -> bool:
"""
Check if a given type is a `typing.Final` type.
"""
if v is None:
return False
return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final')
def is_classvar(ann_type: Type[Any]) -> bool:
if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)):
return True
# this is an ugly workaround for class vars that contain forward references and are therefore themselves
# forward references, see #3679
if ann_type.__class__ == ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['):
return True
return False
def is_finalvar(ann_type: Type[Any]) -> bool:
return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type))
def update_field_forward_refs(field: 'ModelField', globalns: Any, localns: Any) -> None:
"""
Try to update ForwardRefs on fields based on this ModelField, globalns and localns.
"""
prepare = False
if field.type_.__class__ == ForwardRef:
prepare = True
field.type_ = evaluate_forwardref(field.type_, globalns, localns or None)
if field.outer_type_.__class__ == ForwardRef:
prepare = True
field.outer_type_ = evaluate_forwardref(field.outer_type_, globalns, localns or None)
if prepare:
field.prepare()
if field.sub_fields:
for sub_f in field.sub_fields:
update_field_forward_refs(sub_f, globalns=globalns, localns=localns)
if field.discriminator_key is not None:
field.prepare_discriminated_union_sub_fields()
def update_model_forward_refs(
model: Type[Any],
fields: Iterable['ModelField'],
json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable],
localns: 'DictStrAny',
exc_to_suppress: Tuple[Type[BaseException], ...] = (),
) -> None:
"""
Try to update model fields ForwardRefs based on model and localns.
"""
if model.__module__ in sys.modules:
globalns = sys.modules[model.__module__].__dict__.copy()
else:
globalns = {}
globalns.setdefault(model.__name__, model)
for f in fields:
try:
update_field_forward_refs(f, globalns=globalns, localns=localns)
except exc_to_suppress:
pass
for key in set(json_encoders.keys()):
if isinstance(key, str):
fr: ForwardRef = ForwardRef(key)
elif isinstance(key, ForwardRef):
fr = key
else:
continue
try:
new_key = evaluate_forwardref(fr, globalns, localns or None)
except exc_to_suppress: # pragma: no cover
continue
json_encoders[new_key] = json_encoders.pop(key)
def get_class(type_: Type[Any]) -> Union[None, bool, Type[Any]]:
"""
Tries to get the class of a Type[T] annotation. Returns True if Type is used
without brackets. Otherwise returns None.
"""
if type_ is type:
return True
if get_origin(type_) is None:
return None
args = get_args(type_)
if not args or not isinstance(args[0], type):
return True
else:
return args[0]
def get_sub_types(tp: Any) -> List[Any]:
"""
Return all the types that are allowed by type `tp`
`tp` can be a `Union` of allowed types or an `Annotated` type
"""
origin = get_origin(tp)
if origin is Annotated:
return get_sub_types(get_args(tp)[0])
elif is_union(origin):
return [x for t in get_args(tp) for x in get_sub_types(t)]
else:
return [tp]

841
lib/pydantic/utils.py Normal file
View file

@ -0,0 +1,841 @@
import keyword
import warnings
import weakref
from collections import OrderedDict, defaultdict, deque
from copy import deepcopy
from itertools import islice, zip_longest
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Collection,
Dict,
Generator,
Iterable,
Iterator,
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)
from typing_extensions import Annotated
from .errors import ConfigError
from .typing import (
NoneType,
WithArgsTypes,
all_literal_values,
display_as_type,
get_args,
get_origin,
is_literal_type,
is_union,
)
from .version import version_info
if TYPE_CHECKING:
from inspect import Signature
from pathlib import Path
from .config import BaseConfig
from .dataclasses import Dataclass
from .fields import ModelField
from .main import BaseModel
from .typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs
RichReprResult = Iterable[Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]]
__all__ = (
'import_string',
'sequence_like',
'validate_field_name',
'lenient_isinstance',
'lenient_issubclass',
'in_ipython',
'is_valid_identifier',
'deep_update',
'update_not_none',
'almost_equal_floats',
'get_model',
'to_camel',
'is_valid_field',
'smart_deepcopy',
'PyObjectStr',
'Representation',
'GetterDict',
'ValueItems',
'version_info', # required here to match behaviour in v1.3
'ClassAttribute',
'path_type',
'ROOT_KEY',
'get_unique_discriminator_alias',
'get_discriminator_alias_and_values',
'DUNDER_ATTRIBUTES',
'LimitedDict',
)
ROOT_KEY = '__root__'
# these are types that are returned unchanged by deepcopy
IMMUTABLE_NON_COLLECTIONS_TYPES: Set[Type[Any]] = {
int,
float,
complex,
str,
bool,
bytes,
type,
NoneType,
FunctionType,
BuiltinFunctionType,
LambdaType,
weakref.ref,
CodeType,
# note: including ModuleType will differ from behaviour of deepcopy by not producing error.
# It might be not a good idea in general, but considering that this function used only internally
# against default values of fields, this will allow to actually have a field with module as default value
ModuleType,
NotImplemented.__class__,
Ellipsis.__class__,
}
# these are types that if empty, might be copied with simple copy() instead of deepcopy()
BUILTIN_COLLECTIONS: Set[Type[Any]] = {
list,
set,
tuple,
frozenset,
dict,
OrderedDict,
defaultdict,
deque,
}
def import_string(dotted_path: str) -> Any:
"""
Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the
last name in the path. Raise ImportError if the import fails.
"""
from importlib import import_module
try:
module_path, class_name = dotted_path.strip(' ').rsplit('.', 1)
except ValueError as e:
raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e
module = import_module(module_path)
try:
return getattr(module, class_name)
except AttributeError as e:
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e
def truncate(v: Union[str], *, max_len: int = 80) -> str:
"""
Truncate a value and add a unicode ellipsis (three dots) to the end if it was too long
"""
warnings.warn('`truncate` is no-longer used by pydantic and is deprecated', DeprecationWarning)
if isinstance(v, str) and len(v) > (max_len - 2):
# -3 so quote + string + … + quote has correct length
return (v[: (max_len - 3)] + '').__repr__()
try:
v = v.__repr__()
except TypeError:
v = v.__class__.__repr__(v) # in case v is a type
if len(v) > max_len:
v = v[: max_len - 1] + ''
return v
def sequence_like(v: Any) -> bool:
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None:
"""
Ensure that the field's name does not shadow an existing attribute of the model.
"""
for base in bases:
if getattr(base, field_name, None):
raise NameError(
f'Field name "{field_name}" shadows a BaseModel attribute; '
f'use a different field name with "alias=\'{field_name}\'".'
)
def lenient_isinstance(o: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
try:
return isinstance(o, class_or_tuple) # type: ignore[arg-type]
except TypeError:
return False
def lenient_issubclass(cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
except TypeError:
if isinstance(cls, WithArgsTypes):
return False
raise # pragma: no cover
def in_ipython() -> bool:
"""
Check whether we're in an ipython environment, including jupyter notebooks.
"""
try:
eval('__IPYTHON__')
except NameError:
return False
else: # pragma: no cover
return True
def is_valid_identifier(identifier: str) -> bool:
"""
Checks that a string is a valid identifier and not a Python keyword.
:param identifier: The identifier to test.
:return: True if the identifier is valid.
"""
return identifier.isidentifier() and not keyword.iskeyword(identifier)
KeyType = TypeVar('KeyType')
def deep_update(mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]) -> Dict[KeyType, Any]:
updated_mapping = mapping.copy()
for updating_mapping in updating_mappings:
for k, v in updating_mapping.items():
if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
updated_mapping[k] = deep_update(updated_mapping[k], v)
else:
updated_mapping[k] = v
return updated_mapping
def update_not_none(mapping: Dict[Any, Any], **update: Any) -> None:
mapping.update({k: v for k, v in update.items() if v is not None})
def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool:
"""
Return True if two floats are almost equal
"""
return abs(value_1 - value_2) <= delta
def generate_model_signature(
init: Callable[..., None], fields: Dict[str, 'ModelField'], config: Type['BaseConfig']
) -> 'Signature':
"""
Generate signature for model based on its fields
"""
from inspect import Parameter, Signature, signature
from .config import Extra
present_params = signature(init).parameters.values()
merged_params: Dict[str, Parameter] = {}
var_kw = None
use_var_kw = False
for param in islice(present_params, 1, None): # skip self arg
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = param
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = config.allow_population_by_field_name
for field_name, field in fields.items():
param_name = field.alias
if field_name in merged_params or param_name in merged_params:
continue
elif not is_valid_identifier(param_name):
if allow_names and is_valid_identifier(field_name):
param_name = field_name
else:
use_var_kw = True
continue
# TODO: replace annotation with actual expected types once #1055 solved
kwargs = {'default': field.default} if not field.required else {}
merged_params[param_name] = Parameter(
param_name, Parameter.KEYWORD_ONLY, annotation=field.annotation, **kwargs
)
if config.extra is Extra.allow:
use_var_kw = True
if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
# else start from var_kw
var_kw_name = var_kw.name
# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
return Signature(parameters=list(merged_params.values()), return_annotation=None)
def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']:
from .main import BaseModel
try:
model_cls = obj.__pydantic_model__ # type: ignore
except AttributeError:
model_cls = obj
if not issubclass(model_cls, BaseModel):
raise TypeError('Unsupported type, must be either BaseModel or dataclass')
return model_cls
def to_camel(string: str) -> str:
return ''.join(word.capitalize() for word in string.split('_'))
def to_lower_camel(string: str) -> str:
if len(string) >= 1:
pascal_string = to_camel(string)
return pascal_string[0].lower() + pascal_string[1:]
return string.lower()
T = TypeVar('T')
def unique_list(
input_list: Union[List[T], Tuple[T, ...]],
*,
name_factory: Callable[[T], str] = str,
) -> List[T]:
"""
Make a list unique while maintaining order.
We update the list if another one with the same name is set
(e.g. root validator overridden in subclass)
"""
result: List[T] = []
result_names: List[str] = []
for v in input_list:
v_name = name_factory(v)
if v_name not in result_names:
result_names.append(v_name)
result.append(v)
else:
result[result_names.index(v_name)] = v
return result
class PyObjectStr(str):
"""
String class where repr doesn't include quotes. Useful with Representation when you want to return a string
representation of something that valid (or pseudo-valid) python.
"""
def __repr__(self) -> str:
return str(self)
class Representation:
"""
Mixin to provide __str__, __repr__, and __pretty__ methods. See #884 for more details.
__pretty__ is used by [devtools](https://python-devtools.helpmanual.io/) to provide human readable representations
of objects.
"""
__slots__: Tuple[str, ...] = tuple()
def __repr_args__(self) -> 'ReprArgs':
"""
Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
Can either return:
* name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
* or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
"""
attrs = ((s, getattr(self, s)) for s in self.__slots__)
return [(a, v) for a, v in attrs if v is not None]
def __repr_name__(self) -> str:
"""
Name of the instance's class, used in __repr__.
"""
return self.__class__.__name__
def __repr_str__(self, join_str: str) -> str:
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]:
"""
Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects
"""
yield self.__repr_name__() + '('
yield 1
for name, value in self.__repr_args__():
if name is not None:
yield name + '='
yield fmt(value)
yield ','
yield 0
yield -1
yield ')'
def __str__(self) -> str:
return self.__repr_str__(' ')
def __repr__(self) -> str:
return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
def __rich_repr__(self) -> 'RichReprResult':
"""Get fields for Rich library"""
for name, field_repr in self.__repr_args__():
if name is None:
yield field_repr
else:
yield name, field_repr
class GetterDict(Representation):
"""
Hack to make object's smell just enough like dicts for validate_model.
We can't inherit from Mapping[str, Any] because it upsets cython so we have to implement all methods ourselves.
"""
__slots__ = ('_obj',)
def __init__(self, obj: Any):
self._obj = obj
def __getitem__(self, key: str) -> Any:
try:
return getattr(self._obj, key)
except AttributeError as e:
raise KeyError(key) from e
def get(self, key: Any, default: Any = None) -> Any:
return getattr(self._obj, key, default)
def extra_keys(self) -> Set[Any]:
"""
We don't want to get any other attributes of obj if the model didn't explicitly ask for them
"""
return set()
def keys(self) -> List[Any]:
"""
Keys of the pseudo dictionary, uses a list not set so order information can be maintained like python
dictionaries.
"""
return list(self)
def values(self) -> List[Any]:
return [self[k] for k in self]
def items(self) -> Iterator[Tuple[str, Any]]:
for k in self:
yield k, self.get(k)
def __iter__(self) -> Iterator[str]:
for name in dir(self._obj):
if not name.startswith('_'):
yield name
def __len__(self) -> int:
return sum(1 for _ in self)
def __contains__(self, item: Any) -> bool:
return item in self.keys()
def __eq__(self, other: Any) -> bool:
return dict(self) == dict(other.items())
def __repr_args__(self) -> 'ReprArgs':
return [(None, dict(self))]
def __repr_name__(self) -> str:
return f'GetterDict[{display_as_type(self._obj)}]'
class ValueItems(Representation):
"""
Class for more convenient calculation of excluded or included fields on values.
"""
__slots__ = ('_items', '_type')
def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> None:
items = self._coerce_items(items)
if isinstance(value, (list, tuple)):
items = self._normalize_indexes(items, len(value))
self._items: 'MappingIntStrAny' = items
def is_excluded(self, item: Any) -> bool:
"""
Check if item is fully excluded.
:param item: key or index of a value
"""
return self.is_true(self._items.get(item))
def is_included(self, item: Any) -> bool:
"""
Check if value is contained in self._items
:param item: key or index of value
"""
return item in self._items
def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]:
"""
:param e: key or index of element on value
:return: raw values for element if self._items is dict and contain needed element
"""
item = self._items.get(e)
return item if not self.is_true(item) else None
def _normalize_indexes(self, items: 'MappingIntStrAny', v_length: int) -> 'DictIntStrAny':
"""
:param items: dict or set of indexes which will be normalized
:param v_length: length of sequence indexes of which will be
>>> self._normalize_indexes({0: True, -2: True, -1: True}, 4)
{0: True, 2: True, 3: True}
>>> self._normalize_indexes({'__all__': True}, 4)
{0: True, 1: True, 2: True, 3: True}
"""
normalized_items: 'DictIntStrAny' = {}
all_items = None
for i, v in items.items():
if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or self.is_true(v)):
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
if i == '__all__':
all_items = self._coerce_value(v)
continue
if not isinstance(i, int):
raise TypeError(
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
'expected integer keys or keyword "__all__"'
)
normalized_i = v_length + i if i < 0 else i
normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i))
if not all_items:
return normalized_items
if self.is_true(all_items):
for i in range(v_length):
normalized_items.setdefault(i, ...)
return normalized_items
for i in range(v_length):
normalized_item = normalized_items.setdefault(i, {})
if not self.is_true(normalized_item):
normalized_items[i] = self.merge(all_items, normalized_item)
return normalized_items
@classmethod
def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any:
"""
Merge a ``base`` item with an ``override`` item.
Both ``base`` and ``override`` are converted to dictionaries if possible.
Sets are converted to dictionaries with the sets entries as keys and
Ellipsis as values.
Each key-value pair existing in ``base`` is merged with ``override``,
while the rest of the key-value pairs are updated recursively with this function.
Merging takes place based on the "union" of keys if ``intersect`` is
set to ``False`` (default) and on the intersection of keys if
``intersect`` is set to ``True``.
"""
override = cls._coerce_value(override)
base = cls._coerce_value(base)
if override is None:
return base
if cls.is_true(base) or base is None:
return override
if cls.is_true(override):
return base if intersect else override
# intersection or union of keys while preserving ordering:
if intersect:
merge_keys = [k for k in base if k in override] + [k for k in override if k in base]
else:
merge_keys = list(base) + [k for k in override if k not in base]
merged: 'DictIntStrAny' = {}
for k in merge_keys:
merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect)
if merged_item is not None:
merged[k] = merged_item
return merged
@staticmethod
def _coerce_items(items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> 'MappingIntStrAny':
if isinstance(items, Mapping):
pass
elif isinstance(items, AbstractSet):
items = dict.fromkeys(items, ...)
else:
class_name = getattr(items, '__class__', '???')
assert_never(
items,
f'Unexpected type of exclude value {class_name}',
)
return items
@classmethod
def _coerce_value(cls, value: Any) -> Any:
if value is None or cls.is_true(value):
return value
return cls._coerce_items(value)
@staticmethod
def is_true(v: Any) -> bool:
return v is True or v is ...
def __repr_args__(self) -> 'ReprArgs':
return [(None, self._items)]
class ClassAttribute:
"""
Hide class attribute from its instances
"""
__slots__ = (
'name',
'value',
)
def __init__(self, name: str, value: Any) -> None:
self.name = name
self.value = value
def __get__(self, instance: Any, owner: Type[Any]) -> None:
if instance is None:
return self.value
raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only')
path_types = {
'is_dir': 'directory',
'is_file': 'file',
'is_mount': 'mount point',
'is_symlink': 'symlink',
'is_block_device': 'block device',
'is_char_device': 'char device',
'is_fifo': 'FIFO',
'is_socket': 'socket',
}
def path_type(p: 'Path') -> str:
"""
Find out what sort of thing a path is.
"""
assert p.exists(), 'path does not exist'
for method, name in path_types.items():
if getattr(p, method)():
return name
return 'unknown'
Obj = TypeVar('Obj')
def smart_deepcopy(obj: Obj) -> Obj:
"""
Return type as is for immutable built-in types
Use obj.copy() for built-in empty collections
Use copy.deepcopy() for non-empty collections and unknown objects
"""
obj_type = obj.__class__
if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES:
return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway
try:
if not obj and obj_type in BUILTIN_COLLECTIONS:
# faster way for empty collections, no need to copy its members
return obj if obj_type is tuple else obj.copy() # type: ignore # tuple doesn't have copy method
except (TypeError, ValueError, RuntimeError):
# do we really dare to catch ALL errors? Seems a bit risky
pass
return deepcopy(obj) # slowest way when we actually might need a deepcopy
def is_valid_field(name: str) -> bool:
if not name.startswith('_'):
return True
return ROOT_KEY == name
DUNDER_ATTRIBUTES = {
'__annotations__',
'__classcell__',
'__doc__',
'__module__',
'__orig_bases__',
'__orig_class__',
'__qualname__',
}
def is_valid_private_name(name: str) -> bool:
return not is_valid_field(name) and name not in DUNDER_ATTRIBUTES
_EMPTY = object()
def all_identical(left: Iterable[Any], right: Iterable[Any]) -> bool:
"""
Check that the items of `left` are the same objects as those in `right`.
>>> a, b = object(), object()
>>> all_identical([a, b, a], [a, b, a])
True
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
False
"""
for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY):
if left_item is not right_item:
return False
return True
def assert_never(obj: NoReturn, msg: str) -> NoReturn:
"""
Helper to make sure that we have covered all possible types.
This is mostly useful for ``mypy``, docs:
https://mypy.readthedocs.io/en/latest/literal_types.html#exhaustive-checks
"""
raise TypeError(msg)
def get_unique_discriminator_alias(all_aliases: Collection[str], discriminator_key: str) -> str:
"""Validate that all aliases are the same and if that's the case return the alias"""
unique_aliases = set(all_aliases)
if len(unique_aliases) > 1:
raise ConfigError(
f'Aliases for discriminator {discriminator_key!r} must be the same (got {", ".join(sorted(all_aliases))})'
)
return unique_aliases.pop()
def get_discriminator_alias_and_values(tp: Any, discriminator_key: str) -> Tuple[str, Tuple[str, ...]]:
"""
Get alias and all valid values in the `Literal` type of the discriminator field
`tp` can be a `BaseModel` class or directly an `Annotated` `Union` of many.
"""
is_root_model = getattr(tp, '__custom_root_type__', False)
if get_origin(tp) is Annotated:
tp = get_args(tp)[0]
if hasattr(tp, '__pydantic_model__'):
tp = tp.__pydantic_model__
if is_union(get_origin(tp)):
alias, all_values = _get_union_alias_and_all_values(tp, discriminator_key)
return alias, tuple(v for values in all_values for v in values)
elif is_root_model:
union_type = tp.__fields__[ROOT_KEY].type_
alias, all_values = _get_union_alias_and_all_values(union_type, discriminator_key)
if len(set(all_values)) > 1:
raise ConfigError(
f'Field {discriminator_key!r} is not the same for all submodels of {display_as_type(tp)!r}'
)
return alias, all_values[0]
else:
try:
t_discriminator_type = tp.__fields__[discriminator_key].type_
except AttributeError as e:
raise TypeError(f'Type {tp.__name__!r} is not a valid `BaseModel` or `dataclass`') from e
except KeyError as e:
raise ConfigError(f'Model {tp.__name__!r} needs a discriminator field for key {discriminator_key!r}') from e
if not is_literal_type(t_discriminator_type):
raise ConfigError(f'Field {discriminator_key!r} of model {tp.__name__!r} needs to be a `Literal`')
return tp.__fields__[discriminator_key].alias, all_literal_values(t_discriminator_type)
def _get_union_alias_and_all_values(
union_type: Type[Any], discriminator_key: str
) -> Tuple[str, Tuple[Tuple[str, ...], ...]]:
zipped_aliases_values = [get_discriminator_alias_and_values(t, discriminator_key) for t in get_args(union_type)]
# unzip: [('alias_a',('v1', 'v2)), ('alias_b', ('v3',))] => [('alias_a', 'alias_b'), (('v1', 'v2'), ('v3',))]
all_aliases, all_values = zip(*zipped_aliases_values)
return get_unique_discriminator_alias(all_aliases, discriminator_key), all_values
KT = TypeVar('KT')
VT = TypeVar('VT')
if TYPE_CHECKING:
# Annoying inheriting from `MutableMapping` and `dict` breaks cython, hence this work around
class LimitedDict(dict, MutableMapping[KT, VT]): # type: ignore[type-arg]
def __init__(self, size_limit: int = 1000):
...
else:
class LimitedDict(dict):
"""
Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage.
Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache.
Annoying inheriting from `MutableMapping` breaks cython.
"""
def __init__(self, size_limit: int = 1000):
self.size_limit = size_limit
super().__init__()
def __setitem__(self, __key: Any, __value: Any) -> None:
super().__setitem__(__key, __value)
if len(self) > self.size_limit:
excess = len(self) - self.size_limit + self.size_limit // 10
to_remove = list(self.keys())[:excess]
for key in to_remove:
del self[key]
def __class_getitem__(cls, *args: Any) -> Any:
# to avoid errors with 3.7
pass

765
lib/pydantic/validators.py Normal file
View file

@ -0,0 +1,765 @@
import math
import re
from collections import OrderedDict, deque
from collections.abc import Hashable as CollectionsHashable
from datetime import date, datetime, time, timedelta
from decimal import Decimal, DecimalException
from enum import Enum, IntEnum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Deque,
Dict,
ForwardRef,
FrozenSet,
Generator,
Hashable,
List,
NamedTuple,
Pattern,
Set,
Tuple,
Type,
TypeVar,
Union,
)
from uuid import UUID
from . import errors
from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time
from .typing import (
AnyCallable,
all_literal_values,
display_as_type,
get_class,
is_callable_type,
is_literal_type,
is_namedtuple,
is_none_type,
is_typeddict,
)
from .utils import almost_equal_floats, lenient_issubclass, sequence_like
if TYPE_CHECKING:
from typing_extensions import Literal, TypedDict
from .config import BaseConfig
from .fields import ModelField
from .types import ConstrainedDecimal, ConstrainedFloat, ConstrainedInt
ConstrainedNumber = Union[ConstrainedDecimal, ConstrainedFloat, ConstrainedInt]
AnyOrderedDict = OrderedDict[Any, Any]
Number = Union[int, float, Decimal]
StrBytes = Union[str, bytes]
def str_validator(v: Any) -> Union[str]:
if isinstance(v, str):
if isinstance(v, Enum):
return v.value
else:
return v
elif isinstance(v, (float, int, Decimal)):
# is there anything else we want to add here? If you think so, create an issue.
return str(v)
elif isinstance(v, (bytes, bytearray)):
return v.decode()
else:
raise errors.StrError()
def strict_str_validator(v: Any) -> Union[str]:
if isinstance(v, str) and not isinstance(v, Enum):
return v
raise errors.StrError()
def bytes_validator(v: Any) -> Union[bytes]:
if isinstance(v, bytes):
return v
elif isinstance(v, bytearray):
return bytes(v)
elif isinstance(v, str):
return v.encode()
elif isinstance(v, (float, int, Decimal)):
return str(v).encode()
else:
raise errors.BytesError()
def strict_bytes_validator(v: Any) -> Union[bytes]:
if isinstance(v, bytes):
return v
elif isinstance(v, bytearray):
return bytes(v)
else:
raise errors.BytesError()
BOOL_FALSE = {0, '0', 'off', 'f', 'false', 'n', 'no'}
BOOL_TRUE = {1, '1', 'on', 't', 'true', 'y', 'yes'}
def bool_validator(v: Any) -> bool:
if v is True or v is False:
return v
if isinstance(v, bytes):
v = v.decode()
if isinstance(v, str):
v = v.lower()
try:
if v in BOOL_TRUE:
return True
if v in BOOL_FALSE:
return False
except TypeError:
raise errors.BoolError()
raise errors.BoolError()
# matches the default limit cpython, see https://github.com/python/cpython/pull/96500
max_str_int = 4_300
def int_validator(v: Any) -> int:
if isinstance(v, int) and not (v is True or v is False):
return v
# see https://github.com/pydantic/pydantic/issues/1477 and in turn, https://github.com/python/cpython/issues/95778
# this check should be unnecessary once patch releases are out for 3.7, 3.8, 3.9 and 3.10
# but better to check here until then.
# NOTICE: this does not fully protect user from the DOS risk since the standard library JSON implementation
# (and other std lib modules like xml) use `int()` and are likely called before this, the best workaround is to
# 1. update to the latest patch release of python once released, 2. use a different JSON library like ujson
if isinstance(v, (str, bytes, bytearray)) and len(v) > max_str_int:
raise errors.IntegerError()
try:
return int(v)
except (TypeError, ValueError, OverflowError):
raise errors.IntegerError()
def strict_int_validator(v: Any) -> int:
if isinstance(v, int) and not (v is True or v is False):
return v
raise errors.IntegerError()
def float_validator(v: Any) -> float:
if isinstance(v, float):
return v
try:
return float(v)
except (TypeError, ValueError):
raise errors.FloatError()
def strict_float_validator(v: Any) -> float:
if isinstance(v, float):
return v
raise errors.FloatError()
def float_finite_validator(v: 'Number', field: 'ModelField', config: 'BaseConfig') -> 'Number':
allow_inf_nan = getattr(field.type_, 'allow_inf_nan', None)
if allow_inf_nan is None:
allow_inf_nan = config.allow_inf_nan
if allow_inf_nan is False and (math.isnan(v) or math.isinf(v)):
raise errors.NumberNotFiniteError()
return v
def number_multiple_validator(v: 'Number', field: 'ModelField') -> 'Number':
field_type: ConstrainedNumber = field.type_
if field_type.multiple_of is not None:
mod = float(v) / float(field_type.multiple_of) % 1
if not almost_equal_floats(mod, 0.0) and not almost_equal_floats(mod, 1.0):
raise errors.NumberNotMultipleError(multiple_of=field_type.multiple_of)
return v
def number_size_validator(v: 'Number', field: 'ModelField') -> 'Number':
field_type: ConstrainedNumber = field.type_
if field_type.gt is not None and not v > field_type.gt:
raise errors.NumberNotGtError(limit_value=field_type.gt)
elif field_type.ge is not None and not v >= field_type.ge:
raise errors.NumberNotGeError(limit_value=field_type.ge)
if field_type.lt is not None and not v < field_type.lt:
raise errors.NumberNotLtError(limit_value=field_type.lt)
if field_type.le is not None and not v <= field_type.le:
raise errors.NumberNotLeError(limit_value=field_type.le)
return v
def constant_validator(v: 'Any', field: 'ModelField') -> 'Any':
"""Validate ``const`` fields.
The value provided for a ``const`` field must be equal to the default value
of the field. This is to support the keyword of the same name in JSON
Schema.
"""
if v != field.default:
raise errors.WrongConstantError(given=v, permitted=[field.default])
return v
def anystr_length_validator(v: 'StrBytes', config: 'BaseConfig') -> 'StrBytes':
v_len = len(v)
min_length = config.min_anystr_length
if v_len < min_length:
raise errors.AnyStrMinLengthError(limit_value=min_length)
max_length = config.max_anystr_length
if max_length is not None and v_len > max_length:
raise errors.AnyStrMaxLengthError(limit_value=max_length)
return v
def anystr_strip_whitespace(v: 'StrBytes') -> 'StrBytes':
return v.strip()
def anystr_upper(v: 'StrBytes') -> 'StrBytes':
return v.upper()
def anystr_lower(v: 'StrBytes') -> 'StrBytes':
return v.lower()
def ordered_dict_validator(v: Any) -> 'AnyOrderedDict':
if isinstance(v, OrderedDict):
return v
try:
return OrderedDict(v)
except (TypeError, ValueError):
raise errors.DictError()
def dict_validator(v: Any) -> Dict[Any, Any]:
if isinstance(v, dict):
return v
try:
return dict(v)
except (TypeError, ValueError):
raise errors.DictError()
def list_validator(v: Any) -> List[Any]:
if isinstance(v, list):
return v
elif sequence_like(v):
return list(v)
else:
raise errors.ListError()
def tuple_validator(v: Any) -> Tuple[Any, ...]:
if isinstance(v, tuple):
return v
elif sequence_like(v):
return tuple(v)
else:
raise errors.TupleError()
def set_validator(v: Any) -> Set[Any]:
if isinstance(v, set):
return v
elif sequence_like(v):
return set(v)
else:
raise errors.SetError()
def frozenset_validator(v: Any) -> FrozenSet[Any]:
if isinstance(v, frozenset):
return v
elif sequence_like(v):
return frozenset(v)
else:
raise errors.FrozenSetError()
def deque_validator(v: Any) -> Deque[Any]:
if isinstance(v, deque):
return v
elif sequence_like(v):
return deque(v)
else:
raise errors.DequeError()
def enum_member_validator(v: Any, field: 'ModelField', config: 'BaseConfig') -> Enum:
try:
enum_v = field.type_(v)
except ValueError:
# field.type_ should be an enum, so will be iterable
raise errors.EnumMemberError(enum_values=list(field.type_))
return enum_v.value if config.use_enum_values else enum_v
def uuid_validator(v: Any, field: 'ModelField') -> UUID:
try:
if isinstance(v, str):
v = UUID(v)
elif isinstance(v, (bytes, bytearray)):
try:
v = UUID(v.decode())
except ValueError:
# 16 bytes in big-endian order as the bytes argument fail
# the above check
v = UUID(bytes=v)
except ValueError:
raise errors.UUIDError()
if not isinstance(v, UUID):
raise errors.UUIDError()
required_version = getattr(field.type_, '_required_version', None)
if required_version and v.version != required_version:
raise errors.UUIDVersionError(required_version=required_version)
return v
def decimal_validator(v: Any) -> Decimal:
if isinstance(v, Decimal):
return v
elif isinstance(v, (bytes, bytearray)):
v = v.decode()
v = str(v).strip()
try:
v = Decimal(v)
except DecimalException:
raise errors.DecimalError()
if not v.is_finite():
raise errors.DecimalIsNotFiniteError()
return v
def hashable_validator(v: Any) -> Hashable:
if isinstance(v, Hashable):
return v
raise errors.HashableError()
def ip_v4_address_validator(v: Any) -> IPv4Address:
if isinstance(v, IPv4Address):
return v
try:
return IPv4Address(v)
except ValueError:
raise errors.IPv4AddressError()
def ip_v6_address_validator(v: Any) -> IPv6Address:
if isinstance(v, IPv6Address):
return v
try:
return IPv6Address(v)
except ValueError:
raise errors.IPv6AddressError()
def ip_v4_network_validator(v: Any) -> IPv4Network:
"""
Assume IPv4Network initialised with a default ``strict`` argument
See more:
https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
"""
if isinstance(v, IPv4Network):
return v
try:
return IPv4Network(v)
except ValueError:
raise errors.IPv4NetworkError()
def ip_v6_network_validator(v: Any) -> IPv6Network:
"""
Assume IPv6Network initialised with a default ``strict`` argument
See more:
https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
"""
if isinstance(v, IPv6Network):
return v
try:
return IPv6Network(v)
except ValueError:
raise errors.IPv6NetworkError()
def ip_v4_interface_validator(v: Any) -> IPv4Interface:
if isinstance(v, IPv4Interface):
return v
try:
return IPv4Interface(v)
except ValueError:
raise errors.IPv4InterfaceError()
def ip_v6_interface_validator(v: Any) -> IPv6Interface:
if isinstance(v, IPv6Interface):
return v
try:
return IPv6Interface(v)
except ValueError:
raise errors.IPv6InterfaceError()
def path_validator(v: Any) -> Path:
if isinstance(v, Path):
return v
try:
return Path(v)
except TypeError:
raise errors.PathError()
def path_exists_validator(v: Any) -> Path:
if not v.exists():
raise errors.PathNotExistsError(path=v)
return v
def callable_validator(v: Any) -> AnyCallable:
"""
Perform a simple check if the value is callable.
Note: complete matching of argument type hints and return types is not performed
"""
if callable(v):
return v
raise errors.CallableError(value=v)
def enum_validator(v: Any) -> Enum:
if isinstance(v, Enum):
return v
raise errors.EnumError(value=v)
def int_enum_validator(v: Any) -> IntEnum:
if isinstance(v, IntEnum):
return v
raise errors.IntEnumError(value=v)
def make_literal_validator(type_: Any) -> Callable[[Any], Any]:
permitted_choices = all_literal_values(type_)
# To have a O(1) complexity and still return one of the values set inside the `Literal`,
# we create a dict with the set values (a set causes some problems with the way intersection works).
# In some cases the set value and checked value can indeed be different (see `test_literal_validator_str_enum`)
allowed_choices = {v: v for v in permitted_choices}
def literal_validator(v: Any) -> Any:
try:
return allowed_choices[v]
except KeyError:
raise errors.WrongConstantError(given=v, permitted=permitted_choices)
return literal_validator
def constr_length_validator(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
v_len = len(v)
min_length = field.type_.min_length if field.type_.min_length is not None else config.min_anystr_length
if v_len < min_length:
raise errors.AnyStrMinLengthError(limit_value=min_length)
max_length = field.type_.max_length if field.type_.max_length is not None else config.max_anystr_length
if max_length is not None and v_len > max_length:
raise errors.AnyStrMaxLengthError(limit_value=max_length)
return v
def constr_strip_whitespace(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
strip_whitespace = field.type_.strip_whitespace or config.anystr_strip_whitespace
if strip_whitespace:
v = v.strip()
return v
def constr_upper(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
upper = field.type_.to_upper or config.anystr_upper
if upper:
v = v.upper()
return v
def constr_lower(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes':
lower = field.type_.to_lower or config.anystr_lower
if lower:
v = v.lower()
return v
def validate_json(v: Any, config: 'BaseConfig') -> Any:
if v is None:
# pass None through to other validators
return v
try:
return config.json_loads(v) # type: ignore
except ValueError:
raise errors.JsonError()
except TypeError:
raise errors.JsonTypeError()
T = TypeVar('T')
def make_arbitrary_type_validator(type_: Type[T]) -> Callable[[T], T]:
def arbitrary_type_validator(v: Any) -> T:
if isinstance(v, type_):
return v
raise errors.ArbitraryTypeError(expected_arbitrary_type=type_)
return arbitrary_type_validator
def make_class_validator(type_: Type[T]) -> Callable[[Any], Type[T]]:
def class_validator(v: Any) -> Type[T]:
if lenient_issubclass(v, type_):
return v
raise errors.SubclassError(expected_class=type_)
return class_validator
def any_class_validator(v: Any) -> Type[T]:
if isinstance(v, type):
return v
raise errors.ClassError()
def none_validator(v: Any) -> 'Literal[None]':
if v is None:
return v
raise errors.NotNoneError()
def pattern_validator(v: Any) -> Pattern[str]:
if isinstance(v, Pattern):
return v
str_value = str_validator(v)
try:
return re.compile(str_value)
except re.error:
raise errors.PatternError()
NamedTupleT = TypeVar('NamedTupleT', bound=NamedTuple)
def make_namedtuple_validator(
namedtuple_cls: Type[NamedTupleT], config: Type['BaseConfig']
) -> Callable[[Tuple[Any, ...]], NamedTupleT]:
from .annotated_types import create_model_from_namedtuple
NamedTupleModel = create_model_from_namedtuple(
namedtuple_cls,
__config__=config,
__module__=namedtuple_cls.__module__,
)
namedtuple_cls.__pydantic_model__ = NamedTupleModel # type: ignore[attr-defined]
def namedtuple_validator(values: Tuple[Any, ...]) -> NamedTupleT:
annotations = NamedTupleModel.__annotations__
if len(values) > len(annotations):
raise errors.ListMaxLengthError(limit_value=len(annotations))
dict_values: Dict[str, Any] = dict(zip(annotations, values))
validated_dict_values: Dict[str, Any] = dict(NamedTupleModel(**dict_values))
return namedtuple_cls(**validated_dict_values)
return namedtuple_validator
def make_typeddict_validator(
typeddict_cls: Type['TypedDict'], config: Type['BaseConfig'] # type: ignore[valid-type]
) -> Callable[[Any], Dict[str, Any]]:
from .annotated_types import create_model_from_typeddict
TypedDictModel = create_model_from_typeddict(
typeddict_cls,
__config__=config,
__module__=typeddict_cls.__module__,
)
typeddict_cls.__pydantic_model__ = TypedDictModel # type: ignore[attr-defined]
def typeddict_validator(values: 'TypedDict') -> Dict[str, Any]: # type: ignore[valid-type]
return TypedDictModel.parse_obj(values).dict(exclude_unset=True)
return typeddict_validator
class IfConfig:
def __init__(self, validator: AnyCallable, *config_attr_names: str, ignored_value: Any = False) -> None:
self.validator = validator
self.config_attr_names = config_attr_names
self.ignored_value = ignored_value
def check(self, config: Type['BaseConfig']) -> bool:
return any(getattr(config, name) not in {None, self.ignored_value} for name in self.config_attr_names)
# order is important here, for example: bool is a subclass of int so has to come first, datetime before date same,
# IPv4Interface before IPv4Address, etc
_VALIDATORS: List[Tuple[Type[Any], List[Any]]] = [
(IntEnum, [int_validator, enum_member_validator]),
(Enum, [enum_member_validator]),
(
str,
[
str_validator,
IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'),
IfConfig(anystr_upper, 'anystr_upper'),
IfConfig(anystr_lower, 'anystr_lower'),
IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'),
],
),
(
bytes,
[
bytes_validator,
IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'),
IfConfig(anystr_upper, 'anystr_upper'),
IfConfig(anystr_lower, 'anystr_lower'),
IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'),
],
),
(bool, [bool_validator]),
(int, [int_validator]),
(float, [float_validator, IfConfig(float_finite_validator, 'allow_inf_nan', ignored_value=True)]),
(Path, [path_validator]),
(datetime, [parse_datetime]),
(date, [parse_date]),
(time, [parse_time]),
(timedelta, [parse_duration]),
(OrderedDict, [ordered_dict_validator]),
(dict, [dict_validator]),
(list, [list_validator]),
(tuple, [tuple_validator]),
(set, [set_validator]),
(frozenset, [frozenset_validator]),
(deque, [deque_validator]),
(UUID, [uuid_validator]),
(Decimal, [decimal_validator]),
(IPv4Interface, [ip_v4_interface_validator]),
(IPv6Interface, [ip_v6_interface_validator]),
(IPv4Address, [ip_v4_address_validator]),
(IPv6Address, [ip_v6_address_validator]),
(IPv4Network, [ip_v4_network_validator]),
(IPv6Network, [ip_v6_network_validator]),
]
def find_validators( # noqa: C901 (ignore complexity)
type_: Type[Any], config: Type['BaseConfig']
) -> Generator[AnyCallable, None, None]:
from .dataclasses import is_builtin_dataclass, make_dataclass_validator
if type_ is Any or type_ is object:
return
type_type = type_.__class__
if type_type == ForwardRef or type_type == TypeVar:
return
if is_none_type(type_):
yield none_validator
return
if type_ is Pattern or type_ is re.Pattern:
yield pattern_validator
return
if type_ is Hashable or type_ is CollectionsHashable:
yield hashable_validator
return
if is_callable_type(type_):
yield callable_validator
return
if is_literal_type(type_):
yield make_literal_validator(type_)
return
if is_builtin_dataclass(type_):
yield from make_dataclass_validator(type_, config)
return
if type_ is Enum:
yield enum_validator
return
if type_ is IntEnum:
yield int_enum_validator
return
if is_namedtuple(type_):
yield tuple_validator
yield make_namedtuple_validator(type_, config)
return
if is_typeddict(type_):
yield make_typeddict_validator(type_, config)
return
class_ = get_class(type_)
if class_ is not None:
if class_ is not Any and isinstance(class_, type):
yield make_class_validator(class_)
else:
yield any_class_validator
return
for val_type, validators in _VALIDATORS:
try:
if issubclass(type_, val_type):
for v in validators:
if isinstance(v, IfConfig):
if v.check(config):
yield v.validator
else:
yield v
return
except TypeError:
raise RuntimeError(f'error checking inheritance of {type_!r} (type: {display_as_type(type_)})')
if config.arbitrary_types_allowed:
yield make_arbitrary_type_validator(type_)
else:
raise RuntimeError(f'no validator found for {type_}, see `arbitrary_types_allowed` in Config')

38
lib/pydantic/version.py Normal file
View file

@ -0,0 +1,38 @@
__all__ = 'compiled', 'VERSION', 'version_info'
VERSION = '1.10.2'
try:
import cython # type: ignore
except ImportError:
compiled: bool = False
else: # pragma: no cover
try:
compiled = cython.compiled
except AttributeError:
compiled = False
def version_info() -> str:
import platform
import sys
from importlib import import_module
from pathlib import Path
optional_deps = []
for p in ('devtools', 'dotenv', 'email-validator', 'typing-extensions'):
try:
import_module(p.replace('-', '_'))
except ImportError:
continue
optional_deps.append(p)
info = {
'pydantic version': VERSION,
'pydantic compiled': compiled,
'install path': Path(__file__).resolve().parent,
'python version': sys.version,
'platform': platform.platform(),
'optional deps. installed': optional_deps,
}
return '\n'.join('{:>30} {}'.format(k + ':', str(v).replace('\n', ' ')) for k, v in info.items())

File diff suppressed because it is too large Load diff

View file

@ -8,7 +8,7 @@ beautifulsoup4==4.11.1
bleach==5.0.1 bleach==5.0.1
certifi==2022.9.24 certifi==2022.9.24
cheroot==8.6.0 cheroot==8.6.0
cherrypy==18.6.1 cherrypy==18.8.0
cloudinary==1.29.0 cloudinary==1.29.0
distro==1.7.0 distro==1.7.0
dnspython==2.2.1 dnspython==2.2.1