Update cherrypy==18.9.0

This commit is contained in:
JonnyWong16 2024-03-24 17:55:12 -07:00
parent 2fc618c01f
commit 51196a7fb1
No known key found for this signature in database
GPG key ID: B1F1F9807184697A
137 changed files with 44442 additions and 11582 deletions

View file

@ -0,0 +1,396 @@
import math
import sys
from dataclasses import dataclass
from datetime import timezone
from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, SupportsFloat, SupportsIndex, TypeVar, Union
if sys.version_info < (3, 8):
from typing_extensions import Protocol, runtime_checkable
else:
from typing import Protocol, runtime_checkable
if sys.version_info < (3, 9):
from typing_extensions import Annotated, Literal
else:
from typing import Annotated, Literal
if sys.version_info < (3, 10):
EllipsisType = type(Ellipsis)
KW_ONLY = {}
SLOTS = {}
else:
from types import EllipsisType
KW_ONLY = {"kw_only": True}
SLOTS = {"slots": True}
__all__ = (
'BaseMetadata',
'GroupedMetadata',
'Gt',
'Ge',
'Lt',
'Le',
'Interval',
'MultipleOf',
'MinLen',
'MaxLen',
'Len',
'Timezone',
'Predicate',
'LowerCase',
'UpperCase',
'IsDigits',
'IsFinite',
'IsNotFinite',
'IsNan',
'IsNotNan',
'IsInfinite',
'IsNotInfinite',
'doc',
'DocInfo',
'__version__',
)
__version__ = '0.6.0'
T = TypeVar('T')
# arguments that start with __ are considered
# positional only
# see https://peps.python.org/pep-0484/#positional-only-arguments
class SupportsGt(Protocol):
def __gt__(self: T, __other: T) -> bool:
...
class SupportsGe(Protocol):
def __ge__(self: T, __other: T) -> bool:
...
class SupportsLt(Protocol):
def __lt__(self: T, __other: T) -> bool:
...
class SupportsLe(Protocol):
def __le__(self: T, __other: T) -> bool:
...
class SupportsMod(Protocol):
def __mod__(self: T, __other: T) -> T:
...
class SupportsDiv(Protocol):
def __div__(self: T, __other: T) -> T:
...
class BaseMetadata:
"""Base class for all metadata.
This exists mainly so that implementers
can do `isinstance(..., BaseMetadata)` while traversing field annotations.
"""
__slots__ = ()
@dataclass(frozen=True, **SLOTS)
class Gt(BaseMetadata):
"""Gt(gt=x) implies that the value must be greater than x.
It can be used with any type that supports the ``>`` operator,
including numbers, dates and times, strings, sets, and so on.
"""
gt: SupportsGt
@dataclass(frozen=True, **SLOTS)
class Ge(BaseMetadata):
"""Ge(ge=x) implies that the value must be greater than or equal to x.
It can be used with any type that supports the ``>=`` operator,
including numbers, dates and times, strings, sets, and so on.
"""
ge: SupportsGe
@dataclass(frozen=True, **SLOTS)
class Lt(BaseMetadata):
"""Lt(lt=x) implies that the value must be less than x.
It can be used with any type that supports the ``<`` operator,
including numbers, dates and times, strings, sets, and so on.
"""
lt: SupportsLt
@dataclass(frozen=True, **SLOTS)
class Le(BaseMetadata):
"""Le(le=x) implies that the value must be less than or equal to x.
It can be used with any type that supports the ``<=`` operator,
including numbers, dates and times, strings, sets, and so on.
"""
le: SupportsLe
@runtime_checkable
class GroupedMetadata(Protocol):
"""A grouping of multiple BaseMetadata objects.
`GroupedMetadata` on its own is not metadata and has no meaning.
All it the the constraint and metadata should be fully expressable
in terms of the `BaseMetadata`'s returned by `GroupedMetadata.__iter__()`.
Concrete implementations should override `GroupedMetadata.__iter__()`
to add their own metadata.
For example:
>>> @dataclass
>>> class Field(GroupedMetadata):
>>> gt: float | None = None
>>> description: str | None = None
...
>>> def __iter__(self) -> Iterable[BaseMetadata]:
>>> if self.gt is not None:
>>> yield Gt(self.gt)
>>> if self.description is not None:
>>> yield Description(self.gt)
Also see the implementation of `Interval` below for an example.
Parsers should recognize this and unpack it so that it can be used
both with and without unpacking:
- `Annotated[int, Field(...)]` (parser must unpack Field)
- `Annotated[int, *Field(...)]` (PEP-646)
""" # noqa: trailing-whitespace
@property
def __is_annotated_types_grouped_metadata__(self) -> Literal[True]:
return True
def __iter__(self) -> Iterator[BaseMetadata]:
...
if not TYPE_CHECKING:
__slots__ = () # allow subclasses to use slots
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
# Basic ABC like functionality without the complexity of an ABC
super().__init_subclass__(*args, **kwargs)
if cls.__iter__ is GroupedMetadata.__iter__:
raise TypeError("Can't subclass GroupedMetadata without implementing __iter__")
def __iter__(self) -> Iterator[BaseMetadata]: # noqa: F811
raise NotImplementedError # more helpful than "None has no attribute..." type errors
@dataclass(frozen=True, **KW_ONLY, **SLOTS)
class Interval(GroupedMetadata):
"""Interval can express inclusive or exclusive bounds with a single object.
It accepts keyword arguments ``gt``, ``ge``, ``lt``, and/or ``le``, which
are interpreted the same way as the single-bound constraints.
"""
gt: Union[SupportsGt, None] = None
ge: Union[SupportsGe, None] = None
lt: Union[SupportsLt, None] = None
le: Union[SupportsLe, None] = None
def __iter__(self) -> Iterator[BaseMetadata]:
"""Unpack an Interval into zero or more single-bounds."""
if self.gt is not None:
yield Gt(self.gt)
if self.ge is not None:
yield Ge(self.ge)
if self.lt is not None:
yield Lt(self.lt)
if self.le is not None:
yield Le(self.le)
@dataclass(frozen=True, **SLOTS)
class MultipleOf(BaseMetadata):
"""MultipleOf(multiple_of=x) might be interpreted in two ways:
1. Python semantics, implying ``value % multiple_of == 0``, or
2. JSONschema semantics, where ``int(value / multiple_of) == value / multiple_of``
We encourage users to be aware of these two common interpretations,
and libraries to carefully document which they implement.
"""
multiple_of: Union[SupportsDiv, SupportsMod]
@dataclass(frozen=True, **SLOTS)
class MinLen(BaseMetadata):
"""
MinLen() implies minimum inclusive length,
e.g. ``len(value) >= min_length``.
"""
min_length: Annotated[int, Ge(0)]
@dataclass(frozen=True, **SLOTS)
class MaxLen(BaseMetadata):
"""
MaxLen() implies maximum inclusive length,
e.g. ``len(value) <= max_length``.
"""
max_length: Annotated[int, Ge(0)]
@dataclass(frozen=True, **SLOTS)
class Len(GroupedMetadata):
"""
Len() implies that ``min_length <= len(value) <= max_length``.
Upper bound may be omitted or ``None`` to indicate no upper length bound.
"""
min_length: Annotated[int, Ge(0)] = 0
max_length: Optional[Annotated[int, Ge(0)]] = None
def __iter__(self) -> Iterator[BaseMetadata]:
"""Unpack a Len into zone or more single-bounds."""
if self.min_length > 0:
yield MinLen(self.min_length)
if self.max_length is not None:
yield MaxLen(self.max_length)
@dataclass(frozen=True, **SLOTS)
class Timezone(BaseMetadata):
"""Timezone(tz=...) requires a datetime to be aware (or ``tz=None``, naive).
``Annotated[datetime, Timezone(None)]`` must be a naive datetime.
``Timezone[...]`` (the ellipsis literal) expresses that the datetime must be
tz-aware but any timezone is allowed.
You may also pass a specific timezone string or timezone object such as
``Timezone(timezone.utc)`` or ``Timezone("Africa/Abidjan")`` to express that
you only allow a specific timezone, though we note that this is often
a symptom of poor design.
"""
tz: Union[str, timezone, EllipsisType, None]
@dataclass(frozen=True, **SLOTS)
class Predicate(BaseMetadata):
"""``Predicate(func: Callable)`` implies `func(value)` is truthy for valid values.
Users should prefer statically inspectable metadata, but if you need the full
power and flexibility of arbitrary runtime predicates... here it is.
We provide a few predefined predicates for common string constraints:
``IsLower = Predicate(str.islower)``, ``IsUpper = Predicate(str.isupper)``, and
``IsDigit = Predicate(str.isdigit)``. Users are encouraged to use methods which
can be given special handling, and avoid indirection like ``lambda s: s.lower()``.
Some libraries might have special logic to handle certain predicates, e.g. by
checking for `str.isdigit` and using its presence to both call custom logic to
enforce digit-only strings, and customise some generated external schema.
We do not specify what behaviour should be expected for predicates that raise
an exception. For example `Annotated[int, Predicate(str.isdigit)]` might silently
skip invalid constraints, or statically raise an error; or it might try calling it
and then propogate or discard the resulting exception.
"""
func: Callable[[Any], bool]
@dataclass
class Not:
func: Callable[[Any], bool]
def __call__(self, __v: Any) -> bool:
return not self.func(__v)
_StrType = TypeVar("_StrType", bound=str)
LowerCase = Annotated[_StrType, Predicate(str.islower)]
"""
Return True if the string is a lowercase string, False otherwise.
A string is lowercase if all cased characters in the string are lowercase and there is at least one cased character in the string.
""" # noqa: E501
UpperCase = Annotated[_StrType, Predicate(str.isupper)]
"""
Return True if the string is an uppercase string, False otherwise.
A string is uppercase if all cased characters in the string are uppercase and there is at least one cased character in the string.
""" # noqa: E501
IsDigits = Annotated[_StrType, Predicate(str.isdigit)]
"""
Return True if the string is a digit string, False otherwise.
A string is a digit string if all characters in the string are digits and there is at least one character in the string.
""" # noqa: E501
IsAscii = Annotated[_StrType, Predicate(str.isascii)]
"""
Return True if all characters in the string are ASCII, False otherwise.
ASCII characters have code points in the range U+0000-U+007F. Empty string is ASCII too.
"""
_NumericType = TypeVar('_NumericType', bound=Union[SupportsFloat, SupportsIndex])
IsFinite = Annotated[_NumericType, Predicate(math.isfinite)]
"""Return True if x is neither an infinity nor a NaN, and False otherwise."""
IsNotFinite = Annotated[_NumericType, Predicate(Not(math.isfinite))]
"""Return True if x is one of infinity or NaN, and False otherwise"""
IsNan = Annotated[_NumericType, Predicate(math.isnan)]
"""Return True if x is a NaN (not a number), and False otherwise."""
IsNotNan = Annotated[_NumericType, Predicate(Not(math.isnan))]
"""Return True if x is anything but NaN (not a number), and False otherwise."""
IsInfinite = Annotated[_NumericType, Predicate(math.isinf)]
"""Return True if x is a positive or negative infinity, and False otherwise."""
IsNotInfinite = Annotated[_NumericType, Predicate(Not(math.isinf))]
"""Return True if x is neither a positive or negative infinity, and False otherwise."""
try:
from typing_extensions import DocInfo, doc # type: ignore [attr-defined]
except ImportError:
@dataclass(frozen=True, **SLOTS)
class DocInfo: # type: ignore [no-redef]
""" "
The return value of doc(), mainly to be used by tools that want to extract the
Annotated documentation at runtime.
"""
documentation: str
"""The documentation string passed to doc()."""
def doc(
documentation: str,
) -> DocInfo:
"""
Add documentation to a type annotation inside of Annotated.
For example:
>>> def hi(name: Annotated[int, doc("The name of the user")]) -> None: ...
"""
return DocInfo(documentation)

View file

@ -0,0 +1,147 @@
import math
import sys
from datetime import date, datetime, timedelta, timezone
from decimal import Decimal
from typing import Any, Dict, Iterable, Iterator, List, NamedTuple, Set, Tuple
if sys.version_info < (3, 9):
from typing_extensions import Annotated
else:
from typing import Annotated
import annotated_types as at
class Case(NamedTuple):
"""
A test case for `annotated_types`.
"""
annotation: Any
valid_cases: Iterable[Any]
invalid_cases: Iterable[Any]
def cases() -> Iterable[Case]:
# Gt, Ge, Lt, Le
yield Case(Annotated[int, at.Gt(4)], (5, 6, 1000), (4, 0, -1))
yield Case(Annotated[float, at.Gt(0.5)], (0.6, 0.7, 0.8, 0.9), (0.5, 0.0, -0.1))
yield Case(
Annotated[datetime, at.Gt(datetime(2000, 1, 1))],
[datetime(2000, 1, 2), datetime(2000, 1, 3)],
[datetime(2000, 1, 1), datetime(1999, 12, 31)],
)
yield Case(
Annotated[datetime, at.Gt(date(2000, 1, 1))],
[date(2000, 1, 2), date(2000, 1, 3)],
[date(2000, 1, 1), date(1999, 12, 31)],
)
yield Case(
Annotated[datetime, at.Gt(Decimal('1.123'))],
[Decimal('1.1231'), Decimal('123')],
[Decimal('1.123'), Decimal('0')],
)
yield Case(Annotated[int, at.Ge(4)], (4, 5, 6, 1000, 4), (0, -1))
yield Case(Annotated[float, at.Ge(0.5)], (0.5, 0.6, 0.7, 0.8, 0.9), (0.4, 0.0, -0.1))
yield Case(
Annotated[datetime, at.Ge(datetime(2000, 1, 1))],
[datetime(2000, 1, 2), datetime(2000, 1, 3)],
[datetime(1998, 1, 1), datetime(1999, 12, 31)],
)
yield Case(Annotated[int, at.Lt(4)], (0, -1), (4, 5, 6, 1000, 4))
yield Case(Annotated[float, at.Lt(0.5)], (0.4, 0.0, -0.1), (0.5, 0.6, 0.7, 0.8, 0.9))
yield Case(
Annotated[datetime, at.Lt(datetime(2000, 1, 1))],
[datetime(1999, 12, 31), datetime(1999, 12, 31)],
[datetime(2000, 1, 2), datetime(2000, 1, 3)],
)
yield Case(Annotated[int, at.Le(4)], (4, 0, -1), (5, 6, 1000))
yield Case(Annotated[float, at.Le(0.5)], (0.5, 0.0, -0.1), (0.6, 0.7, 0.8, 0.9))
yield Case(
Annotated[datetime, at.Le(datetime(2000, 1, 1))],
[datetime(2000, 1, 1), datetime(1999, 12, 31)],
[datetime(2000, 1, 2), datetime(2000, 1, 3)],
)
# Interval
yield Case(Annotated[int, at.Interval(gt=4)], (5, 6, 1000), (4, 0, -1))
yield Case(Annotated[int, at.Interval(gt=4, lt=10)], (5, 6), (4, 10, 1000, 0, -1))
yield Case(Annotated[float, at.Interval(ge=0.5, le=1)], (0.5, 0.9, 1), (0.49, 1.1))
yield Case(
Annotated[datetime, at.Interval(gt=datetime(2000, 1, 1), le=datetime(2000, 1, 3))],
[datetime(2000, 1, 2), datetime(2000, 1, 3)],
[datetime(2000, 1, 1), datetime(2000, 1, 4)],
)
yield Case(Annotated[int, at.MultipleOf(multiple_of=3)], (0, 3, 9), (1, 2, 4))
yield Case(Annotated[float, at.MultipleOf(multiple_of=0.5)], (0, 0.5, 1, 1.5), (0.4, 1.1))
# lengths
yield Case(Annotated[str, at.MinLen(3)], ('123', '1234', 'x' * 10), ('', '1', '12'))
yield Case(Annotated[str, at.Len(3)], ('123', '1234', 'x' * 10), ('', '1', '12'))
yield Case(Annotated[List[int], at.MinLen(3)], ([1, 2, 3], [1, 2, 3, 4], [1] * 10), ([], [1], [1, 2]))
yield Case(Annotated[List[int], at.Len(3)], ([1, 2, 3], [1, 2, 3, 4], [1] * 10), ([], [1], [1, 2]))
yield Case(Annotated[str, at.MaxLen(4)], ('', '1234'), ('12345', 'x' * 10))
yield Case(Annotated[str, at.Len(0, 4)], ('', '1234'), ('12345', 'x' * 10))
yield Case(Annotated[List[str], at.MaxLen(4)], ([], ['a', 'bcdef'], ['a', 'b', 'c']), (['a'] * 5, ['b'] * 10))
yield Case(Annotated[List[str], at.Len(0, 4)], ([], ['a', 'bcdef'], ['a', 'b', 'c']), (['a'] * 5, ['b'] * 10))
yield Case(Annotated[str, at.Len(3, 5)], ('123', '12345'), ('', '1', '12', '123456', 'x' * 10))
yield Case(Annotated[str, at.Len(3, 3)], ('123',), ('12', '1234'))
yield Case(Annotated[Dict[int, int], at.Len(2, 3)], [{1: 1, 2: 2}], [{}, {1: 1}, {1: 1, 2: 2, 3: 3, 4: 4}])
yield Case(Annotated[Set[int], at.Len(2, 3)], ({1, 2}, {1, 2, 3}), (set(), {1}, {1, 2, 3, 4}))
yield Case(Annotated[Tuple[int, ...], at.Len(2, 3)], ((1, 2), (1, 2, 3)), ((), (1,), (1, 2, 3, 4)))
# Timezone
yield Case(
Annotated[datetime, at.Timezone(None)], [datetime(2000, 1, 1)], [datetime(2000, 1, 1, tzinfo=timezone.utc)]
)
yield Case(
Annotated[datetime, at.Timezone(...)], [datetime(2000, 1, 1, tzinfo=timezone.utc)], [datetime(2000, 1, 1)]
)
yield Case(
Annotated[datetime, at.Timezone(timezone.utc)],
[datetime(2000, 1, 1, tzinfo=timezone.utc)],
[datetime(2000, 1, 1), datetime(2000, 1, 1, tzinfo=timezone(timedelta(hours=6)))],
)
yield Case(
Annotated[datetime, at.Timezone('Europe/London')],
[datetime(2000, 1, 1, tzinfo=timezone(timedelta(0), name='Europe/London'))],
[datetime(2000, 1, 1), datetime(2000, 1, 1, tzinfo=timezone(timedelta(hours=6)))],
)
# predicate types
yield Case(at.LowerCase[str], ['abc', 'foobar'], ['', 'A', 'Boom'])
yield Case(at.UpperCase[str], ['ABC', 'DEFO'], ['', 'a', 'abc', 'AbC'])
yield Case(at.IsDigits[str], ['123'], ['', 'ab', 'a1b2'])
yield Case(at.IsAscii[str], ['123', 'foo bar'], ['£100', '😊', 'whatever 👀'])
yield Case(Annotated[int, at.Predicate(lambda x: x % 2 == 0)], [0, 2, 4], [1, 3, 5])
yield Case(at.IsFinite[float], [1.23], [math.nan, math.inf, -math.inf])
yield Case(at.IsNotFinite[float], [math.nan, math.inf], [1.23])
yield Case(at.IsNan[float], [math.nan], [1.23, math.inf])
yield Case(at.IsNotNan[float], [1.23, math.inf], [math.nan])
yield Case(at.IsInfinite[float], [math.inf], [math.nan, 1.23])
yield Case(at.IsNotInfinite[float], [math.nan, 1.23], [math.inf])
# check stacked predicates
yield Case(at.IsInfinite[Annotated[float, at.Predicate(lambda x: x > 0)]], [math.inf], [-math.inf, 1.23, math.nan])
# doc
yield Case(Annotated[int, at.doc("A number")], [1, 2], [])
# custom GroupedMetadata
class MyCustomGroupedMetadata(at.GroupedMetadata):
def __iter__(self) -> Iterator[at.Predicate]:
yield at.Predicate(lambda x: float(x).is_integer())
yield Case(Annotated[float, MyCustomGroupedMetadata()], [0, 2.0], [0.01, 1.5])

View file

@ -20,7 +20,7 @@ from functools import wraps
from inspect import signature from inspect import signature
def _launch_forever_coro(coro, args, kwargs, loop): async def _run_forever_coro(coro, args, kwargs, loop):
''' '''
This helper function launches an async main function that was tagged with This helper function launches an async main function that was tagged with
forever=True. There are two possibilities: forever=True. There are two possibilities:
@ -48,7 +48,7 @@ def _launch_forever_coro(coro, args, kwargs, loop):
# forever=True feature from autoasync at some point in the future. # forever=True feature from autoasync at some point in the future.
thing = coro(*args, **kwargs) thing = coro(*args, **kwargs)
if iscoroutine(thing): if iscoroutine(thing):
loop.create_task(thing) await thing
def autoasync(coro=None, *, loop=None, forever=False, pass_loop=False): def autoasync(coro=None, *, loop=None, forever=False, pass_loop=False):
@ -127,7 +127,9 @@ def autoasync(coro=None, *, loop=None, forever=False, pass_loop=False):
args, kwargs = bound_args.args, bound_args.kwargs args, kwargs = bound_args.args, bound_args.kwargs
if forever: if forever:
_launch_forever_coro(coro, args, kwargs, local_loop) local_loop.create_task(_run_forever_coro(
coro, args, kwargs, local_loop
))
local_loop.run_forever() local_loop.run_forever()
else: else:
return local_loop.run_until_complete(coro(*args, **kwargs)) return local_loop.run_until_complete(coro(*args, **kwargs))

View file

@ -452,6 +452,6 @@ class WSGIErrorHandler(logging.Handler):
class LazyRfc3339UtcTime(object): class LazyRfc3339UtcTime(object):
def __str__(self): def __str__(self):
"""Return now() in RFC3339 UTC Format.""" """Return utcnow() in RFC3339 UTC Format."""
now = datetime.datetime.now() iso_formatted_now = datetime.datetime.utcnow().isoformat('T')
return now.isoformat('T') + 'Z' return f'{iso_formatted_now!s}Z'

View file

@ -622,13 +622,15 @@ def autovary(ignore=None, debug=False):
def convert_params(exception=ValueError, error=400): def convert_params(exception=ValueError, error=400):
"""Convert request params based on function annotations, with error handling. """Convert request params based on function annotations.
exception This function also processes errors that are subclasses of ``exception``.
Exception class to catch.
status :param BaseException exception: Exception class to catch.
The HTTP error code to return to the client on failure. :type exception: BaseException
:param error: The HTTP status code to return to the client on failure.
:type error: int
""" """
request = cherrypy.serving.request request = cherrypy.serving.request
types = request.handler.callable.__annotations__ types = request.handler.callable.__annotations__

View file

@ -47,7 +47,9 @@ try:
import pstats import pstats
def new_func_strip_path(func_name): def new_func_strip_path(func_name):
"""Make profiler output more readable by adding `__init__` modules' parents """Add ``__init__`` modules' parents.
This makes the profiler output more readable.
""" """
filename, line, name = func_name filename, line, name = func_name
if filename.endswith('__init__.py'): if filename.endswith('__init__.py'):

View file

@ -188,7 +188,7 @@ class Parser(configparser.ConfigParser):
def dict_from_file(self, file): def dict_from_file(self, file):
if hasattr(file, 'read'): if hasattr(file, 'read'):
self.readfp(file) self.read_file(file)
else: else:
self.read(file) self.read(file)
return self.as_dict() return self.as_dict()

View file

@ -1,19 +1,18 @@
"""Module with helpers for serving static files.""" """Module with helpers for serving static files."""
import mimetypes
import os import os
import platform import platform
import re import re
import stat import stat
import mimetypes
import urllib.parse
import unicodedata import unicodedata
import urllib.parse
from email.generator import _make_boundary as make_boundary from email.generator import _make_boundary as make_boundary
from io import UnsupportedOperation from io import UnsupportedOperation
import cherrypy import cherrypy
from cherrypy._cpcompat import ntob from cherrypy._cpcompat import ntob
from cherrypy.lib import cptools, httputil, file_generator_limited from cherrypy.lib import cptools, file_generator_limited, httputil
def _setup_mimetypes(): def _setup_mimetypes():
@ -185,7 +184,10 @@ def serve_fileobj(fileobj, content_type=None, disposition=None, name=None,
def _serve_fileobj(fileobj, content_type, content_length, debug=False): def _serve_fileobj(fileobj, content_type, content_length, debug=False):
"""Internal. Set response.body to the given file object, perhaps ranged.""" """Set ``response.body`` to the given file object, perhaps ranged.
Internal helper.
"""
response = cherrypy.serving.response response = cherrypy.serving.response
# HTTP/1.0 didn't have Range/Accept-Ranges headers, or the 206 code # HTTP/1.0 didn't have Range/Accept-Ranges headers, or the 206 code

View file

@ -494,7 +494,7 @@ class Bus(object):
"Cannot reconstruct command from '-c'. " "Cannot reconstruct command from '-c'. "
'Ref: https://github.com/cherrypy/cherrypy/issues/1545') 'Ref: https://github.com/cherrypy/cherrypy/issues/1545')
except AttributeError: except AttributeError:
"""It looks Py_GetArgcArgv is completely absent in some environments """It looks Py_GetArgcArgv's completely absent in some environments
It is known, that there's no Py_GetArgcArgv in MS Windows and It is known, that there's no Py_GetArgcArgv in MS Windows and
``ctypes`` module is completely absent in Google AppEngine ``ctypes`` module is completely absent in Google AppEngine

View file

@ -136,6 +136,9 @@ class HTTPTests(helper.CPWebCase):
self.assertStatus(200) self.assertStatus(200)
self.assertBody(b'Hello world!') self.assertBody(b'Hello world!')
response.close()
c.close()
# Now send a message that has no Content-Length, but does send a body. # Now send a message that has no Content-Length, but does send a body.
# Verify that CP times out the socket and responds # Verify that CP times out the socket and responds
# with 411 Length Required. # with 411 Length Required.
@ -159,6 +162,9 @@ class HTTPTests(helper.CPWebCase):
self.status = str(response.status) self.status = str(response.status)
self.assertStatus(411) self.assertStatus(411)
response.close()
c.close()
def test_post_multipart(self): def test_post_multipart(self):
alphabet = 'abcdefghijklmnopqrstuvwxyz' alphabet = 'abcdefghijklmnopqrstuvwxyz'
# generate file contents for a large post # generate file contents for a large post
@ -184,6 +190,9 @@ class HTTPTests(helper.CPWebCase):
parts = ['%s * 65536' % ch for ch in alphabet] parts = ['%s * 65536' % ch for ch in alphabet]
self.assertBody(', '.join(parts)) self.assertBody(', '.join(parts))
response.close()
c.close()
def test_post_filename_with_special_characters(self): def test_post_filename_with_special_characters(self):
"""Testing that we can handle filenames with special characters. """Testing that we can handle filenames with special characters.
@ -217,6 +226,9 @@ class HTTPTests(helper.CPWebCase):
self.assertStatus(200) self.assertStatus(200)
self.assertBody(fname) self.assertBody(fname)
response.close()
c.close()
def test_malformed_request_line(self): def test_malformed_request_line(self):
if getattr(cherrypy.server, 'using_apache', False): if getattr(cherrypy.server, 'using_apache', False):
return self.skip('skipped due to known Apache differences...') return self.skip('skipped due to known Apache differences...')
@ -264,6 +276,9 @@ class HTTPTests(helper.CPWebCase):
self.body = response.fp.read(20) self.body = response.fp.read(20)
self.assertBody('Illegal header line.') self.assertBody('Illegal header line.')
response.close()
c.close()
def test_http_over_https(self): def test_http_over_https(self):
if self.scheme != 'https': if self.scheme != 'https':
return self.skip('skipped (not running HTTPS)... ') return self.skip('skipped (not running HTTPS)... ')

View file

@ -150,6 +150,8 @@ class IteratorTest(helper.CPWebCase):
self.assertStatus(200) self.assertStatus(200)
self.assertBody('0') self.assertBody('0')
itr_conn.close()
# Now we do the same check with streaming - some classes will # Now we do the same check with streaming - some classes will
# be automatically closed, while others cannot. # be automatically closed, while others cannot.
stream_counts = {} stream_counts = {}

View file

@ -1,5 +1,6 @@
"""Basic tests for the CherryPy core: request handling.""" """Basic tests for the CherryPy core: request handling."""
import datetime
import logging import logging
from cheroot.test import webtest from cheroot.test import webtest
@ -197,6 +198,33 @@ def test_custom_log_format(log_tracker, monkeypatch, server):
) )
def test_utc_in_timez(monkeypatch):
"""Test that ``LazyRfc3339UtcTime`` is rendered as ``str`` using UTC timestamp."""
utcoffset8_local_time_in_naive_utc = (
datetime.datetime(
year=2020,
month=1,
day=1,
hour=1,
minute=23,
second=45,
tzinfo=datetime.timezone(datetime.timedelta(hours=8)),
)
.astimezone(datetime.timezone.utc)
.replace(tzinfo=None)
)
class mock_datetime:
@classmethod
def utcnow(cls):
return utcoffset8_local_time_in_naive_utc
monkeypatch.setattr('datetime.datetime', mock_datetime)
rfc3339_utc_time = str(cherrypy._cplogging.LazyRfc3339UtcTime())
expected_time = '2019-12-31T17:23:45Z'
assert rfc3339_utc_time == expected_time
def test_timez_log_format(log_tracker, monkeypatch, server): def test_timez_log_format(log_tracker, monkeypatch, server):
"""Test a customized access_log_format string, which is a """Test a customized access_log_format string, which is a
feature of _cplogging.LogManager.access().""" feature of _cplogging.LogManager.access()."""

View file

@ -3,8 +3,6 @@ inflect: english language inflection
- correctly generate plurals, ordinals, indefinite articles - correctly generate plurals, ordinals, indefinite articles
- convert numbers to words - convert numbers to words
Copyright (C) 2010 Paul Dyson
Based upon the Perl module Based upon the Perl module
`Lingua::EN::Inflect <https://metacpan.org/pod/Lingua::EN::Inflect>`_. `Lingua::EN::Inflect <https://metacpan.org/pod/Lingua::EN::Inflect>`_.
@ -70,11 +68,16 @@ from typing import (
cast, cast,
Any, Any,
) )
from typing_extensions import Literal
from numbers import Number from numbers import Number
from pydantic import Field, validate_arguments from pydantic import Field
from pydantic.typing import Annotated from typing_extensions import Annotated
from .compat.pydantic1 import validate_call
from .compat.pydantic import same_method
class UnknownClassicalModeError(Exception): class UnknownClassicalModeError(Exception):
@ -105,14 +108,6 @@ class BadGenderError(Exception):
pass pass
STDOUT_ON = False
def print3(txt: str) -> None:
if STDOUT_ON:
print(txt)
def enclose(s: str) -> str: def enclose(s: str) -> str:
return f"(?:{s})" return f"(?:{s})"
@ -1727,66 +1722,44 @@ plverb_irregular_pres = {
"is": "are", "is": "are",
"was": "were", "was": "were",
"were": "were", "were": "were",
"was": "were",
"have": "have",
"have": "have", "have": "have",
"has": "have", "has": "have",
"do": "do", "do": "do",
"do": "do",
"does": "do", "does": "do",
} }
plverb_ambiguous_pres = { plverb_ambiguous_pres = {
"act": "act",
"act": "act", "act": "act",
"acts": "act", "acts": "act",
"blame": "blame", "blame": "blame",
"blame": "blame",
"blames": "blame", "blames": "blame",
"can": "can", "can": "can",
"can": "can",
"can": "can",
"must": "must", "must": "must",
"must": "must",
"must": "must",
"fly": "fly",
"fly": "fly", "fly": "fly",
"flies": "fly", "flies": "fly",
"copy": "copy", "copy": "copy",
"copy": "copy",
"copies": "copy", "copies": "copy",
"drink": "drink", "drink": "drink",
"drink": "drink",
"drinks": "drink", "drinks": "drink",
"fight": "fight", "fight": "fight",
"fight": "fight",
"fights": "fight", "fights": "fight",
"fire": "fire", "fire": "fire",
"fire": "fire",
"fires": "fire", "fires": "fire",
"like": "like", "like": "like",
"like": "like",
"likes": "like", "likes": "like",
"look": "look", "look": "look",
"look": "look",
"looks": "look", "looks": "look",
"make": "make", "make": "make",
"make": "make",
"makes": "make", "makes": "make",
"reach": "reach", "reach": "reach",
"reach": "reach",
"reaches": "reach", "reaches": "reach",
"run": "run", "run": "run",
"run": "run",
"runs": "run", "runs": "run",
"sink": "sink", "sink": "sink",
"sink": "sink",
"sinks": "sink", "sinks": "sink",
"sleep": "sleep", "sleep": "sleep",
"sleep": "sleep",
"sleeps": "sleep", "sleeps": "sleep",
"view": "view", "view": "view",
"view": "view",
"views": "view", "views": "view",
} }
@ -1854,7 +1827,7 @@ pl_adj_poss_keys = re.compile(fr"^({enclose('|'.join(pl_adj_poss))})$", re.IGNOR
A_abbrev = re.compile( A_abbrev = re.compile(
r""" r"""
(?! FJO | [HLMNS]Y. | RY[EO] | SQU ^(?! FJO | [HLMNS]Y. | RY[EO] | SQU
| ( F[LR]? | [HL] | MN? | N | RH? | S[CHKLMNPTVW]? | X(YL)?) [AEIOU]) | ( F[LR]? | [HL] | MN? | N | RH? | S[CHKLMNPTVW]? | X(YL)?) [AEIOU])
[FHLMNRSX][A-Z] [FHLMNRSX][A-Z]
""", """,
@ -2053,15 +2026,14 @@ Falsish = Any # ideally, falsish would only validate on bool(value) is False
class engine: class engine:
def __init__(self) -> None: def __init__(self) -> None:
self.classical_dict = def_classical.copy() self.classical_dict = def_classical.copy()
self.persistent_count: Optional[int] = None self.persistent_count: Optional[int] = None
self.mill_count = 0 self.mill_count = 0
self.pl_sb_user_defined: List[str] = [] self.pl_sb_user_defined: List[Optional[Word]] = []
self.pl_v_user_defined: List[str] = [] self.pl_v_user_defined: List[Optional[Word]] = []
self.pl_adj_user_defined: List[str] = [] self.pl_adj_user_defined: List[Optional[Word]] = []
self.si_sb_user_defined: List[str] = [] self.si_sb_user_defined: List[Optional[Word]] = []
self.A_a_user_defined: List[str] = [] self.A_a_user_defined: List[Optional[Word]] = []
self.thegender = "neuter" self.thegender = "neuter"
self.__number_args: Optional[Dict[str, str]] = None self.__number_args: Optional[Dict[str, str]] = None
@ -2073,28 +2045,8 @@ class engine:
def _number_args(self, val): def _number_args(self, val):
self.__number_args = val self.__number_args = val
deprecated_methods = dict( @validate_call
pl="plural", def defnoun(self, singular: Optional[Word], plural: Optional[Word]) -> int:
plnoun="plural_noun",
plverb="plural_verb",
pladj="plural_adj",
sinoun="single_noun",
prespart="present_participle",
numwords="number_to_words",
plequal="compare",
plnounequal="compare_nouns",
plverbequal="compare_verbs",
pladjequal="compare_adjs",
wordlist="join",
)
def __getattr__(self, meth):
if meth in self.deprecated_methods:
print3(f"{meth}() deprecated, use {self.deprecated_methods[meth]}()")
raise DeprecationWarning
raise AttributeError
def defnoun(self, singular: str, plural: str) -> int:
""" """
Set the noun plural of singular to plural. Set the noun plural of singular to plural.
@ -2105,7 +2057,16 @@ class engine:
self.si_sb_user_defined.extend((plural, singular)) self.si_sb_user_defined.extend((plural, singular))
return 1 return 1
def defverb(self, s1: str, p1: str, s2: str, p2: str, s3: str, p3: str) -> int: @validate_call
def defverb(
self,
s1: Optional[Word],
p1: Optional[Word],
s2: Optional[Word],
p2: Optional[Word],
s3: Optional[Word],
p3: Optional[Word],
) -> int:
""" """
Set the verb plurals for s1, s2 and s3 to p1, p2 and p3 respectively. Set the verb plurals for s1, s2 and s3 to p1, p2 and p3 respectively.
@ -2121,7 +2082,8 @@ class engine:
self.pl_v_user_defined.extend((s1, p1, s2, p2, s3, p3)) self.pl_v_user_defined.extend((s1, p1, s2, p2, s3, p3))
return 1 return 1
def defadj(self, singular: str, plural: str) -> int: @validate_call
def defadj(self, singular: Optional[Word], plural: Optional[Word]) -> int:
""" """
Set the adjective plural of singular to plural. Set the adjective plural of singular to plural.
@ -2131,7 +2093,8 @@ class engine:
self.pl_adj_user_defined.extend((singular, plural)) self.pl_adj_user_defined.extend((singular, plural))
return 1 return 1
def defa(self, pattern: str) -> int: @validate_call
def defa(self, pattern: Optional[Word]) -> int:
""" """
Define the indefinite article as 'a' for words matching pattern. Define the indefinite article as 'a' for words matching pattern.
@ -2140,7 +2103,8 @@ class engine:
self.A_a_user_defined.extend((pattern, "a")) self.A_a_user_defined.extend((pattern, "a"))
return 1 return 1
def defan(self, pattern: str) -> int: @validate_call
def defan(self, pattern: Optional[Word]) -> int:
""" """
Define the indefinite article as 'an' for words matching pattern. Define the indefinite article as 'an' for words matching pattern.
@ -2149,7 +2113,7 @@ class engine:
self.A_a_user_defined.extend((pattern, "an")) self.A_a_user_defined.extend((pattern, "an"))
return 1 return 1
def checkpat(self, pattern: Optional[str]) -> None: def checkpat(self, pattern: Optional[Word]) -> None:
""" """
check for errors in a regex pattern check for errors in a regex pattern
""" """
@ -2158,16 +2122,15 @@ class engine:
try: try:
re.match(pattern, "") re.match(pattern, "")
except re.error: except re.error:
print3(f"\nBad user-defined singular pattern:\n\t{pattern}\n") raise BadUserDefinedPatternError(pattern)
raise BadUserDefinedPatternError
def checkpatplural(self, pattern: str) -> None: def checkpatplural(self, pattern: Optional[Word]) -> None:
""" """
check for errors in a regex replace pattern check for errors in a regex replace pattern
""" """
return return
@validate_arguments @validate_call
def ud_match(self, word: Word, wordlist: Sequence[Optional[Word]]) -> Optional[str]: def ud_match(self, word: Word, wordlist: Sequence[Optional[Word]]) -> Optional[str]:
for i in range(len(wordlist) - 2, -2, -2): # backwards through even elements for i in range(len(wordlist) - 2, -2, -2): # backwards through even elements
mo = re.search(fr"^{wordlist[i]}$", word, re.IGNORECASE) mo = re.search(fr"^{wordlist[i]}$", word, re.IGNORECASE)
@ -2307,7 +2270,7 @@ class engine:
# 0. PERFORM GENERAL INFLECTIONS IN A STRING # 0. PERFORM GENERAL INFLECTIONS IN A STRING
@validate_arguments @validate_call
def inflect(self, text: Word) -> str: def inflect(self, text: Word) -> str:
""" """
Perform inflections in a string. Perform inflections in a string.
@ -2384,7 +2347,7 @@ class engine:
else: else:
return "", "", "" return "", "", ""
@validate_arguments @validate_call
def plural(self, text: Word, count: Optional[Union[str, int, Any]] = None) -> str: def plural(self, text: Word, count: Optional[Union[str, int, Any]] = None) -> str:
""" """
Return the plural of text. Return the plural of text.
@ -2408,7 +2371,7 @@ class engine:
) )
return f"{pre}{plural}{post}" return f"{pre}{plural}{post}"
@validate_arguments @validate_call
def plural_noun( def plural_noun(
self, text: Word, count: Optional[Union[str, int, Any]] = None self, text: Word, count: Optional[Union[str, int, Any]] = None
) -> str: ) -> str:
@ -2429,7 +2392,7 @@ class engine:
plural = self.postprocess(word, self._plnoun(word, count)) plural = self.postprocess(word, self._plnoun(word, count))
return f"{pre}{plural}{post}" return f"{pre}{plural}{post}"
@validate_arguments @validate_call
def plural_verb( def plural_verb(
self, text: Word, count: Optional[Union[str, int, Any]] = None self, text: Word, count: Optional[Union[str, int, Any]] = None
) -> str: ) -> str:
@ -2453,7 +2416,7 @@ class engine:
) )
return f"{pre}{plural}{post}" return f"{pre}{plural}{post}"
@validate_arguments @validate_call
def plural_adj( def plural_adj(
self, text: Word, count: Optional[Union[str, int, Any]] = None self, text: Word, count: Optional[Union[str, int, Any]] = None
) -> str: ) -> str:
@ -2474,7 +2437,7 @@ class engine:
plural = self.postprocess(word, self._pl_special_adjective(word, count) or word) plural = self.postprocess(word, self._pl_special_adjective(word, count) or word)
return f"{pre}{plural}{post}" return f"{pre}{plural}{post}"
@validate_arguments @validate_call
def compare(self, word1: Word, word2: Word) -> Union[str, bool]: def compare(self, word1: Word, word2: Word) -> Union[str, bool]:
""" """
compare word1 and word2 for equality regardless of plurality compare word1 and word2 for equality regardless of plurality
@ -2497,15 +2460,15 @@ class engine:
>>> compare('egg', '') >>> compare('egg', '')
Traceback (most recent call last): Traceback (most recent call last):
... ...
pydantic.error_wrappers.ValidationError: 1 validation error for Compare pydantic...ValidationError: ...
word2 ...
ensure this value has at least 1 characters... ...at least 1 characters...
""" """
norms = self.plural_noun, self.plural_verb, self.plural_adj norms = self.plural_noun, self.plural_verb, self.plural_adj
results = (self._plequal(word1, word2, norm) for norm in norms) results = (self._plequal(word1, word2, norm) for norm in norms)
return next(filter(None, results), False) return next(filter(None, results), False)
@validate_arguments @validate_call
def compare_nouns(self, word1: Word, word2: Word) -> Union[str, bool]: def compare_nouns(self, word1: Word, word2: Word) -> Union[str, bool]:
""" """
compare word1 and word2 for equality regardless of plurality compare word1 and word2 for equality regardless of plurality
@ -2521,7 +2484,7 @@ class engine:
""" """
return self._plequal(word1, word2, self.plural_noun) return self._plequal(word1, word2, self.plural_noun)
@validate_arguments @validate_call
def compare_verbs(self, word1: Word, word2: Word) -> Union[str, bool]: def compare_verbs(self, word1: Word, word2: Word) -> Union[str, bool]:
""" """
compare word1 and word2 for equality regardless of plurality compare word1 and word2 for equality regardless of plurality
@ -2537,7 +2500,7 @@ class engine:
""" """
return self._plequal(word1, word2, self.plural_verb) return self._plequal(word1, word2, self.plural_verb)
@validate_arguments @validate_call
def compare_adjs(self, word1: Word, word2: Word) -> Union[str, bool]: def compare_adjs(self, word1: Word, word2: Word) -> Union[str, bool]:
""" """
compare word1 and word2 for equality regardless of plurality compare word1 and word2 for equality regardless of plurality
@ -2553,13 +2516,13 @@ class engine:
""" """
return self._plequal(word1, word2, self.plural_adj) return self._plequal(word1, word2, self.plural_adj)
@validate_arguments @validate_call
def singular_noun( def singular_noun(
self, self,
text: Word, text: Word,
count: Optional[Union[int, str, Any]] = None, count: Optional[Union[int, str, Any]] = None,
gender: Optional[str] = None, gender: Optional[str] = None,
) -> Union[str, bool]: ) -> Union[str, Literal[False]]:
""" """
Return the singular of text, where text is a plural noun. Return the singular of text, where text is a plural noun.
@ -2611,12 +2574,12 @@ class engine:
return "s:p" return "s:p"
self.classical_dict = classval.copy() self.classical_dict = classval.copy()
if pl == self.plural or pl == self.plural_noun: if same_method(pl, self.plural) or same_method(pl, self.plural_noun):
if self._pl_check_plurals_N(word1, word2): if self._pl_check_plurals_N(word1, word2):
return "p:p" return "p:p"
if self._pl_check_plurals_N(word2, word1): if self._pl_check_plurals_N(word2, word1):
return "p:p" return "p:p"
if pl == self.plural or pl == self.plural_adj: if same_method(pl, self.plural) or same_method(pl, self.plural_adj):
if self._pl_check_plurals_adj(word1, word2): if self._pl_check_plurals_adj(word1, word2):
return "p:p" return "p:p"
return False return False
@ -3266,11 +3229,11 @@ class engine:
if words.last in si_sb_irregular_caps: if words.last in si_sb_irregular_caps:
llen = len(words.last) llen = len(words.last)
return "{}{}".format(word[:-llen], si_sb_irregular_caps[words.last]) return f"{word[:-llen]}{si_sb_irregular_caps[words.last]}"
if words.last.lower() in si_sb_irregular: if words.last.lower() in si_sb_irregular:
llen = len(words.last.lower()) llen = len(words.last.lower())
return "{}{}".format(word[:-llen], si_sb_irregular[words.last.lower()]) return f"{word[:-llen]}{si_sb_irregular[words.last.lower()]}"
dash_split = words.lowered.split("-") dash_split = words.lowered.split("-")
if (" ".join(dash_split[-2:])).lower() in si_sb_irregular_compound: if (" ".join(dash_split[-2:])).lower() in si_sb_irregular_compound:
@ -3341,7 +3304,6 @@ class engine:
# HANDLE INCOMPLETELY ASSIMILATED IMPORTS # HANDLE INCOMPLETELY ASSIMILATED IMPORTS
if self.classical_dict["ancient"]: if self.classical_dict["ancient"]:
if words.lowered[-6:] == "trices": if words.lowered[-6:] == "trices":
return word[:-3] + "x" return word[:-3] + "x"
if words.lowered[-4:] in ("eaux", "ieux"): if words.lowered[-4:] in ("eaux", "ieux"):
@ -3459,7 +3421,6 @@ class engine:
# HANDLE ...o # HANDLE ...o
if words.lowered[-2:] == "os": if words.lowered[-2:] == "os":
if words.last.lower() in si_sb_U_o_os_complete: if words.last.lower() in si_sb_U_o_os_complete:
return word[:-1] return word[:-1]
@ -3489,7 +3450,7 @@ class engine:
# ADJECTIVES # ADJECTIVES
@validate_arguments @validate_call
def a(self, text: Word, count: Optional[Union[int, str, Any]] = 1) -> str: def a(self, text: Word, count: Optional[Union[int, str, Any]] = 1) -> str:
""" """
Return the appropriate indefinite article followed by text. Return the appropriate indefinite article followed by text.
@ -3570,7 +3531,7 @@ class engine:
# 2. TRANSLATE ZERO-QUANTIFIED $word TO "no plural($word)" # 2. TRANSLATE ZERO-QUANTIFIED $word TO "no plural($word)"
@validate_arguments @validate_call
def no(self, text: Word, count: Optional[Union[int, str]] = None) -> str: def no(self, text: Word, count: Optional[Union[int, str]] = None) -> str:
""" """
If count is 0, no, zero or nil, return 'no' followed by the plural If count is 0, no, zero or nil, return 'no' followed by the plural
@ -3608,7 +3569,7 @@ class engine:
# PARTICIPLES # PARTICIPLES
@validate_arguments @validate_call
def present_participle(self, word: Word) -> str: def present_participle(self, word: Word) -> str:
""" """
Return the present participle for word. Return the present participle for word.
@ -3627,31 +3588,31 @@ class engine:
# NUMERICAL INFLECTIONS # NUMERICAL INFLECTIONS
@validate_arguments @validate_call(config=dict(arbitrary_types_allowed=True))
def ordinal(self, num: Union[int, Word]) -> str: # noqa: C901 def ordinal(self, num: Union[Number, Word]) -> str:
""" """
Return the ordinal of num. Return the ordinal of num.
num can be an integer or text >>> ordinal = engine().ordinal
>>> ordinal(1)
e.g. ordinal(1) returns '1st' '1st'
ordinal('one') returns 'first' >>> ordinal('one')
'first'
""" """
if DIGIT.match(str(num)): if DIGIT.match(str(num)):
if isinstance(num, (int, float)): if isinstance(num, (float, int)) and int(num) == num:
n = int(num) n = int(num)
else: else:
if "." in str(num): if "." in str(num):
try: try:
# numbers after decimal, # numbers after decimal,
# so only need last one for ordinal # so only need last one for ordinal
n = int(num[-1]) n = int(str(num)[-1])
except ValueError: # ends with '.', so need to use whole string except ValueError: # ends with '.', so need to use whole string
n = int(num[:-1]) n = int(str(num)[:-1])
else: else:
n = int(num) n = int(num) # type: ignore
try: try:
post = nth[n % 100] post = nth[n % 100]
except KeyError: except KeyError:
@ -3671,7 +3632,6 @@ class engine:
def millfn(self, ind: int = 0) -> str: def millfn(self, ind: int = 0) -> str:
if ind > len(mill) - 1: if ind > len(mill) - 1:
print3("number out of range")
raise NumOutOfRangeError raise NumOutOfRangeError
return mill[ind] return mill[ind]
@ -3787,7 +3747,7 @@ class engine:
num = ONE_DIGIT_WORD.sub(self.unitsub, num, 1) num = ONE_DIGIT_WORD.sub(self.unitsub, num, 1)
return num return num
@validate_arguments(config=dict(arbitrary_types_allowed=True)) # noqa: C901 @validate_call(config=dict(arbitrary_types_allowed=True)) # noqa: C901
def number_to_words( # noqa: C901 def number_to_words( # noqa: C901
self, self,
num: Union[Number, Word], num: Union[Number, Word],
@ -3939,7 +3899,7 @@ class engine:
# Join words with commas and a trailing 'and' (when appropriate)... # Join words with commas and a trailing 'and' (when appropriate)...
@validate_arguments @validate_call
def join( def join(
self, self,
words: Optional[Sequence[Word]], words: Optional[Sequence[Word]],

View file

View file

@ -0,0 +1,19 @@
class ValidateCallWrapperWrapper:
def __init__(self, wrapped):
self.orig = wrapped
def __eq__(self, other):
return self.raw_function == other.raw_function
@property
def raw_function(self):
return getattr(self.orig, 'raw_function') or self.orig
def same_method(m1, m2) -> bool:
"""
Return whether m1 and m2 are the same method.
Workaround for pydantic/pydantic#6390.
"""
return ValidateCallWrapperWrapper(m1) == ValidateCallWrapperWrapper(m2)

View file

@ -0,0 +1,8 @@
try:
from pydantic import validate_call # type: ignore
except ImportError:
# Pydantic 1
from pydantic import validate_arguments as validate_call # type: ignore
__all__ = ['validate_call']

View file

@ -5,23 +5,49 @@ import itertools
import copy import copy
import functools import functools
import random import random
from collections.abc import Container, Iterable, Mapping
from typing import Callable, Union
from jaraco.classes.properties import NonDataProperty
import jaraco.text import jaraco.text
_Matchable = Union[Callable, Container, Iterable, re.Pattern]
def _dispatch(obj: _Matchable) -> Callable:
# can't rely on singledispatch for Union[Container, Iterable]
# due to ambiguity
# (https://peps.python.org/pep-0443/#abstract-base-classes).
if isinstance(obj, re.Pattern):
return obj.fullmatch
if not isinstance(obj, Callable): # type: ignore
if not isinstance(obj, Container):
obj = set(obj) # type: ignore
obj = obj.__contains__
return obj # type: ignore
class Projection(collections.abc.Mapping): class Projection(collections.abc.Mapping):
""" """
Project a set of keys over a mapping Project a set of keys over a mapping
>>> sample = {'a': 1, 'b': 2, 'c': 3} >>> sample = {'a': 1, 'b': 2, 'c': 3}
>>> prj = Projection(['a', 'c', 'd'], sample) >>> prj = Projection(['a', 'c', 'd'], sample)
>>> prj == {'a': 1, 'c': 3} >>> dict(prj)
{'a': 1, 'c': 3}
Projection also accepts an iterable or callable or pattern.
>>> iter_prj = Projection(iter('acd'), sample)
>>> call_prj = Projection(lambda k: ord(k) in (97, 99, 100), sample)
>>> pat_prj = Projection(re.compile(r'[acd]'), sample)
>>> prj == iter_prj == call_prj == pat_prj
True True
Keys should only appear if they were specified and exist in the space. Keys should only appear if they were specified and exist in the space.
Order is retained.
>>> sorted(list(prj.keys())) >>> list(prj)
['a', 'c'] ['a', 'c']
Attempting to access a key not in the projection Attempting to access a key not in the projection
@ -36,119 +62,58 @@ class Projection(collections.abc.Mapping):
>>> target = {'a': 2, 'b': 2} >>> target = {'a': 2, 'b': 2}
>>> target.update(prj) >>> target.update(prj)
>>> target == {'a': 1, 'b': 2, 'c': 3} >>> target
True {'a': 1, 'b': 2, 'c': 3}
Also note that Projection keeps a reference to the original dict, so Projection keeps a reference to the original dict, so
if you modify the original dict, that could modify the Projection. modifying the original dict may modify the Projection.
>>> del sample['a'] >>> del sample['a']
>>> dict(prj) >>> dict(prj)
{'c': 3} {'c': 3}
""" """
def __init__(self, keys, space): def __init__(self, keys: _Matchable, space: Mapping):
self._keys = tuple(keys) self._match = _dispatch(keys)
self._space = space self._space = space
def __getitem__(self, key): def __getitem__(self, key):
if key not in self._keys: if not self._match(key):
raise KeyError(key) raise KeyError(key)
return self._space[key] return self._space[key]
def _keys_resolved(self):
return filter(self._match, self._space)
def __iter__(self): def __iter__(self):
return iter(set(self._keys).intersection(self._space)) return self._keys_resolved()
def __len__(self): def __len__(self):
return len(tuple(iter(self))) return len(tuple(self._keys_resolved()))
class DictFilter(collections.abc.Mapping): class Mask(Projection):
""" """
Takes a dict, and simulates a sub-dict based on the keys. The inverse of a :class:`Projection`, masking out keys.
>>> sample = {'a': 1, 'b': 2, 'c': 3} >>> sample = {'a': 1, 'b': 2, 'c': 3}
>>> filtered = DictFilter(sample, ['a', 'c']) >>> msk = Mask(['a', 'c', 'd'], sample)
>>> filtered == {'a': 1, 'c': 3} >>> dict(msk)
True
>>> set(filtered.values()) == {1, 3}
True
>>> set(filtered.items()) == {('a', 1), ('c', 3)}
True
One can also filter by a regular expression pattern
>>> sample['d'] = 4
>>> sample['ef'] = 5
Here we filter for only single-character keys
>>> filtered = DictFilter(sample, include_pattern='.$')
>>> filtered == {'a': 1, 'b': 2, 'c': 3, 'd': 4}
True
>>> filtered['e']
Traceback (most recent call last):
...
KeyError: 'e'
>>> 'e' in filtered
False
Pattern is useful for excluding keys with a prefix.
>>> filtered = DictFilter(sample, include_pattern=r'(?![ace])')
>>> dict(filtered)
{'b': 2, 'd': 4}
Also note that DictFilter keeps a reference to the original dict, so
if you modify the original dict, that could modify the filtered dict.
>>> del sample['d']
>>> dict(filtered)
{'b': 2} {'b': 2}
""" """
def __init__(self, dict, include_keys=[], include_pattern=None): def __init__(self, *args, **kwargs):
self.dict = dict super().__init__(*args, **kwargs)
self.specified_keys = set(include_keys) # self._match = compose(operator.not_, self._match)
if include_pattern is not None: self._match = lambda key, orig=self._match: not orig(key)
self.include_pattern = re.compile(include_pattern)
else:
# for performance, replace the pattern_keys property
self.pattern_keys = set()
def get_pattern_keys(self):
keys = filter(self.include_pattern.match, self.dict.keys())
return set(keys)
pattern_keys = NonDataProperty(get_pattern_keys)
@property
def include_keys(self):
return self.specified_keys | self.pattern_keys
def __getitem__(self, i):
if i not in self.include_keys:
raise KeyError(i)
return self.dict[i]
def __iter__(self):
return filter(self.include_keys.__contains__, self.dict.keys())
def __len__(self):
return len(list(self))
def dict_map(function, dictionary): def dict_map(function, dictionary):
""" """
dict_map is much like the built-in function map. It takes a dictionary Return a new dict with function applied to values of dictionary.
and applys a function to the values of that dictionary, returning a
new dictionary with the mapped values in the original keys.
>>> d = dict_map(lambda x:x+1, dict(a=1, b=2)) >>> dict_map(lambda x: x+1, dict(a=1, b=2))
>>> d == dict(a=2,b=3) {'a': 2, 'b': 3}
True
""" """
return dict((key, function(value)) for key, value in dictionary.items()) return dict((key, function(value)) for key, value in dictionary.items())
@ -164,7 +129,7 @@ class RangeMap(dict):
One may supply keyword parameters to be passed to the sort function used One may supply keyword parameters to be passed to the sort function used
to sort keys (i.e. key, reverse) as sort_params. to sort keys (i.e. key, reverse) as sort_params.
Let's create a map that maps 1-3 -> 'a', 4-6 -> 'b' Create a map that maps 1-3 -> 'a', 4-6 -> 'b'
>>> r = RangeMap({3: 'a', 6: 'b'}) # boy, that was easy >>> r = RangeMap({3: 'a', 6: 'b'}) # boy, that was easy
>>> r[1], r[2], r[3], r[4], r[5], r[6] >>> r[1], r[2], r[3], r[4], r[5], r[6]
@ -176,7 +141,7 @@ class RangeMap(dict):
>>> r[4.5] >>> r[4.5]
'b' 'b'
But you'll notice that the way rangemap is defined, it must be open-ended Notice that the way rangemap is defined, it must be open-ended
on one side. on one side.
>>> r[0] >>> r[0]
@ -279,7 +244,7 @@ class RangeMap(dict):
return (sorted_keys[RangeMap.first_item], sorted_keys[RangeMap.last_item]) return (sorted_keys[RangeMap.first_item], sorted_keys[RangeMap.last_item])
# some special values for the RangeMap # some special values for the RangeMap
undefined_value = type(str('RangeValueUndefined'), (), {})() undefined_value = type('RangeValueUndefined', (), {})()
class Item(int): class Item(int):
"RangeMap Item" "RangeMap Item"
@ -294,7 +259,7 @@ def __identity(x):
def sorted_items(d, key=__identity, reverse=False): def sorted_items(d, key=__identity, reverse=False):
""" """
Return the items of the dictionary sorted by the keys Return the items of the dictionary sorted by the keys.
>>> sample = dict(foo=20, bar=42, baz=10) >>> sample = dict(foo=20, bar=42, baz=10)
>>> tuple(sorted_items(sample)) >>> tuple(sorted_items(sample))
@ -307,6 +272,7 @@ def sorted_items(d, key=__identity, reverse=False):
>>> tuple(sorted_items(sample, reverse=True)) >>> tuple(sorted_items(sample, reverse=True))
(('foo', 20), ('baz', 10), ('bar', 42)) (('foo', 20), ('baz', 10), ('bar', 42))
""" """
# wrap the key func so it operates on the first element of each item # wrap the key func so it operates on the first element of each item
def pairkey_key(item): def pairkey_key(item):
return key(item[0]) return key(item[0])
@ -475,7 +441,7 @@ class ItemsAsAttributes:
Mix-in class to enable a mapping object to provide items as Mix-in class to enable a mapping object to provide items as
attributes. attributes.
>>> C = type(str('C'), (dict, ItemsAsAttributes), dict()) >>> C = type('C', (dict, ItemsAsAttributes), dict())
>>> i = C() >>> i = C()
>>> i['foo'] = 'bar' >>> i['foo'] = 'bar'
>>> i.foo >>> i.foo
@ -504,7 +470,7 @@ class ItemsAsAttributes:
>>> missing_func = lambda self, key: 'missing item' >>> missing_func = lambda self, key: 'missing item'
>>> C = type( >>> C = type(
... str('C'), ... 'C',
... (dict, ItemsAsAttributes), ... (dict, ItemsAsAttributes),
... dict(__missing__ = missing_func), ... dict(__missing__ = missing_func),
... ) ... )

View file

View file

@ -5,10 +5,18 @@ import functools
import tempfile import tempfile
import shutil import shutil
import operator import operator
import warnings
@contextlib.contextmanager @contextlib.contextmanager
def pushd(dir): def pushd(dir):
"""
>>> tmp_path = getfixture('tmp_path')
>>> with pushd(tmp_path):
... assert os.getcwd() == os.fspath(tmp_path)
>>> assert os.getcwd() != os.fspath(tmp_path)
"""
orig = os.getcwd() orig = os.getcwd()
os.chdir(dir) os.chdir(dir)
try: try:
@ -29,6 +37,8 @@ def tarball_context(url, target_dir=None, runner=None, pushd=pushd):
target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '') target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '')
if runner is None: if runner is None:
runner = functools.partial(subprocess.check_call, shell=True) runner = functools.partial(subprocess.check_call, shell=True)
else:
warnings.warn("runner parameter is deprecated", DeprecationWarning)
# In the tar command, use --strip-components=1 to strip the first path and # In the tar command, use --strip-components=1 to strip the first path and
# then # then
# use -C to cause the files to be extracted to {target_dir}. This ensures # use -C to cause the files to be extracted to {target_dir}. This ensures
@ -48,6 +58,15 @@ def tarball_context(url, target_dir=None, runner=None, pushd=pushd):
def infer_compression(url): def infer_compression(url):
""" """
Given a URL or filename, infer the compression code for tar. Given a URL or filename, infer the compression code for tar.
>>> infer_compression('http://foo/bar.tar.gz')
'z'
>>> infer_compression('http://foo/bar.tgz')
'z'
>>> infer_compression('file.bz')
'j'
>>> infer_compression('file.xz')
'J'
""" """
# cheat and just assume it's the last two characters # cheat and just assume it's the last two characters
compression_indicator = url[-2:] compression_indicator = url[-2:]
@ -61,6 +80,12 @@ def temp_dir(remover=shutil.rmtree):
""" """
Create a temporary directory context. Pass a custom remover Create a temporary directory context. Pass a custom remover
to override the removal behavior. to override the removal behavior.
>>> import pathlib
>>> with temp_dir() as the_dir:
... assert os.path.isdir(the_dir)
... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents')
>>> assert not os.path.exists(the_dir)
""" """
temp_dir = tempfile.mkdtemp() temp_dir = tempfile.mkdtemp()
try: try:
@ -90,6 +115,12 @@ def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir):
@contextlib.contextmanager @contextlib.contextmanager
def null(): def null():
"""
A null context suitable to stand in for a meaningful context.
>>> with null() as value:
... assert value is None
"""
yield yield
@ -112,6 +143,10 @@ class ExceptionTrap:
... raise ValueError("1 + 1 is not 3") ... raise ValueError("1 + 1 is not 3")
>>> bool(trap) >>> bool(trap)
True True
>>> trap.value
ValueError('1 + 1 is not 3')
>>> trap.tb
<traceback object at ...>
>>> with ExceptionTrap(ValueError) as trap: >>> with ExceptionTrap(ValueError) as trap:
... raise Exception() ... raise Exception()
@ -211,3 +246,43 @@ class suppress(contextlib.suppress, contextlib.ContextDecorator):
... {}[''] ... {}['']
>>> key_error() >>> key_error()
""" """
class on_interrupt(contextlib.ContextDecorator):
"""
Replace a KeyboardInterrupt with SystemExit(1)
>>> def do_interrupt():
... raise KeyboardInterrupt()
>>> on_interrupt('error')(do_interrupt)()
Traceback (most recent call last):
...
SystemExit: 1
>>> on_interrupt('error', code=255)(do_interrupt)()
Traceback (most recent call last):
...
SystemExit: 255
>>> on_interrupt('suppress')(do_interrupt)()
>>> with __import__('pytest').raises(KeyboardInterrupt):
... on_interrupt('ignore')(do_interrupt)()
"""
def __init__(
self,
action='error',
# py3.7 compat
# /,
code=1,
):
self.action = action
self.code = code
def __enter__(self):
return self
def __exit__(self, exctype, excinst, exctb):
if exctype is not KeyboardInterrupt or self.action == 'ignore':
return
elif self.action == 'error':
raise SystemExit(self.code) from excinst
return self.action == 'suppress'

View file

@ -1,4 +1,4 @@
import collections import collections.abc
import functools import functools
import inspect import inspect
import itertools import itertools
@ -9,11 +9,6 @@ import warnings
import more_itertools import more_itertools
from typing import Callable, TypeVar
CallableT = TypeVar("CallableT", bound=Callable[..., object])
def compose(*funcs): def compose(*funcs):
""" """
@ -39,24 +34,6 @@ def compose(*funcs):
return functools.reduce(compose_two, funcs) return functools.reduce(compose_two, funcs)
def method_caller(method_name, *args, **kwargs):
"""
Return a function that will call a named method on the
target object with optional positional and keyword
arguments.
>>> lower = method_caller('lower')
>>> lower('MyString')
'mystring'
"""
def call_method(target):
func = getattr(target, method_name)
return func(*args, **kwargs)
return call_method
def once(func): def once(func):
""" """
Decorate func so it's only ever called the first time. Decorate func so it's only ever called the first time.
@ -99,12 +76,7 @@ def once(func):
return wrapper return wrapper
def method_cache( def method_cache(method, cache_wrapper=functools.lru_cache()):
method: CallableT,
cache_wrapper: Callable[
[CallableT], CallableT
] = functools.lru_cache(), # type: ignore[assignment]
) -> CallableT:
""" """
Wrap lru_cache to support storing the cache data in the object instances. Wrap lru_cache to support storing the cache data in the object instances.
@ -172,22 +144,17 @@ def method_cache(
for another implementation and additional justification. for another implementation and additional justification.
""" """
def wrapper(self: object, *args: object, **kwargs: object) -> object: def wrapper(self, *args, **kwargs):
# it's the first call, replace the method with a cached, bound method # it's the first call, replace the method with a cached, bound method
bound_method: CallableT = types.MethodType( # type: ignore[assignment] bound_method = types.MethodType(method, self)
method, self
)
cached_method = cache_wrapper(bound_method) cached_method = cache_wrapper(bound_method)
setattr(self, method.__name__, cached_method) setattr(self, method.__name__, cached_method)
return cached_method(*args, **kwargs) return cached_method(*args, **kwargs)
# Support cache clear even before cache has been created. # Support cache clear even before cache has been created.
wrapper.cache_clear = lambda: None # type: ignore[attr-defined] wrapper.cache_clear = lambda: None
return ( return _special_method_cache(method, cache_wrapper) or wrapper
_special_method_cache(method, cache_wrapper) # type: ignore[return-value]
or wrapper
)
def _special_method_cache(method, cache_wrapper): def _special_method_cache(method, cache_wrapper):
@ -203,12 +170,13 @@ def _special_method_cache(method, cache_wrapper):
""" """
name = method.__name__ name = method.__name__
special_names = '__getattr__', '__getitem__' special_names = '__getattr__', '__getitem__'
if name not in special_names: if name not in special_names:
return return None
wrapper_name = '__cached' + name wrapper_name = '__cached' + name
def proxy(self, *args, **kwargs): def proxy(self, /, *args, **kwargs):
if wrapper_name not in vars(self): if wrapper_name not in vars(self):
bound = types.MethodType(method, self) bound = types.MethodType(method, self)
cache = cache_wrapper(bound) cache = cache_wrapper(bound)
@ -245,7 +213,7 @@ def result_invoke(action):
r""" r"""
Decorate a function with an action function that is Decorate a function with an action function that is
invoked on the results returned from the decorated invoked on the results returned from the decorated
function (for its side-effect), then return the original function (for its side effect), then return the original
result. result.
>>> @result_invoke(print) >>> @result_invoke(print)
@ -269,7 +237,7 @@ def result_invoke(action):
return wrap return wrap
def invoke(f, *args, **kwargs): def invoke(f, /, *args, **kwargs):
""" """
Call a function for its side effect after initialization. Call a function for its side effect after initialization.
@ -304,25 +272,15 @@ def invoke(f, *args, **kwargs):
Use functools.partial to pass parameters to the initial call Use functools.partial to pass parameters to the initial call
>>> @functools.partial(invoke, name='bingo') >>> @functools.partial(invoke, name='bingo')
... def func(name): print("called with", name) ... def func(name): print('called with', name)
called with bingo called with bingo
""" """
f(*args, **kwargs) f(*args, **kwargs)
return f return f
def call_aside(*args, **kwargs):
"""
Deprecated name for invoke.
"""
warnings.warn("call_aside is deprecated, use invoke", DeprecationWarning)
return invoke(*args, **kwargs)
class Throttler: class Throttler:
""" """Rate-limit a function (or other callable)."""
Rate-limit a function (or other callable)
"""
def __init__(self, func, max_rate=float('Inf')): def __init__(self, func, max_rate=float('Inf')):
if isinstance(func, Throttler): if isinstance(func, Throttler):
@ -339,20 +297,20 @@ class Throttler:
return self.func(*args, **kwargs) return self.func(*args, **kwargs)
def _wait(self): def _wait(self):
"ensure at least 1/max_rate seconds from last call" """Ensure at least 1/max_rate seconds from last call."""
elapsed = time.time() - self.last_called elapsed = time.time() - self.last_called
must_wait = 1 / self.max_rate - elapsed must_wait = 1 / self.max_rate - elapsed
time.sleep(max(0, must_wait)) time.sleep(max(0, must_wait))
self.last_called = time.time() self.last_called = time.time()
def __get__(self, obj, type=None): def __get__(self, obj, owner=None):
return first_invoke(self._wait, functools.partial(self.func, obj)) return first_invoke(self._wait, functools.partial(self.func, obj))
def first_invoke(func1, func2): def first_invoke(func1, func2):
""" """
Return a function that when invoked will invoke func1 without Return a function that when invoked will invoke func1 without
any parameters (for its side-effect) and then invoke func2 any parameters (for its side effect) and then invoke func2
with whatever parameters were passed, returning its result. with whatever parameters were passed, returning its result.
""" """
@ -363,6 +321,17 @@ def first_invoke(func1, func2):
return wrapper return wrapper
method_caller = first_invoke(
lambda: warnings.warn(
'`jaraco.functools.method_caller` is deprecated, '
'use `operator.methodcaller` instead',
DeprecationWarning,
stacklevel=3,
),
operator.methodcaller,
)
def retry_call(func, cleanup=lambda: None, retries=0, trap=()): def retry_call(func, cleanup=lambda: None, retries=0, trap=()):
""" """
Given a callable func, trap the indicated exceptions Given a callable func, trap the indicated exceptions
@ -371,7 +340,7 @@ def retry_call(func, cleanup=lambda: None, retries=0, trap=()):
to propagate. to propagate.
""" """
attempts = itertools.count() if retries == float('inf') else range(retries) attempts = itertools.count() if retries == float('inf') else range(retries)
for attempt in attempts: for _ in attempts:
try: try:
return func() return func()
except trap: except trap:
@ -408,7 +377,7 @@ def retry(*r_args, **r_kwargs):
def print_yielded(func): def print_yielded(func):
""" """
Convert a generator into a function that prints all yielded elements Convert a generator into a function that prints all yielded elements.
>>> @print_yielded >>> @print_yielded
... def x(): ... def x():
@ -424,7 +393,7 @@ def print_yielded(func):
def pass_none(func): def pass_none(func):
""" """
Wrap func so it's not called if its first param is None Wrap func so it's not called if its first param is None.
>>> print_text = pass_none(print) >>> print_text = pass_none(print)
>>> print_text('text') >>> print_text('text')
@ -433,9 +402,10 @@ def pass_none(func):
""" """
@functools.wraps(func) @functools.wraps(func)
def wrapper(param, *args, **kwargs): def wrapper(param, /, *args, **kwargs):
if param is not None: if param is not None:
return func(param, *args, **kwargs) return func(param, *args, **kwargs)
return None
return wrapper return wrapper
@ -509,7 +479,7 @@ def save_method_args(method):
args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs') args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs')
@functools.wraps(method) @functools.wraps(method)
def wrapper(self, *args, **kwargs): def wrapper(self, /, *args, **kwargs):
attr_name = '_saved_' + method.__name__ attr_name = '_saved_' + method.__name__
attr = args_and_kwargs(args, kwargs) attr = args_and_kwargs(args, kwargs)
setattr(self, attr_name, attr) setattr(self, attr_name, attr)
@ -559,6 +529,13 @@ def except_(*exceptions, replace=None, use=None):
def identity(x): def identity(x):
"""
Return the argument.
>>> o = object()
>>> identity(o) is o
True
"""
return x return x
@ -580,7 +557,7 @@ def bypass_when(check, *, _op=identity):
def decorate(func): def decorate(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(param): def wrapper(param, /):
return param if _op(check) else func(param) return param if _op(check) else func(param)
return wrapper return wrapper
@ -604,3 +581,53 @@ def bypass_unless(check):
2 2
""" """
return bypass_when(check, _op=operator.not_) return bypass_when(check, _op=operator.not_)
@functools.singledispatch
def _splat_inner(args, func):
"""Splat args to func."""
return func(*args)
@_splat_inner.register
def _(args: collections.abc.Mapping, func):
"""Splat kargs to func as kwargs."""
return func(**args)
def splat(func):
"""
Wrap func to expect its parameters to be passed positionally in a tuple.
Has a similar effect to that of ``itertools.starmap`` over
simple ``map``.
>>> pairs = [(-1, 1), (0, 2)]
>>> more_itertools.consume(itertools.starmap(print, pairs))
-1 1
0 2
>>> more_itertools.consume(map(splat(print), pairs))
-1 1
0 2
The approach generalizes to other iterators that don't have a "star"
equivalent, such as a "starfilter".
>>> list(filter(splat(operator.add), pairs))
[(0, 2)]
Splat also accepts a mapping argument.
>>> def is_nice(msg, code):
... return "smile" in msg or code == 0
>>> msgs = [
... dict(msg='smile!', code=20),
... dict(msg='error :(', code=1),
... dict(msg='unknown', code=0),
... ]
>>> for msg in filter(splat(is_nice), msgs):
... print(msg)
{'msg': 'smile!', 'code': 20}
{'msg': 'unknown', 'code': 0}
"""
return functools.wraps(func)(functools.partial(_splat_inner, func=func))

View file

@ -0,0 +1,128 @@
from collections.abc import Callable, Hashable, Iterator
from functools import partial
from operator import methodcaller
import sys
from typing import (
Any,
Generic,
Protocol,
TypeVar,
overload,
)
if sys.version_info >= (3, 10):
from typing import Concatenate, ParamSpec
else:
from typing_extensions import Concatenate, ParamSpec
_P = ParamSpec('_P')
_R = TypeVar('_R')
_T = TypeVar('_T')
_R1 = TypeVar('_R1')
_R2 = TypeVar('_R2')
_V = TypeVar('_V')
_S = TypeVar('_S')
_R_co = TypeVar('_R_co', covariant=True)
class _OnceCallable(Protocol[_P, _R]):
saved_result: _R
reset: Callable[[], None]
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
class _ProxyMethodCacheWrapper(Protocol[_R_co]):
cache_clear: Callable[[], None]
def __call__(self, *args: Hashable, **kwargs: Hashable) -> _R_co: ...
class _MethodCacheWrapper(Protocol[_R_co]):
def cache_clear(self) -> None: ...
def __call__(self, *args: Hashable, **kwargs: Hashable) -> _R_co: ...
# `compose()` overloads below will cover most use cases.
@overload
def compose(
__func1: Callable[[_R], _T],
__func2: Callable[_P, _R],
/,
) -> Callable[_P, _T]: ...
@overload
def compose(
__func1: Callable[[_R], _T],
__func2: Callable[[_R1], _R],
__func3: Callable[_P, _R1],
/,
) -> Callable[_P, _T]: ...
@overload
def compose(
__func1: Callable[[_R], _T],
__func2: Callable[[_R2], _R],
__func3: Callable[[_R1], _R2],
__func4: Callable[_P, _R1],
/,
) -> Callable[_P, _T]: ...
def once(func: Callable[_P, _R]) -> _OnceCallable[_P, _R]: ...
def method_cache(
method: Callable[..., _R],
cache_wrapper: Callable[[Callable[..., _R]], _MethodCacheWrapper[_R]] = ...,
) -> _MethodCacheWrapper[_R] | _ProxyMethodCacheWrapper[_R]: ...
def apply(
transform: Callable[[_R], _T]
) -> Callable[[Callable[_P, _R]], Callable[_P, _T]]: ...
def result_invoke(
action: Callable[[_R], Any]
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ...
def invoke(
f: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs
) -> Callable[_P, _R]: ...
def call_aside(
f: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
) -> Callable[_P, _R]: ...
class Throttler(Generic[_R]):
last_called: float
func: Callable[..., _R]
max_rate: float
def __init__(
self, func: Callable[..., _R] | Throttler[_R], max_rate: float = ...
) -> None: ...
def reset(self) -> None: ...
def __call__(self, *args: Any, **kwargs: Any) -> _R: ...
def __get__(self, obj: Any, owner: type[Any] | None = ...) -> Callable[..., _R]: ...
def first_invoke(
func1: Callable[..., Any], func2: Callable[_P, _R]
) -> Callable[_P, _R]: ...
method_caller: Callable[..., methodcaller]
def retry_call(
func: Callable[..., _R],
cleanup: Callable[..., None] = ...,
retries: int | float = ...,
trap: type[BaseException] | tuple[type[BaseException], ...] = ...,
) -> _R: ...
def retry(
cleanup: Callable[..., None] = ...,
retries: int | float = ...,
trap: type[BaseException] | tuple[type[BaseException], ...] = ...,
) -> Callable[[Callable[..., _R]], Callable[..., _R]]: ...
def print_yielded(func: Callable[_P, Iterator[Any]]) -> Callable[_P, None]: ...
def pass_none(
func: Callable[Concatenate[_T, _P], _R]
) -> Callable[Concatenate[_T, _P], _R]: ...
def assign_params(
func: Callable[..., _R], namespace: dict[str, Any]
) -> partial[_R]: ...
def save_method_args(
method: Callable[Concatenate[_S, _P], _R]
) -> Callable[Concatenate[_S, _P], _R]: ...
def except_(
*exceptions: type[BaseException], replace: Any = ..., use: Any = ...
) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]: ...
def identity(x: _T) -> _T: ...
def bypass_when(
check: _V, *, _op: Callable[[_V], Any] = ...
) -> Callable[[Callable[[_T], _R]], Callable[[_T], _T | _R]]: ...
def bypass_unless(
check: Any,
) -> Callable[[Callable[[_T], _R]], Callable[[_T], _T | _R]]: ...

View file

View file

@ -227,10 +227,12 @@ def unwrap(s):
return '\n'.join(cleaned) return '\n'.join(cleaned)
lorem_ipsum: str = files(__name__).joinpath('Lorem ipsum.txt').read_text() lorem_ipsum: str = (
files(__name__).joinpath('Lorem ipsum.txt').read_text(encoding='utf-8')
)
class Splitter(object): class Splitter:
"""object that will split a string with the given arguments for each call """object that will split a string with the given arguments for each call
>>> s = Splitter(',') >>> s = Splitter(',')
@ -367,7 +369,7 @@ class WordSet(tuple):
return self.trim_left(item).trim_right(item) return self.trim_left(item).trim_right(item)
def __getitem__(self, item): def __getitem__(self, item):
result = super(WordSet, self).__getitem__(item) result = super().__getitem__(item)
if isinstance(item, slice): if isinstance(item, slice):
result = WordSet(result) result = WordSet(result)
return result return result
@ -582,7 +584,7 @@ def join_continuation(lines):
['foobarbaz'] ['foobarbaz']
Not sure why, but... Not sure why, but...
The character preceeding the backslash is also elided. The character preceding the backslash is also elided.
>>> list(join_continuation(['goo\\', 'dly'])) >>> list(join_continuation(['goo\\', 'dly']))
['godly'] ['godly']
@ -607,16 +609,16 @@ def read_newlines(filename, limit=1024):
r""" r"""
>>> tmp_path = getfixture('tmp_path') >>> tmp_path = getfixture('tmp_path')
>>> filename = tmp_path / 'out.txt' >>> filename = tmp_path / 'out.txt'
>>> _ = filename.write_text('foo\n', newline='') >>> _ = filename.write_text('foo\n', newline='', encoding='utf-8')
>>> read_newlines(filename) >>> read_newlines(filename)
'\n' '\n'
>>> _ = filename.write_text('foo\r\n', newline='') >>> _ = filename.write_text('foo\r\n', newline='', encoding='utf-8')
>>> read_newlines(filename) >>> read_newlines(filename)
'\r\n' '\r\n'
>>> _ = filename.write_text('foo\r\nbar\nbing\r', newline='') >>> _ = filename.write_text('foo\r\nbar\nbing\r', newline='', encoding='utf-8')
>>> read_newlines(filename) >>> read_newlines(filename)
('\r', '\n', '\r\n') ('\r', '\n', '\r\n')
""" """
with open(filename) as fp: with open(filename, encoding='utf-8') as fp:
fp.read(limit) fp.read(limit)
return fp.newlines return fp.newlines

View file

@ -12,11 +12,11 @@ def report_newlines(filename):
>>> tmp_path = getfixture('tmp_path') >>> tmp_path = getfixture('tmp_path')
>>> filename = tmp_path / 'out.txt' >>> filename = tmp_path / 'out.txt'
>>> _ = filename.write_text('foo\nbar\n', newline='') >>> _ = filename.write_text('foo\nbar\n', newline='', encoding='utf-8')
>>> report_newlines(filename) >>> report_newlines(filename)
newline is '\n' newline is '\n'
>>> filename = tmp_path / 'out.txt' >>> filename = tmp_path / 'out.txt'
>>> _ = filename.write_text('foo\nbar\r\n', newline='') >>> _ = filename.write_text('foo\nbar\r\n', newline='', encoding='utf-8')
>>> report_newlines(filename) >>> report_newlines(filename)
newlines are ('\n', '\r\n') newlines are ('\n', '\r\n')
""" """

View file

@ -0,0 +1,21 @@
import sys
import autocommand
from jaraco.text import Stripper
def strip_prefix():
r"""
Strip any common prefix from stdin.
>>> import io, pytest
>>> getfixture('monkeypatch').setattr('sys.stdin', io.StringIO('abcdef\nabc123'))
>>> strip_prefix()
def
123
"""
sys.stdout.writelines(Stripper.strip_prefix(sys.stdin).lines)
autocommand.autocommand(__name__)(strip_prefix)

View file

@ -3,4 +3,4 @@
from .more import * # noqa from .more import * # noqa
from .recipes import * # noqa from .recipes import * # noqa
__version__ = '10.1.0' __version__ = '10.2.0'

View file

@ -19,7 +19,7 @@ from itertools import (
zip_longest, zip_longest,
product, product,
) )
from math import exp, factorial, floor, log from math import exp, factorial, floor, log, perm, comb
from queue import Empty, Queue from queue import Empty, Queue
from random import random, randrange, uniform from random import random, randrange, uniform
from operator import itemgetter, mul, sub, gt, lt, ge, le from operator import itemgetter, mul, sub, gt, lt, ge, le
@ -68,8 +68,10 @@ __all__ = [
'divide', 'divide',
'duplicates_everseen', 'duplicates_everseen',
'duplicates_justseen', 'duplicates_justseen',
'classify_unique',
'exactly_n', 'exactly_n',
'filter_except', 'filter_except',
'filter_map',
'first', 'first',
'gray_product', 'gray_product',
'groupby_transform', 'groupby_transform',
@ -83,6 +85,7 @@ __all__ = [
'is_sorted', 'is_sorted',
'islice_extended', 'islice_extended',
'iterate', 'iterate',
'iter_suppress',
'last', 'last',
'locate', 'locate',
'longest_common_prefix', 'longest_common_prefix',
@ -198,14 +201,13 @@ def first(iterable, default=_marker):
``next(iter(iterable), default)``. ``next(iter(iterable), default)``.
""" """
try: for item in iterable:
return next(iter(iterable)) return item
except StopIteration as e:
if default is _marker: if default is _marker:
raise ValueError( raise ValueError(
'first() was called on an empty iterable, and no ' 'first() was called on an empty iterable, and no '
'default value was provided.' 'default value was provided.'
) from e )
return default return default
@ -582,6 +584,9 @@ def strictly_n(iterable, n, too_short=None, too_long=None):
>>> list(strictly_n(iterable, n)) >>> list(strictly_n(iterable, n))
['a', 'b', 'c', 'd'] ['a', 'b', 'c', 'd']
Note that the returned iterable must be consumed in order for the check to
be made.
By default, *too_short* and *too_long* are functions that raise By default, *too_short* and *too_long* are functions that raise
``ValueError``. ``ValueError``.
@ -919,7 +924,7 @@ def substrings_indexes(seq, reverse=False):
class bucket: class bucket:
"""Wrap *iterable* and return an object that buckets it iterable into """Wrap *iterable* and return an object that buckets the iterable into
child iterables based on a *key* function. child iterables based on a *key* function.
>>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
@ -3222,6 +3227,8 @@ class time_limited:
stops if the time elapsed is greater than *limit_seconds*. If your time stops if the time elapsed is greater than *limit_seconds*. If your time
limit is 1 second, but it takes 2 seconds to generate the first item from limit is 1 second, but it takes 2 seconds to generate the first item from
the iterable, the function will run for 2 seconds and not yield anything. the iterable, the function will run for 2 seconds and not yield anything.
As a special case, when *limit_seconds* is zero, the iterator never
returns anything.
""" """
@ -3237,6 +3244,9 @@ class time_limited:
return self return self
def __next__(self): def __next__(self):
if self.limit_seconds == 0:
self.timed_out = True
raise StopIteration
item = next(self._iterable) item = next(self._iterable)
if monotonic() - self._start_time > self.limit_seconds: if monotonic() - self._start_time > self.limit_seconds:
self.timed_out = True self.timed_out = True
@ -3356,7 +3366,7 @@ def iequals(*iterables):
>>> iequals("abc", "acb") >>> iequals("abc", "acb")
False False
Not to be confused with :func:`all_equals`, which checks whether all Not to be confused with :func:`all_equal`, which checks whether all
elements of iterable are equal to each other. elements of iterable are equal to each other.
""" """
@ -3853,7 +3863,7 @@ def nth_permutation(iterable, r, index):
elif not 0 <= r < n: elif not 0 <= r < n:
raise ValueError raise ValueError
else: else:
c = factorial(n) // factorial(n - r) c = perm(n, r)
if index < 0: if index < 0:
index += c index += c
@ -3898,7 +3908,7 @@ def nth_combination_with_replacement(iterable, r, index):
if (r < 0) or (r > n): if (r < 0) or (r > n):
raise ValueError raise ValueError
c = factorial(n + r - 1) // (factorial(r) * factorial(n - 1)) c = comb(n + r - 1, r)
if index < 0: if index < 0:
index += c index += c
@ -3911,9 +3921,7 @@ def nth_combination_with_replacement(iterable, r, index):
while r: while r:
r -= 1 r -= 1
while n >= 0: while n >= 0:
num_combs = factorial(n + r - 1) // ( num_combs = comb(n + r - 1, r)
factorial(r) * factorial(n - 1)
)
if index < num_combs: if index < num_combs:
break break
n -= 1 n -= 1
@ -4015,9 +4023,9 @@ def combination_index(element, iterable):
for i, j in enumerate(reversed(indexes), start=1): for i, j in enumerate(reversed(indexes), start=1):
j = n - j j = n - j
if i <= j: if i <= j:
index += factorial(j) // (factorial(i) * factorial(j - i)) index += comb(j, i)
return factorial(n + 1) // (factorial(k + 1) * factorial(n - k)) - index return comb(n + 1, k + 1) - index
def combination_with_replacement_index(element, iterable): def combination_with_replacement_index(element, iterable):
@ -4057,7 +4065,7 @@ def combination_with_replacement_index(element, iterable):
break break
else: else:
raise ValueError( raise ValueError(
'element is not a combination with replacment of iterable' 'element is not a combination with replacement of iterable'
) )
n = len(pool) n = len(pool)
@ -4066,11 +4074,13 @@ def combination_with_replacement_index(element, iterable):
occupations[p] += 1 occupations[p] += 1
index = 0 index = 0
cumulative_sum = 0
for k in range(1, n): for k in range(1, n):
j = l + n - 1 - k - sum(occupations[:k]) cumulative_sum += occupations[k - 1]
j = l + n - 1 - k - cumulative_sum
i = n - k i = n - k
if i <= j: if i <= j:
index += factorial(j) // (factorial(i) * factorial(j - i)) index += comb(j, i)
return index return index
@ -4296,7 +4306,7 @@ def duplicates_everseen(iterable, key=None):
>>> list(duplicates_everseen('AaaBbbCccAaa', str.lower)) >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a'] ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
This function is analagous to :func:`unique_everseen` and is subject to This function is analogous to :func:`unique_everseen` and is subject to
the same performance considerations. the same performance considerations.
""" """
@ -4326,12 +4336,54 @@ def duplicates_justseen(iterable, key=None):
>>> list(duplicates_justseen('AaaBbbCccAaa', str.lower)) >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a'] ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
This function is analagous to :func:`unique_justseen`. This function is analogous to :func:`unique_justseen`.
""" """
return flatten(g for _, g in groupby(iterable, key) for _ in g) return flatten(g for _, g in groupby(iterable, key) for _ in g)
def classify_unique(iterable, key=None):
"""Classify each element in terms of its uniqueness.
For each element in the input iterable, return a 3-tuple consisting of:
1. The element itself
2. ``False`` if the element is equal to the one preceding it in the input,
``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
3. ``False`` if this element has been seen anywhere in the input before,
``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
>>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
[('o', True, True),
('t', True, True),
('t', False, False),
('o', True, False)]
This function is analogous to :func:`unique_everseen` and is subject to
the same performance considerations.
"""
seen_set = set()
seen_list = []
use_key = key is not None
previous = None
for i, element in enumerate(iterable):
k = key(element) if use_key else element
is_unique_justseen = not i or previous != k
previous = k
is_unique_everseen = False
try:
if k not in seen_set:
seen_set.add(k)
is_unique_everseen = True
except TypeError:
if k not in seen_list:
seen_list.append(k)
is_unique_everseen = True
yield element, is_unique_justseen, is_unique_everseen
def minmax(iterable_or_value, *others, key=None, default=_marker): def minmax(iterable_or_value, *others, key=None, default=_marker):
"""Returns both the smallest and largest items in an iterable """Returns both the smallest and largest items in an iterable
or the largest of two or more arguments. or the largest of two or more arguments.
@ -4529,10 +4581,8 @@ def takewhile_inclusive(predicate, iterable):
:func:`takewhile` would return ``[1, 4]``. :func:`takewhile` would return ``[1, 4]``.
""" """
for x in iterable: for x in iterable:
if predicate(x):
yield x
else:
yield x yield x
if not predicate(x):
break break
@ -4567,3 +4617,40 @@ def outer_product(func, xs, ys, *args, **kwargs):
starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)), starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
n=len(ys), n=len(ys),
) )
def iter_suppress(iterable, *exceptions):
"""Yield each of the items from *iterable*. If the iteration raises one of
the specified *exceptions*, that exception will be suppressed and iteration
will stop.
>>> from itertools import chain
>>> def breaks_at_five(x):
... while True:
... if x >= 5:
... raise RuntimeError
... yield x
... x += 1
>>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
>>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
>>> list(chain(it_1, it_2))
[1, 2, 3, 4, 2, 3, 4]
"""
try:
yield from iterable
except exceptions:
return
def filter_map(func, iterable):
"""Apply *func* to every element of *iterable*, yielding only those which
are not ``None``.
>>> elems = ['1', 'a', '2', 'b', '3']
>>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
[1, 2, 3]
"""
for x in iterable:
y = func(x)
if y is not None:
yield y

View file

@ -29,7 +29,7 @@ _U = TypeVar('_U')
_V = TypeVar('_V') _V = TypeVar('_V')
_W = TypeVar('_W') _W = TypeVar('_W')
_T_co = TypeVar('_T_co', covariant=True) _T_co = TypeVar('_T_co', covariant=True)
_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[object]]) _GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[Any]])
_Raisable = BaseException | Type[BaseException] _Raisable = BaseException | Type[BaseException]
@type_check_only @type_check_only
@ -74,7 +74,7 @@ class peekable(Generic[_T], Iterator[_T]):
def __getitem__(self, index: slice) -> list[_T]: ... def __getitem__(self, index: slice) -> list[_T]: ...
def consumer(func: _GenFn) -> _GenFn: ... def consumer(func: _GenFn) -> _GenFn: ...
def ilen(iterable: Iterable[object]) -> int: ... def ilen(iterable: Iterable[_T]) -> int: ...
def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ... def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ...
def with_iter( def with_iter(
context_manager: ContextManager[Iterable[_T]], context_manager: ContextManager[Iterable[_T]],
@ -116,7 +116,7 @@ class bucket(Generic[_T, _U], Container[_U]):
self, self,
iterable: Iterable[_T], iterable: Iterable[_T],
key: Callable[[_T], _U], key: Callable[[_T], _U],
validator: Callable[[object], object] | None = ..., validator: Callable[[_U], object] | None = ...,
) -> None: ... ) -> None: ...
def __contains__(self, value: object) -> bool: ... def __contains__(self, value: object) -> bool: ...
def __iter__(self) -> Iterator[_U]: ... def __iter__(self) -> Iterator[_U]: ...
@ -383,7 +383,7 @@ def mark_ends(
iterable: Iterable[_T], iterable: Iterable[_T],
) -> Iterable[tuple[bool, bool, _T]]: ... ) -> Iterable[tuple[bool, bool, _T]]: ...
def locate( def locate(
iterable: Iterable[object], iterable: Iterable[_T],
pred: Callable[..., Any] = ..., pred: Callable[..., Any] = ...,
window_size: int | None = ..., window_size: int | None = ...,
) -> Iterator[int]: ... ) -> Iterator[int]: ...
@ -618,6 +618,9 @@ def duplicates_everseen(
def duplicates_justseen( def duplicates_justseen(
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ... iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
def classify_unique(
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
) -> Iterator[tuple[_T, bool, bool]]: ...
class _SupportsLessThan(Protocol): class _SupportsLessThan(Protocol):
def __lt__(self, __other: Any) -> bool: ... def __lt__(self, __other: Any) -> bool: ...
@ -662,9 +665,9 @@ def minmax(
def longest_common_prefix( def longest_common_prefix(
iterables: Iterable[Iterable[_T]], iterables: Iterable[Iterable[_T]],
) -> Iterator[_T]: ... ) -> Iterator[_T]: ...
def iequals(*iterables: Iterable[object]) -> bool: ... def iequals(*iterables: Iterable[Any]) -> bool: ...
def constrained_batches( def constrained_batches(
iterable: Iterable[object], iterable: Iterable[_T],
max_size: int, max_size: int,
max_count: int | None = ..., max_count: int | None = ...,
get_len: Callable[[_T], object] = ..., get_len: Callable[[_T], object] = ...,
@ -682,3 +685,11 @@ def outer_product(
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> Iterator[tuple[_V, ...]]: ... ) -> Iterator[tuple[_V, ...]]: ...
def iter_suppress(
iterable: Iterable[_T],
*exceptions: Type[BaseException],
) -> Iterator[_T]: ...
def filter_map(
func: Callable[[_T], _V | None],
iterable: Iterable[_T],
) -> Iterator[_V]: ...

View file

@ -28,6 +28,7 @@ from itertools import (
zip_longest, zip_longest,
) )
from random import randrange, sample, choice from random import randrange, sample, choice
from sys import hexversion
__all__ = [ __all__ = [
'all_equal', 'all_equal',
@ -56,6 +57,7 @@ __all__ = [
'powerset', 'powerset',
'prepend', 'prepend',
'quantify', 'quantify',
'reshape',
'random_combination_with_replacement', 'random_combination_with_replacement',
'random_combination', 'random_combination',
'random_permutation', 'random_permutation',
@ -69,6 +71,7 @@ __all__ = [
'tabulate', 'tabulate',
'tail', 'tail',
'take', 'take',
'totient',
'transpose', 'transpose',
'triplewise', 'triplewise',
'unique_everseen', 'unique_everseen',
@ -492,7 +495,7 @@ def unique_everseen(iterable, key=None):
>>> list(unique_everseen(iterable, key=tuple)) # Faster >>> list(unique_everseen(iterable, key=tuple)) # Faster
[[1, 2], [2, 3]] [[1, 2], [2, 3]]
Similary, you may want to convert unhashable ``set`` objects with Similarly, you may want to convert unhashable ``set`` objects with
``key=frozenset``. For ``dict`` objects, ``key=frozenset``. For ``dict`` objects,
``key=lambda x: frozenset(x.items())`` can be used. ``key=lambda x: frozenset(x.items())`` can be used.
@ -524,6 +527,9 @@ def unique_justseen(iterable, key=None):
['A', 'B', 'C', 'A', 'D'] ['A', 'B', 'C', 'A', 'D']
""" """
if key is None:
return map(operator.itemgetter(0), groupby(iterable))
return map(next, map(operator.itemgetter(1), groupby(iterable, key))) return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
@ -817,35 +823,34 @@ def polynomial_from_roots(roots):
return list(reduce(convolve, factors, [1])) return list(reduce(convolve, factors, [1]))
def iter_index(iterable, value, start=0): def iter_index(iterable, value, start=0, stop=None):
"""Yield the index of each place in *iterable* that *value* occurs, """Yield the index of each place in *iterable* that *value* occurs,
beginning with index *start*. beginning with index *start* and ending before index *stop*.
See :func:`locate` for a more general means of finding the indexes See :func:`locate` for a more general means of finding the indexes
associated with particular values. associated with particular values.
>>> list(iter_index('AABCADEAF', 'A')) >>> list(iter_index('AABCADEAF', 'A'))
[0, 1, 4, 7] [0, 1, 4, 7]
>>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive
[1, 4, 7]
>>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive
[1, 4]
""" """
try: seq_index = getattr(iterable, 'index', None)
seq_index = iterable.index if seq_index is None:
except AttributeError:
# Slow path for general iterables # Slow path for general iterables
it = islice(iterable, start, None) it = islice(iterable, start, stop)
i = start - 1 for i, element in enumerate(it, start):
try: if element is value or element == value:
while True:
i = i + operator.indexOf(it, value) + 1
yield i yield i
except ValueError:
pass
else: else:
# Fast path for sequences # Fast path for sequences
stop = len(iterable) if stop is None else stop
i = start - 1 i = start - 1
try: try:
while True: while True:
i = seq_index(value, i + 1) yield (i := seq_index(value, i + 1, stop))
yield i
except ValueError: except ValueError:
pass pass
@ -856,47 +861,52 @@ def sieve(n):
>>> list(sieve(30)) >>> list(sieve(30))
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29] [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
""" """
if n > 2:
yield 2
start = 3
data = bytearray((0, 1)) * (n // 2) data = bytearray((0, 1)) * (n // 2)
data[:3] = 0, 0, 0
limit = math.isqrt(n) + 1 limit = math.isqrt(n) + 1
for p in compress(range(limit), data): for p in iter_index(data, 1, start, limit):
yield from iter_index(data, 1, start, p * p)
data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p))) data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
data[2] = 1 start = p * p
return iter_index(data, 1) if n > 2 else iter([]) yield from iter_index(data, 1, start)
def _batched(iterable, n): def _batched(iterable, n, *, strict=False):
"""Batch data into lists of length *n*. The last batch may be shorter. """Batch data into tuples of length *n*. If the number of items in
*iterable* is not divisible by *n*:
* The last batch will be shorter if *strict* is ``False``.
* :exc:`ValueError` will be raised if *strict* is ``True``.
>>> list(batched('ABCDEFG', 3)) >>> list(batched('ABCDEFG', 3))
[('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)] [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
On Python 3.12 and above, this is an alias for :func:`itertools.batched`. On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
""" """
if n < 1: if n < 1:
raise ValueError('n must be at least one') raise ValueError('n must be at least one')
it = iter(iterable) it = iter(iterable)
while True: while batch := tuple(islice(it, n)):
batch = tuple(islice(it, n)) if strict and len(batch) != n:
if not batch: raise ValueError('batched(): incomplete batch')
break
yield batch yield batch
try: if hexversion >= 0x30D00A2:
from itertools import batched as itertools_batched from itertools import batched as itertools_batched
except ImportError:
batched = _batched
else:
def batched(iterable, n): def batched(iterable, n, *, strict=False):
return itertools_batched(iterable, n) return itertools_batched(iterable, n, strict=strict)
else:
batched = _batched
batched.__doc__ = _batched.__doc__ batched.__doc__ = _batched.__doc__
def transpose(it): def transpose(it):
"""Swap the rows and columns of the input. """Swap the rows and columns of the input matrix.
>>> list(transpose([(1, 2, 3), (11, 22, 33)])) >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
[(1, 11), (2, 22), (3, 33)] [(1, 11), (2, 22), (3, 33)]
@ -907,8 +917,20 @@ def transpose(it):
return _zip_strict(*it) return _zip_strict(*it)
def reshape(matrix, cols):
"""Reshape the 2-D input *matrix* to have a column count given by *cols*.
>>> matrix = [(0, 1), (2, 3), (4, 5)]
>>> cols = 3
>>> list(reshape(matrix, cols))
[(0, 1, 2), (3, 4, 5)]
"""
return batched(chain.from_iterable(matrix), cols)
def matmul(m1, m2): def matmul(m1, m2):
"""Multiply two matrices. """Multiply two matrices.
>>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)])) >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
[(49, 80), (41, 60)] [(49, 80), (41, 60)]
@ -921,13 +943,12 @@ def matmul(m1, m2):
def factor(n): def factor(n):
"""Yield the prime factors of n. """Yield the prime factors of n.
>>> list(factor(360)) >>> list(factor(360))
[2, 2, 2, 3, 3, 5] [2, 2, 2, 3, 3, 5]
""" """
for prime in sieve(math.isqrt(n) + 1): for prime in sieve(math.isqrt(n) + 1):
while True: while not n % prime:
if n % prime:
break
yield prime yield prime
n //= prime n //= prime
if n == 1: if n == 1:
@ -975,3 +996,17 @@ def polynomial_derivative(coefficients):
n = len(coefficients) n = len(coefficients)
powers = reversed(range(1, n)) powers = reversed(range(1, n))
return list(map(operator.mul, coefficients, powers)) return list(map(operator.mul, coefficients, powers))
def totient(n):
"""Return the count of natural numbers up to *n* that are coprime with *n*.
>>> totient(9)
6
>>> totient(12)
4
"""
for p in unique_justseen(factor(n)):
n = n // p * (p - 1)
return n

View file

@ -14,6 +14,8 @@ from typing import (
# Type and type variable definitions # Type and type variable definitions
_T = TypeVar('_T') _T = TypeVar('_T')
_T1 = TypeVar('_T1')
_T2 = TypeVar('_T2')
_U = TypeVar('_U') _U = TypeVar('_U')
def take(n: int, iterable: Iterable[_T]) -> list[_T]: ... def take(n: int, iterable: Iterable[_T]) -> list[_T]: ...
@ -26,14 +28,14 @@ def consume(iterator: Iterable[_T], n: int | None = ...) -> None: ...
def nth(iterable: Iterable[_T], n: int) -> _T | None: ... def nth(iterable: Iterable[_T], n: int) -> _T | None: ...
@overload @overload
def nth(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ... def nth(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ...
def all_equal(iterable: Iterable[object]) -> bool: ... def all_equal(iterable: Iterable[_T]) -> bool: ...
def quantify( def quantify(
iterable: Iterable[_T], pred: Callable[[_T], bool] = ... iterable: Iterable[_T], pred: Callable[[_T], bool] = ...
) -> int: ... ) -> int: ...
def pad_none(iterable: Iterable[_T]) -> Iterator[_T | None]: ... def pad_none(iterable: Iterable[_T]) -> Iterator[_T | None]: ...
def padnone(iterable: Iterable[_T]) -> Iterator[_T | None]: ... def padnone(iterable: Iterable[_T]) -> Iterator[_T | None]: ...
def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: ... def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: ...
def dotproduct(vec1: Iterable[object], vec2: Iterable[object]) -> object: ... def dotproduct(vec1: Iterable[_T1], vec2: Iterable[_T2]) -> Any: ...
def flatten(listOfLists: Iterable[Iterable[_T]]) -> Iterator[_T]: ... def flatten(listOfLists: Iterable[Iterable[_T]]) -> Iterator[_T]: ...
def repeatfunc( def repeatfunc(
func: Callable[..., _U], times: int | None = ..., *args: Any func: Callable[..., _U], times: int | None = ..., *args: Any
@ -103,20 +105,24 @@ def sliding_window(
def subslices(iterable: Iterable[_T]) -> Iterator[list[_T]]: ... def subslices(iterable: Iterable[_T]) -> Iterator[list[_T]]: ...
def polynomial_from_roots(roots: Sequence[_T]) -> list[_T]: ... def polynomial_from_roots(roots: Sequence[_T]) -> list[_T]: ...
def iter_index( def iter_index(
iterable: Iterable[object], iterable: Iterable[_T],
value: Any, value: Any,
start: int | None = ..., start: int | None = ...,
stop: int | None = ...,
) -> Iterator[int]: ... ) -> Iterator[int]: ...
def sieve(n: int) -> Iterator[int]: ... def sieve(n: int) -> Iterator[int]: ...
def batched( def batched(
iterable: Iterable[_T], iterable: Iterable[_T], n: int, *, strict: bool = False
n: int,
) -> Iterator[tuple[_T]]: ... ) -> Iterator[tuple[_T]]: ...
def transpose( def transpose(
it: Iterable[Iterable[_T]], it: Iterable[Iterable[_T]],
) -> Iterator[tuple[_T, ...]]: ... ) -> Iterator[tuple[_T, ...]]: ...
def reshape(
matrix: Iterable[Iterable[_T]], cols: int
) -> Iterator[tuple[_T, ...]]: ...
def matmul(m1: Sequence[_T], m2: Sequence[_T]) -> Iterator[tuple[_T]]: ... def matmul(m1: Sequence[_T], m2: Sequence[_T]) -> Iterator[tuple[_T]]: ...
def factor(n: int) -> Iterator[int]: ... def factor(n: int) -> Iterator[int]: ...
def polynomial_eval(coefficients: Sequence[_T], x: _U) -> _U: ... def polynomial_eval(coefficients: Sequence[_T], x: _U) -> _U: ...
def sum_of_squares(it: Iterable[_T]) -> _T: ... def sum_of_squares(it: Iterable[_T]) -> _T: ...
def polynomial_derivative(coefficients: Sequence[_T]) -> list[_T]: ... def polynomial_derivative(coefficients: Sequence[_T]) -> list[_T]: ...
def totient(n: int) -> int: ...

View file

@ -1,56 +1,114 @@
# flake8: noqa import typing
from . import dataclasses
from .annotated_types import create_model_from_namedtuple, create_model_from_typeddict from ._migration import getattr_migration
from .class_validators import root_validator, validator from .version import VERSION
from .config import BaseConfig, ConfigDict, Extra
from .decorator import validate_arguments if typing.TYPE_CHECKING:
from .env_settings import BaseSettings # import of virtually everything is supported via `__getattr__` below,
from .error_wrappers import ValidationError # but we need them here for type checking and IDE support
from .errors import * import pydantic_core
from .fields import Field, PrivateAttr, Required from pydantic_core.core_schema import (
from .main import * FieldSerializationInfo,
from .networks import * SerializationInfo,
from .parse import Protocol SerializerFunctionWrapHandler,
from .tools import * ValidationInfo,
from .types import * ValidatorFunctionWrapHandler,
from .version import VERSION, compiled )
from . import dataclasses
from ._internal._generate_schema import GenerateSchema as GenerateSchema
from .aliases import AliasChoices, AliasGenerator, AliasPath
from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
from .config import ConfigDict
from .errors import *
from .fields import Field, PrivateAttr, computed_field
from .functional_serializers import (
PlainSerializer,
SerializeAsAny,
WrapSerializer,
field_serializer,
model_serializer,
)
from .functional_validators import (
AfterValidator,
BeforeValidator,
InstanceOf,
PlainValidator,
SkipValidation,
WrapValidator,
field_validator,
model_validator,
)
from .json_schema import WithJsonSchema
from .main import *
from .networks import *
from .type_adapter import TypeAdapter
from .types import *
from .validate_call_decorator import validate_call
from .warnings import PydanticDeprecatedSince20, PydanticDeprecatedSince26, PydanticDeprecationWarning
# this encourages pycharm to import `ValidationError` from here, not pydantic_core
ValidationError = pydantic_core.ValidationError
from .deprecated.class_validators import root_validator, validator
from .deprecated.config import BaseConfig, Extra
from .deprecated.tools import *
from .root_model import RootModel
__version__ = VERSION __version__ = VERSION
__all__ = (
# WARNING __all__ from .errors is not included here, it will be removed as an export here in v2
# please use "from pydantic.errors import ..." instead
__all__ = [
# annotated types utils
'create_model_from_namedtuple',
'create_model_from_typeddict',
# dataclasses # dataclasses
'dataclasses', 'dataclasses',
# class_validators # functional validators
'field_validator',
'model_validator',
'AfterValidator',
'BeforeValidator',
'PlainValidator',
'WrapValidator',
'SkipValidation',
'InstanceOf',
# JSON Schema
'WithJsonSchema',
# deprecated V1 functional validators, these are imported via `__getattr__` below
'root_validator', 'root_validator',
'validator', 'validator',
# functional serializers
'field_serializer',
'model_serializer',
'PlainSerializer',
'SerializeAsAny',
'WrapSerializer',
# config # config
'BaseConfig',
'ConfigDict', 'ConfigDict',
# deprecated V1 config, these are imported via `__getattr__` below
'BaseConfig',
'Extra', 'Extra',
# decorator # validate_call
'validate_arguments', 'validate_call',
# env_settings # errors
'BaseSettings', 'PydanticErrorCodes',
# error_wrappers 'PydanticUserError',
'ValidationError', 'PydanticSchemaGenerationError',
'PydanticImportError',
'PydanticUndefinedAnnotation',
'PydanticInvalidForJsonSchema',
# fields # fields
'Field', 'Field',
'Required', 'computed_field',
'PrivateAttr',
# alias
'AliasChoices',
'AliasGenerator',
'AliasPath',
# main # main
'BaseModel', 'BaseModel',
'create_model', 'create_model',
'validate_model',
# network # network
'AnyUrl', 'AnyUrl',
'AnyHttpUrl', 'AnyHttpUrl',
'FileUrl', 'FileUrl',
'HttpUrl', 'HttpUrl',
'stricturl', 'UrlConstraints',
'EmailStr', 'EmailStr',
'NameEmail', 'NameEmail',
'IPvAnyAddress', 'IPvAnyAddress',
@ -62,48 +120,38 @@ __all__ = [
'RedisDsn', 'RedisDsn',
'MongoDsn', 'MongoDsn',
'KafkaDsn', 'KafkaDsn',
'NatsDsn',
'MySQLDsn',
'MariaDBDsn',
'validate_email', 'validate_email',
# parse # root_model
'Protocol', 'RootModel',
# tools # deprecated tools, these are imported via `__getattr__` below
'parse_file_as',
'parse_obj_as', 'parse_obj_as',
'parse_raw_as',
'schema_of', 'schema_of',
'schema_json_of', 'schema_json_of',
# types # types
'NoneStr', 'Strict',
'NoneBytes',
'StrBytes',
'NoneStrBytes',
'StrictStr', 'StrictStr',
'ConstrainedBytes',
'conbytes', 'conbytes',
'ConstrainedList',
'conlist', 'conlist',
'ConstrainedSet',
'conset', 'conset',
'ConstrainedFrozenSet',
'confrozenset', 'confrozenset',
'ConstrainedStr',
'constr', 'constr',
'PyObject', 'StringConstraints',
'ConstrainedInt', 'ImportString',
'conint', 'conint',
'PositiveInt', 'PositiveInt',
'NegativeInt', 'NegativeInt',
'NonNegativeInt', 'NonNegativeInt',
'NonPositiveInt', 'NonPositiveInt',
'ConstrainedFloat',
'confloat', 'confloat',
'PositiveFloat', 'PositiveFloat',
'NegativeFloat', 'NegativeFloat',
'NonNegativeFloat', 'NonNegativeFloat',
'NonPositiveFloat', 'NonPositiveFloat',
'FiniteFloat', 'FiniteFloat',
'ConstrainedDecimal',
'condecimal', 'condecimal',
'ConstrainedDate',
'condate', 'condate',
'UUID1', 'UUID1',
'UUID3', 'UUID3',
@ -111,9 +159,8 @@ __all__ = [
'UUID5', 'UUID5',
'FilePath', 'FilePath',
'DirectoryPath', 'DirectoryPath',
'NewPath',
'Json', 'Json',
'JsonWrapper',
'SecretField',
'SecretStr', 'SecretStr',
'SecretBytes', 'SecretBytes',
'StrictBool', 'StrictBool',
@ -121,11 +168,221 @@ __all__ = [
'StrictInt', 'StrictInt',
'StrictFloat', 'StrictFloat',
'PaymentCardNumber', 'PaymentCardNumber',
'PrivateAttr',
'ByteSize', 'ByteSize',
'PastDate', 'PastDate',
'FutureDate', 'FutureDate',
'PastDatetime',
'FutureDatetime',
'AwareDatetime',
'NaiveDatetime',
'AllowInfNan',
'EncoderProtocol',
'EncodedBytes',
'EncodedStr',
'Base64Encoder',
'Base64Bytes',
'Base64Str',
'Base64UrlBytes',
'Base64UrlStr',
'GetPydanticSchema',
'Tag',
'Discriminator',
'JsonValue',
# type_adapter
'TypeAdapter',
# version # version
'compiled', '__version__',
'VERSION', 'VERSION',
] # warnings
'PydanticDeprecatedSince20',
'PydanticDeprecatedSince26',
'PydanticDeprecationWarning',
# annotated handlers
'GetCoreSchemaHandler',
'GetJsonSchemaHandler',
# generate schema from ._internal
'GenerateSchema',
# pydantic_core
'ValidationError',
'ValidationInfo',
'SerializationInfo',
'ValidatorFunctionWrapHandler',
'FieldSerializationInfo',
'SerializerFunctionWrapHandler',
'OnErrorOmit',
)
# A mapping of {<member name>: (package, <module name>)} defining dynamic imports
_dynamic_imports: 'dict[str, tuple[str, str]]' = {
'dataclasses': (__package__, '__module__'),
# functional validators
'field_validator': (__package__, '.functional_validators'),
'model_validator': (__package__, '.functional_validators'),
'AfterValidator': (__package__, '.functional_validators'),
'BeforeValidator': (__package__, '.functional_validators'),
'PlainValidator': (__package__, '.functional_validators'),
'WrapValidator': (__package__, '.functional_validators'),
'SkipValidation': (__package__, '.functional_validators'),
'InstanceOf': (__package__, '.functional_validators'),
# JSON Schema
'WithJsonSchema': (__package__, '.json_schema'),
# functional serializers
'field_serializer': (__package__, '.functional_serializers'),
'model_serializer': (__package__, '.functional_serializers'),
'PlainSerializer': (__package__, '.functional_serializers'),
'SerializeAsAny': (__package__, '.functional_serializers'),
'WrapSerializer': (__package__, '.functional_serializers'),
# config
'ConfigDict': (__package__, '.config'),
# validate call
'validate_call': (__package__, '.validate_call_decorator'),
# errors
'PydanticErrorCodes': (__package__, '.errors'),
'PydanticUserError': (__package__, '.errors'),
'PydanticSchemaGenerationError': (__package__, '.errors'),
'PydanticImportError': (__package__, '.errors'),
'PydanticUndefinedAnnotation': (__package__, '.errors'),
'PydanticInvalidForJsonSchema': (__package__, '.errors'),
# fields
'Field': (__package__, '.fields'),
'computed_field': (__package__, '.fields'),
'PrivateAttr': (__package__, '.fields'),
# alias
'AliasChoices': (__package__, '.aliases'),
'AliasGenerator': (__package__, '.aliases'),
'AliasPath': (__package__, '.aliases'),
# main
'BaseModel': (__package__, '.main'),
'create_model': (__package__, '.main'),
# network
'AnyUrl': (__package__, '.networks'),
'AnyHttpUrl': (__package__, '.networks'),
'FileUrl': (__package__, '.networks'),
'HttpUrl': (__package__, '.networks'),
'UrlConstraints': (__package__, '.networks'),
'EmailStr': (__package__, '.networks'),
'NameEmail': (__package__, '.networks'),
'IPvAnyAddress': (__package__, '.networks'),
'IPvAnyInterface': (__package__, '.networks'),
'IPvAnyNetwork': (__package__, '.networks'),
'PostgresDsn': (__package__, '.networks'),
'CockroachDsn': (__package__, '.networks'),
'AmqpDsn': (__package__, '.networks'),
'RedisDsn': (__package__, '.networks'),
'MongoDsn': (__package__, '.networks'),
'KafkaDsn': (__package__, '.networks'),
'NatsDsn': (__package__, '.networks'),
'MySQLDsn': (__package__, '.networks'),
'MariaDBDsn': (__package__, '.networks'),
'validate_email': (__package__, '.networks'),
# root_model
'RootModel': (__package__, '.root_model'),
# types
'Strict': (__package__, '.types'),
'StrictStr': (__package__, '.types'),
'conbytes': (__package__, '.types'),
'conlist': (__package__, '.types'),
'conset': (__package__, '.types'),
'confrozenset': (__package__, '.types'),
'constr': (__package__, '.types'),
'StringConstraints': (__package__, '.types'),
'ImportString': (__package__, '.types'),
'conint': (__package__, '.types'),
'PositiveInt': (__package__, '.types'),
'NegativeInt': (__package__, '.types'),
'NonNegativeInt': (__package__, '.types'),
'NonPositiveInt': (__package__, '.types'),
'confloat': (__package__, '.types'),
'PositiveFloat': (__package__, '.types'),
'NegativeFloat': (__package__, '.types'),
'NonNegativeFloat': (__package__, '.types'),
'NonPositiveFloat': (__package__, '.types'),
'FiniteFloat': (__package__, '.types'),
'condecimal': (__package__, '.types'),
'condate': (__package__, '.types'),
'UUID1': (__package__, '.types'),
'UUID3': (__package__, '.types'),
'UUID4': (__package__, '.types'),
'UUID5': (__package__, '.types'),
'FilePath': (__package__, '.types'),
'DirectoryPath': (__package__, '.types'),
'NewPath': (__package__, '.types'),
'Json': (__package__, '.types'),
'SecretStr': (__package__, '.types'),
'SecretBytes': (__package__, '.types'),
'StrictBool': (__package__, '.types'),
'StrictBytes': (__package__, '.types'),
'StrictInt': (__package__, '.types'),
'StrictFloat': (__package__, '.types'),
'PaymentCardNumber': (__package__, '.types'),
'ByteSize': (__package__, '.types'),
'PastDate': (__package__, '.types'),
'FutureDate': (__package__, '.types'),
'PastDatetime': (__package__, '.types'),
'FutureDatetime': (__package__, '.types'),
'AwareDatetime': (__package__, '.types'),
'NaiveDatetime': (__package__, '.types'),
'AllowInfNan': (__package__, '.types'),
'EncoderProtocol': (__package__, '.types'),
'EncodedBytes': (__package__, '.types'),
'EncodedStr': (__package__, '.types'),
'Base64Encoder': (__package__, '.types'),
'Base64Bytes': (__package__, '.types'),
'Base64Str': (__package__, '.types'),
'Base64UrlBytes': (__package__, '.types'),
'Base64UrlStr': (__package__, '.types'),
'GetPydanticSchema': (__package__, '.types'),
'Tag': (__package__, '.types'),
'Discriminator': (__package__, '.types'),
'JsonValue': (__package__, '.types'),
'OnErrorOmit': (__package__, '.types'),
# type_adapter
'TypeAdapter': (__package__, '.type_adapter'),
# warnings
'PydanticDeprecatedSince20': (__package__, '.warnings'),
'PydanticDeprecatedSince26': (__package__, '.warnings'),
'PydanticDeprecationWarning': (__package__, '.warnings'),
# annotated handlers
'GetCoreSchemaHandler': (__package__, '.annotated_handlers'),
'GetJsonSchemaHandler': (__package__, '.annotated_handlers'),
# generate schema from ._internal
'GenerateSchema': (__package__, '._internal._generate_schema'),
# pydantic_core stuff
'ValidationError': ('pydantic_core', '.'),
'ValidationInfo': ('pydantic_core', '.core_schema'),
'SerializationInfo': ('pydantic_core', '.core_schema'),
'ValidatorFunctionWrapHandler': ('pydantic_core', '.core_schema'),
'FieldSerializationInfo': ('pydantic_core', '.core_schema'),
'SerializerFunctionWrapHandler': ('pydantic_core', '.core_schema'),
# deprecated, mostly not included in __all__
'root_validator': (__package__, '.deprecated.class_validators'),
'validator': (__package__, '.deprecated.class_validators'),
'BaseConfig': (__package__, '.deprecated.config'),
'Extra': (__package__, '.deprecated.config'),
'parse_obj_as': (__package__, '.deprecated.tools'),
'schema_of': (__package__, '.deprecated.tools'),
'schema_json_of': (__package__, '.deprecated.tools'),
'FieldValidationInfo': ('pydantic_core', '.core_schema'),
}
_getattr_migration = getattr_migration(__name__)
def __getattr__(attr_name: str) -> object:
dynamic_attr = _dynamic_imports.get(attr_name)
if dynamic_attr is None:
return _getattr_migration(attr_name)
package, module_name = dynamic_attr
from importlib import import_module
if module_name == '__module__':
return import_module(f'.{attr_name}', package=package)
else:
module = import_module(module_name, package=package)
return getattr(module, attr_name)
def __dir__() -> 'list[str]':
return list(__all__)

View file

View file

@ -0,0 +1,322 @@
from __future__ import annotations as _annotations
import warnings
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Any,
Callable,
cast,
)
from pydantic_core import core_schema
from typing_extensions import (
Literal,
Self,
)
from ..aliases import AliasGenerator
from ..config import ConfigDict, ExtraValues, JsonDict, JsonEncoder, JsonSchemaExtraCallable
from ..errors import PydanticUserError
from ..warnings import PydanticDeprecatedSince20
if not TYPE_CHECKING:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
if TYPE_CHECKING:
from .._internal._schema_generation_shared import GenerateSchema
DEPRECATION_MESSAGE = 'Support for class-based `config` is deprecated, use ConfigDict instead.'
class ConfigWrapper:
"""Internal wrapper for Config which exposes ConfigDict items as attributes."""
__slots__ = ('config_dict',)
config_dict: ConfigDict
# all annotations are copied directly from ConfigDict, and should be kept up to date, a test will fail if they
# stop matching
title: str | None
str_to_lower: bool
str_to_upper: bool
str_strip_whitespace: bool
str_min_length: int
str_max_length: int | None
extra: ExtraValues | None
frozen: bool
populate_by_name: bool
use_enum_values: bool
validate_assignment: bool
arbitrary_types_allowed: bool
from_attributes: bool
# whether to use the actual key provided in the data (e.g. alias or first alias for "field required" errors) instead of field_names
# to construct error `loc`s, default `True`
loc_by_alias: bool
alias_generator: Callable[[str], str] | AliasGenerator | None
ignored_types: tuple[type, ...]
allow_inf_nan: bool
json_schema_extra: JsonDict | JsonSchemaExtraCallable | None
json_encoders: dict[type[object], JsonEncoder] | None
# new in V2
strict: bool
# whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never'
revalidate_instances: Literal['always', 'never', 'subclass-instances']
ser_json_timedelta: Literal['iso8601', 'float']
ser_json_bytes: Literal['utf8', 'base64']
ser_json_inf_nan: Literal['null', 'constants']
# whether to validate default values during validation, default False
validate_default: bool
validate_return: bool
protected_namespaces: tuple[str, ...]
hide_input_in_errors: bool
defer_build: bool
plugin_settings: dict[str, object] | None
schema_generator: type[GenerateSchema] | None
json_schema_serialization_defaults_required: bool
json_schema_mode_override: Literal['validation', 'serialization', None]
coerce_numbers_to_str: bool
regex_engine: Literal['rust-regex', 'python-re']
validation_error_cause: bool
def __init__(self, config: ConfigDict | dict[str, Any] | type[Any] | None, *, check: bool = True):
if check:
self.config_dict = prepare_config(config)
else:
self.config_dict = cast(ConfigDict, config)
@classmethod
def for_model(cls, bases: tuple[type[Any], ...], namespace: dict[str, Any], kwargs: dict[str, Any]) -> Self:
"""Build a new `ConfigWrapper` instance for a `BaseModel`.
The config wrapper built based on (in descending order of priority):
- options from `kwargs`
- options from the `namespace`
- options from the base classes (`bases`)
Args:
bases: A tuple of base classes.
namespace: The namespace of the class being created.
kwargs: The kwargs passed to the class being created.
Returns:
A `ConfigWrapper` instance for `BaseModel`.
"""
config_new = ConfigDict()
for base in bases:
config = getattr(base, 'model_config', None)
if config:
config_new.update(config.copy())
config_class_from_namespace = namespace.get('Config')
config_dict_from_namespace = namespace.get('model_config')
if config_class_from_namespace and config_dict_from_namespace:
raise PydanticUserError('"Config" and "model_config" cannot be used together', code='config-both')
config_from_namespace = config_dict_from_namespace or prepare_config(config_class_from_namespace)
config_new.update(config_from_namespace)
for k in list(kwargs.keys()):
if k in config_keys:
config_new[k] = kwargs.pop(k)
return cls(config_new)
# we don't show `__getattr__` to type checkers so missing attributes cause errors
if not TYPE_CHECKING: # pragma: no branch
def __getattr__(self, name: str) -> Any:
try:
return self.config_dict[name]
except KeyError:
try:
return config_defaults[name]
except KeyError:
raise AttributeError(f'Config has no attribute {name!r}') from None
def core_config(self, obj: Any) -> core_schema.CoreConfig:
"""Create a pydantic-core config, `obj` is just used to populate `title` if not set in config.
Pass `obj=None` if you do not want to attempt to infer the `title`.
We don't use getattr here since we don't want to populate with defaults.
Args:
obj: An object used to populate `title` if not set in config.
Returns:
A `CoreConfig` object created from config.
"""
def dict_not_none(**kwargs: Any) -> Any:
return {k: v for k, v in kwargs.items() if v is not None}
core_config = core_schema.CoreConfig(
**dict_not_none(
title=self.config_dict.get('title') or (obj and obj.__name__),
extra_fields_behavior=self.config_dict.get('extra'),
allow_inf_nan=self.config_dict.get('allow_inf_nan'),
populate_by_name=self.config_dict.get('populate_by_name'),
str_strip_whitespace=self.config_dict.get('str_strip_whitespace'),
str_to_lower=self.config_dict.get('str_to_lower'),
str_to_upper=self.config_dict.get('str_to_upper'),
strict=self.config_dict.get('strict'),
ser_json_timedelta=self.config_dict.get('ser_json_timedelta'),
ser_json_bytes=self.config_dict.get('ser_json_bytes'),
ser_json_inf_nan=self.config_dict.get('ser_json_inf_nan'),
from_attributes=self.config_dict.get('from_attributes'),
loc_by_alias=self.config_dict.get('loc_by_alias'),
revalidate_instances=self.config_dict.get('revalidate_instances'),
validate_default=self.config_dict.get('validate_default'),
str_max_length=self.config_dict.get('str_max_length'),
str_min_length=self.config_dict.get('str_min_length'),
hide_input_in_errors=self.config_dict.get('hide_input_in_errors'),
coerce_numbers_to_str=self.config_dict.get('coerce_numbers_to_str'),
regex_engine=self.config_dict.get('regex_engine'),
validation_error_cause=self.config_dict.get('validation_error_cause'),
)
)
return core_config
def __repr__(self):
c = ', '.join(f'{k}={v!r}' for k, v in self.config_dict.items())
return f'ConfigWrapper({c})'
class ConfigWrapperStack:
"""A stack of `ConfigWrapper` instances."""
def __init__(self, config_wrapper: ConfigWrapper):
self._config_wrapper_stack: list[ConfigWrapper] = [config_wrapper]
@property
def tail(self) -> ConfigWrapper:
return self._config_wrapper_stack[-1]
@contextmanager
def push(self, config_wrapper: ConfigWrapper | ConfigDict | None):
if config_wrapper is None:
yield
return
if not isinstance(config_wrapper, ConfigWrapper):
config_wrapper = ConfigWrapper(config_wrapper, check=False)
self._config_wrapper_stack.append(config_wrapper)
try:
yield
finally:
self._config_wrapper_stack.pop()
config_defaults = ConfigDict(
title=None,
str_to_lower=False,
str_to_upper=False,
str_strip_whitespace=False,
str_min_length=0,
str_max_length=None,
# let the model / dataclass decide how to handle it
extra=None,
frozen=False,
populate_by_name=False,
use_enum_values=False,
validate_assignment=False,
arbitrary_types_allowed=False,
from_attributes=False,
loc_by_alias=True,
alias_generator=None,
ignored_types=(),
allow_inf_nan=True,
json_schema_extra=None,
strict=False,
revalidate_instances='never',
ser_json_timedelta='iso8601',
ser_json_bytes='utf8',
ser_json_inf_nan='null',
validate_default=False,
validate_return=False,
protected_namespaces=('model_',),
hide_input_in_errors=False,
json_encoders=None,
defer_build=False,
plugin_settings=None,
schema_generator=None,
json_schema_serialization_defaults_required=False,
json_schema_mode_override=None,
coerce_numbers_to_str=False,
regex_engine='rust-regex',
validation_error_cause=False,
)
def prepare_config(config: ConfigDict | dict[str, Any] | type[Any] | None) -> ConfigDict:
"""Create a `ConfigDict` instance from an existing dict, a class (e.g. old class-based config) or None.
Args:
config: The input config.
Returns:
A ConfigDict object created from config.
"""
if config is None:
return ConfigDict()
if not isinstance(config, dict):
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
config = {k: getattr(config, k) for k in dir(config) if not k.startswith('__')}
config_dict = cast(ConfigDict, config)
check_deprecated(config_dict)
return config_dict
config_keys = set(ConfigDict.__annotations__.keys())
V2_REMOVED_KEYS = {
'allow_mutation',
'error_msg_templates',
'fields',
'getter_dict',
'smart_union',
'underscore_attrs_are_private',
'json_loads',
'json_dumps',
'copy_on_model_validation',
'post_init_call',
}
V2_RENAMED_KEYS = {
'allow_population_by_field_name': 'populate_by_name',
'anystr_lower': 'str_to_lower',
'anystr_strip_whitespace': 'str_strip_whitespace',
'anystr_upper': 'str_to_upper',
'keep_untouched': 'ignored_types',
'max_anystr_length': 'str_max_length',
'min_anystr_length': 'str_min_length',
'orm_mode': 'from_attributes',
'schema_extra': 'json_schema_extra',
'validate_all': 'validate_default',
}
def check_deprecated(config_dict: ConfigDict) -> None:
"""Check for deprecated config keys and warn the user.
Args:
config_dict: The input config.
"""
deprecated_removed_keys = V2_REMOVED_KEYS & config_dict.keys()
deprecated_renamed_keys = V2_RENAMED_KEYS.keys() & config_dict.keys()
if deprecated_removed_keys or deprecated_renamed_keys:
renamings = {k: V2_RENAMED_KEYS[k] for k in sorted(deprecated_renamed_keys)}
renamed_bullets = [f'* {k!r} has been renamed to {v!r}' for k, v in renamings.items()]
removed_bullets = [f'* {k!r} has been removed' for k in sorted(deprecated_removed_keys)]
message = '\n'.join(['Valid config keys have changed in V2:'] + renamed_bullets + removed_bullets)
warnings.warn(message, UserWarning)

View file

@ -0,0 +1,92 @@
from __future__ import annotations as _annotations
import typing
from typing import Any
import typing_extensions
if typing.TYPE_CHECKING:
from ._schema_generation_shared import (
CoreSchemaOrField as CoreSchemaOrField,
)
from ._schema_generation_shared import (
GetJsonSchemaFunction,
)
class CoreMetadata(typing_extensions.TypedDict, total=False):
"""A `TypedDict` for holding the metadata dict of the schema.
Attributes:
pydantic_js_functions: List of JSON schema functions.
pydantic_js_prefer_positional_arguments: Whether JSON schema generator will
prefer positional over keyword arguments for an 'arguments' schema.
"""
pydantic_js_functions: list[GetJsonSchemaFunction]
pydantic_js_annotation_functions: list[GetJsonSchemaFunction]
# If `pydantic_js_prefer_positional_arguments` is True, the JSON schema generator will
# prefer positional over keyword arguments for an 'arguments' schema.
pydantic_js_prefer_positional_arguments: bool | None
pydantic_typed_dict_cls: type[Any] | None # TODO: Consider moving this into the pydantic-core TypedDictSchema
class CoreMetadataHandler:
"""Because the metadata field in pydantic_core is of type `Any`, we can't assume much about its contents.
This class is used to interact with the metadata field on a CoreSchema object in a consistent
way throughout pydantic.
"""
__slots__ = ('_schema',)
def __init__(self, schema: CoreSchemaOrField):
self._schema = schema
metadata = schema.get('metadata')
if metadata is None:
schema['metadata'] = CoreMetadata()
elif not isinstance(metadata, dict):
raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.')
@property
def metadata(self) -> CoreMetadata:
"""Retrieves the metadata dict from the schema, initializing it to a dict if it is None
and raises an error if it is not a dict.
"""
metadata = self._schema.get('metadata')
if metadata is None:
self._schema['metadata'] = metadata = CoreMetadata()
if not isinstance(metadata, dict):
raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.')
return metadata
def build_metadata_dict(
*, # force keyword arguments to make it easier to modify this signature in a backwards-compatible way
js_functions: list[GetJsonSchemaFunction] | None = None,
js_annotation_functions: list[GetJsonSchemaFunction] | None = None,
js_prefer_positional_arguments: bool | None = None,
typed_dict_cls: type[Any] | None = None,
initial_metadata: Any | None = None,
) -> Any:
"""Builds a dict to use as the metadata field of a CoreSchema object in a manner that is consistent
with the CoreMetadataHandler class.
"""
if initial_metadata is not None and not isinstance(initial_metadata, dict):
raise TypeError(f'CoreSchema metadata should be a dict; got {initial_metadata!r}.')
metadata = CoreMetadata(
pydantic_js_functions=js_functions or [],
pydantic_js_annotation_functions=js_annotation_functions or [],
pydantic_js_prefer_positional_arguments=js_prefer_positional_arguments,
pydantic_typed_dict_cls=typed_dict_cls,
)
metadata = {k: v for k, v in metadata.items() if v is not None}
if initial_metadata is not None:
metadata = {**initial_metadata, **metadata}
return metadata

View file

@ -0,0 +1,570 @@
from __future__ import annotations
import os
from collections import defaultdict
from typing import (
Any,
Callable,
Hashable,
TypeVar,
Union,
)
from pydantic_core import CoreSchema, core_schema
from pydantic_core import validate_core_schema as _validate_core_schema
from typing_extensions import TypeAliasType, TypeGuard, get_args, get_origin
from . import _repr
from ._typing_extra import is_generic_alias
AnyFunctionSchema = Union[
core_schema.AfterValidatorFunctionSchema,
core_schema.BeforeValidatorFunctionSchema,
core_schema.WrapValidatorFunctionSchema,
core_schema.PlainValidatorFunctionSchema,
]
FunctionSchemaWithInnerSchema = Union[
core_schema.AfterValidatorFunctionSchema,
core_schema.BeforeValidatorFunctionSchema,
core_schema.WrapValidatorFunctionSchema,
]
CoreSchemaField = Union[
core_schema.ModelField, core_schema.DataclassField, core_schema.TypedDictField, core_schema.ComputedField
]
CoreSchemaOrField = Union[core_schema.CoreSchema, CoreSchemaField]
_CORE_SCHEMA_FIELD_TYPES = {'typed-dict-field', 'dataclass-field', 'model-field', 'computed-field'}
_FUNCTION_WITH_INNER_SCHEMA_TYPES = {'function-before', 'function-after', 'function-wrap'}
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'set', 'frozenset'}
_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache'
TAGGED_UNION_TAG_KEY = 'pydantic.internal.tagged_union_tag'
"""
Used in a `Tag` schema to specify the tag used for a discriminated union.
"""
HAS_INVALID_SCHEMAS_METADATA_KEY = 'pydantic.internal.invalid'
"""Used to mark a schema that is invalid because it refers to a definition that was not yet defined when the
schema was first encountered.
"""
def is_core_schema(
schema: CoreSchemaOrField,
) -> TypeGuard[CoreSchema]:
return schema['type'] not in _CORE_SCHEMA_FIELD_TYPES
def is_core_schema_field(
schema: CoreSchemaOrField,
) -> TypeGuard[CoreSchemaField]:
return schema['type'] in _CORE_SCHEMA_FIELD_TYPES
def is_function_with_inner_schema(
schema: CoreSchemaOrField,
) -> TypeGuard[FunctionSchemaWithInnerSchema]:
return schema['type'] in _FUNCTION_WITH_INNER_SCHEMA_TYPES
def is_list_like_schema_with_items_schema(
schema: CoreSchema,
) -> TypeGuard[core_schema.ListSchema | core_schema.SetSchema | core_schema.FrozenSetSchema]:
return schema['type'] in _LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES
def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None = None) -> str:
"""Produces the ref to be used for this type by pydantic_core's core schemas.
This `args_override` argument was added for the purpose of creating valid recursive references
when creating generic models without needing to create a concrete class.
"""
origin = get_origin(type_) or type_
args = get_args(type_) if is_generic_alias(type_) else (args_override or ())
generic_metadata = getattr(type_, '__pydantic_generic_metadata__', None)
if generic_metadata:
origin = generic_metadata['origin'] or origin
args = generic_metadata['args'] or args
module_name = getattr(origin, '__module__', '<No __module__>')
if isinstance(origin, TypeAliasType):
type_ref = f'{module_name}.{origin.__name__}:{id(origin)}'
else:
try:
qualname = getattr(origin, '__qualname__', f'<No __qualname__: {origin}>')
except Exception:
qualname = getattr(origin, '__qualname__', '<No __qualname__>')
type_ref = f'{module_name}.{qualname}:{id(origin)}'
arg_refs: list[str] = []
for arg in args:
if isinstance(arg, str):
# Handle string literals as a special case; we may be able to remove this special handling if we
# wrap them in a ForwardRef at some point.
arg_ref = f'{arg}:str-{id(arg)}'
else:
arg_ref = f'{_repr.display_as_type(arg)}:{id(arg)}'
arg_refs.append(arg_ref)
if arg_refs:
type_ref = f'{type_ref}[{",".join(arg_refs)}]'
return type_ref
def get_ref(s: core_schema.CoreSchema) -> None | str:
"""Get the ref from the schema if it has one.
This exists just for type checking to work correctly.
"""
return s.get('ref', None)
def collect_definitions(schema: core_schema.CoreSchema) -> dict[str, core_schema.CoreSchema]:
defs: dict[str, CoreSchema] = {}
def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
ref = get_ref(s)
if ref:
defs[ref] = s
return recurse(s, _record_valid_refs)
walk_core_schema(schema, _record_valid_refs)
return defs
def define_expected_missing_refs(
schema: core_schema.CoreSchema, allowed_missing_refs: set[str]
) -> core_schema.CoreSchema | None:
if not allowed_missing_refs:
# in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema
# this is a common case (will be hit for all non-generic models), so it's worth optimizing for
return None
refs = collect_definitions(schema).keys()
expected_missing_refs = allowed_missing_refs.difference(refs)
if expected_missing_refs:
definitions: list[core_schema.CoreSchema] = [
# TODO: Replace this with a (new) CoreSchema that, if present at any level, makes validation fail
# Issue: https://github.com/pydantic/pydantic-core/issues/619
core_schema.none_schema(ref=ref, metadata={HAS_INVALID_SCHEMAS_METADATA_KEY: True})
for ref in expected_missing_refs
]
return core_schema.definitions_schema(schema, definitions)
return None
def collect_invalid_schemas(schema: core_schema.CoreSchema) -> bool:
invalid = False
def _is_schema_valid(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
nonlocal invalid
if 'metadata' in s:
metadata = s['metadata']
if HAS_INVALID_SCHEMAS_METADATA_KEY in metadata:
invalid = metadata[HAS_INVALID_SCHEMAS_METADATA_KEY]
return s
return recurse(s, _is_schema_valid)
walk_core_schema(schema, _is_schema_valid)
return invalid
T = TypeVar('T')
Recurse = Callable[[core_schema.CoreSchema, 'Walk'], core_schema.CoreSchema]
Walk = Callable[[core_schema.CoreSchema, Recurse], core_schema.CoreSchema]
# TODO: Should we move _WalkCoreSchema into pydantic_core proper?
# Issue: https://github.com/pydantic/pydantic-core/issues/615
class _WalkCoreSchema:
def __init__(self):
self._schema_type_to_method = self._build_schema_type_to_method()
def _build_schema_type_to_method(self) -> dict[core_schema.CoreSchemaType, Recurse]:
mapping: dict[core_schema.CoreSchemaType, Recurse] = {}
key: core_schema.CoreSchemaType
for key in get_args(core_schema.CoreSchemaType):
method_name = f"handle_{key.replace('-', '_')}_schema"
mapping[key] = getattr(self, method_name, self._handle_other_schemas)
return mapping
def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
return f(schema, self._walk)
def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
schema = self._schema_type_to_method[schema['type']](schema.copy(), f)
ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore
if ser_schema:
schema['serialization'] = self._handle_ser_schemas(ser_schema, f)
return schema
def _handle_other_schemas(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
sub_schema = schema.get('schema', None)
if sub_schema is not None:
schema['schema'] = self.walk(sub_schema, f) # type: ignore
return schema
def _handle_ser_schemas(self, ser_schema: core_schema.SerSchema, f: Walk) -> core_schema.SerSchema:
schema: core_schema.CoreSchema | None = ser_schema.get('schema', None)
if schema is not None:
ser_schema['schema'] = self.walk(schema, f) # type: ignore
return_schema: core_schema.CoreSchema | None = ser_schema.get('return_schema', None)
if return_schema is not None:
ser_schema['return_schema'] = self.walk(return_schema, f) # type: ignore
return ser_schema
def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema, f: Walk) -> core_schema.CoreSchema:
new_definitions: list[core_schema.CoreSchema] = []
for definition in schema['definitions']:
if 'schema_ref' in definition and 'ref' in definition:
# This indicates a purposely indirect reference
# We want to keep such references around for implications related to JSON schema, etc.:
new_definitions.append(definition)
# However, we still need to walk the referenced definition:
self.walk(definition, f)
continue
updated_definition = self.walk(definition, f)
if 'ref' in updated_definition:
# If the updated definition schema doesn't have a 'ref', it shouldn't go in the definitions
# This is most likely to happen due to replacing something with a definition reference, in
# which case it should certainly not go in the definitions list
new_definitions.append(updated_definition)
new_inner_schema = self.walk(schema['schema'], f)
if not new_definitions and len(schema) == 3:
# This means we'd be returning a "trivial" definitions schema that just wrapped the inner schema
return new_inner_schema
new_schema = schema.copy()
new_schema['schema'] = new_inner_schema
new_schema['definitions'] = new_definitions
return new_schema
def handle_list_schema(self, schema: core_schema.ListSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_set_schema(self, schema: core_schema.SetSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_generator_schema(self, schema: core_schema.GeneratorSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_tuple_schema(self, schema: core_schema.TupleSchema, f: Walk) -> core_schema.CoreSchema:
schema['items_schema'] = [self.walk(v, f) for v in schema['items_schema']]
return schema
def handle_dict_schema(self, schema: core_schema.DictSchema, f: Walk) -> core_schema.CoreSchema:
keys_schema = schema.get('keys_schema')
if keys_schema is not None:
schema['keys_schema'] = self.walk(keys_schema, f)
values_schema = schema.get('values_schema')
if values_schema:
schema['values_schema'] = self.walk(values_schema, f)
return schema
def handle_function_schema(self, schema: AnyFunctionSchema, f: Walk) -> core_schema.CoreSchema:
if not is_function_with_inner_schema(schema):
return schema
schema['schema'] = self.walk(schema['schema'], f)
return schema
def handle_union_schema(self, schema: core_schema.UnionSchema, f: Walk) -> core_schema.CoreSchema:
new_choices: list[CoreSchema | tuple[CoreSchema, str]] = []
for v in schema['choices']:
if isinstance(v, tuple):
new_choices.append((self.walk(v[0], f), v[1]))
else:
new_choices.append(self.walk(v, f))
schema['choices'] = new_choices
return schema
def handle_tagged_union_schema(self, schema: core_schema.TaggedUnionSchema, f: Walk) -> core_schema.CoreSchema:
new_choices: dict[Hashable, core_schema.CoreSchema] = {}
for k, v in schema['choices'].items():
new_choices[k] = v if isinstance(v, (str, int)) else self.walk(v, f)
schema['choices'] = new_choices
return schema
def handle_chain_schema(self, schema: core_schema.ChainSchema, f: Walk) -> core_schema.CoreSchema:
schema['steps'] = [self.walk(v, f) for v in schema['steps']]
return schema
def handle_lax_or_strict_schema(self, schema: core_schema.LaxOrStrictSchema, f: Walk) -> core_schema.CoreSchema:
schema['lax_schema'] = self.walk(schema['lax_schema'], f)
schema['strict_schema'] = self.walk(schema['strict_schema'], f)
return schema
def handle_json_or_python_schema(self, schema: core_schema.JsonOrPythonSchema, f: Walk) -> core_schema.CoreSchema:
schema['json_schema'] = self.walk(schema['json_schema'], f)
schema['python_schema'] = self.walk(schema['python_schema'], f)
return schema
def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: Walk) -> core_schema.CoreSchema:
extras_schema = schema.get('extras_schema')
if extras_schema is not None:
schema['extras_schema'] = self.walk(extras_schema, f)
replaced_fields: dict[str, core_schema.ModelField] = {}
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', ()):
replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
if replaced_computed_fields:
schema['computed_fields'] = replaced_computed_fields
for k, v in schema['fields'].items():
replaced_field = v.copy()
replaced_field['schema'] = self.walk(v['schema'], f)
replaced_fields[k] = replaced_field
schema['fields'] = replaced_fields
return schema
def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk) -> core_schema.CoreSchema:
extras_schema = schema.get('extras_schema')
if extras_schema is not None:
schema['extras_schema'] = self.walk(extras_schema, f)
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', ()):
replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
if replaced_computed_fields:
schema['computed_fields'] = replaced_computed_fields
replaced_fields: dict[str, core_schema.TypedDictField] = {}
for k, v in schema['fields'].items():
replaced_field = v.copy()
replaced_field['schema'] = self.walk(v['schema'], f)
replaced_fields[k] = replaced_field
schema['fields'] = replaced_fields
return schema
def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema, f: Walk) -> core_schema.CoreSchema:
replaced_fields: list[core_schema.DataclassField] = []
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', ()):
replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
if replaced_computed_fields:
schema['computed_fields'] = replaced_computed_fields
for field in schema['fields']:
replaced_field = field.copy()
replaced_field['schema'] = self.walk(field['schema'], f)
replaced_fields.append(replaced_field)
schema['fields'] = replaced_fields
return schema
def handle_arguments_schema(self, schema: core_schema.ArgumentsSchema, f: Walk) -> core_schema.CoreSchema:
replaced_arguments_schema: list[core_schema.ArgumentsParameter] = []
for param in schema['arguments_schema']:
replaced_param = param.copy()
replaced_param['schema'] = self.walk(param['schema'], f)
replaced_arguments_schema.append(replaced_param)
schema['arguments_schema'] = replaced_arguments_schema
if 'var_args_schema' in schema:
schema['var_args_schema'] = self.walk(schema['var_args_schema'], f)
if 'var_kwargs_schema' in schema:
schema['var_kwargs_schema'] = self.walk(schema['var_kwargs_schema'], f)
return schema
def handle_call_schema(self, schema: core_schema.CallSchema, f: Walk) -> core_schema.CoreSchema:
schema['arguments_schema'] = self.walk(schema['arguments_schema'], f)
if 'return_schema' in schema:
schema['return_schema'] = self.walk(schema['return_schema'], f)
return schema
_dispatch = _WalkCoreSchema().walk
def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
"""Recursively traverse a CoreSchema.
Args:
schema (core_schema.CoreSchema): The CoreSchema to process, it will not be modified.
f (Walk): A function to apply. This function takes two arguments:
1. The current CoreSchema that is being processed
(not the same one you passed into this function, one level down).
2. The "next" `f` to call. This lets you for example use `f=functools.partial(some_method, some_context)`
to pass data down the recursive calls without using globals or other mutable state.
Returns:
core_schema.CoreSchema: A processed CoreSchema.
"""
return f(schema.copy(), _dispatch)
def simplify_schema_references(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: # noqa: C901
definitions: dict[str, core_schema.CoreSchema] = {}
ref_counts: dict[str, int] = defaultdict(int)
involved_in_recursion: dict[str, bool] = {}
current_recursion_ref_count: dict[str, int] = defaultdict(int)
def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] == 'definitions':
for definition in s['definitions']:
ref = get_ref(definition)
assert ref is not None
if ref not in definitions:
definitions[ref] = definition
recurse(definition, collect_refs)
return recurse(s['schema'], collect_refs)
else:
ref = get_ref(s)
if ref is not None:
new = recurse(s, collect_refs)
new_ref = get_ref(new)
if new_ref:
definitions[new_ref] = new
return core_schema.definition_reference_schema(schema_ref=ref)
else:
return recurse(s, collect_refs)
schema = walk_core_schema(schema, collect_refs)
def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] != 'definition-ref':
return recurse(s, count_refs)
ref = s['schema_ref']
ref_counts[ref] += 1
if ref_counts[ref] >= 2:
# If this model is involved in a recursion this should be detected
# on its second encounter, we can safely stop the walk here.
if current_recursion_ref_count[ref] != 0:
involved_in_recursion[ref] = True
return s
current_recursion_ref_count[ref] += 1
recurse(definitions[ref], count_refs)
current_recursion_ref_count[ref] -= 1
return s
schema = walk_core_schema(schema, count_refs)
assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it'
def can_be_inlined(s: core_schema.DefinitionReferenceSchema, ref: str) -> bool:
if ref_counts[ref] > 1:
return False
if involved_in_recursion.get(ref, False):
return False
if 'serialization' in s:
return False
if 'metadata' in s:
metadata = s['metadata']
for k in (
'pydantic_js_functions',
'pydantic_js_annotation_functions',
'pydantic.internal.union_discriminator',
):
if k in metadata:
# we need to keep this as a ref
return False
return True
def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] == 'definition-ref':
ref = s['schema_ref']
# Check if the reference is only used once, not involved in recursion and does not have
# any extra keys (like 'serialization')
if can_be_inlined(s, ref):
# Inline the reference by replacing the reference with the actual schema
new = definitions.pop(ref)
ref_counts[ref] -= 1 # because we just replaced it!
# put all other keys that were on the def-ref schema into the inlined version
# in particular this is needed for `serialization`
if 'serialization' in s:
new['serialization'] = s['serialization']
s = recurse(new, inline_refs)
return s
else:
return recurse(s, inline_refs)
else:
return recurse(s, inline_refs)
schema = walk_core_schema(schema, inline_refs)
def_values = [v for v in definitions.values() if ref_counts[v['ref']] > 0] # type: ignore
if def_values:
schema = core_schema.definitions_schema(schema=schema, definitions=def_values)
return schema
def _strip_metadata(schema: CoreSchema) -> CoreSchema:
def strip_metadata(s: CoreSchema, recurse: Recurse) -> CoreSchema:
s = s.copy()
s.pop('metadata', None)
if s['type'] == 'model-fields':
s = s.copy()
s['fields'] = {k: v.copy() for k, v in s['fields'].items()}
for field_name, field_schema in s['fields'].items():
field_schema.pop('metadata', None)
s['fields'][field_name] = field_schema
computed_fields = s.get('computed_fields', None)
if computed_fields:
s['computed_fields'] = [cf.copy() for cf in computed_fields]
for cf in computed_fields:
cf.pop('metadata', None)
else:
s.pop('computed_fields', None)
elif s['type'] == 'model':
# remove some defaults
if s.get('custom_init', True) is False:
s.pop('custom_init')
if s.get('root_model', True) is False:
s.pop('root_model')
if {'title'}.issuperset(s.get('config', {}).keys()):
s.pop('config', None)
return recurse(s, strip_metadata)
return walk_core_schema(schema, strip_metadata)
def pretty_print_core_schema(
schema: CoreSchema,
include_metadata: bool = False,
) -> None:
"""Pretty print a CoreSchema using rich.
This is intended for debugging purposes.
Args:
schema: The CoreSchema to print.
include_metadata: Whether to include metadata in the output. Defaults to `False`.
"""
from rich import print # type: ignore # install it manually in your dev env
if not include_metadata:
schema = _strip_metadata(schema)
return print(schema)
def validate_core_schema(schema: CoreSchema) -> CoreSchema:
if 'PYDANTIC_SKIP_VALIDATING_CORE_SCHEMAS' in os.environ:
return schema
return _validate_core_schema(schema)

View file

@ -0,0 +1,225 @@
"""Private logic for creating pydantic dataclasses."""
from __future__ import annotations as _annotations
import dataclasses
import typing
import warnings
from functools import partial, wraps
from typing import Any, Callable, ClassVar
from pydantic_core import (
ArgsKwargs,
SchemaSerializer,
SchemaValidator,
core_schema,
)
from typing_extensions import TypeGuard
from ..errors import PydanticUndefinedAnnotation
from ..fields import FieldInfo
from ..plugin._schema_validator import create_schema_validator
from ..warnings import PydanticDeprecatedSince20
from . import _config, _decorators, _typing_extra
from ._fields import collect_dataclass_fields
from ._generate_schema import GenerateSchema
from ._generics import get_standard_typevars_map
from ._mock_val_ser import set_dataclass_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._signature import generate_pydantic_signature
if typing.TYPE_CHECKING:
from ..config import ConfigDict
class StandardDataclass(typing.Protocol):
__dataclass_fields__: ClassVar[dict[str, Any]]
__dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
__post_init__: ClassVar[Callable[..., None]]
def __init__(self, *args: object, **kwargs: object) -> None:
pass
class PydanticDataclass(StandardDataclass, typing.Protocol):
"""A protocol containing attributes only available once a class has been decorated as a Pydantic dataclass.
Attributes:
__pydantic_config__: Pydantic-specific configuration settings for the dataclass.
__pydantic_complete__: Whether dataclass building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
__pydantic_decorators__: Metadata containing the decorators defined on the dataclass.
__pydantic_fields__: Metadata about the fields defined on the dataclass.
__pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the dataclass.
__pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the dataclass.
"""
__pydantic_config__: ClassVar[ConfigDict]
__pydantic_complete__: ClassVar[bool]
__pydantic_core_schema__: ClassVar[core_schema.CoreSchema]
__pydantic_decorators__: ClassVar[_decorators.DecoratorInfos]
__pydantic_fields__: ClassVar[dict[str, FieldInfo]]
__pydantic_serializer__: ClassVar[SchemaSerializer]
__pydantic_validator__: ClassVar[SchemaValidator]
else:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
def set_dataclass_fields(cls: type[StandardDataclass], types_namespace: dict[str, Any] | None = None) -> None:
"""Collect and set `cls.__pydantic_fields__`.
Args:
cls: The class.
types_namespace: The types namespace, defaults to `None`.
"""
typevars_map = get_standard_typevars_map(cls)
fields = collect_dataclass_fields(cls, types_namespace, typevars_map=typevars_map)
cls.__pydantic_fields__ = fields # type: ignore
def complete_dataclass(
cls: type[Any],
config_wrapper: _config.ConfigWrapper,
*,
raise_errors: bool = True,
types_namespace: dict[str, Any] | None,
) -> bool:
"""Finish building a pydantic dataclass.
This logic is called on a class which has already been wrapped in `dataclasses.dataclass()`.
This is somewhat analogous to `pydantic._internal._model_construction.complete_model_class`.
Args:
cls: The class.
config_wrapper: The config wrapper instance.
raise_errors: Whether to raise errors, defaults to `True`.
types_namespace: The types namespace.
Returns:
`True` if building a pydantic dataclass is successfully completed, `False` otherwise.
Raises:
PydanticUndefinedAnnotation: If `raise_error` is `True` and there is an undefined annotations.
"""
if hasattr(cls, '__post_init_post_parse__'):
warnings.warn(
'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning
)
if types_namespace is None:
types_namespace = _typing_extra.get_cls_types_namespace(cls)
set_dataclass_fields(cls, types_namespace)
typevars_map = get_standard_typevars_map(cls)
gen_schema = GenerateSchema(
config_wrapper,
types_namespace,
typevars_map,
)
# This needs to be called before we change the __init__
sig = generate_pydantic_signature(
init=cls.__init__,
fields=cls.__pydantic_fields__, # type: ignore
config_wrapper=config_wrapper,
is_dataclass=True,
)
# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None:
__tracebackhide__ = True
s = __dataclass_self__
s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s)
__init__.__qualname__ = f'{cls.__qualname__}.__init__'
cls.__init__ = __init__ # type: ignore
cls.__pydantic_config__ = config_wrapper.config_dict # type: ignore
cls.__signature__ = sig # type: ignore
get_core_schema = getattr(cls, '__get_pydantic_core_schema__', None)
try:
if get_core_schema:
schema = get_core_schema(
cls,
CallbackGetCoreSchemaHandler(
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
gen_schema,
ref_mode='unpack',
),
)
else:
schema = gen_schema.generate_schema(cls, from_dunder_get_core_schema=False)
except PydanticUndefinedAnnotation as e:
if raise_errors:
raise
set_dataclass_mocks(cls, cls.__name__, f'`{e.name}`')
return False
core_config = config_wrapper.core_config(cls)
try:
schema = gen_schema.clean_schema(schema)
except gen_schema.CollectedInvalid:
set_dataclass_mocks(cls, cls.__name__, 'all referenced types')
return False
# We are about to set all the remaining required properties expected for this cast;
# __pydantic_decorators__ and __pydantic_fields__ should already be set
cls = typing.cast('type[PydanticDataclass]', cls)
# debug(schema)
cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = validator = create_schema_validator(
schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
if config_wrapper.validate_assignment:
@wraps(cls.__setattr__)
def validated_setattr(instance: Any, __field: str, __value: str) -> None:
validator.validate_assignment(instance, __field, __value)
cls.__setattr__ = validated_setattr.__get__(None, cls) # type: ignore
return True
def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
"""Returns True if a class is a stdlib dataclass and *not* a pydantic dataclass.
We check that
- `_cls` is a dataclass
- `_cls` does not inherit from a processed pydantic dataclass (and thus have a `__pydantic_validator__`)
- `_cls` does not have any annotations that are not dataclass fields
e.g.
```py
import dataclasses
import pydantic.dataclasses
@dataclasses.dataclass
class A:
x: int
@pydantic.dataclasses.dataclass
class B(A):
y: int
```
In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
Args:
cls: The class.
Returns:
`True` if the class is a stdlib dataclass, `False` otherwise.
"""
return (
dataclasses.is_dataclass(_cls)
and not hasattr(_cls, '__pydantic_validator__')
and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {})))
)

View file

@ -0,0 +1,791 @@
"""Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators."""
from __future__ import annotations as _annotations
from collections import deque
from dataclasses import dataclass, field
from functools import cached_property, partial, partialmethod
from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature
from itertools import islice
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Iterable, TypeVar, Union
from pydantic_core import PydanticUndefined, core_schema
from typing_extensions import Literal, TypeAlias, is_typeddict
from ..errors import PydanticUserError
from ._core_utils import get_type_ref
from ._internal_dataclass import slots_true
from ._typing_extra import get_function_type_hints
if TYPE_CHECKING:
from ..fields import ComputedFieldInfo
from ..functional_validators import FieldValidatorModes
@dataclass(**slots_true)
class ValidatorDecoratorInfo:
"""A container for data from `@validator` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@validator'.
fields: A tuple of field names the validator should be called on.
mode: The proposed validator mode.
each_item: For complex objects (sets, lists etc.) whether to validate individual
elements rather than the whole object.
always: Whether this method and other validators should be called even if the value is missing.
check_fields: Whether to check that the fields actually exist on the model.
"""
decorator_repr: ClassVar[str] = '@validator'
fields: tuple[str, ...]
mode: Literal['before', 'after']
each_item: bool
always: bool
check_fields: bool | None
@dataclass(**slots_true)
class FieldValidatorDecoratorInfo:
"""A container for data from `@field_validator` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@field_validator'.
fields: A tuple of field names the validator should be called on.
mode: The proposed validator mode.
check_fields: Whether to check that the fields actually exist on the model.
"""
decorator_repr: ClassVar[str] = '@field_validator'
fields: tuple[str, ...]
mode: FieldValidatorModes
check_fields: bool | None
@dataclass(**slots_true)
class RootValidatorDecoratorInfo:
"""A container for data from `@root_validator` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@root_validator'.
mode: The proposed validator mode.
"""
decorator_repr: ClassVar[str] = '@root_validator'
mode: Literal['before', 'after']
@dataclass(**slots_true)
class FieldSerializerDecoratorInfo:
"""A container for data from `@field_serializer` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@field_serializer'.
fields: A tuple of field names the serializer should be called on.
mode: The proposed serializer mode.
return_type: The type of the serializer's return value.
when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
and `'json-unless-none'`.
check_fields: Whether to check that the fields actually exist on the model.
"""
decorator_repr: ClassVar[str] = '@field_serializer'
fields: tuple[str, ...]
mode: Literal['plain', 'wrap']
return_type: Any
when_used: core_schema.WhenUsed
check_fields: bool | None
@dataclass(**slots_true)
class ModelSerializerDecoratorInfo:
"""A container for data from `@model_serializer` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@model_serializer'.
mode: The proposed serializer mode.
return_type: The type of the serializer's return value.
when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
and `'json-unless-none'`.
"""
decorator_repr: ClassVar[str] = '@model_serializer'
mode: Literal['plain', 'wrap']
return_type: Any
when_used: core_schema.WhenUsed
@dataclass(**slots_true)
class ModelValidatorDecoratorInfo:
"""A container for data from `@model_validator` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@model_serializer'.
mode: The proposed serializer mode.
"""
decorator_repr: ClassVar[str] = '@model_validator'
mode: Literal['wrap', 'before', 'after']
DecoratorInfo: TypeAlias = """Union[
ValidatorDecoratorInfo,
FieldValidatorDecoratorInfo,
RootValidatorDecoratorInfo,
FieldSerializerDecoratorInfo,
ModelSerializerDecoratorInfo,
ModelValidatorDecoratorInfo,
ComputedFieldInfo,
]"""
ReturnType = TypeVar('ReturnType')
DecoratedType: TypeAlias = (
'Union[classmethod[Any, Any, ReturnType], staticmethod[Any, ReturnType], Callable[..., ReturnType], property]'
)
@dataclass # can't use slots here since we set attributes on `__post_init__`
class PydanticDescriptorProxy(Generic[ReturnType]):
"""Wrap a classmethod, staticmethod, property or unbound function
and act as a descriptor that allows us to detect decorated items
from the class' attributes.
This class' __get__ returns the wrapped item's __get__ result,
which makes it transparent for classmethods and staticmethods.
Attributes:
wrapped: The decorator that has to be wrapped.
decorator_info: The decorator info.
shim: A wrapper function to wrap V1 style function.
"""
wrapped: DecoratedType[ReturnType]
decorator_info: DecoratorInfo
shim: Callable[[Callable[..., Any]], Callable[..., Any]] | None = None
def __post_init__(self):
for attr in 'setter', 'deleter':
if hasattr(self.wrapped, attr):
f = partial(self._call_wrapped_attr, name=attr)
setattr(self, attr, f)
def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]:
self.wrapped = getattr(self.wrapped, name)(func)
return self
def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]:
try:
return self.wrapped.__get__(obj, obj_type)
except AttributeError:
# not a descriptor, e.g. a partial object
return self.wrapped # type: ignore[return-value]
def __set_name__(self, instance: Any, name: str) -> None:
if hasattr(self.wrapped, '__set_name__'):
self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess]
def __getattr__(self, __name: str) -> Any:
"""Forward checks for __isabstractmethod__ and such."""
return getattr(self.wrapped, __name)
DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo)
@dataclass(**slots_true)
class Decorator(Generic[DecoratorInfoType]):
"""A generic container class to join together the decorator metadata
(metadata from decorator itself, which we have when the
decorator is called but not when we are building the core-schema)
and the bound function (which we have after the class itself is created).
Attributes:
cls_ref: The class ref.
cls_var_name: The decorated function name.
func: The decorated function.
shim: A wrapper function to wrap V1 style function.
info: The decorator info.
"""
cls_ref: str
cls_var_name: str
func: Callable[..., Any]
shim: Callable[[Any], Any] | None
info: DecoratorInfoType
@staticmethod
def build(
cls_: Any,
*,
cls_var_name: str,
shim: Callable[[Any], Any] | None,
info: DecoratorInfoType,
) -> Decorator[DecoratorInfoType]:
"""Build a new decorator.
Args:
cls_: The class.
cls_var_name: The decorated function name.
shim: A wrapper function to wrap V1 style function.
info: The decorator info.
Returns:
The new decorator instance.
"""
func = get_attribute_from_bases(cls_, cls_var_name)
if shim is not None:
func = shim(func)
func = unwrap_wrapped_function(func, unwrap_partial=False)
if not callable(func):
# This branch will get hit for classmethod properties
attribute = get_attribute_from_base_dicts(cls_, cls_var_name) # prevents the binding call to `__get__`
if isinstance(attribute, PydanticDescriptorProxy):
func = unwrap_wrapped_function(attribute.wrapped)
return Decorator(
cls_ref=get_type_ref(cls_),
cls_var_name=cls_var_name,
func=func,
shim=shim,
info=info,
)
def bind_to_cls(self, cls: Any) -> Decorator[DecoratorInfoType]:
"""Bind the decorator to a class.
Args:
cls: the class.
Returns:
The new decorator instance.
"""
return self.build(
cls,
cls_var_name=self.cls_var_name,
shim=self.shim,
info=self.info,
)
def get_bases(tp: type[Any]) -> tuple[type[Any], ...]:
"""Get the base classes of a class or typeddict.
Args:
tp: The type or class to get the bases.
Returns:
The base classes.
"""
if is_typeddict(tp):
return tp.__orig_bases__ # type: ignore
try:
return tp.__bases__
except AttributeError:
return ()
def mro(tp: type[Any]) -> tuple[type[Any], ...]:
"""Calculate the Method Resolution Order of bases using the C3 algorithm.
See https://www.python.org/download/releases/2.3/mro/
"""
# try to use the existing mro, for performance mainly
# but also because it helps verify the implementation below
if not is_typeddict(tp):
try:
return tp.__mro__
except AttributeError:
# GenericAlias and some other cases
pass
bases = get_bases(tp)
return (tp,) + mro_for_bases(bases)
def mro_for_bases(bases: tuple[type[Any], ...]) -> tuple[type[Any], ...]:
def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]:
while True:
non_empty = [seq for seq in seqs if seq]
if not non_empty:
# Nothing left to process, we're done.
return
candidate: type[Any] | None = None
for seq in non_empty: # Find merge candidates among seq heads.
candidate = seq[0]
not_head = [s for s in non_empty if candidate in islice(s, 1, None)]
if not_head:
# Reject the candidate.
candidate = None
else:
break
if not candidate:
raise TypeError('Inconsistent hierarchy, no C3 MRO is possible')
yield candidate
for seq in non_empty:
# Remove candidate.
if seq[0] == candidate:
seq.popleft()
seqs = [deque(mro(base)) for base in bases] + [deque(bases)]
return tuple(merge_seqs(seqs))
_sentinel = object()
def get_attribute_from_bases(tp: type[Any] | tuple[type[Any], ...], name: str) -> Any:
"""Get the attribute from the next class in the MRO that has it,
aiming to simulate calling the method on the actual class.
The reason for iterating over the mro instead of just getting
the attribute (which would do that for us) is to support TypedDict,
which lacks a real __mro__, but can have a virtual one constructed
from its bases (as done here).
Args:
tp: The type or class to search for the attribute. If a tuple, this is treated as a set of base classes.
name: The name of the attribute to retrieve.
Returns:
Any: The attribute value, if found.
Raises:
AttributeError: If the attribute is not found in any class in the MRO.
"""
if isinstance(tp, tuple):
for base in mro_for_bases(tp):
attribute = base.__dict__.get(name, _sentinel)
if attribute is not _sentinel:
attribute_get = getattr(attribute, '__get__', None)
if attribute_get is not None:
return attribute_get(None, tp)
return attribute
raise AttributeError(f'{name} not found in {tp}')
else:
try:
return getattr(tp, name)
except AttributeError:
return get_attribute_from_bases(mro(tp), name)
def get_attribute_from_base_dicts(tp: type[Any], name: str) -> Any:
"""Get an attribute out of the `__dict__` following the MRO.
This prevents the call to `__get__` on the descriptor, and allows
us to get the original function for classmethod properties.
Args:
tp: The type or class to search for the attribute.
name: The name of the attribute to retrieve.
Returns:
Any: The attribute value, if found.
Raises:
KeyError: If the attribute is not found in any class's `__dict__` in the MRO.
"""
for base in reversed(mro(tp)):
if name in base.__dict__:
return base.__dict__[name]
return tp.__dict__[name] # raise the error
@dataclass(**slots_true)
class DecoratorInfos:
"""Mapping of name in the class namespace to decorator info.
note that the name in the class namespace is the function or attribute name
not the field name!
"""
validators: dict[str, Decorator[ValidatorDecoratorInfo]] = field(default_factory=dict)
field_validators: dict[str, Decorator[FieldValidatorDecoratorInfo]] = field(default_factory=dict)
root_validators: dict[str, Decorator[RootValidatorDecoratorInfo]] = field(default_factory=dict)
field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]] = field(default_factory=dict)
model_serializers: dict[str, Decorator[ModelSerializerDecoratorInfo]] = field(default_factory=dict)
model_validators: dict[str, Decorator[ModelValidatorDecoratorInfo]] = field(default_factory=dict)
computed_fields: dict[str, Decorator[ComputedFieldInfo]] = field(default_factory=dict)
@staticmethod
def build(model_dc: type[Any]) -> DecoratorInfos: # noqa: C901 (ignore complexity)
"""We want to collect all DecFunc instances that exist as
attributes in the namespace of the class (a BaseModel or dataclass)
that called us
But we want to collect these in the order of the bases
So instead of getting them all from the leaf class (the class that called us),
we traverse the bases from root (the oldest ancestor class) to leaf
and collect all of the instances as we go, taking care to replace
any duplicate ones with the last one we see to mimic how function overriding
works with inheritance.
If we do replace any functions we put the replacement into the position
the replaced function was in; that is, we maintain the order.
"""
# reminder: dicts are ordered and replacement does not alter the order
res = DecoratorInfos()
for base in reversed(mro(model_dc)[1:]):
existing: DecoratorInfos | None = base.__dict__.get('__pydantic_decorators__')
if existing is None:
existing = DecoratorInfos.build(base)
res.validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.validators.items()})
res.field_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_validators.items()})
res.root_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.root_validators.items()})
res.field_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_serializers.items()})
res.model_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_serializers.items()})
res.model_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_validators.items()})
res.computed_fields.update({k: v.bind_to_cls(model_dc) for k, v in existing.computed_fields.items()})
to_replace: list[tuple[str, Any]] = []
for var_name, var_value in vars(model_dc).items():
if isinstance(var_value, PydanticDescriptorProxy):
info = var_value.decorator_info
if isinstance(info, ValidatorDecoratorInfo):
res.validators[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
elif isinstance(info, FieldValidatorDecoratorInfo):
res.field_validators[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
elif isinstance(info, RootValidatorDecoratorInfo):
res.root_validators[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
elif isinstance(info, FieldSerializerDecoratorInfo):
# check whether a serializer function is already registered for fields
for field_serializer_decorator in res.field_serializers.values():
# check that each field has at most one serializer function.
# serializer functions for the same field in subclasses are allowed,
# and are treated as overrides
if field_serializer_decorator.cls_var_name == var_name:
continue
for f in info.fields:
if f in field_serializer_decorator.info.fields:
raise PydanticUserError(
'Multiple field serializer functions were defined '
f'for field {f!r}, this is not allowed.',
code='multiple-field-serializers',
)
res.field_serializers[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
elif isinstance(info, ModelValidatorDecoratorInfo):
res.model_validators[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
elif isinstance(info, ModelSerializerDecoratorInfo):
res.model_serializers[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
else:
from ..fields import ComputedFieldInfo
isinstance(var_value, ComputedFieldInfo)
res.computed_fields[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=None, info=info
)
to_replace.append((var_name, var_value.wrapped))
if to_replace:
# If we can save `__pydantic_decorators__` on the class we'll be able to check for it above
# so then we don't need to re-process the type, which means we can discard our descriptor wrappers
# and replace them with the thing they are wrapping (see the other setattr call below)
# which allows validator class methods to also function as regular class methods
setattr(model_dc, '__pydantic_decorators__', res)
for name, value in to_replace:
setattr(model_dc, name, value)
return res
def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes) -> bool:
"""Look at a field or model validator function and determine whether it takes an info argument.
An error is raised if the function has an invalid signature.
Args:
validator: The validator function to inspect.
mode: The proposed validator mode.
Returns:
Whether the validator takes an info argument.
"""
try:
sig = signature(validator)
except ValueError:
# builtins and some C extensions don't have signatures
# assume that they don't take an info argument and only take a single argument
# e.g. `str.strip` or `datetime.datetime`
return False
n_positional = count_positional_params(sig)
if mode == 'wrap':
if n_positional == 3:
return True
elif n_positional == 2:
return False
else:
assert mode in {'before', 'after', 'plain'}, f"invalid mode: {mode!r}, expected 'before', 'after' or 'plain"
if n_positional == 2:
return True
elif n_positional == 1:
return False
raise PydanticUserError(
f'Unrecognized field_validator function signature for {validator} with `mode={mode}`:{sig}',
code='validator-signature',
)
def inspect_field_serializer(
serializer: Callable[..., Any], mode: Literal['plain', 'wrap'], computed_field: bool = False
) -> tuple[bool, bool]:
"""Look at a field serializer function and determine if it is a field serializer,
and whether it takes an info argument.
An error is raised if the function has an invalid signature.
Args:
serializer: The serializer function to inspect.
mode: The serializer mode, either 'plain' or 'wrap'.
computed_field: When serializer is applied on computed_field. It doesn't require
info signature.
Returns:
Tuple of (is_field_serializer, info_arg).
"""
sig = signature(serializer)
first = next(iter(sig.parameters.values()), None)
is_field_serializer = first is not None and first.name == 'self'
n_positional = count_positional_params(sig)
if is_field_serializer:
# -1 to correct for self parameter
info_arg = _serializer_info_arg(mode, n_positional - 1)
else:
info_arg = _serializer_info_arg(mode, n_positional)
if info_arg is None:
raise PydanticUserError(
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
code='field-serializer-signature',
)
if info_arg and computed_field:
raise PydanticUserError(
'field_serializer on computed_field does not use info signature', code='field-serializer-signature'
)
else:
return is_field_serializer, info_arg
def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
"""Look at a serializer function used via `Annotated` and determine whether it takes an info argument.
An error is raised if the function has an invalid signature.
Args:
serializer: The serializer function to check.
mode: The serializer mode, either 'plain' or 'wrap'.
Returns:
info_arg
"""
sig = signature(serializer)
info_arg = _serializer_info_arg(mode, count_positional_params(sig))
if info_arg is None:
raise PydanticUserError(
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
code='field-serializer-signature',
)
else:
return info_arg
def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
"""Look at a model serializer function and determine whether it takes an info argument.
An error is raised if the function has an invalid signature.
Args:
serializer: The serializer function to check.
mode: The serializer mode, either 'plain' or 'wrap'.
Returns:
`info_arg` - whether the function expects an info argument.
"""
if isinstance(serializer, (staticmethod, classmethod)) or not is_instance_method_from_sig(serializer):
raise PydanticUserError(
'`@model_serializer` must be applied to instance methods', code='model-serializer-instance-method'
)
sig = signature(serializer)
info_arg = _serializer_info_arg(mode, count_positional_params(sig))
if info_arg is None:
raise PydanticUserError(
f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}',
code='model-serializer-signature',
)
else:
return info_arg
def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None:
if mode == 'plain':
if n_positional == 1:
# (__input_value: Any) -> Any
return False
elif n_positional == 2:
# (__model: Any, __input_value: Any) -> Any
return True
else:
assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'"
if n_positional == 2:
# (__input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any
return False
elif n_positional == 3:
# (__input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: SerializationInfo) -> Any
return True
return None
AnyDecoratorCallable: TypeAlias = (
'Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any], Callable[..., Any]]'
)
def is_instance_method_from_sig(function: AnyDecoratorCallable) -> bool:
"""Whether the function is an instance method.
It will consider a function as instance method if the first parameter of
function is `self`.
Args:
function: The function to check.
Returns:
`True` if the function is an instance method, `False` otherwise.
"""
sig = signature(unwrap_wrapped_function(function))
first = next(iter(sig.parameters.values()), None)
if first and first.name == 'self':
return True
return False
def ensure_classmethod_based_on_signature(function: AnyDecoratorCallable) -> Any:
"""Apply the `@classmethod` decorator on the function.
Args:
function: The function to apply the decorator on.
Return:
The `@classmethod` decorator applied function.
"""
if not isinstance(
unwrap_wrapped_function(function, unwrap_class_static_method=False), classmethod
) and _is_classmethod_from_sig(function):
return classmethod(function) # type: ignore[arg-type]
return function
def _is_classmethod_from_sig(function: AnyDecoratorCallable) -> bool:
sig = signature(unwrap_wrapped_function(function))
first = next(iter(sig.parameters.values()), None)
if first and first.name == 'cls':
return True
return False
def unwrap_wrapped_function(
func: Any,
*,
unwrap_partial: bool = True,
unwrap_class_static_method: bool = True,
) -> Any:
"""Recursively unwraps a wrapped function until the underlying function is reached.
This handles property, functools.partial, functools.partialmethod, staticmethod and classmethod.
Args:
func: The function to unwrap.
unwrap_partial: If True (default), unwrap partial and partialmethod decorators, otherwise don't.
decorators.
unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod
decorators. If False, only unwrap partial and partialmethod decorators.
Returns:
The underlying function of the wrapped function.
"""
all: set[Any] = {property, cached_property}
if unwrap_partial:
all.update({partial, partialmethod})
if unwrap_class_static_method:
all.update({staticmethod, classmethod})
while isinstance(func, tuple(all)):
if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)):
func = func.__func__
elif isinstance(func, (partial, partialmethod)):
func = func.func
elif isinstance(func, property):
func = func.fget # arbitrary choice, convenient for computed fields
else:
# Make coverage happy as it can only get here in the last possible case
assert isinstance(func, cached_property)
func = func.func # type: ignore
return func
def get_function_return_type(
func: Any, explicit_return_type: Any, types_namespace: dict[str, Any] | None = None
) -> Any:
"""Get the function return type.
It gets the return type from the type annotation if `explicit_return_type` is `None`.
Otherwise, it returns `explicit_return_type`.
Args:
func: The function to get its return type.
explicit_return_type: The explicit return type.
types_namespace: The types namespace, defaults to `None`.
Returns:
The function return type.
"""
if explicit_return_type is PydanticUndefined:
# try to get it from the type annotation
hints = get_function_type_hints(
unwrap_wrapped_function(func), include_keys={'return'}, types_namespace=types_namespace
)
return hints.get('return', PydanticUndefined)
else:
return explicit_return_type
def count_positional_params(sig: Signature) -> int:
return sum(1 for param in sig.parameters.values() if can_be_positional(param))
def can_be_positional(param: Parameter) -> bool:
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
def ensure_property(f: Any) -> Any:
"""Ensure that a function is a `property` or `cached_property`, or is a valid descriptor.
Args:
f: The function to check.
Returns:
The function, or a `property` or `cached_property` instance wrapping the function.
"""
if ismethoddescriptor(f) or isdatadescriptor(f):
return f
else:
return property(f)

View file

@ -0,0 +1,181 @@
"""Logic for V1 validators, e.g. `@validator` and `@root_validator`."""
from __future__ import annotations as _annotations
from inspect import Parameter, signature
from typing import Any, Dict, Tuple, Union, cast
from pydantic_core import core_schema
from typing_extensions import Protocol
from ..errors import PydanticUserError
from ._decorators import can_be_positional
class V1OnlyValueValidator(Protocol):
"""A simple validator, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any) -> Any:
...
class V1ValidatorWithValues(Protocol):
"""A validator with `values` argument, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, values: dict[str, Any]) -> Any:
...
class V1ValidatorWithValuesKwOnly(Protocol):
"""A validator with keyword only `values` argument, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any:
...
class V1ValidatorWithKwargs(Protocol):
"""A validator with `kwargs` argument, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, **kwargs: Any) -> Any:
...
class V1ValidatorWithValuesAndKwargs(Protocol):
"""A validator with `values` and `kwargs` arguments, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any:
...
V1Validator = Union[
V1ValidatorWithValues, V1ValidatorWithValuesKwOnly, V1ValidatorWithKwargs, V1ValidatorWithValuesAndKwargs
]
def can_be_keyword(param: Parameter) -> bool:
return param.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
def make_generic_v1_field_validator(validator: V1Validator) -> core_schema.WithInfoValidatorFunction:
"""Wrap a V1 style field validator for V2 compatibility.
Args:
validator: The V1 style field validator.
Returns:
A wrapped V2 style field validator.
Raises:
PydanticUserError: If the signature is not supported or the parameters are
not available in Pydantic V2.
"""
sig = signature(validator)
needs_values_kw = False
for param_num, (param_name, parameter) in enumerate(sig.parameters.items()):
if can_be_keyword(parameter) and param_name in ('field', 'config'):
raise PydanticUserError(
'The `field` and `config` parameters are not available in Pydantic V2, '
'please use the `info` parameter instead.',
code='validator-field-config-info',
)
if parameter.kind is Parameter.VAR_KEYWORD:
needs_values_kw = True
elif can_be_keyword(parameter) and param_name == 'values':
needs_values_kw = True
elif can_be_positional(parameter) and param_num == 0:
# value
continue
elif parameter.default is Parameter.empty: # ignore params with defaults e.g. bound by functools.partial
raise PydanticUserError(
f'Unsupported signature for V1 style validator {validator}: {sig} is not supported.',
code='validator-v1-signature',
)
if needs_values_kw:
# (v, **kwargs), (v, values, **kwargs), (v, *, values, **kwargs) or (v, *, values)
val1 = cast(V1ValidatorWithValues, validator)
def wrapper1(value: Any, info: core_schema.ValidationInfo) -> Any:
return val1(value, values=info.data)
return wrapper1
else:
val2 = cast(V1OnlyValueValidator, validator)
def wrapper2(value: Any, _: core_schema.ValidationInfo) -> Any:
return val2(value)
return wrapper2
RootValidatorValues = Dict[str, Any]
# technically tuple[model_dict, model_extra, fields_set] | tuple[dataclass_dict, init_vars]
RootValidatorFieldsTuple = Tuple[Any, ...]
class V1RootValidatorFunction(Protocol):
"""A simple root validator, supported for V1 validators and V2 validators."""
def __call__(self, __values: RootValidatorValues) -> RootValidatorValues:
...
class V2CoreBeforeRootValidator(Protocol):
"""V2 validator with mode='before'."""
def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues:
...
class V2CoreAfterRootValidator(Protocol):
"""V2 validator with mode='after'."""
def __call__(
self, __fields_tuple: RootValidatorFieldsTuple, __info: core_schema.ValidationInfo
) -> RootValidatorFieldsTuple:
...
def make_v1_generic_root_validator(
validator: V1RootValidatorFunction, pre: bool
) -> V2CoreBeforeRootValidator | V2CoreAfterRootValidator:
"""Wrap a V1 style root validator for V2 compatibility.
Args:
validator: The V1 style field validator.
pre: Whether the validator is a pre validator.
Returns:
A wrapped V2 style validator.
"""
if pre is True:
# mode='before' for pydantic-core
def _wrapper1(values: RootValidatorValues, _: core_schema.ValidationInfo) -> RootValidatorValues:
return validator(values)
return _wrapper1
# mode='after' for pydantic-core
def _wrapper2(fields_tuple: RootValidatorFieldsTuple, _: core_schema.ValidationInfo) -> RootValidatorFieldsTuple:
if len(fields_tuple) == 2:
# dataclass, this is easy
values, init_vars = fields_tuple
values = validator(values)
return values, init_vars
else:
# ugly hack: to match v1 behaviour, we merge values and model_extra, then split them up based on fields
# afterwards
model_dict, model_extra, fields_set = fields_tuple
if model_extra:
fields = set(model_dict.keys())
model_dict.update(model_extra)
model_dict_new = validator(model_dict)
for k in list(model_dict_new.keys()):
if k not in fields:
model_extra[k] = model_dict_new.pop(k)
else:
model_dict_new = validator(model_dict)
return model_dict_new, model_extra, fields_set
return _wrapper2

View file

@ -0,0 +1,506 @@
from __future__ import annotations as _annotations
from typing import TYPE_CHECKING, Any, Hashable, Sequence
from pydantic_core import CoreSchema, core_schema
from ..errors import PydanticUserError
from . import _core_utils
from ._core_utils import (
CoreSchemaField,
collect_definitions,
simplify_schema_references,
)
if TYPE_CHECKING:
from ..types import Discriminator
CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'
class MissingDefinitionForUnionRef(Exception):
"""Raised when applying a discriminated union discriminator to a schema
requires a definition that is not yet defined
"""
def __init__(self, ref: str) -> None:
self.ref = ref
super().__init__(f'Missing definition for ref {self.ref!r}')
def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None:
schema.setdefault('metadata', {})
metadata = schema.get('metadata')
assert metadata is not None
metadata[CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY] = discriminator
def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
definitions: dict[str, CoreSchema] | None = None
def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema:
nonlocal definitions
s = recurse(s, inner)
if s['type'] == 'tagged-union':
return s
metadata = s.get('metadata', {})
discriminator = metadata.pop(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
if discriminator is not None:
if definitions is None:
definitions = collect_definitions(schema)
s = apply_discriminator(s, discriminator, definitions)
return s
return simplify_schema_references(_core_utils.walk_core_schema(schema, inner))
def apply_discriminator(
schema: core_schema.CoreSchema,
discriminator: str | Discriminator,
definitions: dict[str, core_schema.CoreSchema] | None = None,
) -> core_schema.CoreSchema:
"""Applies the discriminator and returns a new core schema.
Args:
schema: The input schema.
discriminator: The name of the field which will serve as the discriminator.
definitions: A mapping of schema ref to schema.
Returns:
The new core schema.
Raises:
TypeError:
- If `discriminator` is used with invalid union variant.
- If `discriminator` is used with `Union` type with one variant.
- If `discriminator` value mapped to multiple choices.
MissingDefinitionForUnionRef:
If the definition for ref is missing.
PydanticUserError:
- If a model in union doesn't have a discriminator field.
- If discriminator field has a non-string alias.
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
from ..types import Discriminator
if isinstance(discriminator, Discriminator):
if isinstance(discriminator.discriminator, str):
discriminator = discriminator.discriminator
else:
return discriminator._convert_schema(schema)
return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)
class _ApplyInferredDiscriminator:
"""This class is used to convert an input schema containing a union schema into one where that union is
replaced with a tagged-union, with all the associated debugging and performance benefits.
This is done by:
* Validating that the input schema is compatible with the provided discriminator
* Introspecting the schema to determine which discriminator values should map to which union choices
* Handling various edge cases such as 'definitions', 'default', 'nullable' schemas, and more
I have chosen to implement the conversion algorithm in this class, rather than a function,
to make it easier to maintain state while recursively walking the provided CoreSchema.
"""
def __init__(self, discriminator: str, definitions: dict[str, core_schema.CoreSchema]):
# `discriminator` should be the name of the field which will serve as the discriminator.
# It must be the python name of the field, and *not* the field's alias. Note that as of now,
# all members of a discriminated union _must_ use a field with the same name as the discriminator.
# This may change if/when we expose a way to manually specify the TaggedUnionSchema's choices.
self.discriminator = discriminator
# `definitions` should contain a mapping of schema ref to schema for all schemas which might
# be referenced by some choice
self.definitions = definitions
# `_discriminator_alias` will hold the value, if present, of the alias for the discriminator
#
# Note: following the v1 implementation, we currently disallow the use of different aliases
# for different choices. This is not a limitation of pydantic_core, but if we try to handle
# this, the inference logic gets complicated very quickly, and could result in confusing
# debugging challenges for users making subtle mistakes.
#
# Rather than trying to do the most powerful inference possible, I think we should eventually
# expose a way to more-manually control the way the TaggedUnionSchema is constructed through
# the use of a new type which would be placed as an Annotation on the Union type. This would
# provide the full flexibility/power of pydantic_core's TaggedUnionSchema where necessary for
# more complex cases, without over-complicating the inference logic for the common cases.
self._discriminator_alias: str | None = None
# `_should_be_nullable` indicates whether the converted union has `None` as an allowed value.
# If `None` is an acceptable value of the (possibly-wrapped) union, we ignore it while
# constructing the TaggedUnionSchema, but set the `_should_be_nullable` attribute to True.
# Once we have constructed the TaggedUnionSchema, if `_should_be_nullable` is True, we ensure
# that the final schema gets wrapped as a NullableSchema. This has the same semantics on the
# python side, but resolves the issue that `None` cannot correspond to any discriminator values.
self._should_be_nullable = False
# `_is_nullable` is used to track if the final produced schema will definitely be nullable;
# we set it to True if the input schema is wrapped in a nullable schema that we know will be preserved
# as an indication that, even if None is discovered as one of the union choices, we will not need to wrap
# the final value in another nullable schema.
#
# This is more complicated than just checking for the final outermost schema having type 'nullable' thanks
# to the possible presence of other wrapper schemas such as DefinitionsSchema, WithDefaultSchema, etc.
self._is_nullable = False
# `_choices_to_handle` serves as a stack of choices to add to the tagged union. Initially, choices
# from the union in the wrapped schema will be appended to this list, and the recursive choice-handling
# algorithm may add more choices to this stack as (nested) unions are encountered.
self._choices_to_handle: list[core_schema.CoreSchema] = []
# `_tagged_union_choices` is built during the call to `apply`, and will hold the choices to be included
# in the output TaggedUnionSchema that will replace the union from the input schema
self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {}
# `_used` is changed to True after applying the discriminator to prevent accidental re-use
self._used = False
def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""Return a new CoreSchema based on `schema` that uses a tagged-union with the discriminator provided
to this class.
Args:
schema: The input schema.
Returns:
The new core schema.
Raises:
TypeError:
- If `discriminator` is used with invalid union variant.
- If `discriminator` is used with `Union` type with one variant.
- If `discriminator` value mapped to multiple choices.
ValueError:
If the definition for ref is missing.
PydanticUserError:
- If a model in union doesn't have a discriminator field.
- If discriminator field has a non-string alias.
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
self.definitions.update(collect_definitions(schema))
assert not self._used
schema = self._apply_to_root(schema)
if self._should_be_nullable and not self._is_nullable:
schema = core_schema.nullable_schema(schema)
self._used = True
new_defs = collect_definitions(schema)
missing_defs = self.definitions.keys() - new_defs.keys()
if missing_defs:
schema = core_schema.definitions_schema(schema, [self.definitions[ref] for ref in missing_defs])
return schema
def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""This method handles the outer-most stage of recursion over the input schema:
unwrapping nullable or definitions schemas, and calling the `_handle_choice`
method iteratively on the choices extracted (recursively) from the possibly-wrapped union.
"""
if schema['type'] == 'nullable':
self._is_nullable = True
wrapped = self._apply_to_root(schema['schema'])
nullable_wrapper = schema.copy()
nullable_wrapper['schema'] = wrapped
return nullable_wrapper
if schema['type'] == 'definitions':
wrapped = self._apply_to_root(schema['schema'])
definitions_wrapper = schema.copy()
definitions_wrapper['schema'] = wrapped
return definitions_wrapper
if schema['type'] != 'union':
# If the schema is not a union, it probably means it just had a single member and
# was flattened by pydantic_core.
# However, it still may make sense to apply the discriminator to this schema,
# as a way to get discriminated-union-style error messages, so we allow this here.
schema = core_schema.union_schema([schema])
# Reverse the choices list before extending the stack so that they get handled in the order they occur
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in schema['choices'][::-1]]
self._choices_to_handle.extend(choices_schemas)
while self._choices_to_handle:
choice = self._choices_to_handle.pop()
self._handle_choice(choice)
if self._discriminator_alias is not None and self._discriminator_alias != self.discriminator:
# * We need to annotate `discriminator` as a union here to handle both branches of this conditional
# * We need to annotate `discriminator` as list[list[str | int]] and not list[list[str]] due to the
# invariance of list, and because list[list[str | int]] is the type of the discriminator argument
# to tagged_union_schema below
# * See the docstring of pydantic_core.core_schema.tagged_union_schema for more details about how to
# interpret the value of the discriminator argument to tagged_union_schema. (The list[list[str]] here
# is the appropriate way to provide a list of fallback attributes to check for a discriminator value.)
discriminator: str | list[list[str | int]] = [[self.discriminator], [self._discriminator_alias]]
else:
discriminator = self.discriminator
return core_schema.tagged_union_schema(
choices=self._tagged_union_choices,
discriminator=discriminator,
custom_error_type=schema.get('custom_error_type'),
custom_error_message=schema.get('custom_error_message'),
custom_error_context=schema.get('custom_error_context'),
strict=False,
from_attributes=True,
ref=schema.get('ref'),
metadata=schema.get('metadata'),
serialization=schema.get('serialization'),
)
def _handle_choice(self, choice: core_schema.CoreSchema) -> None:
"""This method handles the "middle" stage of recursion over the input schema.
Specifically, it is responsible for handling each choice of the outermost union
(and any "coalesced" choices obtained from inner unions).
Here, "handling" entails:
* Coalescing nested unions and compatible tagged-unions
* Tracking the presence of 'none' and 'nullable' schemas occurring as choices
* Validating that each allowed discriminator value maps to a unique choice
* Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
"""
if choice['type'] == 'definition-ref':
if choice['schema_ref'] not in self.definitions:
raise MissingDefinitionForUnionRef(choice['schema_ref'])
if choice['type'] == 'none':
self._should_be_nullable = True
elif choice['type'] == 'definitions':
self._handle_choice(choice['schema'])
elif choice['type'] == 'nullable':
self._should_be_nullable = True
self._handle_choice(choice['schema']) # unwrap the nullable schema
elif choice['type'] == 'union':
# Reverse the choices list before extending the stack so that they get handled in the order they occur
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]]
self._choices_to_handle.extend(choices_schemas)
elif choice['type'] not in {
'model',
'typed-dict',
'tagged-union',
'lax-or-strict',
'dataclass',
'dataclass-args',
'definition-ref',
} and not _core_utils.is_function_with_inner_schema(choice):
# We should eventually handle 'definition-ref' as well
raise TypeError(
f'{choice["type"]!r} is not a valid discriminated union variant;'
' should be a `BaseModel` or `dataclass`'
)
else:
if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice):
# In this case, this inner tagged-union is compatible with the outer tagged-union,
# and its choices can be coalesced into the outer TaggedUnionSchema.
subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
# Reverse the choices list before extending the stack so that they get handled in the order they occur
self._choices_to_handle.extend(subchoices[::-1])
return
inferred_discriminator_values = self._infer_discriminator_values_for_choice(choice, source_name=None)
self._set_unique_choice_for_values(choice, inferred_discriminator_values)
def _is_discriminator_shared(self, choice: core_schema.TaggedUnionSchema) -> bool:
"""This method returns a boolean indicating whether the discriminator for the `choice`
is the same as that being used for the outermost tagged union. This is used to
determine whether this TaggedUnionSchema choice should be "coalesced" into the top level,
or whether it should be treated as a separate (nested) choice.
"""
inner_discriminator = choice['discriminator']
return inner_discriminator == self.discriminator or (
isinstance(inner_discriminator, list)
and (self.discriminator in inner_discriminator or [self.discriminator] in inner_discriminator)
)
def _infer_discriminator_values_for_choice( # noqa C901
self, choice: core_schema.CoreSchema, source_name: str | None
) -> list[str | int]:
"""This function recurses over `choice`, extracting all discriminator values that should map to this choice.
`model_name` is accepted for the purpose of producing useful error messages.
"""
if choice['type'] == 'definitions':
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
elif choice['type'] == 'function-plain':
raise TypeError(
f'{choice["type"]!r} is not a valid discriminated union variant;'
' should be a `BaseModel` or `dataclass`'
)
elif _core_utils.is_function_with_inner_schema(choice):
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
elif choice['type'] == 'lax-or-strict':
return sorted(
set(
self._infer_discriminator_values_for_choice(choice['lax_schema'], source_name=None)
+ self._infer_discriminator_values_for_choice(choice['strict_schema'], source_name=None)
)
)
elif choice['type'] == 'tagged-union':
values: list[str | int] = []
# Ignore str/int "choices" since these are just references to other choices
subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
for subchoice in subchoices:
subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None)
values.extend(subchoice_values)
return values
elif choice['type'] == 'union':
values = []
for subchoice in choice['choices']:
subchoice_schema = subchoice[0] if isinstance(subchoice, tuple) else subchoice
subchoice_values = self._infer_discriminator_values_for_choice(subchoice_schema, source_name=None)
values.extend(subchoice_values)
return values
elif choice['type'] == 'nullable':
self._should_be_nullable = True
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=None)
elif choice['type'] == 'model':
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
elif choice['type'] == 'dataclass':
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
elif choice['type'] == 'model-fields':
return self._infer_discriminator_values_for_model_choice(choice, source_name=source_name)
elif choice['type'] == 'dataclass-args':
return self._infer_discriminator_values_for_dataclass_choice(choice, source_name=source_name)
elif choice['type'] == 'typed-dict':
return self._infer_discriminator_values_for_typed_dict_choice(choice, source_name=source_name)
elif choice['type'] == 'definition-ref':
schema_ref = choice['schema_ref']
if schema_ref not in self.definitions:
raise MissingDefinitionForUnionRef(schema_ref)
return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name)
else:
raise TypeError(
f'{choice["type"]!r} is not a valid discriminated union variant;'
' should be a `BaseModel` or `dataclass`'
)
def _infer_discriminator_values_for_typed_dict_choice(
self, choice: core_schema.TypedDictSchema, source_name: str | None = None
) -> list[str | int]:
"""This method just extracts the _infer_discriminator_values_for_choice logic specific to TypedDictSchema
for the sake of readability.
"""
source = 'TypedDict' if source_name is None else f'TypedDict {source_name!r}'
field = choice['fields'].get(self.discriminator)
if field is None:
raise PydanticUserError(
f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
)
return self._infer_discriminator_values_for_field(field, source)
def _infer_discriminator_values_for_model_choice(
self, choice: core_schema.ModelFieldsSchema, source_name: str | None = None
) -> list[str | int]:
source = 'ModelFields' if source_name is None else f'Model {source_name!r}'
field = choice['fields'].get(self.discriminator)
if field is None:
raise PydanticUserError(
f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
)
return self._infer_discriminator_values_for_field(field, source)
def _infer_discriminator_values_for_dataclass_choice(
self, choice: core_schema.DataclassArgsSchema, source_name: str | None = None
) -> list[str | int]:
source = 'DataclassArgs' if source_name is None else f'Dataclass {source_name!r}'
for field in choice['fields']:
if field['name'] == self.discriminator:
break
else:
raise PydanticUserError(
f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
)
return self._infer_discriminator_values_for_field(field, source)
def _infer_discriminator_values_for_field(self, field: CoreSchemaField, source: str) -> list[str | int]:
if field['type'] == 'computed-field':
# This should never occur as a discriminator, as it is only relevant to serialization
return []
alias = field.get('validation_alias', self.discriminator)
if not isinstance(alias, str):
raise PydanticUserError(
f'Alias {alias!r} is not supported in a discriminated union', code='discriminator-alias-type'
)
if self._discriminator_alias is None:
self._discriminator_alias = alias
elif self._discriminator_alias != alias:
raise PydanticUserError(
f'Aliases for discriminator {self.discriminator!r} must be the same '
f'(got {alias}, {self._discriminator_alias})',
code='discriminator-alias',
)
return self._infer_discriminator_values_for_inner_schema(field['schema'], source)
def _infer_discriminator_values_for_inner_schema(
self, schema: core_schema.CoreSchema, source: str
) -> list[str | int]:
"""When inferring discriminator values for a field, we typically extract the expected values from a literal
schema. This function does that, but also handles nested unions and defaults.
"""
if schema['type'] == 'literal':
return schema['expected']
elif schema['type'] == 'union':
# Generally when multiple values are allowed they should be placed in a single `Literal`, but
# we add this case to handle the situation where a field is annotated as a `Union` of `Literal`s.
# For example, this lets us handle `Union[Literal['key'], Union[Literal['Key'], Literal['KEY']]]`
values: list[Any] = []
for choice in schema['choices']:
choice_schema = choice[0] if isinstance(choice, tuple) else choice
choice_values = self._infer_discriminator_values_for_inner_schema(choice_schema, source)
values.extend(choice_values)
return values
elif schema['type'] == 'default':
# This will happen if the field has a default value; we ignore it while extracting the discriminator values
return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
elif schema['type'] == 'function-after':
# After validators don't affect the discriminator values
return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
elif schema['type'] in {'function-before', 'function-wrap', 'function-plain'}:
validator_type = repr(schema['type'].split('-')[1])
raise PydanticUserError(
f'Cannot use a mode={validator_type} validator in the'
f' discriminator field {self.discriminator!r} of {source}',
code='discriminator-validator',
)
else:
raise PydanticUserError(
f'{source} needs field {self.discriminator!r} to be of type `Literal`',
code='discriminator-needs-literal',
)
def _set_unique_choice_for_values(self, choice: core_schema.CoreSchema, values: Sequence[str | int]) -> None:
"""This method updates `self.tagged_union_choices` so that all provided (discriminator) `values` map to the
provided `choice`, validating that none of these values already map to another (different) choice.
"""
for discriminator_value in values:
if discriminator_value in self._tagged_union_choices:
# It is okay if `value` is already in tagged_union_choices as long as it maps to the same value.
# Because tagged_union_choices may map values to other values, we need to walk the choices dict
# until we get to a "real" choice, and confirm that is equal to the one assigned.
existing_choice = self._tagged_union_choices[discriminator_value]
if existing_choice != choice:
raise TypeError(
f'Value {discriminator_value!r} for discriminator '
f'{self.discriminator!r} mapped to multiple choices'
)
else:
self._tagged_union_choices[discriminator_value] = choice

View file

@ -0,0 +1,319 @@
"""Private logic related to fields (the `Field()` function and `FieldInfo` class), and arguments to `Annotated`."""
from __future__ import annotations as _annotations
import dataclasses
import sys
import warnings
from copy import copy
from functools import lru_cache
from typing import TYPE_CHECKING, Any
from pydantic_core import PydanticUndefined
from pydantic.errors import PydanticUserError
from . import _typing_extra
from ._config import ConfigWrapper
from ._repr import Representation
from ._typing_extra import get_cls_type_hints_lenient, get_type_hints, is_classvar, is_finalvar
if TYPE_CHECKING:
from annotated_types import BaseMetadata
from ..fields import FieldInfo
from ..main import BaseModel
from ._dataclasses import StandardDataclass
from ._decorators import DecoratorInfos
def get_type_hints_infer_globalns(
obj: Any,
localns: dict[str, Any] | None = None,
include_extras: bool = False,
) -> dict[str, Any]:
"""Gets type hints for an object by inferring the global namespace.
It uses the `typing.get_type_hints`, The only thing that we do here is fetching
global namespace from `obj.__module__` if it is not `None`.
Args:
obj: The object to get its type hints.
localns: The local namespaces.
include_extras: Whether to recursively include annotation metadata.
Returns:
The object type hints.
"""
module_name = getattr(obj, '__module__', None)
globalns: dict[str, Any] | None = None
if module_name:
try:
globalns = sys.modules[module_name].__dict__
except KeyError:
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
pass
return get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)
class PydanticMetadata(Representation):
"""Base class for annotation markers like `Strict`."""
__slots__ = ()
def pydantic_general_metadata(**metadata: Any) -> BaseMetadata:
"""Create a new `_PydanticGeneralMetadata` class with the given metadata.
Args:
**metadata: The metadata to add.
Returns:
The new `_PydanticGeneralMetadata` class.
"""
return _general_metadata_cls()(metadata) # type: ignore
@lru_cache(maxsize=None)
def _general_metadata_cls() -> type[BaseMetadata]:
"""Do it this way to avoid importing `annotated_types` at import time."""
from annotated_types import BaseMetadata
class _PydanticGeneralMetadata(PydanticMetadata, BaseMetadata):
"""Pydantic general metadata like `max_digits`."""
def __init__(self, metadata: Any):
self.__dict__ = metadata
return _PydanticGeneralMetadata # type: ignore
def collect_model_fields( # noqa: C901
cls: type[BaseModel],
bases: tuple[type[Any], ...],
config_wrapper: ConfigWrapper,
types_namespace: dict[str, Any] | None,
*,
typevars_map: dict[Any, Any] | None = None,
) -> tuple[dict[str, FieldInfo], set[str]]:
"""Collect the fields of a nascent pydantic model.
Also collect the names of any ClassVars present in the type hints.
The returned value is a tuple of two items: the fields dict, and the set of ClassVar names.
Args:
cls: BaseModel or dataclass.
bases: Parents of the class, generally `cls.__bases__`.
config_wrapper: The config wrapper instance.
types_namespace: Optional extra namespace to look for types in.
typevars_map: A dictionary mapping type variables to their concrete types.
Returns:
A tuple contains fields and class variables.
Raises:
NameError:
- If there is a conflict between a field name and protected namespaces.
- If there is a field other than `root` in `RootModel`.
- If a field shadows an attribute in the parent model.
"""
from ..fields import FieldInfo
type_hints = get_cls_type_hints_lenient(cls, types_namespace)
# https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
# annotations is only used for finding fields in parent classes
annotations = cls.__dict__.get('__annotations__', {})
fields: dict[str, FieldInfo] = {}
class_vars: set[str] = set()
for ann_name, ann_type in type_hints.items():
if ann_name == 'model_config':
# We never want to treat `model_config` as a field
# Note: we may need to change this logic if/when we introduce a `BareModel` class with no
# protected namespaces (where `model_config` might be allowed as a field name)
continue
for protected_namespace in config_wrapper.protected_namespaces:
if ann_name.startswith(protected_namespace):
for b in bases:
if hasattr(b, ann_name):
from ..main import BaseModel
if not (issubclass(b, BaseModel) and ann_name in b.model_fields):
raise NameError(
f'Field "{ann_name}" conflicts with member {getattr(b, ann_name)}'
f' of protected namespace "{protected_namespace}".'
)
else:
valid_namespaces = tuple(
x for x in config_wrapper.protected_namespaces if not ann_name.startswith(x)
)
warnings.warn(
f'Field "{ann_name}" has conflict with protected namespace "{protected_namespace}".'
'\n\nYou may be able to resolve this warning by setting'
f" `model_config['protected_namespaces'] = {valid_namespaces}`.",
UserWarning,
)
if is_classvar(ann_type):
class_vars.add(ann_name)
continue
if _is_finalvar_with_default_val(ann_type, getattr(cls, ann_name, PydanticUndefined)):
class_vars.add(ann_name)
continue
if not is_valid_field_name(ann_name):
continue
if cls.__pydantic_root_model__ and ann_name != 'root':
raise NameError(
f"Unexpected field with name {ann_name!r}; only 'root' is allowed as a field of a `RootModel`"
)
# when building a generic model with `MyModel[int]`, the generic_origin check makes sure we don't get
# "... shadows an attribute" errors
generic_origin = getattr(cls, '__pydantic_generic_metadata__', {}).get('origin')
for base in bases:
dataclass_fields = {
field.name for field in (dataclasses.fields(base) if dataclasses.is_dataclass(base) else ())
}
if hasattr(base, ann_name):
if base is generic_origin:
# Don't error when "shadowing" of attributes in parametrized generics
continue
if ann_name in dataclass_fields:
# Don't error when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set
# on the class instance.
continue
warnings.warn(
f'Field name "{ann_name}" shadows an attribute in parent "{base.__qualname__}"; ',
UserWarning,
)
try:
default = getattr(cls, ann_name, PydanticUndefined)
if default is PydanticUndefined:
raise AttributeError
except AttributeError:
if ann_name in annotations:
field_info = FieldInfo.from_annotation(ann_type)
else:
# if field has no default value and is not in __annotations__ this means that it is
# defined in a base class and we can take it from there
model_fields_lookup: dict[str, FieldInfo] = {}
for x in cls.__bases__[::-1]:
model_fields_lookup.update(getattr(x, 'model_fields', {}))
if ann_name in model_fields_lookup:
# The field was present on one of the (possibly multiple) base classes
# copy the field to make sure typevar substitutions don't cause issues with the base classes
field_info = copy(model_fields_lookup[ann_name])
else:
# The field was not found on any base classes; this seems to be caused by fields not getting
# generated thanks to models not being fully defined while initializing recursive models.
# Nothing stops us from just creating a new FieldInfo for this type hint, so we do this.
field_info = FieldInfo.from_annotation(ann_type)
else:
field_info = FieldInfo.from_annotated_attribute(ann_type, default)
# attributes which are fields are removed from the class namespace:
# 1. To match the behaviour of annotation-only fields
# 2. To avoid false positives in the NameError check above
try:
delattr(cls, ann_name)
except AttributeError:
pass # indicates the attribute was on a parent class
# Use cls.__dict__['__pydantic_decorators__'] instead of cls.__pydantic_decorators__
# to make sure the decorators have already been built for this exact class
decorators: DecoratorInfos = cls.__dict__['__pydantic_decorators__']
if ann_name in decorators.computed_fields:
raise ValueError("you can't override a field with a computed field")
fields[ann_name] = field_info
if typevars_map:
for field in fields.values():
field.apply_typevars_map(typevars_map, types_namespace)
return fields, class_vars
def _is_finalvar_with_default_val(type_: type[Any], val: Any) -> bool:
from ..fields import FieldInfo
if not is_finalvar(type_):
return False
elif val is PydanticUndefined:
return False
elif isinstance(val, FieldInfo) and (val.default is PydanticUndefined and val.default_factory is None):
return False
else:
return True
def collect_dataclass_fields(
cls: type[StandardDataclass], types_namespace: dict[str, Any] | None, *, typevars_map: dict[Any, Any] | None = None
) -> dict[str, FieldInfo]:
"""Collect the fields of a dataclass.
Args:
cls: dataclass.
types_namespace: Optional extra namespace to look for types in.
typevars_map: A dictionary mapping type variables to their concrete types.
Returns:
The dataclass fields.
"""
from ..fields import FieldInfo
fields: dict[str, FieldInfo] = {}
dataclass_fields: dict[str, dataclasses.Field] = cls.__dataclass_fields__
cls_localns = dict(vars(cls)) # this matches get_cls_type_hints_lenient, but all tests pass with `= None` instead
source_module = sys.modules.get(cls.__module__)
if source_module is not None:
types_namespace = {**source_module.__dict__, **(types_namespace or {})}
for ann_name, dataclass_field in dataclass_fields.items():
ann_type = _typing_extra.eval_type_lenient(dataclass_field.type, types_namespace, cls_localns)
if is_classvar(ann_type):
continue
if (
not dataclass_field.init
and dataclass_field.default == dataclasses.MISSING
and dataclass_field.default_factory == dataclasses.MISSING
):
# TODO: We should probably do something with this so that validate_assignment behaves properly
# Issue: https://github.com/pydantic/pydantic/issues/5470
continue
if isinstance(dataclass_field.default, FieldInfo):
if dataclass_field.default.init_var:
if dataclass_field.default.init is False:
raise PydanticUserError(
f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.',
code='clashing-init-and-init-var',
)
# TODO: same note as above re validate_assignment
continue
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field.default)
else:
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field)
fields[ann_name] = field_info
if field_info.default is not PydanticUndefined and isinstance(getattr(cls, ann_name, field_info), FieldInfo):
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
setattr(cls, ann_name, field_info.default)
if typevars_map:
for field in fields.values():
field.apply_typevars_map(typevars_map, types_namespace)
return fields
def is_valid_field_name(name: str) -> bool:
return not name.startswith('_')
def is_valid_privateattr_name(name: str) -> bool:
return name.startswith('_') and not name.startswith('__')

View file

@ -0,0 +1,23 @@
from __future__ import annotations as _annotations
from dataclasses import dataclass
from typing import Union
@dataclass
class PydanticRecursiveRef:
type_ref: str
__name__ = 'PydanticRecursiveRef'
__hash__ = object.__hash__
def __call__(self) -> None:
"""Defining __call__ is necessary for the `typing` module to let you use an instance of
this class as the result of resolving a standard ForwardRef.
"""
def __or__(self, other):
return Union[self, other] # type: ignore
def __ror__(self, other):
return Union[other, self] # type: ignore

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,517 @@
from __future__ import annotations
import sys
import types
import typing
from collections import ChainMap
from contextlib import contextmanager
from contextvars import ContextVar
from types import prepare_class
from typing import TYPE_CHECKING, Any, Iterator, List, Mapping, MutableMapping, Tuple, TypeVar
from weakref import WeakValueDictionary
import typing_extensions
from ._core_utils import get_type_ref
from ._forward_ref import PydanticRecursiveRef
from ._typing_extra import TypeVarType, typing_base
from ._utils import all_identical, is_model_class
if sys.version_info >= (3, 10):
from typing import _UnionGenericAlias # type: ignore[attr-defined]
if TYPE_CHECKING:
from ..main import BaseModel
GenericTypesCacheKey = Tuple[Any, Any, Tuple[Any, ...]]
# Note: We want to remove LimitedDict, but to do this, we'd need to improve the handling of generics caching.
# Right now, to handle recursive generics, we some types must remain cached for brief periods without references.
# By chaining the WeakValuesDict with a LimitedDict, we have a way to retain caching for all types with references,
# while also retaining a limited number of types even without references. This is generally enough to build
# specific recursive generic models without losing required items out of the cache.
KT = TypeVar('KT')
VT = TypeVar('VT')
_LIMITED_DICT_SIZE = 100
if TYPE_CHECKING:
class LimitedDict(dict, MutableMapping[KT, VT]):
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
...
else:
class LimitedDict(dict):
"""Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage.
Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache.
"""
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
self.size_limit = size_limit
super().__init__()
def __setitem__(self, __key: Any, __value: Any) -> None:
super().__setitem__(__key, __value)
if len(self) > self.size_limit:
excess = len(self) - self.size_limit + self.size_limit // 10
to_remove = list(self.keys())[:excess]
for key in to_remove:
del self[key]
# weak dictionaries allow the dynamically created parametrized versions of generic models to get collected
# once they are no longer referenced by the caller.
if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9
GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
else:
GenericTypesCache = WeakValueDictionary
if TYPE_CHECKING:
class DeepChainMap(ChainMap[KT, VT]): # type: ignore
...
else:
class DeepChainMap(ChainMap):
"""Variant of ChainMap that allows direct updates to inner scopes.
Taken from https://docs.python.org/3/library/collections.html#collections.ChainMap,
with some light modifications for this use case.
"""
def clear(self) -> None:
for mapping in self.maps:
mapping.clear()
def __setitem__(self, key: KT, value: VT) -> None:
for mapping in self.maps:
mapping[key] = value
def __delitem__(self, key: KT) -> None:
hit = False
for mapping in self.maps:
if key in mapping:
del mapping[key]
hit = True
if not hit:
raise KeyError(key)
# Despite the fact that LimitedDict _seems_ no longer necessary, I'm very nervous to actually remove it
# and discover later on that we need to re-add all this infrastructure...
# _GENERIC_TYPES_CACHE = DeepChainMap(GenericTypesCache(), LimitedDict())
_GENERIC_TYPES_CACHE = GenericTypesCache()
class PydanticGenericMetadata(typing_extensions.TypedDict):
origin: type[BaseModel] | None # analogous to typing._GenericAlias.__origin__
args: tuple[Any, ...] # analogous to typing._GenericAlias.__args__
parameters: tuple[type[Any], ...] # analogous to typing.Generic.__parameters__
def create_generic_submodel(
model_name: str, origin: type[BaseModel], args: tuple[Any, ...], params: tuple[Any, ...]
) -> type[BaseModel]:
"""Dynamically create a submodel of a provided (generic) BaseModel.
This is used when producing concrete parametrizations of generic models. This function
only *creates* the new subclass; the schema/validators/serialization must be updated to
reflect a concrete parametrization elsewhere.
Args:
model_name: The name of the newly created model.
origin: The base class for the new model to inherit from.
args: A tuple of generic metadata arguments.
params: A tuple of generic metadata parameters.
Returns:
The created submodel.
"""
namespace: dict[str, Any] = {'__module__': origin.__module__}
bases = (origin,)
meta, ns, kwds = prepare_class(model_name, bases)
namespace.update(ns)
created_model = meta(
model_name,
bases,
namespace,
__pydantic_generic_metadata__={
'origin': origin,
'args': args,
'parameters': params,
},
__pydantic_reset_parent_namespace__=False,
**kwds,
)
model_module, called_globally = _get_caller_frame_info(depth=3)
if called_globally: # create global reference and therefore allow pickling
object_by_reference = None
reference_name = model_name
reference_module_globals = sys.modules[created_model.__module__].__dict__
while object_by_reference is not created_model:
object_by_reference = reference_module_globals.setdefault(reference_name, created_model)
reference_name += '_'
return created_model
def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
"""Used inside a function to check whether it was called globally.
Args:
depth: The depth to get the frame.
Returns:
A tuple contains `module_name` and `called_globally`.
Raises:
RuntimeError: If the function is not called inside a function.
"""
try:
previous_caller_frame = sys._getframe(depth)
except ValueError as e:
raise RuntimeError('This function must be used inside another function') from e
except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it
return None, False
frame_globals = previous_caller_frame.f_globals
return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals
DictValues: type[Any] = {}.values().__class__
def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
"""Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found.
This is inspired as an alternative to directly accessing the `__parameters__` attribute of a GenericAlias,
since __parameters__ of (nested) generic BaseModel subclasses won't show up in that list.
"""
if isinstance(v, TypeVar):
yield v
elif is_model_class(v):
yield from v.__pydantic_generic_metadata__['parameters']
elif isinstance(v, (DictValues, list)):
for var in v:
yield from iter_contained_typevars(var)
else:
args = get_args(v)
for arg in args:
yield from iter_contained_typevars(arg)
def get_args(v: Any) -> Any:
pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
if pydantic_generic_metadata:
return pydantic_generic_metadata.get('args')
return typing_extensions.get_args(v)
def get_origin(v: Any) -> Any:
pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
if pydantic_generic_metadata:
return pydantic_generic_metadata.get('origin')
return typing_extensions.get_origin(v)
def get_standard_typevars_map(cls: type[Any]) -> dict[TypeVarType, Any] | None:
"""Package a generic type's typevars and parametrization (if present) into a dictionary compatible with the
`replace_types` function. Specifically, this works with standard typing generics and typing._GenericAlias.
"""
origin = get_origin(cls)
if origin is None:
return None
if not hasattr(origin, '__parameters__'):
return None
# In this case, we know that cls is a _GenericAlias, and origin is the generic type
# So it is safe to access cls.__args__ and origin.__parameters__
args: tuple[Any, ...] = cls.__args__ # type: ignore
parameters: tuple[TypeVarType, ...] = origin.__parameters__
return dict(zip(parameters, args))
def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVarType, Any] | None:
"""Package a generic BaseModel's typevars and concrete parametrization (if present) into a dictionary compatible
with the `replace_types` function.
Since BaseModel.__class_getitem__ does not produce a typing._GenericAlias, and the BaseModel generic info is
stored in the __pydantic_generic_metadata__ attribute, we need special handling here.
"""
# TODO: This could be unified with `get_standard_typevars_map` if we stored the generic metadata
# in the __origin__, __args__, and __parameters__ attributes of the model.
generic_metadata = cls.__pydantic_generic_metadata__
origin = generic_metadata['origin']
args = generic_metadata['args']
return dict(zip(iter_contained_typevars(origin), args))
def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
Args:
type_: The class or generic alias.
type_map: Mapping from `TypeVar` instance to concrete types.
Returns:
A new type representing the basic structure of `type_` with all
`typevar_map` keys recursively replaced.
Example:
```py
from typing import List, Tuple, Union
from pydantic._internal._generics import replace_types
replace_types(Tuple[str, Union[List[str], float]], {str: int})
#> Tuple[int, Union[List[int], float]]
```
"""
if not type_map:
return type_
type_args = get_args(type_)
origin_type = get_origin(type_)
if origin_type is typing_extensions.Annotated:
annotated_type, *annotations = type_args
annotated = replace_types(annotated_type, type_map)
for annotation in annotations:
annotated = typing_extensions.Annotated[annotated, annotation]
return annotated
# Having type args is a good indicator that this is a typing module
# class instantiation or a generic alias of some sort.
if type_args:
resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
if all_identical(type_args, resolved_type_args):
# If all arguments are the same, there is no need to modify the
# type or create a new object at all
return type_
if (
origin_type is not None
and isinstance(type_, typing_base)
and not isinstance(origin_type, typing_base)
and getattr(type_, '_name', None) is not None
):
# In python < 3.9 generic aliases don't exist so any of these like `list`,
# `type` or `collections.abc.Callable` need to be translated.
# See: https://www.python.org/dev/peps/pep-0585
origin_type = getattr(typing, type_._name)
assert origin_type is not None
# PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__.
# We also cannot use isinstance() since we have to compare types.
if sys.version_info >= (3, 10) and origin_type is types.UnionType:
return _UnionGenericAlias(origin_type, resolved_type_args)
# NotRequired[T] and Required[T] don't support tuple type resolved_type_args, hence the condition below
return origin_type[resolved_type_args[0] if len(resolved_type_args) == 1 else resolved_type_args]
# We handle pydantic generic models separately as they don't have the same
# semantics as "typing" classes or generic aliases
if not origin_type and is_model_class(type_):
parameters = type_.__pydantic_generic_metadata__['parameters']
if not parameters:
return type_
resolved_type_args = tuple(replace_types(t, type_map) for t in parameters)
if all_identical(parameters, resolved_type_args):
return type_
return type_[resolved_type_args]
# Handle special case for typehints that can have lists as arguments.
# `typing.Callable[[int, str], int]` is an example for this.
if isinstance(type_, (List, list)):
resolved_list = list(replace_types(element, type_map) for element in type_)
if all_identical(type_, resolved_list):
return type_
return resolved_list
# If all else fails, we try to resolve the type directly and otherwise just
# return the input with no modifications.
return type_map.get(type_, type_)
def has_instance_in_type(type_: Any, isinstance_target: Any) -> bool:
"""Checks if the type, or any of its arbitrary nested args, satisfy
`isinstance(<type>, isinstance_target)`.
"""
if isinstance(type_, isinstance_target):
return True
type_args = get_args(type_)
origin_type = get_origin(type_)
if origin_type is typing_extensions.Annotated:
annotated_type, *annotations = type_args
return has_instance_in_type(annotated_type, isinstance_target)
# Having type args is a good indicator that this is a typing module
# class instantiation or a generic alias of some sort.
if any(has_instance_in_type(a, isinstance_target) for a in type_args):
return True
# Handle special case for typehints that can have lists as arguments.
# `typing.Callable[[int, str], int]` is an example for this.
if isinstance(type_, (List, list)) and not isinstance(type_, typing_extensions.ParamSpec):
if any(has_instance_in_type(element, isinstance_target) for element in type_):
return True
return False
def check_parameters_count(cls: type[BaseModel], parameters: tuple[Any, ...]) -> None:
"""Check the generic model parameters count is equal.
Args:
cls: The generic model.
parameters: A tuple of passed parameters to the generic model.
Raises:
TypeError: If the passed parameters count is not equal to generic model parameters count.
"""
actual = len(parameters)
expected = len(cls.__pydantic_generic_metadata__['parameters'])
if actual != expected:
description = 'many' if actual > expected else 'few'
raise TypeError(f'Too {description} parameters for {cls}; actual {actual}, expected {expected}')
_generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None)
@contextmanager
def generic_recursion_self_type(
origin: type[BaseModel], args: tuple[Any, ...]
) -> Iterator[PydanticRecursiveRef | None]:
"""This contextmanager should be placed around the recursive calls used to build a generic type,
and accept as arguments the generic origin type and the type arguments being passed to it.
If the same origin and arguments are observed twice, it implies that a self-reference placeholder
can be used while building the core schema, and will produce a schema_ref that will be valid in the
final parent schema.
"""
previously_seen_type_refs = _generic_recursion_cache.get()
if previously_seen_type_refs is None:
previously_seen_type_refs = set()
token = _generic_recursion_cache.set(previously_seen_type_refs)
else:
token = None
try:
type_ref = get_type_ref(origin, args_override=args)
if type_ref in previously_seen_type_refs:
self_type = PydanticRecursiveRef(type_ref=type_ref)
yield self_type
else:
previously_seen_type_refs.add(type_ref)
yield None
finally:
if token:
_generic_recursion_cache.reset(token)
def recursively_defined_type_refs() -> set[str]:
visited = _generic_recursion_cache.get()
if not visited:
return set() # not in a generic recursion, so there are no types
return visited.copy() # don't allow modifications
def get_cached_generic_type_early(parent: type[BaseModel], typevar_values: Any) -> type[BaseModel] | None:
"""The use of a two-stage cache lookup approach was necessary to have the highest performance possible for
repeated calls to `__class_getitem__` on generic types (which may happen in tighter loops during runtime),
while still ensuring that certain alternative parametrizations ultimately resolve to the same type.
As a concrete example, this approach was necessary to make Model[List[T]][int] equal to Model[List[int]].
The approach could be modified to not use two different cache keys at different points, but the
_early_cache_key is optimized to be as quick to compute as possible (for repeated-access speed), and the
_late_cache_key is optimized to be as "correct" as possible, so that two types that will ultimately be the
same after resolving the type arguments will always produce cache hits.
If we wanted to move to only using a single cache key per type, we would either need to always use the
slower/more computationally intensive logic associated with _late_cache_key, or would need to accept
that Model[List[T]][int] is a different type than Model[List[T]][int]. Because we rely on subclass relationships
during validation, I think it is worthwhile to ensure that types that are functionally equivalent are actually
equal.
"""
return _GENERIC_TYPES_CACHE.get(_early_cache_key(parent, typevar_values))
def get_cached_generic_type_late(
parent: type[BaseModel], typevar_values: Any, origin: type[BaseModel], args: tuple[Any, ...]
) -> type[BaseModel] | None:
"""See the docstring of `get_cached_generic_type_early` for more information about the two-stage cache lookup."""
cached = _GENERIC_TYPES_CACHE.get(_late_cache_key(origin, args, typevar_values))
if cached is not None:
set_cached_generic_type(parent, typevar_values, cached, origin, args)
return cached
def set_cached_generic_type(
parent: type[BaseModel],
typevar_values: tuple[Any, ...],
type_: type[BaseModel],
origin: type[BaseModel] | None = None,
args: tuple[Any, ...] | None = None,
) -> None:
"""See the docstring of `get_cached_generic_type_early` for more information about why items are cached with
two different keys.
"""
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values)] = type_
if len(typevar_values) == 1:
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values[0])] = type_
if origin and args:
_GENERIC_TYPES_CACHE[_late_cache_key(origin, args, typevar_values)] = type_
def _union_orderings_key(typevar_values: Any) -> Any:
"""This is intended to help differentiate between Union types with the same arguments in different order.
Thanks to caching internal to the `typing` module, it is not possible to distinguish between
List[Union[int, float]] and List[Union[float, int]] (and similarly for other "parent" origins besides List)
because `typing` considers Union[int, float] to be equal to Union[float, int].
However, you _can_ distinguish between (top-level) Union[int, float] vs. Union[float, int].
Because we parse items as the first Union type that is successful, we get slightly more consistent behavior
if we make an effort to distinguish the ordering of items in a union. It would be best if we could _always_
get the exact-correct order of items in the union, but that would require a change to the `typing` module itself.
(See https://github.com/python/cpython/issues/86483 for reference.)
"""
if isinstance(typevar_values, tuple):
args_data = []
for value in typevar_values:
args_data.append(_union_orderings_key(value))
return tuple(args_data)
elif typing_extensions.get_origin(typevar_values) is typing.Union:
return get_args(typevar_values)
else:
return ()
def _early_cache_key(cls: type[BaseModel], typevar_values: Any) -> GenericTypesCacheKey:
"""This is intended for minimal computational overhead during lookups of cached types.
Note that this is overly simplistic, and it's possible that two different cls/typevar_values
inputs would ultimately result in the same type being created in BaseModel.__class_getitem__.
To handle this, we have a fallback _late_cache_key that is checked later if the _early_cache_key
lookup fails, and should result in a cache hit _precisely_ when the inputs to __class_getitem__
would result in the same type.
"""
return cls, typevar_values, _union_orderings_key(typevar_values)
def _late_cache_key(origin: type[BaseModel], args: tuple[Any, ...], typevar_values: Any) -> GenericTypesCacheKey:
"""This is intended for use later in the process of creating a new type, when we have more information
about the exact args that will be passed. If it turns out that a different set of inputs to
__class_getitem__ resulted in the same inputs to the generic type creation process, we can still
return the cached type, and update the cache with the _early_cache_key as well.
"""
# The _union_orderings_key is placed at the start here to ensure there cannot be a collision with an
# _early_cache_key, as that function will always produce a BaseModel subclass as the first item in the key,
# whereas this function will always produce a tuple as the first item in the key.
return _union_orderings_key(typevar_values), origin, args

View file

@ -0,0 +1,26 @@
"""Git utilities, adopted from mypy's git utilities (https://github.com/python/mypy/blob/master/mypy/git.py)."""
from __future__ import annotations
import os
import subprocess
def is_git_repo(dir: str) -> bool:
"""Is the given directory version-controlled with git?"""
return os.path.exists(os.path.join(dir, '.git'))
def have_git() -> bool:
"""Can we run the git executable?"""
try:
subprocess.check_output(['git', '--help'])
return True
except subprocess.CalledProcessError:
return False
except OSError:
return False
def git_revision(dir: str) -> str:
"""Get the SHA-1 of the HEAD of a git repository."""
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=dir).decode('utf-8').strip()

View file

@ -0,0 +1,10 @@
import sys
from typing import Any, Dict
dataclass_kwargs: Dict[str, Any]
# `slots` is available on Python >= 3.10
if sys.version_info >= (3, 10):
slots_true = {'slots': True}
else:
slots_true = {}

View file

@ -0,0 +1,410 @@
from __future__ import annotations
from collections import defaultdict
from copy import copy
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Iterable
from pydantic_core import CoreSchema, PydanticCustomError, to_jsonable_python
from pydantic_core import core_schema as cs
from ._fields import PydanticMetadata
if TYPE_CHECKING:
from ..annotated_handlers import GetJsonSchemaHandler
STRICT = {'strict'}
SEQUENCE_CONSTRAINTS = {'min_length', 'max_length'}
INEQUALITY = {'le', 'ge', 'lt', 'gt'}
NUMERIC_CONSTRAINTS = {'multiple_of', 'allow_inf_nan', *INEQUALITY}
STR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT, 'strip_whitespace', 'to_lower', 'to_upper', 'pattern'}
BYTES_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
LIST_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
TUPLE_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
SET_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
DICT_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
GENERATOR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
BOOL_CONSTRAINTS = STRICT
UUID_CONSTRAINTS = STRICT
DATE_TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
TIMEDELTA_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
LAX_OR_STRICT_CONSTRAINTS = STRICT
UNION_CONSTRAINTS = {'union_mode'}
URL_CONSTRAINTS = {
'max_length',
'allowed_schemes',
'host_required',
'default_host',
'default_port',
'default_path',
}
TEXT_SCHEMA_TYPES = ('str', 'bytes', 'url', 'multi-host-url')
SEQUENCE_SCHEMA_TYPES = ('list', 'tuple', 'set', 'frozenset', 'generator', *TEXT_SCHEMA_TYPES)
NUMERIC_SCHEMA_TYPES = ('float', 'int', 'date', 'time', 'timedelta', 'datetime')
CONSTRAINTS_TO_ALLOWED_SCHEMAS: dict[str, set[str]] = defaultdict(set)
for constraint in STR_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(TEXT_SCHEMA_TYPES)
for constraint in BYTES_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bytes',))
for constraint in LIST_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('list',))
for constraint in TUPLE_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('tuple',))
for constraint in SET_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('set', 'frozenset'))
for constraint in DICT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('dict',))
for constraint in GENERATOR_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('generator',))
for constraint in FLOAT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('float',))
for constraint in INT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('int',))
for constraint in DATE_TIME_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('date', 'time', 'datetime'))
for constraint in TIMEDELTA_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('timedelta',))
for constraint in TIME_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('time',))
for schema_type in (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model'):
CONSTRAINTS_TO_ALLOWED_SCHEMAS['strict'].add(schema_type)
for constraint in UNION_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('union',))
for constraint in URL_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('url', 'multi-host-url'))
for constraint in BOOL_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bool',))
for constraint in UUID_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('uuid',))
for constraint in LAX_OR_STRICT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('lax-or-strict',))
def add_js_update_schema(s: cs.CoreSchema, f: Callable[[], dict[str, Any]]) -> None:
def update_js_schema(s: cs.CoreSchema, handler: GetJsonSchemaHandler) -> dict[str, Any]:
js_schema = handler(s)
js_schema.update(f())
return js_schema
if 'metadata' in s:
metadata = s['metadata']
if 'pydantic_js_functions' in s:
metadata['pydantic_js_functions'].append(update_js_schema)
else:
metadata['pydantic_js_functions'] = [update_js_schema]
else:
s['metadata'] = {'pydantic_js_functions': [update_js_schema]}
def as_jsonable_value(v: Any) -> Any:
if type(v) not in (int, str, float, bytes, bool, type(None)):
return to_jsonable_python(v)
return v
def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
"""Expand the annotations.
Args:
annotations: An iterable of annotations.
Returns:
An iterable of expanded annotations.
Example:
```py
from annotated_types import Ge, Len
from pydantic._internal._known_annotated_metadata import expand_grouped_metadata
print(list(expand_grouped_metadata([Ge(4), Len(5)])))
#> [Ge(ge=4), MinLen(min_length=5)]
```
"""
import annotated_types as at
from pydantic.fields import FieldInfo # circular import
for annotation in annotations:
if isinstance(annotation, at.GroupedMetadata):
yield from annotation
elif isinstance(annotation, FieldInfo):
yield from annotation.metadata
# this is a bit problematic in that it results in duplicate metadata
# all of our "consumers" can handle it, but it is not ideal
# we probably should split up FieldInfo into:
# - annotated types metadata
# - individual metadata known only to Pydantic
annotation = copy(annotation)
annotation.metadata = []
yield annotation
else:
yield annotation
def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | None: # noqa: C901
"""Apply `annotation` to `schema` if it is an annotation we know about (Gt, Le, etc.).
Otherwise return `None`.
This does not handle all known annotations. If / when it does, it can always
return a CoreSchema and return the unmodified schema if the annotation should be ignored.
Assumes that GroupedMetadata has already been expanded via `expand_grouped_metadata`.
Args:
annotation: The annotation.
schema: The schema.
Returns:
An updated schema with annotation if it is an annotation we know about, `None` otherwise.
Raises:
PydanticCustomError: If `Predicate` fails.
"""
import annotated_types as at
from . import _validators
schema = schema.copy()
schema_update, other_metadata = collect_known_metadata([annotation])
schema_type = schema['type']
for constraint, value in schema_update.items():
if constraint not in CONSTRAINTS_TO_ALLOWED_SCHEMAS:
raise ValueError(f'Unknown constraint {constraint}')
allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint]
if schema_type in allowed_schemas:
if constraint == 'union_mode' and schema_type == 'union':
schema['mode'] = value # type: ignore # schema is UnionSchema
else:
schema[constraint] = value
continue
if constraint == 'allow_inf_nan' and value is False:
return cs.no_info_after_validator_function(
_validators.forbid_inf_nan_check,
schema,
)
elif constraint == 'pattern':
# insert a str schema to make sure the regex engine matches
return cs.chain_schema(
[
schema,
cs.str_schema(pattern=value),
]
)
elif constraint == 'gt':
s = cs.no_info_after_validator_function(
partial(_validators.greater_than_validator, gt=value),
schema,
)
add_js_update_schema(s, lambda: {'gt': as_jsonable_value(value)})
return s
elif constraint == 'ge':
return cs.no_info_after_validator_function(
partial(_validators.greater_than_or_equal_validator, ge=value),
schema,
)
elif constraint == 'lt':
return cs.no_info_after_validator_function(
partial(_validators.less_than_validator, lt=value),
schema,
)
elif constraint == 'le':
return cs.no_info_after_validator_function(
partial(_validators.less_than_or_equal_validator, le=value),
schema,
)
elif constraint == 'multiple_of':
return cs.no_info_after_validator_function(
partial(_validators.multiple_of_validator, multiple_of=value),
schema,
)
elif constraint == 'min_length':
s = cs.no_info_after_validator_function(
partial(_validators.min_length_validator, min_length=value),
schema,
)
add_js_update_schema(s, lambda: {'minLength': (as_jsonable_value(value))})
return s
elif constraint == 'max_length':
s = cs.no_info_after_validator_function(
partial(_validators.max_length_validator, max_length=value),
schema,
)
add_js_update_schema(s, lambda: {'maxLength': (as_jsonable_value(value))})
return s
elif constraint == 'strip_whitespace':
return cs.chain_schema(
[
schema,
cs.str_schema(strip_whitespace=True),
]
)
elif constraint == 'to_lower':
return cs.chain_schema(
[
schema,
cs.str_schema(to_lower=True),
]
)
elif constraint == 'to_upper':
return cs.chain_schema(
[
schema,
cs.str_schema(to_upper=True),
]
)
elif constraint == 'min_length':
return cs.no_info_after_validator_function(
partial(_validators.min_length_validator, min_length=annotation.min_length),
schema,
)
elif constraint == 'max_length':
return cs.no_info_after_validator_function(
partial(_validators.max_length_validator, max_length=annotation.max_length),
schema,
)
else:
raise RuntimeError(f'Unable to apply constraint {constraint} to schema {schema_type}')
for annotation in other_metadata:
if isinstance(annotation, at.Gt):
return cs.no_info_after_validator_function(
partial(_validators.greater_than_validator, gt=annotation.gt),
schema,
)
elif isinstance(annotation, at.Ge):
return cs.no_info_after_validator_function(
partial(_validators.greater_than_or_equal_validator, ge=annotation.ge),
schema,
)
elif isinstance(annotation, at.Lt):
return cs.no_info_after_validator_function(
partial(_validators.less_than_validator, lt=annotation.lt),
schema,
)
elif isinstance(annotation, at.Le):
return cs.no_info_after_validator_function(
partial(_validators.less_than_or_equal_validator, le=annotation.le),
schema,
)
elif isinstance(annotation, at.MultipleOf):
return cs.no_info_after_validator_function(
partial(_validators.multiple_of_validator, multiple_of=annotation.multiple_of),
schema,
)
elif isinstance(annotation, at.MinLen):
return cs.no_info_after_validator_function(
partial(_validators.min_length_validator, min_length=annotation.min_length),
schema,
)
elif isinstance(annotation, at.MaxLen):
return cs.no_info_after_validator_function(
partial(_validators.max_length_validator, max_length=annotation.max_length),
schema,
)
elif isinstance(annotation, at.Predicate):
predicate_name = f'{annotation.func.__qualname__} ' if hasattr(annotation.func, '__qualname__') else ''
def val_func(v: Any) -> Any:
# annotation.func may also raise an exception, let it pass through
if not annotation.func(v):
raise PydanticCustomError(
'predicate_failed',
f'Predicate {predicate_name}failed', # type: ignore
)
return v
return cs.no_info_after_validator_function(val_func, schema)
# ignore any other unknown metadata
return None
return schema
def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any], list[Any]]:
"""Split `annotations` into known metadata and unknown annotations.
Args:
annotations: An iterable of annotations.
Returns:
A tuple contains a dict of known metadata and a list of unknown annotations.
Example:
```py
from annotated_types import Gt, Len
from pydantic._internal._known_annotated_metadata import collect_known_metadata
print(collect_known_metadata([Gt(1), Len(42), ...]))
#> ({'gt': 1, 'min_length': 42}, [Ellipsis])
```
"""
import annotated_types as at
annotations = expand_grouped_metadata(annotations)
res: dict[str, Any] = {}
remaining: list[Any] = []
for annotation in annotations:
# isinstance(annotation, PydanticMetadata) also covers ._fields:_PydanticGeneralMetadata
if isinstance(annotation, PydanticMetadata):
res.update(annotation.__dict__)
# we don't use dataclasses.asdict because that recursively calls asdict on the field values
elif isinstance(annotation, at.MinLen):
res.update({'min_length': annotation.min_length})
elif isinstance(annotation, at.MaxLen):
res.update({'max_length': annotation.max_length})
elif isinstance(annotation, at.Gt):
res.update({'gt': annotation.gt})
elif isinstance(annotation, at.Ge):
res.update({'ge': annotation.ge})
elif isinstance(annotation, at.Lt):
res.update({'lt': annotation.lt})
elif isinstance(annotation, at.Le):
res.update({'le': annotation.le})
elif isinstance(annotation, at.MultipleOf):
res.update({'multiple_of': annotation.multiple_of})
elif isinstance(annotation, type) and issubclass(annotation, PydanticMetadata):
# also support PydanticMetadata classes being used without initialisation,
# e.g. `Annotated[int, Strict]` as well as `Annotated[int, Strict()]`
res.update({k: v for k, v in vars(annotation).items() if not k.startswith('_')})
else:
remaining.append(annotation)
# Nones can sneak in but pydantic-core will reject them
# it'd be nice to clean things up so we don't put in None (we probably don't _need_ to, it was just easier)
# but this is simple enough to kick that can down the road
res = {k: v for k, v in res.items() if v is not None}
return res, remaining
def check_metadata(metadata: dict[str, Any], allowed: Iterable[str], source_type: Any) -> None:
"""A small utility function to validate that the given metadata can be applied to the target.
More than saving lines of code, this gives us a consistent error message for all of our internal implementations.
Args:
metadata: A dict of metadata.
allowed: An iterable of allowed metadata.
source_type: The source type.
Raises:
TypeError: If there is metadatas that can't be applied on source type.
"""
unknown = metadata.keys() - set(allowed)
if unknown:
raise TypeError(
f'The following constraints cannot be applied to {source_type!r}: {", ".join([f"{k!r}" for k in unknown])}'
)

View file

@ -0,0 +1,140 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable, Generic, TypeVar
from pydantic_core import SchemaSerializer, SchemaValidator
from typing_extensions import Literal
from ..errors import PydanticErrorCodes, PydanticUserError
if TYPE_CHECKING:
from ..dataclasses import PydanticDataclass
from ..main import BaseModel
ValSer = TypeVar('ValSer', SchemaValidator, SchemaSerializer)
class MockValSer(Generic[ValSer]):
"""Mocker for `pydantic_core.SchemaValidator` or `pydantic_core.SchemaSerializer` which optionally attempts to
rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails.
"""
__slots__ = '_error_message', '_code', '_val_or_ser', '_attempt_rebuild'
def __init__(
self,
error_message: str,
*,
code: PydanticErrorCodes,
val_or_ser: Literal['validator', 'serializer'],
attempt_rebuild: Callable[[], ValSer | None] | None = None,
) -> None:
self._error_message = error_message
self._val_or_ser = SchemaValidator if val_or_ser == 'validator' else SchemaSerializer
self._code: PydanticErrorCodes = code
self._attempt_rebuild = attempt_rebuild
def __getattr__(self, item: str) -> None:
__tracebackhide__ = True
if self._attempt_rebuild:
val_ser = self._attempt_rebuild()
if val_ser is not None:
return getattr(val_ser, item)
# raise an AttributeError if `item` doesn't exist
getattr(self._val_or_ser, item)
raise PydanticUserError(self._error_message, code=self._code)
def rebuild(self) -> ValSer | None:
if self._attempt_rebuild:
val_ser = self._attempt_rebuild()
if val_ser is not None:
return val_ser
else:
raise PydanticUserError(self._error_message, code=self._code)
return None
def set_model_mocks(cls: type[BaseModel], cls_name: str, undefined_name: str = 'all referenced types') -> None:
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a model.
Args:
cls: The model class to set the mocks on
cls_name: Name of the model class, used in error messages
undefined_name: Name of the undefined thing, used in error messages
"""
undefined_type_error_message = (
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
f' then call `{cls_name}.model_rebuild()`.'
)
def attempt_rebuild_validator() -> SchemaValidator | None:
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
return cls.__pydantic_validator__
else:
return None
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_validator,
)
def attempt_rebuild_serializer() -> SchemaSerializer | None:
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
return cls.__pydantic_serializer__
else:
return None
cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='serializer',
attempt_rebuild=attempt_rebuild_serializer,
)
def set_dataclass_mocks(
cls: type[PydanticDataclass], cls_name: str, undefined_name: str = 'all referenced types'
) -> None:
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a dataclass.
Args:
cls: The model class to set the mocks on
cls_name: Name of the model class, used in error messages
undefined_name: Name of the undefined thing, used in error messages
"""
from ..dataclasses import rebuild_dataclass
undefined_type_error_message = (
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
f' then call `pydantic.dataclasses.rebuild_dataclass({cls_name})`.'
)
def attempt_rebuild_validator() -> SchemaValidator | None:
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
return cls.__pydantic_validator__
else:
return None
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_validator,
)
def attempt_rebuild_serializer() -> SchemaSerializer | None:
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
return cls.__pydantic_serializer__
else:
return None
cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_serializer,
)

View file

@ -0,0 +1,637 @@
"""Private logic for creating models."""
from __future__ import annotations as _annotations
import operator
import typing
import warnings
import weakref
from abc import ABCMeta
from functools import partial
from types import FunctionType
from typing import Any, Callable, Generic
import typing_extensions
from pydantic_core import PydanticUndefined, SchemaSerializer
from typing_extensions import dataclass_transform, deprecated
from ..errors import PydanticUndefinedAnnotation, PydanticUserError
from ..plugin._schema_validator import create_schema_validator
from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20
from ._config import ConfigWrapper
from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
from ._generate_schema import GenerateSchema
from ._generics import PydanticGenericMetadata, get_model_typevars_map
from ._mock_val_ser import MockValSer, set_model_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._signature import generate_pydantic_signature
from ._typing_extra import get_cls_types_namespace, is_annotated, is_classvar, parent_frame_namespace
from ._utils import ClassAttribute, SafeGetItemProxy
from ._validate_call import ValidateCallWrapper
if typing.TYPE_CHECKING:
from ..fields import Field as PydanticModelField
from ..fields import FieldInfo, ModelPrivateAttr
from ..main import BaseModel
else:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
PydanticModelField = object()
object_setattr = object.__setattr__
class _ModelNamespaceDict(dict):
"""A dictionary subclass that intercepts attribute setting on model classes and
warns about overriding of decorators.
"""
def __setitem__(self, k: str, v: object) -> None:
existing: Any = self.get(k, None)
if existing and v is not existing and isinstance(existing, PydanticDescriptorProxy):
warnings.warn(f'`{k}` overrides an existing Pydantic `{existing.decorator_info.decorator_repr}` decorator')
return super().__setitem__(k, v)
@dataclass_transform(kw_only_default=True, field_specifiers=(PydanticModelField,))
class ModelMetaclass(ABCMeta):
def __new__(
mcs,
cls_name: str,
bases: tuple[type[Any], ...],
namespace: dict[str, Any],
__pydantic_generic_metadata__: PydanticGenericMetadata | None = None,
__pydantic_reset_parent_namespace__: bool = True,
_create_model_module: str | None = None,
**kwargs: Any,
) -> type:
"""Metaclass for creating Pydantic models.
Args:
cls_name: The name of the class to be created.
bases: The base classes of the class to be created.
namespace: The attribute dictionary of the class to be created.
__pydantic_generic_metadata__: Metadata for generic models.
__pydantic_reset_parent_namespace__: Reset parent namespace.
_create_model_module: The module of the class to be created, if created by `create_model`.
**kwargs: Catch-all for any other keyword arguments.
Returns:
The new class created by the metaclass.
"""
# Note `ModelMetaclass` refers to `BaseModel`, but is also used to *create* `BaseModel`, so we rely on the fact
# that `BaseModel` itself won't have any bases, but any subclass of it will, to determine whether the `__new__`
# call we're in the middle of is for the `BaseModel` class.
if bases:
base_field_names, class_vars, base_private_attributes = mcs._collect_bases_data(bases)
config_wrapper = ConfigWrapper.for_model(bases, namespace, kwargs)
namespace['model_config'] = config_wrapper.config_dict
private_attributes = inspect_namespace(
namespace, config_wrapper.ignored_types, class_vars, base_field_names
)
if private_attributes:
original_model_post_init = get_model_post_init(namespace, bases)
if original_model_post_init is not None:
# if there are private_attributes and a model_post_init function, we handle both
def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
"""We need to both initialize private attributes and call the user-defined model_post_init
method.
"""
init_private_attributes(self, __context)
original_model_post_init(self, __context)
namespace['model_post_init'] = wrapped_model_post_init
else:
namespace['model_post_init'] = init_private_attributes
namespace['__class_vars__'] = class_vars
namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes}
cls: type[BaseModel] = super().__new__(mcs, cls_name, bases, namespace, **kwargs) # type: ignore
from ..main import BaseModel
mro = cls.__mro__
if Generic in mro and mro.index(Generic) < mro.index(BaseModel):
warnings.warn(
GenericBeforeBaseModelWarning(
'Classes should inherit from `BaseModel` before generic classes (e.g. `typing.Generic[T]`) '
'for pydantic generics to work properly.'
),
stacklevel=2,
)
cls.__pydantic_custom_init__ = not getattr(cls.__init__, '__pydantic_base_init__', False)
cls.__pydantic_post_init__ = None if cls.model_post_init is BaseModel.model_post_init else 'model_post_init'
cls.__pydantic_decorators__ = DecoratorInfos.build(cls)
# Use the getattr below to grab the __parameters__ from the `typing.Generic` parent class
if __pydantic_generic_metadata__:
cls.__pydantic_generic_metadata__ = __pydantic_generic_metadata__
else:
parent_parameters = getattr(cls, '__pydantic_generic_metadata__', {}).get('parameters', ())
parameters = getattr(cls, '__parameters__', None) or parent_parameters
if parameters and parent_parameters and not all(x in parameters for x in parent_parameters):
combined_parameters = parent_parameters + tuple(x for x in parameters if x not in parent_parameters)
parameters_str = ', '.join([str(x) for x in combined_parameters])
generic_type_label = f'typing.Generic[{parameters_str}]'
error_message = (
f'All parameters must be present on typing.Generic;'
f' you should inherit from {generic_type_label}.'
)
if Generic not in bases: # pragma: no cover
# We raise an error here not because it is desirable, but because some cases are mishandled.
# It would be nice to remove this error and still have things behave as expected, it's just
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
# and not returning a typing._GenericAlias from it.
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
error_message += (
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
)
raise TypeError(error_message)
cls.__pydantic_generic_metadata__ = {
'origin': None,
'args': (),
'parameters': parameters,
}
cls.__pydantic_complete__ = False # Ensure this specific class gets completed
# preserve `__set_name__` protocol defined in https://peps.python.org/pep-0487
# for attributes not in `new_namespace` (e.g. private attributes)
for name, obj in private_attributes.items():
obj.__set_name__(cls, name)
if __pydantic_reset_parent_namespace__:
cls.__pydantic_parent_namespace__ = build_lenient_weakvaluedict(parent_frame_namespace())
parent_namespace = getattr(cls, '__pydantic_parent_namespace__', None)
if isinstance(parent_namespace, dict):
parent_namespace = unpack_lenient_weakvaluedict(parent_namespace)
types_namespace = get_cls_types_namespace(cls, parent_namespace)
set_model_fields(cls, bases, config_wrapper, types_namespace)
if config_wrapper.frozen and '__hash__' not in namespace:
set_default_hash_func(cls, bases)
complete_model_class(
cls,
cls_name,
config_wrapper,
raise_errors=False,
types_namespace=types_namespace,
create_model_module=_create_model_module,
)
# If this is placed before the complete_model_class call above,
# the generic computed fields return type is set to PydanticUndefined
cls.model_computed_fields = {k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()}
# using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__
# I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is
# only hit for _proper_ subclasses of BaseModel
super(cls, cls).__pydantic_init_subclass__(**kwargs) # type: ignore[misc]
return cls
else:
# this is the BaseModel class itself being created, no logic required
return super().__new__(mcs, cls_name, bases, namespace, **kwargs)
if not typing.TYPE_CHECKING: # pragma: no branch
# We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access
def __getattr__(self, item: str) -> Any:
"""This is necessary to keep attribute access working for class attribute access."""
private_attributes = self.__dict__.get('__private_attributes__')
if private_attributes and item in private_attributes:
return private_attributes[item]
if item == '__pydantic_core_schema__':
# This means the class didn't get a schema generated for it, likely because there was an undefined reference
maybe_mock_validator = getattr(self, '__pydantic_validator__', None)
if isinstance(maybe_mock_validator, MockValSer):
rebuilt_validator = maybe_mock_validator.rebuild()
if rebuilt_validator is not None:
# In this case, a validator was built, and so `__pydantic_core_schema__` should now be set
return getattr(self, '__pydantic_core_schema__')
raise AttributeError(item)
@classmethod
def __prepare__(cls, *args: Any, **kwargs: Any) -> dict[str, object]:
return _ModelNamespaceDict()
def __instancecheck__(self, instance: Any) -> bool:
"""Avoid calling ABC _abc_subclasscheck unless we're pretty sure.
See #3829 and python/cpython#92810
"""
return hasattr(instance, '__pydantic_validator__') and super().__instancecheck__(instance)
@staticmethod
def _collect_bases_data(bases: tuple[type[Any], ...]) -> tuple[set[str], set[str], dict[str, ModelPrivateAttr]]:
from ..main import BaseModel
field_names: set[str] = set()
class_vars: set[str] = set()
private_attributes: dict[str, ModelPrivateAttr] = {}
for base in bases:
if issubclass(base, BaseModel) and base is not BaseModel:
# model_fields might not be defined yet in the case of generics, so we use getattr here:
field_names.update(getattr(base, 'model_fields', {}).keys())
class_vars.update(base.__class_vars__)
private_attributes.update(base.__private_attributes__)
return field_names, class_vars, private_attributes
@property
@deprecated('The `__fields__` attribute is deprecated, use `model_fields` instead.', category=None)
def __fields__(self) -> dict[str, FieldInfo]:
warnings.warn(
'The `__fields__` attribute is deprecated, use `model_fields` instead.', PydanticDeprecatedSince20
)
return self.model_fields # type: ignore
def __dir__(self) -> list[str]:
attributes = list(super().__dir__())
if '__fields__' in attributes:
attributes.remove('__fields__')
return attributes
def init_private_attributes(self: BaseModel, __context: Any) -> None:
"""This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Args:
self: The BaseModel instance.
__context: The context.
"""
if getattr(self, '__pydantic_private__', None) is None:
pydantic_private = {}
for name, private_attr in self.__private_attributes__.items():
default = private_attr.get_default()
if default is not PydanticUndefined:
pydantic_private[name] = default
object_setattr(self, '__pydantic_private__', pydantic_private)
def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> Callable[..., Any] | None:
"""Get the `model_post_init` method from the namespace or the class bases, or `None` if not defined."""
if 'model_post_init' in namespace:
return namespace['model_post_init']
from ..main import BaseModel
model_post_init = get_attribute_from_bases(bases, 'model_post_init')
if model_post_init is not BaseModel.model_post_init:
return model_post_init
def inspect_namespace( # noqa C901
namespace: dict[str, Any],
ignored_types: tuple[type[Any], ...],
base_class_vars: set[str],
base_class_fields: set[str],
) -> dict[str, ModelPrivateAttr]:
"""Iterate over the namespace and:
* gather private attributes
* check for items which look like fields but are not (e.g. have no annotation) and warn.
Args:
namespace: The attribute dictionary of the class to be created.
ignored_types: A tuple of ignore types.
base_class_vars: A set of base class class variables.
base_class_fields: A set of base class fields.
Returns:
A dict contains private attributes info.
Raises:
TypeError: If there is a `__root__` field in model.
NameError: If private attribute name is invalid.
PydanticUserError:
- If a field does not have a type annotation.
- If a field on base class was overridden by a non-annotated attribute.
"""
from ..fields import FieldInfo, ModelPrivateAttr, PrivateAttr
all_ignored_types = ignored_types + default_ignored_types()
private_attributes: dict[str, ModelPrivateAttr] = {}
raw_annotations = namespace.get('__annotations__', {})
if '__root__' in raw_annotations or '__root__' in namespace:
raise TypeError("To define root models, use `pydantic.RootModel` rather than a field called '__root__'")
ignored_names: set[str] = set()
for var_name, value in list(namespace.items()):
if var_name == 'model_config':
continue
elif (
isinstance(value, type)
and value.__module__ == namespace['__module__']
and value.__qualname__.startswith(namespace['__qualname__'])
):
# `value` is a nested type defined in this namespace; don't error
continue
elif isinstance(value, all_ignored_types) or value.__class__.__module__ == 'functools':
ignored_names.add(var_name)
continue
elif isinstance(value, ModelPrivateAttr):
if var_name.startswith('__'):
raise NameError(
'Private attributes must not use dunder names;'
f' use a single underscore prefix instead of {var_name!r}.'
)
elif is_valid_field_name(var_name):
raise NameError(
'Private attributes must not use valid field names;'
f' use sunder names, e.g. {"_" + var_name!r} instead of {var_name!r}.'
)
private_attributes[var_name] = value
del namespace[var_name]
elif isinstance(value, FieldInfo) and not is_valid_field_name(var_name):
suggested_name = var_name.lstrip('_') or 'my_field' # don't suggest '' for all-underscore name
raise NameError(
f'Fields must not use names with leading underscores;'
f' e.g., use {suggested_name!r} instead of {var_name!r}.'
)
elif var_name.startswith('__'):
continue
elif is_valid_privateattr_name(var_name):
if var_name not in raw_annotations or not is_classvar(raw_annotations[var_name]):
private_attributes[var_name] = PrivateAttr(default=value)
del namespace[var_name]
elif var_name in base_class_vars:
continue
elif var_name not in raw_annotations:
if var_name in base_class_fields:
raise PydanticUserError(
f'Field {var_name!r} defined on a base class was overridden by a non-annotated attribute. '
f'All field definitions, including overrides, require a type annotation.',
code='model-field-overridden',
)
elif isinstance(value, FieldInfo):
raise PydanticUserError(
f'Field {var_name!r} requires a type annotation', code='model-field-missing-annotation'
)
else:
raise PydanticUserError(
f'A non-annotated attribute was detected: `{var_name} = {value!r}`. All model fields require a '
f'type annotation; if `{var_name}` is not meant to be a field, you may be able to resolve this '
f"error by annotating it as a `ClassVar` or updating `model_config['ignored_types']`.",
code='model-field-missing-annotation',
)
for ann_name, ann_type in raw_annotations.items():
if (
is_valid_privateattr_name(ann_name)
and ann_name not in private_attributes
and ann_name not in ignored_names
and not is_classvar(ann_type)
and ann_type not in all_ignored_types
and getattr(ann_type, '__module__', None) != 'functools'
):
if is_annotated(ann_type):
_, *metadata = typing_extensions.get_args(ann_type)
private_attr = next((v for v in metadata if isinstance(v, ModelPrivateAttr)), None)
if private_attr is not None:
private_attributes[ann_name] = private_attr
continue
private_attributes[ann_name] = PrivateAttr()
return private_attributes
def set_default_hash_func(cls: type[BaseModel], bases: tuple[type[Any], ...]) -> None:
base_hash_func = get_attribute_from_bases(bases, '__hash__')
new_hash_func = make_hash_func(cls)
if base_hash_func in {None, object.__hash__} or getattr(base_hash_func, '__code__', None) == new_hash_func.__code__:
# If `__hash__` is some default, we generate a hash function.
# It will be `None` if not overridden from BaseModel.
# It may be `object.__hash__` if there is another
# parent class earlier in the bases which doesn't override `__hash__` (e.g. `typing.Generic`).
# It may be a value set by `set_default_hash_func` if `cls` is a subclass of another frozen model.
# In the last case we still need a new hash function to account for new `model_fields`.
cls.__hash__ = new_hash_func
def make_hash_func(cls: type[BaseModel]) -> Any:
getter = operator.itemgetter(*cls.model_fields.keys()) if cls.model_fields else lambda _: 0
def hash_func(self: Any) -> int:
try:
return hash(getter(self.__dict__))
except KeyError:
# In rare cases (such as when using the deprecated copy method), the __dict__ may not contain
# all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys,
# and wrapping it in a `try` doesn't slow things down much in the common case.
return hash(getter(SafeGetItemProxy(self.__dict__)))
return hash_func
def set_model_fields(
cls: type[BaseModel], bases: tuple[type[Any], ...], config_wrapper: ConfigWrapper, types_namespace: dict[str, Any]
) -> None:
"""Collect and set `cls.model_fields` and `cls.__class_vars__`.
Args:
cls: BaseModel or dataclass.
bases: Parents of the class, generally `cls.__bases__`.
config_wrapper: The config wrapper instance.
types_namespace: Optional extra namespace to look for types in.
"""
typevars_map = get_model_typevars_map(cls)
fields, class_vars = collect_model_fields(cls, bases, config_wrapper, types_namespace, typevars_map=typevars_map)
cls.model_fields = fields
cls.__class_vars__.update(class_vars)
for k in class_vars:
# Class vars should not be private attributes
# We remove them _here_ and not earlier because we rely on inspecting the class to determine its classvars,
# but private attributes are determined by inspecting the namespace _prior_ to class creation.
# In the case that a classvar with a leading-'_' is defined via a ForwardRef (e.g., when using
# `__future__.annotations`), we want to remove the private attribute which was detected _before_ we knew it
# evaluated to a classvar
value = cls.__private_attributes__.pop(k, None)
if value is not None and value.default is not PydanticUndefined:
setattr(cls, k, value.default)
def complete_model_class(
cls: type[BaseModel],
cls_name: str,
config_wrapper: ConfigWrapper,
*,
raise_errors: bool = True,
types_namespace: dict[str, Any] | None,
create_model_module: str | None = None,
) -> bool:
"""Finish building a model class.
This logic must be called after class has been created since validation functions must be bound
and `get_type_hints` requires a class object.
Args:
cls: BaseModel or dataclass.
cls_name: The model or dataclass name.
config_wrapper: The config wrapper instance.
raise_errors: Whether to raise errors.
types_namespace: Optional extra namespace to look for types in.
create_model_module: The module of the class to be created, if created by `create_model`.
Returns:
`True` if the model is successfully completed, else `False`.
Raises:
PydanticUndefinedAnnotation: If `PydanticUndefinedAnnotation` occurs in`__get_pydantic_core_schema__`
and `raise_errors=True`.
"""
typevars_map = get_model_typevars_map(cls)
gen_schema = GenerateSchema(
config_wrapper,
types_namespace,
typevars_map,
)
handler = CallbackGetCoreSchemaHandler(
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
gen_schema,
ref_mode='unpack',
)
if config_wrapper.defer_build:
set_model_mocks(cls, cls_name)
return False
try:
schema = cls.__get_pydantic_core_schema__(cls, handler)
except PydanticUndefinedAnnotation as e:
if raise_errors:
raise
set_model_mocks(cls, cls_name, f'`{e.name}`')
return False
core_config = config_wrapper.core_config(cls)
try:
schema = gen_schema.clean_schema(schema)
except gen_schema.CollectedInvalid:
set_model_mocks(cls, cls_name)
return False
# debug(schema)
cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = create_schema_validator(
schema,
cls,
create_model_module or cls.__module__,
cls.__qualname__,
'create_model' if create_model_module else 'BaseModel',
core_config,
config_wrapper.plugin_settings,
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
cls.__pydantic_complete__ = True
# set __signature__ attr only for model class, but not for its instances
cls.__signature__ = ClassAttribute(
'__signature__',
generate_pydantic_signature(init=cls.__init__, fields=cls.model_fields, config_wrapper=config_wrapper),
)
return True
class _PydanticWeakRef:
"""Wrapper for `weakref.ref` that enables `pickle` serialization.
Cloudpickle fails to serialize `weakref.ref` objects due to an arcane error related
to abstract base classes (`abc.ABC`). This class works around the issue by wrapping
`weakref.ref` instead of subclassing it.
See https://github.com/pydantic/pydantic/issues/6763 for context.
Semantics:
- If not pickled, behaves the same as a `weakref.ref`.
- If pickled along with the referenced object, the same `weakref.ref` behavior
will be maintained between them after unpickling.
- If pickled without the referenced object, after unpickling the underlying
reference will be cleared (`__call__` will always return `None`).
"""
def __init__(self, obj: Any):
if obj is None:
# The object will be `None` upon deserialization if the serialized weakref
# had lost its underlying object.
self._wr = None
else:
self._wr = weakref.ref(obj)
def __call__(self) -> Any:
if self._wr is None:
return None
else:
return self._wr()
def __reduce__(self) -> tuple[Callable, tuple[weakref.ReferenceType | None]]:
return _PydanticWeakRef, (self(),)
def build_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None:
"""Takes an input dictionary, and produces a new value that (invertibly) replaces the values with weakrefs.
We can't just use a WeakValueDictionary because many types (including int, str, etc.) can't be stored as values
in a WeakValueDictionary.
The `unpack_lenient_weakvaluedict` function can be used to reverse this operation.
"""
if d is None:
return None
result = {}
for k, v in d.items():
try:
proxy = _PydanticWeakRef(v)
except TypeError:
proxy = v
result[k] = proxy
return result
def unpack_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None:
"""Inverts the transform performed by `build_lenient_weakvaluedict`."""
if d is None:
return None
result = {}
for k, v in d.items():
if isinstance(v, _PydanticWeakRef):
v = v()
if v is not None:
result[k] = v
else:
result[k] = v
return result
def default_ignored_types() -> tuple[type[Any], ...]:
from ..fields import ComputedFieldInfo
return (
FunctionType,
property,
classmethod,
staticmethod,
PydanticDescriptorProxy,
ComputedFieldInfo,
ValidateCallWrapper,
)

View file

@ -0,0 +1,117 @@
"""Tools to provide pretty/human-readable display of objects."""
from __future__ import annotations as _annotations
import types
import typing
from typing import Any
import typing_extensions
from . import _typing_extra
if typing.TYPE_CHECKING:
ReprArgs: typing_extensions.TypeAlias = 'typing.Iterable[tuple[str | None, Any]]'
RichReprResult: typing_extensions.TypeAlias = (
'typing.Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]'
)
class PlainRepr(str):
"""String class where repr doesn't include quotes. Useful with Representation when you want to return a string
representation of something that is valid (or pseudo-valid) python.
"""
def __repr__(self) -> str:
return str(self)
class Representation:
# Mixin to provide `__str__`, `__repr__`, and `__pretty__` and `__rich_repr__` methods.
# `__pretty__` is used by [devtools](https://python-devtools.helpmanual.io/).
# `__rich_repr__` is used by [rich](https://rich.readthedocs.io/en/stable/pretty.html).
# (this is not a docstring to avoid adding a docstring to classes which inherit from Representation)
# we don't want to use a type annotation here as it can break get_type_hints
__slots__ = tuple() # type: typing.Collection[str]
def __repr_args__(self) -> ReprArgs:
"""Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
Can either return:
* name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
* or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
"""
attrs_names = self.__slots__
if not attrs_names and hasattr(self, '__dict__'):
attrs_names = self.__dict__.keys()
attrs = ((s, getattr(self, s)) for s in attrs_names)
return [(a, v) for a, v in attrs if v is not None]
def __repr_name__(self) -> str:
"""Name of the instance's class, used in __repr__."""
return self.__class__.__name__
def __repr_str__(self, join_str: str) -> str:
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
def __pretty__(self, fmt: typing.Callable[[Any], Any], **kwargs: Any) -> typing.Generator[Any, None, None]:
"""Used by devtools (https://python-devtools.helpmanual.io/) to pretty print objects."""
yield self.__repr_name__() + '('
yield 1
for name, value in self.__repr_args__():
if name is not None:
yield name + '='
yield fmt(value)
yield ','
yield 0
yield -1
yield ')'
def __rich_repr__(self) -> RichReprResult:
"""Used by Rich (https://rich.readthedocs.io/en/stable/pretty.html) to pretty print objects."""
for name, field_repr in self.__repr_args__():
if name is None:
yield field_repr
else:
yield name, field_repr
def __str__(self) -> str:
return self.__repr_str__(' ')
def __repr__(self) -> str:
return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
def display_as_type(obj: Any) -> str:
"""Pretty representation of a type, should be as close as possible to the original type definition string.
Takes some logic from `typing._type_repr`.
"""
if isinstance(obj, types.FunctionType):
return obj.__name__
elif obj is ...:
return '...'
elif isinstance(obj, Representation):
return repr(obj)
elif isinstance(obj, typing_extensions.TypeAliasType):
return str(obj)
if not isinstance(obj, (_typing_extra.typing_base, _typing_extra.WithArgsTypes, type)):
obj = obj.__class__
if _typing_extra.origin_is_union(typing_extensions.get_origin(obj)):
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
return f'Union[{args}]'
elif isinstance(obj, _typing_extra.WithArgsTypes):
if typing_extensions.get_origin(obj) == typing_extensions.Literal:
args = ', '.join(map(repr, typing_extensions.get_args(obj)))
else:
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
try:
return f'{obj.__qualname__}[{args}]'
except AttributeError:
return str(obj) # handles TypeAliasType in 3.12
elif isinstance(obj, type):
return obj.__qualname__
else:
return repr(obj).replace('typing.', '').replace('typing_extensions.', '')

View file

@ -0,0 +1,124 @@
"""Types and utility functions used by various other internal tools."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable
from pydantic_core import core_schema
from typing_extensions import Literal
from ..annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
if TYPE_CHECKING:
from ..json_schema import GenerateJsonSchema, JsonSchemaValue
from ._core_utils import CoreSchemaOrField
from ._generate_schema import GenerateSchema
GetJsonSchemaFunction = Callable[[CoreSchemaOrField, GetJsonSchemaHandler], JsonSchemaValue]
HandlerOverride = Callable[[CoreSchemaOrField], JsonSchemaValue]
class GenerateJsonSchemaHandler(GetJsonSchemaHandler):
"""JsonSchemaHandler implementation that doesn't do ref unwrapping by default.
This is used for any Annotated metadata so that we don't end up with conflicting
modifications to the definition schema.
Used internally by Pydantic, please do not rely on this implementation.
See `GetJsonSchemaHandler` for the handler API.
"""
def __init__(self, generate_json_schema: GenerateJsonSchema, handler_override: HandlerOverride | None) -> None:
self.generate_json_schema = generate_json_schema
self.handler = handler_override or generate_json_schema.generate_inner
self.mode = generate_json_schema.mode
def __call__(self, __core_schema: CoreSchemaOrField) -> JsonSchemaValue:
return self.handler(__core_schema)
def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue:
"""Resolves `$ref` in the json schema.
This returns the input json schema if there is no `$ref` in json schema.
Args:
maybe_ref_json_schema: The input json schema that may contains `$ref`.
Returns:
Resolved json schema.
Raises:
LookupError: If it can't find the definition for `$ref`.
"""
if '$ref' not in maybe_ref_json_schema:
return maybe_ref_json_schema
ref = maybe_ref_json_schema['$ref']
json_schema = self.generate_json_schema.get_schema_from_definitions(ref)
if json_schema is None:
raise LookupError(
f'Could not find a ref for {ref}.'
' Maybe you tried to call resolve_ref_schema from within a recursive model?'
)
return json_schema
class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler):
"""Wrapper to use an arbitrary function as a `GetCoreSchemaHandler`.
Used internally by Pydantic, please do not rely on this implementation.
See `GetCoreSchemaHandler` for the handler API.
"""
def __init__(
self,
handler: Callable[[Any], core_schema.CoreSchema],
generate_schema: GenerateSchema,
ref_mode: Literal['to-def', 'unpack'] = 'to-def',
) -> None:
self._handler = handler
self._generate_schema = generate_schema
self._ref_mode = ref_mode
def __call__(self, __source_type: Any) -> core_schema.CoreSchema:
schema = self._handler(__source_type)
ref = schema.get('ref')
if self._ref_mode == 'to-def':
if ref is not None:
self._generate_schema.defs.definitions[ref] = schema
return core_schema.definition_reference_schema(ref)
return schema
else: # ref_mode = 'unpack
return self.resolve_ref_schema(schema)
def _get_types_namespace(self) -> dict[str, Any] | None:
return self._generate_schema._types_namespace
def generate_schema(self, __source_type: Any) -> core_schema.CoreSchema:
return self._generate_schema.generate_schema(__source_type)
@property
def field_name(self) -> str | None:
return self._generate_schema.field_name_stack.get()
def resolve_ref_schema(self, maybe_ref_schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""Resolves reference in the core schema.
Args:
maybe_ref_schema: The input core schema that may contains reference.
Returns:
Resolved core schema.
Raises:
LookupError: If it can't find the definition for reference.
"""
if maybe_ref_schema['type'] == 'definition-ref':
ref = maybe_ref_schema['schema_ref']
if ref not in self._generate_schema.defs.definitions:
raise LookupError(
f'Could not find a ref for {ref}.'
' Maybe you tried to call resolve_ref_schema from within a recursive model?'
)
return self._generate_schema.defs.definitions[ref]
elif maybe_ref_schema['type'] == 'definitions':
return self.resolve_ref_schema(maybe_ref_schema['schema'])
return maybe_ref_schema

View file

@ -0,0 +1,164 @@
from __future__ import annotations
import dataclasses
from inspect import Parameter, Signature, signature
from typing import TYPE_CHECKING, Any, Callable
from pydantic_core import PydanticUndefined
from ._config import ConfigWrapper
from ._utils import is_valid_identifier
if TYPE_CHECKING:
from ..fields import FieldInfo
def _field_name_for_signature(field_name: str, field_info: FieldInfo) -> str:
"""Extract the correct name to use for the field when generating a signature.
Assuming the field has a valid alias, this will return the alias. Otherwise, it will return the field name.
First priority is given to the validation_alias, then the alias, then the field name.
Args:
field_name: The name of the field
field_info: The corresponding FieldInfo object.
Returns:
The correct name to use when generating a signature.
"""
def _alias_if_valid(x: Any) -> str | None:
"""Return the alias if it is a valid alias and identifier, else None."""
return x if isinstance(x, str) and is_valid_identifier(x) else None
return _alias_if_valid(field_info.alias) or _alias_if_valid(field_info.validation_alias) or field_name
def _process_param_defaults(param: Parameter) -> Parameter:
"""Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.
Args:
param (Parameter): The parameter
Returns:
Parameter: The custom processed parameter
"""
from ..fields import FieldInfo
param_default = param.default
if isinstance(param_default, FieldInfo):
annotation = param.annotation
# Replace the annotation if appropriate
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if annotation == 'Any':
annotation = Any
# Replace the field default
default = param_default.default
if default is PydanticUndefined:
if param_default.default_factory is PydanticUndefined:
default = Signature.empty
else:
# this is used by dataclasses to indicate a factory exists:
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
return param.replace(
annotation=annotation, name=_field_name_for_signature(param.name, param_default), default=default
)
return param
def _generate_signature_parameters( # noqa: C901 (ignore complexity, could use a refactor)
init: Callable[..., None],
fields: dict[str, FieldInfo],
config_wrapper: ConfigWrapper,
) -> dict[str, Parameter]:
"""Generate a mapping of parameter names to Parameter objects for a pydantic BaseModel or dataclass."""
from itertools import islice
present_params = signature(init).parameters.values()
merged_params: dict[str, Parameter] = {}
var_kw = None
use_var_kw = False
for param in islice(present_params, 1, None): # skip self arg
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if fields.get(param.name):
# exclude params with init=False
if getattr(fields[param.name], 'init', True) is False:
continue
param = param.replace(name=_field_name_for_signature(param.name, fields[param.name]))
if param.annotation == 'Any':
param = param.replace(annotation=Any)
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = param
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = config_wrapper.populate_by_name
for field_name, field in fields.items():
# when alias is a str it should be used for signature generation
param_name = _field_name_for_signature(field_name, field)
if field_name in merged_params or param_name in merged_params:
continue
if not is_valid_identifier(param_name):
if allow_names:
param_name = field_name
else:
use_var_kw = True
continue
kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
merged_params[param_name] = Parameter(
param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs
)
if config_wrapper.extra == 'allow':
use_var_kw = True
if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('self', Parameter.POSITIONAL_ONLY),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
# else start from var_kw
var_kw_name = var_kw.name
# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
return merged_params
def generate_pydantic_signature(
init: Callable[..., None], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper, is_dataclass: bool = False
) -> Signature:
"""Generate signature for a pydantic BaseModel or dataclass.
Args:
init: The class init.
fields: The model fields.
config_wrapper: The config wrapper instance.
is_dataclass: Whether the model is a dataclass.
Returns:
The dataclass/BaseModel subclass signature.
"""
merged_params = _generate_signature_parameters(init, fields, config_wrapper)
if is_dataclass:
merged_params = {k: _process_param_defaults(v) for k, v in merged_params.items()}
return Signature(parameters=list(merged_params.values()), return_annotation=None)

View file

@ -0,0 +1,714 @@
"""Logic for generating pydantic-core schemas for standard library types.
Import of this module is deferred since it contains imports of many standard library modules.
"""
from __future__ import annotations as _annotations
import collections
import collections.abc
import dataclasses
import decimal
import inspect
import os
import typing
from enum import Enum
from functools import partial
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from typing import Any, Callable, Iterable, TypeVar
import typing_extensions
from pydantic_core import (
CoreSchema,
MultiHostUrl,
PydanticCustomError,
PydanticOmit,
Url,
core_schema,
)
from typing_extensions import get_args, get_origin
from pydantic.errors import PydanticSchemaGenerationError
from pydantic.fields import FieldInfo
from pydantic.types import Strict
from ..config import ConfigDict
from ..json_schema import JsonSchemaValue, update_json_schema
from . import _known_annotated_metadata, _typing_extra, _validators
from ._core_utils import get_type_ref
from ._internal_dataclass import slots_true
from ._schema_generation_shared import GetCoreSchemaHandler, GetJsonSchemaHandler
if typing.TYPE_CHECKING:
from ._generate_schema import GenerateSchema
StdSchemaFunction = Callable[[GenerateSchema, type[Any]], core_schema.CoreSchema]
@dataclasses.dataclass(**slots_true)
class SchemaTransformer:
get_core_schema: Callable[[Any, GetCoreSchemaHandler], CoreSchema]
get_json_schema: Callable[[CoreSchema, GetJsonSchemaHandler], JsonSchemaValue]
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
return self.get_core_schema(source_type, handler)
def __get_pydantic_json_schema__(self, schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
return self.get_json_schema(schema, handler)
def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchema:
cases: list[Any] = list(enum_type.__members__.values())
enum_ref = get_type_ref(enum_type)
description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
if description == 'An enumeration.': # This is the default value provided by enum.EnumMeta.__new__; don't use it
description = None
updates = {'title': enum_type.__name__, 'description': description}
updates = {k: v for k, v in updates.items() if v is not None}
def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
json_schema = handler(core_schema.literal_schema([x.value for x in cases], ref=enum_ref))
original_schema = handler.resolve_ref_schema(json_schema)
update_json_schema(original_schema, updates)
return json_schema
if not cases:
# Use an isinstance check for enums with no cases.
# The most important use case for this is creating TypeVar bounds for generics that should
# be restricted to enums. This is more consistent than it might seem at first, since you can only
# subclass enum.Enum (or subclasses of enum.Enum) if all parent classes have no cases.
# We use the get_json_schema function when an Enum subclass has been declared with no cases
# so that we can still generate a valid json schema.
return core_schema.is_instance_schema(enum_type, metadata={'pydantic_js_functions': [get_json_schema]})
use_enum_values = config.get('use_enum_values', False)
if len(cases) == 1:
expected = repr(cases[0].value)
else:
expected = ', '.join([repr(case.value) for case in cases[:-1]]) + f' or {cases[-1].value!r}'
def to_enum(__input_value: Any) -> Enum:
try:
enum_field = enum_type(__input_value)
if use_enum_values:
return enum_field.value
return enum_field
except ValueError:
# The type: ignore on the next line is to ignore the requirement of LiteralString
raise PydanticCustomError('enum', f'Input should be {expected}', {'expected': expected}) # type: ignore
strict_python_schema = core_schema.is_instance_schema(enum_type)
if use_enum_values:
strict_python_schema = core_schema.chain_schema(
[strict_python_schema, core_schema.no_info_plain_validator_function(lambda x: x.value)]
)
to_enum_validator = core_schema.no_info_plain_validator_function(to_enum)
if issubclass(enum_type, int):
# this handles `IntEnum`, and also `Foobar(int, Enum)`
updates['type'] = 'integer'
lax = core_schema.chain_schema([core_schema.int_schema(), to_enum_validator])
# Disallow float from JSON due to strict mode
strict = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.int_schema()),
python_schema=strict_python_schema,
)
elif issubclass(enum_type, str):
# this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)`
updates['type'] = 'string'
lax = core_schema.chain_schema([core_schema.str_schema(), to_enum_validator])
strict = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.str_schema()),
python_schema=strict_python_schema,
)
elif issubclass(enum_type, float):
updates['type'] = 'numeric'
lax = core_schema.chain_schema([core_schema.float_schema(), to_enum_validator])
strict = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.float_schema()),
python_schema=strict_python_schema,
)
else:
lax = to_enum_validator
strict = core_schema.json_or_python_schema(json_schema=to_enum_validator, python_schema=strict_python_schema)
return core_schema.lax_or_strict_schema(
lax_schema=lax, strict_schema=strict, ref=enum_ref, metadata={'pydantic_js_functions': [get_json_schema]}
)
@dataclasses.dataclass(**slots_true)
class InnerSchemaValidator:
"""Use a fixed CoreSchema, avoiding interference from outward annotations."""
core_schema: CoreSchema
js_schema: JsonSchemaValue | None = None
js_core_schema: CoreSchema | None = None
js_schema_update: JsonSchemaValue | None = None
def __get_pydantic_json_schema__(self, _schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
if self.js_schema is not None:
return self.js_schema
js_schema = handler(self.js_core_schema or self.core_schema)
if self.js_schema_update is not None:
js_schema.update(self.js_schema_update)
return js_schema
def __get_pydantic_core_schema__(self, _source_type: Any, _handler: GetCoreSchemaHandler) -> CoreSchema:
return self.core_schema
def decimal_prepare_pydantic_annotations(
source: Any, annotations: Iterable[Any], config: ConfigDict
) -> tuple[Any, list[Any]] | None:
if source is not decimal.Decimal:
return None
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
config_allow_inf_nan = config.get('allow_inf_nan')
if config_allow_inf_nan is not None:
metadata.setdefault('allow_inf_nan', config_allow_inf_nan)
_known_annotated_metadata.check_metadata(
metadata, {*_known_annotated_metadata.FLOAT_CONSTRAINTS, 'max_digits', 'decimal_places'}, decimal.Decimal
)
return source, [InnerSchemaValidator(core_schema.decimal_schema(**metadata)), *remaining_annotations]
def datetime_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
import datetime
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
if source_type is datetime.date:
sv = InnerSchemaValidator(core_schema.date_schema(**metadata))
elif source_type is datetime.datetime:
sv = InnerSchemaValidator(core_schema.datetime_schema(**metadata))
elif source_type is datetime.time:
sv = InnerSchemaValidator(core_schema.time_schema(**metadata))
elif source_type is datetime.timedelta:
sv = InnerSchemaValidator(core_schema.timedelta_schema(**metadata))
else:
return None
# check now that we know the source type is correct
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.DATE_TIME_CONSTRAINTS, source_type)
return (source_type, [sv, *remaining_annotations])
def uuid_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
# UUIDs have no constraints - they are fixed length, constructing a UUID instance checks the length
from uuid import UUID
if source_type is not UUID:
return None
return (source_type, [InnerSchemaValidator(core_schema.uuid_schema()), *annotations])
def path_schema_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
import pathlib
if source_type not in {
os.PathLike,
pathlib.Path,
pathlib.PurePath,
pathlib.PosixPath,
pathlib.PurePosixPath,
pathlib.PureWindowsPath,
}:
return None
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.STR_CONSTRAINTS, source_type)
construct_path = pathlib.PurePath if source_type is os.PathLike else source_type
def path_validator(input_value: str) -> os.PathLike[Any]:
try:
return construct_path(input_value)
except TypeError as e:
raise PydanticCustomError('path_type', 'Input is not a valid path') from e
constrained_str_schema = core_schema.str_schema(**metadata)
instance_schema = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(path_validator, constrained_str_schema),
python_schema=core_schema.is_instance_schema(source_type),
)
strict: bool | None = None
for annotation in annotations:
if isinstance(annotation, Strict):
strict = annotation.strict
schema = core_schema.lax_or_strict_schema(
lax_schema=core_schema.union_schema(
[
instance_schema,
core_schema.no_info_after_validator_function(path_validator, constrained_str_schema),
],
custom_error_type='path_type',
custom_error_message='Input is not a valid path',
strict=True,
),
strict_schema=instance_schema,
serialization=core_schema.to_string_ser_schema(),
strict=strict,
)
return (
source_type,
[
InnerSchemaValidator(schema, js_core_schema=constrained_str_schema, js_schema_update={'format': 'path'}),
*remaining_annotations,
],
)
def dequeue_validator(
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, maxlen: None | int
) -> collections.deque[Any]:
if isinstance(input_value, collections.deque):
maxlens = [v for v in (input_value.maxlen, maxlen) if v is not None]
if maxlens:
maxlen = min(maxlens)
return collections.deque(handler(input_value), maxlen=maxlen)
else:
return collections.deque(handler(input_value), maxlen=maxlen)
@dataclasses.dataclass(**slots_true)
class SequenceValidator:
mapped_origin: type[Any]
item_source_type: type[Any]
min_length: int | None = None
max_length: int | None = None
strict: bool = False
def serialize_sequence_via_list(
self, v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo
) -> Any:
items: list[Any] = []
for index, item in enumerate(v):
try:
v = handler(item, index)
except PydanticOmit:
pass
else:
items.append(v)
if info.mode_is_json():
return items
else:
return self.mapped_origin(items)
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
if self.item_source_type is Any:
items_schema = None
else:
items_schema = handler.generate_schema(self.item_source_type)
metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict}
if self.mapped_origin in (list, set, frozenset):
if self.mapped_origin is list:
constrained_schema = core_schema.list_schema(items_schema, **metadata)
elif self.mapped_origin is set:
constrained_schema = core_schema.set_schema(items_schema, **metadata)
else:
assert self.mapped_origin is frozenset # safety check in case we forget to add a case
constrained_schema = core_schema.frozenset_schema(items_schema, **metadata)
schema = constrained_schema
else:
# safety check in case we forget to add a case
assert self.mapped_origin in (collections.deque, collections.Counter)
if self.mapped_origin is collections.deque:
# if we have a MaxLen annotation might as well set that as the default maxlen on the deque
# this lets us re-use existing metadata annotations to let users set the maxlen on a dequeue
# that e.g. comes from JSON
coerce_instance_wrap = partial(
core_schema.no_info_wrap_validator_function,
partial(dequeue_validator, maxlen=metadata.get('max_length', None)),
)
else:
coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin)
constrained_schema = core_schema.list_schema(items_schema, **metadata)
check_instance = core_schema.json_or_python_schema(
json_schema=core_schema.list_schema(),
python_schema=core_schema.is_instance_schema(self.mapped_origin),
)
serialization = core_schema.wrap_serializer_function_ser_schema(
self.serialize_sequence_via_list, schema=items_schema or core_schema.any_schema(), info_arg=True
)
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
if metadata.get('strict', False):
schema = strict
else:
lax = coerce_instance_wrap(constrained_schema)
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
schema['serialization'] = serialization
return schema
SEQUENCE_ORIGIN_MAP: dict[Any, Any] = {
typing.Deque: collections.deque,
collections.deque: collections.deque,
list: list,
typing.List: list,
set: set,
typing.AbstractSet: set,
typing.Set: set,
frozenset: frozenset,
typing.FrozenSet: frozenset,
typing.Sequence: list,
typing.MutableSequence: list,
typing.MutableSet: set,
# this doesn't handle subclasses of these
# parametrized typing.Set creates one of these
collections.abc.MutableSet: set,
collections.abc.Set: frozenset,
}
def identity(s: CoreSchema) -> CoreSchema:
return s
def sequence_like_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
origin: Any = get_origin(source_type)
mapped_origin = SEQUENCE_ORIGIN_MAP.get(origin, None) if origin else SEQUENCE_ORIGIN_MAP.get(source_type, None)
if mapped_origin is None:
return None
args = get_args(source_type)
if not args:
args = (Any,)
elif len(args) != 1:
raise ValueError('Expected sequence to have exactly 1 generic parameter')
item_source_type = args[0]
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
return (source_type, [SequenceValidator(mapped_origin, item_source_type, **metadata), *remaining_annotations])
MAPPING_ORIGIN_MAP: dict[Any, Any] = {
typing.DefaultDict: collections.defaultdict,
collections.defaultdict: collections.defaultdict,
collections.OrderedDict: collections.OrderedDict,
typing_extensions.OrderedDict: collections.OrderedDict,
dict: dict,
typing.Dict: dict,
collections.Counter: collections.Counter,
typing.Counter: collections.Counter,
# this doesn't handle subclasses of these
typing.Mapping: dict,
typing.MutableMapping: dict,
# parametrized typing.{Mutable}Mapping creates one of these
collections.abc.MutableMapping: dict,
collections.abc.Mapping: dict,
}
def defaultdict_validator(
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
) -> collections.defaultdict[Any, Any]:
if isinstance(input_value, collections.defaultdict):
default_factory = input_value.default_factory
return collections.defaultdict(default_factory, handler(input_value))
else:
return collections.defaultdict(default_default_factory, handler(input_value))
def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
def infer_default() -> Callable[[], Any]:
allowed_default_types: dict[Any, Any] = {
typing.Tuple: tuple,
tuple: tuple,
collections.abc.Sequence: tuple,
collections.abc.MutableSequence: list,
typing.List: list,
list: list,
typing.Sequence: list,
typing.Set: set,
set: set,
typing.MutableSet: set,
collections.abc.MutableSet: set,
collections.abc.Set: frozenset,
typing.MutableMapping: dict,
typing.Mapping: dict,
collections.abc.Mapping: dict,
collections.abc.MutableMapping: dict,
float: float,
int: int,
str: str,
bool: bool,
}
values_type_origin = get_origin(values_source_type) or values_source_type
instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
if isinstance(values_type_origin, TypeVar):
def type_var_default_factory() -> None:
raise RuntimeError(
'Generic defaultdict cannot be used without a concrete value type or an'
' explicit default factory, ' + instructions
)
return type_var_default_factory
elif values_type_origin not in allowed_default_types:
# a somewhat subjective set of types that have reasonable default values
allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
raise PydanticSchemaGenerationError(
f'Unable to infer a default factory for keys of type {values_source_type}.'
f' Only {allowed_msg} are supported, other types require an explicit default factory'
' ' + instructions
)
return allowed_default_types[values_type_origin]
# Assume Annotated[..., Field(...)]
if _typing_extra.is_annotated(values_source_type):
field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None)
else:
field_info = None
if field_info and field_info.default_factory:
default_default_factory = field_info.default_factory
else:
default_default_factory = infer_default()
return default_default_factory
@dataclasses.dataclass(**slots_true)
class MappingValidator:
mapped_origin: type[Any]
keys_source_type: type[Any]
values_source_type: type[Any]
min_length: int | None = None
max_length: int | None = None
strict: bool = False
def serialize_mapping_via_dict(self, v: Any, handler: core_schema.SerializerFunctionWrapHandler) -> Any:
return handler(v)
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
if self.keys_source_type is Any:
keys_schema = None
else:
keys_schema = handler.generate_schema(self.keys_source_type)
if self.values_source_type is Any:
values_schema = None
else:
values_schema = handler.generate_schema(self.values_source_type)
metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict}
if self.mapped_origin is dict:
schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
else:
constrained_schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
check_instance = core_schema.json_or_python_schema(
json_schema=core_schema.dict_schema(),
python_schema=core_schema.is_instance_schema(self.mapped_origin),
)
if self.mapped_origin is collections.defaultdict:
default_default_factory = get_defaultdict_default_default_factory(self.values_source_type)
coerce_instance_wrap = partial(
core_schema.no_info_wrap_validator_function,
partial(defaultdict_validator, default_default_factory=default_default_factory),
)
else:
coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin)
serialization = core_schema.wrap_serializer_function_ser_schema(
self.serialize_mapping_via_dict,
schema=core_schema.dict_schema(
keys_schema or core_schema.any_schema(), values_schema or core_schema.any_schema()
),
info_arg=False,
)
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
if metadata.get('strict', False):
schema = strict
else:
lax = coerce_instance_wrap(constrained_schema)
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
schema['serialization'] = serialization
return schema
def mapping_like_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
origin: Any = get_origin(source_type)
mapped_origin = MAPPING_ORIGIN_MAP.get(origin, None) if origin else MAPPING_ORIGIN_MAP.get(source_type, None)
if mapped_origin is None:
return None
args = get_args(source_type)
if not args:
args = (Any, Any)
elif mapped_origin is collections.Counter:
# a single generic
if len(args) != 1:
raise ValueError('Expected Counter to have exactly 1 generic parameter')
args = (args[0], int) # keys are always an int
elif len(args) != 2:
raise ValueError('Expected mapping to have exactly 2 generic parameters')
keys_source_type, values_source_type = args
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
return (
source_type,
[
MappingValidator(mapped_origin, keys_source_type, values_source_type, **metadata),
*remaining_annotations,
],
)
def ip_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
def make_strict_ip_schema(tp: type[Any]) -> CoreSchema:
return core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(tp, core_schema.str_schema()),
python_schema=core_schema.is_instance_schema(tp),
)
if source_type is IPv4Address:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_address_validator),
strict_schema=make_strict_ip_schema(IPv4Address),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv4'},
),
*annotations,
]
if source_type is IPv4Network:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_network_validator),
strict_schema=make_strict_ip_schema(IPv4Network),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv4network'},
),
*annotations,
]
if source_type is IPv4Interface:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_interface_validator),
strict_schema=make_strict_ip_schema(IPv4Interface),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv4interface'},
),
*annotations,
]
if source_type is IPv6Address:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_address_validator),
strict_schema=make_strict_ip_schema(IPv6Address),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv6'},
),
*annotations,
]
if source_type is IPv6Network:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_network_validator),
strict_schema=make_strict_ip_schema(IPv6Network),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv6network'},
),
*annotations,
]
if source_type is IPv6Interface:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_interface_validator),
strict_schema=make_strict_ip_schema(IPv6Interface),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv6interface'},
),
*annotations,
]
return None
def url_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
if source_type is Url:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.url_schema(),
lambda cs, handler: handler(cs),
),
*annotations,
]
if source_type is MultiHostUrl:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.multi_host_url_schema(),
lambda cs, handler: handler(cs),
),
*annotations,
]
PREPARE_METHODS: tuple[Callable[[Any, Iterable[Any], ConfigDict], tuple[Any, list[Any]] | None], ...] = (
decimal_prepare_pydantic_annotations,
sequence_like_prepare_pydantic_annotations,
datetime_prepare_pydantic_annotations,
uuid_prepare_pydantic_annotations,
path_schema_prepare_pydantic_annotations,
mapping_like_prepare_pydantic_annotations,
ip_prepare_pydantic_annotations,
url_prepare_pydantic_annotations,
)

View file

@ -0,0 +1,469 @@
"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap python's typing module."""
from __future__ import annotations as _annotations
import dataclasses
import sys
import types
import typing
from collections.abc import Callable
from functools import partial
from types import GetSetDescriptorType
from typing import TYPE_CHECKING, Any, Final
from typing_extensions import Annotated, Literal, TypeAliasType, TypeGuard, get_args, get_origin
if TYPE_CHECKING:
from ._dataclasses import StandardDataclass
try:
from typing import _TypingBase # type: ignore[attr-defined]
except ImportError:
from typing import _Final as _TypingBase # type: ignore[attr-defined]
typing_base = _TypingBase
if sys.version_info < (3, 9):
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
TypingGenericAlias = ()
else:
from typing import GenericAlias as TypingGenericAlias # type: ignore
if sys.version_info < (3, 11):
from typing_extensions import NotRequired, Required
else:
from typing import NotRequired, Required # noqa: F401
if sys.version_info < (3, 10):
def origin_is_union(tp: type[Any] | None) -> bool:
return tp is typing.Union
WithArgsTypes = (TypingGenericAlias,)
else:
def origin_is_union(tp: type[Any] | None) -> bool:
return tp is typing.Union or tp is types.UnionType
WithArgsTypes = typing._GenericAlias, types.GenericAlias, types.UnionType # type: ignore[attr-defined]
if sys.version_info < (3, 10):
NoneType = type(None)
EllipsisType = type(Ellipsis)
else:
from types import NoneType as NoneType
LITERAL_TYPES: set[Any] = {Literal}
if hasattr(typing, 'Literal'):
LITERAL_TYPES.add(typing.Literal) # type: ignore
NONE_TYPES: tuple[Any, ...] = (None, NoneType, *(tp[None] for tp in LITERAL_TYPES))
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type
def is_none_type(type_: Any) -> bool:
return type_ in NONE_TYPES
def is_callable_type(type_: type[Any]) -> bool:
return type_ is Callable or get_origin(type_) is Callable
def is_literal_type(type_: type[Any]) -> bool:
return Literal is not None and get_origin(type_) in LITERAL_TYPES
def literal_values(type_: type[Any]) -> tuple[Any, ...]:
return get_args(type_)
def all_literal_values(type_: type[Any]) -> list[Any]:
"""This method is used to retrieve all Literal values as
Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586)
e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`.
"""
if not is_literal_type(type_):
return [type_]
values = literal_values(type_)
return list(x for value in values for x in all_literal_values(value))
def is_annotated(ann_type: Any) -> bool:
from ._utils import lenient_issubclass
origin = get_origin(ann_type)
return origin is not None and lenient_issubclass(origin, Annotated)
def is_namedtuple(type_: type[Any]) -> bool:
"""Check if a given class is a named tuple.
It can be either a `typing.NamedTuple` or `collections.namedtuple`.
"""
from ._utils import lenient_issubclass
return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields')
test_new_type = typing.NewType('test_new_type', str)
def is_new_type(type_: type[Any]) -> bool:
"""Check whether type_ was created using typing.NewType.
Can't use isinstance because it fails <3.10.
"""
return isinstance(type_, test_new_type.__class__) and hasattr(type_, '__supertype__') # type: ignore[arg-type]
def _check_classvar(v: type[Any] | None) -> bool:
if v is None:
return False
return v.__class__ == typing.ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar'
def is_classvar(ann_type: type[Any]) -> bool:
if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)):
return True
# this is an ugly workaround for class vars that contain forward references and are therefore themselves
# forward references, see #3679
if ann_type.__class__ == typing.ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['): # type: ignore
return True
return False
def _check_finalvar(v: type[Any] | None) -> bool:
"""Check if a given type is a `typing.Final` type."""
if v is None:
return False
return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final')
def is_finalvar(ann_type: Any) -> bool:
return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type))
def parent_frame_namespace(*, parent_depth: int = 2) -> dict[str, Any] | None:
"""We allow use of items in parent namespace to get around the issue with `get_type_hints` only looking in the
global module namespace. See https://github.com/pydantic/pydantic/issues/2678#issuecomment-1008139014 -> Scope
and suggestion at the end of the next comment by @gvanrossum.
WARNING 1: it matters exactly where this is called. By default, this function will build a namespace from the
parent of where it is called.
WARNING 2: this only looks in the parent namespace, not other parents since (AFAIK) there's no way to collect a
dict of exactly what's in scope. Using `f_back` would work sometimes but would be very wrong and confusing in many
other cases. See https://discuss.python.org/t/is-there-a-way-to-access-parent-nested-namespaces/20659.
"""
frame = sys._getframe(parent_depth)
# if f_back is None, it's the global module namespace and we don't need to include it here
if frame.f_back is None:
return None
else:
return frame.f_locals
def add_module_globals(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]:
module_name = getattr(obj, '__module__', None)
if module_name:
try:
module_globalns = sys.modules[module_name].__dict__
except KeyError:
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
pass
else:
if globalns:
return {**module_globalns, **globalns}
else:
# copy module globals to make sure it can't be updated later
return module_globalns.copy()
return globalns or {}
def get_cls_types_namespace(cls: type[Any], parent_namespace: dict[str, Any] | None = None) -> dict[str, Any]:
ns = add_module_globals(cls, parent_namespace)
ns[cls.__name__] = cls
return ns
def get_cls_type_hints_lenient(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]:
"""Collect annotations from a class, including those from parent classes.
Unlike `typing.get_type_hints`, this function will not error if a forward reference is not resolvable.
"""
hints = {}
for base in reversed(obj.__mro__):
ann = base.__dict__.get('__annotations__')
localns = dict(vars(base))
if ann is not None and ann is not GetSetDescriptorType:
for name, value in ann.items():
hints[name] = eval_type_lenient(value, globalns, localns)
return hints
def eval_type_lenient(value: Any, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None) -> Any:
"""Behaves like typing._eval_type, except it won't raise an error if a forward reference can't be resolved."""
if value is None:
value = NoneType
elif isinstance(value, str):
value = _make_forward_ref(value, is_argument=False, is_class=True)
try:
return eval_type_backport(value, globalns, localns)
except NameError:
# the point of this function is to be tolerant to this case
return value
def eval_type_backport(
value: Any, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None
) -> Any:
"""Like `typing._eval_type`, but falls back to the `eval_type_backport` package if it's
installed to let older Python versions use newer typing features.
Specifically, this transforms `X | Y` into `typing.Union[X, Y]`
and `list[X]` into `typing.List[X]` etc. (for all the types made generic in PEP 585)
if the original syntax is not supported in the current Python version.
"""
try:
return typing._eval_type( # type: ignore
value, globalns, localns
)
except TypeError as e:
if not (isinstance(value, typing.ForwardRef) and is_backport_fixable_error(e)):
raise
try:
from eval_type_backport import eval_type_backport
except ImportError:
raise TypeError(
f'You have a type annotation {value.__forward_arg__!r} '
f'which makes use of newer typing features than are supported in your version of Python. '
f'To handle this error, you should either remove the use of new syntax '
f'or install the `eval_type_backport` package.'
) from e
return eval_type_backport(value, globalns, localns, try_default=False)
def is_backport_fixable_error(e: TypeError) -> bool:
msg = str(e)
return msg.startswith('unsupported operand type(s) for |: ') or "' object is not subscriptable" in msg
def get_function_type_hints(
function: Callable[..., Any], *, include_keys: set[str] | None = None, types_namespace: dict[str, Any] | None = None
) -> dict[str, Any]:
"""Like `typing.get_type_hints`, but doesn't convert `X` to `Optional[X]` if the default value is `None`, also
copes with `partial`.
"""
if isinstance(function, partial):
annotations = function.func.__annotations__
else:
annotations = function.__annotations__
globalns = add_module_globals(function)
type_hints = {}
for name, value in annotations.items():
if include_keys is not None and name not in include_keys:
continue
if value is None:
value = NoneType
elif isinstance(value, str):
value = _make_forward_ref(value)
type_hints[name] = eval_type_backport(value, globalns, types_namespace)
return type_hints
if sys.version_info < (3, 9, 8) or (3, 10) <= sys.version_info < (3, 10, 1):
def _make_forward_ref(
arg: Any,
is_argument: bool = True,
*,
is_class: bool = False,
) -> typing.ForwardRef:
"""Wrapper for ForwardRef that accounts for the `is_class` argument missing in older versions.
The `module` argument is omitted as it breaks <3.9.8, =3.10.0 and isn't used in the calls below.
See https://github.com/python/cpython/pull/28560 for some background.
The backport happened on 3.9.8, see:
https://github.com/pydantic/pydantic/discussions/6244#discussioncomment-6275458,
and on 3.10.1 for the 3.10 branch, see:
https://github.com/pydantic/pydantic/issues/6912
Implemented as EAFP with memory.
"""
return typing.ForwardRef(arg, is_argument)
else:
_make_forward_ref = typing.ForwardRef
if sys.version_info >= (3, 10):
get_type_hints = typing.get_type_hints
else:
"""
For older versions of python, we have a custom implementation of `get_type_hints` which is a close as possible to
the implementation in CPython 3.10.8.
"""
@typing.no_type_check
def get_type_hints( # noqa: C901
obj: Any,
globalns: dict[str, Any] | None = None,
localns: dict[str, Any] | None = None,
include_extras: bool = False,
) -> dict[str, Any]: # pragma: no cover
"""Taken verbatim from python 3.10.8 unchanged, except:
* type annotations of the function definition above.
* prefixing `typing.` where appropriate
* Use `_make_forward_ref` instead of `typing.ForwardRef` to handle the `is_class` argument.
https://github.com/python/cpython/blob/aaaf5174241496afca7ce4d4584570190ff972fe/Lib/typing.py#L1773-L1875
DO NOT CHANGE THIS METHOD UNLESS ABSOLUTELY NECESSARY.
======================================================
Return type hints for an object.
This is often the same as obj.__annotations__, but it handles
forward references encoded as string literals, adds Optional[t] if a
default value equal to None is set and recursively replaces all
'Annotated[T, ...]' with 'T' (unless 'include_extras=True').
The argument may be a module, class, method, or function. The annotations
are returned as a dictionary. For classes, annotations include also
inherited members.
TypeError is raised if the argument is not of a type that can contain
annotations, and an empty dictionary is returned if no annotations are
present.
BEWARE -- the behavior of globalns and localns is counterintuitive
(unless you are familiar with how eval() and exec() work). The
search order is locals first, then globals.
- If no dict arguments are passed, an attempt is made to use the
globals from obj (or the respective module's globals for classes),
and these are also used as the locals. If the object does not appear
to have globals, an empty dictionary is used. For classes, the search
order is globals first then locals.
- If one dict argument is passed, it is used for both globals and
locals.
- If two dict arguments are passed, they specify globals and
locals, respectively.
"""
if getattr(obj, '__no_type_check__', None):
return {}
# Classes require a special treatment.
if isinstance(obj, type):
hints = {}
for base in reversed(obj.__mro__):
if globalns is None:
base_globals = getattr(sys.modules.get(base.__module__, None), '__dict__', {})
else:
base_globals = globalns
ann = base.__dict__.get('__annotations__', {})
if isinstance(ann, types.GetSetDescriptorType):
ann = {}
base_locals = dict(vars(base)) if localns is None else localns
if localns is None and globalns is None:
# This is surprising, but required. Before Python 3.10,
# get_type_hints only evaluated the globalns of
# a class. To maintain backwards compatibility, we reverse
# the globalns and localns order so that eval() looks into
# *base_globals* first rather than *base_locals*.
# This only affects ForwardRefs.
base_globals, base_locals = base_locals, base_globals
for name, value in ann.items():
if value is None:
value = type(None)
if isinstance(value, str):
value = _make_forward_ref(value, is_argument=False, is_class=True)
value = eval_type_backport(value, base_globals, base_locals)
hints[name] = value
if not include_extras and hasattr(typing, '_strip_annotations'):
return {
k: typing._strip_annotations(t) # type: ignore
for k, t in hints.items()
}
else:
return hints
if globalns is None:
if isinstance(obj, types.ModuleType):
globalns = obj.__dict__
else:
nsobj = obj
# Find globalns for the unwrapped object.
while hasattr(nsobj, '__wrapped__'):
nsobj = nsobj.__wrapped__
globalns = getattr(nsobj, '__globals__', {})
if localns is None:
localns = globalns
elif localns is None:
localns = globalns
hints = getattr(obj, '__annotations__', None)
if hints is None:
# Return empty annotations for something that _could_ have them.
if isinstance(obj, typing._allowed_types): # type: ignore
return {}
else:
raise TypeError(f'{obj!r} is not a module, class, method, ' 'or function.')
defaults = typing._get_defaults(obj) # type: ignore
hints = dict(hints)
for name, value in hints.items():
if value is None:
value = type(None)
if isinstance(value, str):
# class-level forward refs were handled above, this must be either
# a module-level annotation or a function argument annotation
value = _make_forward_ref(
value,
is_argument=not isinstance(obj, types.ModuleType),
is_class=False,
)
value = eval_type_backport(value, globalns, localns)
if name in defaults and defaults[name] is None:
value = typing.Optional[value]
hints[name] = value
return hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore
def is_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
# The dataclasses.is_dataclass function doesn't seem to provide TypeGuard functionality,
# so I created this convenience function
return dataclasses.is_dataclass(_cls)
def origin_is_type_alias_type(origin: Any) -> TypeGuard[TypeAliasType]:
return isinstance(origin, TypeAliasType)
if sys.version_info >= (3, 10):
def is_generic_alias(type_: type[Any]) -> bool:
return isinstance(type_, (types.GenericAlias, typing._GenericAlias)) # type: ignore[attr-defined]
else:
def is_generic_alias(type_: type[Any]) -> bool:
return isinstance(type_, typing._GenericAlias) # type: ignore

View file

@ -0,0 +1,362 @@
"""Bucket of reusable internal utilities.
This should be reduced as much as possible with functions only used in one place, moved to that place.
"""
from __future__ import annotations as _annotations
import dataclasses
import keyword
import typing
import weakref
from collections import OrderedDict, defaultdict, deque
from copy import deepcopy
from itertools import zip_longest
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
from typing import Any, Mapping, TypeVar
from typing_extensions import TypeAlias, TypeGuard
from . import _repr, _typing_extra
if typing.TYPE_CHECKING:
MappingIntStrAny: TypeAlias = 'typing.Mapping[int, Any] | typing.Mapping[str, Any]'
AbstractSetIntStr: TypeAlias = 'typing.AbstractSet[int] | typing.AbstractSet[str]'
from ..main import BaseModel
# these are types that are returned unchanged by deepcopy
IMMUTABLE_NON_COLLECTIONS_TYPES: set[type[Any]] = {
int,
float,
complex,
str,
bool,
bytes,
type,
_typing_extra.NoneType,
FunctionType,
BuiltinFunctionType,
LambdaType,
weakref.ref,
CodeType,
# note: including ModuleType will differ from behaviour of deepcopy by not producing error.
# It might be not a good idea in general, but considering that this function used only internally
# against default values of fields, this will allow to actually have a field with module as default value
ModuleType,
NotImplemented.__class__,
Ellipsis.__class__,
}
# these are types that if empty, might be copied with simple copy() instead of deepcopy()
BUILTIN_COLLECTIONS: set[type[Any]] = {
list,
set,
tuple,
frozenset,
dict,
OrderedDict,
defaultdict,
deque,
}
def sequence_like(v: Any) -> bool:
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
def lenient_isinstance(o: Any, class_or_tuple: type[Any] | tuple[type[Any], ...] | None) -> bool: # pragma: no cover
try:
return isinstance(o, class_or_tuple) # type: ignore[arg-type]
except TypeError:
return False
def lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
except TypeError:
if isinstance(cls, _typing_extra.WithArgsTypes):
return False
raise # pragma: no cover
def is_model_class(cls: Any) -> TypeGuard[type[BaseModel]]:
"""Returns true if cls is a _proper_ subclass of BaseModel, and provides proper type-checking,
unlike raw calls to lenient_issubclass.
"""
from ..main import BaseModel
return lenient_issubclass(cls, BaseModel) and cls is not BaseModel
def is_valid_identifier(identifier: str) -> bool:
"""Checks that a string is a valid identifier and not a Python keyword.
:param identifier: The identifier to test.
:return: True if the identifier is valid.
"""
return identifier.isidentifier() and not keyword.iskeyword(identifier)
KeyType = TypeVar('KeyType')
def deep_update(mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]) -> dict[KeyType, Any]:
updated_mapping = mapping.copy()
for updating_mapping in updating_mappings:
for k, v in updating_mapping.items():
if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
updated_mapping[k] = deep_update(updated_mapping[k], v)
else:
updated_mapping[k] = v
return updated_mapping
def update_not_none(mapping: dict[Any, Any], **update: Any) -> None:
mapping.update({k: v for k, v in update.items() if v is not None})
T = TypeVar('T')
def unique_list(
input_list: list[T] | tuple[T, ...],
*,
name_factory: typing.Callable[[T], str] = str,
) -> list[T]:
"""Make a list unique while maintaining order.
We update the list if another one with the same name is set
(e.g. model validator overridden in subclass).
"""
result: list[T] = []
result_names: list[str] = []
for v in input_list:
v_name = name_factory(v)
if v_name not in result_names:
result_names.append(v_name)
result.append(v)
else:
result[result_names.index(v_name)] = v
return result
class ValueItems(_repr.Representation):
"""Class for more convenient calculation of excluded or included fields on values."""
__slots__ = ('_items', '_type')
def __init__(self, value: Any, items: AbstractSetIntStr | MappingIntStrAny) -> None:
items = self._coerce_items(items)
if isinstance(value, (list, tuple)):
items = self._normalize_indexes(items, len(value)) # type: ignore
self._items: MappingIntStrAny = items # type: ignore
def is_excluded(self, item: Any) -> bool:
"""Check if item is fully excluded.
:param item: key or index of a value
"""
return self.is_true(self._items.get(item))
def is_included(self, item: Any) -> bool:
"""Check if value is contained in self._items.
:param item: key or index of value
"""
return item in self._items
def for_element(self, e: int | str) -> AbstractSetIntStr | MappingIntStrAny | None:
""":param e: key or index of element on value
:return: raw values for element if self._items is dict and contain needed element
"""
item = self._items.get(e) # type: ignore
return item if not self.is_true(item) else None
def _normalize_indexes(self, items: MappingIntStrAny, v_length: int) -> dict[int | str, Any]:
""":param items: dict or set of indexes which will be normalized
:param v_length: length of sequence indexes of which will be
>>> self._normalize_indexes({0: True, -2: True, -1: True}, 4)
{0: True, 2: True, 3: True}
>>> self._normalize_indexes({'__all__': True}, 4)
{0: True, 1: True, 2: True, 3: True}
"""
normalized_items: dict[int | str, Any] = {}
all_items = None
for i, v in items.items():
if not (isinstance(v, typing.Mapping) or isinstance(v, typing.AbstractSet) or self.is_true(v)):
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
if i == '__all__':
all_items = self._coerce_value(v)
continue
if not isinstance(i, int):
raise TypeError(
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
'expected integer keys or keyword "__all__"'
)
normalized_i = v_length + i if i < 0 else i
normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i))
if not all_items:
return normalized_items
if self.is_true(all_items):
for i in range(v_length):
normalized_items.setdefault(i, ...)
return normalized_items
for i in range(v_length):
normalized_item = normalized_items.setdefault(i, {})
if not self.is_true(normalized_item):
normalized_items[i] = self.merge(all_items, normalized_item)
return normalized_items
@classmethod
def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any:
"""Merge a `base` item with an `override` item.
Both `base` and `override` are converted to dictionaries if possible.
Sets are converted to dictionaries with the sets entries as keys and
Ellipsis as values.
Each key-value pair existing in `base` is merged with `override`,
while the rest of the key-value pairs are updated recursively with this function.
Merging takes place based on the "union" of keys if `intersect` is
set to `False` (default) and on the intersection of keys if
`intersect` is set to `True`.
"""
override = cls._coerce_value(override)
base = cls._coerce_value(base)
if override is None:
return base
if cls.is_true(base) or base is None:
return override
if cls.is_true(override):
return base if intersect else override
# intersection or union of keys while preserving ordering:
if intersect:
merge_keys = [k for k in base if k in override] + [k for k in override if k in base]
else:
merge_keys = list(base) + [k for k in override if k not in base]
merged: dict[int | str, Any] = {}
for k in merge_keys:
merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect)
if merged_item is not None:
merged[k] = merged_item
return merged
@staticmethod
def _coerce_items(items: AbstractSetIntStr | MappingIntStrAny) -> MappingIntStrAny:
if isinstance(items, typing.Mapping):
pass
elif isinstance(items, typing.AbstractSet):
items = dict.fromkeys(items, ...) # type: ignore
else:
class_name = getattr(items, '__class__', '???')
raise TypeError(f'Unexpected type of exclude value {class_name}')
return items # type: ignore
@classmethod
def _coerce_value(cls, value: Any) -> Any:
if value is None or cls.is_true(value):
return value
return cls._coerce_items(value)
@staticmethod
def is_true(v: Any) -> bool:
return v is True or v is ...
def __repr_args__(self) -> _repr.ReprArgs:
return [(None, self._items)]
if typing.TYPE_CHECKING:
def ClassAttribute(name: str, value: T) -> T:
...
else:
class ClassAttribute:
"""Hide class attribute from its instances."""
__slots__ = 'name', 'value'
def __init__(self, name: str, value: Any) -> None:
self.name = name
self.value = value
def __get__(self, instance: Any, owner: type[Any]) -> None:
if instance is None:
return self.value
raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only')
Obj = TypeVar('Obj')
def smart_deepcopy(obj: Obj) -> Obj:
"""Return type as is for immutable built-in types
Use obj.copy() for built-in empty collections
Use copy.deepcopy() for non-empty collections and unknown objects.
"""
obj_type = obj.__class__
if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES:
return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway
try:
if not obj and obj_type in BUILTIN_COLLECTIONS:
# faster way for empty collections, no need to copy its members
return obj if obj_type is tuple else obj.copy() # tuple doesn't have copy method # type: ignore
except (TypeError, ValueError, RuntimeError):
# do we really dare to catch ALL errors? Seems a bit risky
pass
return deepcopy(obj) # slowest way when we actually might need a deepcopy
_SENTINEL = object()
def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bool:
"""Check that the items of `left` are the same objects as those in `right`.
>>> a, b = object(), object()
>>> all_identical([a, b, a], [a, b, a])
True
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
False
"""
for left_item, right_item in zip_longest(left, right, fillvalue=_SENTINEL):
if left_item is not right_item:
return False
return True
@dataclasses.dataclass(frozen=True)
class SafeGetItemProxy:
"""Wrapper redirecting `__getitem__` to `get` with a sentinel value as default
This makes is safe to use in `operator.itemgetter` when some keys may be missing
"""
# Define __slots__manually for performances
# @dataclasses.dataclass() only support slots=True in python>=3.10
__slots__ = ('wrapped',)
wrapped: Mapping[str, Any]
def __getitem__(self, __key: str) -> Any:
return self.wrapped.get(__key, _SENTINEL)
# required to pass the object to operator.itemgetter() instances due to a quirk of typeshed
# https://github.com/python/mypy/issues/13713
# https://github.com/python/typeshed/pull/8785
# Since this is typing-only, hide it in a typing.TYPE_CHECKING block
if typing.TYPE_CHECKING:
def __contains__(self, __key: str) -> bool:
return self.wrapped.__contains__(__key)

View file

@ -0,0 +1,84 @@
from __future__ import annotations as _annotations
import inspect
from functools import partial
from typing import Any, Awaitable, Callable
import pydantic_core
from ..config import ConfigDict
from ..plugin._schema_validator import create_schema_validator
from . import _generate_schema, _typing_extra
from ._config import ConfigWrapper
class ValidateCallWrapper:
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value."""
__slots__ = (
'__pydantic_validator__',
'__name__',
'__qualname__',
'__annotations__',
'__dict__', # required for __module__
)
def __init__(self, function: Callable[..., Any], config: ConfigDict | None, validate_return: bool):
if isinstance(function, partial):
func = function.func
schema_type = func
self.__name__ = f'partial({func.__name__})'
self.__qualname__ = f'partial({func.__qualname__})'
self.__module__ = func.__module__
else:
schema_type = function
self.__name__ = function.__name__
self.__qualname__ = function.__qualname__
self.__module__ = function.__module__
namespace = _typing_extra.add_module_globals(function, None)
config_wrapper = ConfigWrapper(config)
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
core_config = config_wrapper.core_config(self)
self.__pydantic_validator__ = create_schema_validator(
schema,
schema_type,
self.__module__,
self.__qualname__,
'validate_call',
core_config,
config_wrapper.plugin_settings,
)
if validate_return:
signature = inspect.signature(function)
return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
validator = create_schema_validator(
schema,
schema_type,
self.__module__,
self.__qualname__,
'validate_call',
core_config,
config_wrapper.plugin_settings,
)
if inspect.iscoroutinefunction(function):
async def return_val_wrapper(aw: Awaitable[Any]) -> None:
return validator.validate_python(await aw)
self.__return_pydantic_validator__ = return_val_wrapper
else:
self.__return_pydantic_validator__ = validator.validate_python
else:
self.__return_pydantic_validator__ = None
def __call__(self, *args: Any, **kwargs: Any) -> Any:
res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
if self.__return_pydantic_validator__:
return self.__return_pydantic_validator__(res)
return res

View file

@ -0,0 +1,278 @@
"""Validator functions for standard library types.
Import of this module is deferred since it contains imports of many standard library modules.
"""
from __future__ import annotations as _annotations
import math
import re
import typing
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from typing import Any
from pydantic_core import PydanticCustomError, core_schema
from pydantic_core._pydantic_core import PydanticKnownError
def sequence_validator(
__input_value: typing.Sequence[Any],
validator: core_schema.ValidatorFunctionWrapHandler,
) -> typing.Sequence[Any]:
"""Validator for `Sequence` types, isinstance(v, Sequence) has already been called."""
value_type = type(__input_value)
# We don't accept any plain string as a sequence
# Relevant issue: https://github.com/pydantic/pydantic/issues/5595
if issubclass(value_type, (str, bytes)):
raise PydanticCustomError(
'sequence_str',
"'{type_name}' instances are not allowed as a Sequence value",
{'type_name': value_type.__name__},
)
v_list = validator(__input_value)
# the rest of the logic is just re-creating the original type from `v_list`
if value_type == list:
return v_list
elif issubclass(value_type, range):
# return the list as we probably can't re-create the range
return v_list
else:
# best guess at how to re-create the original type, more custom construction logic might be required
return value_type(v_list) # type: ignore[call-arg]
def import_string(value: Any) -> Any:
if isinstance(value, str):
try:
return _import_string_logic(value)
except ImportError as e:
raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) from e
else:
# otherwise we just return the value and let the next validator do the rest of the work
return value
def _import_string_logic(dotted_path: str) -> Any:
"""Inspired by uvicorn — dotted paths should include a colon before the final item if that item is not a module.
(This is necessary to distinguish between a submodule and an attribute when there is a conflict.).
If the dotted path does not include a colon and the final item is not a valid module, importing as an attribute
rather than a submodule will be attempted automatically.
So, for example, the following values of `dotted_path` result in the following returned values:
* 'collections': <module 'collections'>
* 'collections.abc': <module 'collections.abc'>
* 'collections.abc:Mapping': <class 'collections.abc.Mapping'>
* `collections.abc.Mapping`: <class 'collections.abc.Mapping'> (though this is a bit slower than the previous line)
An error will be raised under any of the following scenarios:
* `dotted_path` contains more than one colon (e.g., 'collections:abc:Mapping')
* the substring of `dotted_path` before the colon is not a valid module in the environment (e.g., '123:Mapping')
* the substring of `dotted_path` after the colon is not an attribute of the module (e.g., 'collections:abc123')
"""
from importlib import import_module
components = dotted_path.strip().split(':')
if len(components) > 2:
raise ImportError(f"Import strings should have at most one ':'; received {dotted_path!r}")
module_path = components[0]
if not module_path:
raise ImportError(f'Import strings should have a nonempty module name; received {dotted_path!r}')
try:
module = import_module(module_path)
except ModuleNotFoundError as e:
if '.' in module_path:
# Check if it would be valid if the final item was separated from its module with a `:`
maybe_module_path, maybe_attribute = dotted_path.strip().rsplit('.', 1)
try:
return _import_string_logic(f'{maybe_module_path}:{maybe_attribute}')
except ImportError:
pass
raise ImportError(f'No module named {module_path!r}') from e
raise e
if len(components) > 1:
attribute = components[1]
try:
return getattr(module, attribute)
except AttributeError as e:
raise ImportError(f'cannot import name {attribute!r} from {module_path!r}') from e
else:
return module
def pattern_either_validator(__input_value: Any) -> typing.Pattern[Any]:
if isinstance(__input_value, typing.Pattern):
return __input_value
elif isinstance(__input_value, (str, bytes)):
# todo strict mode
return compile_pattern(__input_value) # type: ignore
else:
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
def pattern_str_validator(__input_value: Any) -> typing.Pattern[str]:
if isinstance(__input_value, typing.Pattern):
if isinstance(__input_value.pattern, str):
return __input_value
else:
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
elif isinstance(__input_value, str):
return compile_pattern(__input_value)
elif isinstance(__input_value, bytes):
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
else:
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
def pattern_bytes_validator(__input_value: Any) -> typing.Pattern[bytes]:
if isinstance(__input_value, typing.Pattern):
if isinstance(__input_value.pattern, bytes):
return __input_value
else:
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
elif isinstance(__input_value, bytes):
return compile_pattern(__input_value)
elif isinstance(__input_value, str):
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
else:
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
PatternType = typing.TypeVar('PatternType', str, bytes)
def compile_pattern(pattern: PatternType) -> typing.Pattern[PatternType]:
try:
return re.compile(pattern)
except re.error:
raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression')
def ip_v4_address_validator(__input_value: Any) -> IPv4Address:
if isinstance(__input_value, IPv4Address):
return __input_value
try:
return IPv4Address(__input_value)
except ValueError:
raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address')
def ip_v6_address_validator(__input_value: Any) -> IPv6Address:
if isinstance(__input_value, IPv6Address):
return __input_value
try:
return IPv6Address(__input_value)
except ValueError:
raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address')
def ip_v4_network_validator(__input_value: Any) -> IPv4Network:
"""Assume IPv4Network initialised with a default `strict` argument.
See more:
https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
"""
if isinstance(__input_value, IPv4Network):
return __input_value
try:
return IPv4Network(__input_value)
except ValueError:
raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network')
def ip_v6_network_validator(__input_value: Any) -> IPv6Network:
"""Assume IPv6Network initialised with a default `strict` argument.
See more:
https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
"""
if isinstance(__input_value, IPv6Network):
return __input_value
try:
return IPv6Network(__input_value)
except ValueError:
raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network')
def ip_v4_interface_validator(__input_value: Any) -> IPv4Interface:
if isinstance(__input_value, IPv4Interface):
return __input_value
try:
return IPv4Interface(__input_value)
except ValueError:
raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface')
def ip_v6_interface_validator(__input_value: Any) -> IPv6Interface:
if isinstance(__input_value, IPv6Interface):
return __input_value
try:
return IPv6Interface(__input_value)
except ValueError:
raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface')
def greater_than_validator(x: Any, gt: Any) -> Any:
if not (x > gt):
raise PydanticKnownError('greater_than', {'gt': gt})
return x
def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
if not (x >= ge):
raise PydanticKnownError('greater_than_equal', {'ge': ge})
return x
def less_than_validator(x: Any, lt: Any) -> Any:
if not (x < lt):
raise PydanticKnownError('less_than', {'lt': lt})
return x
def less_than_or_equal_validator(x: Any, le: Any) -> Any:
if not (x <= le):
raise PydanticKnownError('less_than_equal', {'le': le})
return x
def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
if not (x % multiple_of == 0):
raise PydanticKnownError('multiple_of', {'multiple_of': multiple_of})
return x
def min_length_validator(x: Any, min_length: Any) -> Any:
if not (len(x) >= min_length):
raise PydanticKnownError(
'too_short',
{'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)},
)
return x
def max_length_validator(x: Any, max_length: Any) -> Any:
if len(x) > max_length:
raise PydanticKnownError(
'too_long',
{'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
)
return x
def forbid_inf_nan_check(x: Any) -> Any:
if not math.isfinite(x):
raise PydanticKnownError('finite_number')
return x

308
lib/pydantic/_migration.py Normal file
View file

@ -0,0 +1,308 @@
import sys
from typing import Any, Callable, Dict
from .version import version_short
MOVED_IN_V2 = {
'pydantic.utils:version_info': 'pydantic.version:version_info',
'pydantic.error_wrappers:ValidationError': 'pydantic:ValidationError',
'pydantic.utils:to_camel': 'pydantic.alias_generators:to_pascal',
'pydantic.utils:to_lower_camel': 'pydantic.alias_generators:to_camel',
'pydantic:PyObject': 'pydantic.types:ImportString',
'pydantic.types:PyObject': 'pydantic.types:ImportString',
'pydantic.generics:GenericModel': 'pydantic.BaseModel',
}
DEPRECATED_MOVED_IN_V2 = {
'pydantic.tools:schema_of': 'pydantic.deprecated.tools:schema_of',
'pydantic.tools:parse_obj_as': 'pydantic.deprecated.tools:parse_obj_as',
'pydantic.tools:schema_json_of': 'pydantic.deprecated.tools:schema_json_of',
'pydantic.json:pydantic_encoder': 'pydantic.deprecated.json:pydantic_encoder',
'pydantic:validate_arguments': 'pydantic.deprecated.decorator:validate_arguments',
'pydantic.json:custom_pydantic_encoder': 'pydantic.deprecated.json:custom_pydantic_encoder',
'pydantic.json:timedelta_isoformat': 'pydantic.deprecated.json:timedelta_isoformat',
'pydantic.decorator:validate_arguments': 'pydantic.deprecated.decorator:validate_arguments',
'pydantic.class_validators:validator': 'pydantic.deprecated.class_validators:validator',
'pydantic.class_validators:root_validator': 'pydantic.deprecated.class_validators:root_validator',
'pydantic.config:BaseConfig': 'pydantic.deprecated.config:BaseConfig',
'pydantic.config:Extra': 'pydantic.deprecated.config:Extra',
}
REDIRECT_TO_V1 = {
f'pydantic.utils:{obj}': f'pydantic.v1.utils:{obj}'
for obj in (
'deep_update',
'GetterDict',
'lenient_issubclass',
'lenient_isinstance',
'is_valid_field',
'update_not_none',
'import_string',
'Representation',
'ROOT_KEY',
'smart_deepcopy',
'sequence_like',
)
}
REMOVED_IN_V2 = {
'pydantic:ConstrainedBytes',
'pydantic:ConstrainedDate',
'pydantic:ConstrainedDecimal',
'pydantic:ConstrainedFloat',
'pydantic:ConstrainedFrozenSet',
'pydantic:ConstrainedInt',
'pydantic:ConstrainedList',
'pydantic:ConstrainedSet',
'pydantic:ConstrainedStr',
'pydantic:JsonWrapper',
'pydantic:NoneBytes',
'pydantic:NoneStr',
'pydantic:NoneStrBytes',
'pydantic:Protocol',
'pydantic:Required',
'pydantic:StrBytes',
'pydantic:compiled',
'pydantic.config:get_config',
'pydantic.config:inherit_config',
'pydantic.config:prepare_config',
'pydantic:create_model_from_namedtuple',
'pydantic:create_model_from_typeddict',
'pydantic.dataclasses:create_pydantic_model_from_dataclass',
'pydantic.dataclasses:make_dataclass_validator',
'pydantic.dataclasses:set_validation',
'pydantic.datetime_parse:parse_date',
'pydantic.datetime_parse:parse_time',
'pydantic.datetime_parse:parse_datetime',
'pydantic.datetime_parse:parse_duration',
'pydantic.error_wrappers:ErrorWrapper',
'pydantic.errors:AnyStrMaxLengthError',
'pydantic.errors:AnyStrMinLengthError',
'pydantic.errors:ArbitraryTypeError',
'pydantic.errors:BoolError',
'pydantic.errors:BytesError',
'pydantic.errors:CallableError',
'pydantic.errors:ClassError',
'pydantic.errors:ColorError',
'pydantic.errors:ConfigError',
'pydantic.errors:DataclassTypeError',
'pydantic.errors:DateError',
'pydantic.errors:DateNotInTheFutureError',
'pydantic.errors:DateNotInThePastError',
'pydantic.errors:DateTimeError',
'pydantic.errors:DecimalError',
'pydantic.errors:DecimalIsNotFiniteError',
'pydantic.errors:DecimalMaxDigitsError',
'pydantic.errors:DecimalMaxPlacesError',
'pydantic.errors:DecimalWholeDigitsError',
'pydantic.errors:DictError',
'pydantic.errors:DurationError',
'pydantic.errors:EmailError',
'pydantic.errors:EnumError',
'pydantic.errors:EnumMemberError',
'pydantic.errors:ExtraError',
'pydantic.errors:FloatError',
'pydantic.errors:FrozenSetError',
'pydantic.errors:FrozenSetMaxLengthError',
'pydantic.errors:FrozenSetMinLengthError',
'pydantic.errors:HashableError',
'pydantic.errors:IPv4AddressError',
'pydantic.errors:IPv4InterfaceError',
'pydantic.errors:IPv4NetworkError',
'pydantic.errors:IPv6AddressError',
'pydantic.errors:IPv6InterfaceError',
'pydantic.errors:IPv6NetworkError',
'pydantic.errors:IPvAnyAddressError',
'pydantic.errors:IPvAnyInterfaceError',
'pydantic.errors:IPvAnyNetworkError',
'pydantic.errors:IntEnumError',
'pydantic.errors:IntegerError',
'pydantic.errors:InvalidByteSize',
'pydantic.errors:InvalidByteSizeUnit',
'pydantic.errors:InvalidDiscriminator',
'pydantic.errors:InvalidLengthForBrand',
'pydantic.errors:JsonError',
'pydantic.errors:JsonTypeError',
'pydantic.errors:ListError',
'pydantic.errors:ListMaxLengthError',
'pydantic.errors:ListMinLengthError',
'pydantic.errors:ListUniqueItemsError',
'pydantic.errors:LuhnValidationError',
'pydantic.errors:MissingDiscriminator',
'pydantic.errors:MissingError',
'pydantic.errors:NoneIsAllowedError',
'pydantic.errors:NoneIsNotAllowedError',
'pydantic.errors:NotDigitError',
'pydantic.errors:NotNoneError',
'pydantic.errors:NumberNotGeError',
'pydantic.errors:NumberNotGtError',
'pydantic.errors:NumberNotLeError',
'pydantic.errors:NumberNotLtError',
'pydantic.errors:NumberNotMultipleError',
'pydantic.errors:PathError',
'pydantic.errors:PathNotADirectoryError',
'pydantic.errors:PathNotAFileError',
'pydantic.errors:PathNotExistsError',
'pydantic.errors:PatternError',
'pydantic.errors:PyObjectError',
'pydantic.errors:PydanticTypeError',
'pydantic.errors:PydanticValueError',
'pydantic.errors:SequenceError',
'pydantic.errors:SetError',
'pydantic.errors:SetMaxLengthError',
'pydantic.errors:SetMinLengthError',
'pydantic.errors:StrError',
'pydantic.errors:StrRegexError',
'pydantic.errors:StrictBoolError',
'pydantic.errors:SubclassError',
'pydantic.errors:TimeError',
'pydantic.errors:TupleError',
'pydantic.errors:TupleLengthError',
'pydantic.errors:UUIDError',
'pydantic.errors:UUIDVersionError',
'pydantic.errors:UrlError',
'pydantic.errors:UrlExtraError',
'pydantic.errors:UrlHostError',
'pydantic.errors:UrlHostTldError',
'pydantic.errors:UrlPortError',
'pydantic.errors:UrlSchemeError',
'pydantic.errors:UrlSchemePermittedError',
'pydantic.errors:UrlUserInfoError',
'pydantic.errors:WrongConstantError',
'pydantic.main:validate_model',
'pydantic.networks:stricturl',
'pydantic:parse_file_as',
'pydantic:parse_raw_as',
'pydantic:stricturl',
'pydantic.tools:parse_file_as',
'pydantic.tools:parse_raw_as',
'pydantic.types:ConstrainedBytes',
'pydantic.types:ConstrainedDate',
'pydantic.types:ConstrainedDecimal',
'pydantic.types:ConstrainedFloat',
'pydantic.types:ConstrainedFrozenSet',
'pydantic.types:ConstrainedInt',
'pydantic.types:ConstrainedList',
'pydantic.types:ConstrainedSet',
'pydantic.types:ConstrainedStr',
'pydantic.types:JsonWrapper',
'pydantic.types:NoneBytes',
'pydantic.types:NoneStr',
'pydantic.types:NoneStrBytes',
'pydantic.types:StrBytes',
'pydantic.typing:evaluate_forwardref',
'pydantic.typing:AbstractSetIntStr',
'pydantic.typing:AnyCallable',
'pydantic.typing:AnyClassMethod',
'pydantic.typing:CallableGenerator',
'pydantic.typing:DictAny',
'pydantic.typing:DictIntStrAny',
'pydantic.typing:DictStrAny',
'pydantic.typing:IntStr',
'pydantic.typing:ListStr',
'pydantic.typing:MappingIntStrAny',
'pydantic.typing:NoArgAnyCallable',
'pydantic.typing:NoneType',
'pydantic.typing:ReprArgs',
'pydantic.typing:SetStr',
'pydantic.typing:StrPath',
'pydantic.typing:TupleGenerator',
'pydantic.typing:WithArgsTypes',
'pydantic.typing:all_literal_values',
'pydantic.typing:display_as_type',
'pydantic.typing:get_all_type_hints',
'pydantic.typing:get_args',
'pydantic.typing:get_origin',
'pydantic.typing:get_sub_types',
'pydantic.typing:is_callable_type',
'pydantic.typing:is_classvar',
'pydantic.typing:is_finalvar',
'pydantic.typing:is_literal_type',
'pydantic.typing:is_namedtuple',
'pydantic.typing:is_new_type',
'pydantic.typing:is_none_type',
'pydantic.typing:is_typeddict',
'pydantic.typing:is_typeddict_special',
'pydantic.typing:is_union',
'pydantic.typing:new_type_supertype',
'pydantic.typing:resolve_annotations',
'pydantic.typing:typing_base',
'pydantic.typing:update_field_forward_refs',
'pydantic.typing:update_model_forward_refs',
'pydantic.utils:ClassAttribute',
'pydantic.utils:DUNDER_ATTRIBUTES',
'pydantic.utils:PyObjectStr',
'pydantic.utils:ValueItems',
'pydantic.utils:almost_equal_floats',
'pydantic.utils:get_discriminator_alias_and_values',
'pydantic.utils:get_model',
'pydantic.utils:get_unique_discriminator_alias',
'pydantic.utils:in_ipython',
'pydantic.utils:is_valid_identifier',
'pydantic.utils:path_type',
'pydantic.utils:validate_field_name',
'pydantic:validate_model',
}
def getattr_migration(module: str) -> Callable[[str], Any]:
"""Implement PEP 562 for objects that were either moved or removed on the migration
to V2.
Args:
module: The module name.
Returns:
A callable that will raise an error if the object is not found.
"""
# This avoids circular import with errors.py.
from .errors import PydanticImportError
def wrapper(name: str) -> object:
"""Raise an error if the object is not found, or warn if it was moved.
In case it was moved, it still returns the object.
Args:
name: The object name.
Returns:
The object.
"""
if name == '__path__':
raise AttributeError(f'module {module!r} has no attribute {name!r}')
import warnings
from ._internal._validators import import_string
import_path = f'{module}:{name}'
if import_path in MOVED_IN_V2.keys():
new_location = MOVED_IN_V2[import_path]
warnings.warn(f'`{import_path}` has been moved to `{new_location}`.')
return import_string(MOVED_IN_V2[import_path])
if import_path in DEPRECATED_MOVED_IN_V2:
# skip the warning here because a deprecation warning will be raised elsewhere
return import_string(DEPRECATED_MOVED_IN_V2[import_path])
if import_path in REDIRECT_TO_V1:
new_location = REDIRECT_TO_V1[import_path]
warnings.warn(
f'`{import_path}` has been removed. We are importing from `{new_location}` instead.'
'See the migration guide for more details: https://docs.pydantic.dev/latest/migration/'
)
return import_string(REDIRECT_TO_V1[import_path])
if import_path == 'pydantic:BaseSettings':
raise PydanticImportError(
'`BaseSettings` has been moved to the `pydantic-settings` package. '
f'See https://docs.pydantic.dev/{version_short()}/migration/#basesettings-has-moved-to-pydantic-settings '
'for more details.'
)
if import_path in REMOVED_IN_V2:
raise PydanticImportError(f'`{import_path}` has been removed in V2.')
globals: Dict[str, Any] = sys.modules[module].__dict__
if name in globals:
return globals[name]
raise AttributeError(f'module {module!r} has no attribute {name!r}')
return wrapper

View file

@ -0,0 +1,50 @@
"""Alias generators for converting between different capitalization conventions."""
import re
__all__ = ('to_pascal', 'to_camel', 'to_snake')
def to_pascal(snake: str) -> str:
"""Convert a snake_case string to PascalCase.
Args:
snake: The string to convert.
Returns:
The PascalCase string.
"""
camel = snake.title()
return re.sub('([0-9A-Za-z])_(?=[0-9A-Z])', lambda m: m.group(1), camel)
def to_camel(snake: str) -> str:
"""Convert a snake_case string to camelCase.
Args:
snake: The string to convert.
Returns:
The converted camelCase string.
"""
camel = to_pascal(snake)
return re.sub('(^_*[A-Z])', lambda m: m.group(1).lower(), camel)
def to_snake(camel: str) -> str:
"""Convert a PascalCase or camelCase string to snake_case.
Args:
camel: The string to convert.
Returns:
The converted string in snake_case.
"""
# Handle the sequence of uppercase letters followed by a lowercase letter
snake = re.sub(r'([A-Z]+)([A-Z][a-z])', lambda m: f'{m.group(1)}_{m.group(2)}', camel)
# Insert an underscore between a lowercase letter and an uppercase letter
snake = re.sub(r'([a-z])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
# Insert an underscore between a digit and an uppercase letter
snake = re.sub(r'([0-9])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
# Insert an underscore between a lowercase letter and a digit
snake = re.sub(r'([a-z])([0-9])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
return snake.lower()

112
lib/pydantic/aliases.py Normal file
View file

@ -0,0 +1,112 @@
"""Support for alias configurations."""
from __future__ import annotations
import dataclasses
from typing import Callable, Literal
from ._internal import _internal_dataclass
__all__ = ('AliasGenerator', 'AliasPath', 'AliasChoices')
@dataclasses.dataclass(**_internal_dataclass.slots_true)
class AliasPath:
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/alias#aliaspath-and-aliaschoices
A data class used by `validation_alias` as a convenience to create aliases.
Attributes:
path: A list of string or integer aliases.
"""
path: list[int | str]
def __init__(self, first_arg: str, *args: str | int) -> None:
self.path = [first_arg] + list(args)
def convert_to_aliases(self) -> list[str | int]:
"""Converts arguments to a list of string or integer aliases.
Returns:
The list of aliases.
"""
return self.path
@dataclasses.dataclass(**_internal_dataclass.slots_true)
class AliasChoices:
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/alias#aliaspath-and-aliaschoices
A data class used by `validation_alias` as a convenience to create aliases.
Attributes:
choices: A list containing a string or `AliasPath`.
"""
choices: list[str | AliasPath]
def __init__(self, first_choice: str | AliasPath, *choices: str | AliasPath) -> None:
self.choices = [first_choice] + list(choices)
def convert_to_aliases(self) -> list[list[str | int]]:
"""Converts arguments to a list of lists containing string or integer aliases.
Returns:
The list of aliases.
"""
aliases: list[list[str | int]] = []
for c in self.choices:
if isinstance(c, AliasPath):
aliases.append(c.convert_to_aliases())
else:
aliases.append([c])
return aliases
@dataclasses.dataclass(**_internal_dataclass.slots_true)
class AliasGenerator:
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/alias#using-an-aliasgenerator
A data class used by `alias_generator` as a convenience to create various aliases.
Attributes:
alias: A callable that takes a field name and returns an alias for it.
validation_alias: A callable that takes a field name and returns a validation alias for it.
serialization_alias: A callable that takes a field name and returns a serialization alias for it.
"""
alias: Callable[[str], str] | None = None
validation_alias: Callable[[str], str | AliasPath | AliasChoices] | None = None
serialization_alias: Callable[[str], str] | None = None
def _generate_alias(
self,
alias_kind: Literal['alias', 'validation_alias', 'serialization_alias'],
allowed_types: tuple[type[str] | type[AliasPath] | type[AliasChoices], ...],
field_name: str,
) -> str | AliasPath | AliasChoices | None:
"""Generate an alias of the specified kind. Returns None if the alias generator is None.
Raises:
TypeError: If the alias generator produces an invalid type.
"""
alias = None
if alias_generator := getattr(self, alias_kind):
alias = alias_generator(field_name)
if alias and not isinstance(alias, allowed_types):
raise TypeError(
f'Invalid `{alias_kind}` type. `{alias_kind}` generator must produce one of `{allowed_types}`'
)
return alias
def generate_aliases(self, field_name: str) -> tuple[str | None, str | AliasPath | AliasChoices | None, str | None]:
"""Generate `alias`, `validation_alias`, and `serialization_alias` for a field.
Returns:
A tuple of three aliases - validation, alias, and serialization.
"""
alias = self._generate_alias('alias', (str,), field_name)
validation_alias = self._generate_alias('validation_alias', (str, AliasChoices, AliasPath), field_name)
serialization_alias = self._generate_alias('serialization_alias', (str,), field_name)
return alias, validation_alias, serialization_alias # type: ignore

View file

@ -0,0 +1,120 @@
"""Type annotations to use with `__get_pydantic_core_schema__` and `__get_pydantic_json_schema__`."""
from __future__ import annotations as _annotations
from typing import TYPE_CHECKING, Any, Union
from pydantic_core import core_schema
if TYPE_CHECKING:
from .json_schema import JsonSchemaMode, JsonSchemaValue
CoreSchemaOrField = Union[
core_schema.CoreSchema,
core_schema.ModelField,
core_schema.DataclassField,
core_schema.TypedDictField,
core_schema.ComputedField,
]
__all__ = 'GetJsonSchemaHandler', 'GetCoreSchemaHandler'
class GetJsonSchemaHandler:
"""Handler to call into the next JSON schema generation function.
Attributes:
mode: Json schema mode, can be `validation` or `serialization`.
"""
mode: JsonSchemaMode
def __call__(self, __core_schema: CoreSchemaOrField) -> JsonSchemaValue:
"""Call the inner handler and get the JsonSchemaValue it returns.
This will call the next JSON schema modifying function up until it calls
into `pydantic.json_schema.GenerateJsonSchema`, which will raise a
`pydantic.errors.PydanticInvalidForJsonSchema` error if it cannot generate
a JSON schema.
Args:
__core_schema: A `pydantic_core.core_schema.CoreSchema`.
Returns:
JsonSchemaValue: The JSON schema generated by the inner JSON schema modify
functions.
"""
raise NotImplementedError
def resolve_ref_schema(self, __maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue:
"""Get the real schema for a `{"$ref": ...}` schema.
If the schema given is not a `$ref` schema, it will be returned as is.
This means you don't have to check before calling this function.
Args:
__maybe_ref_json_schema: A JsonSchemaValue which may be a `$ref` schema.
Raises:
LookupError: If the ref is not found.
Returns:
JsonSchemaValue: A JsonSchemaValue that has no `$ref`.
"""
raise NotImplementedError
class GetCoreSchemaHandler:
"""Handler to call into the next CoreSchema schema generation function."""
def __call__(self, __source_type: Any) -> core_schema.CoreSchema:
"""Call the inner handler and get the CoreSchema it returns.
This will call the next CoreSchema modifying function up until it calls
into Pydantic's internal schema generation machinery, which will raise a
`pydantic.errors.PydanticSchemaGenerationError` error if it cannot generate
a CoreSchema for the given source type.
Args:
__source_type: The input type.
Returns:
CoreSchema: The `pydantic-core` CoreSchema generated.
"""
raise NotImplementedError
def generate_schema(self, __source_type: Any) -> core_schema.CoreSchema:
"""Generate a schema unrelated to the current context.
Use this function if e.g. you are handling schema generation for a sequence
and want to generate a schema for its items.
Otherwise, you may end up doing something like applying a `min_length` constraint
that was intended for the sequence itself to its items!
Args:
__source_type: The input type.
Returns:
CoreSchema: The `pydantic-core` CoreSchema generated.
"""
raise NotImplementedError
def resolve_ref_schema(self, __maybe_ref_schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""Get the real schema for a `definition-ref` schema.
If the schema given is not a `definition-ref` schema, it will be returned as is.
This means you don't have to check before calling this function.
Args:
__maybe_ref_schema: A `CoreSchema`, `ref`-based or not.
Raises:
LookupError: If the `ref` is not found.
Returns:
A concrete `CoreSchema`.
"""
raise NotImplementedError
@property
def field_name(self) -> str | None:
"""Get the name of the closest field to this validator."""
raise NotImplementedError
def _get_types_namespace(self) -> dict[str, Any] | None:
"""Internal method used during type resolution for serializer annotations."""
raise NotImplementedError

View file

@ -1,342 +1,4 @@
import warnings """`class_validators` module is a backport module from V1."""
from collections import ChainMap from ._migration import getattr_migration
from functools import wraps
from itertools import chain
from types import FunctionType
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, overload
from .errors import ConfigError __getattr__ = getattr_migration(__name__)
from .typing import AnyCallable
from .utils import ROOT_KEY, in_ipython
if TYPE_CHECKING:
from .typing import AnyClassMethod
class Validator:
__slots__ = 'func', 'pre', 'each_item', 'always', 'check_fields', 'skip_on_failure'
def __init__(
self,
func: AnyCallable,
pre: bool = False,
each_item: bool = False,
always: bool = False,
check_fields: bool = False,
skip_on_failure: bool = False,
):
self.func = func
self.pre = pre
self.each_item = each_item
self.always = always
self.check_fields = check_fields
self.skip_on_failure = skip_on_failure
if TYPE_CHECKING:
from inspect import Signature
from .config import BaseConfig
from .fields import ModelField
from .types import ModelOrDc
ValidatorCallable = Callable[[Optional[ModelOrDc], Any, Dict[str, Any], ModelField, Type[BaseConfig]], Any]
ValidatorsList = List[ValidatorCallable]
ValidatorListDict = Dict[str, List[Validator]]
_FUNCS: Set[str] = set()
VALIDATOR_CONFIG_KEY = '__validator_config__'
ROOT_VALIDATOR_CONFIG_KEY = '__root_validator_config__'
def validator(
*fields: str,
pre: bool = False,
each_item: bool = False,
always: bool = False,
check_fields: bool = True,
whole: bool = None,
allow_reuse: bool = False,
) -> Callable[[AnyCallable], 'AnyClassMethod']:
"""
Decorate methods on the class indicating that they should be used to validate fields
:param fields: which field(s) the method should be called on
:param pre: whether or not this validator should be called before the standard validators (else after)
:param each_item: for complex objects (sets, lists etc.) whether to validate individual elements rather than the
whole object
:param always: whether this method and other validators should be called even if the value is missing
:param check_fields: whether to check that the fields actually exist on the model
:param allow_reuse: whether to track and raise an error if another validator refers to the decorated function
"""
if not fields:
raise ConfigError('validator with no fields specified')
elif isinstance(fields[0], FunctionType):
raise ConfigError(
"validators should be used with fields and keyword arguments, not bare. " # noqa: Q000
"E.g. usage should be `@validator('<field_name>', ...)`"
)
elif not all(isinstance(field, str) for field in fields):
raise ConfigError(
"validator fields should be passed as separate string args. " # noqa: Q000
"E.g. usage should be `@validator('<field_name_1>', '<field_name_2>', ...)`"
)
if whole is not None:
warnings.warn(
'The "whole" keyword argument is deprecated, use "each_item" (inverse meaning, default False) instead',
DeprecationWarning,
)
assert each_item is False, '"each_item" and "whole" conflict, remove "whole"'
each_item = not whole
def dec(f: AnyCallable) -> 'AnyClassMethod':
f_cls = _prepare_validator(f, allow_reuse)
setattr(
f_cls,
VALIDATOR_CONFIG_KEY,
(
fields,
Validator(func=f_cls.__func__, pre=pre, each_item=each_item, always=always, check_fields=check_fields),
),
)
return f_cls
return dec
@overload
def root_validator(_func: AnyCallable) -> 'AnyClassMethod':
...
@overload
def root_validator(
*, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False
) -> Callable[[AnyCallable], 'AnyClassMethod']:
...
def root_validator(
_func: Optional[AnyCallable] = None, *, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False
) -> Union['AnyClassMethod', Callable[[AnyCallable], 'AnyClassMethod']]:
"""
Decorate methods on a model indicating that they should be used to validate (and perhaps modify) data either
before or after standard model parsing/validation is performed.
"""
if _func:
f_cls = _prepare_validator(_func, allow_reuse)
setattr(
f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure)
)
return f_cls
def dec(f: AnyCallable) -> 'AnyClassMethod':
f_cls = _prepare_validator(f, allow_reuse)
setattr(
f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure)
)
return f_cls
return dec
def _prepare_validator(function: AnyCallable, allow_reuse: bool) -> 'AnyClassMethod':
"""
Avoid validators with duplicated names since without this, validators can be overwritten silently
which generally isn't the intended behaviour, don't run in ipython (see #312) or if allow_reuse is False.
"""
f_cls = function if isinstance(function, classmethod) else classmethod(function)
if not in_ipython() and not allow_reuse:
ref = f_cls.__func__.__module__ + '.' + f_cls.__func__.__qualname__
if ref in _FUNCS:
raise ConfigError(f'duplicate validator function "{ref}"; if this is intended, set `allow_reuse=True`')
_FUNCS.add(ref)
return f_cls
class ValidatorGroup:
def __init__(self, validators: 'ValidatorListDict') -> None:
self.validators = validators
self.used_validators = {'*'}
def get_validators(self, name: str) -> Optional[Dict[str, Validator]]:
self.used_validators.add(name)
validators = self.validators.get(name, [])
if name != ROOT_KEY:
validators += self.validators.get('*', [])
if validators:
return {v.func.__name__: v for v in validators}
else:
return None
def check_for_unused(self) -> None:
unused_validators = set(
chain.from_iterable(
(v.func.__name__ for v in self.validators[f] if v.check_fields)
for f in (self.validators.keys() - self.used_validators)
)
)
if unused_validators:
fn = ', '.join(unused_validators)
raise ConfigError(
f"Validators defined with incorrect fields: {fn} " # noqa: Q000
f"(use check_fields=False if you're inheriting from the model and intended this)"
)
def extract_validators(namespace: Dict[str, Any]) -> Dict[str, List[Validator]]:
validators: Dict[str, List[Validator]] = {}
for var_name, value in namespace.items():
validator_config = getattr(value, VALIDATOR_CONFIG_KEY, None)
if validator_config:
fields, v = validator_config
for field in fields:
if field in validators:
validators[field].append(v)
else:
validators[field] = [v]
return validators
def extract_root_validators(namespace: Dict[str, Any]) -> Tuple[List[AnyCallable], List[Tuple[bool, AnyCallable]]]:
from inspect import signature
pre_validators: List[AnyCallable] = []
post_validators: List[Tuple[bool, AnyCallable]] = []
for name, value in namespace.items():
validator_config: Optional[Validator] = getattr(value, ROOT_VALIDATOR_CONFIG_KEY, None)
if validator_config:
sig = signature(validator_config.func)
args = list(sig.parameters.keys())
if args[0] == 'self':
raise ConfigError(
f'Invalid signature for root validator {name}: {sig}, "self" not permitted as first argument, '
f'should be: (cls, values).'
)
if len(args) != 2:
raise ConfigError(f'Invalid signature for root validator {name}: {sig}, should be: (cls, values).')
# check function signature
if validator_config.pre:
pre_validators.append(validator_config.func)
else:
post_validators.append((validator_config.skip_on_failure, validator_config.func))
return pre_validators, post_validators
def inherit_validators(base_validators: 'ValidatorListDict', validators: 'ValidatorListDict') -> 'ValidatorListDict':
for field, field_validators in base_validators.items():
if field not in validators:
validators[field] = []
validators[field] += field_validators
return validators
def make_generic_validator(validator: AnyCallable) -> 'ValidatorCallable':
"""
Make a generic function which calls a validator with the right arguments.
Unfortunately other approaches (eg. return a partial of a function that builds the arguments) is slow,
hence this laborious way of doing things.
It's done like this so validators don't all need **kwargs in their signature, eg. any combination of
the arguments "values", "fields" and/or "config" are permitted.
"""
from inspect import signature
sig = signature(validator)
args = list(sig.parameters.keys())
first_arg = args.pop(0)
if first_arg == 'self':
raise ConfigError(
f'Invalid signature for validator {validator}: {sig}, "self" not permitted as first argument, '
f'should be: (cls, value, values, config, field), "values", "config" and "field" are all optional.'
)
elif first_arg == 'cls':
# assume the second argument is value
return wraps(validator)(_generic_validator_cls(validator, sig, set(args[1:])))
else:
# assume the first argument was value which has already been removed
return wraps(validator)(_generic_validator_basic(validator, sig, set(args)))
def prep_validators(v_funcs: Iterable[AnyCallable]) -> 'ValidatorsList':
return [make_generic_validator(f) for f in v_funcs if f]
all_kwargs = {'values', 'field', 'config'}
def _generic_validator_cls(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable':
# assume the first argument is value
has_kwargs = False
if 'kwargs' in args:
has_kwargs = True
args -= {'kwargs'}
if not args.issubset(all_kwargs):
raise ConfigError(
f'Invalid signature for validator {validator}: {sig}, should be: '
f'(cls, value, values, config, field), "values", "config" and "field" are all optional.'
)
if has_kwargs:
return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config)
elif args == set():
return lambda cls, v, values, field, config: validator(cls, v)
elif args == {'values'}:
return lambda cls, v, values, field, config: validator(cls, v, values=values)
elif args == {'field'}:
return lambda cls, v, values, field, config: validator(cls, v, field=field)
elif args == {'config'}:
return lambda cls, v, values, field, config: validator(cls, v, config=config)
elif args == {'values', 'field'}:
return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field)
elif args == {'values', 'config'}:
return lambda cls, v, values, field, config: validator(cls, v, values=values, config=config)
elif args == {'field', 'config'}:
return lambda cls, v, values, field, config: validator(cls, v, field=field, config=config)
else:
# args == {'values', 'field', 'config'}
return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config)
def _generic_validator_basic(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable':
has_kwargs = False
if 'kwargs' in args:
has_kwargs = True
args -= {'kwargs'}
if not args.issubset(all_kwargs):
raise ConfigError(
f'Invalid signature for validator {validator}: {sig}, should be: '
f'(value, values, config, field), "values", "config" and "field" are all optional.'
)
if has_kwargs:
return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config)
elif args == set():
return lambda cls, v, values, field, config: validator(v)
elif args == {'values'}:
return lambda cls, v, values, field, config: validator(v, values=values)
elif args == {'field'}:
return lambda cls, v, values, field, config: validator(v, field=field)
elif args == {'config'}:
return lambda cls, v, values, field, config: validator(v, config=config)
elif args == {'values', 'field'}:
return lambda cls, v, values, field, config: validator(v, values=values, field=field)
elif args == {'values', 'config'}:
return lambda cls, v, values, field, config: validator(v, values=values, config=config)
elif args == {'field', 'config'}:
return lambda cls, v, values, field, config: validator(v, field=field, config=config)
else:
# args == {'values', 'field', 'config'}
return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config)
def gather_all_validators(type_: 'ModelOrDc') -> Dict[str, 'AnyClassMethod']:
all_attributes = ChainMap(*[cls.__dict__ for cls in type_.__mro__]) # type: ignore[arg-type,var-annotated]
return {
k: v
for k, v in all_attributes.items()
if hasattr(v, VALIDATOR_CONFIG_KEY) or hasattr(v, ROOT_VALIDATOR_CONFIG_KEY)
}

View file

@ -1,22 +1,28 @@
""" """Color definitions are used as per the CSS3
Color definitions are used as per CSS3 specification: [CSS Color Module Level 3](http://www.w3.org/TR/css3-color/#svg-color) specification.
http://www.w3.org/TR/css3-color/#svg-color
A few colors have multiple names referring to the sames colors, eg. `grey` and `gray` or `aqua` and `cyan`. A few colors have multiple names referring to the sames colors, eg. `grey` and `gray` or `aqua` and `cyan`.
In these cases the LAST color when sorted alphabetically takes preferences, In these cases the _last_ color when sorted alphabetically takes preferences,
eg. Color((0, 255, 255)).as_named() == 'cyan' because "cyan" comes after "aqua". eg. `Color((0, 255, 255)).as_named() == 'cyan'` because "cyan" comes after "aqua".
Warning: Deprecated
The `Color` class is deprecated, use `pydantic_extra_types` instead.
See [`pydantic-extra-types.Color`](../usage/types/extra_types/color_types.md)
for more information.
""" """
import math import math
import re import re
from colorsys import hls_to_rgb, rgb_to_hls from colorsys import hls_to_rgb, rgb_to_hls
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast from typing import Any, Callable, Optional, Tuple, Type, Union, cast
from .errors import ColorError from pydantic_core import CoreSchema, PydanticCustomError, core_schema
from .utils import Representation, almost_equal_floats from typing_extensions import deprecated
if TYPE_CHECKING: from ._internal import _repr
from .typing import CallableGenerator, ReprArgs from ._internal._schema_generation_shared import GetJsonSchemaHandler as _GetJsonSchemaHandler
from .json_schema import JsonSchemaValue
from .warnings import PydanticDeprecatedSince20
ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]] ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]]
ColorType = Union[ColorTuple, str] ColorType = Union[ColorTuple, str]
@ -24,9 +30,7 @@ HslColorTuple = Union[Tuple[float, float, float], Tuple[float, float, float, flo
class RGBA: class RGBA:
""" """Internal use only as a representation of a color."""
Internal use only as a representation of a color.
"""
__slots__ = 'r', 'g', 'b', 'alpha', '_tuple' __slots__ = 'r', 'g', 'b', 'alpha', '_tuple'
@ -43,24 +47,35 @@ class RGBA:
# these are not compiled here to avoid import slowdown, they'll be compiled the first time they're used, then cached # these are not compiled here to avoid import slowdown, they'll be compiled the first time they're used, then cached
r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*'
r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*'
_r_255 = r'(\d{1,3}(?:\.\d+)?)' _r_255 = r'(\d{1,3}(?:\.\d+)?)'
_r_comma = r'\s*,\s*' _r_comma = r'\s*,\s*'
r_rgb = fr'\s*rgb\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}\)\s*'
_r_alpha = r'(\d(?:\.\d+)?|\.\d+|\d{1,2}%)' _r_alpha = r'(\d(?:\.\d+)?|\.\d+|\d{1,2}%)'
r_rgba = fr'\s*rgba\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_alpha}\s*\)\s*'
_r_h = r'(-?\d+(?:\.\d+)?|-?\.\d+)(deg|rad|turn)?' _r_h = r'(-?\d+(?:\.\d+)?|-?\.\d+)(deg|rad|turn)?'
_r_sl = r'(\d{1,3}(?:\.\d+)?)%' _r_sl = r'(\d{1,3}(?:\.\d+)?)%'
r_hsl = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}\s*\)\s*' r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*'
r_hsla = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}{_r_comma}{_r_alpha}\s*\)\s*' r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*'
# CSS3 RGB examples: rgb(0, 0, 0), rgba(0, 0, 0, 0.5), rgba(0, 0, 0, 50%)
r_rgb = rf'\s*rgba?\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}(?:{_r_comma}{_r_alpha})?\s*\)\s*'
# CSS3 HSL examples: hsl(270, 60%, 50%), hsla(270, 60%, 50%, 0.5), hsla(270, 60%, 50%, 50%)
r_hsl = rf'\s*hsla?\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}(?:{_r_comma}{_r_alpha})?\s*\)\s*'
# CSS4 RGB examples: rgb(0 0 0), rgb(0 0 0 / 0.5), rgb(0 0 0 / 50%), rgba(0 0 0 / 50%)
r_rgb_v4_style = rf'\s*rgba?\(\s*{_r_255}\s+{_r_255}\s+{_r_255}(?:\s*/\s*{_r_alpha})?\s*\)\s*'
# CSS4 HSL examples: hsl(270 60% 50%), hsl(270 60% 50% / 0.5), hsl(270 60% 50% / 50%), hsla(270 60% 50% / 50%)
r_hsl_v4_style = rf'\s*hsla?\(\s*{_r_h}\s+{_r_sl}\s+{_r_sl}(?:\s*/\s*{_r_alpha})?\s*\)\s*'
# colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used # colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used
repeat_colors = {int(c * 2, 16) for c in '0123456789abcdef'} repeat_colors = {int(c * 2, 16) for c in '0123456789abcdef'}
rads = 2 * math.pi rads = 2 * math.pi
class Color(Representation): @deprecated(
'The `Color` class is deprecated, use `pydantic_extra_types` instead. '
'See https://docs.pydantic.dev/latest/api/pydantic_extra_types_color/.',
category=PydanticDeprecatedSince20,
)
class Color(_repr.Representation):
"""Represents a color."""
__slots__ = '_original', '_rgba' __slots__ = '_original', '_rgba'
def __init__(self, value: ColorType) -> None: def __init__(self, value: ColorType) -> None:
@ -74,22 +89,39 @@ class Color(Representation):
self._rgba = value._rgba self._rgba = value._rgba
value = value._original value = value._original
else: else:
raise ColorError(reason='value must be a tuple, list or string') raise PydanticCustomError(
'color_error', 'value is not a valid color: value must be a tuple, list or string'
)
# if we've got here value must be a valid color # if we've got here value must be a valid color
self._original = value self._original = value
@classmethod @classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: _GetJsonSchemaHandler
) -> JsonSchemaValue:
field_schema = {}
field_schema.update(type='string', format='color') field_schema.update(type='string', format='color')
return field_schema
def original(self) -> ColorType: def original(self) -> ColorType:
""" """Original value passed to `Color`."""
Original value passed to Color
"""
return self._original return self._original
def as_named(self, *, fallback: bool = False) -> str: def as_named(self, *, fallback: bool = False) -> str:
"""Returns the name of the color if it can be found in `COLORS_BY_VALUE` dictionary,
otherwise returns the hexadecimal representation of the color or raises `ValueError`.
Args:
fallback: If True, falls back to returning the hexadecimal representation of
the color instead of raising a ValueError when no named color is found.
Returns:
The name of the color, or the hexadecimal representation of the color.
Raises:
ValueError: When no named color is found and fallback is `False`.
"""
if self._rgba.alpha is None: if self._rgba.alpha is None:
rgb = cast(Tuple[int, int, int], self.as_rgb_tuple()) rgb = cast(Tuple[int, int, int], self.as_rgb_tuple())
try: try:
@ -103,9 +135,13 @@ class Color(Representation):
return self.as_hex() return self.as_hex()
def as_hex(self) -> str: def as_hex(self) -> str:
""" """Returns the hexadecimal representation of the color.
Hex string representing the color can be 3, 4, 6 or 8 characters depending on whether the string
Hex string representing the color can be 3, 4, 6, or 8 characters depending on whether the string
a "short" representation of the color is possible and whether there's an alpha channel. a "short" representation of the color is possible and whether there's an alpha channel.
Returns:
The hexadecimal representation of the color.
""" """
values = [float_to_255(c) for c in self._rgba[:3]] values = [float_to_255(c) for c in self._rgba[:3]]
if self._rgba.alpha is not None: if self._rgba.alpha is not None:
@ -117,9 +153,7 @@ class Color(Representation):
return '#' + as_hex return '#' + as_hex
def as_rgb(self) -> str: def as_rgb(self) -> str:
""" """Color as an `rgb(<r>, <g>, <b>)` or `rgba(<r>, <g>, <b>, <a>)` string."""
Color as an rgb(<r>, <g>, <b>) or rgba(<r>, <g>, <b>, <a>) string.
"""
if self._rgba.alpha is None: if self._rgba.alpha is None:
return f'rgb({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)})' return f'rgb({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)})'
else: else:
@ -129,14 +163,18 @@ class Color(Representation):
) )
def as_rgb_tuple(self, *, alpha: Optional[bool] = None) -> ColorTuple: def as_rgb_tuple(self, *, alpha: Optional[bool] = None) -> ColorTuple:
""" """Returns the color as an RGB or RGBA tuple.
Color as an RGB or RGBA tuple; red, green and blue are in the range 0 to 255, alpha if included is
in the range 0 to 1.
:param alpha: whether to include the alpha channel, options are Args:
None - (default) include alpha only if it's set (e.g. not None) alpha: Whether to include the alpha channel. There are three options for this input:
True - always include alpha,
False - always omit alpha, - `None` (default): Include alpha only if it's set. (e.g. not `None`)
- `True`: Always include alpha.
- `False`: Always omit alpha.
Returns:
A tuple that contains the values of the red, green, and blue channels in the range 0 to 255.
If alpha is included, it is in the range 0 to 1.
""" """
r, g, b = (float_to_255(c) for c in self._rgba[:3]) r, g, b = (float_to_255(c) for c in self._rgba[:3])
if alpha is None: if alpha is None:
@ -151,9 +189,7 @@ class Color(Representation):
return r, g, b return r, g, b
def as_hsl(self) -> str: def as_hsl(self) -> str:
""" """Color as an `hsl(<h>, <s>, <l>)` or `hsl(<h>, <s>, <l>, <a>)` string."""
Color as an hsl(<h>, <s>, <l>) or hsl(<h>, <s>, <l>, <a>) string.
"""
if self._rgba.alpha is None: if self._rgba.alpha is None:
h, s, li = self.as_hsl_tuple(alpha=False) # type: ignore h, s, li = self.as_hsl_tuple(alpha=False) # type: ignore
return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%})' return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%})'
@ -162,18 +198,23 @@ class Color(Representation):
return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%}, {round(a, 2)})' return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%}, {round(a, 2)})'
def as_hsl_tuple(self, *, alpha: Optional[bool] = None) -> HslColorTuple: def as_hsl_tuple(self, *, alpha: Optional[bool] = None) -> HslColorTuple:
""" """Returns the color as an HSL or HSLA tuple.
Color as an HSL or HSLA tuple, e.g. hue, saturation, lightness and optionally alpha; all elements are in
the range 0 to 1.
NOTE: this is HSL as used in HTML and most other places, not HLS as used in python's colorsys. Args:
alpha: Whether to include the alpha channel.
:param alpha: whether to include the alpha channel, options are - `None` (default): Include the alpha channel only if it's set (e.g. not `None`).
None - (default) include alpha only if it's set (e.g. not None) - `True`: Always include alpha.
True - always include alpha, - `False`: Always omit alpha.
False - always omit alpha,
Returns:
The color as a tuple of hue, saturation, lightness, and alpha (if included).
All elements are in the range 0 to 1.
Note:
This is HSL as used in HTML and most other places, not HLS as used in Python's `colorsys`.
""" """
h, l, s = rgb_to_hls(self._rgba.r, self._rgba.g, self._rgba.b) h, l, s = rgb_to_hls(self._rgba.r, self._rgba.g, self._rgba.b) # noqa: E741
if alpha is None: if alpha is None:
if self._rgba.alpha is None: if self._rgba.alpha is None:
return h, s, l return h, s, l
@ -189,14 +230,22 @@ class Color(Representation):
return 1 if self._rgba.alpha is None else self._rgba.alpha return 1 if self._rgba.alpha is None else self._rgba.alpha
@classmethod @classmethod
def __get_validators__(cls) -> 'CallableGenerator': def __get_pydantic_core_schema__(
yield cls cls, source: Type[Any], handler: Callable[[Any], CoreSchema]
) -> core_schema.CoreSchema:
return core_schema.with_info_plain_validator_function(
cls._validate, serialization=core_schema.to_string_ser_schema()
)
@classmethod
def _validate(cls, __input_value: Any, _: Any) -> 'Color':
return cls(__input_value)
def __str__(self) -> str: def __str__(self) -> str:
return self.as_named(fallback=True) return self.as_named(fallback=True)
def __repr_args__(self) -> 'ReprArgs': def __repr_args__(self) -> '_repr.ReprArgs':
return [(None, self.as_named(fallback=True))] + [('rgb', self.as_rgb_tuple())] # type: ignore return [(None, self.as_named(fallback=True))] + [('rgb', self.as_rgb_tuple())]
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
return isinstance(other, Color) and self.as_rgb_tuple() == other.as_rgb_tuple() return isinstance(other, Color) and self.as_rgb_tuple() == other.as_rgb_tuple()
@ -206,8 +255,16 @@ class Color(Representation):
def parse_tuple(value: Tuple[Any, ...]) -> RGBA: def parse_tuple(value: Tuple[Any, ...]) -> RGBA:
""" """Parse a tuple or list to get RGBA values.
Parse a tuple or list as a color.
Args:
value: A tuple or list.
Returns:
An `RGBA` tuple parsed from the input tuple.
Raises:
PydanticCustomError: If tuple is not valid.
""" """
if len(value) == 3: if len(value) == 3:
r, g, b = (parse_color_value(v) for v in value) r, g, b = (parse_color_value(v) for v in value)
@ -216,17 +273,28 @@ def parse_tuple(value: Tuple[Any, ...]) -> RGBA:
r, g, b = (parse_color_value(v) for v in value[:3]) r, g, b = (parse_color_value(v) for v in value[:3])
return RGBA(r, g, b, parse_float_alpha(value[3])) return RGBA(r, g, b, parse_float_alpha(value[3]))
else: else:
raise ColorError(reason='tuples must have length 3 or 4') raise PydanticCustomError('color_error', 'value is not a valid color: tuples must have length 3 or 4')
def parse_str(value: str) -> RGBA: def parse_str(value: str) -> RGBA:
""" """Parse a string representing a color to an RGBA tuple.
Parse a string to an RGBA tuple, trying the following formats (in this order):
* named color, see COLORS_BY_NAME below Possible formats for the input string include:
* named color, see `COLORS_BY_NAME`
* hex short eg. `<prefix>fff` (prefix can be `#`, `0x` or nothing) * hex short eg. `<prefix>fff` (prefix can be `#`, `0x` or nothing)
* hex long eg. `<prefix>ffffff` (prefix can be `#`, `0x` or nothing) * hex long eg. `<prefix>ffffff` (prefix can be `#`, `0x` or nothing)
* `rgb(<r>, <g>, <b>) ` * `rgb(<r>, <g>, <b>)`
* `rgba(<r>, <g>, <b>, <a>)` * `rgba(<r>, <g>, <b>, <a>)`
Args:
value: A string representing a color.
Returns:
An `RGBA` tuple parsed from the input string.
Raises:
ValueError: If the input string cannot be parsed to an RGBA tuple.
""" """
value_lower = value.lower() value_lower = value.lower()
try: try:
@ -256,49 +324,70 @@ def parse_str(value: str) -> RGBA:
alpha = None alpha = None
return ints_to_rgba(r, g, b, alpha) return ints_to_rgba(r, g, b, alpha)
m = re.fullmatch(r_rgb, value_lower) m = re.fullmatch(r_rgb, value_lower) or re.fullmatch(r_rgb_v4_style, value_lower)
if m:
return ints_to_rgba(*m.groups(), None) # type: ignore
m = re.fullmatch(r_rgba, value_lower)
if m: if m:
return ints_to_rgba(*m.groups()) # type: ignore return ints_to_rgba(*m.groups()) # type: ignore
m = re.fullmatch(r_hsl, value_lower) m = re.fullmatch(r_hsl, value_lower) or re.fullmatch(r_hsl_v4_style, value_lower)
if m: if m:
h, h_units, s, l_ = m.groups() return parse_hsl(*m.groups()) # type: ignore
return parse_hsl(h, h_units, s, l_)
m = re.fullmatch(r_hsla, value_lower) raise PydanticCustomError('color_error', 'value is not a valid color: string not recognised as a valid color')
if m:
h, h_units, s, l_, a = m.groups()
return parse_hsl(h, h_units, s, l_, parse_float_alpha(a))
raise ColorError(reason='string not recognised as a valid color')
def ints_to_rgba(r: Union[int, str], g: Union[int, str], b: Union[int, str], alpha: Optional[float]) -> RGBA: def ints_to_rgba(r: Union[int, str], g: Union[int, str], b: Union[int, str], alpha: Optional[float] = None) -> RGBA:
"""Converts integer or string values for RGB color and an optional alpha value to an `RGBA` object.
Args:
r: An integer or string representing the red color value.
g: An integer or string representing the green color value.
b: An integer or string representing the blue color value.
alpha: A float representing the alpha value. Defaults to None.
Returns:
An instance of the `RGBA` class with the corresponding color and alpha values.
"""
return RGBA(parse_color_value(r), parse_color_value(g), parse_color_value(b), parse_float_alpha(alpha)) return RGBA(parse_color_value(r), parse_color_value(g), parse_color_value(b), parse_float_alpha(alpha))
def parse_color_value(value: Union[int, str], max_val: int = 255) -> float: def parse_color_value(value: Union[int, str], max_val: int = 255) -> float:
""" """Parse the color value provided and return a number between 0 and 1.
Parse a value checking it's a valid int in the range 0 to max_val and divide by max_val to give a number
in the range 0 to 1 Args:
value: An integer or string color value.
max_val: Maximum range value. Defaults to 255.
Raises:
PydanticCustomError: If the value is not a valid color.
Returns:
A number between 0 and 1.
""" """
try: try:
color = float(value) color = float(value)
except ValueError: except ValueError:
raise ColorError(reason='color values must be a valid number') raise PydanticCustomError('color_error', 'value is not a valid color: color values must be a valid number')
if 0 <= color <= max_val: if 0 <= color <= max_val:
return color / max_val return color / max_val
else: else:
raise ColorError(reason=f'color values must be in the range 0 to {max_val}') raise PydanticCustomError(
'color_error',
'value is not a valid color: color values must be in the range 0 to {max_val}',
{'max_val': max_val},
)
def parse_float_alpha(value: Union[None, str, float, int]) -> Optional[float]: def parse_float_alpha(value: Union[None, str, float, int]) -> Optional[float]:
""" """Parse an alpha value checking it's a valid float in the range 0 to 1.
Parse a value checking it's a valid float in the range 0 to 1
Args:
value: The input value to parse.
Returns:
The parsed value as a float, or `None` if the value was None or equal 1.
Raises:
PydanticCustomError: If the input value cannot be successfully parsed as a float in the expected range.
""" """
if value is None: if value is None:
return None return None
@ -308,19 +397,28 @@ def parse_float_alpha(value: Union[None, str, float, int]) -> Optional[float]:
else: else:
alpha = float(value) alpha = float(value)
except ValueError: except ValueError:
raise ColorError(reason='alpha values must be a valid float') raise PydanticCustomError('color_error', 'value is not a valid color: alpha values must be a valid float')
if almost_equal_floats(alpha, 1): if math.isclose(alpha, 1):
return None return None
elif 0 <= alpha <= 1: elif 0 <= alpha <= 1:
return alpha return alpha
else: else:
raise ColorError(reason='alpha values must be in the range 0 to 1') raise PydanticCustomError('color_error', 'value is not a valid color: alpha values must be in the range 0 to 1')
def parse_hsl(h: str, h_units: str, sat: str, light: str, alpha: Optional[float] = None) -> RGBA: def parse_hsl(h: str, h_units: str, sat: str, light: str, alpha: Optional[float] = None) -> RGBA:
""" """Parse raw hue, saturation, lightness, and alpha values and convert to RGBA.
Parse raw hue, saturation, lightness and alpha values and convert to RGBA.
Args:
h: The hue value.
h_units: The unit for hue value.
sat: The saturation value.
light: The lightness value.
alpha: Alpha value.
Returns:
An instance of `RGBA`.
""" """
s_value, l_value = parse_color_value(sat, 100), parse_color_value(light, 100) s_value, l_value = parse_color_value(sat, 100), parse_color_value(light, 100)
@ -334,10 +432,21 @@ def parse_hsl(h: str, h_units: str, sat: str, light: str, alpha: Optional[float]
h_value = h_value % 1 h_value = h_value % 1
r, g, b = hls_to_rgb(h_value, l_value, s_value) r, g, b = hls_to_rgb(h_value, l_value, s_value)
return RGBA(r, g, b, alpha) return RGBA(r, g, b, parse_float_alpha(alpha))
def float_to_255(c: float) -> int: def float_to_255(c: float) -> int:
"""Converts a float value between 0 and 1 (inclusive) to an integer between 0 and 255 (inclusive).
Args:
c: The float value to be converted. Must be between 0 and 1 (inclusive).
Returns:
The integer equivalent of the given float value rounded to the nearest whole number.
Raises:
ValueError: If the given float value is outside the acceptable range of 0 to 1 (inclusive).
"""
return int(round(c * 255)) return int(round(c * 255))

File diff suppressed because it is too large Load diff

View file

@ -1,479 +1,327 @@
""" """Provide an enhanced dataclass that performs validation."""
The main purpose is to enhance stdlib dataclasses by adding validation from __future__ import annotations as _annotations
A pydantic dataclass can be generated from scratch or from a stdlib one.
Behind the scene, a pydantic dataclass is just like a regular one on which we attach import dataclasses
a `BaseModel` and magic methods to trigger the validation of the data.
`__init__` and `__post_init__` are hence overridden and have extra logic to be
able to validate input data.
When a pydantic dataclass is generated from scratch, it's just a plain dataclass
with validation triggered at initialization
The tricky part if for stdlib dataclasses that are converted after into pydantic ones e.g.
```py
@dataclasses.dataclass
class M:
x: int
ValidatedM = pydantic.dataclasses.dataclass(M)
```
We indeed still want to support equality, hashing, repr, ... as if it was the stdlib one!
```py
assert isinstance(ValidatedM(x=1), M)
assert ValidatedM(x=1) == M(x=1)
```
This means we **don't want to create a new dataclass that inherits from it**
The trick is to create a wrapper around `M` that will act as a proxy to trigger
validation without altering default `M` behaviour.
"""
import sys import sys
from contextlib import contextmanager import types
from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Generic, NoReturn, TypeVar, overload
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Dict,
Generator,
Optional,
Set,
Type,
TypeVar,
Union,
overload,
)
from typing_extensions import dataclass_transform from typing_extensions import Literal, TypeGuard, dataclass_transform
from .class_validators import gather_all_validators from ._internal import _config, _decorators, _typing_extra
from .config import BaseConfig, ConfigDict, Extra, get_config from ._internal import _dataclasses as _pydantic_dataclasses
from .error_wrappers import ValidationError from ._migration import getattr_migration
from .errors import DataclassTypeError from .config import ConfigDict
from .fields import Field, FieldInfo, Required, Undefined from .fields import Field, FieldInfo
from .main import create_model, validate_model
from .utils import ClassAttribute
if TYPE_CHECKING: if TYPE_CHECKING:
from .main import BaseModel from ._internal._dataclasses import PydanticDataclass
from .typing import CallableGenerator, NoArgAnyCallable
DataclassT = TypeVar('DataclassT', bound='Dataclass') __all__ = 'dataclass', 'rebuild_dataclass'
DataclassClassOrWrapper = Union[Type['Dataclass'], 'DataclassProxy']
class Dataclass:
# stdlib attributes
__dataclass_fields__: ClassVar[Dict[str, Any]]
__dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
__post_init__: ClassVar[Callable[..., None]]
# Added by pydantic
__pydantic_run_validation__: ClassVar[bool]
__post_init_post_parse__: ClassVar[Callable[..., None]]
__pydantic_initialised__: ClassVar[bool]
__pydantic_model__: ClassVar[Type[BaseModel]]
__pydantic_validate_values__: ClassVar[Callable[['Dataclass'], None]]
__pydantic_has_field_info_default__: ClassVar[bool] # whether a `pydantic.Field` is used as default value
def __init__(self, *args: object, **kwargs: object) -> None:
pass
@classmethod
def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator':
pass
@classmethod
def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
pass
__all__ = [
'dataclass',
'set_validation',
'create_pydantic_model_from_dataclass',
'is_builtin_dataclass',
'make_dataclass_validator',
]
_T = TypeVar('_T') _T = TypeVar('_T')
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) @dataclass_transform(field_specifiers=(dataclasses.field, Field))
@overload @overload
def dataclass( def dataclass(
*, *,
init: bool = True, init: Literal[False] = False,
repr: bool = True, repr: bool = True,
eq: bool = True, eq: bool = True,
order: bool = False, order: bool = False,
unsafe_hash: bool = False, unsafe_hash: bool = False,
frozen: bool = False, frozen: bool = False,
config: Union[ConfigDict, Type[object], None] = None, config: ConfigDict | type[object] | None = None,
validate_on_init: Optional[bool] = None, validate_on_init: bool | None = None,
kw_only: bool = ..., kw_only: bool = ...,
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']: slots: bool = ...,
) -> Callable[[type[_T]], type[PydanticDataclass]]: # type: ignore
... ...
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) @dataclass_transform(field_specifiers=(dataclasses.field, Field))
@overload @overload
def dataclass( def dataclass(
_cls: Type[_T], _cls: type[_T], # type: ignore
*, *,
init: bool = True, init: Literal[False] = False,
repr: bool = True, repr: bool = True,
eq: bool = True, eq: bool = True,
order: bool = False, order: bool = False,
unsafe_hash: bool = False, unsafe_hash: bool = False,
frozen: bool = False, frozen: bool = False,
config: Union[ConfigDict, Type[object], None] = None, config: ConfigDict | type[object] | None = None,
validate_on_init: Optional[bool] = None, validate_on_init: bool | None = None,
kw_only: bool = ..., kw_only: bool = ...,
) -> 'DataclassClassOrWrapper': slots: bool = ...,
) -> type[PydanticDataclass]:
... ...
else: else:
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) @dataclass_transform(field_specifiers=(dataclasses.field, Field))
@overload @overload
def dataclass( def dataclass(
*, *,
init: bool = True, init: Literal[False] = False,
repr: bool = True, repr: bool = True,
eq: bool = True, eq: bool = True,
order: bool = False, order: bool = False,
unsafe_hash: bool = False, unsafe_hash: bool = False,
frozen: bool = False, frozen: bool = False,
config: Union[ConfigDict, Type[object], None] = None, config: ConfigDict | type[object] | None = None,
validate_on_init: Optional[bool] = None, validate_on_init: bool | None = None,
) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']: ) -> Callable[[type[_T]], type[PydanticDataclass]]: # type: ignore
... ...
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) @dataclass_transform(field_specifiers=(dataclasses.field, Field))
@overload @overload
def dataclass( def dataclass(
_cls: Type[_T], _cls: type[_T], # type: ignore
*, *,
init: bool = True, init: Literal[False] = False,
repr: bool = True, repr: bool = True,
eq: bool = True, eq: bool = True,
order: bool = False, order: bool = False,
unsafe_hash: bool = False, unsafe_hash: bool = False,
frozen: bool = False, frozen: bool = False,
config: Union[ConfigDict, Type[object], None] = None, config: ConfigDict | type[object] | None = None,
validate_on_init: Optional[bool] = None, validate_on_init: bool | None = None,
) -> 'DataclassClassOrWrapper': ) -> type[PydanticDataclass]:
... ...
@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) @dataclass_transform(field_specifiers=(dataclasses.field, Field))
def dataclass( def dataclass( # noqa: C901
_cls: Optional[Type[_T]] = None, _cls: type[_T] | None = None,
*, *,
init: bool = True, init: Literal[False] = False,
repr: bool = True, repr: bool = True,
eq: bool = True, eq: bool = True,
order: bool = False, order: bool = False,
unsafe_hash: bool = False, unsafe_hash: bool = False,
frozen: bool = False, frozen: bool = False,
config: Union[ConfigDict, Type[object], None] = None, config: ConfigDict | type[object] | None = None,
validate_on_init: Optional[bool] = None, validate_on_init: bool | None = None,
kw_only: bool = False, kw_only: bool = False,
) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']: slots: bool = False,
""" ) -> Callable[[type[_T]], type[PydanticDataclass]] | type[PydanticDataclass]:
Like the python standard lib dataclasses but with type validation. """Usage docs: https://docs.pydantic.dev/2.6/concepts/dataclasses/
The result is either a pydantic dataclass that will validate input data
or a wrapper that will trigger validation around a stdlib dataclass
to avoid modifying it directly
"""
the_config = get_config(config)
def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper': A decorator used to create a Pydantic-enhanced dataclass, similar to the standard Python `dataclass`,
import dataclasses but with added validation.
This function should be used similarly to `dataclasses.dataclass`.
Args:
_cls: The target `dataclass`.
init: Included for signature compatibility with `dataclasses.dataclass`, and is passed through to
`dataclasses.dataclass` when appropriate. If specified, must be set to `False`, as pydantic inserts its
own `__init__` function.
repr: A boolean indicating whether to include the field in the `__repr__` output.
eq: Determines if a `__eq__` method should be generated for the class.
order: Determines if comparison magic methods should be generated, such as `__lt__`, but not `__eq__`.
unsafe_hash: Determines if a `__hash__` method should be included in the class, as in `dataclasses.dataclass`.
frozen: Determines if the generated class should be a 'frozen' `dataclass`, which does not allow its
attributes to be modified after it has been initialized.
config: The Pydantic config to use for the `dataclass`.
validate_on_init: A deprecated parameter included for backwards compatibility; in V2, all Pydantic dataclasses
are validated on init.
kw_only: Determines if `__init__` method parameters must be specified by keyword only. Defaults to `False`.
slots: Determines if the generated class should be a 'slots' `dataclass`, which does not allow the addition of
new attributes after instantiation.
Returns:
A decorator that accepts a class as its argument and returns a Pydantic `dataclass`.
Raises:
AssertionError: Raised if `init` is not `False` or `validate_on_init` is `False`.
"""
assert init is False, 'pydantic.dataclasses.dataclass only supports init=False'
assert validate_on_init is not False, 'validate_on_init=False is no longer supported'
if is_builtin_dataclass(cls) and _extra_dc_args(_cls) == _extra_dc_args(_cls.__bases__[0]): # type: ignore
dc_cls_doc = ''
dc_cls = DataclassProxy(cls)
default_validate_on_init = False
else:
dc_cls_doc = cls.__doc__ or '' # needs to be done before generating dataclass
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
dc_cls = dataclasses.dataclass( kwargs = dict(kw_only=kw_only, slots=slots)
else:
kwargs = {}
def make_pydantic_fields_compatible(cls: type[Any]) -> None:
"""Make sure that stdlib `dataclasses` understands `Field` kwargs like `kw_only`
To do that, we simply change
`x: int = pydantic.Field(..., kw_only=True)`
into
`x: int = dataclasses.field(default=pydantic.Field(..., kw_only=True), kw_only=True)`
"""
for annotation_cls in cls.__mro__:
# In Python < 3.9, `__annotations__` might not be present if there are no fields.
# we therefore need to use `getattr` to avoid an `AttributeError`.
annotations = getattr(annotation_cls, '__annotations__', [])
for field_name in annotations:
field_value = getattr(cls, field_name, None)
# Process only if this is an instance of `FieldInfo`.
if not isinstance(field_value, FieldInfo):
continue
# Initialize arguments for the standard `dataclasses.field`.
field_args: dict = {'default': field_value}
# Handle `kw_only` for Python 3.10+
if sys.version_info >= (3, 10) and field_value.kw_only:
field_args['kw_only'] = True
# Set `repr` attribute if it's explicitly specified to be not `True`.
if field_value.repr is not True:
field_args['repr'] = field_value.repr
setattr(cls, field_name, dataclasses.field(**field_args))
# In Python 3.8, dataclasses checks cls.__dict__['__annotations__'] for annotations,
# so we must make sure it's initialized before we add to it.
if cls.__dict__.get('__annotations__') is None:
cls.__annotations__ = {}
cls.__annotations__[field_name] = annotations[field_name]
def create_dataclass(cls: type[Any]) -> type[PydanticDataclass]:
"""Create a Pydantic dataclass from a regular dataclass.
Args:
cls: The class to create the Pydantic dataclass from.
Returns:
A Pydantic dataclass.
"""
original_cls = cls
config_dict = config
if config_dict is None:
# if not explicitly provided, read from the type
cls_config = getattr(cls, '__pydantic_config__', None)
if cls_config is not None:
config_dict = cls_config
config_wrapper = _config.ConfigWrapper(config_dict)
decorators = _decorators.DecoratorInfos.build(cls)
# Keep track of the original __doc__ so that we can restore it after applying the dataclasses decorator
# Otherwise, classes with no __doc__ will have their signature added into the JSON schema description,
# since dataclasses.dataclass will set this as the __doc__
original_doc = cls.__doc__
if _pydantic_dataclasses.is_builtin_dataclass(cls):
# Don't preserve the docstring for vanilla dataclasses, as it may include the signature
# This matches v1 behavior, and there was an explicit test for it
original_doc = None
# We don't want to add validation to the existing std lib dataclass, so we will subclass it
# If the class is generic, we need to make sure the subclass also inherits from Generic
# with all the same parameters.
bases = (cls,)
if issubclass(cls, Generic):
generic_base = Generic[cls.__parameters__] # type: ignore
bases = bases + (generic_base,)
cls = types.new_class(cls.__name__, bases)
make_pydantic_fields_compatible(cls)
cls = dataclasses.dataclass( # type: ignore[call-overload]
cls, cls,
init=init, # the value of init here doesn't affect anything except that it makes it easier to generate a signature
init=True,
repr=repr, repr=repr,
eq=eq, eq=eq,
order=order, order=order,
unsafe_hash=unsafe_hash, unsafe_hash=unsafe_hash,
frozen=frozen, frozen=frozen,
kw_only=kw_only, **kwargs,
) )
else:
dc_cls = dataclasses.dataclass( # type: ignore
cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
)
default_validate_on_init = True
should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init cls.__pydantic_decorators__ = decorators # type: ignore
_add_pydantic_validation_attributes(cls, the_config, should_validate_on_init, dc_cls_doc) cls.__doc__ = original_doc
dc_cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls}) cls.__module__ = original_cls.__module__
return dc_cls cls.__qualname__ = original_cls.__qualname__
pydantic_complete = _pydantic_dataclasses.complete_dataclass(
cls, config_wrapper, raise_errors=False, types_namespace=None
)
cls.__pydantic_complete__ = pydantic_complete # type: ignore
return cls
if _cls is None: if _cls is None:
return wrap return create_dataclass
return wrap(_cls) return create_dataclass(_cls)
@contextmanager __getattr__ = getattr_migration(__name__)
def set_validation(cls: Type['DataclassT'], value: bool) -> Generator[Type['DataclassT'], None, None]:
original_run_validation = cls.__pydantic_run_validation__
try:
cls.__pydantic_run_validation__ = value
yield cls
finally:
cls.__pydantic_run_validation__ = original_run_validation
if (3, 8) <= sys.version_info < (3, 11):
# Monkeypatch dataclasses.InitVar so that typing doesn't error if it occurs as a type when evaluating type hints
# Starting in 3.11, typing.get_type_hints will not raise an error if the retrieved type hints are not callable.
class DataclassProxy: def _call_initvar(*args: Any, **kwargs: Any) -> NoReturn:
__slots__ = '__dataclass__' """This function does nothing but raise an error that is as similar as possible to what you'd get
if you were to try calling `InitVar[int]()` without this monkeypatch. The whole purpose is just
def __init__(self, dc_cls: Type['Dataclass']) -> None: to ensure typing._type_check does not error if the type hint evaluates to `InitVar[<parameter>]`.
object.__setattr__(self, '__dataclass__', dc_cls)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
with set_validation(self.__dataclass__, True):
return self.__dataclass__(*args, **kwargs)
def __getattr__(self, name: str) -> Any:
return getattr(self.__dataclass__, name)
def __instancecheck__(self, instance: Any) -> bool:
return isinstance(instance, self.__dataclass__)
def _add_pydantic_validation_attributes( # noqa: C901 (ignore complexity)
dc_cls: Type['Dataclass'],
config: Type[BaseConfig],
validate_on_init: bool,
dc_cls_doc: str,
) -> None:
""" """
We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass raise TypeError("'InitVar' object is not callable")
it won't even exist (code is generated on the fly by `dataclasses`)
By default, we run validation after `__init__` or `__post_init__` if defined dataclasses.InitVar.__call__ = _call_initvar
def rebuild_dataclass(
cls: type[PydanticDataclass],
*,
force: bool = False,
raise_errors: bool = True,
_parent_namespace_depth: int = 2,
_types_namespace: dict[str, Any] | None = None,
) -> bool | None:
"""Try to rebuild the pydantic-core schema for the dataclass.
This may be necessary when one of the annotations is a ForwardRef which could not be resolved during
the initial attempt to build the schema, and automatic rebuilding fails.
This is analogous to `BaseModel.model_rebuild`.
Args:
cls: The class to rebuild the pydantic-core schema for.
force: Whether to force the rebuilding of the schema, defaults to `False`.
raise_errors: Whether to raise errors, defaults to `True`.
_parent_namespace_depth: The depth level of the parent namespace, defaults to 2.
_types_namespace: The types namespace, defaults to `None`.
Returns:
Returns `None` if the schema is already "complete" and rebuilding was not required.
If rebuilding _was_ required, returns `True` if rebuilding was successful, otherwise `False`.
""" """
init = dc_cls.__init__ if not force and cls.__pydantic_complete__:
return None
@wraps(init)
def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
if config.extra == Extra.ignore:
init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
elif config.extra == Extra.allow:
for k, v in kwargs.items():
self.__dict__.setdefault(k, v)
init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__})
else: else:
init(self, *args, **kwargs) if _types_namespace is not None:
types_namespace: dict[str, Any] | None = _types_namespace.copy()
if hasattr(dc_cls, '__post_init__'):
post_init = dc_cls.__post_init__
@wraps(post_init)
def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None:
if config.post_init_call == 'before_validation':
post_init(self, *args, **kwargs)
if self.__class__.__pydantic_run_validation__:
self.__pydantic_validate_values__()
if hasattr(self, '__post_init_post_parse__'):
self.__post_init_post_parse__(*args, **kwargs)
if config.post_init_call == 'after_validation':
post_init(self, *args, **kwargs)
setattr(dc_cls, '__init__', handle_extra_init)
setattr(dc_cls, '__post_init__', new_post_init)
else: else:
if _parent_namespace_depth > 0:
@wraps(init) frame_parent_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth) or {}
def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: # Note: we may need to add something similar to cls.__pydantic_parent_namespace__ from BaseModel
handle_extra_init(self, *args, **kwargs) # here when implementing handling of recursive generics. See BaseModel.model_rebuild for reference.
types_namespace = frame_parent_ns
if self.__class__.__pydantic_run_validation__:
self.__pydantic_validate_values__()
if hasattr(self, '__post_init_post_parse__'):
# We need to find again the initvars. To do that we use `__dataclass_fields__` instead of
# public method `dataclasses.fields`
import dataclasses
# get all initvars and their default values
initvars_and_values: Dict[str, Any] = {}
for i, f in enumerate(self.__class__.__dataclass_fields__.values()):
if f._field_type is dataclasses._FIELD_INITVAR: # type: ignore[attr-defined]
try:
# set arg value by default
initvars_and_values[f.name] = args[i]
except IndexError:
initvars_and_values[f.name] = kwargs.get(f.name, f.default)
self.__post_init_post_parse__(**initvars_and_values)
setattr(dc_cls, '__init__', new_init)
setattr(dc_cls, '__pydantic_run_validation__', ClassAttribute('__pydantic_run_validation__', validate_on_init))
setattr(dc_cls, '__pydantic_initialised__', False)
setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config, dc_cls_doc))
setattr(dc_cls, '__pydantic_validate_values__', _dataclass_validate_values)
setattr(dc_cls, '__validate__', classmethod(_validate_dataclass))
setattr(dc_cls, '__get_validators__', classmethod(_get_validators))
if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen:
setattr(dc_cls, '__setattr__', _dataclass_validate_assignment_setattr)
def _get_validators(cls: 'DataclassClassOrWrapper') -> 'CallableGenerator':
yield cls.__validate__
def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT':
with set_validation(cls, True):
if isinstance(v, cls):
v.__pydantic_validate_values__()
return v
elif isinstance(v, (list, tuple)):
return cls(*v)
elif isinstance(v, dict):
return cls(**v)
else: else:
raise DataclassTypeError(class_name=cls.__name__) types_namespace = {}
types_namespace = _typing_extra.get_cls_types_namespace(cls, types_namespace)
def create_pydantic_model_from_dataclass( return _pydantic_dataclasses.complete_dataclass(
dc_cls: Type['Dataclass'], cls,
config: Type[Any] = BaseConfig, _config.ConfigWrapper(cls.__pydantic_config__, check=False),
dc_cls_doc: Optional[str] = None, raise_errors=raise_errors,
) -> Type['BaseModel']: types_namespace=types_namespace,
import dataclasses
field_definitions: Dict[str, Any] = {}
for field in dataclasses.fields(dc_cls):
default: Any = Undefined
default_factory: Optional['NoArgAnyCallable'] = None
field_info: FieldInfo
if field.default is not dataclasses.MISSING:
default = field.default
elif field.default_factory is not dataclasses.MISSING:
default_factory = field.default_factory
else:
default = Required
if isinstance(default, FieldInfo):
field_info = default
dc_cls.__pydantic_has_field_info_default__ = True
else:
field_info = Field(default=default, default_factory=default_factory, **field.metadata)
field_definitions[field.name] = (field.type, field_info)
validators = gather_all_validators(dc_cls)
model: Type['BaseModel'] = create_model(
dc_cls.__name__,
__config__=config,
__module__=dc_cls.__module__,
__validators__=validators,
__cls_kwargs__={'__resolve_forward_refs__': False},
**field_definitions,
)
model.__doc__ = dc_cls_doc if dc_cls_doc is not None else dc_cls.__doc__ or ''
return model
def _dataclass_validate_values(self: 'Dataclass') -> None:
# validation errors can occur if this function is called twice on an already initialised dataclass.
# for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property
if getattr(self, '__pydantic_initialised__'):
return
if getattr(self, '__pydantic_has_field_info_default__', False):
# We need to remove `FieldInfo` values since they are not valid as input
# It's ok to do that because they are obviously the default values!
input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)}
else:
input_data = self.__dict__
d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__)
if validation_error:
raise validation_error
self.__dict__.update(d)
object.__setattr__(self, '__pydantic_initialised__', True)
def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None:
if self.__pydantic_initialised__:
d = dict(self.__dict__)
d.pop(name, None)
known_field = self.__pydantic_model__.__fields__.get(name, None)
if known_field:
value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__)
if error_:
raise ValidationError([error_], self.__class__)
object.__setattr__(self, name, value)
def _extra_dc_args(cls: Type[Any]) -> Set[str]:
return {
x
for x in dir(cls)
if x not in getattr(cls, '__dataclass_fields__', {}) and not (x.startswith('__') and x.endswith('__'))
}
def is_builtin_dataclass(_cls: Type[Any]) -> bool:
"""
Whether a class is a stdlib dataclass
(useful to discriminated a pydantic dataclass that is actually a wrapper around a stdlib dataclass)
we check that
- `_cls` is a dataclass
- `_cls` is not a processed pydantic dataclass (with a basemodel attached)
- `_cls` is not a pydantic dataclass inheriting directly from a stdlib dataclass
e.g.
```
@dataclasses.dataclass
class A:
x: int
@pydantic.dataclasses.dataclass
class B(A):
y: int
```
In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
"""
import dataclasses
return (
dataclasses.is_dataclass(_cls)
and not hasattr(_cls, '__pydantic_model__')
and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {})))
) )
def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]) -> 'CallableGenerator': def is_pydantic_dataclass(__cls: type[Any]) -> TypeGuard[type[PydanticDataclass]]:
"""Whether a class is a pydantic dataclass.
Args:
__cls: The class.
Returns:
`True` if the class is a pydantic dataclass, `False` otherwise.
""" """
Create a pydantic.dataclass from a builtin dataclass to add type validation return dataclasses.is_dataclass(__cls) and '__pydantic_validator__' in __cls.__dict__
and yield the validators
It retrieves the parameters of the dataclass and forwards them to the newly created dataclass
"""
yield from _get_validators(dataclass(dc_cls, config=config, validate_on_init=False))

View file

@ -1,248 +1,4 @@
""" """The `datetime_parse` module is a backport module from V1."""
Functions to parse datetime objects. from ._migration import getattr_migration
We're using regular expressions rather than time.strptime because: __getattr__ = getattr_migration(__name__)
- They provide both validation and parsing.
- They're more flexible for datetimes.
- The date/datetime/time constructors produce friendlier error messages.
Stolen from https://raw.githubusercontent.com/django/django/main/django/utils/dateparse.py at
9718fa2e8abe430c3526a9278dd976443d4ae3c6
Changed to:
* use standard python datetime types not django.utils.timezone
* raise ValueError when regex doesn't match rather than returning None
* support parsing unix timestamps for dates and datetimes
"""
import re
from datetime import date, datetime, time, timedelta, timezone
from typing import Dict, Optional, Type, Union
from . import errors
date_expr = r'(?P<year>\d{4})-(?P<month>\d{1,2})-(?P<day>\d{1,2})'
time_expr = (
r'(?P<hour>\d{1,2}):(?P<minute>\d{1,2})'
r'(?::(?P<second>\d{1,2})(?:\.(?P<microsecond>\d{1,6})\d{0,6})?)?'
r'(?P<tzinfo>Z|[+-]\d{2}(?::?\d{2})?)?$'
)
date_re = re.compile(f'{date_expr}$')
time_re = re.compile(time_expr)
datetime_re = re.compile(f'{date_expr}[T ]{time_expr}')
standard_duration_re = re.compile(
r'^'
r'(?:(?P<days>-?\d+) (days?, )?)?'
r'((?:(?P<hours>-?\d+):)(?=\d+:\d+))?'
r'(?:(?P<minutes>-?\d+):)?'
r'(?P<seconds>-?\d+)'
r'(?:\.(?P<microseconds>\d{1,6})\d{0,6})?'
r'$'
)
# Support the sections of ISO 8601 date representation that are accepted by timedelta
iso8601_duration_re = re.compile(
r'^(?P<sign>[-+]?)'
r'P'
r'(?:(?P<days>\d+(.\d+)?)D)?'
r'(?:T'
r'(?:(?P<hours>\d+(.\d+)?)H)?'
r'(?:(?P<minutes>\d+(.\d+)?)M)?'
r'(?:(?P<seconds>\d+(.\d+)?)S)?'
r')?'
r'$'
)
EPOCH = datetime(1970, 1, 1)
# if greater than this, the number is in ms, if less than or equal it's in seconds
# (in seconds this is 11th October 2603, in ms it's 20th August 1970)
MS_WATERSHED = int(2e10)
# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9
MAX_NUMBER = int(3e20)
StrBytesIntFloat = Union[str, bytes, int, float]
def get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]:
if isinstance(value, (int, float)):
return value
try:
return float(value)
except ValueError:
return None
except TypeError:
raise TypeError(f'invalid type; expected {native_expected_type}, string, bytes, int or float')
def from_unix_seconds(seconds: Union[int, float]) -> datetime:
if seconds > MAX_NUMBER:
return datetime.max
elif seconds < -MAX_NUMBER:
return datetime.min
while abs(seconds) > MS_WATERSHED:
seconds /= 1000
dt = EPOCH + timedelta(seconds=seconds)
return dt.replace(tzinfo=timezone.utc)
def _parse_timezone(value: Optional[str], error: Type[Exception]) -> Union[None, int, timezone]:
if value == 'Z':
return timezone.utc
elif value is not None:
offset_mins = int(value[-2:]) if len(value) > 3 else 0
offset = 60 * int(value[1:3]) + offset_mins
if value[0] == '-':
offset = -offset
try:
return timezone(timedelta(minutes=offset))
except ValueError:
raise error()
else:
return None
def parse_date(value: Union[date, StrBytesIntFloat]) -> date:
"""
Parse a date/int/float/string and return a datetime.date.
Raise ValueError if the input is well formatted but not a valid date.
Raise ValueError if the input isn't well formatted.
"""
if isinstance(value, date):
if isinstance(value, datetime):
return value.date()
else:
return value
number = get_numeric(value, 'date')
if number is not None:
return from_unix_seconds(number).date()
if isinstance(value, bytes):
value = value.decode()
match = date_re.match(value) # type: ignore
if match is None:
raise errors.DateError()
kw = {k: int(v) for k, v in match.groupdict().items()}
try:
return date(**kw)
except ValueError:
raise errors.DateError()
def parse_time(value: Union[time, StrBytesIntFloat]) -> time:
"""
Parse a time/string and return a datetime.time.
Raise ValueError if the input is well formatted but not a valid time.
Raise ValueError if the input isn't well formatted, in particular if it contains an offset.
"""
if isinstance(value, time):
return value
number = get_numeric(value, 'time')
if number is not None:
if number >= 86400:
# doesn't make sense since the time time loop back around to 0
raise errors.TimeError()
return (datetime.min + timedelta(seconds=number)).time()
if isinstance(value, bytes):
value = value.decode()
match = time_re.match(value) # type: ignore
if match is None:
raise errors.TimeError()
kw = match.groupdict()
if kw['microsecond']:
kw['microsecond'] = kw['microsecond'].ljust(6, '0')
tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.TimeError)
kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None}
kw_['tzinfo'] = tzinfo
try:
return time(**kw_) # type: ignore
except ValueError:
raise errors.TimeError()
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime:
"""
Parse a datetime/int/float/string and return a datetime.datetime.
This function supports time zone offsets. When the input contains one,
the output uses a timezone with a fixed offset from UTC.
Raise ValueError if the input is well formatted but not a valid datetime.
Raise ValueError if the input isn't well formatted.
"""
if isinstance(value, datetime):
return value
number = get_numeric(value, 'datetime')
if number is not None:
return from_unix_seconds(number)
if isinstance(value, bytes):
value = value.decode()
match = datetime_re.match(value) # type: ignore
if match is None:
raise errors.DateTimeError()
kw = match.groupdict()
if kw['microsecond']:
kw['microsecond'] = kw['microsecond'].ljust(6, '0')
tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.DateTimeError)
kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None}
kw_['tzinfo'] = tzinfo
try:
return datetime(**kw_) # type: ignore
except ValueError:
raise errors.DateTimeError()
def parse_duration(value: StrBytesIntFloat) -> timedelta:
"""
Parse a duration int/float/string and return a datetime.timedelta.
The preferred format for durations in Django is '%d %H:%M:%S.%f'.
Also supports ISO 8601 representation.
"""
if isinstance(value, timedelta):
return value
if isinstance(value, (int, float)):
# below code requires a string
value = f'{value:f}'
elif isinstance(value, bytes):
value = value.decode()
try:
match = standard_duration_re.match(value) or iso8601_duration_re.match(value)
except TypeError:
raise TypeError('invalid type; expected timedelta, string, bytes, int or float')
if not match:
raise errors.DurationError()
kw = match.groupdict()
sign = -1 if kw.pop('sign', '+') == '-' else 1
if kw.get('microseconds'):
kw['microseconds'] = kw['microseconds'].ljust(6, '0')
if kw.get('seconds') and kw.get('microseconds') and kw['seconds'].startswith('-'):
kw['microseconds'] = '-' + kw['microseconds']
kw_ = {k: float(v) for k, v in kw.items() if v is not None}
return sign * timedelta(**kw_)

View file

@ -1,264 +1,4 @@
from functools import wraps """The `decorator` module is a backport module from V1."""
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload from ._migration import getattr_migration
from . import validator __getattr__ = getattr_migration(__name__)
from .config import Extra
from .errors import ConfigError
from .main import BaseModel, create_model
from .typing import get_all_type_hints
from .utils import to_camel
__all__ = ('validate_arguments',)
if TYPE_CHECKING:
from .typing import AnyCallable
AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable)
ConfigType = Union[None, Type[Any], Dict[str, Any]]
@overload
def validate_arguments(func: None = None, *, config: 'ConfigType' = None) -> Callable[['AnyCallableT'], 'AnyCallableT']:
...
@overload
def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT':
...
def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any:
"""
Decorator to validate the arguments passed to a function.
"""
def validate(_func: 'AnyCallable') -> 'AnyCallable':
vd = ValidatedFunction(_func, config)
@wraps(_func)
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
return vd.call(*args, **kwargs)
wrapper_function.vd = vd # type: ignore
wrapper_function.validate = vd.init_model_instance # type: ignore
wrapper_function.raw_function = vd.raw_function # type: ignore
wrapper_function.model = vd.model # type: ignore
return wrapper_function
if func:
return validate(func)
else:
return validate
ALT_V_ARGS = 'v__args'
ALT_V_KWARGS = 'v__kwargs'
V_POSITIONAL_ONLY_NAME = 'v__positional_only'
V_DUPLICATE_KWARGS = 'v__duplicate_kwargs'
class ValidatedFunction:
def __init__(self, function: 'AnyCallableT', config: 'ConfigType'): # noqa C901
from inspect import Parameter, signature
parameters: Mapping[str, Parameter] = signature(function).parameters
if parameters.keys() & {ALT_V_ARGS, ALT_V_KWARGS, V_POSITIONAL_ONLY_NAME, V_DUPLICATE_KWARGS}:
raise ConfigError(
f'"{ALT_V_ARGS}", "{ALT_V_KWARGS}", "{V_POSITIONAL_ONLY_NAME}" and "{V_DUPLICATE_KWARGS}" '
f'are not permitted as argument names when using the "{validate_arguments.__name__}" decorator'
)
self.raw_function = function
self.arg_mapping: Dict[int, str] = {}
self.positional_only_args = set()
self.v_args_name = 'args'
self.v_kwargs_name = 'kwargs'
type_hints = get_all_type_hints(function)
takes_args = False
takes_kwargs = False
fields: Dict[str, Tuple[Any, Any]] = {}
for i, (name, p) in enumerate(parameters.items()):
if p.annotation is p.empty:
annotation = Any
else:
annotation = type_hints[name]
default = ... if p.default is p.empty else p.default
if p.kind == Parameter.POSITIONAL_ONLY:
self.arg_mapping[i] = name
fields[name] = annotation, default
fields[V_POSITIONAL_ONLY_NAME] = List[str], None
self.positional_only_args.add(name)
elif p.kind == Parameter.POSITIONAL_OR_KEYWORD:
self.arg_mapping[i] = name
fields[name] = annotation, default
fields[V_DUPLICATE_KWARGS] = List[str], None
elif p.kind == Parameter.KEYWORD_ONLY:
fields[name] = annotation, default
elif p.kind == Parameter.VAR_POSITIONAL:
self.v_args_name = name
fields[name] = Tuple[annotation, ...], None
takes_args = True
else:
assert p.kind == Parameter.VAR_KEYWORD, p.kind
self.v_kwargs_name = name
fields[name] = Dict[str, annotation], None # type: ignore
takes_kwargs = True
# these checks avoid a clash between "args" and a field with that name
if not takes_args and self.v_args_name in fields:
self.v_args_name = ALT_V_ARGS
# same with "kwargs"
if not takes_kwargs and self.v_kwargs_name in fields:
self.v_kwargs_name = ALT_V_KWARGS
if not takes_args:
# we add the field so validation below can raise the correct exception
fields[self.v_args_name] = List[Any], None
if not takes_kwargs:
# same with kwargs
fields[self.v_kwargs_name] = Dict[Any, Any], None
self.create_model(fields, takes_args, takes_kwargs, config)
def init_model_instance(self, *args: Any, **kwargs: Any) -> BaseModel:
values = self.build_values(args, kwargs)
return self.model(**values)
def call(self, *args: Any, **kwargs: Any) -> Any:
m = self.init_model_instance(*args, **kwargs)
return self.execute(m)
def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]:
values: Dict[str, Any] = {}
if args:
arg_iter = enumerate(args)
while True:
try:
i, a = next(arg_iter)
except StopIteration:
break
arg_name = self.arg_mapping.get(i)
if arg_name is not None:
values[arg_name] = a
else:
values[self.v_args_name] = [a] + [a for _, a in arg_iter]
break
var_kwargs: Dict[str, Any] = {}
wrong_positional_args = []
duplicate_kwargs = []
fields_alias = [
field.alias
for name, field in self.model.__fields__.items()
if name not in (self.v_args_name, self.v_kwargs_name)
]
non_var_fields = set(self.model.__fields__) - {self.v_args_name, self.v_kwargs_name}
for k, v in kwargs.items():
if k in non_var_fields or k in fields_alias:
if k in self.positional_only_args:
wrong_positional_args.append(k)
if k in values:
duplicate_kwargs.append(k)
values[k] = v
else:
var_kwargs[k] = v
if var_kwargs:
values[self.v_kwargs_name] = var_kwargs
if wrong_positional_args:
values[V_POSITIONAL_ONLY_NAME] = wrong_positional_args
if duplicate_kwargs:
values[V_DUPLICATE_KWARGS] = duplicate_kwargs
return values
def execute(self, m: BaseModel) -> Any:
d = {k: v for k, v in m._iter() if k in m.__fields_set__ or m.__fields__[k].default_factory}
var_kwargs = d.pop(self.v_kwargs_name, {})
if self.v_args_name in d:
args_: List[Any] = []
in_kwargs = False
kwargs = {}
for name, value in d.items():
if in_kwargs:
kwargs[name] = value
elif name == self.v_args_name:
args_ += value
in_kwargs = True
else:
args_.append(value)
return self.raw_function(*args_, **kwargs, **var_kwargs)
elif self.positional_only_args:
args_ = []
kwargs = {}
for name, value in d.items():
if name in self.positional_only_args:
args_.append(value)
else:
kwargs[name] = value
return self.raw_function(*args_, **kwargs, **var_kwargs)
else:
return self.raw_function(**d, **var_kwargs)
def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None:
pos_args = len(self.arg_mapping)
class CustomConfig:
pass
if not TYPE_CHECKING: # pragma: no branch
if isinstance(config, dict):
CustomConfig = type('Config', (), config) # noqa: F811
elif config is not None:
CustomConfig = config # noqa: F811
if hasattr(CustomConfig, 'fields') or hasattr(CustomConfig, 'alias_generator'):
raise ConfigError(
'Setting the "fields" and "alias_generator" property on custom Config for '
'@validate_arguments is not yet supported, please remove.'
)
class DecoratorBaseModel(BaseModel):
@validator(self.v_args_name, check_fields=False, allow_reuse=True)
def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]:
if takes_args or v is None:
return v
raise TypeError(f'{pos_args} positional arguments expected but {pos_args + len(v)} given')
@validator(self.v_kwargs_name, check_fields=False, allow_reuse=True)
def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if takes_kwargs or v is None:
return v
plural = '' if len(v) == 1 else 's'
keys = ', '.join(map(repr, v.keys()))
raise TypeError(f'unexpected keyword argument{plural}: {keys}')
@validator(V_POSITIONAL_ONLY_NAME, check_fields=False, allow_reuse=True)
def check_positional_only(cls, v: Optional[List[str]]) -> None:
if v is None:
return
plural = '' if len(v) == 1 else 's'
keys = ', '.join(map(repr, v))
raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}')
@validator(V_DUPLICATE_KWARGS, check_fields=False, allow_reuse=True)
def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None:
if v is None:
return
plural = '' if len(v) == 1 else 's'
keys = ', '.join(map(repr, v))
raise TypeError(f'multiple values for argument{plural}: {keys}')
class Config(CustomConfig):
extra = getattr(CustomConfig, 'extra', Extra.forbid)
self.model = create_model(to_camel(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields)

View file

View file

@ -0,0 +1,253 @@
"""Old `@validator` and `@root_validator` function validators from V1."""
from __future__ import annotations as _annotations
from functools import partial, partialmethod
from types import FunctionType
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, overload
from warnings import warn
from typing_extensions import Literal, Protocol, TypeAlias
from .._internal import _decorators, _decorators_v1
from ..errors import PydanticUserError
from ..warnings import PydanticDeprecatedSince20
_ALLOW_REUSE_WARNING_MESSAGE = '`allow_reuse` is deprecated and will be ignored; it should no longer be necessary'
if TYPE_CHECKING:
class _OnlyValueValidatorClsMethod(Protocol):
def __call__(self, __cls: Any, __value: Any) -> Any:
...
class _V1ValidatorWithValuesClsMethod(Protocol):
def __call__(self, __cls: Any, __value: Any, values: dict[str, Any]) -> Any:
...
class _V1ValidatorWithValuesKwOnlyClsMethod(Protocol):
def __call__(self, __cls: Any, __value: Any, *, values: dict[str, Any]) -> Any:
...
class _V1ValidatorWithKwargsClsMethod(Protocol):
def __call__(self, __cls: Any, **kwargs: Any) -> Any:
...
class _V1ValidatorWithValuesAndKwargsClsMethod(Protocol):
def __call__(self, __cls: Any, values: dict[str, Any], **kwargs: Any) -> Any:
...
class _V1RootValidatorClsMethod(Protocol):
def __call__(
self, __cls: Any, __values: _decorators_v1.RootValidatorValues
) -> _decorators_v1.RootValidatorValues:
...
V1Validator = Union[
_OnlyValueValidatorClsMethod,
_V1ValidatorWithValuesClsMethod,
_V1ValidatorWithValuesKwOnlyClsMethod,
_V1ValidatorWithKwargsClsMethod,
_V1ValidatorWithValuesAndKwargsClsMethod,
_decorators_v1.V1ValidatorWithValues,
_decorators_v1.V1ValidatorWithValuesKwOnly,
_decorators_v1.V1ValidatorWithKwargs,
_decorators_v1.V1ValidatorWithValuesAndKwargs,
]
V1RootValidator = Union[
_V1RootValidatorClsMethod,
_decorators_v1.V1RootValidatorFunction,
]
_PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]]
# Allow both a V1 (assumed pre=False) or V2 (assumed mode='after') validator
# We lie to type checkers and say we return the same thing we get
# but in reality we return a proxy object that _mostly_ behaves like the wrapped thing
_V1ValidatorType = TypeVar('_V1ValidatorType', V1Validator, _PartialClsOrStaticMethod)
_V1RootValidatorFunctionType = TypeVar(
'_V1RootValidatorFunctionType',
_decorators_v1.V1RootValidatorFunction,
_V1RootValidatorClsMethod,
_PartialClsOrStaticMethod,
)
else:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
def validator(
__field: str,
*fields: str,
pre: bool = False,
each_item: bool = False,
always: bool = False,
check_fields: bool | None = None,
allow_reuse: bool = False,
) -> Callable[[_V1ValidatorType], _V1ValidatorType]:
"""Decorate methods on the class indicating that they should be used to validate fields.
Args:
__field (str): The first field the validator should be called on; this is separate
from `fields` to ensure an error is raised if you don't pass at least one.
*fields (str): Additional field(s) the validator should be called on.
pre (bool, optional): Whether this validator should be called before the standard
validators (else after). Defaults to False.
each_item (bool, optional): For complex objects (sets, lists etc.) whether to validate
individual elements rather than the whole object. Defaults to False.
always (bool, optional): Whether this method and other validators should be called even if
the value is missing. Defaults to False.
check_fields (bool | None, optional): Whether to check that the fields actually exist on the model.
Defaults to None.
allow_reuse (bool, optional): Whether to track and raise an error if another validator refers to
the decorated function. Defaults to False.
Returns:
Callable: A decorator that can be used to decorate a
function to be used as a validator.
"""
if allow_reuse is True: # pragma: no cover
warn(_ALLOW_REUSE_WARNING_MESSAGE, DeprecationWarning)
fields = tuple((__field, *fields))
if isinstance(fields[0], FunctionType):
raise PydanticUserError(
'`@validator` should be used with fields and keyword arguments, not bare. '
"E.g. usage should be `@validator('<field_name>', ...)`",
code='validator-no-fields',
)
elif not all(isinstance(field, str) for field in fields):
raise PydanticUserError(
'`@validator` fields should be passed as separate string args. '
"E.g. usage should be `@validator('<field_name_1>', '<field_name_2>', ...)`",
code='validator-invalid-fields',
)
warn(
'Pydantic V1 style `@validator` validators are deprecated.'
' You should migrate to Pydantic V2 style `@field_validator` validators,'
' see the migration guide for more details',
DeprecationWarning,
stacklevel=2,
)
mode: Literal['before', 'after'] = 'before' if pre is True else 'after'
def dec(f: Any) -> _decorators.PydanticDescriptorProxy[Any]:
if _decorators.is_instance_method_from_sig(f):
raise PydanticUserError(
'`@validator` cannot be applied to instance methods', code='validator-instance-method'
)
# auto apply the @classmethod decorator
f = _decorators.ensure_classmethod_based_on_signature(f)
wrap = _decorators_v1.make_generic_v1_field_validator
validator_wrapper_info = _decorators.ValidatorDecoratorInfo(
fields=fields,
mode=mode,
each_item=each_item,
always=always,
check_fields=check_fields,
)
return _decorators.PydanticDescriptorProxy(f, validator_wrapper_info, shim=wrap)
return dec # type: ignore[return-value]
@overload
def root_validator(
*,
# if you don't specify `pre` the default is `pre=False`
# which means you need to specify `skip_on_failure=True`
skip_on_failure: Literal[True],
allow_reuse: bool = ...,
) -> Callable[
[_V1RootValidatorFunctionType],
_V1RootValidatorFunctionType,
]:
...
@overload
def root_validator(
*,
# if you specify `pre=True` then you don't need to specify
# `skip_on_failure`, in fact it is not allowed as an argument!
pre: Literal[True],
allow_reuse: bool = ...,
) -> Callable[
[_V1RootValidatorFunctionType],
_V1RootValidatorFunctionType,
]:
...
@overload
def root_validator(
*,
# if you explicitly specify `pre=False` then you
# MUST specify `skip_on_failure=True`
pre: Literal[False],
skip_on_failure: Literal[True],
allow_reuse: bool = ...,
) -> Callable[
[_V1RootValidatorFunctionType],
_V1RootValidatorFunctionType,
]:
...
def root_validator(
*__args,
pre: bool = False,
skip_on_failure: bool = False,
allow_reuse: bool = False,
) -> Any:
"""Decorate methods on a model indicating that they should be used to validate (and perhaps
modify) data either before or after standard model parsing/validation is performed.
Args:
pre (bool, optional): Whether this validator should be called before the standard
validators (else after). Defaults to False.
skip_on_failure (bool, optional): Whether to stop validation and return as soon as a
failure is encountered. Defaults to False.
allow_reuse (bool, optional): Whether to track and raise an error if another validator
refers to the decorated function. Defaults to False.
Returns:
Any: A decorator that can be used to decorate a function to be used as a root_validator.
"""
warn(
'Pydantic V1 style `@root_validator` validators are deprecated.'
' You should migrate to Pydantic V2 style `@model_validator` validators,'
' see the migration guide for more details',
DeprecationWarning,
stacklevel=2,
)
if __args:
# Ensure a nice error is raised if someone attempts to use the bare decorator
return root_validator()(*__args) # type: ignore
if allow_reuse is True: # pragma: no cover
warn(_ALLOW_REUSE_WARNING_MESSAGE, DeprecationWarning)
mode: Literal['before', 'after'] = 'before' if pre is True else 'after'
if pre is False and skip_on_failure is not True:
raise PydanticUserError(
'If you use `@root_validator` with pre=False (the default) you MUST specify `skip_on_failure=True`.'
' Note that `@root_validator` is deprecated and should be replaced with `@model_validator`.',
code='root-validator-pre-skip',
)
wrap = partial(_decorators_v1.make_v1_generic_root_validator, pre=pre)
def dec(f: Callable[..., Any] | classmethod[Any, Any, Any] | staticmethod[Any, Any]) -> Any:
if _decorators.is_instance_method_from_sig(f):
raise TypeError('`@root_validator` cannot be applied to instance methods')
# auto apply the @classmethod decorator
res = _decorators.ensure_classmethod_based_on_signature(f)
dec_info = _decorators.RootValidatorDecoratorInfo(mode=mode)
return _decorators.PydanticDescriptorProxy(res, dec_info, shim=wrap)
return dec

View file

@ -0,0 +1,72 @@
from __future__ import annotations as _annotations
import warnings
from typing import TYPE_CHECKING, Any
from typing_extensions import Literal, deprecated
from .._internal import _config
from ..warnings import PydanticDeprecatedSince20
if not TYPE_CHECKING:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
__all__ = 'BaseConfig', 'Extra'
class _ConfigMetaclass(type):
def __getattr__(self, item: str) -> Any:
try:
obj = _config.config_defaults[item]
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
return obj
except KeyError as exc:
raise AttributeError(f"type object '{self.__name__}' has no attribute {exc}") from exc
@deprecated('BaseConfig is deprecated. Use the `pydantic.ConfigDict` instead.', category=PydanticDeprecatedSince20)
class BaseConfig(metaclass=_ConfigMetaclass):
"""This class is only retained for backwards compatibility.
!!! Warning "Deprecated"
BaseConfig is deprecated. Use the [`pydantic.ConfigDict`][pydantic.ConfigDict] instead.
"""
def __getattr__(self, item: str) -> Any:
try:
obj = super().__getattribute__(item)
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
return obj
except AttributeError as exc:
try:
return getattr(type(self), item)
except AttributeError:
# re-raising changes the displayed text to reflect that `self` is not a type
raise AttributeError(str(exc)) from exc
def __init_subclass__(cls, **kwargs: Any) -> None:
warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning)
return super().__init_subclass__(**kwargs)
class _ExtraMeta(type):
def __getattribute__(self, __name: str) -> Any:
# The @deprecated decorator accesses other attributes, so we only emit a warning for the expected ones
if __name in {'allow', 'ignore', 'forbid'}:
warnings.warn(
"`pydantic.config.Extra` is deprecated, use literal values instead (e.g. `extra='allow'`)",
DeprecationWarning,
stacklevel=2,
)
return super().__getattribute__(__name)
@deprecated(
"Extra is deprecated. Use literal values instead (e.g. `extra='allow'`)", category=PydanticDeprecatedSince20
)
class Extra(metaclass=_ExtraMeta):
allow: Literal['allow'] = 'allow'
ignore: Literal['ignore'] = 'ignore'
forbid: Literal['forbid'] = 'forbid'

View file

@ -0,0 +1,224 @@
from __future__ import annotations as _annotations
import typing
from copy import deepcopy
from enum import Enum
from typing import Any, Tuple
import typing_extensions
from .._internal import (
_model_construction,
_typing_extra,
_utils,
)
if typing.TYPE_CHECKING:
from .. import BaseModel
from .._internal._utils import AbstractSetIntStr, MappingIntStrAny
AnyClassMethod = classmethod[Any, Any, Any]
TupleGenerator = typing.Generator[Tuple[str, Any], None, None]
Model = typing.TypeVar('Model', bound='BaseModel')
# should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope
IncEx: typing_extensions.TypeAlias = 'set[int] | set[str] | dict[int, Any] | dict[str, Any] | None'
_object_setattr = _model_construction.object_setattr
def _iter(
self: BaseModel,
to_dict: bool = False,
by_alias: bool = False,
include: AbstractSetIntStr | MappingIntStrAny | None = None,
exclude: AbstractSetIntStr | MappingIntStrAny | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> TupleGenerator:
# Merge field set excludes with explicit exclude parameter with explicit overriding field set options.
# The extra "is not None" guards are not logically necessary but optimizes performance for the simple case.
if exclude is not None:
exclude = _utils.ValueItems.merge(
{k: v.exclude for k, v in self.model_fields.items() if v.exclude is not None}, exclude
)
if include is not None:
include = _utils.ValueItems.merge({k: True for k in self.model_fields}, include, intersect=True)
allowed_keys = _calculate_keys(self, include=include, exclude=exclude, exclude_unset=exclude_unset) # type: ignore
if allowed_keys is None and not (to_dict or by_alias or exclude_unset or exclude_defaults or exclude_none):
# huge boost for plain _iter()
yield from self.__dict__.items()
if self.__pydantic_extra__:
yield from self.__pydantic_extra__.items()
return
value_exclude = _utils.ValueItems(self, exclude) if exclude is not None else None
value_include = _utils.ValueItems(self, include) if include is not None else None
if self.__pydantic_extra__ is None:
items = self.__dict__.items()
else:
items = list(self.__dict__.items()) + list(self.__pydantic_extra__.items())
for field_key, v in items:
if (allowed_keys is not None and field_key not in allowed_keys) or (exclude_none and v is None):
continue
if exclude_defaults:
try:
field = self.model_fields[field_key]
except KeyError:
pass
else:
if not field.is_required() and field.default == v:
continue
if by_alias and field_key in self.model_fields:
dict_key = self.model_fields[field_key].alias or field_key
else:
dict_key = field_key
if to_dict or value_include or value_exclude:
v = _get_value(
type(self),
v,
to_dict=to_dict,
by_alias=by_alias,
include=value_include and value_include.for_element(field_key),
exclude=value_exclude and value_exclude.for_element(field_key),
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
yield dict_key, v
def _copy_and_set_values(
self: Model,
values: dict[str, Any],
fields_set: set[str],
extra: dict[str, Any] | None = None,
private: dict[str, Any] | None = None,
*,
deep: bool, # UP006
) -> Model:
if deep:
# chances of having empty dict here are quite low for using smart_deepcopy
values = deepcopy(values)
extra = deepcopy(extra)
private = deepcopy(private)
cls = self.__class__
m = cls.__new__(cls)
_object_setattr(m, '__dict__', values)
_object_setattr(m, '__pydantic_extra__', extra)
_object_setattr(m, '__pydantic_fields_set__', fields_set)
_object_setattr(m, '__pydantic_private__', private)
return m
@typing.no_type_check
def _get_value(
cls: type[BaseModel],
v: Any,
to_dict: bool,
by_alias: bool,
include: AbstractSetIntStr | MappingIntStrAny | None,
exclude: AbstractSetIntStr | MappingIntStrAny | None,
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
) -> Any:
from .. import BaseModel
if isinstance(v, BaseModel):
if to_dict:
return v.model_dump(
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
include=include, # type: ignore
exclude=exclude, # type: ignore
exclude_none=exclude_none,
)
else:
return v.copy(include=include, exclude=exclude)
value_exclude = _utils.ValueItems(v, exclude) if exclude else None
value_include = _utils.ValueItems(v, include) if include else None
if isinstance(v, dict):
return {
k_: _get_value(
cls,
v_,
to_dict=to_dict,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
include=value_include and value_include.for_element(k_),
exclude=value_exclude and value_exclude.for_element(k_),
exclude_none=exclude_none,
)
for k_, v_ in v.items()
if (not value_exclude or not value_exclude.is_excluded(k_))
and (not value_include or value_include.is_included(k_))
}
elif _utils.sequence_like(v):
seq_args = (
_get_value(
cls,
v_,
to_dict=to_dict,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
include=value_include and value_include.for_element(i),
exclude=value_exclude and value_exclude.for_element(i),
exclude_none=exclude_none,
)
for i, v_ in enumerate(v)
if (not value_exclude or not value_exclude.is_excluded(i))
and (not value_include or value_include.is_included(i))
)
return v.__class__(*seq_args) if _typing_extra.is_namedtuple(v.__class__) else v.__class__(seq_args)
elif isinstance(v, Enum) and getattr(cls.model_config, 'use_enum_values', False):
return v.value
else:
return v
def _calculate_keys(
self: BaseModel,
include: MappingIntStrAny | None,
exclude: MappingIntStrAny | None,
exclude_unset: bool,
update: typing.Dict[str, Any] | None = None, # noqa UP006
) -> typing.AbstractSet[str] | None:
if include is None and exclude is None and exclude_unset is False:
return None
keys: typing.AbstractSet[str]
if exclude_unset:
keys = self.__pydantic_fields_set__.copy()
else:
keys = set(self.__dict__.keys())
keys = keys | (self.__pydantic_extra__ or {}).keys()
if include is not None:
keys &= include.keys()
if update:
keys -= update.keys()
if exclude:
keys -= {k for k, v in exclude.items() if _utils.ValueItems.is_true(v)}
return keys

View file

@ -0,0 +1,279 @@
import warnings
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload
from typing_extensions import deprecated
from .._internal import _config, _typing_extra
from ..alias_generators import to_pascal
from ..errors import PydanticUserError
from ..functional_validators import field_validator
from ..main import BaseModel, create_model
from ..warnings import PydanticDeprecatedSince20
if not TYPE_CHECKING:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
__all__ = ('validate_arguments',)
if TYPE_CHECKING:
AnyCallable = Callable[..., Any]
AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable)
ConfigType = Union[None, Type[Any], Dict[str, Any]]
@overload
def validate_arguments(func: None = None, *, config: 'ConfigType' = None) -> Callable[['AnyCallableT'], 'AnyCallableT']:
...
@overload
def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT':
...
@deprecated(
'The `validate_arguments` method is deprecated; use `validate_call` instead.',
category=None,
)
def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any:
"""Decorator to validate the arguments passed to a function."""
warnings.warn(
'The `validate_arguments` method is deprecated; use `validate_call` instead.',
PydanticDeprecatedSince20,
stacklevel=2,
)
def validate(_func: 'AnyCallable') -> 'AnyCallable':
vd = ValidatedFunction(_func, config)
@wraps(_func)
def wrapper_function(*args: Any, **kwargs: Any) -> Any:
return vd.call(*args, **kwargs)
wrapper_function.vd = vd # type: ignore
wrapper_function.validate = vd.init_model_instance # type: ignore
wrapper_function.raw_function = vd.raw_function # type: ignore
wrapper_function.model = vd.model # type: ignore
return wrapper_function
if func:
return validate(func)
else:
return validate
ALT_V_ARGS = 'v__args'
ALT_V_KWARGS = 'v__kwargs'
V_POSITIONAL_ONLY_NAME = 'v__positional_only'
V_DUPLICATE_KWARGS = 'v__duplicate_kwargs'
class ValidatedFunction:
def __init__(self, function: 'AnyCallable', config: 'ConfigType'):
from inspect import Parameter, signature
parameters: Mapping[str, Parameter] = signature(function).parameters
if parameters.keys() & {ALT_V_ARGS, ALT_V_KWARGS, V_POSITIONAL_ONLY_NAME, V_DUPLICATE_KWARGS}:
raise PydanticUserError(
f'"{ALT_V_ARGS}", "{ALT_V_KWARGS}", "{V_POSITIONAL_ONLY_NAME}" and "{V_DUPLICATE_KWARGS}" '
f'are not permitted as argument names when using the "{validate_arguments.__name__}" decorator',
code=None,
)
self.raw_function = function
self.arg_mapping: Dict[int, str] = {}
self.positional_only_args: set[str] = set()
self.v_args_name = 'args'
self.v_kwargs_name = 'kwargs'
type_hints = _typing_extra.get_type_hints(function, include_extras=True)
takes_args = False
takes_kwargs = False
fields: Dict[str, Tuple[Any, Any]] = {}
for i, (name, p) in enumerate(parameters.items()):
if p.annotation is p.empty:
annotation = Any
else:
annotation = type_hints[name]
default = ... if p.default is p.empty else p.default
if p.kind == Parameter.POSITIONAL_ONLY:
self.arg_mapping[i] = name
fields[name] = annotation, default
fields[V_POSITIONAL_ONLY_NAME] = List[str], None
self.positional_only_args.add(name)
elif p.kind == Parameter.POSITIONAL_OR_KEYWORD:
self.arg_mapping[i] = name
fields[name] = annotation, default
fields[V_DUPLICATE_KWARGS] = List[str], None
elif p.kind == Parameter.KEYWORD_ONLY:
fields[name] = annotation, default
elif p.kind == Parameter.VAR_POSITIONAL:
self.v_args_name = name
fields[name] = Tuple[annotation, ...], None
takes_args = True
else:
assert p.kind == Parameter.VAR_KEYWORD, p.kind
self.v_kwargs_name = name
fields[name] = Dict[str, annotation], None
takes_kwargs = True
# these checks avoid a clash between "args" and a field with that name
if not takes_args and self.v_args_name in fields:
self.v_args_name = ALT_V_ARGS
# same with "kwargs"
if not takes_kwargs and self.v_kwargs_name in fields:
self.v_kwargs_name = ALT_V_KWARGS
if not takes_args:
# we add the field so validation below can raise the correct exception
fields[self.v_args_name] = List[Any], None
if not takes_kwargs:
# same with kwargs
fields[self.v_kwargs_name] = Dict[Any, Any], None
self.create_model(fields, takes_args, takes_kwargs, config)
def init_model_instance(self, *args: Any, **kwargs: Any) -> BaseModel:
values = self.build_values(args, kwargs)
return self.model(**values)
def call(self, *args: Any, **kwargs: Any) -> Any:
m = self.init_model_instance(*args, **kwargs)
return self.execute(m)
def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]:
values: Dict[str, Any] = {}
if args:
arg_iter = enumerate(args)
while True:
try:
i, a = next(arg_iter)
except StopIteration:
break
arg_name = self.arg_mapping.get(i)
if arg_name is not None:
values[arg_name] = a
else:
values[self.v_args_name] = [a] + [a for _, a in arg_iter]
break
var_kwargs: Dict[str, Any] = {}
wrong_positional_args = []
duplicate_kwargs = []
fields_alias = [
field.alias
for name, field in self.model.model_fields.items()
if name not in (self.v_args_name, self.v_kwargs_name)
]
non_var_fields = set(self.model.model_fields) - {self.v_args_name, self.v_kwargs_name}
for k, v in kwargs.items():
if k in non_var_fields or k in fields_alias:
if k in self.positional_only_args:
wrong_positional_args.append(k)
if k in values:
duplicate_kwargs.append(k)
values[k] = v
else:
var_kwargs[k] = v
if var_kwargs:
values[self.v_kwargs_name] = var_kwargs
if wrong_positional_args:
values[V_POSITIONAL_ONLY_NAME] = wrong_positional_args
if duplicate_kwargs:
values[V_DUPLICATE_KWARGS] = duplicate_kwargs
return values
def execute(self, m: BaseModel) -> Any:
d = {k: v for k, v in m.__dict__.items() if k in m.__pydantic_fields_set__ or m.model_fields[k].default_factory}
var_kwargs = d.pop(self.v_kwargs_name, {})
if self.v_args_name in d:
args_: List[Any] = []
in_kwargs = False
kwargs = {}
for name, value in d.items():
if in_kwargs:
kwargs[name] = value
elif name == self.v_args_name:
args_ += value
in_kwargs = True
else:
args_.append(value)
return self.raw_function(*args_, **kwargs, **var_kwargs)
elif self.positional_only_args:
args_ = []
kwargs = {}
for name, value in d.items():
if name in self.positional_only_args:
args_.append(value)
else:
kwargs[name] = value
return self.raw_function(*args_, **kwargs, **var_kwargs)
else:
return self.raw_function(**d, **var_kwargs)
def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None:
pos_args = len(self.arg_mapping)
config_wrapper = _config.ConfigWrapper(config)
if config_wrapper.alias_generator:
raise PydanticUserError(
'Setting the "alias_generator" property on custom Config for '
'@validate_arguments is not yet supported, please remove.',
code=None,
)
if config_wrapper.extra is None:
config_wrapper.config_dict['extra'] = 'forbid'
class DecoratorBaseModel(BaseModel):
@field_validator(self.v_args_name, check_fields=False)
@classmethod
def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]:
if takes_args or v is None:
return v
raise TypeError(f'{pos_args} positional arguments expected but {pos_args + len(v)} given')
@field_validator(self.v_kwargs_name, check_fields=False)
@classmethod
def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if takes_kwargs or v is None:
return v
plural = '' if len(v) == 1 else 's'
keys = ', '.join(map(repr, v.keys()))
raise TypeError(f'unexpected keyword argument{plural}: {keys}')
@field_validator(V_POSITIONAL_ONLY_NAME, check_fields=False)
@classmethod
def check_positional_only(cls, v: Optional[List[str]]) -> None:
if v is None:
return
plural = '' if len(v) == 1 else 's'
keys = ', '.join(map(repr, v))
raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}')
@field_validator(V_DUPLICATE_KWARGS, check_fields=False)
@classmethod
def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None:
if v is None:
return
plural = '' if len(v) == 1 else 's'
keys = ', '.join(map(repr, v))
raise TypeError(f'multiple values for argument{plural}: {keys}')
model_config = config_wrapper.config_dict
self.model = create_model(to_pascal(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields)

View file

@ -0,0 +1,140 @@
import datetime
import warnings
from collections import deque
from decimal import Decimal
from enum import Enum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from pathlib import Path
from re import Pattern
from types import GeneratorType
from typing import TYPE_CHECKING, Any, Callable, Dict, Type, Union
from uuid import UUID
from typing_extensions import deprecated
from ..color import Color
from ..networks import NameEmail
from ..types import SecretBytes, SecretStr
from ..warnings import PydanticDeprecatedSince20
if not TYPE_CHECKING:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat'
def isoformat(o: Union[datetime.date, datetime.time]) -> str:
return o.isoformat()
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
"""Encodes a Decimal as int of there's no exponent, otherwise float.
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
where a integer (but not int typed) is used. Encoding this as a float
results in failed round-tripping between encode and parse.
Our Id type is a prime example of this.
>>> decimal_encoder(Decimal("1.0"))
1.0
>>> decimal_encoder(Decimal("1"))
1
"""
exponent = dec_value.as_tuple().exponent
if isinstance(exponent, int) and exponent >= 0:
return int(dec_value)
else:
return float(dec_value)
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
bytes: lambda o: o.decode(),
Color: str,
datetime.date: isoformat,
datetime.datetime: isoformat,
datetime.time: isoformat,
datetime.timedelta: lambda td: td.total_seconds(),
Decimal: decimal_encoder,
Enum: lambda o: o.value,
frozenset: list,
deque: list,
GeneratorType: list,
IPv4Address: str,
IPv4Interface: str,
IPv4Network: str,
IPv6Address: str,
IPv6Interface: str,
IPv6Network: str,
NameEmail: str,
Path: str,
Pattern: lambda o: o.pattern,
SecretBytes: str,
SecretStr: str,
set: list,
UUID: str,
}
@deprecated(
'`pydantic_encoder` is deprecated, use `pydantic_core.to_jsonable_python` instead.',
category=None,
)
def pydantic_encoder(obj: Any) -> Any:
warnings.warn(
'`pydantic_encoder` is deprecated, use `pydantic_core.to_jsonable_python` instead.',
category=PydanticDeprecatedSince20,
stacklevel=2,
)
from dataclasses import asdict, is_dataclass
from ..main import BaseModel
if isinstance(obj, BaseModel):
return obj.model_dump()
elif is_dataclass(obj):
return asdict(obj)
# Check the class type and its superclasses for a matching encoder
for base in obj.__class__.__mro__[:-1]:
try:
encoder = ENCODERS_BY_TYPE[base]
except KeyError:
continue
return encoder(obj)
else: # We have exited the for loop without finding a suitable encoder
raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")
# TODO: Add a suggested migration path once there is a way to use custom encoders
@deprecated(
'`custom_pydantic_encoder` is deprecated, use `BaseModel.model_dump` instead.',
category=None,
)
def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any:
warnings.warn(
'`custom_pydantic_encoder` is deprecated, use `BaseModel.model_dump` instead.',
category=PydanticDeprecatedSince20,
stacklevel=2,
)
# Check the class type and its superclasses for a matching encoder
for base in obj.__class__.__mro__[:-1]:
try:
encoder = type_encoders[base]
except KeyError:
continue
return encoder(obj)
else: # We have exited the for loop without finding a suitable encoder
return pydantic_encoder(obj)
@deprecated('`timedelta_isoformat` is deprecated.', category=None)
def timedelta_isoformat(td: datetime.timedelta) -> str:
"""ISO 8601 encoding for Python timedelta object."""
warnings.warn('`timedelta_isoformat` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2)
minutes, seconds = divmod(td.seconds, 60)
hours, minutes = divmod(minutes, 60)
return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S'

View file

@ -0,0 +1,80 @@
from __future__ import annotations
import json
import pickle
import warnings
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
from typing_extensions import deprecated
from ..warnings import PydanticDeprecatedSince20
if not TYPE_CHECKING:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
class Protocol(str, Enum):
json = 'json'
pickle = 'pickle'
@deprecated('`load_str_bytes` is deprecated.', category=None)
def load_str_bytes(
b: str | bytes,
*,
content_type: str | None = None,
encoding: str = 'utf8',
proto: Protocol | None = None,
allow_pickle: bool = False,
json_loads: Callable[[str], Any] = json.loads,
) -> Any:
warnings.warn('`load_str_bytes` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2)
if proto is None and content_type:
if content_type.endswith(('json', 'javascript')):
pass
elif allow_pickle and content_type.endswith('pickle'):
proto = Protocol.pickle
else:
raise TypeError(f'Unknown content-type: {content_type}')
proto = proto or Protocol.json
if proto == Protocol.json:
if isinstance(b, bytes):
b = b.decode(encoding)
return json_loads(b) # type: ignore
elif proto == Protocol.pickle:
if not allow_pickle:
raise RuntimeError('Trying to decode with pickle with allow_pickle=False')
bb = b if isinstance(b, bytes) else b.encode() # type: ignore
return pickle.loads(bb)
else:
raise TypeError(f'Unknown protocol: {proto}')
@deprecated('`load_file` is deprecated.', category=None)
def load_file(
path: str | Path,
*,
content_type: str | None = None,
encoding: str = 'utf8',
proto: Protocol | None = None,
allow_pickle: bool = False,
json_loads: Callable[[str], Any] = json.loads,
) -> Any:
warnings.warn('`load_file` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2)
path = Path(path)
b = path.read_bytes()
if content_type is None:
if path.suffix in ('.js', '.json'):
proto = Protocol.json
elif path.suffix == '.pkl':
proto = Protocol.pickle
return load_str_bytes(
b, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle, json_loads=json_loads
)

View file

@ -0,0 +1,103 @@
from __future__ import annotations
import json
import warnings
from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar, Union
from typing_extensions import deprecated
from ..json_schema import DEFAULT_REF_TEMPLATE, GenerateJsonSchema
from ..type_adapter import TypeAdapter
from ..warnings import PydanticDeprecatedSince20
if not TYPE_CHECKING:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
__all__ = 'parse_obj_as', 'schema_of', 'schema_json_of'
NameFactory = Union[str, Callable[[Type[Any]], str]]
T = TypeVar('T')
@deprecated(
'`parse_obj_as` is deprecated. Use `pydantic.TypeAdapter.validate_python` instead.',
category=None,
)
def parse_obj_as(type_: type[T], obj: Any, type_name: NameFactory | None = None) -> T:
warnings.warn(
'`parse_obj_as` is deprecated. Use `pydantic.TypeAdapter.validate_python` instead.',
category=PydanticDeprecatedSince20,
stacklevel=2,
)
if type_name is not None: # pragma: no cover
warnings.warn(
'The type_name parameter is deprecated. parse_obj_as no longer creates temporary models',
DeprecationWarning,
stacklevel=2,
)
return TypeAdapter(type_).validate_python(obj)
@deprecated(
'`schema_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
category=None,
)
def schema_of(
type_: Any,
*,
title: NameFactory | None = None,
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
) -> dict[str, Any]:
"""Generate a JSON schema (as dict) for the passed model or dynamically generated one."""
warnings.warn(
'`schema_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
category=PydanticDeprecatedSince20,
stacklevel=2,
)
res = TypeAdapter(type_).json_schema(
by_alias=by_alias,
schema_generator=schema_generator,
ref_template=ref_template,
)
if title is not None:
if isinstance(title, str):
res['title'] = title
else:
warnings.warn(
'Passing a callable for the `title` parameter is deprecated and no longer supported',
DeprecationWarning,
stacklevel=2,
)
res['title'] = title(type_)
return res
@deprecated(
'`schema_json_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
category=None,
)
def schema_json_of(
type_: Any,
*,
title: NameFactory | None = None,
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
**dumps_kwargs: Any,
) -> str:
"""Generate a JSON schema (as JSON) for the passed model or dynamically generated one."""
warnings.warn(
'`schema_json_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.',
category=PydanticDeprecatedSince20,
stacklevel=2,
)
return json.dumps(
schema_of(type_, title=title, by_alias=by_alias, ref_template=ref_template, schema_generator=schema_generator),
**dumps_kwargs,
)

View file

@ -1,346 +1,4 @@
import os """The `env_settings` module is a backport module from V1."""
import warnings from ._migration import getattr_migration
from pathlib import Path
from typing import AbstractSet, Any, Callable, ClassVar, Dict, List, Mapping, Optional, Tuple, Type, Union
from .config import BaseConfig, Extra __getattr__ = getattr_migration(__name__)
from .fields import ModelField
from .main import BaseModel
from .typing import StrPath, display_as_type, get_origin, is_union
from .utils import deep_update, path_type, sequence_like
env_file_sentinel = str(object())
SettingsSourceCallable = Callable[['BaseSettings'], Dict[str, Any]]
DotenvType = Union[StrPath, List[StrPath], Tuple[StrPath, ...]]
class SettingsError(ValueError):
pass
class BaseSettings(BaseModel):
"""
Base class for settings, allowing values to be overridden by environment variables.
This is useful in production for secrets you do not wish to save in code, it plays nicely with docker(-compose),
Heroku and any 12 factor app design.
"""
def __init__(
__pydantic_self__,
_env_file: Optional[DotenvType] = env_file_sentinel,
_env_file_encoding: Optional[str] = None,
_env_nested_delimiter: Optional[str] = None,
_secrets_dir: Optional[StrPath] = None,
**values: Any,
) -> None:
# Uses something other than `self` the first arg to allow "self" as a settable attribute
super().__init__(
**__pydantic_self__._build_values(
values,
_env_file=_env_file,
_env_file_encoding=_env_file_encoding,
_env_nested_delimiter=_env_nested_delimiter,
_secrets_dir=_secrets_dir,
)
)
def _build_values(
self,
init_kwargs: Dict[str, Any],
_env_file: Optional[DotenvType] = None,
_env_file_encoding: Optional[str] = None,
_env_nested_delimiter: Optional[str] = None,
_secrets_dir: Optional[StrPath] = None,
) -> Dict[str, Any]:
# Configure built-in sources
init_settings = InitSettingsSource(init_kwargs=init_kwargs)
env_settings = EnvSettingsSource(
env_file=(_env_file if _env_file != env_file_sentinel else self.__config__.env_file),
env_file_encoding=(
_env_file_encoding if _env_file_encoding is not None else self.__config__.env_file_encoding
),
env_nested_delimiter=(
_env_nested_delimiter if _env_nested_delimiter is not None else self.__config__.env_nested_delimiter
),
env_prefix_len=len(self.__config__.env_prefix),
)
file_secret_settings = SecretsSettingsSource(secrets_dir=_secrets_dir or self.__config__.secrets_dir)
# Provide a hook to set built-in sources priority and add / remove sources
sources = self.__config__.customise_sources(
init_settings=init_settings, env_settings=env_settings, file_secret_settings=file_secret_settings
)
if sources:
return deep_update(*reversed([source(self) for source in sources]))
else:
# no one should mean to do this, but I think returning an empty dict is marginally preferable
# to an informative error and much better than a confusing error
return {}
class Config(BaseConfig):
env_prefix: str = ''
env_file: Optional[DotenvType] = None
env_file_encoding: Optional[str] = None
env_nested_delimiter: Optional[str] = None
secrets_dir: Optional[StrPath] = None
validate_all: bool = True
extra: Extra = Extra.forbid
arbitrary_types_allowed: bool = True
case_sensitive: bool = False
@classmethod
def prepare_field(cls, field: ModelField) -> None:
env_names: Union[List[str], AbstractSet[str]]
field_info_from_config = cls.get_field_info(field.name)
env = field_info_from_config.get('env') or field.field_info.extra.get('env')
if env is None:
if field.has_alias:
warnings.warn(
'aliases are no longer used by BaseSettings to define which environment variables to read. '
'Instead use the "env" field setting. '
'See https://pydantic-docs.helpmanual.io/usage/settings/#environment-variable-names',
FutureWarning,
)
env_names = {cls.env_prefix + field.name}
elif isinstance(env, str):
env_names = {env}
elif isinstance(env, (set, frozenset)):
env_names = env
elif sequence_like(env):
env_names = list(env)
else:
raise TypeError(f'invalid field env: {env!r} ({display_as_type(env)}); should be string, list or set')
if not cls.case_sensitive:
env_names = env_names.__class__(n.lower() for n in env_names)
field.field_info.extra['env_names'] = env_names
@classmethod
def customise_sources(
cls,
init_settings: SettingsSourceCallable,
env_settings: SettingsSourceCallable,
file_secret_settings: SettingsSourceCallable,
) -> Tuple[SettingsSourceCallable, ...]:
return init_settings, env_settings, file_secret_settings
@classmethod
def parse_env_var(cls, field_name: str, raw_val: str) -> Any:
return cls.json_loads(raw_val)
# populated by the metaclass using the Config class defined above, annotated here to help IDEs only
__config__: ClassVar[Type[Config]]
class InitSettingsSource:
__slots__ = ('init_kwargs',)
def __init__(self, init_kwargs: Dict[str, Any]):
self.init_kwargs = init_kwargs
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
return self.init_kwargs
def __repr__(self) -> str:
return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})'
class EnvSettingsSource:
__slots__ = ('env_file', 'env_file_encoding', 'env_nested_delimiter', 'env_prefix_len')
def __init__(
self,
env_file: Optional[DotenvType],
env_file_encoding: Optional[str],
env_nested_delimiter: Optional[str] = None,
env_prefix_len: int = 0,
):
self.env_file: Optional[DotenvType] = env_file
self.env_file_encoding: Optional[str] = env_file_encoding
self.env_nested_delimiter: Optional[str] = env_nested_delimiter
self.env_prefix_len: int = env_prefix_len
def __call__(self, settings: BaseSettings) -> Dict[str, Any]: # noqa C901
"""
Build environment variables suitable for passing to the Model.
"""
d: Dict[str, Any] = {}
if settings.__config__.case_sensitive:
env_vars: Mapping[str, Optional[str]] = os.environ
else:
env_vars = {k.lower(): v for k, v in os.environ.items()}
dotenv_vars = self._read_env_files(settings.__config__.case_sensitive)
if dotenv_vars:
env_vars = {**dotenv_vars, **env_vars}
for field in settings.__fields__.values():
env_val: Optional[str] = None
for env_name in field.field_info.extra['env_names']:
env_val = env_vars.get(env_name)
if env_val is not None:
break
is_complex, allow_parse_failure = self.field_is_complex(field)
if is_complex:
if env_val is None:
# field is complex but no value found so far, try explode_env_vars
env_val_built = self.explode_env_vars(field, env_vars)
if env_val_built:
d[field.alias] = env_val_built
else:
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
try:
env_val = settings.__config__.parse_env_var(field.name, env_val)
except ValueError as e:
if not allow_parse_failure:
raise SettingsError(f'error parsing env var "{env_name}"') from e
if isinstance(env_val, dict):
d[field.alias] = deep_update(env_val, self.explode_env_vars(field, env_vars))
else:
d[field.alias] = env_val
elif env_val is not None:
# simplest case, field is not complex, we only need to add the value if it was found
d[field.alias] = env_val
return d
def _read_env_files(self, case_sensitive: bool) -> Dict[str, Optional[str]]:
env_files = self.env_file
if env_files is None:
return {}
if isinstance(env_files, (str, os.PathLike)):
env_files = [env_files]
dotenv_vars = {}
for env_file in env_files:
env_path = Path(env_file).expanduser()
if env_path.is_file():
dotenv_vars.update(
read_env_file(env_path, encoding=self.env_file_encoding, case_sensitive=case_sensitive)
)
return dotenv_vars
def field_is_complex(self, field: ModelField) -> Tuple[bool, bool]:
"""
Find out if a field is complex, and if so whether JSON errors should be ignored
"""
if field.is_complex():
allow_parse_failure = False
elif is_union(get_origin(field.type_)) and field.sub_fields and any(f.is_complex() for f in field.sub_fields):
allow_parse_failure = True
else:
return False, False
return True, allow_parse_failure
def explode_env_vars(self, field: ModelField, env_vars: Mapping[str, Optional[str]]) -> Dict[str, Any]:
"""
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
This is applied to a single field, hence filtering by env_var prefix.
"""
prefixes = [f'{env_name}{self.env_nested_delimiter}' for env_name in field.field_info.extra['env_names']]
result: Dict[str, Any] = {}
for env_name, env_val in env_vars.items():
if not any(env_name.startswith(prefix) for prefix in prefixes):
continue
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
env_name_without_prefix = env_name[self.env_prefix_len :]
_, *keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter)
env_var = result
for key in keys:
env_var = env_var.setdefault(key, {})
env_var[last_key] = env_val
return result
def __repr__(self) -> str:
return (
f'EnvSettingsSource(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, '
f'env_nested_delimiter={self.env_nested_delimiter!r})'
)
class SecretsSettingsSource:
__slots__ = ('secrets_dir',)
def __init__(self, secrets_dir: Optional[StrPath]):
self.secrets_dir: Optional[StrPath] = secrets_dir
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
"""
Build fields from "secrets" files.
"""
secrets: Dict[str, Optional[str]] = {}
if self.secrets_dir is None:
return secrets
secrets_path = Path(self.secrets_dir).expanduser()
if not secrets_path.exists():
warnings.warn(f'directory "{secrets_path}" does not exist')
return secrets
if not secrets_path.is_dir():
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type(secrets_path)}')
for field in settings.__fields__.values():
for env_name in field.field_info.extra['env_names']:
path = find_case_path(secrets_path, env_name, settings.__config__.case_sensitive)
if not path:
# path does not exist, we curently don't return a warning for this
continue
if path.is_file():
secret_value = path.read_text().strip()
if field.is_complex():
try:
secret_value = settings.__config__.parse_env_var(field.name, secret_value)
except ValueError as e:
raise SettingsError(f'error parsing env var "{env_name}"') from e
secrets[field.alias] = secret_value
else:
warnings.warn(
f'attempted to load secret file "{path}" but found a {path_type(path)} instead.',
stacklevel=4,
)
return secrets
def __repr__(self) -> str:
return f'SecretsSettingsSource(secrets_dir={self.secrets_dir!r})'
def read_env_file(
file_path: StrPath, *, encoding: str = None, case_sensitive: bool = False
) -> Dict[str, Optional[str]]:
try:
from dotenv import dotenv_values
except ImportError as e:
raise ImportError('python-dotenv is not installed, run `pip install pydantic[dotenv]`') from e
file_vars: Dict[str, Optional[str]] = dotenv_values(file_path, encoding=encoding or 'utf8')
if not case_sensitive:
return {k.lower(): v for k, v in file_vars.items()}
else:
return file_vars
def find_case_path(dir_path: Path, file_name: str, case_sensitive: bool) -> Optional[Path]:
"""
Find a file within path's directory matching filename, optionally ignoring case.
"""
for f in dir_path.iterdir():
if f.name == file_name:
return f
elif not case_sensitive and f.name.lower() == file_name.lower():
return f
return None

View file

@ -1,162 +1,4 @@
import json """The `error_wrappers` module is a backport module from V1."""
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple, Type, Union from ._migration import getattr_migration
from .json import pydantic_encoder __getattr__ = getattr_migration(__name__)
from .utils import Representation
if TYPE_CHECKING:
from typing_extensions import TypedDict
from .config import BaseConfig
from .types import ModelOrDc
from .typing import ReprArgs
Loc = Tuple[Union[int, str], ...]
class _ErrorDictRequired(TypedDict):
loc: Loc
msg: str
type: str
class ErrorDict(_ErrorDictRequired, total=False):
ctx: Dict[str, Any]
__all__ = 'ErrorWrapper', 'ValidationError'
class ErrorWrapper(Representation):
__slots__ = 'exc', '_loc'
def __init__(self, exc: Exception, loc: Union[str, 'Loc']) -> None:
self.exc = exc
self._loc = loc
def loc_tuple(self) -> 'Loc':
if isinstance(self._loc, tuple):
return self._loc
else:
return (self._loc,)
def __repr_args__(self) -> 'ReprArgs':
return [('exc', self.exc), ('loc', self.loc_tuple())]
# ErrorList is something like Union[List[Union[List[ErrorWrapper], ErrorWrapper]], ErrorWrapper]
# but recursive, therefore just use:
ErrorList = Union[Sequence[Any], ErrorWrapper]
class ValidationError(Representation, ValueError):
__slots__ = 'raw_errors', 'model', '_error_cache'
def __init__(self, errors: Sequence[ErrorList], model: 'ModelOrDc') -> None:
self.raw_errors = errors
self.model = model
self._error_cache: Optional[List['ErrorDict']] = None
def errors(self) -> List['ErrorDict']:
if self._error_cache is None:
try:
config = self.model.__config__ # type: ignore
except AttributeError:
config = self.model.__pydantic_model__.__config__ # type: ignore
self._error_cache = list(flatten_errors(self.raw_errors, config))
return self._error_cache
def json(self, *, indent: Union[None, int, str] = 2) -> str:
return json.dumps(self.errors(), indent=indent, default=pydantic_encoder)
def __str__(self) -> str:
errors = self.errors()
no_errors = len(errors)
return (
f'{no_errors} validation error{"" if no_errors == 1 else "s"} for {self.model.__name__}\n'
f'{display_errors(errors)}'
)
def __repr_args__(self) -> 'ReprArgs':
return [('model', self.model.__name__), ('errors', self.errors())]
def display_errors(errors: List['ErrorDict']) -> str:
return '\n'.join(f'{_display_error_loc(e)}\n {e["msg"]} ({_display_error_type_and_ctx(e)})' for e in errors)
def _display_error_loc(error: 'ErrorDict') -> str:
return ' -> '.join(str(e) for e in error['loc'])
def _display_error_type_and_ctx(error: 'ErrorDict') -> str:
t = 'type=' + error['type']
ctx = error.get('ctx')
if ctx:
return t + ''.join(f'; {k}={v}' for k, v in ctx.items())
else:
return t
def flatten_errors(
errors: Sequence[Any], config: Type['BaseConfig'], loc: Optional['Loc'] = None
) -> Generator['ErrorDict', None, None]:
for error in errors:
if isinstance(error, ErrorWrapper):
if loc:
error_loc = loc + error.loc_tuple()
else:
error_loc = error.loc_tuple()
if isinstance(error.exc, ValidationError):
yield from flatten_errors(error.exc.raw_errors, config, error_loc)
else:
yield error_dict(error.exc, config, error_loc)
elif isinstance(error, list):
yield from flatten_errors(error, config, loc=loc)
else:
raise RuntimeError(f'Unknown error object: {error}')
def error_dict(exc: Exception, config: Type['BaseConfig'], loc: 'Loc') -> 'ErrorDict':
type_ = get_exc_type(exc.__class__)
msg_template = config.error_msg_templates.get(type_) or getattr(exc, 'msg_template', None)
ctx = exc.__dict__
if msg_template:
msg = msg_template.format(**ctx)
else:
msg = str(exc)
d: 'ErrorDict' = {'loc': loc, 'msg': msg, 'type': type_}
if ctx:
d['ctx'] = ctx
return d
_EXC_TYPE_CACHE: Dict[Type[Exception], str] = {}
def get_exc_type(cls: Type[Exception]) -> str:
# slightly more efficient than using lru_cache since we don't need to worry about the cache filling up
try:
return _EXC_TYPE_CACHE[cls]
except KeyError:
r = _get_exc_type(cls)
_EXC_TYPE_CACHE[cls] = r
return r
def _get_exc_type(cls: Type[Exception]) -> str:
if issubclass(cls, AssertionError):
return 'assertion_error'
base_name = 'type_error' if issubclass(cls, TypeError) else 'value_error'
if cls in (TypeError, ValueError):
# just TypeError or ValueError, no extra code
return base_name
# if it's not a TypeError or ValueError, we just take the lowercase of the exception name
# no chaining or snake case logic, use "code" for more complex error types.
code = getattr(cls, 'code', None) or cls.__name__.replace('Error', '').lower()
return base_name + '.' + code

View file

@ -1,646 +1,152 @@
from decimal import Decimal """Pydantic-specific errors."""
from pathlib import Path from __future__ import annotations as _annotations
from typing import TYPE_CHECKING, Any, Callable, Sequence, Set, Tuple, Type, Union
from .typing import display_as_type import re
if TYPE_CHECKING: from typing_extensions import Literal, Self
from .typing import DictStrAny
from ._migration import getattr_migration
from .version import version_short
# explicitly state exports to avoid "from .errors import *" also importing Decimal, Path etc.
__all__ = ( __all__ = (
'PydanticTypeError', 'PydanticUserError',
'PydanticValueError', 'PydanticUndefinedAnnotation',
'ConfigError', 'PydanticImportError',
'MissingError', 'PydanticSchemaGenerationError',
'ExtraError', 'PydanticInvalidForJsonSchema',
'NoneIsNotAllowedError', 'PydanticErrorCodes',
'NoneIsAllowedError',
'WrongConstantError',
'NotNoneError',
'BoolError',
'BytesError',
'DictError',
'EmailError',
'UrlError',
'UrlSchemeError',
'UrlSchemePermittedError',
'UrlUserInfoError',
'UrlHostError',
'UrlHostTldError',
'UrlPortError',
'UrlExtraError',
'EnumError',
'IntEnumError',
'EnumMemberError',
'IntegerError',
'FloatError',
'PathError',
'PathNotExistsError',
'PathNotAFileError',
'PathNotADirectoryError',
'PyObjectError',
'SequenceError',
'ListError',
'SetError',
'FrozenSetError',
'TupleError',
'TupleLengthError',
'ListMinLengthError',
'ListMaxLengthError',
'ListUniqueItemsError',
'SetMinLengthError',
'SetMaxLengthError',
'FrozenSetMinLengthError',
'FrozenSetMaxLengthError',
'AnyStrMinLengthError',
'AnyStrMaxLengthError',
'StrError',
'StrRegexError',
'NumberNotGtError',
'NumberNotGeError',
'NumberNotLtError',
'NumberNotLeError',
'NumberNotMultipleError',
'DecimalError',
'DecimalIsNotFiniteError',
'DecimalMaxDigitsError',
'DecimalMaxPlacesError',
'DecimalWholeDigitsError',
'DateTimeError',
'DateError',
'DateNotInThePastError',
'DateNotInTheFutureError',
'TimeError',
'DurationError',
'HashableError',
'UUIDError',
'UUIDVersionError',
'ArbitraryTypeError',
'ClassError',
'SubclassError',
'JsonError',
'JsonTypeError',
'PatternError',
'DataclassTypeError',
'CallableError',
'IPvAnyAddressError',
'IPvAnyInterfaceError',
'IPvAnyNetworkError',
'IPv4AddressError',
'IPv6AddressError',
'IPv4NetworkError',
'IPv6NetworkError',
'IPv4InterfaceError',
'IPv6InterfaceError',
'ColorError',
'StrictBoolError',
'NotDigitError',
'LuhnValidationError',
'InvalidLengthForBrand',
'InvalidByteSize',
'InvalidByteSizeUnit',
'MissingDiscriminator',
'InvalidDiscriminator',
) )
# We use this URL to allow for future flexibility about how we host the docs, while allowing for Pydantic
def cls_kwargs(cls: Type['PydanticErrorMixin'], ctx: 'DictStrAny') -> 'PydanticErrorMixin': # code in the while with "old" URLs to still work.
""" # 'u' refers to "user errors" - e.g. errors caused by developers using pydantic, as opposed to validation errors.
For built-in exceptions like ValueError or TypeError, we need to implement DEV_ERROR_DOCS_URL = f'https://errors.pydantic.dev/{version_short()}/u/'
__reduce__ to override the default behaviour (instead of __getstate__/__setstate__) PydanticErrorCodes = Literal[
By default pickle protocol 2 calls `cls.__new__(cls, *args)`. 'class-not-fully-defined',
Since we only use kwargs, we need a little constructor to change that. 'custom-json-schema',
Note: the callable can't be a lambda as pickle looks in the namespace to find it 'decorator-missing-field',
""" 'discriminator-no-field',
return cls(**ctx) 'discriminator-alias-type',
'discriminator-needs-literal',
'discriminator-alias',
'discriminator-validator',
'callable-discriminator-no-tag',
'typed-dict-version',
'model-field-overridden',
'model-field-missing-annotation',
'config-both',
'removed-kwargs',
'invalid-for-json-schema',
'json-schema-already-used',
'base-model-instantiated',
'undefined-annotation',
'schema-for-unknown-type',
'import-error',
'create-model-field-definitions',
'create-model-config-base',
'validator-no-fields',
'validator-invalid-fields',
'validator-instance-method',
'root-validator-pre-skip',
'model-serializer-instance-method',
'validator-field-config-info',
'validator-v1-signature',
'validator-signature',
'field-serializer-signature',
'model-serializer-signature',
'multiple-field-serializers',
'invalid_annotated_type',
'type-adapter-config-unused',
'root-model-extra',
'unevaluable-type-annotation',
'dataclass-init-false-extra-allow',
'clashing-init-and-init-var',
]
class PydanticErrorMixin: class PydanticErrorMixin:
code: str """A mixin class for common functionality shared by all Pydantic-specific errors.
msg_template: str
def __init__(self, **ctx: Any) -> None: Attributes:
self.__dict__ = ctx message: A message describing the error.
code: An optional error code from PydanticErrorCodes enum.
"""
def __init__(self, message: str, *, code: PydanticErrorCodes | None) -> None:
self.message = message
self.code = code
def __str__(self) -> str: def __str__(self) -> str:
return self.msg_template.format(**self.__dict__) if self.code is None:
return self.message
else:
return f'{self.message}\n\nFor further information visit {DEV_ERROR_DOCS_URL}{self.code}'
def __reduce__(self) -> Tuple[Callable[..., 'PydanticErrorMixin'], Tuple[Type['PydanticErrorMixin'], 'DictStrAny']]:
return cls_kwargs, (self.__class__, self.__dict__)
class PydanticUserError(PydanticErrorMixin, TypeError):
"""An error raised due to incorrect use of Pydantic."""
class PydanticTypeError(PydanticErrorMixin, TypeError):
pass
class PydanticUndefinedAnnotation(PydanticErrorMixin, NameError):
"""A subclass of `NameError` raised when handling undefined annotations during `CoreSchema` generation.
class PydanticValueError(PydanticErrorMixin, ValueError): Attributes:
pass name: Name of the error.
message: Description of the error.
"""
def __init__(self, name: str, message: str) -> None:
self.name = name
super().__init__(message=message, code='undefined-annotation')
class ConfigError(RuntimeError): @classmethod
pass def from_name_error(cls, name_error: NameError) -> Self:
"""Convert a `NameError` to a `PydanticUndefinedAnnotation` error.
Args:
name_error: `NameError` to be converted.
class MissingError(PydanticValueError): Returns:
msg_template = 'field required' Converted `PydanticUndefinedAnnotation` error.
"""
try:
name = name_error.name # type: ignore # python > 3.10
except AttributeError:
name = re.search(r".*'(.+?)'", str(name_error)).group(1) # type: ignore[union-attr]
return cls(name=name, message=str(name_error))
class ExtraError(PydanticValueError): class PydanticImportError(PydanticErrorMixin, ImportError):
msg_template = 'extra fields not permitted' """An error raised when an import fails due to module changes between V1 and V2.
Attributes:
message: Description of the error.
"""
class NoneIsNotAllowedError(PydanticTypeError): def __init__(self, message: str) -> None:
code = 'none.not_allowed' super().__init__(message, code='import-error')
msg_template = 'none is not an allowed value'
class NoneIsAllowedError(PydanticTypeError): class PydanticSchemaGenerationError(PydanticUserError):
code = 'none.allowed' """An error raised during failures to generate a `CoreSchema` for some type.
msg_template = 'value is not none'
Attributes:
message: Description of the error.
"""
class WrongConstantError(PydanticValueError): def __init__(self, message: str) -> None:
code = 'const' super().__init__(message, code='schema-for-unknown-type')
def __str__(self) -> str:
permitted = ', '.join(repr(v) for v in self.permitted) # type: ignore
return f'unexpected value; permitted: {permitted}'
class PydanticInvalidForJsonSchema(PydanticUserError):
"""An error raised during failures to generate a JSON schema for some `CoreSchema`.
class NotNoneError(PydanticTypeError): Attributes:
code = 'not_none' message: Description of the error.
msg_template = 'value is not None' """
def __init__(self, message: str) -> None:
super().__init__(message, code='invalid-for-json-schema')
class BoolError(PydanticTypeError):
msg_template = 'value could not be parsed to a boolean'
__getattr__ = getattr_migration(__name__)
class BytesError(PydanticTypeError):
msg_template = 'byte type expected'
class DictError(PydanticTypeError):
msg_template = 'value is not a valid dict'
class EmailError(PydanticValueError):
msg_template = 'value is not a valid email address'
class UrlError(PydanticValueError):
code = 'url'
class UrlSchemeError(UrlError):
code = 'url.scheme'
msg_template = 'invalid or missing URL scheme'
class UrlSchemePermittedError(UrlError):
code = 'url.scheme'
msg_template = 'URL scheme not permitted'
def __init__(self, allowed_schemes: Set[str]):
super().__init__(allowed_schemes=allowed_schemes)
class UrlUserInfoError(UrlError):
code = 'url.userinfo'
msg_template = 'userinfo required in URL but missing'
class UrlHostError(UrlError):
code = 'url.host'
msg_template = 'URL host invalid'
class UrlHostTldError(UrlError):
code = 'url.host'
msg_template = 'URL host invalid, top level domain required'
class UrlPortError(UrlError):
code = 'url.port'
msg_template = 'URL port invalid, port cannot exceed 65535'
class UrlExtraError(UrlError):
code = 'url.extra'
msg_template = 'URL invalid, extra characters found after valid URL: {extra!r}'
class EnumMemberError(PydanticTypeError):
code = 'enum'
def __str__(self) -> str:
permitted = ', '.join(repr(v.value) for v in self.enum_values) # type: ignore
return f'value is not a valid enumeration member; permitted: {permitted}'
class IntegerError(PydanticTypeError):
msg_template = 'value is not a valid integer'
class FloatError(PydanticTypeError):
msg_template = 'value is not a valid float'
class PathError(PydanticTypeError):
msg_template = 'value is not a valid path'
class _PathValueError(PydanticValueError):
def __init__(self, *, path: Path) -> None:
super().__init__(path=str(path))
class PathNotExistsError(_PathValueError):
code = 'path.not_exists'
msg_template = 'file or directory at path "{path}" does not exist'
class PathNotAFileError(_PathValueError):
code = 'path.not_a_file'
msg_template = 'path "{path}" does not point to a file'
class PathNotADirectoryError(_PathValueError):
code = 'path.not_a_directory'
msg_template = 'path "{path}" does not point to a directory'
class PyObjectError(PydanticTypeError):
msg_template = 'ensure this value contains valid import path or valid callable: {error_message}'
class SequenceError(PydanticTypeError):
msg_template = 'value is not a valid sequence'
class IterableError(PydanticTypeError):
msg_template = 'value is not a valid iterable'
class ListError(PydanticTypeError):
msg_template = 'value is not a valid list'
class SetError(PydanticTypeError):
msg_template = 'value is not a valid set'
class FrozenSetError(PydanticTypeError):
msg_template = 'value is not a valid frozenset'
class DequeError(PydanticTypeError):
msg_template = 'value is not a valid deque'
class TupleError(PydanticTypeError):
msg_template = 'value is not a valid tuple'
class TupleLengthError(PydanticValueError):
code = 'tuple.length'
msg_template = 'wrong tuple length {actual_length}, expected {expected_length}'
def __init__(self, *, actual_length: int, expected_length: int) -> None:
super().__init__(actual_length=actual_length, expected_length=expected_length)
class ListMinLengthError(PydanticValueError):
code = 'list.min_items'
msg_template = 'ensure this value has at least {limit_value} items'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class ListMaxLengthError(PydanticValueError):
code = 'list.max_items'
msg_template = 'ensure this value has at most {limit_value} items'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class ListUniqueItemsError(PydanticValueError):
code = 'list.unique_items'
msg_template = 'the list has duplicated items'
class SetMinLengthError(PydanticValueError):
code = 'set.min_items'
msg_template = 'ensure this value has at least {limit_value} items'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class SetMaxLengthError(PydanticValueError):
code = 'set.max_items'
msg_template = 'ensure this value has at most {limit_value} items'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class FrozenSetMinLengthError(PydanticValueError):
code = 'frozenset.min_items'
msg_template = 'ensure this value has at least {limit_value} items'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class FrozenSetMaxLengthError(PydanticValueError):
code = 'frozenset.max_items'
msg_template = 'ensure this value has at most {limit_value} items'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class AnyStrMinLengthError(PydanticValueError):
code = 'any_str.min_length'
msg_template = 'ensure this value has at least {limit_value} characters'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class AnyStrMaxLengthError(PydanticValueError):
code = 'any_str.max_length'
msg_template = 'ensure this value has at most {limit_value} characters'
def __init__(self, *, limit_value: int) -> None:
super().__init__(limit_value=limit_value)
class StrError(PydanticTypeError):
msg_template = 'str type expected'
class StrRegexError(PydanticValueError):
code = 'str.regex'
msg_template = 'string does not match regex "{pattern}"'
def __init__(self, *, pattern: str) -> None:
super().__init__(pattern=pattern)
class _NumberBoundError(PydanticValueError):
def __init__(self, *, limit_value: Union[int, float, Decimal]) -> None:
super().__init__(limit_value=limit_value)
class NumberNotGtError(_NumberBoundError):
code = 'number.not_gt'
msg_template = 'ensure this value is greater than {limit_value}'
class NumberNotGeError(_NumberBoundError):
code = 'number.not_ge'
msg_template = 'ensure this value is greater than or equal to {limit_value}'
class NumberNotLtError(_NumberBoundError):
code = 'number.not_lt'
msg_template = 'ensure this value is less than {limit_value}'
class NumberNotLeError(_NumberBoundError):
code = 'number.not_le'
msg_template = 'ensure this value is less than or equal to {limit_value}'
class NumberNotFiniteError(PydanticValueError):
code = 'number.not_finite_number'
msg_template = 'ensure this value is a finite number'
class NumberNotMultipleError(PydanticValueError):
code = 'number.not_multiple'
msg_template = 'ensure this value is a multiple of {multiple_of}'
def __init__(self, *, multiple_of: Union[int, float, Decimal]) -> None:
super().__init__(multiple_of=multiple_of)
class DecimalError(PydanticTypeError):
msg_template = 'value is not a valid decimal'
class DecimalIsNotFiniteError(PydanticValueError):
code = 'decimal.not_finite'
msg_template = 'value is not a valid decimal'
class DecimalMaxDigitsError(PydanticValueError):
code = 'decimal.max_digits'
msg_template = 'ensure that there are no more than {max_digits} digits in total'
def __init__(self, *, max_digits: int) -> None:
super().__init__(max_digits=max_digits)
class DecimalMaxPlacesError(PydanticValueError):
code = 'decimal.max_places'
msg_template = 'ensure that there are no more than {decimal_places} decimal places'
def __init__(self, *, decimal_places: int) -> None:
super().__init__(decimal_places=decimal_places)
class DecimalWholeDigitsError(PydanticValueError):
code = 'decimal.whole_digits'
msg_template = 'ensure that there are no more than {whole_digits} digits before the decimal point'
def __init__(self, *, whole_digits: int) -> None:
super().__init__(whole_digits=whole_digits)
class DateTimeError(PydanticValueError):
msg_template = 'invalid datetime format'
class DateError(PydanticValueError):
msg_template = 'invalid date format'
class DateNotInThePastError(PydanticValueError):
code = 'date.not_in_the_past'
msg_template = 'date is not in the past'
class DateNotInTheFutureError(PydanticValueError):
code = 'date.not_in_the_future'
msg_template = 'date is not in the future'
class TimeError(PydanticValueError):
msg_template = 'invalid time format'
class DurationError(PydanticValueError):
msg_template = 'invalid duration format'
class HashableError(PydanticTypeError):
msg_template = 'value is not a valid hashable'
class UUIDError(PydanticTypeError):
msg_template = 'value is not a valid uuid'
class UUIDVersionError(PydanticValueError):
code = 'uuid.version'
msg_template = 'uuid version {required_version} expected'
def __init__(self, *, required_version: int) -> None:
super().__init__(required_version=required_version)
class ArbitraryTypeError(PydanticTypeError):
code = 'arbitrary_type'
msg_template = 'instance of {expected_arbitrary_type} expected'
def __init__(self, *, expected_arbitrary_type: Type[Any]) -> None:
super().__init__(expected_arbitrary_type=display_as_type(expected_arbitrary_type))
class ClassError(PydanticTypeError):
code = 'class'
msg_template = 'a class is expected'
class SubclassError(PydanticTypeError):
code = 'subclass'
msg_template = 'subclass of {expected_class} expected'
def __init__(self, *, expected_class: Type[Any]) -> None:
super().__init__(expected_class=display_as_type(expected_class))
class JsonError(PydanticValueError):
msg_template = 'Invalid JSON'
class JsonTypeError(PydanticTypeError):
code = 'json'
msg_template = 'JSON object must be str, bytes or bytearray'
class PatternError(PydanticValueError):
code = 'regex_pattern'
msg_template = 'Invalid regular expression'
class DataclassTypeError(PydanticTypeError):
code = 'dataclass'
msg_template = 'instance of {class_name}, tuple or dict expected'
class CallableError(PydanticTypeError):
msg_template = '{value} is not callable'
class EnumError(PydanticTypeError):
code = 'enum_instance'
msg_template = '{value} is not a valid Enum instance'
class IntEnumError(PydanticTypeError):
code = 'int_enum_instance'
msg_template = '{value} is not a valid IntEnum instance'
class IPvAnyAddressError(PydanticValueError):
msg_template = 'value is not a valid IPv4 or IPv6 address'
class IPvAnyInterfaceError(PydanticValueError):
msg_template = 'value is not a valid IPv4 or IPv6 interface'
class IPvAnyNetworkError(PydanticValueError):
msg_template = 'value is not a valid IPv4 or IPv6 network'
class IPv4AddressError(PydanticValueError):
msg_template = 'value is not a valid IPv4 address'
class IPv6AddressError(PydanticValueError):
msg_template = 'value is not a valid IPv6 address'
class IPv4NetworkError(PydanticValueError):
msg_template = 'value is not a valid IPv4 network'
class IPv6NetworkError(PydanticValueError):
msg_template = 'value is not a valid IPv6 network'
class IPv4InterfaceError(PydanticValueError):
msg_template = 'value is not a valid IPv4 interface'
class IPv6InterfaceError(PydanticValueError):
msg_template = 'value is not a valid IPv6 interface'
class ColorError(PydanticValueError):
msg_template = 'value is not a valid color: {reason}'
class StrictBoolError(PydanticValueError):
msg_template = 'value is not a valid boolean'
class NotDigitError(PydanticValueError):
code = 'payment_card_number.digits'
msg_template = 'card number is not all digits'
class LuhnValidationError(PydanticValueError):
code = 'payment_card_number.luhn_check'
msg_template = 'card number is not luhn valid'
class InvalidLengthForBrand(PydanticValueError):
code = 'payment_card_number.invalid_length_for_brand'
msg_template = 'Length for a {brand} card must be {required_length}'
class InvalidByteSize(PydanticValueError):
msg_template = 'could not parse value and unit from byte string'
class InvalidByteSizeUnit(PydanticValueError):
msg_template = 'could not interpret byte unit: {unit}'
class MissingDiscriminator(PydanticValueError):
code = 'discriminated_union.missing_discriminator'
msg_template = 'Discriminator {discriminator_key!r} is missing in value'
class InvalidDiscriminator(PydanticValueError):
code = 'discriminated_union.invalid_discriminator'
msg_template = (
'No match for discriminator {discriminator_key!r} and value {discriminator_value!r} '
'(allowed values: {allowed_values})'
)
def __init__(self, *, discriminator_key: str, discriminator_value: Any, allowed_values: Sequence[Any]) -> None:
super().__init__(
discriminator_key=discriminator_key,
discriminator_value=discriminator_value,
allowed_values=', '.join(map(repr, allowed_values)),
)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,395 @@
"""This module contains related classes and functions for serialization."""
from __future__ import annotations
import dataclasses
from functools import partialmethod
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, overload
from pydantic_core import PydanticUndefined, core_schema
from pydantic_core import core_schema as _core_schema
from typing_extensions import Annotated, Literal, TypeAlias
from . import PydanticUndefinedAnnotation
from ._internal import _decorators, _internal_dataclass
from .annotated_handlers import GetCoreSchemaHandler
@dataclasses.dataclass(**_internal_dataclass.slots_true, frozen=True)
class PlainSerializer:
"""Plain serializers use a function to modify the output of serialization.
This is particularly helpful when you want to customize the serialization for annotated types.
Consider an input of `list`, which will be serialized into a space-delimited string.
```python
from typing import List
from typing_extensions import Annotated
from pydantic import BaseModel, PlainSerializer
CustomStr = Annotated[
List, PlainSerializer(lambda x: ' '.join(x), return_type=str)
]
class StudentModel(BaseModel):
courses: CustomStr
student = StudentModel(courses=['Math', 'Chemistry', 'English'])
print(student.model_dump())
#> {'courses': 'Math Chemistry English'}
```
Attributes:
func: The serializer function.
return_type: The return type for the function. If omitted it will be inferred from the type annotation.
when_used: Determines when this serializer should be used. Accepts a string with values `'always'`,
`'unless-none'`, `'json'`, and `'json-unless-none'`. Defaults to 'always'.
"""
func: core_schema.SerializerFunction
return_type: Any = PydanticUndefined
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always'
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
"""Gets the Pydantic core schema.
Args:
source_type: The source type.
handler: The `GetCoreSchemaHandler` instance.
Returns:
The Pydantic core schema.
"""
schema = handler(source_type)
try:
return_type = _decorators.get_function_return_type(
self.func, self.return_type, handler._get_types_namespace()
)
except NameError as e:
raise PydanticUndefinedAnnotation.from_name_error(e) from e
return_schema = None if return_type is PydanticUndefined else handler.generate_schema(return_type)
schema['serialization'] = core_schema.plain_serializer_function_ser_schema(
function=self.func,
info_arg=_decorators.inspect_annotated_serializer(self.func, 'plain'),
return_schema=return_schema,
when_used=self.when_used,
)
return schema
@dataclasses.dataclass(**_internal_dataclass.slots_true, frozen=True)
class WrapSerializer:
"""Wrap serializers receive the raw inputs along with a handler function that applies the standard serialization
logic, and can modify the resulting value before returning it as the final output of serialization.
For example, here's a scenario in which a wrap serializer transforms timezones to UTC **and** utilizes the existing `datetime` serialization logic.
```python
from datetime import datetime, timezone
from typing import Any, Dict
from typing_extensions import Annotated
from pydantic import BaseModel, WrapSerializer
class EventDatetime(BaseModel):
start: datetime
end: datetime
def convert_to_utc(value: Any, handler, info) -> Dict[str, datetime]:
# Note that `helper` can actually help serialize the `value` for further custom serialization in case it's a subclass.
partial_result = handler(value, info)
if info.mode == 'json':
return {
k: datetime.fromisoformat(v).astimezone(timezone.utc)
for k, v in partial_result.items()
}
return {k: v.astimezone(timezone.utc) for k, v in partial_result.items()}
UTCEventDatetime = Annotated[EventDatetime, WrapSerializer(convert_to_utc)]
class EventModel(BaseModel):
event_datetime: UTCEventDatetime
dt = EventDatetime(
start='2024-01-01T07:00:00-08:00', end='2024-01-03T20:00:00+06:00'
)
event = EventModel(event_datetime=dt)
print(event.model_dump())
'''
{
'event_datetime': {
'start': datetime.datetime(
2024, 1, 1, 15, 0, tzinfo=datetime.timezone.utc
),
'end': datetime.datetime(
2024, 1, 3, 14, 0, tzinfo=datetime.timezone.utc
),
}
}
'''
print(event.model_dump_json())
'''
{"event_datetime":{"start":"2024-01-01T15:00:00Z","end":"2024-01-03T14:00:00Z"}}
'''
```
Attributes:
func: The serializer function to be wrapped.
return_type: The return type for the function. If omitted it will be inferred from the type annotation.
when_used: Determines when this serializer should be used. Accepts a string with values `'always'`,
`'unless-none'`, `'json'`, and `'json-unless-none'`. Defaults to 'always'.
"""
func: core_schema.WrapSerializerFunction
return_type: Any = PydanticUndefined
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always'
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
"""This method is used to get the Pydantic core schema of the class.
Args:
source_type: Source type.
handler: Core schema handler.
Returns:
The generated core schema of the class.
"""
schema = handler(source_type)
try:
return_type = _decorators.get_function_return_type(
self.func, self.return_type, handler._get_types_namespace()
)
except NameError as e:
raise PydanticUndefinedAnnotation.from_name_error(e) from e
return_schema = None if return_type is PydanticUndefined else handler.generate_schema(return_type)
schema['serialization'] = core_schema.wrap_serializer_function_ser_schema(
function=self.func,
info_arg=_decorators.inspect_annotated_serializer(self.func, 'wrap'),
return_schema=return_schema,
when_used=self.when_used,
)
return schema
if TYPE_CHECKING:
_PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]]
_PlainSerializationFunction = Union[_core_schema.SerializerFunction, _PartialClsOrStaticMethod]
_WrapSerializationFunction = Union[_core_schema.WrapSerializerFunction, _PartialClsOrStaticMethod]
_PlainSerializeMethodType = TypeVar('_PlainSerializeMethodType', bound=_PlainSerializationFunction)
_WrapSerializeMethodType = TypeVar('_WrapSerializeMethodType', bound=_WrapSerializationFunction)
@overload
def field_serializer(
__field: str,
*fields: str,
return_type: Any = ...,
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = ...,
check_fields: bool | None = ...,
) -> Callable[[_PlainSerializeMethodType], _PlainSerializeMethodType]:
...
@overload
def field_serializer(
__field: str,
*fields: str,
mode: Literal['plain'],
return_type: Any = ...,
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = ...,
check_fields: bool | None = ...,
) -> Callable[[_PlainSerializeMethodType], _PlainSerializeMethodType]:
...
@overload
def field_serializer(
__field: str,
*fields: str,
mode: Literal['wrap'],
return_type: Any = ...,
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = ...,
check_fields: bool | None = ...,
) -> Callable[[_WrapSerializeMethodType], _WrapSerializeMethodType]:
...
def field_serializer(
*fields: str,
mode: Literal['plain', 'wrap'] = 'plain',
return_type: Any = PydanticUndefined,
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always',
check_fields: bool | None = None,
) -> Callable[[Any], Any]:
"""Decorator that enables custom field serialization.
In the below example, a field of type `set` is used to mitigate duplication. A `field_serializer` is used to serialize the data as a sorted list.
```python
from typing import Set
from pydantic import BaseModel, field_serializer
class StudentModel(BaseModel):
name: str = 'Jane'
courses: Set[str]
@field_serializer('courses', when_used='json')
def serialize_courses_in_order(courses: Set[str]):
return sorted(courses)
student = StudentModel(courses={'Math', 'Chemistry', 'English'})
print(student.model_dump_json())
#> {"name":"Jane","courses":["Chemistry","English","Math"]}
```
See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information.
Four signatures are supported:
- `(self, value: Any, info: FieldSerializationInfo)`
- `(self, value: Any, nxt: SerializerFunctionWrapHandler, info: FieldSerializationInfo)`
- `(value: Any, info: SerializationInfo)`
- `(value: Any, nxt: SerializerFunctionWrapHandler, info: SerializationInfo)`
Args:
fields: Which field(s) the method should be called on.
mode: The serialization mode.
- `plain` means the function will be called instead of the default serialization logic,
- `wrap` means the function will be called with an argument to optionally call the
default serialization logic.
return_type: Optional return type for the function, if omitted it will be inferred from the type annotation.
when_used: Determines the serializer will be used for serialization.
check_fields: Whether to check that the fields actually exist on the model.
Returns:
The decorator function.
"""
def dec(
f: Callable[..., Any] | staticmethod[Any, Any] | classmethod[Any, Any, Any],
) -> _decorators.PydanticDescriptorProxy[Any]:
dec_info = _decorators.FieldSerializerDecoratorInfo(
fields=fields,
mode=mode,
return_type=return_type,
when_used=when_used,
check_fields=check_fields,
)
return _decorators.PydanticDescriptorProxy(f, dec_info)
return dec
FuncType = TypeVar('FuncType', bound=Callable[..., Any])
@overload
def model_serializer(__f: FuncType) -> FuncType:
...
@overload
def model_serializer(
*,
mode: Literal['plain', 'wrap'] = ...,
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always',
return_type: Any = ...,
) -> Callable[[FuncType], FuncType]:
...
def model_serializer(
__f: Callable[..., Any] | None = None,
*,
mode: Literal['plain', 'wrap'] = 'plain',
when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always',
return_type: Any = PydanticUndefined,
) -> Callable[[Any], Any]:
"""Decorator that enables custom model serialization.
This is useful when a model need to be serialized in a customized manner, allowing for flexibility beyond just specific fields.
An example would be to serialize temperature to the same temperature scale, such as degrees Celsius.
```python
from typing import Literal
from pydantic import BaseModel, model_serializer
class TemperatureModel(BaseModel):
unit: Literal['C', 'F']
value: int
@model_serializer()
def serialize_model(self):
if self.unit == 'F':
return {'unit': 'C', 'value': int((self.value - 32) / 1.8)}
return {'unit': self.unit, 'value': self.value}
temperature = TemperatureModel(unit='F', value=212)
print(temperature.model_dump())
#> {'unit': 'C', 'value': 100}
```
See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information.
Args:
__f: The function to be decorated.
mode: The serialization mode.
- `'plain'` means the function will be called instead of the default serialization logic
- `'wrap'` means the function will be called with an argument to optionally call the default
serialization logic.
when_used: Determines when this serializer should be used.
return_type: The return type for the function. If omitted it will be inferred from the type annotation.
Returns:
The decorator function.
"""
def dec(f: Callable[..., Any]) -> _decorators.PydanticDescriptorProxy[Any]:
dec_info = _decorators.ModelSerializerDecoratorInfo(mode=mode, return_type=return_type, when_used=when_used)
return _decorators.PydanticDescriptorProxy(f, dec_info)
if __f is None:
return dec
else:
return dec(__f) # type: ignore
AnyType = TypeVar('AnyType')
if TYPE_CHECKING:
SerializeAsAny = Annotated[AnyType, ...] # SerializeAsAny[list[str]] will be treated by type checkers as list[str]
"""Force serialization to ignore whatever is defined in the schema and instead ask the object
itself how it should be serialized.
In particular, this means that when model subclasses are serialized, fields present in the subclass
but not in the original schema will be included.
"""
else:
@dataclasses.dataclass(**_internal_dataclass.slots_true)
class SerializeAsAny: # noqa: D101
def __class_getitem__(cls, item: Any) -> Any:
return Annotated[item, SerializeAsAny()]
def __get_pydantic_core_schema__(
self, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
schema = handler(source_type)
schema_to_update = schema
while schema_to_update['type'] == 'definitions':
schema_to_update = schema_to_update.copy()
schema_to_update = schema_to_update['schema']
schema_to_update['serialization'] = core_schema.wrap_serializer_function_ser_schema(
lambda x, h: h(x), schema=core_schema.any_schema()
)
return schema
__hash__ = object.__hash__

View file

@ -0,0 +1,706 @@
"""This module contains related classes and functions for validation."""
from __future__ import annotations as _annotations
import dataclasses
import sys
from functools import partialmethod
from types import FunctionType
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast, overload
from pydantic_core import core_schema
from pydantic_core import core_schema as _core_schema
from typing_extensions import Annotated, Literal, TypeAlias
from . import GetCoreSchemaHandler as _GetCoreSchemaHandler
from ._internal import _core_metadata, _decorators, _generics, _internal_dataclass
from .annotated_handlers import GetCoreSchemaHandler
from .errors import PydanticUserError
if sys.version_info < (3, 11):
from typing_extensions import Protocol
else:
from typing import Protocol
_inspect_validator = _decorators.inspect_validator
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
class AfterValidator:
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/validators/#annotated-validators
A metadata class that indicates that a validation should be applied **after** the inner validation logic.
Attributes:
func: The validator function.
Example:
```py
from typing_extensions import Annotated
from pydantic import AfterValidator, BaseModel, ValidationError
MyInt = Annotated[int, AfterValidator(lambda v: v + 1)]
class Model(BaseModel):
a: MyInt
print(Model(a=1).a)
#> 2
try:
Model(a='a')
except ValidationError as e:
print(e.json(indent=2))
'''
[
{
"type": "int_parsing",
"loc": [
"a"
],
"msg": "Input should be a valid integer, unable to parse string as an integer",
"input": "a",
"url": "https://errors.pydantic.dev/2/v/int_parsing"
}
]
'''
```
"""
func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction
def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema:
schema = handler(source_type)
info_arg = _inspect_validator(self.func, 'after')
if info_arg:
func = cast(core_schema.WithInfoValidatorFunction, self.func)
return core_schema.with_info_after_validator_function(func, schema=schema, field_name=handler.field_name)
else:
func = cast(core_schema.NoInfoValidatorFunction, self.func)
return core_schema.no_info_after_validator_function(func, schema=schema)
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
class BeforeValidator:
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/validators/#annotated-validators
A metadata class that indicates that a validation should be applied **before** the inner validation logic.
Attributes:
func: The validator function.
Example:
```py
from typing_extensions import Annotated
from pydantic import BaseModel, BeforeValidator
MyInt = Annotated[int, BeforeValidator(lambda v: v + 1)]
class Model(BaseModel):
a: MyInt
print(Model(a=1).a)
#> 2
try:
Model(a='a')
except TypeError as e:
print(e)
#> can only concatenate str (not "int") to str
```
"""
func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction
def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema:
schema = handler(source_type)
info_arg = _inspect_validator(self.func, 'before')
if info_arg:
func = cast(core_schema.WithInfoValidatorFunction, self.func)
return core_schema.with_info_before_validator_function(func, schema=schema, field_name=handler.field_name)
else:
func = cast(core_schema.NoInfoValidatorFunction, self.func)
return core_schema.no_info_before_validator_function(func, schema=schema)
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
class PlainValidator:
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/validators/#annotated-validators
A metadata class that indicates that a validation should be applied **instead** of the inner validation logic.
Attributes:
func: The validator function.
Example:
```py
from typing_extensions import Annotated
from pydantic import BaseModel, PlainValidator
MyInt = Annotated[int, PlainValidator(lambda v: int(v) + 1)]
class Model(BaseModel):
a: MyInt
print(Model(a='1').a)
#> 2
```
"""
func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction
def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema:
# Note that for some valid uses of PlainValidator, it is not possible to generate a core schema for the
# source_type, so calling `handler(source_type)` will error, which prevents us from generating a proper
# serialization schema. To work around this for use cases that will not involve serialization, we simply
# catch any PydanticSchemaGenerationError that may be raised while attempting to build the serialization schema
# and abort any attempts to handle special serialization.
from pydantic import PydanticSchemaGenerationError
try:
schema = handler(source_type)
serialization = core_schema.wrap_serializer_function_ser_schema(function=lambda v, h: h(v), schema=schema)
except PydanticSchemaGenerationError:
serialization = None
info_arg = _inspect_validator(self.func, 'plain')
if info_arg:
func = cast(core_schema.WithInfoValidatorFunction, self.func)
return core_schema.with_info_plain_validator_function(
func, field_name=handler.field_name, serialization=serialization
)
else:
func = cast(core_schema.NoInfoValidatorFunction, self.func)
return core_schema.no_info_plain_validator_function(func, serialization=serialization)
@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true)
class WrapValidator:
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/validators/#annotated-validators
A metadata class that indicates that a validation should be applied **around** the inner validation logic.
Attributes:
func: The validator function.
```py
from datetime import datetime
from typing_extensions import Annotated
from pydantic import BaseModel, ValidationError, WrapValidator
def validate_timestamp(v, handler):
if v == 'now':
# we don't want to bother with further validation, just return the new value
return datetime.now()
try:
return handler(v)
except ValidationError:
# validation failed, in this case we want to return a default value
return datetime(2000, 1, 1)
MyTimestamp = Annotated[datetime, WrapValidator(validate_timestamp)]
class Model(BaseModel):
a: MyTimestamp
print(Model(a='now').a)
#> 2032-01-02 03:04:05.000006
print(Model(a='invalid').a)
#> 2000-01-01 00:00:00
```
"""
func: core_schema.NoInfoWrapValidatorFunction | core_schema.WithInfoWrapValidatorFunction
def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema:
schema = handler(source_type)
info_arg = _inspect_validator(self.func, 'wrap')
if info_arg:
func = cast(core_schema.WithInfoWrapValidatorFunction, self.func)
return core_schema.with_info_wrap_validator_function(func, schema=schema, field_name=handler.field_name)
else:
func = cast(core_schema.NoInfoWrapValidatorFunction, self.func)
return core_schema.no_info_wrap_validator_function(func, schema=schema)
if TYPE_CHECKING:
class _OnlyValueValidatorClsMethod(Protocol):
def __call__(self, cls: Any, value: Any, /) -> Any:
...
class _V2ValidatorClsMethod(Protocol):
def __call__(self, cls: Any, value: Any, info: _core_schema.ValidationInfo, /) -> Any:
...
class _V2WrapValidatorClsMethod(Protocol):
def __call__(
self,
cls: Any,
value: Any,
handler: _core_schema.ValidatorFunctionWrapHandler,
info: _core_schema.ValidationInfo,
/,
) -> Any:
...
_V2Validator = Union[
_V2ValidatorClsMethod,
_core_schema.WithInfoValidatorFunction,
_OnlyValueValidatorClsMethod,
_core_schema.NoInfoValidatorFunction,
]
_V2WrapValidator = Union[
_V2WrapValidatorClsMethod,
_core_schema.WithInfoWrapValidatorFunction,
]
_PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]]
_V2BeforeAfterOrPlainValidatorType = TypeVar(
'_V2BeforeAfterOrPlainValidatorType',
_V2Validator,
_PartialClsOrStaticMethod,
)
_V2WrapValidatorType = TypeVar('_V2WrapValidatorType', _V2WrapValidator, _PartialClsOrStaticMethod)
@overload
def field_validator(
__field: str,
*fields: str,
mode: Literal['before', 'after', 'plain'] = ...,
check_fields: bool | None = ...,
) -> Callable[[_V2BeforeAfterOrPlainValidatorType], _V2BeforeAfterOrPlainValidatorType]:
...
@overload
def field_validator(
__field: str,
*fields: str,
mode: Literal['wrap'],
check_fields: bool | None = ...,
) -> Callable[[_V2WrapValidatorType], _V2WrapValidatorType]:
...
FieldValidatorModes: TypeAlias = Literal['before', 'after', 'wrap', 'plain']
def field_validator(
__field: str,
*fields: str,
mode: FieldValidatorModes = 'after',
check_fields: bool | None = None,
) -> Callable[[Any], Any]:
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/validators/#field-validators
Decorate methods on the class indicating that they should be used to validate fields.
Example usage:
```py
from typing import Any
from pydantic import (
BaseModel,
ValidationError,
field_validator,
)
class Model(BaseModel):
a: str
@field_validator('a')
@classmethod
def ensure_foobar(cls, v: Any):
if 'foobar' not in v:
raise ValueError('"foobar" not found in a')
return v
print(repr(Model(a='this is foobar good')))
#> Model(a='this is foobar good')
try:
Model(a='snap')
except ValidationError as exc_info:
print(exc_info)
'''
1 validation error for Model
a
Value error, "foobar" not found in a [type=value_error, input_value='snap', input_type=str]
'''
```
For more in depth examples, see [Field Validators](../concepts/validators.md#field-validators).
Args:
__field: The first field the `field_validator` should be called on; this is separate
from `fields` to ensure an error is raised if you don't pass at least one.
*fields: Additional field(s) the `field_validator` should be called on.
mode: Specifies whether to validate the fields before or after validation.
check_fields: Whether to check that the fields actually exist on the model.
Returns:
A decorator that can be used to decorate a function to be used as a field_validator.
Raises:
PydanticUserError:
- If `@field_validator` is used bare (with no fields).
- If the args passed to `@field_validator` as fields are not strings.
- If `@field_validator` applied to instance methods.
"""
if isinstance(__field, FunctionType):
raise PydanticUserError(
'`@field_validator` should be used with fields and keyword arguments, not bare. '
"E.g. usage should be `@validator('<field_name>', ...)`",
code='validator-no-fields',
)
fields = __field, *fields
if not all(isinstance(field, str) for field in fields):
raise PydanticUserError(
'`@field_validator` fields should be passed as separate string args. '
"E.g. usage should be `@validator('<field_name_1>', '<field_name_2>', ...)`",
code='validator-invalid-fields',
)
def dec(
f: Callable[..., Any] | staticmethod[Any, Any] | classmethod[Any, Any, Any],
) -> _decorators.PydanticDescriptorProxy[Any]:
if _decorators.is_instance_method_from_sig(f):
raise PydanticUserError(
'`@field_validator` cannot be applied to instance methods', code='validator-instance-method'
)
# auto apply the @classmethod decorator
f = _decorators.ensure_classmethod_based_on_signature(f)
dec_info = _decorators.FieldValidatorDecoratorInfo(fields=fields, mode=mode, check_fields=check_fields)
return _decorators.PydanticDescriptorProxy(f, dec_info)
return dec
_ModelType = TypeVar('_ModelType')
_ModelTypeCo = TypeVar('_ModelTypeCo', covariant=True)
class ModelWrapValidatorHandler(_core_schema.ValidatorFunctionWrapHandler, Protocol[_ModelTypeCo]):
"""@model_validator decorated function handler argument type. This is used when `mode='wrap'`."""
def __call__( # noqa: D102
self,
value: Any,
outer_location: str | int | None = None,
/,
) -> _ModelTypeCo: # pragma: no cover
...
class ModelWrapValidatorWithoutInfo(Protocol[_ModelType]):
"""A @model_validator decorated function signature.
This is used when `mode='wrap'` and the function does not have info argument.
"""
def __call__( # noqa: D102
self,
cls: type[_ModelType],
# this can be a dict, a model instance
# or anything else that gets passed to validate_python
# thus validators _must_ handle all cases
value: Any,
handler: ModelWrapValidatorHandler[_ModelType],
/,
) -> _ModelType:
...
class ModelWrapValidator(Protocol[_ModelType]):
"""A @model_validator decorated function signature. This is used when `mode='wrap'`."""
def __call__( # noqa: D102
self,
cls: type[_ModelType],
# this can be a dict, a model instance
# or anything else that gets passed to validate_python
# thus validators _must_ handle all cases
value: Any,
handler: ModelWrapValidatorHandler[_ModelType],
info: _core_schema.ValidationInfo,
/,
) -> _ModelType:
...
class FreeModelBeforeValidatorWithoutInfo(Protocol):
"""A @model_validator decorated function signature.
This is used when `mode='before'` and the function does not have info argument.
"""
def __call__( # noqa: D102
self,
# this can be a dict, a model instance
# or anything else that gets passed to validate_python
# thus validators _must_ handle all cases
value: Any,
/,
) -> Any:
...
class ModelBeforeValidatorWithoutInfo(Protocol):
"""A @model_validator decorated function signature.
This is used when `mode='before'` and the function does not have info argument.
"""
def __call__( # noqa: D102
self,
cls: Any,
# this can be a dict, a model instance
# or anything else that gets passed to validate_python
# thus validators _must_ handle all cases
value: Any,
/,
) -> Any:
...
class FreeModelBeforeValidator(Protocol):
"""A `@model_validator` decorated function signature. This is used when `mode='before'`."""
def __call__( # noqa: D102
self,
# this can be a dict, a model instance
# or anything else that gets passed to validate_python
# thus validators _must_ handle all cases
value: Any,
info: _core_schema.ValidationInfo,
/,
) -> Any:
...
class ModelBeforeValidator(Protocol):
"""A `@model_validator` decorated function signature. This is used when `mode='before'`."""
def __call__( # noqa: D102
self,
cls: Any,
# this can be a dict, a model instance
# or anything else that gets passed to validate_python
# thus validators _must_ handle all cases
value: Any,
info: _core_schema.ValidationInfo,
/,
) -> Any:
...
ModelAfterValidatorWithoutInfo = Callable[[_ModelType], _ModelType]
"""A `@model_validator` decorated function signature. This is used when `mode='after'` and the function does not
have info argument.
"""
ModelAfterValidator = Callable[[_ModelType, _core_schema.ValidationInfo], _ModelType]
"""A `@model_validator` decorated function signature. This is used when `mode='after'`."""
_AnyModelWrapValidator = Union[ModelWrapValidator[_ModelType], ModelWrapValidatorWithoutInfo[_ModelType]]
_AnyModeBeforeValidator = Union[
FreeModelBeforeValidator, ModelBeforeValidator, FreeModelBeforeValidatorWithoutInfo, ModelBeforeValidatorWithoutInfo
]
_AnyModelAfterValidator = Union[ModelAfterValidator[_ModelType], ModelAfterValidatorWithoutInfo[_ModelType]]
@overload
def model_validator(
*,
mode: Literal['wrap'],
) -> Callable[
[_AnyModelWrapValidator[_ModelType]], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]
]:
...
@overload
def model_validator(
*,
mode: Literal['before'],
) -> Callable[[_AnyModeBeforeValidator], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]]:
...
@overload
def model_validator(
*,
mode: Literal['after'],
) -> Callable[
[_AnyModelAfterValidator[_ModelType]], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]
]:
...
def model_validator(
*,
mode: Literal['wrap', 'before', 'after'],
) -> Any:
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/validators/#model-validators
Decorate model methods for validation purposes.
Example usage:
```py
from typing_extensions import Self
from pydantic import BaseModel, ValidationError, model_validator
class Square(BaseModel):
width: float
height: float
@model_validator(mode='after')
def verify_square(self) -> Self:
if self.width != self.height:
raise ValueError('width and height do not match')
return self
s = Square(width=1, height=1)
print(repr(s))
#> Square(width=1.0, height=1.0)
try:
Square(width=1, height=2)
except ValidationError as e:
print(e)
'''
1 validation error for Square
Value error, width and height do not match [type=value_error, input_value={'width': 1, 'height': 2}, input_type=dict]
'''
```
For more in depth examples, see [Model Validators](../concepts/validators.md#model-validators).
Args:
mode: A required string literal that specifies the validation mode.
It can be one of the following: 'wrap', 'before', or 'after'.
Returns:
A decorator that can be used to decorate a function to be used as a model validator.
"""
def dec(f: Any) -> _decorators.PydanticDescriptorProxy[Any]:
# auto apply the @classmethod decorator
f = _decorators.ensure_classmethod_based_on_signature(f)
dec_info = _decorators.ModelValidatorDecoratorInfo(mode=mode)
return _decorators.PydanticDescriptorProxy(f, dec_info)
return dec
AnyType = TypeVar('AnyType')
if TYPE_CHECKING:
# If we add configurable attributes to IsInstance, we'd probably need to stop hiding it from type checkers like this
InstanceOf = Annotated[AnyType, ...] # `IsInstance[Sequence]` will be recognized by type checkers as `Sequence`
else:
@dataclasses.dataclass(**_internal_dataclass.slots_true)
class InstanceOf:
'''Generic type for annotating a type that is an instance of a given class.
Example:
```py
from pydantic import BaseModel, InstanceOf
class Foo:
...
class Bar(BaseModel):
foo: InstanceOf[Foo]
Bar(foo=Foo())
try:
Bar(foo=42)
except ValidationError as e:
print(e)
"""
[
{
'type': 'is_instance_of',
'loc': ('foo',),
'msg': 'Input should be an instance of Foo',
'input': 42,
'ctx': {'class': 'Foo'},
'url': 'https://errors.pydantic.dev/0.38.0/v/is_instance_of'
}
]
"""
```
'''
@classmethod
def __class_getitem__(cls, item: AnyType) -> AnyType:
return Annotated[item, cls()]
@classmethod
def __get_pydantic_core_schema__(cls, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
from pydantic import PydanticSchemaGenerationError
# use the generic _origin_ as the second argument to isinstance when appropriate
instance_of_schema = core_schema.is_instance_schema(_generics.get_origin(source) or source)
try:
# Try to generate the "standard" schema, which will be used when loading from JSON
original_schema = handler(source)
except PydanticSchemaGenerationError:
# If that fails, just produce a schema that can validate from python
return instance_of_schema
else:
# Use the "original" approach to serialization
instance_of_schema['serialization'] = core_schema.wrap_serializer_function_ser_schema(
function=lambda v, h: h(v), schema=original_schema
)
return core_schema.json_or_python_schema(python_schema=instance_of_schema, json_schema=original_schema)
__hash__ = object.__hash__
if TYPE_CHECKING:
SkipValidation = Annotated[AnyType, ...] # SkipValidation[list[str]] will be treated by type checkers as list[str]
else:
@dataclasses.dataclass(**_internal_dataclass.slots_true)
class SkipValidation:
"""If this is applied as an annotation (e.g., via `x: Annotated[int, SkipValidation]`), validation will be
skipped. You can also use `SkipValidation[int]` as a shorthand for `Annotated[int, SkipValidation]`.
This can be useful if you want to use a type annotation for documentation/IDE/type-checking purposes,
and know that it is safe to skip validation for one or more of the fields.
Because this converts the validation schema to `any_schema`, subsequent annotation-applied transformations
may not have the expected effects. Therefore, when used, this annotation should generally be the final
annotation applied to a type.
"""
def __class_getitem__(cls, item: Any) -> Any:
return Annotated[item, SkipValidation()]
@classmethod
def __get_pydantic_core_schema__(cls, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
original_schema = handler(source)
metadata = _core_metadata.build_metadata_dict(js_annotation_functions=[lambda _c, h: h(original_schema)])
return core_schema.any_schema(
metadata=metadata,
serialization=core_schema.wrap_serializer_function_ser_schema(
function=lambda v, h: h(v), schema=original_schema
),
)
__hash__ = object.__hash__

View file

@ -1,364 +1,4 @@
import sys """The `generics` module is a backport module from V1."""
import typing from ._migration import getattr_migration
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Generic,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from typing_extensions import Annotated __getattr__ = getattr_migration(__name__)
from .class_validators import gather_all_validators
from .fields import DeferredType
from .main import BaseModel, create_model
from .types import JsonWrapper
from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base
from .utils import LimitedDict, all_identical, lenient_issubclass
GenericModelT = TypeVar('GenericModelT', bound='GenericModel')
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type
Parametrization = Mapping[TypeVarType, Type[Any]]
_generic_types_cache: LimitedDict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = LimitedDict()
# _assigned_parameters is a Mapping from parametrized version of generic models to assigned types of parametrizations
# as captured during construction of the class (not instances).
# E.g., for generic model `Model[A, B]`, when parametrized model `Model[int, str]` is created,
# `Model[int, str]`: {A: int, B: str}` will be stored in `_assigned_parameters`.
# (This information is only otherwise available after creation from the class name string).
_assigned_parameters: LimitedDict[Type[Any], Parametrization] = LimitedDict()
class GenericModel(BaseModel):
__slots__ = ()
__concrete__: ClassVar[bool] = False
if TYPE_CHECKING:
# Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with
# `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of
# `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below.
__parameters__: ClassVar[Tuple[TypeVarType, ...]]
# Setting the return type as Type[Any] instead of Type[BaseModel] prevents PyCharm warnings
def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[Type[Any], ...]]) -> Type[Any]:
"""Instantiates a new class from a generic class `cls` and type variables `params`.
:param params: Tuple of types the class . Given a generic class
`Model` with 2 type variables and a concrete model `Model[str, int]`,
the value `(str, int)` would be passed to `params`.
:return: New model class inheriting from `cls` with instantiated
types described by `params`. If no parameters are given, `cls` is
returned as is.
"""
def _cache_key(_params: Any) -> Tuple[Type[GenericModelT], Any, Tuple[Any, ...]]:
return cls, _params, get_args(_params)
cached = _generic_types_cache.get(_cache_key(params))
if cached is not None:
return cached
if cls.__concrete__ and Generic not in cls.__bases__:
raise TypeError('Cannot parameterize a concrete instantiation of a generic model')
if not isinstance(params, tuple):
params = (params,)
if cls is GenericModel and any(isinstance(param, TypeVar) for param in params):
raise TypeError('Type parameters should be placed on typing.Generic, not GenericModel')
if not hasattr(cls, '__parameters__'):
raise TypeError(f'Type {cls.__name__} must inherit from typing.Generic before being parameterized')
check_parameters_count(cls, params)
# Build map from generic typevars to passed params
typevars_map: Dict[TypeVarType, Type[Any]] = dict(zip(cls.__parameters__, params))
if all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map:
return cls # if arguments are equal to parameters it's the same object
# Create new model with original model as parent inserting fields with DeferredType.
model_name = cls.__concrete_name__(params)
validators = gather_all_validators(cls)
type_hints = get_all_type_hints(cls).items()
instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar}
fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__}
model_module, called_globally = get_caller_frame_info()
created_model = cast(
Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes
create_model(
model_name,
__module__=model_module or cls.__module__,
__base__=(cls,) + tuple(cls.__parameterized_bases__(typevars_map)),
__config__=None,
__validators__=validators,
__cls_kwargs__=None,
**fields,
),
)
_assigned_parameters[created_model] = typevars_map
if called_globally: # create global reference and therefore allow pickling
object_by_reference = None
reference_name = model_name
reference_module_globals = sys.modules[created_model.__module__].__dict__
while object_by_reference is not created_model:
object_by_reference = reference_module_globals.setdefault(reference_name, created_model)
reference_name += '_'
created_model.Config = cls.Config
# Find any typevars that are still present in the model.
# If none are left, the model is fully "concrete", otherwise the new
# class is a generic class as well taking the found typevars as
# parameters.
new_params = tuple(
{param: None for param in iter_contained_typevars(typevars_map.values())}
) # use dict as ordered set
created_model.__concrete__ = not new_params
if new_params:
created_model.__parameters__ = new_params
# Save created model in cache so we don't end up creating duplicate
# models that should be identical.
_generic_types_cache[_cache_key(params)] = created_model
if len(params) == 1:
_generic_types_cache[_cache_key(params[0])] = created_model
# Recursively walk class type hints and replace generic typevars
# with concrete types that were passed.
_prepare_model_fields(created_model, fields, instance_type_hints, typevars_map)
return created_model
@classmethod
def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str:
"""Compute class name for child classes.
:param params: Tuple of types the class . Given a generic class
`Model` with 2 type variables and a concrete model `Model[str, int]`,
the value `(str, int)` would be passed to `params`.
:return: String representing a the new class where `params` are
passed to `cls` as type variables.
This method can be overridden to achieve a custom naming scheme for GenericModels.
"""
param_names = [display_as_type(param) for param in params]
params_component = ', '.join(param_names)
return f'{cls.__name__}[{params_component}]'
@classmethod
def __parameterized_bases__(cls, typevars_map: Parametrization) -> Iterator[Type[Any]]:
"""
Returns unbound bases of cls parameterised to given type variables
:param typevars_map: Dictionary of type applications for binding subclasses.
Given a generic class `Model` with 2 type variables [S, T]
and a concrete model `Model[str, int]`,
the value `{S: str, T: int}` would be passed to `typevars_map`.
:return: an iterator of generic sub classes, parameterised by `typevars_map`
and other assigned parameters of `cls`
e.g.:
```
class A(GenericModel, Generic[T]):
...
class B(A[V], Generic[V]):
...
assert A[int] in B.__parameterized_bases__({V: int})
```
"""
def build_base_model(
base_model: Type[GenericModel], mapped_types: Parametrization
) -> Iterator[Type[GenericModel]]:
base_parameters = tuple(mapped_types[param] for param in base_model.__parameters__)
parameterized_base = base_model.__class_getitem__(base_parameters)
if parameterized_base is base_model or parameterized_base is cls:
# Avoid duplication in MRO
return
yield parameterized_base
for base_model in cls.__bases__:
if not issubclass(base_model, GenericModel):
# not a class that can be meaningfully parameterized
continue
elif not getattr(base_model, '__parameters__', None):
# base_model is "GenericModel" (and has no __parameters__)
# or
# base_model is already concrete, and will be included transitively via cls.
continue
elif cls in _assigned_parameters:
if base_model in _assigned_parameters:
# cls is partially parameterised but not from base_model
# e.g. cls = B[S], base_model = A[S]
# B[S][int] should subclass A[int], (and will be transitively via B[int])
# but it's not viable to consistently subclass types with arbitrary construction
# So don't attempt to include A[S][int]
continue
else: # base_model not in _assigned_parameters:
# cls is partially parameterized, base_model is original generic
# e.g. cls = B[str, T], base_model = B[S, T]
# Need to determine the mapping for the base_model parameters
mapped_types: Parametrization = {
key: typevars_map.get(value, value) for key, value in _assigned_parameters[cls].items()
}
yield from build_base_model(base_model, mapped_types)
else:
# cls is base generic, so base_class has a distinct base
# can construct the Parameterised base model using typevars_map directly
yield from build_base_model(base_model, typevars_map)
def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any:
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
:param type_: Any type, class or generic alias
:param type_map: Mapping from `TypeVar` instance to concrete types.
:return: New type representing the basic structure of `type_` with all
`typevar_map` keys recursively replaced.
>>> replace_types(Tuple[str, Union[List[str], float]], {str: int})
Tuple[int, Union[List[int], float]]
"""
if not type_map:
return type_
type_args = get_args(type_)
origin_type = get_origin(type_)
if origin_type is Annotated:
annotated_type, *annotations = type_args
return Annotated[replace_types(annotated_type, type_map), tuple(annotations)]
# Having type args is a good indicator that this is a typing module
# class instantiation or a generic alias of some sort.
if type_args:
resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
if all_identical(type_args, resolved_type_args):
# If all arguments are the same, there is no need to modify the
# type or create a new object at all
return type_
if (
origin_type is not None
and isinstance(type_, typing_base)
and not isinstance(origin_type, typing_base)
and getattr(type_, '_name', None) is not None
):
# In python < 3.9 generic aliases don't exist so any of these like `list`,
# `type` or `collections.abc.Callable` need to be translated.
# See: https://www.python.org/dev/peps/pep-0585
origin_type = getattr(typing, type_._name)
assert origin_type is not None
return origin_type[resolved_type_args]
# We handle pydantic generic models separately as they don't have the same
# semantics as "typing" classes or generic aliases
if not origin_type and lenient_issubclass(type_, GenericModel) and not type_.__concrete__:
type_args = type_.__parameters__
resolved_type_args = tuple(replace_types(t, type_map) for t in type_args)
if all_identical(type_args, resolved_type_args):
return type_
return type_[resolved_type_args]
# Handle special case for typehints that can have lists as arguments.
# `typing.Callable[[int, str], int]` is an example for this.
if isinstance(type_, (List, list)):
resolved_list = list(replace_types(element, type_map) for element in type_)
if all_identical(type_, resolved_list):
return type_
return resolved_list
# For JsonWrapperValue, need to handle its inner type to allow correct parsing
# of generic Json arguments like Json[T]
if not origin_type and lenient_issubclass(type_, JsonWrapper):
type_.inner_type = replace_types(type_.inner_type, type_map)
return type_
# If all else fails, we try to resolve the type directly and otherwise just
# return the input with no modifications.
return type_map.get(type_, type_)
def check_parameters_count(cls: Type[GenericModel], parameters: Tuple[Any, ...]) -> None:
actual = len(parameters)
expected = len(cls.__parameters__)
if actual != expected:
description = 'many' if actual > expected else 'few'
raise TypeError(f'Too {description} parameters for {cls.__name__}; actual {actual}, expected {expected}')
DictValues: Type[Any] = {}.values().__class__
def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
"""Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found."""
if isinstance(v, TypeVar):
yield v
elif hasattr(v, '__parameters__') and not get_origin(v) and lenient_issubclass(v, GenericModel):
yield from v.__parameters__
elif isinstance(v, (DictValues, list)):
for var in v:
yield from iter_contained_typevars(var)
else:
args = get_args(v)
for arg in args:
yield from iter_contained_typevars(arg)
def get_caller_frame_info() -> Tuple[Optional[str], bool]:
"""
Used inside a function to check whether it was called globally
Will only work against non-compiled code, therefore used only in pydantic.generics
:returns Tuple[module_name, called_globally]
"""
try:
previous_caller_frame = sys._getframe(2)
except ValueError as e:
raise RuntimeError('This function must be used inside another function') from e
except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it
return None, False
frame_globals = previous_caller_frame.f_globals
return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals
def _prepare_model_fields(
created_model: Type[GenericModel],
fields: Mapping[str, Any],
instance_type_hints: Mapping[str, type],
typevars_map: Mapping[Any, type],
) -> None:
"""
Replace DeferredType fields with concrete type hints and prepare them.
"""
for key, field in created_model.__fields__.items():
if key not in fields:
assert field.type_.__class__ is not DeferredType
# https://github.com/nedbat/coveragepy/issues/198
continue # pragma: no cover
assert field.type_.__class__ is DeferredType, field.type_.__class__
field_type_hint = instance_type_hints[key]
concrete_type = replace_types(field_type_hint, typevars_map)
field.type_ = concrete_type
field.outer_type_ = concrete_type
field.prepare()
created_model.__annotations__[key] = concrete_type

View file

@ -1,112 +1,4 @@
import datetime """The `json` module is a backport module from V1."""
from collections import deque from ._migration import getattr_migration
from decimal import Decimal
from enum import Enum
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from pathlib import Path
from re import Pattern
from types import GeneratorType
from typing import Any, Callable, Dict, Type, Union
from uuid import UUID
from .color import Color __getattr__ = getattr_migration(__name__)
from .networks import NameEmail
from .types import SecretBytes, SecretStr
__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat'
def isoformat(o: Union[datetime.date, datetime.time]) -> str:
return o.isoformat()
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
"""
Encodes a Decimal as int of there's no exponent, otherwise float
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
where a integer (but not int typed) is used. Encoding this as a float
results in failed round-tripping between encode and parse.
Our Id type is a prime example of this.
>>> decimal_encoder(Decimal("1.0"))
1.0
>>> decimal_encoder(Decimal("1"))
1
"""
if dec_value.as_tuple().exponent >= 0:
return int(dec_value)
else:
return float(dec_value)
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
bytes: lambda o: o.decode(),
Color: str,
datetime.date: isoformat,
datetime.datetime: isoformat,
datetime.time: isoformat,
datetime.timedelta: lambda td: td.total_seconds(),
Decimal: decimal_encoder,
Enum: lambda o: o.value,
frozenset: list,
deque: list,
GeneratorType: list,
IPv4Address: str,
IPv4Interface: str,
IPv4Network: str,
IPv6Address: str,
IPv6Interface: str,
IPv6Network: str,
NameEmail: str,
Path: str,
Pattern: lambda o: o.pattern,
SecretBytes: str,
SecretStr: str,
set: list,
UUID: str,
}
def pydantic_encoder(obj: Any) -> Any:
from dataclasses import asdict, is_dataclass
from .main import BaseModel
if isinstance(obj, BaseModel):
return obj.dict()
elif is_dataclass(obj):
return asdict(obj)
# Check the class type and its superclasses for a matching encoder
for base in obj.__class__.__mro__[:-1]:
try:
encoder = ENCODERS_BY_TYPE[base]
except KeyError:
continue
return encoder(obj)
else: # We have exited the for loop without finding a suitable encoder
raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable")
def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any:
# Check the class type and its superclasses for a matching encoder
for base in obj.__class__.__mro__[:-1]:
try:
encoder = type_encoders[base]
except KeyError:
continue
return encoder(obj)
else: # We have exited the for loop without finding a suitable encoder
return pydantic_encoder(obj)
def timedelta_isoformat(td: datetime.timedelta) -> str:
"""
ISO 8601 encoding for Python timedelta object.
"""
minutes, seconds = divmod(td.seconds, 60)
hours, minutes = divmod(minutes, 60)
return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S'

2425
lib/pydantic/json_schema.py Normal file

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,66 +1,4 @@
import json """The `parse` module is a backport module from V1."""
import pickle from ._migration import getattr_migration
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Union
from .types import StrBytes __getattr__ = getattr_migration(__name__)
class Protocol(str, Enum):
json = 'json'
pickle = 'pickle'
def load_str_bytes(
b: StrBytes,
*,
content_type: str = None,
encoding: str = 'utf8',
proto: Protocol = None,
allow_pickle: bool = False,
json_loads: Callable[[str], Any] = json.loads,
) -> Any:
if proto is None and content_type:
if content_type.endswith(('json', 'javascript')):
pass
elif allow_pickle and content_type.endswith('pickle'):
proto = Protocol.pickle
else:
raise TypeError(f'Unknown content-type: {content_type}')
proto = proto or Protocol.json
if proto == Protocol.json:
if isinstance(b, bytes):
b = b.decode(encoding)
return json_loads(b)
elif proto == Protocol.pickle:
if not allow_pickle:
raise RuntimeError('Trying to decode with pickle with allow_pickle=False')
bb = b if isinstance(b, bytes) else b.encode()
return pickle.loads(bb)
else:
raise TypeError(f'Unknown protocol: {proto}')
def load_file(
path: Union[str, Path],
*,
content_type: str = None,
encoding: str = 'utf8',
proto: Protocol = None,
allow_pickle: bool = False,
json_loads: Callable[[str], Any] = json.loads,
) -> Any:
path = Path(path)
b = path.read_bytes()
if content_type is None:
if path.suffix in ('.js', '.json'):
proto = Protocol.json
elif path.suffix == '.pkl':
proto = Protocol.pickle
return load_str_bytes(
b, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle, json_loads=json_loads
)

View file

@ -0,0 +1,170 @@
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/plugins#build-a-plugin
Plugin interface for Pydantic plugins, and related types.
"""
from __future__ import annotations
from typing import Any, Callable, NamedTuple
from pydantic_core import CoreConfig, CoreSchema, ValidationError
from typing_extensions import Literal, Protocol, TypeAlias
__all__ = (
'PydanticPluginProtocol',
'BaseValidateHandlerProtocol',
'ValidatePythonHandlerProtocol',
'ValidateJsonHandlerProtocol',
'ValidateStringsHandlerProtocol',
'NewSchemaReturns',
'SchemaTypePath',
'SchemaKind',
)
NewSchemaReturns: TypeAlias = 'tuple[ValidatePythonHandlerProtocol | None, ValidateJsonHandlerProtocol | None, ValidateStringsHandlerProtocol | None]'
class SchemaTypePath(NamedTuple):
"""Path defining where `schema_type` was defined, or where `TypeAdapter` was called."""
module: str
name: str
SchemaKind: TypeAlias = Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call']
class PydanticPluginProtocol(Protocol):
"""Protocol defining the interface for Pydantic plugins."""
def new_schema_validator(
self,
schema: CoreSchema,
schema_type: Any,
schema_type_path: SchemaTypePath,
schema_kind: SchemaKind,
config: CoreConfig | None,
plugin_settings: dict[str, object],
) -> tuple[
ValidatePythonHandlerProtocol | None, ValidateJsonHandlerProtocol | None, ValidateStringsHandlerProtocol | None
]:
"""This method is called for each plugin every time a new [`SchemaValidator`][pydantic_core.SchemaValidator]
is created.
It should return an event handler for each of the three validation methods, or `None` if the plugin does not
implement that method.
Args:
schema: The schema to validate against.
schema_type: The original type which the schema was created from, e.g. the model class.
schema_type_path: Path defining where `schema_type` was defined, or where `TypeAdapter` was called.
schema_kind: The kind of schema to validate against.
config: The config to use for validation.
plugin_settings: Any plugin settings.
Returns:
A tuple of optional event handlers for each of the three validation methods -
`validate_python`, `validate_json`, `validate_strings`.
"""
raise NotImplementedError('Pydantic plugins should implement `new_schema_validator`.')
class BaseValidateHandlerProtocol(Protocol):
"""Base class for plugin callbacks protocols.
You shouldn't implement this protocol directly, instead use one of the subclasses with adds the correctly
typed `on_error` method.
"""
on_enter: Callable[..., None]
"""`on_enter` is changed to be more specific on all subclasses"""
def on_success(self, result: Any) -> None:
"""Callback to be notified of successful validation.
Args:
result: The result of the validation.
"""
return
def on_error(self, error: ValidationError) -> None:
"""Callback to be notified of validation errors.
Args:
error: The validation error.
"""
return
def on_exception(self, exception: Exception) -> None:
"""Callback to be notified of validation exceptions.
Args:
exception: The exception raised during validation.
"""
return
class ValidatePythonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
"""Event handler for `SchemaValidator.validate_python`."""
def on_enter(
self,
input: Any,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: dict[str, Any] | None = None,
self_instance: Any | None = None,
) -> None:
"""Callback to be notified of validation start, and create an instance of the event handler.
Args:
input: The input to be validated.
strict: Whether to validate the object in strict mode.
from_attributes: Whether to validate objects as inputs by extracting attributes.
context: The context to use for validation, this is passed to functional validators.
self_instance: An instance of a model to set attributes on from validation, this is used when running
validation from the `__init__` method of a model.
"""
pass
class ValidateJsonHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
"""Event handler for `SchemaValidator.validate_json`."""
def on_enter(
self,
input: str | bytes | bytearray,
*,
strict: bool | None = None,
context: dict[str, Any] | None = None,
self_instance: Any | None = None,
) -> None:
"""Callback to be notified of validation start, and create an instance of the event handler.
Args:
input: The JSON data to be validated.
strict: Whether to validate the object in strict mode.
context: The context to use for validation, this is passed to functional validators.
self_instance: An instance of a model to set attributes on from validation, this is used when running
validation from the `__init__` method of a model.
"""
pass
StringInput: TypeAlias = 'dict[str, StringInput]'
class ValidateStringsHandlerProtocol(BaseValidateHandlerProtocol, Protocol):
"""Event handler for `SchemaValidator.validate_strings`."""
def on_enter(
self, input: StringInput, *, strict: bool | None = None, context: dict[str, Any] | None = None
) -> None:
"""Callback to be notified of validation start, and create an instance of the event handler.
Args:
input: The string data to be validated.
strict: Whether to validate the object in strict mode.
context: The context to use for validation, this is passed to functional validators.
"""
pass

View file

@ -0,0 +1,50 @@
from __future__ import annotations
import importlib.metadata as importlib_metadata
import warnings
from typing import TYPE_CHECKING, Final, Iterable
if TYPE_CHECKING:
from . import PydanticPluginProtocol
PYDANTIC_ENTRY_POINT_GROUP: Final[str] = 'pydantic'
# cache of plugins
_plugins: dict[str, PydanticPluginProtocol] | None = None
# return no plugins while loading plugins to avoid recursion and errors while import plugins
# this means that if plugins use pydantic
_loading_plugins: bool = False
def get_plugins() -> Iterable[PydanticPluginProtocol]:
"""Load plugins for Pydantic.
Inspired by: https://github.com/pytest-dev/pluggy/blob/1.3.0/src/pluggy/_manager.py#L376-L402
"""
global _plugins, _loading_plugins
if _loading_plugins:
# this happens when plugins themselves use pydantic, we return no plugins
return ()
elif _plugins is None:
_plugins = {}
# set _loading_plugins so any plugins that use pydantic don't themselves use plugins
_loading_plugins = True
try:
for dist in importlib_metadata.distributions():
for entry_point in dist.entry_points:
if entry_point.group != PYDANTIC_ENTRY_POINT_GROUP:
continue
if entry_point.value in _plugins:
continue
try:
_plugins[entry_point.value] = entry_point.load()
except (ImportError, AttributeError) as e:
warnings.warn(
f'{e.__class__.__name__} while loading the `{entry_point.name}` Pydantic plugin, '
f'this plugin will not be installed.\n\n{e!r}'
)
finally:
_loading_plugins = False
return _plugins.values()

View file

@ -0,0 +1,138 @@
"""Pluggable schema validator for pydantic."""
from __future__ import annotations
import functools
from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar
from pydantic_core import CoreConfig, CoreSchema, SchemaValidator, ValidationError
from typing_extensions import Literal, ParamSpec
if TYPE_CHECKING:
from . import BaseValidateHandlerProtocol, PydanticPluginProtocol, SchemaKind, SchemaTypePath
P = ParamSpec('P')
R = TypeVar('R')
Event = Literal['on_validate_python', 'on_validate_json', 'on_validate_strings']
events: list[Event] = list(Event.__args__) # type: ignore
def create_schema_validator(
schema: CoreSchema,
schema_type: Any,
schema_type_module: str,
schema_type_name: str,
schema_kind: SchemaKind,
config: CoreConfig | None = None,
plugin_settings: dict[str, Any] | None = None,
) -> SchemaValidator:
"""Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed.
Returns:
If plugins are installed then return `PluggableSchemaValidator`, otherwise return `SchemaValidator`.
"""
from . import SchemaTypePath
from ._loader import get_plugins
plugins = get_plugins()
if plugins:
return PluggableSchemaValidator(
schema,
schema_type,
SchemaTypePath(schema_type_module, schema_type_name),
schema_kind,
config,
plugins,
plugin_settings or {},
) # type: ignore
else:
return SchemaValidator(schema, config)
class PluggableSchemaValidator:
"""Pluggable schema validator."""
__slots__ = '_schema_validator', 'validate_json', 'validate_python', 'validate_strings'
def __init__(
self,
schema: CoreSchema,
schema_type: Any,
schema_type_path: SchemaTypePath,
schema_kind: SchemaKind,
config: CoreConfig | None,
plugins: Iterable[PydanticPluginProtocol],
plugin_settings: dict[str, Any],
) -> None:
self._schema_validator = SchemaValidator(schema, config)
python_event_handlers: list[BaseValidateHandlerProtocol] = []
json_event_handlers: list[BaseValidateHandlerProtocol] = []
strings_event_handlers: list[BaseValidateHandlerProtocol] = []
for plugin in plugins:
try:
p, j, s = plugin.new_schema_validator(
schema, schema_type, schema_type_path, schema_kind, config, plugin_settings
)
except TypeError as e: # pragma: no cover
raise TypeError(f'Error using plugin `{plugin.__module__}:{plugin.__class__.__name__}`: {e}') from e
if p is not None:
python_event_handlers.append(p)
if j is not None:
json_event_handlers.append(j)
if s is not None:
strings_event_handlers.append(s)
self.validate_python = build_wrapper(self._schema_validator.validate_python, python_event_handlers)
self.validate_json = build_wrapper(self._schema_validator.validate_json, json_event_handlers)
self.validate_strings = build_wrapper(self._schema_validator.validate_strings, strings_event_handlers)
def __getattr__(self, name: str) -> Any:
return getattr(self._schema_validator, name)
def build_wrapper(func: Callable[P, R], event_handlers: list[BaseValidateHandlerProtocol]) -> Callable[P, R]:
if not event_handlers:
return func
else:
on_enters = tuple(h.on_enter for h in event_handlers if filter_handlers(h, 'on_enter'))
on_successes = tuple(h.on_success for h in event_handlers if filter_handlers(h, 'on_success'))
on_errors = tuple(h.on_error for h in event_handlers if filter_handlers(h, 'on_error'))
on_exceptions = tuple(h.on_exception for h in event_handlers if filter_handlers(h, 'on_exception'))
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
for on_enter_handler in on_enters:
on_enter_handler(*args, **kwargs)
try:
result = func(*args, **kwargs)
except ValidationError as error:
for on_error_handler in on_errors:
on_error_handler(error)
raise
except Exception as exception:
for on_exception_handler in on_exceptions:
on_exception_handler(exception)
raise
else:
for on_success_handler in on_successes:
on_success_handler(result)
return result
return wrapper
def filter_handlers(handler_cls: BaseValidateHandlerProtocol, method_name: str) -> bool:
"""Filter out handler methods which are not implemented by the plugin directly - e.g. are missing
or are inherited from the protocol.
"""
handler = getattr(handler_cls, method_name, None)
if handler is None:
return False
elif handler.__module__ == 'pydantic.plugin':
# this is the original handler, from the protocol due to runtime inheritance
# we don't want to call it
return False
else:
return True

149
lib/pydantic/root_model.py Normal file
View file

@ -0,0 +1,149 @@
"""RootModel class and type definitions."""
from __future__ import annotations as _annotations
import typing
from copy import copy, deepcopy
from pydantic_core import PydanticUndefined
from . import PydanticUserError
from ._internal import _model_construction, _repr
from .main import BaseModel, _object_setattr
if typing.TYPE_CHECKING:
from typing import Any
from typing_extensions import Literal, dataclass_transform
from .fields import Field as PydanticModelField
# dataclass_transform could be applied to RootModel directly, but `ModelMetaclass`'s dataclass_transform
# takes priority (at least with pyright). We trick type checkers into thinking we apply dataclass_transform
# on a new metaclass.
@dataclass_transform(kw_only_default=False, field_specifiers=(PydanticModelField,))
class _RootModelMetaclass(_model_construction.ModelMetaclass):
...
Model = typing.TypeVar('Model', bound='BaseModel')
else:
_RootModelMetaclass = _model_construction.ModelMetaclass
__all__ = ('RootModel',)
RootModelRootType = typing.TypeVar('RootModelRootType')
class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootModelMetaclass):
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/models/#rootmodel-and-custom-root-types
A Pydantic `BaseModel` for the root object of the model.
Attributes:
root: The root object of the model.
__pydantic_root_model__: Whether the model is a RootModel.
__pydantic_private__: Private fields in the model.
__pydantic_extra__: Extra fields in the model.
"""
__pydantic_root_model__ = True
__pydantic_private__ = None
__pydantic_extra__ = None
root: RootModelRootType
def __init_subclass__(cls, **kwargs):
extra = cls.model_config.get('extra')
if extra is not None:
raise PydanticUserError(
"`RootModel` does not support setting `model_config['extra']`", code='root-model-extra'
)
super().__init_subclass__(**kwargs)
def __init__(self, /, root: RootModelRootType = PydanticUndefined, **data) -> None: # type: ignore
__tracebackhide__ = True
if data:
if root is not PydanticUndefined:
raise ValueError(
'"RootModel.__init__" accepts either a single positional argument or arbitrary keyword arguments'
)
root = data # type: ignore
self.__pydantic_validator__.validate_python(root, self_instance=self)
__init__.__pydantic_base_init__ = True # pyright: ignore[reportFunctionMemberAccess]
@classmethod
def model_construct(cls: type[Model], root: RootModelRootType, _fields_set: set[str] | None = None) -> Model: # type: ignore
"""Create a new model using the provided root object and update fields set.
Args:
root: The root object of the model.
_fields_set: The set of fields to be updated.
Returns:
The new model.
Raises:
NotImplemented: If the model is not a subclass of `RootModel`.
"""
return super().model_construct(root=root, _fields_set=_fields_set)
def __getstate__(self) -> dict[Any, Any]:
return {
'__dict__': self.__dict__,
'__pydantic_fields_set__': self.__pydantic_fields_set__,
}
def __setstate__(self, state: dict[Any, Any]) -> None:
_object_setattr(self, '__pydantic_fields_set__', state['__pydantic_fields_set__'])
_object_setattr(self, '__dict__', state['__dict__'])
def __copy__(self: Model) -> Model:
"""Returns a shallow copy of the model."""
cls = type(self)
m = cls.__new__(cls)
_object_setattr(m, '__dict__', copy(self.__dict__))
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
return m
def __deepcopy__(self: Model, memo: dict[int, Any] | None = None) -> Model:
"""Returns a deep copy of the model."""
cls = type(self)
m = cls.__new__(cls)
_object_setattr(m, '__dict__', deepcopy(self.__dict__, memo=memo))
# This next line doesn't need a deepcopy because __pydantic_fields_set__ is a set[str],
# and attempting a deepcopy would be marginally slower.
_object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__))
return m
if typing.TYPE_CHECKING:
def model_dump( # type: ignore
self,
*,
mode: Literal['json', 'python'] | str = 'python',
include: Any = None,
exclude: Any = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool = True,
) -> RootModelRootType:
"""This method is included just to get a more accurate return type for type checkers.
It is included in this `if TYPE_CHECKING:` block since no override is actually necessary.
See the documentation of `BaseModel.model_dump` for more details about the arguments.
"""
...
def __eq__(self, other: Any) -> bool:
if not isinstance(other, RootModel):
return NotImplemented
return self.model_fields['root'].annotation == other.model_fields['root'].annotation and super().__eq__(other)
def __repr_args__(self) -> _repr.ReprArgs:
yield 'root', self.root

File diff suppressed because it is too large Load diff

View file

@ -1,92 +1,4 @@
import json """The `tools` module is a backport module from V1."""
from functools import lru_cache from ._migration import getattr_migration
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Union
from .parse import Protocol, load_file, load_str_bytes __getattr__ = getattr_migration(__name__)
from .types import StrBytes
from .typing import display_as_type
__all__ = ('parse_file_as', 'parse_obj_as', 'parse_raw_as', 'schema_of', 'schema_json_of')
NameFactory = Union[str, Callable[[Type[Any]], str]]
if TYPE_CHECKING:
from .typing import DictStrAny
def _generate_parsing_type_name(type_: Any) -> str:
return f'ParsingModel[{display_as_type(type_)}]'
@lru_cache(maxsize=2048)
def _get_parsing_type(type_: Any, *, type_name: Optional[NameFactory] = None) -> Any:
from pydantic.main import create_model
if type_name is None:
type_name = _generate_parsing_type_name
if not isinstance(type_name, str):
type_name = type_name(type_)
return create_model(type_name, __root__=(type_, ...))
T = TypeVar('T')
def parse_obj_as(type_: Type[T], obj: Any, *, type_name: Optional[NameFactory] = None) -> T:
model_type = _get_parsing_type(type_, type_name=type_name) # type: ignore[arg-type]
return model_type(__root__=obj).__root__
def parse_file_as(
type_: Type[T],
path: Union[str, Path],
*,
content_type: str = None,
encoding: str = 'utf8',
proto: Protocol = None,
allow_pickle: bool = False,
json_loads: Callable[[str], Any] = json.loads,
type_name: Optional[NameFactory] = None,
) -> T:
obj = load_file(
path,
proto=proto,
content_type=content_type,
encoding=encoding,
allow_pickle=allow_pickle,
json_loads=json_loads,
)
return parse_obj_as(type_, obj, type_name=type_name)
def parse_raw_as(
type_: Type[T],
b: StrBytes,
*,
content_type: str = None,
encoding: str = 'utf8',
proto: Protocol = None,
allow_pickle: bool = False,
json_loads: Callable[[str], Any] = json.loads,
type_name: Optional[NameFactory] = None,
) -> T:
obj = load_str_bytes(
b,
proto=proto,
content_type=content_type,
encoding=encoding,
allow_pickle=allow_pickle,
json_loads=json_loads,
)
return parse_obj_as(type_, obj, type_name=type_name)
def schema_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_kwargs: Any) -> 'DictStrAny':
"""Generate a JSON schema (as dict) for the passed model or dynamically generated one"""
return _get_parsing_type(type_, type_name=title).schema(**schema_kwargs)
def schema_json_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_json_kwargs: Any) -> str:
"""Generate a JSON schema (as JSON) for the passed model or dynamically generated one"""
return _get_parsing_type(type_, type_name=title).schema_json(**schema_json_kwargs)

View file

@ -0,0 +1,460 @@
"""Type adapter specification."""
from __future__ import annotations as _annotations
import sys
from dataclasses import is_dataclass
from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Set, TypeVar, Union, cast, final, overload
from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator, Some
from typing_extensions import Literal, get_args, is_typeddict
from pydantic.errors import PydanticUserError
from pydantic.main import BaseModel
from ._internal import _config, _generate_schema, _typing_extra
from .config import ConfigDict
from .json_schema import (
DEFAULT_REF_TEMPLATE,
GenerateJsonSchema,
JsonSchemaKeyT,
JsonSchemaMode,
JsonSchemaValue,
)
from .plugin._schema_validator import create_schema_validator
T = TypeVar('T')
if TYPE_CHECKING:
# should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
def _get_schema(type_: Any, config_wrapper: _config.ConfigWrapper, parent_depth: int) -> CoreSchema:
"""`BaseModel` uses its own `__module__` to find out where it was defined
and then looks for symbols to resolve forward references in those globals.
On the other hand this function can be called with arbitrary objects,
including type aliases, where `__module__` (always `typing.py`) is not useful.
So instead we look at the globals in our parent stack frame.
This works for the case where this function is called in a module that
has the target of forward references in its scope, but
does not always work for more complex cases.
For example, take the following:
a.py
```python
from typing import Dict, List
IntList = List[int]
OuterDict = Dict[str, 'IntList']
```
b.py
```python test="skip"
from a import OuterDict
from pydantic import TypeAdapter
IntList = int # replaces the symbol the forward reference is looking for
v = TypeAdapter(OuterDict)
v({'x': 1}) # should fail but doesn't
```
If `OuterDict` were a `BaseModel`, this would work because it would resolve
the forward reference within the `a.py` namespace.
But `TypeAdapter(OuterDict)` can't determine what module `OuterDict` came from.
In other words, the assumption that _all_ forward references exist in the
module we are being called from is not technically always true.
Although most of the time it is and it works fine for recursive models and such,
`BaseModel`'s behavior isn't perfect either and _can_ break in similar ways,
so there is no right or wrong between the two.
But at the very least this behavior is _subtly_ different from `BaseModel`'s.
"""
local_ns = _typing_extra.parent_frame_namespace(parent_depth=parent_depth)
global_ns = sys._getframe(max(parent_depth - 1, 1)).f_globals.copy()
global_ns.update(local_ns or {})
gen = _generate_schema.GenerateSchema(config_wrapper, types_namespace=global_ns, typevars_map={})
schema = gen.generate_schema(type_)
schema = gen.clean_schema(schema)
return schema
def _getattr_no_parents(obj: Any, attribute: str) -> Any:
"""Returns the attribute value without attempting to look up attributes from parent types."""
if hasattr(obj, '__dict__'):
try:
return obj.__dict__[attribute]
except KeyError:
pass
slots = getattr(obj, '__slots__', None)
if slots is not None and attribute in slots:
return getattr(obj, attribute)
else:
raise AttributeError(attribute)
def _type_has_config(type_: Any) -> bool:
"""Returns whether the type has config."""
try:
return issubclass(type_, BaseModel) or is_dataclass(type_) or is_typeddict(type_)
except TypeError:
# type is not a class
return False
@final
class TypeAdapter(Generic[T]):
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/type_adapter/
Type adapters provide a flexible way to perform validation and serialization based on a Python type.
A `TypeAdapter` instance exposes some of the functionality from `BaseModel` instance methods
for types that do not have such methods (such as dataclasses, primitive types, and more).
**Note:** `TypeAdapter` instances are not types, and cannot be used as type annotations for fields.
Attributes:
core_schema: The core schema for the type.
validator (SchemaValidator): The schema validator for the type.
serializer: The schema serializer for the type.
"""
@overload
def __init__(
self,
type: type[T],
*,
config: ConfigDict | None = ...,
_parent_depth: int = ...,
module: str | None = ...,
) -> None:
...
# This second overload is for unsupported special forms (such as Union). `pyright` handles them fine, but `mypy` does not match
# them against `type: type[T]`, so an explicit overload with `type: T` is needed.
@overload
def __init__( # pyright: ignore[reportOverlappingOverload]
self,
type: T,
*,
config: ConfigDict | None = ...,
_parent_depth: int = ...,
module: str | None = ...,
) -> None:
...
def __init__(
self,
type: type[T] | T,
*,
config: ConfigDict | None = None,
_parent_depth: int = 2,
module: str | None = None,
) -> None:
"""Initializes the TypeAdapter object.
Args:
type: The type associated with the `TypeAdapter`.
config: Configuration for the `TypeAdapter`, should be a dictionary conforming to [`ConfigDict`][pydantic.config.ConfigDict].
_parent_depth: depth at which to search the parent namespace to construct the local namespace.
module: The module that passes to plugin if provided.
!!! note
You cannot use the `config` argument when instantiating a `TypeAdapter` if the type you're using has its own
config that cannot be overridden (ex: `BaseModel`, `TypedDict`, and `dataclass`). A
[`type-adapter-config-unused`](../errors/usage_errors.md#type-adapter-config-unused) error will be raised in this case.
!!! note
The `_parent_depth` argument is named with an underscore to suggest its private nature and discourage use.
It may be deprecated in a minor version, so we only recommend using it if you're
comfortable with potential change in behavior / support.
??? tip "Compatibility with `mypy`"
Depending on the type used, `mypy` might raise an error when instantiating a `TypeAdapter`. As a workaround, you can explicitly
annotate your variable:
```py
from typing import Union
from pydantic import TypeAdapter
ta: TypeAdapter[Union[str, int]] = TypeAdapter(Union[str, int]) # type: ignore[arg-type]
```
Returns:
A type adapter configured for the specified `type`.
"""
type_is_annotated: bool = _typing_extra.is_annotated(type)
annotated_type: Any = get_args(type)[0] if type_is_annotated else None
type_has_config: bool = _type_has_config(annotated_type if type_is_annotated else type)
if type_has_config and config is not None:
raise PydanticUserError(
'Cannot use `config` when the type is a BaseModel, dataclass or TypedDict.'
' These types can have their own config and setting the config via the `config`'
' parameter to TypeAdapter will not override it, thus the `config` you passed to'
' TypeAdapter becomes meaningless, which is probably not what you want.',
code='type-adapter-config-unused',
)
config_wrapper = _config.ConfigWrapper(config)
core_schema: CoreSchema
try:
core_schema = _getattr_no_parents(type, '__pydantic_core_schema__')
except AttributeError:
core_schema = _get_schema(type, config_wrapper, parent_depth=_parent_depth + 1)
core_config = config_wrapper.core_config(None)
validator: SchemaValidator
try:
validator = _getattr_no_parents(type, '__pydantic_validator__')
except AttributeError:
if module is None:
f = sys._getframe(1)
module = cast(str, f.f_globals.get('__name__', ''))
validator = create_schema_validator(
core_schema, type, module, str(type), 'TypeAdapter', core_config, config_wrapper.plugin_settings
) # type: ignore
serializer: SchemaSerializer
try:
serializer = _getattr_no_parents(type, '__pydantic_serializer__')
except AttributeError:
serializer = SchemaSerializer(core_schema, core_config)
self.core_schema = core_schema
self.validator = validator
self.serializer = serializer
def validate_python(
self,
__object: Any,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: dict[str, Any] | None = None,
) -> T:
"""Validate a Python object against the model.
Args:
__object: The Python object to validate against the model.
strict: Whether to strictly check types.
from_attributes: Whether to extract data from object attributes.
context: Additional context to pass to the validator.
!!! note
When using `TypeAdapter` with a Pydantic `dataclass`, the use of the `from_attributes`
argument is not supported.
Returns:
The validated object.
"""
return self.validator.validate_python(__object, strict=strict, from_attributes=from_attributes, context=context)
def validate_json(
self, __data: str | bytes, *, strict: bool | None = None, context: dict[str, Any] | None = None
) -> T:
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/json/#json-parsing
Validate a JSON string or bytes against the model.
Args:
__data: The JSON data to validate against the model.
strict: Whether to strictly check types.
context: Additional context to use during validation.
Returns:
The validated object.
"""
return self.validator.validate_json(__data, strict=strict, context=context)
def validate_strings(self, __obj: Any, *, strict: bool | None = None, context: dict[str, Any] | None = None) -> T:
"""Validate object contains string data against the model.
Args:
__obj: The object contains string data to validate.
strict: Whether to strictly check types.
context: Additional context to use during validation.
Returns:
The validated object.
"""
return self.validator.validate_strings(__obj, strict=strict, context=context)
def get_default_value(self, *, strict: bool | None = None, context: dict[str, Any] | None = None) -> Some[T] | None:
"""Get the default value for the wrapped type.
Args:
strict: Whether to strictly check types.
context: Additional context to pass to the validator.
Returns:
The default value wrapped in a `Some` if there is one or None if not.
"""
return self.validator.get_default_value(strict=strict, context=context)
def dump_python(
self,
__instance: T,
*,
mode: Literal['json', 'python'] = 'python',
include: IncEx | None = None,
exclude: IncEx | None = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool = True,
) -> Any:
"""Dump an instance of the adapted type to a Python object.
Args:
__instance: The Python object to serialize.
mode: The output format.
include: Fields to include in the output.
exclude: Fields to exclude from the output.
by_alias: Whether to use alias names for field names.
exclude_unset: Whether to exclude unset fields.
exclude_defaults: Whether to exclude fields with default values.
exclude_none: Whether to exclude fields with None values.
round_trip: Whether to output the serialized data in a way that is compatible with deserialization.
warnings: Whether to display serialization warnings.
Returns:
The serialized object.
"""
return self.serializer.to_python(
__instance,
mode=mode,
by_alias=by_alias,
include=include,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
round_trip=round_trip,
warnings=warnings,
)
def dump_json(
self,
__instance: T,
*,
indent: int | None = None,
include: IncEx | None = None,
exclude: IncEx | None = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool = True,
) -> bytes:
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/json/#json-serialization
Serialize an instance of the adapted type to JSON.
Args:
__instance: The instance to be serialized.
indent: Number of spaces for JSON indentation.
include: Fields to include.
exclude: Fields to exclude.
by_alias: Whether to use alias names for field names.
exclude_unset: Whether to exclude unset fields.
exclude_defaults: Whether to exclude fields with default values.
exclude_none: Whether to exclude fields with a value of `None`.
round_trip: Whether to serialize and deserialize the instance to ensure round-tripping.
warnings: Whether to emit serialization warnings.
Returns:
The JSON representation of the given instance as bytes.
"""
return self.serializer.to_json(
__instance,
indent=indent,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
round_trip=round_trip,
warnings=warnings,
)
def json_schema(
self,
*,
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
mode: JsonSchemaMode = 'validation',
) -> dict[str, Any]:
"""Generate a JSON schema for the adapted type.
Args:
by_alias: Whether to use alias names for field names.
ref_template: The format string used for generating $ref strings.
schema_generator: The generator class used for creating the schema.
mode: The mode to use for schema generation.
Returns:
The JSON schema for the model as a dictionary.
"""
schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template)
return schema_generator_instance.generate(self.core_schema, mode=mode)
@staticmethod
def json_schemas(
__inputs: Iterable[tuple[JsonSchemaKeyT, JsonSchemaMode, TypeAdapter[Any]]],
*,
by_alias: bool = True,
title: str | None = None,
description: str | None = None,
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
) -> tuple[dict[tuple[JsonSchemaKeyT, JsonSchemaMode], JsonSchemaValue], JsonSchemaValue]:
"""Generate a JSON schema including definitions from multiple type adapters.
Args:
__inputs: Inputs to schema generation. The first two items will form the keys of the (first)
output mapping; the type adapters will provide the core schemas that get converted into
definitions in the output JSON schema.
by_alias: Whether to use alias names.
title: The title for the schema.
description: The description for the schema.
ref_template: The format string used for generating $ref strings.
schema_generator: The generator class used for creating the schema.
Returns:
A tuple where:
- The first element is a dictionary whose keys are tuples of JSON schema key type and JSON mode, and
whose values are the JSON schema corresponding to that pair of inputs. (These schemas may have
JsonRef references to definitions that are defined in the second returned element.)
- The second element is a JSON schema containing all definitions referenced in the first returned
element, along with the optional title and description keys.
"""
schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template)
inputs = [(key, mode, adapter.core_schema) for key, mode, adapter in __inputs]
json_schemas_map, definitions = schema_generator_instance.generate_definitions(inputs)
json_schema: dict[str, Any] = {}
if definitions:
json_schema['$defs'] = definitions
if title:
json_schema['title'] = title
if description:
json_schema['description'] = description
return json_schemas_map, json_schema

File diff suppressed because it is too large Load diff

View file

@ -1,602 +1,4 @@
import sys """`typing` module is a backport module from V1."""
from collections.abc import Callable from ._migration import getattr_migration
from os import PathLike
from typing import ( # type: ignore
TYPE_CHECKING,
AbstractSet,
Any,
Callable as TypingCallable,
ClassVar,
Dict,
ForwardRef,
Generator,
Iterable,
List,
Mapping,
NewType,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
_eval_type,
cast,
get_type_hints,
)
from typing_extensions import ( __getattr__ = getattr_migration(__name__)
Annotated,
Final,
Literal,
NotRequired as TypedDictNotRequired,
Required as TypedDictRequired,
)
try:
from typing import _TypingBase as typing_base # type: ignore
except ImportError:
from typing import _Final as typing_base # type: ignore
try:
from typing import GenericAlias as TypingGenericAlias # type: ignore
except ImportError:
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
TypingGenericAlias = ()
try:
from types import UnionType as TypesUnionType # type: ignore
except ImportError:
# python < 3.10 does not have UnionType (str | int, byte | bool and so on)
TypesUnionType = ()
if sys.version_info < (3, 9):
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
return type_._evaluate(globalns, localns)
else:
def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any:
# Even though it is the right signature for python 3.9, mypy complains with
# `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast...
return cast(Any, type_)._evaluate(globalns, localns, set())
if sys.version_info < (3, 9):
# Ensure we always get all the whole `Annotated` hint, not just the annotated type.
# For 3.7 to 3.8, `get_type_hints` doesn't recognize `typing_extensions.Annotated`,
# so it already returns the full annotation
get_all_type_hints = get_type_hints
else:
def get_all_type_hints(obj: Any, globalns: Any = None, localns: Any = None) -> Any:
return get_type_hints(obj, globalns, localns, include_extras=True)
_T = TypeVar('_T')
AnyCallable = TypingCallable[..., Any]
NoArgAnyCallable = TypingCallable[[], Any]
# workaround for https://github.com/python/mypy/issues/9496
AnyArgTCallable = TypingCallable[..., _T]
# Annotated[...] is implemented by returning an instance of one of these classes, depending on
# python/typing_extensions version.
AnnotatedTypeNames = {'AnnotatedMeta', '_AnnotatedAlias'}
if sys.version_info < (3, 8):
def get_origin(t: Type[Any]) -> Optional[Type[Any]]:
if type(t).__name__ in AnnotatedTypeNames:
# weirdly this is a runtime requirement, as well as for mypy
return cast(Type[Any], Annotated)
return getattr(t, '__origin__', None)
else:
from typing import get_origin as _typing_get_origin
def get_origin(tp: Type[Any]) -> Optional[Type[Any]]:
"""
We can't directly use `typing.get_origin` since we need a fallback to support
custom generic classes like `ConstrainedList`
It should be useless once https://github.com/cython/cython/issues/3537 is
solved and https://github.com/pydantic/pydantic/pull/1753 is merged.
"""
if type(tp).__name__ in AnnotatedTypeNames:
return cast(Type[Any], Annotated) # mypy complains about _SpecialForm
return _typing_get_origin(tp) or getattr(tp, '__origin__', None)
if sys.version_info < (3, 8):
from typing import _GenericAlias
def get_args(t: Type[Any]) -> Tuple[Any, ...]:
"""Compatibility version of get_args for python 3.7.
Mostly compatible with the python 3.8 `typing` module version
and able to handle almost all use cases.
"""
if type(t).__name__ in AnnotatedTypeNames:
return t.__args__ + t.__metadata__
if isinstance(t, _GenericAlias):
res = t.__args__
if t.__origin__ is Callable and res and res[0] is not Ellipsis:
res = (list(res[:-1]), res[-1])
return res
return getattr(t, '__args__', ())
else:
from typing import get_args as _typing_get_args
def _generic_get_args(tp: Type[Any]) -> Tuple[Any, ...]:
"""
In python 3.9, `typing.Dict`, `typing.List`, ...
do have an empty `__args__` by default (instead of the generic ~T for example).
In order to still support `Dict` for example and consider it as `Dict[Any, Any]`,
we retrieve the `_nparams` value that tells us how many parameters it needs.
"""
if hasattr(tp, '_nparams'):
return (Any,) * tp._nparams
# Special case for `tuple[()]`, which used to return ((),) with `typing.Tuple`
# in python 3.10- but now returns () for `tuple` and `Tuple`.
# This will probably be clarified in pydantic v2
try:
if tp == Tuple[()] or sys.version_info >= (3, 9) and tp == tuple[()]: # type: ignore[misc]
return ((),)
# there is a TypeError when compiled with cython
except TypeError: # pragma: no cover
pass
return ()
def get_args(tp: Type[Any]) -> Tuple[Any, ...]:
"""Get type arguments with all substitutions performed.
For unions, basic simplifications used by Union constructor are performed.
Examples::
get_args(Dict[str, int]) == (str, int)
get_args(int) == ()
get_args(Union[int, Union[T, int], str][int]) == (int, str)
get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
get_args(Callable[[], T][int]) == ([], int)
"""
if type(tp).__name__ in AnnotatedTypeNames:
return tp.__args__ + tp.__metadata__
# the fallback is needed for the same reasons as `get_origin` (see above)
return _typing_get_args(tp) or getattr(tp, '__args__', ()) or _generic_get_args(tp)
if sys.version_info < (3, 9):
def convert_generics(tp: Type[Any]) -> Type[Any]:
"""Python 3.9 and older only supports generics from `typing` module.
They convert strings to ForwardRef automatically.
Examples::
typing.List['Hero'] == typing.List[ForwardRef('Hero')]
"""
return tp
else:
from typing import _UnionGenericAlias # type: ignore
from typing_extensions import _AnnotatedAlias
def convert_generics(tp: Type[Any]) -> Type[Any]:
"""
Recursively searches for `str` type hints and replaces them with ForwardRef.
Examples::
convert_generics(list['Hero']) == list[ForwardRef('Hero')]
convert_generics(dict['Hero', 'Team']) == dict[ForwardRef('Hero'), ForwardRef('Team')]
convert_generics(typing.Dict['Hero', 'Team']) == typing.Dict[ForwardRef('Hero'), ForwardRef('Team')]
convert_generics(list[str | 'Hero'] | int) == list[str | ForwardRef('Hero')] | int
"""
origin = get_origin(tp)
if not origin or not hasattr(tp, '__args__'):
return tp
args = get_args(tp)
# typing.Annotated needs special treatment
if origin is Annotated:
return _AnnotatedAlias(convert_generics(args[0]), args[1:])
# recursively replace `str` instances inside of `GenericAlias` with `ForwardRef(arg)`
converted = tuple(
ForwardRef(arg) if isinstance(arg, str) and isinstance(tp, TypingGenericAlias) else convert_generics(arg)
for arg in args
)
if converted == args:
return tp
elif isinstance(tp, TypingGenericAlias):
return TypingGenericAlias(origin, converted)
elif isinstance(tp, TypesUnionType):
# recreate types.UnionType (PEP604, Python >= 3.10)
return _UnionGenericAlias(origin, converted)
else:
try:
setattr(tp, '__args__', converted)
except AttributeError:
pass
return tp
if sys.version_info < (3, 10):
def is_union(tp: Optional[Type[Any]]) -> bool:
return tp is Union
WithArgsTypes = (TypingGenericAlias,)
else:
import types
import typing
def is_union(tp: Optional[Type[Any]]) -> bool:
return tp is Union or tp is types.UnionType # noqa: E721
WithArgsTypes = (typing._GenericAlias, types.GenericAlias, types.UnionType)
if sys.version_info < (3, 9):
StrPath = Union[str, PathLike]
else:
StrPath = Union[str, PathLike]
# TODO: Once we switch to Cython 3 to handle generics properly
# (https://github.com/cython/cython/issues/2753), use following lines instead
# of the one above
# # os.PathLike only becomes subscriptable from Python 3.9 onwards
# StrPath = Union[str, PathLike[str]]
if TYPE_CHECKING:
from .fields import ModelField
TupleGenerator = Generator[Tuple[str, Any], None, None]
DictStrAny = Dict[str, Any]
DictAny = Dict[Any, Any]
SetStr = Set[str]
ListStr = List[str]
IntStr = Union[int, str]
AbstractSetIntStr = AbstractSet[IntStr]
DictIntStrAny = Dict[IntStr, Any]
MappingIntStrAny = Mapping[IntStr, Any]
CallableGenerator = Generator[AnyCallable, None, None]
ReprArgs = Sequence[Tuple[Optional[str], Any]]
AnyClassMethod = classmethod[Any]
__all__ = (
'AnyCallable',
'NoArgAnyCallable',
'NoneType',
'is_none_type',
'display_as_type',
'resolve_annotations',
'is_callable_type',
'is_literal_type',
'all_literal_values',
'is_namedtuple',
'is_typeddict',
'is_typeddict_special',
'is_new_type',
'new_type_supertype',
'is_classvar',
'is_finalvar',
'update_field_forward_refs',
'update_model_forward_refs',
'TupleGenerator',
'DictStrAny',
'DictAny',
'SetStr',
'ListStr',
'IntStr',
'AbstractSetIntStr',
'DictIntStrAny',
'CallableGenerator',
'ReprArgs',
'AnyClassMethod',
'CallableGenerator',
'WithArgsTypes',
'get_args',
'get_origin',
'get_sub_types',
'typing_base',
'get_all_type_hints',
'is_union',
'StrPath',
'MappingIntStrAny',
)
NoneType = None.__class__
NONE_TYPES: Tuple[Any, Any, Any] = (None, NoneType, Literal[None])
if sys.version_info < (3, 8):
# Even though this implementation is slower, we need it for python 3.7:
# In python 3.7 "Literal" is not a builtin type and uses a different
# mechanism.
# for this reason `Literal[None] is Literal[None]` evaluates to `False`,
# breaking the faster implementation used for the other python versions.
def is_none_type(type_: Any) -> bool:
return type_ in NONE_TYPES
elif sys.version_info[:2] == (3, 8):
def is_none_type(type_: Any) -> bool:
for none_type in NONE_TYPES:
if type_ is none_type:
return True
# With python 3.8, specifically 3.8.10, Literal "is" check sare very flakey
# can change on very subtle changes like use of types in other modules,
# hopefully this check avoids that issue.
if is_literal_type(type_): # pragma: no cover
return all_literal_values(type_) == (None,)
return False
else:
def is_none_type(type_: Any) -> bool:
for none_type in NONE_TYPES:
if type_ is none_type:
return True
return False
def display_as_type(v: Type[Any]) -> str:
if not isinstance(v, typing_base) and not isinstance(v, WithArgsTypes) and not isinstance(v, type):
v = v.__class__
if is_union(get_origin(v)):
return f'Union[{", ".join(map(display_as_type, get_args(v)))}]'
if isinstance(v, WithArgsTypes):
# Generic alias are constructs like `list[int]`
return str(v).replace('typing.', '')
try:
return v.__name__
except AttributeError:
# happens with typing objects
return str(v).replace('typing.', '')
def resolve_annotations(raw_annotations: Dict[str, Type[Any]], module_name: Optional[str]) -> Dict[str, Type[Any]]:
"""
Partially taken from typing.get_type_hints.
Resolve string or ForwardRef annotations into type objects if possible.
"""
base_globals: Optional[Dict[str, Any]] = None
if module_name:
try:
module = sys.modules[module_name]
except KeyError:
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
pass
else:
base_globals = module.__dict__
annotations = {}
for name, value in raw_annotations.items():
if isinstance(value, str):
if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1):
value = ForwardRef(value, is_argument=False, is_class=True)
else:
value = ForwardRef(value, is_argument=False)
try:
value = _eval_type(value, base_globals, None)
except NameError:
# this is ok, it can be fixed with update_forward_refs
pass
annotations[name] = value
return annotations
def is_callable_type(type_: Type[Any]) -> bool:
return type_ is Callable or get_origin(type_) is Callable
def is_literal_type(type_: Type[Any]) -> bool:
return Literal is not None and get_origin(type_) is Literal
def literal_values(type_: Type[Any]) -> Tuple[Any, ...]:
return get_args(type_)
def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]:
"""
This method is used to retrieve all Literal values as
Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586)
e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`
"""
if not is_literal_type(type_):
return (type_,)
values = literal_values(type_)
return tuple(x for value in values for x in all_literal_values(value))
def is_namedtuple(type_: Type[Any]) -> bool:
"""
Check if a given class is a named tuple.
It can be either a `typing.NamedTuple` or `collections.namedtuple`
"""
from .utils import lenient_issubclass
return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields')
def is_typeddict(type_: Type[Any]) -> bool:
"""
Check if a given class is a typed dict (from `typing` or `typing_extensions`)
In 3.10, there will be a public method (https://docs.python.org/3.10/library/typing.html#typing.is_typeddict)
"""
from .utils import lenient_issubclass
return lenient_issubclass(type_, dict) and hasattr(type_, '__total__')
def _check_typeddict_special(type_: Any) -> bool:
return type_ is TypedDictRequired or type_ is TypedDictNotRequired
def is_typeddict_special(type_: Any) -> bool:
"""
Check if type is a TypedDict special form (Required or NotRequired).
"""
return _check_typeddict_special(type_) or _check_typeddict_special(get_origin(type_))
test_type = NewType('test_type', str)
def is_new_type(type_: Type[Any]) -> bool:
"""
Check whether type_ was created using typing.NewType
"""
return isinstance(type_, test_type.__class__) and hasattr(type_, '__supertype__') # type: ignore
def new_type_supertype(type_: Type[Any]) -> Type[Any]:
while hasattr(type_, '__supertype__'):
type_ = type_.__supertype__
return type_
def _check_classvar(v: Optional[Type[Any]]) -> bool:
if v is None:
return False
return v.__class__ == ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar'
def _check_finalvar(v: Optional[Type[Any]]) -> bool:
"""
Check if a given type is a `typing.Final` type.
"""
if v is None:
return False
return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final')
def is_classvar(ann_type: Type[Any]) -> bool:
if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)):
return True
# this is an ugly workaround for class vars that contain forward references and are therefore themselves
# forward references, see #3679
if ann_type.__class__ == ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['):
return True
return False
def is_finalvar(ann_type: Type[Any]) -> bool:
return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type))
def update_field_forward_refs(field: 'ModelField', globalns: Any, localns: Any) -> None:
"""
Try to update ForwardRefs on fields based on this ModelField, globalns and localns.
"""
prepare = False
if field.type_.__class__ == ForwardRef:
prepare = True
field.type_ = evaluate_forwardref(field.type_, globalns, localns or None)
if field.outer_type_.__class__ == ForwardRef:
prepare = True
field.outer_type_ = evaluate_forwardref(field.outer_type_, globalns, localns or None)
if prepare:
field.prepare()
if field.sub_fields:
for sub_f in field.sub_fields:
update_field_forward_refs(sub_f, globalns=globalns, localns=localns)
if field.discriminator_key is not None:
field.prepare_discriminated_union_sub_fields()
def update_model_forward_refs(
model: Type[Any],
fields: Iterable['ModelField'],
json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable],
localns: 'DictStrAny',
exc_to_suppress: Tuple[Type[BaseException], ...] = (),
) -> None:
"""
Try to update model fields ForwardRefs based on model and localns.
"""
if model.__module__ in sys.modules:
globalns = sys.modules[model.__module__].__dict__.copy()
else:
globalns = {}
globalns.setdefault(model.__name__, model)
for f in fields:
try:
update_field_forward_refs(f, globalns=globalns, localns=localns)
except exc_to_suppress:
pass
for key in set(json_encoders.keys()):
if isinstance(key, str):
fr: ForwardRef = ForwardRef(key)
elif isinstance(key, ForwardRef):
fr = key
else:
continue
try:
new_key = evaluate_forwardref(fr, globalns, localns or None)
except exc_to_suppress: # pragma: no cover
continue
json_encoders[new_key] = json_encoders.pop(key)
def get_class(type_: Type[Any]) -> Union[None, bool, Type[Any]]:
"""
Tries to get the class of a Type[T] annotation. Returns True if Type is used
without brackets. Otherwise returns None.
"""
if type_ is type:
return True
if get_origin(type_) is None:
return None
args = get_args(type_)
if not args or not isinstance(args[0], type):
return True
else:
return args[0]
def get_sub_types(tp: Any) -> List[Any]:
"""
Return all the types that are allowed by type `tp`
`tp` can be a `Union` of allowed types or an `Annotated` type
"""
origin = get_origin(tp)
if origin is Annotated:
return get_sub_types(get_args(tp)[0])
elif is_union(origin):
return [x for t in get_args(tp) for x in get_sub_types(t)]
else:
return [tp]

View file

@ -1,841 +1,4 @@
import keyword """The `utils` module is a backport module from V1."""
import warnings from ._migration import getattr_migration
import weakref
from collections import OrderedDict, defaultdict, deque
from copy import deepcopy
from itertools import islice, zip_longest
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Collection,
Dict,
Generator,
Iterable,
Iterator,
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)
from typing_extensions import Annotated __getattr__ = getattr_migration(__name__)
from .errors import ConfigError
from .typing import (
NoneType,
WithArgsTypes,
all_literal_values,
display_as_type,
get_args,
get_origin,
is_literal_type,
is_union,
)
from .version import version_info
if TYPE_CHECKING:
from inspect import Signature
from pathlib import Path
from .config import BaseConfig
from .dataclasses import Dataclass
from .fields import ModelField
from .main import BaseModel
from .typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs
RichReprResult = Iterable[Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]]
__all__ = (
'import_string',
'sequence_like',
'validate_field_name',
'lenient_isinstance',
'lenient_issubclass',
'in_ipython',
'is_valid_identifier',
'deep_update',
'update_not_none',
'almost_equal_floats',
'get_model',
'to_camel',
'is_valid_field',
'smart_deepcopy',
'PyObjectStr',
'Representation',
'GetterDict',
'ValueItems',
'version_info', # required here to match behaviour in v1.3
'ClassAttribute',
'path_type',
'ROOT_KEY',
'get_unique_discriminator_alias',
'get_discriminator_alias_and_values',
'DUNDER_ATTRIBUTES',
'LimitedDict',
)
ROOT_KEY = '__root__'
# these are types that are returned unchanged by deepcopy
IMMUTABLE_NON_COLLECTIONS_TYPES: Set[Type[Any]] = {
int,
float,
complex,
str,
bool,
bytes,
type,
NoneType,
FunctionType,
BuiltinFunctionType,
LambdaType,
weakref.ref,
CodeType,
# note: including ModuleType will differ from behaviour of deepcopy by not producing error.
# It might be not a good idea in general, but considering that this function used only internally
# against default values of fields, this will allow to actually have a field with module as default value
ModuleType,
NotImplemented.__class__,
Ellipsis.__class__,
}
# these are types that if empty, might be copied with simple copy() instead of deepcopy()
BUILTIN_COLLECTIONS: Set[Type[Any]] = {
list,
set,
tuple,
frozenset,
dict,
OrderedDict,
defaultdict,
deque,
}
def import_string(dotted_path: str) -> Any:
"""
Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the
last name in the path. Raise ImportError if the import fails.
"""
from importlib import import_module
try:
module_path, class_name = dotted_path.strip(' ').rsplit('.', 1)
except ValueError as e:
raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e
module = import_module(module_path)
try:
return getattr(module, class_name)
except AttributeError as e:
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e
def truncate(v: Union[str], *, max_len: int = 80) -> str:
"""
Truncate a value and add a unicode ellipsis (three dots) to the end if it was too long
"""
warnings.warn('`truncate` is no-longer used by pydantic and is deprecated', DeprecationWarning)
if isinstance(v, str) and len(v) > (max_len - 2):
# -3 so quote + string + … + quote has correct length
return (v[: (max_len - 3)] + '').__repr__()
try:
v = v.__repr__()
except TypeError:
v = v.__class__.__repr__(v) # in case v is a type
if len(v) > max_len:
v = v[: max_len - 1] + ''
return v
def sequence_like(v: Any) -> bool:
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None:
"""
Ensure that the field's name does not shadow an existing attribute of the model.
"""
for base in bases:
if getattr(base, field_name, None):
raise NameError(
f'Field name "{field_name}" shadows a BaseModel attribute; '
f'use a different field name with "alias=\'{field_name}\'".'
)
def lenient_isinstance(o: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
try:
return isinstance(o, class_or_tuple) # type: ignore[arg-type]
except TypeError:
return False
def lenient_issubclass(cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool:
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
except TypeError:
if isinstance(cls, WithArgsTypes):
return False
raise # pragma: no cover
def in_ipython() -> bool:
"""
Check whether we're in an ipython environment, including jupyter notebooks.
"""
try:
eval('__IPYTHON__')
except NameError:
return False
else: # pragma: no cover
return True
def is_valid_identifier(identifier: str) -> bool:
"""
Checks that a string is a valid identifier and not a Python keyword.
:param identifier: The identifier to test.
:return: True if the identifier is valid.
"""
return identifier.isidentifier() and not keyword.iskeyword(identifier)
KeyType = TypeVar('KeyType')
def deep_update(mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]) -> Dict[KeyType, Any]:
updated_mapping = mapping.copy()
for updating_mapping in updating_mappings:
for k, v in updating_mapping.items():
if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
updated_mapping[k] = deep_update(updated_mapping[k], v)
else:
updated_mapping[k] = v
return updated_mapping
def update_not_none(mapping: Dict[Any, Any], **update: Any) -> None:
mapping.update({k: v for k, v in update.items() if v is not None})
def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool:
"""
Return True if two floats are almost equal
"""
return abs(value_1 - value_2) <= delta
def generate_model_signature(
init: Callable[..., None], fields: Dict[str, 'ModelField'], config: Type['BaseConfig']
) -> 'Signature':
"""
Generate signature for model based on its fields
"""
from inspect import Parameter, Signature, signature
from .config import Extra
present_params = signature(init).parameters.values()
merged_params: Dict[str, Parameter] = {}
var_kw = None
use_var_kw = False
for param in islice(present_params, 1, None): # skip self arg
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = param
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = config.allow_population_by_field_name
for field_name, field in fields.items():
param_name = field.alias
if field_name in merged_params or param_name in merged_params:
continue
elif not is_valid_identifier(param_name):
if allow_names and is_valid_identifier(field_name):
param_name = field_name
else:
use_var_kw = True
continue
# TODO: replace annotation with actual expected types once #1055 solved
kwargs = {'default': field.default} if not field.required else {}
merged_params[param_name] = Parameter(
param_name, Parameter.KEYWORD_ONLY, annotation=field.annotation, **kwargs
)
if config.extra is Extra.allow:
use_var_kw = True
if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
# else start from var_kw
var_kw_name = var_kw.name
# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
return Signature(parameters=list(merged_params.values()), return_annotation=None)
def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']:
from .main import BaseModel
try:
model_cls = obj.__pydantic_model__ # type: ignore
except AttributeError:
model_cls = obj
if not issubclass(model_cls, BaseModel):
raise TypeError('Unsupported type, must be either BaseModel or dataclass')
return model_cls
def to_camel(string: str) -> str:
return ''.join(word.capitalize() for word in string.split('_'))
def to_lower_camel(string: str) -> str:
if len(string) >= 1:
pascal_string = to_camel(string)
return pascal_string[0].lower() + pascal_string[1:]
return string.lower()
T = TypeVar('T')
def unique_list(
input_list: Union[List[T], Tuple[T, ...]],
*,
name_factory: Callable[[T], str] = str,
) -> List[T]:
"""
Make a list unique while maintaining order.
We update the list if another one with the same name is set
(e.g. root validator overridden in subclass)
"""
result: List[T] = []
result_names: List[str] = []
for v in input_list:
v_name = name_factory(v)
if v_name not in result_names:
result_names.append(v_name)
result.append(v)
else:
result[result_names.index(v_name)] = v
return result
class PyObjectStr(str):
"""
String class where repr doesn't include quotes. Useful with Representation when you want to return a string
representation of something that valid (or pseudo-valid) python.
"""
def __repr__(self) -> str:
return str(self)
class Representation:
"""
Mixin to provide __str__, __repr__, and __pretty__ methods. See #884 for more details.
__pretty__ is used by [devtools](https://python-devtools.helpmanual.io/) to provide human readable representations
of objects.
"""
__slots__: Tuple[str, ...] = tuple()
def __repr_args__(self) -> 'ReprArgs':
"""
Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
Can either return:
* name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
* or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
"""
attrs = ((s, getattr(self, s)) for s in self.__slots__)
return [(a, v) for a, v in attrs if v is not None]
def __repr_name__(self) -> str:
"""
Name of the instance's class, used in __repr__.
"""
return self.__class__.__name__
def __repr_str__(self, join_str: str) -> str:
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]:
"""
Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects
"""
yield self.__repr_name__() + '('
yield 1
for name, value in self.__repr_args__():
if name is not None:
yield name + '='
yield fmt(value)
yield ','
yield 0
yield -1
yield ')'
def __str__(self) -> str:
return self.__repr_str__(' ')
def __repr__(self) -> str:
return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
def __rich_repr__(self) -> 'RichReprResult':
"""Get fields for Rich library"""
for name, field_repr in self.__repr_args__():
if name is None:
yield field_repr
else:
yield name, field_repr
class GetterDict(Representation):
"""
Hack to make object's smell just enough like dicts for validate_model.
We can't inherit from Mapping[str, Any] because it upsets cython so we have to implement all methods ourselves.
"""
__slots__ = ('_obj',)
def __init__(self, obj: Any):
self._obj = obj
def __getitem__(self, key: str) -> Any:
try:
return getattr(self._obj, key)
except AttributeError as e:
raise KeyError(key) from e
def get(self, key: Any, default: Any = None) -> Any:
return getattr(self._obj, key, default)
def extra_keys(self) -> Set[Any]:
"""
We don't want to get any other attributes of obj if the model didn't explicitly ask for them
"""
return set()
def keys(self) -> List[Any]:
"""
Keys of the pseudo dictionary, uses a list not set so order information can be maintained like python
dictionaries.
"""
return list(self)
def values(self) -> List[Any]:
return [self[k] for k in self]
def items(self) -> Iterator[Tuple[str, Any]]:
for k in self:
yield k, self.get(k)
def __iter__(self) -> Iterator[str]:
for name in dir(self._obj):
if not name.startswith('_'):
yield name
def __len__(self) -> int:
return sum(1 for _ in self)
def __contains__(self, item: Any) -> bool:
return item in self.keys()
def __eq__(self, other: Any) -> bool:
return dict(self) == dict(other.items())
def __repr_args__(self) -> 'ReprArgs':
return [(None, dict(self))]
def __repr_name__(self) -> str:
return f'GetterDict[{display_as_type(self._obj)}]'
class ValueItems(Representation):
"""
Class for more convenient calculation of excluded or included fields on values.
"""
__slots__ = ('_items', '_type')
def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> None:
items = self._coerce_items(items)
if isinstance(value, (list, tuple)):
items = self._normalize_indexes(items, len(value))
self._items: 'MappingIntStrAny' = items
def is_excluded(self, item: Any) -> bool:
"""
Check if item is fully excluded.
:param item: key or index of a value
"""
return self.is_true(self._items.get(item))
def is_included(self, item: Any) -> bool:
"""
Check if value is contained in self._items
:param item: key or index of value
"""
return item in self._items
def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]:
"""
:param e: key or index of element on value
:return: raw values for element if self._items is dict and contain needed element
"""
item = self._items.get(e)
return item if not self.is_true(item) else None
def _normalize_indexes(self, items: 'MappingIntStrAny', v_length: int) -> 'DictIntStrAny':
"""
:param items: dict or set of indexes which will be normalized
:param v_length: length of sequence indexes of which will be
>>> self._normalize_indexes({0: True, -2: True, -1: True}, 4)
{0: True, 2: True, 3: True}
>>> self._normalize_indexes({'__all__': True}, 4)
{0: True, 1: True, 2: True, 3: True}
"""
normalized_items: 'DictIntStrAny' = {}
all_items = None
for i, v in items.items():
if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or self.is_true(v)):
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
if i == '__all__':
all_items = self._coerce_value(v)
continue
if not isinstance(i, int):
raise TypeError(
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
'expected integer keys or keyword "__all__"'
)
normalized_i = v_length + i if i < 0 else i
normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i))
if not all_items:
return normalized_items
if self.is_true(all_items):
for i in range(v_length):
normalized_items.setdefault(i, ...)
return normalized_items
for i in range(v_length):
normalized_item = normalized_items.setdefault(i, {})
if not self.is_true(normalized_item):
normalized_items[i] = self.merge(all_items, normalized_item)
return normalized_items
@classmethod
def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any:
"""
Merge a ``base`` item with an ``override`` item.
Both ``base`` and ``override`` are converted to dictionaries if possible.
Sets are converted to dictionaries with the sets entries as keys and
Ellipsis as values.
Each key-value pair existing in ``base`` is merged with ``override``,
while the rest of the key-value pairs are updated recursively with this function.
Merging takes place based on the "union" of keys if ``intersect`` is
set to ``False`` (default) and on the intersection of keys if
``intersect`` is set to ``True``.
"""
override = cls._coerce_value(override)
base = cls._coerce_value(base)
if override is None:
return base
if cls.is_true(base) or base is None:
return override
if cls.is_true(override):
return base if intersect else override
# intersection or union of keys while preserving ordering:
if intersect:
merge_keys = [k for k in base if k in override] + [k for k in override if k in base]
else:
merge_keys = list(base) + [k for k in override if k not in base]
merged: 'DictIntStrAny' = {}
for k in merge_keys:
merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect)
if merged_item is not None:
merged[k] = merged_item
return merged
@staticmethod
def _coerce_items(items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> 'MappingIntStrAny':
if isinstance(items, Mapping):
pass
elif isinstance(items, AbstractSet):
items = dict.fromkeys(items, ...)
else:
class_name = getattr(items, '__class__', '???')
assert_never(
items,
f'Unexpected type of exclude value {class_name}',
)
return items
@classmethod
def _coerce_value(cls, value: Any) -> Any:
if value is None or cls.is_true(value):
return value
return cls._coerce_items(value)
@staticmethod
def is_true(v: Any) -> bool:
return v is True or v is ...
def __repr_args__(self) -> 'ReprArgs':
return [(None, self._items)]
class ClassAttribute:
"""
Hide class attribute from its instances
"""
__slots__ = (
'name',
'value',
)
def __init__(self, name: str, value: Any) -> None:
self.name = name
self.value = value
def __get__(self, instance: Any, owner: Type[Any]) -> None:
if instance is None:
return self.value
raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only')
path_types = {
'is_dir': 'directory',
'is_file': 'file',
'is_mount': 'mount point',
'is_symlink': 'symlink',
'is_block_device': 'block device',
'is_char_device': 'char device',
'is_fifo': 'FIFO',
'is_socket': 'socket',
}
def path_type(p: 'Path') -> str:
"""
Find out what sort of thing a path is.
"""
assert p.exists(), 'path does not exist'
for method, name in path_types.items():
if getattr(p, method)():
return name
return 'unknown'
Obj = TypeVar('Obj')
def smart_deepcopy(obj: Obj) -> Obj:
"""
Return type as is for immutable built-in types
Use obj.copy() for built-in empty collections
Use copy.deepcopy() for non-empty collections and unknown objects
"""
obj_type = obj.__class__
if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES:
return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway
try:
if not obj and obj_type in BUILTIN_COLLECTIONS:
# faster way for empty collections, no need to copy its members
return obj if obj_type is tuple else obj.copy() # type: ignore # tuple doesn't have copy method
except (TypeError, ValueError, RuntimeError):
# do we really dare to catch ALL errors? Seems a bit risky
pass
return deepcopy(obj) # slowest way when we actually might need a deepcopy
def is_valid_field(name: str) -> bool:
if not name.startswith('_'):
return True
return ROOT_KEY == name
DUNDER_ATTRIBUTES = {
'__annotations__',
'__classcell__',
'__doc__',
'__module__',
'__orig_bases__',
'__orig_class__',
'__qualname__',
}
def is_valid_private_name(name: str) -> bool:
return not is_valid_field(name) and name not in DUNDER_ATTRIBUTES
_EMPTY = object()
def all_identical(left: Iterable[Any], right: Iterable[Any]) -> bool:
"""
Check that the items of `left` are the same objects as those in `right`.
>>> a, b = object(), object()
>>> all_identical([a, b, a], [a, b, a])
True
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
False
"""
for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY):
if left_item is not right_item:
return False
return True
def assert_never(obj: NoReturn, msg: str) -> NoReturn:
"""
Helper to make sure that we have covered all possible types.
This is mostly useful for ``mypy``, docs:
https://mypy.readthedocs.io/en/latest/literal_types.html#exhaustive-checks
"""
raise TypeError(msg)
def get_unique_discriminator_alias(all_aliases: Collection[str], discriminator_key: str) -> str:
"""Validate that all aliases are the same and if that's the case return the alias"""
unique_aliases = set(all_aliases)
if len(unique_aliases) > 1:
raise ConfigError(
f'Aliases for discriminator {discriminator_key!r} must be the same (got {", ".join(sorted(all_aliases))})'
)
return unique_aliases.pop()
def get_discriminator_alias_and_values(tp: Any, discriminator_key: str) -> Tuple[str, Tuple[str, ...]]:
"""
Get alias and all valid values in the `Literal` type of the discriminator field
`tp` can be a `BaseModel` class or directly an `Annotated` `Union` of many.
"""
is_root_model = getattr(tp, '__custom_root_type__', False)
if get_origin(tp) is Annotated:
tp = get_args(tp)[0]
if hasattr(tp, '__pydantic_model__'):
tp = tp.__pydantic_model__
if is_union(get_origin(tp)):
alias, all_values = _get_union_alias_and_all_values(tp, discriminator_key)
return alias, tuple(v for values in all_values for v in values)
elif is_root_model:
union_type = tp.__fields__[ROOT_KEY].type_
alias, all_values = _get_union_alias_and_all_values(union_type, discriminator_key)
if len(set(all_values)) > 1:
raise ConfigError(
f'Field {discriminator_key!r} is not the same for all submodels of {display_as_type(tp)!r}'
)
return alias, all_values[0]
else:
try:
t_discriminator_type = tp.__fields__[discriminator_key].type_
except AttributeError as e:
raise TypeError(f'Type {tp.__name__!r} is not a valid `BaseModel` or `dataclass`') from e
except KeyError as e:
raise ConfigError(f'Model {tp.__name__!r} needs a discriminator field for key {discriminator_key!r}') from e
if not is_literal_type(t_discriminator_type):
raise ConfigError(f'Field {discriminator_key!r} of model {tp.__name__!r} needs to be a `Literal`')
return tp.__fields__[discriminator_key].alias, all_literal_values(t_discriminator_type)
def _get_union_alias_and_all_values(
union_type: Type[Any], discriminator_key: str
) -> Tuple[str, Tuple[Tuple[str, ...], ...]]:
zipped_aliases_values = [get_discriminator_alias_and_values(t, discriminator_key) for t in get_args(union_type)]
# unzip: [('alias_a',('v1', 'v2)), ('alias_b', ('v3',))] => [('alias_a', 'alias_b'), (('v1', 'v2'), ('v3',))]
all_aliases, all_values = zip(*zipped_aliases_values)
return get_unique_discriminator_alias(all_aliases, discriminator_key), all_values
KT = TypeVar('KT')
VT = TypeVar('VT')
if TYPE_CHECKING:
# Annoying inheriting from `MutableMapping` and `dict` breaks cython, hence this work around
class LimitedDict(dict, MutableMapping[KT, VT]): # type: ignore[type-arg]
def __init__(self, size_limit: int = 1000):
...
else:
class LimitedDict(dict):
"""
Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage.
Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache.
Annoying inheriting from `MutableMapping` breaks cython.
"""
def __init__(self, size_limit: int = 1000):
self.size_limit = size_limit
super().__init__()
def __setitem__(self, __key: Any, __value: Any) -> None:
super().__setitem__(__key, __value)
if len(self) > self.size_limit:
excess = len(self) - self.size_limit + self.size_limit // 10
to_remove = list(self.keys())[:excess]
for key in to_remove:
del self[key]
def __class_getitem__(cls, *args: Any) -> Any:
# to avoid errors with 3.7
pass

131
lib/pydantic/v1/__init__.py Normal file
View file

@ -0,0 +1,131 @@
# flake8: noqa
from . import dataclasses
from .annotated_types import create_model_from_namedtuple, create_model_from_typeddict
from .class_validators import root_validator, validator
from .config import BaseConfig, ConfigDict, Extra
from .decorator import validate_arguments
from .env_settings import BaseSettings
from .error_wrappers import ValidationError
from .errors import *
from .fields import Field, PrivateAttr, Required
from .main import *
from .networks import *
from .parse import Protocol
from .tools import *
from .types import *
from .version import VERSION, compiled
__version__ = VERSION
# WARNING __all__ from .errors is not included here, it will be removed as an export here in v2
# please use "from pydantic.errors import ..." instead
__all__ = [
# annotated types utils
'create_model_from_namedtuple',
'create_model_from_typeddict',
# dataclasses
'dataclasses',
# class_validators
'root_validator',
'validator',
# config
'BaseConfig',
'ConfigDict',
'Extra',
# decorator
'validate_arguments',
# env_settings
'BaseSettings',
# error_wrappers
'ValidationError',
# fields
'Field',
'Required',
# main
'BaseModel',
'create_model',
'validate_model',
# network
'AnyUrl',
'AnyHttpUrl',
'FileUrl',
'HttpUrl',
'stricturl',
'EmailStr',
'NameEmail',
'IPvAnyAddress',
'IPvAnyInterface',
'IPvAnyNetwork',
'PostgresDsn',
'CockroachDsn',
'AmqpDsn',
'RedisDsn',
'MongoDsn',
'KafkaDsn',
'validate_email',
# parse
'Protocol',
# tools
'parse_file_as',
'parse_obj_as',
'parse_raw_as',
'schema_of',
'schema_json_of',
# types
'NoneStr',
'NoneBytes',
'StrBytes',
'NoneStrBytes',
'StrictStr',
'ConstrainedBytes',
'conbytes',
'ConstrainedList',
'conlist',
'ConstrainedSet',
'conset',
'ConstrainedFrozenSet',
'confrozenset',
'ConstrainedStr',
'constr',
'PyObject',
'ConstrainedInt',
'conint',
'PositiveInt',
'NegativeInt',
'NonNegativeInt',
'NonPositiveInt',
'ConstrainedFloat',
'confloat',
'PositiveFloat',
'NegativeFloat',
'NonNegativeFloat',
'NonPositiveFloat',
'FiniteFloat',
'ConstrainedDecimal',
'condecimal',
'ConstrainedDate',
'condate',
'UUID1',
'UUID3',
'UUID4',
'UUID5',
'FilePath',
'DirectoryPath',
'Json',
'JsonWrapper',
'SecretField',
'SecretStr',
'SecretBytes',
'StrictBool',
'StrictBytes',
'StrictInt',
'StrictFloat',
'PaymentCardNumber',
'PrivateAttr',
'ByteSize',
'PastDate',
'FutureDate',
# version
'compiled',
'VERSION',
]

View file

@ -10,7 +10,7 @@ Pydantic is installed. See also:
https://hypothesis.readthedocs.io/en/latest/strategies.html#registering-strategies-via-setuptools-entry-points https://hypothesis.readthedocs.io/en/latest/strategies.html#registering-strategies-via-setuptools-entry-points
https://hypothesis.readthedocs.io/en/latest/data.html#hypothesis.strategies.register_type_strategy https://hypothesis.readthedocs.io/en/latest/data.html#hypothesis.strategies.register_type_strategy
https://hypothesis.readthedocs.io/en/latest/strategies.html#interaction-with-pytest-cov https://hypothesis.readthedocs.io/en/latest/strategies.html#interaction-with-pytest-cov
https://pydantic-docs.helpmanual.io/usage/types/#pydantic-types https://docs.pydantic.dev/usage/types/#pydantic-types
Note that because our motivation is to *improve user experience*, the strategies Note that because our motivation is to *improve user experience*, the strategies
are always sound (never generate invalid data) but sacrifice completeness for are always sound (never generate invalid data) but sacrifice completeness for
@ -46,7 +46,7 @@ from pydantic.utils import lenient_issubclass
# #
# conlist() and conset() are unsupported for now, because the workarounds for # conlist() and conset() are unsupported for now, because the workarounds for
# Cython and Hypothesis to handle parametrized generic types are incompatible. # Cython and Hypothesis to handle parametrized generic types are incompatible.
# Once Cython can support 'normal' generics we'll revisit this. # We are rethinking Hypothesis compatibility in Pydantic v2.
# Emails # Emails
try: try:
@ -168,6 +168,11 @@ st.register_type_strategy(pydantic.StrictBool, st.booleans())
st.register_type_strategy(pydantic.StrictStr, st.text()) st.register_type_strategy(pydantic.StrictStr, st.text())
# FutureDate, PastDate
st.register_type_strategy(pydantic.FutureDate, st.dates(min_value=datetime.date.today() + datetime.timedelta(days=1)))
st.register_type_strategy(pydantic.PastDate, st.dates(max_value=datetime.date.today() - datetime.timedelta(days=1)))
# Constrained-type resolver functions # Constrained-type resolver functions
# #
# For these ones, we actually want to inspect the type in order to work out a # For these ones, we actually want to inspect the type in order to work out a

Some files were not shown because too many files have changed in this diff Show more