From 51196a7fb1b709ca81e56e7e3534d7c8f0914a90 Mon Sep 17 00:00:00 2001 From: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com> Date: Sun, 24 Mar 2024 17:55:12 -0700 Subject: [PATCH] Update cherrypy==18.9.0 --- lib/annotated_types/__init__.py | 396 ++ .../__init__.py => annotated_types/py.typed} | 0 lib/annotated_types/test_cases.py | 147 + lib/autocommand/autoasync.py | 8 +- lib/cherrypy/_cplogging.py | 6 +- lib/cherrypy/lib/cptools.py | 12 +- lib/cherrypy/lib/profiler.py | 4 +- lib/cherrypy/lib/reprconf.py | 2 +- lib/cherrypy/lib/static.py | 12 +- lib/cherrypy/process/wspbus.py | 2 +- lib/cherrypy/test/test_http.py | 15 + lib/cherrypy/test/test_iterator.py | 2 + lib/cherrypy/test/test_logging.py | 28 + lib/inflect/__init__.py | 180 +- lib/inflect/compat/__init__.py | 0 lib/inflect/compat/pydantic.py | 19 + lib/inflect/compat/pydantic1.py | 8 + .../__init__.py} | 152 +- lib/jaraco/collections/py.typed | 0 lib/jaraco/context.py | 75 + .../{functools.py => functools/__init__.py} | 155 +- lib/jaraco/functools/__init__.pyi | 128 + lib/jaraco/functools/py.typed | 0 lib/jaraco/text/__init__.py | 18 +- lib/jaraco/text/show-newlines.py | 4 +- lib/jaraco/text/strip-prefix.py | 21 + lib/more_itertools/__init__.py | 2 +- lib/more_itertools/more.py | 143 +- lib/more_itertools/more.pyi | 23 +- lib/more_itertools/recipes.py | 107 +- lib/more_itertools/recipes.pyi | 16 +- lib/pydantic/__init__.py | 373 +- lib/pydantic/_internal/__init__.py | 0 lib/pydantic/_internal/_config.py | 322 ++ lib/pydantic/_internal/_core_metadata.py | 92 + lib/pydantic/_internal/_core_utils.py | 570 +++ lib/pydantic/_internal/_dataclasses.py | 225 + lib/pydantic/_internal/_decorators.py | 791 ++++ lib/pydantic/_internal/_decorators_v1.py | 181 + .../_internal/_discriminated_union.py | 506 +++ lib/pydantic/_internal/_fields.py | 319 ++ lib/pydantic/_internal/_forward_ref.py | 23 + lib/pydantic/_internal/_generate_schema.py | 2231 +++++++++ lib/pydantic/_internal/_generics.py | 517 +++ lib/pydantic/_internal/_git.py | 26 + lib/pydantic/_internal/_internal_dataclass.py | 10 + .../_internal/_known_annotated_metadata.py | 410 ++ lib/pydantic/_internal/_mock_val_ser.py | 140 + lib/pydantic/_internal/_model_construction.py | 637 +++ lib/pydantic/_internal/_repr.py | 117 + .../_internal/_schema_generation_shared.py | 124 + lib/pydantic/_internal/_signature.py | 164 + lib/pydantic/_internal/_std_types_schema.py | 714 +++ lib/pydantic/_internal/_typing_extra.py | 469 ++ lib/pydantic/_internal/_utils.py | 362 ++ lib/pydantic/_internal/_validate_call.py | 84 + lib/pydantic/_internal/_validators.py | 278 ++ lib/pydantic/_migration.py | 308 ++ lib/pydantic/alias_generators.py | 50 + lib/pydantic/aliases.py | 112 + lib/pydantic/annotated_handlers.py | 120 + lib/pydantic/class_validators.py | 344 +- lib/pydantic/color.py | 285 +- lib/pydantic/config.py | 1048 ++++- lib/pydantic/dataclasses.py | 632 +-- lib/pydantic/datetime_parse.py | 250 +- lib/pydantic/decorator.py | 266 +- lib/pydantic/deprecated/__init__.py | 0 lib/pydantic/deprecated/class_validators.py | 253 ++ lib/pydantic/deprecated/config.py | 72 + lib/pydantic/deprecated/copy_internals.py | 224 + lib/pydantic/deprecated/decorator.py | 279 ++ lib/pydantic/deprecated/json.py | 140 + lib/pydantic/deprecated/parse.py | 80 + lib/pydantic/deprecated/tools.py | 103 + lib/pydantic/env_settings.py | 348 +- lib/pydantic/error_wrappers.py | 164 +- lib/pydantic/errors.py | 728 +-- lib/pydantic/fields.py | 2201 +++++---- lib/pydantic/functional_serializers.py | 395 ++ lib/pydantic/functional_validators.py | 706 +++ lib/pydantic/generics.py | 366 +- lib/pydantic/json.py | 114 +- lib/pydantic/json_schema.py | 2425 ++++++++++ lib/pydantic/main.py | 2273 ++++++---- lib/pydantic/mypy.py | 1199 +++-- lib/pydantic/networks.py | 1202 +++-- lib/pydantic/parse.py | 68 +- lib/pydantic/plugin/__init__.py | 170 + lib/pydantic/plugin/_loader.py | 50 + lib/pydantic/plugin/_schema_validator.py | 138 + lib/pydantic/root_model.py | 149 + lib/pydantic/schema.py | 1155 +---- lib/pydantic/tools.py | 94 +- lib/pydantic/type_adapter.py | 460 ++ lib/pydantic/types.py | 3382 ++++++++++---- lib/pydantic/typing.py | 604 +-- lib/pydantic/utils.py | 843 +--- lib/pydantic/v1/__init__.py | 131 + lib/pydantic/{ => v1}/_hypothesis_plugin.py | 9 +- lib/pydantic/{ => v1}/annotated_types.py | 0 lib/pydantic/v1/class_validators.py | 361 ++ lib/pydantic/v1/color.py | 494 ++ lib/pydantic/v1/config.py | 191 + lib/pydantic/v1/dataclasses.py | 500 +++ lib/pydantic/v1/datetime_parse.py | 248 + lib/pydantic/v1/decorator.py | 264 ++ lib/pydantic/v1/env_settings.py | 350 ++ lib/pydantic/v1/error_wrappers.py | 161 + lib/pydantic/v1/errors.py | 646 +++ lib/pydantic/v1/fields.py | 1253 ++++++ lib/pydantic/v1/generics.py | 400 ++ lib/pydantic/v1/json.py | 112 + lib/pydantic/v1/main.py | 1107 +++++ lib/pydantic/v1/mypy.py | 944 ++++ lib/pydantic/v1/networks.py | 747 ++++ lib/pydantic/v1/parse.py | 66 + lib/pydantic/v1/py.typed | 0 lib/pydantic/v1/schema.py | 1163 +++++ lib/pydantic/v1/tools.py | 92 + lib/pydantic/v1/types.py | 1205 +++++ lib/pydantic/v1/typing.py | 603 +++ lib/pydantic/v1/utils.py | 803 ++++ lib/pydantic/v1/validators.py | 765 ++++ lib/pydantic/v1/version.py | 38 + lib/pydantic/validate_call_decorator.py | 67 + lib/pydantic/validators.py | 767 +--- lib/pydantic/version.py | 84 +- lib/pydantic/warnings.py | 58 + lib/pydantic_core/__init__.py | 139 + lib/pydantic_core/_pydantic_core.pyi | 882 ++++ lib/pydantic_core/core_schema.py | 3980 +++++++++++++++++ lib/pydantic_core/py.typed | 0 lib/tempora/timing.py | 13 +- lib/typing_extensions.py | 2142 ++++++--- lib/zc/lockfile/__init__.py | 23 +- lib/zc/lockfile/tests.py | 46 +- 137 files changed, 44442 insertions(+), 11582 deletions(-) create mode 100644 lib/annotated_types/__init__.py rename lib/{jaraco/classes/__init__.py => annotated_types/py.typed} (100%) create mode 100644 lib/annotated_types/test_cases.py create mode 100644 lib/inflect/compat/__init__.py create mode 100644 lib/inflect/compat/pydantic.py create mode 100644 lib/inflect/compat/pydantic1.py rename lib/jaraco/{collections.py => collections/__init__.py} (88%) create mode 100644 lib/jaraco/collections/py.typed rename lib/jaraco/{functools.py => functools/__init__.py} (84%) create mode 100644 lib/jaraco/functools/__init__.pyi create mode 100644 lib/jaraco/functools/py.typed create mode 100644 lib/jaraco/text/strip-prefix.py create mode 100644 lib/pydantic/_internal/__init__.py create mode 100644 lib/pydantic/_internal/_config.py create mode 100644 lib/pydantic/_internal/_core_metadata.py create mode 100644 lib/pydantic/_internal/_core_utils.py create mode 100644 lib/pydantic/_internal/_dataclasses.py create mode 100644 lib/pydantic/_internal/_decorators.py create mode 100644 lib/pydantic/_internal/_decorators_v1.py create mode 100644 lib/pydantic/_internal/_discriminated_union.py create mode 100644 lib/pydantic/_internal/_fields.py create mode 100644 lib/pydantic/_internal/_forward_ref.py create mode 100644 lib/pydantic/_internal/_generate_schema.py create mode 100644 lib/pydantic/_internal/_generics.py create mode 100644 lib/pydantic/_internal/_git.py create mode 100644 lib/pydantic/_internal/_internal_dataclass.py create mode 100644 lib/pydantic/_internal/_known_annotated_metadata.py create mode 100644 lib/pydantic/_internal/_mock_val_ser.py create mode 100644 lib/pydantic/_internal/_model_construction.py create mode 100644 lib/pydantic/_internal/_repr.py create mode 100644 lib/pydantic/_internal/_schema_generation_shared.py create mode 100644 lib/pydantic/_internal/_signature.py create mode 100644 lib/pydantic/_internal/_std_types_schema.py create mode 100644 lib/pydantic/_internal/_typing_extra.py create mode 100644 lib/pydantic/_internal/_utils.py create mode 100644 lib/pydantic/_internal/_validate_call.py create mode 100644 lib/pydantic/_internal/_validators.py create mode 100644 lib/pydantic/_migration.py create mode 100644 lib/pydantic/alias_generators.py create mode 100644 lib/pydantic/aliases.py create mode 100644 lib/pydantic/annotated_handlers.py create mode 100644 lib/pydantic/deprecated/__init__.py create mode 100644 lib/pydantic/deprecated/class_validators.py create mode 100644 lib/pydantic/deprecated/config.py create mode 100644 lib/pydantic/deprecated/copy_internals.py create mode 100644 lib/pydantic/deprecated/decorator.py create mode 100644 lib/pydantic/deprecated/json.py create mode 100644 lib/pydantic/deprecated/parse.py create mode 100644 lib/pydantic/deprecated/tools.py create mode 100644 lib/pydantic/functional_serializers.py create mode 100644 lib/pydantic/functional_validators.py create mode 100644 lib/pydantic/json_schema.py create mode 100644 lib/pydantic/plugin/__init__.py create mode 100644 lib/pydantic/plugin/_loader.py create mode 100644 lib/pydantic/plugin/_schema_validator.py create mode 100644 lib/pydantic/root_model.py create mode 100644 lib/pydantic/type_adapter.py create mode 100644 lib/pydantic/v1/__init__.py rename lib/pydantic/{ => v1}/_hypothesis_plugin.py (97%) rename lib/pydantic/{ => v1}/annotated_types.py (100%) create mode 100644 lib/pydantic/v1/class_validators.py create mode 100644 lib/pydantic/v1/color.py create mode 100644 lib/pydantic/v1/config.py create mode 100644 lib/pydantic/v1/dataclasses.py create mode 100644 lib/pydantic/v1/datetime_parse.py create mode 100644 lib/pydantic/v1/decorator.py create mode 100644 lib/pydantic/v1/env_settings.py create mode 100644 lib/pydantic/v1/error_wrappers.py create mode 100644 lib/pydantic/v1/errors.py create mode 100644 lib/pydantic/v1/fields.py create mode 100644 lib/pydantic/v1/generics.py create mode 100644 lib/pydantic/v1/json.py create mode 100644 lib/pydantic/v1/main.py create mode 100644 lib/pydantic/v1/mypy.py create mode 100644 lib/pydantic/v1/networks.py create mode 100644 lib/pydantic/v1/parse.py create mode 100644 lib/pydantic/v1/py.typed create mode 100644 lib/pydantic/v1/schema.py create mode 100644 lib/pydantic/v1/tools.py create mode 100644 lib/pydantic/v1/types.py create mode 100644 lib/pydantic/v1/typing.py create mode 100644 lib/pydantic/v1/utils.py create mode 100644 lib/pydantic/v1/validators.py create mode 100644 lib/pydantic/v1/version.py create mode 100644 lib/pydantic/validate_call_decorator.py create mode 100644 lib/pydantic/warnings.py create mode 100644 lib/pydantic_core/__init__.py create mode 100644 lib/pydantic_core/_pydantic_core.pyi create mode 100644 lib/pydantic_core/core_schema.py create mode 100644 lib/pydantic_core/py.typed diff --git a/lib/annotated_types/__init__.py b/lib/annotated_types/__init__.py new file mode 100644 index 00000000..2f989504 --- /dev/null +++ b/lib/annotated_types/__init__.py @@ -0,0 +1,396 @@ +import math +import sys +from dataclasses import dataclass +from datetime import timezone +from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, SupportsFloat, SupportsIndex, TypeVar, Union + +if sys.version_info < (3, 8): + from typing_extensions import Protocol, runtime_checkable +else: + from typing import Protocol, runtime_checkable + +if sys.version_info < (3, 9): + from typing_extensions import Annotated, Literal +else: + from typing import Annotated, Literal + +if sys.version_info < (3, 10): + EllipsisType = type(Ellipsis) + KW_ONLY = {} + SLOTS = {} +else: + from types import EllipsisType + + KW_ONLY = {"kw_only": True} + SLOTS = {"slots": True} + + +__all__ = ( + 'BaseMetadata', + 'GroupedMetadata', + 'Gt', + 'Ge', + 'Lt', + 'Le', + 'Interval', + 'MultipleOf', + 'MinLen', + 'MaxLen', + 'Len', + 'Timezone', + 'Predicate', + 'LowerCase', + 'UpperCase', + 'IsDigits', + 'IsFinite', + 'IsNotFinite', + 'IsNan', + 'IsNotNan', + 'IsInfinite', + 'IsNotInfinite', + 'doc', + 'DocInfo', + '__version__', +) + +__version__ = '0.6.0' + + +T = TypeVar('T') + + +# arguments that start with __ are considered +# positional only +# see https://peps.python.org/pep-0484/#positional-only-arguments + + +class SupportsGt(Protocol): + def __gt__(self: T, __other: T) -> bool: + ... + + +class SupportsGe(Protocol): + def __ge__(self: T, __other: T) -> bool: + ... + + +class SupportsLt(Protocol): + def __lt__(self: T, __other: T) -> bool: + ... + + +class SupportsLe(Protocol): + def __le__(self: T, __other: T) -> bool: + ... + + +class SupportsMod(Protocol): + def __mod__(self: T, __other: T) -> T: + ... + + +class SupportsDiv(Protocol): + def __div__(self: T, __other: T) -> T: + ... + + +class BaseMetadata: + """Base class for all metadata. + + This exists mainly so that implementers + can do `isinstance(..., BaseMetadata)` while traversing field annotations. + """ + + __slots__ = () + + +@dataclass(frozen=True, **SLOTS) +class Gt(BaseMetadata): + """Gt(gt=x) implies that the value must be greater than x. + + It can be used with any type that supports the ``>`` operator, + including numbers, dates and times, strings, sets, and so on. + """ + + gt: SupportsGt + + +@dataclass(frozen=True, **SLOTS) +class Ge(BaseMetadata): + """Ge(ge=x) implies that the value must be greater than or equal to x. + + It can be used with any type that supports the ``>=`` operator, + including numbers, dates and times, strings, sets, and so on. + """ + + ge: SupportsGe + + +@dataclass(frozen=True, **SLOTS) +class Lt(BaseMetadata): + """Lt(lt=x) implies that the value must be less than x. + + It can be used with any type that supports the ``<`` operator, + including numbers, dates and times, strings, sets, and so on. + """ + + lt: SupportsLt + + +@dataclass(frozen=True, **SLOTS) +class Le(BaseMetadata): + """Le(le=x) implies that the value must be less than or equal to x. + + It can be used with any type that supports the ``<=`` operator, + including numbers, dates and times, strings, sets, and so on. + """ + + le: SupportsLe + + +@runtime_checkable +class GroupedMetadata(Protocol): + """A grouping of multiple BaseMetadata objects. + + `GroupedMetadata` on its own is not metadata and has no meaning. + All it the the constraint and metadata should be fully expressable + in terms of the `BaseMetadata`'s returned by `GroupedMetadata.__iter__()`. + + Concrete implementations should override `GroupedMetadata.__iter__()` + to add their own metadata. + For example: + + >>> @dataclass + >>> class Field(GroupedMetadata): + >>> gt: float | None = None + >>> description: str | None = None + ... + >>> def __iter__(self) -> Iterable[BaseMetadata]: + >>> if self.gt is not None: + >>> yield Gt(self.gt) + >>> if self.description is not None: + >>> yield Description(self.gt) + + Also see the implementation of `Interval` below for an example. + + Parsers should recognize this and unpack it so that it can be used + both with and without unpacking: + + - `Annotated[int, Field(...)]` (parser must unpack Field) + - `Annotated[int, *Field(...)]` (PEP-646) + """ # noqa: trailing-whitespace + + @property + def __is_annotated_types_grouped_metadata__(self) -> Literal[True]: + return True + + def __iter__(self) -> Iterator[BaseMetadata]: + ... + + if not TYPE_CHECKING: + __slots__ = () # allow subclasses to use slots + + def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None: + # Basic ABC like functionality without the complexity of an ABC + super().__init_subclass__(*args, **kwargs) + if cls.__iter__ is GroupedMetadata.__iter__: + raise TypeError("Can't subclass GroupedMetadata without implementing __iter__") + + def __iter__(self) -> Iterator[BaseMetadata]: # noqa: F811 + raise NotImplementedError # more helpful than "None has no attribute..." type errors + + +@dataclass(frozen=True, **KW_ONLY, **SLOTS) +class Interval(GroupedMetadata): + """Interval can express inclusive or exclusive bounds with a single object. + + It accepts keyword arguments ``gt``, ``ge``, ``lt``, and/or ``le``, which + are interpreted the same way as the single-bound constraints. + """ + + gt: Union[SupportsGt, None] = None + ge: Union[SupportsGe, None] = None + lt: Union[SupportsLt, None] = None + le: Union[SupportsLe, None] = None + + def __iter__(self) -> Iterator[BaseMetadata]: + """Unpack an Interval into zero or more single-bounds.""" + if self.gt is not None: + yield Gt(self.gt) + if self.ge is not None: + yield Ge(self.ge) + if self.lt is not None: + yield Lt(self.lt) + if self.le is not None: + yield Le(self.le) + + +@dataclass(frozen=True, **SLOTS) +class MultipleOf(BaseMetadata): + """MultipleOf(multiple_of=x) might be interpreted in two ways: + + 1. Python semantics, implying ``value % multiple_of == 0``, or + 2. JSONschema semantics, where ``int(value / multiple_of) == value / multiple_of`` + + We encourage users to be aware of these two common interpretations, + and libraries to carefully document which they implement. + """ + + multiple_of: Union[SupportsDiv, SupportsMod] + + +@dataclass(frozen=True, **SLOTS) +class MinLen(BaseMetadata): + """ + MinLen() implies minimum inclusive length, + e.g. ``len(value) >= min_length``. + """ + + min_length: Annotated[int, Ge(0)] + + +@dataclass(frozen=True, **SLOTS) +class MaxLen(BaseMetadata): + """ + MaxLen() implies maximum inclusive length, + e.g. ``len(value) <= max_length``. + """ + + max_length: Annotated[int, Ge(0)] + + +@dataclass(frozen=True, **SLOTS) +class Len(GroupedMetadata): + """ + Len() implies that ``min_length <= len(value) <= max_length``. + + Upper bound may be omitted or ``None`` to indicate no upper length bound. + """ + + min_length: Annotated[int, Ge(0)] = 0 + max_length: Optional[Annotated[int, Ge(0)]] = None + + def __iter__(self) -> Iterator[BaseMetadata]: + """Unpack a Len into zone or more single-bounds.""" + if self.min_length > 0: + yield MinLen(self.min_length) + if self.max_length is not None: + yield MaxLen(self.max_length) + + +@dataclass(frozen=True, **SLOTS) +class Timezone(BaseMetadata): + """Timezone(tz=...) requires a datetime to be aware (or ``tz=None``, naive). + + ``Annotated[datetime, Timezone(None)]`` must be a naive datetime. + ``Timezone[...]`` (the ellipsis literal) expresses that the datetime must be + tz-aware but any timezone is allowed. + + You may also pass a specific timezone string or timezone object such as + ``Timezone(timezone.utc)`` or ``Timezone("Africa/Abidjan")`` to express that + you only allow a specific timezone, though we note that this is often + a symptom of poor design. + """ + + tz: Union[str, timezone, EllipsisType, None] + + +@dataclass(frozen=True, **SLOTS) +class Predicate(BaseMetadata): + """``Predicate(func: Callable)`` implies `func(value)` is truthy for valid values. + + Users should prefer statically inspectable metadata, but if you need the full + power and flexibility of arbitrary runtime predicates... here it is. + + We provide a few predefined predicates for common string constraints: + ``IsLower = Predicate(str.islower)``, ``IsUpper = Predicate(str.isupper)``, and + ``IsDigit = Predicate(str.isdigit)``. Users are encouraged to use methods which + can be given special handling, and avoid indirection like ``lambda s: s.lower()``. + + Some libraries might have special logic to handle certain predicates, e.g. by + checking for `str.isdigit` and using its presence to both call custom logic to + enforce digit-only strings, and customise some generated external schema. + + We do not specify what behaviour should be expected for predicates that raise + an exception. For example `Annotated[int, Predicate(str.isdigit)]` might silently + skip invalid constraints, or statically raise an error; or it might try calling it + and then propogate or discard the resulting exception. + """ + + func: Callable[[Any], bool] + + +@dataclass +class Not: + func: Callable[[Any], bool] + + def __call__(self, __v: Any) -> bool: + return not self.func(__v) + + +_StrType = TypeVar("_StrType", bound=str) + +LowerCase = Annotated[_StrType, Predicate(str.islower)] +""" +Return True if the string is a lowercase string, False otherwise. + +A string is lowercase if all cased characters in the string are lowercase and there is at least one cased character in the string. +""" # noqa: E501 +UpperCase = Annotated[_StrType, Predicate(str.isupper)] +""" +Return True if the string is an uppercase string, False otherwise. + +A string is uppercase if all cased characters in the string are uppercase and there is at least one cased character in the string. +""" # noqa: E501 +IsDigits = Annotated[_StrType, Predicate(str.isdigit)] +""" +Return True if the string is a digit string, False otherwise. + +A string is a digit string if all characters in the string are digits and there is at least one character in the string. +""" # noqa: E501 +IsAscii = Annotated[_StrType, Predicate(str.isascii)] +""" +Return True if all characters in the string are ASCII, False otherwise. + +ASCII characters have code points in the range U+0000-U+007F. Empty string is ASCII too. +""" + +_NumericType = TypeVar('_NumericType', bound=Union[SupportsFloat, SupportsIndex]) +IsFinite = Annotated[_NumericType, Predicate(math.isfinite)] +"""Return True if x is neither an infinity nor a NaN, and False otherwise.""" +IsNotFinite = Annotated[_NumericType, Predicate(Not(math.isfinite))] +"""Return True if x is one of infinity or NaN, and False otherwise""" +IsNan = Annotated[_NumericType, Predicate(math.isnan)] +"""Return True if x is a NaN (not a number), and False otherwise.""" +IsNotNan = Annotated[_NumericType, Predicate(Not(math.isnan))] +"""Return True if x is anything but NaN (not a number), and False otherwise.""" +IsInfinite = Annotated[_NumericType, Predicate(math.isinf)] +"""Return True if x is a positive or negative infinity, and False otherwise.""" +IsNotInfinite = Annotated[_NumericType, Predicate(Not(math.isinf))] +"""Return True if x is neither a positive or negative infinity, and False otherwise.""" + +try: + from typing_extensions import DocInfo, doc # type: ignore [attr-defined] +except ImportError: + + @dataclass(frozen=True, **SLOTS) + class DocInfo: # type: ignore [no-redef] + """ " + The return value of doc(), mainly to be used by tools that want to extract the + Annotated documentation at runtime. + """ + + documentation: str + """The documentation string passed to doc().""" + + def doc( + documentation: str, + ) -> DocInfo: + """ + Add documentation to a type annotation inside of Annotated. + + For example: + + >>> def hi(name: Annotated[int, doc("The name of the user")]) -> None: ... + """ + return DocInfo(documentation) diff --git a/lib/jaraco/classes/__init__.py b/lib/annotated_types/py.typed similarity index 100% rename from lib/jaraco/classes/__init__.py rename to lib/annotated_types/py.typed diff --git a/lib/annotated_types/test_cases.py b/lib/annotated_types/test_cases.py new file mode 100644 index 00000000..f54df700 --- /dev/null +++ b/lib/annotated_types/test_cases.py @@ -0,0 +1,147 @@ +import math +import sys +from datetime import date, datetime, timedelta, timezone +from decimal import Decimal +from typing import Any, Dict, Iterable, Iterator, List, NamedTuple, Set, Tuple + +if sys.version_info < (3, 9): + from typing_extensions import Annotated +else: + from typing import Annotated + +import annotated_types as at + + +class Case(NamedTuple): + """ + A test case for `annotated_types`. + """ + + annotation: Any + valid_cases: Iterable[Any] + invalid_cases: Iterable[Any] + + +def cases() -> Iterable[Case]: + # Gt, Ge, Lt, Le + yield Case(Annotated[int, at.Gt(4)], (5, 6, 1000), (4, 0, -1)) + yield Case(Annotated[float, at.Gt(0.5)], (0.6, 0.7, 0.8, 0.9), (0.5, 0.0, -0.1)) + yield Case( + Annotated[datetime, at.Gt(datetime(2000, 1, 1))], + [datetime(2000, 1, 2), datetime(2000, 1, 3)], + [datetime(2000, 1, 1), datetime(1999, 12, 31)], + ) + yield Case( + Annotated[datetime, at.Gt(date(2000, 1, 1))], + [date(2000, 1, 2), date(2000, 1, 3)], + [date(2000, 1, 1), date(1999, 12, 31)], + ) + yield Case( + Annotated[datetime, at.Gt(Decimal('1.123'))], + [Decimal('1.1231'), Decimal('123')], + [Decimal('1.123'), Decimal('0')], + ) + + yield Case(Annotated[int, at.Ge(4)], (4, 5, 6, 1000, 4), (0, -1)) + yield Case(Annotated[float, at.Ge(0.5)], (0.5, 0.6, 0.7, 0.8, 0.9), (0.4, 0.0, -0.1)) + yield Case( + Annotated[datetime, at.Ge(datetime(2000, 1, 1))], + [datetime(2000, 1, 2), datetime(2000, 1, 3)], + [datetime(1998, 1, 1), datetime(1999, 12, 31)], + ) + + yield Case(Annotated[int, at.Lt(4)], (0, -1), (4, 5, 6, 1000, 4)) + yield Case(Annotated[float, at.Lt(0.5)], (0.4, 0.0, -0.1), (0.5, 0.6, 0.7, 0.8, 0.9)) + yield Case( + Annotated[datetime, at.Lt(datetime(2000, 1, 1))], + [datetime(1999, 12, 31), datetime(1999, 12, 31)], + [datetime(2000, 1, 2), datetime(2000, 1, 3)], + ) + + yield Case(Annotated[int, at.Le(4)], (4, 0, -1), (5, 6, 1000)) + yield Case(Annotated[float, at.Le(0.5)], (0.5, 0.0, -0.1), (0.6, 0.7, 0.8, 0.9)) + yield Case( + Annotated[datetime, at.Le(datetime(2000, 1, 1))], + [datetime(2000, 1, 1), datetime(1999, 12, 31)], + [datetime(2000, 1, 2), datetime(2000, 1, 3)], + ) + + # Interval + yield Case(Annotated[int, at.Interval(gt=4)], (5, 6, 1000), (4, 0, -1)) + yield Case(Annotated[int, at.Interval(gt=4, lt=10)], (5, 6), (4, 10, 1000, 0, -1)) + yield Case(Annotated[float, at.Interval(ge=0.5, le=1)], (0.5, 0.9, 1), (0.49, 1.1)) + yield Case( + Annotated[datetime, at.Interval(gt=datetime(2000, 1, 1), le=datetime(2000, 1, 3))], + [datetime(2000, 1, 2), datetime(2000, 1, 3)], + [datetime(2000, 1, 1), datetime(2000, 1, 4)], + ) + + yield Case(Annotated[int, at.MultipleOf(multiple_of=3)], (0, 3, 9), (1, 2, 4)) + yield Case(Annotated[float, at.MultipleOf(multiple_of=0.5)], (0, 0.5, 1, 1.5), (0.4, 1.1)) + + # lengths + + yield Case(Annotated[str, at.MinLen(3)], ('123', '1234', 'x' * 10), ('', '1', '12')) + yield Case(Annotated[str, at.Len(3)], ('123', '1234', 'x' * 10), ('', '1', '12')) + yield Case(Annotated[List[int], at.MinLen(3)], ([1, 2, 3], [1, 2, 3, 4], [1] * 10), ([], [1], [1, 2])) + yield Case(Annotated[List[int], at.Len(3)], ([1, 2, 3], [1, 2, 3, 4], [1] * 10), ([], [1], [1, 2])) + + yield Case(Annotated[str, at.MaxLen(4)], ('', '1234'), ('12345', 'x' * 10)) + yield Case(Annotated[str, at.Len(0, 4)], ('', '1234'), ('12345', 'x' * 10)) + yield Case(Annotated[List[str], at.MaxLen(4)], ([], ['a', 'bcdef'], ['a', 'b', 'c']), (['a'] * 5, ['b'] * 10)) + yield Case(Annotated[List[str], at.Len(0, 4)], ([], ['a', 'bcdef'], ['a', 'b', 'c']), (['a'] * 5, ['b'] * 10)) + + yield Case(Annotated[str, at.Len(3, 5)], ('123', '12345'), ('', '1', '12', '123456', 'x' * 10)) + yield Case(Annotated[str, at.Len(3, 3)], ('123',), ('12', '1234')) + + yield Case(Annotated[Dict[int, int], at.Len(2, 3)], [{1: 1, 2: 2}], [{}, {1: 1}, {1: 1, 2: 2, 3: 3, 4: 4}]) + yield Case(Annotated[Set[int], at.Len(2, 3)], ({1, 2}, {1, 2, 3}), (set(), {1}, {1, 2, 3, 4})) + yield Case(Annotated[Tuple[int, ...], at.Len(2, 3)], ((1, 2), (1, 2, 3)), ((), (1,), (1, 2, 3, 4))) + + # Timezone + + yield Case( + Annotated[datetime, at.Timezone(None)], [datetime(2000, 1, 1)], [datetime(2000, 1, 1, tzinfo=timezone.utc)] + ) + yield Case( + Annotated[datetime, at.Timezone(...)], [datetime(2000, 1, 1, tzinfo=timezone.utc)], [datetime(2000, 1, 1)] + ) + yield Case( + Annotated[datetime, at.Timezone(timezone.utc)], + [datetime(2000, 1, 1, tzinfo=timezone.utc)], + [datetime(2000, 1, 1), datetime(2000, 1, 1, tzinfo=timezone(timedelta(hours=6)))], + ) + yield Case( + Annotated[datetime, at.Timezone('Europe/London')], + [datetime(2000, 1, 1, tzinfo=timezone(timedelta(0), name='Europe/London'))], + [datetime(2000, 1, 1), datetime(2000, 1, 1, tzinfo=timezone(timedelta(hours=6)))], + ) + + # predicate types + + yield Case(at.LowerCase[str], ['abc', 'foobar'], ['', 'A', 'Boom']) + yield Case(at.UpperCase[str], ['ABC', 'DEFO'], ['', 'a', 'abc', 'AbC']) + yield Case(at.IsDigits[str], ['123'], ['', 'ab', 'a1b2']) + yield Case(at.IsAscii[str], ['123', 'foo bar'], ['£100', '😊', 'whatever 👀']) + + yield Case(Annotated[int, at.Predicate(lambda x: x % 2 == 0)], [0, 2, 4], [1, 3, 5]) + + yield Case(at.IsFinite[float], [1.23], [math.nan, math.inf, -math.inf]) + yield Case(at.IsNotFinite[float], [math.nan, math.inf], [1.23]) + yield Case(at.IsNan[float], [math.nan], [1.23, math.inf]) + yield Case(at.IsNotNan[float], [1.23, math.inf], [math.nan]) + yield Case(at.IsInfinite[float], [math.inf], [math.nan, 1.23]) + yield Case(at.IsNotInfinite[float], [math.nan, 1.23], [math.inf]) + + # check stacked predicates + yield Case(at.IsInfinite[Annotated[float, at.Predicate(lambda x: x > 0)]], [math.inf], [-math.inf, 1.23, math.nan]) + + # doc + yield Case(Annotated[int, at.doc("A number")], [1, 2], []) + + # custom GroupedMetadata + class MyCustomGroupedMetadata(at.GroupedMetadata): + def __iter__(self) -> Iterator[at.Predicate]: + yield at.Predicate(lambda x: float(x).is_integer()) + + yield Case(Annotated[float, MyCustomGroupedMetadata()], [0, 2.0], [0.01, 1.5]) diff --git a/lib/autocommand/autoasync.py b/lib/autocommand/autoasync.py index 3c8ebdcf..688f7e05 100644 --- a/lib/autocommand/autoasync.py +++ b/lib/autocommand/autoasync.py @@ -20,7 +20,7 @@ from functools import wraps from inspect import signature -def _launch_forever_coro(coro, args, kwargs, loop): +async def _run_forever_coro(coro, args, kwargs, loop): ''' This helper function launches an async main function that was tagged with forever=True. There are two possibilities: @@ -48,7 +48,7 @@ def _launch_forever_coro(coro, args, kwargs, loop): # forever=True feature from autoasync at some point in the future. thing = coro(*args, **kwargs) if iscoroutine(thing): - loop.create_task(thing) + await thing def autoasync(coro=None, *, loop=None, forever=False, pass_loop=False): @@ -127,7 +127,9 @@ def autoasync(coro=None, *, loop=None, forever=False, pass_loop=False): args, kwargs = bound_args.args, bound_args.kwargs if forever: - _launch_forever_coro(coro, args, kwargs, local_loop) + local_loop.create_task(_run_forever_coro( + coro, args, kwargs, local_loop + )) local_loop.run_forever() else: return local_loop.run_until_complete(coro(*args, **kwargs)) diff --git a/lib/cherrypy/_cplogging.py b/lib/cherrypy/_cplogging.py index 151d3b40..bce1c87b 100644 --- a/lib/cherrypy/_cplogging.py +++ b/lib/cherrypy/_cplogging.py @@ -452,6 +452,6 @@ class WSGIErrorHandler(logging.Handler): class LazyRfc3339UtcTime(object): def __str__(self): - """Return now() in RFC3339 UTC Format.""" - now = datetime.datetime.now() - return now.isoformat('T') + 'Z' + """Return utcnow() in RFC3339 UTC Format.""" + iso_formatted_now = datetime.datetime.utcnow().isoformat('T') + return f'{iso_formatted_now!s}Z' diff --git a/lib/cherrypy/lib/cptools.py b/lib/cherrypy/lib/cptools.py index 613a8995..13b4c567 100644 --- a/lib/cherrypy/lib/cptools.py +++ b/lib/cherrypy/lib/cptools.py @@ -622,13 +622,15 @@ def autovary(ignore=None, debug=False): def convert_params(exception=ValueError, error=400): - """Convert request params based on function annotations, with error handling. + """Convert request params based on function annotations. - exception - Exception class to catch. + This function also processes errors that are subclasses of ``exception``. - status - The HTTP error code to return to the client on failure. + :param BaseException exception: Exception class to catch. + :type exception: BaseException + + :param error: The HTTP status code to return to the client on failure. + :type error: int """ request = cherrypy.serving.request types = request.handler.callable.__annotations__ diff --git a/lib/cherrypy/lib/profiler.py b/lib/cherrypy/lib/profiler.py index fccf2eb8..7182278a 100644 --- a/lib/cherrypy/lib/profiler.py +++ b/lib/cherrypy/lib/profiler.py @@ -47,7 +47,9 @@ try: import pstats def new_func_strip_path(func_name): - """Make profiler output more readable by adding `__init__` modules' parents + """Add ``__init__`` modules' parents. + + This makes the profiler output more readable. """ filename, line, name = func_name if filename.endswith('__init__.py'): diff --git a/lib/cherrypy/lib/reprconf.py b/lib/cherrypy/lib/reprconf.py index 76381d7b..536b9417 100644 --- a/lib/cherrypy/lib/reprconf.py +++ b/lib/cherrypy/lib/reprconf.py @@ -188,7 +188,7 @@ class Parser(configparser.ConfigParser): def dict_from_file(self, file): if hasattr(file, 'read'): - self.readfp(file) + self.read_file(file) else: self.read(file) return self.as_dict() diff --git a/lib/cherrypy/lib/static.py b/lib/cherrypy/lib/static.py index 66a5a947..c1ad95f3 100644 --- a/lib/cherrypy/lib/static.py +++ b/lib/cherrypy/lib/static.py @@ -1,19 +1,18 @@ """Module with helpers for serving static files.""" +import mimetypes import os import platform import re import stat -import mimetypes -import urllib.parse import unicodedata - +import urllib.parse from email.generator import _make_boundary as make_boundary from io import UnsupportedOperation import cherrypy from cherrypy._cpcompat import ntob -from cherrypy.lib import cptools, httputil, file_generator_limited +from cherrypy.lib import cptools, file_generator_limited, httputil def _setup_mimetypes(): @@ -185,7 +184,10 @@ def serve_fileobj(fileobj, content_type=None, disposition=None, name=None, def _serve_fileobj(fileobj, content_type, content_length, debug=False): - """Internal. Set response.body to the given file object, perhaps ranged.""" + """Set ``response.body`` to the given file object, perhaps ranged. + + Internal helper. + """ response = cherrypy.serving.response # HTTP/1.0 didn't have Range/Accept-Ranges headers, or the 206 code diff --git a/lib/cherrypy/process/wspbus.py b/lib/cherrypy/process/wspbus.py index 1d2789b1..a60cd51e 100644 --- a/lib/cherrypy/process/wspbus.py +++ b/lib/cherrypy/process/wspbus.py @@ -494,7 +494,7 @@ class Bus(object): "Cannot reconstruct command from '-c'. " 'Ref: https://github.com/cherrypy/cherrypy/issues/1545') except AttributeError: - """It looks Py_GetArgcArgv is completely absent in some environments + """It looks Py_GetArgcArgv's completely absent in some environments It is known, that there's no Py_GetArgcArgv in MS Windows and ``ctypes`` module is completely absent in Google AppEngine diff --git a/lib/cherrypy/test/test_http.py b/lib/cherrypy/test/test_http.py index a955be43..9a7e9331 100644 --- a/lib/cherrypy/test/test_http.py +++ b/lib/cherrypy/test/test_http.py @@ -136,6 +136,9 @@ class HTTPTests(helper.CPWebCase): self.assertStatus(200) self.assertBody(b'Hello world!') + response.close() + c.close() + # Now send a message that has no Content-Length, but does send a body. # Verify that CP times out the socket and responds # with 411 Length Required. @@ -159,6 +162,9 @@ class HTTPTests(helper.CPWebCase): self.status = str(response.status) self.assertStatus(411) + response.close() + c.close() + def test_post_multipart(self): alphabet = 'abcdefghijklmnopqrstuvwxyz' # generate file contents for a large post @@ -184,6 +190,9 @@ class HTTPTests(helper.CPWebCase): parts = ['%s * 65536' % ch for ch in alphabet] self.assertBody(', '.join(parts)) + response.close() + c.close() + def test_post_filename_with_special_characters(self): """Testing that we can handle filenames with special characters. @@ -217,6 +226,9 @@ class HTTPTests(helper.CPWebCase): self.assertStatus(200) self.assertBody(fname) + response.close() + c.close() + def test_malformed_request_line(self): if getattr(cherrypy.server, 'using_apache', False): return self.skip('skipped due to known Apache differences...') @@ -264,6 +276,9 @@ class HTTPTests(helper.CPWebCase): self.body = response.fp.read(20) self.assertBody('Illegal header line.') + response.close() + c.close() + def test_http_over_https(self): if self.scheme != 'https': return self.skip('skipped (not running HTTPS)... ') diff --git a/lib/cherrypy/test/test_iterator.py b/lib/cherrypy/test/test_iterator.py index 6600a78d..5bad59be 100644 --- a/lib/cherrypy/test/test_iterator.py +++ b/lib/cherrypy/test/test_iterator.py @@ -150,6 +150,8 @@ class IteratorTest(helper.CPWebCase): self.assertStatus(200) self.assertBody('0') + itr_conn.close() + # Now we do the same check with streaming - some classes will # be automatically closed, while others cannot. stream_counts = {} diff --git a/lib/cherrypy/test/test_logging.py b/lib/cherrypy/test/test_logging.py index 2d4aa56f..49d41d0a 100644 --- a/lib/cherrypy/test/test_logging.py +++ b/lib/cherrypy/test/test_logging.py @@ -1,5 +1,6 @@ """Basic tests for the CherryPy core: request handling.""" +import datetime import logging from cheroot.test import webtest @@ -197,6 +198,33 @@ def test_custom_log_format(log_tracker, monkeypatch, server): ) +def test_utc_in_timez(monkeypatch): + """Test that ``LazyRfc3339UtcTime`` is rendered as ``str`` using UTC timestamp.""" + utcoffset8_local_time_in_naive_utc = ( + datetime.datetime( + year=2020, + month=1, + day=1, + hour=1, + minute=23, + second=45, + tzinfo=datetime.timezone(datetime.timedelta(hours=8)), + ) + .astimezone(datetime.timezone.utc) + .replace(tzinfo=None) + ) + + class mock_datetime: + @classmethod + def utcnow(cls): + return utcoffset8_local_time_in_naive_utc + + monkeypatch.setattr('datetime.datetime', mock_datetime) + rfc3339_utc_time = str(cherrypy._cplogging.LazyRfc3339UtcTime()) + expected_time = '2019-12-31T17:23:45Z' + assert rfc3339_utc_time == expected_time + + def test_timez_log_format(log_tracker, monkeypatch, server): """Test a customized access_log_format string, which is a feature of _cplogging.LogManager.access().""" diff --git a/lib/inflect/__init__.py b/lib/inflect/__init__.py index 78d2e33c..b638c6b8 100644 --- a/lib/inflect/__init__.py +++ b/lib/inflect/__init__.py @@ -3,8 +3,6 @@ inflect: english language inflection - correctly generate plurals, ordinals, indefinite articles - convert numbers to words -Copyright (C) 2010 Paul Dyson - Based upon the Perl module `Lingua::EN::Inflect `_. @@ -70,11 +68,16 @@ from typing import ( cast, Any, ) +from typing_extensions import Literal from numbers import Number -from pydantic import Field, validate_arguments -from pydantic.typing import Annotated +from pydantic import Field +from typing_extensions import Annotated + + +from .compat.pydantic1 import validate_call +from .compat.pydantic import same_method class UnknownClassicalModeError(Exception): @@ -105,14 +108,6 @@ class BadGenderError(Exception): pass -STDOUT_ON = False - - -def print3(txt: str) -> None: - if STDOUT_ON: - print(txt) - - def enclose(s: str) -> str: return f"(?:{s})" @@ -1727,66 +1722,44 @@ plverb_irregular_pres = { "is": "are", "was": "were", "were": "were", - "was": "were", - "have": "have", "have": "have", "has": "have", "do": "do", - "do": "do", "does": "do", } plverb_ambiguous_pres = { - "act": "act", "act": "act", "acts": "act", "blame": "blame", - "blame": "blame", "blames": "blame", "can": "can", - "can": "can", - "can": "can", "must": "must", - "must": "must", - "must": "must", - "fly": "fly", "fly": "fly", "flies": "fly", "copy": "copy", - "copy": "copy", "copies": "copy", "drink": "drink", - "drink": "drink", "drinks": "drink", "fight": "fight", - "fight": "fight", "fights": "fight", "fire": "fire", - "fire": "fire", "fires": "fire", "like": "like", - "like": "like", "likes": "like", "look": "look", - "look": "look", "looks": "look", "make": "make", - "make": "make", "makes": "make", "reach": "reach", - "reach": "reach", "reaches": "reach", "run": "run", - "run": "run", "runs": "run", "sink": "sink", - "sink": "sink", "sinks": "sink", "sleep": "sleep", - "sleep": "sleep", "sleeps": "sleep", "view": "view", - "view": "view", "views": "view", } @@ -1854,7 +1827,7 @@ pl_adj_poss_keys = re.compile(fr"^({enclose('|'.join(pl_adj_poss))})$", re.IGNOR A_abbrev = re.compile( r""" -(?! FJO | [HLMNS]Y. | RY[EO] | SQU +^(?! FJO | [HLMNS]Y. | RY[EO] | SQU | ( F[LR]? | [HL] | MN? | N | RH? | S[CHKLMNPTVW]? | X(YL)?) [AEIOU]) [FHLMNRSX][A-Z] """, @@ -2053,15 +2026,14 @@ Falsish = Any # ideally, falsish would only validate on bool(value) is False class engine: def __init__(self) -> None: - self.classical_dict = def_classical.copy() self.persistent_count: Optional[int] = None self.mill_count = 0 - self.pl_sb_user_defined: List[str] = [] - self.pl_v_user_defined: List[str] = [] - self.pl_adj_user_defined: List[str] = [] - self.si_sb_user_defined: List[str] = [] - self.A_a_user_defined: List[str] = [] + self.pl_sb_user_defined: List[Optional[Word]] = [] + self.pl_v_user_defined: List[Optional[Word]] = [] + self.pl_adj_user_defined: List[Optional[Word]] = [] + self.si_sb_user_defined: List[Optional[Word]] = [] + self.A_a_user_defined: List[Optional[Word]] = [] self.thegender = "neuter" self.__number_args: Optional[Dict[str, str]] = None @@ -2073,28 +2045,8 @@ class engine: def _number_args(self, val): self.__number_args = val - deprecated_methods = dict( - pl="plural", - plnoun="plural_noun", - plverb="plural_verb", - pladj="plural_adj", - sinoun="single_noun", - prespart="present_participle", - numwords="number_to_words", - plequal="compare", - plnounequal="compare_nouns", - plverbequal="compare_verbs", - pladjequal="compare_adjs", - wordlist="join", - ) - - def __getattr__(self, meth): - if meth in self.deprecated_methods: - print3(f"{meth}() deprecated, use {self.deprecated_methods[meth]}()") - raise DeprecationWarning - raise AttributeError - - def defnoun(self, singular: str, plural: str) -> int: + @validate_call + def defnoun(self, singular: Optional[Word], plural: Optional[Word]) -> int: """ Set the noun plural of singular to plural. @@ -2105,7 +2057,16 @@ class engine: self.si_sb_user_defined.extend((plural, singular)) return 1 - def defverb(self, s1: str, p1: str, s2: str, p2: str, s3: str, p3: str) -> int: + @validate_call + def defverb( + self, + s1: Optional[Word], + p1: Optional[Word], + s2: Optional[Word], + p2: Optional[Word], + s3: Optional[Word], + p3: Optional[Word], + ) -> int: """ Set the verb plurals for s1, s2 and s3 to p1, p2 and p3 respectively. @@ -2121,7 +2082,8 @@ class engine: self.pl_v_user_defined.extend((s1, p1, s2, p2, s3, p3)) return 1 - def defadj(self, singular: str, plural: str) -> int: + @validate_call + def defadj(self, singular: Optional[Word], plural: Optional[Word]) -> int: """ Set the adjective plural of singular to plural. @@ -2131,7 +2093,8 @@ class engine: self.pl_adj_user_defined.extend((singular, plural)) return 1 - def defa(self, pattern: str) -> int: + @validate_call + def defa(self, pattern: Optional[Word]) -> int: """ Define the indefinite article as 'a' for words matching pattern. @@ -2140,7 +2103,8 @@ class engine: self.A_a_user_defined.extend((pattern, "a")) return 1 - def defan(self, pattern: str) -> int: + @validate_call + def defan(self, pattern: Optional[Word]) -> int: """ Define the indefinite article as 'an' for words matching pattern. @@ -2149,7 +2113,7 @@ class engine: self.A_a_user_defined.extend((pattern, "an")) return 1 - def checkpat(self, pattern: Optional[str]) -> None: + def checkpat(self, pattern: Optional[Word]) -> None: """ check for errors in a regex pattern """ @@ -2158,16 +2122,15 @@ class engine: try: re.match(pattern, "") except re.error: - print3(f"\nBad user-defined singular pattern:\n\t{pattern}\n") - raise BadUserDefinedPatternError + raise BadUserDefinedPatternError(pattern) - def checkpatplural(self, pattern: str) -> None: + def checkpatplural(self, pattern: Optional[Word]) -> None: """ check for errors in a regex replace pattern """ return - @validate_arguments + @validate_call def ud_match(self, word: Word, wordlist: Sequence[Optional[Word]]) -> Optional[str]: for i in range(len(wordlist) - 2, -2, -2): # backwards through even elements mo = re.search(fr"^{wordlist[i]}$", word, re.IGNORECASE) @@ -2307,7 +2270,7 @@ class engine: # 0. PERFORM GENERAL INFLECTIONS IN A STRING - @validate_arguments + @validate_call def inflect(self, text: Word) -> str: """ Perform inflections in a string. @@ -2384,7 +2347,7 @@ class engine: else: return "", "", "" - @validate_arguments + @validate_call def plural(self, text: Word, count: Optional[Union[str, int, Any]] = None) -> str: """ Return the plural of text. @@ -2408,7 +2371,7 @@ class engine: ) return f"{pre}{plural}{post}" - @validate_arguments + @validate_call def plural_noun( self, text: Word, count: Optional[Union[str, int, Any]] = None ) -> str: @@ -2429,7 +2392,7 @@ class engine: plural = self.postprocess(word, self._plnoun(word, count)) return f"{pre}{plural}{post}" - @validate_arguments + @validate_call def plural_verb( self, text: Word, count: Optional[Union[str, int, Any]] = None ) -> str: @@ -2453,7 +2416,7 @@ class engine: ) return f"{pre}{plural}{post}" - @validate_arguments + @validate_call def plural_adj( self, text: Word, count: Optional[Union[str, int, Any]] = None ) -> str: @@ -2474,7 +2437,7 @@ class engine: plural = self.postprocess(word, self._pl_special_adjective(word, count) or word) return f"{pre}{plural}{post}" - @validate_arguments + @validate_call def compare(self, word1: Word, word2: Word) -> Union[str, bool]: """ compare word1 and word2 for equality regardless of plurality @@ -2497,15 +2460,15 @@ class engine: >>> compare('egg', '') Traceback (most recent call last): ... - pydantic.error_wrappers.ValidationError: 1 validation error for Compare - word2 - ensure this value has at least 1 characters... + pydantic...ValidationError: ... + ... + ...at least 1 characters... """ norms = self.plural_noun, self.plural_verb, self.plural_adj results = (self._plequal(word1, word2, norm) for norm in norms) return next(filter(None, results), False) - @validate_arguments + @validate_call def compare_nouns(self, word1: Word, word2: Word) -> Union[str, bool]: """ compare word1 and word2 for equality regardless of plurality @@ -2521,7 +2484,7 @@ class engine: """ return self._plequal(word1, word2, self.plural_noun) - @validate_arguments + @validate_call def compare_verbs(self, word1: Word, word2: Word) -> Union[str, bool]: """ compare word1 and word2 for equality regardless of plurality @@ -2537,7 +2500,7 @@ class engine: """ return self._plequal(word1, word2, self.plural_verb) - @validate_arguments + @validate_call def compare_adjs(self, word1: Word, word2: Word) -> Union[str, bool]: """ compare word1 and word2 for equality regardless of plurality @@ -2553,13 +2516,13 @@ class engine: """ return self._plequal(word1, word2, self.plural_adj) - @validate_arguments + @validate_call def singular_noun( self, text: Word, count: Optional[Union[int, str, Any]] = None, gender: Optional[str] = None, - ) -> Union[str, bool]: + ) -> Union[str, Literal[False]]: """ Return the singular of text, where text is a plural noun. @@ -2611,12 +2574,12 @@ class engine: return "s:p" self.classical_dict = classval.copy() - if pl == self.plural or pl == self.plural_noun: + if same_method(pl, self.plural) or same_method(pl, self.plural_noun): if self._pl_check_plurals_N(word1, word2): return "p:p" if self._pl_check_plurals_N(word2, word1): return "p:p" - if pl == self.plural or pl == self.plural_adj: + if same_method(pl, self.plural) or same_method(pl, self.plural_adj): if self._pl_check_plurals_adj(word1, word2): return "p:p" return False @@ -3266,11 +3229,11 @@ class engine: if words.last in si_sb_irregular_caps: llen = len(words.last) - return "{}{}".format(word[:-llen], si_sb_irregular_caps[words.last]) + return f"{word[:-llen]}{si_sb_irregular_caps[words.last]}" if words.last.lower() in si_sb_irregular: llen = len(words.last.lower()) - return "{}{}".format(word[:-llen], si_sb_irregular[words.last.lower()]) + return f"{word[:-llen]}{si_sb_irregular[words.last.lower()]}" dash_split = words.lowered.split("-") if (" ".join(dash_split[-2:])).lower() in si_sb_irregular_compound: @@ -3341,7 +3304,6 @@ class engine: # HANDLE INCOMPLETELY ASSIMILATED IMPORTS if self.classical_dict["ancient"]: - if words.lowered[-6:] == "trices": return word[:-3] + "x" if words.lowered[-4:] in ("eaux", "ieux"): @@ -3459,7 +3421,6 @@ class engine: # HANDLE ...o if words.lowered[-2:] == "os": - if words.last.lower() in si_sb_U_o_os_complete: return word[:-1] @@ -3489,7 +3450,7 @@ class engine: # ADJECTIVES - @validate_arguments + @validate_call def a(self, text: Word, count: Optional[Union[int, str, Any]] = 1) -> str: """ Return the appropriate indefinite article followed by text. @@ -3570,7 +3531,7 @@ class engine: # 2. TRANSLATE ZERO-QUANTIFIED $word TO "no plural($word)" - @validate_arguments + @validate_call def no(self, text: Word, count: Optional[Union[int, str]] = None) -> str: """ If count is 0, no, zero or nil, return 'no' followed by the plural @@ -3608,7 +3569,7 @@ class engine: # PARTICIPLES - @validate_arguments + @validate_call def present_participle(self, word: Word) -> str: """ Return the present participle for word. @@ -3627,31 +3588,31 @@ class engine: # NUMERICAL INFLECTIONS - @validate_arguments - def ordinal(self, num: Union[int, Word]) -> str: # noqa: C901 + @validate_call(config=dict(arbitrary_types_allowed=True)) + def ordinal(self, num: Union[Number, Word]) -> str: """ Return the ordinal of num. - num can be an integer or text - - e.g. ordinal(1) returns '1st' - ordinal('one') returns 'first' - + >>> ordinal = engine().ordinal + >>> ordinal(1) + '1st' + >>> ordinal('one') + 'first' """ if DIGIT.match(str(num)): - if isinstance(num, (int, float)): + if isinstance(num, (float, int)) and int(num) == num: n = int(num) else: if "." in str(num): try: # numbers after decimal, # so only need last one for ordinal - n = int(num[-1]) + n = int(str(num)[-1]) except ValueError: # ends with '.', so need to use whole string - n = int(num[:-1]) + n = int(str(num)[:-1]) else: - n = int(num) + n = int(num) # type: ignore try: post = nth[n % 100] except KeyError: @@ -3660,7 +3621,7 @@ class engine: else: # Mad props to Damian Conway (?) whose ordinal() # algorithm is type-bendy enough to foil MyPy - str_num: str = num # type: ignore[assignment] + str_num: str = num # type: ignore[assignment] mo = ordinal_suff.search(str_num) if mo: post = ordinal[mo.group(1)] @@ -3671,7 +3632,6 @@ class engine: def millfn(self, ind: int = 0) -> str: if ind > len(mill) - 1: - print3("number out of range") raise NumOutOfRangeError return mill[ind] @@ -3787,7 +3747,7 @@ class engine: num = ONE_DIGIT_WORD.sub(self.unitsub, num, 1) return num - @validate_arguments(config=dict(arbitrary_types_allowed=True)) # noqa: C901 + @validate_call(config=dict(arbitrary_types_allowed=True)) # noqa: C901 def number_to_words( # noqa: C901 self, num: Union[Number, Word], @@ -3939,7 +3899,7 @@ class engine: # Join words with commas and a trailing 'and' (when appropriate)... - @validate_arguments + @validate_call def join( self, words: Optional[Sequence[Word]], diff --git a/lib/inflect/compat/__init__.py b/lib/inflect/compat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/inflect/compat/pydantic.py b/lib/inflect/compat/pydantic.py new file mode 100644 index 00000000..d777564a --- /dev/null +++ b/lib/inflect/compat/pydantic.py @@ -0,0 +1,19 @@ +class ValidateCallWrapperWrapper: + def __init__(self, wrapped): + self.orig = wrapped + + def __eq__(self, other): + return self.raw_function == other.raw_function + + @property + def raw_function(self): + return getattr(self.orig, 'raw_function') or self.orig + + +def same_method(m1, m2) -> bool: + """ + Return whether m1 and m2 are the same method. + + Workaround for pydantic/pydantic#6390. + """ + return ValidateCallWrapperWrapper(m1) == ValidateCallWrapperWrapper(m2) diff --git a/lib/inflect/compat/pydantic1.py b/lib/inflect/compat/pydantic1.py new file mode 100644 index 00000000..8262fdcf --- /dev/null +++ b/lib/inflect/compat/pydantic1.py @@ -0,0 +1,8 @@ +try: + from pydantic import validate_call # type: ignore +except ImportError: + # Pydantic 1 + from pydantic import validate_arguments as validate_call # type: ignore + + +__all__ = ['validate_call'] diff --git a/lib/jaraco/collections.py b/lib/jaraco/collections/__init__.py similarity index 88% rename from lib/jaraco/collections.py rename to lib/jaraco/collections/__init__.py index db89b122..abedf002 100644 --- a/lib/jaraco/collections.py +++ b/lib/jaraco/collections/__init__.py @@ -5,23 +5,49 @@ import itertools import copy import functools import random +from collections.abc import Container, Iterable, Mapping +from typing import Callable, Union -from jaraco.classes.properties import NonDataProperty import jaraco.text +_Matchable = Union[Callable, Container, Iterable, re.Pattern] + + +def _dispatch(obj: _Matchable) -> Callable: + # can't rely on singledispatch for Union[Container, Iterable] + # due to ambiguity + # (https://peps.python.org/pep-0443/#abstract-base-classes). + if isinstance(obj, re.Pattern): + return obj.fullmatch + if not isinstance(obj, Callable): # type: ignore + if not isinstance(obj, Container): + obj = set(obj) # type: ignore + obj = obj.__contains__ + return obj # type: ignore + + class Projection(collections.abc.Mapping): """ Project a set of keys over a mapping >>> sample = {'a': 1, 'b': 2, 'c': 3} >>> prj = Projection(['a', 'c', 'd'], sample) - >>> prj == {'a': 1, 'c': 3} + >>> dict(prj) + {'a': 1, 'c': 3} + + Projection also accepts an iterable or callable or pattern. + + >>> iter_prj = Projection(iter('acd'), sample) + >>> call_prj = Projection(lambda k: ord(k) in (97, 99, 100), sample) + >>> pat_prj = Projection(re.compile(r'[acd]'), sample) + >>> prj == iter_prj == call_prj == pat_prj True Keys should only appear if they were specified and exist in the space. + Order is retained. - >>> sorted(list(prj.keys())) + >>> list(prj) ['a', 'c'] Attempting to access a key not in the projection @@ -36,119 +62,58 @@ class Projection(collections.abc.Mapping): >>> target = {'a': 2, 'b': 2} >>> target.update(prj) - >>> target == {'a': 1, 'b': 2, 'c': 3} - True + >>> target + {'a': 1, 'b': 2, 'c': 3} - Also note that Projection keeps a reference to the original dict, so - if you modify the original dict, that could modify the Projection. + Projection keeps a reference to the original dict, so + modifying the original dict may modify the Projection. >>> del sample['a'] >>> dict(prj) {'c': 3} """ - def __init__(self, keys, space): - self._keys = tuple(keys) + def __init__(self, keys: _Matchable, space: Mapping): + self._match = _dispatch(keys) self._space = space def __getitem__(self, key): - if key not in self._keys: + if not self._match(key): raise KeyError(key) return self._space[key] + def _keys_resolved(self): + return filter(self._match, self._space) + def __iter__(self): - return iter(set(self._keys).intersection(self._space)) + return self._keys_resolved() def __len__(self): - return len(tuple(iter(self))) + return len(tuple(self._keys_resolved())) -class DictFilter(collections.abc.Mapping): +class Mask(Projection): """ - Takes a dict, and simulates a sub-dict based on the keys. + The inverse of a :class:`Projection`, masking out keys. >>> sample = {'a': 1, 'b': 2, 'c': 3} - >>> filtered = DictFilter(sample, ['a', 'c']) - >>> filtered == {'a': 1, 'c': 3} - True - >>> set(filtered.values()) == {1, 3} - True - >>> set(filtered.items()) == {('a', 1), ('c', 3)} - True - - One can also filter by a regular expression pattern - - >>> sample['d'] = 4 - >>> sample['ef'] = 5 - - Here we filter for only single-character keys - - >>> filtered = DictFilter(sample, include_pattern='.$') - >>> filtered == {'a': 1, 'b': 2, 'c': 3, 'd': 4} - True - - >>> filtered['e'] - Traceback (most recent call last): - ... - KeyError: 'e' - - >>> 'e' in filtered - False - - Pattern is useful for excluding keys with a prefix. - - >>> filtered = DictFilter(sample, include_pattern=r'(?![ace])') - >>> dict(filtered) - {'b': 2, 'd': 4} - - Also note that DictFilter keeps a reference to the original dict, so - if you modify the original dict, that could modify the filtered dict. - - >>> del sample['d'] - >>> dict(filtered) + >>> msk = Mask(['a', 'c', 'd'], sample) + >>> dict(msk) {'b': 2} """ - def __init__(self, dict, include_keys=[], include_pattern=None): - self.dict = dict - self.specified_keys = set(include_keys) - if include_pattern is not None: - self.include_pattern = re.compile(include_pattern) - else: - # for performance, replace the pattern_keys property - self.pattern_keys = set() - - def get_pattern_keys(self): - keys = filter(self.include_pattern.match, self.dict.keys()) - return set(keys) - - pattern_keys = NonDataProperty(get_pattern_keys) - - @property - def include_keys(self): - return self.specified_keys | self.pattern_keys - - def __getitem__(self, i): - if i not in self.include_keys: - raise KeyError(i) - return self.dict[i] - - def __iter__(self): - return filter(self.include_keys.__contains__, self.dict.keys()) - - def __len__(self): - return len(list(self)) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # self._match = compose(operator.not_, self._match) + self._match = lambda key, orig=self._match: not orig(key) def dict_map(function, dictionary): """ - dict_map is much like the built-in function map. It takes a dictionary - and applys a function to the values of that dictionary, returning a - new dictionary with the mapped values in the original keys. + Return a new dict with function applied to values of dictionary. - >>> d = dict_map(lambda x:x+1, dict(a=1, b=2)) - >>> d == dict(a=2,b=3) - True + >>> dict_map(lambda x: x+1, dict(a=1, b=2)) + {'a': 2, 'b': 3} """ return dict((key, function(value)) for key, value in dictionary.items()) @@ -164,7 +129,7 @@ class RangeMap(dict): One may supply keyword parameters to be passed to the sort function used to sort keys (i.e. key, reverse) as sort_params. - Let's create a map that maps 1-3 -> 'a', 4-6 -> 'b' + Create a map that maps 1-3 -> 'a', 4-6 -> 'b' >>> r = RangeMap({3: 'a', 6: 'b'}) # boy, that was easy >>> r[1], r[2], r[3], r[4], r[5], r[6] @@ -176,7 +141,7 @@ class RangeMap(dict): >>> r[4.5] 'b' - But you'll notice that the way rangemap is defined, it must be open-ended + Notice that the way rangemap is defined, it must be open-ended on one side. >>> r[0] @@ -279,7 +244,7 @@ class RangeMap(dict): return (sorted_keys[RangeMap.first_item], sorted_keys[RangeMap.last_item]) # some special values for the RangeMap - undefined_value = type(str('RangeValueUndefined'), (), {})() + undefined_value = type('RangeValueUndefined', (), {})() class Item(int): "RangeMap Item" @@ -294,7 +259,7 @@ def __identity(x): def sorted_items(d, key=__identity, reverse=False): """ - Return the items of the dictionary sorted by the keys + Return the items of the dictionary sorted by the keys. >>> sample = dict(foo=20, bar=42, baz=10) >>> tuple(sorted_items(sample)) @@ -307,6 +272,7 @@ def sorted_items(d, key=__identity, reverse=False): >>> tuple(sorted_items(sample, reverse=True)) (('foo', 20), ('baz', 10), ('bar', 42)) """ + # wrap the key func so it operates on the first element of each item def pairkey_key(item): return key(item[0]) @@ -475,7 +441,7 @@ class ItemsAsAttributes: Mix-in class to enable a mapping object to provide items as attributes. - >>> C = type(str('C'), (dict, ItemsAsAttributes), dict()) + >>> C = type('C', (dict, ItemsAsAttributes), dict()) >>> i = C() >>> i['foo'] = 'bar' >>> i.foo @@ -504,7 +470,7 @@ class ItemsAsAttributes: >>> missing_func = lambda self, key: 'missing item' >>> C = type( - ... str('C'), + ... 'C', ... (dict, ItemsAsAttributes), ... dict(__missing__ = missing_func), ... ) diff --git a/lib/jaraco/collections/py.typed b/lib/jaraco/collections/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/lib/jaraco/context.py b/lib/jaraco/context.py index 87a4e3dc..b0d1ef37 100644 --- a/lib/jaraco/context.py +++ b/lib/jaraco/context.py @@ -5,10 +5,18 @@ import functools import tempfile import shutil import operator +import warnings @contextlib.contextmanager def pushd(dir): + """ + >>> tmp_path = getfixture('tmp_path') + >>> with pushd(tmp_path): + ... assert os.getcwd() == os.fspath(tmp_path) + >>> assert os.getcwd() != os.fspath(tmp_path) + """ + orig = os.getcwd() os.chdir(dir) try: @@ -29,6 +37,8 @@ def tarball_context(url, target_dir=None, runner=None, pushd=pushd): target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '') if runner is None: runner = functools.partial(subprocess.check_call, shell=True) + else: + warnings.warn("runner parameter is deprecated", DeprecationWarning) # In the tar command, use --strip-components=1 to strip the first path and # then # use -C to cause the files to be extracted to {target_dir}. This ensures @@ -48,6 +58,15 @@ def tarball_context(url, target_dir=None, runner=None, pushd=pushd): def infer_compression(url): """ Given a URL or filename, infer the compression code for tar. + + >>> infer_compression('http://foo/bar.tar.gz') + 'z' + >>> infer_compression('http://foo/bar.tgz') + 'z' + >>> infer_compression('file.bz') + 'j' + >>> infer_compression('file.xz') + 'J' """ # cheat and just assume it's the last two characters compression_indicator = url[-2:] @@ -61,6 +80,12 @@ def temp_dir(remover=shutil.rmtree): """ Create a temporary directory context. Pass a custom remover to override the removal behavior. + + >>> import pathlib + >>> with temp_dir() as the_dir: + ... assert os.path.isdir(the_dir) + ... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents') + >>> assert not os.path.exists(the_dir) """ temp_dir = tempfile.mkdtemp() try: @@ -90,6 +115,12 @@ def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir): @contextlib.contextmanager def null(): + """ + A null context suitable to stand in for a meaningful context. + + >>> with null() as value: + ... assert value is None + """ yield @@ -112,6 +143,10 @@ class ExceptionTrap: ... raise ValueError("1 + 1 is not 3") >>> bool(trap) True + >>> trap.value + ValueError('1 + 1 is not 3') + >>> trap.tb + >>> with ExceptionTrap(ValueError) as trap: ... raise Exception() @@ -211,3 +246,43 @@ class suppress(contextlib.suppress, contextlib.ContextDecorator): ... {}[''] >>> key_error() """ + + +class on_interrupt(contextlib.ContextDecorator): + """ + Replace a KeyboardInterrupt with SystemExit(1) + + >>> def do_interrupt(): + ... raise KeyboardInterrupt() + >>> on_interrupt('error')(do_interrupt)() + Traceback (most recent call last): + ... + SystemExit: 1 + >>> on_interrupt('error', code=255)(do_interrupt)() + Traceback (most recent call last): + ... + SystemExit: 255 + >>> on_interrupt('suppress')(do_interrupt)() + >>> with __import__('pytest').raises(KeyboardInterrupt): + ... on_interrupt('ignore')(do_interrupt)() + """ + + def __init__( + self, + action='error', + # py3.7 compat + # /, + code=1, + ): + self.action = action + self.code = code + + def __enter__(self): + return self + + def __exit__(self, exctype, excinst, exctb): + if exctype is not KeyboardInterrupt or self.action == 'ignore': + return + elif self.action == 'error': + raise SystemExit(self.code) from excinst + return self.action == 'suppress' diff --git a/lib/jaraco/functools.py b/lib/jaraco/functools/__init__.py similarity index 84% rename from lib/jaraco/functools.py rename to lib/jaraco/functools/__init__.py index 43c009f9..ca6c22fa 100644 --- a/lib/jaraco/functools.py +++ b/lib/jaraco/functools/__init__.py @@ -1,4 +1,4 @@ -import collections +import collections.abc import functools import inspect import itertools @@ -9,11 +9,6 @@ import warnings import more_itertools -from typing import Callable, TypeVar - - -CallableT = TypeVar("CallableT", bound=Callable[..., object]) - def compose(*funcs): """ @@ -39,24 +34,6 @@ def compose(*funcs): return functools.reduce(compose_two, funcs) -def method_caller(method_name, *args, **kwargs): - """ - Return a function that will call a named method on the - target object with optional positional and keyword - arguments. - - >>> lower = method_caller('lower') - >>> lower('MyString') - 'mystring' - """ - - def call_method(target): - func = getattr(target, method_name) - return func(*args, **kwargs) - - return call_method - - def once(func): """ Decorate func so it's only ever called the first time. @@ -99,12 +76,7 @@ def once(func): return wrapper -def method_cache( - method: CallableT, - cache_wrapper: Callable[ - [CallableT], CallableT - ] = functools.lru_cache(), # type: ignore[assignment] -) -> CallableT: +def method_cache(method, cache_wrapper=functools.lru_cache()): """ Wrap lru_cache to support storing the cache data in the object instances. @@ -172,22 +144,17 @@ def method_cache( for another implementation and additional justification. """ - def wrapper(self: object, *args: object, **kwargs: object) -> object: + def wrapper(self, *args, **kwargs): # it's the first call, replace the method with a cached, bound method - bound_method: CallableT = types.MethodType( # type: ignore[assignment] - method, self - ) + bound_method = types.MethodType(method, self) cached_method = cache_wrapper(bound_method) setattr(self, method.__name__, cached_method) return cached_method(*args, **kwargs) # Support cache clear even before cache has been created. - wrapper.cache_clear = lambda: None # type: ignore[attr-defined] + wrapper.cache_clear = lambda: None - return ( - _special_method_cache(method, cache_wrapper) # type: ignore[return-value] - or wrapper - ) + return _special_method_cache(method, cache_wrapper) or wrapper def _special_method_cache(method, cache_wrapper): @@ -203,12 +170,13 @@ def _special_method_cache(method, cache_wrapper): """ name = method.__name__ special_names = '__getattr__', '__getitem__' + if name not in special_names: - return + return None wrapper_name = '__cached' + name - def proxy(self, *args, **kwargs): + def proxy(self, /, *args, **kwargs): if wrapper_name not in vars(self): bound = types.MethodType(method, self) cache = cache_wrapper(bound) @@ -245,7 +213,7 @@ def result_invoke(action): r""" Decorate a function with an action function that is invoked on the results returned from the decorated - function (for its side-effect), then return the original + function (for its side effect), then return the original result. >>> @result_invoke(print) @@ -269,7 +237,7 @@ def result_invoke(action): return wrap -def invoke(f, *args, **kwargs): +def invoke(f, /, *args, **kwargs): """ Call a function for its side effect after initialization. @@ -304,25 +272,15 @@ def invoke(f, *args, **kwargs): Use functools.partial to pass parameters to the initial call >>> @functools.partial(invoke, name='bingo') - ... def func(name): print("called with", name) + ... def func(name): print('called with', name) called with bingo """ f(*args, **kwargs) return f -def call_aside(*args, **kwargs): - """ - Deprecated name for invoke. - """ - warnings.warn("call_aside is deprecated, use invoke", DeprecationWarning) - return invoke(*args, **kwargs) - - class Throttler: - """ - Rate-limit a function (or other callable) - """ + """Rate-limit a function (or other callable).""" def __init__(self, func, max_rate=float('Inf')): if isinstance(func, Throttler): @@ -339,20 +297,20 @@ class Throttler: return self.func(*args, **kwargs) def _wait(self): - "ensure at least 1/max_rate seconds from last call" + """Ensure at least 1/max_rate seconds from last call.""" elapsed = time.time() - self.last_called must_wait = 1 / self.max_rate - elapsed time.sleep(max(0, must_wait)) self.last_called = time.time() - def __get__(self, obj, type=None): + def __get__(self, obj, owner=None): return first_invoke(self._wait, functools.partial(self.func, obj)) def first_invoke(func1, func2): """ Return a function that when invoked will invoke func1 without - any parameters (for its side-effect) and then invoke func2 + any parameters (for its side effect) and then invoke func2 with whatever parameters were passed, returning its result. """ @@ -363,6 +321,17 @@ def first_invoke(func1, func2): return wrapper +method_caller = first_invoke( + lambda: warnings.warn( + '`jaraco.functools.method_caller` is deprecated, ' + 'use `operator.methodcaller` instead', + DeprecationWarning, + stacklevel=3, + ), + operator.methodcaller, +) + + def retry_call(func, cleanup=lambda: None, retries=0, trap=()): """ Given a callable func, trap the indicated exceptions @@ -371,7 +340,7 @@ def retry_call(func, cleanup=lambda: None, retries=0, trap=()): to propagate. """ attempts = itertools.count() if retries == float('inf') else range(retries) - for attempt in attempts: + for _ in attempts: try: return func() except trap: @@ -408,7 +377,7 @@ def retry(*r_args, **r_kwargs): def print_yielded(func): """ - Convert a generator into a function that prints all yielded elements + Convert a generator into a function that prints all yielded elements. >>> @print_yielded ... def x(): @@ -424,7 +393,7 @@ def print_yielded(func): def pass_none(func): """ - Wrap func so it's not called if its first param is None + Wrap func so it's not called if its first param is None. >>> print_text = pass_none(print) >>> print_text('text') @@ -433,9 +402,10 @@ def pass_none(func): """ @functools.wraps(func) - def wrapper(param, *args, **kwargs): + def wrapper(param, /, *args, **kwargs): if param is not None: return func(param, *args, **kwargs) + return None return wrapper @@ -509,7 +479,7 @@ def save_method_args(method): args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs') @functools.wraps(method) - def wrapper(self, *args, **kwargs): + def wrapper(self, /, *args, **kwargs): attr_name = '_saved_' + method.__name__ attr = args_and_kwargs(args, kwargs) setattr(self, attr_name, attr) @@ -559,6 +529,13 @@ def except_(*exceptions, replace=None, use=None): def identity(x): + """ + Return the argument. + + >>> o = object() + >>> identity(o) is o + True + """ return x @@ -580,7 +557,7 @@ def bypass_when(check, *, _op=identity): def decorate(func): @functools.wraps(func) - def wrapper(param): + def wrapper(param, /): return param if _op(check) else func(param) return wrapper @@ -604,3 +581,53 @@ def bypass_unless(check): 2 """ return bypass_when(check, _op=operator.not_) + + +@functools.singledispatch +def _splat_inner(args, func): + """Splat args to func.""" + return func(*args) + + +@_splat_inner.register +def _(args: collections.abc.Mapping, func): + """Splat kargs to func as kwargs.""" + return func(**args) + + +def splat(func): + """ + Wrap func to expect its parameters to be passed positionally in a tuple. + + Has a similar effect to that of ``itertools.starmap`` over + simple ``map``. + + >>> pairs = [(-1, 1), (0, 2)] + >>> more_itertools.consume(itertools.starmap(print, pairs)) + -1 1 + 0 2 + >>> more_itertools.consume(map(splat(print), pairs)) + -1 1 + 0 2 + + The approach generalizes to other iterators that don't have a "star" + equivalent, such as a "starfilter". + + >>> list(filter(splat(operator.add), pairs)) + [(0, 2)] + + Splat also accepts a mapping argument. + + >>> def is_nice(msg, code): + ... return "smile" in msg or code == 0 + >>> msgs = [ + ... dict(msg='smile!', code=20), + ... dict(msg='error :(', code=1), + ... dict(msg='unknown', code=0), + ... ] + >>> for msg in filter(splat(is_nice), msgs): + ... print(msg) + {'msg': 'smile!', 'code': 20} + {'msg': 'unknown', 'code': 0} + """ + return functools.wraps(func)(functools.partial(_splat_inner, func=func)) diff --git a/lib/jaraco/functools/__init__.pyi b/lib/jaraco/functools/__init__.pyi new file mode 100644 index 00000000..c2b9ab17 --- /dev/null +++ b/lib/jaraco/functools/__init__.pyi @@ -0,0 +1,128 @@ +from collections.abc import Callable, Hashable, Iterator +from functools import partial +from operator import methodcaller +import sys +from typing import ( + Any, + Generic, + Protocol, + TypeVar, + overload, +) + +if sys.version_info >= (3, 10): + from typing import Concatenate, ParamSpec +else: + from typing_extensions import Concatenate, ParamSpec + +_P = ParamSpec('_P') +_R = TypeVar('_R') +_T = TypeVar('_T') +_R1 = TypeVar('_R1') +_R2 = TypeVar('_R2') +_V = TypeVar('_V') +_S = TypeVar('_S') +_R_co = TypeVar('_R_co', covariant=True) + +class _OnceCallable(Protocol[_P, _R]): + saved_result: _R + reset: Callable[[], None] + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ... + +class _ProxyMethodCacheWrapper(Protocol[_R_co]): + cache_clear: Callable[[], None] + def __call__(self, *args: Hashable, **kwargs: Hashable) -> _R_co: ... + +class _MethodCacheWrapper(Protocol[_R_co]): + def cache_clear(self) -> None: ... + def __call__(self, *args: Hashable, **kwargs: Hashable) -> _R_co: ... + +# `compose()` overloads below will cover most use cases. + +@overload +def compose( + __func1: Callable[[_R], _T], + __func2: Callable[_P, _R], + /, +) -> Callable[_P, _T]: ... +@overload +def compose( + __func1: Callable[[_R], _T], + __func2: Callable[[_R1], _R], + __func3: Callable[_P, _R1], + /, +) -> Callable[_P, _T]: ... +@overload +def compose( + __func1: Callable[[_R], _T], + __func2: Callable[[_R2], _R], + __func3: Callable[[_R1], _R2], + __func4: Callable[_P, _R1], + /, +) -> Callable[_P, _T]: ... +def once(func: Callable[_P, _R]) -> _OnceCallable[_P, _R]: ... +def method_cache( + method: Callable[..., _R], + cache_wrapper: Callable[[Callable[..., _R]], _MethodCacheWrapper[_R]] = ..., +) -> _MethodCacheWrapper[_R] | _ProxyMethodCacheWrapper[_R]: ... +def apply( + transform: Callable[[_R], _T] +) -> Callable[[Callable[_P, _R]], Callable[_P, _T]]: ... +def result_invoke( + action: Callable[[_R], Any] +) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ... +def invoke( + f: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs +) -> Callable[_P, _R]: ... +def call_aside( + f: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs +) -> Callable[_P, _R]: ... + +class Throttler(Generic[_R]): + last_called: float + func: Callable[..., _R] + max_rate: float + def __init__( + self, func: Callable[..., _R] | Throttler[_R], max_rate: float = ... + ) -> None: ... + def reset(self) -> None: ... + def __call__(self, *args: Any, **kwargs: Any) -> _R: ... + def __get__(self, obj: Any, owner: type[Any] | None = ...) -> Callable[..., _R]: ... + +def first_invoke( + func1: Callable[..., Any], func2: Callable[_P, _R] +) -> Callable[_P, _R]: ... + +method_caller: Callable[..., methodcaller] + +def retry_call( + func: Callable[..., _R], + cleanup: Callable[..., None] = ..., + retries: int | float = ..., + trap: type[BaseException] | tuple[type[BaseException], ...] = ..., +) -> _R: ... +def retry( + cleanup: Callable[..., None] = ..., + retries: int | float = ..., + trap: type[BaseException] | tuple[type[BaseException], ...] = ..., +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: ... +def print_yielded(func: Callable[_P, Iterator[Any]]) -> Callable[_P, None]: ... +def pass_none( + func: Callable[Concatenate[_T, _P], _R] +) -> Callable[Concatenate[_T, _P], _R]: ... +def assign_params( + func: Callable[..., _R], namespace: dict[str, Any] +) -> partial[_R]: ... +def save_method_args( + method: Callable[Concatenate[_S, _P], _R] +) -> Callable[Concatenate[_S, _P], _R]: ... +def except_( + *exceptions: type[BaseException], replace: Any = ..., use: Any = ... +) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]: ... +def identity(x: _T) -> _T: ... +def bypass_when( + check: _V, *, _op: Callable[[_V], Any] = ... +) -> Callable[[Callable[[_T], _R]], Callable[[_T], _T | _R]]: ... +def bypass_unless( + check: Any, +) -> Callable[[Callable[[_T], _R]], Callable[[_T], _T | _R]]: ... diff --git a/lib/jaraco/functools/py.typed b/lib/jaraco/functools/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/lib/jaraco/text/__init__.py b/lib/jaraco/text/__init__.py index e51101c2..0fabd0c3 100644 --- a/lib/jaraco/text/__init__.py +++ b/lib/jaraco/text/__init__.py @@ -227,10 +227,12 @@ def unwrap(s): return '\n'.join(cleaned) -lorem_ipsum: str = files(__name__).joinpath('Lorem ipsum.txt').read_text() +lorem_ipsum: str = ( + files(__name__).joinpath('Lorem ipsum.txt').read_text(encoding='utf-8') +) -class Splitter(object): +class Splitter: """object that will split a string with the given arguments for each call >>> s = Splitter(',') @@ -367,7 +369,7 @@ class WordSet(tuple): return self.trim_left(item).trim_right(item) def __getitem__(self, item): - result = super(WordSet, self).__getitem__(item) + result = super().__getitem__(item) if isinstance(item, slice): result = WordSet(result) return result @@ -582,7 +584,7 @@ def join_continuation(lines): ['foobarbaz'] Not sure why, but... - The character preceeding the backslash is also elided. + The character preceding the backslash is also elided. >>> list(join_continuation(['goo\\', 'dly'])) ['godly'] @@ -607,16 +609,16 @@ def read_newlines(filename, limit=1024): r""" >>> tmp_path = getfixture('tmp_path') >>> filename = tmp_path / 'out.txt' - >>> _ = filename.write_text('foo\n', newline='') + >>> _ = filename.write_text('foo\n', newline='', encoding='utf-8') >>> read_newlines(filename) '\n' - >>> _ = filename.write_text('foo\r\n', newline='') + >>> _ = filename.write_text('foo\r\n', newline='', encoding='utf-8') >>> read_newlines(filename) '\r\n' - >>> _ = filename.write_text('foo\r\nbar\nbing\r', newline='') + >>> _ = filename.write_text('foo\r\nbar\nbing\r', newline='', encoding='utf-8') >>> read_newlines(filename) ('\r', '\n', '\r\n') """ - with open(filename) as fp: + with open(filename, encoding='utf-8') as fp: fp.read(limit) return fp.newlines diff --git a/lib/jaraco/text/show-newlines.py b/lib/jaraco/text/show-newlines.py index 2ba32062..e11d1ba4 100644 --- a/lib/jaraco/text/show-newlines.py +++ b/lib/jaraco/text/show-newlines.py @@ -12,11 +12,11 @@ def report_newlines(filename): >>> tmp_path = getfixture('tmp_path') >>> filename = tmp_path / 'out.txt' - >>> _ = filename.write_text('foo\nbar\n', newline='') + >>> _ = filename.write_text('foo\nbar\n', newline='', encoding='utf-8') >>> report_newlines(filename) newline is '\n' >>> filename = tmp_path / 'out.txt' - >>> _ = filename.write_text('foo\nbar\r\n', newline='') + >>> _ = filename.write_text('foo\nbar\r\n', newline='', encoding='utf-8') >>> report_newlines(filename) newlines are ('\n', '\r\n') """ diff --git a/lib/jaraco/text/strip-prefix.py b/lib/jaraco/text/strip-prefix.py new file mode 100644 index 00000000..761717a9 --- /dev/null +++ b/lib/jaraco/text/strip-prefix.py @@ -0,0 +1,21 @@ +import sys + +import autocommand + +from jaraco.text import Stripper + + +def strip_prefix(): + r""" + Strip any common prefix from stdin. + + >>> import io, pytest + >>> getfixture('monkeypatch').setattr('sys.stdin', io.StringIO('abcdef\nabc123')) + >>> strip_prefix() + def + 123 + """ + sys.stdout.writelines(Stripper.strip_prefix(sys.stdin).lines) + + +autocommand.autocommand(__name__)(strip_prefix) diff --git a/lib/more_itertools/__init__.py b/lib/more_itertools/__init__.py index 28ffadcf..aff94a9a 100644 --- a/lib/more_itertools/__init__.py +++ b/lib/more_itertools/__init__.py @@ -3,4 +3,4 @@ from .more import * # noqa from .recipes import * # noqa -__version__ = '10.1.0' +__version__ = '10.2.0' diff --git a/lib/more_itertools/more.py b/lib/more_itertools/more.py index 59c2f1a4..dd711a47 100755 --- a/lib/more_itertools/more.py +++ b/lib/more_itertools/more.py @@ -19,7 +19,7 @@ from itertools import ( zip_longest, product, ) -from math import exp, factorial, floor, log +from math import exp, factorial, floor, log, perm, comb from queue import Empty, Queue from random import random, randrange, uniform from operator import itemgetter, mul, sub, gt, lt, ge, le @@ -68,8 +68,10 @@ __all__ = [ 'divide', 'duplicates_everseen', 'duplicates_justseen', + 'classify_unique', 'exactly_n', 'filter_except', + 'filter_map', 'first', 'gray_product', 'groupby_transform', @@ -83,6 +85,7 @@ __all__ = [ 'is_sorted', 'islice_extended', 'iterate', + 'iter_suppress', 'last', 'locate', 'longest_common_prefix', @@ -198,15 +201,14 @@ def first(iterable, default=_marker): ``next(iter(iterable), default)``. """ - try: - return next(iter(iterable)) - except StopIteration as e: - if default is _marker: - raise ValueError( - 'first() was called on an empty iterable, and no ' - 'default value was provided.' - ) from e - return default + for item in iterable: + return item + if default is _marker: + raise ValueError( + 'first() was called on an empty iterable, and no ' + 'default value was provided.' + ) + return default def last(iterable, default=_marker): @@ -582,6 +584,9 @@ def strictly_n(iterable, n, too_short=None, too_long=None): >>> list(strictly_n(iterable, n)) ['a', 'b', 'c', 'd'] + Note that the returned iterable must be consumed in order for the check to + be made. + By default, *too_short* and *too_long* are functions that raise ``ValueError``. @@ -919,7 +924,7 @@ def substrings_indexes(seq, reverse=False): class bucket: - """Wrap *iterable* and return an object that buckets it iterable into + """Wrap *iterable* and return an object that buckets the iterable into child iterables based on a *key* function. >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] @@ -3222,6 +3227,8 @@ class time_limited: stops if the time elapsed is greater than *limit_seconds*. If your time limit is 1 second, but it takes 2 seconds to generate the first item from the iterable, the function will run for 2 seconds and not yield anything. + As a special case, when *limit_seconds* is zero, the iterator never + returns anything. """ @@ -3237,6 +3244,9 @@ class time_limited: return self def __next__(self): + if self.limit_seconds == 0: + self.timed_out = True + raise StopIteration item = next(self._iterable) if monotonic() - self._start_time > self.limit_seconds: self.timed_out = True @@ -3356,7 +3366,7 @@ def iequals(*iterables): >>> iequals("abc", "acb") False - Not to be confused with :func:`all_equals`, which checks whether all + Not to be confused with :func:`all_equal`, which checks whether all elements of iterable are equal to each other. """ @@ -3853,7 +3863,7 @@ def nth_permutation(iterable, r, index): elif not 0 <= r < n: raise ValueError else: - c = factorial(n) // factorial(n - r) + c = perm(n, r) if index < 0: index += c @@ -3898,7 +3908,7 @@ def nth_combination_with_replacement(iterable, r, index): if (r < 0) or (r > n): raise ValueError - c = factorial(n + r - 1) // (factorial(r) * factorial(n - 1)) + c = comb(n + r - 1, r) if index < 0: index += c @@ -3911,9 +3921,7 @@ def nth_combination_with_replacement(iterable, r, index): while r: r -= 1 while n >= 0: - num_combs = factorial(n + r - 1) // ( - factorial(r) * factorial(n - 1) - ) + num_combs = comb(n + r - 1, r) if index < num_combs: break n -= 1 @@ -4015,9 +4023,9 @@ def combination_index(element, iterable): for i, j in enumerate(reversed(indexes), start=1): j = n - j if i <= j: - index += factorial(j) // (factorial(i) * factorial(j - i)) + index += comb(j, i) - return factorial(n + 1) // (factorial(k + 1) * factorial(n - k)) - index + return comb(n + 1, k + 1) - index def combination_with_replacement_index(element, iterable): @@ -4057,7 +4065,7 @@ def combination_with_replacement_index(element, iterable): break else: raise ValueError( - 'element is not a combination with replacment of iterable' + 'element is not a combination with replacement of iterable' ) n = len(pool) @@ -4066,11 +4074,13 @@ def combination_with_replacement_index(element, iterable): occupations[p] += 1 index = 0 + cumulative_sum = 0 for k in range(1, n): - j = l + n - 1 - k - sum(occupations[:k]) + cumulative_sum += occupations[k - 1] + j = l + n - 1 - k - cumulative_sum i = n - k if i <= j: - index += factorial(j) // (factorial(i) * factorial(j - i)) + index += comb(j, i) return index @@ -4296,7 +4306,7 @@ def duplicates_everseen(iterable, key=None): >>> list(duplicates_everseen('AaaBbbCccAaa', str.lower)) ['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a'] - This function is analagous to :func:`unique_everseen` and is subject to + This function is analogous to :func:`unique_everseen` and is subject to the same performance considerations. """ @@ -4326,12 +4336,54 @@ def duplicates_justseen(iterable, key=None): >>> list(duplicates_justseen('AaaBbbCccAaa', str.lower)) ['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a'] - This function is analagous to :func:`unique_justseen`. + This function is analogous to :func:`unique_justseen`. """ return flatten(g for _, g in groupby(iterable, key) for _ in g) +def classify_unique(iterable, key=None): + """Classify each element in terms of its uniqueness. + + For each element in the input iterable, return a 3-tuple consisting of: + + 1. The element itself + 2. ``False`` if the element is equal to the one preceding it in the input, + ``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`) + 3. ``False`` if this element has been seen anywhere in the input before, + ``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`) + + >>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE + [('o', True, True), + ('t', True, True), + ('t', False, False), + ('o', True, False)] + + This function is analogous to :func:`unique_everseen` and is subject to + the same performance considerations. + + """ + seen_set = set() + seen_list = [] + use_key = key is not None + previous = None + + for i, element in enumerate(iterable): + k = key(element) if use_key else element + is_unique_justseen = not i or previous != k + previous = k + is_unique_everseen = False + try: + if k not in seen_set: + seen_set.add(k) + is_unique_everseen = True + except TypeError: + if k not in seen_list: + seen_list.append(k) + is_unique_everseen = True + yield element, is_unique_justseen, is_unique_everseen + + def minmax(iterable_or_value, *others, key=None, default=_marker): """Returns both the smallest and largest items in an iterable or the largest of two or more arguments. @@ -4529,10 +4581,8 @@ def takewhile_inclusive(predicate, iterable): :func:`takewhile` would return ``[1, 4]``. """ for x in iterable: - if predicate(x): - yield x - else: - yield x + yield x + if not predicate(x): break @@ -4567,3 +4617,40 @@ def outer_product(func, xs, ys, *args, **kwargs): starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)), n=len(ys), ) + + +def iter_suppress(iterable, *exceptions): + """Yield each of the items from *iterable*. If the iteration raises one of + the specified *exceptions*, that exception will be suppressed and iteration + will stop. + + >>> from itertools import chain + >>> def breaks_at_five(x): + ... while True: + ... if x >= 5: + ... raise RuntimeError + ... yield x + ... x += 1 + >>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError) + >>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError) + >>> list(chain(it_1, it_2)) + [1, 2, 3, 4, 2, 3, 4] + """ + try: + yield from iterable + except exceptions: + return + + +def filter_map(func, iterable): + """Apply *func* to every element of *iterable*, yielding only those which + are not ``None``. + + >>> elems = ['1', 'a', '2', 'b', '3'] + >>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems)) + [1, 2, 3] + """ + for x in iterable: + y = func(x) + if y is not None: + yield y diff --git a/lib/more_itertools/more.pyi b/lib/more_itertools/more.pyi index 07bfc155..9a5fc911 100644 --- a/lib/more_itertools/more.pyi +++ b/lib/more_itertools/more.pyi @@ -29,7 +29,7 @@ _U = TypeVar('_U') _V = TypeVar('_V') _W = TypeVar('_W') _T_co = TypeVar('_T_co', covariant=True) -_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[object]]) +_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[Any]]) _Raisable = BaseException | Type[BaseException] @type_check_only @@ -74,7 +74,7 @@ class peekable(Generic[_T], Iterator[_T]): def __getitem__(self, index: slice) -> list[_T]: ... def consumer(func: _GenFn) -> _GenFn: ... -def ilen(iterable: Iterable[object]) -> int: ... +def ilen(iterable: Iterable[_T]) -> int: ... def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ... def with_iter( context_manager: ContextManager[Iterable[_T]], @@ -116,7 +116,7 @@ class bucket(Generic[_T, _U], Container[_U]): self, iterable: Iterable[_T], key: Callable[[_T], _U], - validator: Callable[[object], object] | None = ..., + validator: Callable[[_U], object] | None = ..., ) -> None: ... def __contains__(self, value: object) -> bool: ... def __iter__(self) -> Iterator[_U]: ... @@ -383,7 +383,7 @@ def mark_ends( iterable: Iterable[_T], ) -> Iterable[tuple[bool, bool, _T]]: ... def locate( - iterable: Iterable[object], + iterable: Iterable[_T], pred: Callable[..., Any] = ..., window_size: int | None = ..., ) -> Iterator[int]: ... @@ -618,6 +618,9 @@ def duplicates_everseen( def duplicates_justseen( iterable: Iterable[_T], key: Callable[[_T], _U] | None = ... ) -> Iterator[_T]: ... +def classify_unique( + iterable: Iterable[_T], key: Callable[[_T], _U] | None = ... +) -> Iterator[tuple[_T, bool, bool]]: ... class _SupportsLessThan(Protocol): def __lt__(self, __other: Any) -> bool: ... @@ -662,9 +665,9 @@ def minmax( def longest_common_prefix( iterables: Iterable[Iterable[_T]], ) -> Iterator[_T]: ... -def iequals(*iterables: Iterable[object]) -> bool: ... +def iequals(*iterables: Iterable[Any]) -> bool: ... def constrained_batches( - iterable: Iterable[object], + iterable: Iterable[_T], max_size: int, max_count: int | None = ..., get_len: Callable[[_T], object] = ..., @@ -682,3 +685,11 @@ def outer_product( *args: Any, **kwargs: Any, ) -> Iterator[tuple[_V, ...]]: ... +def iter_suppress( + iterable: Iterable[_T], + *exceptions: Type[BaseException], +) -> Iterator[_T]: ... +def filter_map( + func: Callable[[_T], _V | None], + iterable: Iterable[_T], +) -> Iterator[_V]: ... diff --git a/lib/more_itertools/recipes.py b/lib/more_itertools/recipes.py index a0bdbece..145e3cb5 100644 --- a/lib/more_itertools/recipes.py +++ b/lib/more_itertools/recipes.py @@ -28,6 +28,7 @@ from itertools import ( zip_longest, ) from random import randrange, sample, choice +from sys import hexversion __all__ = [ 'all_equal', @@ -56,6 +57,7 @@ __all__ = [ 'powerset', 'prepend', 'quantify', + 'reshape', 'random_combination_with_replacement', 'random_combination', 'random_permutation', @@ -69,6 +71,7 @@ __all__ = [ 'tabulate', 'tail', 'take', + 'totient', 'transpose', 'triplewise', 'unique_everseen', @@ -492,7 +495,7 @@ def unique_everseen(iterable, key=None): >>> list(unique_everseen(iterable, key=tuple)) # Faster [[1, 2], [2, 3]] - Similary, you may want to convert unhashable ``set`` objects with + Similarly, you may want to convert unhashable ``set`` objects with ``key=frozenset``. For ``dict`` objects, ``key=lambda x: frozenset(x.items())`` can be used. @@ -524,6 +527,9 @@ def unique_justseen(iterable, key=None): ['A', 'B', 'C', 'A', 'D'] """ + if key is None: + return map(operator.itemgetter(0), groupby(iterable)) + return map(next, map(operator.itemgetter(1), groupby(iterable, key))) @@ -817,35 +823,34 @@ def polynomial_from_roots(roots): return list(reduce(convolve, factors, [1])) -def iter_index(iterable, value, start=0): +def iter_index(iterable, value, start=0, stop=None): """Yield the index of each place in *iterable* that *value* occurs, - beginning with index *start*. + beginning with index *start* and ending before index *stop*. See :func:`locate` for a more general means of finding the indexes associated with particular values. >>> list(iter_index('AABCADEAF', 'A')) [0, 1, 4, 7] + >>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive + [1, 4, 7] + >>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive + [1, 4] """ - try: - seq_index = iterable.index - except AttributeError: + seq_index = getattr(iterable, 'index', None) + if seq_index is None: # Slow path for general iterables - it = islice(iterable, start, None) - i = start - 1 - try: - while True: - i = i + operator.indexOf(it, value) + 1 + it = islice(iterable, start, stop) + for i, element in enumerate(it, start): + if element is value or element == value: yield i - except ValueError: - pass else: # Fast path for sequences + stop = len(iterable) if stop is None else stop i = start - 1 try: while True: - i = seq_index(value, i + 1) - yield i + yield (i := seq_index(value, i + 1, stop)) except ValueError: pass @@ -856,47 +861,52 @@ def sieve(n): >>> list(sieve(30)) [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] """ + if n > 2: + yield 2 + start = 3 data = bytearray((0, 1)) * (n // 2) - data[:3] = 0, 0, 0 limit = math.isqrt(n) + 1 - for p in compress(range(limit), data): + for p in iter_index(data, 1, start, limit): + yield from iter_index(data, 1, start, p * p) data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p))) - data[2] = 1 - return iter_index(data, 1) if n > 2 else iter([]) + start = p * p + yield from iter_index(data, 1, start) -def _batched(iterable, n): - """Batch data into lists of length *n*. The last batch may be shorter. +def _batched(iterable, n, *, strict=False): + """Batch data into tuples of length *n*. If the number of items in + *iterable* is not divisible by *n*: + * The last batch will be shorter if *strict* is ``False``. + * :exc:`ValueError` will be raised if *strict* is ``True``. >>> list(batched('ABCDEFG', 3)) [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)] - On Python 3.12 and above, this is an alias for :func:`itertools.batched`. + On Python 3.13 and above, this is an alias for :func:`itertools.batched`. """ if n < 1: raise ValueError('n must be at least one') it = iter(iterable) - while True: - batch = tuple(islice(it, n)) - if not batch: - break + while batch := tuple(islice(it, n)): + if strict and len(batch) != n: + raise ValueError('batched(): incomplete batch') yield batch -try: +if hexversion >= 0x30D00A2: from itertools import batched as itertools_batched -except ImportError: - batched = _batched -else: - def batched(iterable, n): - return itertools_batched(iterable, n) + def batched(iterable, n, *, strict=False): + return itertools_batched(iterable, n, strict=strict) + +else: + batched = _batched batched.__doc__ = _batched.__doc__ def transpose(it): - """Swap the rows and columns of the input. + """Swap the rows and columns of the input matrix. >>> list(transpose([(1, 2, 3), (11, 22, 33)])) [(1, 11), (2, 22), (3, 33)] @@ -907,8 +917,20 @@ def transpose(it): return _zip_strict(*it) +def reshape(matrix, cols): + """Reshape the 2-D input *matrix* to have a column count given by *cols*. + + >>> matrix = [(0, 1), (2, 3), (4, 5)] + >>> cols = 3 + >>> list(reshape(matrix, cols)) + [(0, 1, 2), (3, 4, 5)] + """ + return batched(chain.from_iterable(matrix), cols) + + def matmul(m1, m2): """Multiply two matrices. + >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)])) [(49, 80), (41, 60)] @@ -921,13 +943,12 @@ def matmul(m1, m2): def factor(n): """Yield the prime factors of n. + >>> list(factor(360)) [2, 2, 2, 3, 3, 5] """ for prime in sieve(math.isqrt(n) + 1): - while True: - if n % prime: - break + while not n % prime: yield prime n //= prime if n == 1: @@ -975,3 +996,17 @@ def polynomial_derivative(coefficients): n = len(coefficients) powers = reversed(range(1, n)) return list(map(operator.mul, coefficients, powers)) + + +def totient(n): + """Return the count of natural numbers up to *n* that are coprime with *n*. + + >>> totient(9) + 6 + >>> totient(12) + 4 + """ + for p in unique_justseen(factor(n)): + n = n // p * (p - 1) + + return n diff --git a/lib/more_itertools/recipes.pyi b/lib/more_itertools/recipes.pyi index ef883864..ed4c19db 100644 --- a/lib/more_itertools/recipes.pyi +++ b/lib/more_itertools/recipes.pyi @@ -14,6 +14,8 @@ from typing import ( # Type and type variable definitions _T = TypeVar('_T') +_T1 = TypeVar('_T1') +_T2 = TypeVar('_T2') _U = TypeVar('_U') def take(n: int, iterable: Iterable[_T]) -> list[_T]: ... @@ -26,14 +28,14 @@ def consume(iterator: Iterable[_T], n: int | None = ...) -> None: ... def nth(iterable: Iterable[_T], n: int) -> _T | None: ... @overload def nth(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ... -def all_equal(iterable: Iterable[object]) -> bool: ... +def all_equal(iterable: Iterable[_T]) -> bool: ... def quantify( iterable: Iterable[_T], pred: Callable[[_T], bool] = ... ) -> int: ... def pad_none(iterable: Iterable[_T]) -> Iterator[_T | None]: ... def padnone(iterable: Iterable[_T]) -> Iterator[_T | None]: ... def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: ... -def dotproduct(vec1: Iterable[object], vec2: Iterable[object]) -> object: ... +def dotproduct(vec1: Iterable[_T1], vec2: Iterable[_T2]) -> Any: ... def flatten(listOfLists: Iterable[Iterable[_T]]) -> Iterator[_T]: ... def repeatfunc( func: Callable[..., _U], times: int | None = ..., *args: Any @@ -103,20 +105,24 @@ def sliding_window( def subslices(iterable: Iterable[_T]) -> Iterator[list[_T]]: ... def polynomial_from_roots(roots: Sequence[_T]) -> list[_T]: ... def iter_index( - iterable: Iterable[object], + iterable: Iterable[_T], value: Any, start: int | None = ..., + stop: int | None = ..., ) -> Iterator[int]: ... def sieve(n: int) -> Iterator[int]: ... def batched( - iterable: Iterable[_T], - n: int, + iterable: Iterable[_T], n: int, *, strict: bool = False ) -> Iterator[tuple[_T]]: ... def transpose( it: Iterable[Iterable[_T]], ) -> Iterator[tuple[_T, ...]]: ... +def reshape( + matrix: Iterable[Iterable[_T]], cols: int +) -> Iterator[tuple[_T, ...]]: ... def matmul(m1: Sequence[_T], m2: Sequence[_T]) -> Iterator[tuple[_T]]: ... def factor(n: int) -> Iterator[int]: ... def polynomial_eval(coefficients: Sequence[_T], x: _U) -> _U: ... def sum_of_squares(it: Iterable[_T]) -> _T: ... def polynomial_derivative(coefficients: Sequence[_T]) -> list[_T]: ... +def totient(n: int) -> int: ... diff --git a/lib/pydantic/__init__.py b/lib/pydantic/__init__.py index 3bf1418f..85a8c18b 100644 --- a/lib/pydantic/__init__.py +++ b/lib/pydantic/__init__.py @@ -1,56 +1,114 @@ -# flake8: noqa -from . import dataclasses -from .annotated_types import create_model_from_namedtuple, create_model_from_typeddict -from .class_validators import root_validator, validator -from .config import BaseConfig, ConfigDict, Extra -from .decorator import validate_arguments -from .env_settings import BaseSettings -from .error_wrappers import ValidationError -from .errors import * -from .fields import Field, PrivateAttr, Required -from .main import * -from .networks import * -from .parse import Protocol -from .tools import * -from .types import * -from .version import VERSION, compiled +import typing + +from ._migration import getattr_migration +from .version import VERSION + +if typing.TYPE_CHECKING: + # import of virtually everything is supported via `__getattr__` below, + # but we need them here for type checking and IDE support + import pydantic_core + from pydantic_core.core_schema import ( + FieldSerializationInfo, + SerializationInfo, + SerializerFunctionWrapHandler, + ValidationInfo, + ValidatorFunctionWrapHandler, + ) + + from . import dataclasses + from ._internal._generate_schema import GenerateSchema as GenerateSchema + from .aliases import AliasChoices, AliasGenerator, AliasPath + from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler + from .config import ConfigDict + from .errors import * + from .fields import Field, PrivateAttr, computed_field + from .functional_serializers import ( + PlainSerializer, + SerializeAsAny, + WrapSerializer, + field_serializer, + model_serializer, + ) + from .functional_validators import ( + AfterValidator, + BeforeValidator, + InstanceOf, + PlainValidator, + SkipValidation, + WrapValidator, + field_validator, + model_validator, + ) + from .json_schema import WithJsonSchema + from .main import * + from .networks import * + from .type_adapter import TypeAdapter + from .types import * + from .validate_call_decorator import validate_call + from .warnings import PydanticDeprecatedSince20, PydanticDeprecatedSince26, PydanticDeprecationWarning + + # this encourages pycharm to import `ValidationError` from here, not pydantic_core + ValidationError = pydantic_core.ValidationError + from .deprecated.class_validators import root_validator, validator + from .deprecated.config import BaseConfig, Extra + from .deprecated.tools import * + from .root_model import RootModel __version__ = VERSION - -# WARNING __all__ from .errors is not included here, it will be removed as an export here in v2 -# please use "from pydantic.errors import ..." instead -__all__ = [ - # annotated types utils - 'create_model_from_namedtuple', - 'create_model_from_typeddict', +__all__ = ( # dataclasses 'dataclasses', - # class_validators + # functional validators + 'field_validator', + 'model_validator', + 'AfterValidator', + 'BeforeValidator', + 'PlainValidator', + 'WrapValidator', + 'SkipValidation', + 'InstanceOf', + # JSON Schema + 'WithJsonSchema', + # deprecated V1 functional validators, these are imported via `__getattr__` below 'root_validator', 'validator', + # functional serializers + 'field_serializer', + 'model_serializer', + 'PlainSerializer', + 'SerializeAsAny', + 'WrapSerializer', # config - 'BaseConfig', 'ConfigDict', + # deprecated V1 config, these are imported via `__getattr__` below + 'BaseConfig', 'Extra', - # decorator - 'validate_arguments', - # env_settings - 'BaseSettings', - # error_wrappers - 'ValidationError', + # validate_call + 'validate_call', + # errors + 'PydanticErrorCodes', + 'PydanticUserError', + 'PydanticSchemaGenerationError', + 'PydanticImportError', + 'PydanticUndefinedAnnotation', + 'PydanticInvalidForJsonSchema', # fields 'Field', - 'Required', + 'computed_field', + 'PrivateAttr', + # alias + 'AliasChoices', + 'AliasGenerator', + 'AliasPath', # main 'BaseModel', 'create_model', - 'validate_model', # network 'AnyUrl', 'AnyHttpUrl', 'FileUrl', 'HttpUrl', - 'stricturl', + 'UrlConstraints', 'EmailStr', 'NameEmail', 'IPvAnyAddress', @@ -62,48 +120,38 @@ __all__ = [ 'RedisDsn', 'MongoDsn', 'KafkaDsn', + 'NatsDsn', + 'MySQLDsn', + 'MariaDBDsn', 'validate_email', - # parse - 'Protocol', - # tools - 'parse_file_as', + # root_model + 'RootModel', + # deprecated tools, these are imported via `__getattr__` below 'parse_obj_as', - 'parse_raw_as', 'schema_of', 'schema_json_of', # types - 'NoneStr', - 'NoneBytes', - 'StrBytes', - 'NoneStrBytes', + 'Strict', 'StrictStr', - 'ConstrainedBytes', 'conbytes', - 'ConstrainedList', 'conlist', - 'ConstrainedSet', 'conset', - 'ConstrainedFrozenSet', 'confrozenset', - 'ConstrainedStr', 'constr', - 'PyObject', - 'ConstrainedInt', + 'StringConstraints', + 'ImportString', 'conint', 'PositiveInt', 'NegativeInt', 'NonNegativeInt', 'NonPositiveInt', - 'ConstrainedFloat', 'confloat', 'PositiveFloat', 'NegativeFloat', 'NonNegativeFloat', 'NonPositiveFloat', 'FiniteFloat', - 'ConstrainedDecimal', 'condecimal', - 'ConstrainedDate', 'condate', 'UUID1', 'UUID3', @@ -111,9 +159,8 @@ __all__ = [ 'UUID5', 'FilePath', 'DirectoryPath', + 'NewPath', 'Json', - 'JsonWrapper', - 'SecretField', 'SecretStr', 'SecretBytes', 'StrictBool', @@ -121,11 +168,221 @@ __all__ = [ 'StrictInt', 'StrictFloat', 'PaymentCardNumber', - 'PrivateAttr', 'ByteSize', 'PastDate', 'FutureDate', + 'PastDatetime', + 'FutureDatetime', + 'AwareDatetime', + 'NaiveDatetime', + 'AllowInfNan', + 'EncoderProtocol', + 'EncodedBytes', + 'EncodedStr', + 'Base64Encoder', + 'Base64Bytes', + 'Base64Str', + 'Base64UrlBytes', + 'Base64UrlStr', + 'GetPydanticSchema', + 'Tag', + 'Discriminator', + 'JsonValue', + # type_adapter + 'TypeAdapter', # version - 'compiled', + '__version__', 'VERSION', -] + # warnings + 'PydanticDeprecatedSince20', + 'PydanticDeprecatedSince26', + 'PydanticDeprecationWarning', + # annotated handlers + 'GetCoreSchemaHandler', + 'GetJsonSchemaHandler', + # generate schema from ._internal + 'GenerateSchema', + # pydantic_core + 'ValidationError', + 'ValidationInfo', + 'SerializationInfo', + 'ValidatorFunctionWrapHandler', + 'FieldSerializationInfo', + 'SerializerFunctionWrapHandler', + 'OnErrorOmit', +) + +# A mapping of {: (package, )} defining dynamic imports +_dynamic_imports: 'dict[str, tuple[str, str]]' = { + 'dataclasses': (__package__, '__module__'), + # functional validators + 'field_validator': (__package__, '.functional_validators'), + 'model_validator': (__package__, '.functional_validators'), + 'AfterValidator': (__package__, '.functional_validators'), + 'BeforeValidator': (__package__, '.functional_validators'), + 'PlainValidator': (__package__, '.functional_validators'), + 'WrapValidator': (__package__, '.functional_validators'), + 'SkipValidation': (__package__, '.functional_validators'), + 'InstanceOf': (__package__, '.functional_validators'), + # JSON Schema + 'WithJsonSchema': (__package__, '.json_schema'), + # functional serializers + 'field_serializer': (__package__, '.functional_serializers'), + 'model_serializer': (__package__, '.functional_serializers'), + 'PlainSerializer': (__package__, '.functional_serializers'), + 'SerializeAsAny': (__package__, '.functional_serializers'), + 'WrapSerializer': (__package__, '.functional_serializers'), + # config + 'ConfigDict': (__package__, '.config'), + # validate call + 'validate_call': (__package__, '.validate_call_decorator'), + # errors + 'PydanticErrorCodes': (__package__, '.errors'), + 'PydanticUserError': (__package__, '.errors'), + 'PydanticSchemaGenerationError': (__package__, '.errors'), + 'PydanticImportError': (__package__, '.errors'), + 'PydanticUndefinedAnnotation': (__package__, '.errors'), + 'PydanticInvalidForJsonSchema': (__package__, '.errors'), + # fields + 'Field': (__package__, '.fields'), + 'computed_field': (__package__, '.fields'), + 'PrivateAttr': (__package__, '.fields'), + # alias + 'AliasChoices': (__package__, '.aliases'), + 'AliasGenerator': (__package__, '.aliases'), + 'AliasPath': (__package__, '.aliases'), + # main + 'BaseModel': (__package__, '.main'), + 'create_model': (__package__, '.main'), + # network + 'AnyUrl': (__package__, '.networks'), + 'AnyHttpUrl': (__package__, '.networks'), + 'FileUrl': (__package__, '.networks'), + 'HttpUrl': (__package__, '.networks'), + 'UrlConstraints': (__package__, '.networks'), + 'EmailStr': (__package__, '.networks'), + 'NameEmail': (__package__, '.networks'), + 'IPvAnyAddress': (__package__, '.networks'), + 'IPvAnyInterface': (__package__, '.networks'), + 'IPvAnyNetwork': (__package__, '.networks'), + 'PostgresDsn': (__package__, '.networks'), + 'CockroachDsn': (__package__, '.networks'), + 'AmqpDsn': (__package__, '.networks'), + 'RedisDsn': (__package__, '.networks'), + 'MongoDsn': (__package__, '.networks'), + 'KafkaDsn': (__package__, '.networks'), + 'NatsDsn': (__package__, '.networks'), + 'MySQLDsn': (__package__, '.networks'), + 'MariaDBDsn': (__package__, '.networks'), + 'validate_email': (__package__, '.networks'), + # root_model + 'RootModel': (__package__, '.root_model'), + # types + 'Strict': (__package__, '.types'), + 'StrictStr': (__package__, '.types'), + 'conbytes': (__package__, '.types'), + 'conlist': (__package__, '.types'), + 'conset': (__package__, '.types'), + 'confrozenset': (__package__, '.types'), + 'constr': (__package__, '.types'), + 'StringConstraints': (__package__, '.types'), + 'ImportString': (__package__, '.types'), + 'conint': (__package__, '.types'), + 'PositiveInt': (__package__, '.types'), + 'NegativeInt': (__package__, '.types'), + 'NonNegativeInt': (__package__, '.types'), + 'NonPositiveInt': (__package__, '.types'), + 'confloat': (__package__, '.types'), + 'PositiveFloat': (__package__, '.types'), + 'NegativeFloat': (__package__, '.types'), + 'NonNegativeFloat': (__package__, '.types'), + 'NonPositiveFloat': (__package__, '.types'), + 'FiniteFloat': (__package__, '.types'), + 'condecimal': (__package__, '.types'), + 'condate': (__package__, '.types'), + 'UUID1': (__package__, '.types'), + 'UUID3': (__package__, '.types'), + 'UUID4': (__package__, '.types'), + 'UUID5': (__package__, '.types'), + 'FilePath': (__package__, '.types'), + 'DirectoryPath': (__package__, '.types'), + 'NewPath': (__package__, '.types'), + 'Json': (__package__, '.types'), + 'SecretStr': (__package__, '.types'), + 'SecretBytes': (__package__, '.types'), + 'StrictBool': (__package__, '.types'), + 'StrictBytes': (__package__, '.types'), + 'StrictInt': (__package__, '.types'), + 'StrictFloat': (__package__, '.types'), + 'PaymentCardNumber': (__package__, '.types'), + 'ByteSize': (__package__, '.types'), + 'PastDate': (__package__, '.types'), + 'FutureDate': (__package__, '.types'), + 'PastDatetime': (__package__, '.types'), + 'FutureDatetime': (__package__, '.types'), + 'AwareDatetime': (__package__, '.types'), + 'NaiveDatetime': (__package__, '.types'), + 'AllowInfNan': (__package__, '.types'), + 'EncoderProtocol': (__package__, '.types'), + 'EncodedBytes': (__package__, '.types'), + 'EncodedStr': (__package__, '.types'), + 'Base64Encoder': (__package__, '.types'), + 'Base64Bytes': (__package__, '.types'), + 'Base64Str': (__package__, '.types'), + 'Base64UrlBytes': (__package__, '.types'), + 'Base64UrlStr': (__package__, '.types'), + 'GetPydanticSchema': (__package__, '.types'), + 'Tag': (__package__, '.types'), + 'Discriminator': (__package__, '.types'), + 'JsonValue': (__package__, '.types'), + 'OnErrorOmit': (__package__, '.types'), + # type_adapter + 'TypeAdapter': (__package__, '.type_adapter'), + # warnings + 'PydanticDeprecatedSince20': (__package__, '.warnings'), + 'PydanticDeprecatedSince26': (__package__, '.warnings'), + 'PydanticDeprecationWarning': (__package__, '.warnings'), + # annotated handlers + 'GetCoreSchemaHandler': (__package__, '.annotated_handlers'), + 'GetJsonSchemaHandler': (__package__, '.annotated_handlers'), + # generate schema from ._internal + 'GenerateSchema': (__package__, '._internal._generate_schema'), + # pydantic_core stuff + 'ValidationError': ('pydantic_core', '.'), + 'ValidationInfo': ('pydantic_core', '.core_schema'), + 'SerializationInfo': ('pydantic_core', '.core_schema'), + 'ValidatorFunctionWrapHandler': ('pydantic_core', '.core_schema'), + 'FieldSerializationInfo': ('pydantic_core', '.core_schema'), + 'SerializerFunctionWrapHandler': ('pydantic_core', '.core_schema'), + # deprecated, mostly not included in __all__ + 'root_validator': (__package__, '.deprecated.class_validators'), + 'validator': (__package__, '.deprecated.class_validators'), + 'BaseConfig': (__package__, '.deprecated.config'), + 'Extra': (__package__, '.deprecated.config'), + 'parse_obj_as': (__package__, '.deprecated.tools'), + 'schema_of': (__package__, '.deprecated.tools'), + 'schema_json_of': (__package__, '.deprecated.tools'), + 'FieldValidationInfo': ('pydantic_core', '.core_schema'), +} + +_getattr_migration = getattr_migration(__name__) + + +def __getattr__(attr_name: str) -> object: + dynamic_attr = _dynamic_imports.get(attr_name) + if dynamic_attr is None: + return _getattr_migration(attr_name) + + package, module_name = dynamic_attr + + from importlib import import_module + + if module_name == '__module__': + return import_module(f'.{attr_name}', package=package) + else: + module = import_module(module_name, package=package) + return getattr(module, attr_name) + + +def __dir__() -> 'list[str]': + return list(__all__) diff --git a/lib/pydantic/_internal/__init__.py b/lib/pydantic/_internal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/pydantic/_internal/_config.py b/lib/pydantic/_internal/_config.py new file mode 100644 index 00000000..52c4cc42 --- /dev/null +++ b/lib/pydantic/_internal/_config.py @@ -0,0 +1,322 @@ +from __future__ import annotations as _annotations + +import warnings +from contextlib import contextmanager +from typing import ( + TYPE_CHECKING, + Any, + Callable, + cast, +) + +from pydantic_core import core_schema +from typing_extensions import ( + Literal, + Self, +) + +from ..aliases import AliasGenerator +from ..config import ConfigDict, ExtraValues, JsonDict, JsonEncoder, JsonSchemaExtraCallable +from ..errors import PydanticUserError +from ..warnings import PydanticDeprecatedSince20 + +if not TYPE_CHECKING: + # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915 + # and https://youtrack.jetbrains.com/issue/PY-51428 + DeprecationWarning = PydanticDeprecatedSince20 + +if TYPE_CHECKING: + from .._internal._schema_generation_shared import GenerateSchema + +DEPRECATION_MESSAGE = 'Support for class-based `config` is deprecated, use ConfigDict instead.' + + +class ConfigWrapper: + """Internal wrapper for Config which exposes ConfigDict items as attributes.""" + + __slots__ = ('config_dict',) + + config_dict: ConfigDict + + # all annotations are copied directly from ConfigDict, and should be kept up to date, a test will fail if they + # stop matching + title: str | None + str_to_lower: bool + str_to_upper: bool + str_strip_whitespace: bool + str_min_length: int + str_max_length: int | None + extra: ExtraValues | None + frozen: bool + populate_by_name: bool + use_enum_values: bool + validate_assignment: bool + arbitrary_types_allowed: bool + from_attributes: bool + # whether to use the actual key provided in the data (e.g. alias or first alias for "field required" errors) instead of field_names + # to construct error `loc`s, default `True` + loc_by_alias: bool + alias_generator: Callable[[str], str] | AliasGenerator | None + ignored_types: tuple[type, ...] + allow_inf_nan: bool + json_schema_extra: JsonDict | JsonSchemaExtraCallable | None + json_encoders: dict[type[object], JsonEncoder] | None + + # new in V2 + strict: bool + # whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never' + revalidate_instances: Literal['always', 'never', 'subclass-instances'] + ser_json_timedelta: Literal['iso8601', 'float'] + ser_json_bytes: Literal['utf8', 'base64'] + ser_json_inf_nan: Literal['null', 'constants'] + # whether to validate default values during validation, default False + validate_default: bool + validate_return: bool + protected_namespaces: tuple[str, ...] + hide_input_in_errors: bool + defer_build: bool + plugin_settings: dict[str, object] | None + schema_generator: type[GenerateSchema] | None + json_schema_serialization_defaults_required: bool + json_schema_mode_override: Literal['validation', 'serialization', None] + coerce_numbers_to_str: bool + regex_engine: Literal['rust-regex', 'python-re'] + validation_error_cause: bool + + def __init__(self, config: ConfigDict | dict[str, Any] | type[Any] | None, *, check: bool = True): + if check: + self.config_dict = prepare_config(config) + else: + self.config_dict = cast(ConfigDict, config) + + @classmethod + def for_model(cls, bases: tuple[type[Any], ...], namespace: dict[str, Any], kwargs: dict[str, Any]) -> Self: + """Build a new `ConfigWrapper` instance for a `BaseModel`. + + The config wrapper built based on (in descending order of priority): + - options from `kwargs` + - options from the `namespace` + - options from the base classes (`bases`) + + Args: + bases: A tuple of base classes. + namespace: The namespace of the class being created. + kwargs: The kwargs passed to the class being created. + + Returns: + A `ConfigWrapper` instance for `BaseModel`. + """ + config_new = ConfigDict() + for base in bases: + config = getattr(base, 'model_config', None) + if config: + config_new.update(config.copy()) + + config_class_from_namespace = namespace.get('Config') + config_dict_from_namespace = namespace.get('model_config') + + if config_class_from_namespace and config_dict_from_namespace: + raise PydanticUserError('"Config" and "model_config" cannot be used together', code='config-both') + + config_from_namespace = config_dict_from_namespace or prepare_config(config_class_from_namespace) + + config_new.update(config_from_namespace) + + for k in list(kwargs.keys()): + if k in config_keys: + config_new[k] = kwargs.pop(k) + + return cls(config_new) + + # we don't show `__getattr__` to type checkers so missing attributes cause errors + if not TYPE_CHECKING: # pragma: no branch + + def __getattr__(self, name: str) -> Any: + try: + return self.config_dict[name] + except KeyError: + try: + return config_defaults[name] + except KeyError: + raise AttributeError(f'Config has no attribute {name!r}') from None + + def core_config(self, obj: Any) -> core_schema.CoreConfig: + """Create a pydantic-core config, `obj` is just used to populate `title` if not set in config. + + Pass `obj=None` if you do not want to attempt to infer the `title`. + + We don't use getattr here since we don't want to populate with defaults. + + Args: + obj: An object used to populate `title` if not set in config. + + Returns: + A `CoreConfig` object created from config. + """ + + def dict_not_none(**kwargs: Any) -> Any: + return {k: v for k, v in kwargs.items() if v is not None} + + core_config = core_schema.CoreConfig( + **dict_not_none( + title=self.config_dict.get('title') or (obj and obj.__name__), + extra_fields_behavior=self.config_dict.get('extra'), + allow_inf_nan=self.config_dict.get('allow_inf_nan'), + populate_by_name=self.config_dict.get('populate_by_name'), + str_strip_whitespace=self.config_dict.get('str_strip_whitespace'), + str_to_lower=self.config_dict.get('str_to_lower'), + str_to_upper=self.config_dict.get('str_to_upper'), + strict=self.config_dict.get('strict'), + ser_json_timedelta=self.config_dict.get('ser_json_timedelta'), + ser_json_bytes=self.config_dict.get('ser_json_bytes'), + ser_json_inf_nan=self.config_dict.get('ser_json_inf_nan'), + from_attributes=self.config_dict.get('from_attributes'), + loc_by_alias=self.config_dict.get('loc_by_alias'), + revalidate_instances=self.config_dict.get('revalidate_instances'), + validate_default=self.config_dict.get('validate_default'), + str_max_length=self.config_dict.get('str_max_length'), + str_min_length=self.config_dict.get('str_min_length'), + hide_input_in_errors=self.config_dict.get('hide_input_in_errors'), + coerce_numbers_to_str=self.config_dict.get('coerce_numbers_to_str'), + regex_engine=self.config_dict.get('regex_engine'), + validation_error_cause=self.config_dict.get('validation_error_cause'), + ) + ) + return core_config + + def __repr__(self): + c = ', '.join(f'{k}={v!r}' for k, v in self.config_dict.items()) + return f'ConfigWrapper({c})' + + +class ConfigWrapperStack: + """A stack of `ConfigWrapper` instances.""" + + def __init__(self, config_wrapper: ConfigWrapper): + self._config_wrapper_stack: list[ConfigWrapper] = [config_wrapper] + + @property + def tail(self) -> ConfigWrapper: + return self._config_wrapper_stack[-1] + + @contextmanager + def push(self, config_wrapper: ConfigWrapper | ConfigDict | None): + if config_wrapper is None: + yield + return + + if not isinstance(config_wrapper, ConfigWrapper): + config_wrapper = ConfigWrapper(config_wrapper, check=False) + + self._config_wrapper_stack.append(config_wrapper) + try: + yield + finally: + self._config_wrapper_stack.pop() + + +config_defaults = ConfigDict( + title=None, + str_to_lower=False, + str_to_upper=False, + str_strip_whitespace=False, + str_min_length=0, + str_max_length=None, + # let the model / dataclass decide how to handle it + extra=None, + frozen=False, + populate_by_name=False, + use_enum_values=False, + validate_assignment=False, + arbitrary_types_allowed=False, + from_attributes=False, + loc_by_alias=True, + alias_generator=None, + ignored_types=(), + allow_inf_nan=True, + json_schema_extra=None, + strict=False, + revalidate_instances='never', + ser_json_timedelta='iso8601', + ser_json_bytes='utf8', + ser_json_inf_nan='null', + validate_default=False, + validate_return=False, + protected_namespaces=('model_',), + hide_input_in_errors=False, + json_encoders=None, + defer_build=False, + plugin_settings=None, + schema_generator=None, + json_schema_serialization_defaults_required=False, + json_schema_mode_override=None, + coerce_numbers_to_str=False, + regex_engine='rust-regex', + validation_error_cause=False, +) + + +def prepare_config(config: ConfigDict | dict[str, Any] | type[Any] | None) -> ConfigDict: + """Create a `ConfigDict` instance from an existing dict, a class (e.g. old class-based config) or None. + + Args: + config: The input config. + + Returns: + A ConfigDict object created from config. + """ + if config is None: + return ConfigDict() + + if not isinstance(config, dict): + warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning) + config = {k: getattr(config, k) for k in dir(config) if not k.startswith('__')} + + config_dict = cast(ConfigDict, config) + check_deprecated(config_dict) + return config_dict + + +config_keys = set(ConfigDict.__annotations__.keys()) + + +V2_REMOVED_KEYS = { + 'allow_mutation', + 'error_msg_templates', + 'fields', + 'getter_dict', + 'smart_union', + 'underscore_attrs_are_private', + 'json_loads', + 'json_dumps', + 'copy_on_model_validation', + 'post_init_call', +} +V2_RENAMED_KEYS = { + 'allow_population_by_field_name': 'populate_by_name', + 'anystr_lower': 'str_to_lower', + 'anystr_strip_whitespace': 'str_strip_whitespace', + 'anystr_upper': 'str_to_upper', + 'keep_untouched': 'ignored_types', + 'max_anystr_length': 'str_max_length', + 'min_anystr_length': 'str_min_length', + 'orm_mode': 'from_attributes', + 'schema_extra': 'json_schema_extra', + 'validate_all': 'validate_default', +} + + +def check_deprecated(config_dict: ConfigDict) -> None: + """Check for deprecated config keys and warn the user. + + Args: + config_dict: The input config. + """ + deprecated_removed_keys = V2_REMOVED_KEYS & config_dict.keys() + deprecated_renamed_keys = V2_RENAMED_KEYS.keys() & config_dict.keys() + if deprecated_removed_keys or deprecated_renamed_keys: + renamings = {k: V2_RENAMED_KEYS[k] for k in sorted(deprecated_renamed_keys)} + renamed_bullets = [f'* {k!r} has been renamed to {v!r}' for k, v in renamings.items()] + removed_bullets = [f'* {k!r} has been removed' for k in sorted(deprecated_removed_keys)] + message = '\n'.join(['Valid config keys have changed in V2:'] + renamed_bullets + removed_bullets) + warnings.warn(message, UserWarning) diff --git a/lib/pydantic/_internal/_core_metadata.py b/lib/pydantic/_internal/_core_metadata.py new file mode 100644 index 00000000..296d49f5 --- /dev/null +++ b/lib/pydantic/_internal/_core_metadata.py @@ -0,0 +1,92 @@ +from __future__ import annotations as _annotations + +import typing +from typing import Any + +import typing_extensions + +if typing.TYPE_CHECKING: + from ._schema_generation_shared import ( + CoreSchemaOrField as CoreSchemaOrField, + ) + from ._schema_generation_shared import ( + GetJsonSchemaFunction, + ) + + +class CoreMetadata(typing_extensions.TypedDict, total=False): + """A `TypedDict` for holding the metadata dict of the schema. + + Attributes: + pydantic_js_functions: List of JSON schema functions. + pydantic_js_prefer_positional_arguments: Whether JSON schema generator will + prefer positional over keyword arguments for an 'arguments' schema. + """ + + pydantic_js_functions: list[GetJsonSchemaFunction] + pydantic_js_annotation_functions: list[GetJsonSchemaFunction] + + # If `pydantic_js_prefer_positional_arguments` is True, the JSON schema generator will + # prefer positional over keyword arguments for an 'arguments' schema. + pydantic_js_prefer_positional_arguments: bool | None + + pydantic_typed_dict_cls: type[Any] | None # TODO: Consider moving this into the pydantic-core TypedDictSchema + + +class CoreMetadataHandler: + """Because the metadata field in pydantic_core is of type `Any`, we can't assume much about its contents. + + This class is used to interact with the metadata field on a CoreSchema object in a consistent + way throughout pydantic. + """ + + __slots__ = ('_schema',) + + def __init__(self, schema: CoreSchemaOrField): + self._schema = schema + + metadata = schema.get('metadata') + if metadata is None: + schema['metadata'] = CoreMetadata() + elif not isinstance(metadata, dict): + raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.') + + @property + def metadata(self) -> CoreMetadata: + """Retrieves the metadata dict from the schema, initializing it to a dict if it is None + and raises an error if it is not a dict. + """ + metadata = self._schema.get('metadata') + if metadata is None: + self._schema['metadata'] = metadata = CoreMetadata() + if not isinstance(metadata, dict): + raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.') + return metadata + + +def build_metadata_dict( + *, # force keyword arguments to make it easier to modify this signature in a backwards-compatible way + js_functions: list[GetJsonSchemaFunction] | None = None, + js_annotation_functions: list[GetJsonSchemaFunction] | None = None, + js_prefer_positional_arguments: bool | None = None, + typed_dict_cls: type[Any] | None = None, + initial_metadata: Any | None = None, +) -> Any: + """Builds a dict to use as the metadata field of a CoreSchema object in a manner that is consistent + with the CoreMetadataHandler class. + """ + if initial_metadata is not None and not isinstance(initial_metadata, dict): + raise TypeError(f'CoreSchema metadata should be a dict; got {initial_metadata!r}.') + + metadata = CoreMetadata( + pydantic_js_functions=js_functions or [], + pydantic_js_annotation_functions=js_annotation_functions or [], + pydantic_js_prefer_positional_arguments=js_prefer_positional_arguments, + pydantic_typed_dict_cls=typed_dict_cls, + ) + metadata = {k: v for k, v in metadata.items() if v is not None} + + if initial_metadata is not None: + metadata = {**initial_metadata, **metadata} + + return metadata diff --git a/lib/pydantic/_internal/_core_utils.py b/lib/pydantic/_internal/_core_utils.py new file mode 100644 index 00000000..e74c74ac --- /dev/null +++ b/lib/pydantic/_internal/_core_utils.py @@ -0,0 +1,570 @@ +from __future__ import annotations + +import os +from collections import defaultdict +from typing import ( + Any, + Callable, + Hashable, + TypeVar, + Union, +) + +from pydantic_core import CoreSchema, core_schema +from pydantic_core import validate_core_schema as _validate_core_schema +from typing_extensions import TypeAliasType, TypeGuard, get_args, get_origin + +from . import _repr +from ._typing_extra import is_generic_alias + +AnyFunctionSchema = Union[ + core_schema.AfterValidatorFunctionSchema, + core_schema.BeforeValidatorFunctionSchema, + core_schema.WrapValidatorFunctionSchema, + core_schema.PlainValidatorFunctionSchema, +] + + +FunctionSchemaWithInnerSchema = Union[ + core_schema.AfterValidatorFunctionSchema, + core_schema.BeforeValidatorFunctionSchema, + core_schema.WrapValidatorFunctionSchema, +] + +CoreSchemaField = Union[ + core_schema.ModelField, core_schema.DataclassField, core_schema.TypedDictField, core_schema.ComputedField +] +CoreSchemaOrField = Union[core_schema.CoreSchema, CoreSchemaField] + +_CORE_SCHEMA_FIELD_TYPES = {'typed-dict-field', 'dataclass-field', 'model-field', 'computed-field'} +_FUNCTION_WITH_INNER_SCHEMA_TYPES = {'function-before', 'function-after', 'function-wrap'} +_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'set', 'frozenset'} + +_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache' + +TAGGED_UNION_TAG_KEY = 'pydantic.internal.tagged_union_tag' +""" +Used in a `Tag` schema to specify the tag used for a discriminated union. +""" +HAS_INVALID_SCHEMAS_METADATA_KEY = 'pydantic.internal.invalid' +"""Used to mark a schema that is invalid because it refers to a definition that was not yet defined when the +schema was first encountered. +""" + + +def is_core_schema( + schema: CoreSchemaOrField, +) -> TypeGuard[CoreSchema]: + return schema['type'] not in _CORE_SCHEMA_FIELD_TYPES + + +def is_core_schema_field( + schema: CoreSchemaOrField, +) -> TypeGuard[CoreSchemaField]: + return schema['type'] in _CORE_SCHEMA_FIELD_TYPES + + +def is_function_with_inner_schema( + schema: CoreSchemaOrField, +) -> TypeGuard[FunctionSchemaWithInnerSchema]: + return schema['type'] in _FUNCTION_WITH_INNER_SCHEMA_TYPES + + +def is_list_like_schema_with_items_schema( + schema: CoreSchema, +) -> TypeGuard[core_schema.ListSchema | core_schema.SetSchema | core_schema.FrozenSetSchema]: + return schema['type'] in _LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES + + +def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None = None) -> str: + """Produces the ref to be used for this type by pydantic_core's core schemas. + + This `args_override` argument was added for the purpose of creating valid recursive references + when creating generic models without needing to create a concrete class. + """ + origin = get_origin(type_) or type_ + + args = get_args(type_) if is_generic_alias(type_) else (args_override or ()) + generic_metadata = getattr(type_, '__pydantic_generic_metadata__', None) + if generic_metadata: + origin = generic_metadata['origin'] or origin + args = generic_metadata['args'] or args + + module_name = getattr(origin, '__module__', '') + if isinstance(origin, TypeAliasType): + type_ref = f'{module_name}.{origin.__name__}:{id(origin)}' + else: + try: + qualname = getattr(origin, '__qualname__', f'') + except Exception: + qualname = getattr(origin, '__qualname__', '') + type_ref = f'{module_name}.{qualname}:{id(origin)}' + + arg_refs: list[str] = [] + for arg in args: + if isinstance(arg, str): + # Handle string literals as a special case; we may be able to remove this special handling if we + # wrap them in a ForwardRef at some point. + arg_ref = f'{arg}:str-{id(arg)}' + else: + arg_ref = f'{_repr.display_as_type(arg)}:{id(arg)}' + arg_refs.append(arg_ref) + if arg_refs: + type_ref = f'{type_ref}[{",".join(arg_refs)}]' + return type_ref + + +def get_ref(s: core_schema.CoreSchema) -> None | str: + """Get the ref from the schema if it has one. + This exists just for type checking to work correctly. + """ + return s.get('ref', None) + + +def collect_definitions(schema: core_schema.CoreSchema) -> dict[str, core_schema.CoreSchema]: + defs: dict[str, CoreSchema] = {} + + def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: + ref = get_ref(s) + if ref: + defs[ref] = s + return recurse(s, _record_valid_refs) + + walk_core_schema(schema, _record_valid_refs) + + return defs + + +def define_expected_missing_refs( + schema: core_schema.CoreSchema, allowed_missing_refs: set[str] +) -> core_schema.CoreSchema | None: + if not allowed_missing_refs: + # in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema + # this is a common case (will be hit for all non-generic models), so it's worth optimizing for + return None + + refs = collect_definitions(schema).keys() + + expected_missing_refs = allowed_missing_refs.difference(refs) + if expected_missing_refs: + definitions: list[core_schema.CoreSchema] = [ + # TODO: Replace this with a (new) CoreSchema that, if present at any level, makes validation fail + # Issue: https://github.com/pydantic/pydantic-core/issues/619 + core_schema.none_schema(ref=ref, metadata={HAS_INVALID_SCHEMAS_METADATA_KEY: True}) + for ref in expected_missing_refs + ] + return core_schema.definitions_schema(schema, definitions) + return None + + +def collect_invalid_schemas(schema: core_schema.CoreSchema) -> bool: + invalid = False + + def _is_schema_valid(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: + nonlocal invalid + if 'metadata' in s: + metadata = s['metadata'] + if HAS_INVALID_SCHEMAS_METADATA_KEY in metadata: + invalid = metadata[HAS_INVALID_SCHEMAS_METADATA_KEY] + return s + return recurse(s, _is_schema_valid) + + walk_core_schema(schema, _is_schema_valid) + return invalid + + +T = TypeVar('T') + + +Recurse = Callable[[core_schema.CoreSchema, 'Walk'], core_schema.CoreSchema] +Walk = Callable[[core_schema.CoreSchema, Recurse], core_schema.CoreSchema] + +# TODO: Should we move _WalkCoreSchema into pydantic_core proper? +# Issue: https://github.com/pydantic/pydantic-core/issues/615 + + +class _WalkCoreSchema: + def __init__(self): + self._schema_type_to_method = self._build_schema_type_to_method() + + def _build_schema_type_to_method(self) -> dict[core_schema.CoreSchemaType, Recurse]: + mapping: dict[core_schema.CoreSchemaType, Recurse] = {} + key: core_schema.CoreSchemaType + for key in get_args(core_schema.CoreSchemaType): + method_name = f"handle_{key.replace('-', '_')}_schema" + mapping[key] = getattr(self, method_name, self._handle_other_schemas) + return mapping + + def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema: + return f(schema, self._walk) + + def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema: + schema = self._schema_type_to_method[schema['type']](schema.copy(), f) + ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore + if ser_schema: + schema['serialization'] = self._handle_ser_schemas(ser_schema, f) + return schema + + def _handle_other_schemas(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema: + sub_schema = schema.get('schema', None) + if sub_schema is not None: + schema['schema'] = self.walk(sub_schema, f) # type: ignore + return schema + + def _handle_ser_schemas(self, ser_schema: core_schema.SerSchema, f: Walk) -> core_schema.SerSchema: + schema: core_schema.CoreSchema | None = ser_schema.get('schema', None) + if schema is not None: + ser_schema['schema'] = self.walk(schema, f) # type: ignore + return_schema: core_schema.CoreSchema | None = ser_schema.get('return_schema', None) + if return_schema is not None: + ser_schema['return_schema'] = self.walk(return_schema, f) # type: ignore + return ser_schema + + def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema, f: Walk) -> core_schema.CoreSchema: + new_definitions: list[core_schema.CoreSchema] = [] + for definition in schema['definitions']: + if 'schema_ref' in definition and 'ref' in definition: + # This indicates a purposely indirect reference + # We want to keep such references around for implications related to JSON schema, etc.: + new_definitions.append(definition) + # However, we still need to walk the referenced definition: + self.walk(definition, f) + continue + + updated_definition = self.walk(definition, f) + if 'ref' in updated_definition: + # If the updated definition schema doesn't have a 'ref', it shouldn't go in the definitions + # This is most likely to happen due to replacing something with a definition reference, in + # which case it should certainly not go in the definitions list + new_definitions.append(updated_definition) + new_inner_schema = self.walk(schema['schema'], f) + + if not new_definitions and len(schema) == 3: + # This means we'd be returning a "trivial" definitions schema that just wrapped the inner schema + return new_inner_schema + + new_schema = schema.copy() + new_schema['schema'] = new_inner_schema + new_schema['definitions'] = new_definitions + return new_schema + + def handle_list_schema(self, schema: core_schema.ListSchema, f: Walk) -> core_schema.CoreSchema: + items_schema = schema.get('items_schema') + if items_schema is not None: + schema['items_schema'] = self.walk(items_schema, f) + return schema + + def handle_set_schema(self, schema: core_schema.SetSchema, f: Walk) -> core_schema.CoreSchema: + items_schema = schema.get('items_schema') + if items_schema is not None: + schema['items_schema'] = self.walk(items_schema, f) + return schema + + def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema, f: Walk) -> core_schema.CoreSchema: + items_schema = schema.get('items_schema') + if items_schema is not None: + schema['items_schema'] = self.walk(items_schema, f) + return schema + + def handle_generator_schema(self, schema: core_schema.GeneratorSchema, f: Walk) -> core_schema.CoreSchema: + items_schema = schema.get('items_schema') + if items_schema is not None: + schema['items_schema'] = self.walk(items_schema, f) + return schema + + def handle_tuple_schema(self, schema: core_schema.TupleSchema, f: Walk) -> core_schema.CoreSchema: + schema['items_schema'] = [self.walk(v, f) for v in schema['items_schema']] + return schema + + def handle_dict_schema(self, schema: core_schema.DictSchema, f: Walk) -> core_schema.CoreSchema: + keys_schema = schema.get('keys_schema') + if keys_schema is not None: + schema['keys_schema'] = self.walk(keys_schema, f) + values_schema = schema.get('values_schema') + if values_schema: + schema['values_schema'] = self.walk(values_schema, f) + return schema + + def handle_function_schema(self, schema: AnyFunctionSchema, f: Walk) -> core_schema.CoreSchema: + if not is_function_with_inner_schema(schema): + return schema + schema['schema'] = self.walk(schema['schema'], f) + return schema + + def handle_union_schema(self, schema: core_schema.UnionSchema, f: Walk) -> core_schema.CoreSchema: + new_choices: list[CoreSchema | tuple[CoreSchema, str]] = [] + for v in schema['choices']: + if isinstance(v, tuple): + new_choices.append((self.walk(v[0], f), v[1])) + else: + new_choices.append(self.walk(v, f)) + schema['choices'] = new_choices + return schema + + def handle_tagged_union_schema(self, schema: core_schema.TaggedUnionSchema, f: Walk) -> core_schema.CoreSchema: + new_choices: dict[Hashable, core_schema.CoreSchema] = {} + for k, v in schema['choices'].items(): + new_choices[k] = v if isinstance(v, (str, int)) else self.walk(v, f) + schema['choices'] = new_choices + return schema + + def handle_chain_schema(self, schema: core_schema.ChainSchema, f: Walk) -> core_schema.CoreSchema: + schema['steps'] = [self.walk(v, f) for v in schema['steps']] + return schema + + def handle_lax_or_strict_schema(self, schema: core_schema.LaxOrStrictSchema, f: Walk) -> core_schema.CoreSchema: + schema['lax_schema'] = self.walk(schema['lax_schema'], f) + schema['strict_schema'] = self.walk(schema['strict_schema'], f) + return schema + + def handle_json_or_python_schema(self, schema: core_schema.JsonOrPythonSchema, f: Walk) -> core_schema.CoreSchema: + schema['json_schema'] = self.walk(schema['json_schema'], f) + schema['python_schema'] = self.walk(schema['python_schema'], f) + return schema + + def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: Walk) -> core_schema.CoreSchema: + extras_schema = schema.get('extras_schema') + if extras_schema is not None: + schema['extras_schema'] = self.walk(extras_schema, f) + replaced_fields: dict[str, core_schema.ModelField] = {} + replaced_computed_fields: list[core_schema.ComputedField] = [] + for computed_field in schema.get('computed_fields', ()): + replaced_field = computed_field.copy() + replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f) + replaced_computed_fields.append(replaced_field) + if replaced_computed_fields: + schema['computed_fields'] = replaced_computed_fields + for k, v in schema['fields'].items(): + replaced_field = v.copy() + replaced_field['schema'] = self.walk(v['schema'], f) + replaced_fields[k] = replaced_field + schema['fields'] = replaced_fields + return schema + + def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk) -> core_schema.CoreSchema: + extras_schema = schema.get('extras_schema') + if extras_schema is not None: + schema['extras_schema'] = self.walk(extras_schema, f) + replaced_computed_fields: list[core_schema.ComputedField] = [] + for computed_field in schema.get('computed_fields', ()): + replaced_field = computed_field.copy() + replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f) + replaced_computed_fields.append(replaced_field) + if replaced_computed_fields: + schema['computed_fields'] = replaced_computed_fields + replaced_fields: dict[str, core_schema.TypedDictField] = {} + for k, v in schema['fields'].items(): + replaced_field = v.copy() + replaced_field['schema'] = self.walk(v['schema'], f) + replaced_fields[k] = replaced_field + schema['fields'] = replaced_fields + return schema + + def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema, f: Walk) -> core_schema.CoreSchema: + replaced_fields: list[core_schema.DataclassField] = [] + replaced_computed_fields: list[core_schema.ComputedField] = [] + for computed_field in schema.get('computed_fields', ()): + replaced_field = computed_field.copy() + replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f) + replaced_computed_fields.append(replaced_field) + if replaced_computed_fields: + schema['computed_fields'] = replaced_computed_fields + for field in schema['fields']: + replaced_field = field.copy() + replaced_field['schema'] = self.walk(field['schema'], f) + replaced_fields.append(replaced_field) + schema['fields'] = replaced_fields + return schema + + def handle_arguments_schema(self, schema: core_schema.ArgumentsSchema, f: Walk) -> core_schema.CoreSchema: + replaced_arguments_schema: list[core_schema.ArgumentsParameter] = [] + for param in schema['arguments_schema']: + replaced_param = param.copy() + replaced_param['schema'] = self.walk(param['schema'], f) + replaced_arguments_schema.append(replaced_param) + schema['arguments_schema'] = replaced_arguments_schema + if 'var_args_schema' in schema: + schema['var_args_schema'] = self.walk(schema['var_args_schema'], f) + if 'var_kwargs_schema' in schema: + schema['var_kwargs_schema'] = self.walk(schema['var_kwargs_schema'], f) + return schema + + def handle_call_schema(self, schema: core_schema.CallSchema, f: Walk) -> core_schema.CoreSchema: + schema['arguments_schema'] = self.walk(schema['arguments_schema'], f) + if 'return_schema' in schema: + schema['return_schema'] = self.walk(schema['return_schema'], f) + return schema + + +_dispatch = _WalkCoreSchema().walk + + +def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema: + """Recursively traverse a CoreSchema. + + Args: + schema (core_schema.CoreSchema): The CoreSchema to process, it will not be modified. + f (Walk): A function to apply. This function takes two arguments: + 1. The current CoreSchema that is being processed + (not the same one you passed into this function, one level down). + 2. The "next" `f` to call. This lets you for example use `f=functools.partial(some_method, some_context)` + to pass data down the recursive calls without using globals or other mutable state. + + Returns: + core_schema.CoreSchema: A processed CoreSchema. + """ + return f(schema.copy(), _dispatch) + + +def simplify_schema_references(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: # noqa: C901 + definitions: dict[str, core_schema.CoreSchema] = {} + ref_counts: dict[str, int] = defaultdict(int) + involved_in_recursion: dict[str, bool] = {} + current_recursion_ref_count: dict[str, int] = defaultdict(int) + + def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: + if s['type'] == 'definitions': + for definition in s['definitions']: + ref = get_ref(definition) + assert ref is not None + if ref not in definitions: + definitions[ref] = definition + recurse(definition, collect_refs) + return recurse(s['schema'], collect_refs) + else: + ref = get_ref(s) + if ref is not None: + new = recurse(s, collect_refs) + new_ref = get_ref(new) + if new_ref: + definitions[new_ref] = new + return core_schema.definition_reference_schema(schema_ref=ref) + else: + return recurse(s, collect_refs) + + schema = walk_core_schema(schema, collect_refs) + + def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: + if s['type'] != 'definition-ref': + return recurse(s, count_refs) + ref = s['schema_ref'] + ref_counts[ref] += 1 + + if ref_counts[ref] >= 2: + # If this model is involved in a recursion this should be detected + # on its second encounter, we can safely stop the walk here. + if current_recursion_ref_count[ref] != 0: + involved_in_recursion[ref] = True + return s + + current_recursion_ref_count[ref] += 1 + recurse(definitions[ref], count_refs) + current_recursion_ref_count[ref] -= 1 + return s + + schema = walk_core_schema(schema, count_refs) + + assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it' + + def can_be_inlined(s: core_schema.DefinitionReferenceSchema, ref: str) -> bool: + if ref_counts[ref] > 1: + return False + if involved_in_recursion.get(ref, False): + return False + if 'serialization' in s: + return False + if 'metadata' in s: + metadata = s['metadata'] + for k in ( + 'pydantic_js_functions', + 'pydantic_js_annotation_functions', + 'pydantic.internal.union_discriminator', + ): + if k in metadata: + # we need to keep this as a ref + return False + return True + + def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: + if s['type'] == 'definition-ref': + ref = s['schema_ref'] + # Check if the reference is only used once, not involved in recursion and does not have + # any extra keys (like 'serialization') + if can_be_inlined(s, ref): + # Inline the reference by replacing the reference with the actual schema + new = definitions.pop(ref) + ref_counts[ref] -= 1 # because we just replaced it! + # put all other keys that were on the def-ref schema into the inlined version + # in particular this is needed for `serialization` + if 'serialization' in s: + new['serialization'] = s['serialization'] + s = recurse(new, inline_refs) + return s + else: + return recurse(s, inline_refs) + else: + return recurse(s, inline_refs) + + schema = walk_core_schema(schema, inline_refs) + + def_values = [v for v in definitions.values() if ref_counts[v['ref']] > 0] # type: ignore + + if def_values: + schema = core_schema.definitions_schema(schema=schema, definitions=def_values) + return schema + + +def _strip_metadata(schema: CoreSchema) -> CoreSchema: + def strip_metadata(s: CoreSchema, recurse: Recurse) -> CoreSchema: + s = s.copy() + s.pop('metadata', None) + if s['type'] == 'model-fields': + s = s.copy() + s['fields'] = {k: v.copy() for k, v in s['fields'].items()} + for field_name, field_schema in s['fields'].items(): + field_schema.pop('metadata', None) + s['fields'][field_name] = field_schema + computed_fields = s.get('computed_fields', None) + if computed_fields: + s['computed_fields'] = [cf.copy() for cf in computed_fields] + for cf in computed_fields: + cf.pop('metadata', None) + else: + s.pop('computed_fields', None) + elif s['type'] == 'model': + # remove some defaults + if s.get('custom_init', True) is False: + s.pop('custom_init') + if s.get('root_model', True) is False: + s.pop('root_model') + if {'title'}.issuperset(s.get('config', {}).keys()): + s.pop('config', None) + + return recurse(s, strip_metadata) + + return walk_core_schema(schema, strip_metadata) + + +def pretty_print_core_schema( + schema: CoreSchema, + include_metadata: bool = False, +) -> None: + """Pretty print a CoreSchema using rich. + This is intended for debugging purposes. + + Args: + schema: The CoreSchema to print. + include_metadata: Whether to include metadata in the output. Defaults to `False`. + """ + from rich import print # type: ignore # install it manually in your dev env + + if not include_metadata: + schema = _strip_metadata(schema) + + return print(schema) + + +def validate_core_schema(schema: CoreSchema) -> CoreSchema: + if 'PYDANTIC_SKIP_VALIDATING_CORE_SCHEMAS' in os.environ: + return schema + return _validate_core_schema(schema) diff --git a/lib/pydantic/_internal/_dataclasses.py b/lib/pydantic/_internal/_dataclasses.py new file mode 100644 index 00000000..1ec23044 --- /dev/null +++ b/lib/pydantic/_internal/_dataclasses.py @@ -0,0 +1,225 @@ +"""Private logic for creating pydantic dataclasses.""" +from __future__ import annotations as _annotations + +import dataclasses +import typing +import warnings +from functools import partial, wraps +from typing import Any, Callable, ClassVar + +from pydantic_core import ( + ArgsKwargs, + SchemaSerializer, + SchemaValidator, + core_schema, +) +from typing_extensions import TypeGuard + +from ..errors import PydanticUndefinedAnnotation +from ..fields import FieldInfo +from ..plugin._schema_validator import create_schema_validator +from ..warnings import PydanticDeprecatedSince20 +from . import _config, _decorators, _typing_extra +from ._fields import collect_dataclass_fields +from ._generate_schema import GenerateSchema +from ._generics import get_standard_typevars_map +from ._mock_val_ser import set_dataclass_mocks +from ._schema_generation_shared import CallbackGetCoreSchemaHandler +from ._signature import generate_pydantic_signature + +if typing.TYPE_CHECKING: + from ..config import ConfigDict + + class StandardDataclass(typing.Protocol): + __dataclass_fields__: ClassVar[dict[str, Any]] + __dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams` + __post_init__: ClassVar[Callable[..., None]] + + def __init__(self, *args: object, **kwargs: object) -> None: + pass + + class PydanticDataclass(StandardDataclass, typing.Protocol): + """A protocol containing attributes only available once a class has been decorated as a Pydantic dataclass. + + Attributes: + __pydantic_config__: Pydantic-specific configuration settings for the dataclass. + __pydantic_complete__: Whether dataclass building is completed, or if there are still undefined fields. + __pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer. + __pydantic_decorators__: Metadata containing the decorators defined on the dataclass. + __pydantic_fields__: Metadata about the fields defined on the dataclass. + __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the dataclass. + __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the dataclass. + """ + + __pydantic_config__: ClassVar[ConfigDict] + __pydantic_complete__: ClassVar[bool] + __pydantic_core_schema__: ClassVar[core_schema.CoreSchema] + __pydantic_decorators__: ClassVar[_decorators.DecoratorInfos] + __pydantic_fields__: ClassVar[dict[str, FieldInfo]] + __pydantic_serializer__: ClassVar[SchemaSerializer] + __pydantic_validator__: ClassVar[SchemaValidator] + +else: + # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915 + # and https://youtrack.jetbrains.com/issue/PY-51428 + DeprecationWarning = PydanticDeprecatedSince20 + + +def set_dataclass_fields(cls: type[StandardDataclass], types_namespace: dict[str, Any] | None = None) -> None: + """Collect and set `cls.__pydantic_fields__`. + + Args: + cls: The class. + types_namespace: The types namespace, defaults to `None`. + """ + typevars_map = get_standard_typevars_map(cls) + fields = collect_dataclass_fields(cls, types_namespace, typevars_map=typevars_map) + + cls.__pydantic_fields__ = fields # type: ignore + + +def complete_dataclass( + cls: type[Any], + config_wrapper: _config.ConfigWrapper, + *, + raise_errors: bool = True, + types_namespace: dict[str, Any] | None, +) -> bool: + """Finish building a pydantic dataclass. + + This logic is called on a class which has already been wrapped in `dataclasses.dataclass()`. + + This is somewhat analogous to `pydantic._internal._model_construction.complete_model_class`. + + Args: + cls: The class. + config_wrapper: The config wrapper instance. + raise_errors: Whether to raise errors, defaults to `True`. + types_namespace: The types namespace. + + Returns: + `True` if building a pydantic dataclass is successfully completed, `False` otherwise. + + Raises: + PydanticUndefinedAnnotation: If `raise_error` is `True` and there is an undefined annotations. + """ + if hasattr(cls, '__post_init_post_parse__'): + warnings.warn( + 'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning + ) + + if types_namespace is None: + types_namespace = _typing_extra.get_cls_types_namespace(cls) + + set_dataclass_fields(cls, types_namespace) + + typevars_map = get_standard_typevars_map(cls) + gen_schema = GenerateSchema( + config_wrapper, + types_namespace, + typevars_map, + ) + + # This needs to be called before we change the __init__ + sig = generate_pydantic_signature( + init=cls.__init__, + fields=cls.__pydantic_fields__, # type: ignore + config_wrapper=config_wrapper, + is_dataclass=True, + ) + + # dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied. + def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None: + __tracebackhide__ = True + s = __dataclass_self__ + s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s) + + __init__.__qualname__ = f'{cls.__qualname__}.__init__' + + cls.__init__ = __init__ # type: ignore + cls.__pydantic_config__ = config_wrapper.config_dict # type: ignore + cls.__signature__ = sig # type: ignore + get_core_schema = getattr(cls, '__get_pydantic_core_schema__', None) + try: + if get_core_schema: + schema = get_core_schema( + cls, + CallbackGetCoreSchemaHandler( + partial(gen_schema.generate_schema, from_dunder_get_core_schema=False), + gen_schema, + ref_mode='unpack', + ), + ) + else: + schema = gen_schema.generate_schema(cls, from_dunder_get_core_schema=False) + except PydanticUndefinedAnnotation as e: + if raise_errors: + raise + set_dataclass_mocks(cls, cls.__name__, f'`{e.name}`') + return False + + core_config = config_wrapper.core_config(cls) + + try: + schema = gen_schema.clean_schema(schema) + except gen_schema.CollectedInvalid: + set_dataclass_mocks(cls, cls.__name__, 'all referenced types') + return False + + # We are about to set all the remaining required properties expected for this cast; + # __pydantic_decorators__ and __pydantic_fields__ should already be set + cls = typing.cast('type[PydanticDataclass]', cls) + # debug(schema) + + cls.__pydantic_core_schema__ = schema + cls.__pydantic_validator__ = validator = create_schema_validator( + schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings + ) + cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config) + + if config_wrapper.validate_assignment: + + @wraps(cls.__setattr__) + def validated_setattr(instance: Any, __field: str, __value: str) -> None: + validator.validate_assignment(instance, __field, __value) + + cls.__setattr__ = validated_setattr.__get__(None, cls) # type: ignore + + return True + + +def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]: + """Returns True if a class is a stdlib dataclass and *not* a pydantic dataclass. + + We check that + - `_cls` is a dataclass + - `_cls` does not inherit from a processed pydantic dataclass (and thus have a `__pydantic_validator__`) + - `_cls` does not have any annotations that are not dataclass fields + e.g. + ```py + import dataclasses + + import pydantic.dataclasses + + @dataclasses.dataclass + class A: + x: int + + @pydantic.dataclasses.dataclass + class B(A): + y: int + ``` + In this case, when we first check `B`, we make an extra check and look at the annotations ('y'), + which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x') + + Args: + cls: The class. + + Returns: + `True` if the class is a stdlib dataclass, `False` otherwise. + """ + return ( + dataclasses.is_dataclass(_cls) + and not hasattr(_cls, '__pydantic_validator__') + and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {}))) + ) diff --git a/lib/pydantic/_internal/_decorators.py b/lib/pydantic/_internal/_decorators.py new file mode 100644 index 00000000..5672464c --- /dev/null +++ b/lib/pydantic/_internal/_decorators.py @@ -0,0 +1,791 @@ +"""Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators.""" +from __future__ import annotations as _annotations + +from collections import deque +from dataclasses import dataclass, field +from functools import cached_property, partial, partialmethod +from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature +from itertools import islice +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Iterable, TypeVar, Union + +from pydantic_core import PydanticUndefined, core_schema +from typing_extensions import Literal, TypeAlias, is_typeddict + +from ..errors import PydanticUserError +from ._core_utils import get_type_ref +from ._internal_dataclass import slots_true +from ._typing_extra import get_function_type_hints + +if TYPE_CHECKING: + from ..fields import ComputedFieldInfo + from ..functional_validators import FieldValidatorModes + + +@dataclass(**slots_true) +class ValidatorDecoratorInfo: + """A container for data from `@validator` so that we can access it + while building the pydantic-core schema. + + Attributes: + decorator_repr: A class variable representing the decorator string, '@validator'. + fields: A tuple of field names the validator should be called on. + mode: The proposed validator mode. + each_item: For complex objects (sets, lists etc.) whether to validate individual + elements rather than the whole object. + always: Whether this method and other validators should be called even if the value is missing. + check_fields: Whether to check that the fields actually exist on the model. + """ + + decorator_repr: ClassVar[str] = '@validator' + + fields: tuple[str, ...] + mode: Literal['before', 'after'] + each_item: bool + always: bool + check_fields: bool | None + + +@dataclass(**slots_true) +class FieldValidatorDecoratorInfo: + """A container for data from `@field_validator` so that we can access it + while building the pydantic-core schema. + + Attributes: + decorator_repr: A class variable representing the decorator string, '@field_validator'. + fields: A tuple of field names the validator should be called on. + mode: The proposed validator mode. + check_fields: Whether to check that the fields actually exist on the model. + """ + + decorator_repr: ClassVar[str] = '@field_validator' + + fields: tuple[str, ...] + mode: FieldValidatorModes + check_fields: bool | None + + +@dataclass(**slots_true) +class RootValidatorDecoratorInfo: + """A container for data from `@root_validator` so that we can access it + while building the pydantic-core schema. + + Attributes: + decorator_repr: A class variable representing the decorator string, '@root_validator'. + mode: The proposed validator mode. + """ + + decorator_repr: ClassVar[str] = '@root_validator' + mode: Literal['before', 'after'] + + +@dataclass(**slots_true) +class FieldSerializerDecoratorInfo: + """A container for data from `@field_serializer` so that we can access it + while building the pydantic-core schema. + + Attributes: + decorator_repr: A class variable representing the decorator string, '@field_serializer'. + fields: A tuple of field names the serializer should be called on. + mode: The proposed serializer mode. + return_type: The type of the serializer's return value. + when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`, + and `'json-unless-none'`. + check_fields: Whether to check that the fields actually exist on the model. + """ + + decorator_repr: ClassVar[str] = '@field_serializer' + fields: tuple[str, ...] + mode: Literal['plain', 'wrap'] + return_type: Any + when_used: core_schema.WhenUsed + check_fields: bool | None + + +@dataclass(**slots_true) +class ModelSerializerDecoratorInfo: + """A container for data from `@model_serializer` so that we can access it + while building the pydantic-core schema. + + Attributes: + decorator_repr: A class variable representing the decorator string, '@model_serializer'. + mode: The proposed serializer mode. + return_type: The type of the serializer's return value. + when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`, + and `'json-unless-none'`. + """ + + decorator_repr: ClassVar[str] = '@model_serializer' + mode: Literal['plain', 'wrap'] + return_type: Any + when_used: core_schema.WhenUsed + + +@dataclass(**slots_true) +class ModelValidatorDecoratorInfo: + """A container for data from `@model_validator` so that we can access it + while building the pydantic-core schema. + + Attributes: + decorator_repr: A class variable representing the decorator string, '@model_serializer'. + mode: The proposed serializer mode. + """ + + decorator_repr: ClassVar[str] = '@model_validator' + mode: Literal['wrap', 'before', 'after'] + + +DecoratorInfo: TypeAlias = """Union[ + ValidatorDecoratorInfo, + FieldValidatorDecoratorInfo, + RootValidatorDecoratorInfo, + FieldSerializerDecoratorInfo, + ModelSerializerDecoratorInfo, + ModelValidatorDecoratorInfo, + ComputedFieldInfo, +]""" + +ReturnType = TypeVar('ReturnType') +DecoratedType: TypeAlias = ( + 'Union[classmethod[Any, Any, ReturnType], staticmethod[Any, ReturnType], Callable[..., ReturnType], property]' +) + + +@dataclass # can't use slots here since we set attributes on `__post_init__` +class PydanticDescriptorProxy(Generic[ReturnType]): + """Wrap a classmethod, staticmethod, property or unbound function + and act as a descriptor that allows us to detect decorated items + from the class' attributes. + + This class' __get__ returns the wrapped item's __get__ result, + which makes it transparent for classmethods and staticmethods. + + Attributes: + wrapped: The decorator that has to be wrapped. + decorator_info: The decorator info. + shim: A wrapper function to wrap V1 style function. + """ + + wrapped: DecoratedType[ReturnType] + decorator_info: DecoratorInfo + shim: Callable[[Callable[..., Any]], Callable[..., Any]] | None = None + + def __post_init__(self): + for attr in 'setter', 'deleter': + if hasattr(self.wrapped, attr): + f = partial(self._call_wrapped_attr, name=attr) + setattr(self, attr, f) + + def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]: + self.wrapped = getattr(self.wrapped, name)(func) + return self + + def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]: + try: + return self.wrapped.__get__(obj, obj_type) + except AttributeError: + # not a descriptor, e.g. a partial object + return self.wrapped # type: ignore[return-value] + + def __set_name__(self, instance: Any, name: str) -> None: + if hasattr(self.wrapped, '__set_name__'): + self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess] + + def __getattr__(self, __name: str) -> Any: + """Forward checks for __isabstractmethod__ and such.""" + return getattr(self.wrapped, __name) + + +DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo) + + +@dataclass(**slots_true) +class Decorator(Generic[DecoratorInfoType]): + """A generic container class to join together the decorator metadata + (metadata from decorator itself, which we have when the + decorator is called but not when we are building the core-schema) + and the bound function (which we have after the class itself is created). + + Attributes: + cls_ref: The class ref. + cls_var_name: The decorated function name. + func: The decorated function. + shim: A wrapper function to wrap V1 style function. + info: The decorator info. + """ + + cls_ref: str + cls_var_name: str + func: Callable[..., Any] + shim: Callable[[Any], Any] | None + info: DecoratorInfoType + + @staticmethod + def build( + cls_: Any, + *, + cls_var_name: str, + shim: Callable[[Any], Any] | None, + info: DecoratorInfoType, + ) -> Decorator[DecoratorInfoType]: + """Build a new decorator. + + Args: + cls_: The class. + cls_var_name: The decorated function name. + shim: A wrapper function to wrap V1 style function. + info: The decorator info. + + Returns: + The new decorator instance. + """ + func = get_attribute_from_bases(cls_, cls_var_name) + if shim is not None: + func = shim(func) + func = unwrap_wrapped_function(func, unwrap_partial=False) + if not callable(func): + # This branch will get hit for classmethod properties + attribute = get_attribute_from_base_dicts(cls_, cls_var_name) # prevents the binding call to `__get__` + if isinstance(attribute, PydanticDescriptorProxy): + func = unwrap_wrapped_function(attribute.wrapped) + return Decorator( + cls_ref=get_type_ref(cls_), + cls_var_name=cls_var_name, + func=func, + shim=shim, + info=info, + ) + + def bind_to_cls(self, cls: Any) -> Decorator[DecoratorInfoType]: + """Bind the decorator to a class. + + Args: + cls: the class. + + Returns: + The new decorator instance. + """ + return self.build( + cls, + cls_var_name=self.cls_var_name, + shim=self.shim, + info=self.info, + ) + + +def get_bases(tp: type[Any]) -> tuple[type[Any], ...]: + """Get the base classes of a class or typeddict. + + Args: + tp: The type or class to get the bases. + + Returns: + The base classes. + """ + if is_typeddict(tp): + return tp.__orig_bases__ # type: ignore + try: + return tp.__bases__ + except AttributeError: + return () + + +def mro(tp: type[Any]) -> tuple[type[Any], ...]: + """Calculate the Method Resolution Order of bases using the C3 algorithm. + + See https://www.python.org/download/releases/2.3/mro/ + """ + # try to use the existing mro, for performance mainly + # but also because it helps verify the implementation below + if not is_typeddict(tp): + try: + return tp.__mro__ + except AttributeError: + # GenericAlias and some other cases + pass + + bases = get_bases(tp) + return (tp,) + mro_for_bases(bases) + + +def mro_for_bases(bases: tuple[type[Any], ...]) -> tuple[type[Any], ...]: + def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]: + while True: + non_empty = [seq for seq in seqs if seq] + if not non_empty: + # Nothing left to process, we're done. + return + candidate: type[Any] | None = None + for seq in non_empty: # Find merge candidates among seq heads. + candidate = seq[0] + not_head = [s for s in non_empty if candidate in islice(s, 1, None)] + if not_head: + # Reject the candidate. + candidate = None + else: + break + if not candidate: + raise TypeError('Inconsistent hierarchy, no C3 MRO is possible') + yield candidate + for seq in non_empty: + # Remove candidate. + if seq[0] == candidate: + seq.popleft() + + seqs = [deque(mro(base)) for base in bases] + [deque(bases)] + return tuple(merge_seqs(seqs)) + + +_sentinel = object() + + +def get_attribute_from_bases(tp: type[Any] | tuple[type[Any], ...], name: str) -> Any: + """Get the attribute from the next class in the MRO that has it, + aiming to simulate calling the method on the actual class. + + The reason for iterating over the mro instead of just getting + the attribute (which would do that for us) is to support TypedDict, + which lacks a real __mro__, but can have a virtual one constructed + from its bases (as done here). + + Args: + tp: The type or class to search for the attribute. If a tuple, this is treated as a set of base classes. + name: The name of the attribute to retrieve. + + Returns: + Any: The attribute value, if found. + + Raises: + AttributeError: If the attribute is not found in any class in the MRO. + """ + if isinstance(tp, tuple): + for base in mro_for_bases(tp): + attribute = base.__dict__.get(name, _sentinel) + if attribute is not _sentinel: + attribute_get = getattr(attribute, '__get__', None) + if attribute_get is not None: + return attribute_get(None, tp) + return attribute + raise AttributeError(f'{name} not found in {tp}') + else: + try: + return getattr(tp, name) + except AttributeError: + return get_attribute_from_bases(mro(tp), name) + + +def get_attribute_from_base_dicts(tp: type[Any], name: str) -> Any: + """Get an attribute out of the `__dict__` following the MRO. + This prevents the call to `__get__` on the descriptor, and allows + us to get the original function for classmethod properties. + + Args: + tp: The type or class to search for the attribute. + name: The name of the attribute to retrieve. + + Returns: + Any: The attribute value, if found. + + Raises: + KeyError: If the attribute is not found in any class's `__dict__` in the MRO. + """ + for base in reversed(mro(tp)): + if name in base.__dict__: + return base.__dict__[name] + return tp.__dict__[name] # raise the error + + +@dataclass(**slots_true) +class DecoratorInfos: + """Mapping of name in the class namespace to decorator info. + + note that the name in the class namespace is the function or attribute name + not the field name! + """ + + validators: dict[str, Decorator[ValidatorDecoratorInfo]] = field(default_factory=dict) + field_validators: dict[str, Decorator[FieldValidatorDecoratorInfo]] = field(default_factory=dict) + root_validators: dict[str, Decorator[RootValidatorDecoratorInfo]] = field(default_factory=dict) + field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]] = field(default_factory=dict) + model_serializers: dict[str, Decorator[ModelSerializerDecoratorInfo]] = field(default_factory=dict) + model_validators: dict[str, Decorator[ModelValidatorDecoratorInfo]] = field(default_factory=dict) + computed_fields: dict[str, Decorator[ComputedFieldInfo]] = field(default_factory=dict) + + @staticmethod + def build(model_dc: type[Any]) -> DecoratorInfos: # noqa: C901 (ignore complexity) + """We want to collect all DecFunc instances that exist as + attributes in the namespace of the class (a BaseModel or dataclass) + that called us + But we want to collect these in the order of the bases + So instead of getting them all from the leaf class (the class that called us), + we traverse the bases from root (the oldest ancestor class) to leaf + and collect all of the instances as we go, taking care to replace + any duplicate ones with the last one we see to mimic how function overriding + works with inheritance. + If we do replace any functions we put the replacement into the position + the replaced function was in; that is, we maintain the order. + """ + # reminder: dicts are ordered and replacement does not alter the order + res = DecoratorInfos() + for base in reversed(mro(model_dc)[1:]): + existing: DecoratorInfos | None = base.__dict__.get('__pydantic_decorators__') + if existing is None: + existing = DecoratorInfos.build(base) + res.validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.validators.items()}) + res.field_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_validators.items()}) + res.root_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.root_validators.items()}) + res.field_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_serializers.items()}) + res.model_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_serializers.items()}) + res.model_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_validators.items()}) + res.computed_fields.update({k: v.bind_to_cls(model_dc) for k, v in existing.computed_fields.items()}) + + to_replace: list[tuple[str, Any]] = [] + + for var_name, var_value in vars(model_dc).items(): + if isinstance(var_value, PydanticDescriptorProxy): + info = var_value.decorator_info + if isinstance(info, ValidatorDecoratorInfo): + res.validators[var_name] = Decorator.build( + model_dc, cls_var_name=var_name, shim=var_value.shim, info=info + ) + elif isinstance(info, FieldValidatorDecoratorInfo): + res.field_validators[var_name] = Decorator.build( + model_dc, cls_var_name=var_name, shim=var_value.shim, info=info + ) + elif isinstance(info, RootValidatorDecoratorInfo): + res.root_validators[var_name] = Decorator.build( + model_dc, cls_var_name=var_name, shim=var_value.shim, info=info + ) + elif isinstance(info, FieldSerializerDecoratorInfo): + # check whether a serializer function is already registered for fields + for field_serializer_decorator in res.field_serializers.values(): + # check that each field has at most one serializer function. + # serializer functions for the same field in subclasses are allowed, + # and are treated as overrides + if field_serializer_decorator.cls_var_name == var_name: + continue + for f in info.fields: + if f in field_serializer_decorator.info.fields: + raise PydanticUserError( + 'Multiple field serializer functions were defined ' + f'for field {f!r}, this is not allowed.', + code='multiple-field-serializers', + ) + res.field_serializers[var_name] = Decorator.build( + model_dc, cls_var_name=var_name, shim=var_value.shim, info=info + ) + elif isinstance(info, ModelValidatorDecoratorInfo): + res.model_validators[var_name] = Decorator.build( + model_dc, cls_var_name=var_name, shim=var_value.shim, info=info + ) + elif isinstance(info, ModelSerializerDecoratorInfo): + res.model_serializers[var_name] = Decorator.build( + model_dc, cls_var_name=var_name, shim=var_value.shim, info=info + ) + else: + from ..fields import ComputedFieldInfo + + isinstance(var_value, ComputedFieldInfo) + res.computed_fields[var_name] = Decorator.build( + model_dc, cls_var_name=var_name, shim=None, info=info + ) + to_replace.append((var_name, var_value.wrapped)) + if to_replace: + # If we can save `__pydantic_decorators__` on the class we'll be able to check for it above + # so then we don't need to re-process the type, which means we can discard our descriptor wrappers + # and replace them with the thing they are wrapping (see the other setattr call below) + # which allows validator class methods to also function as regular class methods + setattr(model_dc, '__pydantic_decorators__', res) + for name, value in to_replace: + setattr(model_dc, name, value) + return res + + +def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes) -> bool: + """Look at a field or model validator function and determine whether it takes an info argument. + + An error is raised if the function has an invalid signature. + + Args: + validator: The validator function to inspect. + mode: The proposed validator mode. + + Returns: + Whether the validator takes an info argument. + """ + try: + sig = signature(validator) + except ValueError: + # builtins and some C extensions don't have signatures + # assume that they don't take an info argument and only take a single argument + # e.g. `str.strip` or `datetime.datetime` + return False + n_positional = count_positional_params(sig) + if mode == 'wrap': + if n_positional == 3: + return True + elif n_positional == 2: + return False + else: + assert mode in {'before', 'after', 'plain'}, f"invalid mode: {mode!r}, expected 'before', 'after' or 'plain" + if n_positional == 2: + return True + elif n_positional == 1: + return False + + raise PydanticUserError( + f'Unrecognized field_validator function signature for {validator} with `mode={mode}`:{sig}', + code='validator-signature', + ) + + +def inspect_field_serializer( + serializer: Callable[..., Any], mode: Literal['plain', 'wrap'], computed_field: bool = False +) -> tuple[bool, bool]: + """Look at a field serializer function and determine if it is a field serializer, + and whether it takes an info argument. + + An error is raised if the function has an invalid signature. + + Args: + serializer: The serializer function to inspect. + mode: The serializer mode, either 'plain' or 'wrap'. + computed_field: When serializer is applied on computed_field. It doesn't require + info signature. + + Returns: + Tuple of (is_field_serializer, info_arg). + """ + sig = signature(serializer) + + first = next(iter(sig.parameters.values()), None) + is_field_serializer = first is not None and first.name == 'self' + + n_positional = count_positional_params(sig) + if is_field_serializer: + # -1 to correct for self parameter + info_arg = _serializer_info_arg(mode, n_positional - 1) + else: + info_arg = _serializer_info_arg(mode, n_positional) + + if info_arg is None: + raise PydanticUserError( + f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}', + code='field-serializer-signature', + ) + if info_arg and computed_field: + raise PydanticUserError( + 'field_serializer on computed_field does not use info signature', code='field-serializer-signature' + ) + + else: + return is_field_serializer, info_arg + + +def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool: + """Look at a serializer function used via `Annotated` and determine whether it takes an info argument. + + An error is raised if the function has an invalid signature. + + Args: + serializer: The serializer function to check. + mode: The serializer mode, either 'plain' or 'wrap'. + + Returns: + info_arg + """ + sig = signature(serializer) + info_arg = _serializer_info_arg(mode, count_positional_params(sig)) + if info_arg is None: + raise PydanticUserError( + f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}', + code='field-serializer-signature', + ) + else: + return info_arg + + +def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool: + """Look at a model serializer function and determine whether it takes an info argument. + + An error is raised if the function has an invalid signature. + + Args: + serializer: The serializer function to check. + mode: The serializer mode, either 'plain' or 'wrap'. + + Returns: + `info_arg` - whether the function expects an info argument. + """ + if isinstance(serializer, (staticmethod, classmethod)) or not is_instance_method_from_sig(serializer): + raise PydanticUserError( + '`@model_serializer` must be applied to instance methods', code='model-serializer-instance-method' + ) + + sig = signature(serializer) + info_arg = _serializer_info_arg(mode, count_positional_params(sig)) + if info_arg is None: + raise PydanticUserError( + f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}', + code='model-serializer-signature', + ) + else: + return info_arg + + +def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None: + if mode == 'plain': + if n_positional == 1: + # (__input_value: Any) -> Any + return False + elif n_positional == 2: + # (__model: Any, __input_value: Any) -> Any + return True + else: + assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'" + if n_positional == 2: + # (__input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any + return False + elif n_positional == 3: + # (__input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: SerializationInfo) -> Any + return True + + return None + + +AnyDecoratorCallable: TypeAlias = ( + 'Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any], Callable[..., Any]]' +) + + +def is_instance_method_from_sig(function: AnyDecoratorCallable) -> bool: + """Whether the function is an instance method. + + It will consider a function as instance method if the first parameter of + function is `self`. + + Args: + function: The function to check. + + Returns: + `True` if the function is an instance method, `False` otherwise. + """ + sig = signature(unwrap_wrapped_function(function)) + first = next(iter(sig.parameters.values()), None) + if first and first.name == 'self': + return True + return False + + +def ensure_classmethod_based_on_signature(function: AnyDecoratorCallable) -> Any: + """Apply the `@classmethod` decorator on the function. + + Args: + function: The function to apply the decorator on. + + Return: + The `@classmethod` decorator applied function. + """ + if not isinstance( + unwrap_wrapped_function(function, unwrap_class_static_method=False), classmethod + ) and _is_classmethod_from_sig(function): + return classmethod(function) # type: ignore[arg-type] + return function + + +def _is_classmethod_from_sig(function: AnyDecoratorCallable) -> bool: + sig = signature(unwrap_wrapped_function(function)) + first = next(iter(sig.parameters.values()), None) + if first and first.name == 'cls': + return True + return False + + +def unwrap_wrapped_function( + func: Any, + *, + unwrap_partial: bool = True, + unwrap_class_static_method: bool = True, +) -> Any: + """Recursively unwraps a wrapped function until the underlying function is reached. + This handles property, functools.partial, functools.partialmethod, staticmethod and classmethod. + + Args: + func: The function to unwrap. + unwrap_partial: If True (default), unwrap partial and partialmethod decorators, otherwise don't. + decorators. + unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod + decorators. If False, only unwrap partial and partialmethod decorators. + + Returns: + The underlying function of the wrapped function. + """ + all: set[Any] = {property, cached_property} + + if unwrap_partial: + all.update({partial, partialmethod}) + + if unwrap_class_static_method: + all.update({staticmethod, classmethod}) + + while isinstance(func, tuple(all)): + if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)): + func = func.__func__ + elif isinstance(func, (partial, partialmethod)): + func = func.func + elif isinstance(func, property): + func = func.fget # arbitrary choice, convenient for computed fields + else: + # Make coverage happy as it can only get here in the last possible case + assert isinstance(func, cached_property) + func = func.func # type: ignore + + return func + + +def get_function_return_type( + func: Any, explicit_return_type: Any, types_namespace: dict[str, Any] | None = None +) -> Any: + """Get the function return type. + + It gets the return type from the type annotation if `explicit_return_type` is `None`. + Otherwise, it returns `explicit_return_type`. + + Args: + func: The function to get its return type. + explicit_return_type: The explicit return type. + types_namespace: The types namespace, defaults to `None`. + + Returns: + The function return type. + """ + if explicit_return_type is PydanticUndefined: + # try to get it from the type annotation + hints = get_function_type_hints( + unwrap_wrapped_function(func), include_keys={'return'}, types_namespace=types_namespace + ) + return hints.get('return', PydanticUndefined) + else: + return explicit_return_type + + +def count_positional_params(sig: Signature) -> int: + return sum(1 for param in sig.parameters.values() if can_be_positional(param)) + + +def can_be_positional(param: Parameter) -> bool: + return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) + + +def ensure_property(f: Any) -> Any: + """Ensure that a function is a `property` or `cached_property`, or is a valid descriptor. + + Args: + f: The function to check. + + Returns: + The function, or a `property` or `cached_property` instance wrapping the function. + """ + if ismethoddescriptor(f) or isdatadescriptor(f): + return f + else: + return property(f) diff --git a/lib/pydantic/_internal/_decorators_v1.py b/lib/pydantic/_internal/_decorators_v1.py new file mode 100644 index 00000000..4f81e6d4 --- /dev/null +++ b/lib/pydantic/_internal/_decorators_v1.py @@ -0,0 +1,181 @@ +"""Logic for V1 validators, e.g. `@validator` and `@root_validator`.""" +from __future__ import annotations as _annotations + +from inspect import Parameter, signature +from typing import Any, Dict, Tuple, Union, cast + +from pydantic_core import core_schema +from typing_extensions import Protocol + +from ..errors import PydanticUserError +from ._decorators import can_be_positional + + +class V1OnlyValueValidator(Protocol): + """A simple validator, supported for V1 validators and V2 validators.""" + + def __call__(self, __value: Any) -> Any: + ... + + +class V1ValidatorWithValues(Protocol): + """A validator with `values` argument, supported for V1 validators and V2 validators.""" + + def __call__(self, __value: Any, values: dict[str, Any]) -> Any: + ... + + +class V1ValidatorWithValuesKwOnly(Protocol): + """A validator with keyword only `values` argument, supported for V1 validators and V2 validators.""" + + def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any: + ... + + +class V1ValidatorWithKwargs(Protocol): + """A validator with `kwargs` argument, supported for V1 validators and V2 validators.""" + + def __call__(self, __value: Any, **kwargs: Any) -> Any: + ... + + +class V1ValidatorWithValuesAndKwargs(Protocol): + """A validator with `values` and `kwargs` arguments, supported for V1 validators and V2 validators.""" + + def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any: + ... + + +V1Validator = Union[ + V1ValidatorWithValues, V1ValidatorWithValuesKwOnly, V1ValidatorWithKwargs, V1ValidatorWithValuesAndKwargs +] + + +def can_be_keyword(param: Parameter) -> bool: + return param.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY) + + +def make_generic_v1_field_validator(validator: V1Validator) -> core_schema.WithInfoValidatorFunction: + """Wrap a V1 style field validator for V2 compatibility. + + Args: + validator: The V1 style field validator. + + Returns: + A wrapped V2 style field validator. + + Raises: + PydanticUserError: If the signature is not supported or the parameters are + not available in Pydantic V2. + """ + sig = signature(validator) + + needs_values_kw = False + + for param_num, (param_name, parameter) in enumerate(sig.parameters.items()): + if can_be_keyword(parameter) and param_name in ('field', 'config'): + raise PydanticUserError( + 'The `field` and `config` parameters are not available in Pydantic V2, ' + 'please use the `info` parameter instead.', + code='validator-field-config-info', + ) + if parameter.kind is Parameter.VAR_KEYWORD: + needs_values_kw = True + elif can_be_keyword(parameter) and param_name == 'values': + needs_values_kw = True + elif can_be_positional(parameter) and param_num == 0: + # value + continue + elif parameter.default is Parameter.empty: # ignore params with defaults e.g. bound by functools.partial + raise PydanticUserError( + f'Unsupported signature for V1 style validator {validator}: {sig} is not supported.', + code='validator-v1-signature', + ) + + if needs_values_kw: + # (v, **kwargs), (v, values, **kwargs), (v, *, values, **kwargs) or (v, *, values) + val1 = cast(V1ValidatorWithValues, validator) + + def wrapper1(value: Any, info: core_schema.ValidationInfo) -> Any: + return val1(value, values=info.data) + + return wrapper1 + else: + val2 = cast(V1OnlyValueValidator, validator) + + def wrapper2(value: Any, _: core_schema.ValidationInfo) -> Any: + return val2(value) + + return wrapper2 + + +RootValidatorValues = Dict[str, Any] +# technically tuple[model_dict, model_extra, fields_set] | tuple[dataclass_dict, init_vars] +RootValidatorFieldsTuple = Tuple[Any, ...] + + +class V1RootValidatorFunction(Protocol): + """A simple root validator, supported for V1 validators and V2 validators.""" + + def __call__(self, __values: RootValidatorValues) -> RootValidatorValues: + ... + + +class V2CoreBeforeRootValidator(Protocol): + """V2 validator with mode='before'.""" + + def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues: + ... + + +class V2CoreAfterRootValidator(Protocol): + """V2 validator with mode='after'.""" + + def __call__( + self, __fields_tuple: RootValidatorFieldsTuple, __info: core_schema.ValidationInfo + ) -> RootValidatorFieldsTuple: + ... + + +def make_v1_generic_root_validator( + validator: V1RootValidatorFunction, pre: bool +) -> V2CoreBeforeRootValidator | V2CoreAfterRootValidator: + """Wrap a V1 style root validator for V2 compatibility. + + Args: + validator: The V1 style field validator. + pre: Whether the validator is a pre validator. + + Returns: + A wrapped V2 style validator. + """ + if pre is True: + # mode='before' for pydantic-core + def _wrapper1(values: RootValidatorValues, _: core_schema.ValidationInfo) -> RootValidatorValues: + return validator(values) + + return _wrapper1 + + # mode='after' for pydantic-core + def _wrapper2(fields_tuple: RootValidatorFieldsTuple, _: core_schema.ValidationInfo) -> RootValidatorFieldsTuple: + if len(fields_tuple) == 2: + # dataclass, this is easy + values, init_vars = fields_tuple + values = validator(values) + return values, init_vars + else: + # ugly hack: to match v1 behaviour, we merge values and model_extra, then split them up based on fields + # afterwards + model_dict, model_extra, fields_set = fields_tuple + if model_extra: + fields = set(model_dict.keys()) + model_dict.update(model_extra) + model_dict_new = validator(model_dict) + for k in list(model_dict_new.keys()): + if k not in fields: + model_extra[k] = model_dict_new.pop(k) + else: + model_dict_new = validator(model_dict) + return model_dict_new, model_extra, fields_set + + return _wrapper2 diff --git a/lib/pydantic/_internal/_discriminated_union.py b/lib/pydantic/_internal/_discriminated_union.py new file mode 100644 index 00000000..c40117d3 --- /dev/null +++ b/lib/pydantic/_internal/_discriminated_union.py @@ -0,0 +1,506 @@ +from __future__ import annotations as _annotations + +from typing import TYPE_CHECKING, Any, Hashable, Sequence + +from pydantic_core import CoreSchema, core_schema + +from ..errors import PydanticUserError +from . import _core_utils +from ._core_utils import ( + CoreSchemaField, + collect_definitions, + simplify_schema_references, +) + +if TYPE_CHECKING: + from ..types import Discriminator + +CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator' + + +class MissingDefinitionForUnionRef(Exception): + """Raised when applying a discriminated union discriminator to a schema + requires a definition that is not yet defined + """ + + def __init__(self, ref: str) -> None: + self.ref = ref + super().__init__(f'Missing definition for ref {self.ref!r}') + + +def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None: + schema.setdefault('metadata', {}) + metadata = schema.get('metadata') + assert metadata is not None + metadata[CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY] = discriminator + + +def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: + definitions: dict[str, CoreSchema] | None = None + + def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema: + nonlocal definitions + + s = recurse(s, inner) + if s['type'] == 'tagged-union': + return s + + metadata = s.get('metadata', {}) + discriminator = metadata.pop(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None) + if discriminator is not None: + if definitions is None: + definitions = collect_definitions(schema) + s = apply_discriminator(s, discriminator, definitions) + return s + + return simplify_schema_references(_core_utils.walk_core_schema(schema, inner)) + + +def apply_discriminator( + schema: core_schema.CoreSchema, + discriminator: str | Discriminator, + definitions: dict[str, core_schema.CoreSchema] | None = None, +) -> core_schema.CoreSchema: + """Applies the discriminator and returns a new core schema. + + Args: + schema: The input schema. + discriminator: The name of the field which will serve as the discriminator. + definitions: A mapping of schema ref to schema. + + Returns: + The new core schema. + + Raises: + TypeError: + - If `discriminator` is used with invalid union variant. + - If `discriminator` is used with `Union` type with one variant. + - If `discriminator` value mapped to multiple choices. + MissingDefinitionForUnionRef: + If the definition for ref is missing. + PydanticUserError: + - If a model in union doesn't have a discriminator field. + - If discriminator field has a non-string alias. + - If discriminator fields have different aliases. + - If discriminator field not of type `Literal`. + """ + from ..types import Discriminator + + if isinstance(discriminator, Discriminator): + if isinstance(discriminator.discriminator, str): + discriminator = discriminator.discriminator + else: + return discriminator._convert_schema(schema) + + return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema) + + +class _ApplyInferredDiscriminator: + """This class is used to convert an input schema containing a union schema into one where that union is + replaced with a tagged-union, with all the associated debugging and performance benefits. + + This is done by: + * Validating that the input schema is compatible with the provided discriminator + * Introspecting the schema to determine which discriminator values should map to which union choices + * Handling various edge cases such as 'definitions', 'default', 'nullable' schemas, and more + + I have chosen to implement the conversion algorithm in this class, rather than a function, + to make it easier to maintain state while recursively walking the provided CoreSchema. + """ + + def __init__(self, discriminator: str, definitions: dict[str, core_schema.CoreSchema]): + # `discriminator` should be the name of the field which will serve as the discriminator. + # It must be the python name of the field, and *not* the field's alias. Note that as of now, + # all members of a discriminated union _must_ use a field with the same name as the discriminator. + # This may change if/when we expose a way to manually specify the TaggedUnionSchema's choices. + self.discriminator = discriminator + + # `definitions` should contain a mapping of schema ref to schema for all schemas which might + # be referenced by some choice + self.definitions = definitions + + # `_discriminator_alias` will hold the value, if present, of the alias for the discriminator + # + # Note: following the v1 implementation, we currently disallow the use of different aliases + # for different choices. This is not a limitation of pydantic_core, but if we try to handle + # this, the inference logic gets complicated very quickly, and could result in confusing + # debugging challenges for users making subtle mistakes. + # + # Rather than trying to do the most powerful inference possible, I think we should eventually + # expose a way to more-manually control the way the TaggedUnionSchema is constructed through + # the use of a new type which would be placed as an Annotation on the Union type. This would + # provide the full flexibility/power of pydantic_core's TaggedUnionSchema where necessary for + # more complex cases, without over-complicating the inference logic for the common cases. + self._discriminator_alias: str | None = None + + # `_should_be_nullable` indicates whether the converted union has `None` as an allowed value. + # If `None` is an acceptable value of the (possibly-wrapped) union, we ignore it while + # constructing the TaggedUnionSchema, but set the `_should_be_nullable` attribute to True. + # Once we have constructed the TaggedUnionSchema, if `_should_be_nullable` is True, we ensure + # that the final schema gets wrapped as a NullableSchema. This has the same semantics on the + # python side, but resolves the issue that `None` cannot correspond to any discriminator values. + self._should_be_nullable = False + + # `_is_nullable` is used to track if the final produced schema will definitely be nullable; + # we set it to True if the input schema is wrapped in a nullable schema that we know will be preserved + # as an indication that, even if None is discovered as one of the union choices, we will not need to wrap + # the final value in another nullable schema. + # + # This is more complicated than just checking for the final outermost schema having type 'nullable' thanks + # to the possible presence of other wrapper schemas such as DefinitionsSchema, WithDefaultSchema, etc. + self._is_nullable = False + + # `_choices_to_handle` serves as a stack of choices to add to the tagged union. Initially, choices + # from the union in the wrapped schema will be appended to this list, and the recursive choice-handling + # algorithm may add more choices to this stack as (nested) unions are encountered. + self._choices_to_handle: list[core_schema.CoreSchema] = [] + + # `_tagged_union_choices` is built during the call to `apply`, and will hold the choices to be included + # in the output TaggedUnionSchema that will replace the union from the input schema + self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {} + + # `_used` is changed to True after applying the discriminator to prevent accidental re-use + self._used = False + + def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: + """Return a new CoreSchema based on `schema` that uses a tagged-union with the discriminator provided + to this class. + + Args: + schema: The input schema. + + Returns: + The new core schema. + + Raises: + TypeError: + - If `discriminator` is used with invalid union variant. + - If `discriminator` is used with `Union` type with one variant. + - If `discriminator` value mapped to multiple choices. + ValueError: + If the definition for ref is missing. + PydanticUserError: + - If a model in union doesn't have a discriminator field. + - If discriminator field has a non-string alias. + - If discriminator fields have different aliases. + - If discriminator field not of type `Literal`. + """ + self.definitions.update(collect_definitions(schema)) + assert not self._used + schema = self._apply_to_root(schema) + if self._should_be_nullable and not self._is_nullable: + schema = core_schema.nullable_schema(schema) + self._used = True + new_defs = collect_definitions(schema) + missing_defs = self.definitions.keys() - new_defs.keys() + if missing_defs: + schema = core_schema.definitions_schema(schema, [self.definitions[ref] for ref in missing_defs]) + return schema + + def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: + """This method handles the outer-most stage of recursion over the input schema: + unwrapping nullable or definitions schemas, and calling the `_handle_choice` + method iteratively on the choices extracted (recursively) from the possibly-wrapped union. + """ + if schema['type'] == 'nullable': + self._is_nullable = True + wrapped = self._apply_to_root(schema['schema']) + nullable_wrapper = schema.copy() + nullable_wrapper['schema'] = wrapped + return nullable_wrapper + + if schema['type'] == 'definitions': + wrapped = self._apply_to_root(schema['schema']) + definitions_wrapper = schema.copy() + definitions_wrapper['schema'] = wrapped + return definitions_wrapper + + if schema['type'] != 'union': + # If the schema is not a union, it probably means it just had a single member and + # was flattened by pydantic_core. + # However, it still may make sense to apply the discriminator to this schema, + # as a way to get discriminated-union-style error messages, so we allow this here. + schema = core_schema.union_schema([schema]) + + # Reverse the choices list before extending the stack so that they get handled in the order they occur + choices_schemas = [v[0] if isinstance(v, tuple) else v for v in schema['choices'][::-1]] + self._choices_to_handle.extend(choices_schemas) + while self._choices_to_handle: + choice = self._choices_to_handle.pop() + self._handle_choice(choice) + + if self._discriminator_alias is not None and self._discriminator_alias != self.discriminator: + # * We need to annotate `discriminator` as a union here to handle both branches of this conditional + # * We need to annotate `discriminator` as list[list[str | int]] and not list[list[str]] due to the + # invariance of list, and because list[list[str | int]] is the type of the discriminator argument + # to tagged_union_schema below + # * See the docstring of pydantic_core.core_schema.tagged_union_schema for more details about how to + # interpret the value of the discriminator argument to tagged_union_schema. (The list[list[str]] here + # is the appropriate way to provide a list of fallback attributes to check for a discriminator value.) + discriminator: str | list[list[str | int]] = [[self.discriminator], [self._discriminator_alias]] + else: + discriminator = self.discriminator + return core_schema.tagged_union_schema( + choices=self._tagged_union_choices, + discriminator=discriminator, + custom_error_type=schema.get('custom_error_type'), + custom_error_message=schema.get('custom_error_message'), + custom_error_context=schema.get('custom_error_context'), + strict=False, + from_attributes=True, + ref=schema.get('ref'), + metadata=schema.get('metadata'), + serialization=schema.get('serialization'), + ) + + def _handle_choice(self, choice: core_schema.CoreSchema) -> None: + """This method handles the "middle" stage of recursion over the input schema. + Specifically, it is responsible for handling each choice of the outermost union + (and any "coalesced" choices obtained from inner unions). + + Here, "handling" entails: + * Coalescing nested unions and compatible tagged-unions + * Tracking the presence of 'none' and 'nullable' schemas occurring as choices + * Validating that each allowed discriminator value maps to a unique choice + * Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema. + """ + if choice['type'] == 'definition-ref': + if choice['schema_ref'] not in self.definitions: + raise MissingDefinitionForUnionRef(choice['schema_ref']) + + if choice['type'] == 'none': + self._should_be_nullable = True + elif choice['type'] == 'definitions': + self._handle_choice(choice['schema']) + elif choice['type'] == 'nullable': + self._should_be_nullable = True + self._handle_choice(choice['schema']) # unwrap the nullable schema + elif choice['type'] == 'union': + # Reverse the choices list before extending the stack so that they get handled in the order they occur + choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]] + self._choices_to_handle.extend(choices_schemas) + elif choice['type'] not in { + 'model', + 'typed-dict', + 'tagged-union', + 'lax-or-strict', + 'dataclass', + 'dataclass-args', + 'definition-ref', + } and not _core_utils.is_function_with_inner_schema(choice): + # We should eventually handle 'definition-ref' as well + raise TypeError( + f'{choice["type"]!r} is not a valid discriminated union variant;' + ' should be a `BaseModel` or `dataclass`' + ) + else: + if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice): + # In this case, this inner tagged-union is compatible with the outer tagged-union, + # and its choices can be coalesced into the outer TaggedUnionSchema. + subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))] + # Reverse the choices list before extending the stack so that they get handled in the order they occur + self._choices_to_handle.extend(subchoices[::-1]) + return + + inferred_discriminator_values = self._infer_discriminator_values_for_choice(choice, source_name=None) + self._set_unique_choice_for_values(choice, inferred_discriminator_values) + + def _is_discriminator_shared(self, choice: core_schema.TaggedUnionSchema) -> bool: + """This method returns a boolean indicating whether the discriminator for the `choice` + is the same as that being used for the outermost tagged union. This is used to + determine whether this TaggedUnionSchema choice should be "coalesced" into the top level, + or whether it should be treated as a separate (nested) choice. + """ + inner_discriminator = choice['discriminator'] + return inner_discriminator == self.discriminator or ( + isinstance(inner_discriminator, list) + and (self.discriminator in inner_discriminator or [self.discriminator] in inner_discriminator) + ) + + def _infer_discriminator_values_for_choice( # noqa C901 + self, choice: core_schema.CoreSchema, source_name: str | None + ) -> list[str | int]: + """This function recurses over `choice`, extracting all discriminator values that should map to this choice. + + `model_name` is accepted for the purpose of producing useful error messages. + """ + if choice['type'] == 'definitions': + return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name) + elif choice['type'] == 'function-plain': + raise TypeError( + f'{choice["type"]!r} is not a valid discriminated union variant;' + ' should be a `BaseModel` or `dataclass`' + ) + elif _core_utils.is_function_with_inner_schema(choice): + return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name) + elif choice['type'] == 'lax-or-strict': + return sorted( + set( + self._infer_discriminator_values_for_choice(choice['lax_schema'], source_name=None) + + self._infer_discriminator_values_for_choice(choice['strict_schema'], source_name=None) + ) + ) + + elif choice['type'] == 'tagged-union': + values: list[str | int] = [] + # Ignore str/int "choices" since these are just references to other choices + subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))] + for subchoice in subchoices: + subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None) + values.extend(subchoice_values) + return values + + elif choice['type'] == 'union': + values = [] + for subchoice in choice['choices']: + subchoice_schema = subchoice[0] if isinstance(subchoice, tuple) else subchoice + subchoice_values = self._infer_discriminator_values_for_choice(subchoice_schema, source_name=None) + values.extend(subchoice_values) + return values + + elif choice['type'] == 'nullable': + self._should_be_nullable = True + return self._infer_discriminator_values_for_choice(choice['schema'], source_name=None) + + elif choice['type'] == 'model': + return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__) + + elif choice['type'] == 'dataclass': + return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__) + + elif choice['type'] == 'model-fields': + return self._infer_discriminator_values_for_model_choice(choice, source_name=source_name) + + elif choice['type'] == 'dataclass-args': + return self._infer_discriminator_values_for_dataclass_choice(choice, source_name=source_name) + + elif choice['type'] == 'typed-dict': + return self._infer_discriminator_values_for_typed_dict_choice(choice, source_name=source_name) + + elif choice['type'] == 'definition-ref': + schema_ref = choice['schema_ref'] + if schema_ref not in self.definitions: + raise MissingDefinitionForUnionRef(schema_ref) + return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name) + else: + raise TypeError( + f'{choice["type"]!r} is not a valid discriminated union variant;' + ' should be a `BaseModel` or `dataclass`' + ) + + def _infer_discriminator_values_for_typed_dict_choice( + self, choice: core_schema.TypedDictSchema, source_name: str | None = None + ) -> list[str | int]: + """This method just extracts the _infer_discriminator_values_for_choice logic specific to TypedDictSchema + for the sake of readability. + """ + source = 'TypedDict' if source_name is None else f'TypedDict {source_name!r}' + field = choice['fields'].get(self.discriminator) + if field is None: + raise PydanticUserError( + f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field' + ) + return self._infer_discriminator_values_for_field(field, source) + + def _infer_discriminator_values_for_model_choice( + self, choice: core_schema.ModelFieldsSchema, source_name: str | None = None + ) -> list[str | int]: + source = 'ModelFields' if source_name is None else f'Model {source_name!r}' + field = choice['fields'].get(self.discriminator) + if field is None: + raise PydanticUserError( + f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field' + ) + return self._infer_discriminator_values_for_field(field, source) + + def _infer_discriminator_values_for_dataclass_choice( + self, choice: core_schema.DataclassArgsSchema, source_name: str | None = None + ) -> list[str | int]: + source = 'DataclassArgs' if source_name is None else f'Dataclass {source_name!r}' + for field in choice['fields']: + if field['name'] == self.discriminator: + break + else: + raise PydanticUserError( + f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field' + ) + return self._infer_discriminator_values_for_field(field, source) + + def _infer_discriminator_values_for_field(self, field: CoreSchemaField, source: str) -> list[str | int]: + if field['type'] == 'computed-field': + # This should never occur as a discriminator, as it is only relevant to serialization + return [] + alias = field.get('validation_alias', self.discriminator) + if not isinstance(alias, str): + raise PydanticUserError( + f'Alias {alias!r} is not supported in a discriminated union', code='discriminator-alias-type' + ) + if self._discriminator_alias is None: + self._discriminator_alias = alias + elif self._discriminator_alias != alias: + raise PydanticUserError( + f'Aliases for discriminator {self.discriminator!r} must be the same ' + f'(got {alias}, {self._discriminator_alias})', + code='discriminator-alias', + ) + return self._infer_discriminator_values_for_inner_schema(field['schema'], source) + + def _infer_discriminator_values_for_inner_schema( + self, schema: core_schema.CoreSchema, source: str + ) -> list[str | int]: + """When inferring discriminator values for a field, we typically extract the expected values from a literal + schema. This function does that, but also handles nested unions and defaults. + """ + if schema['type'] == 'literal': + return schema['expected'] + + elif schema['type'] == 'union': + # Generally when multiple values are allowed they should be placed in a single `Literal`, but + # we add this case to handle the situation where a field is annotated as a `Union` of `Literal`s. + # For example, this lets us handle `Union[Literal['key'], Union[Literal['Key'], Literal['KEY']]]` + values: list[Any] = [] + for choice in schema['choices']: + choice_schema = choice[0] if isinstance(choice, tuple) else choice + choice_values = self._infer_discriminator_values_for_inner_schema(choice_schema, source) + values.extend(choice_values) + return values + + elif schema['type'] == 'default': + # This will happen if the field has a default value; we ignore it while extracting the discriminator values + return self._infer_discriminator_values_for_inner_schema(schema['schema'], source) + + elif schema['type'] == 'function-after': + # After validators don't affect the discriminator values + return self._infer_discriminator_values_for_inner_schema(schema['schema'], source) + + elif schema['type'] in {'function-before', 'function-wrap', 'function-plain'}: + validator_type = repr(schema['type'].split('-')[1]) + raise PydanticUserError( + f'Cannot use a mode={validator_type} validator in the' + f' discriminator field {self.discriminator!r} of {source}', + code='discriminator-validator', + ) + + else: + raise PydanticUserError( + f'{source} needs field {self.discriminator!r} to be of type `Literal`', + code='discriminator-needs-literal', + ) + + def _set_unique_choice_for_values(self, choice: core_schema.CoreSchema, values: Sequence[str | int]) -> None: + """This method updates `self.tagged_union_choices` so that all provided (discriminator) `values` map to the + provided `choice`, validating that none of these values already map to another (different) choice. + """ + for discriminator_value in values: + if discriminator_value in self._tagged_union_choices: + # It is okay if `value` is already in tagged_union_choices as long as it maps to the same value. + # Because tagged_union_choices may map values to other values, we need to walk the choices dict + # until we get to a "real" choice, and confirm that is equal to the one assigned. + existing_choice = self._tagged_union_choices[discriminator_value] + if existing_choice != choice: + raise TypeError( + f'Value {discriminator_value!r} for discriminator ' + f'{self.discriminator!r} mapped to multiple choices' + ) + else: + self._tagged_union_choices[discriminator_value] = choice diff --git a/lib/pydantic/_internal/_fields.py b/lib/pydantic/_internal/_fields.py new file mode 100644 index 00000000..94de3062 --- /dev/null +++ b/lib/pydantic/_internal/_fields.py @@ -0,0 +1,319 @@ +"""Private logic related to fields (the `Field()` function and `FieldInfo` class), and arguments to `Annotated`.""" +from __future__ import annotations as _annotations + +import dataclasses +import sys +import warnings +from copy import copy +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +from pydantic_core import PydanticUndefined + +from pydantic.errors import PydanticUserError + +from . import _typing_extra +from ._config import ConfigWrapper +from ._repr import Representation +from ._typing_extra import get_cls_type_hints_lenient, get_type_hints, is_classvar, is_finalvar + +if TYPE_CHECKING: + from annotated_types import BaseMetadata + + from ..fields import FieldInfo + from ..main import BaseModel + from ._dataclasses import StandardDataclass + from ._decorators import DecoratorInfos + + +def get_type_hints_infer_globalns( + obj: Any, + localns: dict[str, Any] | None = None, + include_extras: bool = False, +) -> dict[str, Any]: + """Gets type hints for an object by inferring the global namespace. + + It uses the `typing.get_type_hints`, The only thing that we do here is fetching + global namespace from `obj.__module__` if it is not `None`. + + Args: + obj: The object to get its type hints. + localns: The local namespaces. + include_extras: Whether to recursively include annotation metadata. + + Returns: + The object type hints. + """ + module_name = getattr(obj, '__module__', None) + globalns: dict[str, Any] | None = None + if module_name: + try: + globalns = sys.modules[module_name].__dict__ + except KeyError: + # happens occasionally, see https://github.com/pydantic/pydantic/issues/2363 + pass + return get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras) + + +class PydanticMetadata(Representation): + """Base class for annotation markers like `Strict`.""" + + __slots__ = () + + +def pydantic_general_metadata(**metadata: Any) -> BaseMetadata: + """Create a new `_PydanticGeneralMetadata` class with the given metadata. + + Args: + **metadata: The metadata to add. + + Returns: + The new `_PydanticGeneralMetadata` class. + """ + return _general_metadata_cls()(metadata) # type: ignore + + +@lru_cache(maxsize=None) +def _general_metadata_cls() -> type[BaseMetadata]: + """Do it this way to avoid importing `annotated_types` at import time.""" + from annotated_types import BaseMetadata + + class _PydanticGeneralMetadata(PydanticMetadata, BaseMetadata): + """Pydantic general metadata like `max_digits`.""" + + def __init__(self, metadata: Any): + self.__dict__ = metadata + + return _PydanticGeneralMetadata # type: ignore + + +def collect_model_fields( # noqa: C901 + cls: type[BaseModel], + bases: tuple[type[Any], ...], + config_wrapper: ConfigWrapper, + types_namespace: dict[str, Any] | None, + *, + typevars_map: dict[Any, Any] | None = None, +) -> tuple[dict[str, FieldInfo], set[str]]: + """Collect the fields of a nascent pydantic model. + + Also collect the names of any ClassVars present in the type hints. + + The returned value is a tuple of two items: the fields dict, and the set of ClassVar names. + + Args: + cls: BaseModel or dataclass. + bases: Parents of the class, generally `cls.__bases__`. + config_wrapper: The config wrapper instance. + types_namespace: Optional extra namespace to look for types in. + typevars_map: A dictionary mapping type variables to their concrete types. + + Returns: + A tuple contains fields and class variables. + + Raises: + NameError: + - If there is a conflict between a field name and protected namespaces. + - If there is a field other than `root` in `RootModel`. + - If a field shadows an attribute in the parent model. + """ + from ..fields import FieldInfo + + type_hints = get_cls_type_hints_lenient(cls, types_namespace) + + # https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older + # annotations is only used for finding fields in parent classes + annotations = cls.__dict__.get('__annotations__', {}) + fields: dict[str, FieldInfo] = {} + + class_vars: set[str] = set() + for ann_name, ann_type in type_hints.items(): + if ann_name == 'model_config': + # We never want to treat `model_config` as a field + # Note: we may need to change this logic if/when we introduce a `BareModel` class with no + # protected namespaces (where `model_config` might be allowed as a field name) + continue + for protected_namespace in config_wrapper.protected_namespaces: + if ann_name.startswith(protected_namespace): + for b in bases: + if hasattr(b, ann_name): + from ..main import BaseModel + + if not (issubclass(b, BaseModel) and ann_name in b.model_fields): + raise NameError( + f'Field "{ann_name}" conflicts with member {getattr(b, ann_name)}' + f' of protected namespace "{protected_namespace}".' + ) + else: + valid_namespaces = tuple( + x for x in config_wrapper.protected_namespaces if not ann_name.startswith(x) + ) + warnings.warn( + f'Field "{ann_name}" has conflict with protected namespace "{protected_namespace}".' + '\n\nYou may be able to resolve this warning by setting' + f" `model_config['protected_namespaces'] = {valid_namespaces}`.", + UserWarning, + ) + if is_classvar(ann_type): + class_vars.add(ann_name) + continue + if _is_finalvar_with_default_val(ann_type, getattr(cls, ann_name, PydanticUndefined)): + class_vars.add(ann_name) + continue + if not is_valid_field_name(ann_name): + continue + if cls.__pydantic_root_model__ and ann_name != 'root': + raise NameError( + f"Unexpected field with name {ann_name!r}; only 'root' is allowed as a field of a `RootModel`" + ) + + # when building a generic model with `MyModel[int]`, the generic_origin check makes sure we don't get + # "... shadows an attribute" errors + generic_origin = getattr(cls, '__pydantic_generic_metadata__', {}).get('origin') + for base in bases: + dataclass_fields = { + field.name for field in (dataclasses.fields(base) if dataclasses.is_dataclass(base) else ()) + } + if hasattr(base, ann_name): + if base is generic_origin: + # Don't error when "shadowing" of attributes in parametrized generics + continue + + if ann_name in dataclass_fields: + # Don't error when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set + # on the class instance. + continue + warnings.warn( + f'Field name "{ann_name}" shadows an attribute in parent "{base.__qualname__}"; ', + UserWarning, + ) + + try: + default = getattr(cls, ann_name, PydanticUndefined) + if default is PydanticUndefined: + raise AttributeError + except AttributeError: + if ann_name in annotations: + field_info = FieldInfo.from_annotation(ann_type) + else: + # if field has no default value and is not in __annotations__ this means that it is + # defined in a base class and we can take it from there + model_fields_lookup: dict[str, FieldInfo] = {} + for x in cls.__bases__[::-1]: + model_fields_lookup.update(getattr(x, 'model_fields', {})) + if ann_name in model_fields_lookup: + # The field was present on one of the (possibly multiple) base classes + # copy the field to make sure typevar substitutions don't cause issues with the base classes + field_info = copy(model_fields_lookup[ann_name]) + else: + # The field was not found on any base classes; this seems to be caused by fields not getting + # generated thanks to models not being fully defined while initializing recursive models. + # Nothing stops us from just creating a new FieldInfo for this type hint, so we do this. + field_info = FieldInfo.from_annotation(ann_type) + else: + field_info = FieldInfo.from_annotated_attribute(ann_type, default) + # attributes which are fields are removed from the class namespace: + # 1. To match the behaviour of annotation-only fields + # 2. To avoid false positives in the NameError check above + try: + delattr(cls, ann_name) + except AttributeError: + pass # indicates the attribute was on a parent class + + # Use cls.__dict__['__pydantic_decorators__'] instead of cls.__pydantic_decorators__ + # to make sure the decorators have already been built for this exact class + decorators: DecoratorInfos = cls.__dict__['__pydantic_decorators__'] + if ann_name in decorators.computed_fields: + raise ValueError("you can't override a field with a computed field") + fields[ann_name] = field_info + + if typevars_map: + for field in fields.values(): + field.apply_typevars_map(typevars_map, types_namespace) + + return fields, class_vars + + +def _is_finalvar_with_default_val(type_: type[Any], val: Any) -> bool: + from ..fields import FieldInfo + + if not is_finalvar(type_): + return False + elif val is PydanticUndefined: + return False + elif isinstance(val, FieldInfo) and (val.default is PydanticUndefined and val.default_factory is None): + return False + else: + return True + + +def collect_dataclass_fields( + cls: type[StandardDataclass], types_namespace: dict[str, Any] | None, *, typevars_map: dict[Any, Any] | None = None +) -> dict[str, FieldInfo]: + """Collect the fields of a dataclass. + + Args: + cls: dataclass. + types_namespace: Optional extra namespace to look for types in. + typevars_map: A dictionary mapping type variables to their concrete types. + + Returns: + The dataclass fields. + """ + from ..fields import FieldInfo + + fields: dict[str, FieldInfo] = {} + dataclass_fields: dict[str, dataclasses.Field] = cls.__dataclass_fields__ + cls_localns = dict(vars(cls)) # this matches get_cls_type_hints_lenient, but all tests pass with `= None` instead + + source_module = sys.modules.get(cls.__module__) + if source_module is not None: + types_namespace = {**source_module.__dict__, **(types_namespace or {})} + + for ann_name, dataclass_field in dataclass_fields.items(): + ann_type = _typing_extra.eval_type_lenient(dataclass_field.type, types_namespace, cls_localns) + if is_classvar(ann_type): + continue + + if ( + not dataclass_field.init + and dataclass_field.default == dataclasses.MISSING + and dataclass_field.default_factory == dataclasses.MISSING + ): + # TODO: We should probably do something with this so that validate_assignment behaves properly + # Issue: https://github.com/pydantic/pydantic/issues/5470 + continue + + if isinstance(dataclass_field.default, FieldInfo): + if dataclass_field.default.init_var: + if dataclass_field.default.init is False: + raise PydanticUserError( + f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.', + code='clashing-init-and-init-var', + ) + + # TODO: same note as above re validate_assignment + continue + field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field.default) + else: + field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field) + + fields[ann_name] = field_info + + if field_info.default is not PydanticUndefined and isinstance(getattr(cls, ann_name, field_info), FieldInfo): + # We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo + setattr(cls, ann_name, field_info.default) + + if typevars_map: + for field in fields.values(): + field.apply_typevars_map(typevars_map, types_namespace) + + return fields + + +def is_valid_field_name(name: str) -> bool: + return not name.startswith('_') + + +def is_valid_privateattr_name(name: str) -> bool: + return name.startswith('_') and not name.startswith('__') diff --git a/lib/pydantic/_internal/_forward_ref.py b/lib/pydantic/_internal/_forward_ref.py new file mode 100644 index 00000000..231f81d1 --- /dev/null +++ b/lib/pydantic/_internal/_forward_ref.py @@ -0,0 +1,23 @@ +from __future__ import annotations as _annotations + +from dataclasses import dataclass +from typing import Union + + +@dataclass +class PydanticRecursiveRef: + type_ref: str + + __name__ = 'PydanticRecursiveRef' + __hash__ = object.__hash__ + + def __call__(self) -> None: + """Defining __call__ is necessary for the `typing` module to let you use an instance of + this class as the result of resolving a standard ForwardRef. + """ + + def __or__(self, other): + return Union[self, other] # type: ignore + + def __ror__(self, other): + return Union[other, self] # type: ignore diff --git a/lib/pydantic/_internal/_generate_schema.py b/lib/pydantic/_internal/_generate_schema.py new file mode 100644 index 00000000..6ab7ec19 --- /dev/null +++ b/lib/pydantic/_internal/_generate_schema.py @@ -0,0 +1,2231 @@ +"""Convert python types to pydantic-core schema.""" +from __future__ import annotations as _annotations + +import collections.abc +import dataclasses +import inspect +import re +import sys +import typing +import warnings +from contextlib import contextmanager +from copy import copy, deepcopy +from enum import Enum +from functools import partial +from inspect import Parameter, _ParameterKind, signature +from itertools import chain +from operator import attrgetter +from types import FunctionType, LambdaType, MethodType +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Final, + ForwardRef, + Iterable, + Iterator, + Mapping, + Type, + TypeVar, + Union, + cast, + overload, +) +from warnings import warn + +from pydantic_core import CoreSchema, PydanticUndefined, core_schema, to_jsonable_python +from typing_extensions import Annotated, Literal, TypeAliasType, TypedDict, get_args, get_origin, is_typeddict + +from ..aliases import AliasGenerator +from ..annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler +from ..config import ConfigDict, JsonDict, JsonEncoder +from ..errors import PydanticSchemaGenerationError, PydanticUndefinedAnnotation, PydanticUserError +from ..json_schema import JsonSchemaValue +from ..version import version_short +from ..warnings import PydanticDeprecatedSince20 +from . import _core_utils, _decorators, _discriminated_union, _known_annotated_metadata, _typing_extra +from ._config import ConfigWrapper, ConfigWrapperStack +from ._core_metadata import CoreMetadataHandler, build_metadata_dict +from ._core_utils import ( + CoreSchemaOrField, + collect_invalid_schemas, + define_expected_missing_refs, + get_ref, + get_type_ref, + is_function_with_inner_schema, + is_list_like_schema_with_items_schema, + simplify_schema_references, + validate_core_schema, +) +from ._decorators import ( + Decorator, + DecoratorInfos, + FieldSerializerDecoratorInfo, + FieldValidatorDecoratorInfo, + ModelSerializerDecoratorInfo, + ModelValidatorDecoratorInfo, + RootValidatorDecoratorInfo, + ValidatorDecoratorInfo, + get_attribute_from_bases, + inspect_field_serializer, + inspect_model_serializer, + inspect_validator, +) +from ._fields import collect_dataclass_fields, get_type_hints_infer_globalns +from ._forward_ref import PydanticRecursiveRef +from ._generics import get_standard_typevars_map, has_instance_in_type, recursively_defined_type_refs, replace_types +from ._schema_generation_shared import ( + CallbackGetCoreSchemaHandler, +) +from ._typing_extra import is_finalvar +from ._utils import lenient_issubclass + +if TYPE_CHECKING: + from ..fields import ComputedFieldInfo, FieldInfo + from ..main import BaseModel + from ..types import Discriminator + from ..validators import FieldValidatorModes + from ._dataclasses import StandardDataclass + from ._schema_generation_shared import GetJsonSchemaFunction + +_SUPPORTS_TYPEDDICT = sys.version_info >= (3, 12) +_AnnotatedType = type(Annotated[int, 123]) + +FieldDecoratorInfo = Union[ValidatorDecoratorInfo, FieldValidatorDecoratorInfo, FieldSerializerDecoratorInfo] +FieldDecoratorInfoType = TypeVar('FieldDecoratorInfoType', bound=FieldDecoratorInfo) +AnyFieldDecorator = Union[ + Decorator[ValidatorDecoratorInfo], + Decorator[FieldValidatorDecoratorInfo], + Decorator[FieldSerializerDecoratorInfo], +] + +ModifyCoreSchemaWrapHandler = GetCoreSchemaHandler +GetCoreSchemaFunction = Callable[[Any, ModifyCoreSchemaWrapHandler], core_schema.CoreSchema] + + +TUPLE_TYPES: list[type] = [tuple, typing.Tuple] +LIST_TYPES: list[type] = [list, typing.List, collections.abc.MutableSequence] +SET_TYPES: list[type] = [set, typing.Set, collections.abc.MutableSet] +FROZEN_SET_TYPES: list[type] = [frozenset, typing.FrozenSet, collections.abc.Set] +DICT_TYPES: list[type] = [dict, typing.Dict, collections.abc.MutableMapping, collections.abc.Mapping] + + +def check_validator_fields_against_field_name( + info: FieldDecoratorInfo, + field: str, +) -> bool: + """Check if field name is in validator fields. + + Args: + info: The field info. + field: The field name to check. + + Returns: + `True` if field name is in validator fields, `False` otherwise. + """ + if isinstance(info, (ValidatorDecoratorInfo, FieldValidatorDecoratorInfo)): + if '*' in info.fields: + return True + for v_field_name in info.fields: + if v_field_name == field: + return True + return False + + +def check_decorator_fields_exist(decorators: Iterable[AnyFieldDecorator], fields: Iterable[str]) -> None: + """Check if the defined fields in decorators exist in `fields` param. + + It ignores the check for a decorator if the decorator has `*` as field or `check_fields=False`. + + Args: + decorators: An iterable of decorators. + fields: An iterable of fields name. + + Raises: + PydanticUserError: If one of the field names does not exist in `fields` param. + """ + fields = set(fields) + for dec in decorators: + if isinstance(dec.info, (ValidatorDecoratorInfo, FieldValidatorDecoratorInfo)) and '*' in dec.info.fields: + continue + if dec.info.check_fields is False: + continue + for field in dec.info.fields: + if field not in fields: + raise PydanticUserError( + f'Decorators defined with incorrect fields: {dec.cls_ref}.{dec.cls_var_name}' + " (use check_fields=False if you're inheriting from the model and intended this)", + code='decorator-missing-field', + ) + + +def filter_field_decorator_info_by_field( + validator_functions: Iterable[Decorator[FieldDecoratorInfoType]], field: str +) -> list[Decorator[FieldDecoratorInfoType]]: + return [dec for dec in validator_functions if check_validator_fields_against_field_name(dec.info, field)] + + +def apply_each_item_validators( + schema: core_schema.CoreSchema, + each_item_validators: list[Decorator[ValidatorDecoratorInfo]], + field_name: str | None, +) -> core_schema.CoreSchema: + # This V1 compatibility shim should eventually be removed + + # push down any `each_item=True` validators + # note that this won't work for any Annotated types that get wrapped by a function validator + # but that's okay because that didn't exist in V1 + if schema['type'] == 'nullable': + schema['schema'] = apply_each_item_validators(schema['schema'], each_item_validators, field_name) + return schema + elif schema['type'] == 'tuple': + if (variadic_item_index := schema.get('variadic_item_index')) is not None: + schema['items_schema'][variadic_item_index] = apply_validators( + schema['items_schema'][variadic_item_index], each_item_validators, field_name + ) + elif is_list_like_schema_with_items_schema(schema): + inner_schema = schema.get('items_schema', None) + if inner_schema is None: + inner_schema = core_schema.any_schema() + schema['items_schema'] = apply_validators(inner_schema, each_item_validators, field_name) + elif schema['type'] == 'dict': + # push down any `each_item=True` validators onto dict _values_ + # this is super arbitrary but it's the V1 behavior + inner_schema = schema.get('values_schema', None) + if inner_schema is None: + inner_schema = core_schema.any_schema() + schema['values_schema'] = apply_validators(inner_schema, each_item_validators, field_name) + elif each_item_validators: + raise TypeError( + f"`@validator(..., each_item=True)` cannot be applied to fields with a schema of {schema['type']}" + ) + return schema + + +def modify_model_json_schema( + schema_or_field: CoreSchemaOrField, handler: GetJsonSchemaHandler, *, cls: Any +) -> JsonSchemaValue: + """Add title and description for model-like classes' JSON schema. + + Args: + schema_or_field: The schema data to generate a JSON schema from. + handler: The `GetCoreSchemaHandler` instance. + cls: The model-like class. + + Returns: + JsonSchemaValue: The updated JSON schema. + """ + from ..main import BaseModel + + json_schema = handler(schema_or_field) + original_schema = handler.resolve_ref_schema(json_schema) + # Preserve the fact that definitions schemas should never have sibling keys: + if '$ref' in original_schema: + ref = original_schema['$ref'] + original_schema.clear() + original_schema['allOf'] = [{'$ref': ref}] + if 'title' not in original_schema: + original_schema['title'] = cls.__name__ + # BaseModel; don't use cls.__doc__ as it will contain the verbose class signature by default + docstring = None if cls is BaseModel else cls.__doc__ + if docstring and 'description' not in original_schema: + original_schema['description'] = inspect.cleandoc(docstring) + return json_schema + + +JsonEncoders = Dict[Type[Any], JsonEncoder] + + +def _add_custom_serialization_from_json_encoders( + json_encoders: JsonEncoders | None, tp: Any, schema: CoreSchema +) -> CoreSchema: + """Iterate over the json_encoders and add the first matching encoder to the schema. + + Args: + json_encoders: A dictionary of types and their encoder functions. + tp: The type to check for a matching encoder. + schema: The schema to add the encoder to. + """ + if not json_encoders: + return schema + if 'serialization' in schema: + return schema + # Check the class type and its superclasses for a matching encoder + # Decimal.__class__.__mro__ (and probably other cases) doesn't include Decimal itself + # if the type is a GenericAlias (e.g. from list[int]) we need to use __class__ instead of .__mro__ + for base in (tp, *getattr(tp, '__mro__', tp.__class__.__mro__)[:-1]): + encoder = json_encoders.get(base) + if encoder is None: + continue + + warnings.warn( + f'`json_encoders` is deprecated. See https://docs.pydantic.dev/{version_short()}/concepts/serialization/#custom-serializers for alternatives', + PydanticDeprecatedSince20, + ) + + # TODO: in theory we should check that the schema accepts a serialization key + schema['serialization'] = core_schema.plain_serializer_function_ser_schema(encoder, when_used='json') + return schema + + return schema + + +TypesNamespace = Union[Dict[str, Any], None] + + +class TypesNamespaceStack: + """A stack of types namespaces.""" + + def __init__(self, types_namespace: TypesNamespace): + self._types_namespace_stack: list[TypesNamespace] = [types_namespace] + + @property + def tail(self) -> TypesNamespace: + return self._types_namespace_stack[-1] + + @contextmanager + def push(self, for_type: type[Any]): + types_namespace = {**_typing_extra.get_cls_types_namespace(for_type), **(self.tail or {})} + self._types_namespace_stack.append(types_namespace) + try: + yield + finally: + self._types_namespace_stack.pop() + + +class GenerateSchema: + """Generate core schema for a Pydantic model, dataclass and types like `str`, `datetime`, ... .""" + + __slots__ = ( + '_config_wrapper_stack', + '_types_namespace_stack', + '_typevars_map', + '_has_invalid_schema', + 'field_name_stack', + 'defs', + ) + + def __init__( + self, + config_wrapper: ConfigWrapper, + types_namespace: dict[str, Any] | None, + typevars_map: dict[Any, Any] | None = None, + ) -> None: + # we need a stack for recursing into child models + self._config_wrapper_stack = ConfigWrapperStack(config_wrapper) + self._types_namespace_stack = TypesNamespaceStack(types_namespace) + self._typevars_map = typevars_map + self._has_invalid_schema = False + self.field_name_stack = _FieldNameStack() + self.defs = _Definitions() + + @classmethod + def __from_parent( + cls, + config_wrapper_stack: ConfigWrapperStack, + types_namespace_stack: TypesNamespaceStack, + typevars_map: dict[Any, Any] | None, + defs: _Definitions, + ) -> GenerateSchema: + obj = cls.__new__(cls) + obj._config_wrapper_stack = config_wrapper_stack + obj._types_namespace_stack = types_namespace_stack + obj._typevars_map = typevars_map + obj._has_invalid_schema = False + obj.field_name_stack = _FieldNameStack() + obj.defs = defs + return obj + + @property + def _config_wrapper(self) -> ConfigWrapper: + return self._config_wrapper_stack.tail + + @property + def _types_namespace(self) -> dict[str, Any] | None: + return self._types_namespace_stack.tail + + @property + def _current_generate_schema(self) -> GenerateSchema: + cls = self._config_wrapper.schema_generator or GenerateSchema + return cls.__from_parent( + self._config_wrapper_stack, + self._types_namespace_stack, + self._typevars_map, + self.defs, + ) + + @property + def _arbitrary_types(self) -> bool: + return self._config_wrapper.arbitrary_types_allowed + + def str_schema(self) -> CoreSchema: + """Generate a CoreSchema for `str`""" + return core_schema.str_schema() + + # the following methods can be overridden but should be considered + # unstable / private APIs + def _list_schema(self, tp: Any, items_type: Any) -> CoreSchema: + return core_schema.list_schema(self.generate_schema(items_type)) + + def _dict_schema(self, tp: Any, keys_type: Any, values_type: Any) -> CoreSchema: + return core_schema.dict_schema(self.generate_schema(keys_type), self.generate_schema(values_type)) + + def _set_schema(self, tp: Any, items_type: Any) -> CoreSchema: + return core_schema.set_schema(self.generate_schema(items_type)) + + def _frozenset_schema(self, tp: Any, items_type: Any) -> CoreSchema: + return core_schema.frozenset_schema(self.generate_schema(items_type)) + + def _arbitrary_type_schema(self, tp: Any) -> CoreSchema: + if not isinstance(tp, type): + warn( + f'{tp!r} is not a Python type (it may be an instance of an object),' + ' Pydantic will allow any object with no validation since we cannot even' + ' enforce that the input is an instance of the given type.' + ' To get rid of this error wrap the type with `pydantic.SkipValidation`.', + UserWarning, + ) + return core_schema.any_schema() + return core_schema.is_instance_schema(tp) + + def _unknown_type_schema(self, obj: Any) -> CoreSchema: + raise PydanticSchemaGenerationError( + f'Unable to generate pydantic-core schema for {obj!r}. ' + 'Set `arbitrary_types_allowed=True` in the model_config to ignore this error' + ' or implement `__get_pydantic_core_schema__` on your type to fully support it.' + '\n\nIf you got this error by calling handler() within' + ' `__get_pydantic_core_schema__` then you likely need to call' + ' `handler.generate_schema()` since we do not call' + ' `__get_pydantic_core_schema__` on `` otherwise to avoid infinite recursion.' + ) + + def _apply_discriminator_to_union( + self, schema: CoreSchema, discriminator: str | Discriminator | None + ) -> CoreSchema: + if discriminator is None: + return schema + try: + return _discriminated_union.apply_discriminator( + schema, + discriminator, + ) + except _discriminated_union.MissingDefinitionForUnionRef: + # defer until defs are resolved + _discriminated_union.set_discriminator_in_metadata( + schema, + discriminator, + ) + return schema + + class CollectedInvalid(Exception): + pass + + def clean_schema(self, schema: CoreSchema) -> CoreSchema: + schema = self.collect_definitions(schema) + schema = simplify_schema_references(schema) + schema = _discriminated_union.apply_discriminators(schema) + if collect_invalid_schemas(schema): + raise self.CollectedInvalid() + schema = validate_core_schema(schema) + return schema + + def collect_definitions(self, schema: CoreSchema) -> CoreSchema: + ref = cast('str | None', schema.get('ref', None)) + if ref: + self.defs.definitions[ref] = schema + if 'ref' in schema: + schema = core_schema.definition_reference_schema(schema['ref']) + return core_schema.definitions_schema( + schema, + list(self.defs.definitions.values()), + ) + + def _add_js_function(self, metadata_schema: CoreSchema, js_function: Callable[..., Any]) -> None: + metadata = CoreMetadataHandler(metadata_schema).metadata + pydantic_js_functions = metadata.setdefault('pydantic_js_functions', []) + # because of how we generate core schemas for nested generic models + # we can end up adding `BaseModel.__get_pydantic_json_schema__` multiple times + # this check may fail to catch duplicates if the function is a `functools.partial` + # or something like that + # but if it does it'll fail by inserting the duplicate + if js_function not in pydantic_js_functions: + pydantic_js_functions.append(js_function) + + def generate_schema( + self, + obj: Any, + from_dunder_get_core_schema: bool = True, + ) -> core_schema.CoreSchema: + """Generate core schema. + + Args: + obj: The object to generate core schema for. + from_dunder_get_core_schema: Whether to generate schema from either the + `__get_pydantic_core_schema__` function or `__pydantic_core_schema__` property. + + Returns: + The generated core schema. + + Raises: + PydanticUndefinedAnnotation: + If it is not possible to evaluate forward reference. + PydanticSchemaGenerationError: + If it is not possible to generate pydantic-core schema. + TypeError: + - If `alias_generator` returns a disallowed type (must be str, AliasPath or AliasChoices). + - If V1 style validator with `each_item=True` applied on a wrong field. + PydanticUserError: + - If `typing.TypedDict` is used instead of `typing_extensions.TypedDict` on Python < 3.12. + - If `__modify_schema__` method is used instead of `__get_pydantic_json_schema__`. + """ + schema: CoreSchema | None = None + + if from_dunder_get_core_schema: + from_property = self._generate_schema_from_property(obj, obj) + if from_property is not None: + schema = from_property + + if schema is None: + schema = self._generate_schema(obj) + + metadata_js_function = _extract_get_pydantic_json_schema(obj, schema) + if metadata_js_function is not None: + metadata_schema = resolve_original_schema(schema, self.defs.definitions) + if metadata_schema: + self._add_js_function(metadata_schema, metadata_js_function) + + schema = _add_custom_serialization_from_json_encoders(self._config_wrapper.json_encoders, obj, schema) + + schema = self._post_process_generated_schema(schema) + + return schema + + def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema: + """Generate schema for a Pydantic model.""" + with self.defs.get_schema_or_ref(cls) as (model_ref, maybe_schema): + if maybe_schema is not None: + return maybe_schema + + fields = cls.model_fields + decorators = cls.__pydantic_decorators__ + computed_fields = decorators.computed_fields + check_decorator_fields_exist( + chain( + decorators.field_validators.values(), + decorators.field_serializers.values(), + decorators.validators.values(), + ), + {*fields.keys(), *computed_fields.keys()}, + ) + config_wrapper = ConfigWrapper(cls.model_config, check=False) + core_config = config_wrapper.core_config(cls) + metadata = build_metadata_dict(js_functions=[partial(modify_model_json_schema, cls=cls)]) + + model_validators = decorators.model_validators.values() + + extras_schema = None + if core_config.get('extra_fields_behavior') == 'allow': + for tp in (cls, *cls.__mro__): + extras_annotation = cls.__annotations__.get('__pydantic_extra__', None) + if extras_annotation is not None: + tp = get_origin(extras_annotation) + if tp not in (Dict, dict): + raise PydanticSchemaGenerationError( + 'The type annotation for `__pydantic_extra__` must be `Dict[str, ...]`' + ) + extra_items_type = self._get_args_resolving_forward_refs( + cls.__annotations__['__pydantic_extra__'], + required=True, + )[1] + if extra_items_type is not Any: + extras_schema = self.generate_schema(extra_items_type) + break + + with self._config_wrapper_stack.push(config_wrapper), self._types_namespace_stack.push(cls): + self = self._current_generate_schema + if cls.__pydantic_root_model__: + root_field = self._common_field_schema('root', fields['root'], decorators) + inner_schema = root_field['schema'] + inner_schema = apply_model_validators(inner_schema, model_validators, 'inner') + model_schema = core_schema.model_schema( + cls, + inner_schema, + custom_init=getattr(cls, '__pydantic_custom_init__', None), + root_model=True, + post_init=getattr(cls, '__pydantic_post_init__', None), + config=core_config, + ref=model_ref, + metadata=metadata, + ) + else: + fields_schema: core_schema.CoreSchema = core_schema.model_fields_schema( + {k: self._generate_md_field_schema(k, v, decorators) for k, v in fields.items()}, + computed_fields=[ + self._computed_field_schema(d, decorators.field_serializers) + for d in computed_fields.values() + ], + extras_schema=extras_schema, + model_name=cls.__name__, + ) + inner_schema = apply_validators(fields_schema, decorators.root_validators.values(), None) + new_inner_schema = define_expected_missing_refs(inner_schema, recursively_defined_type_refs()) + if new_inner_schema is not None: + inner_schema = new_inner_schema + inner_schema = apply_model_validators(inner_schema, model_validators, 'inner') + + model_schema = core_schema.model_schema( + cls, + inner_schema, + custom_init=getattr(cls, '__pydantic_custom_init__', None), + root_model=False, + post_init=getattr(cls, '__pydantic_post_init__', None), + config=core_config, + ref=model_ref, + metadata=metadata, + ) + + schema = self._apply_model_serializers(model_schema, decorators.model_serializers.values()) + schema = apply_model_validators(schema, model_validators, 'outer') + self.defs.definitions[model_ref] = self._post_process_generated_schema(schema) + return core_schema.definition_reference_schema(model_ref) + + def _unpack_refs_defs(self, schema: CoreSchema) -> CoreSchema: + """Unpack all 'definitions' schemas into `GenerateSchema.defs.definitions` + and return the inner schema. + """ + + def get_ref(s: CoreSchema) -> str: + return s['ref'] # type: ignore + + if schema['type'] == 'definitions': + self.defs.definitions.update({get_ref(s): s for s in schema['definitions']}) + schema = schema['schema'] + return schema + + def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.CoreSchema | None: + """Try to generate schema from either the `__get_pydantic_core_schema__` function or + `__pydantic_core_schema__` property. + + Note: `__get_pydantic_core_schema__` takes priority so it can + decide whether to use a `__pydantic_core_schema__` attribute, or generate a fresh schema. + """ + # avoid calling `__get_pydantic_core_schema__` if we've already visited this object + with self.defs.get_schema_or_ref(obj) as (_, maybe_schema): + if maybe_schema is not None: + return maybe_schema + if obj is source: + ref_mode = 'unpack' + else: + ref_mode = 'to-def' + + schema: CoreSchema + get_schema = getattr(obj, '__get_pydantic_core_schema__', None) + if get_schema is None: + validators = getattr(obj, '__get_validators__', None) + if validators is None: + return None + warn( + '`__get_validators__` is deprecated and will be removed, use `__get_pydantic_core_schema__` instead.', + PydanticDeprecatedSince20, + ) + schema = core_schema.chain_schema([core_schema.with_info_plain_validator_function(v) for v in validators()]) + else: + if len(inspect.signature(get_schema).parameters) == 1: + # (source) -> CoreSchema + schema = get_schema(source) + else: + schema = get_schema( + source, CallbackGetCoreSchemaHandler(self._generate_schema, self, ref_mode=ref_mode) + ) + + schema = self._unpack_refs_defs(schema) + + if is_function_with_inner_schema(schema): + ref = schema['schema'].pop('ref', None) # pyright: ignore[reportGeneralTypeIssues] + if ref: + schema['ref'] = ref + else: + ref = get_ref(schema) + + if ref: + self.defs.definitions[ref] = self._post_process_generated_schema(schema) + return core_schema.definition_reference_schema(ref) + + schema = self._post_process_generated_schema(schema) + + return schema + + def _resolve_forward_ref(self, obj: Any) -> Any: + # we assume that types_namespace has the target of forward references in its scope, + # but this could fail, for example, if calling Validator on an imported type which contains + # forward references to other types only defined in the module from which it was imported + # `Validator(SomeImportedTypeAliasWithAForwardReference)` + # or the equivalent for BaseModel + # class Model(BaseModel): + # x: SomeImportedTypeAliasWithAForwardReference + try: + obj = _typing_extra.eval_type_backport(obj, globalns=self._types_namespace) + except NameError as e: + raise PydanticUndefinedAnnotation.from_name_error(e) from e + + # if obj is still a ForwardRef, it means we can't evaluate it, raise PydanticUndefinedAnnotation + if isinstance(obj, ForwardRef): + raise PydanticUndefinedAnnotation(obj.__forward_arg__, f'Unable to evaluate forward reference {obj}') + + if self._typevars_map: + obj = replace_types(obj, self._typevars_map) + + return obj + + @overload + def _get_args_resolving_forward_refs(self, obj: Any, required: Literal[True]) -> tuple[Any, ...]: + ... + + @overload + def _get_args_resolving_forward_refs(self, obj: Any) -> tuple[Any, ...] | None: + ... + + def _get_args_resolving_forward_refs(self, obj: Any, required: bool = False) -> tuple[Any, ...] | None: + args = get_args(obj) + if args: + args = tuple([self._resolve_forward_ref(a) if isinstance(a, ForwardRef) else a for a in args]) + elif required: # pragma: no cover + raise TypeError(f'Expected {obj} to have generic parameters but it had none') + return args + + def _get_first_arg_or_any(self, obj: Any) -> Any: + args = self._get_args_resolving_forward_refs(obj) + if not args: + return Any + return args[0] + + def _get_first_two_args_or_any(self, obj: Any) -> tuple[Any, Any]: + args = self._get_args_resolving_forward_refs(obj) + if not args: + return (Any, Any) + if len(args) < 2: + origin = get_origin(obj) + raise TypeError(f'Expected two type arguments for {origin}, got 1') + return args[0], args[1] + + def _post_process_generated_schema(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: + if 'metadata' not in schema: + schema['metadata'] = {} + return schema + + def _generate_schema(self, obj: Any) -> core_schema.CoreSchema: + """Recursively generate a pydantic-core schema for any supported python type.""" + has_invalid_schema = self._has_invalid_schema + self._has_invalid_schema = False + schema = self._generate_schema_inner(obj) + self._has_invalid_schema = self._has_invalid_schema or has_invalid_schema + return schema + + def _generate_schema_inner(self, obj: Any) -> core_schema.CoreSchema: + if isinstance(obj, _AnnotatedType): + return self._annotated_schema(obj) + + if isinstance(obj, dict): + # we assume this is already a valid schema + return obj # type: ignore[return-value] + + if isinstance(obj, str): + obj = ForwardRef(obj) + + if isinstance(obj, ForwardRef): + return self.generate_schema(self._resolve_forward_ref(obj)) + + from ..main import BaseModel + + if lenient_issubclass(obj, BaseModel): + return self._model_schema(obj) + + if isinstance(obj, PydanticRecursiveRef): + return core_schema.definition_reference_schema(schema_ref=obj.type_ref) + + return self.match_type(obj) + + def match_type(self, obj: Any) -> core_schema.CoreSchema: # noqa: C901 + """Main mapping of types to schemas. + + The general structure is a series of if statements starting with the simple cases + (non-generic primitive types) and then handling generics and other more complex cases. + + Each case either generates a schema directly, calls into a public user-overridable method + (like `GenerateSchema.tuple_variable_schema`) or calls into a private method that handles some + boilerplate before calling into the user-facing method (e.g. `GenerateSchema._tuple_schema`). + + The idea is that we'll evolve this into adding more and more user facing methods over time + as they get requested and we figure out what the right API for them is. + """ + if obj is str: + return self.str_schema() + elif obj is bytes: + return core_schema.bytes_schema() + elif obj is int: + return core_schema.int_schema() + elif obj is float: + return core_schema.float_schema() + elif obj is bool: + return core_schema.bool_schema() + elif obj is Any or obj is object: + return core_schema.any_schema() + elif obj is None or obj is _typing_extra.NoneType: + return core_schema.none_schema() + elif obj in TUPLE_TYPES: + return self._tuple_schema(obj) + elif obj in LIST_TYPES: + return self._list_schema(obj, self._get_first_arg_or_any(obj)) + elif obj in SET_TYPES: + return self._set_schema(obj, self._get_first_arg_or_any(obj)) + elif obj in FROZEN_SET_TYPES: + return self._frozenset_schema(obj, self._get_first_arg_or_any(obj)) + elif obj in DICT_TYPES: + return self._dict_schema(obj, *self._get_first_two_args_or_any(obj)) + elif isinstance(obj, TypeAliasType): + return self._type_alias_type_schema(obj) + elif obj == type: + return self._type_schema() + elif _typing_extra.is_callable_type(obj): + return core_schema.callable_schema() + elif _typing_extra.is_literal_type(obj): + return self._literal_schema(obj) + elif is_typeddict(obj): + return self._typed_dict_schema(obj, None) + elif _typing_extra.is_namedtuple(obj): + return self._namedtuple_schema(obj, None) + elif _typing_extra.is_new_type(obj): + # NewType, can't use isinstance because it fails <3.10 + return self.generate_schema(obj.__supertype__) + elif obj == re.Pattern: + return self._pattern_schema(obj) + elif obj is collections.abc.Hashable or obj is typing.Hashable: + return self._hashable_schema() + elif isinstance(obj, typing.TypeVar): + return self._unsubstituted_typevar_schema(obj) + elif is_finalvar(obj): + if obj is Final: + return core_schema.any_schema() + return self.generate_schema( + self._get_first_arg_or_any(obj), + ) + elif isinstance(obj, (FunctionType, LambdaType, MethodType, partial)): + return self._callable_schema(obj) + elif inspect.isclass(obj) and issubclass(obj, Enum): + from ._std_types_schema import get_enum_core_schema + + return get_enum_core_schema(obj, self._config_wrapper.config_dict) + + if _typing_extra.is_dataclass(obj): + return self._dataclass_schema(obj, None) + + res = self._get_prepare_pydantic_annotations_for_known_type(obj, ()) + if res is not None: + source_type, annotations = res + return self._apply_annotations(source_type, annotations) + + origin = get_origin(obj) + if origin is not None: + return self._match_generic_type(obj, origin) + + if self._arbitrary_types: + return self._arbitrary_type_schema(obj) + return self._unknown_type_schema(obj) + + def _match_generic_type(self, obj: Any, origin: Any) -> CoreSchema: # noqa: C901 + if isinstance(origin, TypeAliasType): + return self._type_alias_type_schema(obj) + + # Need to handle generic dataclasses before looking for the schema properties because attribute accesses + # on _GenericAlias delegate to the origin type, so lose the information about the concrete parametrization + # As a result, currently, there is no way to cache the schema for generic dataclasses. This may be possible + # to resolve by modifying the value returned by `Generic.__class_getitem__`, but that is a dangerous game. + if _typing_extra.is_dataclass(origin): + return self._dataclass_schema(obj, origin) + if _typing_extra.is_namedtuple(origin): + return self._namedtuple_schema(obj, origin) + + from_property = self._generate_schema_from_property(origin, obj) + if from_property is not None: + return from_property + + if _typing_extra.origin_is_union(origin): + return self._union_schema(obj) + elif origin in TUPLE_TYPES: + return self._tuple_schema(obj) + elif origin in LIST_TYPES: + return self._list_schema(obj, self._get_first_arg_or_any(obj)) + elif origin in SET_TYPES: + return self._set_schema(obj, self._get_first_arg_or_any(obj)) + elif origin in FROZEN_SET_TYPES: + return self._frozenset_schema(obj, self._get_first_arg_or_any(obj)) + elif origin in DICT_TYPES: + return self._dict_schema(obj, *self._get_first_two_args_or_any(obj)) + elif is_typeddict(origin): + return self._typed_dict_schema(obj, origin) + elif origin in (typing.Type, type): + return self._subclass_schema(obj) + elif origin in {typing.Sequence, collections.abc.Sequence}: + return self._sequence_schema(obj) + elif origin in {typing.Iterable, collections.abc.Iterable, typing.Generator, collections.abc.Generator}: + return self._iterable_schema(obj) + elif origin in (re.Pattern, typing.Pattern): + return self._pattern_schema(obj) + + if self._arbitrary_types: + return self._arbitrary_type_schema(origin) + return self._unknown_type_schema(obj) + + def _generate_td_field_schema( + self, + name: str, + field_info: FieldInfo, + decorators: DecoratorInfos, + *, + required: bool = True, + ) -> core_schema.TypedDictField: + """Prepare a TypedDictField to represent a model or typeddict field.""" + common_field = self._common_field_schema(name, field_info, decorators) + return core_schema.typed_dict_field( + common_field['schema'], + required=False if not field_info.is_required() else required, + serialization_exclude=common_field['serialization_exclude'], + validation_alias=common_field['validation_alias'], + serialization_alias=common_field['serialization_alias'], + metadata=common_field['metadata'], + ) + + def _generate_md_field_schema( + self, + name: str, + field_info: FieldInfo, + decorators: DecoratorInfos, + ) -> core_schema.ModelField: + """Prepare a ModelField to represent a model field.""" + common_field = self._common_field_schema(name, field_info, decorators) + return core_schema.model_field( + common_field['schema'], + serialization_exclude=common_field['serialization_exclude'], + validation_alias=common_field['validation_alias'], + serialization_alias=common_field['serialization_alias'], + frozen=common_field['frozen'], + metadata=common_field['metadata'], + ) + + def _generate_dc_field_schema( + self, + name: str, + field_info: FieldInfo, + decorators: DecoratorInfos, + ) -> core_schema.DataclassField: + """Prepare a DataclassField to represent the parameter/field, of a dataclass.""" + common_field = self._common_field_schema(name, field_info, decorators) + return core_schema.dataclass_field( + name, + common_field['schema'], + init=field_info.init, + init_only=field_info.init_var or None, + kw_only=None if field_info.kw_only else False, + serialization_exclude=common_field['serialization_exclude'], + validation_alias=common_field['validation_alias'], + serialization_alias=common_field['serialization_alias'], + frozen=common_field['frozen'], + metadata=common_field['metadata'], + ) + + @staticmethod + def _apply_alias_generator_to_field_info( + alias_generator: Callable[[str], str] | AliasGenerator, field_info: FieldInfo, field_name: str + ) -> None: + """Apply an alias_generator to aliases on a FieldInfo instance if appropriate. + + Args: + alias_generator: A callable that takes a string and returns a string, or an AliasGenerator instance. + field_info: The FieldInfo instance to which the alias_generator is (maybe) applied. + field_name: The name of the field from which to generate the alias. + """ + # Apply an alias_generator if + # 1. An alias is not specified + # 2. An alias is specified, but the priority is <= 1 + if ( + field_info.alias_priority is None + or field_info.alias_priority <= 1 + or field_info.alias is None + or field_info.validation_alias is None + or field_info.serialization_alias is None + ): + alias, validation_alias, serialization_alias = None, None, None + + if isinstance(alias_generator, AliasGenerator): + alias, validation_alias, serialization_alias = alias_generator.generate_aliases(field_name) + elif isinstance(alias_generator, Callable): + alias = alias_generator(field_name) + if not isinstance(alias, str): + raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}') + + # if priority is not set, we set to 1 + # which supports the case where the alias_generator from a child class is used + # to generate an alias for a field in a parent class + if field_info.alias_priority is None or field_info.alias_priority <= 1: + field_info.alias_priority = 1 + + # if the priority is 1, then we set the aliases to the generated alias + if field_info.alias_priority == 1: + field_info.serialization_alias = serialization_alias or alias + field_info.validation_alias = validation_alias or alias + field_info.alias = alias + + # if any of the aliases are not set, then we set them to the corresponding generated alias + if field_info.alias is None: + field_info.alias = alias + if field_info.serialization_alias is None: + field_info.serialization_alias = serialization_alias or alias + if field_info.validation_alias is None: + field_info.validation_alias = validation_alias or alias + + @staticmethod + def _apply_alias_generator_to_computed_field_info( + alias_generator: Callable[[str], str] | AliasGenerator, + computed_field_info: ComputedFieldInfo, + computed_field_name: str, + ): + """Apply an alias_generator to alias on a ComputedFieldInfo instance if appropriate. + + Args: + alias_generator: A callable that takes a string and returns a string, or an AliasGenerator instance. + computed_field_info: The ComputedFieldInfo instance to which the alias_generator is (maybe) applied. + computed_field_name: The name of the computed field from which to generate the alias. + """ + # Apply an alias_generator if + # 1. An alias is not specified + # 2. An alias is specified, but the priority is <= 1 + + if ( + computed_field_info.alias_priority is None + or computed_field_info.alias_priority <= 1 + or computed_field_info.alias is None + ): + alias, validation_alias, serialization_alias = None, None, None + + if isinstance(alias_generator, AliasGenerator): + alias, validation_alias, serialization_alias = alias_generator.generate_aliases(computed_field_name) + elif isinstance(alias_generator, Callable): + alias = alias_generator(computed_field_name) + if not isinstance(alias, str): + raise TypeError(f'alias_generator {alias_generator} must return str, not {alias.__class__}') + + # if priority is not set, we set to 1 + # which supports the case where the alias_generator from a child class is used + # to generate an alias for a field in a parent class + if computed_field_info.alias_priority is None or computed_field_info.alias_priority <= 1: + computed_field_info.alias_priority = 1 + + # if the priority is 1, then we set the aliases to the generated alias + # note that we use the serialization_alias with priority over alias, as computed_field + # aliases are used for serialization only (not validation) + if computed_field_info.alias_priority == 1: + computed_field_info.alias = serialization_alias or alias + + def _common_field_schema( # C901 + self, name: str, field_info: FieldInfo, decorators: DecoratorInfos + ) -> _CommonField: + # Update FieldInfo annotation if appropriate: + from .. import AliasChoices, AliasPath + from ..fields import FieldInfo + + if has_instance_in_type(field_info.annotation, (ForwardRef, str)): + types_namespace = self._types_namespace + if self._typevars_map: + types_namespace = (types_namespace or {}).copy() + # Ensure that typevars get mapped to their concrete types: + types_namespace.update({k.__name__: v for k, v in self._typevars_map.items()}) + + evaluated = _typing_extra.eval_type_lenient(field_info.annotation, types_namespace) + if evaluated is not field_info.annotation and not has_instance_in_type(evaluated, PydanticRecursiveRef): + new_field_info = FieldInfo.from_annotation(evaluated) + field_info.annotation = new_field_info.annotation + + # Handle any field info attributes that may have been obtained from now-resolved annotations + for k, v in new_field_info._attributes_set.items(): + # If an attribute is already set, it means it was set by assigning to a call to Field (or just a + # default value), and that should take the highest priority. So don't overwrite existing attributes. + # We skip over "attributes" that are present in the metadata_lookup dict because these won't + # actually end up as attributes of the `FieldInfo` instance. + if k not in field_info._attributes_set and k not in field_info.metadata_lookup: + setattr(field_info, k, v) + + # Finally, ensure the field info also reflects all the `_attributes_set` that are actually metadata. + field_info.metadata = [*new_field_info.metadata, *field_info.metadata] + + source_type, annotations = field_info.annotation, field_info.metadata + + def set_discriminator(schema: CoreSchema) -> CoreSchema: + schema = self._apply_discriminator_to_union(schema, field_info.discriminator) + return schema + + with self.field_name_stack.push(name): + if field_info.discriminator is not None: + schema = self._apply_annotations(source_type, annotations, transform_inner_schema=set_discriminator) + else: + schema = self._apply_annotations( + source_type, + annotations, + ) + + # This V1 compatibility shim should eventually be removed + # push down any `each_item=True` validators + # note that this won't work for any Annotated types that get wrapped by a function validator + # but that's okay because that didn't exist in V1 + this_field_validators = filter_field_decorator_info_by_field(decorators.validators.values(), name) + if _validators_require_validate_default(this_field_validators): + field_info.validate_default = True + each_item_validators = [v for v in this_field_validators if v.info.each_item is True] + this_field_validators = [v for v in this_field_validators if v not in each_item_validators] + schema = apply_each_item_validators(schema, each_item_validators, name) + + schema = apply_validators(schema, filter_field_decorator_info_by_field(this_field_validators, name), name) + schema = apply_validators( + schema, filter_field_decorator_info_by_field(decorators.field_validators.values(), name), name + ) + + # the default validator needs to go outside of any other validators + # so that it is the topmost validator for the field validator + # which uses it to check if the field has a default value or not + if not field_info.is_required(): + schema = wrap_default(field_info, schema) + + schema = self._apply_field_serializers( + schema, filter_field_decorator_info_by_field(decorators.field_serializers.values(), name) + ) + json_schema_updates = { + 'title': field_info.title, + 'description': field_info.description, + 'examples': to_jsonable_python(field_info.examples), + } + json_schema_updates = {k: v for k, v in json_schema_updates.items() if v is not None} + + json_schema_extra = field_info.json_schema_extra + + metadata = build_metadata_dict( + js_annotation_functions=[get_json_schema_update_func(json_schema_updates, json_schema_extra)] + ) + + alias_generator = self._config_wrapper.alias_generator + if alias_generator is not None: + self._apply_alias_generator_to_field_info(alias_generator, field_info, name) + + if isinstance(field_info.validation_alias, (AliasChoices, AliasPath)): + validation_alias = field_info.validation_alias.convert_to_aliases() + else: + validation_alias = field_info.validation_alias + + return _common_field( + schema, + serialization_exclude=True if field_info.exclude else None, + validation_alias=validation_alias, + serialization_alias=field_info.serialization_alias, + frozen=field_info.frozen, + metadata=metadata, + ) + + def _union_schema(self, union_type: Any) -> core_schema.CoreSchema: + """Generate schema for a Union.""" + args = self._get_args_resolving_forward_refs(union_type, required=True) + choices: list[CoreSchema] = [] + nullable = False + for arg in args: + if arg is None or arg is _typing_extra.NoneType: + nullable = True + else: + choices.append(self.generate_schema(arg)) + + if len(choices) == 1: + s = choices[0] + else: + choices_with_tags: list[CoreSchema | tuple[CoreSchema, str]] = [] + for choice in choices: + metadata = choice.get('metadata') + if isinstance(metadata, dict): + tag = metadata.get(_core_utils.TAGGED_UNION_TAG_KEY) + if tag is not None: + choices_with_tags.append((choice, tag)) + else: + choices_with_tags.append(choice) + s = core_schema.union_schema(choices_with_tags) + + if nullable: + s = core_schema.nullable_schema(s) + return s + + def _type_alias_type_schema( + self, + obj: Any, # TypeAliasType + ) -> CoreSchema: + with self.defs.get_schema_or_ref(obj) as (ref, maybe_schema): + if maybe_schema is not None: + return maybe_schema + + origin = get_origin(obj) or obj + + annotation = origin.__value__ + typevars_map = get_standard_typevars_map(obj) + + with self._types_namespace_stack.push(origin): + annotation = _typing_extra.eval_type_lenient(annotation, self._types_namespace) + annotation = replace_types(annotation, typevars_map) + schema = self.generate_schema(annotation) + assert schema['type'] != 'definitions' + schema['ref'] = ref # type: ignore + self.defs.definitions[ref] = schema + return core_schema.definition_reference_schema(ref) + + def _literal_schema(self, literal_type: Any) -> CoreSchema: + """Generate schema for a Literal.""" + expected = _typing_extra.all_literal_values(literal_type) + assert expected, f'literal "expected" cannot be empty, obj={literal_type}' + return core_schema.literal_schema(expected) + + def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.CoreSchema: + """Generate schema for a TypedDict. + + It is not possible to track required/optional keys in TypedDict without __required_keys__ + since TypedDict.__new__ erases the base classes (it replaces them with just `dict`) + and thus we can track usage of total=True/False + __required_keys__ was added in Python 3.9 + (https://github.com/miss-islington/cpython/blob/1e9939657dd1f8eb9f596f77c1084d2d351172fc/Doc/library/typing.rst?plain=1#L1546-L1548) + however it is buggy + (https://github.com/python/typing_extensions/blob/ac52ac5f2cb0e00e7988bae1e2a1b8257ac88d6d/src/typing_extensions.py#L657-L666). + + On 3.11 but < 3.12 TypedDict does not preserve inheritance information. + + Hence to avoid creating validators that do not do what users expect we only + support typing.TypedDict on Python >= 3.12 or typing_extension.TypedDict on all versions + """ + from ..fields import FieldInfo + + with self.defs.get_schema_or_ref(typed_dict_cls) as (typed_dict_ref, maybe_schema): + if maybe_schema is not None: + return maybe_schema + + typevars_map = get_standard_typevars_map(typed_dict_cls) + if origin is not None: + typed_dict_cls = origin + + if not _SUPPORTS_TYPEDDICT and type(typed_dict_cls).__module__ == 'typing': + raise PydanticUserError( + 'Please use `typing_extensions.TypedDict` instead of `typing.TypedDict` on Python < 3.12.', + code='typed-dict-version', + ) + + try: + config: ConfigDict | None = get_attribute_from_bases(typed_dict_cls, '__pydantic_config__') + except AttributeError: + config = None + + with self._config_wrapper_stack.push(config), self._types_namespace_stack.push(typed_dict_cls): + core_config = self._config_wrapper.core_config(typed_dict_cls) + + self = self._current_generate_schema + + required_keys: frozenset[str] = typed_dict_cls.__required_keys__ + + fields: dict[str, core_schema.TypedDictField] = {} + + decorators = DecoratorInfos.build(typed_dict_cls) + + for field_name, annotation in get_type_hints_infer_globalns( + typed_dict_cls, localns=self._types_namespace, include_extras=True + ).items(): + annotation = replace_types(annotation, typevars_map) + required = field_name in required_keys + + if get_origin(annotation) == _typing_extra.Required: + required = True + annotation = self._get_args_resolving_forward_refs( + annotation, + required=True, + )[0] + elif get_origin(annotation) == _typing_extra.NotRequired: + required = False + annotation = self._get_args_resolving_forward_refs( + annotation, + required=True, + )[0] + + field_info = FieldInfo.from_annotation(annotation) + fields[field_name] = self._generate_td_field_schema( + field_name, field_info, decorators, required=required + ) + + metadata = build_metadata_dict( + js_functions=[partial(modify_model_json_schema, cls=typed_dict_cls)], typed_dict_cls=typed_dict_cls + ) + + td_schema = core_schema.typed_dict_schema( + fields, + computed_fields=[ + self._computed_field_schema(d, decorators.field_serializers) + for d in decorators.computed_fields.values() + ], + ref=typed_dict_ref, + metadata=metadata, + config=core_config, + ) + + schema = self._apply_model_serializers(td_schema, decorators.model_serializers.values()) + schema = apply_model_validators(schema, decorators.model_validators.values(), 'all') + self.defs.definitions[typed_dict_ref] = self._post_process_generated_schema(schema) + return core_schema.definition_reference_schema(typed_dict_ref) + + def _namedtuple_schema(self, namedtuple_cls: Any, origin: Any) -> core_schema.CoreSchema: + """Generate schema for a NamedTuple.""" + with self.defs.get_schema_or_ref(namedtuple_cls) as (namedtuple_ref, maybe_schema): + if maybe_schema is not None: + return maybe_schema + typevars_map = get_standard_typevars_map(namedtuple_cls) + if origin is not None: + namedtuple_cls = origin + + annotations: dict[str, Any] = get_type_hints_infer_globalns( + namedtuple_cls, include_extras=True, localns=self._types_namespace + ) + if not annotations: + # annotations is empty, happens if namedtuple_cls defined via collections.namedtuple(...) + annotations = {k: Any for k in namedtuple_cls._fields} + + if typevars_map: + annotations = { + field_name: replace_types(annotation, typevars_map) + for field_name, annotation in annotations.items() + } + + arguments_schema = core_schema.arguments_schema( + [ + self._generate_parameter_schema( + field_name, annotation, default=namedtuple_cls._field_defaults.get(field_name, Parameter.empty) + ) + for field_name, annotation in annotations.items() + ], + metadata=build_metadata_dict(js_prefer_positional_arguments=True), + ) + return core_schema.call_schema(arguments_schema, namedtuple_cls, ref=namedtuple_ref) + + def _generate_parameter_schema( + self, + name: str, + annotation: type[Any], + default: Any = Parameter.empty, + mode: Literal['positional_only', 'positional_or_keyword', 'keyword_only'] | None = None, + ) -> core_schema.ArgumentsParameter: + """Prepare a ArgumentsParameter to represent a field in a namedtuple or function signature.""" + from ..fields import FieldInfo + + if default is Parameter.empty: + field = FieldInfo.from_annotation(annotation) + else: + field = FieldInfo.from_annotated_attribute(annotation, default) + assert field.annotation is not None, 'field.annotation should not be None when generating a schema' + source_type, annotations = field.annotation, field.metadata + with self.field_name_stack.push(name): + schema = self._apply_annotations(source_type, annotations) + + if not field.is_required(): + schema = wrap_default(field, schema) + + parameter_schema = core_schema.arguments_parameter(name, schema) + if mode is not None: + parameter_schema['mode'] = mode + if field.alias is not None: + parameter_schema['alias'] = field.alias + else: + alias_generator = self._config_wrapper.alias_generator + if isinstance(alias_generator, AliasGenerator) and alias_generator.alias is not None: + parameter_schema['alias'] = alias_generator.alias(name) + elif isinstance(alias_generator, Callable): + parameter_schema['alias'] = alias_generator(name) + return parameter_schema + + def _tuple_schema(self, tuple_type: Any) -> core_schema.CoreSchema: + """Generate schema for a Tuple, e.g. `tuple[int, str]` or `tuple[int, ...]`.""" + # TODO: do we really need to resolve type vars here? + typevars_map = get_standard_typevars_map(tuple_type) + params = self._get_args_resolving_forward_refs(tuple_type) + + if typevars_map and params: + params = tuple(replace_types(param, typevars_map) for param in params) + + # NOTE: subtle difference: `tuple[()]` gives `params=()`, whereas `typing.Tuple[()]` gives `params=((),)` + # This is only true for <3.11, on Python 3.11+ `typing.Tuple[()]` gives `params=()` + if not params: + if tuple_type in TUPLE_TYPES: + return core_schema.tuple_schema([core_schema.any_schema()], variadic_item_index=0) + else: + # special case for `tuple[()]` which means `tuple[]` - an empty tuple + return core_schema.tuple_schema([]) + elif params[-1] is Ellipsis: + if len(params) == 2: + return core_schema.tuple_schema([self.generate_schema(params[0])], variadic_item_index=0) + else: + # TODO: something like https://github.com/pydantic/pydantic/issues/5952 + raise ValueError('Variable tuples can only have one type') + elif len(params) == 1 and params[0] == (): + # special case for `Tuple[()]` which means `Tuple[]` - an empty tuple + # NOTE: This conditional can be removed when we drop support for Python 3.10. + return core_schema.tuple_schema([]) + else: + return core_schema.tuple_schema([self.generate_schema(param) for param in params]) + + def _type_schema(self) -> core_schema.CoreSchema: + return core_schema.custom_error_schema( + core_schema.is_instance_schema(type), + custom_error_type='is_type', + custom_error_message='Input should be a type', + ) + + def _union_is_subclass_schema(self, union_type: Any) -> core_schema.CoreSchema: + """Generate schema for `Type[Union[X, ...]]`.""" + args = self._get_args_resolving_forward_refs(union_type, required=True) + return core_schema.union_schema([self.generate_schema(typing.Type[args]) for args in args]) + + def _subclass_schema(self, type_: Any) -> core_schema.CoreSchema: + """Generate schema for a Type, e.g. `Type[int]`.""" + type_param = self._get_first_arg_or_any(type_) + if type_param == Any: + return self._type_schema() + elif isinstance(type_param, typing.TypeVar): + if type_param.__bound__: + if _typing_extra.origin_is_union(get_origin(type_param.__bound__)): + return self._union_is_subclass_schema(type_param.__bound__) + return core_schema.is_subclass_schema(type_param.__bound__) + elif type_param.__constraints__: + return core_schema.union_schema( + [self.generate_schema(typing.Type[c]) for c in type_param.__constraints__] + ) + else: + return self._type_schema() + elif _typing_extra.origin_is_union(get_origin(type_param)): + return self._union_is_subclass_schema(type_param) + else: + return core_schema.is_subclass_schema(type_param) + + def _sequence_schema(self, sequence_type: Any) -> core_schema.CoreSchema: + """Generate schema for a Sequence, e.g. `Sequence[int]`.""" + item_type = self._get_first_arg_or_any(sequence_type) + + list_schema = core_schema.list_schema(self.generate_schema(item_type)) + python_schema = core_schema.is_instance_schema(typing.Sequence, cls_repr='Sequence') + if item_type != Any: + from ._validators import sequence_validator + + python_schema = core_schema.chain_schema( + [python_schema, core_schema.no_info_wrap_validator_function(sequence_validator, list_schema)], + ) + return core_schema.json_or_python_schema(json_schema=list_schema, python_schema=python_schema) + + def _iterable_schema(self, type_: Any) -> core_schema.GeneratorSchema: + """Generate a schema for an `Iterable`.""" + item_type = self._get_first_arg_or_any(type_) + + return core_schema.generator_schema(self.generate_schema(item_type)) + + def _pattern_schema(self, pattern_type: Any) -> core_schema.CoreSchema: + from . import _validators + + metadata = build_metadata_dict(js_functions=[lambda _1, _2: {'type': 'string', 'format': 'regex'}]) + ser = core_schema.plain_serializer_function_ser_schema( + attrgetter('pattern'), when_used='json', return_schema=core_schema.str_schema() + ) + if pattern_type == typing.Pattern or pattern_type == re.Pattern: + # bare type + return core_schema.no_info_plain_validator_function( + _validators.pattern_either_validator, serialization=ser, metadata=metadata + ) + + param = self._get_args_resolving_forward_refs( + pattern_type, + required=True, + )[0] + if param == str: + return core_schema.no_info_plain_validator_function( + _validators.pattern_str_validator, serialization=ser, metadata=metadata + ) + elif param == bytes: + return core_schema.no_info_plain_validator_function( + _validators.pattern_bytes_validator, serialization=ser, metadata=metadata + ) + else: + raise PydanticSchemaGenerationError(f'Unable to generate pydantic-core schema for {pattern_type!r}.') + + def _hashable_schema(self) -> core_schema.CoreSchema: + return core_schema.custom_error_schema( + core_schema.is_instance_schema(collections.abc.Hashable), + custom_error_type='is_hashable', + custom_error_message='Input should be hashable', + ) + + def _dataclass_schema( + self, dataclass: type[StandardDataclass], origin: type[StandardDataclass] | None + ) -> core_schema.CoreSchema: + """Generate schema for a dataclass.""" + with self.defs.get_schema_or_ref(dataclass) as (dataclass_ref, maybe_schema): + if maybe_schema is not None: + return maybe_schema + + typevars_map = get_standard_typevars_map(dataclass) + if origin is not None: + dataclass = origin + + config = getattr(dataclass, '__pydantic_config__', None) + with self._config_wrapper_stack.push(config), self._types_namespace_stack.push(dataclass): + core_config = self._config_wrapper.core_config(dataclass) + + self = self._current_generate_schema + + from ..dataclasses import is_pydantic_dataclass + + if is_pydantic_dataclass(dataclass): + fields = deepcopy(dataclass.__pydantic_fields__) + if typevars_map: + for field in fields.values(): + field.apply_typevars_map(typevars_map, self._types_namespace) + else: + fields = collect_dataclass_fields( + dataclass, + self._types_namespace, + typevars_map=typevars_map, + ) + + # disallow combination of init=False on a dataclass field and extra='allow' on a dataclass + if config and config.get('extra') == 'allow': + # disallow combination of init=False on a dataclass field and extra='allow' on a dataclass + for field_name, field in fields.items(): + if field.init is False: + raise PydanticUserError( + f'Field {field_name} has `init=False` and dataclass has config setting `extra="allow"`. ' + f'This combination is not allowed.', + code='dataclass-init-false-extra-allow', + ) + + decorators = dataclass.__dict__.get('__pydantic_decorators__') or DecoratorInfos.build(dataclass) + # Move kw_only=False args to the start of the list, as this is how vanilla dataclasses work. + # Note that when kw_only is missing or None, it is treated as equivalent to kw_only=True + args = sorted( + (self._generate_dc_field_schema(k, v, decorators) for k, v in fields.items()), + key=lambda a: a.get('kw_only') is not False, + ) + has_post_init = hasattr(dataclass, '__post_init__') + has_slots = hasattr(dataclass, '__slots__') + + args_schema = core_schema.dataclass_args_schema( + dataclass.__name__, + args, + computed_fields=[ + self._computed_field_schema(d, decorators.field_serializers) + for d in decorators.computed_fields.values() + ], + collect_init_only=has_post_init, + ) + + inner_schema = apply_validators(args_schema, decorators.root_validators.values(), None) + + model_validators = decorators.model_validators.values() + inner_schema = apply_model_validators(inner_schema, model_validators, 'inner') + + dc_schema = core_schema.dataclass_schema( + dataclass, + inner_schema, + post_init=has_post_init, + ref=dataclass_ref, + fields=[field.name for field in dataclasses.fields(dataclass)], + slots=has_slots, + config=core_config, + ) + schema = self._apply_model_serializers(dc_schema, decorators.model_serializers.values()) + schema = apply_model_validators(schema, model_validators, 'outer') + self.defs.definitions[dataclass_ref] = self._post_process_generated_schema(schema) + return core_schema.definition_reference_schema(dataclass_ref) + + def _callable_schema(self, function: Callable[..., Any]) -> core_schema.CallSchema: + """Generate schema for a Callable. + + TODO support functional validators once we support them in Config + """ + sig = signature(function) + + type_hints = _typing_extra.get_function_type_hints(function) + + mode_lookup: dict[_ParameterKind, Literal['positional_only', 'positional_or_keyword', 'keyword_only']] = { + Parameter.POSITIONAL_ONLY: 'positional_only', + Parameter.POSITIONAL_OR_KEYWORD: 'positional_or_keyword', + Parameter.KEYWORD_ONLY: 'keyword_only', + } + + arguments_list: list[core_schema.ArgumentsParameter] = [] + var_args_schema: core_schema.CoreSchema | None = None + var_kwargs_schema: core_schema.CoreSchema | None = None + + for name, p in sig.parameters.items(): + if p.annotation is sig.empty: + annotation = Any + else: + annotation = type_hints[name] + + parameter_mode = mode_lookup.get(p.kind) + if parameter_mode is not None: + arg_schema = self._generate_parameter_schema(name, annotation, p.default, parameter_mode) + arguments_list.append(arg_schema) + elif p.kind == Parameter.VAR_POSITIONAL: + var_args_schema = self.generate_schema(annotation) + else: + assert p.kind == Parameter.VAR_KEYWORD, p.kind + var_kwargs_schema = self.generate_schema(annotation) + + return_schema: core_schema.CoreSchema | None = None + config_wrapper = self._config_wrapper + if config_wrapper.validate_return: + return_hint = type_hints.get('return') + if return_hint is not None: + return_schema = self.generate_schema(return_hint) + + return core_schema.call_schema( + core_schema.arguments_schema( + arguments_list, + var_args_schema=var_args_schema, + var_kwargs_schema=var_kwargs_schema, + populate_by_name=config_wrapper.populate_by_name, + ), + function, + return_schema=return_schema, + ) + + def _unsubstituted_typevar_schema(self, typevar: typing.TypeVar) -> core_schema.CoreSchema: + assert isinstance(typevar, typing.TypeVar) + + bound = typevar.__bound__ + constraints = typevar.__constraints__ + default = getattr(typevar, '__default__', None) + + if (bound is not None) + (len(constraints) != 0) + (default is not None) > 1: + raise NotImplementedError( + 'Pydantic does not support mixing more than one of TypeVar bounds, constraints and defaults' + ) + + if default is not None: + return self.generate_schema(default) + elif constraints: + return self._union_schema(typing.Union[constraints]) # type: ignore + elif bound: + schema = self.generate_schema(bound) + schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( + lambda x, h: h(x), schema=core_schema.any_schema() + ) + return schema + else: + return core_schema.any_schema() + + def _computed_field_schema( + self, + d: Decorator[ComputedFieldInfo], + field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]], + ) -> core_schema.ComputedField: + try: + return_type = _decorators.get_function_return_type(d.func, d.info.return_type, self._types_namespace) + except NameError as e: + raise PydanticUndefinedAnnotation.from_name_error(e) from e + if return_type is PydanticUndefined: + raise PydanticUserError( + 'Computed field is missing return type annotation or specifying `return_type`' + ' to the `@computed_field` decorator (e.g. `@computed_field(return_type=int|str)`)', + code='model-field-missing-annotation', + ) + + return_type = replace_types(return_type, self._typevars_map) + # Create a new ComputedFieldInfo so that different type parametrizations of the same + # generic model's computed field can have different return types. + d.info = dataclasses.replace(d.info, return_type=return_type) + return_type_schema = self.generate_schema(return_type) + # Apply serializers to computed field if there exist + return_type_schema = self._apply_field_serializers( + return_type_schema, + filter_field_decorator_info_by_field(field_serializers.values(), d.cls_var_name), + computed_field=True, + ) + + alias_generator = self._config_wrapper.alias_generator + if alias_generator is not None: + self._apply_alias_generator_to_computed_field_info( + alias_generator=alias_generator, computed_field_info=d.info, computed_field_name=d.cls_var_name + ) + + def set_computed_field_metadata(schema: CoreSchemaOrField, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + json_schema = handler(schema) + + json_schema['readOnly'] = True + + title = d.info.title + if title is not None: + json_schema['title'] = title + + description = d.info.description + if description is not None: + json_schema['description'] = description + + examples = d.info.examples + if examples is not None: + json_schema['examples'] = to_jsonable_python(examples) + + json_schema_extra = d.info.json_schema_extra + if json_schema_extra is not None: + add_json_schema_extra(json_schema, json_schema_extra) + + return json_schema + + metadata = build_metadata_dict(js_annotation_functions=[set_computed_field_metadata]) + return core_schema.computed_field( + d.cls_var_name, return_schema=return_type_schema, alias=d.info.alias, metadata=metadata + ) + + def _annotated_schema(self, annotated_type: Any) -> core_schema.CoreSchema: + """Generate schema for an Annotated type, e.g. `Annotated[int, Field(...)]` or `Annotated[int, Gt(0)]`.""" + from ..fields import FieldInfo + + source_type, *annotations = self._get_args_resolving_forward_refs( + annotated_type, + required=True, + ) + schema = self._apply_annotations(source_type, annotations) + # put the default validator last so that TypeAdapter.get_default_value() works + # even if there are function validators involved + for annotation in annotations: + if isinstance(annotation, FieldInfo): + schema = wrap_default(annotation, schema) + return schema + + def _get_prepare_pydantic_annotations_for_known_type( + self, obj: Any, annotations: tuple[Any, ...] + ) -> tuple[Any, list[Any]] | None: + from ._std_types_schema import PREPARE_METHODS + + # Check for hashability + try: + hash(obj) + except TypeError: + # obj is definitely not a known type if this fails + return None + + for gen in PREPARE_METHODS: + res = gen(obj, annotations, self._config_wrapper.config_dict) + if res is not None: + return res + + return None + + def _apply_annotations( + self, + source_type: Any, + annotations: list[Any], + transform_inner_schema: Callable[[CoreSchema], CoreSchema] = lambda x: x, + ) -> CoreSchema: + """Apply arguments from `Annotated` or from `FieldInfo` to a schema. + + This gets called by `GenerateSchema._annotated_schema` but differs from it in that it does + not expect `source_type` to be an `Annotated` object, it expects it to be the first argument of that + (in other words, `GenerateSchema._annotated_schema` just unpacks `Annotated`, this process it). + """ + annotations = list(_known_annotated_metadata.expand_grouped_metadata(annotations)) + res = self._get_prepare_pydantic_annotations_for_known_type(source_type, tuple(annotations)) + if res is not None: + source_type, annotations = res + + pydantic_js_annotation_functions: list[GetJsonSchemaFunction] = [] + + def inner_handler(obj: Any) -> CoreSchema: + from_property = self._generate_schema_from_property(obj, obj) + if from_property is None: + schema = self._generate_schema(obj) + else: + schema = from_property + metadata_js_function = _extract_get_pydantic_json_schema(obj, schema) + if metadata_js_function is not None: + metadata_schema = resolve_original_schema(schema, self.defs.definitions) + if metadata_schema is not None: + self._add_js_function(metadata_schema, metadata_js_function) + return transform_inner_schema(schema) + + get_inner_schema = CallbackGetCoreSchemaHandler(inner_handler, self) + + for annotation in annotations: + if annotation is None: + continue + get_inner_schema = self._get_wrapped_inner_schema( + get_inner_schema, annotation, pydantic_js_annotation_functions + ) + + schema = get_inner_schema(source_type) + if pydantic_js_annotation_functions: + metadata = CoreMetadataHandler(schema).metadata + metadata.setdefault('pydantic_js_annotation_functions', []).extend(pydantic_js_annotation_functions) + return _add_custom_serialization_from_json_encoders(self._config_wrapper.json_encoders, source_type, schema) + + def _apply_single_annotation(self, schema: core_schema.CoreSchema, metadata: Any) -> core_schema.CoreSchema: + from ..fields import FieldInfo + + if isinstance(metadata, FieldInfo): + for field_metadata in metadata.metadata: + schema = self._apply_single_annotation(schema, field_metadata) + + if metadata.discriminator is not None: + schema = self._apply_discriminator_to_union(schema, metadata.discriminator) + return schema + + if schema['type'] == 'nullable': + # for nullable schemas, metadata is automatically applied to the inner schema + inner = schema.get('schema', core_schema.any_schema()) + inner = self._apply_single_annotation(inner, metadata) + if inner: + schema['schema'] = inner + return schema + + original_schema = schema + ref = schema.get('ref', None) + if ref is not None: + schema = schema.copy() + new_ref = ref + f'_{repr(metadata)}' + if new_ref in self.defs.definitions: + return self.defs.definitions[new_ref] + schema['ref'] = new_ref # type: ignore + elif schema['type'] == 'definition-ref': + ref = schema['schema_ref'] + if ref in self.defs.definitions: + schema = self.defs.definitions[ref].copy() + new_ref = ref + f'_{repr(metadata)}' + if new_ref in self.defs.definitions: + return self.defs.definitions[new_ref] + schema['ref'] = new_ref # type: ignore + + maybe_updated_schema = _known_annotated_metadata.apply_known_metadata(metadata, schema.copy()) + + if maybe_updated_schema is not None: + return maybe_updated_schema + return original_schema + + def _apply_single_annotation_json_schema( + self, schema: core_schema.CoreSchema, metadata: Any + ) -> core_schema.CoreSchema: + from ..fields import FieldInfo + + if isinstance(metadata, FieldInfo): + for field_metadata in metadata.metadata: + schema = self._apply_single_annotation_json_schema(schema, field_metadata) + json_schema_update: JsonSchemaValue = {} + if metadata.title: + json_schema_update['title'] = metadata.title + if metadata.description: + json_schema_update['description'] = metadata.description + if metadata.examples: + json_schema_update['examples'] = to_jsonable_python(metadata.examples) + + json_schema_extra = metadata.json_schema_extra + if json_schema_update or json_schema_extra: + CoreMetadataHandler(schema).metadata.setdefault('pydantic_js_annotation_functions', []).append( + get_json_schema_update_func(json_schema_update, json_schema_extra) + ) + return schema + + def _get_wrapped_inner_schema( + self, + get_inner_schema: GetCoreSchemaHandler, + annotation: Any, + pydantic_js_annotation_functions: list[GetJsonSchemaFunction], + ) -> CallbackGetCoreSchemaHandler: + metadata_get_schema: GetCoreSchemaFunction = getattr(annotation, '__get_pydantic_core_schema__', None) or ( + lambda source, handler: handler(source) + ) + + def new_handler(source: Any) -> core_schema.CoreSchema: + schema = metadata_get_schema(source, get_inner_schema) + schema = self._apply_single_annotation(schema, annotation) + schema = self._apply_single_annotation_json_schema(schema, annotation) + + metadata_js_function = _extract_get_pydantic_json_schema(annotation, schema) + if metadata_js_function is not None: + pydantic_js_annotation_functions.append(metadata_js_function) + return schema + + return CallbackGetCoreSchemaHandler(new_handler, self) + + def _apply_field_serializers( + self, + schema: core_schema.CoreSchema, + serializers: list[Decorator[FieldSerializerDecoratorInfo]], + computed_field: bool = False, + ) -> core_schema.CoreSchema: + """Apply field serializers to a schema.""" + if serializers: + schema = copy(schema) + if schema['type'] == 'definitions': + inner_schema = schema['schema'] + schema['schema'] = self._apply_field_serializers(inner_schema, serializers) + return schema + else: + ref = typing.cast('str|None', schema.get('ref', None)) + if ref is not None: + schema = core_schema.definition_reference_schema(ref) + + # use the last serializer to make it easy to override a serializer set on a parent model + serializer = serializers[-1] + is_field_serializer, info_arg = inspect_field_serializer( + serializer.func, serializer.info.mode, computed_field=computed_field + ) + + try: + return_type = _decorators.get_function_return_type( + serializer.func, serializer.info.return_type, self._types_namespace + ) + except NameError as e: + raise PydanticUndefinedAnnotation.from_name_error(e) from e + + if return_type is PydanticUndefined: + return_schema = None + else: + return_schema = self.generate_schema(return_type) + + if serializer.info.mode == 'wrap': + schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( + serializer.func, + is_field_serializer=is_field_serializer, + info_arg=info_arg, + return_schema=return_schema, + when_used=serializer.info.when_used, + ) + else: + assert serializer.info.mode == 'plain' + schema['serialization'] = core_schema.plain_serializer_function_ser_schema( + serializer.func, + is_field_serializer=is_field_serializer, + info_arg=info_arg, + return_schema=return_schema, + when_used=serializer.info.when_used, + ) + return schema + + def _apply_model_serializers( + self, schema: core_schema.CoreSchema, serializers: Iterable[Decorator[ModelSerializerDecoratorInfo]] + ) -> core_schema.CoreSchema: + """Apply model serializers to a schema.""" + ref: str | None = schema.pop('ref', None) # type: ignore + if serializers: + serializer = list(serializers)[-1] + info_arg = inspect_model_serializer(serializer.func, serializer.info.mode) + + try: + return_type = _decorators.get_function_return_type( + serializer.func, serializer.info.return_type, self._types_namespace + ) + except NameError as e: + raise PydanticUndefinedAnnotation.from_name_error(e) from e + if return_type is PydanticUndefined: + return_schema = None + else: + return_schema = self.generate_schema(return_type) + + if serializer.info.mode == 'wrap': + ser_schema: core_schema.SerSchema = core_schema.wrap_serializer_function_ser_schema( + serializer.func, + info_arg=info_arg, + return_schema=return_schema, + when_used=serializer.info.when_used, + ) + else: + # plain + ser_schema = core_schema.plain_serializer_function_ser_schema( + serializer.func, + info_arg=info_arg, + return_schema=return_schema, + when_used=serializer.info.when_used, + ) + schema['serialization'] = ser_schema + if ref: + schema['ref'] = ref # type: ignore + return schema + + +_VALIDATOR_F_MATCH: Mapping[ + tuple[FieldValidatorModes, Literal['no-info', 'with-info']], + Callable[[Callable[..., Any], core_schema.CoreSchema, str | None], core_schema.CoreSchema], +] = { + ('before', 'no-info'): lambda f, schema, _: core_schema.no_info_before_validator_function(f, schema), + ('after', 'no-info'): lambda f, schema, _: core_schema.no_info_after_validator_function(f, schema), + ('plain', 'no-info'): lambda f, _1, _2: core_schema.no_info_plain_validator_function(f), + ('wrap', 'no-info'): lambda f, schema, _: core_schema.no_info_wrap_validator_function(f, schema), + ('before', 'with-info'): lambda f, schema, field_name: core_schema.with_info_before_validator_function( + f, schema, field_name=field_name + ), + ('after', 'with-info'): lambda f, schema, field_name: core_schema.with_info_after_validator_function( + f, schema, field_name=field_name + ), + ('plain', 'with-info'): lambda f, _, field_name: core_schema.with_info_plain_validator_function( + f, field_name=field_name + ), + ('wrap', 'with-info'): lambda f, schema, field_name: core_schema.with_info_wrap_validator_function( + f, schema, field_name=field_name + ), +} + + +def apply_validators( + schema: core_schema.CoreSchema, + validators: Iterable[Decorator[RootValidatorDecoratorInfo]] + | Iterable[Decorator[ValidatorDecoratorInfo]] + | Iterable[Decorator[FieldValidatorDecoratorInfo]], + field_name: str | None, +) -> core_schema.CoreSchema: + """Apply validators to a schema. + + Args: + schema: The schema to apply validators on. + validators: An iterable of validators. + field_name: The name of the field if validators are being applied to a model field. + + Returns: + The updated schema. + """ + for validator in validators: + info_arg = inspect_validator(validator.func, validator.info.mode) + val_type = 'with-info' if info_arg else 'no-info' + + schema = _VALIDATOR_F_MATCH[(validator.info.mode, val_type)](validator.func, schema, field_name) + return schema + + +def _validators_require_validate_default(validators: Iterable[Decorator[ValidatorDecoratorInfo]]) -> bool: + """In v1, if any of the validators for a field had `always=True`, the default value would be validated. + + This serves as an auxiliary function for re-implementing that logic, by looping over a provided + collection of (v1-style) ValidatorDecoratorInfo's and checking if any of them have `always=True`. + + We should be able to drop this function and the associated logic calling it once we drop support + for v1-style validator decorators. (Or we can extend it and keep it if we add something equivalent + to the v1-validator `always` kwarg to `field_validator`.) + """ + for validator in validators: + if validator.info.always: + return True + return False + + +def apply_model_validators( + schema: core_schema.CoreSchema, + validators: Iterable[Decorator[ModelValidatorDecoratorInfo]], + mode: Literal['inner', 'outer', 'all'], +) -> core_schema.CoreSchema: + """Apply model validators to a schema. + + If mode == 'inner', only "before" validators are applied + If mode == 'outer', validators other than "before" are applied + If mode == 'all', all validators are applied + + Args: + schema: The schema to apply validators on. + validators: An iterable of validators. + mode: The validator mode. + + Returns: + The updated schema. + """ + ref: str | None = schema.pop('ref', None) # type: ignore + for validator in validators: + if mode == 'inner' and validator.info.mode != 'before': + continue + if mode == 'outer' and validator.info.mode == 'before': + continue + info_arg = inspect_validator(validator.func, validator.info.mode) + if validator.info.mode == 'wrap': + if info_arg: + schema = core_schema.with_info_wrap_validator_function(function=validator.func, schema=schema) + else: + schema = core_schema.no_info_wrap_validator_function(function=validator.func, schema=schema) + elif validator.info.mode == 'before': + if info_arg: + schema = core_schema.with_info_before_validator_function(function=validator.func, schema=schema) + else: + schema = core_schema.no_info_before_validator_function(function=validator.func, schema=schema) + else: + assert validator.info.mode == 'after' + if info_arg: + schema = core_schema.with_info_after_validator_function(function=validator.func, schema=schema) + else: + schema = core_schema.no_info_after_validator_function(function=validator.func, schema=schema) + if ref: + schema['ref'] = ref # type: ignore + return schema + + +def wrap_default(field_info: FieldInfo, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: + """Wrap schema with default schema if default value or `default_factory` are available. + + Args: + field_info: The field info object. + schema: The schema to apply default on. + + Returns: + Updated schema by default value or `default_factory`. + """ + if field_info.default_factory: + return core_schema.with_default_schema( + schema, default_factory=field_info.default_factory, validate_default=field_info.validate_default + ) + elif field_info.default is not PydanticUndefined: + return core_schema.with_default_schema( + schema, default=field_info.default, validate_default=field_info.validate_default + ) + else: + return schema + + +def _extract_get_pydantic_json_schema(tp: Any, schema: CoreSchema) -> GetJsonSchemaFunction | None: + """Extract `__get_pydantic_json_schema__` from a type, handling the deprecated `__modify_schema__`.""" + js_modify_function = getattr(tp, '__get_pydantic_json_schema__', None) + + if hasattr(tp, '__modify_schema__'): + from pydantic import BaseModel # circular reference + + has_custom_v2_modify_js_func = ( + js_modify_function is not None + and BaseModel.__get_pydantic_json_schema__.__func__ # type: ignore + not in (js_modify_function, getattr(js_modify_function, '__func__', None)) + ) + + if not has_custom_v2_modify_js_func: + raise PydanticUserError( + 'The `__modify_schema__` method is not supported in Pydantic v2. ' + 'Use `__get_pydantic_json_schema__` instead.', + code='custom-json-schema', + ) + + # handle GenericAlias' but ignore Annotated which "lies" about its origin (in this case it would be `int`) + if hasattr(tp, '__origin__') and not isinstance(tp, type(Annotated[int, 'placeholder'])): + return _extract_get_pydantic_json_schema(tp.__origin__, schema) + + if js_modify_function is None: + return None + + return js_modify_function + + +def get_json_schema_update_func( + json_schema_update: JsonSchemaValue, json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None +) -> GetJsonSchemaFunction: + def json_schema_update_func( + core_schema_or_field: CoreSchemaOrField, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + json_schema = {**handler(core_schema_or_field), **json_schema_update} + add_json_schema_extra(json_schema, json_schema_extra) + return json_schema + + return json_schema_update_func + + +def add_json_schema_extra( + json_schema: JsonSchemaValue, json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None +): + if isinstance(json_schema_extra, dict): + json_schema.update(to_jsonable_python(json_schema_extra)) + elif callable(json_schema_extra): + json_schema_extra(json_schema) + + +class _CommonField(TypedDict): + schema: core_schema.CoreSchema + validation_alias: str | list[str | int] | list[list[str | int]] | None + serialization_alias: str | None + serialization_exclude: bool | None + frozen: bool | None + metadata: dict[str, Any] + + +def _common_field( + schema: core_schema.CoreSchema, + *, + validation_alias: str | list[str | int] | list[list[str | int]] | None = None, + serialization_alias: str | None = None, + serialization_exclude: bool | None = None, + frozen: bool | None = None, + metadata: Any = None, +) -> _CommonField: + return { + 'schema': schema, + 'validation_alias': validation_alias, + 'serialization_alias': serialization_alias, + 'serialization_exclude': serialization_exclude, + 'frozen': frozen, + 'metadata': metadata, + } + + +class _Definitions: + """Keeps track of references and definitions.""" + + def __init__(self) -> None: + self.seen: set[str] = set() + self.definitions: dict[str, core_schema.CoreSchema] = {} + + @contextmanager + def get_schema_or_ref(self, tp: Any) -> Iterator[tuple[str, None] | tuple[str, CoreSchema]]: + """Get a definition for `tp` if one exists. + + If a definition exists, a tuple of `(ref_string, CoreSchema)` is returned. + If no definition exists yet, a tuple of `(ref_string, None)` is returned. + + Note that the returned `CoreSchema` will always be a `DefinitionReferenceSchema`, + not the actual definition itself. + + This should be called for any type that can be identified by reference. + This includes any recursive types. + + At present the following types can be named/recursive: + + - BaseModel + - Dataclasses + - TypedDict + - TypeAliasType + """ + ref = get_type_ref(tp) + # return the reference if we're either (1) in a cycle or (2) it was already defined + if ref in self.seen or ref in self.definitions: + yield (ref, core_schema.definition_reference_schema(ref)) + else: + self.seen.add(ref) + try: + yield (ref, None) + finally: + self.seen.discard(ref) + + +def resolve_original_schema(schema: CoreSchema, definitions: dict[str, CoreSchema]) -> CoreSchema | None: + if schema['type'] == 'definition-ref': + return definitions.get(schema['schema_ref'], None) + elif schema['type'] == 'definitions': + return schema['schema'] + else: + return schema + + +class _FieldNameStack: + __slots__ = ('_stack',) + + def __init__(self) -> None: + self._stack: list[str] = [] + + @contextmanager + def push(self, field_name: str) -> Iterator[None]: + self._stack.append(field_name) + yield + self._stack.pop() + + def get(self) -> str | None: + if self._stack: + return self._stack[-1] + else: + return None diff --git a/lib/pydantic/_internal/_generics.py b/lib/pydantic/_internal/_generics.py new file mode 100644 index 00000000..5a66eaa9 --- /dev/null +++ b/lib/pydantic/_internal/_generics.py @@ -0,0 +1,517 @@ +from __future__ import annotations + +import sys +import types +import typing +from collections import ChainMap +from contextlib import contextmanager +from contextvars import ContextVar +from types import prepare_class +from typing import TYPE_CHECKING, Any, Iterator, List, Mapping, MutableMapping, Tuple, TypeVar +from weakref import WeakValueDictionary + +import typing_extensions + +from ._core_utils import get_type_ref +from ._forward_ref import PydanticRecursiveRef +from ._typing_extra import TypeVarType, typing_base +from ._utils import all_identical, is_model_class + +if sys.version_info >= (3, 10): + from typing import _UnionGenericAlias # type: ignore[attr-defined] + +if TYPE_CHECKING: + from ..main import BaseModel + +GenericTypesCacheKey = Tuple[Any, Any, Tuple[Any, ...]] + +# Note: We want to remove LimitedDict, but to do this, we'd need to improve the handling of generics caching. +# Right now, to handle recursive generics, we some types must remain cached for brief periods without references. +# By chaining the WeakValuesDict with a LimitedDict, we have a way to retain caching for all types with references, +# while also retaining a limited number of types even without references. This is generally enough to build +# specific recursive generic models without losing required items out of the cache. + +KT = TypeVar('KT') +VT = TypeVar('VT') +_LIMITED_DICT_SIZE = 100 +if TYPE_CHECKING: + + class LimitedDict(dict, MutableMapping[KT, VT]): + def __init__(self, size_limit: int = _LIMITED_DICT_SIZE): + ... + +else: + + class LimitedDict(dict): + """Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage. + + Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache. + """ + + def __init__(self, size_limit: int = _LIMITED_DICT_SIZE): + self.size_limit = size_limit + super().__init__() + + def __setitem__(self, __key: Any, __value: Any) -> None: + super().__setitem__(__key, __value) + if len(self) > self.size_limit: + excess = len(self) - self.size_limit + self.size_limit // 10 + to_remove = list(self.keys())[:excess] + for key in to_remove: + del self[key] + + +# weak dictionaries allow the dynamically created parametrized versions of generic models to get collected +# once they are no longer referenced by the caller. +if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9 + GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]'] +else: + GenericTypesCache = WeakValueDictionary + +if TYPE_CHECKING: + + class DeepChainMap(ChainMap[KT, VT]): # type: ignore + ... + +else: + + class DeepChainMap(ChainMap): + """Variant of ChainMap that allows direct updates to inner scopes. + + Taken from https://docs.python.org/3/library/collections.html#collections.ChainMap, + with some light modifications for this use case. + """ + + def clear(self) -> None: + for mapping in self.maps: + mapping.clear() + + def __setitem__(self, key: KT, value: VT) -> None: + for mapping in self.maps: + mapping[key] = value + + def __delitem__(self, key: KT) -> None: + hit = False + for mapping in self.maps: + if key in mapping: + del mapping[key] + hit = True + if not hit: + raise KeyError(key) + + +# Despite the fact that LimitedDict _seems_ no longer necessary, I'm very nervous to actually remove it +# and discover later on that we need to re-add all this infrastructure... +# _GENERIC_TYPES_CACHE = DeepChainMap(GenericTypesCache(), LimitedDict()) + +_GENERIC_TYPES_CACHE = GenericTypesCache() + + +class PydanticGenericMetadata(typing_extensions.TypedDict): + origin: type[BaseModel] | None # analogous to typing._GenericAlias.__origin__ + args: tuple[Any, ...] # analogous to typing._GenericAlias.__args__ + parameters: tuple[type[Any], ...] # analogous to typing.Generic.__parameters__ + + +def create_generic_submodel( + model_name: str, origin: type[BaseModel], args: tuple[Any, ...], params: tuple[Any, ...] +) -> type[BaseModel]: + """Dynamically create a submodel of a provided (generic) BaseModel. + + This is used when producing concrete parametrizations of generic models. This function + only *creates* the new subclass; the schema/validators/serialization must be updated to + reflect a concrete parametrization elsewhere. + + Args: + model_name: The name of the newly created model. + origin: The base class for the new model to inherit from. + args: A tuple of generic metadata arguments. + params: A tuple of generic metadata parameters. + + Returns: + The created submodel. + """ + namespace: dict[str, Any] = {'__module__': origin.__module__} + bases = (origin,) + meta, ns, kwds = prepare_class(model_name, bases) + namespace.update(ns) + created_model = meta( + model_name, + bases, + namespace, + __pydantic_generic_metadata__={ + 'origin': origin, + 'args': args, + 'parameters': params, + }, + __pydantic_reset_parent_namespace__=False, + **kwds, + ) + + model_module, called_globally = _get_caller_frame_info(depth=3) + if called_globally: # create global reference and therefore allow pickling + object_by_reference = None + reference_name = model_name + reference_module_globals = sys.modules[created_model.__module__].__dict__ + while object_by_reference is not created_model: + object_by_reference = reference_module_globals.setdefault(reference_name, created_model) + reference_name += '_' + + return created_model + + +def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]: + """Used inside a function to check whether it was called globally. + + Args: + depth: The depth to get the frame. + + Returns: + A tuple contains `module_name` and `called_globally`. + + Raises: + RuntimeError: If the function is not called inside a function. + """ + try: + previous_caller_frame = sys._getframe(depth) + except ValueError as e: + raise RuntimeError('This function must be used inside another function') from e + except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it + return None, False + frame_globals = previous_caller_frame.f_globals + return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals + + +DictValues: type[Any] = {}.values().__class__ + + +def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]: + """Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found. + + This is inspired as an alternative to directly accessing the `__parameters__` attribute of a GenericAlias, + since __parameters__ of (nested) generic BaseModel subclasses won't show up in that list. + """ + if isinstance(v, TypeVar): + yield v + elif is_model_class(v): + yield from v.__pydantic_generic_metadata__['parameters'] + elif isinstance(v, (DictValues, list)): + for var in v: + yield from iter_contained_typevars(var) + else: + args = get_args(v) + for arg in args: + yield from iter_contained_typevars(arg) + + +def get_args(v: Any) -> Any: + pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None) + if pydantic_generic_metadata: + return pydantic_generic_metadata.get('args') + return typing_extensions.get_args(v) + + +def get_origin(v: Any) -> Any: + pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None) + if pydantic_generic_metadata: + return pydantic_generic_metadata.get('origin') + return typing_extensions.get_origin(v) + + +def get_standard_typevars_map(cls: type[Any]) -> dict[TypeVarType, Any] | None: + """Package a generic type's typevars and parametrization (if present) into a dictionary compatible with the + `replace_types` function. Specifically, this works with standard typing generics and typing._GenericAlias. + """ + origin = get_origin(cls) + if origin is None: + return None + if not hasattr(origin, '__parameters__'): + return None + + # In this case, we know that cls is a _GenericAlias, and origin is the generic type + # So it is safe to access cls.__args__ and origin.__parameters__ + args: tuple[Any, ...] = cls.__args__ # type: ignore + parameters: tuple[TypeVarType, ...] = origin.__parameters__ + return dict(zip(parameters, args)) + + +def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVarType, Any] | None: + """Package a generic BaseModel's typevars and concrete parametrization (if present) into a dictionary compatible + with the `replace_types` function. + + Since BaseModel.__class_getitem__ does not produce a typing._GenericAlias, and the BaseModel generic info is + stored in the __pydantic_generic_metadata__ attribute, we need special handling here. + """ + # TODO: This could be unified with `get_standard_typevars_map` if we stored the generic metadata + # in the __origin__, __args__, and __parameters__ attributes of the model. + generic_metadata = cls.__pydantic_generic_metadata__ + origin = generic_metadata['origin'] + args = generic_metadata['args'] + return dict(zip(iter_contained_typevars(origin), args)) + + +def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any: + """Return type with all occurrences of `type_map` keys recursively replaced with their values. + + Args: + type_: The class or generic alias. + type_map: Mapping from `TypeVar` instance to concrete types. + + Returns: + A new type representing the basic structure of `type_` with all + `typevar_map` keys recursively replaced. + + Example: + ```py + from typing import List, Tuple, Union + + from pydantic._internal._generics import replace_types + + replace_types(Tuple[str, Union[List[str], float]], {str: int}) + #> Tuple[int, Union[List[int], float]] + ``` + """ + if not type_map: + return type_ + + type_args = get_args(type_) + origin_type = get_origin(type_) + + if origin_type is typing_extensions.Annotated: + annotated_type, *annotations = type_args + annotated = replace_types(annotated_type, type_map) + for annotation in annotations: + annotated = typing_extensions.Annotated[annotated, annotation] + return annotated + + # Having type args is a good indicator that this is a typing module + # class instantiation or a generic alias of some sort. + if type_args: + resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args) + if all_identical(type_args, resolved_type_args): + # If all arguments are the same, there is no need to modify the + # type or create a new object at all + return type_ + if ( + origin_type is not None + and isinstance(type_, typing_base) + and not isinstance(origin_type, typing_base) + and getattr(type_, '_name', None) is not None + ): + # In python < 3.9 generic aliases don't exist so any of these like `list`, + # `type` or `collections.abc.Callable` need to be translated. + # See: https://www.python.org/dev/peps/pep-0585 + origin_type = getattr(typing, type_._name) + assert origin_type is not None + # PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__. + # We also cannot use isinstance() since we have to compare types. + if sys.version_info >= (3, 10) and origin_type is types.UnionType: + return _UnionGenericAlias(origin_type, resolved_type_args) + # NotRequired[T] and Required[T] don't support tuple type resolved_type_args, hence the condition below + return origin_type[resolved_type_args[0] if len(resolved_type_args) == 1 else resolved_type_args] + + # We handle pydantic generic models separately as they don't have the same + # semantics as "typing" classes or generic aliases + + if not origin_type and is_model_class(type_): + parameters = type_.__pydantic_generic_metadata__['parameters'] + if not parameters: + return type_ + resolved_type_args = tuple(replace_types(t, type_map) for t in parameters) + if all_identical(parameters, resolved_type_args): + return type_ + return type_[resolved_type_args] + + # Handle special case for typehints that can have lists as arguments. + # `typing.Callable[[int, str], int]` is an example for this. + if isinstance(type_, (List, list)): + resolved_list = list(replace_types(element, type_map) for element in type_) + if all_identical(type_, resolved_list): + return type_ + return resolved_list + + # If all else fails, we try to resolve the type directly and otherwise just + # return the input with no modifications. + return type_map.get(type_, type_) + + +def has_instance_in_type(type_: Any, isinstance_target: Any) -> bool: + """Checks if the type, or any of its arbitrary nested args, satisfy + `isinstance(, isinstance_target)`. + """ + if isinstance(type_, isinstance_target): + return True + + type_args = get_args(type_) + origin_type = get_origin(type_) + + if origin_type is typing_extensions.Annotated: + annotated_type, *annotations = type_args + return has_instance_in_type(annotated_type, isinstance_target) + + # Having type args is a good indicator that this is a typing module + # class instantiation or a generic alias of some sort. + if any(has_instance_in_type(a, isinstance_target) for a in type_args): + return True + + # Handle special case for typehints that can have lists as arguments. + # `typing.Callable[[int, str], int]` is an example for this. + if isinstance(type_, (List, list)) and not isinstance(type_, typing_extensions.ParamSpec): + if any(has_instance_in_type(element, isinstance_target) for element in type_): + return True + + return False + + +def check_parameters_count(cls: type[BaseModel], parameters: tuple[Any, ...]) -> None: + """Check the generic model parameters count is equal. + + Args: + cls: The generic model. + parameters: A tuple of passed parameters to the generic model. + + Raises: + TypeError: If the passed parameters count is not equal to generic model parameters count. + """ + actual = len(parameters) + expected = len(cls.__pydantic_generic_metadata__['parameters']) + if actual != expected: + description = 'many' if actual > expected else 'few' + raise TypeError(f'Too {description} parameters for {cls}; actual {actual}, expected {expected}') + + +_generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None) + + +@contextmanager +def generic_recursion_self_type( + origin: type[BaseModel], args: tuple[Any, ...] +) -> Iterator[PydanticRecursiveRef | None]: + """This contextmanager should be placed around the recursive calls used to build a generic type, + and accept as arguments the generic origin type and the type arguments being passed to it. + + If the same origin and arguments are observed twice, it implies that a self-reference placeholder + can be used while building the core schema, and will produce a schema_ref that will be valid in the + final parent schema. + """ + previously_seen_type_refs = _generic_recursion_cache.get() + if previously_seen_type_refs is None: + previously_seen_type_refs = set() + token = _generic_recursion_cache.set(previously_seen_type_refs) + else: + token = None + + try: + type_ref = get_type_ref(origin, args_override=args) + if type_ref in previously_seen_type_refs: + self_type = PydanticRecursiveRef(type_ref=type_ref) + yield self_type + else: + previously_seen_type_refs.add(type_ref) + yield None + finally: + if token: + _generic_recursion_cache.reset(token) + + +def recursively_defined_type_refs() -> set[str]: + visited = _generic_recursion_cache.get() + if not visited: + return set() # not in a generic recursion, so there are no types + + return visited.copy() # don't allow modifications + + +def get_cached_generic_type_early(parent: type[BaseModel], typevar_values: Any) -> type[BaseModel] | None: + """The use of a two-stage cache lookup approach was necessary to have the highest performance possible for + repeated calls to `__class_getitem__` on generic types (which may happen in tighter loops during runtime), + while still ensuring that certain alternative parametrizations ultimately resolve to the same type. + + As a concrete example, this approach was necessary to make Model[List[T]][int] equal to Model[List[int]]. + The approach could be modified to not use two different cache keys at different points, but the + _early_cache_key is optimized to be as quick to compute as possible (for repeated-access speed), and the + _late_cache_key is optimized to be as "correct" as possible, so that two types that will ultimately be the + same after resolving the type arguments will always produce cache hits. + + If we wanted to move to only using a single cache key per type, we would either need to always use the + slower/more computationally intensive logic associated with _late_cache_key, or would need to accept + that Model[List[T]][int] is a different type than Model[List[T]][int]. Because we rely on subclass relationships + during validation, I think it is worthwhile to ensure that types that are functionally equivalent are actually + equal. + """ + return _GENERIC_TYPES_CACHE.get(_early_cache_key(parent, typevar_values)) + + +def get_cached_generic_type_late( + parent: type[BaseModel], typevar_values: Any, origin: type[BaseModel], args: tuple[Any, ...] +) -> type[BaseModel] | None: + """See the docstring of `get_cached_generic_type_early` for more information about the two-stage cache lookup.""" + cached = _GENERIC_TYPES_CACHE.get(_late_cache_key(origin, args, typevar_values)) + if cached is not None: + set_cached_generic_type(parent, typevar_values, cached, origin, args) + return cached + + +def set_cached_generic_type( + parent: type[BaseModel], + typevar_values: tuple[Any, ...], + type_: type[BaseModel], + origin: type[BaseModel] | None = None, + args: tuple[Any, ...] | None = None, +) -> None: + """See the docstring of `get_cached_generic_type_early` for more information about why items are cached with + two different keys. + """ + _GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values)] = type_ + if len(typevar_values) == 1: + _GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values[0])] = type_ + if origin and args: + _GENERIC_TYPES_CACHE[_late_cache_key(origin, args, typevar_values)] = type_ + + +def _union_orderings_key(typevar_values: Any) -> Any: + """This is intended to help differentiate between Union types with the same arguments in different order. + + Thanks to caching internal to the `typing` module, it is not possible to distinguish between + List[Union[int, float]] and List[Union[float, int]] (and similarly for other "parent" origins besides List) + because `typing` considers Union[int, float] to be equal to Union[float, int]. + + However, you _can_ distinguish between (top-level) Union[int, float] vs. Union[float, int]. + Because we parse items as the first Union type that is successful, we get slightly more consistent behavior + if we make an effort to distinguish the ordering of items in a union. It would be best if we could _always_ + get the exact-correct order of items in the union, but that would require a change to the `typing` module itself. + (See https://github.com/python/cpython/issues/86483 for reference.) + """ + if isinstance(typevar_values, tuple): + args_data = [] + for value in typevar_values: + args_data.append(_union_orderings_key(value)) + return tuple(args_data) + elif typing_extensions.get_origin(typevar_values) is typing.Union: + return get_args(typevar_values) + else: + return () + + +def _early_cache_key(cls: type[BaseModel], typevar_values: Any) -> GenericTypesCacheKey: + """This is intended for minimal computational overhead during lookups of cached types. + + Note that this is overly simplistic, and it's possible that two different cls/typevar_values + inputs would ultimately result in the same type being created in BaseModel.__class_getitem__. + To handle this, we have a fallback _late_cache_key that is checked later if the _early_cache_key + lookup fails, and should result in a cache hit _precisely_ when the inputs to __class_getitem__ + would result in the same type. + """ + return cls, typevar_values, _union_orderings_key(typevar_values) + + +def _late_cache_key(origin: type[BaseModel], args: tuple[Any, ...], typevar_values: Any) -> GenericTypesCacheKey: + """This is intended for use later in the process of creating a new type, when we have more information + about the exact args that will be passed. If it turns out that a different set of inputs to + __class_getitem__ resulted in the same inputs to the generic type creation process, we can still + return the cached type, and update the cache with the _early_cache_key as well. + """ + # The _union_orderings_key is placed at the start here to ensure there cannot be a collision with an + # _early_cache_key, as that function will always produce a BaseModel subclass as the first item in the key, + # whereas this function will always produce a tuple as the first item in the key. + return _union_orderings_key(typevar_values), origin, args diff --git a/lib/pydantic/_internal/_git.py b/lib/pydantic/_internal/_git.py new file mode 100644 index 00000000..9de7aaf9 --- /dev/null +++ b/lib/pydantic/_internal/_git.py @@ -0,0 +1,26 @@ +"""Git utilities, adopted from mypy's git utilities (https://github.com/python/mypy/blob/master/mypy/git.py).""" +from __future__ import annotations + +import os +import subprocess + + +def is_git_repo(dir: str) -> bool: + """Is the given directory version-controlled with git?""" + return os.path.exists(os.path.join(dir, '.git')) + + +def have_git() -> bool: + """Can we run the git executable?""" + try: + subprocess.check_output(['git', '--help']) + return True + except subprocess.CalledProcessError: + return False + except OSError: + return False + + +def git_revision(dir: str) -> str: + """Get the SHA-1 of the HEAD of a git repository.""" + return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=dir).decode('utf-8').strip() diff --git a/lib/pydantic/_internal/_internal_dataclass.py b/lib/pydantic/_internal/_internal_dataclass.py new file mode 100644 index 00000000..317a3d9c --- /dev/null +++ b/lib/pydantic/_internal/_internal_dataclass.py @@ -0,0 +1,10 @@ +import sys +from typing import Any, Dict + +dataclass_kwargs: Dict[str, Any] + +# `slots` is available on Python >= 3.10 +if sys.version_info >= (3, 10): + slots_true = {'slots': True} +else: + slots_true = {} diff --git a/lib/pydantic/_internal/_known_annotated_metadata.py b/lib/pydantic/_internal/_known_annotated_metadata.py new file mode 100644 index 00000000..77caf705 --- /dev/null +++ b/lib/pydantic/_internal/_known_annotated_metadata.py @@ -0,0 +1,410 @@ +from __future__ import annotations + +from collections import defaultdict +from copy import copy +from functools import partial +from typing import TYPE_CHECKING, Any, Callable, Iterable + +from pydantic_core import CoreSchema, PydanticCustomError, to_jsonable_python +from pydantic_core import core_schema as cs + +from ._fields import PydanticMetadata + +if TYPE_CHECKING: + from ..annotated_handlers import GetJsonSchemaHandler + + +STRICT = {'strict'} +SEQUENCE_CONSTRAINTS = {'min_length', 'max_length'} +INEQUALITY = {'le', 'ge', 'lt', 'gt'} +NUMERIC_CONSTRAINTS = {'multiple_of', 'allow_inf_nan', *INEQUALITY} + +STR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT, 'strip_whitespace', 'to_lower', 'to_upper', 'pattern'} +BYTES_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT} + +LIST_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT} +TUPLE_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT} +SET_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT} +DICT_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT} +GENERATOR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT} + +FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT} +INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT} +BOOL_CONSTRAINTS = STRICT +UUID_CONSTRAINTS = STRICT + +DATE_TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT} +TIMEDELTA_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT} +TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT} +LAX_OR_STRICT_CONSTRAINTS = STRICT + +UNION_CONSTRAINTS = {'union_mode'} +URL_CONSTRAINTS = { + 'max_length', + 'allowed_schemes', + 'host_required', + 'default_host', + 'default_port', + 'default_path', +} + +TEXT_SCHEMA_TYPES = ('str', 'bytes', 'url', 'multi-host-url') +SEQUENCE_SCHEMA_TYPES = ('list', 'tuple', 'set', 'frozenset', 'generator', *TEXT_SCHEMA_TYPES) +NUMERIC_SCHEMA_TYPES = ('float', 'int', 'date', 'time', 'timedelta', 'datetime') + +CONSTRAINTS_TO_ALLOWED_SCHEMAS: dict[str, set[str]] = defaultdict(set) +for constraint in STR_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(TEXT_SCHEMA_TYPES) +for constraint in BYTES_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bytes',)) +for constraint in LIST_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('list',)) +for constraint in TUPLE_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('tuple',)) +for constraint in SET_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('set', 'frozenset')) +for constraint in DICT_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('dict',)) +for constraint in GENERATOR_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('generator',)) +for constraint in FLOAT_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('float',)) +for constraint in INT_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('int',)) +for constraint in DATE_TIME_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('date', 'time', 'datetime')) +for constraint in TIMEDELTA_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('timedelta',)) +for constraint in TIME_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('time',)) +for schema_type in (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model'): + CONSTRAINTS_TO_ALLOWED_SCHEMAS['strict'].add(schema_type) +for constraint in UNION_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('union',)) +for constraint in URL_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('url', 'multi-host-url')) +for constraint in BOOL_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bool',)) +for constraint in UUID_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('uuid',)) +for constraint in LAX_OR_STRICT_CONSTRAINTS: + CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('lax-or-strict',)) + + +def add_js_update_schema(s: cs.CoreSchema, f: Callable[[], dict[str, Any]]) -> None: + def update_js_schema(s: cs.CoreSchema, handler: GetJsonSchemaHandler) -> dict[str, Any]: + js_schema = handler(s) + js_schema.update(f()) + return js_schema + + if 'metadata' in s: + metadata = s['metadata'] + if 'pydantic_js_functions' in s: + metadata['pydantic_js_functions'].append(update_js_schema) + else: + metadata['pydantic_js_functions'] = [update_js_schema] + else: + s['metadata'] = {'pydantic_js_functions': [update_js_schema]} + + +def as_jsonable_value(v: Any) -> Any: + if type(v) not in (int, str, float, bytes, bool, type(None)): + return to_jsonable_python(v) + return v + + +def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]: + """Expand the annotations. + + Args: + annotations: An iterable of annotations. + + Returns: + An iterable of expanded annotations. + + Example: + ```py + from annotated_types import Ge, Len + + from pydantic._internal._known_annotated_metadata import expand_grouped_metadata + + print(list(expand_grouped_metadata([Ge(4), Len(5)]))) + #> [Ge(ge=4), MinLen(min_length=5)] + ``` + """ + import annotated_types as at + + from pydantic.fields import FieldInfo # circular import + + for annotation in annotations: + if isinstance(annotation, at.GroupedMetadata): + yield from annotation + elif isinstance(annotation, FieldInfo): + yield from annotation.metadata + # this is a bit problematic in that it results in duplicate metadata + # all of our "consumers" can handle it, but it is not ideal + # we probably should split up FieldInfo into: + # - annotated types metadata + # - individual metadata known only to Pydantic + annotation = copy(annotation) + annotation.metadata = [] + yield annotation + else: + yield annotation + + +def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | None: # noqa: C901 + """Apply `annotation` to `schema` if it is an annotation we know about (Gt, Le, etc.). + Otherwise return `None`. + + This does not handle all known annotations. If / when it does, it can always + return a CoreSchema and return the unmodified schema if the annotation should be ignored. + + Assumes that GroupedMetadata has already been expanded via `expand_grouped_metadata`. + + Args: + annotation: The annotation. + schema: The schema. + + Returns: + An updated schema with annotation if it is an annotation we know about, `None` otherwise. + + Raises: + PydanticCustomError: If `Predicate` fails. + """ + import annotated_types as at + + from . import _validators + + schema = schema.copy() + schema_update, other_metadata = collect_known_metadata([annotation]) + schema_type = schema['type'] + for constraint, value in schema_update.items(): + if constraint not in CONSTRAINTS_TO_ALLOWED_SCHEMAS: + raise ValueError(f'Unknown constraint {constraint}') + allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint] + + if schema_type in allowed_schemas: + if constraint == 'union_mode' and schema_type == 'union': + schema['mode'] = value # type: ignore # schema is UnionSchema + else: + schema[constraint] = value + continue + + if constraint == 'allow_inf_nan' and value is False: + return cs.no_info_after_validator_function( + _validators.forbid_inf_nan_check, + schema, + ) + elif constraint == 'pattern': + # insert a str schema to make sure the regex engine matches + return cs.chain_schema( + [ + schema, + cs.str_schema(pattern=value), + ] + ) + elif constraint == 'gt': + s = cs.no_info_after_validator_function( + partial(_validators.greater_than_validator, gt=value), + schema, + ) + add_js_update_schema(s, lambda: {'gt': as_jsonable_value(value)}) + return s + elif constraint == 'ge': + return cs.no_info_after_validator_function( + partial(_validators.greater_than_or_equal_validator, ge=value), + schema, + ) + elif constraint == 'lt': + return cs.no_info_after_validator_function( + partial(_validators.less_than_validator, lt=value), + schema, + ) + elif constraint == 'le': + return cs.no_info_after_validator_function( + partial(_validators.less_than_or_equal_validator, le=value), + schema, + ) + elif constraint == 'multiple_of': + return cs.no_info_after_validator_function( + partial(_validators.multiple_of_validator, multiple_of=value), + schema, + ) + elif constraint == 'min_length': + s = cs.no_info_after_validator_function( + partial(_validators.min_length_validator, min_length=value), + schema, + ) + add_js_update_schema(s, lambda: {'minLength': (as_jsonable_value(value))}) + return s + elif constraint == 'max_length': + s = cs.no_info_after_validator_function( + partial(_validators.max_length_validator, max_length=value), + schema, + ) + add_js_update_schema(s, lambda: {'maxLength': (as_jsonable_value(value))}) + return s + elif constraint == 'strip_whitespace': + return cs.chain_schema( + [ + schema, + cs.str_schema(strip_whitespace=True), + ] + ) + elif constraint == 'to_lower': + return cs.chain_schema( + [ + schema, + cs.str_schema(to_lower=True), + ] + ) + elif constraint == 'to_upper': + return cs.chain_schema( + [ + schema, + cs.str_schema(to_upper=True), + ] + ) + elif constraint == 'min_length': + return cs.no_info_after_validator_function( + partial(_validators.min_length_validator, min_length=annotation.min_length), + schema, + ) + elif constraint == 'max_length': + return cs.no_info_after_validator_function( + partial(_validators.max_length_validator, max_length=annotation.max_length), + schema, + ) + else: + raise RuntimeError(f'Unable to apply constraint {constraint} to schema {schema_type}') + + for annotation in other_metadata: + if isinstance(annotation, at.Gt): + return cs.no_info_after_validator_function( + partial(_validators.greater_than_validator, gt=annotation.gt), + schema, + ) + elif isinstance(annotation, at.Ge): + return cs.no_info_after_validator_function( + partial(_validators.greater_than_or_equal_validator, ge=annotation.ge), + schema, + ) + elif isinstance(annotation, at.Lt): + return cs.no_info_after_validator_function( + partial(_validators.less_than_validator, lt=annotation.lt), + schema, + ) + elif isinstance(annotation, at.Le): + return cs.no_info_after_validator_function( + partial(_validators.less_than_or_equal_validator, le=annotation.le), + schema, + ) + elif isinstance(annotation, at.MultipleOf): + return cs.no_info_after_validator_function( + partial(_validators.multiple_of_validator, multiple_of=annotation.multiple_of), + schema, + ) + elif isinstance(annotation, at.MinLen): + return cs.no_info_after_validator_function( + partial(_validators.min_length_validator, min_length=annotation.min_length), + schema, + ) + elif isinstance(annotation, at.MaxLen): + return cs.no_info_after_validator_function( + partial(_validators.max_length_validator, max_length=annotation.max_length), + schema, + ) + elif isinstance(annotation, at.Predicate): + predicate_name = f'{annotation.func.__qualname__} ' if hasattr(annotation.func, '__qualname__') else '' + + def val_func(v: Any) -> Any: + # annotation.func may also raise an exception, let it pass through + if not annotation.func(v): + raise PydanticCustomError( + 'predicate_failed', + f'Predicate {predicate_name}failed', # type: ignore + ) + return v + + return cs.no_info_after_validator_function(val_func, schema) + # ignore any other unknown metadata + return None + + return schema + + +def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any], list[Any]]: + """Split `annotations` into known metadata and unknown annotations. + + Args: + annotations: An iterable of annotations. + + Returns: + A tuple contains a dict of known metadata and a list of unknown annotations. + + Example: + ```py + from annotated_types import Gt, Len + + from pydantic._internal._known_annotated_metadata import collect_known_metadata + + print(collect_known_metadata([Gt(1), Len(42), ...])) + #> ({'gt': 1, 'min_length': 42}, [Ellipsis]) + ``` + """ + import annotated_types as at + + annotations = expand_grouped_metadata(annotations) + + res: dict[str, Any] = {} + remaining: list[Any] = [] + for annotation in annotations: + # isinstance(annotation, PydanticMetadata) also covers ._fields:_PydanticGeneralMetadata + if isinstance(annotation, PydanticMetadata): + res.update(annotation.__dict__) + # we don't use dataclasses.asdict because that recursively calls asdict on the field values + elif isinstance(annotation, at.MinLen): + res.update({'min_length': annotation.min_length}) + elif isinstance(annotation, at.MaxLen): + res.update({'max_length': annotation.max_length}) + elif isinstance(annotation, at.Gt): + res.update({'gt': annotation.gt}) + elif isinstance(annotation, at.Ge): + res.update({'ge': annotation.ge}) + elif isinstance(annotation, at.Lt): + res.update({'lt': annotation.lt}) + elif isinstance(annotation, at.Le): + res.update({'le': annotation.le}) + elif isinstance(annotation, at.MultipleOf): + res.update({'multiple_of': annotation.multiple_of}) + elif isinstance(annotation, type) and issubclass(annotation, PydanticMetadata): + # also support PydanticMetadata classes being used without initialisation, + # e.g. `Annotated[int, Strict]` as well as `Annotated[int, Strict()]` + res.update({k: v for k, v in vars(annotation).items() if not k.startswith('_')}) + else: + remaining.append(annotation) + # Nones can sneak in but pydantic-core will reject them + # it'd be nice to clean things up so we don't put in None (we probably don't _need_ to, it was just easier) + # but this is simple enough to kick that can down the road + res = {k: v for k, v in res.items() if v is not None} + return res, remaining + + +def check_metadata(metadata: dict[str, Any], allowed: Iterable[str], source_type: Any) -> None: + """A small utility function to validate that the given metadata can be applied to the target. + More than saving lines of code, this gives us a consistent error message for all of our internal implementations. + + Args: + metadata: A dict of metadata. + allowed: An iterable of allowed metadata. + source_type: The source type. + + Raises: + TypeError: If there is metadatas that can't be applied on source type. + """ + unknown = metadata.keys() - set(allowed) + if unknown: + raise TypeError( + f'The following constraints cannot be applied to {source_type!r}: {", ".join([f"{k!r}" for k in unknown])}' + ) diff --git a/lib/pydantic/_internal/_mock_val_ser.py b/lib/pydantic/_internal/_mock_val_ser.py new file mode 100644 index 00000000..b303fed2 --- /dev/null +++ b/lib/pydantic/_internal/_mock_val_ser.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Generic, TypeVar + +from pydantic_core import SchemaSerializer, SchemaValidator +from typing_extensions import Literal + +from ..errors import PydanticErrorCodes, PydanticUserError + +if TYPE_CHECKING: + from ..dataclasses import PydanticDataclass + from ..main import BaseModel + + +ValSer = TypeVar('ValSer', SchemaValidator, SchemaSerializer) + + +class MockValSer(Generic[ValSer]): + """Mocker for `pydantic_core.SchemaValidator` or `pydantic_core.SchemaSerializer` which optionally attempts to + rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails. + """ + + __slots__ = '_error_message', '_code', '_val_or_ser', '_attempt_rebuild' + + def __init__( + self, + error_message: str, + *, + code: PydanticErrorCodes, + val_or_ser: Literal['validator', 'serializer'], + attempt_rebuild: Callable[[], ValSer | None] | None = None, + ) -> None: + self._error_message = error_message + self._val_or_ser = SchemaValidator if val_or_ser == 'validator' else SchemaSerializer + self._code: PydanticErrorCodes = code + self._attempt_rebuild = attempt_rebuild + + def __getattr__(self, item: str) -> None: + __tracebackhide__ = True + if self._attempt_rebuild: + val_ser = self._attempt_rebuild() + if val_ser is not None: + return getattr(val_ser, item) + + # raise an AttributeError if `item` doesn't exist + getattr(self._val_or_ser, item) + raise PydanticUserError(self._error_message, code=self._code) + + def rebuild(self) -> ValSer | None: + if self._attempt_rebuild: + val_ser = self._attempt_rebuild() + if val_ser is not None: + return val_ser + else: + raise PydanticUserError(self._error_message, code=self._code) + return None + + +def set_model_mocks(cls: type[BaseModel], cls_name: str, undefined_name: str = 'all referenced types') -> None: + """Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a model. + + Args: + cls: The model class to set the mocks on + cls_name: Name of the model class, used in error messages + undefined_name: Name of the undefined thing, used in error messages + """ + undefined_type_error_message = ( + f'`{cls_name}` is not fully defined; you should define {undefined_name},' + f' then call `{cls_name}.model_rebuild()`.' + ) + + def attempt_rebuild_validator() -> SchemaValidator | None: + if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False: + return cls.__pydantic_validator__ + else: + return None + + cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment] + undefined_type_error_message, + code='class-not-fully-defined', + val_or_ser='validator', + attempt_rebuild=attempt_rebuild_validator, + ) + + def attempt_rebuild_serializer() -> SchemaSerializer | None: + if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False: + return cls.__pydantic_serializer__ + else: + return None + + cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment] + undefined_type_error_message, + code='class-not-fully-defined', + val_or_ser='serializer', + attempt_rebuild=attempt_rebuild_serializer, + ) + + +def set_dataclass_mocks( + cls: type[PydanticDataclass], cls_name: str, undefined_name: str = 'all referenced types' +) -> None: + """Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a dataclass. + + Args: + cls: The model class to set the mocks on + cls_name: Name of the model class, used in error messages + undefined_name: Name of the undefined thing, used in error messages + """ + from ..dataclasses import rebuild_dataclass + + undefined_type_error_message = ( + f'`{cls_name}` is not fully defined; you should define {undefined_name},' + f' then call `pydantic.dataclasses.rebuild_dataclass({cls_name})`.' + ) + + def attempt_rebuild_validator() -> SchemaValidator | None: + if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False: + return cls.__pydantic_validator__ + else: + return None + + cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment] + undefined_type_error_message, + code='class-not-fully-defined', + val_or_ser='validator', + attempt_rebuild=attempt_rebuild_validator, + ) + + def attempt_rebuild_serializer() -> SchemaSerializer | None: + if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False: + return cls.__pydantic_serializer__ + else: + return None + + cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment] + undefined_type_error_message, + code='class-not-fully-defined', + val_or_ser='validator', + attempt_rebuild=attempt_rebuild_serializer, + ) diff --git a/lib/pydantic/_internal/_model_construction.py b/lib/pydantic/_internal/_model_construction.py new file mode 100644 index 00000000..543f73e9 --- /dev/null +++ b/lib/pydantic/_internal/_model_construction.py @@ -0,0 +1,637 @@ +"""Private logic for creating models.""" +from __future__ import annotations as _annotations + +import operator +import typing +import warnings +import weakref +from abc import ABCMeta +from functools import partial +from types import FunctionType +from typing import Any, Callable, Generic + +import typing_extensions +from pydantic_core import PydanticUndefined, SchemaSerializer +from typing_extensions import dataclass_transform, deprecated + +from ..errors import PydanticUndefinedAnnotation, PydanticUserError +from ..plugin._schema_validator import create_schema_validator +from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20 +from ._config import ConfigWrapper +from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases +from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name +from ._generate_schema import GenerateSchema +from ._generics import PydanticGenericMetadata, get_model_typevars_map +from ._mock_val_ser import MockValSer, set_model_mocks +from ._schema_generation_shared import CallbackGetCoreSchemaHandler +from ._signature import generate_pydantic_signature +from ._typing_extra import get_cls_types_namespace, is_annotated, is_classvar, parent_frame_namespace +from ._utils import ClassAttribute, SafeGetItemProxy +from ._validate_call import ValidateCallWrapper + +if typing.TYPE_CHECKING: + from ..fields import Field as PydanticModelField + from ..fields import FieldInfo, ModelPrivateAttr + from ..main import BaseModel +else: + # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915 + # and https://youtrack.jetbrains.com/issue/PY-51428 + DeprecationWarning = PydanticDeprecatedSince20 + PydanticModelField = object() + +object_setattr = object.__setattr__ + + +class _ModelNamespaceDict(dict): + """A dictionary subclass that intercepts attribute setting on model classes and + warns about overriding of decorators. + """ + + def __setitem__(self, k: str, v: object) -> None: + existing: Any = self.get(k, None) + if existing and v is not existing and isinstance(existing, PydanticDescriptorProxy): + warnings.warn(f'`{k}` overrides an existing Pydantic `{existing.decorator_info.decorator_repr}` decorator') + + return super().__setitem__(k, v) + + +@dataclass_transform(kw_only_default=True, field_specifiers=(PydanticModelField,)) +class ModelMetaclass(ABCMeta): + def __new__( + mcs, + cls_name: str, + bases: tuple[type[Any], ...], + namespace: dict[str, Any], + __pydantic_generic_metadata__: PydanticGenericMetadata | None = None, + __pydantic_reset_parent_namespace__: bool = True, + _create_model_module: str | None = None, + **kwargs: Any, + ) -> type: + """Metaclass for creating Pydantic models. + + Args: + cls_name: The name of the class to be created. + bases: The base classes of the class to be created. + namespace: The attribute dictionary of the class to be created. + __pydantic_generic_metadata__: Metadata for generic models. + __pydantic_reset_parent_namespace__: Reset parent namespace. + _create_model_module: The module of the class to be created, if created by `create_model`. + **kwargs: Catch-all for any other keyword arguments. + + Returns: + The new class created by the metaclass. + """ + # Note `ModelMetaclass` refers to `BaseModel`, but is also used to *create* `BaseModel`, so we rely on the fact + # that `BaseModel` itself won't have any bases, but any subclass of it will, to determine whether the `__new__` + # call we're in the middle of is for the `BaseModel` class. + if bases: + base_field_names, class_vars, base_private_attributes = mcs._collect_bases_data(bases) + + config_wrapper = ConfigWrapper.for_model(bases, namespace, kwargs) + namespace['model_config'] = config_wrapper.config_dict + private_attributes = inspect_namespace( + namespace, config_wrapper.ignored_types, class_vars, base_field_names + ) + if private_attributes: + original_model_post_init = get_model_post_init(namespace, bases) + if original_model_post_init is not None: + # if there are private_attributes and a model_post_init function, we handle both + + def wrapped_model_post_init(self: BaseModel, __context: Any) -> None: + """We need to both initialize private attributes and call the user-defined model_post_init + method. + """ + init_private_attributes(self, __context) + original_model_post_init(self, __context) + + namespace['model_post_init'] = wrapped_model_post_init + else: + namespace['model_post_init'] = init_private_attributes + + namespace['__class_vars__'] = class_vars + namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes} + + cls: type[BaseModel] = super().__new__(mcs, cls_name, bases, namespace, **kwargs) # type: ignore + + from ..main import BaseModel + + mro = cls.__mro__ + if Generic in mro and mro.index(Generic) < mro.index(BaseModel): + warnings.warn( + GenericBeforeBaseModelWarning( + 'Classes should inherit from `BaseModel` before generic classes (e.g. `typing.Generic[T]`) ' + 'for pydantic generics to work properly.' + ), + stacklevel=2, + ) + + cls.__pydantic_custom_init__ = not getattr(cls.__init__, '__pydantic_base_init__', False) + cls.__pydantic_post_init__ = None if cls.model_post_init is BaseModel.model_post_init else 'model_post_init' + + cls.__pydantic_decorators__ = DecoratorInfos.build(cls) + + # Use the getattr below to grab the __parameters__ from the `typing.Generic` parent class + if __pydantic_generic_metadata__: + cls.__pydantic_generic_metadata__ = __pydantic_generic_metadata__ + else: + parent_parameters = getattr(cls, '__pydantic_generic_metadata__', {}).get('parameters', ()) + parameters = getattr(cls, '__parameters__', None) or parent_parameters + if parameters and parent_parameters and not all(x in parameters for x in parent_parameters): + combined_parameters = parent_parameters + tuple(x for x in parameters if x not in parent_parameters) + parameters_str = ', '.join([str(x) for x in combined_parameters]) + generic_type_label = f'typing.Generic[{parameters_str}]' + error_message = ( + f'All parameters must be present on typing.Generic;' + f' you should inherit from {generic_type_label}.' + ) + if Generic not in bases: # pragma: no cover + # We raise an error here not because it is desirable, but because some cases are mishandled. + # It would be nice to remove this error and still have things behave as expected, it's just + # challenging because we are using a custom `__class_getitem__` to parametrize generic models, + # and not returning a typing._GenericAlias from it. + bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label]) + error_message += ( + f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)' + ) + raise TypeError(error_message) + + cls.__pydantic_generic_metadata__ = { + 'origin': None, + 'args': (), + 'parameters': parameters, + } + + cls.__pydantic_complete__ = False # Ensure this specific class gets completed + + # preserve `__set_name__` protocol defined in https://peps.python.org/pep-0487 + # for attributes not in `new_namespace` (e.g. private attributes) + for name, obj in private_attributes.items(): + obj.__set_name__(cls, name) + + if __pydantic_reset_parent_namespace__: + cls.__pydantic_parent_namespace__ = build_lenient_weakvaluedict(parent_frame_namespace()) + parent_namespace = getattr(cls, '__pydantic_parent_namespace__', None) + if isinstance(parent_namespace, dict): + parent_namespace = unpack_lenient_weakvaluedict(parent_namespace) + + types_namespace = get_cls_types_namespace(cls, parent_namespace) + set_model_fields(cls, bases, config_wrapper, types_namespace) + + if config_wrapper.frozen and '__hash__' not in namespace: + set_default_hash_func(cls, bases) + + complete_model_class( + cls, + cls_name, + config_wrapper, + raise_errors=False, + types_namespace=types_namespace, + create_model_module=_create_model_module, + ) + + # If this is placed before the complete_model_class call above, + # the generic computed fields return type is set to PydanticUndefined + cls.model_computed_fields = {k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()} + + # using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__ + # I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is + # only hit for _proper_ subclasses of BaseModel + super(cls, cls).__pydantic_init_subclass__(**kwargs) # type: ignore[misc] + return cls + else: + # this is the BaseModel class itself being created, no logic required + return super().__new__(mcs, cls_name, bases, namespace, **kwargs) + + if not typing.TYPE_CHECKING: # pragma: no branch + # We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access + + def __getattr__(self, item: str) -> Any: + """This is necessary to keep attribute access working for class attribute access.""" + private_attributes = self.__dict__.get('__private_attributes__') + if private_attributes and item in private_attributes: + return private_attributes[item] + if item == '__pydantic_core_schema__': + # This means the class didn't get a schema generated for it, likely because there was an undefined reference + maybe_mock_validator = getattr(self, '__pydantic_validator__', None) + if isinstance(maybe_mock_validator, MockValSer): + rebuilt_validator = maybe_mock_validator.rebuild() + if rebuilt_validator is not None: + # In this case, a validator was built, and so `__pydantic_core_schema__` should now be set + return getattr(self, '__pydantic_core_schema__') + raise AttributeError(item) + + @classmethod + def __prepare__(cls, *args: Any, **kwargs: Any) -> dict[str, object]: + return _ModelNamespaceDict() + + def __instancecheck__(self, instance: Any) -> bool: + """Avoid calling ABC _abc_subclasscheck unless we're pretty sure. + + See #3829 and python/cpython#92810 + """ + return hasattr(instance, '__pydantic_validator__') and super().__instancecheck__(instance) + + @staticmethod + def _collect_bases_data(bases: tuple[type[Any], ...]) -> tuple[set[str], set[str], dict[str, ModelPrivateAttr]]: + from ..main import BaseModel + + field_names: set[str] = set() + class_vars: set[str] = set() + private_attributes: dict[str, ModelPrivateAttr] = {} + for base in bases: + if issubclass(base, BaseModel) and base is not BaseModel: + # model_fields might not be defined yet in the case of generics, so we use getattr here: + field_names.update(getattr(base, 'model_fields', {}).keys()) + class_vars.update(base.__class_vars__) + private_attributes.update(base.__private_attributes__) + return field_names, class_vars, private_attributes + + @property + @deprecated('The `__fields__` attribute is deprecated, use `model_fields` instead.', category=None) + def __fields__(self) -> dict[str, FieldInfo]: + warnings.warn( + 'The `__fields__` attribute is deprecated, use `model_fields` instead.', PydanticDeprecatedSince20 + ) + return self.model_fields # type: ignore + + def __dir__(self) -> list[str]: + attributes = list(super().__dir__()) + if '__fields__' in attributes: + attributes.remove('__fields__') + return attributes + + +def init_private_attributes(self: BaseModel, __context: Any) -> None: + """This function is meant to behave like a BaseModel method to initialise private attributes. + + It takes context as an argument since that's what pydantic-core passes when calling it. + + Args: + self: The BaseModel instance. + __context: The context. + """ + if getattr(self, '__pydantic_private__', None) is None: + pydantic_private = {} + for name, private_attr in self.__private_attributes__.items(): + default = private_attr.get_default() + if default is not PydanticUndefined: + pydantic_private[name] = default + object_setattr(self, '__pydantic_private__', pydantic_private) + + +def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> Callable[..., Any] | None: + """Get the `model_post_init` method from the namespace or the class bases, or `None` if not defined.""" + if 'model_post_init' in namespace: + return namespace['model_post_init'] + + from ..main import BaseModel + + model_post_init = get_attribute_from_bases(bases, 'model_post_init') + if model_post_init is not BaseModel.model_post_init: + return model_post_init + + +def inspect_namespace( # noqa C901 + namespace: dict[str, Any], + ignored_types: tuple[type[Any], ...], + base_class_vars: set[str], + base_class_fields: set[str], +) -> dict[str, ModelPrivateAttr]: + """Iterate over the namespace and: + * gather private attributes + * check for items which look like fields but are not (e.g. have no annotation) and warn. + + Args: + namespace: The attribute dictionary of the class to be created. + ignored_types: A tuple of ignore types. + base_class_vars: A set of base class class variables. + base_class_fields: A set of base class fields. + + Returns: + A dict contains private attributes info. + + Raises: + TypeError: If there is a `__root__` field in model. + NameError: If private attribute name is invalid. + PydanticUserError: + - If a field does not have a type annotation. + - If a field on base class was overridden by a non-annotated attribute. + """ + from ..fields import FieldInfo, ModelPrivateAttr, PrivateAttr + + all_ignored_types = ignored_types + default_ignored_types() + + private_attributes: dict[str, ModelPrivateAttr] = {} + raw_annotations = namespace.get('__annotations__', {}) + + if '__root__' in raw_annotations or '__root__' in namespace: + raise TypeError("To define root models, use `pydantic.RootModel` rather than a field called '__root__'") + + ignored_names: set[str] = set() + for var_name, value in list(namespace.items()): + if var_name == 'model_config': + continue + elif ( + isinstance(value, type) + and value.__module__ == namespace['__module__'] + and value.__qualname__.startswith(namespace['__qualname__']) + ): + # `value` is a nested type defined in this namespace; don't error + continue + elif isinstance(value, all_ignored_types) or value.__class__.__module__ == 'functools': + ignored_names.add(var_name) + continue + elif isinstance(value, ModelPrivateAttr): + if var_name.startswith('__'): + raise NameError( + 'Private attributes must not use dunder names;' + f' use a single underscore prefix instead of {var_name!r}.' + ) + elif is_valid_field_name(var_name): + raise NameError( + 'Private attributes must not use valid field names;' + f' use sunder names, e.g. {"_" + var_name!r} instead of {var_name!r}.' + ) + private_attributes[var_name] = value + del namespace[var_name] + elif isinstance(value, FieldInfo) and not is_valid_field_name(var_name): + suggested_name = var_name.lstrip('_') or 'my_field' # don't suggest '' for all-underscore name + raise NameError( + f'Fields must not use names with leading underscores;' + f' e.g., use {suggested_name!r} instead of {var_name!r}.' + ) + + elif var_name.startswith('__'): + continue + elif is_valid_privateattr_name(var_name): + if var_name not in raw_annotations or not is_classvar(raw_annotations[var_name]): + private_attributes[var_name] = PrivateAttr(default=value) + del namespace[var_name] + elif var_name in base_class_vars: + continue + elif var_name not in raw_annotations: + if var_name in base_class_fields: + raise PydanticUserError( + f'Field {var_name!r} defined on a base class was overridden by a non-annotated attribute. ' + f'All field definitions, including overrides, require a type annotation.', + code='model-field-overridden', + ) + elif isinstance(value, FieldInfo): + raise PydanticUserError( + f'Field {var_name!r} requires a type annotation', code='model-field-missing-annotation' + ) + else: + raise PydanticUserError( + f'A non-annotated attribute was detected: `{var_name} = {value!r}`. All model fields require a ' + f'type annotation; if `{var_name}` is not meant to be a field, you may be able to resolve this ' + f"error by annotating it as a `ClassVar` or updating `model_config['ignored_types']`.", + code='model-field-missing-annotation', + ) + + for ann_name, ann_type in raw_annotations.items(): + if ( + is_valid_privateattr_name(ann_name) + and ann_name not in private_attributes + and ann_name not in ignored_names + and not is_classvar(ann_type) + and ann_type not in all_ignored_types + and getattr(ann_type, '__module__', None) != 'functools' + ): + if is_annotated(ann_type): + _, *metadata = typing_extensions.get_args(ann_type) + private_attr = next((v for v in metadata if isinstance(v, ModelPrivateAttr)), None) + if private_attr is not None: + private_attributes[ann_name] = private_attr + continue + private_attributes[ann_name] = PrivateAttr() + + return private_attributes + + +def set_default_hash_func(cls: type[BaseModel], bases: tuple[type[Any], ...]) -> None: + base_hash_func = get_attribute_from_bases(bases, '__hash__') + new_hash_func = make_hash_func(cls) + if base_hash_func in {None, object.__hash__} or getattr(base_hash_func, '__code__', None) == new_hash_func.__code__: + # If `__hash__` is some default, we generate a hash function. + # It will be `None` if not overridden from BaseModel. + # It may be `object.__hash__` if there is another + # parent class earlier in the bases which doesn't override `__hash__` (e.g. `typing.Generic`). + # It may be a value set by `set_default_hash_func` if `cls` is a subclass of another frozen model. + # In the last case we still need a new hash function to account for new `model_fields`. + cls.__hash__ = new_hash_func + + +def make_hash_func(cls: type[BaseModel]) -> Any: + getter = operator.itemgetter(*cls.model_fields.keys()) if cls.model_fields else lambda _: 0 + + def hash_func(self: Any) -> int: + try: + return hash(getter(self.__dict__)) + except KeyError: + # In rare cases (such as when using the deprecated copy method), the __dict__ may not contain + # all model fields, which is how we can get here. + # getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys, + # and wrapping it in a `try` doesn't slow things down much in the common case. + return hash(getter(SafeGetItemProxy(self.__dict__))) + + return hash_func + + +def set_model_fields( + cls: type[BaseModel], bases: tuple[type[Any], ...], config_wrapper: ConfigWrapper, types_namespace: dict[str, Any] +) -> None: + """Collect and set `cls.model_fields` and `cls.__class_vars__`. + + Args: + cls: BaseModel or dataclass. + bases: Parents of the class, generally `cls.__bases__`. + config_wrapper: The config wrapper instance. + types_namespace: Optional extra namespace to look for types in. + """ + typevars_map = get_model_typevars_map(cls) + fields, class_vars = collect_model_fields(cls, bases, config_wrapper, types_namespace, typevars_map=typevars_map) + + cls.model_fields = fields + cls.__class_vars__.update(class_vars) + + for k in class_vars: + # Class vars should not be private attributes + # We remove them _here_ and not earlier because we rely on inspecting the class to determine its classvars, + # but private attributes are determined by inspecting the namespace _prior_ to class creation. + # In the case that a classvar with a leading-'_' is defined via a ForwardRef (e.g., when using + # `__future__.annotations`), we want to remove the private attribute which was detected _before_ we knew it + # evaluated to a classvar + + value = cls.__private_attributes__.pop(k, None) + if value is not None and value.default is not PydanticUndefined: + setattr(cls, k, value.default) + + +def complete_model_class( + cls: type[BaseModel], + cls_name: str, + config_wrapper: ConfigWrapper, + *, + raise_errors: bool = True, + types_namespace: dict[str, Any] | None, + create_model_module: str | None = None, +) -> bool: + """Finish building a model class. + + This logic must be called after class has been created since validation functions must be bound + and `get_type_hints` requires a class object. + + Args: + cls: BaseModel or dataclass. + cls_name: The model or dataclass name. + config_wrapper: The config wrapper instance. + raise_errors: Whether to raise errors. + types_namespace: Optional extra namespace to look for types in. + create_model_module: The module of the class to be created, if created by `create_model`. + + Returns: + `True` if the model is successfully completed, else `False`. + + Raises: + PydanticUndefinedAnnotation: If `PydanticUndefinedAnnotation` occurs in`__get_pydantic_core_schema__` + and `raise_errors=True`. + """ + typevars_map = get_model_typevars_map(cls) + gen_schema = GenerateSchema( + config_wrapper, + types_namespace, + typevars_map, + ) + + handler = CallbackGetCoreSchemaHandler( + partial(gen_schema.generate_schema, from_dunder_get_core_schema=False), + gen_schema, + ref_mode='unpack', + ) + + if config_wrapper.defer_build: + set_model_mocks(cls, cls_name) + return False + + try: + schema = cls.__get_pydantic_core_schema__(cls, handler) + except PydanticUndefinedAnnotation as e: + if raise_errors: + raise + set_model_mocks(cls, cls_name, f'`{e.name}`') + return False + + core_config = config_wrapper.core_config(cls) + + try: + schema = gen_schema.clean_schema(schema) + except gen_schema.CollectedInvalid: + set_model_mocks(cls, cls_name) + return False + + # debug(schema) + cls.__pydantic_core_schema__ = schema + + cls.__pydantic_validator__ = create_schema_validator( + schema, + cls, + create_model_module or cls.__module__, + cls.__qualname__, + 'create_model' if create_model_module else 'BaseModel', + core_config, + config_wrapper.plugin_settings, + ) + cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config) + cls.__pydantic_complete__ = True + + # set __signature__ attr only for model class, but not for its instances + cls.__signature__ = ClassAttribute( + '__signature__', + generate_pydantic_signature(init=cls.__init__, fields=cls.model_fields, config_wrapper=config_wrapper), + ) + return True + + +class _PydanticWeakRef: + """Wrapper for `weakref.ref` that enables `pickle` serialization. + + Cloudpickle fails to serialize `weakref.ref` objects due to an arcane error related + to abstract base classes (`abc.ABC`). This class works around the issue by wrapping + `weakref.ref` instead of subclassing it. + + See https://github.com/pydantic/pydantic/issues/6763 for context. + + Semantics: + - If not pickled, behaves the same as a `weakref.ref`. + - If pickled along with the referenced object, the same `weakref.ref` behavior + will be maintained between them after unpickling. + - If pickled without the referenced object, after unpickling the underlying + reference will be cleared (`__call__` will always return `None`). + """ + + def __init__(self, obj: Any): + if obj is None: + # The object will be `None` upon deserialization if the serialized weakref + # had lost its underlying object. + self._wr = None + else: + self._wr = weakref.ref(obj) + + def __call__(self) -> Any: + if self._wr is None: + return None + else: + return self._wr() + + def __reduce__(self) -> tuple[Callable, tuple[weakref.ReferenceType | None]]: + return _PydanticWeakRef, (self(),) + + +def build_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None: + """Takes an input dictionary, and produces a new value that (invertibly) replaces the values with weakrefs. + + We can't just use a WeakValueDictionary because many types (including int, str, etc.) can't be stored as values + in a WeakValueDictionary. + + The `unpack_lenient_weakvaluedict` function can be used to reverse this operation. + """ + if d is None: + return None + result = {} + for k, v in d.items(): + try: + proxy = _PydanticWeakRef(v) + except TypeError: + proxy = v + result[k] = proxy + return result + + +def unpack_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None: + """Inverts the transform performed by `build_lenient_weakvaluedict`.""" + if d is None: + return None + + result = {} + for k, v in d.items(): + if isinstance(v, _PydanticWeakRef): + v = v() + if v is not None: + result[k] = v + else: + result[k] = v + return result + + +def default_ignored_types() -> tuple[type[Any], ...]: + from ..fields import ComputedFieldInfo + + return ( + FunctionType, + property, + classmethod, + staticmethod, + PydanticDescriptorProxy, + ComputedFieldInfo, + ValidateCallWrapper, + ) diff --git a/lib/pydantic/_internal/_repr.py b/lib/pydantic/_internal/_repr.py new file mode 100644 index 00000000..479b4479 --- /dev/null +++ b/lib/pydantic/_internal/_repr.py @@ -0,0 +1,117 @@ +"""Tools to provide pretty/human-readable display of objects.""" +from __future__ import annotations as _annotations + +import types +import typing +from typing import Any + +import typing_extensions + +from . import _typing_extra + +if typing.TYPE_CHECKING: + ReprArgs: typing_extensions.TypeAlias = 'typing.Iterable[tuple[str | None, Any]]' + RichReprResult: typing_extensions.TypeAlias = ( + 'typing.Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]' + ) + + +class PlainRepr(str): + """String class where repr doesn't include quotes. Useful with Representation when you want to return a string + representation of something that is valid (or pseudo-valid) python. + """ + + def __repr__(self) -> str: + return str(self) + + +class Representation: + # Mixin to provide `__str__`, `__repr__`, and `__pretty__` and `__rich_repr__` methods. + # `__pretty__` is used by [devtools](https://python-devtools.helpmanual.io/). + # `__rich_repr__` is used by [rich](https://rich.readthedocs.io/en/stable/pretty.html). + # (this is not a docstring to avoid adding a docstring to classes which inherit from Representation) + + # we don't want to use a type annotation here as it can break get_type_hints + __slots__ = tuple() # type: typing.Collection[str] + + def __repr_args__(self) -> ReprArgs: + """Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden. + + Can either return: + * name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]` + * or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]` + """ + attrs_names = self.__slots__ + if not attrs_names and hasattr(self, '__dict__'): + attrs_names = self.__dict__.keys() + attrs = ((s, getattr(self, s)) for s in attrs_names) + return [(a, v) for a, v in attrs if v is not None] + + def __repr_name__(self) -> str: + """Name of the instance's class, used in __repr__.""" + return self.__class__.__name__ + + def __repr_str__(self, join_str: str) -> str: + return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__()) + + def __pretty__(self, fmt: typing.Callable[[Any], Any], **kwargs: Any) -> typing.Generator[Any, None, None]: + """Used by devtools (https://python-devtools.helpmanual.io/) to pretty print objects.""" + yield self.__repr_name__() + '(' + yield 1 + for name, value in self.__repr_args__(): + if name is not None: + yield name + '=' + yield fmt(value) + yield ',' + yield 0 + yield -1 + yield ')' + + def __rich_repr__(self) -> RichReprResult: + """Used by Rich (https://rich.readthedocs.io/en/stable/pretty.html) to pretty print objects.""" + for name, field_repr in self.__repr_args__(): + if name is None: + yield field_repr + else: + yield name, field_repr + + def __str__(self) -> str: + return self.__repr_str__(' ') + + def __repr__(self) -> str: + return f'{self.__repr_name__()}({self.__repr_str__(", ")})' + + +def display_as_type(obj: Any) -> str: + """Pretty representation of a type, should be as close as possible to the original type definition string. + + Takes some logic from `typing._type_repr`. + """ + if isinstance(obj, types.FunctionType): + return obj.__name__ + elif obj is ...: + return '...' + elif isinstance(obj, Representation): + return repr(obj) + elif isinstance(obj, typing_extensions.TypeAliasType): + return str(obj) + + if not isinstance(obj, (_typing_extra.typing_base, _typing_extra.WithArgsTypes, type)): + obj = obj.__class__ + + if _typing_extra.origin_is_union(typing_extensions.get_origin(obj)): + args = ', '.join(map(display_as_type, typing_extensions.get_args(obj))) + return f'Union[{args}]' + elif isinstance(obj, _typing_extra.WithArgsTypes): + if typing_extensions.get_origin(obj) == typing_extensions.Literal: + args = ', '.join(map(repr, typing_extensions.get_args(obj))) + else: + args = ', '.join(map(display_as_type, typing_extensions.get_args(obj))) + try: + return f'{obj.__qualname__}[{args}]' + except AttributeError: + return str(obj) # handles TypeAliasType in 3.12 + elif isinstance(obj, type): + return obj.__qualname__ + else: + return repr(obj).replace('typing.', '').replace('typing_extensions.', '') diff --git a/lib/pydantic/_internal/_schema_generation_shared.py b/lib/pydantic/_internal/_schema_generation_shared.py new file mode 100644 index 00000000..1a9aa852 --- /dev/null +++ b/lib/pydantic/_internal/_schema_generation_shared.py @@ -0,0 +1,124 @@ +"""Types and utility functions used by various other internal tools.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable + +from pydantic_core import core_schema +from typing_extensions import Literal + +from ..annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler + +if TYPE_CHECKING: + from ..json_schema import GenerateJsonSchema, JsonSchemaValue + from ._core_utils import CoreSchemaOrField + from ._generate_schema import GenerateSchema + + GetJsonSchemaFunction = Callable[[CoreSchemaOrField, GetJsonSchemaHandler], JsonSchemaValue] + HandlerOverride = Callable[[CoreSchemaOrField], JsonSchemaValue] + + +class GenerateJsonSchemaHandler(GetJsonSchemaHandler): + """JsonSchemaHandler implementation that doesn't do ref unwrapping by default. + + This is used for any Annotated metadata so that we don't end up with conflicting + modifications to the definition schema. + + Used internally by Pydantic, please do not rely on this implementation. + See `GetJsonSchemaHandler` for the handler API. + """ + + def __init__(self, generate_json_schema: GenerateJsonSchema, handler_override: HandlerOverride | None) -> None: + self.generate_json_schema = generate_json_schema + self.handler = handler_override or generate_json_schema.generate_inner + self.mode = generate_json_schema.mode + + def __call__(self, __core_schema: CoreSchemaOrField) -> JsonSchemaValue: + return self.handler(__core_schema) + + def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue: + """Resolves `$ref` in the json schema. + + This returns the input json schema if there is no `$ref` in json schema. + + Args: + maybe_ref_json_schema: The input json schema that may contains `$ref`. + + Returns: + Resolved json schema. + + Raises: + LookupError: If it can't find the definition for `$ref`. + """ + if '$ref' not in maybe_ref_json_schema: + return maybe_ref_json_schema + ref = maybe_ref_json_schema['$ref'] + json_schema = self.generate_json_schema.get_schema_from_definitions(ref) + if json_schema is None: + raise LookupError( + f'Could not find a ref for {ref}.' + ' Maybe you tried to call resolve_ref_schema from within a recursive model?' + ) + return json_schema + + +class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler): + """Wrapper to use an arbitrary function as a `GetCoreSchemaHandler`. + + Used internally by Pydantic, please do not rely on this implementation. + See `GetCoreSchemaHandler` for the handler API. + """ + + def __init__( + self, + handler: Callable[[Any], core_schema.CoreSchema], + generate_schema: GenerateSchema, + ref_mode: Literal['to-def', 'unpack'] = 'to-def', + ) -> None: + self._handler = handler + self._generate_schema = generate_schema + self._ref_mode = ref_mode + + def __call__(self, __source_type: Any) -> core_schema.CoreSchema: + schema = self._handler(__source_type) + ref = schema.get('ref') + if self._ref_mode == 'to-def': + if ref is not None: + self._generate_schema.defs.definitions[ref] = schema + return core_schema.definition_reference_schema(ref) + return schema + else: # ref_mode = 'unpack + return self.resolve_ref_schema(schema) + + def _get_types_namespace(self) -> dict[str, Any] | None: + return self._generate_schema._types_namespace + + def generate_schema(self, __source_type: Any) -> core_schema.CoreSchema: + return self._generate_schema.generate_schema(__source_type) + + @property + def field_name(self) -> str | None: + return self._generate_schema.field_name_stack.get() + + def resolve_ref_schema(self, maybe_ref_schema: core_schema.CoreSchema) -> core_schema.CoreSchema: + """Resolves reference in the core schema. + + Args: + maybe_ref_schema: The input core schema that may contains reference. + + Returns: + Resolved core schema. + + Raises: + LookupError: If it can't find the definition for reference. + """ + if maybe_ref_schema['type'] == 'definition-ref': + ref = maybe_ref_schema['schema_ref'] + if ref not in self._generate_schema.defs.definitions: + raise LookupError( + f'Could not find a ref for {ref}.' + ' Maybe you tried to call resolve_ref_schema from within a recursive model?' + ) + return self._generate_schema.defs.definitions[ref] + elif maybe_ref_schema['type'] == 'definitions': + return self.resolve_ref_schema(maybe_ref_schema['schema']) + return maybe_ref_schema diff --git a/lib/pydantic/_internal/_signature.py b/lib/pydantic/_internal/_signature.py new file mode 100644 index 00000000..816a1651 --- /dev/null +++ b/lib/pydantic/_internal/_signature.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +import dataclasses +from inspect import Parameter, Signature, signature +from typing import TYPE_CHECKING, Any, Callable + +from pydantic_core import PydanticUndefined + +from ._config import ConfigWrapper +from ._utils import is_valid_identifier + +if TYPE_CHECKING: + from ..fields import FieldInfo + + +def _field_name_for_signature(field_name: str, field_info: FieldInfo) -> str: + """Extract the correct name to use for the field when generating a signature. + + Assuming the field has a valid alias, this will return the alias. Otherwise, it will return the field name. + First priority is given to the validation_alias, then the alias, then the field name. + + Args: + field_name: The name of the field + field_info: The corresponding FieldInfo object. + + Returns: + The correct name to use when generating a signature. + """ + + def _alias_if_valid(x: Any) -> str | None: + """Return the alias if it is a valid alias and identifier, else None.""" + return x if isinstance(x, str) and is_valid_identifier(x) else None + + return _alias_if_valid(field_info.alias) or _alias_if_valid(field_info.validation_alias) or field_name + + +def _process_param_defaults(param: Parameter) -> Parameter: + """Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance. + + Args: + param (Parameter): The parameter + + Returns: + Parameter: The custom processed parameter + """ + from ..fields import FieldInfo + + param_default = param.default + if isinstance(param_default, FieldInfo): + annotation = param.annotation + # Replace the annotation if appropriate + # inspect does "clever" things to show annotations as strings because we have + # `from __future__ import annotations` in main, we don't want that + if annotation == 'Any': + annotation = Any + + # Replace the field default + default = param_default.default + if default is PydanticUndefined: + if param_default.default_factory is PydanticUndefined: + default = Signature.empty + else: + # this is used by dataclasses to indicate a factory exists: + default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore + return param.replace( + annotation=annotation, name=_field_name_for_signature(param.name, param_default), default=default + ) + return param + + +def _generate_signature_parameters( # noqa: C901 (ignore complexity, could use a refactor) + init: Callable[..., None], + fields: dict[str, FieldInfo], + config_wrapper: ConfigWrapper, +) -> dict[str, Parameter]: + """Generate a mapping of parameter names to Parameter objects for a pydantic BaseModel or dataclass.""" + from itertools import islice + + present_params = signature(init).parameters.values() + merged_params: dict[str, Parameter] = {} + var_kw = None + use_var_kw = False + + for param in islice(present_params, 1, None): # skip self arg + # inspect does "clever" things to show annotations as strings because we have + # `from __future__ import annotations` in main, we don't want that + if fields.get(param.name): + # exclude params with init=False + if getattr(fields[param.name], 'init', True) is False: + continue + param = param.replace(name=_field_name_for_signature(param.name, fields[param.name])) + if param.annotation == 'Any': + param = param.replace(annotation=Any) + if param.kind is param.VAR_KEYWORD: + var_kw = param + continue + merged_params[param.name] = param + + if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through + allow_names = config_wrapper.populate_by_name + for field_name, field in fields.items(): + # when alias is a str it should be used for signature generation + param_name = _field_name_for_signature(field_name, field) + + if field_name in merged_params or param_name in merged_params: + continue + + if not is_valid_identifier(param_name): + if allow_names: + param_name = field_name + else: + use_var_kw = True + continue + + kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)} + merged_params[param_name] = Parameter( + param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs + ) + + if config_wrapper.extra == 'allow': + use_var_kw = True + + if var_kw and use_var_kw: + # Make sure the parameter for extra kwargs + # does not have the same name as a field + default_model_signature = [ + ('self', Parameter.POSITIONAL_ONLY), + ('data', Parameter.VAR_KEYWORD), + ] + if [(p.name, p.kind) for p in present_params] == default_model_signature: + # if this is the standard model signature, use extra_data as the extra args name + var_kw_name = 'extra_data' + else: + # else start from var_kw + var_kw_name = var_kw.name + + # generate a name that's definitely unique + while var_kw_name in fields: + var_kw_name += '_' + merged_params[var_kw_name] = var_kw.replace(name=var_kw_name) + + return merged_params + + +def generate_pydantic_signature( + init: Callable[..., None], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper, is_dataclass: bool = False +) -> Signature: + """Generate signature for a pydantic BaseModel or dataclass. + + Args: + init: The class init. + fields: The model fields. + config_wrapper: The config wrapper instance. + is_dataclass: Whether the model is a dataclass. + + Returns: + The dataclass/BaseModel subclass signature. + """ + merged_params = _generate_signature_parameters(init, fields, config_wrapper) + + if is_dataclass: + merged_params = {k: _process_param_defaults(v) for k, v in merged_params.items()} + + return Signature(parameters=list(merged_params.values()), return_annotation=None) diff --git a/lib/pydantic/_internal/_std_types_schema.py b/lib/pydantic/_internal/_std_types_schema.py new file mode 100644 index 00000000..c8523bf4 --- /dev/null +++ b/lib/pydantic/_internal/_std_types_schema.py @@ -0,0 +1,714 @@ +"""Logic for generating pydantic-core schemas for standard library types. + +Import of this module is deferred since it contains imports of many standard library modules. +""" +from __future__ import annotations as _annotations + +import collections +import collections.abc +import dataclasses +import decimal +import inspect +import os +import typing +from enum import Enum +from functools import partial +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from typing import Any, Callable, Iterable, TypeVar + +import typing_extensions +from pydantic_core import ( + CoreSchema, + MultiHostUrl, + PydanticCustomError, + PydanticOmit, + Url, + core_schema, +) +from typing_extensions import get_args, get_origin + +from pydantic.errors import PydanticSchemaGenerationError +from pydantic.fields import FieldInfo +from pydantic.types import Strict + +from ..config import ConfigDict +from ..json_schema import JsonSchemaValue, update_json_schema +from . import _known_annotated_metadata, _typing_extra, _validators +from ._core_utils import get_type_ref +from ._internal_dataclass import slots_true +from ._schema_generation_shared import GetCoreSchemaHandler, GetJsonSchemaHandler + +if typing.TYPE_CHECKING: + from ._generate_schema import GenerateSchema + + StdSchemaFunction = Callable[[GenerateSchema, type[Any]], core_schema.CoreSchema] + + +@dataclasses.dataclass(**slots_true) +class SchemaTransformer: + get_core_schema: Callable[[Any, GetCoreSchemaHandler], CoreSchema] + get_json_schema: Callable[[CoreSchema, GetJsonSchemaHandler], JsonSchemaValue] + + def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: + return self.get_core_schema(source_type, handler) + + def __get_pydantic_json_schema__(self, schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + return self.get_json_schema(schema, handler) + + +def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchema: + cases: list[Any] = list(enum_type.__members__.values()) + + enum_ref = get_type_ref(enum_type) + description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__) + if description == 'An enumeration.': # This is the default value provided by enum.EnumMeta.__new__; don't use it + description = None + updates = {'title': enum_type.__name__, 'description': description} + updates = {k: v for k, v in updates.items() if v is not None} + + def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + json_schema = handler(core_schema.literal_schema([x.value for x in cases], ref=enum_ref)) + original_schema = handler.resolve_ref_schema(json_schema) + update_json_schema(original_schema, updates) + return json_schema + + if not cases: + # Use an isinstance check for enums with no cases. + # The most important use case for this is creating TypeVar bounds for generics that should + # be restricted to enums. This is more consistent than it might seem at first, since you can only + # subclass enum.Enum (or subclasses of enum.Enum) if all parent classes have no cases. + # We use the get_json_schema function when an Enum subclass has been declared with no cases + # so that we can still generate a valid json schema. + return core_schema.is_instance_schema(enum_type, metadata={'pydantic_js_functions': [get_json_schema]}) + + use_enum_values = config.get('use_enum_values', False) + + if len(cases) == 1: + expected = repr(cases[0].value) + else: + expected = ', '.join([repr(case.value) for case in cases[:-1]]) + f' or {cases[-1].value!r}' + + def to_enum(__input_value: Any) -> Enum: + try: + enum_field = enum_type(__input_value) + if use_enum_values: + return enum_field.value + return enum_field + except ValueError: + # The type: ignore on the next line is to ignore the requirement of LiteralString + raise PydanticCustomError('enum', f'Input should be {expected}', {'expected': expected}) # type: ignore + + strict_python_schema = core_schema.is_instance_schema(enum_type) + if use_enum_values: + strict_python_schema = core_schema.chain_schema( + [strict_python_schema, core_schema.no_info_plain_validator_function(lambda x: x.value)] + ) + + to_enum_validator = core_schema.no_info_plain_validator_function(to_enum) + if issubclass(enum_type, int): + # this handles `IntEnum`, and also `Foobar(int, Enum)` + updates['type'] = 'integer' + lax = core_schema.chain_schema([core_schema.int_schema(), to_enum_validator]) + # Disallow float from JSON due to strict mode + strict = core_schema.json_or_python_schema( + json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.int_schema()), + python_schema=strict_python_schema, + ) + elif issubclass(enum_type, str): + # this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)` + updates['type'] = 'string' + lax = core_schema.chain_schema([core_schema.str_schema(), to_enum_validator]) + strict = core_schema.json_or_python_schema( + json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.str_schema()), + python_schema=strict_python_schema, + ) + elif issubclass(enum_type, float): + updates['type'] = 'numeric' + lax = core_schema.chain_schema([core_schema.float_schema(), to_enum_validator]) + strict = core_schema.json_or_python_schema( + json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.float_schema()), + python_schema=strict_python_schema, + ) + else: + lax = to_enum_validator + strict = core_schema.json_or_python_schema(json_schema=to_enum_validator, python_schema=strict_python_schema) + return core_schema.lax_or_strict_schema( + lax_schema=lax, strict_schema=strict, ref=enum_ref, metadata={'pydantic_js_functions': [get_json_schema]} + ) + + +@dataclasses.dataclass(**slots_true) +class InnerSchemaValidator: + """Use a fixed CoreSchema, avoiding interference from outward annotations.""" + + core_schema: CoreSchema + js_schema: JsonSchemaValue | None = None + js_core_schema: CoreSchema | None = None + js_schema_update: JsonSchemaValue | None = None + + def __get_pydantic_json_schema__(self, _schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + if self.js_schema is not None: + return self.js_schema + js_schema = handler(self.js_core_schema or self.core_schema) + if self.js_schema_update is not None: + js_schema.update(self.js_schema_update) + return js_schema + + def __get_pydantic_core_schema__(self, _source_type: Any, _handler: GetCoreSchemaHandler) -> CoreSchema: + return self.core_schema + + +def decimal_prepare_pydantic_annotations( + source: Any, annotations: Iterable[Any], config: ConfigDict +) -> tuple[Any, list[Any]] | None: + if source is not decimal.Decimal: + return None + + metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations) + + config_allow_inf_nan = config.get('allow_inf_nan') + if config_allow_inf_nan is not None: + metadata.setdefault('allow_inf_nan', config_allow_inf_nan) + + _known_annotated_metadata.check_metadata( + metadata, {*_known_annotated_metadata.FLOAT_CONSTRAINTS, 'max_digits', 'decimal_places'}, decimal.Decimal + ) + return source, [InnerSchemaValidator(core_schema.decimal_schema(**metadata)), *remaining_annotations] + + +def datetime_prepare_pydantic_annotations( + source_type: Any, annotations: Iterable[Any], _config: ConfigDict +) -> tuple[Any, list[Any]] | None: + import datetime + + metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations) + if source_type is datetime.date: + sv = InnerSchemaValidator(core_schema.date_schema(**metadata)) + elif source_type is datetime.datetime: + sv = InnerSchemaValidator(core_schema.datetime_schema(**metadata)) + elif source_type is datetime.time: + sv = InnerSchemaValidator(core_schema.time_schema(**metadata)) + elif source_type is datetime.timedelta: + sv = InnerSchemaValidator(core_schema.timedelta_schema(**metadata)) + else: + return None + # check now that we know the source type is correct + _known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.DATE_TIME_CONSTRAINTS, source_type) + return (source_type, [sv, *remaining_annotations]) + + +def uuid_prepare_pydantic_annotations( + source_type: Any, annotations: Iterable[Any], _config: ConfigDict +) -> tuple[Any, list[Any]] | None: + # UUIDs have no constraints - they are fixed length, constructing a UUID instance checks the length + + from uuid import UUID + + if source_type is not UUID: + return None + + return (source_type, [InnerSchemaValidator(core_schema.uuid_schema()), *annotations]) + + +def path_schema_prepare_pydantic_annotations( + source_type: Any, annotations: Iterable[Any], _config: ConfigDict +) -> tuple[Any, list[Any]] | None: + import pathlib + + if source_type not in { + os.PathLike, + pathlib.Path, + pathlib.PurePath, + pathlib.PosixPath, + pathlib.PurePosixPath, + pathlib.PureWindowsPath, + }: + return None + + metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations) + _known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.STR_CONSTRAINTS, source_type) + + construct_path = pathlib.PurePath if source_type is os.PathLike else source_type + + def path_validator(input_value: str) -> os.PathLike[Any]: + try: + return construct_path(input_value) + except TypeError as e: + raise PydanticCustomError('path_type', 'Input is not a valid path') from e + + constrained_str_schema = core_schema.str_schema(**metadata) + + instance_schema = core_schema.json_or_python_schema( + json_schema=core_schema.no_info_after_validator_function(path_validator, constrained_str_schema), + python_schema=core_schema.is_instance_schema(source_type), + ) + + strict: bool | None = None + for annotation in annotations: + if isinstance(annotation, Strict): + strict = annotation.strict + + schema = core_schema.lax_or_strict_schema( + lax_schema=core_schema.union_schema( + [ + instance_schema, + core_schema.no_info_after_validator_function(path_validator, constrained_str_schema), + ], + custom_error_type='path_type', + custom_error_message='Input is not a valid path', + strict=True, + ), + strict_schema=instance_schema, + serialization=core_schema.to_string_ser_schema(), + strict=strict, + ) + + return ( + source_type, + [ + InnerSchemaValidator(schema, js_core_schema=constrained_str_schema, js_schema_update={'format': 'path'}), + *remaining_annotations, + ], + ) + + +def dequeue_validator( + input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, maxlen: None | int +) -> collections.deque[Any]: + if isinstance(input_value, collections.deque): + maxlens = [v for v in (input_value.maxlen, maxlen) if v is not None] + if maxlens: + maxlen = min(maxlens) + return collections.deque(handler(input_value), maxlen=maxlen) + else: + return collections.deque(handler(input_value), maxlen=maxlen) + + +@dataclasses.dataclass(**slots_true) +class SequenceValidator: + mapped_origin: type[Any] + item_source_type: type[Any] + min_length: int | None = None + max_length: int | None = None + strict: bool = False + + def serialize_sequence_via_list( + self, v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo + ) -> Any: + items: list[Any] = [] + for index, item in enumerate(v): + try: + v = handler(item, index) + except PydanticOmit: + pass + else: + items.append(v) + + if info.mode_is_json(): + return items + else: + return self.mapped_origin(items) + + def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: + if self.item_source_type is Any: + items_schema = None + else: + items_schema = handler.generate_schema(self.item_source_type) + + metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict} + + if self.mapped_origin in (list, set, frozenset): + if self.mapped_origin is list: + constrained_schema = core_schema.list_schema(items_schema, **metadata) + elif self.mapped_origin is set: + constrained_schema = core_schema.set_schema(items_schema, **metadata) + else: + assert self.mapped_origin is frozenset # safety check in case we forget to add a case + constrained_schema = core_schema.frozenset_schema(items_schema, **metadata) + + schema = constrained_schema + else: + # safety check in case we forget to add a case + assert self.mapped_origin in (collections.deque, collections.Counter) + + if self.mapped_origin is collections.deque: + # if we have a MaxLen annotation might as well set that as the default maxlen on the deque + # this lets us re-use existing metadata annotations to let users set the maxlen on a dequeue + # that e.g. comes from JSON + coerce_instance_wrap = partial( + core_schema.no_info_wrap_validator_function, + partial(dequeue_validator, maxlen=metadata.get('max_length', None)), + ) + else: + coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin) + + constrained_schema = core_schema.list_schema(items_schema, **metadata) + + check_instance = core_schema.json_or_python_schema( + json_schema=core_schema.list_schema(), + python_schema=core_schema.is_instance_schema(self.mapped_origin), + ) + + serialization = core_schema.wrap_serializer_function_ser_schema( + self.serialize_sequence_via_list, schema=items_schema or core_schema.any_schema(), info_arg=True + ) + + strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)]) + + if metadata.get('strict', False): + schema = strict + else: + lax = coerce_instance_wrap(constrained_schema) + schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict) + schema['serialization'] = serialization + + return schema + + +SEQUENCE_ORIGIN_MAP: dict[Any, Any] = { + typing.Deque: collections.deque, + collections.deque: collections.deque, + list: list, + typing.List: list, + set: set, + typing.AbstractSet: set, + typing.Set: set, + frozenset: frozenset, + typing.FrozenSet: frozenset, + typing.Sequence: list, + typing.MutableSequence: list, + typing.MutableSet: set, + # this doesn't handle subclasses of these + # parametrized typing.Set creates one of these + collections.abc.MutableSet: set, + collections.abc.Set: frozenset, +} + + +def identity(s: CoreSchema) -> CoreSchema: + return s + + +def sequence_like_prepare_pydantic_annotations( + source_type: Any, annotations: Iterable[Any], _config: ConfigDict +) -> tuple[Any, list[Any]] | None: + origin: Any = get_origin(source_type) + + mapped_origin = SEQUENCE_ORIGIN_MAP.get(origin, None) if origin else SEQUENCE_ORIGIN_MAP.get(source_type, None) + if mapped_origin is None: + return None + + args = get_args(source_type) + + if not args: + args = (Any,) + elif len(args) != 1: + raise ValueError('Expected sequence to have exactly 1 generic parameter') + + item_source_type = args[0] + + metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations) + _known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type) + + return (source_type, [SequenceValidator(mapped_origin, item_source_type, **metadata), *remaining_annotations]) + + +MAPPING_ORIGIN_MAP: dict[Any, Any] = { + typing.DefaultDict: collections.defaultdict, + collections.defaultdict: collections.defaultdict, + collections.OrderedDict: collections.OrderedDict, + typing_extensions.OrderedDict: collections.OrderedDict, + dict: dict, + typing.Dict: dict, + collections.Counter: collections.Counter, + typing.Counter: collections.Counter, + # this doesn't handle subclasses of these + typing.Mapping: dict, + typing.MutableMapping: dict, + # parametrized typing.{Mutable}Mapping creates one of these + collections.abc.MutableMapping: dict, + collections.abc.Mapping: dict, +} + + +def defaultdict_validator( + input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any] +) -> collections.defaultdict[Any, Any]: + if isinstance(input_value, collections.defaultdict): + default_factory = input_value.default_factory + return collections.defaultdict(default_factory, handler(input_value)) + else: + return collections.defaultdict(default_default_factory, handler(input_value)) + + +def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]: + def infer_default() -> Callable[[], Any]: + allowed_default_types: dict[Any, Any] = { + typing.Tuple: tuple, + tuple: tuple, + collections.abc.Sequence: tuple, + collections.abc.MutableSequence: list, + typing.List: list, + list: list, + typing.Sequence: list, + typing.Set: set, + set: set, + typing.MutableSet: set, + collections.abc.MutableSet: set, + collections.abc.Set: frozenset, + typing.MutableMapping: dict, + typing.Mapping: dict, + collections.abc.Mapping: dict, + collections.abc.MutableMapping: dict, + float: float, + int: int, + str: str, + bool: bool, + } + values_type_origin = get_origin(values_source_type) or values_source_type + instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`' + if isinstance(values_type_origin, TypeVar): + + def type_var_default_factory() -> None: + raise RuntimeError( + 'Generic defaultdict cannot be used without a concrete value type or an' + ' explicit default factory, ' + instructions + ) + + return type_var_default_factory + elif values_type_origin not in allowed_default_types: + # a somewhat subjective set of types that have reasonable default values + allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())]) + raise PydanticSchemaGenerationError( + f'Unable to infer a default factory for keys of type {values_source_type}.' + f' Only {allowed_msg} are supported, other types require an explicit default factory' + ' ' + instructions + ) + return allowed_default_types[values_type_origin] + + # Assume Annotated[..., Field(...)] + if _typing_extra.is_annotated(values_source_type): + field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None) + else: + field_info = None + if field_info and field_info.default_factory: + default_default_factory = field_info.default_factory + else: + default_default_factory = infer_default() + return default_default_factory + + +@dataclasses.dataclass(**slots_true) +class MappingValidator: + mapped_origin: type[Any] + keys_source_type: type[Any] + values_source_type: type[Any] + min_length: int | None = None + max_length: int | None = None + strict: bool = False + + def serialize_mapping_via_dict(self, v: Any, handler: core_schema.SerializerFunctionWrapHandler) -> Any: + return handler(v) + + def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: + if self.keys_source_type is Any: + keys_schema = None + else: + keys_schema = handler.generate_schema(self.keys_source_type) + if self.values_source_type is Any: + values_schema = None + else: + values_schema = handler.generate_schema(self.values_source_type) + + metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict} + + if self.mapped_origin is dict: + schema = core_schema.dict_schema(keys_schema, values_schema, **metadata) + else: + constrained_schema = core_schema.dict_schema(keys_schema, values_schema, **metadata) + check_instance = core_schema.json_or_python_schema( + json_schema=core_schema.dict_schema(), + python_schema=core_schema.is_instance_schema(self.mapped_origin), + ) + + if self.mapped_origin is collections.defaultdict: + default_default_factory = get_defaultdict_default_default_factory(self.values_source_type) + coerce_instance_wrap = partial( + core_schema.no_info_wrap_validator_function, + partial(defaultdict_validator, default_default_factory=default_default_factory), + ) + else: + coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin) + + serialization = core_schema.wrap_serializer_function_ser_schema( + self.serialize_mapping_via_dict, + schema=core_schema.dict_schema( + keys_schema or core_schema.any_schema(), values_schema or core_schema.any_schema() + ), + info_arg=False, + ) + + strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)]) + + if metadata.get('strict', False): + schema = strict + else: + lax = coerce_instance_wrap(constrained_schema) + schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict) + schema['serialization'] = serialization + + return schema + + +def mapping_like_prepare_pydantic_annotations( + source_type: Any, annotations: Iterable[Any], _config: ConfigDict +) -> tuple[Any, list[Any]] | None: + origin: Any = get_origin(source_type) + + mapped_origin = MAPPING_ORIGIN_MAP.get(origin, None) if origin else MAPPING_ORIGIN_MAP.get(source_type, None) + if mapped_origin is None: + return None + + args = get_args(source_type) + + if not args: + args = (Any, Any) + elif mapped_origin is collections.Counter: + # a single generic + if len(args) != 1: + raise ValueError('Expected Counter to have exactly 1 generic parameter') + args = (args[0], int) # keys are always an int + elif len(args) != 2: + raise ValueError('Expected mapping to have exactly 2 generic parameters') + + keys_source_type, values_source_type = args + + metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations) + _known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type) + + return ( + source_type, + [ + MappingValidator(mapped_origin, keys_source_type, values_source_type, **metadata), + *remaining_annotations, + ], + ) + + +def ip_prepare_pydantic_annotations( + source_type: Any, annotations: Iterable[Any], _config: ConfigDict +) -> tuple[Any, list[Any]] | None: + def make_strict_ip_schema(tp: type[Any]) -> CoreSchema: + return core_schema.json_or_python_schema( + json_schema=core_schema.no_info_after_validator_function(tp, core_schema.str_schema()), + python_schema=core_schema.is_instance_schema(tp), + ) + + if source_type is IPv4Address: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.lax_or_strict_schema( + lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_address_validator), + strict_schema=make_strict_ip_schema(IPv4Address), + serialization=core_schema.to_string_ser_schema(), + ), + lambda _1, _2: {'type': 'string', 'format': 'ipv4'}, + ), + *annotations, + ] + if source_type is IPv4Network: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.lax_or_strict_schema( + lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_network_validator), + strict_schema=make_strict_ip_schema(IPv4Network), + serialization=core_schema.to_string_ser_schema(), + ), + lambda _1, _2: {'type': 'string', 'format': 'ipv4network'}, + ), + *annotations, + ] + if source_type is IPv4Interface: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.lax_or_strict_schema( + lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_interface_validator), + strict_schema=make_strict_ip_schema(IPv4Interface), + serialization=core_schema.to_string_ser_schema(), + ), + lambda _1, _2: {'type': 'string', 'format': 'ipv4interface'}, + ), + *annotations, + ] + + if source_type is IPv6Address: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.lax_or_strict_schema( + lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_address_validator), + strict_schema=make_strict_ip_schema(IPv6Address), + serialization=core_schema.to_string_ser_schema(), + ), + lambda _1, _2: {'type': 'string', 'format': 'ipv6'}, + ), + *annotations, + ] + if source_type is IPv6Network: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.lax_or_strict_schema( + lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_network_validator), + strict_schema=make_strict_ip_schema(IPv6Network), + serialization=core_schema.to_string_ser_schema(), + ), + lambda _1, _2: {'type': 'string', 'format': 'ipv6network'}, + ), + *annotations, + ] + if source_type is IPv6Interface: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.lax_or_strict_schema( + lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_interface_validator), + strict_schema=make_strict_ip_schema(IPv6Interface), + serialization=core_schema.to_string_ser_schema(), + ), + lambda _1, _2: {'type': 'string', 'format': 'ipv6interface'}, + ), + *annotations, + ] + + return None + + +def url_prepare_pydantic_annotations( + source_type: Any, annotations: Iterable[Any], _config: ConfigDict +) -> tuple[Any, list[Any]] | None: + if source_type is Url: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.url_schema(), + lambda cs, handler: handler(cs), + ), + *annotations, + ] + if source_type is MultiHostUrl: + return source_type, [ + SchemaTransformer( + lambda _1, _2: core_schema.multi_host_url_schema(), + lambda cs, handler: handler(cs), + ), + *annotations, + ] + + +PREPARE_METHODS: tuple[Callable[[Any, Iterable[Any], ConfigDict], tuple[Any, list[Any]] | None], ...] = ( + decimal_prepare_pydantic_annotations, + sequence_like_prepare_pydantic_annotations, + datetime_prepare_pydantic_annotations, + uuid_prepare_pydantic_annotations, + path_schema_prepare_pydantic_annotations, + mapping_like_prepare_pydantic_annotations, + ip_prepare_pydantic_annotations, + url_prepare_pydantic_annotations, +) diff --git a/lib/pydantic/_internal/_typing_extra.py b/lib/pydantic/_internal/_typing_extra.py new file mode 100644 index 00000000..1d5d3b3f --- /dev/null +++ b/lib/pydantic/_internal/_typing_extra.py @@ -0,0 +1,469 @@ +"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap python's typing module.""" +from __future__ import annotations as _annotations + +import dataclasses +import sys +import types +import typing +from collections.abc import Callable +from functools import partial +from types import GetSetDescriptorType +from typing import TYPE_CHECKING, Any, Final + +from typing_extensions import Annotated, Literal, TypeAliasType, TypeGuard, get_args, get_origin + +if TYPE_CHECKING: + from ._dataclasses import StandardDataclass + +try: + from typing import _TypingBase # type: ignore[attr-defined] +except ImportError: + from typing import _Final as _TypingBase # type: ignore[attr-defined] + +typing_base = _TypingBase + + +if sys.version_info < (3, 9): + # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on) + TypingGenericAlias = () +else: + from typing import GenericAlias as TypingGenericAlias # type: ignore + + +if sys.version_info < (3, 11): + from typing_extensions import NotRequired, Required +else: + from typing import NotRequired, Required # noqa: F401 + + +if sys.version_info < (3, 10): + + def origin_is_union(tp: type[Any] | None) -> bool: + return tp is typing.Union + + WithArgsTypes = (TypingGenericAlias,) + +else: + + def origin_is_union(tp: type[Any] | None) -> bool: + return tp is typing.Union or tp is types.UnionType + + WithArgsTypes = typing._GenericAlias, types.GenericAlias, types.UnionType # type: ignore[attr-defined] + + +if sys.version_info < (3, 10): + NoneType = type(None) + EllipsisType = type(Ellipsis) +else: + from types import NoneType as NoneType + + +LITERAL_TYPES: set[Any] = {Literal} +if hasattr(typing, 'Literal'): + LITERAL_TYPES.add(typing.Literal) # type: ignore + +NONE_TYPES: tuple[Any, ...] = (None, NoneType, *(tp[None] for tp in LITERAL_TYPES)) + + +TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type + + +def is_none_type(type_: Any) -> bool: + return type_ in NONE_TYPES + + +def is_callable_type(type_: type[Any]) -> bool: + return type_ is Callable or get_origin(type_) is Callable + + +def is_literal_type(type_: type[Any]) -> bool: + return Literal is not None and get_origin(type_) in LITERAL_TYPES + + +def literal_values(type_: type[Any]) -> tuple[Any, ...]: + return get_args(type_) + + +def all_literal_values(type_: type[Any]) -> list[Any]: + """This method is used to retrieve all Literal values as + Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586) + e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`. + """ + if not is_literal_type(type_): + return [type_] + + values = literal_values(type_) + return list(x for value in values for x in all_literal_values(value)) + + +def is_annotated(ann_type: Any) -> bool: + from ._utils import lenient_issubclass + + origin = get_origin(ann_type) + return origin is not None and lenient_issubclass(origin, Annotated) + + +def is_namedtuple(type_: type[Any]) -> bool: + """Check if a given class is a named tuple. + It can be either a `typing.NamedTuple` or `collections.namedtuple`. + """ + from ._utils import lenient_issubclass + + return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields') + + +test_new_type = typing.NewType('test_new_type', str) + + +def is_new_type(type_: type[Any]) -> bool: + """Check whether type_ was created using typing.NewType. + + Can't use isinstance because it fails <3.10. + """ + return isinstance(type_, test_new_type.__class__) and hasattr(type_, '__supertype__') # type: ignore[arg-type] + + +def _check_classvar(v: type[Any] | None) -> bool: + if v is None: + return False + + return v.__class__ == typing.ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar' + + +def is_classvar(ann_type: type[Any]) -> bool: + if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)): + return True + + # this is an ugly workaround for class vars that contain forward references and are therefore themselves + # forward references, see #3679 + if ann_type.__class__ == typing.ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['): # type: ignore + return True + + return False + + +def _check_finalvar(v: type[Any] | None) -> bool: + """Check if a given type is a `typing.Final` type.""" + if v is None: + return False + + return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final') + + +def is_finalvar(ann_type: Any) -> bool: + return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type)) + + +def parent_frame_namespace(*, parent_depth: int = 2) -> dict[str, Any] | None: + """We allow use of items in parent namespace to get around the issue with `get_type_hints` only looking in the + global module namespace. See https://github.com/pydantic/pydantic/issues/2678#issuecomment-1008139014 -> Scope + and suggestion at the end of the next comment by @gvanrossum. + + WARNING 1: it matters exactly where this is called. By default, this function will build a namespace from the + parent of where it is called. + + WARNING 2: this only looks in the parent namespace, not other parents since (AFAIK) there's no way to collect a + dict of exactly what's in scope. Using `f_back` would work sometimes but would be very wrong and confusing in many + other cases. See https://discuss.python.org/t/is-there-a-way-to-access-parent-nested-namespaces/20659. + """ + frame = sys._getframe(parent_depth) + # if f_back is None, it's the global module namespace and we don't need to include it here + if frame.f_back is None: + return None + else: + return frame.f_locals + + +def add_module_globals(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]: + module_name = getattr(obj, '__module__', None) + if module_name: + try: + module_globalns = sys.modules[module_name].__dict__ + except KeyError: + # happens occasionally, see https://github.com/pydantic/pydantic/issues/2363 + pass + else: + if globalns: + return {**module_globalns, **globalns} + else: + # copy module globals to make sure it can't be updated later + return module_globalns.copy() + + return globalns or {} + + +def get_cls_types_namespace(cls: type[Any], parent_namespace: dict[str, Any] | None = None) -> dict[str, Any]: + ns = add_module_globals(cls, parent_namespace) + ns[cls.__name__] = cls + return ns + + +def get_cls_type_hints_lenient(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]: + """Collect annotations from a class, including those from parent classes. + + Unlike `typing.get_type_hints`, this function will not error if a forward reference is not resolvable. + """ + hints = {} + for base in reversed(obj.__mro__): + ann = base.__dict__.get('__annotations__') + localns = dict(vars(base)) + if ann is not None and ann is not GetSetDescriptorType: + for name, value in ann.items(): + hints[name] = eval_type_lenient(value, globalns, localns) + return hints + + +def eval_type_lenient(value: Any, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None) -> Any: + """Behaves like typing._eval_type, except it won't raise an error if a forward reference can't be resolved.""" + if value is None: + value = NoneType + elif isinstance(value, str): + value = _make_forward_ref(value, is_argument=False, is_class=True) + + try: + return eval_type_backport(value, globalns, localns) + except NameError: + # the point of this function is to be tolerant to this case + return value + + +def eval_type_backport( + value: Any, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None +) -> Any: + """Like `typing._eval_type`, but falls back to the `eval_type_backport` package if it's + installed to let older Python versions use newer typing features. + Specifically, this transforms `X | Y` into `typing.Union[X, Y]` + and `list[X]` into `typing.List[X]` etc. (for all the types made generic in PEP 585) + if the original syntax is not supported in the current Python version. + """ + try: + return typing._eval_type( # type: ignore + value, globalns, localns + ) + except TypeError as e: + if not (isinstance(value, typing.ForwardRef) and is_backport_fixable_error(e)): + raise + try: + from eval_type_backport import eval_type_backport + except ImportError: + raise TypeError( + f'You have a type annotation {value.__forward_arg__!r} ' + f'which makes use of newer typing features than are supported in your version of Python. ' + f'To handle this error, you should either remove the use of new syntax ' + f'or install the `eval_type_backport` package.' + ) from e + + return eval_type_backport(value, globalns, localns, try_default=False) + + +def is_backport_fixable_error(e: TypeError) -> bool: + msg = str(e) + return msg.startswith('unsupported operand type(s) for |: ') or "' object is not subscriptable" in msg + + +def get_function_type_hints( + function: Callable[..., Any], *, include_keys: set[str] | None = None, types_namespace: dict[str, Any] | None = None +) -> dict[str, Any]: + """Like `typing.get_type_hints`, but doesn't convert `X` to `Optional[X]` if the default value is `None`, also + copes with `partial`. + """ + if isinstance(function, partial): + annotations = function.func.__annotations__ + else: + annotations = function.__annotations__ + + globalns = add_module_globals(function) + type_hints = {} + for name, value in annotations.items(): + if include_keys is not None and name not in include_keys: + continue + if value is None: + value = NoneType + elif isinstance(value, str): + value = _make_forward_ref(value) + + type_hints[name] = eval_type_backport(value, globalns, types_namespace) + + return type_hints + + +if sys.version_info < (3, 9, 8) or (3, 10) <= sys.version_info < (3, 10, 1): + + def _make_forward_ref( + arg: Any, + is_argument: bool = True, + *, + is_class: bool = False, + ) -> typing.ForwardRef: + """Wrapper for ForwardRef that accounts for the `is_class` argument missing in older versions. + The `module` argument is omitted as it breaks <3.9.8, =3.10.0 and isn't used in the calls below. + + See https://github.com/python/cpython/pull/28560 for some background. + The backport happened on 3.9.8, see: + https://github.com/pydantic/pydantic/discussions/6244#discussioncomment-6275458, + and on 3.10.1 for the 3.10 branch, see: + https://github.com/pydantic/pydantic/issues/6912 + + Implemented as EAFP with memory. + """ + return typing.ForwardRef(arg, is_argument) + +else: + _make_forward_ref = typing.ForwardRef + + +if sys.version_info >= (3, 10): + get_type_hints = typing.get_type_hints + +else: + """ + For older versions of python, we have a custom implementation of `get_type_hints` which is a close as possible to + the implementation in CPython 3.10.8. + """ + + @typing.no_type_check + def get_type_hints( # noqa: C901 + obj: Any, + globalns: dict[str, Any] | None = None, + localns: dict[str, Any] | None = None, + include_extras: bool = False, + ) -> dict[str, Any]: # pragma: no cover + """Taken verbatim from python 3.10.8 unchanged, except: + * type annotations of the function definition above. + * prefixing `typing.` where appropriate + * Use `_make_forward_ref` instead of `typing.ForwardRef` to handle the `is_class` argument. + + https://github.com/python/cpython/blob/aaaf5174241496afca7ce4d4584570190ff972fe/Lib/typing.py#L1773-L1875 + + DO NOT CHANGE THIS METHOD UNLESS ABSOLUTELY NECESSARY. + ====================================================== + + Return type hints for an object. + + This is often the same as obj.__annotations__, but it handles + forward references encoded as string literals, adds Optional[t] if a + default value equal to None is set and recursively replaces all + 'Annotated[T, ...]' with 'T' (unless 'include_extras=True'). + + The argument may be a module, class, method, or function. The annotations + are returned as a dictionary. For classes, annotations include also + inherited members. + + TypeError is raised if the argument is not of a type that can contain + annotations, and an empty dictionary is returned if no annotations are + present. + + BEWARE -- the behavior of globalns and localns is counterintuitive + (unless you are familiar with how eval() and exec() work). The + search order is locals first, then globals. + + - If no dict arguments are passed, an attempt is made to use the + globals from obj (or the respective module's globals for classes), + and these are also used as the locals. If the object does not appear + to have globals, an empty dictionary is used. For classes, the search + order is globals first then locals. + + - If one dict argument is passed, it is used for both globals and + locals. + + - If two dict arguments are passed, they specify globals and + locals, respectively. + """ + if getattr(obj, '__no_type_check__', None): + return {} + # Classes require a special treatment. + if isinstance(obj, type): + hints = {} + for base in reversed(obj.__mro__): + if globalns is None: + base_globals = getattr(sys.modules.get(base.__module__, None), '__dict__', {}) + else: + base_globals = globalns + ann = base.__dict__.get('__annotations__', {}) + if isinstance(ann, types.GetSetDescriptorType): + ann = {} + base_locals = dict(vars(base)) if localns is None else localns + if localns is None and globalns is None: + # This is surprising, but required. Before Python 3.10, + # get_type_hints only evaluated the globalns of + # a class. To maintain backwards compatibility, we reverse + # the globalns and localns order so that eval() looks into + # *base_globals* first rather than *base_locals*. + # This only affects ForwardRefs. + base_globals, base_locals = base_locals, base_globals + for name, value in ann.items(): + if value is None: + value = type(None) + if isinstance(value, str): + value = _make_forward_ref(value, is_argument=False, is_class=True) + + value = eval_type_backport(value, base_globals, base_locals) + hints[name] = value + if not include_extras and hasattr(typing, '_strip_annotations'): + return { + k: typing._strip_annotations(t) # type: ignore + for k, t in hints.items() + } + else: + return hints + + if globalns is None: + if isinstance(obj, types.ModuleType): + globalns = obj.__dict__ + else: + nsobj = obj + # Find globalns for the unwrapped object. + while hasattr(nsobj, '__wrapped__'): + nsobj = nsobj.__wrapped__ + globalns = getattr(nsobj, '__globals__', {}) + if localns is None: + localns = globalns + elif localns is None: + localns = globalns + hints = getattr(obj, '__annotations__', None) + if hints is None: + # Return empty annotations for something that _could_ have them. + if isinstance(obj, typing._allowed_types): # type: ignore + return {} + else: + raise TypeError(f'{obj!r} is not a module, class, method, ' 'or function.') + defaults = typing._get_defaults(obj) # type: ignore + hints = dict(hints) + for name, value in hints.items(): + if value is None: + value = type(None) + if isinstance(value, str): + # class-level forward refs were handled above, this must be either + # a module-level annotation or a function argument annotation + + value = _make_forward_ref( + value, + is_argument=not isinstance(obj, types.ModuleType), + is_class=False, + ) + value = eval_type_backport(value, globalns, localns) + if name in defaults and defaults[name] is None: + value = typing.Optional[value] + hints[name] = value + return hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore + + +def is_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]: + # The dataclasses.is_dataclass function doesn't seem to provide TypeGuard functionality, + # so I created this convenience function + return dataclasses.is_dataclass(_cls) + + +def origin_is_type_alias_type(origin: Any) -> TypeGuard[TypeAliasType]: + return isinstance(origin, TypeAliasType) + + +if sys.version_info >= (3, 10): + + def is_generic_alias(type_: type[Any]) -> bool: + return isinstance(type_, (types.GenericAlias, typing._GenericAlias)) # type: ignore[attr-defined] + +else: + + def is_generic_alias(type_: type[Any]) -> bool: + return isinstance(type_, typing._GenericAlias) # type: ignore diff --git a/lib/pydantic/_internal/_utils.py b/lib/pydantic/_internal/_utils.py new file mode 100644 index 00000000..31f5b2c5 --- /dev/null +++ b/lib/pydantic/_internal/_utils.py @@ -0,0 +1,362 @@ +"""Bucket of reusable internal utilities. + +This should be reduced as much as possible with functions only used in one place, moved to that place. +""" +from __future__ import annotations as _annotations + +import dataclasses +import keyword +import typing +import weakref +from collections import OrderedDict, defaultdict, deque +from copy import deepcopy +from itertools import zip_longest +from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType +from typing import Any, Mapping, TypeVar + +from typing_extensions import TypeAlias, TypeGuard + +from . import _repr, _typing_extra + +if typing.TYPE_CHECKING: + MappingIntStrAny: TypeAlias = 'typing.Mapping[int, Any] | typing.Mapping[str, Any]' + AbstractSetIntStr: TypeAlias = 'typing.AbstractSet[int] | typing.AbstractSet[str]' + from ..main import BaseModel + + +# these are types that are returned unchanged by deepcopy +IMMUTABLE_NON_COLLECTIONS_TYPES: set[type[Any]] = { + int, + float, + complex, + str, + bool, + bytes, + type, + _typing_extra.NoneType, + FunctionType, + BuiltinFunctionType, + LambdaType, + weakref.ref, + CodeType, + # note: including ModuleType will differ from behaviour of deepcopy by not producing error. + # It might be not a good idea in general, but considering that this function used only internally + # against default values of fields, this will allow to actually have a field with module as default value + ModuleType, + NotImplemented.__class__, + Ellipsis.__class__, +} + +# these are types that if empty, might be copied with simple copy() instead of deepcopy() +BUILTIN_COLLECTIONS: set[type[Any]] = { + list, + set, + tuple, + frozenset, + dict, + OrderedDict, + defaultdict, + deque, +} + + +def sequence_like(v: Any) -> bool: + return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque)) + + +def lenient_isinstance(o: Any, class_or_tuple: type[Any] | tuple[type[Any], ...] | None) -> bool: # pragma: no cover + try: + return isinstance(o, class_or_tuple) # type: ignore[arg-type] + except TypeError: + return False + + +def lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover + try: + return isinstance(cls, type) and issubclass(cls, class_or_tuple) + except TypeError: + if isinstance(cls, _typing_extra.WithArgsTypes): + return False + raise # pragma: no cover + + +def is_model_class(cls: Any) -> TypeGuard[type[BaseModel]]: + """Returns true if cls is a _proper_ subclass of BaseModel, and provides proper type-checking, + unlike raw calls to lenient_issubclass. + """ + from ..main import BaseModel + + return lenient_issubclass(cls, BaseModel) and cls is not BaseModel + + +def is_valid_identifier(identifier: str) -> bool: + """Checks that a string is a valid identifier and not a Python keyword. + :param identifier: The identifier to test. + :return: True if the identifier is valid. + """ + return identifier.isidentifier() and not keyword.iskeyword(identifier) + + +KeyType = TypeVar('KeyType') + + +def deep_update(mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]) -> dict[KeyType, Any]: + updated_mapping = mapping.copy() + for updating_mapping in updating_mappings: + for k, v in updating_mapping.items(): + if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict): + updated_mapping[k] = deep_update(updated_mapping[k], v) + else: + updated_mapping[k] = v + return updated_mapping + + +def update_not_none(mapping: dict[Any, Any], **update: Any) -> None: + mapping.update({k: v for k, v in update.items() if v is not None}) + + +T = TypeVar('T') + + +def unique_list( + input_list: list[T] | tuple[T, ...], + *, + name_factory: typing.Callable[[T], str] = str, +) -> list[T]: + """Make a list unique while maintaining order. + We update the list if another one with the same name is set + (e.g. model validator overridden in subclass). + """ + result: list[T] = [] + result_names: list[str] = [] + for v in input_list: + v_name = name_factory(v) + if v_name not in result_names: + result_names.append(v_name) + result.append(v) + else: + result[result_names.index(v_name)] = v + + return result + + +class ValueItems(_repr.Representation): + """Class for more convenient calculation of excluded or included fields on values.""" + + __slots__ = ('_items', '_type') + + def __init__(self, value: Any, items: AbstractSetIntStr | MappingIntStrAny) -> None: + items = self._coerce_items(items) + + if isinstance(value, (list, tuple)): + items = self._normalize_indexes(items, len(value)) # type: ignore + + self._items: MappingIntStrAny = items # type: ignore + + def is_excluded(self, item: Any) -> bool: + """Check if item is fully excluded. + + :param item: key or index of a value + """ + return self.is_true(self._items.get(item)) + + def is_included(self, item: Any) -> bool: + """Check if value is contained in self._items. + + :param item: key or index of value + """ + return item in self._items + + def for_element(self, e: int | str) -> AbstractSetIntStr | MappingIntStrAny | None: + """:param e: key or index of element on value + :return: raw values for element if self._items is dict and contain needed element + """ + item = self._items.get(e) # type: ignore + return item if not self.is_true(item) else None + + def _normalize_indexes(self, items: MappingIntStrAny, v_length: int) -> dict[int | str, Any]: + """:param items: dict or set of indexes which will be normalized + :param v_length: length of sequence indexes of which will be + + >>> self._normalize_indexes({0: True, -2: True, -1: True}, 4) + {0: True, 2: True, 3: True} + >>> self._normalize_indexes({'__all__': True}, 4) + {0: True, 1: True, 2: True, 3: True} + """ + normalized_items: dict[int | str, Any] = {} + all_items = None + for i, v in items.items(): + if not (isinstance(v, typing.Mapping) or isinstance(v, typing.AbstractSet) or self.is_true(v)): + raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}') + if i == '__all__': + all_items = self._coerce_value(v) + continue + if not isinstance(i, int): + raise TypeError( + 'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: ' + 'expected integer keys or keyword "__all__"' + ) + normalized_i = v_length + i if i < 0 else i + normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i)) + + if not all_items: + return normalized_items + if self.is_true(all_items): + for i in range(v_length): + normalized_items.setdefault(i, ...) + return normalized_items + for i in range(v_length): + normalized_item = normalized_items.setdefault(i, {}) + if not self.is_true(normalized_item): + normalized_items[i] = self.merge(all_items, normalized_item) + return normalized_items + + @classmethod + def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any: + """Merge a `base` item with an `override` item. + + Both `base` and `override` are converted to dictionaries if possible. + Sets are converted to dictionaries with the sets entries as keys and + Ellipsis as values. + + Each key-value pair existing in `base` is merged with `override`, + while the rest of the key-value pairs are updated recursively with this function. + + Merging takes place based on the "union" of keys if `intersect` is + set to `False` (default) and on the intersection of keys if + `intersect` is set to `True`. + """ + override = cls._coerce_value(override) + base = cls._coerce_value(base) + if override is None: + return base + if cls.is_true(base) or base is None: + return override + if cls.is_true(override): + return base if intersect else override + + # intersection or union of keys while preserving ordering: + if intersect: + merge_keys = [k for k in base if k in override] + [k for k in override if k in base] + else: + merge_keys = list(base) + [k for k in override if k not in base] + + merged: dict[int | str, Any] = {} + for k in merge_keys: + merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect) + if merged_item is not None: + merged[k] = merged_item + + return merged + + @staticmethod + def _coerce_items(items: AbstractSetIntStr | MappingIntStrAny) -> MappingIntStrAny: + if isinstance(items, typing.Mapping): + pass + elif isinstance(items, typing.AbstractSet): + items = dict.fromkeys(items, ...) # type: ignore + else: + class_name = getattr(items, '__class__', '???') + raise TypeError(f'Unexpected type of exclude value {class_name}') + return items # type: ignore + + @classmethod + def _coerce_value(cls, value: Any) -> Any: + if value is None or cls.is_true(value): + return value + return cls._coerce_items(value) + + @staticmethod + def is_true(v: Any) -> bool: + return v is True or v is ... + + def __repr_args__(self) -> _repr.ReprArgs: + return [(None, self._items)] + + +if typing.TYPE_CHECKING: + + def ClassAttribute(name: str, value: T) -> T: + ... + +else: + + class ClassAttribute: + """Hide class attribute from its instances.""" + + __slots__ = 'name', 'value' + + def __init__(self, name: str, value: Any) -> None: + self.name = name + self.value = value + + def __get__(self, instance: Any, owner: type[Any]) -> None: + if instance is None: + return self.value + raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only') + + +Obj = TypeVar('Obj') + + +def smart_deepcopy(obj: Obj) -> Obj: + """Return type as is for immutable built-in types + Use obj.copy() for built-in empty collections + Use copy.deepcopy() for non-empty collections and unknown objects. + """ + obj_type = obj.__class__ + if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES: + return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway + try: + if not obj and obj_type in BUILTIN_COLLECTIONS: + # faster way for empty collections, no need to copy its members + return obj if obj_type is tuple else obj.copy() # tuple doesn't have copy method # type: ignore + except (TypeError, ValueError, RuntimeError): + # do we really dare to catch ALL errors? Seems a bit risky + pass + + return deepcopy(obj) # slowest way when we actually might need a deepcopy + + +_SENTINEL = object() + + +def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bool: + """Check that the items of `left` are the same objects as those in `right`. + + >>> a, b = object(), object() + >>> all_identical([a, b, a], [a, b, a]) + True + >>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical" + False + """ + for left_item, right_item in zip_longest(left, right, fillvalue=_SENTINEL): + if left_item is not right_item: + return False + return True + + +@dataclasses.dataclass(frozen=True) +class SafeGetItemProxy: + """Wrapper redirecting `__getitem__` to `get` with a sentinel value as default + + This makes is safe to use in `operator.itemgetter` when some keys may be missing + """ + + # Define __slots__manually for performances + # @dataclasses.dataclass() only support slots=True in python>=3.10 + __slots__ = ('wrapped',) + + wrapped: Mapping[str, Any] + + def __getitem__(self, __key: str) -> Any: + return self.wrapped.get(__key, _SENTINEL) + + # required to pass the object to operator.itemgetter() instances due to a quirk of typeshed + # https://github.com/python/mypy/issues/13713 + # https://github.com/python/typeshed/pull/8785 + # Since this is typing-only, hide it in a typing.TYPE_CHECKING block + if typing.TYPE_CHECKING: + + def __contains__(self, __key: str) -> bool: + return self.wrapped.__contains__(__key) diff --git a/lib/pydantic/_internal/_validate_call.py b/lib/pydantic/_internal/_validate_call.py new file mode 100644 index 00000000..664c0630 --- /dev/null +++ b/lib/pydantic/_internal/_validate_call.py @@ -0,0 +1,84 @@ +from __future__ import annotations as _annotations + +import inspect +from functools import partial +from typing import Any, Awaitable, Callable + +import pydantic_core + +from ..config import ConfigDict +from ..plugin._schema_validator import create_schema_validator +from . import _generate_schema, _typing_extra +from ._config import ConfigWrapper + + +class ValidateCallWrapper: + """This is a wrapper around a function that validates the arguments passed to it, and optionally the return value.""" + + __slots__ = ( + '__pydantic_validator__', + '__name__', + '__qualname__', + '__annotations__', + '__dict__', # required for __module__ + ) + + def __init__(self, function: Callable[..., Any], config: ConfigDict | None, validate_return: bool): + if isinstance(function, partial): + func = function.func + schema_type = func + self.__name__ = f'partial({func.__name__})' + self.__qualname__ = f'partial({func.__qualname__})' + self.__module__ = func.__module__ + else: + schema_type = function + self.__name__ = function.__name__ + self.__qualname__ = function.__qualname__ + self.__module__ = function.__module__ + + namespace = _typing_extra.add_module_globals(function, None) + config_wrapper = ConfigWrapper(config) + gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace) + schema = gen_schema.clean_schema(gen_schema.generate_schema(function)) + core_config = config_wrapper.core_config(self) + + self.__pydantic_validator__ = create_schema_validator( + schema, + schema_type, + self.__module__, + self.__qualname__, + 'validate_call', + core_config, + config_wrapper.plugin_settings, + ) + + if validate_return: + signature = inspect.signature(function) + return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any + gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace) + schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type)) + validator = create_schema_validator( + schema, + schema_type, + self.__module__, + self.__qualname__, + 'validate_call', + core_config, + config_wrapper.plugin_settings, + ) + if inspect.iscoroutinefunction(function): + + async def return_val_wrapper(aw: Awaitable[Any]) -> None: + return validator.validate_python(await aw) + + self.__return_pydantic_validator__ = return_val_wrapper + else: + self.__return_pydantic_validator__ = validator.validate_python + else: + self.__return_pydantic_validator__ = None + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs)) + if self.__return_pydantic_validator__: + return self.__return_pydantic_validator__(res) + return res diff --git a/lib/pydantic/_internal/_validators.py b/lib/pydantic/_internal/_validators.py new file mode 100644 index 00000000..7193fe5c --- /dev/null +++ b/lib/pydantic/_internal/_validators.py @@ -0,0 +1,278 @@ +"""Validator functions for standard library types. + +Import of this module is deferred since it contains imports of many standard library modules. +""" + +from __future__ import annotations as _annotations + +import math +import re +import typing +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from typing import Any + +from pydantic_core import PydanticCustomError, core_schema +from pydantic_core._pydantic_core import PydanticKnownError + + +def sequence_validator( + __input_value: typing.Sequence[Any], + validator: core_schema.ValidatorFunctionWrapHandler, +) -> typing.Sequence[Any]: + """Validator for `Sequence` types, isinstance(v, Sequence) has already been called.""" + value_type = type(__input_value) + + # We don't accept any plain string as a sequence + # Relevant issue: https://github.com/pydantic/pydantic/issues/5595 + if issubclass(value_type, (str, bytes)): + raise PydanticCustomError( + 'sequence_str', + "'{type_name}' instances are not allowed as a Sequence value", + {'type_name': value_type.__name__}, + ) + + v_list = validator(__input_value) + + # the rest of the logic is just re-creating the original type from `v_list` + if value_type == list: + return v_list + elif issubclass(value_type, range): + # return the list as we probably can't re-create the range + return v_list + else: + # best guess at how to re-create the original type, more custom construction logic might be required + return value_type(v_list) # type: ignore[call-arg] + + +def import_string(value: Any) -> Any: + if isinstance(value, str): + try: + return _import_string_logic(value) + except ImportError as e: + raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) from e + else: + # otherwise we just return the value and let the next validator do the rest of the work + return value + + +def _import_string_logic(dotted_path: str) -> Any: + """Inspired by uvicorn — dotted paths should include a colon before the final item if that item is not a module. + (This is necessary to distinguish between a submodule and an attribute when there is a conflict.). + + If the dotted path does not include a colon and the final item is not a valid module, importing as an attribute + rather than a submodule will be attempted automatically. + + So, for example, the following values of `dotted_path` result in the following returned values: + * 'collections': + * 'collections.abc': + * 'collections.abc:Mapping': + * `collections.abc.Mapping`: (though this is a bit slower than the previous line) + + An error will be raised under any of the following scenarios: + * `dotted_path` contains more than one colon (e.g., 'collections:abc:Mapping') + * the substring of `dotted_path` before the colon is not a valid module in the environment (e.g., '123:Mapping') + * the substring of `dotted_path` after the colon is not an attribute of the module (e.g., 'collections:abc123') + """ + from importlib import import_module + + components = dotted_path.strip().split(':') + if len(components) > 2: + raise ImportError(f"Import strings should have at most one ':'; received {dotted_path!r}") + + module_path = components[0] + if not module_path: + raise ImportError(f'Import strings should have a nonempty module name; received {dotted_path!r}') + + try: + module = import_module(module_path) + except ModuleNotFoundError as e: + if '.' in module_path: + # Check if it would be valid if the final item was separated from its module with a `:` + maybe_module_path, maybe_attribute = dotted_path.strip().rsplit('.', 1) + try: + return _import_string_logic(f'{maybe_module_path}:{maybe_attribute}') + except ImportError: + pass + raise ImportError(f'No module named {module_path!r}') from e + raise e + + if len(components) > 1: + attribute = components[1] + try: + return getattr(module, attribute) + except AttributeError as e: + raise ImportError(f'cannot import name {attribute!r} from {module_path!r}') from e + else: + return module + + +def pattern_either_validator(__input_value: Any) -> typing.Pattern[Any]: + if isinstance(__input_value, typing.Pattern): + return __input_value + elif isinstance(__input_value, (str, bytes)): + # todo strict mode + return compile_pattern(__input_value) # type: ignore + else: + raise PydanticCustomError('pattern_type', 'Input should be a valid pattern') + + +def pattern_str_validator(__input_value: Any) -> typing.Pattern[str]: + if isinstance(__input_value, typing.Pattern): + if isinstance(__input_value.pattern, str): + return __input_value + else: + raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern') + elif isinstance(__input_value, str): + return compile_pattern(__input_value) + elif isinstance(__input_value, bytes): + raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern') + else: + raise PydanticCustomError('pattern_type', 'Input should be a valid pattern') + + +def pattern_bytes_validator(__input_value: Any) -> typing.Pattern[bytes]: + if isinstance(__input_value, typing.Pattern): + if isinstance(__input_value.pattern, bytes): + return __input_value + else: + raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern') + elif isinstance(__input_value, bytes): + return compile_pattern(__input_value) + elif isinstance(__input_value, str): + raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern') + else: + raise PydanticCustomError('pattern_type', 'Input should be a valid pattern') + + +PatternType = typing.TypeVar('PatternType', str, bytes) + + +def compile_pattern(pattern: PatternType) -> typing.Pattern[PatternType]: + try: + return re.compile(pattern) + except re.error: + raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression') + + +def ip_v4_address_validator(__input_value: Any) -> IPv4Address: + if isinstance(__input_value, IPv4Address): + return __input_value + + try: + return IPv4Address(__input_value) + except ValueError: + raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address') + + +def ip_v6_address_validator(__input_value: Any) -> IPv6Address: + if isinstance(__input_value, IPv6Address): + return __input_value + + try: + return IPv6Address(__input_value) + except ValueError: + raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address') + + +def ip_v4_network_validator(__input_value: Any) -> IPv4Network: + """Assume IPv4Network initialised with a default `strict` argument. + + See more: + https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network + """ + if isinstance(__input_value, IPv4Network): + return __input_value + + try: + return IPv4Network(__input_value) + except ValueError: + raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network') + + +def ip_v6_network_validator(__input_value: Any) -> IPv6Network: + """Assume IPv6Network initialised with a default `strict` argument. + + See more: + https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network + """ + if isinstance(__input_value, IPv6Network): + return __input_value + + try: + return IPv6Network(__input_value) + except ValueError: + raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network') + + +def ip_v4_interface_validator(__input_value: Any) -> IPv4Interface: + if isinstance(__input_value, IPv4Interface): + return __input_value + + try: + return IPv4Interface(__input_value) + except ValueError: + raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface') + + +def ip_v6_interface_validator(__input_value: Any) -> IPv6Interface: + if isinstance(__input_value, IPv6Interface): + return __input_value + + try: + return IPv6Interface(__input_value) + except ValueError: + raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface') + + +def greater_than_validator(x: Any, gt: Any) -> Any: + if not (x > gt): + raise PydanticKnownError('greater_than', {'gt': gt}) + return x + + +def greater_than_or_equal_validator(x: Any, ge: Any) -> Any: + if not (x >= ge): + raise PydanticKnownError('greater_than_equal', {'ge': ge}) + return x + + +def less_than_validator(x: Any, lt: Any) -> Any: + if not (x < lt): + raise PydanticKnownError('less_than', {'lt': lt}) + return x + + +def less_than_or_equal_validator(x: Any, le: Any) -> Any: + if not (x <= le): + raise PydanticKnownError('less_than_equal', {'le': le}) + return x + + +def multiple_of_validator(x: Any, multiple_of: Any) -> Any: + if not (x % multiple_of == 0): + raise PydanticKnownError('multiple_of', {'multiple_of': multiple_of}) + return x + + +def min_length_validator(x: Any, min_length: Any) -> Any: + if not (len(x) >= min_length): + raise PydanticKnownError( + 'too_short', + {'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)}, + ) + return x + + +def max_length_validator(x: Any, max_length: Any) -> Any: + if len(x) > max_length: + raise PydanticKnownError( + 'too_long', + {'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)}, + ) + return x + + +def forbid_inf_nan_check(x: Any) -> Any: + if not math.isfinite(x): + raise PydanticKnownError('finite_number') + return x diff --git a/lib/pydantic/_migration.py b/lib/pydantic/_migration.py new file mode 100644 index 00000000..c8478a62 --- /dev/null +++ b/lib/pydantic/_migration.py @@ -0,0 +1,308 @@ +import sys +from typing import Any, Callable, Dict + +from .version import version_short + +MOVED_IN_V2 = { + 'pydantic.utils:version_info': 'pydantic.version:version_info', + 'pydantic.error_wrappers:ValidationError': 'pydantic:ValidationError', + 'pydantic.utils:to_camel': 'pydantic.alias_generators:to_pascal', + 'pydantic.utils:to_lower_camel': 'pydantic.alias_generators:to_camel', + 'pydantic:PyObject': 'pydantic.types:ImportString', + 'pydantic.types:PyObject': 'pydantic.types:ImportString', + 'pydantic.generics:GenericModel': 'pydantic.BaseModel', +} + +DEPRECATED_MOVED_IN_V2 = { + 'pydantic.tools:schema_of': 'pydantic.deprecated.tools:schema_of', + 'pydantic.tools:parse_obj_as': 'pydantic.deprecated.tools:parse_obj_as', + 'pydantic.tools:schema_json_of': 'pydantic.deprecated.tools:schema_json_of', + 'pydantic.json:pydantic_encoder': 'pydantic.deprecated.json:pydantic_encoder', + 'pydantic:validate_arguments': 'pydantic.deprecated.decorator:validate_arguments', + 'pydantic.json:custom_pydantic_encoder': 'pydantic.deprecated.json:custom_pydantic_encoder', + 'pydantic.json:timedelta_isoformat': 'pydantic.deprecated.json:timedelta_isoformat', + 'pydantic.decorator:validate_arguments': 'pydantic.deprecated.decorator:validate_arguments', + 'pydantic.class_validators:validator': 'pydantic.deprecated.class_validators:validator', + 'pydantic.class_validators:root_validator': 'pydantic.deprecated.class_validators:root_validator', + 'pydantic.config:BaseConfig': 'pydantic.deprecated.config:BaseConfig', + 'pydantic.config:Extra': 'pydantic.deprecated.config:Extra', +} + +REDIRECT_TO_V1 = { + f'pydantic.utils:{obj}': f'pydantic.v1.utils:{obj}' + for obj in ( + 'deep_update', + 'GetterDict', + 'lenient_issubclass', + 'lenient_isinstance', + 'is_valid_field', + 'update_not_none', + 'import_string', + 'Representation', + 'ROOT_KEY', + 'smart_deepcopy', + 'sequence_like', + ) +} + + +REMOVED_IN_V2 = { + 'pydantic:ConstrainedBytes', + 'pydantic:ConstrainedDate', + 'pydantic:ConstrainedDecimal', + 'pydantic:ConstrainedFloat', + 'pydantic:ConstrainedFrozenSet', + 'pydantic:ConstrainedInt', + 'pydantic:ConstrainedList', + 'pydantic:ConstrainedSet', + 'pydantic:ConstrainedStr', + 'pydantic:JsonWrapper', + 'pydantic:NoneBytes', + 'pydantic:NoneStr', + 'pydantic:NoneStrBytes', + 'pydantic:Protocol', + 'pydantic:Required', + 'pydantic:StrBytes', + 'pydantic:compiled', + 'pydantic.config:get_config', + 'pydantic.config:inherit_config', + 'pydantic.config:prepare_config', + 'pydantic:create_model_from_namedtuple', + 'pydantic:create_model_from_typeddict', + 'pydantic.dataclasses:create_pydantic_model_from_dataclass', + 'pydantic.dataclasses:make_dataclass_validator', + 'pydantic.dataclasses:set_validation', + 'pydantic.datetime_parse:parse_date', + 'pydantic.datetime_parse:parse_time', + 'pydantic.datetime_parse:parse_datetime', + 'pydantic.datetime_parse:parse_duration', + 'pydantic.error_wrappers:ErrorWrapper', + 'pydantic.errors:AnyStrMaxLengthError', + 'pydantic.errors:AnyStrMinLengthError', + 'pydantic.errors:ArbitraryTypeError', + 'pydantic.errors:BoolError', + 'pydantic.errors:BytesError', + 'pydantic.errors:CallableError', + 'pydantic.errors:ClassError', + 'pydantic.errors:ColorError', + 'pydantic.errors:ConfigError', + 'pydantic.errors:DataclassTypeError', + 'pydantic.errors:DateError', + 'pydantic.errors:DateNotInTheFutureError', + 'pydantic.errors:DateNotInThePastError', + 'pydantic.errors:DateTimeError', + 'pydantic.errors:DecimalError', + 'pydantic.errors:DecimalIsNotFiniteError', + 'pydantic.errors:DecimalMaxDigitsError', + 'pydantic.errors:DecimalMaxPlacesError', + 'pydantic.errors:DecimalWholeDigitsError', + 'pydantic.errors:DictError', + 'pydantic.errors:DurationError', + 'pydantic.errors:EmailError', + 'pydantic.errors:EnumError', + 'pydantic.errors:EnumMemberError', + 'pydantic.errors:ExtraError', + 'pydantic.errors:FloatError', + 'pydantic.errors:FrozenSetError', + 'pydantic.errors:FrozenSetMaxLengthError', + 'pydantic.errors:FrozenSetMinLengthError', + 'pydantic.errors:HashableError', + 'pydantic.errors:IPv4AddressError', + 'pydantic.errors:IPv4InterfaceError', + 'pydantic.errors:IPv4NetworkError', + 'pydantic.errors:IPv6AddressError', + 'pydantic.errors:IPv6InterfaceError', + 'pydantic.errors:IPv6NetworkError', + 'pydantic.errors:IPvAnyAddressError', + 'pydantic.errors:IPvAnyInterfaceError', + 'pydantic.errors:IPvAnyNetworkError', + 'pydantic.errors:IntEnumError', + 'pydantic.errors:IntegerError', + 'pydantic.errors:InvalidByteSize', + 'pydantic.errors:InvalidByteSizeUnit', + 'pydantic.errors:InvalidDiscriminator', + 'pydantic.errors:InvalidLengthForBrand', + 'pydantic.errors:JsonError', + 'pydantic.errors:JsonTypeError', + 'pydantic.errors:ListError', + 'pydantic.errors:ListMaxLengthError', + 'pydantic.errors:ListMinLengthError', + 'pydantic.errors:ListUniqueItemsError', + 'pydantic.errors:LuhnValidationError', + 'pydantic.errors:MissingDiscriminator', + 'pydantic.errors:MissingError', + 'pydantic.errors:NoneIsAllowedError', + 'pydantic.errors:NoneIsNotAllowedError', + 'pydantic.errors:NotDigitError', + 'pydantic.errors:NotNoneError', + 'pydantic.errors:NumberNotGeError', + 'pydantic.errors:NumberNotGtError', + 'pydantic.errors:NumberNotLeError', + 'pydantic.errors:NumberNotLtError', + 'pydantic.errors:NumberNotMultipleError', + 'pydantic.errors:PathError', + 'pydantic.errors:PathNotADirectoryError', + 'pydantic.errors:PathNotAFileError', + 'pydantic.errors:PathNotExistsError', + 'pydantic.errors:PatternError', + 'pydantic.errors:PyObjectError', + 'pydantic.errors:PydanticTypeError', + 'pydantic.errors:PydanticValueError', + 'pydantic.errors:SequenceError', + 'pydantic.errors:SetError', + 'pydantic.errors:SetMaxLengthError', + 'pydantic.errors:SetMinLengthError', + 'pydantic.errors:StrError', + 'pydantic.errors:StrRegexError', + 'pydantic.errors:StrictBoolError', + 'pydantic.errors:SubclassError', + 'pydantic.errors:TimeError', + 'pydantic.errors:TupleError', + 'pydantic.errors:TupleLengthError', + 'pydantic.errors:UUIDError', + 'pydantic.errors:UUIDVersionError', + 'pydantic.errors:UrlError', + 'pydantic.errors:UrlExtraError', + 'pydantic.errors:UrlHostError', + 'pydantic.errors:UrlHostTldError', + 'pydantic.errors:UrlPortError', + 'pydantic.errors:UrlSchemeError', + 'pydantic.errors:UrlSchemePermittedError', + 'pydantic.errors:UrlUserInfoError', + 'pydantic.errors:WrongConstantError', + 'pydantic.main:validate_model', + 'pydantic.networks:stricturl', + 'pydantic:parse_file_as', + 'pydantic:parse_raw_as', + 'pydantic:stricturl', + 'pydantic.tools:parse_file_as', + 'pydantic.tools:parse_raw_as', + 'pydantic.types:ConstrainedBytes', + 'pydantic.types:ConstrainedDate', + 'pydantic.types:ConstrainedDecimal', + 'pydantic.types:ConstrainedFloat', + 'pydantic.types:ConstrainedFrozenSet', + 'pydantic.types:ConstrainedInt', + 'pydantic.types:ConstrainedList', + 'pydantic.types:ConstrainedSet', + 'pydantic.types:ConstrainedStr', + 'pydantic.types:JsonWrapper', + 'pydantic.types:NoneBytes', + 'pydantic.types:NoneStr', + 'pydantic.types:NoneStrBytes', + 'pydantic.types:StrBytes', + 'pydantic.typing:evaluate_forwardref', + 'pydantic.typing:AbstractSetIntStr', + 'pydantic.typing:AnyCallable', + 'pydantic.typing:AnyClassMethod', + 'pydantic.typing:CallableGenerator', + 'pydantic.typing:DictAny', + 'pydantic.typing:DictIntStrAny', + 'pydantic.typing:DictStrAny', + 'pydantic.typing:IntStr', + 'pydantic.typing:ListStr', + 'pydantic.typing:MappingIntStrAny', + 'pydantic.typing:NoArgAnyCallable', + 'pydantic.typing:NoneType', + 'pydantic.typing:ReprArgs', + 'pydantic.typing:SetStr', + 'pydantic.typing:StrPath', + 'pydantic.typing:TupleGenerator', + 'pydantic.typing:WithArgsTypes', + 'pydantic.typing:all_literal_values', + 'pydantic.typing:display_as_type', + 'pydantic.typing:get_all_type_hints', + 'pydantic.typing:get_args', + 'pydantic.typing:get_origin', + 'pydantic.typing:get_sub_types', + 'pydantic.typing:is_callable_type', + 'pydantic.typing:is_classvar', + 'pydantic.typing:is_finalvar', + 'pydantic.typing:is_literal_type', + 'pydantic.typing:is_namedtuple', + 'pydantic.typing:is_new_type', + 'pydantic.typing:is_none_type', + 'pydantic.typing:is_typeddict', + 'pydantic.typing:is_typeddict_special', + 'pydantic.typing:is_union', + 'pydantic.typing:new_type_supertype', + 'pydantic.typing:resolve_annotations', + 'pydantic.typing:typing_base', + 'pydantic.typing:update_field_forward_refs', + 'pydantic.typing:update_model_forward_refs', + 'pydantic.utils:ClassAttribute', + 'pydantic.utils:DUNDER_ATTRIBUTES', + 'pydantic.utils:PyObjectStr', + 'pydantic.utils:ValueItems', + 'pydantic.utils:almost_equal_floats', + 'pydantic.utils:get_discriminator_alias_and_values', + 'pydantic.utils:get_model', + 'pydantic.utils:get_unique_discriminator_alias', + 'pydantic.utils:in_ipython', + 'pydantic.utils:is_valid_identifier', + 'pydantic.utils:path_type', + 'pydantic.utils:validate_field_name', + 'pydantic:validate_model', +} + + +def getattr_migration(module: str) -> Callable[[str], Any]: + """Implement PEP 562 for objects that were either moved or removed on the migration + to V2. + + Args: + module: The module name. + + Returns: + A callable that will raise an error if the object is not found. + """ + # This avoids circular import with errors.py. + from .errors import PydanticImportError + + def wrapper(name: str) -> object: + """Raise an error if the object is not found, or warn if it was moved. + + In case it was moved, it still returns the object. + + Args: + name: The object name. + + Returns: + The object. + """ + if name == '__path__': + raise AttributeError(f'module {module!r} has no attribute {name!r}') + + import warnings + + from ._internal._validators import import_string + + import_path = f'{module}:{name}' + if import_path in MOVED_IN_V2.keys(): + new_location = MOVED_IN_V2[import_path] + warnings.warn(f'`{import_path}` has been moved to `{new_location}`.') + return import_string(MOVED_IN_V2[import_path]) + if import_path in DEPRECATED_MOVED_IN_V2: + # skip the warning here because a deprecation warning will be raised elsewhere + return import_string(DEPRECATED_MOVED_IN_V2[import_path]) + if import_path in REDIRECT_TO_V1: + new_location = REDIRECT_TO_V1[import_path] + warnings.warn( + f'`{import_path}` has been removed. We are importing from `{new_location}` instead.' + 'See the migration guide for more details: https://docs.pydantic.dev/latest/migration/' + ) + return import_string(REDIRECT_TO_V1[import_path]) + if import_path == 'pydantic:BaseSettings': + raise PydanticImportError( + '`BaseSettings` has been moved to the `pydantic-settings` package. ' + f'See https://docs.pydantic.dev/{version_short()}/migration/#basesettings-has-moved-to-pydantic-settings ' + 'for more details.' + ) + if import_path in REMOVED_IN_V2: + raise PydanticImportError(f'`{import_path}` has been removed in V2.') + globals: Dict[str, Any] = sys.modules[module].__dict__ + if name in globals: + return globals[name] + raise AttributeError(f'module {module!r} has no attribute {name!r}') + + return wrapper diff --git a/lib/pydantic/alias_generators.py b/lib/pydantic/alias_generators.py new file mode 100644 index 00000000..155e66e0 --- /dev/null +++ b/lib/pydantic/alias_generators.py @@ -0,0 +1,50 @@ +"""Alias generators for converting between different capitalization conventions.""" +import re + +__all__ = ('to_pascal', 'to_camel', 'to_snake') + + +def to_pascal(snake: str) -> str: + """Convert a snake_case string to PascalCase. + + Args: + snake: The string to convert. + + Returns: + The PascalCase string. + """ + camel = snake.title() + return re.sub('([0-9A-Za-z])_(?=[0-9A-Z])', lambda m: m.group(1), camel) + + +def to_camel(snake: str) -> str: + """Convert a snake_case string to camelCase. + + Args: + snake: The string to convert. + + Returns: + The converted camelCase string. + """ + camel = to_pascal(snake) + return re.sub('(^_*[A-Z])', lambda m: m.group(1).lower(), camel) + + +def to_snake(camel: str) -> str: + """Convert a PascalCase or camelCase string to snake_case. + + Args: + camel: The string to convert. + + Returns: + The converted string in snake_case. + """ + # Handle the sequence of uppercase letters followed by a lowercase letter + snake = re.sub(r'([A-Z]+)([A-Z][a-z])', lambda m: f'{m.group(1)}_{m.group(2)}', camel) + # Insert an underscore between a lowercase letter and an uppercase letter + snake = re.sub(r'([a-z])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake) + # Insert an underscore between a digit and an uppercase letter + snake = re.sub(r'([0-9])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake) + # Insert an underscore between a lowercase letter and a digit + snake = re.sub(r'([a-z])([0-9])', lambda m: f'{m.group(1)}_{m.group(2)}', snake) + return snake.lower() diff --git a/lib/pydantic/aliases.py b/lib/pydantic/aliases.py new file mode 100644 index 00000000..b53557b1 --- /dev/null +++ b/lib/pydantic/aliases.py @@ -0,0 +1,112 @@ +"""Support for alias configurations.""" +from __future__ import annotations + +import dataclasses +from typing import Callable, Literal + +from ._internal import _internal_dataclass + +__all__ = ('AliasGenerator', 'AliasPath', 'AliasChoices') + + +@dataclasses.dataclass(**_internal_dataclass.slots_true) +class AliasPath: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/alias#aliaspath-and-aliaschoices + + A data class used by `validation_alias` as a convenience to create aliases. + + Attributes: + path: A list of string or integer aliases. + """ + + path: list[int | str] + + def __init__(self, first_arg: str, *args: str | int) -> None: + self.path = [first_arg] + list(args) + + def convert_to_aliases(self) -> list[str | int]: + """Converts arguments to a list of string or integer aliases. + + Returns: + The list of aliases. + """ + return self.path + + +@dataclasses.dataclass(**_internal_dataclass.slots_true) +class AliasChoices: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/alias#aliaspath-and-aliaschoices + + A data class used by `validation_alias` as a convenience to create aliases. + + Attributes: + choices: A list containing a string or `AliasPath`. + """ + + choices: list[str | AliasPath] + + def __init__(self, first_choice: str | AliasPath, *choices: str | AliasPath) -> None: + self.choices = [first_choice] + list(choices) + + def convert_to_aliases(self) -> list[list[str | int]]: + """Converts arguments to a list of lists containing string or integer aliases. + + Returns: + The list of aliases. + """ + aliases: list[list[str | int]] = [] + for c in self.choices: + if isinstance(c, AliasPath): + aliases.append(c.convert_to_aliases()) + else: + aliases.append([c]) + return aliases + + +@dataclasses.dataclass(**_internal_dataclass.slots_true) +class AliasGenerator: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/alias#using-an-aliasgenerator + + A data class used by `alias_generator` as a convenience to create various aliases. + + Attributes: + alias: A callable that takes a field name and returns an alias for it. + validation_alias: A callable that takes a field name and returns a validation alias for it. + serialization_alias: A callable that takes a field name and returns a serialization alias for it. + """ + + alias: Callable[[str], str] | None = None + validation_alias: Callable[[str], str | AliasPath | AliasChoices] | None = None + serialization_alias: Callable[[str], str] | None = None + + def _generate_alias( + self, + alias_kind: Literal['alias', 'validation_alias', 'serialization_alias'], + allowed_types: tuple[type[str] | type[AliasPath] | type[AliasChoices], ...], + field_name: str, + ) -> str | AliasPath | AliasChoices | None: + """Generate an alias of the specified kind. Returns None if the alias generator is None. + + Raises: + TypeError: If the alias generator produces an invalid type. + """ + alias = None + if alias_generator := getattr(self, alias_kind): + alias = alias_generator(field_name) + if alias and not isinstance(alias, allowed_types): + raise TypeError( + f'Invalid `{alias_kind}` type. `{alias_kind}` generator must produce one of `{allowed_types}`' + ) + return alias + + def generate_aliases(self, field_name: str) -> tuple[str | None, str | AliasPath | AliasChoices | None, str | None]: + """Generate `alias`, `validation_alias`, and `serialization_alias` for a field. + + Returns: + A tuple of three aliases - validation, alias, and serialization. + """ + alias = self._generate_alias('alias', (str,), field_name) + validation_alias = self._generate_alias('validation_alias', (str, AliasChoices, AliasPath), field_name) + serialization_alias = self._generate_alias('serialization_alias', (str,), field_name) + + return alias, validation_alias, serialization_alias # type: ignore diff --git a/lib/pydantic/annotated_handlers.py b/lib/pydantic/annotated_handlers.py new file mode 100644 index 00000000..081949a8 --- /dev/null +++ b/lib/pydantic/annotated_handlers.py @@ -0,0 +1,120 @@ +"""Type annotations to use with `__get_pydantic_core_schema__` and `__get_pydantic_json_schema__`.""" +from __future__ import annotations as _annotations + +from typing import TYPE_CHECKING, Any, Union + +from pydantic_core import core_schema + +if TYPE_CHECKING: + from .json_schema import JsonSchemaMode, JsonSchemaValue + + CoreSchemaOrField = Union[ + core_schema.CoreSchema, + core_schema.ModelField, + core_schema.DataclassField, + core_schema.TypedDictField, + core_schema.ComputedField, + ] + +__all__ = 'GetJsonSchemaHandler', 'GetCoreSchemaHandler' + + +class GetJsonSchemaHandler: + """Handler to call into the next JSON schema generation function. + + Attributes: + mode: Json schema mode, can be `validation` or `serialization`. + """ + + mode: JsonSchemaMode + + def __call__(self, __core_schema: CoreSchemaOrField) -> JsonSchemaValue: + """Call the inner handler and get the JsonSchemaValue it returns. + This will call the next JSON schema modifying function up until it calls + into `pydantic.json_schema.GenerateJsonSchema`, which will raise a + `pydantic.errors.PydanticInvalidForJsonSchema` error if it cannot generate + a JSON schema. + + Args: + __core_schema: A `pydantic_core.core_schema.CoreSchema`. + + Returns: + JsonSchemaValue: The JSON schema generated by the inner JSON schema modify + functions. + """ + raise NotImplementedError + + def resolve_ref_schema(self, __maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue: + """Get the real schema for a `{"$ref": ...}` schema. + If the schema given is not a `$ref` schema, it will be returned as is. + This means you don't have to check before calling this function. + + Args: + __maybe_ref_json_schema: A JsonSchemaValue which may be a `$ref` schema. + + Raises: + LookupError: If the ref is not found. + + Returns: + JsonSchemaValue: A JsonSchemaValue that has no `$ref`. + """ + raise NotImplementedError + + +class GetCoreSchemaHandler: + """Handler to call into the next CoreSchema schema generation function.""" + + def __call__(self, __source_type: Any) -> core_schema.CoreSchema: + """Call the inner handler and get the CoreSchema it returns. + This will call the next CoreSchema modifying function up until it calls + into Pydantic's internal schema generation machinery, which will raise a + `pydantic.errors.PydanticSchemaGenerationError` error if it cannot generate + a CoreSchema for the given source type. + + Args: + __source_type: The input type. + + Returns: + CoreSchema: The `pydantic-core` CoreSchema generated. + """ + raise NotImplementedError + + def generate_schema(self, __source_type: Any) -> core_schema.CoreSchema: + """Generate a schema unrelated to the current context. + Use this function if e.g. you are handling schema generation for a sequence + and want to generate a schema for its items. + Otherwise, you may end up doing something like applying a `min_length` constraint + that was intended for the sequence itself to its items! + + Args: + __source_type: The input type. + + Returns: + CoreSchema: The `pydantic-core` CoreSchema generated. + """ + raise NotImplementedError + + def resolve_ref_schema(self, __maybe_ref_schema: core_schema.CoreSchema) -> core_schema.CoreSchema: + """Get the real schema for a `definition-ref` schema. + If the schema given is not a `definition-ref` schema, it will be returned as is. + This means you don't have to check before calling this function. + + Args: + __maybe_ref_schema: A `CoreSchema`, `ref`-based or not. + + Raises: + LookupError: If the `ref` is not found. + + Returns: + A concrete `CoreSchema`. + """ + raise NotImplementedError + + @property + def field_name(self) -> str | None: + """Get the name of the closest field to this validator.""" + raise NotImplementedError + + def _get_types_namespace(self) -> dict[str, Any] | None: + """Internal method used during type resolution for serializer annotations.""" + raise NotImplementedError diff --git a/lib/pydantic/class_validators.py b/lib/pydantic/class_validators.py index 87190610..2ff72ae5 100644 --- a/lib/pydantic/class_validators.py +++ b/lib/pydantic/class_validators.py @@ -1,342 +1,4 @@ -import warnings -from collections import ChainMap -from functools import wraps -from itertools import chain -from types import FunctionType -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, overload +"""`class_validators` module is a backport module from V1.""" +from ._migration import getattr_migration -from .errors import ConfigError -from .typing import AnyCallable -from .utils import ROOT_KEY, in_ipython - -if TYPE_CHECKING: - from .typing import AnyClassMethod - - -class Validator: - __slots__ = 'func', 'pre', 'each_item', 'always', 'check_fields', 'skip_on_failure' - - def __init__( - self, - func: AnyCallable, - pre: bool = False, - each_item: bool = False, - always: bool = False, - check_fields: bool = False, - skip_on_failure: bool = False, - ): - self.func = func - self.pre = pre - self.each_item = each_item - self.always = always - self.check_fields = check_fields - self.skip_on_failure = skip_on_failure - - -if TYPE_CHECKING: - from inspect import Signature - - from .config import BaseConfig - from .fields import ModelField - from .types import ModelOrDc - - ValidatorCallable = Callable[[Optional[ModelOrDc], Any, Dict[str, Any], ModelField, Type[BaseConfig]], Any] - ValidatorsList = List[ValidatorCallable] - ValidatorListDict = Dict[str, List[Validator]] - -_FUNCS: Set[str] = set() -VALIDATOR_CONFIG_KEY = '__validator_config__' -ROOT_VALIDATOR_CONFIG_KEY = '__root_validator_config__' - - -def validator( - *fields: str, - pre: bool = False, - each_item: bool = False, - always: bool = False, - check_fields: bool = True, - whole: bool = None, - allow_reuse: bool = False, -) -> Callable[[AnyCallable], 'AnyClassMethod']: - """ - Decorate methods on the class indicating that they should be used to validate fields - :param fields: which field(s) the method should be called on - :param pre: whether or not this validator should be called before the standard validators (else after) - :param each_item: for complex objects (sets, lists etc.) whether to validate individual elements rather than the - whole object - :param always: whether this method and other validators should be called even if the value is missing - :param check_fields: whether to check that the fields actually exist on the model - :param allow_reuse: whether to track and raise an error if another validator refers to the decorated function - """ - if not fields: - raise ConfigError('validator with no fields specified') - elif isinstance(fields[0], FunctionType): - raise ConfigError( - "validators should be used with fields and keyword arguments, not bare. " # noqa: Q000 - "E.g. usage should be `@validator('', ...)`" - ) - elif not all(isinstance(field, str) for field in fields): - raise ConfigError( - "validator fields should be passed as separate string args. " # noqa: Q000 - "E.g. usage should be `@validator('', '', ...)`" - ) - - if whole is not None: - warnings.warn( - 'The "whole" keyword argument is deprecated, use "each_item" (inverse meaning, default False) instead', - DeprecationWarning, - ) - assert each_item is False, '"each_item" and "whole" conflict, remove "whole"' - each_item = not whole - - def dec(f: AnyCallable) -> 'AnyClassMethod': - f_cls = _prepare_validator(f, allow_reuse) - setattr( - f_cls, - VALIDATOR_CONFIG_KEY, - ( - fields, - Validator(func=f_cls.__func__, pre=pre, each_item=each_item, always=always, check_fields=check_fields), - ), - ) - return f_cls - - return dec - - -@overload -def root_validator(_func: AnyCallable) -> 'AnyClassMethod': - ... - - -@overload -def root_validator( - *, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False -) -> Callable[[AnyCallable], 'AnyClassMethod']: - ... - - -def root_validator( - _func: Optional[AnyCallable] = None, *, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False -) -> Union['AnyClassMethod', Callable[[AnyCallable], 'AnyClassMethod']]: - """ - Decorate methods on a model indicating that they should be used to validate (and perhaps modify) data either - before or after standard model parsing/validation is performed. - """ - if _func: - f_cls = _prepare_validator(_func, allow_reuse) - setattr( - f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure) - ) - return f_cls - - def dec(f: AnyCallable) -> 'AnyClassMethod': - f_cls = _prepare_validator(f, allow_reuse) - setattr( - f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure) - ) - return f_cls - - return dec - - -def _prepare_validator(function: AnyCallable, allow_reuse: bool) -> 'AnyClassMethod': - """ - Avoid validators with duplicated names since without this, validators can be overwritten silently - which generally isn't the intended behaviour, don't run in ipython (see #312) or if allow_reuse is False. - """ - f_cls = function if isinstance(function, classmethod) else classmethod(function) - if not in_ipython() and not allow_reuse: - ref = f_cls.__func__.__module__ + '.' + f_cls.__func__.__qualname__ - if ref in _FUNCS: - raise ConfigError(f'duplicate validator function "{ref}"; if this is intended, set `allow_reuse=True`') - _FUNCS.add(ref) - return f_cls - - -class ValidatorGroup: - def __init__(self, validators: 'ValidatorListDict') -> None: - self.validators = validators - self.used_validators = {'*'} - - def get_validators(self, name: str) -> Optional[Dict[str, Validator]]: - self.used_validators.add(name) - validators = self.validators.get(name, []) - if name != ROOT_KEY: - validators += self.validators.get('*', []) - if validators: - return {v.func.__name__: v for v in validators} - else: - return None - - def check_for_unused(self) -> None: - unused_validators = set( - chain.from_iterable( - (v.func.__name__ for v in self.validators[f] if v.check_fields) - for f in (self.validators.keys() - self.used_validators) - ) - ) - if unused_validators: - fn = ', '.join(unused_validators) - raise ConfigError( - f"Validators defined with incorrect fields: {fn} " # noqa: Q000 - f"(use check_fields=False if you're inheriting from the model and intended this)" - ) - - -def extract_validators(namespace: Dict[str, Any]) -> Dict[str, List[Validator]]: - validators: Dict[str, List[Validator]] = {} - for var_name, value in namespace.items(): - validator_config = getattr(value, VALIDATOR_CONFIG_KEY, None) - if validator_config: - fields, v = validator_config - for field in fields: - if field in validators: - validators[field].append(v) - else: - validators[field] = [v] - return validators - - -def extract_root_validators(namespace: Dict[str, Any]) -> Tuple[List[AnyCallable], List[Tuple[bool, AnyCallable]]]: - from inspect import signature - - pre_validators: List[AnyCallable] = [] - post_validators: List[Tuple[bool, AnyCallable]] = [] - for name, value in namespace.items(): - validator_config: Optional[Validator] = getattr(value, ROOT_VALIDATOR_CONFIG_KEY, None) - if validator_config: - sig = signature(validator_config.func) - args = list(sig.parameters.keys()) - if args[0] == 'self': - raise ConfigError( - f'Invalid signature for root validator {name}: {sig}, "self" not permitted as first argument, ' - f'should be: (cls, values).' - ) - if len(args) != 2: - raise ConfigError(f'Invalid signature for root validator {name}: {sig}, should be: (cls, values).') - # check function signature - if validator_config.pre: - pre_validators.append(validator_config.func) - else: - post_validators.append((validator_config.skip_on_failure, validator_config.func)) - return pre_validators, post_validators - - -def inherit_validators(base_validators: 'ValidatorListDict', validators: 'ValidatorListDict') -> 'ValidatorListDict': - for field, field_validators in base_validators.items(): - if field not in validators: - validators[field] = [] - validators[field] += field_validators - return validators - - -def make_generic_validator(validator: AnyCallable) -> 'ValidatorCallable': - """ - Make a generic function which calls a validator with the right arguments. - - Unfortunately other approaches (eg. return a partial of a function that builds the arguments) is slow, - hence this laborious way of doing things. - - It's done like this so validators don't all need **kwargs in their signature, eg. any combination of - the arguments "values", "fields" and/or "config" are permitted. - """ - from inspect import signature - - sig = signature(validator) - args = list(sig.parameters.keys()) - first_arg = args.pop(0) - if first_arg == 'self': - raise ConfigError( - f'Invalid signature for validator {validator}: {sig}, "self" not permitted as first argument, ' - f'should be: (cls, value, values, config, field), "values", "config" and "field" are all optional.' - ) - elif first_arg == 'cls': - # assume the second argument is value - return wraps(validator)(_generic_validator_cls(validator, sig, set(args[1:]))) - else: - # assume the first argument was value which has already been removed - return wraps(validator)(_generic_validator_basic(validator, sig, set(args))) - - -def prep_validators(v_funcs: Iterable[AnyCallable]) -> 'ValidatorsList': - return [make_generic_validator(f) for f in v_funcs if f] - - -all_kwargs = {'values', 'field', 'config'} - - -def _generic_validator_cls(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable': - # assume the first argument is value - has_kwargs = False - if 'kwargs' in args: - has_kwargs = True - args -= {'kwargs'} - - if not args.issubset(all_kwargs): - raise ConfigError( - f'Invalid signature for validator {validator}: {sig}, should be: ' - f'(cls, value, values, config, field), "values", "config" and "field" are all optional.' - ) - - if has_kwargs: - return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config) - elif args == set(): - return lambda cls, v, values, field, config: validator(cls, v) - elif args == {'values'}: - return lambda cls, v, values, field, config: validator(cls, v, values=values) - elif args == {'field'}: - return lambda cls, v, values, field, config: validator(cls, v, field=field) - elif args == {'config'}: - return lambda cls, v, values, field, config: validator(cls, v, config=config) - elif args == {'values', 'field'}: - return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field) - elif args == {'values', 'config'}: - return lambda cls, v, values, field, config: validator(cls, v, values=values, config=config) - elif args == {'field', 'config'}: - return lambda cls, v, values, field, config: validator(cls, v, field=field, config=config) - else: - # args == {'values', 'field', 'config'} - return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config) - - -def _generic_validator_basic(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable': - has_kwargs = False - if 'kwargs' in args: - has_kwargs = True - args -= {'kwargs'} - - if not args.issubset(all_kwargs): - raise ConfigError( - f'Invalid signature for validator {validator}: {sig}, should be: ' - f'(value, values, config, field), "values", "config" and "field" are all optional.' - ) - - if has_kwargs: - return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config) - elif args == set(): - return lambda cls, v, values, field, config: validator(v) - elif args == {'values'}: - return lambda cls, v, values, field, config: validator(v, values=values) - elif args == {'field'}: - return lambda cls, v, values, field, config: validator(v, field=field) - elif args == {'config'}: - return lambda cls, v, values, field, config: validator(v, config=config) - elif args == {'values', 'field'}: - return lambda cls, v, values, field, config: validator(v, values=values, field=field) - elif args == {'values', 'config'}: - return lambda cls, v, values, field, config: validator(v, values=values, config=config) - elif args == {'field', 'config'}: - return lambda cls, v, values, field, config: validator(v, field=field, config=config) - else: - # args == {'values', 'field', 'config'} - return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config) - - -def gather_all_validators(type_: 'ModelOrDc') -> Dict[str, 'AnyClassMethod']: - all_attributes = ChainMap(*[cls.__dict__ for cls in type_.__mro__]) # type: ignore[arg-type,var-annotated] - return { - k: v - for k, v in all_attributes.items() - if hasattr(v, VALIDATOR_CONFIG_KEY) or hasattr(v, ROOT_VALIDATOR_CONFIG_KEY) - } +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/color.py b/lib/pydantic/color.py index 6fdc9fb1..108bb8fa 100644 --- a/lib/pydantic/color.py +++ b/lib/pydantic/color.py @@ -1,22 +1,28 @@ -""" -Color definitions are used as per CSS3 specification: -http://www.w3.org/TR/css3-color/#svg-color +"""Color definitions are used as per the CSS3 +[CSS Color Module Level 3](http://www.w3.org/TR/css3-color/#svg-color) specification. A few colors have multiple names referring to the sames colors, eg. `grey` and `gray` or `aqua` and `cyan`. -In these cases the LAST color when sorted alphabetically takes preferences, -eg. Color((0, 255, 255)).as_named() == 'cyan' because "cyan" comes after "aqua". +In these cases the _last_ color when sorted alphabetically takes preferences, +eg. `Color((0, 255, 255)).as_named() == 'cyan'` because "cyan" comes after "aqua". + +Warning: Deprecated + The `Color` class is deprecated, use `pydantic_extra_types` instead. + See [`pydantic-extra-types.Color`](../usage/types/extra_types/color_types.md) + for more information. """ import math import re from colorsys import hls_to_rgb, rgb_to_hls -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast +from typing import Any, Callable, Optional, Tuple, Type, Union, cast -from .errors import ColorError -from .utils import Representation, almost_equal_floats +from pydantic_core import CoreSchema, PydanticCustomError, core_schema +from typing_extensions import deprecated -if TYPE_CHECKING: - from .typing import CallableGenerator, ReprArgs +from ._internal import _repr +from ._internal._schema_generation_shared import GetJsonSchemaHandler as _GetJsonSchemaHandler +from .json_schema import JsonSchemaValue +from .warnings import PydanticDeprecatedSince20 ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]] ColorType = Union[ColorTuple, str] @@ -24,9 +30,7 @@ HslColorTuple = Union[Tuple[float, float, float], Tuple[float, float, float, flo class RGBA: - """ - Internal use only as a representation of a color. - """ + """Internal use only as a representation of a color.""" __slots__ = 'r', 'g', 'b', 'alpha', '_tuple' @@ -43,24 +47,35 @@ class RGBA: # these are not compiled here to avoid import slowdown, they'll be compiled the first time they're used, then cached -r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*' -r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*' _r_255 = r'(\d{1,3}(?:\.\d+)?)' _r_comma = r'\s*,\s*' -r_rgb = fr'\s*rgb\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}\)\s*' _r_alpha = r'(\d(?:\.\d+)?|\.\d+|\d{1,2}%)' -r_rgba = fr'\s*rgba\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_alpha}\s*\)\s*' _r_h = r'(-?\d+(?:\.\d+)?|-?\.\d+)(deg|rad|turn)?' _r_sl = r'(\d{1,3}(?:\.\d+)?)%' -r_hsl = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}\s*\)\s*' -r_hsla = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}{_r_comma}{_r_alpha}\s*\)\s*' +r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*' +r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*' +# CSS3 RGB examples: rgb(0, 0, 0), rgba(0, 0, 0, 0.5), rgba(0, 0, 0, 50%) +r_rgb = rf'\s*rgba?\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}(?:{_r_comma}{_r_alpha})?\s*\)\s*' +# CSS3 HSL examples: hsl(270, 60%, 50%), hsla(270, 60%, 50%, 0.5), hsla(270, 60%, 50%, 50%) +r_hsl = rf'\s*hsla?\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}(?:{_r_comma}{_r_alpha})?\s*\)\s*' +# CSS4 RGB examples: rgb(0 0 0), rgb(0 0 0 / 0.5), rgb(0 0 0 / 50%), rgba(0 0 0 / 50%) +r_rgb_v4_style = rf'\s*rgba?\(\s*{_r_255}\s+{_r_255}\s+{_r_255}(?:\s*/\s*{_r_alpha})?\s*\)\s*' +# CSS4 HSL examples: hsl(270 60% 50%), hsl(270 60% 50% / 0.5), hsl(270 60% 50% / 50%), hsla(270 60% 50% / 50%) +r_hsl_v4_style = rf'\s*hsla?\(\s*{_r_h}\s+{_r_sl}\s+{_r_sl}(?:\s*/\s*{_r_alpha})?\s*\)\s*' # colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used repeat_colors = {int(c * 2, 16) for c in '0123456789abcdef'} rads = 2 * math.pi -class Color(Representation): +@deprecated( + 'The `Color` class is deprecated, use `pydantic_extra_types` instead. ' + 'See https://docs.pydantic.dev/latest/api/pydantic_extra_types_color/.', + category=PydanticDeprecatedSince20, +) +class Color(_repr.Representation): + """Represents a color.""" + __slots__ = '_original', '_rgba' def __init__(self, value: ColorType) -> None: @@ -74,22 +89,39 @@ class Color(Representation): self._rgba = value._rgba value = value._original else: - raise ColorError(reason='value must be a tuple, list or string') + raise PydanticCustomError( + 'color_error', 'value is not a valid color: value must be a tuple, list or string' + ) # if we've got here value must be a valid color self._original = value @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: _GetJsonSchemaHandler + ) -> JsonSchemaValue: + field_schema = {} field_schema.update(type='string', format='color') + return field_schema def original(self) -> ColorType: - """ - Original value passed to Color - """ + """Original value passed to `Color`.""" return self._original def as_named(self, *, fallback: bool = False) -> str: + """Returns the name of the color if it can be found in `COLORS_BY_VALUE` dictionary, + otherwise returns the hexadecimal representation of the color or raises `ValueError`. + + Args: + fallback: If True, falls back to returning the hexadecimal representation of + the color instead of raising a ValueError when no named color is found. + + Returns: + The name of the color, or the hexadecimal representation of the color. + + Raises: + ValueError: When no named color is found and fallback is `False`. + """ if self._rgba.alpha is None: rgb = cast(Tuple[int, int, int], self.as_rgb_tuple()) try: @@ -103,9 +135,13 @@ class Color(Representation): return self.as_hex() def as_hex(self) -> str: - """ - Hex string representing the color can be 3, 4, 6 or 8 characters depending on whether the string + """Returns the hexadecimal representation of the color. + + Hex string representing the color can be 3, 4, 6, or 8 characters depending on whether the string a "short" representation of the color is possible and whether there's an alpha channel. + + Returns: + The hexadecimal representation of the color. """ values = [float_to_255(c) for c in self._rgba[:3]] if self._rgba.alpha is not None: @@ -117,9 +153,7 @@ class Color(Representation): return '#' + as_hex def as_rgb(self) -> str: - """ - Color as an rgb(, , ) or rgba(, , , ) string. - """ + """Color as an `rgb(, , )` or `rgba(, , , )` string.""" if self._rgba.alpha is None: return f'rgb({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)})' else: @@ -129,14 +163,18 @@ class Color(Representation): ) def as_rgb_tuple(self, *, alpha: Optional[bool] = None) -> ColorTuple: - """ - Color as an RGB or RGBA tuple; red, green and blue are in the range 0 to 255, alpha if included is - in the range 0 to 1. + """Returns the color as an RGB or RGBA tuple. - :param alpha: whether to include the alpha channel, options are - None - (default) include alpha only if it's set (e.g. not None) - True - always include alpha, - False - always omit alpha, + Args: + alpha: Whether to include the alpha channel. There are three options for this input: + + - `None` (default): Include alpha only if it's set. (e.g. not `None`) + - `True`: Always include alpha. + - `False`: Always omit alpha. + + Returns: + A tuple that contains the values of the red, green, and blue channels in the range 0 to 255. + If alpha is included, it is in the range 0 to 1. """ r, g, b = (float_to_255(c) for c in self._rgba[:3]) if alpha is None: @@ -151,9 +189,7 @@ class Color(Representation): return r, g, b def as_hsl(self) -> str: - """ - Color as an hsl(, , ) or hsl(, , , ) string. - """ + """Color as an `hsl(, , )` or `hsl(, , , )` string.""" if self._rgba.alpha is None: h, s, li = self.as_hsl_tuple(alpha=False) # type: ignore return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%})' @@ -162,18 +198,23 @@ class Color(Representation): return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%}, {round(a, 2)})' def as_hsl_tuple(self, *, alpha: Optional[bool] = None) -> HslColorTuple: - """ - Color as an HSL or HSLA tuple, e.g. hue, saturation, lightness and optionally alpha; all elements are in - the range 0 to 1. + """Returns the color as an HSL or HSLA tuple. - NOTE: this is HSL as used in HTML and most other places, not HLS as used in python's colorsys. + Args: + alpha: Whether to include the alpha channel. - :param alpha: whether to include the alpha channel, options are - None - (default) include alpha only if it's set (e.g. not None) - True - always include alpha, - False - always omit alpha, + - `None` (default): Include the alpha channel only if it's set (e.g. not `None`). + - `True`: Always include alpha. + - `False`: Always omit alpha. + + Returns: + The color as a tuple of hue, saturation, lightness, and alpha (if included). + All elements are in the range 0 to 1. + + Note: + This is HSL as used in HTML and most other places, not HLS as used in Python's `colorsys`. """ - h, l, s = rgb_to_hls(self._rgba.r, self._rgba.g, self._rgba.b) + h, l, s = rgb_to_hls(self._rgba.r, self._rgba.g, self._rgba.b) # noqa: E741 if alpha is None: if self._rgba.alpha is None: return h, s, l @@ -189,14 +230,22 @@ class Color(Representation): return 1 if self._rgba.alpha is None else self._rgba.alpha @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield cls + def __get_pydantic_core_schema__( + cls, source: Type[Any], handler: Callable[[Any], CoreSchema] + ) -> core_schema.CoreSchema: + return core_schema.with_info_plain_validator_function( + cls._validate, serialization=core_schema.to_string_ser_schema() + ) + + @classmethod + def _validate(cls, __input_value: Any, _: Any) -> 'Color': + return cls(__input_value) def __str__(self) -> str: return self.as_named(fallback=True) - def __repr_args__(self) -> 'ReprArgs': - return [(None, self.as_named(fallback=True))] + [('rgb', self.as_rgb_tuple())] # type: ignore + def __repr_args__(self) -> '_repr.ReprArgs': + return [(None, self.as_named(fallback=True))] + [('rgb', self.as_rgb_tuple())] def __eq__(self, other: Any) -> bool: return isinstance(other, Color) and self.as_rgb_tuple() == other.as_rgb_tuple() @@ -206,8 +255,16 @@ class Color(Representation): def parse_tuple(value: Tuple[Any, ...]) -> RGBA: - """ - Parse a tuple or list as a color. + """Parse a tuple or list to get RGBA values. + + Args: + value: A tuple or list. + + Returns: + An `RGBA` tuple parsed from the input tuple. + + Raises: + PydanticCustomError: If tuple is not valid. """ if len(value) == 3: r, g, b = (parse_color_value(v) for v in value) @@ -216,17 +273,28 @@ def parse_tuple(value: Tuple[Any, ...]) -> RGBA: r, g, b = (parse_color_value(v) for v in value[:3]) return RGBA(r, g, b, parse_float_alpha(value[3])) else: - raise ColorError(reason='tuples must have length 3 or 4') + raise PydanticCustomError('color_error', 'value is not a valid color: tuples must have length 3 or 4') def parse_str(value: str) -> RGBA: - """ - Parse a string to an RGBA tuple, trying the following formats (in this order): - * named color, see COLORS_BY_NAME below + """Parse a string representing a color to an RGBA tuple. + + Possible formats for the input string include: + + * named color, see `COLORS_BY_NAME` * hex short eg. `fff` (prefix can be `#`, `0x` or nothing) * hex long eg. `ffffff` (prefix can be `#`, `0x` or nothing) - * `rgb(, , ) ` + * `rgb(, , )` * `rgba(, , , )` + + Args: + value: A string representing a color. + + Returns: + An `RGBA` tuple parsed from the input string. + + Raises: + ValueError: If the input string cannot be parsed to an RGBA tuple. """ value_lower = value.lower() try: @@ -256,49 +324,70 @@ def parse_str(value: str) -> RGBA: alpha = None return ints_to_rgba(r, g, b, alpha) - m = re.fullmatch(r_rgb, value_lower) - if m: - return ints_to_rgba(*m.groups(), None) # type: ignore - - m = re.fullmatch(r_rgba, value_lower) + m = re.fullmatch(r_rgb, value_lower) or re.fullmatch(r_rgb_v4_style, value_lower) if m: return ints_to_rgba(*m.groups()) # type: ignore - m = re.fullmatch(r_hsl, value_lower) + m = re.fullmatch(r_hsl, value_lower) or re.fullmatch(r_hsl_v4_style, value_lower) if m: - h, h_units, s, l_ = m.groups() - return parse_hsl(h, h_units, s, l_) + return parse_hsl(*m.groups()) # type: ignore - m = re.fullmatch(r_hsla, value_lower) - if m: - h, h_units, s, l_, a = m.groups() - return parse_hsl(h, h_units, s, l_, parse_float_alpha(a)) - - raise ColorError(reason='string not recognised as a valid color') + raise PydanticCustomError('color_error', 'value is not a valid color: string not recognised as a valid color') -def ints_to_rgba(r: Union[int, str], g: Union[int, str], b: Union[int, str], alpha: Optional[float]) -> RGBA: +def ints_to_rgba(r: Union[int, str], g: Union[int, str], b: Union[int, str], alpha: Optional[float] = None) -> RGBA: + """Converts integer or string values for RGB color and an optional alpha value to an `RGBA` object. + + Args: + r: An integer or string representing the red color value. + g: An integer or string representing the green color value. + b: An integer or string representing the blue color value. + alpha: A float representing the alpha value. Defaults to None. + + Returns: + An instance of the `RGBA` class with the corresponding color and alpha values. + """ return RGBA(parse_color_value(r), parse_color_value(g), parse_color_value(b), parse_float_alpha(alpha)) def parse_color_value(value: Union[int, str], max_val: int = 255) -> float: - """ - Parse a value checking it's a valid int in the range 0 to max_val and divide by max_val to give a number - in the range 0 to 1 + """Parse the color value provided and return a number between 0 and 1. + + Args: + value: An integer or string color value. + max_val: Maximum range value. Defaults to 255. + + Raises: + PydanticCustomError: If the value is not a valid color. + + Returns: + A number between 0 and 1. """ try: color = float(value) except ValueError: - raise ColorError(reason='color values must be a valid number') + raise PydanticCustomError('color_error', 'value is not a valid color: color values must be a valid number') if 0 <= color <= max_val: return color / max_val else: - raise ColorError(reason=f'color values must be in the range 0 to {max_val}') + raise PydanticCustomError( + 'color_error', + 'value is not a valid color: color values must be in the range 0 to {max_val}', + {'max_val': max_val}, + ) def parse_float_alpha(value: Union[None, str, float, int]) -> Optional[float]: - """ - Parse a value checking it's a valid float in the range 0 to 1 + """Parse an alpha value checking it's a valid float in the range 0 to 1. + + Args: + value: The input value to parse. + + Returns: + The parsed value as a float, or `None` if the value was None or equal 1. + + Raises: + PydanticCustomError: If the input value cannot be successfully parsed as a float in the expected range. """ if value is None: return None @@ -308,19 +397,28 @@ def parse_float_alpha(value: Union[None, str, float, int]) -> Optional[float]: else: alpha = float(value) except ValueError: - raise ColorError(reason='alpha values must be a valid float') + raise PydanticCustomError('color_error', 'value is not a valid color: alpha values must be a valid float') - if almost_equal_floats(alpha, 1): + if math.isclose(alpha, 1): return None elif 0 <= alpha <= 1: return alpha else: - raise ColorError(reason='alpha values must be in the range 0 to 1') + raise PydanticCustomError('color_error', 'value is not a valid color: alpha values must be in the range 0 to 1') def parse_hsl(h: str, h_units: str, sat: str, light: str, alpha: Optional[float] = None) -> RGBA: - """ - Parse raw hue, saturation, lightness and alpha values and convert to RGBA. + """Parse raw hue, saturation, lightness, and alpha values and convert to RGBA. + + Args: + h: The hue value. + h_units: The unit for hue value. + sat: The saturation value. + light: The lightness value. + alpha: Alpha value. + + Returns: + An instance of `RGBA`. """ s_value, l_value = parse_color_value(sat, 100), parse_color_value(light, 100) @@ -334,10 +432,21 @@ def parse_hsl(h: str, h_units: str, sat: str, light: str, alpha: Optional[float] h_value = h_value % 1 r, g, b = hls_to_rgb(h_value, l_value, s_value) - return RGBA(r, g, b, alpha) + return RGBA(r, g, b, parse_float_alpha(alpha)) def float_to_255(c: float) -> int: + """Converts a float value between 0 and 1 (inclusive) to an integer between 0 and 255 (inclusive). + + Args: + c: The float value to be converted. Must be between 0 and 1 (inclusive). + + Returns: + The integer equivalent of the given float value rounded to the nearest whole number. + + Raises: + ValueError: If the given float value is outside the acceptable range of 0 to 1 (inclusive). + """ return int(round(c * 255)) diff --git a/lib/pydantic/config.py b/lib/pydantic/config.py index 74687ca0..6b22586b 100644 --- a/lib/pydantic/config.py +++ b/lib/pydantic/config.py @@ -1,192 +1,912 @@ -import json -from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, ForwardRef, Optional, Tuple, Type, Union +"""Configuration for Pydantic models.""" +from __future__ import annotations as _annotations -from typing_extensions import Literal, Protocol +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Type, Union -from .typing import AnyArgTCallable, AnyCallable -from .utils import GetterDict -from .version import compiled +from typing_extensions import Literal, TypeAlias, TypedDict + +from ._migration import getattr_migration +from .aliases import AliasGenerator if TYPE_CHECKING: - from typing import overload + from ._internal._generate_schema import GenerateSchema as _GenerateSchema - from .fields import ModelField - from .main import BaseModel - - ConfigType = Type['BaseConfig'] - - class SchemaExtraCallable(Protocol): - @overload - def __call__(self, schema: Dict[str, Any]) -> None: - pass - - @overload - def __call__(self, schema: Dict[str, Any], model_class: Type[BaseModel]) -> None: - pass - -else: - SchemaExtraCallable = Callable[..., None] - -__all__ = 'BaseConfig', 'ConfigDict', 'get_config', 'Extra', 'inherit_config', 'prepare_config' +__all__ = ('ConfigDict',) -class Extra(str, Enum): - allow = 'allow' - ignore = 'ignore' - forbid = 'forbid' +JsonValue: TypeAlias = Union[int, float, str, bool, None, List['JsonValue'], 'JsonDict'] +JsonDict: TypeAlias = Dict[str, JsonValue] + +JsonEncoder = Callable[[Any], Any] + +JsonSchemaExtraCallable: TypeAlias = Union[ + Callable[[JsonDict], None], + Callable[[JsonDict, Type[Any]], None], +] + +ExtraValues = Literal['allow', 'ignore', 'forbid'] -# https://github.com/cython/cython/issues/4003 -# Will be fixed with Cython 3 but still in alpha right now -if not compiled: - from typing_extensions import TypedDict +class ConfigDict(TypedDict, total=False): + """A TypedDict for configuring Pydantic behaviour.""" - class ConfigDict(TypedDict, total=False): - title: Optional[str] - anystr_lower: bool - anystr_strip_whitespace: bool - min_anystr_length: int - max_anystr_length: Optional[int] - validate_all: bool - extra: Extra - allow_mutation: bool - frozen: bool - allow_population_by_field_name: bool - use_enum_values: bool - fields: Dict[str, Union[str, Dict[str, str]]] - validate_assignment: bool - error_msg_templates: Dict[str, str] - arbitrary_types_allowed: bool - orm_mode: bool - getter_dict: Type[GetterDict] - alias_generator: Optional[Callable[[str], str]] - keep_untouched: Tuple[type, ...] - schema_extra: Union[Dict[str, object], 'SchemaExtraCallable'] - json_loads: Callable[[str], object] - json_dumps: AnyArgTCallable[str] - json_encoders: Dict[Type[object], AnyCallable] - underscore_attrs_are_private: bool - allow_inf_nan: bool + title: str | None + """The title for the generated JSON schema, defaults to the model's name""" - # whether or not inherited models as fields should be reconstructed as base model - copy_on_model_validation: bool - # whether dataclass `__post_init__` should be run after validation - post_init_call: Literal['before_validation', 'after_validation'] + str_to_lower: bool + """Whether to convert all characters to lowercase for str types. Defaults to `False`.""" -else: - ConfigDict = dict # type: ignore + str_to_upper: bool + """Whether to convert all characters to uppercase for str types. Defaults to `False`.""" + str_strip_whitespace: bool + """Whether to strip leading and trailing whitespace for str types.""" + + str_min_length: int + """The minimum length for str types. Defaults to `None`.""" + + str_max_length: int | None + """The maximum length for str types. Defaults to `None`.""" + + extra: ExtraValues | None + """ + Whether to ignore, allow, or forbid extra attributes during model initialization. Defaults to `'ignore'`. + + You can configure how pydantic handles the attributes that are not defined in the model: + + * `allow` - Allow any extra attributes. + * `forbid` - Forbid any extra attributes. + * `ignore` - Ignore any extra attributes. + + ```py + from pydantic import BaseModel, ConfigDict -class BaseConfig: - title: Optional[str] = None - anystr_lower: bool = False - anystr_upper: bool = False - anystr_strip_whitespace: bool = False - min_anystr_length: int = 0 - max_anystr_length: Optional[int] = None - validate_all: bool = False - extra: Extra = Extra.ignore - allow_mutation: bool = True - frozen: bool = False - allow_population_by_field_name: bool = False - use_enum_values: bool = False - fields: Dict[str, Union[str, Dict[str, str]]] = {} - validate_assignment: bool = False - error_msg_templates: Dict[str, str] = {} - arbitrary_types_allowed: bool = False - orm_mode: bool = False - getter_dict: Type[GetterDict] = GetterDict - alias_generator: Optional[Callable[[str], str]] = None - keep_untouched: Tuple[type, ...] = () - schema_extra: Union[Dict[str, Any], 'SchemaExtraCallable'] = {} - json_loads: Callable[[str], Any] = json.loads - json_dumps: Callable[..., str] = json.dumps - json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable] = {} - underscore_attrs_are_private: bool = False - allow_inf_nan: bool = True + class User(BaseModel): + model_config = ConfigDict(extra='ignore') # (1)! - # whether inherited models as fields should be reconstructed as base model, - # and whether such a copy should be shallow or deep - copy_on_model_validation: Literal['none', 'deep', 'shallow'] = 'shallow' - - # whether `Union` should check all allowed types before even trying to coerce - smart_union: bool = False - # whether dataclass `__post_init__` should be run before or after validation - post_init_call: Literal['before_validation', 'after_validation'] = 'before_validation' - - @classmethod - def get_field_info(cls, name: str) -> Dict[str, Any]: - """ - Get properties of FieldInfo from the `fields` property of the config class. - """ - - fields_value = cls.fields.get(name) - - if isinstance(fields_value, str): - field_info: Dict[str, Any] = {'alias': fields_value} - elif isinstance(fields_value, dict): - field_info = fields_value - else: - field_info = {} - - if 'alias' in field_info: - field_info.setdefault('alias_priority', 2) - - if field_info.get('alias_priority', 0) <= 1 and cls.alias_generator: - alias = cls.alias_generator(name) - if not isinstance(alias, str): - raise TypeError(f'Config.alias_generator must return str, not {alias.__class__}') - field_info.update(alias=alias, alias_priority=1) - return field_info - - @classmethod - def prepare_field(cls, field: 'ModelField') -> None: - """ - Optional hook to check or modify fields during model creation. - """ - pass + name: str -def get_config(config: Union[ConfigDict, Type[object], None]) -> Type[BaseConfig]: - if config is None: - return BaseConfig + user = User(name='John Doe', age=20) # (2)! + print(user) + #> name='John Doe' + ``` - else: - config_dict = ( - config - if isinstance(config, dict) - else {k: getattr(config, k) for k in dir(config) if not k.startswith('__')} + 1. This is the default behaviour. + 2. The `age` argument is ignored. + + Instead, with `extra='allow'`, the `age` argument is included: + + ```py + from pydantic import BaseModel, ConfigDict + + + class User(BaseModel): + model_config = ConfigDict(extra='allow') + + name: str + + + user = User(name='John Doe', age=20) # (1)! + print(user) + #> name='John Doe' age=20 + ``` + + 1. The `age` argument is included. + + With `extra='forbid'`, an error is raised: + + ```py + from pydantic import BaseModel, ConfigDict, ValidationError + + + class User(BaseModel): + model_config = ConfigDict(extra='forbid') + + name: str + + + try: + User(name='John Doe', age=20) + except ValidationError as e: + print(e) + ''' + 1 validation error for User + age + Extra inputs are not permitted [type=extra_forbidden, input_value=20, input_type=int] + ''' + ``` + """ + + frozen: bool + """ + Whether models are faux-immutable, i.e. whether `__setattr__` is allowed, and also generates + a `__hash__()` method for the model. This makes instances of the model potentially hashable if all the + attributes are hashable. Defaults to `False`. + + Note: + On V1, the inverse of this setting was called `allow_mutation`, and was `True` by default. + """ + + populate_by_name: bool + """ + Whether an aliased field may be populated by its name as given by the model + attribute, as well as the alias. Defaults to `False`. + + Note: + The name of this configuration setting was changed in **v2.0** from + `allow_population_by_field_name` to `populate_by_name`. + + ```py + from pydantic import BaseModel, ConfigDict, Field + + + class User(BaseModel): + model_config = ConfigDict(populate_by_name=True) + + name: str = Field(alias='full_name') # (1)! + age: int + + + user = User(full_name='John Doe', age=20) # (2)! + print(user) + #> name='John Doe' age=20 + user = User(name='John Doe', age=20) # (3)! + print(user) + #> name='John Doe' age=20 + ``` + + 1. The field `'name'` has an alias `'full_name'`. + 2. The model is populated by the alias `'full_name'`. + 3. The model is populated by the field name `'name'`. + """ + + use_enum_values: bool + """ + Whether to populate models with the `value` property of enums, rather than the raw enum. + This may be useful if you want to serialize `model.model_dump()` later. Defaults to `False`. + + !!! note + If you have an `Optional[Enum]` value that you set a default for, you need to use `validate_default=True` + for said Field to ensure that the `use_enum_values` flag takes effect on the default, as extracting an + enum's value occurs during validation, not serialization. + + ```py + from enum import Enum + from typing import Optional + + from pydantic import BaseModel, ConfigDict, Field + + + class SomeEnum(Enum): + FOO = 'foo' + BAR = 'bar' + BAZ = 'baz' + + + class SomeModel(BaseModel): + model_config = ConfigDict(use_enum_values=True) + + some_enum: SomeEnum + another_enum: Optional[SomeEnum] = Field(default=SomeEnum.FOO, validate_default=True) + + + model1 = SomeModel(some_enum=SomeEnum.BAR) + print(model1.model_dump()) + # {'some_enum': 'bar', 'another_enum': 'foo'} + + model2 = SomeModel(some_enum=SomeEnum.BAR, another_enum=SomeEnum.BAZ) + print(model2.model_dump()) + #> {'some_enum': 'bar', 'another_enum': 'baz'} + ``` + """ + + validate_assignment: bool + """ + Whether to validate the data when the model is changed. Defaults to `False`. + + The default behavior of Pydantic is to validate the data when the model is created. + + In case the user changes the data after the model is created, the model is _not_ revalidated. + + ```py + from pydantic import BaseModel + + class User(BaseModel): + name: str + + user = User(name='John Doe') # (1)! + print(user) + #> name='John Doe' + user.name = 123 # (1)! + print(user) + #> name=123 + ``` + + 1. The validation happens only when the model is created. + 2. The validation does not happen when the data is changed. + + In case you want to revalidate the model when the data is changed, you can use `validate_assignment=True`: + + ```py + from pydantic import BaseModel, ValidationError + + class User(BaseModel, validate_assignment=True): # (1)! + name: str + + user = User(name='John Doe') # (2)! + print(user) + #> name='John Doe' + try: + user.name = 123 # (3)! + except ValidationError as e: + print(e) + ''' + 1 validation error for User + name + Input should be a valid string [type=string_type, input_value=123, input_type=int] + ''' + ``` + + 1. You can either use class keyword arguments, or `model_config` to set `validate_assignment=True`. + 2. The validation happens when the model is created. + 3. The validation _also_ happens when the data is changed. + """ + + arbitrary_types_allowed: bool + """ + Whether arbitrary types are allowed for field types. Defaults to `False`. + + ```py + from pydantic import BaseModel, ConfigDict, ValidationError + + # This is not a pydantic model, it's an arbitrary class + class Pet: + def __init__(self, name: str): + self.name = name + + class Model(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + pet: Pet + owner: str + + pet = Pet(name='Hedwig') + # A simple check of instance type is used to validate the data + model = Model(owner='Harry', pet=pet) + print(model) + #> pet=<__main__.Pet object at 0x0123456789ab> owner='Harry' + print(model.pet) + #> <__main__.Pet object at 0x0123456789ab> + print(model.pet.name) + #> Hedwig + print(type(model.pet)) + #> + try: + # If the value is not an instance of the type, it's invalid + Model(owner='Harry', pet='Hedwig') + except ValidationError as e: + print(e) + ''' + 1 validation error for Model + pet + Input should be an instance of Pet [type=is_instance_of, input_value='Hedwig', input_type=str] + ''' + + # Nothing in the instance of the arbitrary type is checked + # Here name probably should have been a str, but it's not validated + pet2 = Pet(name=42) + model2 = Model(owner='Harry', pet=pet2) + print(model2) + #> pet=<__main__.Pet object at 0x0123456789ab> owner='Harry' + print(model2.pet) + #> <__main__.Pet object at 0x0123456789ab> + print(model2.pet.name) + #> 42 + print(type(model2.pet)) + #> + ``` + """ + + from_attributes: bool + """ + Whether to build models and look up discriminators of tagged unions using python object attributes. + """ + + loc_by_alias: bool + """Whether to use the actual key provided in the data (e.g. alias) for error `loc`s rather than the field's name. Defaults to `True`.""" + + alias_generator: Callable[[str], str] | AliasGenerator | None + """ + A callable that takes a field name and returns an alias for it + or an instance of [`AliasGenerator`][pydantic.aliases.AliasGenerator]. Defaults to `None`. + + When using a callable, the alias generator is used for both validation and serialization. + If you want to use different alias generators for validation and serialization, you can use + [`AliasGenerator`][pydantic.aliases.AliasGenerator] instead. + + If data source field names do not match your code style (e. g. CamelCase fields), + you can automatically generate aliases using `alias_generator`. Here's an example with + a basic callable: + + ```py + from pydantic import BaseModel, ConfigDict + from pydantic.alias_generators import to_pascal + + class Voice(BaseModel): + model_config = ConfigDict(alias_generator=to_pascal) + + name: str + language_code: str + + voice = Voice(Name='Filiz', LanguageCode='tr-TR') + print(voice.language_code) + #> tr-TR + print(voice.model_dump(by_alias=True)) + #> {'Name': 'Filiz', 'LanguageCode': 'tr-TR'} + ``` + + If you want to use different alias generators for validation and serialization, you can use + [`AliasGenerator`][pydantic.aliases.AliasGenerator]. + + ```py + from pydantic import AliasGenerator, BaseModel, ConfigDict + from pydantic.alias_generators import to_camel, to_pascal + + class Athlete(BaseModel): + first_name: str + last_name: str + sport: str + + model_config = ConfigDict( + alias_generator=AliasGenerator( + validation_alias=to_camel, + serialization_alias=to_pascal, + ) ) - class Config(BaseConfig): - ... + athlete = Athlete(firstName='John', lastName='Doe', sport='track') + print(athlete.model_dump(by_alias=True)) + #> {'FirstName': 'John', 'LastName': 'Doe', 'Sport': 'track'} + ``` - for k, v in config_dict.items(): - setattr(Config, k, v) - return Config + Note: + Pydantic offers three built-in alias generators: [`to_pascal`][pydantic.alias_generators.to_pascal], + [`to_camel`][pydantic.alias_generators.to_camel], and [`to_snake`][pydantic.alias_generators.to_snake]. + """ + ignored_types: tuple[type, ...] + """A tuple of types that may occur as values of class attributes without annotations. This is + typically used for custom descriptors (classes that behave like `property`). If an attribute is set on a + class without an annotation and has a type that is not in this tuple (or otherwise recognized by + _pydantic_), an error will be raised. Defaults to `()`. + """ -def inherit_config(self_config: 'ConfigType', parent_config: 'ConfigType', **namespace: Any) -> 'ConfigType': - if not self_config: - base_classes: Tuple['ConfigType', ...] = (parent_config,) - elif self_config == parent_config: - base_classes = (self_config,) - else: - base_classes = self_config, parent_config + allow_inf_nan: bool + """Whether to allow infinity (`+inf` an `-inf`) and NaN values to float fields. Defaults to `True`.""" - namespace['json_encoders'] = { - **getattr(parent_config, 'json_encoders', {}), - **getattr(self_config, 'json_encoders', {}), - **namespace.get('json_encoders', {}), + json_schema_extra: JsonDict | JsonSchemaExtraCallable | None + """A dict or callable to provide extra JSON schema properties. Defaults to `None`.""" + + json_encoders: dict[type[object], JsonEncoder] | None + """ + A `dict` of custom JSON encoders for specific types. Defaults to `None`. + + !!! warning "Deprecated" + This config option is a carryover from v1. + We originally planned to remove it in v2 but didn't have a 1:1 replacement so we are keeping it for now. + It is still deprecated and will likely be removed in the future. + """ + + # new in V2 + strict: bool + """ + _(new in V2)_ If `True`, strict validation is applied to all fields on the model. + + By default, Pydantic attempts to coerce values to the correct type, when possible. + + There are situations in which you may want to disable this behavior, and instead raise an error if a value's type + does not match the field's type annotation. + + To configure strict mode for all fields on a model, you can set `strict=True` on the model. + + ```py + from pydantic import BaseModel, ConfigDict + + class Model(BaseModel): + model_config = ConfigDict(strict=True) + + name: str + age: int + ``` + + See [Strict Mode](../concepts/strict_mode.md) for more details. + + See the [Conversion Table](../concepts/conversion_table.md) for more details on how Pydantic converts data in both + strict and lax modes. + """ + # whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never' + revalidate_instances: Literal['always', 'never', 'subclass-instances'] + """ + When and how to revalidate models and dataclasses during validation. Accepts the string + values of `'never'`, `'always'` and `'subclass-instances'`. Defaults to `'never'`. + + - `'never'` will not revalidate models and dataclasses during validation + - `'always'` will revalidate models and dataclasses during validation + - `'subclass-instances'` will revalidate models and dataclasses during validation if the instance is a + subclass of the model or dataclass + + By default, model and dataclass instances are not revalidated during validation. + + ```py + from typing import List + + from pydantic import BaseModel + + class User(BaseModel, revalidate_instances='never'): # (1)! + hobbies: List[str] + + class SubUser(User): + sins: List[str] + + class Transaction(BaseModel): + user: User + + my_user = User(hobbies=['reading']) + t = Transaction(user=my_user) + print(t) + #> user=User(hobbies=['reading']) + + my_user.hobbies = [1] # (2)! + t = Transaction(user=my_user) # (3)! + print(t) + #> user=User(hobbies=[1]) + + my_sub_user = SubUser(hobbies=['scuba diving'], sins=['lying']) + t = Transaction(user=my_sub_user) + print(t) + #> user=SubUser(hobbies=['scuba diving'], sins=['lying']) + ``` + + 1. `revalidate_instances` is set to `'never'` by **default. + 2. The assignment is not validated, unless you set `validate_assignment` to `True` in the model's config. + 3. Since `revalidate_instances` is set to `never`, this is not revalidated. + + If you want to revalidate instances during validation, you can set `revalidate_instances` to `'always'` + in the model's config. + + ```py + from typing import List + + from pydantic import BaseModel, ValidationError + + class User(BaseModel, revalidate_instances='always'): # (1)! + hobbies: List[str] + + class SubUser(User): + sins: List[str] + + class Transaction(BaseModel): + user: User + + my_user = User(hobbies=['reading']) + t = Transaction(user=my_user) + print(t) + #> user=User(hobbies=['reading']) + + my_user.hobbies = [1] + try: + t = Transaction(user=my_user) # (2)! + except ValidationError as e: + print(e) + ''' + 1 validation error for Transaction + user.hobbies.0 + Input should be a valid string [type=string_type, input_value=1, input_type=int] + ''' + + my_sub_user = SubUser(hobbies=['scuba diving'], sins=['lying']) + t = Transaction(user=my_sub_user) + print(t) # (3)! + #> user=User(hobbies=['scuba diving']) + ``` + + 1. `revalidate_instances` is set to `'always'`. + 2. The model is revalidated, since `revalidate_instances` is set to `'always'`. + 3. Using `'never'` we would have gotten `user=SubUser(hobbies=['scuba diving'], sins=['lying'])`. + + It's also possible to set `revalidate_instances` to `'subclass-instances'` to only revalidate instances + of subclasses of the model. + + ```py + from typing import List + + from pydantic import BaseModel + + class User(BaseModel, revalidate_instances='subclass-instances'): # (1)! + hobbies: List[str] + + class SubUser(User): + sins: List[str] + + class Transaction(BaseModel): + user: User + + my_user = User(hobbies=['reading']) + t = Transaction(user=my_user) + print(t) + #> user=User(hobbies=['reading']) + + my_user.hobbies = [1] + t = Transaction(user=my_user) # (2)! + print(t) + #> user=User(hobbies=[1]) + + my_sub_user = SubUser(hobbies=['scuba diving'], sins=['lying']) + t = Transaction(user=my_sub_user) + print(t) # (3)! + #> user=User(hobbies=['scuba diving']) + ``` + + 1. `revalidate_instances` is set to `'subclass-instances'`. + 2. This is not revalidated, since `my_user` is not a subclass of `User`. + 3. Using `'never'` we would have gotten `user=SubUser(hobbies=['scuba diving'], sins=['lying'])`. + """ + + ser_json_timedelta: Literal['iso8601', 'float'] + """ + The format of JSON serialized timedeltas. Accepts the string values of `'iso8601'` and + `'float'`. Defaults to `'iso8601'`. + + - `'iso8601'` will serialize timedeltas to ISO 8601 durations. + - `'float'` will serialize timedeltas to the total number of seconds. + """ + + ser_json_bytes: Literal['utf8', 'base64'] + """ + The encoding of JSON serialized bytes. Accepts the string values of `'utf8'` and `'base64'`. + Defaults to `'utf8'`. + + - `'utf8'` will serialize bytes to UTF-8 strings. + - `'base64'` will serialize bytes to URL safe base64 strings. + """ + + ser_json_inf_nan: Literal['null', 'constants'] + """ + The encoding of JSON serialized infinity and NaN float values. Accepts the string values of `'null'` and `'constants'`. + Defaults to `'null'`. + + - `'null'` will serialize infinity and NaN values as `null`. + - `'constants'` will serialize infinity and NaN values as `Infinity` and `NaN`. + """ + + # whether to validate default values during validation, default False + validate_default: bool + """Whether to validate default values during validation. Defaults to `False`.""" + + validate_return: bool + """whether to validate the return value from call validators. Defaults to `False`.""" + + protected_namespaces: tuple[str, ...] + """ + A `tuple` of strings that prevent model to have field which conflict with them. + Defaults to `('model_', )`). + + Pydantic prevents collisions between model attributes and `BaseModel`'s own methods by + namespacing them with the prefix `model_`. + + ```py + import warnings + + from pydantic import BaseModel + + warnings.filterwarnings('error') # Raise warnings as errors + + try: + + class Model(BaseModel): + model_prefixed_field: str + + except UserWarning as e: + print(e) + ''' + Field "model_prefixed_field" has conflict with protected namespace "model_". + + You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`. + ''' + ``` + + You can customize this behavior using the `protected_namespaces` setting: + + ```py + import warnings + + from pydantic import BaseModel, ConfigDict + + warnings.filterwarnings('error') # Raise warnings as errors + + try: + + class Model(BaseModel): + model_prefixed_field: str + also_protect_field: str + + model_config = ConfigDict( + protected_namespaces=('protect_me_', 'also_protect_') + ) + + except UserWarning as e: + print(e) + ''' + Field "also_protect_field" has conflict with protected namespace "also_protect_". + + You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ('protect_me_',)`. + ''' + ``` + + While Pydantic will only emit a warning when an item is in a protected namespace but does not actually have a collision, + an error _is_ raised if there is an actual collision with an existing attribute: + + ```py + from pydantic import BaseModel + + try: + + class Model(BaseModel): + model_validate: str + + except NameError as e: + print(e) + ''' + Field "model_validate" conflicts with member > of protected namespace "model_". + ''' + ``` + """ + + hide_input_in_errors: bool + """ + Whether to hide inputs when printing errors. Defaults to `False`. + + Pydantic shows the input value and type when it raises `ValidationError` during the validation. + + ```py + from pydantic import BaseModel, ValidationError + + class Model(BaseModel): + a: str + + try: + Model(a=123) + except ValidationError as e: + print(e) + ''' + 1 validation error for Model + a + Input should be a valid string [type=string_type, input_value=123, input_type=int] + ''' + ``` + + You can hide the input value and type by setting the `hide_input_in_errors` config to `True`. + + ```py + from pydantic import BaseModel, ConfigDict, ValidationError + + class Model(BaseModel): + a: str + model_config = ConfigDict(hide_input_in_errors=True) + + try: + Model(a=123) + except ValidationError as e: + print(e) + ''' + 1 validation error for Model + a + Input should be a valid string [type=string_type] + ''' + ``` + """ + + defer_build: bool + """ + Whether to defer model validator and serializer construction until the first model validation. + + This can be useful to avoid the overhead of building models which are only + used nested within other models, or when you want to manually define type namespace via + [`Model.model_rebuild(_types_namespace=...)`][pydantic.BaseModel.model_rebuild]. Defaults to False. + """ + + plugin_settings: dict[str, object] | None + """A `dict` of settings for plugins. Defaults to `None`. + + See [Pydantic Plugins](../concepts/plugins.md) for details. + """ + + schema_generator: type[_GenerateSchema] | None + """ + A custom core schema generator class to use when generating JSON schemas. + Useful if you want to change the way types are validated across an entire model/schema. Defaults to `None`. + + The `GenerateSchema` interface is subject to change, currently only the `string_schema` method is public. + + See [#6737](https://github.com/pydantic/pydantic/pull/6737) for details. + """ + + json_schema_serialization_defaults_required: bool + """ + Whether fields with default values should be marked as required in the serialization schema. Defaults to `False`. + + This ensures that the serialization schema will reflect the fact a field with a default will always be present + when serializing the model, even though it is not required for validation. + + However, there are scenarios where this may be undesirable — in particular, if you want to share the schema + between validation and serialization, and don't mind fields with defaults being marked as not required during + serialization. See [#7209](https://github.com/pydantic/pydantic/issues/7209) for more details. + + ```py + from pydantic import BaseModel, ConfigDict + + class Model(BaseModel): + a: str = 'a' + + model_config = ConfigDict(json_schema_serialization_defaults_required=True) + + print(Model.model_json_schema(mode='validation')) + ''' + { + 'properties': {'a': {'default': 'a', 'title': 'A', 'type': 'string'}}, + 'title': 'Model', + 'type': 'object', } + ''' + print(Model.model_json_schema(mode='serialization')) + ''' + { + 'properties': {'a': {'default': 'a', 'title': 'A', 'type': 'string'}}, + 'required': ['a'], + 'title': 'Model', + 'type': 'object', + } + ''' + ``` + """ - return type('Config', base_classes, namespace) + json_schema_mode_override: Literal['validation', 'serialization', None] + """ + If not `None`, the specified mode will be used to generate the JSON schema regardless of what `mode` was passed to + the function call. Defaults to `None`. + + This provides a way to force the JSON schema generation to reflect a specific mode, e.g., to always use the + validation schema. + + It can be useful when using frameworks (such as FastAPI) that may generate different schemas for validation + and serialization that must both be referenced from the same schema; when this happens, we automatically append + `-Input` to the definition reference for the validation schema and `-Output` to the definition reference for the + serialization schema. By specifying a `json_schema_mode_override` though, this prevents the conflict between + the validation and serialization schemas (since both will use the specified schema), and so prevents the suffixes + from being added to the definition references. + + ```py + from pydantic import BaseModel, ConfigDict, Json + + class Model(BaseModel): + a: Json[int] # requires a string to validate, but will dump an int + + print(Model.model_json_schema(mode='serialization')) + ''' + { + 'properties': {'a': {'title': 'A', 'type': 'integer'}}, + 'required': ['a'], + 'title': 'Model', + 'type': 'object', + } + ''' + + class ForceInputModel(Model): + # the following ensures that even with mode='serialization', we + # will get the schema that would be generated for validation. + model_config = ConfigDict(json_schema_mode_override='validation') + + print(ForceInputModel.model_json_schema(mode='serialization')) + ''' + { + 'properties': { + 'a': { + 'contentMediaType': 'application/json', + 'contentSchema': {'type': 'integer'}, + 'title': 'A', + 'type': 'string', + } + }, + 'required': ['a'], + 'title': 'ForceInputModel', + 'type': 'object', + } + ''' + ``` + """ + + coerce_numbers_to_str: bool + """ + If `True`, enables automatic coercion of any `Number` type to `str` in "lax" (non-strict) mode. Defaults to `False`. + + Pydantic doesn't allow number types (`int`, `float`, `Decimal`) to be coerced as type `str` by default. + + ```py + from decimal import Decimal + + from pydantic import BaseModel, ConfigDict, ValidationError + + class Model(BaseModel): + value: str + + try: + print(Model(value=42)) + except ValidationError as e: + print(e) + ''' + 1 validation error for Model + value + Input should be a valid string [type=string_type, input_value=42, input_type=int] + ''' + + class Model(BaseModel): + model_config = ConfigDict(coerce_numbers_to_str=True) + + value: str + + repr(Model(value=42).value) + #> "42" + repr(Model(value=42.13).value) + #> "42.13" + repr(Model(value=Decimal('42.13')).value) + #> "42.13" + ``` + """ + + regex_engine: Literal['rust-regex', 'python-re'] + """ + The regex engine to used for pattern validation + Defaults to `'rust-regex'`. + + - `rust-regex` uses the [`regex`](https://docs.rs/regex) Rust crate, + which is non-backtracking and therefore more DDoS resistant, but does not support all regex features. + - `python-re` use the [`re`](https://docs.python.org/3/library/re.html) module, + which supports all regex features, but may be slower. + + ```py + from pydantic import BaseModel, ConfigDict, Field, ValidationError + + class Model(BaseModel): + model_config = ConfigDict(regex_engine='python-re') + + value: str = Field(pattern=r'^abc(?=def)') + + print(Model(value='abcdef').value) + #> abcdef + + try: + print(Model(value='abxyzcdef')) + except ValidationError as e: + print(e) + ''' + 1 validation error for Model + value + String should match pattern '^abc(?=def)' [type=string_pattern_mismatch, input_value='abxyzcdef', input_type=str] + ''' + ``` + """ + + validation_error_cause: bool + """ + If `True`, python exceptions that were part of a validation failure will be shown as an exception group as a cause. Can be useful for debugging. Defaults to `False`. + + Note: + Python 3.10 and older don't support exception groups natively. <=3.10, backport must be installed: `pip install exceptiongroup`. + + Note: + The structure of validation errors are likely to change in future pydantic versions. Pydantic offers no guarantees about the structure of validation errors. Should be used for visual traceback debugging only. + """ -def prepare_config(config: Type[BaseConfig], cls_name: str) -> None: - if not isinstance(config.extra, Extra): - try: - config.extra = Extra(config.extra) - except ValueError: - raise ValueError(f'"{cls_name}": {config.extra} is not a valid value for "extra"') +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/dataclasses.py b/lib/pydantic/dataclasses.py index 68331127..d9c9c903 100644 --- a/lib/pydantic/dataclasses.py +++ b/lib/pydantic/dataclasses.py @@ -1,479 +1,327 @@ -""" -The main purpose is to enhance stdlib dataclasses by adding validation -A pydantic dataclass can be generated from scratch or from a stdlib one. +"""Provide an enhanced dataclass that performs validation.""" +from __future__ import annotations as _annotations -Behind the scene, a pydantic dataclass is just like a regular one on which we attach -a `BaseModel` and magic methods to trigger the validation of the data. -`__init__` and `__post_init__` are hence overridden and have extra logic to be -able to validate input data. - -When a pydantic dataclass is generated from scratch, it's just a plain dataclass -with validation triggered at initialization - -The tricky part if for stdlib dataclasses that are converted after into pydantic ones e.g. - -```py -@dataclasses.dataclass -class M: - x: int - -ValidatedM = pydantic.dataclasses.dataclass(M) -``` - -We indeed still want to support equality, hashing, repr, ... as if it was the stdlib one! - -```py -assert isinstance(ValidatedM(x=1), M) -assert ValidatedM(x=1) == M(x=1) -``` - -This means we **don't want to create a new dataclass that inherits from it** -The trick is to create a wrapper around `M` that will act as a proxy to trigger -validation without altering default `M` behaviour. -""" +import dataclasses import sys -from contextlib import contextmanager -from functools import wraps -from typing import ( - TYPE_CHECKING, - Any, - Callable, - ClassVar, - Dict, - Generator, - Optional, - Set, - Type, - TypeVar, - Union, - overload, -) +import types +from typing import TYPE_CHECKING, Any, Callable, Generic, NoReturn, TypeVar, overload -from typing_extensions import dataclass_transform +from typing_extensions import Literal, TypeGuard, dataclass_transform -from .class_validators import gather_all_validators -from .config import BaseConfig, ConfigDict, Extra, get_config -from .error_wrappers import ValidationError -from .errors import DataclassTypeError -from .fields import Field, FieldInfo, Required, Undefined -from .main import create_model, validate_model -from .utils import ClassAttribute +from ._internal import _config, _decorators, _typing_extra +from ._internal import _dataclasses as _pydantic_dataclasses +from ._migration import getattr_migration +from .config import ConfigDict +from .fields import Field, FieldInfo if TYPE_CHECKING: - from .main import BaseModel - from .typing import CallableGenerator, NoArgAnyCallable + from ._internal._dataclasses import PydanticDataclass - DataclassT = TypeVar('DataclassT', bound='Dataclass') - - DataclassClassOrWrapper = Union[Type['Dataclass'], 'DataclassProxy'] - - class Dataclass: - # stdlib attributes - __dataclass_fields__: ClassVar[Dict[str, Any]] - __dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams` - __post_init__: ClassVar[Callable[..., None]] - - # Added by pydantic - __pydantic_run_validation__: ClassVar[bool] - __post_init_post_parse__: ClassVar[Callable[..., None]] - __pydantic_initialised__: ClassVar[bool] - __pydantic_model__: ClassVar[Type[BaseModel]] - __pydantic_validate_values__: ClassVar[Callable[['Dataclass'], None]] - __pydantic_has_field_info_default__: ClassVar[bool] # whether a `pydantic.Field` is used as default value - - def __init__(self, *args: object, **kwargs: object) -> None: - pass - - @classmethod - def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator': - pass - - @classmethod - def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT': - pass - - -__all__ = [ - 'dataclass', - 'set_validation', - 'create_pydantic_model_from_dataclass', - 'is_builtin_dataclass', - 'make_dataclass_validator', -] +__all__ = 'dataclass', 'rebuild_dataclass' _T = TypeVar('_T') if sys.version_info >= (3, 10): - @dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) + @dataclass_transform(field_specifiers=(dataclasses.field, Field)) @overload def dataclass( *, - init: bool = True, + init: Literal[False] = False, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, - config: Union[ConfigDict, Type[object], None] = None, - validate_on_init: Optional[bool] = None, + config: ConfigDict | type[object] | None = None, + validate_on_init: bool | None = None, kw_only: bool = ..., - ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']: + slots: bool = ..., + ) -> Callable[[type[_T]], type[PydanticDataclass]]: # type: ignore ... - @dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) + @dataclass_transform(field_specifiers=(dataclasses.field, Field)) @overload def dataclass( - _cls: Type[_T], + _cls: type[_T], # type: ignore *, - init: bool = True, + init: Literal[False] = False, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, - config: Union[ConfigDict, Type[object], None] = None, - validate_on_init: Optional[bool] = None, + config: ConfigDict | type[object] | None = None, + validate_on_init: bool | None = None, kw_only: bool = ..., - ) -> 'DataclassClassOrWrapper': + slots: bool = ..., + ) -> type[PydanticDataclass]: ... else: - @dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) + @dataclass_transform(field_specifiers=(dataclasses.field, Field)) @overload def dataclass( *, - init: bool = True, + init: Literal[False] = False, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, - config: Union[ConfigDict, Type[object], None] = None, - validate_on_init: Optional[bool] = None, - ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']: + config: ConfigDict | type[object] | None = None, + validate_on_init: bool | None = None, + ) -> Callable[[type[_T]], type[PydanticDataclass]]: # type: ignore ... - @dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) + @dataclass_transform(field_specifiers=(dataclasses.field, Field)) @overload def dataclass( - _cls: Type[_T], + _cls: type[_T], # type: ignore *, - init: bool = True, + init: Literal[False] = False, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, - config: Union[ConfigDict, Type[object], None] = None, - validate_on_init: Optional[bool] = None, - ) -> 'DataclassClassOrWrapper': + config: ConfigDict | type[object] | None = None, + validate_on_init: bool | None = None, + ) -> type[PydanticDataclass]: ... -@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) -def dataclass( - _cls: Optional[Type[_T]] = None, +@dataclass_transform(field_specifiers=(dataclasses.field, Field)) +def dataclass( # noqa: C901 + _cls: type[_T] | None = None, *, - init: bool = True, + init: Literal[False] = False, repr: bool = True, eq: bool = True, order: bool = False, unsafe_hash: bool = False, frozen: bool = False, - config: Union[ConfigDict, Type[object], None] = None, - validate_on_init: Optional[bool] = None, + config: ConfigDict | type[object] | None = None, + validate_on_init: bool | None = None, kw_only: bool = False, -) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']: + slots: bool = False, +) -> Callable[[type[_T]], type[PydanticDataclass]] | type[PydanticDataclass]: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/dataclasses/ + + A decorator used to create a Pydantic-enhanced dataclass, similar to the standard Python `dataclass`, + but with added validation. + + This function should be used similarly to `dataclasses.dataclass`. + + Args: + _cls: The target `dataclass`. + init: Included for signature compatibility with `dataclasses.dataclass`, and is passed through to + `dataclasses.dataclass` when appropriate. If specified, must be set to `False`, as pydantic inserts its + own `__init__` function. + repr: A boolean indicating whether to include the field in the `__repr__` output. + eq: Determines if a `__eq__` method should be generated for the class. + order: Determines if comparison magic methods should be generated, such as `__lt__`, but not `__eq__`. + unsafe_hash: Determines if a `__hash__` method should be included in the class, as in `dataclasses.dataclass`. + frozen: Determines if the generated class should be a 'frozen' `dataclass`, which does not allow its + attributes to be modified after it has been initialized. + config: The Pydantic config to use for the `dataclass`. + validate_on_init: A deprecated parameter included for backwards compatibility; in V2, all Pydantic dataclasses + are validated on init. + kw_only: Determines if `__init__` method parameters must be specified by keyword only. Defaults to `False`. + slots: Determines if the generated class should be a 'slots' `dataclass`, which does not allow the addition of + new attributes after instantiation. + + Returns: + A decorator that accepts a class as its argument and returns a Pydantic `dataclass`. + + Raises: + AssertionError: Raised if `init` is not `False` or `validate_on_init` is `False`. """ - Like the python standard lib dataclasses but with type validation. - The result is either a pydantic dataclass that will validate input data - or a wrapper that will trigger validation around a stdlib dataclass - to avoid modifying it directly - """ - the_config = get_config(config) + assert init is False, 'pydantic.dataclasses.dataclass only supports init=False' + assert validate_on_init is not False, 'validate_on_init=False is no longer supported' - def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper': - import dataclasses + if sys.version_info >= (3, 10): + kwargs = dict(kw_only=kw_only, slots=slots) + else: + kwargs = {} - if is_builtin_dataclass(cls) and _extra_dc_args(_cls) == _extra_dc_args(_cls.__bases__[0]): # type: ignore - dc_cls_doc = '' - dc_cls = DataclassProxy(cls) - default_validate_on_init = False - else: - dc_cls_doc = cls.__doc__ or '' # needs to be done before generating dataclass - if sys.version_info >= (3, 10): - dc_cls = dataclasses.dataclass( - cls, - init=init, - repr=repr, - eq=eq, - order=order, - unsafe_hash=unsafe_hash, - frozen=frozen, - kw_only=kw_only, - ) - else: - dc_cls = dataclasses.dataclass( # type: ignore - cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen - ) - default_validate_on_init = True + def make_pydantic_fields_compatible(cls: type[Any]) -> None: + """Make sure that stdlib `dataclasses` understands `Field` kwargs like `kw_only` + To do that, we simply change + `x: int = pydantic.Field(..., kw_only=True)` + into + `x: int = dataclasses.field(default=pydantic.Field(..., kw_only=True), kw_only=True)` + """ + for annotation_cls in cls.__mro__: + # In Python < 3.9, `__annotations__` might not be present if there are no fields. + # we therefore need to use `getattr` to avoid an `AttributeError`. + annotations = getattr(annotation_cls, '__annotations__', []) + for field_name in annotations: + field_value = getattr(cls, field_name, None) + # Process only if this is an instance of `FieldInfo`. + if not isinstance(field_value, FieldInfo): + continue - should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init - _add_pydantic_validation_attributes(cls, the_config, should_validate_on_init, dc_cls_doc) - dc_cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls}) - return dc_cls + # Initialize arguments for the standard `dataclasses.field`. + field_args: dict = {'default': field_value} + + # Handle `kw_only` for Python 3.10+ + if sys.version_info >= (3, 10) and field_value.kw_only: + field_args['kw_only'] = True + + # Set `repr` attribute if it's explicitly specified to be not `True`. + if field_value.repr is not True: + field_args['repr'] = field_value.repr + + setattr(cls, field_name, dataclasses.field(**field_args)) + # In Python 3.8, dataclasses checks cls.__dict__['__annotations__'] for annotations, + # so we must make sure it's initialized before we add to it. + if cls.__dict__.get('__annotations__') is None: + cls.__annotations__ = {} + cls.__annotations__[field_name] = annotations[field_name] + + def create_dataclass(cls: type[Any]) -> type[PydanticDataclass]: + """Create a Pydantic dataclass from a regular dataclass. + + Args: + cls: The class to create the Pydantic dataclass from. + + Returns: + A Pydantic dataclass. + """ + original_cls = cls + + config_dict = config + if config_dict is None: + # if not explicitly provided, read from the type + cls_config = getattr(cls, '__pydantic_config__', None) + if cls_config is not None: + config_dict = cls_config + config_wrapper = _config.ConfigWrapper(config_dict) + decorators = _decorators.DecoratorInfos.build(cls) + + # Keep track of the original __doc__ so that we can restore it after applying the dataclasses decorator + # Otherwise, classes with no __doc__ will have their signature added into the JSON schema description, + # since dataclasses.dataclass will set this as the __doc__ + original_doc = cls.__doc__ + + if _pydantic_dataclasses.is_builtin_dataclass(cls): + # Don't preserve the docstring for vanilla dataclasses, as it may include the signature + # This matches v1 behavior, and there was an explicit test for it + original_doc = None + + # We don't want to add validation to the existing std lib dataclass, so we will subclass it + # If the class is generic, we need to make sure the subclass also inherits from Generic + # with all the same parameters. + bases = (cls,) + if issubclass(cls, Generic): + generic_base = Generic[cls.__parameters__] # type: ignore + bases = bases + (generic_base,) + cls = types.new_class(cls.__name__, bases) + + make_pydantic_fields_compatible(cls) + + cls = dataclasses.dataclass( # type: ignore[call-overload] + cls, + # the value of init here doesn't affect anything except that it makes it easier to generate a signature + init=True, + repr=repr, + eq=eq, + order=order, + unsafe_hash=unsafe_hash, + frozen=frozen, + **kwargs, + ) + + cls.__pydantic_decorators__ = decorators # type: ignore + cls.__doc__ = original_doc + cls.__module__ = original_cls.__module__ + cls.__qualname__ = original_cls.__qualname__ + pydantic_complete = _pydantic_dataclasses.complete_dataclass( + cls, config_wrapper, raise_errors=False, types_namespace=None + ) + cls.__pydantic_complete__ = pydantic_complete # type: ignore + return cls if _cls is None: - return wrap + return create_dataclass - return wrap(_cls) + return create_dataclass(_cls) -@contextmanager -def set_validation(cls: Type['DataclassT'], value: bool) -> Generator[Type['DataclassT'], None, None]: - original_run_validation = cls.__pydantic_run_validation__ - try: - cls.__pydantic_run_validation__ = value - yield cls - finally: - cls.__pydantic_run_validation__ = original_run_validation +__getattr__ = getattr_migration(__name__) + +if (3, 8) <= sys.version_info < (3, 11): + # Monkeypatch dataclasses.InitVar so that typing doesn't error if it occurs as a type when evaluating type hints + # Starting in 3.11, typing.get_type_hints will not raise an error if the retrieved type hints are not callable. + + def _call_initvar(*args: Any, **kwargs: Any) -> NoReturn: + """This function does nothing but raise an error that is as similar as possible to what you'd get + if you were to try calling `InitVar[int]()` without this monkeypatch. The whole purpose is just + to ensure typing._type_check does not error if the type hint evaluates to `InitVar[]`. + """ + raise TypeError("'InitVar' object is not callable") + + dataclasses.InitVar.__call__ = _call_initvar -class DataclassProxy: - __slots__ = '__dataclass__' +def rebuild_dataclass( + cls: type[PydanticDataclass], + *, + force: bool = False, + raise_errors: bool = True, + _parent_namespace_depth: int = 2, + _types_namespace: dict[str, Any] | None = None, +) -> bool | None: + """Try to rebuild the pydantic-core schema for the dataclass. - def __init__(self, dc_cls: Type['Dataclass']) -> None: - object.__setattr__(self, '__dataclass__', dc_cls) + This may be necessary when one of the annotations is a ForwardRef which could not be resolved during + the initial attempt to build the schema, and automatic rebuilding fails. - def __call__(self, *args: Any, **kwargs: Any) -> Any: - with set_validation(self.__dataclass__, True): - return self.__dataclass__(*args, **kwargs) + This is analogous to `BaseModel.model_rebuild`. - def __getattr__(self, name: str) -> Any: - return getattr(self.__dataclass__, name) + Args: + cls: The class to rebuild the pydantic-core schema for. + force: Whether to force the rebuilding of the schema, defaults to `False`. + raise_errors: Whether to raise errors, defaults to `True`. + _parent_namespace_depth: The depth level of the parent namespace, defaults to 2. + _types_namespace: The types namespace, defaults to `None`. - def __instancecheck__(self, instance: Any) -> bool: - return isinstance(instance, self.__dataclass__) - - -def _add_pydantic_validation_attributes( # noqa: C901 (ignore complexity) - dc_cls: Type['Dataclass'], - config: Type[BaseConfig], - validate_on_init: bool, - dc_cls_doc: str, -) -> None: + Returns: + Returns `None` if the schema is already "complete" and rebuilding was not required. + If rebuilding _was_ required, returns `True` if rebuilding was successful, otherwise `False`. """ - We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass - it won't even exist (code is generated on the fly by `dataclasses`) - By default, we run validation after `__init__` or `__post_init__` if defined - """ - init = dc_cls.__init__ - - @wraps(init) - def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: - if config.extra == Extra.ignore: - init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__}) - - elif config.extra == Extra.allow: - for k, v in kwargs.items(): - self.__dict__.setdefault(k, v) - init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__}) - - else: - init(self, *args, **kwargs) - - if hasattr(dc_cls, '__post_init__'): - post_init = dc_cls.__post_init__ - - @wraps(post_init) - def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: - if config.post_init_call == 'before_validation': - post_init(self, *args, **kwargs) - - if self.__class__.__pydantic_run_validation__: - self.__pydantic_validate_values__() - if hasattr(self, '__post_init_post_parse__'): - self.__post_init_post_parse__(*args, **kwargs) - - if config.post_init_call == 'after_validation': - post_init(self, *args, **kwargs) - - setattr(dc_cls, '__init__', handle_extra_init) - setattr(dc_cls, '__post_init__', new_post_init) - + if not force and cls.__pydantic_complete__: + return None else: - - @wraps(init) - def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: - handle_extra_init(self, *args, **kwargs) - - if self.__class__.__pydantic_run_validation__: - self.__pydantic_validate_values__() - - if hasattr(self, '__post_init_post_parse__'): - # We need to find again the initvars. To do that we use `__dataclass_fields__` instead of - # public method `dataclasses.fields` - import dataclasses - - # get all initvars and their default values - initvars_and_values: Dict[str, Any] = {} - for i, f in enumerate(self.__class__.__dataclass_fields__.values()): - if f._field_type is dataclasses._FIELD_INITVAR: # type: ignore[attr-defined] - try: - # set arg value by default - initvars_and_values[f.name] = args[i] - except IndexError: - initvars_and_values[f.name] = kwargs.get(f.name, f.default) - - self.__post_init_post_parse__(**initvars_and_values) - - setattr(dc_cls, '__init__', new_init) - - setattr(dc_cls, '__pydantic_run_validation__', ClassAttribute('__pydantic_run_validation__', validate_on_init)) - setattr(dc_cls, '__pydantic_initialised__', False) - setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config, dc_cls_doc)) - setattr(dc_cls, '__pydantic_validate_values__', _dataclass_validate_values) - setattr(dc_cls, '__validate__', classmethod(_validate_dataclass)) - setattr(dc_cls, '__get_validators__', classmethod(_get_validators)) - - if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen: - setattr(dc_cls, '__setattr__', _dataclass_validate_assignment_setattr) - - -def _get_validators(cls: 'DataclassClassOrWrapper') -> 'CallableGenerator': - yield cls.__validate__ - - -def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT': - with set_validation(cls, True): - if isinstance(v, cls): - v.__pydantic_validate_values__() - return v - elif isinstance(v, (list, tuple)): - return cls(*v) - elif isinstance(v, dict): - return cls(**v) + if _types_namespace is not None: + types_namespace: dict[str, Any] | None = _types_namespace.copy() else: - raise DataclassTypeError(class_name=cls.__name__) + if _parent_namespace_depth > 0: + frame_parent_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth) or {} + # Note: we may need to add something similar to cls.__pydantic_parent_namespace__ from BaseModel + # here when implementing handling of recursive generics. See BaseModel.model_rebuild for reference. + types_namespace = frame_parent_ns + else: + types_namespace = {} + + types_namespace = _typing_extra.get_cls_types_namespace(cls, types_namespace) + return _pydantic_dataclasses.complete_dataclass( + cls, + _config.ConfigWrapper(cls.__pydantic_config__, check=False), + raise_errors=raise_errors, + types_namespace=types_namespace, + ) -def create_pydantic_model_from_dataclass( - dc_cls: Type['Dataclass'], - config: Type[Any] = BaseConfig, - dc_cls_doc: Optional[str] = None, -) -> Type['BaseModel']: - import dataclasses +def is_pydantic_dataclass(__cls: type[Any]) -> TypeGuard[type[PydanticDataclass]]: + """Whether a class is a pydantic dataclass. - field_definitions: Dict[str, Any] = {} - for field in dataclasses.fields(dc_cls): - default: Any = Undefined - default_factory: Optional['NoArgAnyCallable'] = None - field_info: FieldInfo + Args: + __cls: The class. - if field.default is not dataclasses.MISSING: - default = field.default - elif field.default_factory is not dataclasses.MISSING: - default_factory = field.default_factory - else: - default = Required - - if isinstance(default, FieldInfo): - field_info = default - dc_cls.__pydantic_has_field_info_default__ = True - else: - field_info = Field(default=default, default_factory=default_factory, **field.metadata) - - field_definitions[field.name] = (field.type, field_info) - - validators = gather_all_validators(dc_cls) - model: Type['BaseModel'] = create_model( - dc_cls.__name__, - __config__=config, - __module__=dc_cls.__module__, - __validators__=validators, - __cls_kwargs__={'__resolve_forward_refs__': False}, - **field_definitions, - ) - model.__doc__ = dc_cls_doc if dc_cls_doc is not None else dc_cls.__doc__ or '' - return model - - -def _dataclass_validate_values(self: 'Dataclass') -> None: - # validation errors can occur if this function is called twice on an already initialised dataclass. - # for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property - if getattr(self, '__pydantic_initialised__'): - return - if getattr(self, '__pydantic_has_field_info_default__', False): - # We need to remove `FieldInfo` values since they are not valid as input - # It's ok to do that because they are obviously the default values! - input_data = {k: v for k, v in self.__dict__.items() if not isinstance(v, FieldInfo)} - else: - input_data = self.__dict__ - d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__) - if validation_error: - raise validation_error - self.__dict__.update(d) - object.__setattr__(self, '__pydantic_initialised__', True) - - -def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None: - if self.__pydantic_initialised__: - d = dict(self.__dict__) - d.pop(name, None) - known_field = self.__pydantic_model__.__fields__.get(name, None) - if known_field: - value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__) - if error_: - raise ValidationError([error_], self.__class__) - - object.__setattr__(self, name, value) - - -def _extra_dc_args(cls: Type[Any]) -> Set[str]: - return { - x - for x in dir(cls) - if x not in getattr(cls, '__dataclass_fields__', {}) and not (x.startswith('__') and x.endswith('__')) - } - - -def is_builtin_dataclass(_cls: Type[Any]) -> bool: + Returns: + `True` if the class is a pydantic dataclass, `False` otherwise. """ - Whether a class is a stdlib dataclass - (useful to discriminated a pydantic dataclass that is actually a wrapper around a stdlib dataclass) - - we check that - - `_cls` is a dataclass - - `_cls` is not a processed pydantic dataclass (with a basemodel attached) - - `_cls` is not a pydantic dataclass inheriting directly from a stdlib dataclass - e.g. - ``` - @dataclasses.dataclass - class A: - x: int - - @pydantic.dataclasses.dataclass - class B(A): - y: int - ``` - In this case, when we first check `B`, we make an extra check and look at the annotations ('y'), - which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x') - """ - import dataclasses - - return ( - dataclasses.is_dataclass(_cls) - and not hasattr(_cls, '__pydantic_model__') - and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {}))) - ) - - -def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]) -> 'CallableGenerator': - """ - Create a pydantic.dataclass from a builtin dataclass to add type validation - and yield the validators - It retrieves the parameters of the dataclass and forwards them to the newly created dataclass - """ - yield from _get_validators(dataclass(dc_cls, config=config, validate_on_init=False)) + return dataclasses.is_dataclass(__cls) and '__pydantic_validator__' in __cls.__dict__ diff --git a/lib/pydantic/datetime_parse.py b/lib/pydantic/datetime_parse.py index cfd54593..902219df 100644 --- a/lib/pydantic/datetime_parse.py +++ b/lib/pydantic/datetime_parse.py @@ -1,248 +1,4 @@ -""" -Functions to parse datetime objects. +"""The `datetime_parse` module is a backport module from V1.""" +from ._migration import getattr_migration -We're using regular expressions rather than time.strptime because: -- They provide both validation and parsing. -- They're more flexible for datetimes. -- The date/datetime/time constructors produce friendlier error messages. - -Stolen from https://raw.githubusercontent.com/django/django/main/django/utils/dateparse.py at -9718fa2e8abe430c3526a9278dd976443d4ae3c6 - -Changed to: -* use standard python datetime types not django.utils.timezone -* raise ValueError when regex doesn't match rather than returning None -* support parsing unix timestamps for dates and datetimes -""" -import re -from datetime import date, datetime, time, timedelta, timezone -from typing import Dict, Optional, Type, Union - -from . import errors - -date_expr = r'(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})' -time_expr = ( - r'(?P\d{1,2}):(?P\d{1,2})' - r'(?::(?P\d{1,2})(?:\.(?P\d{1,6})\d{0,6})?)?' - r'(?PZ|[+-]\d{2}(?::?\d{2})?)?$' -) - -date_re = re.compile(f'{date_expr}$') -time_re = re.compile(time_expr) -datetime_re = re.compile(f'{date_expr}[T ]{time_expr}') - -standard_duration_re = re.compile( - r'^' - r'(?:(?P-?\d+) (days?, )?)?' - r'((?:(?P-?\d+):)(?=\d+:\d+))?' - r'(?:(?P-?\d+):)?' - r'(?P-?\d+)' - r'(?:\.(?P\d{1,6})\d{0,6})?' - r'$' -) - -# Support the sections of ISO 8601 date representation that are accepted by timedelta -iso8601_duration_re = re.compile( - r'^(?P[-+]?)' - r'P' - r'(?:(?P\d+(.\d+)?)D)?' - r'(?:T' - r'(?:(?P\d+(.\d+)?)H)?' - r'(?:(?P\d+(.\d+)?)M)?' - r'(?:(?P\d+(.\d+)?)S)?' - r')?' - r'$' -) - -EPOCH = datetime(1970, 1, 1) -# if greater than this, the number is in ms, if less than or equal it's in seconds -# (in seconds this is 11th October 2603, in ms it's 20th August 1970) -MS_WATERSHED = int(2e10) -# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9 -MAX_NUMBER = int(3e20) -StrBytesIntFloat = Union[str, bytes, int, float] - - -def get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]: - if isinstance(value, (int, float)): - return value - try: - return float(value) - except ValueError: - return None - except TypeError: - raise TypeError(f'invalid type; expected {native_expected_type}, string, bytes, int or float') - - -def from_unix_seconds(seconds: Union[int, float]) -> datetime: - if seconds > MAX_NUMBER: - return datetime.max - elif seconds < -MAX_NUMBER: - return datetime.min - - while abs(seconds) > MS_WATERSHED: - seconds /= 1000 - dt = EPOCH + timedelta(seconds=seconds) - return dt.replace(tzinfo=timezone.utc) - - -def _parse_timezone(value: Optional[str], error: Type[Exception]) -> Union[None, int, timezone]: - if value == 'Z': - return timezone.utc - elif value is not None: - offset_mins = int(value[-2:]) if len(value) > 3 else 0 - offset = 60 * int(value[1:3]) + offset_mins - if value[0] == '-': - offset = -offset - try: - return timezone(timedelta(minutes=offset)) - except ValueError: - raise error() - else: - return None - - -def parse_date(value: Union[date, StrBytesIntFloat]) -> date: - """ - Parse a date/int/float/string and return a datetime.date. - - Raise ValueError if the input is well formatted but not a valid date. - Raise ValueError if the input isn't well formatted. - """ - if isinstance(value, date): - if isinstance(value, datetime): - return value.date() - else: - return value - - number = get_numeric(value, 'date') - if number is not None: - return from_unix_seconds(number).date() - - if isinstance(value, bytes): - value = value.decode() - - match = date_re.match(value) # type: ignore - if match is None: - raise errors.DateError() - - kw = {k: int(v) for k, v in match.groupdict().items()} - - try: - return date(**kw) - except ValueError: - raise errors.DateError() - - -def parse_time(value: Union[time, StrBytesIntFloat]) -> time: - """ - Parse a time/string and return a datetime.time. - - Raise ValueError if the input is well formatted but not a valid time. - Raise ValueError if the input isn't well formatted, in particular if it contains an offset. - """ - if isinstance(value, time): - return value - - number = get_numeric(value, 'time') - if number is not None: - if number >= 86400: - # doesn't make sense since the time time loop back around to 0 - raise errors.TimeError() - return (datetime.min + timedelta(seconds=number)).time() - - if isinstance(value, bytes): - value = value.decode() - - match = time_re.match(value) # type: ignore - if match is None: - raise errors.TimeError() - - kw = match.groupdict() - if kw['microsecond']: - kw['microsecond'] = kw['microsecond'].ljust(6, '0') - - tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.TimeError) - kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None} - kw_['tzinfo'] = tzinfo - - try: - return time(**kw_) # type: ignore - except ValueError: - raise errors.TimeError() - - -def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: - """ - Parse a datetime/int/float/string and return a datetime.datetime. - - This function supports time zone offsets. When the input contains one, - the output uses a timezone with a fixed offset from UTC. - - Raise ValueError if the input is well formatted but not a valid datetime. - Raise ValueError if the input isn't well formatted. - """ - if isinstance(value, datetime): - return value - - number = get_numeric(value, 'datetime') - if number is not None: - return from_unix_seconds(number) - - if isinstance(value, bytes): - value = value.decode() - - match = datetime_re.match(value) # type: ignore - if match is None: - raise errors.DateTimeError() - - kw = match.groupdict() - if kw['microsecond']: - kw['microsecond'] = kw['microsecond'].ljust(6, '0') - - tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.DateTimeError) - kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None} - kw_['tzinfo'] = tzinfo - - try: - return datetime(**kw_) # type: ignore - except ValueError: - raise errors.DateTimeError() - - -def parse_duration(value: StrBytesIntFloat) -> timedelta: - """ - Parse a duration int/float/string and return a datetime.timedelta. - - The preferred format for durations in Django is '%d %H:%M:%S.%f'. - - Also supports ISO 8601 representation. - """ - if isinstance(value, timedelta): - return value - - if isinstance(value, (int, float)): - # below code requires a string - value = f'{value:f}' - elif isinstance(value, bytes): - value = value.decode() - - try: - match = standard_duration_re.match(value) or iso8601_duration_re.match(value) - except TypeError: - raise TypeError('invalid type; expected timedelta, string, bytes, int or float') - - if not match: - raise errors.DurationError() - - kw = match.groupdict() - sign = -1 if kw.pop('sign', '+') == '-' else 1 - if kw.get('microseconds'): - kw['microseconds'] = kw['microseconds'].ljust(6, '0') - - if kw.get('seconds') and kw.get('microseconds') and kw['seconds'].startswith('-'): - kw['microseconds'] = '-' + kw['microseconds'] - - kw_ = {k: float(v) for k, v in kw.items() if v is not None} - - return sign * timedelta(**kw_) +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/decorator.py b/lib/pydantic/decorator.py index 089aab65..c3643468 100644 --- a/lib/pydantic/decorator.py +++ b/lib/pydantic/decorator.py @@ -1,264 +1,4 @@ -from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload +"""The `decorator` module is a backport module from V1.""" +from ._migration import getattr_migration -from . import validator -from .config import Extra -from .errors import ConfigError -from .main import BaseModel, create_model -from .typing import get_all_type_hints -from .utils import to_camel - -__all__ = ('validate_arguments',) - -if TYPE_CHECKING: - from .typing import AnyCallable - - AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable) - ConfigType = Union[None, Type[Any], Dict[str, Any]] - - -@overload -def validate_arguments(func: None = None, *, config: 'ConfigType' = None) -> Callable[['AnyCallableT'], 'AnyCallableT']: - ... - - -@overload -def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT': - ... - - -def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any: - """ - Decorator to validate the arguments passed to a function. - """ - - def validate(_func: 'AnyCallable') -> 'AnyCallable': - vd = ValidatedFunction(_func, config) - - @wraps(_func) - def wrapper_function(*args: Any, **kwargs: Any) -> Any: - return vd.call(*args, **kwargs) - - wrapper_function.vd = vd # type: ignore - wrapper_function.validate = vd.init_model_instance # type: ignore - wrapper_function.raw_function = vd.raw_function # type: ignore - wrapper_function.model = vd.model # type: ignore - return wrapper_function - - if func: - return validate(func) - else: - return validate - - -ALT_V_ARGS = 'v__args' -ALT_V_KWARGS = 'v__kwargs' -V_POSITIONAL_ONLY_NAME = 'v__positional_only' -V_DUPLICATE_KWARGS = 'v__duplicate_kwargs' - - -class ValidatedFunction: - def __init__(self, function: 'AnyCallableT', config: 'ConfigType'): # noqa C901 - from inspect import Parameter, signature - - parameters: Mapping[str, Parameter] = signature(function).parameters - - if parameters.keys() & {ALT_V_ARGS, ALT_V_KWARGS, V_POSITIONAL_ONLY_NAME, V_DUPLICATE_KWARGS}: - raise ConfigError( - f'"{ALT_V_ARGS}", "{ALT_V_KWARGS}", "{V_POSITIONAL_ONLY_NAME}" and "{V_DUPLICATE_KWARGS}" ' - f'are not permitted as argument names when using the "{validate_arguments.__name__}" decorator' - ) - - self.raw_function = function - self.arg_mapping: Dict[int, str] = {} - self.positional_only_args = set() - self.v_args_name = 'args' - self.v_kwargs_name = 'kwargs' - - type_hints = get_all_type_hints(function) - takes_args = False - takes_kwargs = False - fields: Dict[str, Tuple[Any, Any]] = {} - for i, (name, p) in enumerate(parameters.items()): - if p.annotation is p.empty: - annotation = Any - else: - annotation = type_hints[name] - - default = ... if p.default is p.empty else p.default - if p.kind == Parameter.POSITIONAL_ONLY: - self.arg_mapping[i] = name - fields[name] = annotation, default - fields[V_POSITIONAL_ONLY_NAME] = List[str], None - self.positional_only_args.add(name) - elif p.kind == Parameter.POSITIONAL_OR_KEYWORD: - self.arg_mapping[i] = name - fields[name] = annotation, default - fields[V_DUPLICATE_KWARGS] = List[str], None - elif p.kind == Parameter.KEYWORD_ONLY: - fields[name] = annotation, default - elif p.kind == Parameter.VAR_POSITIONAL: - self.v_args_name = name - fields[name] = Tuple[annotation, ...], None - takes_args = True - else: - assert p.kind == Parameter.VAR_KEYWORD, p.kind - self.v_kwargs_name = name - fields[name] = Dict[str, annotation], None # type: ignore - takes_kwargs = True - - # these checks avoid a clash between "args" and a field with that name - if not takes_args and self.v_args_name in fields: - self.v_args_name = ALT_V_ARGS - - # same with "kwargs" - if not takes_kwargs and self.v_kwargs_name in fields: - self.v_kwargs_name = ALT_V_KWARGS - - if not takes_args: - # we add the field so validation below can raise the correct exception - fields[self.v_args_name] = List[Any], None - - if not takes_kwargs: - # same with kwargs - fields[self.v_kwargs_name] = Dict[Any, Any], None - - self.create_model(fields, takes_args, takes_kwargs, config) - - def init_model_instance(self, *args: Any, **kwargs: Any) -> BaseModel: - values = self.build_values(args, kwargs) - return self.model(**values) - - def call(self, *args: Any, **kwargs: Any) -> Any: - m = self.init_model_instance(*args, **kwargs) - return self.execute(m) - - def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]: - values: Dict[str, Any] = {} - if args: - arg_iter = enumerate(args) - while True: - try: - i, a = next(arg_iter) - except StopIteration: - break - arg_name = self.arg_mapping.get(i) - if arg_name is not None: - values[arg_name] = a - else: - values[self.v_args_name] = [a] + [a for _, a in arg_iter] - break - - var_kwargs: Dict[str, Any] = {} - wrong_positional_args = [] - duplicate_kwargs = [] - fields_alias = [ - field.alias - for name, field in self.model.__fields__.items() - if name not in (self.v_args_name, self.v_kwargs_name) - ] - non_var_fields = set(self.model.__fields__) - {self.v_args_name, self.v_kwargs_name} - for k, v in kwargs.items(): - if k in non_var_fields or k in fields_alias: - if k in self.positional_only_args: - wrong_positional_args.append(k) - if k in values: - duplicate_kwargs.append(k) - values[k] = v - else: - var_kwargs[k] = v - - if var_kwargs: - values[self.v_kwargs_name] = var_kwargs - if wrong_positional_args: - values[V_POSITIONAL_ONLY_NAME] = wrong_positional_args - if duplicate_kwargs: - values[V_DUPLICATE_KWARGS] = duplicate_kwargs - return values - - def execute(self, m: BaseModel) -> Any: - d = {k: v for k, v in m._iter() if k in m.__fields_set__ or m.__fields__[k].default_factory} - var_kwargs = d.pop(self.v_kwargs_name, {}) - - if self.v_args_name in d: - args_: List[Any] = [] - in_kwargs = False - kwargs = {} - for name, value in d.items(): - if in_kwargs: - kwargs[name] = value - elif name == self.v_args_name: - args_ += value - in_kwargs = True - else: - args_.append(value) - return self.raw_function(*args_, **kwargs, **var_kwargs) - elif self.positional_only_args: - args_ = [] - kwargs = {} - for name, value in d.items(): - if name in self.positional_only_args: - args_.append(value) - else: - kwargs[name] = value - return self.raw_function(*args_, **kwargs, **var_kwargs) - else: - return self.raw_function(**d, **var_kwargs) - - def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None: - pos_args = len(self.arg_mapping) - - class CustomConfig: - pass - - if not TYPE_CHECKING: # pragma: no branch - if isinstance(config, dict): - CustomConfig = type('Config', (), config) # noqa: F811 - elif config is not None: - CustomConfig = config # noqa: F811 - - if hasattr(CustomConfig, 'fields') or hasattr(CustomConfig, 'alias_generator'): - raise ConfigError( - 'Setting the "fields" and "alias_generator" property on custom Config for ' - '@validate_arguments is not yet supported, please remove.' - ) - - class DecoratorBaseModel(BaseModel): - @validator(self.v_args_name, check_fields=False, allow_reuse=True) - def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]: - if takes_args or v is None: - return v - - raise TypeError(f'{pos_args} positional arguments expected but {pos_args + len(v)} given') - - @validator(self.v_kwargs_name, check_fields=False, allow_reuse=True) - def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: - if takes_kwargs or v is None: - return v - - plural = '' if len(v) == 1 else 's' - keys = ', '.join(map(repr, v.keys())) - raise TypeError(f'unexpected keyword argument{plural}: {keys}') - - @validator(V_POSITIONAL_ONLY_NAME, check_fields=False, allow_reuse=True) - def check_positional_only(cls, v: Optional[List[str]]) -> None: - if v is None: - return - - plural = '' if len(v) == 1 else 's' - keys = ', '.join(map(repr, v)) - raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}') - - @validator(V_DUPLICATE_KWARGS, check_fields=False, allow_reuse=True) - def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None: - if v is None: - return - - plural = '' if len(v) == 1 else 's' - keys = ', '.join(map(repr, v)) - raise TypeError(f'multiple values for argument{plural}: {keys}') - - class Config(CustomConfig): - extra = getattr(CustomConfig, 'extra', Extra.forbid) - - self.model = create_model(to_camel(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields) +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/deprecated/__init__.py b/lib/pydantic/deprecated/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lib/pydantic/deprecated/class_validators.py b/lib/pydantic/deprecated/class_validators.py new file mode 100644 index 00000000..7b48afd2 --- /dev/null +++ b/lib/pydantic/deprecated/class_validators.py @@ -0,0 +1,253 @@ +"""Old `@validator` and `@root_validator` function validators from V1.""" + +from __future__ import annotations as _annotations + +from functools import partial, partialmethod +from types import FunctionType +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, overload +from warnings import warn + +from typing_extensions import Literal, Protocol, TypeAlias + +from .._internal import _decorators, _decorators_v1 +from ..errors import PydanticUserError +from ..warnings import PydanticDeprecatedSince20 + +_ALLOW_REUSE_WARNING_MESSAGE = '`allow_reuse` is deprecated and will be ignored; it should no longer be necessary' + + +if TYPE_CHECKING: + + class _OnlyValueValidatorClsMethod(Protocol): + def __call__(self, __cls: Any, __value: Any) -> Any: + ... + + class _V1ValidatorWithValuesClsMethod(Protocol): + def __call__(self, __cls: Any, __value: Any, values: dict[str, Any]) -> Any: + ... + + class _V1ValidatorWithValuesKwOnlyClsMethod(Protocol): + def __call__(self, __cls: Any, __value: Any, *, values: dict[str, Any]) -> Any: + ... + + class _V1ValidatorWithKwargsClsMethod(Protocol): + def __call__(self, __cls: Any, **kwargs: Any) -> Any: + ... + + class _V1ValidatorWithValuesAndKwargsClsMethod(Protocol): + def __call__(self, __cls: Any, values: dict[str, Any], **kwargs: Any) -> Any: + ... + + class _V1RootValidatorClsMethod(Protocol): + def __call__( + self, __cls: Any, __values: _decorators_v1.RootValidatorValues + ) -> _decorators_v1.RootValidatorValues: + ... + + V1Validator = Union[ + _OnlyValueValidatorClsMethod, + _V1ValidatorWithValuesClsMethod, + _V1ValidatorWithValuesKwOnlyClsMethod, + _V1ValidatorWithKwargsClsMethod, + _V1ValidatorWithValuesAndKwargsClsMethod, + _decorators_v1.V1ValidatorWithValues, + _decorators_v1.V1ValidatorWithValuesKwOnly, + _decorators_v1.V1ValidatorWithKwargs, + _decorators_v1.V1ValidatorWithValuesAndKwargs, + ] + + V1RootValidator = Union[ + _V1RootValidatorClsMethod, + _decorators_v1.V1RootValidatorFunction, + ] + + _PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]] + + # Allow both a V1 (assumed pre=False) or V2 (assumed mode='after') validator + # We lie to type checkers and say we return the same thing we get + # but in reality we return a proxy object that _mostly_ behaves like the wrapped thing + _V1ValidatorType = TypeVar('_V1ValidatorType', V1Validator, _PartialClsOrStaticMethod) + _V1RootValidatorFunctionType = TypeVar( + '_V1RootValidatorFunctionType', + _decorators_v1.V1RootValidatorFunction, + _V1RootValidatorClsMethod, + _PartialClsOrStaticMethod, + ) +else: + # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915 + # and https://youtrack.jetbrains.com/issue/PY-51428 + DeprecationWarning = PydanticDeprecatedSince20 + + +def validator( + __field: str, + *fields: str, + pre: bool = False, + each_item: bool = False, + always: bool = False, + check_fields: bool | None = None, + allow_reuse: bool = False, +) -> Callable[[_V1ValidatorType], _V1ValidatorType]: + """Decorate methods on the class indicating that they should be used to validate fields. + + Args: + __field (str): The first field the validator should be called on; this is separate + from `fields` to ensure an error is raised if you don't pass at least one. + *fields (str): Additional field(s) the validator should be called on. + pre (bool, optional): Whether this validator should be called before the standard + validators (else after). Defaults to False. + each_item (bool, optional): For complex objects (sets, lists etc.) whether to validate + individual elements rather than the whole object. Defaults to False. + always (bool, optional): Whether this method and other validators should be called even if + the value is missing. Defaults to False. + check_fields (bool | None, optional): Whether to check that the fields actually exist on the model. + Defaults to None. + allow_reuse (bool, optional): Whether to track and raise an error if another validator refers to + the decorated function. Defaults to False. + + Returns: + Callable: A decorator that can be used to decorate a + function to be used as a validator. + """ + if allow_reuse is True: # pragma: no cover + warn(_ALLOW_REUSE_WARNING_MESSAGE, DeprecationWarning) + fields = tuple((__field, *fields)) + if isinstance(fields[0], FunctionType): + raise PydanticUserError( + '`@validator` should be used with fields and keyword arguments, not bare. ' + "E.g. usage should be `@validator('', ...)`", + code='validator-no-fields', + ) + elif not all(isinstance(field, str) for field in fields): + raise PydanticUserError( + '`@validator` fields should be passed as separate string args. ' + "E.g. usage should be `@validator('', '', ...)`", + code='validator-invalid-fields', + ) + + warn( + 'Pydantic V1 style `@validator` validators are deprecated.' + ' You should migrate to Pydantic V2 style `@field_validator` validators,' + ' see the migration guide for more details', + DeprecationWarning, + stacklevel=2, + ) + + mode: Literal['before', 'after'] = 'before' if pre is True else 'after' + + def dec(f: Any) -> _decorators.PydanticDescriptorProxy[Any]: + if _decorators.is_instance_method_from_sig(f): + raise PydanticUserError( + '`@validator` cannot be applied to instance methods', code='validator-instance-method' + ) + # auto apply the @classmethod decorator + f = _decorators.ensure_classmethod_based_on_signature(f) + wrap = _decorators_v1.make_generic_v1_field_validator + validator_wrapper_info = _decorators.ValidatorDecoratorInfo( + fields=fields, + mode=mode, + each_item=each_item, + always=always, + check_fields=check_fields, + ) + return _decorators.PydanticDescriptorProxy(f, validator_wrapper_info, shim=wrap) + + return dec # type: ignore[return-value] + + +@overload +def root_validator( + *, + # if you don't specify `pre` the default is `pre=False` + # which means you need to specify `skip_on_failure=True` + skip_on_failure: Literal[True], + allow_reuse: bool = ..., +) -> Callable[ + [_V1RootValidatorFunctionType], + _V1RootValidatorFunctionType, +]: + ... + + +@overload +def root_validator( + *, + # if you specify `pre=True` then you don't need to specify + # `skip_on_failure`, in fact it is not allowed as an argument! + pre: Literal[True], + allow_reuse: bool = ..., +) -> Callable[ + [_V1RootValidatorFunctionType], + _V1RootValidatorFunctionType, +]: + ... + + +@overload +def root_validator( + *, + # if you explicitly specify `pre=False` then you + # MUST specify `skip_on_failure=True` + pre: Literal[False], + skip_on_failure: Literal[True], + allow_reuse: bool = ..., +) -> Callable[ + [_V1RootValidatorFunctionType], + _V1RootValidatorFunctionType, +]: + ... + + +def root_validator( + *__args, + pre: bool = False, + skip_on_failure: bool = False, + allow_reuse: bool = False, +) -> Any: + """Decorate methods on a model indicating that they should be used to validate (and perhaps + modify) data either before or after standard model parsing/validation is performed. + + Args: + pre (bool, optional): Whether this validator should be called before the standard + validators (else after). Defaults to False. + skip_on_failure (bool, optional): Whether to stop validation and return as soon as a + failure is encountered. Defaults to False. + allow_reuse (bool, optional): Whether to track and raise an error if another validator + refers to the decorated function. Defaults to False. + + Returns: + Any: A decorator that can be used to decorate a function to be used as a root_validator. + """ + warn( + 'Pydantic V1 style `@root_validator` validators are deprecated.' + ' You should migrate to Pydantic V2 style `@model_validator` validators,' + ' see the migration guide for more details', + DeprecationWarning, + stacklevel=2, + ) + + if __args: + # Ensure a nice error is raised if someone attempts to use the bare decorator + return root_validator()(*__args) # type: ignore + + if allow_reuse is True: # pragma: no cover + warn(_ALLOW_REUSE_WARNING_MESSAGE, DeprecationWarning) + mode: Literal['before', 'after'] = 'before' if pre is True else 'after' + if pre is False and skip_on_failure is not True: + raise PydanticUserError( + 'If you use `@root_validator` with pre=False (the default) you MUST specify `skip_on_failure=True`.' + ' Note that `@root_validator` is deprecated and should be replaced with `@model_validator`.', + code='root-validator-pre-skip', + ) + + wrap = partial(_decorators_v1.make_v1_generic_root_validator, pre=pre) + + def dec(f: Callable[..., Any] | classmethod[Any, Any, Any] | staticmethod[Any, Any]) -> Any: + if _decorators.is_instance_method_from_sig(f): + raise TypeError('`@root_validator` cannot be applied to instance methods') + # auto apply the @classmethod decorator + res = _decorators.ensure_classmethod_based_on_signature(f) + dec_info = _decorators.RootValidatorDecoratorInfo(mode=mode) + return _decorators.PydanticDescriptorProxy(res, dec_info, shim=wrap) + + return dec diff --git a/lib/pydantic/deprecated/config.py b/lib/pydantic/deprecated/config.py new file mode 100644 index 00000000..45400c65 --- /dev/null +++ b/lib/pydantic/deprecated/config.py @@ -0,0 +1,72 @@ +from __future__ import annotations as _annotations + +import warnings +from typing import TYPE_CHECKING, Any + +from typing_extensions import Literal, deprecated + +from .._internal import _config +from ..warnings import PydanticDeprecatedSince20 + +if not TYPE_CHECKING: + # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915 + # and https://youtrack.jetbrains.com/issue/PY-51428 + DeprecationWarning = PydanticDeprecatedSince20 + +__all__ = 'BaseConfig', 'Extra' + + +class _ConfigMetaclass(type): + def __getattr__(self, item: str) -> Any: + try: + obj = _config.config_defaults[item] + warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning) + return obj + except KeyError as exc: + raise AttributeError(f"type object '{self.__name__}' has no attribute {exc}") from exc + + +@deprecated('BaseConfig is deprecated. Use the `pydantic.ConfigDict` instead.', category=PydanticDeprecatedSince20) +class BaseConfig(metaclass=_ConfigMetaclass): + """This class is only retained for backwards compatibility. + + !!! Warning "Deprecated" + BaseConfig is deprecated. Use the [`pydantic.ConfigDict`][pydantic.ConfigDict] instead. + """ + + def __getattr__(self, item: str) -> Any: + try: + obj = super().__getattribute__(item) + warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning) + return obj + except AttributeError as exc: + try: + return getattr(type(self), item) + except AttributeError: + # re-raising changes the displayed text to reflect that `self` is not a type + raise AttributeError(str(exc)) from exc + + def __init_subclass__(cls, **kwargs: Any) -> None: + warnings.warn(_config.DEPRECATION_MESSAGE, DeprecationWarning) + return super().__init_subclass__(**kwargs) + + +class _ExtraMeta(type): + def __getattribute__(self, __name: str) -> Any: + # The @deprecated decorator accesses other attributes, so we only emit a warning for the expected ones + if __name in {'allow', 'ignore', 'forbid'}: + warnings.warn( + "`pydantic.config.Extra` is deprecated, use literal values instead (e.g. `extra='allow'`)", + DeprecationWarning, + stacklevel=2, + ) + return super().__getattribute__(__name) + + +@deprecated( + "Extra is deprecated. Use literal values instead (e.g. `extra='allow'`)", category=PydanticDeprecatedSince20 +) +class Extra(metaclass=_ExtraMeta): + allow: Literal['allow'] = 'allow' + ignore: Literal['ignore'] = 'ignore' + forbid: Literal['forbid'] = 'forbid' diff --git a/lib/pydantic/deprecated/copy_internals.py b/lib/pydantic/deprecated/copy_internals.py new file mode 100644 index 00000000..efe5de28 --- /dev/null +++ b/lib/pydantic/deprecated/copy_internals.py @@ -0,0 +1,224 @@ +from __future__ import annotations as _annotations + +import typing +from copy import deepcopy +from enum import Enum +from typing import Any, Tuple + +import typing_extensions + +from .._internal import ( + _model_construction, + _typing_extra, + _utils, +) + +if typing.TYPE_CHECKING: + from .. import BaseModel + from .._internal._utils import AbstractSetIntStr, MappingIntStrAny + + AnyClassMethod = classmethod[Any, Any, Any] + TupleGenerator = typing.Generator[Tuple[str, Any], None, None] + Model = typing.TypeVar('Model', bound='BaseModel') + # should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope + IncEx: typing_extensions.TypeAlias = 'set[int] | set[str] | dict[int, Any] | dict[str, Any] | None' + +_object_setattr = _model_construction.object_setattr + + +def _iter( + self: BaseModel, + to_dict: bool = False, + by_alias: bool = False, + include: AbstractSetIntStr | MappingIntStrAny | None = None, + exclude: AbstractSetIntStr | MappingIntStrAny | None = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, +) -> TupleGenerator: + # Merge field set excludes with explicit exclude parameter with explicit overriding field set options. + # The extra "is not None" guards are not logically necessary but optimizes performance for the simple case. + if exclude is not None: + exclude = _utils.ValueItems.merge( + {k: v.exclude for k, v in self.model_fields.items() if v.exclude is not None}, exclude + ) + + if include is not None: + include = _utils.ValueItems.merge({k: True for k in self.model_fields}, include, intersect=True) + + allowed_keys = _calculate_keys(self, include=include, exclude=exclude, exclude_unset=exclude_unset) # type: ignore + if allowed_keys is None and not (to_dict or by_alias or exclude_unset or exclude_defaults or exclude_none): + # huge boost for plain _iter() + yield from self.__dict__.items() + if self.__pydantic_extra__: + yield from self.__pydantic_extra__.items() + return + + value_exclude = _utils.ValueItems(self, exclude) if exclude is not None else None + value_include = _utils.ValueItems(self, include) if include is not None else None + + if self.__pydantic_extra__ is None: + items = self.__dict__.items() + else: + items = list(self.__dict__.items()) + list(self.__pydantic_extra__.items()) + + for field_key, v in items: + if (allowed_keys is not None and field_key not in allowed_keys) or (exclude_none and v is None): + continue + + if exclude_defaults: + try: + field = self.model_fields[field_key] + except KeyError: + pass + else: + if not field.is_required() and field.default == v: + continue + + if by_alias and field_key in self.model_fields: + dict_key = self.model_fields[field_key].alias or field_key + else: + dict_key = field_key + + if to_dict or value_include or value_exclude: + v = _get_value( + type(self), + v, + to_dict=to_dict, + by_alias=by_alias, + include=value_include and value_include.for_element(field_key), + exclude=value_exclude and value_exclude.for_element(field_key), + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + yield dict_key, v + + +def _copy_and_set_values( + self: Model, + values: dict[str, Any], + fields_set: set[str], + extra: dict[str, Any] | None = None, + private: dict[str, Any] | None = None, + *, + deep: bool, # UP006 +) -> Model: + if deep: + # chances of having empty dict here are quite low for using smart_deepcopy + values = deepcopy(values) + extra = deepcopy(extra) + private = deepcopy(private) + + cls = self.__class__ + m = cls.__new__(cls) + _object_setattr(m, '__dict__', values) + _object_setattr(m, '__pydantic_extra__', extra) + _object_setattr(m, '__pydantic_fields_set__', fields_set) + _object_setattr(m, '__pydantic_private__', private) + + return m + + +@typing.no_type_check +def _get_value( + cls: type[BaseModel], + v: Any, + to_dict: bool, + by_alias: bool, + include: AbstractSetIntStr | MappingIntStrAny | None, + exclude: AbstractSetIntStr | MappingIntStrAny | None, + exclude_unset: bool, + exclude_defaults: bool, + exclude_none: bool, +) -> Any: + from .. import BaseModel + + if isinstance(v, BaseModel): + if to_dict: + return v.model_dump( + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + include=include, # type: ignore + exclude=exclude, # type: ignore + exclude_none=exclude_none, + ) + else: + return v.copy(include=include, exclude=exclude) + + value_exclude = _utils.ValueItems(v, exclude) if exclude else None + value_include = _utils.ValueItems(v, include) if include else None + + if isinstance(v, dict): + return { + k_: _get_value( + cls, + v_, + to_dict=to_dict, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + include=value_include and value_include.for_element(k_), + exclude=value_exclude and value_exclude.for_element(k_), + exclude_none=exclude_none, + ) + for k_, v_ in v.items() + if (not value_exclude or not value_exclude.is_excluded(k_)) + and (not value_include or value_include.is_included(k_)) + } + + elif _utils.sequence_like(v): + seq_args = ( + _get_value( + cls, + v_, + to_dict=to_dict, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + include=value_include and value_include.for_element(i), + exclude=value_exclude and value_exclude.for_element(i), + exclude_none=exclude_none, + ) + for i, v_ in enumerate(v) + if (not value_exclude or not value_exclude.is_excluded(i)) + and (not value_include or value_include.is_included(i)) + ) + + return v.__class__(*seq_args) if _typing_extra.is_namedtuple(v.__class__) else v.__class__(seq_args) + + elif isinstance(v, Enum) and getattr(cls.model_config, 'use_enum_values', False): + return v.value + + else: + return v + + +def _calculate_keys( + self: BaseModel, + include: MappingIntStrAny | None, + exclude: MappingIntStrAny | None, + exclude_unset: bool, + update: typing.Dict[str, Any] | None = None, # noqa UP006 +) -> typing.AbstractSet[str] | None: + if include is None and exclude is None and exclude_unset is False: + return None + + keys: typing.AbstractSet[str] + if exclude_unset: + keys = self.__pydantic_fields_set__.copy() + else: + keys = set(self.__dict__.keys()) + keys = keys | (self.__pydantic_extra__ or {}).keys() + + if include is not None: + keys &= include.keys() + + if update: + keys -= update.keys() + + if exclude: + keys -= {k for k, v in exclude.items() if _utils.ValueItems.is_true(v)} + + return keys diff --git a/lib/pydantic/deprecated/decorator.py b/lib/pydantic/deprecated/decorator.py new file mode 100644 index 00000000..36bd0690 --- /dev/null +++ b/lib/pydantic/deprecated/decorator.py @@ -0,0 +1,279 @@ +import warnings +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload + +from typing_extensions import deprecated + +from .._internal import _config, _typing_extra +from ..alias_generators import to_pascal +from ..errors import PydanticUserError +from ..functional_validators import field_validator +from ..main import BaseModel, create_model +from ..warnings import PydanticDeprecatedSince20 + +if not TYPE_CHECKING: + # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915 + # and https://youtrack.jetbrains.com/issue/PY-51428 + DeprecationWarning = PydanticDeprecatedSince20 + +__all__ = ('validate_arguments',) + +if TYPE_CHECKING: + AnyCallable = Callable[..., Any] + + AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable) + ConfigType = Union[None, Type[Any], Dict[str, Any]] + + +@overload +def validate_arguments(func: None = None, *, config: 'ConfigType' = None) -> Callable[['AnyCallableT'], 'AnyCallableT']: + ... + + +@overload +def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT': + ... + + +@deprecated( + 'The `validate_arguments` method is deprecated; use `validate_call` instead.', + category=None, +) +def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any: + """Decorator to validate the arguments passed to a function.""" + warnings.warn( + 'The `validate_arguments` method is deprecated; use `validate_call` instead.', + PydanticDeprecatedSince20, + stacklevel=2, + ) + + def validate(_func: 'AnyCallable') -> 'AnyCallable': + vd = ValidatedFunction(_func, config) + + @wraps(_func) + def wrapper_function(*args: Any, **kwargs: Any) -> Any: + return vd.call(*args, **kwargs) + + wrapper_function.vd = vd # type: ignore + wrapper_function.validate = vd.init_model_instance # type: ignore + wrapper_function.raw_function = vd.raw_function # type: ignore + wrapper_function.model = vd.model # type: ignore + return wrapper_function + + if func: + return validate(func) + else: + return validate + + +ALT_V_ARGS = 'v__args' +ALT_V_KWARGS = 'v__kwargs' +V_POSITIONAL_ONLY_NAME = 'v__positional_only' +V_DUPLICATE_KWARGS = 'v__duplicate_kwargs' + + +class ValidatedFunction: + def __init__(self, function: 'AnyCallable', config: 'ConfigType'): + from inspect import Parameter, signature + + parameters: Mapping[str, Parameter] = signature(function).parameters + + if parameters.keys() & {ALT_V_ARGS, ALT_V_KWARGS, V_POSITIONAL_ONLY_NAME, V_DUPLICATE_KWARGS}: + raise PydanticUserError( + f'"{ALT_V_ARGS}", "{ALT_V_KWARGS}", "{V_POSITIONAL_ONLY_NAME}" and "{V_DUPLICATE_KWARGS}" ' + f'are not permitted as argument names when using the "{validate_arguments.__name__}" decorator', + code=None, + ) + + self.raw_function = function + self.arg_mapping: Dict[int, str] = {} + self.positional_only_args: set[str] = set() + self.v_args_name = 'args' + self.v_kwargs_name = 'kwargs' + + type_hints = _typing_extra.get_type_hints(function, include_extras=True) + takes_args = False + takes_kwargs = False + fields: Dict[str, Tuple[Any, Any]] = {} + for i, (name, p) in enumerate(parameters.items()): + if p.annotation is p.empty: + annotation = Any + else: + annotation = type_hints[name] + + default = ... if p.default is p.empty else p.default + if p.kind == Parameter.POSITIONAL_ONLY: + self.arg_mapping[i] = name + fields[name] = annotation, default + fields[V_POSITIONAL_ONLY_NAME] = List[str], None + self.positional_only_args.add(name) + elif p.kind == Parameter.POSITIONAL_OR_KEYWORD: + self.arg_mapping[i] = name + fields[name] = annotation, default + fields[V_DUPLICATE_KWARGS] = List[str], None + elif p.kind == Parameter.KEYWORD_ONLY: + fields[name] = annotation, default + elif p.kind == Parameter.VAR_POSITIONAL: + self.v_args_name = name + fields[name] = Tuple[annotation, ...], None + takes_args = True + else: + assert p.kind == Parameter.VAR_KEYWORD, p.kind + self.v_kwargs_name = name + fields[name] = Dict[str, annotation], None + takes_kwargs = True + + # these checks avoid a clash between "args" and a field with that name + if not takes_args and self.v_args_name in fields: + self.v_args_name = ALT_V_ARGS + + # same with "kwargs" + if not takes_kwargs and self.v_kwargs_name in fields: + self.v_kwargs_name = ALT_V_KWARGS + + if not takes_args: + # we add the field so validation below can raise the correct exception + fields[self.v_args_name] = List[Any], None + + if not takes_kwargs: + # same with kwargs + fields[self.v_kwargs_name] = Dict[Any, Any], None + + self.create_model(fields, takes_args, takes_kwargs, config) + + def init_model_instance(self, *args: Any, **kwargs: Any) -> BaseModel: + values = self.build_values(args, kwargs) + return self.model(**values) + + def call(self, *args: Any, **kwargs: Any) -> Any: + m = self.init_model_instance(*args, **kwargs) + return self.execute(m) + + def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]: + values: Dict[str, Any] = {} + if args: + arg_iter = enumerate(args) + while True: + try: + i, a = next(arg_iter) + except StopIteration: + break + arg_name = self.arg_mapping.get(i) + if arg_name is not None: + values[arg_name] = a + else: + values[self.v_args_name] = [a] + [a for _, a in arg_iter] + break + + var_kwargs: Dict[str, Any] = {} + wrong_positional_args = [] + duplicate_kwargs = [] + fields_alias = [ + field.alias + for name, field in self.model.model_fields.items() + if name not in (self.v_args_name, self.v_kwargs_name) + ] + non_var_fields = set(self.model.model_fields) - {self.v_args_name, self.v_kwargs_name} + for k, v in kwargs.items(): + if k in non_var_fields or k in fields_alias: + if k in self.positional_only_args: + wrong_positional_args.append(k) + if k in values: + duplicate_kwargs.append(k) + values[k] = v + else: + var_kwargs[k] = v + + if var_kwargs: + values[self.v_kwargs_name] = var_kwargs + if wrong_positional_args: + values[V_POSITIONAL_ONLY_NAME] = wrong_positional_args + if duplicate_kwargs: + values[V_DUPLICATE_KWARGS] = duplicate_kwargs + return values + + def execute(self, m: BaseModel) -> Any: + d = {k: v for k, v in m.__dict__.items() if k in m.__pydantic_fields_set__ or m.model_fields[k].default_factory} + var_kwargs = d.pop(self.v_kwargs_name, {}) + + if self.v_args_name in d: + args_: List[Any] = [] + in_kwargs = False + kwargs = {} + for name, value in d.items(): + if in_kwargs: + kwargs[name] = value + elif name == self.v_args_name: + args_ += value + in_kwargs = True + else: + args_.append(value) + return self.raw_function(*args_, **kwargs, **var_kwargs) + elif self.positional_only_args: + args_ = [] + kwargs = {} + for name, value in d.items(): + if name in self.positional_only_args: + args_.append(value) + else: + kwargs[name] = value + return self.raw_function(*args_, **kwargs, **var_kwargs) + else: + return self.raw_function(**d, **var_kwargs) + + def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None: + pos_args = len(self.arg_mapping) + + config_wrapper = _config.ConfigWrapper(config) + + if config_wrapper.alias_generator: + raise PydanticUserError( + 'Setting the "alias_generator" property on custom Config for ' + '@validate_arguments is not yet supported, please remove.', + code=None, + ) + if config_wrapper.extra is None: + config_wrapper.config_dict['extra'] = 'forbid' + + class DecoratorBaseModel(BaseModel): + @field_validator(self.v_args_name, check_fields=False) + @classmethod + def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]: + if takes_args or v is None: + return v + + raise TypeError(f'{pos_args} positional arguments expected but {pos_args + len(v)} given') + + @field_validator(self.v_kwargs_name, check_fields=False) + @classmethod + def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if takes_kwargs or v is None: + return v + + plural = '' if len(v) == 1 else 's' + keys = ', '.join(map(repr, v.keys())) + raise TypeError(f'unexpected keyword argument{plural}: {keys}') + + @field_validator(V_POSITIONAL_ONLY_NAME, check_fields=False) + @classmethod + def check_positional_only(cls, v: Optional[List[str]]) -> None: + if v is None: + return + + plural = '' if len(v) == 1 else 's' + keys = ', '.join(map(repr, v)) + raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}') + + @field_validator(V_DUPLICATE_KWARGS, check_fields=False) + @classmethod + def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None: + if v is None: + return + + plural = '' if len(v) == 1 else 's' + keys = ', '.join(map(repr, v)) + raise TypeError(f'multiple values for argument{plural}: {keys}') + + model_config = config_wrapper.config_dict + + self.model = create_model(to_pascal(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields) diff --git a/lib/pydantic/deprecated/json.py b/lib/pydantic/deprecated/json.py new file mode 100644 index 00000000..79e2f44a --- /dev/null +++ b/lib/pydantic/deprecated/json.py @@ -0,0 +1,140 @@ +import datetime +import warnings +from collections import deque +from decimal import Decimal +from enum import Enum +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from pathlib import Path +from re import Pattern +from types import GeneratorType +from typing import TYPE_CHECKING, Any, Callable, Dict, Type, Union +from uuid import UUID + +from typing_extensions import deprecated + +from ..color import Color +from ..networks import NameEmail +from ..types import SecretBytes, SecretStr +from ..warnings import PydanticDeprecatedSince20 + +if not TYPE_CHECKING: + # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915 + # and https://youtrack.jetbrains.com/issue/PY-51428 + DeprecationWarning = PydanticDeprecatedSince20 + +__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat' + + +def isoformat(o: Union[datetime.date, datetime.time]) -> str: + return o.isoformat() + + +def decimal_encoder(dec_value: Decimal) -> Union[int, float]: + """Encodes a Decimal as int of there's no exponent, otherwise float. + + This is useful when we use ConstrainedDecimal to represent Numeric(x,0) + where a integer (but not int typed) is used. Encoding this as a float + results in failed round-tripping between encode and parse. + Our Id type is a prime example of this. + + >>> decimal_encoder(Decimal("1.0")) + 1.0 + + >>> decimal_encoder(Decimal("1")) + 1 + """ + exponent = dec_value.as_tuple().exponent + if isinstance(exponent, int) and exponent >= 0: + return int(dec_value) + else: + return float(dec_value) + + +ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { + bytes: lambda o: o.decode(), + Color: str, + datetime.date: isoformat, + datetime.datetime: isoformat, + datetime.time: isoformat, + datetime.timedelta: lambda td: td.total_seconds(), + Decimal: decimal_encoder, + Enum: lambda o: o.value, + frozenset: list, + deque: list, + GeneratorType: list, + IPv4Address: str, + IPv4Interface: str, + IPv4Network: str, + IPv6Address: str, + IPv6Interface: str, + IPv6Network: str, + NameEmail: str, + Path: str, + Pattern: lambda o: o.pattern, + SecretBytes: str, + SecretStr: str, + set: list, + UUID: str, +} + + +@deprecated( + '`pydantic_encoder` is deprecated, use `pydantic_core.to_jsonable_python` instead.', + category=None, +) +def pydantic_encoder(obj: Any) -> Any: + warnings.warn( + '`pydantic_encoder` is deprecated, use `pydantic_core.to_jsonable_python` instead.', + category=PydanticDeprecatedSince20, + stacklevel=2, + ) + from dataclasses import asdict, is_dataclass + + from ..main import BaseModel + + if isinstance(obj, BaseModel): + return obj.model_dump() + elif is_dataclass(obj): + return asdict(obj) + + # Check the class type and its superclasses for a matching encoder + for base in obj.__class__.__mro__[:-1]: + try: + encoder = ENCODERS_BY_TYPE[base] + except KeyError: + continue + return encoder(obj) + else: # We have exited the for loop without finding a suitable encoder + raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable") + + +# TODO: Add a suggested migration path once there is a way to use custom encoders +@deprecated( + '`custom_pydantic_encoder` is deprecated, use `BaseModel.model_dump` instead.', + category=None, +) +def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any: + warnings.warn( + '`custom_pydantic_encoder` is deprecated, use `BaseModel.model_dump` instead.', + category=PydanticDeprecatedSince20, + stacklevel=2, + ) + # Check the class type and its superclasses for a matching encoder + for base in obj.__class__.__mro__[:-1]: + try: + encoder = type_encoders[base] + except KeyError: + continue + + return encoder(obj) + else: # We have exited the for loop without finding a suitable encoder + return pydantic_encoder(obj) + + +@deprecated('`timedelta_isoformat` is deprecated.', category=None) +def timedelta_isoformat(td: datetime.timedelta) -> str: + """ISO 8601 encoding for Python timedelta object.""" + warnings.warn('`timedelta_isoformat` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2) + minutes, seconds = divmod(td.seconds, 60) + hours, minutes = divmod(minutes, 60) + return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S' diff --git a/lib/pydantic/deprecated/parse.py b/lib/pydantic/deprecated/parse.py new file mode 100644 index 00000000..2a92e62b --- /dev/null +++ b/lib/pydantic/deprecated/parse.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import json +import pickle +import warnings +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable + +from typing_extensions import deprecated + +from ..warnings import PydanticDeprecatedSince20 + +if not TYPE_CHECKING: + # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915 + # and https://youtrack.jetbrains.com/issue/PY-51428 + DeprecationWarning = PydanticDeprecatedSince20 + + +class Protocol(str, Enum): + json = 'json' + pickle = 'pickle' + + +@deprecated('`load_str_bytes` is deprecated.', category=None) +def load_str_bytes( + b: str | bytes, + *, + content_type: str | None = None, + encoding: str = 'utf8', + proto: Protocol | None = None, + allow_pickle: bool = False, + json_loads: Callable[[str], Any] = json.loads, +) -> Any: + warnings.warn('`load_str_bytes` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2) + if proto is None and content_type: + if content_type.endswith(('json', 'javascript')): + pass + elif allow_pickle and content_type.endswith('pickle'): + proto = Protocol.pickle + else: + raise TypeError(f'Unknown content-type: {content_type}') + + proto = proto or Protocol.json + + if proto == Protocol.json: + if isinstance(b, bytes): + b = b.decode(encoding) + return json_loads(b) # type: ignore + elif proto == Protocol.pickle: + if not allow_pickle: + raise RuntimeError('Trying to decode with pickle with allow_pickle=False') + bb = b if isinstance(b, bytes) else b.encode() # type: ignore + return pickle.loads(bb) + else: + raise TypeError(f'Unknown protocol: {proto}') + + +@deprecated('`load_file` is deprecated.', category=None) +def load_file( + path: str | Path, + *, + content_type: str | None = None, + encoding: str = 'utf8', + proto: Protocol | None = None, + allow_pickle: bool = False, + json_loads: Callable[[str], Any] = json.loads, +) -> Any: + warnings.warn('`load_file` is deprecated.', category=PydanticDeprecatedSince20, stacklevel=2) + path = Path(path) + b = path.read_bytes() + if content_type is None: + if path.suffix in ('.js', '.json'): + proto = Protocol.json + elif path.suffix == '.pkl': + proto = Protocol.pickle + + return load_str_bytes( + b, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle, json_loads=json_loads + ) diff --git a/lib/pydantic/deprecated/tools.py b/lib/pydantic/deprecated/tools.py new file mode 100644 index 00000000..b04eae40 --- /dev/null +++ b/lib/pydantic/deprecated/tools.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import json +import warnings +from typing import TYPE_CHECKING, Any, Callable, Type, TypeVar, Union + +from typing_extensions import deprecated + +from ..json_schema import DEFAULT_REF_TEMPLATE, GenerateJsonSchema +from ..type_adapter import TypeAdapter +from ..warnings import PydanticDeprecatedSince20 + +if not TYPE_CHECKING: + # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915 + # and https://youtrack.jetbrains.com/issue/PY-51428 + DeprecationWarning = PydanticDeprecatedSince20 + +__all__ = 'parse_obj_as', 'schema_of', 'schema_json_of' + +NameFactory = Union[str, Callable[[Type[Any]], str]] + + +T = TypeVar('T') + + +@deprecated( + '`parse_obj_as` is deprecated. Use `pydantic.TypeAdapter.validate_python` instead.', + category=None, +) +def parse_obj_as(type_: type[T], obj: Any, type_name: NameFactory | None = None) -> T: + warnings.warn( + '`parse_obj_as` is deprecated. Use `pydantic.TypeAdapter.validate_python` instead.', + category=PydanticDeprecatedSince20, + stacklevel=2, + ) + if type_name is not None: # pragma: no cover + warnings.warn( + 'The type_name parameter is deprecated. parse_obj_as no longer creates temporary models', + DeprecationWarning, + stacklevel=2, + ) + return TypeAdapter(type_).validate_python(obj) + + +@deprecated( + '`schema_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.', + category=None, +) +def schema_of( + type_: Any, + *, + title: NameFactory | None = None, + by_alias: bool = True, + ref_template: str = DEFAULT_REF_TEMPLATE, + schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema, +) -> dict[str, Any]: + """Generate a JSON schema (as dict) for the passed model or dynamically generated one.""" + warnings.warn( + '`schema_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.', + category=PydanticDeprecatedSince20, + stacklevel=2, + ) + res = TypeAdapter(type_).json_schema( + by_alias=by_alias, + schema_generator=schema_generator, + ref_template=ref_template, + ) + if title is not None: + if isinstance(title, str): + res['title'] = title + else: + warnings.warn( + 'Passing a callable for the `title` parameter is deprecated and no longer supported', + DeprecationWarning, + stacklevel=2, + ) + res['title'] = title(type_) + return res + + +@deprecated( + '`schema_json_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.', + category=None, +) +def schema_json_of( + type_: Any, + *, + title: NameFactory | None = None, + by_alias: bool = True, + ref_template: str = DEFAULT_REF_TEMPLATE, + schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema, + **dumps_kwargs: Any, +) -> str: + """Generate a JSON schema (as JSON) for the passed model or dynamically generated one.""" + warnings.warn( + '`schema_json_of` is deprecated. Use `pydantic.TypeAdapter.json_schema` instead.', + category=PydanticDeprecatedSince20, + stacklevel=2, + ) + return json.dumps( + schema_of(type_, title=title, by_alias=by_alias, ref_template=ref_template, schema_generator=schema_generator), + **dumps_kwargs, + ) diff --git a/lib/pydantic/env_settings.py b/lib/pydantic/env_settings.py index e9988c01..662f5900 100644 --- a/lib/pydantic/env_settings.py +++ b/lib/pydantic/env_settings.py @@ -1,346 +1,4 @@ -import os -import warnings -from pathlib import Path -from typing import AbstractSet, Any, Callable, ClassVar, Dict, List, Mapping, Optional, Tuple, Type, Union +"""The `env_settings` module is a backport module from V1.""" +from ._migration import getattr_migration -from .config import BaseConfig, Extra -from .fields import ModelField -from .main import BaseModel -from .typing import StrPath, display_as_type, get_origin, is_union -from .utils import deep_update, path_type, sequence_like - -env_file_sentinel = str(object()) - -SettingsSourceCallable = Callable[['BaseSettings'], Dict[str, Any]] -DotenvType = Union[StrPath, List[StrPath], Tuple[StrPath, ...]] - - -class SettingsError(ValueError): - pass - - -class BaseSettings(BaseModel): - """ - Base class for settings, allowing values to be overridden by environment variables. - - This is useful in production for secrets you do not wish to save in code, it plays nicely with docker(-compose), - Heroku and any 12 factor app design. - """ - - def __init__( - __pydantic_self__, - _env_file: Optional[DotenvType] = env_file_sentinel, - _env_file_encoding: Optional[str] = None, - _env_nested_delimiter: Optional[str] = None, - _secrets_dir: Optional[StrPath] = None, - **values: Any, - ) -> None: - # Uses something other than `self` the first arg to allow "self" as a settable attribute - super().__init__( - **__pydantic_self__._build_values( - values, - _env_file=_env_file, - _env_file_encoding=_env_file_encoding, - _env_nested_delimiter=_env_nested_delimiter, - _secrets_dir=_secrets_dir, - ) - ) - - def _build_values( - self, - init_kwargs: Dict[str, Any], - _env_file: Optional[DotenvType] = None, - _env_file_encoding: Optional[str] = None, - _env_nested_delimiter: Optional[str] = None, - _secrets_dir: Optional[StrPath] = None, - ) -> Dict[str, Any]: - # Configure built-in sources - init_settings = InitSettingsSource(init_kwargs=init_kwargs) - env_settings = EnvSettingsSource( - env_file=(_env_file if _env_file != env_file_sentinel else self.__config__.env_file), - env_file_encoding=( - _env_file_encoding if _env_file_encoding is not None else self.__config__.env_file_encoding - ), - env_nested_delimiter=( - _env_nested_delimiter if _env_nested_delimiter is not None else self.__config__.env_nested_delimiter - ), - env_prefix_len=len(self.__config__.env_prefix), - ) - file_secret_settings = SecretsSettingsSource(secrets_dir=_secrets_dir or self.__config__.secrets_dir) - # Provide a hook to set built-in sources priority and add / remove sources - sources = self.__config__.customise_sources( - init_settings=init_settings, env_settings=env_settings, file_secret_settings=file_secret_settings - ) - if sources: - return deep_update(*reversed([source(self) for source in sources])) - else: - # no one should mean to do this, but I think returning an empty dict is marginally preferable - # to an informative error and much better than a confusing error - return {} - - class Config(BaseConfig): - env_prefix: str = '' - env_file: Optional[DotenvType] = None - env_file_encoding: Optional[str] = None - env_nested_delimiter: Optional[str] = None - secrets_dir: Optional[StrPath] = None - validate_all: bool = True - extra: Extra = Extra.forbid - arbitrary_types_allowed: bool = True - case_sensitive: bool = False - - @classmethod - def prepare_field(cls, field: ModelField) -> None: - env_names: Union[List[str], AbstractSet[str]] - field_info_from_config = cls.get_field_info(field.name) - - env = field_info_from_config.get('env') or field.field_info.extra.get('env') - if env is None: - if field.has_alias: - warnings.warn( - 'aliases are no longer used by BaseSettings to define which environment variables to read. ' - 'Instead use the "env" field setting. ' - 'See https://pydantic-docs.helpmanual.io/usage/settings/#environment-variable-names', - FutureWarning, - ) - env_names = {cls.env_prefix + field.name} - elif isinstance(env, str): - env_names = {env} - elif isinstance(env, (set, frozenset)): - env_names = env - elif sequence_like(env): - env_names = list(env) - else: - raise TypeError(f'invalid field env: {env!r} ({display_as_type(env)}); should be string, list or set') - - if not cls.case_sensitive: - env_names = env_names.__class__(n.lower() for n in env_names) - field.field_info.extra['env_names'] = env_names - - @classmethod - def customise_sources( - cls, - init_settings: SettingsSourceCallable, - env_settings: SettingsSourceCallable, - file_secret_settings: SettingsSourceCallable, - ) -> Tuple[SettingsSourceCallable, ...]: - return init_settings, env_settings, file_secret_settings - - @classmethod - def parse_env_var(cls, field_name: str, raw_val: str) -> Any: - return cls.json_loads(raw_val) - - # populated by the metaclass using the Config class defined above, annotated here to help IDEs only - __config__: ClassVar[Type[Config]] - - -class InitSettingsSource: - __slots__ = ('init_kwargs',) - - def __init__(self, init_kwargs: Dict[str, Any]): - self.init_kwargs = init_kwargs - - def __call__(self, settings: BaseSettings) -> Dict[str, Any]: - return self.init_kwargs - - def __repr__(self) -> str: - return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})' - - -class EnvSettingsSource: - __slots__ = ('env_file', 'env_file_encoding', 'env_nested_delimiter', 'env_prefix_len') - - def __init__( - self, - env_file: Optional[DotenvType], - env_file_encoding: Optional[str], - env_nested_delimiter: Optional[str] = None, - env_prefix_len: int = 0, - ): - self.env_file: Optional[DotenvType] = env_file - self.env_file_encoding: Optional[str] = env_file_encoding - self.env_nested_delimiter: Optional[str] = env_nested_delimiter - self.env_prefix_len: int = env_prefix_len - - def __call__(self, settings: BaseSettings) -> Dict[str, Any]: # noqa C901 - """ - Build environment variables suitable for passing to the Model. - """ - d: Dict[str, Any] = {} - - if settings.__config__.case_sensitive: - env_vars: Mapping[str, Optional[str]] = os.environ - else: - env_vars = {k.lower(): v for k, v in os.environ.items()} - - dotenv_vars = self._read_env_files(settings.__config__.case_sensitive) - if dotenv_vars: - env_vars = {**dotenv_vars, **env_vars} - - for field in settings.__fields__.values(): - env_val: Optional[str] = None - for env_name in field.field_info.extra['env_names']: - env_val = env_vars.get(env_name) - if env_val is not None: - break - - is_complex, allow_parse_failure = self.field_is_complex(field) - if is_complex: - if env_val is None: - # field is complex but no value found so far, try explode_env_vars - env_val_built = self.explode_env_vars(field, env_vars) - if env_val_built: - d[field.alias] = env_val_built - else: - # field is complex and there's a value, decode that as JSON, then add explode_env_vars - try: - env_val = settings.__config__.parse_env_var(field.name, env_val) - except ValueError as e: - if not allow_parse_failure: - raise SettingsError(f'error parsing env var "{env_name}"') from e - - if isinstance(env_val, dict): - d[field.alias] = deep_update(env_val, self.explode_env_vars(field, env_vars)) - else: - d[field.alias] = env_val - elif env_val is not None: - # simplest case, field is not complex, we only need to add the value if it was found - d[field.alias] = env_val - - return d - - def _read_env_files(self, case_sensitive: bool) -> Dict[str, Optional[str]]: - env_files = self.env_file - if env_files is None: - return {} - - if isinstance(env_files, (str, os.PathLike)): - env_files = [env_files] - - dotenv_vars = {} - for env_file in env_files: - env_path = Path(env_file).expanduser() - if env_path.is_file(): - dotenv_vars.update( - read_env_file(env_path, encoding=self.env_file_encoding, case_sensitive=case_sensitive) - ) - - return dotenv_vars - - def field_is_complex(self, field: ModelField) -> Tuple[bool, bool]: - """ - Find out if a field is complex, and if so whether JSON errors should be ignored - """ - if field.is_complex(): - allow_parse_failure = False - elif is_union(get_origin(field.type_)) and field.sub_fields and any(f.is_complex() for f in field.sub_fields): - allow_parse_failure = True - else: - return False, False - - return True, allow_parse_failure - - def explode_env_vars(self, field: ModelField, env_vars: Mapping[str, Optional[str]]) -> Dict[str, Any]: - """ - Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries. - - This is applied to a single field, hence filtering by env_var prefix. - """ - prefixes = [f'{env_name}{self.env_nested_delimiter}' for env_name in field.field_info.extra['env_names']] - result: Dict[str, Any] = {} - for env_name, env_val in env_vars.items(): - if not any(env_name.startswith(prefix) for prefix in prefixes): - continue - # we remove the prefix before splitting in case the prefix has characters in common with the delimiter - env_name_without_prefix = env_name[self.env_prefix_len :] - _, *keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter) - env_var = result - for key in keys: - env_var = env_var.setdefault(key, {}) - env_var[last_key] = env_val - - return result - - def __repr__(self) -> str: - return ( - f'EnvSettingsSource(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, ' - f'env_nested_delimiter={self.env_nested_delimiter!r})' - ) - - -class SecretsSettingsSource: - __slots__ = ('secrets_dir',) - - def __init__(self, secrets_dir: Optional[StrPath]): - self.secrets_dir: Optional[StrPath] = secrets_dir - - def __call__(self, settings: BaseSettings) -> Dict[str, Any]: - """ - Build fields from "secrets" files. - """ - secrets: Dict[str, Optional[str]] = {} - - if self.secrets_dir is None: - return secrets - - secrets_path = Path(self.secrets_dir).expanduser() - - if not secrets_path.exists(): - warnings.warn(f'directory "{secrets_path}" does not exist') - return secrets - - if not secrets_path.is_dir(): - raise SettingsError(f'secrets_dir must reference a directory, not a {path_type(secrets_path)}') - - for field in settings.__fields__.values(): - for env_name in field.field_info.extra['env_names']: - path = find_case_path(secrets_path, env_name, settings.__config__.case_sensitive) - if not path: - # path does not exist, we curently don't return a warning for this - continue - - if path.is_file(): - secret_value = path.read_text().strip() - if field.is_complex(): - try: - secret_value = settings.__config__.parse_env_var(field.name, secret_value) - except ValueError as e: - raise SettingsError(f'error parsing env var "{env_name}"') from e - - secrets[field.alias] = secret_value - else: - warnings.warn( - f'attempted to load secret file "{path}" but found a {path_type(path)} instead.', - stacklevel=4, - ) - return secrets - - def __repr__(self) -> str: - return f'SecretsSettingsSource(secrets_dir={self.secrets_dir!r})' - - -def read_env_file( - file_path: StrPath, *, encoding: str = None, case_sensitive: bool = False -) -> Dict[str, Optional[str]]: - try: - from dotenv import dotenv_values - except ImportError as e: - raise ImportError('python-dotenv is not installed, run `pip install pydantic[dotenv]`') from e - - file_vars: Dict[str, Optional[str]] = dotenv_values(file_path, encoding=encoding or 'utf8') - if not case_sensitive: - return {k.lower(): v for k, v in file_vars.items()} - else: - return file_vars - - -def find_case_path(dir_path: Path, file_name: str, case_sensitive: bool) -> Optional[Path]: - """ - Find a file within path's directory matching filename, optionally ignoring case. - """ - for f in dir_path.iterdir(): - if f.name == file_name: - return f - elif not case_sensitive and f.name.lower() == file_name.lower(): - return f - return None +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/error_wrappers.py b/lib/pydantic/error_wrappers.py index 5d3204f4..5144eeee 100644 --- a/lib/pydantic/error_wrappers.py +++ b/lib/pydantic/error_wrappers.py @@ -1,162 +1,4 @@ -import json -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple, Type, Union +"""The `error_wrappers` module is a backport module from V1.""" +from ._migration import getattr_migration -from .json import pydantic_encoder -from .utils import Representation - -if TYPE_CHECKING: - from typing_extensions import TypedDict - - from .config import BaseConfig - from .types import ModelOrDc - from .typing import ReprArgs - - Loc = Tuple[Union[int, str], ...] - - class _ErrorDictRequired(TypedDict): - loc: Loc - msg: str - type: str - - class ErrorDict(_ErrorDictRequired, total=False): - ctx: Dict[str, Any] - - -__all__ = 'ErrorWrapper', 'ValidationError' - - -class ErrorWrapper(Representation): - __slots__ = 'exc', '_loc' - - def __init__(self, exc: Exception, loc: Union[str, 'Loc']) -> None: - self.exc = exc - self._loc = loc - - def loc_tuple(self) -> 'Loc': - if isinstance(self._loc, tuple): - return self._loc - else: - return (self._loc,) - - def __repr_args__(self) -> 'ReprArgs': - return [('exc', self.exc), ('loc', self.loc_tuple())] - - -# ErrorList is something like Union[List[Union[List[ErrorWrapper], ErrorWrapper]], ErrorWrapper] -# but recursive, therefore just use: -ErrorList = Union[Sequence[Any], ErrorWrapper] - - -class ValidationError(Representation, ValueError): - __slots__ = 'raw_errors', 'model', '_error_cache' - - def __init__(self, errors: Sequence[ErrorList], model: 'ModelOrDc') -> None: - self.raw_errors = errors - self.model = model - self._error_cache: Optional[List['ErrorDict']] = None - - def errors(self) -> List['ErrorDict']: - if self._error_cache is None: - try: - config = self.model.__config__ # type: ignore - except AttributeError: - config = self.model.__pydantic_model__.__config__ # type: ignore - self._error_cache = list(flatten_errors(self.raw_errors, config)) - return self._error_cache - - def json(self, *, indent: Union[None, int, str] = 2) -> str: - return json.dumps(self.errors(), indent=indent, default=pydantic_encoder) - - def __str__(self) -> str: - errors = self.errors() - no_errors = len(errors) - return ( - f'{no_errors} validation error{"" if no_errors == 1 else "s"} for {self.model.__name__}\n' - f'{display_errors(errors)}' - ) - - def __repr_args__(self) -> 'ReprArgs': - return [('model', self.model.__name__), ('errors', self.errors())] - - -def display_errors(errors: List['ErrorDict']) -> str: - return '\n'.join(f'{_display_error_loc(e)}\n {e["msg"]} ({_display_error_type_and_ctx(e)})' for e in errors) - - -def _display_error_loc(error: 'ErrorDict') -> str: - return ' -> '.join(str(e) for e in error['loc']) - - -def _display_error_type_and_ctx(error: 'ErrorDict') -> str: - t = 'type=' + error['type'] - ctx = error.get('ctx') - if ctx: - return t + ''.join(f'; {k}={v}' for k, v in ctx.items()) - else: - return t - - -def flatten_errors( - errors: Sequence[Any], config: Type['BaseConfig'], loc: Optional['Loc'] = None -) -> Generator['ErrorDict', None, None]: - for error in errors: - if isinstance(error, ErrorWrapper): - - if loc: - error_loc = loc + error.loc_tuple() - else: - error_loc = error.loc_tuple() - - if isinstance(error.exc, ValidationError): - yield from flatten_errors(error.exc.raw_errors, config, error_loc) - else: - yield error_dict(error.exc, config, error_loc) - elif isinstance(error, list): - yield from flatten_errors(error, config, loc=loc) - else: - raise RuntimeError(f'Unknown error object: {error}') - - -def error_dict(exc: Exception, config: Type['BaseConfig'], loc: 'Loc') -> 'ErrorDict': - type_ = get_exc_type(exc.__class__) - msg_template = config.error_msg_templates.get(type_) or getattr(exc, 'msg_template', None) - ctx = exc.__dict__ - if msg_template: - msg = msg_template.format(**ctx) - else: - msg = str(exc) - - d: 'ErrorDict' = {'loc': loc, 'msg': msg, 'type': type_} - - if ctx: - d['ctx'] = ctx - - return d - - -_EXC_TYPE_CACHE: Dict[Type[Exception], str] = {} - - -def get_exc_type(cls: Type[Exception]) -> str: - # slightly more efficient than using lru_cache since we don't need to worry about the cache filling up - try: - return _EXC_TYPE_CACHE[cls] - except KeyError: - r = _get_exc_type(cls) - _EXC_TYPE_CACHE[cls] = r - return r - - -def _get_exc_type(cls: Type[Exception]) -> str: - if issubclass(cls, AssertionError): - return 'assertion_error' - - base_name = 'type_error' if issubclass(cls, TypeError) else 'value_error' - if cls in (TypeError, ValueError): - # just TypeError or ValueError, no extra code - return base_name - - # if it's not a TypeError or ValueError, we just take the lowercase of the exception name - # no chaining or snake case logic, use "code" for more complex error types. - code = getattr(cls, 'code', None) or cls.__name__.replace('Error', '').lower() - return base_name + '.' + code +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/errors.py b/lib/pydantic/errors.py index 7bdafdd1..c5fa9612 100644 --- a/lib/pydantic/errors.py +++ b/lib/pydantic/errors.py @@ -1,646 +1,152 @@ -from decimal import Decimal -from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Sequence, Set, Tuple, Type, Union +"""Pydantic-specific errors.""" +from __future__ import annotations as _annotations -from .typing import display_as_type +import re -if TYPE_CHECKING: - from .typing import DictStrAny +from typing_extensions import Literal, Self + +from ._migration import getattr_migration +from .version import version_short -# explicitly state exports to avoid "from .errors import *" also importing Decimal, Path etc. __all__ = ( - 'PydanticTypeError', - 'PydanticValueError', - 'ConfigError', - 'MissingError', - 'ExtraError', - 'NoneIsNotAllowedError', - 'NoneIsAllowedError', - 'WrongConstantError', - 'NotNoneError', - 'BoolError', - 'BytesError', - 'DictError', - 'EmailError', - 'UrlError', - 'UrlSchemeError', - 'UrlSchemePermittedError', - 'UrlUserInfoError', - 'UrlHostError', - 'UrlHostTldError', - 'UrlPortError', - 'UrlExtraError', - 'EnumError', - 'IntEnumError', - 'EnumMemberError', - 'IntegerError', - 'FloatError', - 'PathError', - 'PathNotExistsError', - 'PathNotAFileError', - 'PathNotADirectoryError', - 'PyObjectError', - 'SequenceError', - 'ListError', - 'SetError', - 'FrozenSetError', - 'TupleError', - 'TupleLengthError', - 'ListMinLengthError', - 'ListMaxLengthError', - 'ListUniqueItemsError', - 'SetMinLengthError', - 'SetMaxLengthError', - 'FrozenSetMinLengthError', - 'FrozenSetMaxLengthError', - 'AnyStrMinLengthError', - 'AnyStrMaxLengthError', - 'StrError', - 'StrRegexError', - 'NumberNotGtError', - 'NumberNotGeError', - 'NumberNotLtError', - 'NumberNotLeError', - 'NumberNotMultipleError', - 'DecimalError', - 'DecimalIsNotFiniteError', - 'DecimalMaxDigitsError', - 'DecimalMaxPlacesError', - 'DecimalWholeDigitsError', - 'DateTimeError', - 'DateError', - 'DateNotInThePastError', - 'DateNotInTheFutureError', - 'TimeError', - 'DurationError', - 'HashableError', - 'UUIDError', - 'UUIDVersionError', - 'ArbitraryTypeError', - 'ClassError', - 'SubclassError', - 'JsonError', - 'JsonTypeError', - 'PatternError', - 'DataclassTypeError', - 'CallableError', - 'IPvAnyAddressError', - 'IPvAnyInterfaceError', - 'IPvAnyNetworkError', - 'IPv4AddressError', - 'IPv6AddressError', - 'IPv4NetworkError', - 'IPv6NetworkError', - 'IPv4InterfaceError', - 'IPv6InterfaceError', - 'ColorError', - 'StrictBoolError', - 'NotDigitError', - 'LuhnValidationError', - 'InvalidLengthForBrand', - 'InvalidByteSize', - 'InvalidByteSizeUnit', - 'MissingDiscriminator', - 'InvalidDiscriminator', + 'PydanticUserError', + 'PydanticUndefinedAnnotation', + 'PydanticImportError', + 'PydanticSchemaGenerationError', + 'PydanticInvalidForJsonSchema', + 'PydanticErrorCodes', ) - -def cls_kwargs(cls: Type['PydanticErrorMixin'], ctx: 'DictStrAny') -> 'PydanticErrorMixin': - """ - For built-in exceptions like ValueError or TypeError, we need to implement - __reduce__ to override the default behaviour (instead of __getstate__/__setstate__) - By default pickle protocol 2 calls `cls.__new__(cls, *args)`. - Since we only use kwargs, we need a little constructor to change that. - Note: the callable can't be a lambda as pickle looks in the namespace to find it - """ - return cls(**ctx) +# We use this URL to allow for future flexibility about how we host the docs, while allowing for Pydantic +# code in the while with "old" URLs to still work. +# 'u' refers to "user errors" - e.g. errors caused by developers using pydantic, as opposed to validation errors. +DEV_ERROR_DOCS_URL = f'https://errors.pydantic.dev/{version_short()}/u/' +PydanticErrorCodes = Literal[ + 'class-not-fully-defined', + 'custom-json-schema', + 'decorator-missing-field', + 'discriminator-no-field', + 'discriminator-alias-type', + 'discriminator-needs-literal', + 'discriminator-alias', + 'discriminator-validator', + 'callable-discriminator-no-tag', + 'typed-dict-version', + 'model-field-overridden', + 'model-field-missing-annotation', + 'config-both', + 'removed-kwargs', + 'invalid-for-json-schema', + 'json-schema-already-used', + 'base-model-instantiated', + 'undefined-annotation', + 'schema-for-unknown-type', + 'import-error', + 'create-model-field-definitions', + 'create-model-config-base', + 'validator-no-fields', + 'validator-invalid-fields', + 'validator-instance-method', + 'root-validator-pre-skip', + 'model-serializer-instance-method', + 'validator-field-config-info', + 'validator-v1-signature', + 'validator-signature', + 'field-serializer-signature', + 'model-serializer-signature', + 'multiple-field-serializers', + 'invalid_annotated_type', + 'type-adapter-config-unused', + 'root-model-extra', + 'unevaluable-type-annotation', + 'dataclass-init-false-extra-allow', + 'clashing-init-and-init-var', +] class PydanticErrorMixin: - code: str - msg_template: str + """A mixin class for common functionality shared by all Pydantic-specific errors. - def __init__(self, **ctx: Any) -> None: - self.__dict__ = ctx + Attributes: + message: A message describing the error. + code: An optional error code from PydanticErrorCodes enum. + """ + + def __init__(self, message: str, *, code: PydanticErrorCodes | None) -> None: + self.message = message + self.code = code def __str__(self) -> str: - return self.msg_template.format(**self.__dict__) + if self.code is None: + return self.message + else: + return f'{self.message}\n\nFor further information visit {DEV_ERROR_DOCS_URL}{self.code}' - def __reduce__(self) -> Tuple[Callable[..., 'PydanticErrorMixin'], Tuple[Type['PydanticErrorMixin'], 'DictStrAny']]: - return cls_kwargs, (self.__class__, self.__dict__) +class PydanticUserError(PydanticErrorMixin, TypeError): + """An error raised due to incorrect use of Pydantic.""" -class PydanticTypeError(PydanticErrorMixin, TypeError): - pass +class PydanticUndefinedAnnotation(PydanticErrorMixin, NameError): + """A subclass of `NameError` raised when handling undefined annotations during `CoreSchema` generation. -class PydanticValueError(PydanticErrorMixin, ValueError): - pass + Attributes: + name: Name of the error. + message: Description of the error. + """ + def __init__(self, name: str, message: str) -> None: + self.name = name + super().__init__(message=message, code='undefined-annotation') -class ConfigError(RuntimeError): - pass + @classmethod + def from_name_error(cls, name_error: NameError) -> Self: + """Convert a `NameError` to a `PydanticUndefinedAnnotation` error. + Args: + name_error: `NameError` to be converted. -class MissingError(PydanticValueError): - msg_template = 'field required' + Returns: + Converted `PydanticUndefinedAnnotation` error. + """ + try: + name = name_error.name # type: ignore # python > 3.10 + except AttributeError: + name = re.search(r".*'(.+?)'", str(name_error)).group(1) # type: ignore[union-attr] + return cls(name=name, message=str(name_error)) -class ExtraError(PydanticValueError): - msg_template = 'extra fields not permitted' +class PydanticImportError(PydanticErrorMixin, ImportError): + """An error raised when an import fails due to module changes between V1 and V2. + Attributes: + message: Description of the error. + """ -class NoneIsNotAllowedError(PydanticTypeError): - code = 'none.not_allowed' - msg_template = 'none is not an allowed value' + def __init__(self, message: str) -> None: + super().__init__(message, code='import-error') -class NoneIsAllowedError(PydanticTypeError): - code = 'none.allowed' - msg_template = 'value is not none' +class PydanticSchemaGenerationError(PydanticUserError): + """An error raised during failures to generate a `CoreSchema` for some type. + Attributes: + message: Description of the error. + """ -class WrongConstantError(PydanticValueError): - code = 'const' + def __init__(self, message: str) -> None: + super().__init__(message, code='schema-for-unknown-type') - def __str__(self) -> str: - permitted = ', '.join(repr(v) for v in self.permitted) # type: ignore - return f'unexpected value; permitted: {permitted}' +class PydanticInvalidForJsonSchema(PydanticUserError): + """An error raised during failures to generate a JSON schema for some `CoreSchema`. -class NotNoneError(PydanticTypeError): - code = 'not_none' - msg_template = 'value is not None' + Attributes: + message: Description of the error. + """ + def __init__(self, message: str) -> None: + super().__init__(message, code='invalid-for-json-schema') -class BoolError(PydanticTypeError): - msg_template = 'value could not be parsed to a boolean' - -class BytesError(PydanticTypeError): - msg_template = 'byte type expected' - - -class DictError(PydanticTypeError): - msg_template = 'value is not a valid dict' - - -class EmailError(PydanticValueError): - msg_template = 'value is not a valid email address' - - -class UrlError(PydanticValueError): - code = 'url' - - -class UrlSchemeError(UrlError): - code = 'url.scheme' - msg_template = 'invalid or missing URL scheme' - - -class UrlSchemePermittedError(UrlError): - code = 'url.scheme' - msg_template = 'URL scheme not permitted' - - def __init__(self, allowed_schemes: Set[str]): - super().__init__(allowed_schemes=allowed_schemes) - - -class UrlUserInfoError(UrlError): - code = 'url.userinfo' - msg_template = 'userinfo required in URL but missing' - - -class UrlHostError(UrlError): - code = 'url.host' - msg_template = 'URL host invalid' - - -class UrlHostTldError(UrlError): - code = 'url.host' - msg_template = 'URL host invalid, top level domain required' - - -class UrlPortError(UrlError): - code = 'url.port' - msg_template = 'URL port invalid, port cannot exceed 65535' - - -class UrlExtraError(UrlError): - code = 'url.extra' - msg_template = 'URL invalid, extra characters found after valid URL: {extra!r}' - - -class EnumMemberError(PydanticTypeError): - code = 'enum' - - def __str__(self) -> str: - permitted = ', '.join(repr(v.value) for v in self.enum_values) # type: ignore - return f'value is not a valid enumeration member; permitted: {permitted}' - - -class IntegerError(PydanticTypeError): - msg_template = 'value is not a valid integer' - - -class FloatError(PydanticTypeError): - msg_template = 'value is not a valid float' - - -class PathError(PydanticTypeError): - msg_template = 'value is not a valid path' - - -class _PathValueError(PydanticValueError): - def __init__(self, *, path: Path) -> None: - super().__init__(path=str(path)) - - -class PathNotExistsError(_PathValueError): - code = 'path.not_exists' - msg_template = 'file or directory at path "{path}" does not exist' - - -class PathNotAFileError(_PathValueError): - code = 'path.not_a_file' - msg_template = 'path "{path}" does not point to a file' - - -class PathNotADirectoryError(_PathValueError): - code = 'path.not_a_directory' - msg_template = 'path "{path}" does not point to a directory' - - -class PyObjectError(PydanticTypeError): - msg_template = 'ensure this value contains valid import path or valid callable: {error_message}' - - -class SequenceError(PydanticTypeError): - msg_template = 'value is not a valid sequence' - - -class IterableError(PydanticTypeError): - msg_template = 'value is not a valid iterable' - - -class ListError(PydanticTypeError): - msg_template = 'value is not a valid list' - - -class SetError(PydanticTypeError): - msg_template = 'value is not a valid set' - - -class FrozenSetError(PydanticTypeError): - msg_template = 'value is not a valid frozenset' - - -class DequeError(PydanticTypeError): - msg_template = 'value is not a valid deque' - - -class TupleError(PydanticTypeError): - msg_template = 'value is not a valid tuple' - - -class TupleLengthError(PydanticValueError): - code = 'tuple.length' - msg_template = 'wrong tuple length {actual_length}, expected {expected_length}' - - def __init__(self, *, actual_length: int, expected_length: int) -> None: - super().__init__(actual_length=actual_length, expected_length=expected_length) - - -class ListMinLengthError(PydanticValueError): - code = 'list.min_items' - msg_template = 'ensure this value has at least {limit_value} items' - - def __init__(self, *, limit_value: int) -> None: - super().__init__(limit_value=limit_value) - - -class ListMaxLengthError(PydanticValueError): - code = 'list.max_items' - msg_template = 'ensure this value has at most {limit_value} items' - - def __init__(self, *, limit_value: int) -> None: - super().__init__(limit_value=limit_value) - - -class ListUniqueItemsError(PydanticValueError): - code = 'list.unique_items' - msg_template = 'the list has duplicated items' - - -class SetMinLengthError(PydanticValueError): - code = 'set.min_items' - msg_template = 'ensure this value has at least {limit_value} items' - - def __init__(self, *, limit_value: int) -> None: - super().__init__(limit_value=limit_value) - - -class SetMaxLengthError(PydanticValueError): - code = 'set.max_items' - msg_template = 'ensure this value has at most {limit_value} items' - - def __init__(self, *, limit_value: int) -> None: - super().__init__(limit_value=limit_value) - - -class FrozenSetMinLengthError(PydanticValueError): - code = 'frozenset.min_items' - msg_template = 'ensure this value has at least {limit_value} items' - - def __init__(self, *, limit_value: int) -> None: - super().__init__(limit_value=limit_value) - - -class FrozenSetMaxLengthError(PydanticValueError): - code = 'frozenset.max_items' - msg_template = 'ensure this value has at most {limit_value} items' - - def __init__(self, *, limit_value: int) -> None: - super().__init__(limit_value=limit_value) - - -class AnyStrMinLengthError(PydanticValueError): - code = 'any_str.min_length' - msg_template = 'ensure this value has at least {limit_value} characters' - - def __init__(self, *, limit_value: int) -> None: - super().__init__(limit_value=limit_value) - - -class AnyStrMaxLengthError(PydanticValueError): - code = 'any_str.max_length' - msg_template = 'ensure this value has at most {limit_value} characters' - - def __init__(self, *, limit_value: int) -> None: - super().__init__(limit_value=limit_value) - - -class StrError(PydanticTypeError): - msg_template = 'str type expected' - - -class StrRegexError(PydanticValueError): - code = 'str.regex' - msg_template = 'string does not match regex "{pattern}"' - - def __init__(self, *, pattern: str) -> None: - super().__init__(pattern=pattern) - - -class _NumberBoundError(PydanticValueError): - def __init__(self, *, limit_value: Union[int, float, Decimal]) -> None: - super().__init__(limit_value=limit_value) - - -class NumberNotGtError(_NumberBoundError): - code = 'number.not_gt' - msg_template = 'ensure this value is greater than {limit_value}' - - -class NumberNotGeError(_NumberBoundError): - code = 'number.not_ge' - msg_template = 'ensure this value is greater than or equal to {limit_value}' - - -class NumberNotLtError(_NumberBoundError): - code = 'number.not_lt' - msg_template = 'ensure this value is less than {limit_value}' - - -class NumberNotLeError(_NumberBoundError): - code = 'number.not_le' - msg_template = 'ensure this value is less than or equal to {limit_value}' - - -class NumberNotFiniteError(PydanticValueError): - code = 'number.not_finite_number' - msg_template = 'ensure this value is a finite number' - - -class NumberNotMultipleError(PydanticValueError): - code = 'number.not_multiple' - msg_template = 'ensure this value is a multiple of {multiple_of}' - - def __init__(self, *, multiple_of: Union[int, float, Decimal]) -> None: - super().__init__(multiple_of=multiple_of) - - -class DecimalError(PydanticTypeError): - msg_template = 'value is not a valid decimal' - - -class DecimalIsNotFiniteError(PydanticValueError): - code = 'decimal.not_finite' - msg_template = 'value is not a valid decimal' - - -class DecimalMaxDigitsError(PydanticValueError): - code = 'decimal.max_digits' - msg_template = 'ensure that there are no more than {max_digits} digits in total' - - def __init__(self, *, max_digits: int) -> None: - super().__init__(max_digits=max_digits) - - -class DecimalMaxPlacesError(PydanticValueError): - code = 'decimal.max_places' - msg_template = 'ensure that there are no more than {decimal_places} decimal places' - - def __init__(self, *, decimal_places: int) -> None: - super().__init__(decimal_places=decimal_places) - - -class DecimalWholeDigitsError(PydanticValueError): - code = 'decimal.whole_digits' - msg_template = 'ensure that there are no more than {whole_digits} digits before the decimal point' - - def __init__(self, *, whole_digits: int) -> None: - super().__init__(whole_digits=whole_digits) - - -class DateTimeError(PydanticValueError): - msg_template = 'invalid datetime format' - - -class DateError(PydanticValueError): - msg_template = 'invalid date format' - - -class DateNotInThePastError(PydanticValueError): - code = 'date.not_in_the_past' - msg_template = 'date is not in the past' - - -class DateNotInTheFutureError(PydanticValueError): - code = 'date.not_in_the_future' - msg_template = 'date is not in the future' - - -class TimeError(PydanticValueError): - msg_template = 'invalid time format' - - -class DurationError(PydanticValueError): - msg_template = 'invalid duration format' - - -class HashableError(PydanticTypeError): - msg_template = 'value is not a valid hashable' - - -class UUIDError(PydanticTypeError): - msg_template = 'value is not a valid uuid' - - -class UUIDVersionError(PydanticValueError): - code = 'uuid.version' - msg_template = 'uuid version {required_version} expected' - - def __init__(self, *, required_version: int) -> None: - super().__init__(required_version=required_version) - - -class ArbitraryTypeError(PydanticTypeError): - code = 'arbitrary_type' - msg_template = 'instance of {expected_arbitrary_type} expected' - - def __init__(self, *, expected_arbitrary_type: Type[Any]) -> None: - super().__init__(expected_arbitrary_type=display_as_type(expected_arbitrary_type)) - - -class ClassError(PydanticTypeError): - code = 'class' - msg_template = 'a class is expected' - - -class SubclassError(PydanticTypeError): - code = 'subclass' - msg_template = 'subclass of {expected_class} expected' - - def __init__(self, *, expected_class: Type[Any]) -> None: - super().__init__(expected_class=display_as_type(expected_class)) - - -class JsonError(PydanticValueError): - msg_template = 'Invalid JSON' - - -class JsonTypeError(PydanticTypeError): - code = 'json' - msg_template = 'JSON object must be str, bytes or bytearray' - - -class PatternError(PydanticValueError): - code = 'regex_pattern' - msg_template = 'Invalid regular expression' - - -class DataclassTypeError(PydanticTypeError): - code = 'dataclass' - msg_template = 'instance of {class_name}, tuple or dict expected' - - -class CallableError(PydanticTypeError): - msg_template = '{value} is not callable' - - -class EnumError(PydanticTypeError): - code = 'enum_instance' - msg_template = '{value} is not a valid Enum instance' - - -class IntEnumError(PydanticTypeError): - code = 'int_enum_instance' - msg_template = '{value} is not a valid IntEnum instance' - - -class IPvAnyAddressError(PydanticValueError): - msg_template = 'value is not a valid IPv4 or IPv6 address' - - -class IPvAnyInterfaceError(PydanticValueError): - msg_template = 'value is not a valid IPv4 or IPv6 interface' - - -class IPvAnyNetworkError(PydanticValueError): - msg_template = 'value is not a valid IPv4 or IPv6 network' - - -class IPv4AddressError(PydanticValueError): - msg_template = 'value is not a valid IPv4 address' - - -class IPv6AddressError(PydanticValueError): - msg_template = 'value is not a valid IPv6 address' - - -class IPv4NetworkError(PydanticValueError): - msg_template = 'value is not a valid IPv4 network' - - -class IPv6NetworkError(PydanticValueError): - msg_template = 'value is not a valid IPv6 network' - - -class IPv4InterfaceError(PydanticValueError): - msg_template = 'value is not a valid IPv4 interface' - - -class IPv6InterfaceError(PydanticValueError): - msg_template = 'value is not a valid IPv6 interface' - - -class ColorError(PydanticValueError): - msg_template = 'value is not a valid color: {reason}' - - -class StrictBoolError(PydanticValueError): - msg_template = 'value is not a valid boolean' - - -class NotDigitError(PydanticValueError): - code = 'payment_card_number.digits' - msg_template = 'card number is not all digits' - - -class LuhnValidationError(PydanticValueError): - code = 'payment_card_number.luhn_check' - msg_template = 'card number is not luhn valid' - - -class InvalidLengthForBrand(PydanticValueError): - code = 'payment_card_number.invalid_length_for_brand' - msg_template = 'Length for a {brand} card must be {required_length}' - - -class InvalidByteSize(PydanticValueError): - msg_template = 'could not parse value and unit from byte string' - - -class InvalidByteSizeUnit(PydanticValueError): - msg_template = 'could not interpret byte unit: {unit}' - - -class MissingDiscriminator(PydanticValueError): - code = 'discriminated_union.missing_discriminator' - msg_template = 'Discriminator {discriminator_key!r} is missing in value' - - -class InvalidDiscriminator(PydanticValueError): - code = 'discriminated_union.invalid_discriminator' - msg_template = ( - 'No match for discriminator {discriminator_key!r} and value {discriminator_value!r} ' - '(allowed values: {allowed_values})' - ) - - def __init__(self, *, discriminator_key: str, discriminator_value: Any, allowed_values: Sequence[Any]) -> None: - super().__init__( - discriminator_key=discriminator_key, - discriminator_value=discriminator_value, - allowed_values=', '.join(map(repr, allowed_values)), - ) +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/fields.py b/lib/pydantic/fields.py index cecd3d20..b416bb7d 100644 --- a/lib/pydantic/fields.py +++ b/lib/pydantic/fields.py @@ -1,1209 +1,875 @@ -import copy -import re -from collections import Counter as CollectionCounter, defaultdict, deque -from collections.abc import Callable, Hashable as CollectionsHashable, Iterable as CollectionsIterable -from typing import ( - TYPE_CHECKING, - Any, - Counter, - DefaultDict, - Deque, - Dict, - ForwardRef, - FrozenSet, - Generator, - Iterable, - Iterator, - List, - Mapping, - Optional, - Pattern, - Sequence, - Set, - Tuple, - Type, - TypeVar, - Union, -) +"""Defining fields on models.""" +from __future__ import annotations as _annotations -from typing_extensions import Annotated, Final +import dataclasses +import inspect +import typing +from copy import copy +from dataclasses import Field as DataclassField +from functools import cached_property +from typing import Any, ClassVar +from warnings import warn -from . import errors as errors_ -from .class_validators import Validator, make_generic_validator, prep_validators -from .error_wrappers import ErrorWrapper -from .errors import ConfigError, InvalidDiscriminator, MissingDiscriminator, NoneIsNotAllowedError -from .types import Json, JsonWrapper -from .typing import ( - NoArgAnyCallable, - convert_generics, - display_as_type, - get_args, - get_origin, - is_finalvar, - is_literal_type, - is_new_type, - is_none_type, - is_typeddict, - is_typeddict_special, - is_union, - new_type_supertype, -) -from .utils import ( - PyObjectStr, - Representation, - ValueItems, - get_discriminator_alias_and_values, - get_unique_discriminator_alias, - lenient_isinstance, - lenient_issubclass, - sequence_like, - smart_deepcopy, -) -from .validators import constant_validator, dict_validator, find_validators, validate_json +import annotated_types +import typing_extensions +from pydantic_core import PydanticUndefined +from typing_extensions import Literal, Unpack -Required: Any = Ellipsis +from . import types +from ._internal import _decorators, _fields, _generics, _internal_dataclass, _repr, _typing_extra, _utils +from .aliases import AliasChoices, AliasPath +from .config import JsonDict +from .errors import PydanticUserError +from .warnings import PydanticDeprecatedSince20 -T = TypeVar('T') +if typing.TYPE_CHECKING: + from ._internal._repr import ReprArgs +else: + # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915 + # and https://youtrack.jetbrains.com/issue/PY-51428 + DeprecationWarning = PydanticDeprecatedSince20 -class UndefinedType: - def __repr__(self) -> str: - return 'PydanticUndefined' - - def __copy__(self: T) -> T: - return self - - def __reduce__(self) -> str: - return 'Undefined' - - def __deepcopy__(self: T, _: Any) -> T: - return self +_Unset: Any = PydanticUndefined -Undefined = UndefinedType() +class _FromFieldInfoInputs(typing_extensions.TypedDict, total=False): + """This class exists solely to add type checking for the `**kwargs` in `FieldInfo.from_field`.""" -if TYPE_CHECKING: - from .class_validators import ValidatorsList - from .config import BaseConfig - from .error_wrappers import ErrorList - from .types import ModelOrDc - from .typing import AbstractSetIntStr, MappingIntStrAny, ReprArgs - - ValidateReturn = Tuple[Optional[Any], Optional[ErrorList]] - LocStr = Union[Tuple[Union[int, str], ...], str] - BoolUndefined = Union[bool, UndefinedType] + annotation: type[Any] | None + default_factory: typing.Callable[[], Any] | None + alias: str | None + alias_priority: int | None + validation_alias: str | AliasPath | AliasChoices | None + serialization_alias: str | None + title: str | None + description: str | None + examples: list[Any] | None + exclude: bool | None + gt: float | None + ge: float | None + lt: float | None + le: float | None + multiple_of: float | None + strict: bool | None + min_length: int | None + max_length: int | None + pattern: str | None + allow_inf_nan: bool | None + max_digits: int | None + decimal_places: int | None + union_mode: Literal['smart', 'left_to_right'] | None + discriminator: str | types.Discriminator | None + json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None + frozen: bool | None + validate_default: bool | None + repr: bool + init: bool | None + init_var: bool | None + kw_only: bool | None -class FieldInfo(Representation): - """ - Captures extra information about a field. +class _FieldInfoInputs(_FromFieldInfoInputs, total=False): + """This class exists solely to add type checking for the `**kwargs` in `FieldInfo.__init__`.""" + + default: Any + + +class FieldInfo(_repr.Representation): + """This class holds information about a field. + + `FieldInfo` is used for any field definition regardless of whether the [`Field()`][pydantic.fields.Field] + function is explicitly used. + + !!! warning + You generally shouldn't be creating `FieldInfo` directly, you'll only need to use it when accessing + [`BaseModel`][pydantic.main.BaseModel] `.model_fields` internals. + + Attributes: + annotation: The type annotation of the field. + default: The default value of the field. + default_factory: The factory function used to construct the default for the field. + alias: The alias name of the field. + alias_priority: The priority of the field's alias. + validation_alias: The validation alias of the field. + serialization_alias: The serialization alias of the field. + title: The title of the field. + description: The description of the field. + examples: List of examples of the field. + exclude: Whether to exclude the field from the model serialization. + discriminator: Field name or Discriminator for discriminating the type in a tagged union. + json_schema_extra: A dict or callable to provide extra JSON schema properties. + frozen: Whether the field is frozen. + validate_default: Whether to validate the default value of the field. + repr: Whether to include the field in representation of the model. + init: Whether the field should be included in the constructor of the dataclass. + init_var: Whether the field should _only_ be included in the constructor of the dataclass, and not stored. + kw_only: Whether the field should be a keyword-only argument in the constructor of the dataclass. + metadata: List of metadata constraints. """ + annotation: type[Any] | None + default: Any + default_factory: typing.Callable[[], Any] | None + alias: str | None + alias_priority: int | None + validation_alias: str | AliasPath | AliasChoices | None + serialization_alias: str | None + title: str | None + description: str | None + examples: list[Any] | None + exclude: bool | None + discriminator: str | types.Discriminator | None + json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None + frozen: bool | None + validate_default: bool | None + repr: bool + init: bool | None + init_var: bool | None + kw_only: bool | None + metadata: list[Any] + __slots__ = ( + 'annotation', 'default', 'default_factory', 'alias', 'alias_priority', + 'validation_alias', + 'serialization_alias', 'title', 'description', + 'examples', 'exclude', - 'include', - 'const', - 'gt', - 'ge', - 'lt', - 'le', - 'multiple_of', - 'allow_inf_nan', - 'max_digits', - 'decimal_places', - 'min_items', - 'max_items', - 'unique_items', - 'min_length', - 'max_length', - 'allow_mutation', - 'repr', - 'regex', 'discriminator', - 'extra', + 'json_schema_extra', + 'frozen', + 'validate_default', + 'repr', + 'init', + 'init_var', + 'kw_only', + 'metadata', + '_attributes_set', ) - # field constraints with the default value, it's also used in update_from_config below - __field_constraints__ = { - 'min_length': None, - 'max_length': None, - 'regex': None, - 'gt': None, - 'lt': None, - 'ge': None, - 'le': None, - 'multiple_of': None, + # used to convert kwargs to metadata/constraints, + # None has a special meaning - these items are collected into a `PydanticGeneralMetadata` + metadata_lookup: ClassVar[dict[str, typing.Callable[[Any], Any] | None]] = { + 'strict': types.Strict, + 'gt': annotated_types.Gt, + 'ge': annotated_types.Ge, + 'lt': annotated_types.Lt, + 'le': annotated_types.Le, + 'multiple_of': annotated_types.MultipleOf, + 'min_length': annotated_types.MinLen, + 'max_length': annotated_types.MaxLen, + 'pattern': None, 'allow_inf_nan': None, 'max_digits': None, 'decimal_places': None, - 'min_items': None, - 'max_items': None, - 'unique_items': None, - 'allow_mutation': True, + 'union_mode': None, } - def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: - self.default = default + def __init__(self, **kwargs: Unpack[_FieldInfoInputs]) -> None: + """This class should generally not be initialized directly; instead, use the `pydantic.fields.Field` function + or one of the constructor classmethods. + + See the signature of `pydantic.fields.Field` for more details about the expected arguments. + """ + self._attributes_set = {k: v for k, v in kwargs.items() if v is not _Unset} + kwargs = {k: _DefaultValues.get(k) if v is _Unset else v for k, v in kwargs.items()} # type: ignore + self.annotation, annotation_metadata = self._extract_metadata(kwargs.get('annotation')) + + default = kwargs.pop('default', PydanticUndefined) + if default is Ellipsis: + self.default = PydanticUndefined + else: + self.default = default + self.default_factory = kwargs.pop('default_factory', None) - self.alias = kwargs.pop('alias', None) - self.alias_priority = kwargs.pop('alias_priority', 2 if self.alias is not None else None) + + if self.default is not PydanticUndefined and self.default_factory is not None: + raise TypeError('cannot specify both default and default_factory') + self.title = kwargs.pop('title', None) + self.alias = kwargs.pop('alias', None) + self.validation_alias = kwargs.pop('validation_alias', None) + self.serialization_alias = kwargs.pop('serialization_alias', None) + alias_is_set = any(alias is not None for alias in (self.alias, self.validation_alias, self.serialization_alias)) + self.alias_priority = kwargs.pop('alias_priority', None) or 2 if alias_is_set else None self.description = kwargs.pop('description', None) + self.examples = kwargs.pop('examples', None) self.exclude = kwargs.pop('exclude', None) - self.include = kwargs.pop('include', None) - self.const = kwargs.pop('const', None) - self.gt = kwargs.pop('gt', None) - self.ge = kwargs.pop('ge', None) - self.lt = kwargs.pop('lt', None) - self.le = kwargs.pop('le', None) - self.multiple_of = kwargs.pop('multiple_of', None) - self.allow_inf_nan = kwargs.pop('allow_inf_nan', None) - self.max_digits = kwargs.pop('max_digits', None) - self.decimal_places = kwargs.pop('decimal_places', None) - self.min_items = kwargs.pop('min_items', None) - self.max_items = kwargs.pop('max_items', None) - self.unique_items = kwargs.pop('unique_items', None) - self.min_length = kwargs.pop('min_length', None) - self.max_length = kwargs.pop('max_length', None) - self.allow_mutation = kwargs.pop('allow_mutation', True) - self.regex = kwargs.pop('regex', None) self.discriminator = kwargs.pop('discriminator', None) self.repr = kwargs.pop('repr', True) - self.extra = kwargs + self.json_schema_extra = kwargs.pop('json_schema_extra', None) + self.validate_default = kwargs.pop('validate_default', None) + self.frozen = kwargs.pop('frozen', None) + # currently only used on dataclasses + self.init = kwargs.pop('init', None) + self.init_var = kwargs.pop('init_var', None) + self.kw_only = kwargs.pop('kw_only', None) - def __repr_args__(self) -> 'ReprArgs': + self.metadata = self._collect_metadata(kwargs) + annotation_metadata # type: ignore - field_defaults_to_hide: Dict[str, Any] = { - 'repr': True, - **self.__field_constraints__, - } + @staticmethod + def from_field(default: Any = PydanticUndefined, **kwargs: Unpack[_FromFieldInfoInputs]) -> FieldInfo: + """Create a new `FieldInfo` object with the `Field` function. - attrs = ((s, getattr(self, s)) for s in self.__slots__) - return [(a, v) for a, v in attrs if v != field_defaults_to_hide.get(a, None)] + Args: + default: The default value for the field. Defaults to Undefined. + **kwargs: Additional arguments dictionary. - def get_constraints(self) -> Set[str]: + Raises: + TypeError: If 'annotation' is passed as a keyword argument. + + Returns: + A new FieldInfo object with the given parameters. + + Example: + This is how you can create a field with default value like this: + + ```python + import pydantic + + class MyModel(pydantic.BaseModel): + foo: int = pydantic.Field(4) + ``` """ - Gets the constraints set on the field by comparing the constraint value with its default value + if 'annotation' in kwargs: + raise TypeError('"annotation" is not permitted as a Field keyword argument') + return FieldInfo(default=default, **kwargs) - :return: the constraints set on field_info - """ - return {attr for attr, default in self.__field_constraints__.items() if getattr(self, attr) != default} + @staticmethod + def from_annotation(annotation: type[Any]) -> FieldInfo: + """Creates a `FieldInfo` instance from a bare annotation. - def update_from_config(self, from_config: Dict[str, Any]) -> None: + This function is used internally to create a `FieldInfo` from a bare annotation like this: + + ```python + import pydantic + + class MyModel(pydantic.BaseModel): + foo: int # <-- like this + ``` + + We also account for the case where the annotation can be an instance of `Annotated` and where + one of the (not first) arguments in `Annotated` is an instance of `FieldInfo`, e.g.: + + ```python + import annotated_types + from typing_extensions import Annotated + + import pydantic + + class MyModel(pydantic.BaseModel): + foo: Annotated[int, annotated_types.Gt(42)] + bar: Annotated[int, pydantic.Field(gt=42)] + ``` + + Args: + annotation: An annotation object. + + Returns: + An instance of the field metadata. """ - Update this FieldInfo based on a dict from get_field_info, only fields which have not been set are dated. + final = False + if _typing_extra.is_finalvar(annotation): + final = True + if annotation is not typing_extensions.Final: + annotation = typing_extensions.get_args(annotation)[0] + + if _typing_extra.is_annotated(annotation): + first_arg, *extra_args = typing_extensions.get_args(annotation) + if _typing_extra.is_finalvar(first_arg): + final = True + field_info_annotations = [a for a in extra_args if isinstance(a, FieldInfo)] + field_info = FieldInfo.merge_field_infos(*field_info_annotations, annotation=first_arg) + if field_info: + new_field_info = copy(field_info) + new_field_info.annotation = first_arg + new_field_info.frozen = final or field_info.frozen + metadata: list[Any] = [] + for a in extra_args: + if not isinstance(a, FieldInfo): + metadata.append(a) + else: + metadata.extend(a.metadata) + new_field_info.metadata = metadata + return new_field_info + + return FieldInfo(annotation=annotation, frozen=final or None) + + @staticmethod + def from_annotated_attribute(annotation: type[Any], default: Any) -> FieldInfo: + """Create `FieldInfo` from an annotation with a default value. + + This is used in cases like the following: + + ```python + import annotated_types + from typing_extensions import Annotated + + import pydantic + + class MyModel(pydantic.BaseModel): + foo: int = 4 # <-- like this + bar: Annotated[int, annotated_types.Gt(4)] = 4 # <-- or this + spam: Annotated[int, pydantic.Field(gt=4)] = 4 # <-- or this + ``` + + Args: + annotation: The type annotation of the field. + default: The default value of the field. + + Returns: + A field object with the passed values. """ - for attr_name, value in from_config.items(): + if annotation is default: + raise PydanticUserError( + 'Error when building FieldInfo from annotated attribute. ' + "Make sure you don't have any field name clashing with a type annotation ", + code='unevaluable-type-annotation', + ) + + final = False + if _typing_extra.is_finalvar(annotation): + final = True + if annotation is not typing_extensions.Final: + annotation = typing_extensions.get_args(annotation)[0] + + if isinstance(default, FieldInfo): + default.annotation, annotation_metadata = FieldInfo._extract_metadata(annotation) + default.metadata += annotation_metadata + default = default.merge_field_infos( + *[x for x in annotation_metadata if isinstance(x, FieldInfo)], default, annotation=default.annotation + ) + default.frozen = final or default.frozen + return default + elif isinstance(default, dataclasses.Field): + init_var = False + if annotation is dataclasses.InitVar: + init_var = True + annotation = Any + elif isinstance(annotation, dataclasses.InitVar): + init_var = True + annotation = annotation.type + pydantic_field = FieldInfo._from_dataclass_field(default) + pydantic_field.annotation, annotation_metadata = FieldInfo._extract_metadata(annotation) + pydantic_field.metadata += annotation_metadata + pydantic_field = pydantic_field.merge_field_infos( + *[x for x in annotation_metadata if isinstance(x, FieldInfo)], + pydantic_field, + annotation=pydantic_field.annotation, + ) + pydantic_field.frozen = final or pydantic_field.frozen + pydantic_field.init_var = init_var + pydantic_field.init = getattr(default, 'init', None) + pydantic_field.kw_only = getattr(default, 'kw_only', None) + return pydantic_field + else: + if _typing_extra.is_annotated(annotation): + first_arg, *extra_args = typing_extensions.get_args(annotation) + field_infos = [a for a in extra_args if isinstance(a, FieldInfo)] + field_info = FieldInfo.merge_field_infos(*field_infos, annotation=first_arg, default=default) + metadata: list[Any] = [] + for a in extra_args: + if not isinstance(a, FieldInfo): + metadata.append(a) + else: + metadata.extend(a.metadata) + field_info.metadata = metadata + return field_info + + return FieldInfo(annotation=annotation, default=default, frozen=final or None) + + @staticmethod + def merge_field_infos(*field_infos: FieldInfo, **overrides: Any) -> FieldInfo: + """Merge `FieldInfo` instances keeping only explicitly set attributes. + + Later `FieldInfo` instances override earlier ones. + + Returns: + FieldInfo: A merged FieldInfo instance. + """ + flattened_field_infos: list[FieldInfo] = [] + for field_info in field_infos: + flattened_field_infos.extend(x for x in field_info.metadata if isinstance(x, FieldInfo)) + flattened_field_infos.append(field_info) + field_infos = tuple(flattened_field_infos) + if len(field_infos) == 1: + # No merging necessary, but we still need to make a copy and apply the overrides + field_info = copy(field_infos[0]) + field_info._attributes_set.update(overrides) + for k, v in overrides.items(): + setattr(field_info, k, v) + return field_info # type: ignore + + new_kwargs: dict[str, Any] = {} + metadata = {} + for field_info in field_infos: + new_kwargs.update(field_info._attributes_set) + for x in field_info.metadata: + if not isinstance(x, FieldInfo): + metadata[type(x)] = x + new_kwargs.update(overrides) + field_info = FieldInfo(**new_kwargs) + field_info.metadata = list(metadata.values()) + return field_info + + @staticmethod + def _from_dataclass_field(dc_field: DataclassField[Any]) -> FieldInfo: + """Return a new `FieldInfo` instance from a `dataclasses.Field` instance. + + Args: + dc_field: The `dataclasses.Field` instance to convert. + + Returns: + The corresponding `FieldInfo` instance. + + Raises: + TypeError: If any of the `FieldInfo` kwargs does not match the `dataclass.Field` kwargs. + """ + default = dc_field.default + if default is dataclasses.MISSING: + default = PydanticUndefined + + if dc_field.default_factory is dataclasses.MISSING: + default_factory: typing.Callable[[], Any] | None = None + else: + default_factory = dc_field.default_factory + + # use the `Field` function so in correct kwargs raise the correct `TypeError` + dc_field_metadata = {k: v for k, v in dc_field.metadata.items() if k in _FIELD_ARG_NAMES} + return Field(default=default, default_factory=default_factory, repr=dc_field.repr, **dc_field_metadata) + + @staticmethod + def _extract_metadata(annotation: type[Any] | None) -> tuple[type[Any] | None, list[Any]]: + """Tries to extract metadata/constraints from an annotation if it uses `Annotated`. + + Args: + annotation: The type hint annotation for which metadata has to be extracted. + + Returns: + A tuple containing the extracted metadata type and the list of extra arguments. + """ + if annotation is not None: + if _typing_extra.is_annotated(annotation): + first_arg, *extra_args = typing_extensions.get_args(annotation) + return first_arg, list(extra_args) + + return annotation, [] + + @staticmethod + def _collect_metadata(kwargs: dict[str, Any]) -> list[Any]: + """Collect annotations from kwargs. + + Args: + kwargs: Keyword arguments passed to the function. + + Returns: + A list of metadata objects - a combination of `annotated_types.BaseMetadata` and + `PydanticMetadata`. + """ + metadata: list[Any] = [] + general_metadata = {} + for key, value in list(kwargs.items()): try: - current_value = getattr(self, attr_name) - except AttributeError: - # attr_name is not an attribute of FieldInfo, it should therefore be added to extra - # (except if extra already has this value!) - self.extra.setdefault(attr_name, value) + marker = FieldInfo.metadata_lookup[key] + except KeyError: + continue + + del kwargs[key] + if value is not None: + if marker is None: + general_metadata[key] = value + else: + metadata.append(marker(value)) + if general_metadata: + metadata.append(_fields.pydantic_general_metadata(**general_metadata)) + return metadata + + def get_default(self, *, call_default_factory: bool = False) -> Any: + """Get the default value. + + We expose an option for whether to call the default_factory (if present), as calling it may + result in side effects that we want to avoid. However, there are times when it really should + be called (namely, when instantiating a model via `model_construct`). + + Args: + call_default_factory: Whether to call the default_factory or not. Defaults to `False`. + + Returns: + The default value, calling the default factory if requested or `None` if not set. + """ + if self.default_factory is None: + return _utils.smart_deepcopy(self.default) + elif call_default_factory: + return self.default_factory() + else: + return None + + def is_required(self) -> bool: + """Check if the field is required (i.e., does not have a default value or factory). + + Returns: + `True` if the field is required, `False` otherwise. + """ + return self.default is PydanticUndefined and self.default_factory is None + + def rebuild_annotation(self) -> Any: + """Attempts to rebuild the original annotation for use in function signatures. + + If metadata is present, it adds it to the original annotation using + `Annotated`. Otherwise, it returns the original annotation as-is. + + Note that because the metadata has been flattened, the original annotation + may not be reconstructed exactly as originally provided, e.g. if the original + type had unrecognized annotations, or was annotated with a call to `pydantic.Field`. + + Returns: + The rebuilt annotation. + """ + if not self.metadata: + return self.annotation + else: + # Annotated arguments must be a tuple + return typing_extensions.Annotated[(self.annotation, *self.metadata)] # type: ignore + + def apply_typevars_map(self, typevars_map: dict[Any, Any] | None, types_namespace: dict[str, Any] | None) -> None: + """Apply a `typevars_map` to the annotation. + + This method is used when analyzing parametrized generic types to replace typevars with their concrete types. + + This method applies the `typevars_map` to the annotation in place. + + Args: + typevars_map: A dictionary mapping type variables to their concrete types. + types_namespace (dict | None): A dictionary containing related types to the annotated type. + + See Also: + pydantic._internal._generics.replace_types is used for replacing the typevars with + their concrete types. + """ + annotation = _typing_extra.eval_type_lenient(self.annotation, types_namespace) + self.annotation = _generics.replace_types(annotation, typevars_map) + + def __repr_args__(self) -> ReprArgs: + yield 'annotation', _repr.PlainRepr(_repr.display_as_type(self.annotation)) + yield 'required', self.is_required() + + for s in self.__slots__: + if s == '_attributes_set': + continue + if s == 'annotation': + continue + elif s == 'metadata' and not self.metadata: + continue + elif s == 'repr' and self.repr is True: + continue + if s == 'frozen' and self.frozen is False: + continue + if s == 'validation_alias' and self.validation_alias == self.alias: + continue + if s == 'serialization_alias' and self.serialization_alias == self.alias: + continue + if s == 'default_factory' and self.default_factory is not None: + yield 'default_factory', _repr.PlainRepr(_repr.display_as_type(self.default_factory)) else: - if current_value is self.__field_constraints__.get(attr_name, None): - setattr(self, attr_name, value) - elif attr_name == 'exclude': - self.exclude = ValueItems.merge(value, current_value) - elif attr_name == 'include': - self.include = ValueItems.merge(value, current_value, intersect=True) - - def _validate(self) -> None: - if self.default is not Undefined and self.default_factory is not None: - raise ValueError('cannot specify both default and default_factory') + value = getattr(self, s) + if value is not None and value is not PydanticUndefined: + yield s, value -def Field( - default: Any = Undefined, +class _EmptyKwargs(typing_extensions.TypedDict): + """This class exists solely to ensure that type checking warns about passing `**extra` in `Field`.""" + + +_DefaultValues = dict( + default=..., + default_factory=None, + alias=None, + alias_priority=None, + validation_alias=None, + serialization_alias=None, + title=None, + description=None, + examples=None, + exclude=None, + discriminator=None, + json_schema_extra=None, + frozen=None, + validate_default=None, + repr=True, + init=None, + init_var=None, + kw_only=None, + pattern=None, + strict=None, + gt=None, + ge=None, + lt=None, + le=None, + multiple_of=None, + allow_inf_nan=None, + max_digits=None, + decimal_places=None, + min_length=None, + max_length=None, +) + + +def Field( # noqa: C901 + default: Any = PydanticUndefined, *, - default_factory: Optional[NoArgAnyCallable] = None, - alias: str = None, - title: str = None, - description: str = None, - exclude: Union['AbstractSetIntStr', 'MappingIntStrAny', Any] = None, - include: Union['AbstractSetIntStr', 'MappingIntStrAny', Any] = None, - const: bool = None, - gt: float = None, - ge: float = None, - lt: float = None, - le: float = None, - multiple_of: float = None, - allow_inf_nan: bool = None, - max_digits: int = None, - decimal_places: int = None, - min_items: int = None, - max_items: int = None, - unique_items: bool = None, - min_length: int = None, - max_length: int = None, - allow_mutation: bool = True, - regex: str = None, - discriminator: str = None, - repr: bool = True, - **extra: Any, + default_factory: typing.Callable[[], Any] | None = _Unset, + alias: str | None = _Unset, + alias_priority: int | None = _Unset, + validation_alias: str | AliasPath | AliasChoices | None = _Unset, + serialization_alias: str | None = _Unset, + title: str | None = _Unset, + description: str | None = _Unset, + examples: list[Any] | None = _Unset, + exclude: bool | None = _Unset, + discriminator: str | types.Discriminator | None = _Unset, + json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None = _Unset, + frozen: bool | None = _Unset, + validate_default: bool | None = _Unset, + repr: bool = _Unset, + init: bool | None = _Unset, + init_var: bool | None = _Unset, + kw_only: bool | None = _Unset, + pattern: str | None = _Unset, + strict: bool | None = _Unset, + gt: float | None = _Unset, + ge: float | None = _Unset, + lt: float | None = _Unset, + le: float | None = _Unset, + multiple_of: float | None = _Unset, + allow_inf_nan: bool | None = _Unset, + max_digits: int | None = _Unset, + decimal_places: int | None = _Unset, + min_length: int | None = _Unset, + max_length: int | None = _Unset, + union_mode: Literal['smart', 'left_to_right'] = _Unset, + **extra: Unpack[_EmptyKwargs], ) -> Any: - """ - Used to provide extra information about a field, either for the model schema or complex validation. Some arguments - apply only to number fields (``int``, ``float``, ``Decimal``) and some apply only to ``str``. + """Usage docs: https://docs.pydantic.dev/2.6/concepts/fields - :param default: since this is replacing the field’s default, its first argument is used - to set the default, use ellipsis (``...``) to indicate the field is required - :param default_factory: callable that will be called when a default value is needed for this field - If both `default` and `default_factory` are set, an error is raised. - :param alias: the public name of the field - :param title: can be any string, used in the schema - :param description: can be any string, used in the schema - :param exclude: exclude this field while dumping. - Takes same values as the ``include`` and ``exclude`` arguments on the ``.dict`` method. - :param include: include this field while dumping. - Takes same values as the ``include`` and ``exclude`` arguments on the ``.dict`` method. - :param const: this field is required and *must* take it's default value - :param gt: only applies to numbers, requires the field to be "greater than". The schema - will have an ``exclusiveMinimum`` validation keyword - :param ge: only applies to numbers, requires the field to be "greater than or equal to". The - schema will have a ``minimum`` validation keyword - :param lt: only applies to numbers, requires the field to be "less than". The schema - will have an ``exclusiveMaximum`` validation keyword - :param le: only applies to numbers, requires the field to be "less than or equal to". The - schema will have a ``maximum`` validation keyword - :param multiple_of: only applies to numbers, requires the field to be "a multiple of". The - schema will have a ``multipleOf`` validation keyword - :param allow_inf_nan: only applies to numbers, allows the field to be NaN or infinity (+inf or -inf), - which is a valid Python float. Default True, set to False for compatibility with JSON. - :param max_digits: only applies to Decimals, requires the field to have a maximum number - of digits within the decimal. It does not include a zero before the decimal point or trailing decimal zeroes. - :param decimal_places: only applies to Decimals, requires the field to have at most a number of decimal places - allowed. It does not include trailing decimal zeroes. - :param min_items: only applies to lists, requires the field to have a minimum number of - elements. The schema will have a ``minItems`` validation keyword - :param max_items: only applies to lists, requires the field to have a maximum number of - elements. The schema will have a ``maxItems`` validation keyword - :param unique_items: only applies to lists, requires the field not to have duplicated - elements. The schema will have a ``uniqueItems`` validation keyword - :param min_length: only applies to strings, requires the field to have a minimum length. The - schema will have a ``maximum`` validation keyword - :param max_length: only applies to strings, requires the field to have a maximum length. The - schema will have a ``maxLength`` validation keyword - :param allow_mutation: a boolean which defaults to True. When False, the field raises a TypeError if the field is - assigned on an instance. The BaseModel Config must set validate_assignment to True - :param regex: only applies to strings, requires the field match against a regular expression - pattern string. The schema will have a ``pattern`` validation keyword - :param discriminator: only useful with a (discriminated a.k.a. tagged) `Union` of sub models with a common field. - The `discriminator` is the name of this common field to shorten validation and improve generated schema - :param repr: show this field in the representation - :param **extra: any additional keyword arguments will be added as is to the schema + Create a field for objects that can be configured. + + Used to provide extra information about a field, either for the model schema or complex validation. Some arguments + apply only to number fields (`int`, `float`, `Decimal`) and some apply only to `str`. + + Note: + - Any `_Unset` objects will be replaced by the corresponding value defined in the `_DefaultValues` dictionary. If a key for the `_Unset` object is not found in the `_DefaultValues` dictionary, it will default to `None` + + Args: + default: Default value if the field is not set. + default_factory: A callable to generate the default value, such as :func:`~datetime.utcnow`. + alias: The name to use for the attribute when validating or serializing by alias. + This is often used for things like converting between snake and camel case. + alias_priority: Priority of the alias. This affects whether an alias generator is used. + validation_alias: Like `alias`, but only affects validation, not serialization. + serialization_alias: Like `alias`, but only affects serialization, not validation. + title: Human-readable title. + description: Human-readable description. + examples: Example values for this field. + exclude: Whether to exclude the field from the model serialization. + discriminator: Field name or Discriminator for discriminating the type in a tagged union. + json_schema_extra: A dict or callable to provide extra JSON schema properties. + frozen: Whether the field is frozen. If true, attempts to change the value on an instance will raise an error. + validate_default: If `True`, apply validation to the default value every time you create an instance. + Otherwise, for performance reasons, the default value of the field is trusted and not validated. + repr: A boolean indicating whether to include the field in the `__repr__` output. + init: Whether the field should be included in the constructor of the dataclass. + (Only applies to dataclasses.) + init_var: Whether the field should _only_ be included in the constructor of the dataclass. + (Only applies to dataclasses.) + kw_only: Whether the field should be a keyword-only argument in the constructor of the dataclass. + (Only applies to dataclasses.) + strict: If `True`, strict validation is applied to the field. + See [Strict Mode](../concepts/strict_mode.md) for details. + gt: Greater than. If set, value must be greater than this. Only applicable to numbers. + ge: Greater than or equal. If set, value must be greater than or equal to this. Only applicable to numbers. + lt: Less than. If set, value must be less than this. Only applicable to numbers. + le: Less than or equal. If set, value must be less than or equal to this. Only applicable to numbers. + multiple_of: Value must be a multiple of this. Only applicable to numbers. + min_length: Minimum length for strings. + max_length: Maximum length for strings. + pattern: Pattern for strings (a regular expression). + allow_inf_nan: Allow `inf`, `-inf`, `nan`. Only applicable to numbers. + max_digits: Maximum number of allow digits for strings. + decimal_places: Maximum number of decimal places allowed for numbers. + union_mode: The strategy to apply when validating a union. Can be `smart` (the default), or `left_to_right`. + See [Union Mode](standard_library_types.md#union-mode) for details. + extra: (Deprecated) Extra fields that will be included in the JSON schema. + + !!! warning Deprecated + The `extra` kwargs is deprecated. Use `json_schema_extra` instead. + + Returns: + A new [`FieldInfo`][pydantic.fields.FieldInfo]. The return annotation is `Any` so `Field` can be used on + type-annotated fields without causing a type error. """ - field_info = FieldInfo( + # Check deprecated and removed params from V1. This logic should eventually be removed. + const = extra.pop('const', None) # type: ignore + if const is not None: + raise PydanticUserError('`const` is removed, use `Literal` instead', code='removed-kwargs') + + min_items = extra.pop('min_items', None) # type: ignore + if min_items is not None: + warn('`min_items` is deprecated and will be removed, use `min_length` instead', DeprecationWarning) + if min_length in (None, _Unset): + min_length = min_items # type: ignore + + max_items = extra.pop('max_items', None) # type: ignore + if max_items is not None: + warn('`max_items` is deprecated and will be removed, use `max_length` instead', DeprecationWarning) + if max_length in (None, _Unset): + max_length = max_items # type: ignore + + unique_items = extra.pop('unique_items', None) # type: ignore + if unique_items is not None: + raise PydanticUserError( + ( + '`unique_items` is removed, use `Set` instead' + '(this feature is discussed in https://github.com/pydantic/pydantic-core/issues/296)' + ), + code='removed-kwargs', + ) + + allow_mutation = extra.pop('allow_mutation', None) # type: ignore + if allow_mutation is not None: + warn('`allow_mutation` is deprecated and will be removed. use `frozen` instead', DeprecationWarning) + if allow_mutation is False: + frozen = True + + regex = extra.pop('regex', None) # type: ignore + if regex is not None: + raise PydanticUserError('`regex` is removed. use `pattern` instead', code='removed-kwargs') + + if extra: + warn( + 'Using extra keyword arguments on `Field` is deprecated and will be removed.' + ' Use `json_schema_extra` instead.' + f' (Extra keys: {", ".join(k.__repr__() for k in extra.keys())})', + DeprecationWarning, + ) + if not json_schema_extra or json_schema_extra is _Unset: + json_schema_extra = extra # type: ignore + + if ( + validation_alias + and validation_alias is not _Unset + and not isinstance(validation_alias, (str, AliasChoices, AliasPath)) + ): + raise TypeError('Invalid `validation_alias` type. it should be `str`, `AliasChoices`, or `AliasPath`') + + if serialization_alias in (_Unset, None) and isinstance(alias, str): + serialization_alias = alias + + if validation_alias in (_Unset, None): + validation_alias = alias + + include = extra.pop('include', None) # type: ignore + if include is not None: + warn('`include` is deprecated and does nothing. It will be removed, use `exclude` instead', DeprecationWarning) + + return FieldInfo.from_field( default, default_factory=default_factory, alias=alias, + alias_priority=alias_priority, + validation_alias=validation_alias, + serialization_alias=serialization_alias, title=title, description=description, + examples=examples, exclude=exclude, - include=include, - const=const, + discriminator=discriminator, + json_schema_extra=json_schema_extra, + frozen=frozen, + pattern=pattern, + validate_default=validate_default, + repr=repr, + init=init, + init_var=init_var, + kw_only=kw_only, + strict=strict, gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of, + min_length=min_length, + max_length=max_length, allow_inf_nan=allow_inf_nan, max_digits=max_digits, decimal_places=decimal_places, - min_items=min_items, - max_items=max_items, - unique_items=unique_items, - min_length=min_length, - max_length=max_length, - allow_mutation=allow_mutation, - regex=regex, - discriminator=discriminator, - repr=repr, - **extra, + union_mode=union_mode, ) - field_info._validate() - return field_info -# used to be an enum but changed to int's for small performance improvement as less access overhead -SHAPE_SINGLETON = 1 -SHAPE_LIST = 2 -SHAPE_SET = 3 -SHAPE_MAPPING = 4 -SHAPE_TUPLE = 5 -SHAPE_TUPLE_ELLIPSIS = 6 -SHAPE_SEQUENCE = 7 -SHAPE_FROZENSET = 8 -SHAPE_ITERABLE = 9 -SHAPE_GENERIC = 10 -SHAPE_DEQUE = 11 -SHAPE_DICT = 12 -SHAPE_DEFAULTDICT = 13 -SHAPE_COUNTER = 14 -SHAPE_NAME_LOOKUP = { - SHAPE_LIST: 'List[{}]', - SHAPE_SET: 'Set[{}]', - SHAPE_TUPLE_ELLIPSIS: 'Tuple[{}, ...]', - SHAPE_SEQUENCE: 'Sequence[{}]', - SHAPE_FROZENSET: 'FrozenSet[{}]', - SHAPE_ITERABLE: 'Iterable[{}]', - SHAPE_DEQUE: 'Deque[{}]', - SHAPE_DICT: 'Dict[{}]', - SHAPE_DEFAULTDICT: 'DefaultDict[{}]', - SHAPE_COUNTER: 'Counter[{}]', -} - -MAPPING_LIKE_SHAPES: Set[int] = {SHAPE_DEFAULTDICT, SHAPE_DICT, SHAPE_MAPPING, SHAPE_COUNTER} +_FIELD_ARG_NAMES = set(inspect.signature(Field).parameters) +_FIELD_ARG_NAMES.remove('extra') # do not include the varkwargs parameter -class ModelField(Representation): - __slots__ = ( - 'type_', - 'outer_type_', - 'annotation', - 'sub_fields', - 'sub_fields_mapping', - 'key_field', - 'validators', - 'pre_validators', - 'post_validators', - 'default', - 'default_factory', - 'required', - 'final', - 'model_config', - 'name', - 'alias', - 'has_alias', - 'field_info', - 'discriminator_key', - 'discriminator_alias', - 'validate_always', - 'allow_none', - 'shape', - 'class_validators', - 'parse_json', - ) +class ModelPrivateAttr(_repr.Representation): + """A descriptor for private attributes in class models. + + !!! warning + You generally shouldn't be creating `ModelPrivateAttr` instances directly, instead use + `pydantic.fields.PrivateAttr`. (This is similar to `FieldInfo` vs. `Field`.) + + Attributes: + default: The default value of the attribute if not provided. + default_factory: A callable function that generates the default value of the + attribute if not provided. + """ + + __slots__ = 'default', 'default_factory' def __init__( - self, - *, - name: str, - type_: Type[Any], - class_validators: Optional[Dict[str, Validator]], - model_config: Type['BaseConfig'], - default: Any = None, - default_factory: Optional[NoArgAnyCallable] = None, - required: 'BoolUndefined' = Undefined, - final: bool = False, - alias: str = None, - field_info: Optional[FieldInfo] = None, + self, default: Any = PydanticUndefined, *, default_factory: typing.Callable[[], Any] | None = None ) -> None: - - self.name: str = name - self.has_alias: bool = alias is not None - self.alias: str = alias if alias is not None else name - self.annotation = type_ - self.type_: Any = convert_generics(type_) - self.outer_type_: Any = type_ - self.class_validators = class_validators or {} - self.default: Any = default - self.default_factory: Optional[NoArgAnyCallable] = default_factory - self.required: 'BoolUndefined' = required - self.final: bool = final - self.model_config = model_config - self.field_info: FieldInfo = field_info or FieldInfo(default) - self.discriminator_key: Optional[str] = self.field_info.discriminator - self.discriminator_alias: Optional[str] = self.discriminator_key - - self.allow_none: bool = False - self.validate_always: bool = False - self.sub_fields: Optional[List[ModelField]] = None - self.sub_fields_mapping: Optional[Dict[str, 'ModelField']] = None # used for discriminated union - self.key_field: Optional[ModelField] = None - self.validators: 'ValidatorsList' = [] - self.pre_validators: Optional['ValidatorsList'] = None - self.post_validators: Optional['ValidatorsList'] = None - self.parse_json: bool = False - self.shape: int = SHAPE_SINGLETON - self.model_config.prepare_field(self) - self.prepare() - - def get_default(self) -> Any: - return smart_deepcopy(self.default) if self.default_factory is None else self.default_factory() - - @staticmethod - def _get_field_info( - field_name: str, annotation: Any, value: Any, config: Type['BaseConfig'] - ) -> Tuple[FieldInfo, Any]: - """ - Get a FieldInfo from a root typing.Annotated annotation, value, or config default. - - The FieldInfo may be set in typing.Annotated or the value, but not both. If neither contain - a FieldInfo, a new one will be created using the config. - - :param field_name: name of the field for use in error messages - :param annotation: a type hint such as `str` or `Annotated[str, Field(..., min_length=5)]` - :param value: the field's assigned value - :param config: the model's config object - :return: the FieldInfo contained in the `annotation`, the value, or a new one from the config. - """ - field_info_from_config = config.get_field_info(field_name) - - field_info = None - if get_origin(annotation) is Annotated: - field_infos = [arg for arg in get_args(annotation)[1:] if isinstance(arg, FieldInfo)] - if len(field_infos) > 1: - raise ValueError(f'cannot specify multiple `Annotated` `Field`s for {field_name!r}') - field_info = next(iter(field_infos), None) - if field_info is not None: - field_info = copy.copy(field_info) - field_info.update_from_config(field_info_from_config) - if field_info.default not in (Undefined, Required): - raise ValueError(f'`Field` default cannot be set in `Annotated` for {field_name!r}') - if value is not Undefined and value is not Required: - # check also `Required` because of `validate_arguments` that sets `...` as default value - field_info.default = value - - if isinstance(value, FieldInfo): - if field_info is not None: - raise ValueError(f'cannot specify `Annotated` and value `Field`s together for {field_name!r}') - field_info = value - field_info.update_from_config(field_info_from_config) - elif field_info is None: - field_info = FieldInfo(value, **field_info_from_config) - value = None if field_info.default_factory is not None else field_info.default - field_info._validate() - return field_info, value - - @classmethod - def infer( - cls, - *, - name: str, - value: Any, - annotation: Any, - class_validators: Optional[Dict[str, Validator]], - config: Type['BaseConfig'], - ) -> 'ModelField': - from .schema import get_annotation_from_field_info - - field_info, value = cls._get_field_info(name, annotation, value, config) - required: 'BoolUndefined' = Undefined - if value is Required: - required = True - value = None - elif value is not Undefined: - required = False - annotation = get_annotation_from_field_info(annotation, field_info, name, config.validate_assignment) - - return cls( - name=name, - type_=annotation, - alias=field_info.alias, - class_validators=class_validators, - default=value, - default_factory=field_info.default_factory, - required=required, - model_config=config, - field_info=field_info, - ) - - def set_config(self, config: Type['BaseConfig']) -> None: - self.model_config = config - info_from_config = config.get_field_info(self.name) - config.prepare_field(self) - new_alias = info_from_config.get('alias') - new_alias_priority = info_from_config.get('alias_priority') or 0 - if new_alias and new_alias_priority >= (self.field_info.alias_priority or 0): - self.field_info.alias = new_alias - self.field_info.alias_priority = new_alias_priority - self.alias = new_alias - new_exclude = info_from_config.get('exclude') - if new_exclude is not None: - self.field_info.exclude = ValueItems.merge(self.field_info.exclude, new_exclude) - new_include = info_from_config.get('include') - if new_include is not None: - self.field_info.include = ValueItems.merge(self.field_info.include, new_include, intersect=True) - - @property - def alt_alias(self) -> bool: - return self.name != self.alias - - def prepare(self) -> None: - """ - Prepare the field but inspecting self.default, self.type_ etc. - - Note: this method is **not** idempotent (because _type_analysis is not idempotent), - e.g. calling it it multiple times may modify the field and configure it incorrectly. - """ - self._set_default_and_type() - if self.type_.__class__ is ForwardRef or self.type_.__class__ is DeferredType: - # self.type_ is currently a ForwardRef and there's nothing we can do now, - # user will need to call model.update_forward_refs() - return - - self._type_analysis() - if self.required is Undefined: - self.required = True - if self.default is Undefined and self.default_factory is None: - self.default = None - self.populate_validators() - - def _set_default_and_type(self) -> None: - """ - Set the default value, infer the type if needed and check if `None` value is valid. - """ - if self.default_factory is not None: - if self.type_ is Undefined: - raise errors_.ConfigError( - f'you need to set the type of field {self.name!r} when using `default_factory`' - ) - return - - default_value = self.get_default() - - if default_value is not None and self.type_ is Undefined: - self.type_ = default_value.__class__ - self.outer_type_ = self.type_ - self.annotation = self.type_ - - if self.type_ is Undefined: - raise errors_.ConfigError(f'unable to infer type for attribute "{self.name}"') - - if self.required is False and default_value is None: - self.allow_none = True - - def _type_analysis(self) -> None: # noqa: C901 (ignore complexity) - # typing interface is horrible, we have to do some ugly checks - if lenient_issubclass(self.type_, JsonWrapper): - self.type_ = self.type_.inner_type - self.parse_json = True - elif lenient_issubclass(self.type_, Json): - self.type_ = Any - self.parse_json = True - elif isinstance(self.type_, TypeVar): - if self.type_.__bound__: - self.type_ = self.type_.__bound__ - elif self.type_.__constraints__: - self.type_ = Union[self.type_.__constraints__] - else: - self.type_ = Any - elif is_new_type(self.type_): - self.type_ = new_type_supertype(self.type_) - - if self.type_ is Any or self.type_ is object: - if self.required is Undefined: - self.required = False - self.allow_none = True - return - elif self.type_ is Pattern or self.type_ is re.Pattern: - # python 3.7 only, Pattern is a typing object but without sub fields - return - elif is_literal_type(self.type_): - return - elif is_typeddict(self.type_): - return - - if is_finalvar(self.type_): - self.final = True - - if self.type_ is Final: - self.type_ = Any - else: - self.type_ = get_args(self.type_)[0] - - self._type_analysis() - return - - origin = get_origin(self.type_) - - if origin is Annotated or is_typeddict_special(origin): - self.type_ = get_args(self.type_)[0] - self._type_analysis() - return - - if self.discriminator_key is not None and not is_union(origin): - raise TypeError('`discriminator` can only be used with `Union` type with more than one variant') - - # add extra check for `collections.abc.Hashable` for python 3.10+ where origin is not `None` - if origin is None or origin is CollectionsHashable: - # field is not "typing" object eg. Union, Dict, List etc. - # allow None for virtual superclasses of NoneType, e.g. Hashable - if isinstance(self.type_, type) and isinstance(None, self.type_): - self.allow_none = True - return - elif origin is Callable: - return - elif is_union(origin): - types_ = [] - for type_ in get_args(self.type_): - if is_none_type(type_) or type_ is Any or type_ is object: - if self.required is Undefined: - self.required = False - self.allow_none = True - if is_none_type(type_): - continue - types_.append(type_) - - if len(types_) == 1: - # Optional[] - self.type_ = types_[0] - # this is the one case where the "outer type" isn't just the original type - self.outer_type_ = self.type_ - # re-run to correctly interpret the new self.type_ - self._type_analysis() - else: - self.sub_fields = [self._create_sub_type(t, f'{self.name}_{display_as_type(t)}') for t in types_] - - if self.discriminator_key is not None: - self.prepare_discriminated_union_sub_fields() - return - elif issubclass(origin, Tuple): # type: ignore - # origin == Tuple without item type - args = get_args(self.type_) - if not args: # plain tuple - self.type_ = Any - self.shape = SHAPE_TUPLE_ELLIPSIS - elif len(args) == 2 and args[1] is Ellipsis: # e.g. Tuple[int, ...] - self.type_ = args[0] - self.shape = SHAPE_TUPLE_ELLIPSIS - self.sub_fields = [self._create_sub_type(args[0], f'{self.name}_0')] - elif args == ((),): # Tuple[()] means empty tuple - self.shape = SHAPE_TUPLE - self.type_ = Any - self.sub_fields = [] - else: - self.shape = SHAPE_TUPLE - self.sub_fields = [self._create_sub_type(t, f'{self.name}_{i}') for i, t in enumerate(args)] - return - elif issubclass(origin, List): - # Create self validators - get_validators = getattr(self.type_, '__get_validators__', None) - if get_validators: - self.class_validators.update( - {f'list_{i}': Validator(validator, pre=True) for i, validator in enumerate(get_validators())} - ) - - self.type_ = get_args(self.type_)[0] - self.shape = SHAPE_LIST - elif issubclass(origin, Set): - # Create self validators - get_validators = getattr(self.type_, '__get_validators__', None) - if get_validators: - self.class_validators.update( - {f'set_{i}': Validator(validator, pre=True) for i, validator in enumerate(get_validators())} - ) - - self.type_ = get_args(self.type_)[0] - self.shape = SHAPE_SET - elif issubclass(origin, FrozenSet): - # Create self validators - get_validators = getattr(self.type_, '__get_validators__', None) - if get_validators: - self.class_validators.update( - {f'frozenset_{i}': Validator(validator, pre=True) for i, validator in enumerate(get_validators())} - ) - - self.type_ = get_args(self.type_)[0] - self.shape = SHAPE_FROZENSET - elif issubclass(origin, Deque): - self.type_ = get_args(self.type_)[0] - self.shape = SHAPE_DEQUE - elif issubclass(origin, Sequence): - self.type_ = get_args(self.type_)[0] - self.shape = SHAPE_SEQUENCE - # priority to most common mapping: dict - elif origin is dict or origin is Dict: - self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True) - self.type_ = get_args(self.type_)[1] - self.shape = SHAPE_DICT - elif issubclass(origin, DefaultDict): - self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True) - self.type_ = get_args(self.type_)[1] - self.shape = SHAPE_DEFAULTDICT - elif issubclass(origin, Counter): - self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True) - self.type_ = int - self.shape = SHAPE_COUNTER - elif issubclass(origin, Mapping): - self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True) - self.type_ = get_args(self.type_)[1] - self.shape = SHAPE_MAPPING - # Equality check as almost everything inherits form Iterable, including str - # check for Iterable and CollectionsIterable, as it could receive one even when declared with the other - elif origin in {Iterable, CollectionsIterable}: - self.type_ = get_args(self.type_)[0] - self.shape = SHAPE_ITERABLE - self.sub_fields = [self._create_sub_type(self.type_, f'{self.name}_type')] - elif issubclass(origin, Type): # type: ignore - return - elif hasattr(origin, '__get_validators__') or self.model_config.arbitrary_types_allowed: - # Is a Pydantic-compatible generic that handles itself - # or we have arbitrary_types_allowed = True - self.shape = SHAPE_GENERIC - self.sub_fields = [self._create_sub_type(t, f'{self.name}_{i}') for i, t in enumerate(get_args(self.type_))] - self.type_ = origin - return - else: - raise TypeError(f'Fields of type "{origin}" are not supported.') - - # type_ has been refined eg. as the type of a List and sub_fields needs to be populated - self.sub_fields = [self._create_sub_type(self.type_, '_' + self.name)] - - def prepare_discriminated_union_sub_fields(self) -> None: - """ - Prepare the mapping -> and update `sub_fields` - Note that this process can be aborted if a `ForwardRef` is encountered - """ - assert self.discriminator_key is not None - - if self.type_.__class__ is DeferredType: - return - - assert self.sub_fields is not None - sub_fields_mapping: Dict[str, 'ModelField'] = {} - all_aliases: Set[str] = set() - - for sub_field in self.sub_fields: - t = sub_field.type_ - if t.__class__ is ForwardRef: - # Stopping everything...will need to call `update_forward_refs` - return - - alias, discriminator_values = get_discriminator_alias_and_values(t, self.discriminator_key) - all_aliases.add(alias) - for discriminator_value in discriminator_values: - sub_fields_mapping[discriminator_value] = sub_field - - self.sub_fields_mapping = sub_fields_mapping - self.discriminator_alias = get_unique_discriminator_alias(all_aliases, self.discriminator_key) - - def _create_sub_type(self, type_: Type[Any], name: str, *, for_keys: bool = False) -> 'ModelField': - if for_keys: - class_validators = None - else: - # validators for sub items should not have `each_item` as we want to check only the first sublevel - class_validators = { - k: Validator( - func=v.func, - pre=v.pre, - each_item=False, - always=v.always, - check_fields=v.check_fields, - skip_on_failure=v.skip_on_failure, - ) - for k, v in self.class_validators.items() - if v.each_item - } - - field_info, _ = self._get_field_info(name, type_, None, self.model_config) - - return self.__class__( - type_=type_, - name=name, - class_validators=class_validators, - model_config=self.model_config, - field_info=field_info, - ) - - def populate_validators(self) -> None: - """ - Prepare self.pre_validators, self.validators, and self.post_validators based on self.type_'s __get_validators__ - and class validators. This method should be idempotent, e.g. it should be safe to call multiple times - without mis-configuring the field. - """ - self.validate_always = getattr(self.type_, 'validate_always', False) or any( - v.always for v in self.class_validators.values() - ) - - class_validators_ = self.class_validators.values() - if not self.sub_fields or self.shape == SHAPE_GENERIC: - get_validators = getattr(self.type_, '__get_validators__', None) - v_funcs = ( - *[v.func for v in class_validators_ if v.each_item and v.pre], - *(get_validators() if get_validators else list(find_validators(self.type_, self.model_config))), - *[v.func for v in class_validators_ if v.each_item and not v.pre], - ) - self.validators = prep_validators(v_funcs) - - self.pre_validators = [] - self.post_validators = [] - - if self.field_info and self.field_info.const: - self.post_validators.append(make_generic_validator(constant_validator)) - - if class_validators_: - self.pre_validators += prep_validators(v.func for v in class_validators_ if not v.each_item and v.pre) - self.post_validators += prep_validators(v.func for v in class_validators_ if not v.each_item and not v.pre) - - if self.parse_json: - self.pre_validators.append(make_generic_validator(validate_json)) - - self.pre_validators = self.pre_validators or None - self.post_validators = self.post_validators or None - - def validate( - self, v: Any, values: Dict[str, Any], *, loc: 'LocStr', cls: Optional['ModelOrDc'] = None - ) -> 'ValidateReturn': - - assert self.type_.__class__ is not DeferredType - - if self.type_.__class__ is ForwardRef: - assert cls is not None - raise ConfigError( - f'field "{self.name}" not yet prepared so type is still a ForwardRef, ' - f'you might need to call {cls.__name__}.update_forward_refs().' - ) - - errors: Optional['ErrorList'] - if self.pre_validators: - v, errors = self._apply_validators(v, values, loc, cls, self.pre_validators) - if errors: - return v, errors - - if v is None: - if is_none_type(self.type_): - # keep validating - pass - elif self.allow_none: - if self.post_validators: - return self._apply_validators(v, values, loc, cls, self.post_validators) - else: - return None, None - else: - return v, ErrorWrapper(NoneIsNotAllowedError(), loc) - - if self.shape == SHAPE_SINGLETON: - v, errors = self._validate_singleton(v, values, loc, cls) - elif self.shape in MAPPING_LIKE_SHAPES: - v, errors = self._validate_mapping_like(v, values, loc, cls) - elif self.shape == SHAPE_TUPLE: - v, errors = self._validate_tuple(v, values, loc, cls) - elif self.shape == SHAPE_ITERABLE: - v, errors = self._validate_iterable(v, values, loc, cls) - elif self.shape == SHAPE_GENERIC: - v, errors = self._apply_validators(v, values, loc, cls, self.validators) - else: - # sequence, list, set, generator, tuple with ellipsis, frozen set - v, errors = self._validate_sequence_like(v, values, loc, cls) - - if not errors and self.post_validators: - v, errors = self._apply_validators(v, values, loc, cls, self.post_validators) - return v, errors - - def _validate_sequence_like( # noqa: C901 (ignore complexity) - self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] - ) -> 'ValidateReturn': - """ - Validate sequence-like containers: lists, tuples, sets and generators - Note that large if-else blocks are necessary to enable Cython - optimization, which is why we disable the complexity check above. - """ - if not sequence_like(v): - e: errors_.PydanticTypeError - if self.shape == SHAPE_LIST: - e = errors_.ListError() - elif self.shape in (SHAPE_TUPLE, SHAPE_TUPLE_ELLIPSIS): - e = errors_.TupleError() - elif self.shape == SHAPE_SET: - e = errors_.SetError() - elif self.shape == SHAPE_FROZENSET: - e = errors_.FrozenSetError() - else: - e = errors_.SequenceError() - return v, ErrorWrapper(e, loc) - - loc = loc if isinstance(loc, tuple) else (loc,) - result = [] - errors: List[ErrorList] = [] - for i, v_ in enumerate(v): - v_loc = *loc, i - r, ee = self._validate_singleton(v_, values, v_loc, cls) - if ee: - errors.append(ee) - else: - result.append(r) - - if errors: - return v, errors - - converted: Union[List[Any], Set[Any], FrozenSet[Any], Tuple[Any, ...], Iterator[Any], Deque[Any]] = result - - if self.shape == SHAPE_SET: - converted = set(result) - elif self.shape == SHAPE_FROZENSET: - converted = frozenset(result) - elif self.shape == SHAPE_TUPLE_ELLIPSIS: - converted = tuple(result) - elif self.shape == SHAPE_DEQUE: - converted = deque(result) - elif self.shape == SHAPE_SEQUENCE: - if isinstance(v, tuple): - converted = tuple(result) - elif isinstance(v, set): - converted = set(result) - elif isinstance(v, Generator): - converted = iter(result) - elif isinstance(v, deque): - converted = deque(result) - return converted, None - - def _validate_iterable( - self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] - ) -> 'ValidateReturn': - """ - Validate Iterables. - - This intentionally doesn't validate values to allow infinite generators. - """ - - try: - iterable = iter(v) - except TypeError: - return v, ErrorWrapper(errors_.IterableError(), loc) - return iterable, None - - def _validate_tuple( - self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] - ) -> 'ValidateReturn': - e: Optional[Exception] = None - if not sequence_like(v): - e = errors_.TupleError() - else: - actual_length, expected_length = len(v), len(self.sub_fields) # type: ignore - if actual_length != expected_length: - e = errors_.TupleLengthError(actual_length=actual_length, expected_length=expected_length) - - if e: - return v, ErrorWrapper(e, loc) - - loc = loc if isinstance(loc, tuple) else (loc,) - result = [] - errors: List[ErrorList] = [] - for i, (v_, field) in enumerate(zip(v, self.sub_fields)): # type: ignore - v_loc = *loc, i - r, ee = field.validate(v_, values, loc=v_loc, cls=cls) - if ee: - errors.append(ee) - else: - result.append(r) - - if errors: - return v, errors - else: - return tuple(result), None - - def _validate_mapping_like( - self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] - ) -> 'ValidateReturn': - try: - v_iter = dict_validator(v) - except TypeError as exc: - return v, ErrorWrapper(exc, loc) - - loc = loc if isinstance(loc, tuple) else (loc,) - result, errors = {}, [] - for k, v_ in v_iter.items(): - v_loc = *loc, '__key__' - key_result, key_errors = self.key_field.validate(k, values, loc=v_loc, cls=cls) # type: ignore - if key_errors: - errors.append(key_errors) - continue - - v_loc = *loc, k - value_result, value_errors = self._validate_singleton(v_, values, v_loc, cls) - if value_errors: - errors.append(value_errors) - continue - - result[key_result] = value_result - if errors: - return v, errors - elif self.shape == SHAPE_DICT: - return result, None - elif self.shape == SHAPE_DEFAULTDICT: - return defaultdict(self.type_, result), None - elif self.shape == SHAPE_COUNTER: - return CollectionCounter(result), None - else: - return self._get_mapping_value(v, result), None - - def _get_mapping_value(self, original: T, converted: Dict[Any, Any]) -> Union[T, Dict[Any, Any]]: - """ - When type is `Mapping[KT, KV]` (or another unsupported mapping), we try to avoid - coercing to `dict` unwillingly. - """ - original_cls = original.__class__ - - if original_cls == dict or original_cls == Dict: - return converted - elif original_cls in {defaultdict, DefaultDict}: - return defaultdict(self.type_, converted) - else: - try: - # Counter, OrderedDict, UserDict, ... - return original_cls(converted) # type: ignore - except TypeError: - raise RuntimeError(f'Could not convert dictionary to {original_cls.__name__!r}') from None - - def _validate_singleton( - self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] - ) -> 'ValidateReturn': - if self.sub_fields: - if self.discriminator_key is not None: - return self._validate_discriminated_union(v, values, loc, cls) - - errors = [] - - if self.model_config.smart_union and is_union(get_origin(self.type_)): - # 1st pass: check if the value is an exact instance of one of the Union types - # (e.g. to avoid coercing a bool into an int) - for field in self.sub_fields: - if v.__class__ is field.outer_type_: - return v, None - - # 2nd pass: check if the value is an instance of any subclass of the Union types - for field in self.sub_fields: - # This whole logic will be improved later on to support more complex `isinstance` checks - # It will probably be done once a strict mode is added and be something like: - # ``` - # value, error = field.validate(v, values, strict=True) - # if error is None: - # return value, None - # ``` - try: - if isinstance(v, field.outer_type_): - return v, None - except TypeError: - # compound type - if lenient_isinstance(v, get_origin(field.outer_type_)): - value, error = field.validate(v, values, loc=loc, cls=cls) - if not error: - return value, None - - # 1st pass by default or 3rd pass with `smart_union` enabled: - # check if the value can be coerced into one of the Union types - for field in self.sub_fields: - value, error = field.validate(v, values, loc=loc, cls=cls) - if error: - errors.append(error) - else: - return value, None - return v, errors - else: - return self._apply_validators(v, values, loc, cls, self.validators) - - def _validate_discriminated_union( - self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] - ) -> 'ValidateReturn': - assert self.discriminator_key is not None - assert self.discriminator_alias is not None - - try: - discriminator_value = v[self.discriminator_alias] - except KeyError: - return v, ErrorWrapper(MissingDiscriminator(discriminator_key=self.discriminator_key), loc) - except TypeError: - try: - # BaseModel or dataclass - discriminator_value = getattr(v, self.discriminator_key) - except (AttributeError, TypeError): - return v, ErrorWrapper(MissingDiscriminator(discriminator_key=self.discriminator_key), loc) - - try: - sub_field = self.sub_fields_mapping[discriminator_value] # type: ignore[index] - except TypeError: - assert cls is not None - raise ConfigError( - f'field "{self.name}" not yet prepared so type is still a ForwardRef, ' - f'you might need to call {cls.__name__}.update_forward_refs().' - ) - except KeyError: - assert self.sub_fields_mapping is not None - return v, ErrorWrapper( - InvalidDiscriminator( - discriminator_key=self.discriminator_key, - discriminator_value=discriminator_value, - allowed_values=list(self.sub_fields_mapping), - ), - loc, - ) - else: - if not isinstance(loc, tuple): - loc = (loc,) - return sub_field.validate(v, values, loc=(*loc, display_as_type(sub_field.type_)), cls=cls) - - def _apply_validators( - self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'], validators: 'ValidatorsList' - ) -> 'ValidateReturn': - for validator in validators: - try: - v = validator(cls, v, values, self, self.model_config) - except (ValueError, TypeError, AssertionError) as exc: - return v, ErrorWrapper(exc, loc) - return v, None - - def is_complex(self) -> bool: - """ - Whether the field is "complex" eg. env variables should be parsed as JSON. - """ - from .main import BaseModel - - return ( - self.shape != SHAPE_SINGLETON - or hasattr(self.type_, '__pydantic_model__') - or lenient_issubclass(self.type_, (BaseModel, list, set, frozenset, dict)) - ) - - def _type_display(self) -> PyObjectStr: - t = display_as_type(self.type_) - - if self.shape in MAPPING_LIKE_SHAPES: - t = f'Mapping[{display_as_type(self.key_field.type_)}, {t}]' # type: ignore - elif self.shape == SHAPE_TUPLE: - t = 'Tuple[{}]'.format(', '.join(display_as_type(f.type_) for f in self.sub_fields)) # type: ignore - elif self.shape == SHAPE_GENERIC: - assert self.sub_fields - t = '{}[{}]'.format( - display_as_type(self.type_), ', '.join(display_as_type(f.type_) for f in self.sub_fields) - ) - elif self.shape != SHAPE_SINGLETON: - t = SHAPE_NAME_LOOKUP[self.shape].format(t) - - if self.allow_none and (self.shape != SHAPE_SINGLETON or not self.sub_fields): - t = f'Optional[{t}]' - return PyObjectStr(t) - - def __repr_args__(self) -> 'ReprArgs': - args = [('name', self.name), ('type', self._type_display()), ('required', self.required)] - - if not self.required: - if self.default_factory is not None: - args.append(('default_factory', f'')) - else: - args.append(('default', self.default)) - - if self.alt_alias: - args.append(('alias', self.alias)) - return args - - -class ModelPrivateAttr(Representation): - __slots__ = ('default', 'default_factory') - - def __init__(self, default: Any = Undefined, *, default_factory: Optional[NoArgAnyCallable] = None) -> None: self.default = default self.default_factory = default_factory + if not typing.TYPE_CHECKING: + # We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access + + def __getattr__(self, item: str) -> Any: + """This function improves compatibility with custom descriptors by ensuring delegation happens + as expected when the default value of a private attribute is a descriptor. + """ + if item in {'__get__', '__set__', '__delete__'}: + if hasattr(self.default, item): + return getattr(self.default, item) + raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}') + + def __set_name__(self, cls: type[Any], name: str) -> None: + """Preserve `__set_name__` protocol defined in https://peps.python.org/pep-0487.""" + if self.default is PydanticUndefined: + return + if not hasattr(self.default, '__set_name__'): + return + set_name = self.default.__set_name__ + if callable(set_name): + set_name(cls, name) + def get_default(self) -> Any: - return smart_deepcopy(self.default) if self.default_factory is None else self.default_factory() + """Retrieve the default value of the object. + + If `self.default_factory` is `None`, the method will return a deep copy of the `self.default` object. + + If `self.default_factory` is not `None`, it will call `self.default_factory` and return the value returned. + + Returns: + The default value of the object. + """ + return _utils.smart_deepcopy(self.default) if self.default_factory is None else self.default_factory() def __eq__(self, other: Any) -> bool: return isinstance(other, self.__class__) and (self.default, self.default_factory) == ( @@ -1213,23 +879,32 @@ class ModelPrivateAttr(Representation): def PrivateAttr( - default: Any = Undefined, + default: Any = PydanticUndefined, *, - default_factory: Optional[NoArgAnyCallable] = None, + default_factory: typing.Callable[[], Any] | None = None, ) -> Any: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/models/#private-model-attributes + + Indicates that an attribute is intended for private use and not handled during normal validation/serialization. + + Private attributes are not validated by Pydantic, so it's up to you to ensure they are used in a type-safe manner. + + Private attributes are stored in `__private_attributes__` on the model. + + Args: + default: The attribute's default value. Defaults to Undefined. + default_factory: Callable that will be + called when a default value is needed for this attribute. + If both `default` and `default_factory` are set, an error will be raised. + + Returns: + An instance of [`ModelPrivateAttr`][pydantic.fields.ModelPrivateAttr] class. + + Raises: + ValueError: If both `default` and `default_factory` are set. """ - Indicates that attribute is only used internally and never mixed with regular fields. - - Types or values of private attrs are not checked by pydantic and it's up to you to keep them relevant. - - Private attrs are stored in model __slots__. - - :param default: the attribute’s default value - :param default_factory: callable that will be called when a default value is needed for this attribute - If both `default` and `default_factory` are set, an error is raised. - """ - if default is not Undefined and default_factory is not None: - raise ValueError('cannot specify both default and default_factory') + if default is not PydanticUndefined and default_factory is not None: + raise TypeError('cannot specify both default and default_factory') return ModelPrivateAttr( default, @@ -1237,11 +912,243 @@ def PrivateAttr( ) -class DeferredType: - """ - Used to postpone field preparation, while creating recursive generic models. +@dataclasses.dataclass(**_internal_dataclass.slots_true) +class ComputedFieldInfo: + """A container for data from `@computed_field` so that we can access it while building the pydantic-core schema. + + Attributes: + decorator_repr: A class variable representing the decorator string, '@computed_field'. + wrapped_property: The wrapped computed field property. + return_type: The type of the computed field property's return value. + alias: The alias of the property to be used during serialization. + alias_priority: The priority of the alias. This affects whether an alias generator is used. + title: Title of the computed field to include in the serialization JSON schema. + description: Description of the computed field to include in the serialization JSON schema. + examples: Example values of the computed field to include in the serialization JSON schema. + json_schema_extra: A dict or callable to provide extra JSON schema properties. + repr: A boolean indicating whether to include the field in the __repr__ output. """ + decorator_repr: ClassVar[str] = '@computed_field' + wrapped_property: property + return_type: Any + alias: str | None + alias_priority: int | None + title: str | None + description: str | None + examples: list[Any] | None + json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None + repr: bool -def is_finalvar_with_default_val(type_: Type[Any], val: Any) -> bool: - return is_finalvar(type_) and val is not Undefined and not isinstance(val, FieldInfo) + +def _wrapped_property_is_private(property_: cached_property | property) -> bool: # type: ignore + """Returns true if provided property is private, False otherwise.""" + wrapped_name: str = '' + + if isinstance(property_, property): + wrapped_name = getattr(property_.fget, '__name__', '') + elif isinstance(property_, cached_property): # type: ignore + wrapped_name = getattr(property_.func, '__name__', '') # type: ignore + + return wrapped_name.startswith('_') and not wrapped_name.startswith('__') + + +# this should really be `property[T], cached_property[T]` but property is not generic unlike cached_property +# See https://github.com/python/typing/issues/985 and linked issues +PropertyT = typing.TypeVar('PropertyT') + + +@typing.overload +def computed_field( + *, + alias: str | None = None, + alias_priority: int | None = None, + title: str | None = None, + description: str | None = None, + examples: list[Any] | None = None, + json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None = None, + repr: bool = True, + return_type: Any = PydanticUndefined, +) -> typing.Callable[[PropertyT], PropertyT]: + ... + + +@typing.overload +def computed_field(__func: PropertyT) -> PropertyT: + ... + + +def computed_field( + __f: PropertyT | None = None, + *, + alias: str | None = None, + alias_priority: int | None = None, + title: str | None = None, + description: str | None = None, + examples: list[Any] | None = None, + json_schema_extra: JsonDict | typing.Callable[[JsonDict], None] | None = None, + repr: bool | None = None, + return_type: Any = PydanticUndefined, +) -> PropertyT | typing.Callable[[PropertyT], PropertyT]: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/fields#the-computed_field-decorator + + Decorator to include `property` and `cached_property` when serializing models or dataclasses. + + This is useful for fields that are computed from other fields, or for fields that are expensive to compute and should be cached. + + ```py + from pydantic import BaseModel, computed_field + + class Rectangle(BaseModel): + width: int + length: int + + @computed_field + @property + def area(self) -> int: + return self.width * self.length + + print(Rectangle(width=3, length=2).model_dump()) + #> {'width': 3, 'length': 2, 'area': 6} + ``` + + If applied to functions not yet decorated with `@property` or `@cached_property`, the function is + automatically wrapped with `property`. Although this is more concise, you will lose IntelliSense in your IDE, + and confuse static type checkers, thus explicit use of `@property` is recommended. + + !!! warning "Mypy Warning" + Even with the `@property` or `@cached_property` applied to your function before `@computed_field`, + mypy may throw a `Decorated property not supported` error. + See [mypy issue #1362](https://github.com/python/mypy/issues/1362), for more information. + To avoid this error message, add `# type: ignore[misc]` to the `@computed_field` line. + + [pyright](https://github.com/microsoft/pyright) supports `@computed_field` without error. + + ```py + import random + + from pydantic import BaseModel, computed_field + + class Square(BaseModel): + width: float + + @computed_field + def area(self) -> float: # converted to a `property` by `computed_field` + return round(self.width**2, 2) + + @area.setter + def area(self, new_area: float) -> None: + self.width = new_area**0.5 + + @computed_field(alias='the magic number', repr=False) + def random_number(self) -> int: + return random.randint(0, 1_000) + + square = Square(width=1.3) + + # `random_number` does not appear in representation + print(repr(square)) + #> Square(width=1.3, area=1.69) + + print(square.random_number) + #> 3 + + square.area = 4 + + print(square.model_dump_json(by_alias=True)) + #> {"width":2.0,"area":4.0,"the magic number":3} + ``` + + !!! warning "Overriding with `computed_field`" + You can't override a field from a parent class with a `computed_field` in the child class. + `mypy` complains about this behavior if allowed, and `dataclasses` doesn't allow this pattern either. + See the example below: + + ```py + from pydantic import BaseModel, computed_field + + class Parent(BaseModel): + a: str + + try: + + class Child(Parent): + @computed_field + @property + def a(self) -> str: + return 'new a' + + except ValueError as e: + print(repr(e)) + #> ValueError("you can't override a field with a computed field") + ``` + + Private properties decorated with `@computed_field` have `repr=False` by default. + + ```py + from functools import cached_property + + from pydantic import BaseModel, computed_field + + class Model(BaseModel): + foo: int + + @computed_field + @cached_property + def _private_cached_property(self) -> int: + return -self.foo + + @computed_field + @property + def _private_property(self) -> int: + return -self.foo + + m = Model(foo=1) + print(repr(m)) + #> M(foo=1) + ``` + + Args: + __f: the function to wrap. + alias: alias to use when serializing this computed field, only used when `by_alias=True` + alias_priority: priority of the alias. This affects whether an alias generator is used + title: Title to use when including this computed field in JSON Schema + description: Description to use when including this computed field in JSON Schema, defaults to the function's + docstring + examples: Example values to use when including this computed field in JSON Schema + json_schema_extra: A dict or callable to provide extra JSON schema properties. + repr: whether to include this computed field in model repr. + Default is `False` for private properties and `True` for public properties. + return_type: optional return for serialization logic to expect when serializing to JSON, if included + this must be correct, otherwise a `TypeError` is raised. + If you don't include a return type Any is used, which does runtime introspection to handle arbitrary + objects. + + Returns: + A proxy wrapper for the property. + """ + + def dec(f: Any) -> Any: + nonlocal description, return_type, alias_priority + unwrapped = _decorators.unwrap_wrapped_function(f) + if description is None and unwrapped.__doc__: + description = inspect.cleandoc(unwrapped.__doc__) + + # if the function isn't already decorated with `@property` (or another descriptor), then we wrap it now + f = _decorators.ensure_property(f) + alias_priority = (alias_priority or 2) if alias is not None else None + + if repr is None: + repr_: bool = False if _wrapped_property_is_private(property_=f) else True + else: + repr_ = repr + + dec_info = ComputedFieldInfo( + f, return_type, alias, alias_priority, title, description, examples, json_schema_extra, repr_ + ) + return _decorators.PydanticDescriptorProxy(f, dec_info) + + if __f is None: + return dec + else: + return dec(__f) diff --git a/lib/pydantic/functional_serializers.py b/lib/pydantic/functional_serializers.py new file mode 100644 index 00000000..6e31bf67 --- /dev/null +++ b/lib/pydantic/functional_serializers.py @@ -0,0 +1,395 @@ +"""This module contains related classes and functions for serialization.""" +from __future__ import annotations + +import dataclasses +from functools import partialmethod +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, overload + +from pydantic_core import PydanticUndefined, core_schema +from pydantic_core import core_schema as _core_schema +from typing_extensions import Annotated, Literal, TypeAlias + +from . import PydanticUndefinedAnnotation +from ._internal import _decorators, _internal_dataclass +from .annotated_handlers import GetCoreSchemaHandler + + +@dataclasses.dataclass(**_internal_dataclass.slots_true, frozen=True) +class PlainSerializer: + """Plain serializers use a function to modify the output of serialization. + + This is particularly helpful when you want to customize the serialization for annotated types. + Consider an input of `list`, which will be serialized into a space-delimited string. + + ```python + from typing import List + + from typing_extensions import Annotated + + from pydantic import BaseModel, PlainSerializer + + CustomStr = Annotated[ + List, PlainSerializer(lambda x: ' '.join(x), return_type=str) + ] + + class StudentModel(BaseModel): + courses: CustomStr + + student = StudentModel(courses=['Math', 'Chemistry', 'English']) + print(student.model_dump()) + #> {'courses': 'Math Chemistry English'} + ``` + + Attributes: + func: The serializer function. + return_type: The return type for the function. If omitted it will be inferred from the type annotation. + when_used: Determines when this serializer should be used. Accepts a string with values `'always'`, + `'unless-none'`, `'json'`, and `'json-unless-none'`. Defaults to 'always'. + """ + + func: core_schema.SerializerFunction + return_type: Any = PydanticUndefined + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always' + + def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + """Gets the Pydantic core schema. + + Args: + source_type: The source type. + handler: The `GetCoreSchemaHandler` instance. + + Returns: + The Pydantic core schema. + """ + schema = handler(source_type) + try: + return_type = _decorators.get_function_return_type( + self.func, self.return_type, handler._get_types_namespace() + ) + except NameError as e: + raise PydanticUndefinedAnnotation.from_name_error(e) from e + return_schema = None if return_type is PydanticUndefined else handler.generate_schema(return_type) + schema['serialization'] = core_schema.plain_serializer_function_ser_schema( + function=self.func, + info_arg=_decorators.inspect_annotated_serializer(self.func, 'plain'), + return_schema=return_schema, + when_used=self.when_used, + ) + return schema + + +@dataclasses.dataclass(**_internal_dataclass.slots_true, frozen=True) +class WrapSerializer: + """Wrap serializers receive the raw inputs along with a handler function that applies the standard serialization + logic, and can modify the resulting value before returning it as the final output of serialization. + + For example, here's a scenario in which a wrap serializer transforms timezones to UTC **and** utilizes the existing `datetime` serialization logic. + + ```python + from datetime import datetime, timezone + from typing import Any, Dict + + from typing_extensions import Annotated + + from pydantic import BaseModel, WrapSerializer + + class EventDatetime(BaseModel): + start: datetime + end: datetime + + def convert_to_utc(value: Any, handler, info) -> Dict[str, datetime]: + # Note that `helper` can actually help serialize the `value` for further custom serialization in case it's a subclass. + partial_result = handler(value, info) + if info.mode == 'json': + return { + k: datetime.fromisoformat(v).astimezone(timezone.utc) + for k, v in partial_result.items() + } + return {k: v.astimezone(timezone.utc) for k, v in partial_result.items()} + + UTCEventDatetime = Annotated[EventDatetime, WrapSerializer(convert_to_utc)] + + class EventModel(BaseModel): + event_datetime: UTCEventDatetime + + dt = EventDatetime( + start='2024-01-01T07:00:00-08:00', end='2024-01-03T20:00:00+06:00' + ) + event = EventModel(event_datetime=dt) + print(event.model_dump()) + ''' + { + 'event_datetime': { + 'start': datetime.datetime( + 2024, 1, 1, 15, 0, tzinfo=datetime.timezone.utc + ), + 'end': datetime.datetime( + 2024, 1, 3, 14, 0, tzinfo=datetime.timezone.utc + ), + } + } + ''' + + print(event.model_dump_json()) + ''' + {"event_datetime":{"start":"2024-01-01T15:00:00Z","end":"2024-01-03T14:00:00Z"}} + ''' + ``` + + Attributes: + func: The serializer function to be wrapped. + return_type: The return type for the function. If omitted it will be inferred from the type annotation. + when_used: Determines when this serializer should be used. Accepts a string with values `'always'`, + `'unless-none'`, `'json'`, and `'json-unless-none'`. Defaults to 'always'. + """ + + func: core_schema.WrapSerializerFunction + return_type: Any = PydanticUndefined + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always' + + def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + """This method is used to get the Pydantic core schema of the class. + + Args: + source_type: Source type. + handler: Core schema handler. + + Returns: + The generated core schema of the class. + """ + schema = handler(source_type) + try: + return_type = _decorators.get_function_return_type( + self.func, self.return_type, handler._get_types_namespace() + ) + except NameError as e: + raise PydanticUndefinedAnnotation.from_name_error(e) from e + return_schema = None if return_type is PydanticUndefined else handler.generate_schema(return_type) + schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( + function=self.func, + info_arg=_decorators.inspect_annotated_serializer(self.func, 'wrap'), + return_schema=return_schema, + when_used=self.when_used, + ) + return schema + + +if TYPE_CHECKING: + _PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]] + _PlainSerializationFunction = Union[_core_schema.SerializerFunction, _PartialClsOrStaticMethod] + _WrapSerializationFunction = Union[_core_schema.WrapSerializerFunction, _PartialClsOrStaticMethod] + _PlainSerializeMethodType = TypeVar('_PlainSerializeMethodType', bound=_PlainSerializationFunction) + _WrapSerializeMethodType = TypeVar('_WrapSerializeMethodType', bound=_WrapSerializationFunction) + + +@overload +def field_serializer( + __field: str, + *fields: str, + return_type: Any = ..., + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = ..., + check_fields: bool | None = ..., +) -> Callable[[_PlainSerializeMethodType], _PlainSerializeMethodType]: + ... + + +@overload +def field_serializer( + __field: str, + *fields: str, + mode: Literal['plain'], + return_type: Any = ..., + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = ..., + check_fields: bool | None = ..., +) -> Callable[[_PlainSerializeMethodType], _PlainSerializeMethodType]: + ... + + +@overload +def field_serializer( + __field: str, + *fields: str, + mode: Literal['wrap'], + return_type: Any = ..., + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = ..., + check_fields: bool | None = ..., +) -> Callable[[_WrapSerializeMethodType], _WrapSerializeMethodType]: + ... + + +def field_serializer( + *fields: str, + mode: Literal['plain', 'wrap'] = 'plain', + return_type: Any = PydanticUndefined, + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always', + check_fields: bool | None = None, +) -> Callable[[Any], Any]: + """Decorator that enables custom field serialization. + + In the below example, a field of type `set` is used to mitigate duplication. A `field_serializer` is used to serialize the data as a sorted list. + + ```python + from typing import Set + + from pydantic import BaseModel, field_serializer + + class StudentModel(BaseModel): + name: str = 'Jane' + courses: Set[str] + + @field_serializer('courses', when_used='json') + def serialize_courses_in_order(courses: Set[str]): + return sorted(courses) + + student = StudentModel(courses={'Math', 'Chemistry', 'English'}) + print(student.model_dump_json()) + #> {"name":"Jane","courses":["Chemistry","English","Math"]} + ``` + + See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information. + + Four signatures are supported: + + - `(self, value: Any, info: FieldSerializationInfo)` + - `(self, value: Any, nxt: SerializerFunctionWrapHandler, info: FieldSerializationInfo)` + - `(value: Any, info: SerializationInfo)` + - `(value: Any, nxt: SerializerFunctionWrapHandler, info: SerializationInfo)` + + Args: + fields: Which field(s) the method should be called on. + mode: The serialization mode. + + - `plain` means the function will be called instead of the default serialization logic, + - `wrap` means the function will be called with an argument to optionally call the + default serialization logic. + return_type: Optional return type for the function, if omitted it will be inferred from the type annotation. + when_used: Determines the serializer will be used for serialization. + check_fields: Whether to check that the fields actually exist on the model. + + Returns: + The decorator function. + """ + + def dec( + f: Callable[..., Any] | staticmethod[Any, Any] | classmethod[Any, Any, Any], + ) -> _decorators.PydanticDescriptorProxy[Any]: + dec_info = _decorators.FieldSerializerDecoratorInfo( + fields=fields, + mode=mode, + return_type=return_type, + when_used=when_used, + check_fields=check_fields, + ) + return _decorators.PydanticDescriptorProxy(f, dec_info) + + return dec + + +FuncType = TypeVar('FuncType', bound=Callable[..., Any]) + + +@overload +def model_serializer(__f: FuncType) -> FuncType: + ... + + +@overload +def model_serializer( + *, + mode: Literal['plain', 'wrap'] = ..., + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always', + return_type: Any = ..., +) -> Callable[[FuncType], FuncType]: + ... + + +def model_serializer( + __f: Callable[..., Any] | None = None, + *, + mode: Literal['plain', 'wrap'] = 'plain', + when_used: Literal['always', 'unless-none', 'json', 'json-unless-none'] = 'always', + return_type: Any = PydanticUndefined, +) -> Callable[[Any], Any]: + """Decorator that enables custom model serialization. + + This is useful when a model need to be serialized in a customized manner, allowing for flexibility beyond just specific fields. + + An example would be to serialize temperature to the same temperature scale, such as degrees Celsius. + + ```python + from typing import Literal + + from pydantic import BaseModel, model_serializer + + class TemperatureModel(BaseModel): + unit: Literal['C', 'F'] + value: int + + @model_serializer() + def serialize_model(self): + if self.unit == 'F': + return {'unit': 'C', 'value': int((self.value - 32) / 1.8)} + return {'unit': self.unit, 'value': self.value} + + temperature = TemperatureModel(unit='F', value=212) + print(temperature.model_dump()) + #> {'unit': 'C', 'value': 100} + ``` + + See [Custom serializers](../concepts/serialization.md#custom-serializers) for more information. + + Args: + __f: The function to be decorated. + mode: The serialization mode. + + - `'plain'` means the function will be called instead of the default serialization logic + - `'wrap'` means the function will be called with an argument to optionally call the default + serialization logic. + when_used: Determines when this serializer should be used. + return_type: The return type for the function. If omitted it will be inferred from the type annotation. + + Returns: + The decorator function. + """ + + def dec(f: Callable[..., Any]) -> _decorators.PydanticDescriptorProxy[Any]: + dec_info = _decorators.ModelSerializerDecoratorInfo(mode=mode, return_type=return_type, when_used=when_used) + return _decorators.PydanticDescriptorProxy(f, dec_info) + + if __f is None: + return dec + else: + return dec(__f) # type: ignore + + +AnyType = TypeVar('AnyType') + + +if TYPE_CHECKING: + SerializeAsAny = Annotated[AnyType, ...] # SerializeAsAny[list[str]] will be treated by type checkers as list[str] + """Force serialization to ignore whatever is defined in the schema and instead ask the object + itself how it should be serialized. + In particular, this means that when model subclasses are serialized, fields present in the subclass + but not in the original schema will be included. + """ +else: + + @dataclasses.dataclass(**_internal_dataclass.slots_true) + class SerializeAsAny: # noqa: D101 + def __class_getitem__(cls, item: Any) -> Any: + return Annotated[item, SerializeAsAny()] + + def __get_pydantic_core_schema__( + self, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + schema = handler(source_type) + schema_to_update = schema + while schema_to_update['type'] == 'definitions': + schema_to_update = schema_to_update.copy() + schema_to_update = schema_to_update['schema'] + schema_to_update['serialization'] = core_schema.wrap_serializer_function_ser_schema( + lambda x, h: h(x), schema=core_schema.any_schema() + ) + return schema + + __hash__ = object.__hash__ diff --git a/lib/pydantic/functional_validators.py b/lib/pydantic/functional_validators.py new file mode 100644 index 00000000..b547755b --- /dev/null +++ b/lib/pydantic/functional_validators.py @@ -0,0 +1,706 @@ +"""This module contains related classes and functions for validation.""" + +from __future__ import annotations as _annotations + +import dataclasses +import sys +from functools import partialmethod +from types import FunctionType +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, cast, overload + +from pydantic_core import core_schema +from pydantic_core import core_schema as _core_schema +from typing_extensions import Annotated, Literal, TypeAlias + +from . import GetCoreSchemaHandler as _GetCoreSchemaHandler +from ._internal import _core_metadata, _decorators, _generics, _internal_dataclass +from .annotated_handlers import GetCoreSchemaHandler +from .errors import PydanticUserError + +if sys.version_info < (3, 11): + from typing_extensions import Protocol +else: + from typing import Protocol + +_inspect_validator = _decorators.inspect_validator + + +@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true) +class AfterValidator: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/validators/#annotated-validators + + A metadata class that indicates that a validation should be applied **after** the inner validation logic. + + Attributes: + func: The validator function. + + Example: + ```py + from typing_extensions import Annotated + + from pydantic import AfterValidator, BaseModel, ValidationError + + MyInt = Annotated[int, AfterValidator(lambda v: v + 1)] + + class Model(BaseModel): + a: MyInt + + print(Model(a=1).a) + #> 2 + + try: + Model(a='a') + except ValidationError as e: + print(e.json(indent=2)) + ''' + [ + { + "type": "int_parsing", + "loc": [ + "a" + ], + "msg": "Input should be a valid integer, unable to parse string as an integer", + "input": "a", + "url": "https://errors.pydantic.dev/2/v/int_parsing" + } + ] + ''' + ``` + """ + + func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction + + def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema: + schema = handler(source_type) + info_arg = _inspect_validator(self.func, 'after') + if info_arg: + func = cast(core_schema.WithInfoValidatorFunction, self.func) + return core_schema.with_info_after_validator_function(func, schema=schema, field_name=handler.field_name) + else: + func = cast(core_schema.NoInfoValidatorFunction, self.func) + return core_schema.no_info_after_validator_function(func, schema=schema) + + +@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true) +class BeforeValidator: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/validators/#annotated-validators + + A metadata class that indicates that a validation should be applied **before** the inner validation logic. + + Attributes: + func: The validator function. + + Example: + ```py + from typing_extensions import Annotated + + from pydantic import BaseModel, BeforeValidator + + MyInt = Annotated[int, BeforeValidator(lambda v: v + 1)] + + class Model(BaseModel): + a: MyInt + + print(Model(a=1).a) + #> 2 + + try: + Model(a='a') + except TypeError as e: + print(e) + #> can only concatenate str (not "int") to str + ``` + """ + + func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction + + def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema: + schema = handler(source_type) + info_arg = _inspect_validator(self.func, 'before') + if info_arg: + func = cast(core_schema.WithInfoValidatorFunction, self.func) + return core_schema.with_info_before_validator_function(func, schema=schema, field_name=handler.field_name) + else: + func = cast(core_schema.NoInfoValidatorFunction, self.func) + return core_schema.no_info_before_validator_function(func, schema=schema) + + +@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true) +class PlainValidator: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/validators/#annotated-validators + + A metadata class that indicates that a validation should be applied **instead** of the inner validation logic. + + Attributes: + func: The validator function. + + Example: + ```py + from typing_extensions import Annotated + + from pydantic import BaseModel, PlainValidator + + MyInt = Annotated[int, PlainValidator(lambda v: int(v) + 1)] + + class Model(BaseModel): + a: MyInt + + print(Model(a='1').a) + #> 2 + ``` + """ + + func: core_schema.NoInfoValidatorFunction | core_schema.WithInfoValidatorFunction + + def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema: + # Note that for some valid uses of PlainValidator, it is not possible to generate a core schema for the + # source_type, so calling `handler(source_type)` will error, which prevents us from generating a proper + # serialization schema. To work around this for use cases that will not involve serialization, we simply + # catch any PydanticSchemaGenerationError that may be raised while attempting to build the serialization schema + # and abort any attempts to handle special serialization. + from pydantic import PydanticSchemaGenerationError + + try: + schema = handler(source_type) + serialization = core_schema.wrap_serializer_function_ser_schema(function=lambda v, h: h(v), schema=schema) + except PydanticSchemaGenerationError: + serialization = None + + info_arg = _inspect_validator(self.func, 'plain') + if info_arg: + func = cast(core_schema.WithInfoValidatorFunction, self.func) + return core_schema.with_info_plain_validator_function( + func, field_name=handler.field_name, serialization=serialization + ) + else: + func = cast(core_schema.NoInfoValidatorFunction, self.func) + return core_schema.no_info_plain_validator_function(func, serialization=serialization) + + +@dataclasses.dataclass(frozen=True, **_internal_dataclass.slots_true) +class WrapValidator: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/validators/#annotated-validators + + A metadata class that indicates that a validation should be applied **around** the inner validation logic. + + Attributes: + func: The validator function. + + ```py + from datetime import datetime + + from typing_extensions import Annotated + + from pydantic import BaseModel, ValidationError, WrapValidator + + def validate_timestamp(v, handler): + if v == 'now': + # we don't want to bother with further validation, just return the new value + return datetime.now() + try: + return handler(v) + except ValidationError: + # validation failed, in this case we want to return a default value + return datetime(2000, 1, 1) + + MyTimestamp = Annotated[datetime, WrapValidator(validate_timestamp)] + + class Model(BaseModel): + a: MyTimestamp + + print(Model(a='now').a) + #> 2032-01-02 03:04:05.000006 + print(Model(a='invalid').a) + #> 2000-01-01 00:00:00 + ``` + """ + + func: core_schema.NoInfoWrapValidatorFunction | core_schema.WithInfoWrapValidatorFunction + + def __get_pydantic_core_schema__(self, source_type: Any, handler: _GetCoreSchemaHandler) -> core_schema.CoreSchema: + schema = handler(source_type) + info_arg = _inspect_validator(self.func, 'wrap') + if info_arg: + func = cast(core_schema.WithInfoWrapValidatorFunction, self.func) + return core_schema.with_info_wrap_validator_function(func, schema=schema, field_name=handler.field_name) + else: + func = cast(core_schema.NoInfoWrapValidatorFunction, self.func) + return core_schema.no_info_wrap_validator_function(func, schema=schema) + + +if TYPE_CHECKING: + + class _OnlyValueValidatorClsMethod(Protocol): + def __call__(self, cls: Any, value: Any, /) -> Any: + ... + + class _V2ValidatorClsMethod(Protocol): + def __call__(self, cls: Any, value: Any, info: _core_schema.ValidationInfo, /) -> Any: + ... + + class _V2WrapValidatorClsMethod(Protocol): + def __call__( + self, + cls: Any, + value: Any, + handler: _core_schema.ValidatorFunctionWrapHandler, + info: _core_schema.ValidationInfo, + /, + ) -> Any: + ... + + _V2Validator = Union[ + _V2ValidatorClsMethod, + _core_schema.WithInfoValidatorFunction, + _OnlyValueValidatorClsMethod, + _core_schema.NoInfoValidatorFunction, + ] + + _V2WrapValidator = Union[ + _V2WrapValidatorClsMethod, + _core_schema.WithInfoWrapValidatorFunction, + ] + + _PartialClsOrStaticMethod: TypeAlias = Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any]] + + _V2BeforeAfterOrPlainValidatorType = TypeVar( + '_V2BeforeAfterOrPlainValidatorType', + _V2Validator, + _PartialClsOrStaticMethod, + ) + _V2WrapValidatorType = TypeVar('_V2WrapValidatorType', _V2WrapValidator, _PartialClsOrStaticMethod) + + +@overload +def field_validator( + __field: str, + *fields: str, + mode: Literal['before', 'after', 'plain'] = ..., + check_fields: bool | None = ..., +) -> Callable[[_V2BeforeAfterOrPlainValidatorType], _V2BeforeAfterOrPlainValidatorType]: + ... + + +@overload +def field_validator( + __field: str, + *fields: str, + mode: Literal['wrap'], + check_fields: bool | None = ..., +) -> Callable[[_V2WrapValidatorType], _V2WrapValidatorType]: + ... + + +FieldValidatorModes: TypeAlias = Literal['before', 'after', 'wrap', 'plain'] + + +def field_validator( + __field: str, + *fields: str, + mode: FieldValidatorModes = 'after', + check_fields: bool | None = None, +) -> Callable[[Any], Any]: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/validators/#field-validators + + Decorate methods on the class indicating that they should be used to validate fields. + + Example usage: + ```py + from typing import Any + + from pydantic import ( + BaseModel, + ValidationError, + field_validator, + ) + + class Model(BaseModel): + a: str + + @field_validator('a') + @classmethod + def ensure_foobar(cls, v: Any): + if 'foobar' not in v: + raise ValueError('"foobar" not found in a') + return v + + print(repr(Model(a='this is foobar good'))) + #> Model(a='this is foobar good') + + try: + Model(a='snap') + except ValidationError as exc_info: + print(exc_info) + ''' + 1 validation error for Model + a + Value error, "foobar" not found in a [type=value_error, input_value='snap', input_type=str] + ''' + ``` + + For more in depth examples, see [Field Validators](../concepts/validators.md#field-validators). + + Args: + __field: The first field the `field_validator` should be called on; this is separate + from `fields` to ensure an error is raised if you don't pass at least one. + *fields: Additional field(s) the `field_validator` should be called on. + mode: Specifies whether to validate the fields before or after validation. + check_fields: Whether to check that the fields actually exist on the model. + + Returns: + A decorator that can be used to decorate a function to be used as a field_validator. + + Raises: + PydanticUserError: + - If `@field_validator` is used bare (with no fields). + - If the args passed to `@field_validator` as fields are not strings. + - If `@field_validator` applied to instance methods. + """ + if isinstance(__field, FunctionType): + raise PydanticUserError( + '`@field_validator` should be used with fields and keyword arguments, not bare. ' + "E.g. usage should be `@validator('', ...)`", + code='validator-no-fields', + ) + fields = __field, *fields + if not all(isinstance(field, str) for field in fields): + raise PydanticUserError( + '`@field_validator` fields should be passed as separate string args. ' + "E.g. usage should be `@validator('', '', ...)`", + code='validator-invalid-fields', + ) + + def dec( + f: Callable[..., Any] | staticmethod[Any, Any] | classmethod[Any, Any, Any], + ) -> _decorators.PydanticDescriptorProxy[Any]: + if _decorators.is_instance_method_from_sig(f): + raise PydanticUserError( + '`@field_validator` cannot be applied to instance methods', code='validator-instance-method' + ) + + # auto apply the @classmethod decorator + f = _decorators.ensure_classmethod_based_on_signature(f) + + dec_info = _decorators.FieldValidatorDecoratorInfo(fields=fields, mode=mode, check_fields=check_fields) + return _decorators.PydanticDescriptorProxy(f, dec_info) + + return dec + + +_ModelType = TypeVar('_ModelType') +_ModelTypeCo = TypeVar('_ModelTypeCo', covariant=True) + + +class ModelWrapValidatorHandler(_core_schema.ValidatorFunctionWrapHandler, Protocol[_ModelTypeCo]): + """@model_validator decorated function handler argument type. This is used when `mode='wrap'`.""" + + def __call__( # noqa: D102 + self, + value: Any, + outer_location: str | int | None = None, + /, + ) -> _ModelTypeCo: # pragma: no cover + ... + + +class ModelWrapValidatorWithoutInfo(Protocol[_ModelType]): + """A @model_validator decorated function signature. + This is used when `mode='wrap'` and the function does not have info argument. + """ + + def __call__( # noqa: D102 + self, + cls: type[_ModelType], + # this can be a dict, a model instance + # or anything else that gets passed to validate_python + # thus validators _must_ handle all cases + value: Any, + handler: ModelWrapValidatorHandler[_ModelType], + /, + ) -> _ModelType: + ... + + +class ModelWrapValidator(Protocol[_ModelType]): + """A @model_validator decorated function signature. This is used when `mode='wrap'`.""" + + def __call__( # noqa: D102 + self, + cls: type[_ModelType], + # this can be a dict, a model instance + # or anything else that gets passed to validate_python + # thus validators _must_ handle all cases + value: Any, + handler: ModelWrapValidatorHandler[_ModelType], + info: _core_schema.ValidationInfo, + /, + ) -> _ModelType: + ... + + +class FreeModelBeforeValidatorWithoutInfo(Protocol): + """A @model_validator decorated function signature. + This is used when `mode='before'` and the function does not have info argument. + """ + + def __call__( # noqa: D102 + self, + # this can be a dict, a model instance + # or anything else that gets passed to validate_python + # thus validators _must_ handle all cases + value: Any, + /, + ) -> Any: + ... + + +class ModelBeforeValidatorWithoutInfo(Protocol): + """A @model_validator decorated function signature. + This is used when `mode='before'` and the function does not have info argument. + """ + + def __call__( # noqa: D102 + self, + cls: Any, + # this can be a dict, a model instance + # or anything else that gets passed to validate_python + # thus validators _must_ handle all cases + value: Any, + /, + ) -> Any: + ... + + +class FreeModelBeforeValidator(Protocol): + """A `@model_validator` decorated function signature. This is used when `mode='before'`.""" + + def __call__( # noqa: D102 + self, + # this can be a dict, a model instance + # or anything else that gets passed to validate_python + # thus validators _must_ handle all cases + value: Any, + info: _core_schema.ValidationInfo, + /, + ) -> Any: + ... + + +class ModelBeforeValidator(Protocol): + """A `@model_validator` decorated function signature. This is used when `mode='before'`.""" + + def __call__( # noqa: D102 + self, + cls: Any, + # this can be a dict, a model instance + # or anything else that gets passed to validate_python + # thus validators _must_ handle all cases + value: Any, + info: _core_schema.ValidationInfo, + /, + ) -> Any: + ... + + +ModelAfterValidatorWithoutInfo = Callable[[_ModelType], _ModelType] +"""A `@model_validator` decorated function signature. This is used when `mode='after'` and the function does not +have info argument. +""" + +ModelAfterValidator = Callable[[_ModelType, _core_schema.ValidationInfo], _ModelType] +"""A `@model_validator` decorated function signature. This is used when `mode='after'`.""" + +_AnyModelWrapValidator = Union[ModelWrapValidator[_ModelType], ModelWrapValidatorWithoutInfo[_ModelType]] +_AnyModeBeforeValidator = Union[ + FreeModelBeforeValidator, ModelBeforeValidator, FreeModelBeforeValidatorWithoutInfo, ModelBeforeValidatorWithoutInfo +] +_AnyModelAfterValidator = Union[ModelAfterValidator[_ModelType], ModelAfterValidatorWithoutInfo[_ModelType]] + + +@overload +def model_validator( + *, + mode: Literal['wrap'], +) -> Callable[ + [_AnyModelWrapValidator[_ModelType]], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo] +]: + ... + + +@overload +def model_validator( + *, + mode: Literal['before'], +) -> Callable[[_AnyModeBeforeValidator], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo]]: + ... + + +@overload +def model_validator( + *, + mode: Literal['after'], +) -> Callable[ + [_AnyModelAfterValidator[_ModelType]], _decorators.PydanticDescriptorProxy[_decorators.ModelValidatorDecoratorInfo] +]: + ... + + +def model_validator( + *, + mode: Literal['wrap', 'before', 'after'], +) -> Any: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/validators/#model-validators + + Decorate model methods for validation purposes. + + Example usage: + ```py + from typing_extensions import Self + + from pydantic import BaseModel, ValidationError, model_validator + + class Square(BaseModel): + width: float + height: float + + @model_validator(mode='after') + def verify_square(self) -> Self: + if self.width != self.height: + raise ValueError('width and height do not match') + return self + + s = Square(width=1, height=1) + print(repr(s)) + #> Square(width=1.0, height=1.0) + + try: + Square(width=1, height=2) + except ValidationError as e: + print(e) + ''' + 1 validation error for Square + Value error, width and height do not match [type=value_error, input_value={'width': 1, 'height': 2}, input_type=dict] + ''' + ``` + + For more in depth examples, see [Model Validators](../concepts/validators.md#model-validators). + + Args: + mode: A required string literal that specifies the validation mode. + It can be one of the following: 'wrap', 'before', or 'after'. + + Returns: + A decorator that can be used to decorate a function to be used as a model validator. + """ + + def dec(f: Any) -> _decorators.PydanticDescriptorProxy[Any]: + # auto apply the @classmethod decorator + f = _decorators.ensure_classmethod_based_on_signature(f) + dec_info = _decorators.ModelValidatorDecoratorInfo(mode=mode) + return _decorators.PydanticDescriptorProxy(f, dec_info) + + return dec + + +AnyType = TypeVar('AnyType') + + +if TYPE_CHECKING: + # If we add configurable attributes to IsInstance, we'd probably need to stop hiding it from type checkers like this + InstanceOf = Annotated[AnyType, ...] # `IsInstance[Sequence]` will be recognized by type checkers as `Sequence` + +else: + + @dataclasses.dataclass(**_internal_dataclass.slots_true) + class InstanceOf: + '''Generic type for annotating a type that is an instance of a given class. + + Example: + ```py + from pydantic import BaseModel, InstanceOf + + class Foo: + ... + + class Bar(BaseModel): + foo: InstanceOf[Foo] + + Bar(foo=Foo()) + try: + Bar(foo=42) + except ValidationError as e: + print(e) + """ + [ + │ { + │ │ 'type': 'is_instance_of', + │ │ 'loc': ('foo',), + │ │ 'msg': 'Input should be an instance of Foo', + │ │ 'input': 42, + │ │ 'ctx': {'class': 'Foo'}, + │ │ 'url': 'https://errors.pydantic.dev/0.38.0/v/is_instance_of' + │ } + ] + """ + ``` + ''' + + @classmethod + def __class_getitem__(cls, item: AnyType) -> AnyType: + return Annotated[item, cls()] + + @classmethod + def __get_pydantic_core_schema__(cls, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + from pydantic import PydanticSchemaGenerationError + + # use the generic _origin_ as the second argument to isinstance when appropriate + instance_of_schema = core_schema.is_instance_schema(_generics.get_origin(source) or source) + + try: + # Try to generate the "standard" schema, which will be used when loading from JSON + original_schema = handler(source) + except PydanticSchemaGenerationError: + # If that fails, just produce a schema that can validate from python + return instance_of_schema + else: + # Use the "original" approach to serialization + instance_of_schema['serialization'] = core_schema.wrap_serializer_function_ser_schema( + function=lambda v, h: h(v), schema=original_schema + ) + return core_schema.json_or_python_schema(python_schema=instance_of_schema, json_schema=original_schema) + + __hash__ = object.__hash__ + + +if TYPE_CHECKING: + SkipValidation = Annotated[AnyType, ...] # SkipValidation[list[str]] will be treated by type checkers as list[str] +else: + + @dataclasses.dataclass(**_internal_dataclass.slots_true) + class SkipValidation: + """If this is applied as an annotation (e.g., via `x: Annotated[int, SkipValidation]`), validation will be + skipped. You can also use `SkipValidation[int]` as a shorthand for `Annotated[int, SkipValidation]`. + + This can be useful if you want to use a type annotation for documentation/IDE/type-checking purposes, + and know that it is safe to skip validation for one or more of the fields. + + Because this converts the validation schema to `any_schema`, subsequent annotation-applied transformations + may not have the expected effects. Therefore, when used, this annotation should generally be the final + annotation applied to a type. + """ + + def __class_getitem__(cls, item: Any) -> Any: + return Annotated[item, SkipValidation()] + + @classmethod + def __get_pydantic_core_schema__(cls, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + original_schema = handler(source) + metadata = _core_metadata.build_metadata_dict(js_annotation_functions=[lambda _c, h: h(original_schema)]) + return core_schema.any_schema( + metadata=metadata, + serialization=core_schema.wrap_serializer_function_ser_schema( + function=lambda v, h: h(v), schema=original_schema + ), + ) + + __hash__ = object.__hash__ diff --git a/lib/pydantic/generics.py b/lib/pydantic/generics.py index a3f52bfe..5f6f7f7a 100644 --- a/lib/pydantic/generics.py +++ b/lib/pydantic/generics.py @@ -1,364 +1,4 @@ -import sys -import typing -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Dict, - Generic, - Iterator, - List, - Mapping, - Optional, - Tuple, - Type, - TypeVar, - Union, - cast, -) +"""The `generics` module is a backport module from V1.""" +from ._migration import getattr_migration -from typing_extensions import Annotated - -from .class_validators import gather_all_validators -from .fields import DeferredType -from .main import BaseModel, create_model -from .types import JsonWrapper -from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base -from .utils import LimitedDict, all_identical, lenient_issubclass - -GenericModelT = TypeVar('GenericModelT', bound='GenericModel') -TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type - -Parametrization = Mapping[TypeVarType, Type[Any]] - -_generic_types_cache: LimitedDict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = LimitedDict() -# _assigned_parameters is a Mapping from parametrized version of generic models to assigned types of parametrizations -# as captured during construction of the class (not instances). -# E.g., for generic model `Model[A, B]`, when parametrized model `Model[int, str]` is created, -# `Model[int, str]`: {A: int, B: str}` will be stored in `_assigned_parameters`. -# (This information is only otherwise available after creation from the class name string). -_assigned_parameters: LimitedDict[Type[Any], Parametrization] = LimitedDict() - - -class GenericModel(BaseModel): - __slots__ = () - __concrete__: ClassVar[bool] = False - - if TYPE_CHECKING: - # Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with - # `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of - # `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below. - __parameters__: ClassVar[Tuple[TypeVarType, ...]] - - # Setting the return type as Type[Any] instead of Type[BaseModel] prevents PyCharm warnings - def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[Type[Any], ...]]) -> Type[Any]: - """Instantiates a new class from a generic class `cls` and type variables `params`. - - :param params: Tuple of types the class . Given a generic class - `Model` with 2 type variables and a concrete model `Model[str, int]`, - the value `(str, int)` would be passed to `params`. - :return: New model class inheriting from `cls` with instantiated - types described by `params`. If no parameters are given, `cls` is - returned as is. - - """ - - def _cache_key(_params: Any) -> Tuple[Type[GenericModelT], Any, Tuple[Any, ...]]: - return cls, _params, get_args(_params) - - cached = _generic_types_cache.get(_cache_key(params)) - if cached is not None: - return cached - if cls.__concrete__ and Generic not in cls.__bases__: - raise TypeError('Cannot parameterize a concrete instantiation of a generic model') - if not isinstance(params, tuple): - params = (params,) - if cls is GenericModel and any(isinstance(param, TypeVar) for param in params): - raise TypeError('Type parameters should be placed on typing.Generic, not GenericModel') - if not hasattr(cls, '__parameters__'): - raise TypeError(f'Type {cls.__name__} must inherit from typing.Generic before being parameterized') - - check_parameters_count(cls, params) - # Build map from generic typevars to passed params - typevars_map: Dict[TypeVarType, Type[Any]] = dict(zip(cls.__parameters__, params)) - if all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map: - return cls # if arguments are equal to parameters it's the same object - - # Create new model with original model as parent inserting fields with DeferredType. - model_name = cls.__concrete_name__(params) - validators = gather_all_validators(cls) - - type_hints = get_all_type_hints(cls).items() - instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar} - - fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__} - - model_module, called_globally = get_caller_frame_info() - created_model = cast( - Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes - create_model( - model_name, - __module__=model_module or cls.__module__, - __base__=(cls,) + tuple(cls.__parameterized_bases__(typevars_map)), - __config__=None, - __validators__=validators, - __cls_kwargs__=None, - **fields, - ), - ) - - _assigned_parameters[created_model] = typevars_map - - if called_globally: # create global reference and therefore allow pickling - object_by_reference = None - reference_name = model_name - reference_module_globals = sys.modules[created_model.__module__].__dict__ - while object_by_reference is not created_model: - object_by_reference = reference_module_globals.setdefault(reference_name, created_model) - reference_name += '_' - - created_model.Config = cls.Config - - # Find any typevars that are still present in the model. - # If none are left, the model is fully "concrete", otherwise the new - # class is a generic class as well taking the found typevars as - # parameters. - new_params = tuple( - {param: None for param in iter_contained_typevars(typevars_map.values())} - ) # use dict as ordered set - created_model.__concrete__ = not new_params - if new_params: - created_model.__parameters__ = new_params - - # Save created model in cache so we don't end up creating duplicate - # models that should be identical. - _generic_types_cache[_cache_key(params)] = created_model - if len(params) == 1: - _generic_types_cache[_cache_key(params[0])] = created_model - - # Recursively walk class type hints and replace generic typevars - # with concrete types that were passed. - _prepare_model_fields(created_model, fields, instance_type_hints, typevars_map) - - return created_model - - @classmethod - def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str: - """Compute class name for child classes. - - :param params: Tuple of types the class . Given a generic class - `Model` with 2 type variables and a concrete model `Model[str, int]`, - the value `(str, int)` would be passed to `params`. - :return: String representing a the new class where `params` are - passed to `cls` as type variables. - - This method can be overridden to achieve a custom naming scheme for GenericModels. - """ - param_names = [display_as_type(param) for param in params] - params_component = ', '.join(param_names) - return f'{cls.__name__}[{params_component}]' - - @classmethod - def __parameterized_bases__(cls, typevars_map: Parametrization) -> Iterator[Type[Any]]: - """ - Returns unbound bases of cls parameterised to given type variables - - :param typevars_map: Dictionary of type applications for binding subclasses. - Given a generic class `Model` with 2 type variables [S, T] - and a concrete model `Model[str, int]`, - the value `{S: str, T: int}` would be passed to `typevars_map`. - :return: an iterator of generic sub classes, parameterised by `typevars_map` - and other assigned parameters of `cls` - - e.g.: - ``` - class A(GenericModel, Generic[T]): - ... - - class B(A[V], Generic[V]): - ... - - assert A[int] in B.__parameterized_bases__({V: int}) - ``` - """ - - def build_base_model( - base_model: Type[GenericModel], mapped_types: Parametrization - ) -> Iterator[Type[GenericModel]]: - base_parameters = tuple(mapped_types[param] for param in base_model.__parameters__) - parameterized_base = base_model.__class_getitem__(base_parameters) - if parameterized_base is base_model or parameterized_base is cls: - # Avoid duplication in MRO - return - yield parameterized_base - - for base_model in cls.__bases__: - if not issubclass(base_model, GenericModel): - # not a class that can be meaningfully parameterized - continue - elif not getattr(base_model, '__parameters__', None): - # base_model is "GenericModel" (and has no __parameters__) - # or - # base_model is already concrete, and will be included transitively via cls. - continue - elif cls in _assigned_parameters: - if base_model in _assigned_parameters: - # cls is partially parameterised but not from base_model - # e.g. cls = B[S], base_model = A[S] - # B[S][int] should subclass A[int], (and will be transitively via B[int]) - # but it's not viable to consistently subclass types with arbitrary construction - # So don't attempt to include A[S][int] - continue - else: # base_model not in _assigned_parameters: - # cls is partially parameterized, base_model is original generic - # e.g. cls = B[str, T], base_model = B[S, T] - # Need to determine the mapping for the base_model parameters - mapped_types: Parametrization = { - key: typevars_map.get(value, value) for key, value in _assigned_parameters[cls].items() - } - yield from build_base_model(base_model, mapped_types) - else: - # cls is base generic, so base_class has a distinct base - # can construct the Parameterised base model using typevars_map directly - yield from build_base_model(base_model, typevars_map) - - -def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any: - """Return type with all occurrences of `type_map` keys recursively replaced with their values. - - :param type_: Any type, class or generic alias - :param type_map: Mapping from `TypeVar` instance to concrete types. - :return: New type representing the basic structure of `type_` with all - `typevar_map` keys recursively replaced. - - >>> replace_types(Tuple[str, Union[List[str], float]], {str: int}) - Tuple[int, Union[List[int], float]] - - """ - if not type_map: - return type_ - - type_args = get_args(type_) - origin_type = get_origin(type_) - - if origin_type is Annotated: - annotated_type, *annotations = type_args - return Annotated[replace_types(annotated_type, type_map), tuple(annotations)] - - # Having type args is a good indicator that this is a typing module - # class instantiation or a generic alias of some sort. - if type_args: - resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args) - if all_identical(type_args, resolved_type_args): - # If all arguments are the same, there is no need to modify the - # type or create a new object at all - return type_ - if ( - origin_type is not None - and isinstance(type_, typing_base) - and not isinstance(origin_type, typing_base) - and getattr(type_, '_name', None) is not None - ): - # In python < 3.9 generic aliases don't exist so any of these like `list`, - # `type` or `collections.abc.Callable` need to be translated. - # See: https://www.python.org/dev/peps/pep-0585 - origin_type = getattr(typing, type_._name) - assert origin_type is not None - return origin_type[resolved_type_args] - - # We handle pydantic generic models separately as they don't have the same - # semantics as "typing" classes or generic aliases - if not origin_type and lenient_issubclass(type_, GenericModel) and not type_.__concrete__: - type_args = type_.__parameters__ - resolved_type_args = tuple(replace_types(t, type_map) for t in type_args) - if all_identical(type_args, resolved_type_args): - return type_ - return type_[resolved_type_args] - - # Handle special case for typehints that can have lists as arguments. - # `typing.Callable[[int, str], int]` is an example for this. - if isinstance(type_, (List, list)): - resolved_list = list(replace_types(element, type_map) for element in type_) - if all_identical(type_, resolved_list): - return type_ - return resolved_list - - # For JsonWrapperValue, need to handle its inner type to allow correct parsing - # of generic Json arguments like Json[T] - if not origin_type and lenient_issubclass(type_, JsonWrapper): - type_.inner_type = replace_types(type_.inner_type, type_map) - return type_ - - # If all else fails, we try to resolve the type directly and otherwise just - # return the input with no modifications. - return type_map.get(type_, type_) - - -def check_parameters_count(cls: Type[GenericModel], parameters: Tuple[Any, ...]) -> None: - actual = len(parameters) - expected = len(cls.__parameters__) - if actual != expected: - description = 'many' if actual > expected else 'few' - raise TypeError(f'Too {description} parameters for {cls.__name__}; actual {actual}, expected {expected}') - - -DictValues: Type[Any] = {}.values().__class__ - - -def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]: - """Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found.""" - if isinstance(v, TypeVar): - yield v - elif hasattr(v, '__parameters__') and not get_origin(v) and lenient_issubclass(v, GenericModel): - yield from v.__parameters__ - elif isinstance(v, (DictValues, list)): - for var in v: - yield from iter_contained_typevars(var) - else: - args = get_args(v) - for arg in args: - yield from iter_contained_typevars(arg) - - -def get_caller_frame_info() -> Tuple[Optional[str], bool]: - """ - Used inside a function to check whether it was called globally - - Will only work against non-compiled code, therefore used only in pydantic.generics - - :returns Tuple[module_name, called_globally] - """ - try: - previous_caller_frame = sys._getframe(2) - except ValueError as e: - raise RuntimeError('This function must be used inside another function') from e - except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it - return None, False - frame_globals = previous_caller_frame.f_globals - return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals - - -def _prepare_model_fields( - created_model: Type[GenericModel], - fields: Mapping[str, Any], - instance_type_hints: Mapping[str, type], - typevars_map: Mapping[Any, type], -) -> None: - """ - Replace DeferredType fields with concrete type hints and prepare them. - """ - - for key, field in created_model.__fields__.items(): - if key not in fields: - assert field.type_.__class__ is not DeferredType - # https://github.com/nedbat/coveragepy/issues/198 - continue # pragma: no cover - - assert field.type_.__class__ is DeferredType, field.type_.__class__ - - field_type_hint = instance_type_hints[key] - concrete_type = replace_types(field_type_hint, typevars_map) - field.type_ = concrete_type - field.outer_type_ = concrete_type - field.prepare() - created_model.__annotations__[key] = concrete_type +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/json.py b/lib/pydantic/json.py index b358b850..020fb6d2 100644 --- a/lib/pydantic/json.py +++ b/lib/pydantic/json.py @@ -1,112 +1,4 @@ -import datetime -from collections import deque -from decimal import Decimal -from enum import Enum -from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network -from pathlib import Path -from re import Pattern -from types import GeneratorType -from typing import Any, Callable, Dict, Type, Union -from uuid import UUID +"""The `json` module is a backport module from V1.""" +from ._migration import getattr_migration -from .color import Color -from .networks import NameEmail -from .types import SecretBytes, SecretStr - -__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat' - - -def isoformat(o: Union[datetime.date, datetime.time]) -> str: - return o.isoformat() - - -def decimal_encoder(dec_value: Decimal) -> Union[int, float]: - """ - Encodes a Decimal as int of there's no exponent, otherwise float - - This is useful when we use ConstrainedDecimal to represent Numeric(x,0) - where a integer (but not int typed) is used. Encoding this as a float - results in failed round-tripping between encode and parse. - Our Id type is a prime example of this. - - >>> decimal_encoder(Decimal("1.0")) - 1.0 - - >>> decimal_encoder(Decimal("1")) - 1 - """ - if dec_value.as_tuple().exponent >= 0: - return int(dec_value) - else: - return float(dec_value) - - -ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { - bytes: lambda o: o.decode(), - Color: str, - datetime.date: isoformat, - datetime.datetime: isoformat, - datetime.time: isoformat, - datetime.timedelta: lambda td: td.total_seconds(), - Decimal: decimal_encoder, - Enum: lambda o: o.value, - frozenset: list, - deque: list, - GeneratorType: list, - IPv4Address: str, - IPv4Interface: str, - IPv4Network: str, - IPv6Address: str, - IPv6Interface: str, - IPv6Network: str, - NameEmail: str, - Path: str, - Pattern: lambda o: o.pattern, - SecretBytes: str, - SecretStr: str, - set: list, - UUID: str, -} - - -def pydantic_encoder(obj: Any) -> Any: - from dataclasses import asdict, is_dataclass - - from .main import BaseModel - - if isinstance(obj, BaseModel): - return obj.dict() - elif is_dataclass(obj): - return asdict(obj) - - # Check the class type and its superclasses for a matching encoder - for base in obj.__class__.__mro__[:-1]: - try: - encoder = ENCODERS_BY_TYPE[base] - except KeyError: - continue - return encoder(obj) - else: # We have exited the for loop without finding a suitable encoder - raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable") - - -def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any: - # Check the class type and its superclasses for a matching encoder - for base in obj.__class__.__mro__[:-1]: - try: - encoder = type_encoders[base] - except KeyError: - continue - - return encoder(obj) - else: # We have exited the for loop without finding a suitable encoder - return pydantic_encoder(obj) - - -def timedelta_isoformat(td: datetime.timedelta) -> str: - """ - ISO 8601 encoding for Python timedelta object. - """ - minutes, seconds = divmod(td.seconds, 60) - hours, minutes = divmod(minutes, 60) - return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S' +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/json_schema.py b/lib/pydantic/json_schema.py new file mode 100644 index 00000000..eee9e60e --- /dev/null +++ b/lib/pydantic/json_schema.py @@ -0,0 +1,2425 @@ +""" +Usage docs: https://docs.pydantic.dev/2.5/concepts/json_schema/ + +The `json_schema` module contains classes and functions to allow the way [JSON Schema](https://json-schema.org/) +is generated to be customized. + +In general you shouldn't need to use this module directly; instead, you can +[`BaseModel.model_json_schema`][pydantic.BaseModel.model_json_schema] and +[`TypeAdapter.json_schema`][pydantic.TypeAdapter.json_schema]. +""" +from __future__ import annotations as _annotations + +import dataclasses +import inspect +import math +import re +import warnings +from collections import defaultdict +from copy import deepcopy +from dataclasses import is_dataclass +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Counter, + Dict, + Hashable, + Iterable, + NewType, + Sequence, + Tuple, + TypeVar, + Union, + cast, +) + +import pydantic_core +from pydantic_core import CoreSchema, PydanticOmit, core_schema, to_jsonable_python +from pydantic_core.core_schema import ComputedField +from typing_extensions import Annotated, Literal, TypeAlias, assert_never, deprecated, final + +from pydantic.warnings import PydanticDeprecatedSince26 + +from ._internal import ( + _config, + _core_metadata, + _core_utils, + _decorators, + _internal_dataclass, + _mock_val_ser, + _schema_generation_shared, + _typing_extra, +) +from .annotated_handlers import GetJsonSchemaHandler +from .config import JsonDict, JsonSchemaExtraCallable, JsonValue +from .errors import PydanticInvalidForJsonSchema, PydanticUserError + +if TYPE_CHECKING: + from . import ConfigDict + from ._internal._core_utils import CoreSchemaField, CoreSchemaOrField + from ._internal._dataclasses import PydanticDataclass + from ._internal._schema_generation_shared import GetJsonSchemaFunction + from .main import BaseModel + + +CoreSchemaOrFieldType = Literal[core_schema.CoreSchemaType, core_schema.CoreSchemaFieldType] +""" +A type alias for defined schema types that represents a union of +`core_schema.CoreSchemaType` and +`core_schema.CoreSchemaFieldType`. +""" + +JsonSchemaValue = Dict[str, Any] +""" +A type alias for a JSON schema value. This is a dictionary of string keys to arbitrary JSON values. +""" + +JsonSchemaMode = Literal['validation', 'serialization'] +""" +A type alias that represents the mode of a JSON schema; either 'validation' or 'serialization'. + +For some types, the inputs to validation differ from the outputs of serialization. For example, +computed fields will only be present when serializing, and should not be provided when +validating. This flag provides a way to indicate whether you want the JSON schema required +for validation inputs, or that will be matched by serialization outputs. +""" + +_MODE_TITLE_MAPPING: dict[JsonSchemaMode, str] = {'validation': 'Input', 'serialization': 'Output'} + + +def update_json_schema(schema: JsonSchemaValue, updates: dict[str, Any]) -> JsonSchemaValue: + """Update a JSON schema in-place by providing a dictionary of updates. + + This function sets the provided key-value pairs in the schema and returns the updated schema. + + Args: + schema: The JSON schema to update. + updates: A dictionary of key-value pairs to set in the schema. + + Returns: + The updated JSON schema. + """ + schema.update(updates) + return schema + + +JsonSchemaWarningKind = Literal['skipped-choice', 'non-serializable-default'] +""" +A type alias representing the kinds of warnings that can be emitted during JSON schema generation. + +See [`GenerateJsonSchema.render_warning_message`][pydantic.json_schema.GenerateJsonSchema.render_warning_message] +for more details. +""" + + +class PydanticJsonSchemaWarning(UserWarning): + """This class is used to emit warnings produced during JSON schema generation. + See the [`GenerateJsonSchema.emit_warning`][pydantic.json_schema.GenerateJsonSchema.emit_warning] and + [`GenerateJsonSchema.render_warning_message`][pydantic.json_schema.GenerateJsonSchema.render_warning_message] + methods for more details; these can be overridden to control warning behavior. + """ + + +# ##### JSON Schema Generation ##### +DEFAULT_REF_TEMPLATE = '#/$defs/{model}' +"""The default format string used to generate reference names.""" + +# There are three types of references relevant to building JSON schemas: +# 1. core_schema "ref" values; these are not exposed as part of the JSON schema +# * these might look like the fully qualified path of a model, its id, or something similar +CoreRef = NewType('CoreRef', str) +# 2. keys of the "definitions" object that will eventually go into the JSON schema +# * by default, these look like "MyModel", though may change in the presence of collisions +# * eventually, we may want to make it easier to modify the way these names are generated +DefsRef = NewType('DefsRef', str) +# 3. the values corresponding to the "$ref" key in the schema +# * By default, these look like "#/$defs/MyModel", as in {"$ref": "#/$defs/MyModel"} +JsonRef = NewType('JsonRef', str) + +CoreModeRef = Tuple[CoreRef, JsonSchemaMode] +JsonSchemaKeyT = TypeVar('JsonSchemaKeyT', bound=Hashable) + + +@dataclasses.dataclass(**_internal_dataclass.slots_true) +class _DefinitionsRemapping: + defs_remapping: dict[DefsRef, DefsRef] + json_remapping: dict[JsonRef, JsonRef] + + @staticmethod + def from_prioritized_choices( + prioritized_choices: dict[DefsRef, list[DefsRef]], + defs_to_json: dict[DefsRef, JsonRef], + definitions: dict[DefsRef, JsonSchemaValue], + ) -> _DefinitionsRemapping: + """ + This function should produce a remapping that replaces complex DefsRef with the simpler ones from the + prioritized_choices such that applying the name remapping would result in an equivalent JSON schema. + """ + # We need to iteratively simplify the definitions until we reach a fixed point. + # The reason for this is that outer definitions may reference inner definitions that get simplified + # into an equivalent reference, and the outer definitions won't be equivalent until we've simplified + # the inner definitions. + copied_definitions = deepcopy(definitions) + definitions_schema = {'$defs': copied_definitions} + for _iter in range(100): # prevent an infinite loop in the case of a bug, 100 iterations should be enough + # For every possible remapped DefsRef, collect all schemas that that DefsRef might be used for: + schemas_for_alternatives: dict[DefsRef, list[JsonSchemaValue]] = defaultdict(list) + for defs_ref in copied_definitions: + alternatives = prioritized_choices[defs_ref] + for alternative in alternatives: + schemas_for_alternatives[alternative].append(copied_definitions[defs_ref]) + + # Deduplicate the schemas for each alternative; the idea is that we only want to remap to a new DefsRef + # if it introduces no ambiguity, i.e., there is only one distinct schema for that DefsRef. + for defs_ref, schemas in schemas_for_alternatives.items(): + schemas_for_alternatives[defs_ref] = _deduplicate_schemas(schemas_for_alternatives[defs_ref]) + + # Build the remapping + defs_remapping: dict[DefsRef, DefsRef] = {} + json_remapping: dict[JsonRef, JsonRef] = {} + for original_defs_ref in definitions: + alternatives = prioritized_choices[original_defs_ref] + # Pick the first alternative that has only one schema, since that means there is no collision + remapped_defs_ref = next(x for x in alternatives if len(schemas_for_alternatives[x]) == 1) + defs_remapping[original_defs_ref] = remapped_defs_ref + json_remapping[defs_to_json[original_defs_ref]] = defs_to_json[remapped_defs_ref] + remapping = _DefinitionsRemapping(defs_remapping, json_remapping) + new_definitions_schema = remapping.remap_json_schema({'$defs': copied_definitions}) + if definitions_schema == new_definitions_schema: + # We've reached the fixed point + return remapping + definitions_schema = new_definitions_schema + + raise PydanticInvalidForJsonSchema('Failed to simplify the JSON schema definitions') + + def remap_defs_ref(self, ref: DefsRef) -> DefsRef: + return self.defs_remapping.get(ref, ref) + + def remap_json_ref(self, ref: JsonRef) -> JsonRef: + return self.json_remapping.get(ref, ref) + + def remap_json_schema(self, schema: Any) -> Any: + """ + Recursively update the JSON schema replacing all $refs + """ + if isinstance(schema, str): + # Note: this may not really be a JsonRef; we rely on having no collisions between JsonRefs and other strings + return self.remap_json_ref(JsonRef(schema)) + elif isinstance(schema, list): + return [self.remap_json_schema(item) for item in schema] + elif isinstance(schema, dict): + for key, value in schema.items(): + if key == '$ref' and isinstance(value, str): + schema['$ref'] = self.remap_json_ref(JsonRef(value)) + elif key == '$defs': + schema['$defs'] = { + self.remap_defs_ref(DefsRef(key)): self.remap_json_schema(value) + for key, value in schema['$defs'].items() + } + else: + schema[key] = self.remap_json_schema(value) + return schema + + +class GenerateJsonSchema: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/json_schema/#customizing-the-json-schema-generation-process + + A class for generating JSON schemas. + + This class generates JSON schemas based on configured parameters. The default schema dialect + is [https://json-schema.org/draft/2020-12/schema](https://json-schema.org/draft/2020-12/schema). + The class uses `by_alias` to configure how fields with + multiple names are handled and `ref_template` to format reference names. + + Attributes: + schema_dialect: The JSON schema dialect used to generate the schema. See + [Declaring a Dialect](https://json-schema.org/understanding-json-schema/reference/schema.html#id4) + in the JSON Schema documentation for more information about dialects. + ignored_warning_kinds: Warnings to ignore when generating the schema. `self.render_warning_message` will + do nothing if its argument `kind` is in `ignored_warning_kinds`; + this value can be modified on subclasses to easily control which warnings are emitted. + by_alias: Whether to use field aliases when generating the schema. + ref_template: The format string used when generating reference names. + core_to_json_refs: A mapping of core refs to JSON refs. + core_to_defs_refs: A mapping of core refs to definition refs. + defs_to_core_refs: A mapping of definition refs to core refs. + json_to_defs_refs: A mapping of JSON refs to definition refs. + definitions: Definitions in the schema. + + Args: + by_alias: Whether to use field aliases in the generated schemas. + ref_template: The format string to use when generating reference names. + + Raises: + JsonSchemaError: If the instance of the class is inadvertently re-used after generating a schema. + """ + + schema_dialect = 'https://json-schema.org/draft/2020-12/schema' + + # `self.render_warning_message` will do nothing if its argument `kind` is in `ignored_warning_kinds`; + # this value can be modified on subclasses to easily control which warnings are emitted + ignored_warning_kinds: set[JsonSchemaWarningKind] = {'skipped-choice'} + + def __init__(self, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE): + self.by_alias = by_alias + self.ref_template = ref_template + + self.core_to_json_refs: dict[CoreModeRef, JsonRef] = {} + self.core_to_defs_refs: dict[CoreModeRef, DefsRef] = {} + self.defs_to_core_refs: dict[DefsRef, CoreModeRef] = {} + self.json_to_defs_refs: dict[JsonRef, DefsRef] = {} + + self.definitions: dict[DefsRef, JsonSchemaValue] = {} + self._config_wrapper_stack = _config.ConfigWrapperStack(_config.ConfigWrapper({})) + + self._mode: JsonSchemaMode = 'validation' + + # The following includes a mapping of a fully-unique defs ref choice to a list of preferred + # alternatives, which are generally simpler, such as only including the class name. + # At the end of schema generation, we use these to produce a JSON schema with more human-readable + # definitions, which would also work better in a generated OpenAPI client, etc. + self._prioritized_defsref_choices: dict[DefsRef, list[DefsRef]] = {} + self._collision_counter: dict[str, int] = defaultdict(int) + self._collision_index: dict[str, int] = {} + + self._schema_type_to_method = self.build_schema_type_to_method() + + # When we encounter definitions we need to try to build them immediately + # so that they are available schemas that reference them + # But it's possible that CoreSchema was never going to be used + # (e.g. because the CoreSchema that references short circuits is JSON schema generation without needing + # the reference) so instead of failing altogether if we can't build a definition we + # store the error raised and re-throw it if we end up needing that def + self._core_defs_invalid_for_json_schema: dict[DefsRef, PydanticInvalidForJsonSchema] = {} + + # This changes to True after generating a schema, to prevent issues caused by accidental re-use + # of a single instance of a schema generator + self._used = False + + @property + def _config(self) -> _config.ConfigWrapper: + return self._config_wrapper_stack.tail + + @property + def mode(self) -> JsonSchemaMode: + if self._config.json_schema_mode_override is not None: + return self._config.json_schema_mode_override + else: + return self._mode + + def build_schema_type_to_method( + self, + ) -> dict[CoreSchemaOrFieldType, Callable[[CoreSchemaOrField], JsonSchemaValue]]: + """Builds a dictionary mapping fields to methods for generating JSON schemas. + + Returns: + A dictionary containing the mapping of `CoreSchemaOrFieldType` to a handler method. + + Raises: + TypeError: If no method has been defined for generating a JSON schema for a given pydantic core schema type. + """ + mapping: dict[CoreSchemaOrFieldType, Callable[[CoreSchemaOrField], JsonSchemaValue]] = {} + core_schema_types: list[CoreSchemaOrFieldType] = _typing_extra.all_literal_values( + CoreSchemaOrFieldType # type: ignore + ) + for key in core_schema_types: + method_name = f"{key.replace('-', '_')}_schema" + try: + mapping[key] = getattr(self, method_name) + except AttributeError as e: # pragma: no cover + raise TypeError( + f'No method for generating JsonSchema for core_schema.type={key!r} ' + f'(expected: {type(self).__name__}.{method_name})' + ) from e + return mapping + + def generate_definitions( + self, inputs: Sequence[tuple[JsonSchemaKeyT, JsonSchemaMode, core_schema.CoreSchema]] + ) -> tuple[dict[tuple[JsonSchemaKeyT, JsonSchemaMode], JsonSchemaValue], dict[DefsRef, JsonSchemaValue]]: + """Generates JSON schema definitions from a list of core schemas, pairing the generated definitions with a + mapping that links the input keys to the definition references. + + Args: + inputs: A sequence of tuples, where: + + - The first element is a JSON schema key type. + - The second element is the JSON mode: either 'validation' or 'serialization'. + - The third element is a core schema. + + Returns: + A tuple where: + + - The first element is a dictionary whose keys are tuples of JSON schema key type and JSON mode, and + whose values are the JSON schema corresponding to that pair of inputs. (These schemas may have + JsonRef references to definitions that are defined in the second returned element.) + - The second element is a dictionary whose keys are definition references for the JSON schemas + from the first returned element, and whose values are the actual JSON schema definitions. + + Raises: + PydanticUserError: Raised if the JSON schema generator has already been used to generate a JSON schema. + """ + if self._used: + raise PydanticUserError( + 'This JSON schema generator has already been used to generate a JSON schema. ' + f'You must create a new instance of {type(self).__name__} to generate a new JSON schema.', + code='json-schema-already-used', + ) + + for key, mode, schema in inputs: + self._mode = mode + self.generate_inner(schema) + + definitions_remapping = self._build_definitions_remapping() + + json_schemas_map: dict[tuple[JsonSchemaKeyT, JsonSchemaMode], DefsRef] = {} + for key, mode, schema in inputs: + self._mode = mode + json_schema = self.generate_inner(schema) + json_schemas_map[(key, mode)] = definitions_remapping.remap_json_schema(json_schema) + + json_schema = {'$defs': self.definitions} + json_schema = definitions_remapping.remap_json_schema(json_schema) + self._used = True + return json_schemas_map, _sort_json_schema(json_schema['$defs']) # type: ignore + + def generate(self, schema: CoreSchema, mode: JsonSchemaMode = 'validation') -> JsonSchemaValue: + """Generates a JSON schema for a specified schema in a specified mode. + + Args: + schema: A Pydantic model. + mode: The mode in which to generate the schema. Defaults to 'validation'. + + Returns: + A JSON schema representing the specified schema. + + Raises: + PydanticUserError: If the JSON schema generator has already been used to generate a JSON schema. + """ + self._mode = mode + if self._used: + raise PydanticUserError( + 'This JSON schema generator has already been used to generate a JSON schema. ' + f'You must create a new instance of {type(self).__name__} to generate a new JSON schema.', + code='json-schema-already-used', + ) + + json_schema: JsonSchemaValue = self.generate_inner(schema) + json_ref_counts = self.get_json_ref_counts(json_schema) + + # Remove the top-level $ref if present; note that the _generate method already ensures there are no sibling keys + ref = cast(JsonRef, json_schema.get('$ref')) + while ref is not None: # may need to unpack multiple levels + ref_json_schema = self.get_schema_from_definitions(ref) + if json_ref_counts[ref] > 1 or ref_json_schema is None: + # Keep the ref, but use an allOf to remove the top level $ref + json_schema = {'allOf': [{'$ref': ref}]} + else: + # "Unpack" the ref since this is the only reference + json_schema = ref_json_schema.copy() # copy to prevent recursive dict reference + json_ref_counts[ref] -= 1 + ref = cast(JsonRef, json_schema.get('$ref')) + + self._garbage_collect_definitions(json_schema) + definitions_remapping = self._build_definitions_remapping() + + if self.definitions: + json_schema['$defs'] = self.definitions + + json_schema = definitions_remapping.remap_json_schema(json_schema) + + # For now, we will not set the $schema key. However, if desired, this can be easily added by overriding + # this method and adding the following line after a call to super().generate(schema): + # json_schema['$schema'] = self.schema_dialect + + self._used = True + return _sort_json_schema(json_schema) + + def generate_inner(self, schema: CoreSchemaOrField) -> JsonSchemaValue: # noqa: C901 + """Generates a JSON schema for a given core schema. + + Args: + schema: The given core schema. + + Returns: + The generated JSON schema. + """ + # If a schema with the same CoreRef has been handled, just return a reference to it + # Note that this assumes that it will _never_ be the case that the same CoreRef is used + # on types that should have different JSON schemas + if 'ref' in schema: + core_ref = CoreRef(schema['ref']) # type: ignore[typeddict-item] + core_mode_ref = (core_ref, self.mode) + if core_mode_ref in self.core_to_defs_refs and self.core_to_defs_refs[core_mode_ref] in self.definitions: + return {'$ref': self.core_to_json_refs[core_mode_ref]} + + # Generate the JSON schema, accounting for the json_schema_override and core_schema_override + metadata_handler = _core_metadata.CoreMetadataHandler(schema) + + def populate_defs(core_schema: CoreSchema, json_schema: JsonSchemaValue) -> JsonSchemaValue: + if 'ref' in core_schema: + core_ref = CoreRef(core_schema['ref']) # type: ignore[typeddict-item] + defs_ref, ref_json_schema = self.get_cache_defs_ref_schema(core_ref) + json_ref = JsonRef(ref_json_schema['$ref']) + self.json_to_defs_refs[json_ref] = defs_ref + # Replace the schema if it's not a reference to itself + # What we want to avoid is having the def be just a ref to itself + # which is what would happen if we blindly assigned any + if json_schema.get('$ref', None) != json_ref: + self.definitions[defs_ref] = json_schema + self._core_defs_invalid_for_json_schema.pop(defs_ref, None) + json_schema = ref_json_schema + return json_schema + + def convert_to_all_of(json_schema: JsonSchemaValue) -> JsonSchemaValue: + if '$ref' in json_schema and len(json_schema.keys()) > 1: + # technically you can't have any other keys next to a "$ref" + # but it's an easy mistake to make and not hard to correct automatically here + json_schema = json_schema.copy() + ref = json_schema.pop('$ref') + json_schema = {'allOf': [{'$ref': ref}], **json_schema} + return json_schema + + def handler_func(schema_or_field: CoreSchemaOrField) -> JsonSchemaValue: + """Generate a JSON schema based on the input schema. + + Args: + schema_or_field: The core schema to generate a JSON schema from. + + Returns: + The generated JSON schema. + + Raises: + TypeError: If an unexpected schema type is encountered. + """ + # Generate the core-schema-type-specific bits of the schema generation: + json_schema: JsonSchemaValue | None = None + if self.mode == 'serialization' and 'serialization' in schema_or_field: + ser_schema = schema_or_field['serialization'] # type: ignore + json_schema = self.ser_schema(ser_schema) + if json_schema is None: + if _core_utils.is_core_schema(schema_or_field) or _core_utils.is_core_schema_field(schema_or_field): + generate_for_schema_type = self._schema_type_to_method[schema_or_field['type']] + json_schema = generate_for_schema_type(schema_or_field) + else: + raise TypeError(f'Unexpected schema type: schema={schema_or_field}') + if _core_utils.is_core_schema(schema_or_field): + json_schema = populate_defs(schema_or_field, json_schema) + json_schema = convert_to_all_of(json_schema) + return json_schema + + current_handler = _schema_generation_shared.GenerateJsonSchemaHandler(self, handler_func) + + for js_modify_function in metadata_handler.metadata.get('pydantic_js_functions', ()): + + def new_handler_func( + schema_or_field: CoreSchemaOrField, + current_handler: GetJsonSchemaHandler = current_handler, + js_modify_function: GetJsonSchemaFunction = js_modify_function, + ) -> JsonSchemaValue: + json_schema = js_modify_function(schema_or_field, current_handler) + if _core_utils.is_core_schema(schema_or_field): + json_schema = populate_defs(schema_or_field, json_schema) + original_schema = current_handler.resolve_ref_schema(json_schema) + ref = json_schema.pop('$ref', None) + if ref and json_schema: + original_schema.update(json_schema) + return original_schema + + current_handler = _schema_generation_shared.GenerateJsonSchemaHandler(self, new_handler_func) + + for js_modify_function in metadata_handler.metadata.get('pydantic_js_annotation_functions', ()): + + def new_handler_func( + schema_or_field: CoreSchemaOrField, + current_handler: GetJsonSchemaHandler = current_handler, + js_modify_function: GetJsonSchemaFunction = js_modify_function, + ) -> JsonSchemaValue: + json_schema = js_modify_function(schema_or_field, current_handler) + if _core_utils.is_core_schema(schema_or_field): + json_schema = populate_defs(schema_or_field, json_schema) + json_schema = convert_to_all_of(json_schema) + return json_schema + + current_handler = _schema_generation_shared.GenerateJsonSchemaHandler(self, new_handler_func) + + json_schema = current_handler(schema) + if _core_utils.is_core_schema(schema): + json_schema = populate_defs(schema, json_schema) + json_schema = convert_to_all_of(json_schema) + return json_schema + + # ### Schema generation methods + def any_schema(self, schema: core_schema.AnySchema) -> JsonSchemaValue: + """Generates a JSON schema that matches any value. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return {} + + def none_schema(self, schema: core_schema.NoneSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches `None`. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return {'type': 'null'} + + def bool_schema(self, schema: core_schema.BoolSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a bool value. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return {'type': 'boolean'} + + def int_schema(self, schema: core_schema.IntSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches an int value. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + json_schema: dict[str, Any] = {'type': 'integer'} + self.update_with_validations(json_schema, schema, self.ValidationsMapping.numeric) + json_schema = {k: v for k, v in json_schema.items() if v not in {math.inf, -math.inf}} + return json_schema + + def float_schema(self, schema: core_schema.FloatSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a float value. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + json_schema: dict[str, Any] = {'type': 'number'} + self.update_with_validations(json_schema, schema, self.ValidationsMapping.numeric) + json_schema = {k: v for k, v in json_schema.items() if v not in {math.inf, -math.inf}} + return json_schema + + def decimal_schema(self, schema: core_schema.DecimalSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a decimal value. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + json_schema = self.str_schema(core_schema.str_schema()) + if self.mode == 'validation': + multiple_of = schema.get('multiple_of') + le = schema.get('le') + ge = schema.get('ge') + lt = schema.get('lt') + gt = schema.get('gt') + json_schema = { + 'anyOf': [ + self.float_schema( + core_schema.float_schema( + allow_inf_nan=schema.get('allow_inf_nan'), + multiple_of=None if multiple_of is None else float(multiple_of), + le=None if le is None else float(le), + ge=None if ge is None else float(ge), + lt=None if lt is None else float(lt), + gt=None if gt is None else float(gt), + ) + ), + json_schema, + ], + } + return json_schema + + def str_schema(self, schema: core_schema.StringSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a string value. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + json_schema = {'type': 'string'} + self.update_with_validations(json_schema, schema, self.ValidationsMapping.string) + return json_schema + + def bytes_schema(self, schema: core_schema.BytesSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a bytes value. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + json_schema = {'type': 'string', 'format': 'base64url' if self._config.ser_json_bytes == 'base64' else 'binary'} + self.update_with_validations(json_schema, schema, self.ValidationsMapping.bytes) + return json_schema + + def date_schema(self, schema: core_schema.DateSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a date value. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + json_schema = {'type': 'string', 'format': 'date'} + self.update_with_validations(json_schema, schema, self.ValidationsMapping.date) + return json_schema + + def time_schema(self, schema: core_schema.TimeSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a time value. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return {'type': 'string', 'format': 'time'} + + def datetime_schema(self, schema: core_schema.DatetimeSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a datetime value. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return {'type': 'string', 'format': 'date-time'} + + def timedelta_schema(self, schema: core_schema.TimedeltaSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a timedelta value. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + if self._config.ser_json_timedelta == 'float': + return {'type': 'number'} + return {'type': 'string', 'format': 'duration'} + + def literal_schema(self, schema: core_schema.LiteralSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a literal value. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + expected = [v.value if isinstance(v, Enum) else v for v in schema['expected']] + # jsonify the expected values + expected = [to_jsonable_python(v) for v in expected] + + if len(expected) == 1: + return {'const': expected[0]} + + types = {type(e) for e in expected} + if types == {str}: + return {'enum': expected, 'type': 'string'} + elif types == {int}: + return {'enum': expected, 'type': 'integer'} + elif types == {float}: + return {'enum': expected, 'type': 'number'} + elif types == {bool}: + return {'enum': expected, 'type': 'boolean'} + elif types == {list}: + return {'enum': expected, 'type': 'array'} + # there is not None case because if it's mixed it hits the final `else` + # if it's a single Literal[None] then it becomes a `const` schema above + else: + return {'enum': expected} + + def is_instance_schema(self, schema: core_schema.IsInstanceSchema) -> JsonSchemaValue: + """Handles JSON schema generation for a core schema that checks if a value is an instance of a class. + + Unless overridden in a subclass, this raises an error. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return self.handle_invalid_for_json_schema(schema, f'core_schema.IsInstanceSchema ({schema["cls"]})') + + def is_subclass_schema(self, schema: core_schema.IsSubclassSchema) -> JsonSchemaValue: + """Handles JSON schema generation for a core schema that checks if a value is a subclass of a class. + + For backwards compatibility with v1, this does not raise an error, but can be overridden to change this. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + # Note: This is for compatibility with V1; you can override if you want different behavior. + return {} + + def callable_schema(self, schema: core_schema.CallableSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a callable value. + + Unless overridden in a subclass, this raises an error. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return self.handle_invalid_for_json_schema(schema, 'core_schema.CallableSchema') + + def list_schema(self, schema: core_schema.ListSchema) -> JsonSchemaValue: + """Returns a schema that matches a list schema. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + items_schema = {} if 'items_schema' not in schema else self.generate_inner(schema['items_schema']) + json_schema = {'type': 'array', 'items': items_schema} + self.update_with_validations(json_schema, schema, self.ValidationsMapping.array) + return json_schema + + @deprecated('`tuple_positional_schema` is deprecated. Use `tuple_schema` instead.', category=None) + @final + def tuple_positional_schema(self, schema: core_schema.TupleSchema) -> JsonSchemaValue: + """Replaced by `tuple_schema`.""" + warnings.warn( + '`tuple_positional_schema` is deprecated. Use `tuple_schema` instead.', + PydanticDeprecatedSince26, + stacklevel=2, + ) + return self.tuple_schema(schema) + + @deprecated('`tuple_variable_schema` is deprecated. Use `tuple_schema` instead.', category=None) + @final + def tuple_variable_schema(self, schema: core_schema.TupleSchema) -> JsonSchemaValue: + """Replaced by `tuple_schema`.""" + warnings.warn( + '`tuple_variable_schema` is deprecated. Use `tuple_schema` instead.', + PydanticDeprecatedSince26, + stacklevel=2, + ) + return self.tuple_schema(schema) + + def tuple_schema(self, schema: core_schema.TupleSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a tuple schema e.g. `Tuple[int, + str, bool]` or `Tuple[int, ...]`. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + json_schema: JsonSchemaValue = {'type': 'array'} + if 'variadic_item_index' in schema: + variadic_item_index = schema['variadic_item_index'] + if variadic_item_index > 0: + json_schema['minItems'] = variadic_item_index + json_schema['prefixItems'] = [ + self.generate_inner(item) for item in schema['items_schema'][:variadic_item_index] + ] + if variadic_item_index + 1 == len(schema['items_schema']): + # if the variadic item is the last item, then represent it faithfully + json_schema['items'] = self.generate_inner(schema['items_schema'][variadic_item_index]) + else: + # otherwise, 'items' represents the schema for the variadic + # item plus the suffix, so just allow anything for simplicity + # for now + json_schema['items'] = True + else: + prefixItems = [self.generate_inner(item) for item in schema['items_schema']] + if prefixItems: + json_schema['prefixItems'] = prefixItems + json_schema['minItems'] = len(prefixItems) + json_schema['maxItems'] = len(prefixItems) + self.update_with_validations(json_schema, schema, self.ValidationsMapping.array) + return json_schema + + def set_schema(self, schema: core_schema.SetSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a set schema. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return self._common_set_schema(schema) + + def frozenset_schema(self, schema: core_schema.FrozenSetSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a frozenset schema. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return self._common_set_schema(schema) + + def _common_set_schema(self, schema: core_schema.SetSchema | core_schema.FrozenSetSchema) -> JsonSchemaValue: + items_schema = {} if 'items_schema' not in schema else self.generate_inner(schema['items_schema']) + json_schema = {'type': 'array', 'uniqueItems': True, 'items': items_schema} + self.update_with_validations(json_schema, schema, self.ValidationsMapping.array) + return json_schema + + def generator_schema(self, schema: core_schema.GeneratorSchema) -> JsonSchemaValue: + """Returns a JSON schema that represents the provided GeneratorSchema. + + Args: + schema: The schema. + + Returns: + The generated JSON schema. + """ + items_schema = {} if 'items_schema' not in schema else self.generate_inner(schema['items_schema']) + json_schema = {'type': 'array', 'items': items_schema} + self.update_with_validations(json_schema, schema, self.ValidationsMapping.array) + return json_schema + + def dict_schema(self, schema: core_schema.DictSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a dict schema. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + json_schema: JsonSchemaValue = {'type': 'object'} + + keys_schema = self.generate_inner(schema['keys_schema']).copy() if 'keys_schema' in schema else {} + keys_pattern = keys_schema.pop('pattern', None) + + values_schema = self.generate_inner(schema['values_schema']).copy() if 'values_schema' in schema else {} + values_schema.pop('title', None) # don't give a title to the additionalProperties + if values_schema or keys_pattern is not None: # don't add additionalProperties if it's empty + if keys_pattern is None: + json_schema['additionalProperties'] = values_schema + else: + json_schema['patternProperties'] = {keys_pattern: values_schema} + + self.update_with_validations(json_schema, schema, self.ValidationsMapping.object) + return json_schema + + def _function_schema( + self, + schema: _core_utils.AnyFunctionSchema, + ) -> JsonSchemaValue: + if _core_utils.is_function_with_inner_schema(schema): + # This could be wrong if the function's mode is 'before', but in practice will often be right, and when it + # isn't, I think it would be hard to automatically infer what the desired schema should be. + return self.generate_inner(schema['schema']) + + # function-plain + return self.handle_invalid_for_json_schema( + schema, f'core_schema.PlainValidatorFunctionSchema ({schema["function"]})' + ) + + def function_before_schema(self, schema: core_schema.BeforeValidatorFunctionSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a function-before schema. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return self._function_schema(schema) + + def function_after_schema(self, schema: core_schema.AfterValidatorFunctionSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a function-after schema. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return self._function_schema(schema) + + def function_plain_schema(self, schema: core_schema.PlainValidatorFunctionSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a function-plain schema. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return self._function_schema(schema) + + def function_wrap_schema(self, schema: core_schema.WrapValidatorFunctionSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a function-wrap schema. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return self._function_schema(schema) + + def default_schema(self, schema: core_schema.WithDefaultSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema with a default value. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + json_schema = self.generate_inner(schema['schema']) + + if 'default' not in schema: + return json_schema + default = schema['default'] + # Note: if you want to include the value returned by the default_factory, + # override this method and replace the code above with: + # if 'default' in schema: + # default = schema['default'] + # elif 'default_factory' in schema: + # default = schema['default_factory']() + # else: + # return json_schema + + try: + encoded_default = self.encode_default(default) + except pydantic_core.PydanticSerializationError: + self.emit_warning( + 'non-serializable-default', + f'Default value {default} is not JSON serializable; excluding default from JSON schema', + ) + # Return the inner schema, as though there was no default + return json_schema + + if '$ref' in json_schema: + # Since reference schemas do not support child keys, we wrap the reference schema in a single-case allOf: + return {'allOf': [json_schema], 'default': encoded_default} + else: + json_schema['default'] = encoded_default + return json_schema + + def nullable_schema(self, schema: core_schema.NullableSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that allows null values. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + null_schema = {'type': 'null'} + inner_json_schema = self.generate_inner(schema['schema']) + + if inner_json_schema == null_schema: + return null_schema + else: + # Thanks to the equality check against `null_schema` above, I think 'oneOf' would also be valid here; + # I'll use 'anyOf' for now, but it could be changed it if it would work better with some external tooling + return self.get_flattened_anyof([inner_json_schema, null_schema]) + + def union_schema(self, schema: core_schema.UnionSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that allows values matching any of the given schemas. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + generated: list[JsonSchemaValue] = [] + + choices = schema['choices'] + for choice in choices: + # choice will be a tuple if an explicit label was provided + choice_schema = choice[0] if isinstance(choice, tuple) else choice + try: + generated.append(self.generate_inner(choice_schema)) + except PydanticOmit: + continue + except PydanticInvalidForJsonSchema as exc: + self.emit_warning('skipped-choice', exc.message) + if len(generated) == 1: + return generated[0] + return self.get_flattened_anyof(generated) + + def tagged_union_schema(self, schema: core_schema.TaggedUnionSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that allows values matching any of the given schemas, where + the schemas are tagged with a discriminator field that indicates which schema should be used to validate + the value. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + generated: dict[str, JsonSchemaValue] = {} + for k, v in schema['choices'].items(): + if isinstance(k, Enum): + k = k.value + try: + # Use str(k) since keys must be strings for json; while not technically correct, + # it's the closest that can be represented in valid JSON + generated[str(k)] = self.generate_inner(v).copy() + except PydanticOmit: + continue + except PydanticInvalidForJsonSchema as exc: + self.emit_warning('skipped-choice', exc.message) + + one_of_choices = _deduplicate_schemas(generated.values()) + json_schema: JsonSchemaValue = {'oneOf': one_of_choices} + + # This reflects the v1 behavior; TODO: we should make it possible to exclude OpenAPI stuff from the JSON schema + openapi_discriminator = self._extract_discriminator(schema, one_of_choices) + if openapi_discriminator is not None: + json_schema['discriminator'] = { + 'propertyName': openapi_discriminator, + 'mapping': {k: v.get('$ref', v) for k, v in generated.items()}, + } + + return json_schema + + def _extract_discriminator( + self, schema: core_schema.TaggedUnionSchema, one_of_choices: list[JsonDict] + ) -> str | None: + """Extract a compatible OpenAPI discriminator from the schema and one_of choices that end up in the final + schema.""" + openapi_discriminator: str | None = None + + if isinstance(schema['discriminator'], str): + return schema['discriminator'] + + if isinstance(schema['discriminator'], list): + # If the discriminator is a single item list containing a string, that is equivalent to the string case + if len(schema['discriminator']) == 1 and isinstance(schema['discriminator'][0], str): + return schema['discriminator'][0] + # When an alias is used that is different from the field name, the discriminator will be a list of single + # str lists, one for the attribute and one for the actual alias. The logic here will work even if there is + # more than one possible attribute, and looks for whether a single alias choice is present as a documented + # property on all choices. If so, that property will be used as the OpenAPI discriminator. + for alias_path in schema['discriminator']: + if not isinstance(alias_path, list): + break # this means that the discriminator is not a list of alias paths + if len(alias_path) != 1: + continue # this means that the "alias" does not represent a single field + alias = alias_path[0] + if not isinstance(alias, str): + continue # this means that the "alias" does not represent a field + alias_is_present_on_all_choices = True + for choice in one_of_choices: + while '$ref' in choice: + assert isinstance(choice['$ref'], str) + choice = self.get_schema_from_definitions(JsonRef(choice['$ref'])) or {} + properties = choice.get('properties', {}) + if not isinstance(properties, dict) or alias not in properties: + alias_is_present_on_all_choices = False + break + if alias_is_present_on_all_choices: + openapi_discriminator = alias + break + return openapi_discriminator + + def chain_schema(self, schema: core_schema.ChainSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a core_schema.ChainSchema. + + When generating a schema for validation, we return the validation JSON schema for the first step in the chain. + For serialization, we return the serialization JSON schema for the last step in the chain. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + step_index = 0 if self.mode == 'validation' else -1 # use first step for validation, last for serialization + return self.generate_inner(schema['steps'][step_index]) + + def lax_or_strict_schema(self, schema: core_schema.LaxOrStrictSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that allows values matching either the lax schema or the + strict schema. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + # TODO: Need to read the default value off of model config or whatever + use_strict = schema.get('strict', False) # TODO: replace this default False + # If your JSON schema fails to generate it is probably + # because one of the following two branches failed. + if use_strict: + return self.generate_inner(schema['strict_schema']) + else: + return self.generate_inner(schema['lax_schema']) + + def json_or_python_schema(self, schema: core_schema.JsonOrPythonSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that allows values matching either the JSON schema or the + Python schema. + + The JSON schema is used instead of the Python schema. If you want to use the Python schema, you should override + this method. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return self.generate_inner(schema['json_schema']) + + def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a typed dict. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + total = schema.get('total', True) + named_required_fields: list[tuple[str, bool, CoreSchemaField]] = [ + (name, self.field_is_required(field, total), field) + for name, field in schema['fields'].items() + if self.field_is_present(field) + ] + if self.mode == 'serialization': + named_required_fields.extend(self._name_required_computed_fields(schema.get('computed_fields', []))) + + config = _get_typed_dict_config(schema) + with self._config_wrapper_stack.push(config): + json_schema = self._named_required_fields_schema(named_required_fields) + + extra = schema.get('extra_behavior') + if extra is None: + extra = config.get('extra', 'ignore') + if extra == 'forbid': + json_schema['additionalProperties'] = False + elif extra == 'allow': + json_schema['additionalProperties'] = True + + return json_schema + + @staticmethod + def _name_required_computed_fields( + computed_fields: list[ComputedField], + ) -> list[tuple[str, bool, core_schema.ComputedField]]: + return [(field['property_name'], True, field) for field in computed_fields] + + def _named_required_fields_schema( + self, named_required_fields: Sequence[tuple[str, bool, CoreSchemaField]] + ) -> JsonSchemaValue: + properties: dict[str, JsonSchemaValue] = {} + required_fields: list[str] = [] + for name, required, field in named_required_fields: + if self.by_alias: + name = self._get_alias_name(field, name) + try: + field_json_schema = self.generate_inner(field).copy() + except PydanticOmit: + continue + if 'title' not in field_json_schema and self.field_title_should_be_set(field): + title = self.get_title_from_name(name) + field_json_schema['title'] = title + field_json_schema = self.handle_ref_overrides(field_json_schema) + properties[name] = field_json_schema + if required: + required_fields.append(name) + + json_schema = {'type': 'object', 'properties': properties} + if required_fields: + json_schema['required'] = required_fields + return json_schema + + def _get_alias_name(self, field: CoreSchemaField, name: str) -> str: + if field['type'] == 'computed-field': + alias: Any = field.get('alias', name) + elif self.mode == 'validation': + alias = field.get('validation_alias', name) + else: + alias = field.get('serialization_alias', name) + if isinstance(alias, str): + name = alias + elif isinstance(alias, list): + alias = cast('list[str] | str', alias) + for path in alias: + if isinstance(path, list) and len(path) == 1 and isinstance(path[0], str): + # Use the first valid single-item string path; the code that constructs the alias array + # should ensure the first such item is what belongs in the JSON schema + name = path[0] + break + else: + assert_never(alias) + return name + + def typed_dict_field_schema(self, schema: core_schema.TypedDictField) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a typed dict field. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return self.generate_inner(schema['schema']) + + def dataclass_field_schema(self, schema: core_schema.DataclassField) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a dataclass field. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return self.generate_inner(schema['schema']) + + def model_field_schema(self, schema: core_schema.ModelField) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a model field. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return self.generate_inner(schema['schema']) + + def computed_field_schema(self, schema: core_schema.ComputedField) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a computed field. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return self.generate_inner(schema['return_schema']) + + def model_schema(self, schema: core_schema.ModelSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a model. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + # We do not use schema['model'].model_json_schema() here + # because it could lead to inconsistent refs handling, etc. + cls = cast('type[BaseModel]', schema['cls']) + config = cls.model_config + title = config.get('title') + + with self._config_wrapper_stack.push(config): + json_schema = self.generate_inner(schema['schema']) + + json_schema_extra = config.get('json_schema_extra') + if cls.__pydantic_root_model__: + root_json_schema_extra = cls.model_fields['root'].json_schema_extra + if json_schema_extra and root_json_schema_extra: + raise ValueError( + '"model_config[\'json_schema_extra\']" and "Field.json_schema_extra" on "RootModel.root"' + ' field must not be set simultaneously' + ) + if root_json_schema_extra: + json_schema_extra = root_json_schema_extra + + json_schema = self._update_class_schema(json_schema, title, config.get('extra', None), cls, json_schema_extra) + + return json_schema + + def _update_class_schema( + self, + json_schema: JsonSchemaValue, + title: str | None, + extra: Literal['allow', 'ignore', 'forbid'] | None, + cls: type[Any], + json_schema_extra: JsonDict | JsonSchemaExtraCallable | None, + ) -> JsonSchemaValue: + if '$ref' in json_schema: + schema_to_update = self.get_schema_from_definitions(JsonRef(json_schema['$ref'])) or json_schema + else: + schema_to_update = json_schema + + if title is not None: + # referenced_schema['title'] = title + schema_to_update.setdefault('title', title) + + if 'additionalProperties' not in schema_to_update: + if extra == 'allow': + schema_to_update['additionalProperties'] = True + elif extra == 'forbid': + schema_to_update['additionalProperties'] = False + + if isinstance(json_schema_extra, (staticmethod, classmethod)): + # In older versions of python, this is necessary to ensure staticmethod/classmethods are callable + json_schema_extra = json_schema_extra.__get__(cls) + + if isinstance(json_schema_extra, dict): + schema_to_update.update(json_schema_extra) + elif callable(json_schema_extra): + if len(inspect.signature(json_schema_extra).parameters) > 1: + json_schema_extra(schema_to_update, cls) # type: ignore + else: + json_schema_extra(schema_to_update) # type: ignore + elif json_schema_extra is not None: + raise ValueError( + f"model_config['json_schema_extra']={json_schema_extra} should be a dict, callable, or None" + ) + + return json_schema + + def resolve_schema_to_update(self, json_schema: JsonSchemaValue) -> JsonSchemaValue: + """Resolve a JsonSchemaValue to the non-ref schema if it is a $ref schema. + + Args: + json_schema: The schema to resolve. + + Returns: + The resolved schema. + """ + if '$ref' in json_schema: + schema_to_update = self.get_schema_from_definitions(JsonRef(json_schema['$ref'])) + if schema_to_update is None: + raise RuntimeError(f'Cannot update undefined schema for $ref={json_schema["$ref"]}') + return self.resolve_schema_to_update(schema_to_update) + else: + schema_to_update = json_schema + return schema_to_update + + def model_fields_schema(self, schema: core_schema.ModelFieldsSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a model's fields. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + named_required_fields: list[tuple[str, bool, CoreSchemaField]] = [ + (name, self.field_is_required(field, total=True), field) + for name, field in schema['fields'].items() + if self.field_is_present(field) + ] + if self.mode == 'serialization': + named_required_fields.extend(self._name_required_computed_fields(schema.get('computed_fields', []))) + json_schema = self._named_required_fields_schema(named_required_fields) + extras_schema = schema.get('extras_schema', None) + if extras_schema is not None: + schema_to_update = self.resolve_schema_to_update(json_schema) + schema_to_update['additionalProperties'] = self.generate_inner(extras_schema) + return json_schema + + def field_is_present(self, field: CoreSchemaField) -> bool: + """Whether the field should be included in the generated JSON schema. + + Args: + field: The schema for the field itself. + + Returns: + `True` if the field should be included in the generated JSON schema, `False` otherwise. + """ + if self.mode == 'serialization': + # If you still want to include the field in the generated JSON schema, + # override this method and return True + return not field.get('serialization_exclude') + elif self.mode == 'validation': + return True + else: + assert_never(self.mode) + + def field_is_required( + self, + field: core_schema.ModelField | core_schema.DataclassField | core_schema.TypedDictField, + total: bool, + ) -> bool: + """Whether the field should be marked as required in the generated JSON schema. + (Note that this is irrelevant if the field is not present in the JSON schema.). + + Args: + field: The schema for the field itself. + total: Only applies to `TypedDictField`s. + Indicates if the `TypedDict` this field belongs to is total, in which case any fields that don't + explicitly specify `required=False` are required. + + Returns: + `True` if the field should be marked as required in the generated JSON schema, `False` otherwise. + """ + if self.mode == 'serialization' and self._config.json_schema_serialization_defaults_required: + return not field.get('serialization_exclude') + else: + if field['type'] == 'typed-dict-field': + return field.get('required', total) + else: + return field['schema']['type'] != 'default' + + def dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a dataclass's constructor arguments. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + named_required_fields: list[tuple[str, bool, CoreSchemaField]] = [ + (field['name'], self.field_is_required(field, total=True), field) + for field in schema['fields'] + if self.field_is_present(field) + ] + if self.mode == 'serialization': + named_required_fields.extend(self._name_required_computed_fields(schema.get('computed_fields', []))) + return self._named_required_fields_schema(named_required_fields) + + def dataclass_schema(self, schema: core_schema.DataclassSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a dataclass. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + cls = schema['cls'] + config: ConfigDict = getattr(cls, '__pydantic_config__', cast('ConfigDict', {})) + title = config.get('title') or cls.__name__ + + with self._config_wrapper_stack.push(config): + json_schema = self.generate_inner(schema['schema']).copy() + + json_schema_extra = config.get('json_schema_extra') + json_schema = self._update_class_schema(json_schema, title, config.get('extra', None), cls, json_schema_extra) + + # Dataclass-specific handling of description + if is_dataclass(cls) and not hasattr(cls, '__pydantic_validator__'): + # vanilla dataclass; don't use cls.__doc__ as it will contain the class signature by default + description = None + else: + description = None if cls.__doc__ is None else inspect.cleandoc(cls.__doc__) + if description: + json_schema['description'] = description + + return json_schema + + def arguments_schema(self, schema: core_schema.ArgumentsSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a function's arguments. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + metadata = _core_metadata.CoreMetadataHandler(schema).metadata + prefer_positional = metadata.get('pydantic_js_prefer_positional_arguments') + + arguments = schema['arguments_schema'] + kw_only_arguments = [a for a in arguments if a.get('mode') == 'keyword_only'] + kw_or_p_arguments = [a for a in arguments if a.get('mode') in {'positional_or_keyword', None}] + p_only_arguments = [a for a in arguments if a.get('mode') == 'positional_only'] + var_args_schema = schema.get('var_args_schema') + var_kwargs_schema = schema.get('var_kwargs_schema') + + if prefer_positional: + positional_possible = not kw_only_arguments and not var_kwargs_schema + if positional_possible: + return self.p_arguments_schema(p_only_arguments + kw_or_p_arguments, var_args_schema) + + keyword_possible = not p_only_arguments and not var_args_schema + if keyword_possible: + return self.kw_arguments_schema(kw_or_p_arguments + kw_only_arguments, var_kwargs_schema) + + if not prefer_positional: + positional_possible = not kw_only_arguments and not var_kwargs_schema + if positional_possible: + return self.p_arguments_schema(p_only_arguments + kw_or_p_arguments, var_args_schema) + + raise PydanticInvalidForJsonSchema( + 'Unable to generate JSON schema for arguments validator with positional-only and keyword-only arguments' + ) + + def kw_arguments_schema( + self, arguments: list[core_schema.ArgumentsParameter], var_kwargs_schema: CoreSchema | None + ) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a function's keyword arguments. + + Args: + arguments: The core schema. + + Returns: + The generated JSON schema. + """ + properties: dict[str, JsonSchemaValue] = {} + required: list[str] = [] + for argument in arguments: + name = self.get_argument_name(argument) + argument_schema = self.generate_inner(argument['schema']).copy() + argument_schema['title'] = self.get_title_from_name(name) + properties[name] = argument_schema + + if argument['schema']['type'] != 'default': + # This assumes that if the argument has a default value, + # the inner schema must be of type WithDefaultSchema. + # I believe this is true, but I am not 100% sure + required.append(name) + + json_schema: JsonSchemaValue = {'type': 'object', 'properties': properties} + if required: + json_schema['required'] = required + + if var_kwargs_schema: + additional_properties_schema = self.generate_inner(var_kwargs_schema) + if additional_properties_schema: + json_schema['additionalProperties'] = additional_properties_schema + else: + json_schema['additionalProperties'] = False + return json_schema + + def p_arguments_schema( + self, arguments: list[core_schema.ArgumentsParameter], var_args_schema: CoreSchema | None + ) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a function's positional arguments. + + Args: + arguments: The core schema. + + Returns: + The generated JSON schema. + """ + prefix_items: list[JsonSchemaValue] = [] + min_items = 0 + + for argument in arguments: + name = self.get_argument_name(argument) + + argument_schema = self.generate_inner(argument['schema']).copy() + argument_schema['title'] = self.get_title_from_name(name) + prefix_items.append(argument_schema) + + if argument['schema']['type'] != 'default': + # This assumes that if the argument has a default value, + # the inner schema must be of type WithDefaultSchema. + # I believe this is true, but I am not 100% sure + min_items += 1 + + json_schema: JsonSchemaValue = {'type': 'array', 'prefixItems': prefix_items} + if min_items: + json_schema['minItems'] = min_items + + if var_args_schema: + items_schema = self.generate_inner(var_args_schema) + if items_schema: + json_schema['items'] = items_schema + else: + json_schema['maxItems'] = len(prefix_items) + + return json_schema + + def get_argument_name(self, argument: core_schema.ArgumentsParameter) -> str: + """Retrieves the name of an argument. + + Args: + argument: The core schema. + + Returns: + The name of the argument. + """ + name = argument['name'] + if self.by_alias: + alias = argument.get('alias') + if isinstance(alias, str): + name = alias + else: + pass # might want to do something else? + return name + + def call_schema(self, schema: core_schema.CallSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a function call. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return self.generate_inner(schema['arguments_schema']) + + def custom_error_schema(self, schema: core_schema.CustomErrorSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a custom error. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return self.generate_inner(schema['schema']) + + def json_schema(self, schema: core_schema.JsonSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a JSON object. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + content_core_schema = schema.get('schema') or core_schema.any_schema() + content_json_schema = self.generate_inner(content_core_schema) + if self.mode == 'validation': + return {'type': 'string', 'contentMediaType': 'application/json', 'contentSchema': content_json_schema} + else: + # self.mode == 'serialization' + return content_json_schema + + def url_schema(self, schema: core_schema.UrlSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a URL. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + json_schema = {'type': 'string', 'format': 'uri', 'minLength': 1} + self.update_with_validations(json_schema, schema, self.ValidationsMapping.string) + return json_schema + + def multi_host_url_schema(self, schema: core_schema.MultiHostUrlSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a URL that can be used with multiple hosts. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + # Note: 'multi-host-uri' is a custom/pydantic-specific format, not part of the JSON Schema spec + json_schema = {'type': 'string', 'format': 'multi-host-uri', 'minLength': 1} + self.update_with_validations(json_schema, schema, self.ValidationsMapping.string) + return json_schema + + def uuid_schema(self, schema: core_schema.UuidSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a UUID. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + return {'type': 'string', 'format': 'uuid'} + + def definitions_schema(self, schema: core_schema.DefinitionsSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that defines a JSON object with definitions. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + for definition in schema['definitions']: + try: + self.generate_inner(definition) + except PydanticInvalidForJsonSchema as e: + core_ref: CoreRef = CoreRef(definition['ref']) # type: ignore + self._core_defs_invalid_for_json_schema[self.get_defs_ref((core_ref, self.mode))] = e + continue + return self.generate_inner(schema['schema']) + + def definition_ref_schema(self, schema: core_schema.DefinitionReferenceSchema) -> JsonSchemaValue: + """Generates a JSON schema that matches a schema that references a definition. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + core_ref = CoreRef(schema['schema_ref']) + _, ref_json_schema = self.get_cache_defs_ref_schema(core_ref) + return ref_json_schema + + def ser_schema( + self, schema: core_schema.SerSchema | core_schema.IncExSeqSerSchema | core_schema.IncExDictSerSchema + ) -> JsonSchemaValue | None: + """Generates a JSON schema that matches a schema that defines a serialized object. + + Args: + schema: The core schema. + + Returns: + The generated JSON schema. + """ + schema_type = schema['type'] + if schema_type == 'function-plain' or schema_type == 'function-wrap': + # PlainSerializerFunctionSerSchema or WrapSerializerFunctionSerSchema + return_schema = schema.get('return_schema') + if return_schema is not None: + return self.generate_inner(return_schema) + elif schema_type == 'format' or schema_type == 'to-string': + # FormatSerSchema or ToStringSerSchema + return self.str_schema(core_schema.str_schema()) + elif schema['type'] == 'model': + # ModelSerSchema + return self.generate_inner(schema['schema']) + return None + + # ### Utility methods + + def get_title_from_name(self, name: str) -> str: + """Retrieves a title from a name. + + Args: + name: The name to retrieve a title from. + + Returns: + The title. + """ + return name.title().replace('_', ' ') + + def field_title_should_be_set(self, schema: CoreSchemaOrField) -> bool: + """Returns true if a field with the given schema should have a title set based on the field name. + + Intuitively, we want this to return true for schemas that wouldn't otherwise provide their own title + (e.g., int, float, str), and false for those that would (e.g., BaseModel subclasses). + + Args: + schema: The schema to check. + + Returns: + `True` if the field should have a title set, `False` otherwise. + """ + if _core_utils.is_core_schema_field(schema): + if schema['type'] == 'computed-field': + field_schema = schema['return_schema'] + else: + field_schema = schema['schema'] + return self.field_title_should_be_set(field_schema) + + elif _core_utils.is_core_schema(schema): + if schema.get('ref'): # things with refs, such as models and enums, should not have titles set + return False + if schema['type'] in {'default', 'nullable', 'definitions'}: + return self.field_title_should_be_set(schema['schema']) # type: ignore[typeddict-item] + if _core_utils.is_function_with_inner_schema(schema): + return self.field_title_should_be_set(schema['schema']) + if schema['type'] == 'definition-ref': + # Referenced schemas should not have titles set for the same reason + # schemas with refs should not + return False + return True # anything else should have title set + + else: + raise PydanticInvalidForJsonSchema(f'Unexpected schema type: schema={schema}') # pragma: no cover + + def normalize_name(self, name: str) -> str: + """Normalizes a name to be used as a key in a dictionary. + + Args: + name: The name to normalize. + + Returns: + The normalized name. + """ + return re.sub(r'[^a-zA-Z0-9.\-_]', '_', name).replace('.', '__') + + def get_defs_ref(self, core_mode_ref: CoreModeRef) -> DefsRef: + """Override this method to change the way that definitions keys are generated from a core reference. + + Args: + core_mode_ref: The core reference. + + Returns: + The definitions key. + """ + # Split the core ref into "components"; generic origins and arguments are each separate components + core_ref, mode = core_mode_ref + components = re.split(r'([\][,])', core_ref) + # Remove IDs from each component + components = [x.rsplit(':', 1)[0] for x in components] + core_ref_no_id = ''.join(components) + # Remove everything before the last period from each "component" + components = [re.sub(r'(?:[^.[\]]+\.)+((?:[^.[\]]+))', r'\1', x) for x in components] + short_ref = ''.join(components) + + mode_title = _MODE_TITLE_MAPPING[mode] + + # It is important that the generated defs_ref values be such that at least one choice will not + # be generated for any other core_ref. Currently, this should be the case because we include + # the id of the source type in the core_ref + name = DefsRef(self.normalize_name(short_ref)) + name_mode = DefsRef(self.normalize_name(short_ref) + f'-{mode_title}') + module_qualname = DefsRef(self.normalize_name(core_ref_no_id)) + module_qualname_mode = DefsRef(f'{module_qualname}-{mode_title}') + module_qualname_id = DefsRef(self.normalize_name(core_ref)) + occurrence_index = self._collision_index.get(module_qualname_id) + if occurrence_index is None: + self._collision_counter[module_qualname] += 1 + occurrence_index = self._collision_index[module_qualname_id] = self._collision_counter[module_qualname] + + module_qualname_occurrence = DefsRef(f'{module_qualname}__{occurrence_index}') + module_qualname_occurrence_mode = DefsRef(f'{module_qualname_mode}__{occurrence_index}') + + self._prioritized_defsref_choices[module_qualname_occurrence_mode] = [ + name, + name_mode, + module_qualname, + module_qualname_mode, + module_qualname_occurrence, + module_qualname_occurrence_mode, + ] + + return module_qualname_occurrence_mode + + def get_cache_defs_ref_schema(self, core_ref: CoreRef) -> tuple[DefsRef, JsonSchemaValue]: + """This method wraps the get_defs_ref method with some cache-lookup/population logic, + and returns both the produced defs_ref and the JSON schema that will refer to the right definition. + + Args: + core_ref: The core reference to get the definitions reference for. + + Returns: + A tuple of the definitions reference and the JSON schema that will refer to it. + """ + core_mode_ref = (core_ref, self.mode) + maybe_defs_ref = self.core_to_defs_refs.get(core_mode_ref) + if maybe_defs_ref is not None: + json_ref = self.core_to_json_refs[core_mode_ref] + return maybe_defs_ref, {'$ref': json_ref} + + defs_ref = self.get_defs_ref(core_mode_ref) + + # populate the ref translation mappings + self.core_to_defs_refs[core_mode_ref] = defs_ref + self.defs_to_core_refs[defs_ref] = core_mode_ref + + json_ref = JsonRef(self.ref_template.format(model=defs_ref)) + self.core_to_json_refs[core_mode_ref] = json_ref + self.json_to_defs_refs[json_ref] = defs_ref + ref_json_schema = {'$ref': json_ref} + return defs_ref, ref_json_schema + + def handle_ref_overrides(self, json_schema: JsonSchemaValue) -> JsonSchemaValue: + """It is not valid for a schema with a top-level $ref to have sibling keys. + + During our own schema generation, we treat sibling keys as overrides to the referenced schema, + but this is not how the official JSON schema spec works. + + Because of this, we first remove any sibling keys that are redundant with the referenced schema, then if + any remain, we transform the schema from a top-level '$ref' to use allOf to move the $ref out of the top level. + (See bottom of https://swagger.io/docs/specification/using-ref/ for a reference about this behavior) + """ + if '$ref' in json_schema: + # prevent modifications to the input; this copy may be safe to drop if there is significant overhead + json_schema = json_schema.copy() + + referenced_json_schema = self.get_schema_from_definitions(JsonRef(json_schema['$ref'])) + if referenced_json_schema is None: + # This can happen when building schemas for models with not-yet-defined references. + # It may be a good idea to do a recursive pass at the end of the generation to remove + # any redundant override keys. + if len(json_schema) > 1: + # Make it an allOf to at least resolve the sibling keys issue + json_schema = json_schema.copy() + json_schema.setdefault('allOf', []) + json_schema['allOf'].append({'$ref': json_schema['$ref']}) + del json_schema['$ref'] + + return json_schema + for k, v in list(json_schema.items()): + if k == '$ref': + continue + if k in referenced_json_schema and referenced_json_schema[k] == v: + del json_schema[k] # redundant key + if len(json_schema) > 1: + # There is a remaining "override" key, so we need to move $ref out of the top level + json_ref = JsonRef(json_schema['$ref']) + del json_schema['$ref'] + assert 'allOf' not in json_schema # this should never happen, but just in case + json_schema['allOf'] = [{'$ref': json_ref}] + + return json_schema + + def get_schema_from_definitions(self, json_ref: JsonRef) -> JsonSchemaValue | None: + def_ref = self.json_to_defs_refs[json_ref] + if def_ref in self._core_defs_invalid_for_json_schema: + raise self._core_defs_invalid_for_json_schema[def_ref] + return self.definitions.get(def_ref, None) + + def encode_default(self, dft: Any) -> Any: + """Encode a default value to a JSON-serializable value. + + This is used to encode default values for fields in the generated JSON schema. + + Args: + dft: The default value to encode. + + Returns: + The encoded default value. + """ + config = self._config + return pydantic_core.to_jsonable_python( + dft, + timedelta_mode=config.ser_json_timedelta, + bytes_mode=config.ser_json_bytes, + ) + + def update_with_validations( + self, json_schema: JsonSchemaValue, core_schema: CoreSchema, mapping: dict[str, str] + ) -> None: + """Update the json_schema with the corresponding validations specified in the core_schema, + using the provided mapping to translate keys in core_schema to the appropriate keys for a JSON schema. + + Args: + json_schema: The JSON schema to update. + core_schema: The core schema to get the validations from. + mapping: A mapping from core_schema attribute names to the corresponding JSON schema attribute names. + """ + for core_key, json_schema_key in mapping.items(): + if core_key in core_schema: + json_schema[json_schema_key] = core_schema[core_key] + + class ValidationsMapping: + """This class just contains mappings from core_schema attribute names to the corresponding + JSON schema attribute names. While I suspect it is unlikely to be necessary, you can in + principle override this class in a subclass of GenerateJsonSchema (by inheriting from + GenerateJsonSchema.ValidationsMapping) to change these mappings. + """ + + numeric = { + 'multiple_of': 'multipleOf', + 'le': 'maximum', + 'ge': 'minimum', + 'lt': 'exclusiveMaximum', + 'gt': 'exclusiveMinimum', + } + bytes = { + 'min_length': 'minLength', + 'max_length': 'maxLength', + } + string = { + 'min_length': 'minLength', + 'max_length': 'maxLength', + 'pattern': 'pattern', + } + array = { + 'min_length': 'minItems', + 'max_length': 'maxItems', + } + object = { + 'min_length': 'minProperties', + 'max_length': 'maxProperties', + } + date = { + 'le': 'maximum', + 'ge': 'minimum', + 'lt': 'exclusiveMaximum', + 'gt': 'exclusiveMinimum', + } + + def get_flattened_anyof(self, schemas: list[JsonSchemaValue]) -> JsonSchemaValue: + members = [] + for schema in schemas: + if len(schema) == 1 and 'anyOf' in schema: + members.extend(schema['anyOf']) + else: + members.append(schema) + members = _deduplicate_schemas(members) + if len(members) == 1: + return members[0] + return {'anyOf': members} + + def get_json_ref_counts(self, json_schema: JsonSchemaValue) -> dict[JsonRef, int]: + """Get all values corresponding to the key '$ref' anywhere in the json_schema.""" + json_refs: dict[JsonRef, int] = Counter() + + def _add_json_refs(schema: Any) -> None: + if isinstance(schema, dict): + if '$ref' in schema: + json_ref = JsonRef(schema['$ref']) + if not isinstance(json_ref, str): + return # in this case, '$ref' might have been the name of a property + already_visited = json_ref in json_refs + json_refs[json_ref] += 1 + if already_visited: + return # prevent recursion on a definition that was already visited + defs_ref = self.json_to_defs_refs[json_ref] + if defs_ref in self._core_defs_invalid_for_json_schema: + raise self._core_defs_invalid_for_json_schema[defs_ref] + _add_json_refs(self.definitions[defs_ref]) + + for v in schema.values(): + _add_json_refs(v) + elif isinstance(schema, list): + for v in schema: + _add_json_refs(v) + + _add_json_refs(json_schema) + return json_refs + + def handle_invalid_for_json_schema(self, schema: CoreSchemaOrField, error_info: str) -> JsonSchemaValue: + raise PydanticInvalidForJsonSchema(f'Cannot generate a JsonSchema for {error_info}') + + def emit_warning(self, kind: JsonSchemaWarningKind, detail: str) -> None: + """This method simply emits PydanticJsonSchemaWarnings based on handling in the `warning_message` method.""" + message = self.render_warning_message(kind, detail) + if message is not None: + warnings.warn(message, PydanticJsonSchemaWarning) + + def render_warning_message(self, kind: JsonSchemaWarningKind, detail: str) -> str | None: + """This method is responsible for ignoring warnings as desired, and for formatting the warning messages. + + You can override the value of `ignored_warning_kinds` in a subclass of GenerateJsonSchema + to modify what warnings are generated. If you want more control, you can override this method; + just return None in situations where you don't want warnings to be emitted. + + Args: + kind: The kind of warning to render. It can be one of the following: + + - 'skipped-choice': A choice field was skipped because it had no valid choices. + - 'non-serializable-default': A default value was skipped because it was not JSON-serializable. + detail: A string with additional details about the warning. + + Returns: + The formatted warning message, or `None` if no warning should be emitted. + """ + if kind in self.ignored_warning_kinds: + return None + return f'{detail} [{kind}]' + + def _build_definitions_remapping(self) -> _DefinitionsRemapping: + defs_to_json: dict[DefsRef, JsonRef] = {} + for defs_refs in self._prioritized_defsref_choices.values(): + for defs_ref in defs_refs: + json_ref = JsonRef(self.ref_template.format(model=defs_ref)) + defs_to_json[defs_ref] = json_ref + + return _DefinitionsRemapping.from_prioritized_choices( + self._prioritized_defsref_choices, defs_to_json, self.definitions + ) + + def _garbage_collect_definitions(self, schema: JsonSchemaValue) -> None: + visited_defs_refs: set[DefsRef] = set() + unvisited_json_refs = _get_all_json_refs(schema) + while unvisited_json_refs: + next_json_ref = unvisited_json_refs.pop() + next_defs_ref = self.json_to_defs_refs[next_json_ref] + if next_defs_ref in visited_defs_refs: + continue + visited_defs_refs.add(next_defs_ref) + unvisited_json_refs.update(_get_all_json_refs(self.definitions[next_defs_ref])) + + self.definitions = {k: v for k, v in self.definitions.items() if k in visited_defs_refs} + + +# ##### Start JSON Schema Generation Functions ##### + + +def model_json_schema( + cls: type[BaseModel] | type[PydanticDataclass], + by_alias: bool = True, + ref_template: str = DEFAULT_REF_TEMPLATE, + schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema, + mode: JsonSchemaMode = 'validation', +) -> dict[str, Any]: + """Utility function to generate a JSON Schema for a model. + + Args: + cls: The model class to generate a JSON Schema for. + by_alias: If `True` (the default), fields will be serialized according to their alias. + If `False`, fields will be serialized according to their attribute name. + ref_template: The template to use for generating JSON Schema references. + schema_generator: The class to use for generating the JSON Schema. + mode: The mode to use for generating the JSON Schema. It can be one of the following: + + - 'validation': Generate a JSON Schema for validating data. + - 'serialization': Generate a JSON Schema for serializing data. + + Returns: + The generated JSON Schema. + """ + schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template) + if isinstance(cls.__pydantic_validator__, _mock_val_ser.MockValSer): + cls.__pydantic_validator__.rebuild() + assert '__pydantic_core_schema__' in cls.__dict__, 'this is a bug! please report it' + return schema_generator_instance.generate(cls.__pydantic_core_schema__, mode=mode) + + +def models_json_schema( + models: Sequence[tuple[type[BaseModel] | type[PydanticDataclass], JsonSchemaMode]], + *, + by_alias: bool = True, + title: str | None = None, + description: str | None = None, + ref_template: str = DEFAULT_REF_TEMPLATE, + schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema, +) -> tuple[dict[tuple[type[BaseModel] | type[PydanticDataclass], JsonSchemaMode], JsonSchemaValue], JsonSchemaValue]: + """Utility function to generate a JSON Schema for multiple models. + + Args: + models: A sequence of tuples of the form (model, mode). + by_alias: Whether field aliases should be used as keys in the generated JSON Schema. + title: The title of the generated JSON Schema. + description: The description of the generated JSON Schema. + ref_template: The reference template to use for generating JSON Schema references. + schema_generator: The schema generator to use for generating the JSON Schema. + + Returns: + A tuple where: + - The first element is a dictionary whose keys are tuples of JSON schema key type and JSON mode, and + whose values are the JSON schema corresponding to that pair of inputs. (These schemas may have + JsonRef references to definitions that are defined in the second returned element.) + - The second element is a JSON schema containing all definitions referenced in the first returned + element, along with the optional title and description keys. + """ + for cls, _ in models: + if isinstance(cls.__pydantic_validator__, _mock_val_ser.MockValSer): + cls.__pydantic_validator__.rebuild() + + instance = schema_generator(by_alias=by_alias, ref_template=ref_template) + inputs = [(m, mode, m.__pydantic_core_schema__) for m, mode in models] + json_schemas_map, definitions = instance.generate_definitions(inputs) + + json_schema: dict[str, Any] = {} + if definitions: + json_schema['$defs'] = definitions + if title: + json_schema['title'] = title + if description: + json_schema['description'] = description + + return json_schemas_map, json_schema + + +# ##### End JSON Schema Generation Functions ##### + + +_HashableJsonValue: TypeAlias = Union[ + int, float, str, bool, None, Tuple['_HashableJsonValue', ...], Tuple[Tuple[str, '_HashableJsonValue'], ...] +] + + +def _deduplicate_schemas(schemas: Iterable[JsonDict]) -> list[JsonDict]: + return list({_make_json_hashable(schema): schema for schema in schemas}.values()) + + +def _make_json_hashable(value: JsonValue) -> _HashableJsonValue: + if isinstance(value, dict): + return tuple(sorted((k, _make_json_hashable(v)) for k, v in value.items())) + elif isinstance(value, list): + return tuple(_make_json_hashable(v) for v in value) + else: + return value + + +def _sort_json_schema(value: JsonSchemaValue, parent_key: str | None = None) -> JsonSchemaValue: + if isinstance(value, dict): + sorted_dict: dict[str, JsonSchemaValue] = {} + keys = value.keys() + if (parent_key != 'properties') and (parent_key != 'default'): + keys = sorted(keys) + for key in keys: + sorted_dict[key] = _sort_json_schema(value[key], parent_key=key) + return sorted_dict + elif isinstance(value, list): + sorted_list: list[JsonSchemaValue] = [] + for item in value: # type: ignore + sorted_list.append(_sort_json_schema(item, parent_key)) + return sorted_list # type: ignore + else: + return value + + +@dataclasses.dataclass(**_internal_dataclass.slots_true) +class WithJsonSchema: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/json_schema/#withjsonschema-annotation + + Add this as an annotation on a field to override the (base) JSON schema that would be generated for that field. + This provides a way to set a JSON schema for types that would otherwise raise errors when producing a JSON schema, + such as Callable, or types that have an is-instance core schema, without needing to go so far as creating a + custom subclass of pydantic.json_schema.GenerateJsonSchema. + Note that any _modifications_ to the schema that would normally be made (such as setting the title for model fields) + will still be performed. + + If `mode` is set this will only apply to that schema generation mode, allowing you + to set different json schemas for validation and serialization. + """ + + json_schema: JsonSchemaValue | None + mode: Literal['validation', 'serialization'] | None = None + + def __get_pydantic_json_schema__( + self, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + mode = self.mode or handler.mode + if mode != handler.mode: + return handler(core_schema) + if self.json_schema is None: + # This exception is handled in pydantic.json_schema.GenerateJsonSchema._named_required_fields_schema + raise PydanticOmit + else: + return self.json_schema + + def __hash__(self) -> int: + return hash(type(self.mode)) + + +@dataclasses.dataclass(**_internal_dataclass.slots_true) +class Examples: + """Add examples to a JSON schema. + + Examples should be a map of example names (strings) + to example values (any valid JSON). + + If `mode` is set this will only apply to that schema generation mode, + allowing you to add different examples for validation and serialization. + """ + + examples: dict[str, Any] + mode: Literal['validation', 'serialization'] | None = None + + def __get_pydantic_json_schema__( + self, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + mode = self.mode or handler.mode + json_schema = handler(core_schema) + if mode != handler.mode: + return json_schema + examples = json_schema.get('examples', {}) + examples.update(to_jsonable_python(self.examples)) + json_schema['examples'] = examples + return json_schema + + def __hash__(self) -> int: + return hash(type(self.mode)) + + +def _get_all_json_refs(item: Any) -> set[JsonRef]: + """Get all the definitions references from a JSON schema.""" + refs: set[JsonRef] = set() + if isinstance(item, dict): + for key, value in item.items(): + if key == '$ref' and isinstance(value, str): + # the isinstance check ensures that '$ref' isn't the name of a property, etc. + refs.add(JsonRef(value)) + elif isinstance(value, dict): + refs.update(_get_all_json_refs(value)) + elif isinstance(value, list): + for item in value: + refs.update(_get_all_json_refs(item)) + elif isinstance(item, list): + for item in item: + refs.update(_get_all_json_refs(item)) + return refs + + +AnyType = TypeVar('AnyType') + +if TYPE_CHECKING: + SkipJsonSchema = Annotated[AnyType, ...] +else: + + @dataclasses.dataclass(**_internal_dataclass.slots_true) + class SkipJsonSchema: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/json_schema/#skipjsonschema-annotation + + Add this as an annotation on a field to skip generating a JSON schema for that field. + + Example: + ```py + from typing import Union + + from pydantic import BaseModel + from pydantic.json_schema import SkipJsonSchema + + from pprint import pprint + + + class Model(BaseModel): + a: Union[int, None] = None # (1)! + b: Union[int, SkipJsonSchema[None]] = None # (2)! + c: SkipJsonSchema[Union[int, None]] = None # (3)! + + + pprint(Model.model_json_schema()) + ''' + { + 'properties': { + 'a': { + 'anyOf': [ + {'type': 'integer'}, + {'type': 'null'} + ], + 'default': None, + 'title': 'A' + }, + 'b': { + 'default': None, + 'title': 'B', + 'type': 'integer' + } + }, + 'title': 'Model', + 'type': 'object' + } + ''' + ``` + + 1. The integer and null types are both included in the schema for `a`. + 2. The integer type is the only type included in the schema for `b`. + 3. The entirety of the `c` field is omitted from the schema. + """ + + def __class_getitem__(cls, item: AnyType) -> AnyType: + return Annotated[item, cls()] + + def __get_pydantic_json_schema__( + self, core_schema: CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + raise PydanticOmit + + def __hash__(self) -> int: + return hash(type(self)) + + +def _get_typed_dict_config(schema: core_schema.TypedDictSchema) -> ConfigDict: + metadata = _core_metadata.CoreMetadataHandler(schema).metadata + cls = metadata.get('pydantic_typed_dict_cls') + if cls is not None: + try: + return _decorators.get_attribute_from_bases(cls, '__pydantic_config__') + except AttributeError: + pass + return {} diff --git a/lib/pydantic/main.py b/lib/pydantic/main.py index 69f3b751..8c7ebbbf 100644 --- a/lib/pydantic/main.py +++ b/lib/pydantic/main.py @@ -1,980 +1,1434 @@ +"""Logic for creating models.""" +from __future__ import annotations as _annotations + +import operator +import sys +import types +import typing import warnings -from abc import ABCMeta -from copy import deepcopy -from enum import Enum -from functools import partial -from pathlib import Path -from types import FunctionType, prepare_class, resolve_bases -from typing import ( - TYPE_CHECKING, - AbstractSet, - Any, - Callable, - ClassVar, - Dict, - List, - Mapping, - Optional, - Tuple, - Type, - TypeVar, - Union, - cast, - no_type_check, - overload, -) +from copy import copy, deepcopy +from typing import Any, ClassVar -from typing_extensions import dataclass_transform +import pydantic_core +import typing_extensions +from pydantic_core import PydanticUndefined -from .class_validators import ValidatorGroup, extract_root_validators, extract_validators, inherit_validators -from .config import BaseConfig, Extra, inherit_config, prepare_config -from .error_wrappers import ErrorWrapper, ValidationError -from .errors import ConfigError, DictError, ExtraError, MissingError -from .fields import ( - MAPPING_LIKE_SHAPES, - Field, - FieldInfo, - ModelField, - ModelPrivateAttr, - PrivateAttr, - Undefined, - is_finalvar_with_default_val, -) -from .json import custom_pydantic_encoder, pydantic_encoder -from .parse import Protocol, load_file, load_str_bytes -from .schema import default_ref_template, model_schema -from .types import PyObject, StrBytes -from .typing import ( - AnyCallable, - get_args, - get_origin, - is_classvar, - is_namedtuple, - is_union, - resolve_annotations, - update_model_forward_refs, -) -from .utils import ( - DUNDER_ATTRIBUTES, - ROOT_KEY, - ClassAttribute, - GetterDict, - Representation, - ValueItems, - generate_model_signature, - is_valid_field, - is_valid_private_name, - lenient_issubclass, - sequence_like, - smart_deepcopy, - unique_list, - validate_field_name, +from ._internal import ( + _config, + _decorators, + _fields, + _forward_ref, + _generics, + _mock_val_ser, + _model_construction, + _repr, + _typing_extra, + _utils, ) +from ._migration import getattr_migration +from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler +from .config import ConfigDict +from .errors import PydanticUndefinedAnnotation, PydanticUserError +from .json_schema import DEFAULT_REF_TEMPLATE, GenerateJsonSchema, JsonSchemaMode, JsonSchemaValue, model_json_schema +from .warnings import PydanticDeprecatedSince20 -if TYPE_CHECKING: +if typing.TYPE_CHECKING: from inspect import Signature + from pathlib import Path - from .class_validators import ValidatorListDict - from .types import ModelOrDc - from .typing import ( - AbstractSetIntStr, - AnyClassMethod, - CallableGenerator, - DictAny, - DictStrAny, - MappingIntStrAny, - ReprArgs, - SetStr, - TupleGenerator, - ) + from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator + from typing_extensions import Literal, Unpack - Model = TypeVar('Model', bound='BaseModel') + from ._internal._utils import AbstractSetIntStr, MappingIntStrAny + from .deprecated.parse import Protocol as DeprecatedParseProtocol + from .fields import ComputedFieldInfo, FieldInfo, ModelPrivateAttr + from .fields import Field as _Field -__all__ = 'BaseModel', 'create_model', 'validate_model' + TupleGenerator = typing.Generator[typing.Tuple[str, Any], None, None] + Model = typing.TypeVar('Model', bound='BaseModel') + # should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope + IncEx: typing_extensions.TypeAlias = 'set[int] | set[str] | dict[int, Any] | dict[str, Any] | None' +else: + # See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915 + # and https://youtrack.jetbrains.com/issue/PY-51428 + DeprecationWarning = PydanticDeprecatedSince20 -_T = TypeVar('_T') +__all__ = 'BaseModel', 'create_model' + +_object_setattr = _model_construction.object_setattr -def validate_custom_root_type(fields: Dict[str, ModelField]) -> None: - if len(fields) > 1: - raise ValueError(f'{ROOT_KEY} cannot be mixed with other fields') +class BaseModel(metaclass=_model_construction.ModelMetaclass): + """Usage docs: https://docs.pydantic.dev/2.6/concepts/models/ + A base class for creating Pydantic models. -def generate_hash_function(frozen: bool) -> Optional[Callable[[Any], int]]: - def hash_function(self_: Any) -> int: - return hash(self_.__class__) + hash(tuple(self_.__dict__.values())) + Attributes: + __class_vars__: The names of classvars defined on the model. + __private_attributes__: Metadata about the private attributes of the model. + __signature__: The signature for instantiating the model. - return hash_function if frozen else None + __pydantic_complete__: Whether model building is completed, or if there are still undefined fields. + __pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer. + __pydantic_custom_init__: Whether the model has a custom `__init__` function. + __pydantic_decorators__: Metadata containing the decorators defined on the model. + This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1. + __pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to + __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these. + __pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models. + __pydantic_post_init__: The name of the post-init method for the model, if defined. + __pydantic_root_model__: Whether the model is a `RootModel`. + __pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the model. + __pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the model. + __pydantic_extra__: An instance attribute with the values of extra fields from validation when + `model_config['extra'] == 'allow'`. + __pydantic_fields_set__: An instance attribute with the names of fields explicitly set. + __pydantic_private__: Instance attribute with the values of private attributes set on the model instance. + """ -# If a field is of type `Callable`, its default value should be a function and cannot to ignored. -ANNOTATED_FIELD_UNTOUCHED_TYPES: Tuple[Any, ...] = (property, type, classmethod, staticmethod) -# When creating a `BaseModel` instance, we bypass all the methods, properties... added to the model -UNTOUCHED_TYPES: Tuple[Any, ...] = (FunctionType,) + ANNOTATED_FIELD_UNTOUCHED_TYPES -# Note `ModelMetaclass` refers to `BaseModel`, but is also used to *create* `BaseModel`, so we need to add this extra -# (somewhat hacky) boolean to keep track of whether we've created the `BaseModel` class yet, and therefore whether it's -# safe to refer to it. If it *hasn't* been created, we assume that the `__new__` call we're in the middle of is for -# the `BaseModel` class, since that's defined immediately after the metaclass. -_is_base_model_class_defined = False + if typing.TYPE_CHECKING: + # Here we provide annotations for the attributes of BaseModel. + # Many of these are populated by the metaclass, which is why this section is in a `TYPE_CHECKING` block. + # However, for the sake of easy review, we have included type annotations of all class and instance attributes + # of `BaseModel` here: + # Class attributes + model_config: ClassVar[ConfigDict] + """ + Configuration for the model, should be a dictionary conforming to [`ConfigDict`][pydantic.config.ConfigDict]. + """ -@dataclass_transform(kw_only_default=True, field_descriptors=(Field, FieldInfo)) -class ModelMetaclass(ABCMeta): - @no_type_check # noqa C901 - def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901 - fields: Dict[str, ModelField] = {} - config = BaseConfig - validators: 'ValidatorListDict' = {} + model_fields: ClassVar[dict[str, FieldInfo]] + """ + Metadata about the fields defined on the model, + mapping of field names to [`FieldInfo`][pydantic.fields.FieldInfo]. - pre_root_validators, post_root_validators = [], [] - private_attributes: Dict[str, ModelPrivateAttr] = {} - base_private_attributes: Dict[str, ModelPrivateAttr] = {} - slots: SetStr = namespace.get('__slots__', ()) - slots = {slots} if isinstance(slots, str) else set(slots) - class_vars: SetStr = set() - hash_func: Optional[Callable[[Any], int]] = None + This replaces `Model.__fields__` from Pydantic V1. + """ - for base in reversed(bases): - if _is_base_model_class_defined and issubclass(base, BaseModel) and base != BaseModel: - fields.update(smart_deepcopy(base.__fields__)) - config = inherit_config(base.__config__, config) - validators = inherit_validators(base.__validators__, validators) - pre_root_validators += base.__pre_root_validators__ - post_root_validators += base.__post_root_validators__ - base_private_attributes.update(base.__private_attributes__) - class_vars.update(base.__class_vars__) - hash_func = base.__hash__ + model_computed_fields: ClassVar[dict[str, ComputedFieldInfo]] + """A dictionary of computed field names and their corresponding `ComputedFieldInfo` objects.""" - resolve_forward_refs = kwargs.pop('__resolve_forward_refs__', True) - allowed_config_kwargs: SetStr = { - key - for key in dir(config) - if not (key.startswith('__') and key.endswith('__')) # skip dunder methods and attributes - } - config_kwargs = {key: kwargs.pop(key) for key in kwargs.keys() & allowed_config_kwargs} - config_from_namespace = namespace.get('Config') - if config_kwargs and config_from_namespace: - raise TypeError('Specifying config in two places is ambiguous, use either Config attribute or class kwargs') - config = inherit_config(config_from_namespace, config, **config_kwargs) + __class_vars__: ClassVar[set[str]] + __private_attributes__: ClassVar[dict[str, ModelPrivateAttr]] + __signature__: ClassVar[Signature] - validators = inherit_validators(extract_validators(namespace), validators) - vg = ValidatorGroup(validators) + __pydantic_complete__: ClassVar[bool] + __pydantic_core_schema__: ClassVar[CoreSchema] + __pydantic_custom_init__: ClassVar[bool] + __pydantic_decorators__: ClassVar[_decorators.DecoratorInfos] + __pydantic_generic_metadata__: ClassVar[_generics.PydanticGenericMetadata] + __pydantic_parent_namespace__: ClassVar[dict[str, Any] | None] + __pydantic_post_init__: ClassVar[None | Literal['model_post_init']] + __pydantic_root_model__: ClassVar[bool] + __pydantic_serializer__: ClassVar[SchemaSerializer] + __pydantic_validator__: ClassVar[SchemaValidator] - for f in fields.values(): - f.set_config(config) - extra_validators = vg.get_validators(f.name) - if extra_validators: - f.class_validators.update(extra_validators) - # re-run prepare to add extra validators - f.populate_validators() + # Instance attributes + # Note: we use the non-existent kwarg `init=False` in pydantic.fields.Field below so that @dataclass_transform + # doesn't think these are valid as keyword arguments to the class initializer. + __pydantic_extra__: dict[str, Any] | None = _Field(init=False) # type: ignore + __pydantic_fields_set__: set[str] = _Field(init=False) # type: ignore + __pydantic_private__: dict[str, Any] | None = _Field(init=False) # type: ignore - prepare_config(config, name) + else: + # `model_fields` and `__pydantic_decorators__` must be set for + # pydantic._internal._generate_schema.GenerateSchema.model_schema to work for a plain BaseModel annotation + model_fields = {} + model_computed_fields = {} - untouched_types = ANNOTATED_FIELD_UNTOUCHED_TYPES + __pydantic_decorators__ = _decorators.DecoratorInfos() + __pydantic_parent_namespace__ = None + # Prevent `BaseModel` from being instantiated directly: + __pydantic_validator__ = _mock_val_ser.MockValSer( + 'Pydantic models should inherit from BaseModel, BaseModel cannot be instantiated directly', + val_or_ser='validator', + code='base-model-instantiated', + ) + __pydantic_serializer__ = _mock_val_ser.MockValSer( + 'Pydantic models should inherit from BaseModel, BaseModel cannot be instantiated directly', + val_or_ser='serializer', + code='base-model-instantiated', + ) - def is_untouched(v: Any) -> bool: - return isinstance(v, untouched_types) or v.__class__.__name__ == 'cython_function_or_method' + __slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__' - if (namespace.get('__module__'), namespace.get('__qualname__')) != ('pydantic.main', 'BaseModel'): - annotations = resolve_annotations(namespace.get('__annotations__', {}), namespace.get('__module__', None)) - # annotation only fields need to come first in fields - for ann_name, ann_type in annotations.items(): - if is_classvar(ann_type): - class_vars.add(ann_name) - elif is_finalvar_with_default_val(ann_type, namespace.get(ann_name, Undefined)): - class_vars.add(ann_name) - elif is_valid_field(ann_name): - validate_field_name(bases, ann_name) - value = namespace.get(ann_name, Undefined) - allowed_types = get_args(ann_type) if is_union(get_origin(ann_type)) else (ann_type,) - if ( - is_untouched(value) - and ann_type != PyObject - and not any( - lenient_issubclass(get_origin(allowed_type), Type) for allowed_type in allowed_types - ) - ): - continue - fields[ann_name] = ModelField.infer( - name=ann_name, - value=value, - annotation=ann_type, - class_validators=vg.get_validators(ann_name), - config=config, - ) - elif ann_name not in namespace and config.underscore_attrs_are_private: - private_attributes[ann_name] = PrivateAttr() + model_config = ConfigDict() + __pydantic_complete__ = False + __pydantic_root_model__ = False - untouched_types = UNTOUCHED_TYPES + config.keep_untouched - for var_name, value in namespace.items(): - can_be_changed = var_name not in class_vars and not is_untouched(value) - if isinstance(value, ModelPrivateAttr): - if not is_valid_private_name(var_name): - raise NameError( - f'Private attributes "{var_name}" must not be a valid field name; ' - f'Use sunder or dunder names, e. g. "_{var_name}" or "__{var_name}__"' - ) - private_attributes[var_name] = value - elif config.underscore_attrs_are_private and is_valid_private_name(var_name) and can_be_changed: - private_attributes[var_name] = PrivateAttr(default=value) - elif is_valid_field(var_name) and var_name not in annotations and can_be_changed: - validate_field_name(bases, var_name) - inferred = ModelField.infer( - name=var_name, - value=value, - annotation=annotations.get(var_name, Undefined), - class_validators=vg.get_validators(var_name), - config=config, - ) - if var_name in fields: - if lenient_issubclass(inferred.type_, fields[var_name].type_): - inferred.type_ = fields[var_name].type_ - else: - raise TypeError( - f'The type of {name}.{var_name} differs from the new default value; ' - f'if you wish to change the type of this field, please use a type annotation' - ) - fields[var_name] = inferred + def __init__(self, /, **data: Any) -> None: # type: ignore + """Create a new model by parsing and validating input data from keyword arguments. - _custom_root_type = ROOT_KEY in fields - if _custom_root_type: - validate_custom_root_type(fields) - vg.check_for_unused() - if config.json_encoders: - json_encoder = partial(custom_pydantic_encoder, config.json_encoders) + Raises [`ValidationError`][pydantic_core.ValidationError] if the input data cannot be + validated to form a valid model. + + `self` is explicitly positional-only to allow `self` as a field name. + """ + # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks + __tracebackhide__ = True + self.__pydantic_validator__.validate_python(data, self_instance=self) + + # The following line sets a flag that we use to determine when `__init__` gets overridden by the user + __init__.__pydantic_base_init__ = True # pyright: ignore[reportFunctionMemberAccess] + + @property + def model_extra(self) -> dict[str, Any] | None: + """Get extra fields set during validation. + + Returns: + A dictionary of extra fields, or `None` if `config.extra` is not set to `"allow"`. + """ + return self.__pydantic_extra__ + + @property + def model_fields_set(self) -> set[str]: + """Returns the set of fields that have been explicitly set on this model instance. + + Returns: + A set of strings representing the fields that have been set, + i.e. that were not filled from defaults. + """ + return self.__pydantic_fields_set__ + + @classmethod + def model_construct(cls: type[Model], _fields_set: set[str] | None = None, **values: Any) -> Model: + """Creates a new instance of the `Model` class with validated data. + + Creates a new model setting `__dict__` and `__pydantic_fields_set__` from trusted or pre-validated data. + Default values are respected, but no other validation is performed. + Behaves as if `Config.extra = 'allow'` was set since it adds all passed values + + Args: + _fields_set: The set of field names accepted for the Model instance. + values: Trusted or pre-validated data dictionary. + + Returns: + A new instance of the `Model` class with validated data. + """ + m = cls.__new__(cls) + fields_values: dict[str, Any] = {} + fields_set = set() + + for name, field in cls.model_fields.items(): + if field.alias and field.alias in values: + fields_values[name] = values.pop(field.alias) + fields_set.add(name) + elif name in values: + fields_values[name] = values.pop(name) + fields_set.add(name) + elif not field.is_required(): + fields_values[name] = field.get_default(call_default_factory=True) + if _fields_set is None: + _fields_set = fields_set + + _extra: dict[str, Any] | None = None + if cls.model_config.get('extra') == 'allow': + _extra = {} + for k, v in values.items(): + _extra[k] = v else: - json_encoder = pydantic_encoder - pre_rv_new, post_rv_new = extract_root_validators(namespace) + fields_values.update(values) + _object_setattr(m, '__dict__', fields_values) + _object_setattr(m, '__pydantic_fields_set__', _fields_set) + if not cls.__pydantic_root_model__: + _object_setattr(m, '__pydantic_extra__', _extra) - if hash_func is None: - hash_func = generate_hash_function(config.frozen) + if cls.__pydantic_post_init__: + m.model_post_init(None) + # update private attributes with values set + if hasattr(m, '__pydantic_private__') and m.__pydantic_private__ is not None: + for k, v in values.items(): + if k in m.__private_attributes__: + m.__pydantic_private__[k] = v - exclude_from_namespace = fields | private_attributes.keys() | {'__slots__'} - new_namespace = { - '__config__': config, - '__fields__': fields, - '__exclude_fields__': { - name: field.field_info.exclude for name, field in fields.items() if field.field_info.exclude is not None - } - or None, - '__include_fields__': { - name: field.field_info.include for name, field in fields.items() if field.field_info.include is not None - } - or None, - '__validators__': vg.validators, - '__pre_root_validators__': unique_list( - pre_root_validators + pre_rv_new, - name_factory=lambda v: v.__name__, - ), - '__post_root_validators__': unique_list( - post_root_validators + post_rv_new, - name_factory=lambda skip_on_failure_and_v: skip_on_failure_and_v[1].__name__, - ), - '__schema_cache__': {}, - '__json_encoder__': staticmethod(json_encoder), - '__custom_root_type__': _custom_root_type, - '__private_attributes__': {**base_private_attributes, **private_attributes}, - '__slots__': slots | private_attributes.keys(), - '__hash__': hash_func, - '__class_vars__': class_vars, - **{n: v for n, v in namespace.items() if n not in exclude_from_namespace}, - } + elif not cls.__pydantic_root_model__: + # Note: if there are any private attributes, cls.__pydantic_post_init__ would exist + # Since it doesn't, that means that `__pydantic_private__` should be set to None + _object_setattr(m, '__pydantic_private__', None) - cls = super().__new__(mcs, name, bases, new_namespace, **kwargs) - # set __signature__ attr only for model class, but not for its instances - cls.__signature__ = ClassAttribute('__signature__', generate_model_signature(cls.__init__, fields, config)) - if resolve_forward_refs: - cls.__try_update_forward_refs__() + return m - # preserve `__set_name__` protocol defined in https://peps.python.org/pep-0487 - # for attributes not in `new_namespace` (e.g. private attributes) - for name, obj in namespace.items(): - if name not in new_namespace: - set_name = getattr(obj, '__set_name__', None) - if callable(set_name): - set_name(cls, name) + def model_copy(self: Model, *, update: dict[str, Any] | None = None, deep: bool = False) -> Model: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/serialization/#model_copy - return cls + Returns a copy of the model. - def __instancecheck__(self, instance: Any) -> bool: + Args: + update: Values to change/add in the new model. Note: the data is not validated + before creating the new model. You should trust this data. + deep: Set to `True` to make a deep copy of the model. + + Returns: + New model instance. """ - Avoid calling ABC _abc_subclasscheck unless we're pretty sure. + copied = self.__deepcopy__() if deep else self.__copy__() + if update: + if self.model_config.get('extra') == 'allow': + for k, v in update.items(): + if k in self.model_fields: + copied.__dict__[k] = v + else: + if copied.__pydantic_extra__ is None: + copied.__pydantic_extra__ = {} + copied.__pydantic_extra__[k] = v + else: + copied.__dict__.update(update) + copied.__pydantic_fields_set__.update(update.keys()) + return copied - See #3829 and python/cpython#92810 - """ - return hasattr(instance, '__fields__') and super().__instancecheck__(instance) - - -object_setattr = object.__setattr__ - - -class BaseModel(Representation, metaclass=ModelMetaclass): - if TYPE_CHECKING: - # populated by the metaclass, defined here to help IDEs only - __fields__: ClassVar[Dict[str, ModelField]] = {} - __include_fields__: ClassVar[Optional[Mapping[str, Any]]] = None - __exclude_fields__: ClassVar[Optional[Mapping[str, Any]]] = None - __validators__: ClassVar[Dict[str, AnyCallable]] = {} - __pre_root_validators__: ClassVar[List[AnyCallable]] - __post_root_validators__: ClassVar[List[Tuple[bool, AnyCallable]]] - __config__: ClassVar[Type[BaseConfig]] = BaseConfig - __json_encoder__: ClassVar[Callable[[Any], Any]] = lambda x: x - __schema_cache__: ClassVar['DictAny'] = {} - __custom_root_type__: ClassVar[bool] = False - __signature__: ClassVar['Signature'] - __private_attributes__: ClassVar[Dict[str, ModelPrivateAttr]] - __class_vars__: ClassVar[SetStr] - __fields_set__: ClassVar[SetStr] = set() - - Config = BaseConfig - __slots__ = ('__dict__', '__fields_set__') - __doc__ = '' # Null out the Representation docstring - - def __init__(__pydantic_self__, **data: Any) -> None: - """ - Create a new model by parsing and validating input data from keyword arguments. - - Raises ValidationError if the input data cannot be parsed to form a valid model. - """ - # Uses something other than `self` the first arg to allow "self" as a settable attribute - values, fields_set, validation_error = validate_model(__pydantic_self__.__class__, data) - if validation_error: - raise validation_error - try: - object_setattr(__pydantic_self__, '__dict__', values) - except TypeError as e: - raise TypeError( - 'Model values must be a dict; you may not have returned a dictionary from a root validator' - ) from e - object_setattr(__pydantic_self__, '__fields_set__', fields_set) - __pydantic_self__._init_private_attributes() - - @no_type_check - def __setattr__(self, name, value): # noqa: C901 (ignore complexity) - if name in self.__private_attributes__ or name in DUNDER_ATTRIBUTES: - return object_setattr(self, name, value) - - if self.__config__.extra is not Extra.allow and name not in self.__fields__: - raise ValueError(f'"{self.__class__.__name__}" object has no field "{name}"') - elif not self.__config__.allow_mutation or self.__config__.frozen: - raise TypeError(f'"{self.__class__.__name__}" is immutable and does not support item assignment') - elif name in self.__fields__ and self.__fields__[name].final: - raise TypeError( - f'"{self.__class__.__name__}" object "{name}" field is final and does not support reassignment' - ) - elif self.__config__.validate_assignment: - new_values = {**self.__dict__, name: value} - - for validator in self.__pre_root_validators__: - try: - new_values = validator(self.__class__, new_values) - except (ValueError, TypeError, AssertionError) as exc: - raise ValidationError([ErrorWrapper(exc, loc=ROOT_KEY)], self.__class__) - - known_field = self.__fields__.get(name, None) - if known_field: - # We want to - # - make sure validators are called without the current value for this field inside `values` - # - keep other values (e.g. submodels) untouched (using `BaseModel.dict()` will change them into dicts) - # - keep the order of the fields - if not known_field.field_info.allow_mutation: - raise TypeError(f'"{known_field.name}" has allow_mutation set to False and cannot be assigned') - dict_without_original_value = {k: v for k, v in self.__dict__.items() if k != name} - value, error_ = known_field.validate(value, dict_without_original_value, loc=name, cls=self.__class__) - if error_: - raise ValidationError([error_], self.__class__) - else: - new_values[name] = value - - errors = [] - for skip_on_failure, validator in self.__post_root_validators__: - if skip_on_failure and errors: - continue - try: - new_values = validator(self.__class__, new_values) - except (ValueError, TypeError, AssertionError) as exc: - errors.append(ErrorWrapper(exc, loc=ROOT_KEY)) - if errors: - raise ValidationError(errors, self.__class__) - - # update the whole __dict__ as other values than just `value` - # may be changed (e.g. with `root_validator`) - object_setattr(self, '__dict__', new_values) - else: - self.__dict__[name] = value - - self.__fields_set__.add(name) - - def __getstate__(self) -> 'DictAny': - private_attrs = ((k, getattr(self, k, Undefined)) for k in self.__private_attributes__) - return { - '__dict__': self.__dict__, - '__fields_set__': self.__fields_set__, - '__private_attribute_values__': {k: v for k, v in private_attrs if v is not Undefined}, - } - - def __setstate__(self, state: 'DictAny') -> None: - object_setattr(self, '__dict__', state['__dict__']) - object_setattr(self, '__fields_set__', state['__fields_set__']) - for name, value in state.get('__private_attribute_values__', {}).items(): - object_setattr(self, name, value) - - def _init_private_attributes(self) -> None: - for name, private_attr in self.__private_attributes__.items(): - default = private_attr.get_default() - if default is not Undefined: - object_setattr(self, name, default) - - def dict( + def model_dump( self, *, - include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, - exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + mode: Literal['json', 'python'] | str = 'python', + include: IncEx = None, + exclude: IncEx = None, by_alias: bool = False, - skip_defaults: Optional[bool] = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - ) -> 'DictStrAny': - """ + round_trip: bool = False, + warnings: bool = True, + ) -> dict[str, Any]: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/serialization/#modelmodel_dump + Generate a dictionary representation of the model, optionally specifying which fields to include or exclude. - """ - if skip_defaults is not None: - warnings.warn( - f'{self.__class__.__name__}.dict(): "skip_defaults" is deprecated and replaced by "exclude_unset"', - DeprecationWarning, - ) - exclude_unset = skip_defaults + Args: + mode: The mode in which `to_python` should run. + If mode is 'json', the output will only contain JSON serializable types. + If mode is 'python', the output may contain non-JSON-serializable Python objects. + include: A list of fields to include in the output. + exclude: A list of fields to exclude from the output. + by_alias: Whether to use the field's alias in the dictionary key if defined. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that are set to their default value. + exclude_none: Whether to exclude fields that have a value of `None`. + round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T]. + warnings: Whether to log warnings when invalid fields are encountered. - return dict( - self._iter( - to_dict=True, - by_alias=by_alias, - include=include, - exclude=exclude, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) + Returns: + A dictionary representation of the model. + """ + return self.__pydantic_serializer__.to_python( + self, + mode=mode, + by_alias=by_alias, + include=include, + exclude=exclude, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + round_trip=round_trip, + warnings=warnings, ) - def json( + def model_dump_json( self, *, - include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, - exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + indent: int | None = None, + include: IncEx = None, + exclude: IncEx = None, by_alias: bool = False, - skip_defaults: Optional[bool] = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - encoder: Optional[Callable[[Any], Any]] = None, - models_as_dict: bool = True, + round_trip: bool = False, + warnings: bool = True, + ) -> str: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/serialization/#modelmodel_dump_json + + Generates a JSON representation of the model using Pydantic's `to_json` method. + + Args: + indent: Indentation to use in the JSON output. If None is passed, the output will be compact. + include: Field(s) to include in the JSON output. + exclude: Field(s) to exclude from the JSON output. + by_alias: Whether to serialize using field aliases. + exclude_unset: Whether to exclude fields that have not been explicitly set. + exclude_defaults: Whether to exclude fields that are set to their default value. + exclude_none: Whether to exclude fields that have a value of `None`. + round_trip: If True, dumped values should be valid as input for non-idempotent types such as Json[T]. + warnings: Whether to log warnings when invalid fields are encountered. + + Returns: + A JSON string representation of the model. + """ + return self.__pydantic_serializer__.to_json( + self, + indent=indent, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + round_trip=round_trip, + warnings=warnings, + ).decode() + + @classmethod + def model_json_schema( + cls, + by_alias: bool = True, + ref_template: str = DEFAULT_REF_TEMPLATE, + schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema, + mode: JsonSchemaMode = 'validation', + ) -> dict[str, Any]: + """Generates a JSON schema for a model class. + + Args: + by_alias: Whether to use attribute aliases or not. + ref_template: The reference template. + schema_generator: To override the logic used to generate the JSON schema, as a subclass of + `GenerateJsonSchema` with your desired modifications + mode: The mode in which to generate the schema. + + Returns: + The JSON schema for the given model class. + """ + return model_json_schema( + cls, by_alias=by_alias, ref_template=ref_template, schema_generator=schema_generator, mode=mode + ) + + @classmethod + def model_parametrized_name(cls, params: tuple[type[Any], ...]) -> str: + """Compute the class name for parametrizations of generic classes. + + This method can be overridden to achieve a custom naming scheme for generic BaseModels. + + Args: + params: Tuple of types of the class. Given a generic class + `Model` with 2 type variables and a concrete model `Model[str, int]`, + the value `(str, int)` would be passed to `params`. + + Returns: + String representing the new class where `params` are passed to `cls` as type variables. + + Raises: + TypeError: Raised when trying to generate concrete names for non-generic models. + """ + if not issubclass(cls, typing.Generic): + raise TypeError('Concrete names should only be generated for generic models.') + + # Any strings received should represent forward references, so we handle them specially below. + # If we eventually move toward wrapping them in a ForwardRef in __class_getitem__ in the future, + # we may be able to remove this special case. + param_names = [param if isinstance(param, str) else _repr.display_as_type(param) for param in params] + params_component = ', '.join(param_names) + return f'{cls.__name__}[{params_component}]' + + def model_post_init(self, __context: Any) -> None: + """Override this method to perform additional initialization after `__init__` and `model_construct`. + This is useful if you want to do some validation that requires the entire model to be initialized. + """ + pass + + @classmethod + def model_rebuild( + cls, + *, + force: bool = False, + raise_errors: bool = True, + _parent_namespace_depth: int = 2, + _types_namespace: dict[str, Any] | None = None, + ) -> bool | None: + """Try to rebuild the pydantic-core schema for the model. + + This may be necessary when one of the annotations is a ForwardRef which could not be resolved during + the initial attempt to build the schema, and automatic rebuilding fails. + + Args: + force: Whether to force the rebuilding of the model schema, defaults to `False`. + raise_errors: Whether to raise errors, defaults to `True`. + _parent_namespace_depth: The depth level of the parent namespace, defaults to 2. + _types_namespace: The types namespace, defaults to `None`. + + Returns: + Returns `None` if the schema is already "complete" and rebuilding was not required. + If rebuilding _was_ required, returns `True` if rebuilding was successful, otherwise `False`. + """ + if not force and cls.__pydantic_complete__: + return None + else: + if '__pydantic_core_schema__' in cls.__dict__: + delattr(cls, '__pydantic_core_schema__') # delete cached value to ensure full rebuild happens + if _types_namespace is not None: + types_namespace: dict[str, Any] | None = _types_namespace.copy() + else: + if _parent_namespace_depth > 0: + frame_parent_ns = _typing_extra.parent_frame_namespace(parent_depth=_parent_namespace_depth) or {} + cls_parent_ns = ( + _model_construction.unpack_lenient_weakvaluedict(cls.__pydantic_parent_namespace__) or {} + ) + types_namespace = {**cls_parent_ns, **frame_parent_ns} + cls.__pydantic_parent_namespace__ = _model_construction.build_lenient_weakvaluedict(types_namespace) + else: + types_namespace = _model_construction.unpack_lenient_weakvaluedict( + cls.__pydantic_parent_namespace__ + ) + + types_namespace = _typing_extra.get_cls_types_namespace(cls, types_namespace) + + # manually override defer_build so complete_model_class doesn't skip building the model again + config = {**cls.model_config, 'defer_build': False} + return _model_construction.complete_model_class( + cls, + cls.__name__, + _config.ConfigWrapper(config, check=False), + raise_errors=raise_errors, + types_namespace=types_namespace, + ) + + @classmethod + def model_validate( + cls: type[Model], + obj: Any, + *, + strict: bool | None = None, + from_attributes: bool | None = None, + context: dict[str, Any] | None = None, + ) -> Model: + """Validate a pydantic model instance. + + Args: + obj: The object to validate. + strict: Whether to enforce types strictly. + from_attributes: Whether to extract data from object attributes. + context: Additional context to pass to the validator. + + Raises: + ValidationError: If the object could not be validated. + + Returns: + The validated model instance. + """ + # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks + __tracebackhide__ = True + return cls.__pydantic_validator__.validate_python( + obj, strict=strict, from_attributes=from_attributes, context=context + ) + + @classmethod + def model_validate_json( + cls: type[Model], + json_data: str | bytes | bytearray, + *, + strict: bool | None = None, + context: dict[str, Any] | None = None, + ) -> Model: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/json/#json-parsing + + Validate the given JSON data against the Pydantic model. + + Args: + json_data: The JSON data to validate. + strict: Whether to enforce types strictly. + context: Extra variables to pass to the validator. + + Returns: + The validated Pydantic model. + + Raises: + ValueError: If `json_data` is not a JSON string. + """ + # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks + __tracebackhide__ = True + return cls.__pydantic_validator__.validate_json(json_data, strict=strict, context=context) + + @classmethod + def model_validate_strings( + cls: type[Model], + obj: Any, + *, + strict: bool | None = None, + context: dict[str, Any] | None = None, + ) -> Model: + """Validate the given object contains string data against the Pydantic model. + + Args: + obj: The object contains string data to validate. + strict: Whether to enforce types strictly. + context: Extra variables to pass to the validator. + + Returns: + The validated Pydantic model. + """ + # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks + __tracebackhide__ = True + return cls.__pydantic_validator__.validate_strings(obj, strict=strict, context=context) + + @classmethod + def __get_pydantic_core_schema__(cls, __source: type[BaseModel], __handler: GetCoreSchemaHandler) -> CoreSchema: + """Hook into generating the model's CoreSchema. + + Args: + __source: The class we are generating a schema for. + This will generally be the same as the `cls` argument if this is a classmethod. + __handler: Call into Pydantic's internal JSON schema generation. + A callable that calls into Pydantic's internal CoreSchema generation logic. + + Returns: + A `pydantic-core` `CoreSchema`. + """ + # Only use the cached value from this _exact_ class; we don't want one from a parent class + # This is why we check `cls.__dict__` and don't use `cls.__pydantic_core_schema__` or similar. + if '__pydantic_core_schema__' in cls.__dict__: + # Due to the way generic classes are built, it's possible that an invalid schema may be temporarily + # set on generic classes. I think we could resolve this to ensure that we get proper schema caching + # for generics, but for simplicity for now, we just always rebuild if the class has a generic origin. + if not cls.__pydantic_generic_metadata__['origin']: + return cls.__pydantic_core_schema__ + + return __handler(__source) + + @classmethod + def __get_pydantic_json_schema__( + cls, + __core_schema: CoreSchema, + __handler: GetJsonSchemaHandler, + ) -> JsonSchemaValue: + """Hook into generating the model's JSON schema. + + Args: + __core_schema: A `pydantic-core` CoreSchema. + You can ignore this argument and call the handler with a new CoreSchema, + wrap this CoreSchema (`{'type': 'nullable', 'schema': current_schema}`), + or just call the handler with the original schema. + __handler: Call into Pydantic's internal JSON schema generation. + This will raise a `pydantic.errors.PydanticInvalidForJsonSchema` if JSON schema + generation fails. + Since this gets called by `BaseModel.model_json_schema` you can override the + `schema_generator` argument to that function to change JSON schema generation globally + for a type. + + Returns: + A JSON schema, as a Python object. + """ + return __handler(__core_schema) + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs: Any) -> None: + """This is intended to behave just like `__init_subclass__`, but is called by `ModelMetaclass` + only after the class is actually fully initialized. In particular, attributes like `model_fields` will + be present when this is called. + + This is necessary because `__init_subclass__` will always be called by `type.__new__`, + and it would require a prohibitively large refactor to the `ModelMetaclass` to ensure that + `type.__new__` was called in such a manner that the class would already be sufficiently initialized. + + This will receive the same `kwargs` that would be passed to the standard `__init_subclass__`, namely, + any kwargs passed to the class definition that aren't used internally by pydantic. + + Args: + **kwargs: Any keyword arguments passed to the class definition that aren't used internally + by pydantic. + """ + pass + + def __class_getitem__( + cls, typevar_values: type[Any] | tuple[type[Any], ...] + ) -> type[BaseModel] | _forward_ref.PydanticRecursiveRef: + cached = _generics.get_cached_generic_type_early(cls, typevar_values) + if cached is not None: + return cached + + if cls is BaseModel: + raise TypeError('Type parameters should be placed on typing.Generic, not BaseModel') + if not hasattr(cls, '__parameters__'): + raise TypeError(f'{cls} cannot be parametrized because it does not inherit from typing.Generic') + if not cls.__pydantic_generic_metadata__['parameters'] and typing.Generic not in cls.__bases__: + raise TypeError(f'{cls} is not a generic class') + + if not isinstance(typevar_values, tuple): + typevar_values = (typevar_values,) + _generics.check_parameters_count(cls, typevar_values) + + # Build map from generic typevars to passed params + typevars_map: dict[_typing_extra.TypeVarType, type[Any]] = dict( + zip(cls.__pydantic_generic_metadata__['parameters'], typevar_values) + ) + + if _utils.all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map: + submodel = cls # if arguments are equal to parameters it's the same object + _generics.set_cached_generic_type(cls, typevar_values, submodel) + else: + parent_args = cls.__pydantic_generic_metadata__['args'] + if not parent_args: + args = typevar_values + else: + args = tuple(_generics.replace_types(arg, typevars_map) for arg in parent_args) + + origin = cls.__pydantic_generic_metadata__['origin'] or cls + model_name = origin.model_parametrized_name(args) + params = tuple( + {param: None for param in _generics.iter_contained_typevars(typevars_map.values())} + ) # use dict as ordered set + + with _generics.generic_recursion_self_type(origin, args) as maybe_self_type: + if maybe_self_type is not None: + return maybe_self_type + + cached = _generics.get_cached_generic_type_late(cls, typevar_values, origin, args) + if cached is not None: + return cached + + # Attempt to rebuild the origin in case new types have been defined + try: + # depth 3 gets you above this __class_getitem__ call + origin.model_rebuild(_parent_namespace_depth=3) + except PydanticUndefinedAnnotation: + # It's okay if it fails, it just means there are still undefined types + # that could be evaluated later. + # TODO: Make sure validation fails if there are still undefined types, perhaps using MockValidator + pass + + submodel = _generics.create_generic_submodel(model_name, origin, args, params) + + # Update cache + _generics.set_cached_generic_type(cls, typevar_values, submodel, origin, args) + + return submodel + + def __copy__(self: Model) -> Model: + """Returns a shallow copy of the model.""" + cls = type(self) + m = cls.__new__(cls) + _object_setattr(m, '__dict__', copy(self.__dict__)) + _object_setattr(m, '__pydantic_extra__', copy(self.__pydantic_extra__)) + _object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__)) + + if self.__pydantic_private__ is None: + _object_setattr(m, '__pydantic_private__', None) + else: + _object_setattr( + m, + '__pydantic_private__', + {k: v for k, v in self.__pydantic_private__.items() if v is not PydanticUndefined}, + ) + + return m + + def __deepcopy__(self: Model, memo: dict[int, Any] | None = None) -> Model: + """Returns a deep copy of the model.""" + cls = type(self) + m = cls.__new__(cls) + _object_setattr(m, '__dict__', deepcopy(self.__dict__, memo=memo)) + _object_setattr(m, '__pydantic_extra__', deepcopy(self.__pydantic_extra__, memo=memo)) + # This next line doesn't need a deepcopy because __pydantic_fields_set__ is a set[str], + # and attempting a deepcopy would be marginally slower. + _object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__)) + + if self.__pydantic_private__ is None: + _object_setattr(m, '__pydantic_private__', None) + else: + _object_setattr( + m, + '__pydantic_private__', + deepcopy({k: v for k, v in self.__pydantic_private__.items() if v is not PydanticUndefined}, memo=memo), + ) + + return m + + if not typing.TYPE_CHECKING: + # We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access + + def __getattr__(self, item: str) -> Any: + private_attributes = object.__getattribute__(self, '__private_attributes__') + if item in private_attributes: + attribute = private_attributes[item] + if hasattr(attribute, '__get__'): + return attribute.__get__(self, type(self)) # type: ignore + + try: + # Note: self.__pydantic_private__ cannot be None if self.__private_attributes__ has items + return self.__pydantic_private__[item] # type: ignore + except KeyError as exc: + raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}') from exc + else: + # `__pydantic_extra__` can fail to be set if the model is not yet fully initialized. + # See `BaseModel.__repr_args__` for more details + try: + pydantic_extra = object.__getattribute__(self, '__pydantic_extra__') + except AttributeError: + pydantic_extra = None + + if pydantic_extra is not None: + try: + return pydantic_extra[item] + except KeyError as exc: + raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}') from exc + else: + if hasattr(self.__class__, item): + return super().__getattribute__(item) # Raises AttributeError if appropriate + else: + # this is the current error + raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}') + + def __setattr__(self, name: str, value: Any) -> None: + if name in self.__class_vars__: + raise AttributeError( + f'{name!r} is a ClassVar of `{self.__class__.__name__}` and cannot be set on an instance. ' + f'If you want to set a value on the class, use `{self.__class__.__name__}.{name} = value`.' + ) + elif not _fields.is_valid_field_name(name): + if self.__pydantic_private__ is None or name not in self.__private_attributes__: + _object_setattr(self, name, value) + else: + attribute = self.__private_attributes__[name] + if hasattr(attribute, '__set__'): + attribute.__set__(self, value) # type: ignore + else: + self.__pydantic_private__[name] = value + return + + self._check_frozen(name, value) + + attr = getattr(self.__class__, name, None) + if isinstance(attr, property): + attr.__set__(self, value) + elif self.model_config.get('validate_assignment', None): + self.__pydantic_validator__.validate_assignment(self, name, value) + elif self.model_config.get('extra') != 'allow' and name not in self.model_fields: + # TODO - matching error + raise ValueError(f'"{self.__class__.__name__}" object has no field "{name}"') + elif self.model_config.get('extra') == 'allow' and name not in self.model_fields: + if self.model_extra and name in self.model_extra: + self.__pydantic_extra__[name] = value # type: ignore + else: + try: + getattr(self, name) + except AttributeError: + # attribute does not already exist on instance, so put it in extra + self.__pydantic_extra__[name] = value # type: ignore + else: + # attribute _does_ already exist on instance, and was not in extra, so update it + _object_setattr(self, name, value) + else: + self.__dict__[name] = value + self.__pydantic_fields_set__.add(name) + + def __delattr__(self, item: str) -> Any: + if item in self.__private_attributes__: + attribute = self.__private_attributes__[item] + if hasattr(attribute, '__delete__'): + attribute.__delete__(self) # type: ignore + return + + try: + # Note: self.__pydantic_private__ cannot be None if self.__private_attributes__ has items + del self.__pydantic_private__[item] # type: ignore + return + except KeyError as exc: + raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}') from exc + + self._check_frozen(item, None) + + if item in self.model_fields: + object.__delattr__(self, item) + elif self.__pydantic_extra__ is not None and item in self.__pydantic_extra__: + del self.__pydantic_extra__[item] + else: + try: + object.__delattr__(self, item) + except AttributeError: + raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}') + + def _check_frozen(self, name: str, value: Any) -> None: + if self.model_config.get('frozen', None): + typ = 'frozen_instance' + elif getattr(self.model_fields.get(name), 'frozen', False): + typ = 'frozen_field' + else: + return + error: pydantic_core.InitErrorDetails = { + 'type': typ, + 'loc': (name,), + 'input': value, + } + raise pydantic_core.ValidationError.from_exception_data(self.__class__.__name__, [error]) + + def __getstate__(self) -> dict[Any, Any]: + private = self.__pydantic_private__ + if private: + private = {k: v for k, v in private.items() if v is not PydanticUndefined} + return { + '__dict__': self.__dict__, + '__pydantic_extra__': self.__pydantic_extra__, + '__pydantic_fields_set__': self.__pydantic_fields_set__, + '__pydantic_private__': private, + } + + def __setstate__(self, state: dict[Any, Any]) -> None: + _object_setattr(self, '__pydantic_fields_set__', state['__pydantic_fields_set__']) + _object_setattr(self, '__pydantic_extra__', state['__pydantic_extra__']) + _object_setattr(self, '__pydantic_private__', state['__pydantic_private__']) + _object_setattr(self, '__dict__', state['__dict__']) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, BaseModel): + # When comparing instances of generic types for equality, as long as all field values are equal, + # only require their generic origin types to be equal, rather than exact type equality. + # This prevents headaches like MyGeneric(x=1) != MyGeneric[Any](x=1). + self_type = self.__pydantic_generic_metadata__['origin'] or self.__class__ + other_type = other.__pydantic_generic_metadata__['origin'] or other.__class__ + + # Perform common checks first + if not ( + self_type == other_type + and self.__pydantic_private__ == other.__pydantic_private__ + and self.__pydantic_extra__ == other.__pydantic_extra__ + ): + return False + + # We only want to compare pydantic fields but ignoring fields is costly. + # We'll perform a fast check first, and fallback only when needed + # See GH-7444 and GH-7825 for rationale and a performance benchmark + + # First, do the fast (and sometimes faulty) __dict__ comparison + if self.__dict__ == other.__dict__: + # If the check above passes, then pydantic fields are equal, we can return early + return True + + # We don't want to trigger unnecessary costly filtering of __dict__ on all unequal objects, so we return + # early if there are no keys to ignore (we would just return False later on anyway) + model_fields = type(self).model_fields.keys() + if self.__dict__.keys() <= model_fields and other.__dict__.keys() <= model_fields: + return False + + # If we reach here, there are non-pydantic-fields keys, mapped to unequal values, that we need to ignore + # Resort to costly filtering of the __dict__ objects + # We use operator.itemgetter because it is much faster than dict comprehensions + # NOTE: Contrary to standard python class and instances, when the Model class has a default value for an + # attribute and the model instance doesn't have a corresponding attribute, accessing the missing attribute + # raises an error in BaseModel.__getattr__ instead of returning the class attribute + # So we can use operator.itemgetter() instead of operator.attrgetter() + getter = operator.itemgetter(*model_fields) if model_fields else lambda _: _utils._SENTINEL + try: + return getter(self.__dict__) == getter(other.__dict__) + except KeyError: + # In rare cases (such as when using the deprecated BaseModel.copy() method), + # the __dict__ may not contain all model fields, which is how we can get here. + # getter(self.__dict__) is much faster than any 'safe' method that accounts + # for missing keys, and wrapping it in a `try` doesn't slow things down much + # in the common case. + self_fields_proxy = _utils.SafeGetItemProxy(self.__dict__) + other_fields_proxy = _utils.SafeGetItemProxy(other.__dict__) + return getter(self_fields_proxy) == getter(other_fields_proxy) + + # other instance is not a BaseModel + else: + return NotImplemented # delegate to the other item in the comparison + + if typing.TYPE_CHECKING: + # We put `__init_subclass__` in a TYPE_CHECKING block because, even though we want the type-checking benefits + # described in the signature of `__init_subclass__` below, we don't want to modify the default behavior of + # subclass initialization. + + def __init_subclass__(cls, **kwargs: Unpack[ConfigDict]): + """This signature is included purely to help type-checkers check arguments to class declaration, which + provides a way to conveniently set model_config key/value pairs. + + ```py + from pydantic import BaseModel + + class MyModel(BaseModel, extra='allow'): + ... + ``` + + However, this may be deceiving, since the _actual_ calls to `__init_subclass__` will not receive any + of the config arguments, and will only receive any keyword arguments passed during class initialization + that are _not_ expected keys in ConfigDict. (This is due to the way `ModelMetaclass.__new__` works.) + + Args: + **kwargs: Keyword arguments passed to the class definition, which set model_config + + Note: + You may want to override `__pydantic_init_subclass__` instead, which behaves similarly but is called + *after* the class is fully initialized. + """ + + def __iter__(self) -> TupleGenerator: + """So `dict(model)` works.""" + yield from [(k, v) for (k, v) in self.__dict__.items() if not k.startswith('_')] + extra = self.__pydantic_extra__ + if extra: + yield from extra.items() + + def __repr__(self) -> str: + return f'{self.__repr_name__()}({self.__repr_str__(", ")})' + + def __repr_args__(self) -> _repr.ReprArgs: + for k, v in self.__dict__.items(): + field = self.model_fields.get(k) + if field and field.repr: + yield k, v + + # `__pydantic_extra__` can fail to be set if the model is not yet fully initialized. + # This can happen if a `ValidationError` is raised during initialization and the instance's + # repr is generated as part of the exception handling. Therefore, we use `getattr` here + # with a fallback, even though the type hints indicate the attribute will always be present. + try: + pydantic_extra = object.__getattribute__(self, '__pydantic_extra__') + except AttributeError: + pydantic_extra = None + + if pydantic_extra is not None: + yield from ((k, v) for k, v in pydantic_extra.items()) + yield from ((k, getattr(self, k)) for k, v in self.model_computed_fields.items() if v.repr) + + # take logic from `_repr.Representation` without the side effects of inheritance, see #5740 + __repr_name__ = _repr.Representation.__repr_name__ + __repr_str__ = _repr.Representation.__repr_str__ + __pretty__ = _repr.Representation.__pretty__ + __rich_repr__ = _repr.Representation.__rich_repr__ + + def __str__(self) -> str: + return self.__repr_str__(' ') + + # ##### Deprecated methods from v1 ##### + @property + @typing_extensions.deprecated( + 'The `__fields__` attribute is deprecated, use `model_fields` instead.', category=None + ) + def __fields__(self) -> dict[str, FieldInfo]: + warnings.warn( + 'The `__fields__` attribute is deprecated, use `model_fields` instead.', category=PydanticDeprecatedSince20 + ) + return self.model_fields + + @property + @typing_extensions.deprecated( + 'The `__fields_set__` attribute is deprecated, use `model_fields_set` instead.', + category=None, + ) + def __fields_set__(self) -> set[str]: + warnings.warn( + 'The `__fields_set__` attribute is deprecated, use `model_fields_set` instead.', + category=PydanticDeprecatedSince20, + ) + return self.__pydantic_fields_set__ + + @typing_extensions.deprecated('The `dict` method is deprecated; use `model_dump` instead.', category=None) + def dict( # noqa: D102 + self, + *, + include: IncEx = None, + exclude: IncEx = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> typing.Dict[str, Any]: # noqa UP006 + warnings.warn('The `dict` method is deprecated; use `model_dump` instead.', category=PydanticDeprecatedSince20) + return self.model_dump( + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + + @typing_extensions.deprecated('The `json` method is deprecated; use `model_dump_json` instead.', category=None) + def json( # noqa: D102 + self, + *, + include: IncEx = None, + exclude: IncEx = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + encoder: typing.Callable[[Any], Any] | None = PydanticUndefined, # type: ignore[assignment] + models_as_dict: bool = PydanticUndefined, # type: ignore[assignment] **dumps_kwargs: Any, ) -> str: - """ - Generate a JSON representation of the model, `include` and `exclude` arguments as per `dict()`. - - `encoder` is an optional function to supply as `default` to json.dumps(), other arguments as per `json.dumps()`. - """ - if skip_defaults is not None: - warnings.warn( - f'{self.__class__.__name__}.json(): "skip_defaults" is deprecated and replaced by "exclude_unset"', - DeprecationWarning, - ) - exclude_unset = skip_defaults - encoder = cast(Callable[[Any], Any], encoder or self.__json_encoder__) - - # We don't directly call `self.dict()`, which does exactly this with `to_dict=True` - # because we want to be able to keep raw `BaseModel` instances and not as `dict`. - # This allows users to write custom JSON encoders for given `BaseModel` classes. - data = dict( - self._iter( - to_dict=models_as_dict, - by_alias=by_alias, - include=include, - exclude=exclude, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) + warnings.warn( + 'The `json` method is deprecated; use `model_dump_json` instead.', category=PydanticDeprecatedSince20 + ) + if encoder is not PydanticUndefined: + raise TypeError('The `encoder` argument is no longer supported; use field serializers instead.') + if models_as_dict is not PydanticUndefined: + raise TypeError('The `models_as_dict` argument is no longer supported; use a model serializer instead.') + if dumps_kwargs: + raise TypeError('`dumps_kwargs` keyword arguments are no longer supported.') + return self.model_dump_json( + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, ) - if self.__custom_root_type__: - data = data[ROOT_KEY] - return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs) @classmethod - def _enforce_dict_if_root(cls, obj: Any) -> Any: - if cls.__custom_root_type__ and ( - not (isinstance(obj, dict) and obj.keys() == {ROOT_KEY}) - or cls.__fields__[ROOT_KEY].shape in MAPPING_LIKE_SHAPES - ): - return {ROOT_KEY: obj} - else: - return obj + @typing_extensions.deprecated('The `parse_obj` method is deprecated; use `model_validate` instead.', category=None) + def parse_obj(cls: type[Model], obj: Any) -> Model: # noqa: D102 + warnings.warn( + 'The `parse_obj` method is deprecated; use `model_validate` instead.', category=PydanticDeprecatedSince20 + ) + return cls.model_validate(obj) @classmethod - def parse_obj(cls: Type['Model'], obj: Any) -> 'Model': - obj = cls._enforce_dict_if_root(obj) - if not isinstance(obj, dict): - try: - obj = dict(obj) - except (TypeError, ValueError) as e: - exc = TypeError(f'{cls.__name__} expected dict not {obj.__class__.__name__}') - raise ValidationError([ErrorWrapper(exc, loc=ROOT_KEY)], cls) from e - return cls(**obj) - - @classmethod - def parse_raw( - cls: Type['Model'], - b: StrBytes, + @typing_extensions.deprecated( + 'The `parse_raw` method is deprecated; if your data is JSON use `model_validate_json`, ' + 'otherwise load the data then use `model_validate` instead.', + category=None, + ) + def parse_raw( # noqa: D102 + cls: type[Model], + b: str | bytes, *, - content_type: str = None, + content_type: str | None = None, encoding: str = 'utf8', - proto: Protocol = None, + proto: DeprecatedParseProtocol | None = None, allow_pickle: bool = False, - ) -> 'Model': + ) -> Model: # pragma: no cover + warnings.warn( + 'The `parse_raw` method is deprecated; if your data is JSON use `model_validate_json`, ' + 'otherwise load the data then use `model_validate` instead.', + category=PydanticDeprecatedSince20, + ) + from .deprecated import parse + try: - obj = load_str_bytes( + obj = parse.load_str_bytes( b, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle, - json_loads=cls.__config__.json_loads, ) - except (ValueError, TypeError, UnicodeDecodeError) as e: - raise ValidationError([ErrorWrapper(e, loc=ROOT_KEY)], cls) - return cls.parse_obj(obj) + except (ValueError, TypeError) as exc: + import json + + # try to match V1 + if isinstance(exc, UnicodeDecodeError): + type_str = 'value_error.unicodedecode' + elif isinstance(exc, json.JSONDecodeError): + type_str = 'value_error.jsondecode' + elif isinstance(exc, ValueError): + type_str = 'value_error' + else: + type_str = 'type_error' + + # ctx is missing here, but since we've added `input` to the error, we're not pretending it's the same + error: pydantic_core.InitErrorDetails = { + # The type: ignore on the next line is to ignore the requirement of LiteralString + 'type': pydantic_core.PydanticCustomError(type_str, str(exc)), # type: ignore + 'loc': ('__root__',), + 'input': b, + } + raise pydantic_core.ValidationError.from_exception_data(cls.__name__, [error]) + return cls.model_validate(obj) @classmethod - def parse_file( - cls: Type['Model'], - path: Union[str, Path], + @typing_extensions.deprecated( + 'The `parse_file` method is deprecated; load the data from file, then if your data is JSON ' + 'use `model_validate_json`, otherwise `model_validate` instead.', + category=None, + ) + def parse_file( # noqa: D102 + cls: type[Model], + path: str | Path, *, - content_type: str = None, + content_type: str | None = None, encoding: str = 'utf8', - proto: Protocol = None, + proto: DeprecatedParseProtocol | None = None, allow_pickle: bool = False, - ) -> 'Model': - obj = load_file( + ) -> Model: + warnings.warn( + 'The `parse_file` method is deprecated; load the data from file, then if your data is JSON ' + 'use `model_validate_json`, otherwise `model_validate` instead.', + category=PydanticDeprecatedSince20, + ) + from .deprecated import parse + + obj = parse.load_file( path, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle, - json_loads=cls.__config__.json_loads, ) return cls.parse_obj(obj) @classmethod - def from_orm(cls: Type['Model'], obj: Any) -> 'Model': - if not cls.__config__.orm_mode: - raise ConfigError('You must have the config attribute orm_mode=True to use from_orm') - obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj) - m = cls.__new__(cls) - values, fields_set, validation_error = validate_model(cls, obj) - if validation_error: - raise validation_error - object_setattr(m, '__dict__', values) - object_setattr(m, '__fields_set__', fields_set) - m._init_private_attributes() - return m + @typing_extensions.deprecated( + 'The `from_orm` method is deprecated; set ' + "`model_config['from_attributes']=True` and use `model_validate` instead.", + category=None, + ) + def from_orm(cls: type[Model], obj: Any) -> Model: # noqa: D102 + warnings.warn( + 'The `from_orm` method is deprecated; set ' + "`model_config['from_attributes']=True` and use `model_validate` instead.", + category=PydanticDeprecatedSince20, + ) + if not cls.model_config.get('from_attributes', None): + raise PydanticUserError( + 'You must set the config attribute `from_attributes=True` to use from_orm', code=None + ) + return cls.model_validate(obj) @classmethod - def construct(cls: Type['Model'], _fields_set: Optional['SetStr'] = None, **values: Any) -> 'Model': - """ - Creates a new model setting __dict__ and __fields_set__ from trusted or pre-validated data. - Default values are respected, but no other validation is performed. - Behaves as if `Config.extra = 'allow'` was set since it adds all passed values - """ - m = cls.__new__(cls) - fields_values: Dict[str, Any] = {} - for name, field in cls.__fields__.items(): - if field.alt_alias and field.alias in values: - fields_values[name] = values[field.alias] - elif name in values: - fields_values[name] = values[name] - elif not field.required: - fields_values[name] = field.get_default() - fields_values.update(values) - object_setattr(m, '__dict__', fields_values) - if _fields_set is None: - _fields_set = set(values.keys()) - object_setattr(m, '__fields_set__', _fields_set) - m._init_private_attributes() - return m - - def _copy_and_set_values(self: 'Model', values: 'DictStrAny', fields_set: 'SetStr', *, deep: bool) -> 'Model': - if deep: - # chances of having empty dict here are quite low for using smart_deepcopy - values = deepcopy(values) - - cls = self.__class__ - m = cls.__new__(cls) - object_setattr(m, '__dict__', values) - object_setattr(m, '__fields_set__', fields_set) - for name in self.__private_attributes__: - value = getattr(self, name, Undefined) - if value is not Undefined: - if deep: - value = deepcopy(value) - object_setattr(m, name, value) - - return m + @typing_extensions.deprecated('The `construct` method is deprecated; use `model_construct` instead.', category=None) + def construct(cls: type[Model], _fields_set: set[str] | None = None, **values: Any) -> Model: # noqa: D102 + warnings.warn( + 'The `construct` method is deprecated; use `model_construct` instead.', category=PydanticDeprecatedSince20 + ) + return cls.model_construct(_fields_set=_fields_set, **values) + @typing_extensions.deprecated( + 'The `copy` method is deprecated; use `model_copy` instead. ' + 'See the docstring of `BaseModel.copy` for details about how to handle `include` and `exclude`.', + category=None, + ) def copy( - self: 'Model', + self: Model, *, - include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, - exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, - update: Optional['DictStrAny'] = None, + include: AbstractSetIntStr | MappingIntStrAny | None = None, + exclude: AbstractSetIntStr | MappingIntStrAny | None = None, + update: typing.Dict[str, Any] | None = None, # noqa UP006 deep: bool = False, - ) -> 'Model': - """ - Duplicate a model, optionally choose which fields to include, exclude and change. + ) -> Model: # pragma: no cover + """Returns a copy of the model. - :param include: fields to include in new model - :param exclude: fields to exclude from new model, as with values this takes precedence over include - :param update: values to change/add in the new model. Note: the data is not validated before creating - the new model: you should trust this data - :param deep: set to `True` to make a deep copy of the model - :return: new model instance + !!! warning "Deprecated" + This method is now deprecated; use `model_copy` instead. + + If you need `include` or `exclude`, use: + + ```py + data = self.model_dump(include=include, exclude=exclude, round_trip=True) + data = {**data, **(update or {})} + copied = self.model_validate(data) + ``` + + Args: + include: Optional set or mapping specifying which fields to include in the copied model. + exclude: Optional set or mapping specifying which fields to exclude in the copied model. + update: Optional dictionary of field-value pairs to override field values in the copied model. + deep: If True, the values of fields that are Pydantic models will be deep-copied. + + Returns: + A copy of the model with included, excluded and updated fields as specified. """ + warnings.warn( + 'The `copy` method is deprecated; use `model_copy` instead. ' + 'See the docstring of `BaseModel.copy` for details about how to handle `include` and `exclude`.', + category=PydanticDeprecatedSince20, + ) + from .deprecated import copy_internals values = dict( - self._iter(to_dict=False, by_alias=False, include=include, exclude=exclude, exclude_unset=False), + copy_internals._iter( + self, to_dict=False, by_alias=False, include=include, exclude=exclude, exclude_unset=False + ), **(update or {}), ) + if self.__pydantic_private__ is None: + private = None + else: + private = {k: v for k, v in self.__pydantic_private__.items() if v is not PydanticUndefined} - # new `__fields_set__` can have unset optional fields with a set value in `update` kwarg + if self.__pydantic_extra__ is None: + extra: dict[str, Any] | None = None + else: + extra = self.__pydantic_extra__.copy() + for k in list(self.__pydantic_extra__): + if k not in values: # k was in the exclude + extra.pop(k) + for k in list(values): + if k in self.__pydantic_extra__: # k must have come from extra + extra[k] = values.pop(k) + + # new `__pydantic_fields_set__` can have unset optional fields with a set value in `update` kwarg if update: - fields_set = self.__fields_set__ | update.keys() + fields_set = self.__pydantic_fields_set__ | update.keys() else: - fields_set = set(self.__fields_set__) - - return self._copy_and_set_values(values, fields_set, deep=deep) - - @classmethod - def schema(cls, by_alias: bool = True, ref_template: str = default_ref_template) -> 'DictStrAny': - cached = cls.__schema_cache__.get((by_alias, ref_template)) - if cached is not None: - return cached - s = model_schema(cls, by_alias=by_alias, ref_template=ref_template) - cls.__schema_cache__[(by_alias, ref_template)] = s - return s - - @classmethod - def schema_json( - cls, *, by_alias: bool = True, ref_template: str = default_ref_template, **dumps_kwargs: Any - ) -> str: - from .json import pydantic_encoder - - return cls.__config__.json_dumps( - cls.schema(by_alias=by_alias, ref_template=ref_template), default=pydantic_encoder, **dumps_kwargs - ) - - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield cls.validate - - @classmethod - def validate(cls: Type['Model'], value: Any) -> 'Model': - if isinstance(value, cls): - copy_on_model_validation = cls.__config__.copy_on_model_validation - # whether to deep or shallow copy the model on validation, None means do not copy - deep_copy: Optional[bool] = None - if copy_on_model_validation not in {'deep', 'shallow', 'none'}: - # Warn about deprecated behavior - warnings.warn( - "`copy_on_model_validation` should be a string: 'deep', 'shallow' or 'none'", DeprecationWarning - ) - if copy_on_model_validation: - deep_copy = False - - if copy_on_model_validation == 'shallow': - # shallow copy - deep_copy = False - elif copy_on_model_validation == 'deep': - # deep copy - deep_copy = True - - if deep_copy is None: - return value - else: - return value._copy_and_set_values(value.__dict__, value.__fields_set__, deep=deep_copy) - - value = cls._enforce_dict_if_root(value) - - if isinstance(value, dict): - return cls(**value) - elif cls.__config__.orm_mode: - return cls.from_orm(value) - else: - try: - value_as_dict = dict(value) - except (TypeError, ValueError) as e: - raise DictError() from e - return cls(**value_as_dict) - - @classmethod - def _decompose_class(cls: Type['Model'], obj: Any) -> GetterDict: - if isinstance(obj, GetterDict): - return obj - return cls.__config__.getter_dict(obj) - - @classmethod - @no_type_check - def _get_value( - cls, - v: Any, - to_dict: bool, - by_alias: bool, - include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']], - exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']], - exclude_unset: bool, - exclude_defaults: bool, - exclude_none: bool, - ) -> Any: - - if isinstance(v, BaseModel): - if to_dict: - v_dict = v.dict( - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - include=include, - exclude=exclude, - exclude_none=exclude_none, - ) - if ROOT_KEY in v_dict: - return v_dict[ROOT_KEY] - return v_dict - else: - return v.copy(include=include, exclude=exclude) - - value_exclude = ValueItems(v, exclude) if exclude else None - value_include = ValueItems(v, include) if include else None - - if isinstance(v, dict): - return { - k_: cls._get_value( - v_, - to_dict=to_dict, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - include=value_include and value_include.for_element(k_), - exclude=value_exclude and value_exclude.for_element(k_), - exclude_none=exclude_none, - ) - for k_, v_ in v.items() - if (not value_exclude or not value_exclude.is_excluded(k_)) - and (not value_include or value_include.is_included(k_)) - } - - elif sequence_like(v): - seq_args = ( - cls._get_value( - v_, - to_dict=to_dict, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - include=value_include and value_include.for_element(i), - exclude=value_exclude and value_exclude.for_element(i), - exclude_none=exclude_none, - ) - for i, v_ in enumerate(v) - if (not value_exclude or not value_exclude.is_excluded(i)) - and (not value_include or value_include.is_included(i)) - ) - - return v.__class__(*seq_args) if is_namedtuple(v.__class__) else v.__class__(seq_args) - - elif isinstance(v, Enum) and getattr(cls.Config, 'use_enum_values', False): - return v.value - - else: - return v - - @classmethod - def __try_update_forward_refs__(cls, **localns: Any) -> None: - """ - Same as update_forward_refs but will not raise exception - when forward references are not defined. - """ - update_model_forward_refs(cls, cls.__fields__.values(), cls.__config__.json_encoders, localns, (NameError,)) - - @classmethod - def update_forward_refs(cls, **localns: Any) -> None: - """ - Try to update ForwardRefs on fields based on this Model, globalns and localns. - """ - update_model_forward_refs(cls, cls.__fields__.values(), cls.__config__.json_encoders, localns) - - def __iter__(self) -> 'TupleGenerator': - """ - so `dict(model)` works - """ - yield from self.__dict__.items() - - def _iter( - self, - to_dict: bool = False, - by_alias: bool = False, - include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, - exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - ) -> 'TupleGenerator': - - # Merge field set excludes with explicit exclude parameter with explicit overriding field set options. - # The extra "is not None" guards are not logically necessary but optimizes performance for the simple case. - if exclude is not None or self.__exclude_fields__ is not None: - exclude = ValueItems.merge(self.__exclude_fields__, exclude) - - if include is not None or self.__include_fields__ is not None: - include = ValueItems.merge(self.__include_fields__, include, intersect=True) - - allowed_keys = self._calculate_keys( - include=include, exclude=exclude, exclude_unset=exclude_unset # type: ignore - ) - if allowed_keys is None and not (to_dict or by_alias or exclude_unset or exclude_defaults or exclude_none): - # huge boost for plain _iter() - yield from self.__dict__.items() - return - - value_exclude = ValueItems(self, exclude) if exclude is not None else None - value_include = ValueItems(self, include) if include is not None else None - - for field_key, v in self.__dict__.items(): - if (allowed_keys is not None and field_key not in allowed_keys) or (exclude_none and v is None): - continue - - if exclude_defaults: - model_field = self.__fields__.get(field_key) - if not getattr(model_field, 'required', True) and getattr(model_field, 'default', _missing) == v: - continue - - if by_alias and field_key in self.__fields__: - dict_key = self.__fields__[field_key].alias - else: - dict_key = field_key - - if to_dict or value_include or value_exclude: - v = self._get_value( - v, - to_dict=to_dict, - by_alias=by_alias, - include=value_include and value_include.for_element(field_key), - exclude=value_exclude and value_exclude.for_element(field_key), - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - ) - yield dict_key, v - - def _calculate_keys( - self, - include: Optional['MappingIntStrAny'], - exclude: Optional['MappingIntStrAny'], - exclude_unset: bool, - update: Optional['DictStrAny'] = None, - ) -> Optional[AbstractSet[str]]: - if include is None and exclude is None and exclude_unset is False: - return None - - keys: AbstractSet[str] - if exclude_unset: - keys = self.__fields_set__.copy() - else: - keys = self.__dict__.keys() - - if include is not None: - keys &= include.keys() - - if update: - keys -= update.keys() + fields_set = set(self.__pydantic_fields_set__) + # removing excluded fields from `__pydantic_fields_set__` if exclude: - keys -= {k for k, v in exclude.items() if ValueItems.is_true(v)} + fields_set -= set(exclude) - return keys + return copy_internals._copy_and_set_values(self, values, fields_set, extra, private, deep=deep) - def __eq__(self, other: Any) -> bool: - if isinstance(other, BaseModel): - return self.dict() == other.dict() - else: - return self.dict() == other + @classmethod + @typing_extensions.deprecated('The `schema` method is deprecated; use `model_json_schema` instead.', category=None) + def schema( # noqa: D102 + cls, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE + ) -> typing.Dict[str, Any]: # noqa UP006 + warnings.warn( + 'The `schema` method is deprecated; use `model_json_schema` instead.', category=PydanticDeprecatedSince20 + ) + return cls.model_json_schema(by_alias=by_alias, ref_template=ref_template) - def __repr_args__(self) -> 'ReprArgs': - return [ - (k, v) - for k, v in self.__dict__.items() - if k not in DUNDER_ATTRIBUTES and (k not in self.__fields__ or self.__fields__[k].field_info.repr) - ] + @classmethod + @typing_extensions.deprecated( + 'The `schema_json` method is deprecated; use `model_json_schema` and json.dumps instead.', + category=None, + ) + def schema_json( # noqa: D102 + cls, *, by_alias: bool = True, ref_template: str = DEFAULT_REF_TEMPLATE, **dumps_kwargs: Any + ) -> str: # pragma: no cover + warnings.warn( + 'The `schema_json` method is deprecated; use `model_json_schema` and json.dumps instead.', + category=PydanticDeprecatedSince20, + ) + import json + + from .deprecated.json import pydantic_encoder + + return json.dumps( + cls.model_json_schema(by_alias=by_alias, ref_template=ref_template), + default=pydantic_encoder, + **dumps_kwargs, + ) + + @classmethod + @typing_extensions.deprecated('The `validate` method is deprecated; use `model_validate` instead.', category=None) + def validate(cls: type[Model], value: Any) -> Model: # noqa: D102 + warnings.warn( + 'The `validate` method is deprecated; use `model_validate` instead.', category=PydanticDeprecatedSince20 + ) + return cls.model_validate(value) + + @classmethod + @typing_extensions.deprecated( + 'The `update_forward_refs` method is deprecated; use `model_rebuild` instead.', + category=None, + ) + def update_forward_refs(cls, **localns: Any) -> None: # noqa: D102 + warnings.warn( + 'The `update_forward_refs` method is deprecated; use `model_rebuild` instead.', + category=PydanticDeprecatedSince20, + ) + if localns: # pragma: no cover + raise TypeError('`localns` arguments are not longer accepted.') + cls.model_rebuild(force=True) + + @typing_extensions.deprecated( + 'The private method `_iter` will be removed and should no longer be used.', category=None + ) + def _iter(self, *args: Any, **kwargs: Any) -> Any: + warnings.warn( + 'The private method `_iter` will be removed and should no longer be used.', + category=PydanticDeprecatedSince20, + ) + from .deprecated import copy_internals + + return copy_internals._iter(self, *args, **kwargs) + + @typing_extensions.deprecated( + 'The private method `_copy_and_set_values` will be removed and should no longer be used.', + category=None, + ) + def _copy_and_set_values(self, *args: Any, **kwargs: Any) -> Any: + warnings.warn( + 'The private method `_copy_and_set_values` will be removed and should no longer be used.', + category=PydanticDeprecatedSince20, + ) + from .deprecated import copy_internals + + return copy_internals._copy_and_set_values(self, *args, **kwargs) + + @classmethod + @typing_extensions.deprecated( + 'The private method `_get_value` will be removed and should no longer be used.', + category=None, + ) + def _get_value(cls, *args: Any, **kwargs: Any) -> Any: + warnings.warn( + 'The private method `_get_value` will be removed and should no longer be used.', + category=PydanticDeprecatedSince20, + ) + from .deprecated import copy_internals + + return copy_internals._get_value(cls, *args, **kwargs) + + @typing_extensions.deprecated( + 'The private method `_calculate_keys` will be removed and should no longer be used.', + category=None, + ) + def _calculate_keys(self, *args: Any, **kwargs: Any) -> Any: + warnings.warn( + 'The private method `_calculate_keys` will be removed and should no longer be used.', + category=PydanticDeprecatedSince20, + ) + from .deprecated import copy_internals + + return copy_internals._calculate_keys(self, *args, **kwargs) -_is_base_model_class_defined = True - - -@overload +@typing.overload def create_model( __model_name: str, *, - __config__: Optional[Type[BaseConfig]] = None, + __config__: ConfigDict | None = None, + __doc__: str | None = None, __base__: None = None, __module__: str = __name__, - __validators__: Dict[str, 'AnyClassMethod'] = None, - __cls_kwargs__: Dict[str, Any] = None, + __validators__: dict[str, classmethod] | None = None, + __cls_kwargs__: dict[str, Any] | None = None, **field_definitions: Any, -) -> Type['BaseModel']: +) -> type[BaseModel]: ... -@overload +@typing.overload def create_model( __model_name: str, *, - __config__: Optional[Type[BaseConfig]] = None, - __base__: Union[Type['Model'], Tuple[Type['Model'], ...]], + __config__: ConfigDict | None = None, + __doc__: str | None = None, + __base__: type[Model] | tuple[type[Model], ...], __module__: str = __name__, - __validators__: Dict[str, 'AnyClassMethod'] = None, - __cls_kwargs__: Dict[str, Any] = None, + __validators__: dict[str, classmethod] | None = None, + __cls_kwargs__: dict[str, Any] | None = None, **field_definitions: Any, -) -> Type['Model']: +) -> type[Model]: ... -def create_model( +def create_model( # noqa: C901 __model_name: str, *, - __config__: Optional[Type[BaseConfig]] = None, - __base__: Union[None, Type['Model'], Tuple[Type['Model'], ...]] = None, - __module__: str = __name__, - __validators__: Dict[str, 'AnyClassMethod'] = None, - __cls_kwargs__: Dict[str, Any] = None, - __slots__: Optional[Tuple[str, ...]] = None, + __config__: ConfigDict | None = None, + __doc__: str | None = None, + __base__: type[Model] | tuple[type[Model], ...] | None = None, + __module__: str | None = None, + __validators__: dict[str, classmethod] | None = None, + __cls_kwargs__: dict[str, Any] | None = None, + __slots__: tuple[str, ...] | None = None, **field_definitions: Any, -) -> Type['Model']: - """ - Dynamically create a model. - :param __model_name: name of the created model - :param __config__: config class to use for the new model - :param __base__: base class for the new model to inherit from - :param __module__: module of the created model - :param __validators__: a dict of method names and @validator class methods - :param __cls_kwargs__: a dict for class creation - :param __slots__: Deprecated, `__slots__` should not be passed to `create_model` - :param field_definitions: fields of the model (or extra fields if a base is supplied) - in the format `=(, )` or `=, e.g. - `foobar=(str, ...)` or `foobar=123`, or, for complex use-cases, in the format - `=` or `=(, )`, e.g. - `foo=Field(datetime, default_factory=datetime.utcnow, alias='bar')` or - `foo=(str, FieldInfo(title='Foo'))` +) -> type[Model]: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/models/#dynamic-model-creation + + Dynamically creates and returns a new Pydantic model, in other words, `create_model` dynamically creates a + subclass of [`BaseModel`][pydantic.BaseModel]. + + Args: + __model_name: The name of the newly created model. + __config__: The configuration of the new model. + __doc__: The docstring of the new model. + __base__: The base class or classes for the new model. + __module__: The name of the module that the model belongs to; + if `None`, the value is taken from `sys._getframe(1)` + __validators__: A dictionary of methods that validate fields. + __cls_kwargs__: A dictionary of keyword arguments for class creation, such as `metaclass`. + __slots__: Deprecated. Should not be passed to `create_model`. + **field_definitions: Attributes of the new model. They should be passed in the format: + `=(, )` or `=(, )`. + + Returns: + The new [model][pydantic.BaseModel]. + + Raises: + PydanticUserError: If `__base__` and `__config__` are both passed. """ if __slots__ is not None: # __slots__ will be ignored from here on @@ -982,11 +1436,14 @@ def create_model( if __base__ is not None: if __config__ is not None: - raise ConfigError('to avoid confusion __config__ and __base__ cannot be used together') + raise PydanticUserError( + 'to avoid confusion `__config__` and `__base__` cannot be used together', + code='create-model-config-base', + ) if not isinstance(__base__, tuple): __base__ = (__base__,) else: - __base__ = (cast(Type['Model'], BaseModel),) + __base__ = (typing.cast(typing.Type['Model'], BaseModel),) __cls_kwargs__ = __cls_kwargs__ or {} @@ -994,16 +1451,16 @@ def create_model( annotations = {} for f_name, f_def in field_definitions.items(): - if not is_valid_field(f_name): + if not _fields.is_valid_field_name(f_name): warnings.warn(f'fields may not start with an underscore, ignoring "{f_name}"', RuntimeWarning) if isinstance(f_def, tuple): + f_def = typing.cast('tuple[str, Any]', f_def) try: f_annotation, f_value = f_def except ValueError as e: - raise ConfigError( - 'field definitions should either be a tuple of (, ) or just a ' - 'default value, unfortunately this means tuples as ' - 'default values are not allowed' + raise PydanticUserError( + 'Field definitions should be a `(, )`.', + code='create-model-field-definitions', ) from e else: f_annotation, f_value = None, f_def @@ -1012,98 +1469,32 @@ def create_model( annotations[f_name] = f_annotation fields[f_name] = f_value - namespace: 'DictStrAny' = {'__annotations__': annotations, '__module__': __module__} + if __module__ is None: + f = sys._getframe(1) + __module__ = f.f_globals['__name__'] + + namespace: dict[str, Any] = {'__annotations__': annotations, '__module__': __module__} + if __doc__: + namespace.update({'__doc__': __doc__}) if __validators__: namespace.update(__validators__) namespace.update(fields) if __config__: - namespace['Config'] = inherit_config(__config__, BaseConfig) - resolved_bases = resolve_bases(__base__) - meta, ns, kwds = prepare_class(__model_name, resolved_bases, kwds=__cls_kwargs__) + namespace['model_config'] = _config.ConfigWrapper(__config__).config_dict + resolved_bases = types.resolve_bases(__base__) + meta, ns, kwds = types.prepare_class(__model_name, resolved_bases, kwds=__cls_kwargs__) if resolved_bases is not __base__: ns['__orig_bases__'] = __base__ namespace.update(ns) - return meta(__model_name, resolved_bases, namespace, **kwds) + + return meta( + __model_name, + resolved_bases, + namespace, + __pydantic_reset_parent_namespace__=False, + _create_model_module=__module__, + **kwds, + ) -_missing = object() - - -def validate_model( # noqa: C901 (ignore complexity) - model: Type[BaseModel], input_data: 'DictStrAny', cls: 'ModelOrDc' = None -) -> Tuple['DictStrAny', 'SetStr', Optional[ValidationError]]: - """ - validate data against a model. - """ - values = {} - errors = [] - # input_data names, possibly alias - names_used = set() - # field names, never aliases - fields_set = set() - config = model.__config__ - check_extra = config.extra is not Extra.ignore - cls_ = cls or model - - for validator in model.__pre_root_validators__: - try: - input_data = validator(cls_, input_data) - except (ValueError, TypeError, AssertionError) as exc: - return {}, set(), ValidationError([ErrorWrapper(exc, loc=ROOT_KEY)], cls_) - - for name, field in model.__fields__.items(): - value = input_data.get(field.alias, _missing) - using_name = False - if value is _missing and config.allow_population_by_field_name and field.alt_alias: - value = input_data.get(field.name, _missing) - using_name = True - - if value is _missing: - if field.required: - errors.append(ErrorWrapper(MissingError(), loc=field.alias)) - continue - - value = field.get_default() - - if not config.validate_all and not field.validate_always: - values[name] = value - continue - else: - fields_set.add(name) - if check_extra: - names_used.add(field.name if using_name else field.alias) - - v_, errors_ = field.validate(value, values, loc=field.alias, cls=cls_) - if isinstance(errors_, ErrorWrapper): - errors.append(errors_) - elif isinstance(errors_, list): - errors.extend(errors_) - else: - values[name] = v_ - - if check_extra: - if isinstance(input_data, GetterDict): - extra = input_data.extra_keys() - names_used - else: - extra = input_data.keys() - names_used - if extra: - fields_set |= extra - if config.extra is Extra.allow: - for f in extra: - values[f] = input_data[f] - else: - for f in sorted(extra): - errors.append(ErrorWrapper(ExtraError(), loc=f)) - - for skip_on_failure, validator in model.__post_root_validators__: - if skip_on_failure and errors: - continue - try: - values = validator(cls_, values) - except (ValueError, TypeError, AssertionError) as exc: - errors.append(ErrorWrapper(exc, loc=ROOT_KEY)) - - if errors: - return values, fields_set, ValidationError(errors, cls_) - else: - return values, fields_set, None +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/mypy.py b/lib/pydantic/mypy.py index 6bd9db18..0e70eab5 100644 --- a/lib/pydantic/mypy.py +++ b/lib/pydantic/mypy.py @@ -1,8 +1,13 @@ +"""This module includes classes and functions designed specifically for use with the mypy plugin.""" + +from __future__ import annotations + import sys from configparser import ConfigParser -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type as TypingType, Union +from typing import Any, Callable, Iterator from mypy.errorcodes import ErrorCode +from mypy.expandtype import expand_type, expand_type_by_instance from mypy.nodes import ( ARG_NAMED, ARG_NAMED_OPT, @@ -17,21 +22,23 @@ from mypy.nodes import ( ClassDef, Context, Decorator, + DictExpr, EllipsisExpr, - FuncBase, + Expression, FuncDef, + IfStmt, JsonDict, MemberExpr, NameExpr, PassStmt, PlaceholderNode, RefExpr, + Statement, StrExpr, - SymbolNode, SymbolTableNode, TempNode, + TypeAlias, TypeInfo, - TypeVarExpr, Var, ) from mypy.options import Options @@ -41,11 +48,17 @@ from mypy.plugin import ( FunctionContext, MethodContext, Plugin, + ReportConfigContext, SemanticAnalyzerPluginInterface, ) from mypy.plugins import dataclasses -from mypy.semanal import set_callable_name # type: ignore +from mypy.plugins.common import ( + deserialize_and_fixup_type, +) +from mypy.semanal import set_callable_name from mypy.server.trigger import make_wildcard_trigger +from mypy.state import state +from mypy.typeops import map_type_from_supertype from mypy.types import ( AnyType, CallableType, @@ -63,76 +76,131 @@ from mypy.typevars import fill_typevars from mypy.util import get_unique_redefinition_name from mypy.version import __version__ as mypy_version -from pydantic.utils import is_valid_field +from pydantic._internal import _fields +from pydantic.version import parse_mypy_version try: from mypy.types import TypeVarDef # type: ignore[attr-defined] except ImportError: # pragma: no cover - # Backward-compatible with TypeVarDef from Mypy 0.910. + # Backward-compatible with TypeVarDef from Mypy 0.930. from mypy.types import TypeVarType as TypeVarDef CONFIGFILE_KEY = 'pydantic-mypy' METADATA_KEY = 'pydantic-mypy-metadata' BASEMODEL_FULLNAME = 'pydantic.main.BaseModel' -BASESETTINGS_FULLNAME = 'pydantic.env_settings.BaseSettings' +BASESETTINGS_FULLNAME = 'pydantic_settings.main.BaseSettings' +ROOT_MODEL_FULLNAME = 'pydantic.root_model.RootModel' +MODEL_METACLASS_FULLNAME = 'pydantic._internal._model_construction.ModelMetaclass' FIELD_FULLNAME = 'pydantic.fields.Field' DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass' - - -def parse_mypy_version(version: str) -> Tuple[int, ...]: - return tuple(int(part) for part in version.split('+', 1)[0].split('.')) +MODEL_VALIDATOR_FULLNAME = 'pydantic.functional_validators.model_validator' +DECORATOR_FULLNAMES = { + 'pydantic.functional_validators.field_validator', + 'pydantic.functional_validators.model_validator', + 'pydantic.functional_serializers.serializer', + 'pydantic.functional_serializers.model_serializer', + 'pydantic.deprecated.class_validators.validator', + 'pydantic.deprecated.class_validators.root_validator', +} MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version) BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__' +# Increment version if plugin changes and mypy caches should be invalidated +__version__ = 2 -def plugin(version: str) -> 'TypingType[Plugin]': - """ - `version` is the mypy version string + +def plugin(version: str) -> type[Plugin]: + """`version` is the mypy version string. We might want to use this to print a warning if the mypy version being used is newer, or especially older, than we expect (or need). + + Args: + version: The mypy version string. + + Return: + The Pydantic mypy plugin type. """ return PydanticPlugin +class _DeferAnalysis(Exception): + pass + + class PydanticPlugin(Plugin): + """The Pydantic mypy plugin.""" + def __init__(self, options: Options) -> None: self.plugin_config = PydanticPluginConfig(options) + self._plugin_data = self.plugin_config.to_data() super().__init__(options) - def get_base_class_hook(self, fullname: str) -> 'Optional[Callable[[ClassDefContext], None]]': + def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], bool] | None: + """Update Pydantic model class.""" sym = self.lookup_fully_qualified(fullname) if sym and isinstance(sym.node, TypeInfo): # pragma: no branch # No branching may occur if the mypy cache has not been cleared - if any(get_fullname(base) == BASEMODEL_FULLNAME for base in sym.node.mro): + if any(base.fullname == BASEMODEL_FULLNAME for base in sym.node.mro): return self._pydantic_model_class_maker_callback return None - def get_function_hook(self, fullname: str) -> 'Optional[Callable[[FunctionContext], Type]]': + def get_metaclass_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None: + """Update Pydantic `ModelMetaclass` definition.""" + if fullname == MODEL_METACLASS_FULLNAME: + return self._pydantic_model_metaclass_marker_callback + return None + + def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None: + """Adjust the return type of the `Field` function.""" sym = self.lookup_fully_qualified(fullname) if sym and sym.fullname == FIELD_FULLNAME: return self._pydantic_field_callback return None - def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], Type]]: + def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None: + """Adjust return type of `from_orm` method call.""" if fullname.endswith('.from_orm'): - return from_orm_callback + return from_attributes_callback return None - def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: - if fullname == DATACLASS_FULLNAME: + def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None: + """Mark pydantic.dataclasses as dataclass. + + Mypy version 1.1.1 added support for `@dataclass_transform` decorator. + """ + if fullname == DATACLASS_FULLNAME and MYPY_VERSION_TUPLE < (1, 1): return dataclasses.dataclass_class_maker_callback # type: ignore[return-value] return None - def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None: - transformer = PydanticModelTransformer(ctx, self.plugin_config) - transformer.transform() + def report_config_data(self, ctx: ReportConfigContext) -> dict[str, Any]: + """Return all plugin config data. - def _pydantic_field_callback(self, ctx: FunctionContext) -> 'Type': + Used by mypy to determine if cache needs to be discarded. """ - Extract the type of the `default` argument from the Field function, and use it as the return type. + return self._plugin_data + + def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> bool: + transformer = PydanticModelTransformer(ctx.cls, ctx.reason, ctx.api, self.plugin_config) + return transformer.transform() + + def _pydantic_model_metaclass_marker_callback(self, ctx: ClassDefContext) -> None: + """Reset dataclass_transform_spec attribute of ModelMetaclass. + + Let the plugin handle it. This behavior can be disabled + if 'debug_dataclass_transform' is set to True', for testing purposes. + """ + if self.plugin_config.debug_dataclass_transform: + return + info_metaclass = ctx.cls.info.declared_metaclass + assert info_metaclass, "callback not passed from 'get_metaclass_hook'" + if getattr(info_metaclass.type, 'dataclass_transform_spec', None): + info_metaclass.type.dataclass_transform_spec = None + + def _pydantic_field_callback(self, ctx: FunctionContext) -> Type: + """Extract the type of the `default` argument from the Field function, and use it as the return type. In particular: * Check whether the default and default_factory argument is specified. @@ -164,11 +232,7 @@ class PydanticPlugin(Plugin): # Functions which use `ParamSpec` can be overloaded, exposing the callable's types as a parameter # Pydantic calls the default factory without any argument, so we retrieve the first item if isinstance(default_factory_type, Overloaded): - if MYPY_VERSION_TUPLE > (0, 910): - default_factory_type = default_factory_type.items[0] - else: - # Mypy0.910 exposes the items of overloaded types in a function - default_factory_type = default_factory_type.items()[0] # type: ignore[operator] + default_factory_type = default_factory_type.items[0] if isinstance(default_factory_type, CallableType): ret_type = default_factory_type.ret_type @@ -185,11 +249,26 @@ class PydanticPlugin(Plugin): class PydanticPluginConfig: - __slots__ = ('init_forbid_extra', 'init_typed', 'warn_required_dynamic_aliases', 'warn_untyped_fields') + """A Pydantic mypy plugin config holder. + + Attributes: + init_forbid_extra: Whether to add a `**kwargs` at the end of the generated `__init__` signature. + init_typed: Whether to annotate fields in the generated `__init__`. + warn_required_dynamic_aliases: Whether to raise required dynamic aliases error. + debug_dataclass_transform: Whether to not reset `dataclass_transform_spec` attribute + of `ModelMetaclass` for testing purposes. + """ + + __slots__ = ( + 'init_forbid_extra', + 'init_typed', + 'warn_required_dynamic_aliases', + 'debug_dataclass_transform', + ) init_forbid_extra: bool init_typed: bool warn_required_dynamic_aliases: bool - warn_untyped_fields: bool + debug_dataclass_transform: bool # undocumented def __init__(self, options: Options) -> None: if options.config_file is None: # pragma: no cover @@ -210,343 +289,724 @@ class PydanticPluginConfig: setting = plugin_config.getboolean(CONFIGFILE_KEY, key, fallback=False) setattr(self, key, setting) + def to_data(self) -> dict[str, Any]: + """Returns a dict of config names to their values.""" + return {key: getattr(self, key) for key in self.__slots__} -def from_orm_callback(ctx: MethodContext) -> Type: - """ - Raise an error if orm_mode is not enabled - """ + +def from_attributes_callback(ctx: MethodContext) -> Type: + """Raise an error if from_attributes is not enabled.""" model_type: Instance - if isinstance(ctx.type, CallableType) and isinstance(ctx.type.ret_type, Instance): - model_type = ctx.type.ret_type # called on the class - elif isinstance(ctx.type, Instance): - model_type = ctx.type # called on an instance (unusual, but still valid) + ctx_type = ctx.type + if isinstance(ctx_type, TypeType): + ctx_type = ctx_type.item + if isinstance(ctx_type, CallableType) and isinstance(ctx_type.ret_type, Instance): + model_type = ctx_type.ret_type # called on the class + elif isinstance(ctx_type, Instance): + model_type = ctx_type # called on an instance (unusual, but still valid) else: # pragma: no cover - detail = f'ctx.type: {ctx.type} (of type {ctx.type.__class__.__name__})' + detail = f'ctx.type: {ctx_type} (of type {ctx_type.__class__.__name__})' error_unexpected_behavior(detail, ctx.api, ctx.context) return ctx.default_return_type pydantic_metadata = model_type.type.metadata.get(METADATA_KEY) if pydantic_metadata is None: return ctx.default_return_type - orm_mode = pydantic_metadata.get('config', {}).get('orm_mode') - if orm_mode is not True: - error_from_orm(get_name(model_type.type), ctx.api, ctx.context) + from_attributes = pydantic_metadata.get('config', {}).get('from_attributes') + if from_attributes is not True: + error_from_attributes(model_type.type.name, ctx.api, ctx.context) return ctx.default_return_type +class PydanticModelField: + """Based on mypy.plugins.dataclasses.DataclassAttribute.""" + + def __init__( + self, + name: str, + alias: str | None, + has_dynamic_alias: bool, + has_default: bool, + line: int, + column: int, + type: Type | None, + info: TypeInfo, + ): + self.name = name + self.alias = alias + self.has_dynamic_alias = has_dynamic_alias + self.has_default = has_default + self.line = line + self.column = column + self.type = type + self.info = info + + def to_argument( + self, + current_info: TypeInfo, + typed: bool, + force_optional: bool, + use_alias: bool, + api: SemanticAnalyzerPluginInterface, + ) -> Argument: + """Based on mypy.plugins.dataclasses.DataclassAttribute.to_argument.""" + variable = self.to_var(current_info, api, use_alias) + type_annotation = self.expand_type(current_info, api) if typed else AnyType(TypeOfAny.explicit) + return Argument( + variable=variable, + type_annotation=type_annotation, + initializer=None, + kind=ARG_NAMED_OPT if force_optional or self.has_default else ARG_NAMED, + ) + + def expand_type(self, current_info: TypeInfo, api: SemanticAnalyzerPluginInterface) -> Type | None: + """Based on mypy.plugins.dataclasses.DataclassAttribute.expand_type.""" + # The getattr in the next line is used to prevent errors in legacy versions of mypy without this attribute + if self.type is not None and getattr(self.info, 'self_type', None) is not None: + # In general, it is not safe to call `expand_type()` during semantic analyzis, + # however this plugin is called very late, so all types should be fully ready. + # Also, it is tricky to avoid eager expansion of Self types here (e.g. because + # we serialize attributes). + expanded_type = expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)}) + if isinstance(self.type, UnionType) and not isinstance(expanded_type, UnionType): + if not api.final_iteration: + raise _DeferAnalysis() + return expanded_type + return self.type + + def to_var(self, current_info: TypeInfo, api: SemanticAnalyzerPluginInterface, use_alias: bool) -> Var: + """Based on mypy.plugins.dataclasses.DataclassAttribute.to_var.""" + if use_alias and self.alias is not None: + name = self.alias + else: + name = self.name + + return Var(name, self.expand_type(current_info, api)) + + def serialize(self) -> JsonDict: + """Based on mypy.plugins.dataclasses.DataclassAttribute.serialize.""" + assert self.type + return { + 'name': self.name, + 'alias': self.alias, + 'has_dynamic_alias': self.has_dynamic_alias, + 'has_default': self.has_default, + 'line': self.line, + 'column': self.column, + 'type': self.type.serialize(), + } + + @classmethod + def deserialize(cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface) -> PydanticModelField: + """Based on mypy.plugins.dataclasses.DataclassAttribute.deserialize.""" + data = data.copy() + typ = deserialize_and_fixup_type(data.pop('type'), api) + return cls(type=typ, info=info, **data) + + def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: + """Expands type vars in the context of a subtype when an attribute is inherited + from a generic super type. + """ + if self.type is not None: + self.type = map_type_from_supertype(self.type, sub_type, self.info) + + +class PydanticModelClassVar: + """Based on mypy.plugins.dataclasses.DataclassAttribute. + + ClassVars are ignored by subclasses. + + Attributes: + name: the ClassVar name + """ + + def __init__(self, name): + self.name = name + + @classmethod + def deserialize(cls, data: JsonDict) -> PydanticModelClassVar: + """Based on mypy.plugins.dataclasses.DataclassAttribute.deserialize.""" + data = data.copy() + return cls(**data) + + def serialize(self) -> JsonDict: + """Based on mypy.plugins.dataclasses.DataclassAttribute.serialize.""" + return { + 'name': self.name, + } + + class PydanticModelTransformer: - tracked_config_fields: Set[str] = { + """Transform the BaseModel subclass according to the plugin settings. + + Attributes: + tracked_config_fields: A set of field configs that the plugin has to track their value. + """ + + tracked_config_fields: set[str] = { 'extra', - 'allow_mutation', 'frozen', - 'orm_mode', - 'allow_population_by_field_name', + 'from_attributes', + 'populate_by_name', 'alias_generator', } - def __init__(self, ctx: ClassDefContext, plugin_config: PydanticPluginConfig) -> None: - self._ctx = ctx + def __init__( + self, + cls: ClassDef, + reason: Expression | Statement, + api: SemanticAnalyzerPluginInterface, + plugin_config: PydanticPluginConfig, + ) -> None: + self._cls = cls + self._reason = reason + self._api = api + self.plugin_config = plugin_config - def transform(self) -> None: - """ - Configures the BaseModel subclass according to the plugin settings. + def transform(self) -> bool: + """Configures the BaseModel subclass according to the plugin settings. In particular: + * determines the model config and fields, * adds a fields-aware signature for the initializer and construct methods - * freezes the class if allow_mutation = False or frozen = True + * freezes the class if frozen = True * stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses """ - ctx = self._ctx - info = self._ctx.cls.info - - self.adjust_validator_signatures() + info = self._cls.info + is_root_model = any(ROOT_MODEL_FULLNAME in base.fullname for base in info.mro[:-1]) config = self.collect_config() - fields = self.collect_fields(config) + fields, class_vars = self.collect_fields_and_class_vars(config, is_root_model) + if fields is None or class_vars is None: + # Some definitions are not ready. We need another pass. + return False for field in fields: - if info[field.name].type is None: - if not ctx.api.final_iteration: - ctx.api.defer() - is_settings = any(get_fullname(base) == BASESETTINGS_FULLNAME for base in info.mro[:-1]) - self.add_initializer(fields, config, is_settings) - self.add_construct_method(fields) - self.set_frozen(fields, frozen=config.allow_mutation is False or config.frozen is True) + if field.type is None: + return False + + is_settings = any(base.fullname == BASESETTINGS_FULLNAME for base in info.mro[:-1]) + try: + self.add_initializer(fields, config, is_settings, is_root_model) + self.add_model_construct_method(fields, config, is_settings) + self.set_frozen(fields, self._api, frozen=config.frozen is True) + except _DeferAnalysis: + if not self._api.final_iteration: + self._api.defer() + + self.adjust_decorator_signatures() + info.metadata[METADATA_KEY] = { 'fields': {field.name: field.serialize() for field in fields}, - 'config': config.set_values_dict(), + 'class_vars': {class_var.name: class_var.serialize() for class_var in class_vars}, + 'config': config.get_values_dict(), } - def adjust_validator_signatures(self) -> None: - """When we decorate a function `f` with `pydantic.validator(...), mypy sees - `f` as a regular method taking a `self` instance, even though pydantic - internally wraps `f` with `classmethod` if necessary. + return True - Teach mypy this by marking any function whose outermost decorator is a - `validator()` call as a classmethod. + def adjust_decorator_signatures(self) -> None: + """When we decorate a function `f` with `pydantic.validator(...)`, `pydantic.field_validator` + or `pydantic.serializer(...)`, mypy sees `f` as a regular method taking a `self` instance, + even though pydantic internally wraps `f` with `classmethod` if necessary. + + Teach mypy this by marking any function whose outermost decorator is a `validator()`, + `field_validator()` or `serializer()` call as a `classmethod`. """ - for name, sym in self._ctx.cls.info.names.items(): + for name, sym in self._cls.info.names.items(): if isinstance(sym.node, Decorator): first_dec = sym.node.original_decorators[0] if ( isinstance(first_dec, CallExpr) and isinstance(first_dec.callee, NameExpr) - and first_dec.callee.fullname == 'pydantic.class_validators.validator' + and first_dec.callee.fullname in DECORATOR_FULLNAMES + # @model_validator(mode="after") is an exception, it expects a regular method + and not ( + first_dec.callee.fullname == MODEL_VALIDATOR_FULLNAME + and any( + first_dec.arg_names[i] == 'mode' and isinstance(arg, StrExpr) and arg.value == 'after' + for i, arg in enumerate(first_dec.args) + ) + ) ): + # TODO: Only do this if the first argument of the decorated function is `cls` sym.node.func.is_class = True - def collect_config(self) -> 'ModelConfigData': - """ - Collects the values of the config attributes that are used by the plugin, accounting for parent classes. - """ - ctx = self._ctx - cls = ctx.cls + def collect_config(self) -> ModelConfigData: # noqa: C901 (ignore complexity) + """Collects the values of the config attributes that are used by the plugin, accounting for parent classes.""" + cls = self._cls config = ModelConfigData() + + has_config_kwargs = False + has_config_from_namespace = False + + # Handle `class MyModel(BaseModel, =, ...):` + for name, expr in cls.keywords.items(): + config_data = self.get_config_update(name, expr) + if config_data: + has_config_kwargs = True + config.update(config_data) + + # Handle `model_config` + stmt: Statement | None = None for stmt in cls.defs.body: - if not isinstance(stmt, ClassDef): + if not isinstance(stmt, (AssignmentStmt, ClassDef)): continue - if stmt.name == 'Config': + + if isinstance(stmt, AssignmentStmt): + lhs = stmt.lvalues[0] + if not isinstance(lhs, NameExpr) or lhs.name != 'model_config': + continue + + if isinstance(stmt.rvalue, CallExpr): # calls to `dict` or `ConfigDict` + for arg_name, arg in zip(stmt.rvalue.arg_names, stmt.rvalue.args): + if arg_name is None: + continue + config.update(self.get_config_update(arg_name, arg)) + elif isinstance(stmt.rvalue, DictExpr): # dict literals + for key_expr, value_expr in stmt.rvalue.items: + if not isinstance(key_expr, StrExpr): + continue + config.update(self.get_config_update(key_expr.value, value_expr)) + + elif isinstance(stmt, ClassDef): + if stmt.name != 'Config': # 'deprecated' Config-class + continue for substmt in stmt.defs.body: if not isinstance(substmt, AssignmentStmt): continue - config.update(self.get_config_update(substmt)) - if ( - config.has_alias_generator - and not config.allow_population_by_field_name - and self.plugin_config.warn_required_dynamic_aliases - ): - error_required_dynamic_aliases(ctx.api, stmt) + lhs = substmt.lvalues[0] + if not isinstance(lhs, NameExpr): + continue + config.update(self.get_config_update(lhs.name, substmt.rvalue)) + + if has_config_kwargs: + self._api.fail( + 'Specifying config in two places is ambiguous, use either Config attribute or class kwargs', + cls, + ) + break + + has_config_from_namespace = True + + if has_config_kwargs or has_config_from_namespace: + if ( + stmt + and config.has_alias_generator + and not config.populate_by_name + and self.plugin_config.warn_required_dynamic_aliases + ): + error_required_dynamic_aliases(self._api, stmt) + for info in cls.info.mro[1:]: # 0 is the current class if METADATA_KEY not in info.metadata: continue # Each class depends on the set of fields in its ancestors - ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info))) + self._api.add_plugin_dependency(make_wildcard_trigger(info.fullname)) for name, value in info.metadata[METADATA_KEY]['config'].items(): config.setdefault(name, value) return config - def collect_fields(self, model_config: 'ModelConfigData') -> List['PydanticModelField']: - """ - Collects the fields for the model, accounting for parent classes - """ - # First, collect fields belonging to the current class. - ctx = self._ctx - cls = self._ctx.cls - fields = [] # type: List[PydanticModelField] - known_fields = set() # type: Set[str] - for stmt in cls.defs.body: - if not isinstance(stmt, AssignmentStmt): # `and stmt.new_syntax` to require annotation - continue + def collect_fields_and_class_vars( + self, model_config: ModelConfigData, is_root_model: bool + ) -> tuple[list[PydanticModelField] | None, list[PydanticModelClassVar] | None]: + """Collects the fields for the model, accounting for parent classes.""" + cls = self._cls - lhs = stmt.lvalues[0] - if not isinstance(lhs, NameExpr) or not is_valid_field(lhs.name): - continue - - if not stmt.new_syntax and self.plugin_config.warn_untyped_fields: - error_untyped_fields(ctx.api, stmt) - - # if lhs.name == '__config__': # BaseConfig not well handled; I'm not sure why yet - # continue - - sym = cls.info.names.get(lhs.name) - if sym is None: # pragma: no cover - # This is likely due to a star import (see the dataclasses plugin for a more detailed explanation) - # This is the same logic used in the dataclasses plugin - continue - - node = sym.node - if isinstance(node, PlaceholderNode): # pragma: no cover - # See the PlaceholderNode docstring for more detail about how this can occur - # Basically, it is an edge case when dealing with complex import logic - # This is the same logic used in the dataclasses plugin - continue - if not isinstance(node, Var): # pragma: no cover - # Don't know if this edge case still happens with the `is_valid_field` check above - # but better safe than sorry - continue - - # x: ClassVar[int] is ignored by dataclasses. - if node.is_classvar: - continue - - is_required = self.get_is_required(cls, stmt, lhs) - alias, has_dynamic_alias = self.get_alias_info(stmt) - if ( - has_dynamic_alias - and not model_config.allow_population_by_field_name - and self.plugin_config.warn_required_dynamic_aliases - ): - error_required_dynamic_aliases(ctx.api, stmt) - fields.append( - PydanticModelField( - name=lhs.name, - is_required=is_required, - alias=alias, - has_dynamic_alias=has_dynamic_alias, - line=stmt.line, - column=stmt.column, - ) - ) - known_fields.add(lhs.name) - all_fields = fields.copy() - for info in cls.info.mro[1:]: # 0 is the current class, -2 is BaseModel, -1 is object + # First, collect fields and ClassVars belonging to any class in the MRO, ignoring duplicates. + # + # We iterate through the MRO in reverse because attrs defined in the parent must appear + # earlier in the attributes list than attrs defined in the child. See: + # https://docs.python.org/3/library/dataclasses.html#inheritance + # + # However, we also want fields defined in the subtype to override ones defined + # in the parent. We can implement this via a dict without disrupting the attr order + # because dicts preserve insertion order in Python 3.7+. + found_fields: dict[str, PydanticModelField] = {} + found_class_vars: dict[str, PydanticModelClassVar] = {} + for info in reversed(cls.info.mro[1:-1]): # 0 is the current class, -2 is BaseModel, -1 is object + # if BASEMODEL_METADATA_TAG_KEY in info.metadata and BASEMODEL_METADATA_KEY not in info.metadata: + # # We haven't processed the base class yet. Need another pass. + # return None, None if METADATA_KEY not in info.metadata: continue - superclass_fields = [] - # Each class depends on the set of fields in its ancestors - ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info))) + # Each class depends on the set of attributes in its dataclass ancestors. + self._api.add_plugin_dependency(make_wildcard_trigger(info.fullname)) for name, data in info.metadata[METADATA_KEY]['fields'].items(): - if name not in known_fields: - field = PydanticModelField.deserialize(info, data) - known_fields.add(name) - superclass_fields.append(field) - else: - (field,) = (a for a in all_fields if a.name == name) - all_fields.remove(field) - superclass_fields.append(field) - all_fields = superclass_fields + all_fields - return all_fields + field = PydanticModelField.deserialize(info, data, self._api) + # (The following comment comes directly from the dataclasses plugin) + # TODO: We shouldn't be performing type operations during the main + # semantic analysis pass, since some TypeInfo attributes might + # still be in flux. This should be performed in a later phase. + with state.strict_optional_set(self._api.options.strict_optional): + field.expand_typevar_from_subtype(cls.info) + found_fields[name] = field - def add_initializer(self, fields: List['PydanticModelField'], config: 'ModelConfigData', is_settings: bool) -> None: + sym_node = cls.info.names.get(name) + if sym_node and sym_node.node and not isinstance(sym_node.node, Var): + self._api.fail( + 'BaseModel field may only be overridden by another field', + sym_node.node, + ) + # Collect ClassVars + for name, data in info.metadata[METADATA_KEY]['class_vars'].items(): + found_class_vars[name] = PydanticModelClassVar.deserialize(data) + + # Second, collect fields and ClassVars belonging to the current class. + current_field_names: set[str] = set() + current_class_vars_names: set[str] = set() + for stmt in self._get_assignment_statements_from_block(cls.defs): + maybe_field = self.collect_field_or_class_var_from_stmt(stmt, model_config, found_class_vars) + if isinstance(maybe_field, PydanticModelField): + lhs = stmt.lvalues[0] + if is_root_model and lhs.name != 'root': + error_extra_fields_on_root_model(self._api, stmt) + else: + current_field_names.add(lhs.name) + found_fields[lhs.name] = maybe_field + elif isinstance(maybe_field, PydanticModelClassVar): + lhs = stmt.lvalues[0] + current_class_vars_names.add(lhs.name) + found_class_vars[lhs.name] = maybe_field + + return list(found_fields.values()), list(found_class_vars.values()) + + def _get_assignment_statements_from_if_statement(self, stmt: IfStmt) -> Iterator[AssignmentStmt]: + for body in stmt.body: + if not body.is_unreachable: + yield from self._get_assignment_statements_from_block(body) + if stmt.else_body is not None and not stmt.else_body.is_unreachable: + yield from self._get_assignment_statements_from_block(stmt.else_body) + + def _get_assignment_statements_from_block(self, block: Block) -> Iterator[AssignmentStmt]: + for stmt in block.body: + if isinstance(stmt, AssignmentStmt): + yield stmt + elif isinstance(stmt, IfStmt): + yield from self._get_assignment_statements_from_if_statement(stmt) + + def collect_field_or_class_var_from_stmt( # noqa C901 + self, stmt: AssignmentStmt, model_config: ModelConfigData, class_vars: dict[str, PydanticModelClassVar] + ) -> PydanticModelField | PydanticModelClassVar | None: + """Get pydantic model field from statement. + + Args: + stmt: The statement. + model_config: Configuration settings for the model. + class_vars: ClassVars already known to be defined on the model. + + Returns: + A pydantic model field if it could find the field in statement. Otherwise, `None`. """ - Adds a fields-aware `__init__` method to the class. + cls = self._cls + + lhs = stmt.lvalues[0] + if not isinstance(lhs, NameExpr) or not _fields.is_valid_field_name(lhs.name) or lhs.name == 'model_config': + return None + + if not stmt.new_syntax: + if ( + isinstance(stmt.rvalue, CallExpr) + and isinstance(stmt.rvalue.callee, CallExpr) + and isinstance(stmt.rvalue.callee.callee, NameExpr) + and stmt.rvalue.callee.callee.fullname in DECORATOR_FULLNAMES + ): + # This is a (possibly-reused) validator or serializer, not a field + # In particular, it looks something like: my_validator = validator('my_field')(f) + # Eventually, we may want to attempt to respect model_config['ignored_types'] + return None + + if lhs.name in class_vars: + # Class vars are not fields and are not required to be annotated + return None + + # The assignment does not have an annotation, and it's not anything else we recognize + error_untyped_fields(self._api, stmt) + return None + + lhs = stmt.lvalues[0] + if not isinstance(lhs, NameExpr): + return None + + if not _fields.is_valid_field_name(lhs.name) or lhs.name == 'model_config': + return None + + sym = cls.info.names.get(lhs.name) + if sym is None: # pragma: no cover + # This is likely due to a star import (see the dataclasses plugin for a more detailed explanation) + # This is the same logic used in the dataclasses plugin + return None + + node = sym.node + if isinstance(node, PlaceholderNode): # pragma: no cover + # See the PlaceholderNode docstring for more detail about how this can occur + # Basically, it is an edge case when dealing with complex import logic + + # The dataclasses plugin now asserts this cannot happen, but I'd rather not error if it does.. + return None + + if isinstance(node, TypeAlias): + self._api.fail( + 'Type aliases inside BaseModel definitions are not supported at runtime', + node, + ) + # Skip processing this node. This doesn't match the runtime behaviour, + # but the only alternative would be to modify the SymbolTable, + # and it's a little hairy to do that in a plugin. + return None + + if not isinstance(node, Var): # pragma: no cover + # Don't know if this edge case still happens with the `is_valid_field` check above + # but better safe than sorry + + # The dataclasses plugin now asserts this cannot happen, but I'd rather not error if it does.. + return None + + # x: ClassVar[int] is not a field + if node.is_classvar: + return PydanticModelClassVar(lhs.name) + + # x: InitVar[int] is not supported in BaseModel + node_type = get_proper_type(node.type) + if isinstance(node_type, Instance) and node_type.type.fullname == 'dataclasses.InitVar': + self._api.fail( + 'InitVar is not supported in BaseModel', + node, + ) + + has_default = self.get_has_default(stmt) + + if sym.type is None and node.is_final and node.is_inferred: + # This follows the logic from the dataclasses plugin. The following comment is taken verbatim: + # + # This is a special case, assignment like x: Final = 42 is classified + # annotated above, but mypy strips the `Final` turning it into x = 42. + # We do not support inferred types in dataclasses, so we can try inferring + # type for simple literals, and otherwise require an explicit type + # argument for Final[...]. + typ = self._api.analyze_simple_literal_type(stmt.rvalue, is_final=True) + if typ: + node.type = typ + else: + self._api.fail( + 'Need type argument for Final[...] with non-literal default in BaseModel', + stmt, + ) + node.type = AnyType(TypeOfAny.from_error) + + alias, has_dynamic_alias = self.get_alias_info(stmt) + if has_dynamic_alias and not model_config.populate_by_name and self.plugin_config.warn_required_dynamic_aliases: + error_required_dynamic_aliases(self._api, stmt) + + init_type = self._infer_dataclass_attr_init_type(sym, lhs.name, stmt) + return PydanticModelField( + name=lhs.name, + has_dynamic_alias=has_dynamic_alias, + has_default=has_default, + alias=alias, + line=stmt.line, + column=stmt.column, + type=init_type, + info=cls.info, + ) + + def _infer_dataclass_attr_init_type(self, sym: SymbolTableNode, name: str, context: Context) -> Type | None: + """Infer __init__ argument type for an attribute. + + In particular, possibly use the signature of __set__. + """ + default = sym.type + if sym.implicit: + return default + t = get_proper_type(sym.type) + + # Perform a simple-minded inference from the signature of __set__, if present. + # We can't use mypy.checkmember here, since this plugin runs before type checking. + # We only support some basic scanerios here, which is hopefully sufficient for + # the vast majority of use cases. + if not isinstance(t, Instance): + return default + setter = t.type.get('__set__') + if setter: + if isinstance(setter.node, FuncDef): + super_info = t.type.get_containing_type_info('__set__') + assert super_info + if setter.type: + setter_type = get_proper_type(map_type_from_supertype(setter.type, t.type, super_info)) + else: + return AnyType(TypeOfAny.unannotated) + if isinstance(setter_type, CallableType) and setter_type.arg_kinds == [ + ARG_POS, + ARG_POS, + ARG_POS, + ]: + return expand_type_by_instance(setter_type.arg_types[2], t) + else: + self._api.fail(f'Unsupported signature for "__set__" in "{t.type.name}"', context) + else: + self._api.fail(f'Unsupported "__set__" in "{t.type.name}"', context) + + return default + + def add_initializer( + self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool, is_root_model: bool + ) -> None: + """Adds a fields-aware `__init__` method to the class. The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings. """ - ctx = self._ctx + if '__init__' in self._cls.info.names and not self._cls.info.names['__init__'].plugin_generated: + return # Don't generate an __init__ if one already exists + typed = self.plugin_config.init_typed - use_alias = config.allow_population_by_field_name is not True - force_all_optional = is_settings or bool( - config.has_alias_generator and not config.allow_population_by_field_name - ) - init_arguments = self.get_field_arguments( - fields, typed=typed, force_all_optional=force_all_optional, use_alias=use_alias - ) + use_alias = config.populate_by_name is not True + requires_dynamic_aliases = bool(config.has_alias_generator and not config.populate_by_name) + with state.strict_optional_set(self._api.options.strict_optional): + args = self.get_field_arguments( + fields, + typed=typed, + requires_dynamic_aliases=requires_dynamic_aliases, + use_alias=use_alias, + is_settings=is_settings, + ) + + if is_root_model and MYPY_VERSION_TUPLE <= (1, 0, 1): + # convert root argument to positional argument + # This is needed because mypy support for `dataclass_transform` isn't complete on 1.0.1 + args[0].kind = ARG_POS if args[0].kind == ARG_NAMED else ARG_OPT + + if is_settings: + base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node + if '__init__' in base_settings_node.names: + base_settings_init_node = base_settings_node.names['__init__'].node + if base_settings_init_node is not None and base_settings_init_node.type is not None: + func_type = base_settings_init_node.type + for arg_idx, arg_name in enumerate(func_type.arg_names): + if arg_name.startswith('__') or not arg_name.startswith('_'): + continue + analyzed_variable_type = self._api.anal_type(func_type.arg_types[arg_idx]) + variable = Var(arg_name, analyzed_variable_type) + args.append(Argument(variable, analyzed_variable_type, None, ARG_OPT)) + if not self.should_init_forbid_extra(fields, config): var = Var('kwargs') - init_arguments.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2)) + args.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2)) - if '__init__' not in ctx.cls.info.names: - add_method(ctx, '__init__', init_arguments, NoneType()) + add_method(self._api, self._cls, '__init__', args=args, return_type=NoneType()) - def add_construct_method(self, fields: List['PydanticModelField']) -> None: - """ - Adds a fully typed `construct` classmethod to the class. + def add_model_construct_method( + self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool + ) -> None: + """Adds a fully typed `model_construct` classmethod to the class. Similar to the fields-aware __init__ method, but always uses the field names (not aliases), and does not treat settings fields as optional. """ - ctx = self._ctx - set_str = ctx.api.named_type(f'{BUILTINS_NAME}.set', [ctx.api.named_type(f'{BUILTINS_NAME}.str')]) + set_str = self._api.named_type(f'{BUILTINS_NAME}.set', [self._api.named_type(f'{BUILTINS_NAME}.str')]) optional_set_str = UnionType([set_str, NoneType()]) fields_set_argument = Argument(Var('_fields_set', optional_set_str), optional_set_str, None, ARG_OPT) - construct_arguments = self.get_field_arguments(fields, typed=True, force_all_optional=False, use_alias=False) - construct_arguments = [fields_set_argument] + construct_arguments + with state.strict_optional_set(self._api.options.strict_optional): + args = self.get_field_arguments( + fields, typed=True, requires_dynamic_aliases=False, use_alias=False, is_settings=is_settings + ) + if not self.should_init_forbid_extra(fields, config): + var = Var('kwargs') + args.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2)) - obj_type = ctx.api.named_type(f'{BUILTINS_NAME}.object') - self_tvar_name = '_PydanticBaseModel' # Make sure it does not conflict with other names in the class - tvar_fullname = ctx.cls.fullname + '.' + self_tvar_name - tvd = TypeVarDef(self_tvar_name, tvar_fullname, -1, [], obj_type) - self_tvar_expr = TypeVarExpr(self_tvar_name, tvar_fullname, [], obj_type) - ctx.cls.info.names[self_tvar_name] = SymbolTableNode(MDEF, self_tvar_expr) - - # Backward-compatible with TypeVarDef from Mypy 0.910. - if isinstance(tvd, TypeVarType): - self_type = tvd - else: - self_type = TypeVarType(tvd) # type: ignore[call-arg] + args = [fields_set_argument] + args add_method( - ctx, - 'construct', - construct_arguments, - return_type=self_type, - self_type=self_type, - tvar_def=tvd, + self._api, + self._cls, + 'model_construct', + args=args, + return_type=fill_typevars(self._cls.info), is_classmethod=True, ) - def set_frozen(self, fields: List['PydanticModelField'], frozen: bool) -> None: - """ - Marks all fields as properties so that attempts to set them trigger mypy errors. + def set_frozen(self, fields: list[PydanticModelField], api: SemanticAnalyzerPluginInterface, frozen: bool) -> None: + """Marks all fields as properties so that attempts to set them trigger mypy errors. This is the same approach used by the attrs and dataclasses plugins. """ - info = self._ctx.cls.info + info = self._cls.info for field in fields: sym_node = info.names.get(field.name) if sym_node is not None: var = sym_node.node - assert isinstance(var, Var) - var.is_property = frozen + if isinstance(var, Var): + var.is_property = frozen + elif isinstance(var, PlaceholderNode) and not self._api.final_iteration: + # See https://github.com/pydantic/pydantic/issues/5191 to hit this branch for test coverage + self._api.defer() + else: # pragma: no cover + # I don't know whether it's possible to hit this branch, but I've added it for safety + try: + var_str = str(var) + except TypeError: + # This happens for PlaceholderNode; perhaps it will happen for other types in the future.. + var_str = repr(var) + detail = f'sym_node.node: {var_str} (of type {var.__class__})' + error_unexpected_behavior(detail, self._api, self._cls) else: - var = field.to_var(info, use_alias=False) + var = field.to_var(info, api, use_alias=False) var.info = info var.is_property = frozen - var._fullname = get_fullname(info) + '.' + get_name(var) - info.names[get_name(var)] = SymbolTableNode(MDEF, var) + var._fullname = info.fullname + '.' + var.name + info.names[var.name] = SymbolTableNode(MDEF, var) - def get_config_update(self, substmt: AssignmentStmt) -> Optional['ModelConfigData']: - """ - Determines the config update due to a single statement in the Config class definition. + def get_config_update(self, name: str, arg: Expression) -> ModelConfigData | None: + """Determines the config update due to a single kwarg in the ConfigDict definition. Warns if a tracked config attribute is set to a value the plugin doesn't know how to interpret (e.g., an int) """ - lhs = substmt.lvalues[0] - if not (isinstance(lhs, NameExpr) and lhs.name in self.tracked_config_fields): + if name not in self.tracked_config_fields: return None - if lhs.name == 'extra': - if isinstance(substmt.rvalue, StrExpr): - forbid_extra = substmt.rvalue.value == 'forbid' - elif isinstance(substmt.rvalue, MemberExpr): - forbid_extra = substmt.rvalue.name == 'forbid' + if name == 'extra': + if isinstance(arg, StrExpr): + forbid_extra = arg.value == 'forbid' + elif isinstance(arg, MemberExpr): + forbid_extra = arg.name == 'forbid' else: - error_invalid_config_value(lhs.name, self._ctx.api, substmt) + error_invalid_config_value(name, self._api, arg) return None return ModelConfigData(forbid_extra=forbid_extra) - if lhs.name == 'alias_generator': + if name == 'alias_generator': has_alias_generator = True - if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname == 'builtins.None': + if isinstance(arg, NameExpr) and arg.fullname == 'builtins.None': has_alias_generator = False return ModelConfigData(has_alias_generator=has_alias_generator) - if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname in ('builtins.True', 'builtins.False'): - return ModelConfigData(**{lhs.name: substmt.rvalue.fullname == 'builtins.True'}) - error_invalid_config_value(lhs.name, self._ctx.api, substmt) + if isinstance(arg, NameExpr) and arg.fullname in ('builtins.True', 'builtins.False'): + return ModelConfigData(**{name: arg.fullname == 'builtins.True'}) + error_invalid_config_value(name, self._api, arg) return None @staticmethod - def get_is_required(cls: ClassDef, stmt: AssignmentStmt, lhs: NameExpr) -> bool: - """ - Returns a boolean indicating whether the field defined in `stmt` is a required field. - """ + def get_has_default(stmt: AssignmentStmt) -> bool: + """Returns a boolean indicating whether the field defined in `stmt` is a required field.""" expr = stmt.rvalue if isinstance(expr, TempNode): - # TempNode means annotation-only, so only non-required if Optional - value_type = get_proper_type(cls.info[lhs.name].type) - if isinstance(value_type, UnionType) and any(isinstance(item, NoneType) for item in value_type.items): - # Annotated as Optional, or otherwise having NoneType in the union - return False - return True + # TempNode means annotation-only, so has no default + return False if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME: - # The "default value" is a call to `Field`; at this point, the field is - # only required if default is Ellipsis (i.e., `field_name: Annotation = Field(...)`) or if default_factory - # is specified. + # The "default value" is a call to `Field`; at this point, the field has a default if and only if: + # * there is a positional argument that is not `...` + # * there is a keyword argument named "default" that is not `...` + # * there is a "default_factory" that is not `None` for arg, name in zip(expr.args, expr.arg_names): - # If name is None, then this arg is the default because it is the only positonal argument. + # If name is None, then this arg is the default because it is the only positional argument. if name is None or name == 'default': - return arg.__class__ is EllipsisExpr + return arg.__class__ is not EllipsisExpr if name == 'default_factory': - return False - return True - # Only required if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`) - return isinstance(expr, EllipsisExpr) + return not (isinstance(arg, NameExpr) and arg.fullname == 'builtins.None') + return False + # Has no default if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`) + return not isinstance(expr, EllipsisExpr) @staticmethod - def get_alias_info(stmt: AssignmentStmt) -> Tuple[Optional[str], bool]: - """ - Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`. + def get_alias_info(stmt: AssignmentStmt) -> tuple[str | None, bool]: + """Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`. `has_dynamic_alias` is True if and only if an alias is provided, but not as a string literal. If `has_dynamic_alias` is True, `alias` will be None. @@ -573,29 +1033,38 @@ class PydanticModelTransformer: return None, False def get_field_arguments( - self, fields: List['PydanticModelField'], typed: bool, force_all_optional: bool, use_alias: bool - ) -> List[Argument]: - """ - Helper function used during the construction of the `__init__` and `construct` method signatures. + self, + fields: list[PydanticModelField], + typed: bool, + use_alias: bool, + requires_dynamic_aliases: bool, + is_settings: bool, + ) -> list[Argument]: + """Helper function used during the construction of the `__init__` and `model_construct` method signatures. Returns a list of mypy Argument instances for use in the generated signatures. """ - info = self._ctx.cls.info + info = self._cls.info arguments = [ - field.to_argument(info, typed=typed, force_optional=force_all_optional, use_alias=use_alias) + field.to_argument( + info, + typed=typed, + force_optional=requires_dynamic_aliases or is_settings, + use_alias=use_alias, + api=self._api, + ) for field in fields if not (use_alias and field.has_dynamic_alias) ] return arguments - def should_init_forbid_extra(self, fields: List['PydanticModelField'], config: 'ModelConfigData') -> bool: - """ - Indicates whether the generated `__init__` should get a `**kwargs` at the end of its signature + def should_init_forbid_extra(self, fields: list[PydanticModelField], config: ModelConfigData) -> bool: + """Indicates whether the generated `__init__` should get a `**kwargs` at the end of its signature. We disallow arbitrary kwargs if the extra config setting is "forbid", or if the plugin config says to, *unless* a required dynamic alias is present (since then we can't determine a valid signature). """ - if not config.allow_population_by_field_name: + if not config.populate_by_name: if self.is_dynamic_alias_present(fields, bool(config.has_alias_generator)): return False if config.forbid_extra: @@ -603,9 +1072,8 @@ class PydanticModelTransformer: return self.plugin_config.init_forbid_extra @staticmethod - def is_dynamic_alias_present(fields: List['PydanticModelField'], has_alias_generator: bool) -> bool: - """ - Returns whether any fields on the model have a "dynamic alias", i.e., an alias that cannot be + def is_dynamic_alias_present(fields: list[PydanticModelField], has_alias_generator: bool) -> bool: + """Returns whether any fields on the model have a "dynamic alias", i.e., an alias that cannot be determined during static analysis. """ for field in fields: @@ -618,95 +1086,74 @@ class PydanticModelTransformer: return False -class PydanticModelField: - def __init__( - self, name: str, is_required: bool, alias: Optional[str], has_dynamic_alias: bool, line: int, column: int - ): - self.name = name - self.is_required = is_required - self.alias = alias - self.has_dynamic_alias = has_dynamic_alias - self.line = line - self.column = column - - def to_var(self, info: TypeInfo, use_alias: bool) -> Var: - name = self.name - if use_alias and self.alias is not None: - name = self.alias - return Var(name, info[self.name].type) - - def to_argument(self, info: TypeInfo, typed: bool, force_optional: bool, use_alias: bool) -> Argument: - if typed and info[self.name].type is not None: - type_annotation = info[self.name].type - else: - type_annotation = AnyType(TypeOfAny.explicit) - return Argument( - variable=self.to_var(info, use_alias), - type_annotation=type_annotation, - initializer=None, - kind=ARG_NAMED_OPT if force_optional or not self.is_required else ARG_NAMED, - ) - - def serialize(self) -> JsonDict: - return self.__dict__ - - @classmethod - def deserialize(cls, info: TypeInfo, data: JsonDict) -> 'PydanticModelField': - return cls(**data) - - class ModelConfigData: + """Pydantic mypy plugin model config class.""" + def __init__( self, - forbid_extra: Optional[bool] = None, - allow_mutation: Optional[bool] = None, - frozen: Optional[bool] = None, - orm_mode: Optional[bool] = None, - allow_population_by_field_name: Optional[bool] = None, - has_alias_generator: Optional[bool] = None, + forbid_extra: bool | None = None, + frozen: bool | None = None, + from_attributes: bool | None = None, + populate_by_name: bool | None = None, + has_alias_generator: bool | None = None, ): self.forbid_extra = forbid_extra - self.allow_mutation = allow_mutation self.frozen = frozen - self.orm_mode = orm_mode - self.allow_population_by_field_name = allow_population_by_field_name + self.from_attributes = from_attributes + self.populate_by_name = populate_by_name self.has_alias_generator = has_alias_generator - def set_values_dict(self) -> Dict[str, Any]: + def get_values_dict(self) -> dict[str, Any]: + """Returns a dict of Pydantic model config names to their values. + + It includes the config if config value is not `None`. + """ return {k: v for k, v in self.__dict__.items() if v is not None} - def update(self, config: Optional['ModelConfigData']) -> None: + def update(self, config: ModelConfigData | None) -> None: + """Update Pydantic model config values.""" if config is None: return - for k, v in config.set_values_dict().items(): + for k, v in config.get_values_dict().items(): setattr(self, k, v) def setdefault(self, key: str, value: Any) -> None: + """Set default value for Pydantic model config if config value is `None`.""" if getattr(self, key) is None: setattr(self, key, value) -ERROR_ORM = ErrorCode('pydantic-orm', 'Invalid from_orm call', 'Pydantic') +ERROR_ORM = ErrorCode('pydantic-orm', 'Invalid from_attributes call', 'Pydantic') ERROR_CONFIG = ErrorCode('pydantic-config', 'Invalid config value', 'Pydantic') ERROR_ALIAS = ErrorCode('pydantic-alias', 'Dynamic alias disallowed', 'Pydantic') ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic') ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic') ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic') +ERROR_EXTRA_FIELD_ROOT_MODEL = ErrorCode('pydantic-field', 'Extra field on RootModel subclass', 'Pydantic') -def error_from_orm(model_name: str, api: CheckerPluginInterface, context: Context) -> None: - api.fail(f'"{model_name}" does not have orm_mode=True', context, code=ERROR_ORM) +def error_from_attributes(model_name: str, api: CheckerPluginInterface, context: Context) -> None: + """Emits an error when the model does not have `from_attributes=True`.""" + api.fail(f'"{model_name}" does not have from_attributes=True', context, code=ERROR_ORM) def error_invalid_config_value(name: str, api: SemanticAnalyzerPluginInterface, context: Context) -> None: + """Emits an error when the config value is invalid.""" api.fail(f'Invalid value for "Config.{name}"', context, code=ERROR_CONFIG) def error_required_dynamic_aliases(api: SemanticAnalyzerPluginInterface, context: Context) -> None: + """Emits required dynamic aliases error. + + This will be called when `warn_required_dynamic_aliases=True`. + """ api.fail('Required dynamic aliases disallowed', context, code=ERROR_ALIAS) -def error_unexpected_behavior(detail: str, api: CheckerPluginInterface, context: Context) -> None: # pragma: no cover +def error_unexpected_behavior( + detail: str, api: CheckerPluginInterface | SemanticAnalyzerPluginInterface, context: Context +) -> None: # pragma: no cover + """Emits unexpected behavior error.""" # Can't think of a good way to test this, but I confirmed it renders as desired by adding to a non-error path link = 'https://github.com/pydantic/pydantic/issues/new/choose' full_message = f'The pydantic mypy plugin ran into unexpected behavior: {detail}\n' @@ -715,55 +1162,70 @@ def error_unexpected_behavior(detail: str, api: CheckerPluginInterface, context: def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context) -> None: + """Emits an error when there is an untyped field in the model.""" api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED) +def error_extra_fields_on_root_model(api: CheckerPluginInterface, context: Context) -> None: + """Emits an error when there is more than just a root field defined for a subclass of RootModel.""" + api.fail('Only `root` is allowed as a field of a `RootModel`', context, code=ERROR_EXTRA_FIELD_ROOT_MODEL) + + def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None: + """Emits an error when `Field` has both `default` and `default_factory` together.""" api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS) def add_method( - ctx: ClassDefContext, + api: SemanticAnalyzerPluginInterface | CheckerPluginInterface, + cls: ClassDef, name: str, - args: List[Argument], + args: list[Argument], return_type: Type, - self_type: Optional[Type] = None, - tvar_def: Optional[TypeVarDef] = None, + self_type: Type | None = None, + tvar_def: TypeVarDef | None = None, is_classmethod: bool = False, - is_new: bool = False, - # is_staticmethod: bool = False, ) -> None: - """ - Adds a new method to a class. - - This can be dropped if/when https://github.com/python/mypy/issues/7301 is merged - """ - info = ctx.cls.info + """Very closely related to `mypy.plugins.common.add_method_to_class`, with a few pydantic-specific changes.""" + info = cls.info # First remove any previously generated methods with the same name # to avoid clashes and problems in the semantic analyzer. if name in info.names: sym = info.names[name] if sym.plugin_generated and isinstance(sym.node, FuncDef): - ctx.cls.defs.body.remove(sym.node) # pragma: no cover + cls.defs.body.remove(sym.node) # pragma: no cover - self_type = self_type or fill_typevars(info) - if is_classmethod or is_new: - first = [Argument(Var('_cls'), TypeType.make_normalized(self_type), None, ARG_POS)] - # elif is_staticmethod: - # first = [] + if isinstance(api, SemanticAnalyzerPluginInterface): + function_type = api.named_type('builtins.function') + else: + function_type = api.named_generic_type('builtins.function', []) + + if is_classmethod: + self_type = self_type or TypeType(fill_typevars(info)) + first = [Argument(Var('_cls'), self_type, None, ARG_POS, True)] else: self_type = self_type or fill_typevars(info) + # `self` is positional *ONLY* here, but this can't be expressed + # fully in the mypy internal API. ARG_POS is the closest we can get. + # Using ARG_POS will, however, give mypy errors if a `self` field + # is present on a model: + # + # Name "self" already defined (possibly by an import) [no-redef] + # + # As a workaround, we give this argument a name that will + # never conflict. By its positional nature, this name will not + # be used or exposed to users. first = [Argument(Var('__pydantic_self__'), self_type, None, ARG_POS)] args = first + args + arg_types, arg_names, arg_kinds = [], [], [] for arg in args: assert arg.type_annotation, 'All arguments must be fully typed.' arg_types.append(arg.type_annotation) - arg_names.append(get_name(arg.variable)) + arg_names.append(arg.variable.name) arg_kinds.append(arg.kind) - function_type = ctx.api.named_type(f'{BUILTINS_NAME}.function') signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type) if tvar_def: signature.variables = [tvar_def] @@ -772,8 +1234,7 @@ def add_method( func.info = info func.type = set_callable_name(signature, func) func.is_class = is_classmethod - # func.is_static = is_staticmethod - func._fullname = get_fullname(info) + '.' + name + func._fullname = info.fullname + '.' + name func.line = info.line # NOTE: we would like the plugin generated node to dominate, but we still @@ -783,68 +1244,44 @@ def add_method( r_name = get_unique_redefinition_name(name, info.names) info.names[r_name] = info.names[name] - if is_classmethod: # or is_staticmethod: + # Add decorator for is_classmethod + # The dataclasses plugin claims this is unnecessary for classmethods, but not including it results in a + # signature incompatible with the superclass, which causes mypy errors to occur for every subclass of BaseModel. + if is_classmethod: func.is_decorated = True v = Var(name, func.type) v.info = info v._fullname = func._fullname - # if is_classmethod: v.is_classmethod = True dec = Decorator(func, [NameExpr('classmethod')], v) - # else: - # v.is_staticmethod = True - # dec = Decorator(func, [NameExpr('staticmethod')], v) - dec.line = info.line sym = SymbolTableNode(MDEF, dec) else: sym = SymbolTableNode(MDEF, func) sym.plugin_generated = True - info.names[name] = sym + info.defn.defs.body.append(func) -def get_fullname(x: Union[FuncBase, SymbolNode]) -> str: - """ - Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped. - """ - fn = x.fullname - if callable(fn): # pragma: no cover - return fn() - return fn +def parse_toml(config_file: str) -> dict[str, Any] | None: + """Returns a dict of config keys to values. - -def get_name(x: Union[FuncBase, SymbolNode]) -> str: + It reads configs from toml file and returns `None` if the file is not a toml file. """ - Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped. - """ - fn = x.name - if callable(fn): # pragma: no cover - return fn() - return fn - - -def parse_toml(config_file: str) -> Optional[Dict[str, Any]]: if not config_file.endswith('.toml'): return None - read_mode = 'rb' if sys.version_info >= (3, 11): import tomllib as toml_ else: try: import tomli as toml_ - except ImportError: - # older versions of mypy have toml as a dependency, not tomli - read_mode = 'r' - try: - import toml as toml_ # type: ignore[no-redef] - except ImportError: # pragma: no cover - import warnings + except ImportError: # pragma: no cover + import warnings - warnings.warn('No TOML parser installed, cannot read configuration from `pyproject.toml`.') - return None + warnings.warn('No TOML parser installed, cannot read configuration from `pyproject.toml`.') + return None - with open(config_file, read_mode) as rf: - return toml_.load(rf) # type: ignore[arg-type] + with open(config_file, 'rb') as rf: + return toml_.load(rf) diff --git a/lib/pydantic/networks.py b/lib/pydantic/networks.py index c7d97186..6d9d292f 100644 --- a/lib/pydantic/networks.py +++ b/lib/pydantic/networks.py @@ -1,80 +1,35 @@ -import re -from ipaddress import ( - IPv4Address, - IPv4Interface, - IPv4Network, - IPv6Address, - IPv6Interface, - IPv6Network, - _BaseAddress, - _BaseNetwork, -) -from typing import ( - TYPE_CHECKING, - Any, - Collection, - Dict, - Generator, - List, - Match, - Optional, - Pattern, - Set, - Tuple, - Type, - Union, - cast, - no_type_check, -) +"""The networks module contains types for common network-related fields.""" +from __future__ import annotations as _annotations -from . import errors -from .utils import Representation, update_not_none -from .validators import constr_length_validator, str_validator +import dataclasses as _dataclasses +import re +from importlib.metadata import version +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from typing import TYPE_CHECKING, Any + +from pydantic_core import MultiHostUrl, PydanticCustomError, Url, core_schema +from typing_extensions import Annotated, TypeAlias + +from ._internal import _fields, _repr, _schema_generation_shared +from ._migration import getattr_migration +from .annotated_handlers import GetCoreSchemaHandler +from .json_schema import JsonSchemaValue if TYPE_CHECKING: import email_validator - from typing_extensions import TypedDict - from .config import BaseConfig - from .fields import ModelField - from .typing import AnyCallable - - CallableGenerator = Generator[AnyCallable, None, None] - - class Parts(TypedDict, total=False): - scheme: str - user: Optional[str] - password: Optional[str] - ipv4: Optional[str] - ipv6: Optional[str] - domain: Optional[str] - port: Optional[str] - path: Optional[str] - query: Optional[str] - fragment: Optional[str] - - class HostParts(TypedDict, total=False): - host: str - tld: Optional[str] - host_type: Optional[str] - port: Optional[str] - rebuild: bool + NetworkType: TypeAlias = 'str | bytes | int | tuple[str | bytes | int, str | int]' else: email_validator = None - class Parts(dict): - pass - - -NetworkType = Union[str, bytes, int, Tuple[Union[str, bytes, int], Union[str, int]]] __all__ = [ 'AnyUrl', 'AnyHttpUrl', 'FileUrl', 'HttpUrl', - 'stricturl', + 'UrlConstraints', 'EmailStr', 'NameEmail', 'IPvAnyAddress', @@ -86,489 +41,321 @@ __all__ = [ 'RedisDsn', 'MongoDsn', 'KafkaDsn', + 'NatsDsn', 'validate_email', + 'MySQLDsn', + 'MariaDBDsn', ] -_url_regex_cache = None -_multi_host_url_regex_cache = None -_ascii_domain_regex_cache = None -_int_domain_regex_cache = None -_host_regex_cache = None -_host_regex = ( - r'(?:' - r'(?P(?:\d{1,3}\.){3}\d{1,3})(?=$|[/:#?])|' # ipv4 - r'(?P\[[A-F0-9]*:[A-F0-9:]+\])(?=$|[/:#?])|' # ipv6 - r'(?P[^\s/:?#]+)' # domain, validation occurs later - r')?' - r'(?::(?P\d+))?' # port +@_dataclasses.dataclass +class UrlConstraints(_fields.PydanticMetadata): + """Url constraints. + + Attributes: + max_length: The maximum length of the url. Defaults to `None`. + allowed_schemes: The allowed schemes. Defaults to `None`. + host_required: Whether the host is required. Defaults to `None`. + default_host: The default host. Defaults to `None`. + default_port: The default port. Defaults to `None`. + default_path: The default path. Defaults to `None`. + """ + + max_length: int | None = None + allowed_schemes: list[str] | None = None + host_required: bool | None = None + default_host: str | None = None + default_port: int | None = None + default_path: str | None = None + + def __hash__(self) -> int: + return hash( + ( + self.max_length, + tuple(self.allowed_schemes) if self.allowed_schemes is not None else None, + self.host_required, + self.default_host, + self.default_port, + self.default_path, + ) + ) + + +AnyUrl = Url +"""Base type for all URLs. + +* Any scheme allowed +* Top-level domain (TLD) not required +* Host required + +Assuming an input URL of `http://samuel:pass@example.com:8000/the/path/?query=here#fragment=is;this=bit`, +the types export the following properties: + +- `scheme`: the URL scheme (`http`), always set. +- `host`: the URL host (`example.com`), always set. +- `username`: optional username if included (`samuel`). +- `password`: optional password if included (`pass`). +- `port`: optional port (`8000`). +- `path`: optional path (`/the/path/`). +- `query`: optional URL query (for example, `GET` arguments or "search string", such as `query=here`). +- `fragment`: optional fragment (`fragment=is;this=bit`). +""" +AnyHttpUrl = Annotated[Url, UrlConstraints(allowed_schemes=['http', 'https'])] +"""A type that will accept any http or https URL. + +* TLD not required +* Host required +""" +HttpUrl = Annotated[Url, UrlConstraints(max_length=2083, allowed_schemes=['http', 'https'])] +"""A type that will accept any http or https URL. + +* TLD required +* Host required +* Max length 2083 + +```py +from pydantic import BaseModel, HttpUrl, ValidationError + +class MyModel(BaseModel): + url: HttpUrl + +m = MyModel(url='http://www.example.com') +print(m.url) +#> http://www.example.com/ + +try: + MyModel(url='ftp://invalid.url') +except ValidationError as e: + print(e) + ''' + 1 validation error for MyModel + url + URL scheme should be 'http' or 'https' [type=url_scheme, input_value='ftp://invalid.url', input_type=str] + ''' + +try: + MyModel(url='not a url') +except ValidationError as e: + print(e) + ''' + 1 validation error for MyModel + url + Input should be a valid URL, relative URL without a base [type=url_parsing, input_value='not a url', input_type=str] + ''' +``` + +"International domains" (e.g. a URL where the host or TLD includes non-ascii characters) will be encoded via +[punycode](https://en.wikipedia.org/wiki/Punycode) (see +[this article](https://www.xudongz.com/blog/2017/idn-phishing/) for a good description of why this is important): + +```py +from pydantic import BaseModel, HttpUrl + +class MyModel(BaseModel): + url: HttpUrl + +m1 = MyModel(url='http://puny£code.com') +print(m1.url) +#> http://xn--punycode-eja.com/ +m2 = MyModel(url='https://www.аррӏе.com/') +print(m2.url) +#> https://www.xn--80ak6aa92e.com/ +m3 = MyModel(url='https://www.example.珠宝/') +print(m3.url) +#> https://www.example.xn--pbt977c/ +``` + + +!!! warning "Underscores in Hostnames" + In Pydantic, underscores are allowed in all parts of a domain except the TLD. + Technically this might be wrong - in theory the hostname cannot have underscores, but subdomains can. + + To explain this; consider the following two cases: + + - `exam_ple.co.uk`: the hostname is `exam_ple`, which should not be allowed since it contains an underscore. + - `foo_bar.example.com` the hostname is `example`, which should be allowed since the underscore is in the subdomain. + + Without having an exhaustive list of TLDs, it would be impossible to differentiate between these two. Therefore + underscores are allowed, but you can always do further validation in a validator if desired. + + Also, Chrome, Firefox, and Safari all currently accept `http://exam_ple.com` as a URL, so we're in good + (or at least big) company. +""" +FileUrl = Annotated[Url, UrlConstraints(allowed_schemes=['file'])] +"""A type that will accept any file URL. + +* Host not required +""" +PostgresDsn = Annotated[ + MultiHostUrl, + UrlConstraints( + host_required=True, + allowed_schemes=[ + 'postgres', + 'postgresql', + 'postgresql+asyncpg', + 'postgresql+pg8000', + 'postgresql+psycopg', + 'postgresql+psycopg2', + 'postgresql+psycopg2cffi', + 'postgresql+py-postgresql', + 'postgresql+pygresql', + ], + ), +] +"""A type that will accept any Postgres DSN. + +* User info required +* TLD not required +* Host required +* Supports multiple hosts + +If further validation is required, these properties can be used by validators to enforce specific behaviour: + +```py +from pydantic import ( + BaseModel, + HttpUrl, + PostgresDsn, + ValidationError, + field_validator, ) -_scheme_regex = r'(?:(?P[a-z][a-z0-9+\-.]+)://)?' # scheme https://tools.ietf.org/html/rfc3986#appendix-A -_user_info_regex = r'(?:(?P[^\s:/]*)(?::(?P[^\s/]*))?@)?' -_path_regex = r'(?P/[^\s?#]*)?' -_query_regex = r'(?:\?(?P[^\s#]*))?' -_fragment_regex = r'(?:#(?P[^\s#]*))?' - -def url_regex() -> Pattern[str]: - global _url_regex_cache - if _url_regex_cache is None: - _url_regex_cache = re.compile( - rf'{_scheme_regex}{_user_info_regex}{_host_regex}{_path_regex}{_query_regex}{_fragment_regex}', - re.IGNORECASE, - ) - return _url_regex_cache - - -def multi_host_url_regex() -> Pattern[str]: - """ - Compiled multi host url regex. - - Additionally to `url_regex` it allows to match multiple hosts. - E.g. host1.db.net,host2.db.net - """ - global _multi_host_url_regex_cache - if _multi_host_url_regex_cache is None: - _multi_host_url_regex_cache = re.compile( - rf'{_scheme_regex}{_user_info_regex}' - r'(?P([^/]*))' # validation occurs later - rf'{_path_regex}{_query_regex}{_fragment_regex}', - re.IGNORECASE, - ) - return _multi_host_url_regex_cache - - -def ascii_domain_regex() -> Pattern[str]: - global _ascii_domain_regex_cache - if _ascii_domain_regex_cache is None: - ascii_chunk = r'[_0-9a-z](?:[-_0-9a-z]{0,61}[_0-9a-z])?' - ascii_domain_ending = r'(?P\.[a-z]{2,63})?\.?' - _ascii_domain_regex_cache = re.compile( - fr'(?:{ascii_chunk}\.)*?{ascii_chunk}{ascii_domain_ending}', re.IGNORECASE - ) - return _ascii_domain_regex_cache - - -def int_domain_regex() -> Pattern[str]: - global _int_domain_regex_cache - if _int_domain_regex_cache is None: - int_chunk = r'[_0-9a-\U00040000](?:[-_0-9a-\U00040000]{0,61}[_0-9a-\U00040000])?' - int_domain_ending = r'(?P(\.[^\W\d_]{2,63})|(\.(?:xn--)[_0-9a-z-]{2,63}))?\.?' - _int_domain_regex_cache = re.compile(fr'(?:{int_chunk}\.)*?{int_chunk}{int_domain_ending}', re.IGNORECASE) - return _int_domain_regex_cache - - -def host_regex() -> Pattern[str]: - global _host_regex_cache - if _host_regex_cache is None: - _host_regex_cache = re.compile( - _host_regex, - re.IGNORECASE, - ) - return _host_regex_cache - - -class AnyUrl(str): - strip_whitespace = True - min_length = 1 - max_length = 2**16 - allowed_schemes: Optional[Collection[str]] = None - tld_required: bool = False - user_required: bool = False - host_required: bool = True - hidden_parts: Set[str] = set() - - __slots__ = ('scheme', 'user', 'password', 'host', 'tld', 'host_type', 'port', 'path', 'query', 'fragment') - - @no_type_check - def __new__(cls, url: Optional[str], **kwargs) -> object: - return str.__new__(cls, cls.build(**kwargs) if url is None else url) - - def __init__( - self, - url: str, - *, - scheme: str, - user: Optional[str] = None, - password: Optional[str] = None, - host: Optional[str] = None, - tld: Optional[str] = None, - host_type: str = 'domain', - port: Optional[str] = None, - path: Optional[str] = None, - query: Optional[str] = None, - fragment: Optional[str] = None, - ) -> None: - str.__init__(url) - self.scheme = scheme - self.user = user - self.password = password - self.host = host - self.tld = tld - self.host_type = host_type - self.port = port - self.path = path - self.query = query - self.fragment = fragment - - @classmethod - def build( - cls, - *, - scheme: str, - user: Optional[str] = None, - password: Optional[str] = None, - host: str, - port: Optional[str] = None, - path: Optional[str] = None, - query: Optional[str] = None, - fragment: Optional[str] = None, - **_kwargs: str, - ) -> str: - parts = Parts( - scheme=scheme, - user=user, - password=password, - host=host, - port=port, - path=path, - query=query, - fragment=fragment, - **_kwargs, # type: ignore[misc] - ) - - url = scheme + '://' - if user: - url += user - if password: - url += ':' + password - if user or password: - url += '@' - url += host - if port and ('port' not in cls.hidden_parts or cls.get_default_parts(parts).get('port') != port): - url += ':' + port - if path: - url += path - if query: - url += '?' + query - if fragment: - url += '#' + fragment - return url - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none(field_schema, minLength=cls.min_length, maxLength=cls.max_length, format='uri') - - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield cls.validate - - @classmethod - def validate(cls, value: Any, field: 'ModelField', config: 'BaseConfig') -> 'AnyUrl': - if value.__class__ == cls: - return value - value = str_validator(value) - if cls.strip_whitespace: - value = value.strip() - url: str = cast(str, constr_length_validator(value, field, config)) - - m = cls._match_url(url) - # the regex should always match, if it doesn't please report with details of the URL tried - assert m, 'URL regex failed unexpectedly' - - original_parts = cast('Parts', m.groupdict()) - parts = cls.apply_default_parts(original_parts) - parts = cls.validate_parts(parts) - - if m.end() != len(url): - raise errors.UrlExtraError(extra=url[m.end() :]) - - return cls._build_url(m, url, parts) - - @classmethod - def _build_url(cls, m: Match[str], url: str, parts: 'Parts') -> 'AnyUrl': - """ - Validate hosts and build the AnyUrl object. Split from `validate` so this method - can be altered in `MultiHostDsn`. - """ - host, tld, host_type, rebuild = cls.validate_host(parts) - - return cls( - None if rebuild else url, - scheme=parts['scheme'], - user=parts['user'], - password=parts['password'], - host=host, - tld=tld, - host_type=host_type, - port=parts['port'], - path=parts['path'], - query=parts['query'], - fragment=parts['fragment'], - ) - - @staticmethod - def _match_url(url: str) -> Optional[Match[str]]: - return url_regex().match(url) - - @staticmethod - def _validate_port(port: Optional[str]) -> None: - if port is not None and int(port) > 65_535: - raise errors.UrlPortError() - - @classmethod - def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts': - """ - A method used to validate parts of a URL. - Could be overridden to set default values for parts if missing - """ - scheme = parts['scheme'] - if scheme is None: - raise errors.UrlSchemeError() - - if cls.allowed_schemes and scheme.lower() not in cls.allowed_schemes: - raise errors.UrlSchemePermittedError(set(cls.allowed_schemes)) - - if validate_port: - cls._validate_port(parts['port']) - - user = parts['user'] - if cls.user_required and user is None: - raise errors.UrlUserInfoError() - - return parts - - @classmethod - def validate_host(cls, parts: 'Parts') -> Tuple[str, Optional[str], str, bool]: - tld, host_type, rebuild = None, None, False - for f in ('domain', 'ipv4', 'ipv6'): - host = parts[f] # type: ignore[literal-required] - if host: - host_type = f - break - - if host is None: - if cls.host_required: - raise errors.UrlHostError() - elif host_type == 'domain': - is_international = False - d = ascii_domain_regex().fullmatch(host) - if d is None: - d = int_domain_regex().fullmatch(host) - if d is None: - raise errors.UrlHostError() - is_international = True - - tld = d.group('tld') - if tld is None and not is_international: - d = int_domain_regex().fullmatch(host) - assert d is not None - tld = d.group('tld') - is_international = True - - if tld is not None: - tld = tld[1:] - elif cls.tld_required: - raise errors.UrlHostTldError() - - if is_international: - host_type = 'int_domain' - rebuild = True - host = host.encode('idna').decode('ascii') - if tld is not None: - tld = tld.encode('idna').decode('ascii') - - return host, tld, host_type, rebuild # type: ignore - - @staticmethod - def get_default_parts(parts: 'Parts') -> 'Parts': - return {} - - @classmethod - def apply_default_parts(cls, parts: 'Parts') -> 'Parts': - for key, value in cls.get_default_parts(parts).items(): - if not parts[key]: # type: ignore[literal-required] - parts[key] = value # type: ignore[literal-required] - return parts - - def __repr__(self) -> str: - extra = ', '.join(f'{n}={getattr(self, n)!r}' for n in self.__slots__ if getattr(self, n) is not None) - return f'{self.__class__.__name__}({super().__repr__()}, {extra})' - - -class AnyHttpUrl(AnyUrl): - allowed_schemes = {'http', 'https'} - - __slots__ = () - - -class HttpUrl(AnyHttpUrl): - tld_required = True - # https://stackoverflow.com/questions/417142/what-is-the-maximum-length-of-a-url-in-different-browsers - max_length = 2083 - hidden_parts = {'port'} - - @staticmethod - def get_default_parts(parts: 'Parts') -> 'Parts': - return {'port': '80' if parts['scheme'] == 'http' else '443'} - - -class FileUrl(AnyUrl): - allowed_schemes = {'file'} - host_required = False - - __slots__ = () - - -class MultiHostDsn(AnyUrl): - __slots__ = AnyUrl.__slots__ + ('hosts',) - - def __init__(self, *args: Any, hosts: Optional[List['HostParts']] = None, **kwargs: Any): - super().__init__(*args, **kwargs) - self.hosts = hosts - - @staticmethod - def _match_url(url: str) -> Optional[Match[str]]: - return multi_host_url_regex().match(url) - - @classmethod - def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts': - return super().validate_parts(parts, validate_port=False) - - @classmethod - def _build_url(cls, m: Match[str], url: str, parts: 'Parts') -> 'MultiHostDsn': - hosts_parts: List['HostParts'] = [] - host_re = host_regex() - for host in m.groupdict()['hosts'].split(','): - d: Parts = host_re.match(host).groupdict() # type: ignore - host, tld, host_type, rebuild = cls.validate_host(d) - port = d.get('port') - cls._validate_port(port) - hosts_parts.append( - { - 'host': host, - 'host_type': host_type, - 'tld': tld, - 'rebuild': rebuild, - 'port': port, - } - ) - - if len(hosts_parts) > 1: - return cls( - None if any([hp['rebuild'] for hp in hosts_parts]) else url, - scheme=parts['scheme'], - user=parts['user'], - password=parts['password'], - path=parts['path'], - query=parts['query'], - fragment=parts['fragment'], - host_type=None, - hosts=hosts_parts, - ) - else: - # backwards compatibility with single host - host_part = hosts_parts[0] - return cls( - None if host_part['rebuild'] else url, - scheme=parts['scheme'], - user=parts['user'], - password=parts['password'], - host=host_part['host'], - tld=host_part['tld'], - host_type=host_part['host_type'], - port=host_part.get('port'), - path=parts['path'], - query=parts['query'], - fragment=parts['fragment'], - ) - - -class PostgresDsn(MultiHostDsn): - allowed_schemes = { - 'postgres', - 'postgresql', - 'postgresql+asyncpg', - 'postgresql+pg8000', - 'postgresql+psycopg2', - 'postgresql+psycopg2cffi', - 'postgresql+py-postgresql', - 'postgresql+pygresql', - } - user_required = True - - __slots__ = () - - -class CockroachDsn(AnyUrl): - allowed_schemes = { - 'cockroachdb', - 'cockroachdb+psycopg2', - 'cockroachdb+asyncpg', - } - user_required = True - - -class AmqpDsn(AnyUrl): - allowed_schemes = {'amqp', 'amqps'} - host_required = False - - -class RedisDsn(AnyUrl): - __slots__ = () - allowed_schemes = {'redis', 'rediss'} - host_required = False - - @staticmethod - def get_default_parts(parts: 'Parts') -> 'Parts': - return { - 'domain': 'localhost' if not (parts['ipv4'] or parts['ipv6']) else '', - 'port': '6379', - 'path': '/0', - } - - -class MongoDsn(AnyUrl): - allowed_schemes = {'mongodb'} - - # TODO: Needed to generic "Parts" for "Replica Set", "Sharded Cluster", and other mongodb deployment modes - @staticmethod - def get_default_parts(parts: 'Parts') -> 'Parts': - return { - 'port': '27017', - } - - -class KafkaDsn(AnyUrl): - allowed_schemes = {'kafka'} - - @staticmethod - def get_default_parts(parts: 'Parts') -> 'Parts': - return { - 'domain': 'localhost', - 'port': '9092', - } - - -def stricturl( - *, - strip_whitespace: bool = True, - min_length: int = 1, - max_length: int = 2**16, - tld_required: bool = True, - host_required: bool = True, - allowed_schemes: Optional[Collection[str]] = None, -) -> Type[AnyUrl]: - # use kwargs then define conf in a dict to aid with IDE type hinting - namespace = dict( - strip_whitespace=strip_whitespace, - min_length=min_length, - max_length=max_length, - tld_required=tld_required, - host_required=host_required, - allowed_schemes=allowed_schemes, - ) - return type('UrlValue', (AnyUrl,), namespace) +class MyModel(BaseModel): + url: HttpUrl + +m = MyModel(url='http://www.example.com') + +# the repr() method for a url will display all properties of the url +print(repr(m.url)) +#> Url('http://www.example.com/') +print(m.url.scheme) +#> http +print(m.url.host) +#> www.example.com +print(m.url.port) +#> 80 + +class MyDatabaseModel(BaseModel): + db: PostgresDsn + + @field_validator('db') + def check_db_name(cls, v): + assert v.path and len(v.path) > 1, 'database must be provided' + return v + +m = MyDatabaseModel(db='postgres://user:pass@localhost:5432/foobar') +print(m.db) +#> postgres://user:pass@localhost:5432/foobar + +try: + MyDatabaseModel(db='postgres://user:pass@localhost:5432') +except ValidationError as e: + print(e) + ''' + 1 validation error for MyDatabaseModel + db + Assertion failed, database must be provided + assert (None) + + where None = MultiHostUrl('postgres://user:pass@localhost:5432').path [type=assertion_error, input_value='postgres://user:pass@localhost:5432', input_type=str] + ''' +``` +""" + +CockroachDsn = Annotated[ + Url, + UrlConstraints( + host_required=True, + allowed_schemes=[ + 'cockroachdb', + 'cockroachdb+psycopg2', + 'cockroachdb+asyncpg', + ], + ), +] +"""A type that will accept any Cockroach DSN. + +* User info required +* TLD not required +* Host required +""" +AmqpDsn = Annotated[Url, UrlConstraints(allowed_schemes=['amqp', 'amqps'])] +"""A type that will accept any AMQP DSN. + +* User info required +* TLD not required +* Host required +""" +RedisDsn = Annotated[ + Url, + UrlConstraints(allowed_schemes=['redis', 'rediss'], default_host='localhost', default_port=6379, default_path='/0'), +] +"""A type that will accept any Redis DSN. + +* User info required +* TLD not required +* Host required (e.g., `rediss://:pass@localhost`) +""" +MongoDsn = Annotated[MultiHostUrl, UrlConstraints(allowed_schemes=['mongodb', 'mongodb+srv'], default_port=27017)] +"""A type that will accept any MongoDB DSN. + +* User info not required +* Database name not required +* Port not required +* User info may be passed without user part (e.g., `mongodb://mongodb0.example.com:27017`). +""" +KafkaDsn = Annotated[Url, UrlConstraints(allowed_schemes=['kafka'], default_host='localhost', default_port=9092)] +"""A type that will accept any Kafka DSN. + +* User info required +* TLD not required +* Host required +""" +NatsDsn = Annotated[ + MultiHostUrl, UrlConstraints(allowed_schemes=['nats', 'tls', 'ws'], default_host='localhost', default_port=4222) +] +"""A type that will accept any NATS DSN. + +NATS is a connective technology built for the ever increasingly hyper-connected world. +It is a single technology that enables applications to securely communicate across +any combination of cloud vendors, on-premise, edge, web and mobile, and devices. +More: https://nats.io +""" +MySQLDsn = Annotated[ + Url, + UrlConstraints( + allowed_schemes=[ + 'mysql', + 'mysql+mysqlconnector', + 'mysql+aiomysql', + 'mysql+asyncmy', + 'mysql+mysqldb', + 'mysql+pymysql', + 'mysql+cymysql', + 'mysql+pyodbc', + ], + default_port=3306, + ), +] +"""A type that will accept any MySQL DSN. + +* User info required +* TLD not required +* Host required +""" +MariaDBDsn = Annotated[ + Url, + UrlConstraints( + allowed_schemes=['mariadb', 'mariadb+mariadbconnector', 'mariadb+pymysql'], + default_port=3306, + ), +] +"""A type that will accept any MariaDB DSN. + +* User info required +* TLD not required +* Host required +""" def import_email_validator() -> None: @@ -577,27 +364,95 @@ def import_email_validator() -> None: import email_validator except ImportError as e: raise ImportError('email-validator is not installed, run `pip install pydantic[email]`') from e + if not version('email-validator').partition('.')[0] == '2': + raise ImportError('email-validator version >= 2.0 required, run pip install -U email-validator') -class EmailStr(str): - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type='string', format='email') +if TYPE_CHECKING: + EmailStr = Annotated[str, ...] +else: - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - # included here and below so the error happens straight away - import_email_validator() + class EmailStr: + """ + Info: + To use this type, you need to install the optional + [`email-validator`](https://github.com/JoshData/python-email-validator) package: - yield str_validator - yield cls.validate + ```bash + pip install email-validator + ``` - @classmethod - def validate(cls, value: Union[str]) -> str: - return validate_email(value)[1] + Validate email addresses. + + ```py + from pydantic import BaseModel, EmailStr + + class Model(BaseModel): + email: EmailStr + + print(Model(email='contact@mail.com')) + #> email='contact@mail.com' + ``` + """ # noqa: D212 + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source: type[Any], + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + import_email_validator() + return core_schema.no_info_after_validator_function(cls._validate, core_schema.str_schema()) + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler + ) -> JsonSchemaValue: + field_schema = handler(core_schema) + field_schema.update(type='string', format='email') + return field_schema + + @classmethod + def _validate(cls, __input_value: str) -> str: + return validate_email(__input_value)[1] -class NameEmail(Representation): +class NameEmail(_repr.Representation): + """ + Info: + To use this type, you need to install the optional + [`email-validator`](https://github.com/JoshData/python-email-validator) package: + + ```bash + pip install email-validator + ``` + + Validate a name and email address combination, as specified by + [RFC 5322](https://datatracker.ietf.org/doc/html/rfc5322#section-3.4). + + The `NameEmail` has two properties: `name` and `email`. + In case the `name` is not provided, it's inferred from the email address. + + ```py + from pydantic import BaseModel, NameEmail + + class User(BaseModel): + email: NameEmail + + user = User(email='Fred Bloggs ') + print(user.email) + #> Fred Bloggs + print(user.email.name) + #> Fred Bloggs + + user = User(email='fred.bloggs@example.com') + print(user.email) + #> fred.bloggs + print(user.email.name) + #> fred.bloggs + ``` + """ # noqa: D212 + __slots__ = 'name', 'email' def __init__(self, name: str, email: str): @@ -608,39 +463,76 @@ class NameEmail(Representation): return isinstance(other, NameEmail) and (self.name, self.email) == (other.name, other.email) @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler + ) -> JsonSchemaValue: + field_schema = handler(core_schema) field_schema.update(type='string', format='name-email') + return field_schema @classmethod - def __get_validators__(cls) -> 'CallableGenerator': + def __get_pydantic_core_schema__( + cls, + _source: type[Any], + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: import_email_validator() - - yield cls.validate + return core_schema.no_info_after_validator_function( + cls._validate, + core_schema.union_schema( + [core_schema.is_instance_schema(cls), core_schema.str_schema()], + custom_error_type='name_email_type', + custom_error_message='Input is not a valid NameEmail', + ), + serialization=core_schema.to_string_ser_schema(), + ) @classmethod - def validate(cls, value: Any) -> 'NameEmail': - if value.__class__ == cls: - return value - value = str_validator(value) - return cls(*validate_email(value)) + def _validate(cls, __input_value: NameEmail | str) -> NameEmail: + if isinstance(__input_value, cls): + return __input_value + else: + name, email = validate_email(__input_value) # type: ignore[arg-type] + return cls(name, email) def __str__(self) -> str: return f'{self.name} <{self.email}>' -class IPvAnyAddress(_BaseAddress): +class IPvAnyAddress: + """Validate an IPv4 or IPv6 address. + + ```py + from pydantic import BaseModel + from pydantic.networks import IPvAnyAddress + + class IpModel(BaseModel): + ip: IPvAnyAddress + + print(IpModel(ip='127.0.0.1')) + #> ip=IPv4Address('127.0.0.1') + + try: + IpModel(ip='http://www.example.com') + except ValueError as e: + print(e.errors()) + ''' + [ + { + 'type': 'ip_any_address', + 'loc': ('ip',), + 'msg': 'value is not a valid IPv4 or IPv6 address', + 'input': 'http://www.example.com', + } + ] + ''' + ``` + """ + __slots__ = () - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type='string', format='ipvanyaddress') - - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield cls.validate - - @classmethod - def validate(cls, value: Union[str, bytes, int]) -> Union[IPv4Address, IPv6Address]: + def __new__(cls, value: Any) -> IPv4Address | IPv6Address: + """Validate an IPv4 or IPv6 address.""" try: return IPv4Address(value) except ValueError: @@ -649,22 +541,38 @@ class IPvAnyAddress(_BaseAddress): try: return IPv6Address(value) except ValueError: - raise errors.IPvAnyAddressError() + raise PydanticCustomError('ip_any_address', 'value is not a valid IPv4 or IPv6 address') + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler + ) -> JsonSchemaValue: + field_schema = {} + field_schema.update(type='string', format='ipvanyaddress') + return field_schema + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source: type[Any], + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + return core_schema.no_info_plain_validator_function( + cls._validate, serialization=core_schema.to_string_ser_schema() + ) + + @classmethod + def _validate(cls, __input_value: Any) -> IPv4Address | IPv6Address: + return cls(__input_value) # type: ignore[return-value] -class IPvAnyInterface(_BaseAddress): +class IPvAnyInterface: + """Validate an IPv4 or IPv6 interface.""" + __slots__ = () - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type='string', format='ipvanyinterface') - - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield cls.validate - - @classmethod - def validate(cls, value: NetworkType) -> Union[IPv4Interface, IPv6Interface]: + def __new__(cls, value: NetworkType) -> IPv4Interface | IPv6Interface: + """Validate an IPv4 or IPv6 interface.""" try: return IPv4Interface(value) except ValueError: @@ -673,21 +581,39 @@ class IPvAnyInterface(_BaseAddress): try: return IPv6Interface(value) except ValueError: - raise errors.IPvAnyInterfaceError() - - -class IPvAnyNetwork(_BaseNetwork): # type: ignore - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type='string', format='ipvanynetwork') + raise PydanticCustomError('ip_any_interface', 'value is not a valid IPv4 or IPv6 interface') @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield cls.validate + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler + ) -> JsonSchemaValue: + field_schema = {} + field_schema.update(type='string', format='ipvanyinterface') + return field_schema @classmethod - def validate(cls, value: NetworkType) -> Union[IPv4Network, IPv6Network]: - # Assume IP Network is defined with a default value for ``strict`` argument. + def __get_pydantic_core_schema__( + cls, + _source: type[Any], + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + return core_schema.no_info_plain_validator_function( + cls._validate, serialization=core_schema.to_string_ser_schema() + ) + + @classmethod + def _validate(cls, __input_value: NetworkType) -> IPv4Interface | IPv6Interface: + return cls(__input_value) # type: ignore[return-value] + + +class IPvAnyNetwork: + """Validate an IPv4 or IPv6 network.""" + + __slots__ = () + + def __new__(cls, value: NetworkType) -> IPv4Network | IPv6Network: + """Validate an IPv4 or IPv6 network.""" + # Assume IP Network is defined with a default value for `strict` argument. # Define your own class if you want to specify network address check strictness. try: return IPv4Network(value) @@ -697,40 +623,86 @@ class IPvAnyNetwork(_BaseNetwork): # type: ignore try: return IPv6Network(value) except ValueError: - raise errors.IPvAnyNetworkError() + raise PydanticCustomError('ip_any_network', 'value is not a valid IPv4 or IPv6 network') + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: _schema_generation_shared.GetJsonSchemaHandler + ) -> JsonSchemaValue: + field_schema = {} + field_schema.update(type='string', format='ipvanynetwork') + return field_schema + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source: type[Any], + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + return core_schema.no_info_plain_validator_function( + cls._validate, serialization=core_schema.to_string_ser_schema() + ) + + @classmethod + def _validate(cls, __input_value: NetworkType) -> IPv4Network | IPv6Network: + return cls(__input_value) # type: ignore[return-value] -pretty_email_regex = re.compile(r'([\w ]*?) *<(.*)> *') +def _build_pretty_email_regex() -> re.Pattern[str]: + name_chars = r'[\w!#$%&\'*+\-/=?^_`{|}~]' + unquoted_name_group = rf'((?:{name_chars}+\s+)*{name_chars}+)' + quoted_name_group = r'"((?:[^"]|\")+)"' + email_group = r'<\s*(.+)\s*>' + return re.compile(rf'\s*(?:{unquoted_name_group}|{quoted_name_group})?\s*{email_group}\s*') -def validate_email(value: Union[str]) -> Tuple[str, str]: - """ - Brutally simple email address validation. Note unlike most email address validation - * raw ip address (literal) domain parts are not allowed. - * "John Doe " style "pretty" email addresses are processed - * the local part check is extremely basic. This raises the possibility of unicode spoofing, but no better - solution is really possible. - * spaces are striped from the beginning and end of addresses but no error is raised +pretty_email_regex = _build_pretty_email_regex() - See RFC 5322 but treat it with suspicion, there seems to exist no universally acknowledged test for a valid email! +MAX_EMAIL_LENGTH = 2048 +"""Maximum length for an email. +A somewhat arbitrary but very generous number compared to what is allowed by most implementations. +""" + + +def validate_email(value: str) -> tuple[str, str]: + """Email address validation using [email-validator](https://pypi.org/project/email-validator/). + + Note: + Note that: + + * Raw IP address (literal) domain parts are not allowed. + * `"John Doe "` style "pretty" email addresses are processed. + * Spaces are striped from the beginning and end of addresses, but no error is raised. """ if email_validator is None: import_email_validator() + if len(value) > MAX_EMAIL_LENGTH: + raise PydanticCustomError( + 'value_error', + 'value is not a valid email address: {reason}', + {'reason': f'Length must not exceed {MAX_EMAIL_LENGTH} characters'}, + ) + m = pretty_email_regex.fullmatch(value) - name: Optional[str] = None + name: str | None = None if m: - name, value = m.groups() + unquoted_name, quoted_name, value = m.groups() + name = unquoted_name or quoted_name email = value.strip() try: - email_validator.validate_email(email, check_deliverability=False) + parts = email_validator.validate_email(email, check_deliverability=False) except email_validator.EmailNotValidError as e: - raise errors.EmailError() from e + raise PydanticCustomError( + 'value_error', 'value is not a valid email address: {reason}', {'reason': str(e.args[0])} + ) from e - at_index = email.index('@') - local_part = email[:at_index] # RFC 5321, local part must be case-sensitive. - global_part = email[at_index:].lower() + email = parts.normalized + assert email is not None + name = name or parts.local_part + return name, email - return name or local_part, local_part + global_part + +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/parse.py b/lib/pydantic/parse.py index 7ac330ca..ceee6342 100644 --- a/lib/pydantic/parse.py +++ b/lib/pydantic/parse.py @@ -1,66 +1,4 @@ -import json -import pickle -from enum import Enum -from pathlib import Path -from typing import Any, Callable, Union +"""The `parse` module is a backport module from V1.""" +from ._migration import getattr_migration -from .types import StrBytes - - -class Protocol(str, Enum): - json = 'json' - pickle = 'pickle' - - -def load_str_bytes( - b: StrBytes, - *, - content_type: str = None, - encoding: str = 'utf8', - proto: Protocol = None, - allow_pickle: bool = False, - json_loads: Callable[[str], Any] = json.loads, -) -> Any: - if proto is None and content_type: - if content_type.endswith(('json', 'javascript')): - pass - elif allow_pickle and content_type.endswith('pickle'): - proto = Protocol.pickle - else: - raise TypeError(f'Unknown content-type: {content_type}') - - proto = proto or Protocol.json - - if proto == Protocol.json: - if isinstance(b, bytes): - b = b.decode(encoding) - return json_loads(b) - elif proto == Protocol.pickle: - if not allow_pickle: - raise RuntimeError('Trying to decode with pickle with allow_pickle=False') - bb = b if isinstance(b, bytes) else b.encode() - return pickle.loads(bb) - else: - raise TypeError(f'Unknown protocol: {proto}') - - -def load_file( - path: Union[str, Path], - *, - content_type: str = None, - encoding: str = 'utf8', - proto: Protocol = None, - allow_pickle: bool = False, - json_loads: Callable[[str], Any] = json.loads, -) -> Any: - path = Path(path) - b = path.read_bytes() - if content_type is None: - if path.suffix in ('.js', '.json'): - proto = Protocol.json - elif path.suffix == '.pkl': - proto = Protocol.pickle - - return load_str_bytes( - b, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle, json_loads=json_loads - ) +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/plugin/__init__.py b/lib/pydantic/plugin/__init__.py new file mode 100644 index 00000000..84197006 --- /dev/null +++ b/lib/pydantic/plugin/__init__.py @@ -0,0 +1,170 @@ +"""Usage docs: https://docs.pydantic.dev/2.6/concepts/plugins#build-a-plugin + +Plugin interface for Pydantic plugins, and related types. +""" +from __future__ import annotations + +from typing import Any, Callable, NamedTuple + +from pydantic_core import CoreConfig, CoreSchema, ValidationError +from typing_extensions import Literal, Protocol, TypeAlias + +__all__ = ( + 'PydanticPluginProtocol', + 'BaseValidateHandlerProtocol', + 'ValidatePythonHandlerProtocol', + 'ValidateJsonHandlerProtocol', + 'ValidateStringsHandlerProtocol', + 'NewSchemaReturns', + 'SchemaTypePath', + 'SchemaKind', +) + +NewSchemaReturns: TypeAlias = 'tuple[ValidatePythonHandlerProtocol | None, ValidateJsonHandlerProtocol | None, ValidateStringsHandlerProtocol | None]' + + +class SchemaTypePath(NamedTuple): + """Path defining where `schema_type` was defined, or where `TypeAdapter` was called.""" + + module: str + name: str + + +SchemaKind: TypeAlias = Literal['BaseModel', 'TypeAdapter', 'dataclass', 'create_model', 'validate_call'] + + +class PydanticPluginProtocol(Protocol): + """Protocol defining the interface for Pydantic plugins.""" + + def new_schema_validator( + self, + schema: CoreSchema, + schema_type: Any, + schema_type_path: SchemaTypePath, + schema_kind: SchemaKind, + config: CoreConfig | None, + plugin_settings: dict[str, object], + ) -> tuple[ + ValidatePythonHandlerProtocol | None, ValidateJsonHandlerProtocol | None, ValidateStringsHandlerProtocol | None + ]: + """This method is called for each plugin every time a new [`SchemaValidator`][pydantic_core.SchemaValidator] + is created. + + It should return an event handler for each of the three validation methods, or `None` if the plugin does not + implement that method. + + Args: + schema: The schema to validate against. + schema_type: The original type which the schema was created from, e.g. the model class. + schema_type_path: Path defining where `schema_type` was defined, or where `TypeAdapter` was called. + schema_kind: The kind of schema to validate against. + config: The config to use for validation. + plugin_settings: Any plugin settings. + + Returns: + A tuple of optional event handlers for each of the three validation methods - + `validate_python`, `validate_json`, `validate_strings`. + """ + raise NotImplementedError('Pydantic plugins should implement `new_schema_validator`.') + + +class BaseValidateHandlerProtocol(Protocol): + """Base class for plugin callbacks protocols. + + You shouldn't implement this protocol directly, instead use one of the subclasses with adds the correctly + typed `on_error` method. + """ + + on_enter: Callable[..., None] + """`on_enter` is changed to be more specific on all subclasses""" + + def on_success(self, result: Any) -> None: + """Callback to be notified of successful validation. + + Args: + result: The result of the validation. + """ + return + + def on_error(self, error: ValidationError) -> None: + """Callback to be notified of validation errors. + + Args: + error: The validation error. + """ + return + + def on_exception(self, exception: Exception) -> None: + """Callback to be notified of validation exceptions. + + Args: + exception: The exception raised during validation. + """ + return + + +class ValidatePythonHandlerProtocol(BaseValidateHandlerProtocol, Protocol): + """Event handler for `SchemaValidator.validate_python`.""" + + def on_enter( + self, + input: Any, + *, + strict: bool | None = None, + from_attributes: bool | None = None, + context: dict[str, Any] | None = None, + self_instance: Any | None = None, + ) -> None: + """Callback to be notified of validation start, and create an instance of the event handler. + + Args: + input: The input to be validated. + strict: Whether to validate the object in strict mode. + from_attributes: Whether to validate objects as inputs by extracting attributes. + context: The context to use for validation, this is passed to functional validators. + self_instance: An instance of a model to set attributes on from validation, this is used when running + validation from the `__init__` method of a model. + """ + pass + + +class ValidateJsonHandlerProtocol(BaseValidateHandlerProtocol, Protocol): + """Event handler for `SchemaValidator.validate_json`.""" + + def on_enter( + self, + input: str | bytes | bytearray, + *, + strict: bool | None = None, + context: dict[str, Any] | None = None, + self_instance: Any | None = None, + ) -> None: + """Callback to be notified of validation start, and create an instance of the event handler. + + Args: + input: The JSON data to be validated. + strict: Whether to validate the object in strict mode. + context: The context to use for validation, this is passed to functional validators. + self_instance: An instance of a model to set attributes on from validation, this is used when running + validation from the `__init__` method of a model. + """ + pass + + +StringInput: TypeAlias = 'dict[str, StringInput]' + + +class ValidateStringsHandlerProtocol(BaseValidateHandlerProtocol, Protocol): + """Event handler for `SchemaValidator.validate_strings`.""" + + def on_enter( + self, input: StringInput, *, strict: bool | None = None, context: dict[str, Any] | None = None + ) -> None: + """Callback to be notified of validation start, and create an instance of the event handler. + + Args: + input: The string data to be validated. + strict: Whether to validate the object in strict mode. + context: The context to use for validation, this is passed to functional validators. + """ + pass diff --git a/lib/pydantic/plugin/_loader.py b/lib/pydantic/plugin/_loader.py new file mode 100644 index 00000000..9e0e33ca --- /dev/null +++ b/lib/pydantic/plugin/_loader.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import importlib.metadata as importlib_metadata +import warnings +from typing import TYPE_CHECKING, Final, Iterable + +if TYPE_CHECKING: + from . import PydanticPluginProtocol + + +PYDANTIC_ENTRY_POINT_GROUP: Final[str] = 'pydantic' + +# cache of plugins +_plugins: dict[str, PydanticPluginProtocol] | None = None +# return no plugins while loading plugins to avoid recursion and errors while import plugins +# this means that if plugins use pydantic +_loading_plugins: bool = False + + +def get_plugins() -> Iterable[PydanticPluginProtocol]: + """Load plugins for Pydantic. + + Inspired by: https://github.com/pytest-dev/pluggy/blob/1.3.0/src/pluggy/_manager.py#L376-L402 + """ + global _plugins, _loading_plugins + if _loading_plugins: + # this happens when plugins themselves use pydantic, we return no plugins + return () + elif _plugins is None: + _plugins = {} + # set _loading_plugins so any plugins that use pydantic don't themselves use plugins + _loading_plugins = True + try: + for dist in importlib_metadata.distributions(): + for entry_point in dist.entry_points: + if entry_point.group != PYDANTIC_ENTRY_POINT_GROUP: + continue + if entry_point.value in _plugins: + continue + try: + _plugins[entry_point.value] = entry_point.load() + except (ImportError, AttributeError) as e: + warnings.warn( + f'{e.__class__.__name__} while loading the `{entry_point.name}` Pydantic plugin, ' + f'this plugin will not be installed.\n\n{e!r}' + ) + finally: + _loading_plugins = False + + return _plugins.values() diff --git a/lib/pydantic/plugin/_schema_validator.py b/lib/pydantic/plugin/_schema_validator.py new file mode 100644 index 00000000..7186ece6 --- /dev/null +++ b/lib/pydantic/plugin/_schema_validator.py @@ -0,0 +1,138 @@ +"""Pluggable schema validator for pydantic.""" +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar + +from pydantic_core import CoreConfig, CoreSchema, SchemaValidator, ValidationError +from typing_extensions import Literal, ParamSpec + +if TYPE_CHECKING: + from . import BaseValidateHandlerProtocol, PydanticPluginProtocol, SchemaKind, SchemaTypePath + + +P = ParamSpec('P') +R = TypeVar('R') +Event = Literal['on_validate_python', 'on_validate_json', 'on_validate_strings'] +events: list[Event] = list(Event.__args__) # type: ignore + + +def create_schema_validator( + schema: CoreSchema, + schema_type: Any, + schema_type_module: str, + schema_type_name: str, + schema_kind: SchemaKind, + config: CoreConfig | None = None, + plugin_settings: dict[str, Any] | None = None, +) -> SchemaValidator: + """Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed. + + Returns: + If plugins are installed then return `PluggableSchemaValidator`, otherwise return `SchemaValidator`. + """ + from . import SchemaTypePath + from ._loader import get_plugins + + plugins = get_plugins() + if plugins: + return PluggableSchemaValidator( + schema, + schema_type, + SchemaTypePath(schema_type_module, schema_type_name), + schema_kind, + config, + plugins, + plugin_settings or {}, + ) # type: ignore + else: + return SchemaValidator(schema, config) + + +class PluggableSchemaValidator: + """Pluggable schema validator.""" + + __slots__ = '_schema_validator', 'validate_json', 'validate_python', 'validate_strings' + + def __init__( + self, + schema: CoreSchema, + schema_type: Any, + schema_type_path: SchemaTypePath, + schema_kind: SchemaKind, + config: CoreConfig | None, + plugins: Iterable[PydanticPluginProtocol], + plugin_settings: dict[str, Any], + ) -> None: + self._schema_validator = SchemaValidator(schema, config) + + python_event_handlers: list[BaseValidateHandlerProtocol] = [] + json_event_handlers: list[BaseValidateHandlerProtocol] = [] + strings_event_handlers: list[BaseValidateHandlerProtocol] = [] + for plugin in plugins: + try: + p, j, s = plugin.new_schema_validator( + schema, schema_type, schema_type_path, schema_kind, config, plugin_settings + ) + except TypeError as e: # pragma: no cover + raise TypeError(f'Error using plugin `{plugin.__module__}:{plugin.__class__.__name__}`: {e}') from e + if p is not None: + python_event_handlers.append(p) + if j is not None: + json_event_handlers.append(j) + if s is not None: + strings_event_handlers.append(s) + + self.validate_python = build_wrapper(self._schema_validator.validate_python, python_event_handlers) + self.validate_json = build_wrapper(self._schema_validator.validate_json, json_event_handlers) + self.validate_strings = build_wrapper(self._schema_validator.validate_strings, strings_event_handlers) + + def __getattr__(self, name: str) -> Any: + return getattr(self._schema_validator, name) + + +def build_wrapper(func: Callable[P, R], event_handlers: list[BaseValidateHandlerProtocol]) -> Callable[P, R]: + if not event_handlers: + return func + else: + on_enters = tuple(h.on_enter for h in event_handlers if filter_handlers(h, 'on_enter')) + on_successes = tuple(h.on_success for h in event_handlers if filter_handlers(h, 'on_success')) + on_errors = tuple(h.on_error for h in event_handlers if filter_handlers(h, 'on_error')) + on_exceptions = tuple(h.on_exception for h in event_handlers if filter_handlers(h, 'on_exception')) + + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + for on_enter_handler in on_enters: + on_enter_handler(*args, **kwargs) + + try: + result = func(*args, **kwargs) + except ValidationError as error: + for on_error_handler in on_errors: + on_error_handler(error) + raise + except Exception as exception: + for on_exception_handler in on_exceptions: + on_exception_handler(exception) + raise + else: + for on_success_handler in on_successes: + on_success_handler(result) + return result + + return wrapper + + +def filter_handlers(handler_cls: BaseValidateHandlerProtocol, method_name: str) -> bool: + """Filter out handler methods which are not implemented by the plugin directly - e.g. are missing + or are inherited from the protocol. + """ + handler = getattr(handler_cls, method_name, None) + if handler is None: + return False + elif handler.__module__ == 'pydantic.plugin': + # this is the original handler, from the protocol due to runtime inheritance + # we don't want to call it + return False + else: + return True diff --git a/lib/pydantic/root_model.py b/lib/pydantic/root_model.py new file mode 100644 index 00000000..42186b9d --- /dev/null +++ b/lib/pydantic/root_model.py @@ -0,0 +1,149 @@ +"""RootModel class and type definitions.""" + +from __future__ import annotations as _annotations + +import typing +from copy import copy, deepcopy + +from pydantic_core import PydanticUndefined + +from . import PydanticUserError +from ._internal import _model_construction, _repr +from .main import BaseModel, _object_setattr + +if typing.TYPE_CHECKING: + from typing import Any + + from typing_extensions import Literal, dataclass_transform + + from .fields import Field as PydanticModelField + + # dataclass_transform could be applied to RootModel directly, but `ModelMetaclass`'s dataclass_transform + # takes priority (at least with pyright). We trick type checkers into thinking we apply dataclass_transform + # on a new metaclass. + @dataclass_transform(kw_only_default=False, field_specifiers=(PydanticModelField,)) + class _RootModelMetaclass(_model_construction.ModelMetaclass): + ... + + Model = typing.TypeVar('Model', bound='BaseModel') +else: + _RootModelMetaclass = _model_construction.ModelMetaclass + +__all__ = ('RootModel',) + + +RootModelRootType = typing.TypeVar('RootModelRootType') + + +class RootModel(BaseModel, typing.Generic[RootModelRootType], metaclass=_RootModelMetaclass): + """Usage docs: https://docs.pydantic.dev/2.6/concepts/models/#rootmodel-and-custom-root-types + + A Pydantic `BaseModel` for the root object of the model. + + Attributes: + root: The root object of the model. + __pydantic_root_model__: Whether the model is a RootModel. + __pydantic_private__: Private fields in the model. + __pydantic_extra__: Extra fields in the model. + + """ + + __pydantic_root_model__ = True + __pydantic_private__ = None + __pydantic_extra__ = None + + root: RootModelRootType + + def __init_subclass__(cls, **kwargs): + extra = cls.model_config.get('extra') + if extra is not None: + raise PydanticUserError( + "`RootModel` does not support setting `model_config['extra']`", code='root-model-extra' + ) + super().__init_subclass__(**kwargs) + + def __init__(self, /, root: RootModelRootType = PydanticUndefined, **data) -> None: # type: ignore + __tracebackhide__ = True + if data: + if root is not PydanticUndefined: + raise ValueError( + '"RootModel.__init__" accepts either a single positional argument or arbitrary keyword arguments' + ) + root = data # type: ignore + self.__pydantic_validator__.validate_python(root, self_instance=self) + + __init__.__pydantic_base_init__ = True # pyright: ignore[reportFunctionMemberAccess] + + @classmethod + def model_construct(cls: type[Model], root: RootModelRootType, _fields_set: set[str] | None = None) -> Model: # type: ignore + """Create a new model using the provided root object and update fields set. + + Args: + root: The root object of the model. + _fields_set: The set of fields to be updated. + + Returns: + The new model. + + Raises: + NotImplemented: If the model is not a subclass of `RootModel`. + """ + return super().model_construct(root=root, _fields_set=_fields_set) + + def __getstate__(self) -> dict[Any, Any]: + return { + '__dict__': self.__dict__, + '__pydantic_fields_set__': self.__pydantic_fields_set__, + } + + def __setstate__(self, state: dict[Any, Any]) -> None: + _object_setattr(self, '__pydantic_fields_set__', state['__pydantic_fields_set__']) + _object_setattr(self, '__dict__', state['__dict__']) + + def __copy__(self: Model) -> Model: + """Returns a shallow copy of the model.""" + cls = type(self) + m = cls.__new__(cls) + _object_setattr(m, '__dict__', copy(self.__dict__)) + _object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__)) + return m + + def __deepcopy__(self: Model, memo: dict[int, Any] | None = None) -> Model: + """Returns a deep copy of the model.""" + cls = type(self) + m = cls.__new__(cls) + _object_setattr(m, '__dict__', deepcopy(self.__dict__, memo=memo)) + # This next line doesn't need a deepcopy because __pydantic_fields_set__ is a set[str], + # and attempting a deepcopy would be marginally slower. + _object_setattr(m, '__pydantic_fields_set__', copy(self.__pydantic_fields_set__)) + return m + + if typing.TYPE_CHECKING: + + def model_dump( # type: ignore + self, + *, + mode: Literal['json', 'python'] | str = 'python', + include: Any = None, + exclude: Any = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + round_trip: bool = False, + warnings: bool = True, + ) -> RootModelRootType: + """This method is included just to get a more accurate return type for type checkers. + It is included in this `if TYPE_CHECKING:` block since no override is actually necessary. + + See the documentation of `BaseModel.model_dump` for more details about the arguments. + """ + ... + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, RootModel): + return NotImplemented + return self.model_fields['root'].annotation == other.model_fields['root'].annotation and super().__eq__(other) + + def __repr_args__(self) -> _repr.ReprArgs: + yield 'root', self.root diff --git a/lib/pydantic/schema.py b/lib/pydantic/schema.py index e7af56f1..e290aed9 100644 --- a/lib/pydantic/schema.py +++ b/lib/pydantic/schema.py @@ -1,1153 +1,4 @@ -import re -import warnings -from collections import defaultdict -from datetime import date, datetime, time, timedelta -from decimal import Decimal -from enum import Enum -from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network -from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - ForwardRef, - FrozenSet, - Generic, - Iterable, - List, - Optional, - Pattern, - Sequence, - Set, - Tuple, - Type, - TypeVar, - Union, - cast, -) -from uuid import UUID +"""The `schema` module is a backport module from V1.""" +from ._migration import getattr_migration -from typing_extensions import Annotated, Literal - -from .fields import ( - MAPPING_LIKE_SHAPES, - SHAPE_DEQUE, - SHAPE_FROZENSET, - SHAPE_GENERIC, - SHAPE_ITERABLE, - SHAPE_LIST, - SHAPE_SEQUENCE, - SHAPE_SET, - SHAPE_SINGLETON, - SHAPE_TUPLE, - SHAPE_TUPLE_ELLIPSIS, - FieldInfo, - ModelField, -) -from .json import pydantic_encoder -from .networks import AnyUrl, EmailStr -from .types import ( - ConstrainedDecimal, - ConstrainedFloat, - ConstrainedFrozenSet, - ConstrainedInt, - ConstrainedList, - ConstrainedSet, - SecretBytes, - SecretStr, - StrictBytes, - StrictStr, - conbytes, - condecimal, - confloat, - confrozenset, - conint, - conlist, - conset, - constr, -) -from .typing import ( - all_literal_values, - get_args, - get_origin, - get_sub_types, - is_callable_type, - is_literal_type, - is_namedtuple, - is_none_type, - is_union, -) -from .utils import ROOT_KEY, get_model, lenient_issubclass - -if TYPE_CHECKING: - from .dataclasses import Dataclass - from .main import BaseModel - -default_prefix = '#/definitions/' -default_ref_template = '#/definitions/{model}' - -TypeModelOrEnum = Union[Type['BaseModel'], Type[Enum]] -TypeModelSet = Set[TypeModelOrEnum] - - -def _apply_modify_schema( - modify_schema: Callable[..., None], field: Optional[ModelField], field_schema: Dict[str, Any] -) -> None: - from inspect import signature - - sig = signature(modify_schema) - args = set(sig.parameters.keys()) - if 'field' in args or 'kwargs' in args: - modify_schema(field_schema, field=field) - else: - modify_schema(field_schema) - - -def schema( - models: Sequence[Union[Type['BaseModel'], Type['Dataclass']]], - *, - by_alias: bool = True, - title: Optional[str] = None, - description: Optional[str] = None, - ref_prefix: Optional[str] = None, - ref_template: str = default_ref_template, -) -> Dict[str, Any]: - """ - Process a list of models and generate a single JSON Schema with all of them defined in the ``definitions`` - top-level JSON key, including their sub-models. - - :param models: a list of models to include in the generated JSON Schema - :param by_alias: generate the schemas using the aliases defined, if any - :param title: title for the generated schema that includes the definitions - :param description: description for the generated schema - :param ref_prefix: the JSON Pointer prefix for schema references with ``$ref``, if None, will be set to the - default of ``#/definitions/``. Update it if you want the schemas to reference the definitions somewhere - else, e.g. for OpenAPI use ``#/components/schemas/``. The resulting generated schemas will still be at the - top-level key ``definitions``, so you can extract them from there. But all the references will have the set - prefix. - :param ref_template: Use a ``string.format()`` template for ``$ref`` instead of a prefix. This can be useful - for references that cannot be represented by ``ref_prefix`` such as a definition stored in another file. For - a sibling json file in a ``/schemas`` directory use ``"/schemas/${model}.json#"``. - :return: dict with the JSON Schema with a ``definitions`` top-level key including the schema definitions for - the models and sub-models passed in ``models``. - """ - clean_models = [get_model(model) for model in models] - flat_models = get_flat_models_from_models(clean_models) - model_name_map = get_model_name_map(flat_models) - definitions = {} - output_schema: Dict[str, Any] = {} - if title: - output_schema['title'] = title - if description: - output_schema['description'] = description - for model in clean_models: - m_schema, m_definitions, m_nested_models = model_process_schema( - model, - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - ref_template=ref_template, - ) - definitions.update(m_definitions) - model_name = model_name_map[model] - definitions[model_name] = m_schema - if definitions: - output_schema['definitions'] = definitions - return output_schema - - -def model_schema( - model: Union[Type['BaseModel'], Type['Dataclass']], - by_alias: bool = True, - ref_prefix: Optional[str] = None, - ref_template: str = default_ref_template, -) -> Dict[str, Any]: - """ - Generate a JSON Schema for one model. With all the sub-models defined in the ``definitions`` top-level - JSON key. - - :param model: a Pydantic model (a class that inherits from BaseModel) - :param by_alias: generate the schemas using the aliases defined, if any - :param ref_prefix: the JSON Pointer prefix for schema references with ``$ref``, if None, will be set to the - default of ``#/definitions/``. Update it if you want the schemas to reference the definitions somewhere - else, e.g. for OpenAPI use ``#/components/schemas/``. The resulting generated schemas will still be at the - top-level key ``definitions``, so you can extract them from there. But all the references will have the set - prefix. - :param ref_template: Use a ``string.format()`` template for ``$ref`` instead of a prefix. This can be useful for - references that cannot be represented by ``ref_prefix`` such as a definition stored in another file. For a - sibling json file in a ``/schemas`` directory use ``"/schemas/${model}.json#"``. - :return: dict with the JSON Schema for the passed ``model`` - """ - model = get_model(model) - flat_models = get_flat_models_from_model(model) - model_name_map = get_model_name_map(flat_models) - model_name = model_name_map[model] - m_schema, m_definitions, nested_models = model_process_schema( - model, by_alias=by_alias, model_name_map=model_name_map, ref_prefix=ref_prefix, ref_template=ref_template - ) - if model_name in nested_models: - # model_name is in Nested models, it has circular references - m_definitions[model_name] = m_schema - m_schema = get_schema_ref(model_name, ref_prefix, ref_template, False) - if m_definitions: - m_schema.update({'definitions': m_definitions}) - return m_schema - - -def get_field_info_schema(field: ModelField, schema_overrides: bool = False) -> Tuple[Dict[str, Any], bool]: - - # If no title is explicitly set, we don't set title in the schema for enums. - # The behaviour is the same as `BaseModel` reference, where the default title - # is in the definitions part of the schema. - schema_: Dict[str, Any] = {} - if field.field_info.title or not lenient_issubclass(field.type_, Enum): - schema_['title'] = field.field_info.title or field.alias.title().replace('_', ' ') - - if field.field_info.title: - schema_overrides = True - - if field.field_info.description: - schema_['description'] = field.field_info.description - schema_overrides = True - - if not field.required and field.default is not None and not is_callable_type(field.outer_type_): - schema_['default'] = encode_default(field.default) - schema_overrides = True - - return schema_, schema_overrides - - -def field_schema( - field: ModelField, - *, - by_alias: bool = True, - model_name_map: Dict[TypeModelOrEnum, str], - ref_prefix: Optional[str] = None, - ref_template: str = default_ref_template, - known_models: TypeModelSet = None, -) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: - """ - Process a Pydantic field and return a tuple with a JSON Schema for it as the first item. - Also return a dictionary of definitions with models as keys and their schemas as values. If the passed field - is a model and has sub-models, and those sub-models don't have overrides (as ``title``, ``default``, etc), they - will be included in the definitions and referenced in the schema instead of included recursively. - - :param field: a Pydantic ``ModelField`` - :param by_alias: use the defined alias (if any) in the returned schema - :param model_name_map: used to generate the JSON Schema references to other models included in the definitions - :param ref_prefix: the JSON Pointer prefix to use for references to other schemas, if None, the default of - #/definitions/ will be used - :param ref_template: Use a ``string.format()`` template for ``$ref`` instead of a prefix. This can be useful for - references that cannot be represented by ``ref_prefix`` such as a definition stored in another file. For a - sibling json file in a ``/schemas`` directory use ``"/schemas/${model}.json#"``. - :param known_models: used to solve circular references - :return: tuple of the schema for this field and additional definitions - """ - s, schema_overrides = get_field_info_schema(field) - - validation_schema = get_field_schema_validations(field) - if validation_schema: - s.update(validation_schema) - schema_overrides = True - - f_schema, f_definitions, f_nested_models = field_type_schema( - field, - by_alias=by_alias, - model_name_map=model_name_map, - schema_overrides=schema_overrides, - ref_prefix=ref_prefix, - ref_template=ref_template, - known_models=known_models or set(), - ) - - # $ref will only be returned when there are no schema_overrides - if '$ref' in f_schema: - return f_schema, f_definitions, f_nested_models - else: - s.update(f_schema) - return s, f_definitions, f_nested_models - - -numeric_types = (int, float, Decimal) -_str_types_attrs: Tuple[Tuple[str, Union[type, Tuple[type, ...]], str], ...] = ( - ('max_length', numeric_types, 'maxLength'), - ('min_length', numeric_types, 'minLength'), - ('regex', str, 'pattern'), -) - -_numeric_types_attrs: Tuple[Tuple[str, Union[type, Tuple[type, ...]], str], ...] = ( - ('gt', numeric_types, 'exclusiveMinimum'), - ('lt', numeric_types, 'exclusiveMaximum'), - ('ge', numeric_types, 'minimum'), - ('le', numeric_types, 'maximum'), - ('multiple_of', numeric_types, 'multipleOf'), -) - - -def get_field_schema_validations(field: ModelField) -> Dict[str, Any]: - """ - Get the JSON Schema validation keywords for a ``field`` with an annotation of - a Pydantic ``FieldInfo`` with validation arguments. - """ - f_schema: Dict[str, Any] = {} - - if lenient_issubclass(field.type_, Enum): - # schema is already updated by `enum_process_schema`; just update with field extra - if field.field_info.extra: - f_schema.update(field.field_info.extra) - return f_schema - - if lenient_issubclass(field.type_, (str, bytes)): - for attr_name, t, keyword in _str_types_attrs: - attr = getattr(field.field_info, attr_name, None) - if isinstance(attr, t): - f_schema[keyword] = attr - if lenient_issubclass(field.type_, numeric_types) and not issubclass(field.type_, bool): - for attr_name, t, keyword in _numeric_types_attrs: - attr = getattr(field.field_info, attr_name, None) - if isinstance(attr, t): - f_schema[keyword] = attr - if field.field_info is not None and field.field_info.const: - f_schema['const'] = field.default - if field.field_info.extra: - f_schema.update(field.field_info.extra) - modify_schema = getattr(field.outer_type_, '__modify_schema__', None) - if modify_schema: - _apply_modify_schema(modify_schema, field, f_schema) - return f_schema - - -def get_model_name_map(unique_models: TypeModelSet) -> Dict[TypeModelOrEnum, str]: - """ - Process a set of models and generate unique names for them to be used as keys in the JSON Schema - definitions. By default the names are the same as the class name. But if two models in different Python - modules have the same name (e.g. "users.Model" and "items.Model"), the generated names will be - based on the Python module path for those conflicting models to prevent name collisions. - - :param unique_models: a Python set of models - :return: dict mapping models to names - """ - name_model_map = {} - conflicting_names: Set[str] = set() - for model in unique_models: - model_name = normalize_name(model.__name__) - if model_name in conflicting_names: - model_name = get_long_model_name(model) - name_model_map[model_name] = model - elif model_name in name_model_map: - conflicting_names.add(model_name) - conflicting_model = name_model_map.pop(model_name) - name_model_map[get_long_model_name(conflicting_model)] = conflicting_model - name_model_map[get_long_model_name(model)] = model - else: - name_model_map[model_name] = model - return {v: k for k, v in name_model_map.items()} - - -def get_flat_models_from_model(model: Type['BaseModel'], known_models: TypeModelSet = None) -> TypeModelSet: - """ - Take a single ``model`` and generate a set with itself and all the sub-models in the tree. I.e. if you pass - model ``Foo`` (subclass of Pydantic ``BaseModel``) as ``model``, and it has a field of type ``Bar`` (also - subclass of ``BaseModel``) and that model ``Bar`` has a field of type ``Baz`` (also subclass of ``BaseModel``), - the return value will be ``set([Foo, Bar, Baz])``. - - :param model: a Pydantic ``BaseModel`` subclass - :param known_models: used to solve circular references - :return: a set with the initial model and all its sub-models - """ - known_models = known_models or set() - flat_models: TypeModelSet = set() - flat_models.add(model) - known_models |= flat_models - fields = cast(Sequence[ModelField], model.__fields__.values()) - flat_models |= get_flat_models_from_fields(fields, known_models=known_models) - return flat_models - - -def get_flat_models_from_field(field: ModelField, known_models: TypeModelSet) -> TypeModelSet: - """ - Take a single Pydantic ``ModelField`` (from a model) that could have been declared as a sublcass of BaseModel - (so, it could be a submodel), and generate a set with its model and all the sub-models in the tree. - I.e. if you pass a field that was declared to be of type ``Foo`` (subclass of BaseModel) as ``field``, and that - model ``Foo`` has a field of type ``Bar`` (also subclass of ``BaseModel``) and that model ``Bar`` has a field of - type ``Baz`` (also subclass of ``BaseModel``), the return value will be ``set([Foo, Bar, Baz])``. - - :param field: a Pydantic ``ModelField`` - :param known_models: used to solve circular references - :return: a set with the model used in the declaration for this field, if any, and all its sub-models - """ - from .main import BaseModel - - flat_models: TypeModelSet = set() - - field_type = field.type_ - if lenient_issubclass(getattr(field_type, '__pydantic_model__', None), BaseModel): - field_type = field_type.__pydantic_model__ - - if field.sub_fields and not lenient_issubclass(field_type, BaseModel): - flat_models |= get_flat_models_from_fields(field.sub_fields, known_models=known_models) - elif lenient_issubclass(field_type, BaseModel) and field_type not in known_models: - flat_models |= get_flat_models_from_model(field_type, known_models=known_models) - elif lenient_issubclass(field_type, Enum): - flat_models.add(field_type) - return flat_models - - -def get_flat_models_from_fields(fields: Sequence[ModelField], known_models: TypeModelSet) -> TypeModelSet: - """ - Take a list of Pydantic ``ModelField``s (from a model) that could have been declared as subclasses of ``BaseModel`` - (so, any of them could be a submodel), and generate a set with their models and all the sub-models in the tree. - I.e. if you pass a the fields of a model ``Foo`` (subclass of ``BaseModel``) as ``fields``, and on of them has a - field of type ``Bar`` (also subclass of ``BaseModel``) and that model ``Bar`` has a field of type ``Baz`` (also - subclass of ``BaseModel``), the return value will be ``set([Foo, Bar, Baz])``. - - :param fields: a list of Pydantic ``ModelField``s - :param known_models: used to solve circular references - :return: a set with any model declared in the fields, and all their sub-models - """ - flat_models: TypeModelSet = set() - for field in fields: - flat_models |= get_flat_models_from_field(field, known_models=known_models) - return flat_models - - -def get_flat_models_from_models(models: Sequence[Type['BaseModel']]) -> TypeModelSet: - """ - Take a list of ``models`` and generate a set with them and all their sub-models in their trees. I.e. if you pass - a list of two models, ``Foo`` and ``Bar``, both subclasses of Pydantic ``BaseModel`` as models, and ``Bar`` has - a field of type ``Baz`` (also subclass of ``BaseModel``), the return value will be ``set([Foo, Bar, Baz])``. - """ - flat_models: TypeModelSet = set() - for model in models: - flat_models |= get_flat_models_from_model(model) - return flat_models - - -def get_long_model_name(model: TypeModelOrEnum) -> str: - return f'{model.__module__}__{model.__qualname__}'.replace('.', '__') - - -def field_type_schema( - field: ModelField, - *, - by_alias: bool, - model_name_map: Dict[TypeModelOrEnum, str], - ref_template: str, - schema_overrides: bool = False, - ref_prefix: Optional[str] = None, - known_models: TypeModelSet, -) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: - """ - Used by ``field_schema()``, you probably should be using that function. - - Take a single ``field`` and generate the schema for its type only, not including additional - information as title, etc. Also return additional schema definitions, from sub-models. - """ - from .main import BaseModel # noqa: F811 - - definitions = {} - nested_models: Set[str] = set() - f_schema: Dict[str, Any] - if field.shape in { - SHAPE_LIST, - SHAPE_TUPLE_ELLIPSIS, - SHAPE_SEQUENCE, - SHAPE_SET, - SHAPE_FROZENSET, - SHAPE_ITERABLE, - SHAPE_DEQUE, - }: - items_schema, f_definitions, f_nested_models = field_singleton_schema( - field, - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - ref_template=ref_template, - known_models=known_models, - ) - definitions.update(f_definitions) - nested_models.update(f_nested_models) - f_schema = {'type': 'array', 'items': items_schema} - if field.shape in {SHAPE_SET, SHAPE_FROZENSET}: - f_schema['uniqueItems'] = True - - elif field.shape in MAPPING_LIKE_SHAPES: - f_schema = {'type': 'object'} - key_field = cast(ModelField, field.key_field) - regex = getattr(key_field.type_, 'regex', None) - items_schema, f_definitions, f_nested_models = field_singleton_schema( - field, - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - ref_template=ref_template, - known_models=known_models, - ) - definitions.update(f_definitions) - nested_models.update(f_nested_models) - if regex: - # Dict keys have a regex pattern - # items_schema might be a schema or empty dict, add it either way - f_schema['patternProperties'] = {regex.pattern: items_schema} - elif items_schema: - # The dict values are not simply Any, so they need a schema - f_schema['additionalProperties'] = items_schema - elif field.shape == SHAPE_TUPLE or (field.shape == SHAPE_GENERIC and not issubclass(field.type_, BaseModel)): - sub_schema = [] - sub_fields = cast(List[ModelField], field.sub_fields) - for sf in sub_fields: - sf_schema, sf_definitions, sf_nested_models = field_type_schema( - sf, - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - ref_template=ref_template, - known_models=known_models, - ) - definitions.update(sf_definitions) - nested_models.update(sf_nested_models) - sub_schema.append(sf_schema) - - sub_fields_len = len(sub_fields) - if field.shape == SHAPE_GENERIC: - all_of_schemas = sub_schema[0] if sub_fields_len == 1 else {'type': 'array', 'items': sub_schema} - f_schema = {'allOf': [all_of_schemas]} - else: - f_schema = { - 'type': 'array', - 'minItems': sub_fields_len, - 'maxItems': sub_fields_len, - } - if sub_fields_len >= 1: - f_schema['items'] = sub_schema - else: - assert field.shape in {SHAPE_SINGLETON, SHAPE_GENERIC}, field.shape - f_schema, f_definitions, f_nested_models = field_singleton_schema( - field, - by_alias=by_alias, - model_name_map=model_name_map, - schema_overrides=schema_overrides, - ref_prefix=ref_prefix, - ref_template=ref_template, - known_models=known_models, - ) - definitions.update(f_definitions) - nested_models.update(f_nested_models) - - # check field type to avoid repeated calls to the same __modify_schema__ method - if field.type_ != field.outer_type_: - if field.shape == SHAPE_GENERIC: - field_type = field.type_ - else: - field_type = field.outer_type_ - modify_schema = getattr(field_type, '__modify_schema__', None) - if modify_schema: - _apply_modify_schema(modify_schema, field, f_schema) - return f_schema, definitions, nested_models - - -def model_process_schema( - model: TypeModelOrEnum, - *, - by_alias: bool = True, - model_name_map: Dict[TypeModelOrEnum, str], - ref_prefix: Optional[str] = None, - ref_template: str = default_ref_template, - known_models: TypeModelSet = None, - field: Optional[ModelField] = None, -) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: - """ - Used by ``model_schema()``, you probably should be using that function. - - Take a single ``model`` and generate its schema. Also return additional schema definitions, from sub-models. The - sub-models of the returned schema will be referenced, but their definitions will not be included in the schema. All - the definitions are returned as the second value. - """ - from inspect import getdoc, signature - - known_models = known_models or set() - if lenient_issubclass(model, Enum): - model = cast(Type[Enum], model) - s = enum_process_schema(model, field=field) - return s, {}, set() - model = cast(Type['BaseModel'], model) - s = {'title': model.__config__.title or model.__name__} - doc = getdoc(model) - if doc: - s['description'] = doc - known_models.add(model) - m_schema, m_definitions, nested_models = model_type_schema( - model, - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - ref_template=ref_template, - known_models=known_models, - ) - s.update(m_schema) - schema_extra = model.__config__.schema_extra - if callable(schema_extra): - if len(signature(schema_extra).parameters) == 1: - schema_extra(s) - else: - schema_extra(s, model) - else: - s.update(schema_extra) - return s, m_definitions, nested_models - - -def model_type_schema( - model: Type['BaseModel'], - *, - by_alias: bool, - model_name_map: Dict[TypeModelOrEnum, str], - ref_template: str, - ref_prefix: Optional[str] = None, - known_models: TypeModelSet, -) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: - """ - You probably should be using ``model_schema()``, this function is indirectly used by that function. - - Take a single ``model`` and generate the schema for its type only, not including additional - information as title, etc. Also return additional schema definitions, from sub-models. - """ - properties = {} - required = [] - definitions: Dict[str, Any] = {} - nested_models: Set[str] = set() - for k, f in model.__fields__.items(): - try: - f_schema, f_definitions, f_nested_models = field_schema( - f, - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - ref_template=ref_template, - known_models=known_models, - ) - except SkipField as skip: - warnings.warn(skip.message, UserWarning) - continue - definitions.update(f_definitions) - nested_models.update(f_nested_models) - if by_alias: - properties[f.alias] = f_schema - if f.required: - required.append(f.alias) - else: - properties[k] = f_schema - if f.required: - required.append(k) - if ROOT_KEY in properties: - out_schema = properties[ROOT_KEY] - out_schema['title'] = model.__config__.title or model.__name__ - else: - out_schema = {'type': 'object', 'properties': properties} - if required: - out_schema['required'] = required - if model.__config__.extra == 'forbid': - out_schema['additionalProperties'] = False - return out_schema, definitions, nested_models - - -def enum_process_schema(enum: Type[Enum], *, field: Optional[ModelField] = None) -> Dict[str, Any]: - """ - Take a single `enum` and generate its schema. - - This is similar to the `model_process_schema` function, but applies to ``Enum`` objects. - """ - schema_: Dict[str, Any] = { - 'title': enum.__name__, - # Python assigns all enums a default docstring value of 'An enumeration', so - # all enums will have a description field even if not explicitly provided. - 'description': enum.__doc__ or 'An enumeration.', - # Add enum values and the enum field type to the schema. - 'enum': [item.value for item in cast(Iterable[Enum], enum)], - } - - add_field_type_to_schema(enum, schema_) - - modify_schema = getattr(enum, '__modify_schema__', None) - if modify_schema: - _apply_modify_schema(modify_schema, field, schema_) - - return schema_ - - -def field_singleton_sub_fields_schema( - field: ModelField, - *, - by_alias: bool, - model_name_map: Dict[TypeModelOrEnum, str], - ref_template: str, - schema_overrides: bool = False, - ref_prefix: Optional[str] = None, - known_models: TypeModelSet, -) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: - """ - This function is indirectly used by ``field_schema()``, you probably should be using that function. - - Take a list of Pydantic ``ModelField`` from the declaration of a type with parameters, and generate their - schema. I.e., fields used as "type parameters", like ``str`` and ``int`` in ``Tuple[str, int]``. - """ - sub_fields = cast(List[ModelField], field.sub_fields) - definitions = {} - nested_models: Set[str] = set() - if len(sub_fields) == 1: - return field_type_schema( - sub_fields[0], - by_alias=by_alias, - model_name_map=model_name_map, - schema_overrides=schema_overrides, - ref_prefix=ref_prefix, - ref_template=ref_template, - known_models=known_models, - ) - else: - s: Dict[str, Any] = {} - # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#discriminator-object - field_has_discriminator: bool = field.discriminator_key is not None - if field_has_discriminator: - assert field.sub_fields_mapping is not None - - discriminator_models_refs: Dict[str, Union[str, Dict[str, Any]]] = {} - - for discriminator_value, sub_field in field.sub_fields_mapping.items(): - # sub_field is either a `BaseModel` or directly an `Annotated` `Union` of many - if is_union(get_origin(sub_field.type_)): - sub_models = get_sub_types(sub_field.type_) - discriminator_models_refs[discriminator_value] = { - model_name_map[sub_model]: get_schema_ref( - model_name_map[sub_model], ref_prefix, ref_template, False - ) - for sub_model in sub_models - } - else: - sub_field_type = sub_field.type_ - if hasattr(sub_field_type, '__pydantic_model__'): - sub_field_type = sub_field_type.__pydantic_model__ - - discriminator_model_name = model_name_map[sub_field_type] - discriminator_model_ref = get_schema_ref(discriminator_model_name, ref_prefix, ref_template, False) - discriminator_models_refs[discriminator_value] = discriminator_model_ref['$ref'] - - s['discriminator'] = { - 'propertyName': field.discriminator_alias, - 'mapping': discriminator_models_refs, - } - - sub_field_schemas = [] - for sf in sub_fields: - sub_schema, sub_definitions, sub_nested_models = field_type_schema( - sf, - by_alias=by_alias, - model_name_map=model_name_map, - schema_overrides=schema_overrides, - ref_prefix=ref_prefix, - ref_template=ref_template, - known_models=known_models, - ) - definitions.update(sub_definitions) - if schema_overrides and 'allOf' in sub_schema: - # if the sub_field is a referenced schema we only need the referenced - # object. Otherwise we will end up with several allOf inside anyOf/oneOf. - # See https://github.com/pydantic/pydantic/issues/1209 - sub_schema = sub_schema['allOf'][0] - - if sub_schema.keys() == {'discriminator', 'oneOf'}: - # we don't want discriminator information inside oneOf choices, this is dealt with elsewhere - sub_schema.pop('discriminator') - sub_field_schemas.append(sub_schema) - nested_models.update(sub_nested_models) - s['oneOf' if field_has_discriminator else 'anyOf'] = sub_field_schemas - return s, definitions, nested_models - - -# Order is important, e.g. subclasses of str must go before str -# this is used only for standard library types, custom types should use __modify_schema__ instead -field_class_to_schema: Tuple[Tuple[Any, Dict[str, Any]], ...] = ( - (Path, {'type': 'string', 'format': 'path'}), - (datetime, {'type': 'string', 'format': 'date-time'}), - (date, {'type': 'string', 'format': 'date'}), - (time, {'type': 'string', 'format': 'time'}), - (timedelta, {'type': 'number', 'format': 'time-delta'}), - (IPv4Network, {'type': 'string', 'format': 'ipv4network'}), - (IPv6Network, {'type': 'string', 'format': 'ipv6network'}), - (IPv4Interface, {'type': 'string', 'format': 'ipv4interface'}), - (IPv6Interface, {'type': 'string', 'format': 'ipv6interface'}), - (IPv4Address, {'type': 'string', 'format': 'ipv4'}), - (IPv6Address, {'type': 'string', 'format': 'ipv6'}), - (Pattern, {'type': 'string', 'format': 'regex'}), - (str, {'type': 'string'}), - (bytes, {'type': 'string', 'format': 'binary'}), - (bool, {'type': 'boolean'}), - (int, {'type': 'integer'}), - (float, {'type': 'number'}), - (Decimal, {'type': 'number'}), - (UUID, {'type': 'string', 'format': 'uuid'}), - (dict, {'type': 'object'}), - (list, {'type': 'array', 'items': {}}), - (tuple, {'type': 'array', 'items': {}}), - (set, {'type': 'array', 'items': {}, 'uniqueItems': True}), - (frozenset, {'type': 'array', 'items': {}, 'uniqueItems': True}), -) - -json_scheme = {'type': 'string', 'format': 'json-string'} - - -def add_field_type_to_schema(field_type: Any, schema_: Dict[str, Any]) -> None: - """ - Update the given `schema` with the type-specific metadata for the given `field_type`. - - This function looks through `field_class_to_schema` for a class that matches the given `field_type`, - and then modifies the given `schema` with the information from that type. - """ - for type_, t_schema in field_class_to_schema: - # Fallback for `typing.Pattern` and `re.Pattern` as they are not a valid class - if lenient_issubclass(field_type, type_) or field_type is type_ is Pattern: - schema_.update(t_schema) - break - - -def get_schema_ref(name: str, ref_prefix: Optional[str], ref_template: str, schema_overrides: bool) -> Dict[str, Any]: - if ref_prefix: - schema_ref = {'$ref': ref_prefix + name} - else: - schema_ref = {'$ref': ref_template.format(model=name)} - return {'allOf': [schema_ref]} if schema_overrides else schema_ref - - -def field_singleton_schema( # noqa: C901 (ignore complexity) - field: ModelField, - *, - by_alias: bool, - model_name_map: Dict[TypeModelOrEnum, str], - ref_template: str, - schema_overrides: bool = False, - ref_prefix: Optional[str] = None, - known_models: TypeModelSet, -) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: - """ - This function is indirectly used by ``field_schema()``, you should probably be using that function. - - Take a single Pydantic ``ModelField``, and return its schema and any additional definitions from sub-models. - """ - from .main import BaseModel - - definitions: Dict[str, Any] = {} - nested_models: Set[str] = set() - field_type = field.type_ - - # Recurse into this field if it contains sub_fields and is NOT a - # BaseModel OR that BaseModel is a const - if field.sub_fields and ( - (field.field_info and field.field_info.const) or not lenient_issubclass(field_type, BaseModel) - ): - return field_singleton_sub_fields_schema( - field, - by_alias=by_alias, - model_name_map=model_name_map, - schema_overrides=schema_overrides, - ref_prefix=ref_prefix, - ref_template=ref_template, - known_models=known_models, - ) - if field_type is Any or field_type is object or field_type.__class__ == TypeVar or get_origin(field_type) is type: - return {}, definitions, nested_models # no restrictions - if is_none_type(field_type): - return {'type': 'null'}, definitions, nested_models - if is_callable_type(field_type): - raise SkipField(f'Callable {field.name} was excluded from schema since JSON schema has no equivalent type.') - f_schema: Dict[str, Any] = {} - if field.field_info is not None and field.field_info.const: - f_schema['const'] = field.default - - if is_literal_type(field_type): - values = all_literal_values(field_type) - - if len({v.__class__ for v in values}) > 1: - return field_schema( - multitypes_literal_field_for_schema(values, field), - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - ref_template=ref_template, - known_models=known_models, - ) - - # All values have the same type - field_type = values[0].__class__ - f_schema['enum'] = list(values) - add_field_type_to_schema(field_type, f_schema) - elif lenient_issubclass(field_type, Enum): - enum_name = model_name_map[field_type] - f_schema, schema_overrides = get_field_info_schema(field, schema_overrides) - f_schema.update(get_schema_ref(enum_name, ref_prefix, ref_template, schema_overrides)) - definitions[enum_name] = enum_process_schema(field_type, field=field) - elif is_namedtuple(field_type): - sub_schema, *_ = model_process_schema( - field_type.__pydantic_model__, - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - ref_template=ref_template, - known_models=known_models, - field=field, - ) - items_schemas = list(sub_schema['properties'].values()) - f_schema.update( - { - 'type': 'array', - 'items': items_schemas, - 'minItems': len(items_schemas), - 'maxItems': len(items_schemas), - } - ) - elif not hasattr(field_type, '__pydantic_model__'): - add_field_type_to_schema(field_type, f_schema) - - modify_schema = getattr(field_type, '__modify_schema__', None) - if modify_schema: - _apply_modify_schema(modify_schema, field, f_schema) - - if f_schema: - return f_schema, definitions, nested_models - - # Handle dataclass-based models - if lenient_issubclass(getattr(field_type, '__pydantic_model__', None), BaseModel): - field_type = field_type.__pydantic_model__ - - if issubclass(field_type, BaseModel): - model_name = model_name_map[field_type] - if field_type not in known_models: - sub_schema, sub_definitions, sub_nested_models = model_process_schema( - field_type, - by_alias=by_alias, - model_name_map=model_name_map, - ref_prefix=ref_prefix, - ref_template=ref_template, - known_models=known_models, - field=field, - ) - definitions.update(sub_definitions) - definitions[model_name] = sub_schema - nested_models.update(sub_nested_models) - else: - nested_models.add(model_name) - schema_ref = get_schema_ref(model_name, ref_prefix, ref_template, schema_overrides) - return schema_ref, definitions, nested_models - - # For generics with no args - args = get_args(field_type) - if args is not None and not args and Generic in field_type.__bases__: - return f_schema, definitions, nested_models - - raise ValueError(f'Value not declarable with JSON Schema, field: {field}') - - -def multitypes_literal_field_for_schema(values: Tuple[Any, ...], field: ModelField) -> ModelField: - """ - To support `Literal` with values of different types, we split it into multiple `Literal` with same type - e.g. `Literal['qwe', 'asd', 1, 2]` becomes `Union[Literal['qwe', 'asd'], Literal[1, 2]]` - """ - literal_distinct_types = defaultdict(list) - for v in values: - literal_distinct_types[v.__class__].append(v) - distinct_literals = (Literal[tuple(same_type_values)] for same_type_values in literal_distinct_types.values()) - - return ModelField( - name=field.name, - type_=Union[tuple(distinct_literals)], # type: ignore - class_validators=field.class_validators, - model_config=field.model_config, - default=field.default, - required=field.required, - alias=field.alias, - field_info=field.field_info, - ) - - -def encode_default(dft: Any) -> Any: - if isinstance(dft, Enum): - return dft.value - elif isinstance(dft, (int, float, str)): - return dft - elif isinstance(dft, (list, tuple)): - t = dft.__class__ - seq_args = (encode_default(v) for v in dft) - return t(*seq_args) if is_namedtuple(t) else t(seq_args) - elif isinstance(dft, dict): - return {encode_default(k): encode_default(v) for k, v in dft.items()} - elif dft is None: - return None - else: - return pydantic_encoder(dft) - - -_map_types_constraint: Dict[Any, Callable[..., type]] = {int: conint, float: confloat, Decimal: condecimal} - - -def get_annotation_from_field_info( - annotation: Any, field_info: FieldInfo, field_name: str, validate_assignment: bool = False -) -> Type[Any]: - """ - Get an annotation with validation implemented for numbers and strings based on the field_info. - :param annotation: an annotation from a field specification, as ``str``, ``ConstrainedStr`` - :param field_info: an instance of FieldInfo, possibly with declarations for validations and JSON Schema - :param field_name: name of the field for use in error messages - :param validate_assignment: default False, flag for BaseModel Config value of validate_assignment - :return: the same ``annotation`` if unmodified or a new annotation with validation in place - """ - constraints = field_info.get_constraints() - used_constraints: Set[str] = set() - if constraints: - annotation, used_constraints = get_annotation_with_constraints(annotation, field_info) - if validate_assignment: - used_constraints.add('allow_mutation') - - unused_constraints = constraints - used_constraints - if unused_constraints: - raise ValueError( - f'On field "{field_name}" the following field constraints are set but not enforced: ' - f'{", ".join(unused_constraints)}. ' - f'\nFor more details see https://pydantic-docs.helpmanual.io/usage/schema/#unenforced-field-constraints' - ) - - return annotation - - -def get_annotation_with_constraints(annotation: Any, field_info: FieldInfo) -> Tuple[Type[Any], Set[str]]: # noqa: C901 - """ - Get an annotation with used constraints implemented for numbers and strings based on the field_info. - - :param annotation: an annotation from a field specification, as ``str``, ``ConstrainedStr`` - :param field_info: an instance of FieldInfo, possibly with declarations for validations and JSON Schema - :return: the same ``annotation`` if unmodified or a new annotation along with the used constraints. - """ - used_constraints: Set[str] = set() - - def go(type_: Any) -> Type[Any]: - if ( - is_literal_type(type_) - or isinstance(type_, ForwardRef) - or lenient_issubclass(type_, (ConstrainedList, ConstrainedSet, ConstrainedFrozenSet)) - ): - return type_ - origin = get_origin(type_) - if origin is not None: - args: Tuple[Any, ...] = get_args(type_) - if any(isinstance(a, ForwardRef) for a in args): - # forward refs cause infinite recursion below - return type_ - - if origin is Annotated: - return go(args[0]) - if is_union(origin): - return Union[tuple(go(a) for a in args)] # type: ignore - - if issubclass(origin, List) and ( - field_info.min_items is not None - or field_info.max_items is not None - or field_info.unique_items is not None - ): - used_constraints.update({'min_items', 'max_items', 'unique_items'}) - return conlist( - go(args[0]), - min_items=field_info.min_items, - max_items=field_info.max_items, - unique_items=field_info.unique_items, - ) - - if issubclass(origin, Set) and (field_info.min_items is not None or field_info.max_items is not None): - used_constraints.update({'min_items', 'max_items'}) - return conset(go(args[0]), min_items=field_info.min_items, max_items=field_info.max_items) - - if issubclass(origin, FrozenSet) and (field_info.min_items is not None or field_info.max_items is not None): - used_constraints.update({'min_items', 'max_items'}) - return confrozenset(go(args[0]), min_items=field_info.min_items, max_items=field_info.max_items) - - for t in (Tuple, List, Set, FrozenSet, Sequence): - if issubclass(origin, t): # type: ignore - return t[tuple(go(a) for a in args)] # type: ignore - - if issubclass(origin, Dict): - return Dict[args[0], go(args[1])] # type: ignore - - attrs: Optional[Tuple[str, ...]] = None - constraint_func: Optional[Callable[..., type]] = None - if isinstance(type_, type): - if issubclass(type_, (SecretStr, SecretBytes)): - attrs = ('max_length', 'min_length') - - def constraint_func(**kw: Any) -> Type[Any]: - return type(type_.__name__, (type_,), kw) - - elif issubclass(type_, str) and not issubclass(type_, (EmailStr, AnyUrl)): - attrs = ('max_length', 'min_length', 'regex') - if issubclass(type_, StrictStr): - - def constraint_func(**kw: Any) -> Type[Any]: - return type(type_.__name__, (type_,), kw) - - else: - constraint_func = constr - elif issubclass(type_, bytes): - attrs = ('max_length', 'min_length', 'regex') - if issubclass(type_, StrictBytes): - - def constraint_func(**kw: Any) -> Type[Any]: - return type(type_.__name__, (type_,), kw) - - else: - constraint_func = conbytes - elif issubclass(type_, numeric_types) and not issubclass( - type_, - ( - ConstrainedInt, - ConstrainedFloat, - ConstrainedDecimal, - ConstrainedList, - ConstrainedSet, - ConstrainedFrozenSet, - bool, - ), - ): - # Is numeric type - attrs = ('gt', 'lt', 'ge', 'le', 'multiple_of') - if issubclass(type_, float): - attrs += ('allow_inf_nan',) - if issubclass(type_, Decimal): - attrs += ('max_digits', 'decimal_places') - numeric_type = next(t for t in numeric_types if issubclass(type_, t)) # pragma: no branch - constraint_func = _map_types_constraint[numeric_type] - - if attrs: - used_constraints.update(set(attrs)) - kwargs = { - attr_name: attr - for attr_name, attr in ((attr_name, getattr(field_info, attr_name)) for attr_name in attrs) - if attr is not None - } - if kwargs: - constraint_func = cast(Callable[..., type], constraint_func) - return constraint_func(**kwargs) - return type_ - - return go(annotation), used_constraints - - -def normalize_name(name: str) -> str: - """ - Normalizes the given name. This can be applied to either a model *or* enum. - """ - return re.sub(r'[^a-zA-Z0-9.\-_]', '_', name) - - -class SkipField(Exception): - """ - Utility exception used to exclude fields from schema. - """ - - def __init__(self, message: str) -> None: - self.message = message +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/tools.py b/lib/pydantic/tools.py index 9cdb4538..8e317c92 100644 --- a/lib/pydantic/tools.py +++ b/lib/pydantic/tools.py @@ -1,92 +1,4 @@ -import json -from functools import lru_cache -from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Union +"""The `tools` module is a backport module from V1.""" +from ._migration import getattr_migration -from .parse import Protocol, load_file, load_str_bytes -from .types import StrBytes -from .typing import display_as_type - -__all__ = ('parse_file_as', 'parse_obj_as', 'parse_raw_as', 'schema_of', 'schema_json_of') - -NameFactory = Union[str, Callable[[Type[Any]], str]] - -if TYPE_CHECKING: - from .typing import DictStrAny - - -def _generate_parsing_type_name(type_: Any) -> str: - return f'ParsingModel[{display_as_type(type_)}]' - - -@lru_cache(maxsize=2048) -def _get_parsing_type(type_: Any, *, type_name: Optional[NameFactory] = None) -> Any: - from pydantic.main import create_model - - if type_name is None: - type_name = _generate_parsing_type_name - if not isinstance(type_name, str): - type_name = type_name(type_) - return create_model(type_name, __root__=(type_, ...)) - - -T = TypeVar('T') - - -def parse_obj_as(type_: Type[T], obj: Any, *, type_name: Optional[NameFactory] = None) -> T: - model_type = _get_parsing_type(type_, type_name=type_name) # type: ignore[arg-type] - return model_type(__root__=obj).__root__ - - -def parse_file_as( - type_: Type[T], - path: Union[str, Path], - *, - content_type: str = None, - encoding: str = 'utf8', - proto: Protocol = None, - allow_pickle: bool = False, - json_loads: Callable[[str], Any] = json.loads, - type_name: Optional[NameFactory] = None, -) -> T: - obj = load_file( - path, - proto=proto, - content_type=content_type, - encoding=encoding, - allow_pickle=allow_pickle, - json_loads=json_loads, - ) - return parse_obj_as(type_, obj, type_name=type_name) - - -def parse_raw_as( - type_: Type[T], - b: StrBytes, - *, - content_type: str = None, - encoding: str = 'utf8', - proto: Protocol = None, - allow_pickle: bool = False, - json_loads: Callable[[str], Any] = json.loads, - type_name: Optional[NameFactory] = None, -) -> T: - obj = load_str_bytes( - b, - proto=proto, - content_type=content_type, - encoding=encoding, - allow_pickle=allow_pickle, - json_loads=json_loads, - ) - return parse_obj_as(type_, obj, type_name=type_name) - - -def schema_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_kwargs: Any) -> 'DictStrAny': - """Generate a JSON schema (as dict) for the passed model or dynamically generated one""" - return _get_parsing_type(type_, type_name=title).schema(**schema_kwargs) - - -def schema_json_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_json_kwargs: Any) -> str: - """Generate a JSON schema (as JSON) for the passed model or dynamically generated one""" - return _get_parsing_type(type_, type_name=title).schema_json(**schema_json_kwargs) +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/type_adapter.py b/lib/pydantic/type_adapter.py new file mode 100644 index 00000000..366262fe --- /dev/null +++ b/lib/pydantic/type_adapter.py @@ -0,0 +1,460 @@ +"""Type adapter specification.""" +from __future__ import annotations as _annotations + +import sys +from dataclasses import is_dataclass +from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Set, TypeVar, Union, cast, final, overload + +from pydantic_core import CoreSchema, SchemaSerializer, SchemaValidator, Some +from typing_extensions import Literal, get_args, is_typeddict + +from pydantic.errors import PydanticUserError +from pydantic.main import BaseModel + +from ._internal import _config, _generate_schema, _typing_extra +from .config import ConfigDict +from .json_schema import ( + DEFAULT_REF_TEMPLATE, + GenerateJsonSchema, + JsonSchemaKeyT, + JsonSchemaMode, + JsonSchemaValue, +) +from .plugin._schema_validator import create_schema_validator + +T = TypeVar('T') + + +if TYPE_CHECKING: + # should be `set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None`, but mypy can't cope + IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]] + + +def _get_schema(type_: Any, config_wrapper: _config.ConfigWrapper, parent_depth: int) -> CoreSchema: + """`BaseModel` uses its own `__module__` to find out where it was defined + and then looks for symbols to resolve forward references in those globals. + On the other hand this function can be called with arbitrary objects, + including type aliases, where `__module__` (always `typing.py`) is not useful. + So instead we look at the globals in our parent stack frame. + + This works for the case where this function is called in a module that + has the target of forward references in its scope, but + does not always work for more complex cases. + + For example, take the following: + + a.py + ```python + from typing import Dict, List + + IntList = List[int] + OuterDict = Dict[str, 'IntList'] + ``` + + b.py + ```python test="skip" + from a import OuterDict + + from pydantic import TypeAdapter + + IntList = int # replaces the symbol the forward reference is looking for + v = TypeAdapter(OuterDict) + v({'x': 1}) # should fail but doesn't + ``` + + If `OuterDict` were a `BaseModel`, this would work because it would resolve + the forward reference within the `a.py` namespace. + But `TypeAdapter(OuterDict)` can't determine what module `OuterDict` came from. + + In other words, the assumption that _all_ forward references exist in the + module we are being called from is not technically always true. + Although most of the time it is and it works fine for recursive models and such, + `BaseModel`'s behavior isn't perfect either and _can_ break in similar ways, + so there is no right or wrong between the two. + + But at the very least this behavior is _subtly_ different from `BaseModel`'s. + """ + local_ns = _typing_extra.parent_frame_namespace(parent_depth=parent_depth) + global_ns = sys._getframe(max(parent_depth - 1, 1)).f_globals.copy() + global_ns.update(local_ns or {}) + gen = _generate_schema.GenerateSchema(config_wrapper, types_namespace=global_ns, typevars_map={}) + schema = gen.generate_schema(type_) + schema = gen.clean_schema(schema) + return schema + + +def _getattr_no_parents(obj: Any, attribute: str) -> Any: + """Returns the attribute value without attempting to look up attributes from parent types.""" + if hasattr(obj, '__dict__'): + try: + return obj.__dict__[attribute] + except KeyError: + pass + + slots = getattr(obj, '__slots__', None) + if slots is not None and attribute in slots: + return getattr(obj, attribute) + else: + raise AttributeError(attribute) + + +def _type_has_config(type_: Any) -> bool: + """Returns whether the type has config.""" + try: + return issubclass(type_, BaseModel) or is_dataclass(type_) or is_typeddict(type_) + except TypeError: + # type is not a class + return False + + +@final +class TypeAdapter(Generic[T]): + """Usage docs: https://docs.pydantic.dev/2.6/concepts/type_adapter/ + + Type adapters provide a flexible way to perform validation and serialization based on a Python type. + + A `TypeAdapter` instance exposes some of the functionality from `BaseModel` instance methods + for types that do not have such methods (such as dataclasses, primitive types, and more). + + **Note:** `TypeAdapter` instances are not types, and cannot be used as type annotations for fields. + + Attributes: + core_schema: The core schema for the type. + validator (SchemaValidator): The schema validator for the type. + serializer: The schema serializer for the type. + """ + + @overload + def __init__( + self, + type: type[T], + *, + config: ConfigDict | None = ..., + _parent_depth: int = ..., + module: str | None = ..., + ) -> None: + ... + + # This second overload is for unsupported special forms (such as Union). `pyright` handles them fine, but `mypy` does not match + # them against `type: type[T]`, so an explicit overload with `type: T` is needed. + @overload + def __init__( # pyright: ignore[reportOverlappingOverload] + self, + type: T, + *, + config: ConfigDict | None = ..., + _parent_depth: int = ..., + module: str | None = ..., + ) -> None: + ... + + def __init__( + self, + type: type[T] | T, + *, + config: ConfigDict | None = None, + _parent_depth: int = 2, + module: str | None = None, + ) -> None: + """Initializes the TypeAdapter object. + + Args: + type: The type associated with the `TypeAdapter`. + config: Configuration for the `TypeAdapter`, should be a dictionary conforming to [`ConfigDict`][pydantic.config.ConfigDict]. + _parent_depth: depth at which to search the parent namespace to construct the local namespace. + module: The module that passes to plugin if provided. + + !!! note + You cannot use the `config` argument when instantiating a `TypeAdapter` if the type you're using has its own + config that cannot be overridden (ex: `BaseModel`, `TypedDict`, and `dataclass`). A + [`type-adapter-config-unused`](../errors/usage_errors.md#type-adapter-config-unused) error will be raised in this case. + + !!! note + The `_parent_depth` argument is named with an underscore to suggest its private nature and discourage use. + It may be deprecated in a minor version, so we only recommend using it if you're + comfortable with potential change in behavior / support. + + ??? tip "Compatibility with `mypy`" + Depending on the type used, `mypy` might raise an error when instantiating a `TypeAdapter`. As a workaround, you can explicitly + annotate your variable: + + ```py + from typing import Union + + from pydantic import TypeAdapter + + ta: TypeAdapter[Union[str, int]] = TypeAdapter(Union[str, int]) # type: ignore[arg-type] + ``` + + Returns: + A type adapter configured for the specified `type`. + """ + type_is_annotated: bool = _typing_extra.is_annotated(type) + annotated_type: Any = get_args(type)[0] if type_is_annotated else None + type_has_config: bool = _type_has_config(annotated_type if type_is_annotated else type) + + if type_has_config and config is not None: + raise PydanticUserError( + 'Cannot use `config` when the type is a BaseModel, dataclass or TypedDict.' + ' These types can have their own config and setting the config via the `config`' + ' parameter to TypeAdapter will not override it, thus the `config` you passed to' + ' TypeAdapter becomes meaningless, which is probably not what you want.', + code='type-adapter-config-unused', + ) + + config_wrapper = _config.ConfigWrapper(config) + + core_schema: CoreSchema + try: + core_schema = _getattr_no_parents(type, '__pydantic_core_schema__') + except AttributeError: + core_schema = _get_schema(type, config_wrapper, parent_depth=_parent_depth + 1) + + core_config = config_wrapper.core_config(None) + validator: SchemaValidator + try: + validator = _getattr_no_parents(type, '__pydantic_validator__') + except AttributeError: + if module is None: + f = sys._getframe(1) + module = cast(str, f.f_globals.get('__name__', '')) + validator = create_schema_validator( + core_schema, type, module, str(type), 'TypeAdapter', core_config, config_wrapper.plugin_settings + ) # type: ignore + + serializer: SchemaSerializer + try: + serializer = _getattr_no_parents(type, '__pydantic_serializer__') + except AttributeError: + serializer = SchemaSerializer(core_schema, core_config) + + self.core_schema = core_schema + self.validator = validator + self.serializer = serializer + + def validate_python( + self, + __object: Any, + *, + strict: bool | None = None, + from_attributes: bool | None = None, + context: dict[str, Any] | None = None, + ) -> T: + """Validate a Python object against the model. + + Args: + __object: The Python object to validate against the model. + strict: Whether to strictly check types. + from_attributes: Whether to extract data from object attributes. + context: Additional context to pass to the validator. + + !!! note + When using `TypeAdapter` with a Pydantic `dataclass`, the use of the `from_attributes` + argument is not supported. + + Returns: + The validated object. + """ + return self.validator.validate_python(__object, strict=strict, from_attributes=from_attributes, context=context) + + def validate_json( + self, __data: str | bytes, *, strict: bool | None = None, context: dict[str, Any] | None = None + ) -> T: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/json/#json-parsing + + Validate a JSON string or bytes against the model. + + Args: + __data: The JSON data to validate against the model. + strict: Whether to strictly check types. + context: Additional context to use during validation. + + Returns: + The validated object. + """ + return self.validator.validate_json(__data, strict=strict, context=context) + + def validate_strings(self, __obj: Any, *, strict: bool | None = None, context: dict[str, Any] | None = None) -> T: + """Validate object contains string data against the model. + + Args: + __obj: The object contains string data to validate. + strict: Whether to strictly check types. + context: Additional context to use during validation. + + Returns: + The validated object. + """ + return self.validator.validate_strings(__obj, strict=strict, context=context) + + def get_default_value(self, *, strict: bool | None = None, context: dict[str, Any] | None = None) -> Some[T] | None: + """Get the default value for the wrapped type. + + Args: + strict: Whether to strictly check types. + context: Additional context to pass to the validator. + + Returns: + The default value wrapped in a `Some` if there is one or None if not. + """ + return self.validator.get_default_value(strict=strict, context=context) + + def dump_python( + self, + __instance: T, + *, + mode: Literal['json', 'python'] = 'python', + include: IncEx | None = None, + exclude: IncEx | None = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + round_trip: bool = False, + warnings: bool = True, + ) -> Any: + """Dump an instance of the adapted type to a Python object. + + Args: + __instance: The Python object to serialize. + mode: The output format. + include: Fields to include in the output. + exclude: Fields to exclude from the output. + by_alias: Whether to use alias names for field names. + exclude_unset: Whether to exclude unset fields. + exclude_defaults: Whether to exclude fields with default values. + exclude_none: Whether to exclude fields with None values. + round_trip: Whether to output the serialized data in a way that is compatible with deserialization. + warnings: Whether to display serialization warnings. + + Returns: + The serialized object. + """ + return self.serializer.to_python( + __instance, + mode=mode, + by_alias=by_alias, + include=include, + exclude=exclude, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + round_trip=round_trip, + warnings=warnings, + ) + + def dump_json( + self, + __instance: T, + *, + indent: int | None = None, + include: IncEx | None = None, + exclude: IncEx | None = None, + by_alias: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + round_trip: bool = False, + warnings: bool = True, + ) -> bytes: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/json/#json-serialization + + Serialize an instance of the adapted type to JSON. + + Args: + __instance: The instance to be serialized. + indent: Number of spaces for JSON indentation. + include: Fields to include. + exclude: Fields to exclude. + by_alias: Whether to use alias names for field names. + exclude_unset: Whether to exclude unset fields. + exclude_defaults: Whether to exclude fields with default values. + exclude_none: Whether to exclude fields with a value of `None`. + round_trip: Whether to serialize and deserialize the instance to ensure round-tripping. + warnings: Whether to emit serialization warnings. + + Returns: + The JSON representation of the given instance as bytes. + """ + return self.serializer.to_json( + __instance, + indent=indent, + include=include, + exclude=exclude, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + round_trip=round_trip, + warnings=warnings, + ) + + def json_schema( + self, + *, + by_alias: bool = True, + ref_template: str = DEFAULT_REF_TEMPLATE, + schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema, + mode: JsonSchemaMode = 'validation', + ) -> dict[str, Any]: + """Generate a JSON schema for the adapted type. + + Args: + by_alias: Whether to use alias names for field names. + ref_template: The format string used for generating $ref strings. + schema_generator: The generator class used for creating the schema. + mode: The mode to use for schema generation. + + Returns: + The JSON schema for the model as a dictionary. + """ + schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template) + return schema_generator_instance.generate(self.core_schema, mode=mode) + + @staticmethod + def json_schemas( + __inputs: Iterable[tuple[JsonSchemaKeyT, JsonSchemaMode, TypeAdapter[Any]]], + *, + by_alias: bool = True, + title: str | None = None, + description: str | None = None, + ref_template: str = DEFAULT_REF_TEMPLATE, + schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema, + ) -> tuple[dict[tuple[JsonSchemaKeyT, JsonSchemaMode], JsonSchemaValue], JsonSchemaValue]: + """Generate a JSON schema including definitions from multiple type adapters. + + Args: + __inputs: Inputs to schema generation. The first two items will form the keys of the (first) + output mapping; the type adapters will provide the core schemas that get converted into + definitions in the output JSON schema. + by_alias: Whether to use alias names. + title: The title for the schema. + description: The description for the schema. + ref_template: The format string used for generating $ref strings. + schema_generator: The generator class used for creating the schema. + + Returns: + A tuple where: + + - The first element is a dictionary whose keys are tuples of JSON schema key type and JSON mode, and + whose values are the JSON schema corresponding to that pair of inputs. (These schemas may have + JsonRef references to definitions that are defined in the second returned element.) + - The second element is a JSON schema containing all definitions referenced in the first returned + element, along with the optional title and description keys. + + """ + schema_generator_instance = schema_generator(by_alias=by_alias, ref_template=ref_template) + + inputs = [(key, mode, adapter.core_schema) for key, mode, adapter in __inputs] + + json_schemas_map, definitions = schema_generator_instance.generate_definitions(inputs) + + json_schema: dict[str, Any] = {} + if definitions: + json_schema['$defs'] = definitions + if title: + json_schema['title'] = title + if description: + json_schema['description'] = description + + return json_schemas_map, json_schema diff --git a/lib/pydantic/types.py b/lib/pydantic/types.py index f98dba3d..c2534c88 100644 --- a/lib/pydantic/types.py +++ b/lib/pydantic/types.py @@ -1,12 +1,14 @@ -import abc -import math +"""The types module contains custom types used by pydantic.""" +from __future__ import annotations as _annotations + +import base64 +import dataclasses as _dataclasses import re -import warnings -from datetime import date +from datetime import date, datetime from decimal import Decimal from enum import Enum from pathlib import Path -from types import new_class +from types import ModuleType from typing import ( TYPE_CHECKING, Any, @@ -14,78 +16,56 @@ from typing import ( ClassVar, Dict, FrozenSet, + Generic, + Hashable, + Iterator, List, - Optional, - Pattern, Set, - Tuple, - Type, TypeVar, Union, cast, - overload, ) from uuid import UUID -from weakref import WeakSet -from . import errors -from .datetime_parse import parse_date -from .utils import import_string, update_not_none -from .validators import ( - bytes_validator, - constr_length_validator, - constr_lower, - constr_strip_whitespace, - constr_upper, - decimal_validator, - float_finite_validator, - float_validator, - frozenset_validator, - int_validator, - list_validator, - number_multiple_validator, - number_size_validator, - path_exists_validator, - path_validator, - set_validator, - str_validator, - strict_bytes_validator, - strict_float_validator, - strict_int_validator, - strict_str_validator, +import annotated_types +from annotated_types import BaseMetadata, MaxLen, MinLen +from pydantic_core import CoreSchema, PydanticCustomError, core_schema +from typing_extensions import Annotated, Literal, Protocol, TypeAlias, TypeAliasType, deprecated + +from ._internal import ( + _core_utils, + _fields, + _internal_dataclass, + _typing_extra, + _utils, + _validators, ) +from ._migration import getattr_migration +from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler +from .errors import PydanticUserError +from .json_schema import JsonSchemaValue +from .warnings import PydanticDeprecatedSince20 -__all__ = [ - 'NoneStr', - 'NoneBytes', - 'StrBytes', - 'NoneStrBytes', +__all__ = ( + 'Strict', 'StrictStr', - 'ConstrainedBytes', 'conbytes', - 'ConstrainedList', 'conlist', - 'ConstrainedSet', 'conset', - 'ConstrainedFrozenSet', 'confrozenset', - 'ConstrainedStr', 'constr', - 'PyObject', - 'ConstrainedInt', + 'ImportString', 'conint', 'PositiveInt', 'NegativeInt', 'NonNegativeInt', 'NonPositiveInt', - 'ConstrainedFloat', 'confloat', 'PositiveFloat', 'NegativeFloat', 'NonNegativeFloat', 'NonPositiveFloat', 'FiniteFloat', - 'ConstrainedDecimal', 'condecimal', 'UUID1', 'UUID3', @@ -93,9 +73,8 @@ __all__ = [ 'UUID5', 'FilePath', 'DirectoryPath', + 'NewPath', 'Json', - 'JsonWrapper', - 'SecretField', 'SecretStr', 'SecretBytes', 'StrictBool', @@ -106,845 +85,1512 @@ __all__ = [ 'ByteSize', 'PastDate', 'FutureDate', - 'ConstrainedDate', + 'PastDatetime', + 'FutureDatetime', 'condate', -] + 'AwareDatetime', + 'NaiveDatetime', + 'AllowInfNan', + 'EncoderProtocol', + 'EncodedBytes', + 'EncodedStr', + 'Base64Encoder', + 'Base64Bytes', + 'Base64Str', + 'Base64UrlBytes', + 'Base64UrlStr', + 'GetPydanticSchema', + 'StringConstraints', + 'Tag', + 'Discriminator', + 'JsonValue', + 'OnErrorOmit', +) -NoneStr = Optional[str] -NoneBytes = Optional[bytes] -StrBytes = Union[str, bytes] -NoneStrBytes = Optional[StrBytes] -OptionalInt = Optional[int] -OptionalIntFloat = Union[OptionalInt, float] -OptionalIntFloatDecimal = Union[OptionalIntFloat, Decimal] -OptionalDate = Optional[date] -StrIntFloat = Union[str, int, float] - -if TYPE_CHECKING: - from typing_extensions import Annotated - - from .dataclasses import Dataclass - from .main import BaseModel - from .typing import CallableGenerator - - ModelOrDc = Type[Union[BaseModel, Dataclass]] T = TypeVar('T') -_DEFINED_TYPES: 'WeakSet[type]' = WeakSet() -@overload -def _registered(typ: Type[T]) -> Type[T]: - pass +@_dataclasses.dataclass +class Strict(_fields.PydanticMetadata, BaseMetadata): + """Usage docs: https://docs.pydantic.dev/2.6/concepts/strict_mode/#strict-mode-with-annotated-strict + A field metadata class to indicate that a field should be validated in strict mode. -@overload -def _registered(typ: 'ConstrainedNumberMeta') -> 'ConstrainedNumberMeta': - pass + Attributes: + strict: Whether to validate the field in strict mode. + Example: + ```python + from typing_extensions import Annotated -def _registered(typ: Union[Type[T], 'ConstrainedNumberMeta']) -> Union[Type[T], 'ConstrainedNumberMeta']: - # In order to generate valid examples of constrained types, Hypothesis needs - # to inspect the type object - so we keep a weakref to each contype object - # until it can be registered. When (or if) our Hypothesis plugin is loaded, - # it monkeypatches this function. - # If Hypothesis is never used, the total effect is to keep a weak reference - # which has minimal memory usage and doesn't even affect garbage collection. - _DEFINED_TYPES.add(typ) - return typ + from pydantic.types import Strict + StrictBool = Annotated[bool, Strict()] + ``` + """ -class ConstrainedNumberMeta(type): - def __new__(cls, name: str, bases: Any, dct: Dict[str, Any]) -> 'ConstrainedInt': # type: ignore - new_cls = cast('ConstrainedInt', type.__new__(cls, name, bases, dct)) + strict: bool = True - if new_cls.gt is not None and new_cls.ge is not None: - raise errors.ConfigError('bounds gt and ge cannot be specified at the same time') - if new_cls.lt is not None and new_cls.le is not None: - raise errors.ConfigError('bounds lt and le cannot be specified at the same time') - - return _registered(new_cls) # type: ignore + def __hash__(self) -> int: + return hash(self.strict) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BOOLEAN TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -if TYPE_CHECKING: - StrictBool = bool -else: - - class StrictBool(int): - """ - StrictBool to allow for bools which are not type-coerced. - """ - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type='boolean') - - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield cls.validate - - @classmethod - def validate(cls, value: Any) -> bool: - """ - Ensure that we only allow bools. - """ - if isinstance(value, bool): - return value - - raise errors.StrictBoolError() - +StrictBool = Annotated[bool, Strict()] +"""A boolean that must be either ``True`` or ``False``.""" # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTEGER TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -class ConstrainedInt(int, metaclass=ConstrainedNumberMeta): - strict: bool = False - gt: OptionalInt = None - ge: OptionalInt = None - lt: OptionalInt = None - le: OptionalInt = None - multiple_of: OptionalInt = None - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none( - field_schema, - exclusiveMinimum=cls.gt, - exclusiveMaximum=cls.lt, - minimum=cls.ge, - maximum=cls.le, - multipleOf=cls.multiple_of, - ) - - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield strict_int_validator if cls.strict else int_validator - yield number_size_validator - yield number_multiple_validator - - def conint( - *, strict: bool = False, gt: int = None, ge: int = None, lt: int = None, le: int = None, multiple_of: int = None -) -> Type[int]: - # use kwargs then define conf in a dict to aid with IDE type hinting - namespace = dict(strict=strict, gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of) - return type('ConstrainedIntValue', (ConstrainedInt,), namespace) + *, + strict: bool | None = None, + gt: int | None = None, + ge: int | None = None, + lt: int | None = None, + le: int | None = None, + multiple_of: int | None = None, +) -> type[int]: + """ + !!! warning "Discouraged" + This function is **discouraged** in favor of using + [`Annotated`](https://docs.python.org/3/library/typing.html#typing.Annotated) with + [`Field`][pydantic.fields.Field] instead. + + This function will be **deprecated** in Pydantic 3.0. + + The reason is that `conint` returns a type, which doesn't play well with static analysis tools. + + === ":x: Don't do this" + ```py + from pydantic import BaseModel, conint + + class Foo(BaseModel): + bar: conint(strict=True, gt=0) + ``` + + === ":white_check_mark: Do this" + ```py + from typing_extensions import Annotated + + from pydantic import BaseModel, Field + + class Foo(BaseModel): + bar: Annotated[int, Field(strict=True, gt=0)] + ``` + + A wrapper around `int` that allows for additional constraints. + + Args: + strict: Whether to validate the integer in strict mode. Defaults to `None`. + gt: The value must be greater than this. + ge: The value must be greater than or equal to this. + lt: The value must be less than this. + le: The value must be less than or equal to this. + multiple_of: The value must be a multiple of this. + + Returns: + The wrapped integer type. + + ```py + from pydantic import BaseModel, ValidationError, conint + + class ConstrainedExample(BaseModel): + constrained_int: conint(gt=1) + + m = ConstrainedExample(constrained_int=2) + print(repr(m)) + #> ConstrainedExample(constrained_int=2) + + try: + ConstrainedExample(constrained_int=0) + except ValidationError as e: + print(e.errors()) + ''' + [ + { + 'type': 'greater_than', + 'loc': ('constrained_int',), + 'msg': 'Input should be greater than 1', + 'input': 0, + 'ctx': {'gt': 1}, + 'url': 'https://errors.pydantic.dev/2/v/greater_than', + } + ] + ''' + ``` + + """ # noqa: D212 + return Annotated[ + int, + Strict(strict) if strict is not None else None, + annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le), + annotated_types.MultipleOf(multiple_of) if multiple_of is not None else None, + ] -if TYPE_CHECKING: - PositiveInt = int - NegativeInt = int - NonPositiveInt = int - NonNegativeInt = int - StrictInt = int -else: +PositiveInt = Annotated[int, annotated_types.Gt(0)] +"""An integer that must be greater than zero. - class PositiveInt(ConstrainedInt): - gt = 0 +```py +from pydantic import BaseModel, PositiveInt, ValidationError - class NegativeInt(ConstrainedInt): - lt = 0 +class Model(BaseModel): + positive_int: PositiveInt - class NonPositiveInt(ConstrainedInt): - le = 0 +m = Model(positive_int=1) +print(repr(m)) +#> Model(positive_int=1) - class NonNegativeInt(ConstrainedInt): - ge = 0 +try: + Model(positive_int=-1) +except ValidationError as e: + print(e.errors()) + ''' + [ + { + 'type': 'greater_than', + 'loc': ('positive_int',), + 'msg': 'Input should be greater than 0', + 'input': -1, + 'ctx': {'gt': 0}, + 'url': 'https://errors.pydantic.dev/2/v/greater_than', + } + ] + ''' +``` +""" +NegativeInt = Annotated[int, annotated_types.Lt(0)] +"""An integer that must be less than zero. - class StrictInt(ConstrainedInt): - strict = True +```py +from pydantic import BaseModel, NegativeInt, ValidationError +class Model(BaseModel): + negative_int: NegativeInt + +m = Model(negative_int=-1) +print(repr(m)) +#> Model(negative_int=-1) + +try: + Model(negative_int=1) +except ValidationError as e: + print(e.errors()) + ''' + [ + { + 'type': 'less_than', + 'loc': ('negative_int',), + 'msg': 'Input should be less than 0', + 'input': 1, + 'ctx': {'lt': 0}, + 'url': 'https://errors.pydantic.dev/2/v/less_than', + } + ] + ''' +``` +""" +NonPositiveInt = Annotated[int, annotated_types.Le(0)] +"""An integer that must be less than or equal to zero. + +```py +from pydantic import BaseModel, NonPositiveInt, ValidationError + +class Model(BaseModel): + non_positive_int: NonPositiveInt + +m = Model(non_positive_int=0) +print(repr(m)) +#> Model(non_positive_int=0) + +try: + Model(non_positive_int=1) +except ValidationError as e: + print(e.errors()) + ''' + [ + { + 'type': 'less_than_equal', + 'loc': ('non_positive_int',), + 'msg': 'Input should be less than or equal to 0', + 'input': 1, + 'ctx': {'le': 0}, + 'url': 'https://errors.pydantic.dev/2/v/less_than_equal', + } + ] + ''' +``` +""" +NonNegativeInt = Annotated[int, annotated_types.Ge(0)] +"""An integer that must be greater than or equal to zero. + +```py +from pydantic import BaseModel, NonNegativeInt, ValidationError + +class Model(BaseModel): + non_negative_int: NonNegativeInt + +m = Model(non_negative_int=0) +print(repr(m)) +#> Model(non_negative_int=0) + +try: + Model(non_negative_int=-1) +except ValidationError as e: + print(e.errors()) + ''' + [ + { + 'type': 'greater_than_equal', + 'loc': ('non_negative_int',), + 'msg': 'Input should be greater than or equal to 0', + 'input': -1, + 'ctx': {'ge': 0}, + 'url': 'https://errors.pydantic.dev/2/v/greater_than_equal', + } + ] + ''' +``` +""" +StrictInt = Annotated[int, Strict()] +"""An integer that must be validated in strict mode. + +```py +from pydantic import BaseModel, StrictInt, ValidationError + +class StrictIntModel(BaseModel): + strict_int: StrictInt + +try: + StrictIntModel(strict_int=3.14159) +except ValidationError as e: + print(e) + ''' + 1 validation error for StrictIntModel + strict_int + Input should be a valid integer [type=int_type, input_value=3.14159, input_type=float] + ''' +``` +""" # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLOAT TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -class ConstrainedFloat(float, metaclass=ConstrainedNumberMeta): - strict: bool = False - gt: OptionalIntFloat = None - ge: OptionalIntFloat = None - lt: OptionalIntFloat = None - le: OptionalIntFloat = None - multiple_of: OptionalIntFloat = None - allow_inf_nan: Optional[bool] = None +@_dataclasses.dataclass +class AllowInfNan(_fields.PydanticMetadata): + """A field metadata class to indicate that a field should allow ``-inf``, ``inf``, and ``nan``.""" - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none( - field_schema, - exclusiveMinimum=cls.gt, - exclusiveMaximum=cls.lt, - minimum=cls.ge, - maximum=cls.le, - multipleOf=cls.multiple_of, - ) - # Modify constraints to account for differences between IEEE floats and JSON - if field_schema.get('exclusiveMinimum') == -math.inf: - del field_schema['exclusiveMinimum'] - if field_schema.get('minimum') == -math.inf: - del field_schema['minimum'] - if field_schema.get('exclusiveMaximum') == math.inf: - del field_schema['exclusiveMaximum'] - if field_schema.get('maximum') == math.inf: - del field_schema['maximum'] + allow_inf_nan: bool = True - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield strict_float_validator if cls.strict else float_validator - yield number_size_validator - yield number_multiple_validator - yield float_finite_validator + def __hash__(self) -> int: + return hash(self.allow_inf_nan) def confloat( *, - strict: bool = False, - gt: float = None, - ge: float = None, - lt: float = None, - le: float = None, - multiple_of: float = None, - allow_inf_nan: Optional[bool] = None, -) -> Type[float]: - # use kwargs then define conf in a dict to aid with IDE type hinting - namespace = dict(strict=strict, gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of, allow_inf_nan=allow_inf_nan) - return type('ConstrainedFloatValue', (ConstrainedFloat,), namespace) + strict: bool | None = None, + gt: float | None = None, + ge: float | None = None, + lt: float | None = None, + le: float | None = None, + multiple_of: float | None = None, + allow_inf_nan: bool | None = None, +) -> type[float]: + """ + !!! warning "Discouraged" + This function is **discouraged** in favor of using + [`Annotated`](https://docs.python.org/3/library/typing.html#typing.Annotated) with + [`Field`][pydantic.fields.Field] instead. + + This function will be **deprecated** in Pydantic 3.0. + + The reason is that `confloat` returns a type, which doesn't play well with static analysis tools. + + === ":x: Don't do this" + ```py + from pydantic import BaseModel, confloat + + class Foo(BaseModel): + bar: confloat(strict=True, gt=0) + ``` + + === ":white_check_mark: Do this" + ```py + from typing_extensions import Annotated + + from pydantic import BaseModel, Field + + class Foo(BaseModel): + bar: Annotated[float, Field(strict=True, gt=0)] + ``` + + A wrapper around `float` that allows for additional constraints. + + Args: + strict: Whether to validate the float in strict mode. + gt: The value must be greater than this. + ge: The value must be greater than or equal to this. + lt: The value must be less than this. + le: The value must be less than or equal to this. + multiple_of: The value must be a multiple of this. + allow_inf_nan: Whether to allow `-inf`, `inf`, and `nan`. + + Returns: + The wrapped float type. + + ```py + from pydantic import BaseModel, ValidationError, confloat + + class ConstrainedExample(BaseModel): + constrained_float: confloat(gt=1.0) + + m = ConstrainedExample(constrained_float=1.1) + print(repr(m)) + #> ConstrainedExample(constrained_float=1.1) + + try: + ConstrainedExample(constrained_float=0.9) + except ValidationError as e: + print(e.errors()) + ''' + [ + { + 'type': 'greater_than', + 'loc': ('constrained_float',), + 'msg': 'Input should be greater than 1', + 'input': 0.9, + 'ctx': {'gt': 1.0}, + 'url': 'https://errors.pydantic.dev/2/v/greater_than', + } + ] + ''' + ``` + """ # noqa: D212 + return Annotated[ + float, + Strict(strict) if strict is not None else None, + annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le), + annotated_types.MultipleOf(multiple_of) if multiple_of is not None else None, + AllowInfNan(allow_inf_nan) if allow_inf_nan is not None else None, + ] -if TYPE_CHECKING: - PositiveFloat = float - NegativeFloat = float - NonPositiveFloat = float - NonNegativeFloat = float - StrictFloat = float - FiniteFloat = float -else: +PositiveFloat = Annotated[float, annotated_types.Gt(0)] +"""A float that must be greater than zero. - class PositiveFloat(ConstrainedFloat): - gt = 0 +```py +from pydantic import BaseModel, PositiveFloat, ValidationError - class NegativeFloat(ConstrainedFloat): - lt = 0 +class Model(BaseModel): + positive_float: PositiveFloat - class NonPositiveFloat(ConstrainedFloat): - le = 0 +m = Model(positive_float=1.0) +print(repr(m)) +#> Model(positive_float=1.0) - class NonNegativeFloat(ConstrainedFloat): - ge = 0 +try: + Model(positive_float=-1.0) +except ValidationError as e: + print(e.errors()) + ''' + [ + { + 'type': 'greater_than', + 'loc': ('positive_float',), + 'msg': 'Input should be greater than 0', + 'input': -1.0, + 'ctx': {'gt': 0.0}, + 'url': 'https://errors.pydantic.dev/2/v/greater_than', + } + ] + ''' +``` +""" +NegativeFloat = Annotated[float, annotated_types.Lt(0)] +"""A float that must be less than zero. - class StrictFloat(ConstrainedFloat): - strict = True +```py +from pydantic import BaseModel, NegativeFloat, ValidationError - class FiniteFloat(ConstrainedFloat): - allow_inf_nan = False +class Model(BaseModel): + negative_float: NegativeFloat + +m = Model(negative_float=-1.0) +print(repr(m)) +#> Model(negative_float=-1.0) + +try: + Model(negative_float=1.0) +except ValidationError as e: + print(e.errors()) + ''' + [ + { + 'type': 'less_than', + 'loc': ('negative_float',), + 'msg': 'Input should be less than 0', + 'input': 1.0, + 'ctx': {'lt': 0.0}, + 'url': 'https://errors.pydantic.dev/2/v/less_than', + } + ] + ''' +``` +""" +NonPositiveFloat = Annotated[float, annotated_types.Le(0)] +"""A float that must be less than or equal to zero. + +```py +from pydantic import BaseModel, NonPositiveFloat, ValidationError + +class Model(BaseModel): + non_positive_float: NonPositiveFloat + +m = Model(non_positive_float=0.0) +print(repr(m)) +#> Model(non_positive_float=0.0) + +try: + Model(non_positive_float=1.0) +except ValidationError as e: + print(e.errors()) + ''' + [ + { + 'type': 'less_than_equal', + 'loc': ('non_positive_float',), + 'msg': 'Input should be less than or equal to 0', + 'input': 1.0, + 'ctx': {'le': 0.0}, + 'url': 'https://errors.pydantic.dev/2/v/less_than_equal', + } + ] + ''' +``` +""" +NonNegativeFloat = Annotated[float, annotated_types.Ge(0)] +"""A float that must be greater than or equal to zero. + +```py +from pydantic import BaseModel, NonNegativeFloat, ValidationError + +class Model(BaseModel): + non_negative_float: NonNegativeFloat + +m = Model(non_negative_float=0.0) +print(repr(m)) +#> Model(non_negative_float=0.0) + +try: + Model(non_negative_float=-1.0) +except ValidationError as e: + print(e.errors()) + ''' + [ + { + 'type': 'greater_than_equal', + 'loc': ('non_negative_float',), + 'msg': 'Input should be greater than or equal to 0', + 'input': -1.0, + 'ctx': {'ge': 0.0}, + 'url': 'https://errors.pydantic.dev/2/v/greater_than_equal', + } + ] + ''' +``` +""" +StrictFloat = Annotated[float, Strict(True)] +"""A float that must be validated in strict mode. + +```py +from pydantic import BaseModel, StrictFloat, ValidationError + +class StrictFloatModel(BaseModel): + strict_float: StrictFloat + +try: + StrictFloatModel(strict_float='1.0') +except ValidationError as e: + print(e) + ''' + 1 validation error for StrictFloatModel + strict_float + Input should be a valid number [type=float_type, input_value='1.0', input_type=str] + ''' +``` +""" +FiniteFloat = Annotated[float, AllowInfNan(False)] +"""A float that must be finite (not ``-inf``, ``inf``, or ``nan``). + +```py +from pydantic import BaseModel, FiniteFloat + +class Model(BaseModel): + finite: FiniteFloat + +m = Model(finite=1.0) +print(m) +#> finite=1.0 +``` +""" # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTES TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -class ConstrainedBytes(bytes): - strip_whitespace = False - to_upper = False - to_lower = False - min_length: OptionalInt = None - max_length: OptionalInt = None - strict: bool = False - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none(field_schema, minLength=cls.min_length, maxLength=cls.max_length) - - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield strict_bytes_validator if cls.strict else bytes_validator - yield constr_strip_whitespace - yield constr_upper - yield constr_lower - yield constr_length_validator - - def conbytes( *, - strip_whitespace: bool = False, - to_upper: bool = False, - to_lower: bool = False, - min_length: int = None, - max_length: int = None, - strict: bool = False, -) -> Type[bytes]: - # use kwargs then define conf in a dict to aid with IDE type hinting - namespace = dict( - strip_whitespace=strip_whitespace, - to_upper=to_upper, - to_lower=to_lower, - min_length=min_length, - max_length=max_length, - strict=strict, - ) - return _registered(type('ConstrainedBytesValue', (ConstrainedBytes,), namespace)) + min_length: int | None = None, + max_length: int | None = None, + strict: bool | None = None, +) -> type[bytes]: + """A wrapper around `bytes` that allows for additional constraints. + + Args: + min_length: The minimum length of the bytes. + max_length: The maximum length of the bytes. + strict: Whether to validate the bytes in strict mode. + + Returns: + The wrapped bytes type. + """ + return Annotated[ + bytes, + Strict(strict) if strict is not None else None, + annotated_types.Len(min_length or 0, max_length), + ] -if TYPE_CHECKING: - StrictBytes = bytes -else: - - class StrictBytes(ConstrainedBytes): - strict = True +StrictBytes = Annotated[bytes, Strict()] +"""A bytes that must be validated in strict mode.""" # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ STRING TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -class ConstrainedStr(str): - strip_whitespace = False - to_upper = False - to_lower = False - min_length: OptionalInt = None - max_length: OptionalInt = None - curtail_length: OptionalInt = None - regex: Optional[Pattern[str]] = None - strict = False +@_dataclasses.dataclass(frozen=True) +class StringConstraints(annotated_types.GroupedMetadata): + """Usage docs: https://docs.pydantic.dev/2.6/concepts/fields/#string-constraints - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none( - field_schema, - minLength=cls.min_length, - maxLength=cls.max_length, - pattern=cls.regex and cls.regex.pattern, - ) + Apply constraints to `str` types. - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield strict_str_validator if cls.strict else str_validator - yield constr_strip_whitespace - yield constr_upper - yield constr_lower - yield constr_length_validator - yield cls.validate + Attributes: + strip_whitespace: Whether to strip whitespace from the string. + to_upper: Whether to convert the string to uppercase. + to_lower: Whether to convert the string to lowercase. + strict: Whether to validate the string in strict mode. + min_length: The minimum length of the string. + max_length: The maximum length of the string. + pattern: A regex pattern that the string must match. + """ - @classmethod - def validate(cls, value: Union[str]) -> Union[str]: - if cls.curtail_length and len(value) > cls.curtail_length: - value = value[: cls.curtail_length] + strip_whitespace: bool | None = None + to_upper: bool | None = None + to_lower: bool | None = None + strict: bool | None = None + min_length: int | None = None + max_length: int | None = None + pattern: str | None = None - if cls.regex: - if not cls.regex.match(value): - raise errors.StrRegexError(pattern=cls.regex.pattern) - - return value + def __iter__(self) -> Iterator[BaseMetadata]: + if self.min_length is not None: + yield MinLen(self.min_length) + if self.max_length is not None: + yield MaxLen(self.max_length) + if self.strict is not None: + yield Strict() + if ( + self.strip_whitespace is not None + or self.pattern is not None + or self.to_lower is not None + or self.to_upper is not None + ): + yield _fields.pydantic_general_metadata( + strip_whitespace=self.strip_whitespace, + to_upper=self.to_upper, + to_lower=self.to_lower, + pattern=self.pattern, + ) def constr( *, - strip_whitespace: bool = False, - to_upper: bool = False, - to_lower: bool = False, - strict: bool = False, - min_length: int = None, - max_length: int = None, - curtail_length: int = None, - regex: str = None, -) -> Type[str]: - # use kwargs then define conf in a dict to aid with IDE type hinting - namespace = dict( - strip_whitespace=strip_whitespace, - to_upper=to_upper, - to_lower=to_lower, - strict=strict, - min_length=min_length, - max_length=max_length, - curtail_length=curtail_length, - regex=regex and re.compile(regex), - ) - return _registered(type('ConstrainedStrValue', (ConstrainedStr,), namespace)) + strip_whitespace: bool | None = None, + to_upper: bool | None = None, + to_lower: bool | None = None, + strict: bool | None = None, + min_length: int | None = None, + max_length: int | None = None, + pattern: str | None = None, +) -> type[str]: + """ + !!! warning "Discouraged" + This function is **discouraged** in favor of using + [`Annotated`](https://docs.python.org/3/library/typing.html#typing.Annotated) with + [`StringConstraints`][pydantic.types.StringConstraints] instead. + + This function will be **deprecated** in Pydantic 3.0. + + The reason is that `constr` returns a type, which doesn't play well with static analysis tools. + + === ":x: Don't do this" + ```py + from pydantic import BaseModel, constr + + class Foo(BaseModel): + bar: constr(strip_whitespace=True, to_upper=True, pattern=r'^[A-Z]+$') + ``` + + === ":white_check_mark: Do this" + ```py + from typing_extensions import Annotated + + from pydantic import BaseModel, StringConstraints + + class Foo(BaseModel): + bar: Annotated[str, StringConstraints(strip_whitespace=True, to_upper=True, pattern=r'^[A-Z]+$')] + ``` + + A wrapper around `str` that allows for additional constraints. + + ```py + from pydantic import BaseModel, constr + + class Foo(BaseModel): + bar: constr(strip_whitespace=True, to_upper=True, pattern=r'^[A-Z]+$') -if TYPE_CHECKING: - StrictStr = str -else: + foo = Foo(bar=' hello ') + print(foo) + #> bar='HELLO' + ``` - class StrictStr(ConstrainedStr): - strict = True + Args: + strip_whitespace: Whether to remove leading and trailing whitespace. + to_upper: Whether to turn all characters to uppercase. + to_lower: Whether to turn all characters to lowercase. + strict: Whether to validate the string in strict mode. + min_length: The minimum length of the string. + max_length: The maximum length of the string. + pattern: A regex pattern to validate the string against. + + Returns: + The wrapped string type. + """ # noqa: D212 + return Annotated[ + str, + StringConstraints( + strip_whitespace=strip_whitespace, + to_upper=to_upper, + to_lower=to_lower, + strict=strict, + min_length=min_length, + max_length=max_length, + pattern=pattern, + ), + ] -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -# This types superclass should be Set[T], but cython chokes on that... -class ConstrainedSet(set): # type: ignore - # Needed for pydantic to detect that this is a set - __origin__ = set - __args__: Set[Type[T]] # type: ignore - - min_items: Optional[int] = None - max_items: Optional[int] = None - item_type: Type[T] # type: ignore - - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield cls.set_length_validator - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none(field_schema, minItems=cls.min_items, maxItems=cls.max_items) - - @classmethod - def set_length_validator(cls, v: 'Optional[Set[T]]') -> 'Optional[Set[T]]': - if v is None: - return None - - v = set_validator(v) - v_len = len(v) - - if cls.min_items is not None and v_len < cls.min_items: - raise errors.SetMinLengthError(limit_value=cls.min_items) - - if cls.max_items is not None and v_len > cls.max_items: - raise errors.SetMaxLengthError(limit_value=cls.max_items) - - return v +StrictStr = Annotated[str, Strict()] +"""A string that must be validated in strict mode.""" -def conset(item_type: Type[T], *, min_items: int = None, max_items: int = None) -> Type[Set[T]]: - # __args__ is needed to conform to typing generics api - namespace = {'min_items': min_items, 'max_items': max_items, 'item_type': item_type, '__args__': [item_type]} - # We use new_class to be able to deal with Generic types - return new_class('ConstrainedSetValue', (ConstrainedSet,), {}, lambda ns: ns.update(namespace)) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ COLLECTION TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +HashableItemType = TypeVar('HashableItemType', bound=Hashable) -# This types superclass should be FrozenSet[T], but cython chokes on that... -class ConstrainedFrozenSet(frozenset): # type: ignore - # Needed for pydantic to detect that this is a set - __origin__ = frozenset - __args__: FrozenSet[Type[T]] # type: ignore +def conset( + item_type: type[HashableItemType], *, min_length: int | None = None, max_length: int | None = None +) -> type[set[HashableItemType]]: + """A wrapper around `typing.Set` that allows for additional constraints. - min_items: Optional[int] = None - max_items: Optional[int] = None - item_type: Type[T] # type: ignore + Args: + item_type: The type of the items in the set. + min_length: The minimum length of the set. + max_length: The maximum length of the set. - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield cls.frozenset_length_validator - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none(field_schema, minItems=cls.min_items, maxItems=cls.max_items) - - @classmethod - def frozenset_length_validator(cls, v: 'Optional[FrozenSet[T]]') -> 'Optional[FrozenSet[T]]': - if v is None: - return None - - v = frozenset_validator(v) - v_len = len(v) - - if cls.min_items is not None and v_len < cls.min_items: - raise errors.FrozenSetMinLengthError(limit_value=cls.min_items) - - if cls.max_items is not None and v_len > cls.max_items: - raise errors.FrozenSetMaxLengthError(limit_value=cls.max_items) - - return v + Returns: + The wrapped set type. + """ + return Annotated[Set[item_type], annotated_types.Len(min_length or 0, max_length)] -def confrozenset(item_type: Type[T], *, min_items: int = None, max_items: int = None) -> Type[FrozenSet[T]]: - # __args__ is needed to conform to typing generics api - namespace = {'min_items': min_items, 'max_items': max_items, 'item_type': item_type, '__args__': [item_type]} - # We use new_class to be able to deal with Generic types - return new_class('ConstrainedFrozenSetValue', (ConstrainedFrozenSet,), {}, lambda ns: ns.update(namespace)) +def confrozenset( + item_type: type[HashableItemType], *, min_length: int | None = None, max_length: int | None = None +) -> type[frozenset[HashableItemType]]: + """A wrapper around `typing.FrozenSet` that allows for additional constraints. + + Args: + item_type: The type of the items in the frozenset. + min_length: The minimum length of the frozenset. + max_length: The maximum length of the frozenset. + + Returns: + The wrapped frozenset type. + """ + return Annotated[FrozenSet[item_type], annotated_types.Len(min_length or 0, max_length)] -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LIST TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -# This types superclass should be List[T], but cython chokes on that... -class ConstrainedList(list): # type: ignore - # Needed for pydantic to detect that this is a list - __origin__ = list - __args__: Tuple[Type[T], ...] # type: ignore - - min_items: Optional[int] = None - max_items: Optional[int] = None - unique_items: Optional[bool] = None - item_type: Type[T] # type: ignore - - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield cls.list_length_validator - if cls.unique_items: - yield cls.unique_items_validator - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none(field_schema, minItems=cls.min_items, maxItems=cls.max_items, uniqueItems=cls.unique_items) - - @classmethod - def list_length_validator(cls, v: 'Optional[List[T]]') -> 'Optional[List[T]]': - if v is None: - return None - - v = list_validator(v) - v_len = len(v) - - if cls.min_items is not None and v_len < cls.min_items: - raise errors.ListMinLengthError(limit_value=cls.min_items) - - if cls.max_items is not None and v_len > cls.max_items: - raise errors.ListMaxLengthError(limit_value=cls.max_items) - - return v - - @classmethod - def unique_items_validator(cls, v: 'List[T]') -> 'List[T]': - for i, value in enumerate(v, start=1): - if value in v[i:]: - raise errors.ListUniqueItemsError() - - return v +AnyItemType = TypeVar('AnyItemType') def conlist( - item_type: Type[T], *, min_items: int = None, max_items: int = None, unique_items: bool = None -) -> Type[List[T]]: - # __args__ is needed to conform to typing generics api - namespace = dict( - min_items=min_items, max_items=max_items, unique_items=unique_items, item_type=item_type, __args__=(item_type,) - ) - # We use new_class to be able to deal with Generic types - return new_class('ConstrainedListValue', (ConstrainedList,), {}, lambda ns: ns.update(namespace)) + item_type: type[AnyItemType], + *, + min_length: int | None = None, + max_length: int | None = None, + unique_items: bool | None = None, +) -> type[list[AnyItemType]]: + """A wrapper around typing.List that adds validation. + + Args: + item_type: The type of the items in the list. + min_length: The minimum length of the list. Defaults to None. + max_length: The maximum length of the list. Defaults to None. + unique_items: Whether the items in the list must be unique. Defaults to None. + !!! warning Deprecated + The `unique_items` parameter is deprecated, use `Set` instead. + See [this issue](https://github.com/pydantic/pydantic-core/issues/296) for more details. + + Returns: + The wrapped list type. + """ + if unique_items is not None: + raise PydanticUserError( + ( + '`unique_items` is removed, use `Set` instead' + '(this feature is discussed in https://github.com/pydantic/pydantic-core/issues/296)' + ), + code='removed-kwargs', + ) + return Annotated[List[item_type], annotated_types.Len(min_length or 0, max_length)] -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PYOBJECT TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - +# ~~~~~~~~~~~~~~~~~~~~~~~~~~ IMPORT STRING TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +AnyType = TypeVar('AnyType') if TYPE_CHECKING: - PyObject = Callable[..., Any] + ImportString = Annotated[AnyType, ...] else: - class PyObject: - validate_always = True + class ImportString: + """A type that can be used to import a type from a string. + + `ImportString` expects a string and loads the Python object importable at that dotted path. + Attributes of modules may be separated from the module by `:` or `.`, e.g. if `'math:cos'` was provided, + the resulting field value would be the function`cos`. If a `.` is used and both an attribute and submodule + are present at the same path, the module will be preferred. + + On model instantiation, pointers will be evaluated and imported. There is + some nuance to this behavior, demonstrated in the examples below. + + **Good behavior:** + ```py + from math import cos + + from pydantic import BaseModel, Field, ImportString, ValidationError + + + class ImportThings(BaseModel): + obj: ImportString + + + # A string value will cause an automatic import + my_cos = ImportThings(obj='math.cos') + + # You can use the imported function as you would expect + cos_of_0 = my_cos.obj(0) + assert cos_of_0 == 1 + + + # A string whose value cannot be imported will raise an error + try: + ImportThings(obj='foo.bar') + except ValidationError as e: + print(e) + ''' + 1 validation error for ImportThings + obj + Invalid python path: No module named 'foo.bar' [type=import_error, input_value='foo.bar', input_type=str] + ''' + + + # Actual python objects can be assigned as well + my_cos = ImportThings(obj=cos) + my_cos_2 = ImportThings(obj='math.cos') + my_cos_3 = ImportThings(obj='math:cos') + assert my_cos == my_cos_2 == my_cos_3 + + + # You can set default field value either as Python object: + class ImportThingsDefaultPyObj(BaseModel): + obj: ImportString = math.cos + + + # or as a string value (but only if used with `validate_default=True`) + class ImportThingsDefaultString(BaseModel): + obj: ImportString = Field(default='math.cos', validate_default=True) + + + my_cos_default1 = ImportThingsDefaultPyObj() + my_cos_default2 = ImportThingsDefaultString() + assert my_cos_default1.obj == my_cos_default2.obj == math.cos + + + # note: this will not work! + class ImportThingsMissingValidateDefault(BaseModel): + obj: ImportString = 'math.cos' + + my_cos_default3 = ImportThingsMissingValidateDefault() + assert my_cos_default3.obj == 'math.cos' # just string, not evaluated + ``` + + Serializing an `ImportString` type to json is also possible. + + ```py + from pydantic import BaseModel, ImportString + + + class ImportThings(BaseModel): + obj: ImportString + + + # Create an instance + m = ImportThings(obj='math.cos') + print(m) + #> obj= + print(m.model_dump_json()) + #> {"obj":"math.cos"} + ``` + """ @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield cls.validate + def __class_getitem__(cls, item: AnyType) -> AnyType: + return Annotated[item, cls()] @classmethod - def validate(cls, value: Any) -> Any: - if isinstance(value, Callable): - return value + def __get_pydantic_core_schema__( + cls, source: type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + serializer = core_schema.plain_serializer_function_ser_schema(cls._serialize, when_used='json') + if cls is source: + # Treat bare usage of ImportString (`schema is None`) as the same as ImportString[Any] + return core_schema.no_info_plain_validator_function( + function=_validators.import_string, serialization=serializer + ) + else: + return core_schema.no_info_before_validator_function( + function=_validators.import_string, schema=handler(source), serialization=serializer + ) - try: - value = str_validator(value) - except errors.StrError: - raise errors.PyObjectError(error_message='value is neither a valid import path not a valid callable') + @staticmethod + def _serialize(v: Any) -> str: + if isinstance(v, ModuleType): + return v.__name__ + elif hasattr(v, '__module__') and hasattr(v, '__name__'): + return f'{v.__module__}.{v.__name__}' + else: + return v - try: - return import_string(value) - except ImportError as e: - raise errors.PyObjectError(error_message=str(e)) + def __repr__(self) -> str: + return 'ImportString' # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DECIMAL TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -class ConstrainedDecimal(Decimal, metaclass=ConstrainedNumberMeta): - gt: OptionalIntFloatDecimal = None - ge: OptionalIntFloatDecimal = None - lt: OptionalIntFloatDecimal = None - le: OptionalIntFloatDecimal = None - max_digits: OptionalInt = None - decimal_places: OptionalInt = None - multiple_of: OptionalIntFloatDecimal = None - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none( - field_schema, - exclusiveMinimum=cls.gt, - exclusiveMaximum=cls.lt, - minimum=cls.ge, - maximum=cls.le, - multipleOf=cls.multiple_of, - ) - - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield decimal_validator - yield number_size_validator - yield number_multiple_validator - yield cls.validate - - @classmethod - def validate(cls, value: Decimal) -> Decimal: - digit_tuple, exponent = value.as_tuple()[1:] - if exponent in {'F', 'n', 'N'}: - raise errors.DecimalIsNotFiniteError() - - if exponent >= 0: - # A positive exponent adds that many trailing zeros. - digits = len(digit_tuple) + exponent - decimals = 0 - else: - # If the absolute value of the negative exponent is larger than the - # number of digits, then it's the same as the number of digits, - # because it'll consume all of the digits in digit_tuple and then - # add abs(exponent) - len(digit_tuple) leading zeros after the - # decimal point. - if abs(exponent) > len(digit_tuple): - digits = decimals = abs(exponent) - else: - digits = len(digit_tuple) - decimals = abs(exponent) - whole_digits = digits - decimals - - if cls.max_digits is not None and digits > cls.max_digits: - raise errors.DecimalMaxDigitsError(max_digits=cls.max_digits) - - if cls.decimal_places is not None and decimals > cls.decimal_places: - raise errors.DecimalMaxPlacesError(decimal_places=cls.decimal_places) - - if cls.max_digits is not None and cls.decimal_places is not None: - expected = cls.max_digits - cls.decimal_places - if whole_digits > expected: - raise errors.DecimalWholeDigitsError(whole_digits=expected) - - return value - - def condecimal( *, - gt: Decimal = None, - ge: Decimal = None, - lt: Decimal = None, - le: Decimal = None, - max_digits: int = None, - decimal_places: int = None, - multiple_of: Decimal = None, -) -> Type[Decimal]: - # use kwargs then define conf in a dict to aid with IDE type hinting - namespace = dict( - gt=gt, ge=ge, lt=lt, le=le, max_digits=max_digits, decimal_places=decimal_places, multiple_of=multiple_of - ) - return type('ConstrainedDecimalValue', (ConstrainedDecimal,), namespace) + strict: bool | None = None, + gt: int | Decimal | None = None, + ge: int | Decimal | None = None, + lt: int | Decimal | None = None, + le: int | Decimal | None = None, + multiple_of: int | Decimal | None = None, + max_digits: int | None = None, + decimal_places: int | None = None, + allow_inf_nan: bool | None = None, +) -> type[Decimal]: + """ + !!! warning "Discouraged" + This function is **discouraged** in favor of using + [`Annotated`](https://docs.python.org/3/library/typing.html#typing.Annotated) with + [`Field`][pydantic.fields.Field] instead. + + This function will be **deprecated** in Pydantic 3.0. + + The reason is that `condecimal` returns a type, which doesn't play well with static analysis tools. + + === ":x: Don't do this" + ```py + from pydantic import BaseModel, condecimal + + class Foo(BaseModel): + bar: condecimal(strict=True, allow_inf_nan=True) + ``` + + === ":white_check_mark: Do this" + ```py + from decimal import Decimal + + from typing_extensions import Annotated + + from pydantic import BaseModel, Field + + class Foo(BaseModel): + bar: Annotated[Decimal, Field(strict=True, allow_inf_nan=True)] + ``` + + A wrapper around Decimal that adds validation. + + Args: + strict: Whether to validate the value in strict mode. Defaults to `None`. + gt: The value must be greater than this. Defaults to `None`. + ge: The value must be greater than or equal to this. Defaults to `None`. + lt: The value must be less than this. Defaults to `None`. + le: The value must be less than or equal to this. Defaults to `None`. + multiple_of: The value must be a multiple of this. Defaults to `None`. + max_digits: The maximum number of digits. Defaults to `None`. + decimal_places: The number of decimal places. Defaults to `None`. + allow_inf_nan: Whether to allow infinity and NaN. Defaults to `None`. + + ```py + from decimal import Decimal + + from pydantic import BaseModel, ValidationError, condecimal + + class ConstrainedExample(BaseModel): + constrained_decimal: condecimal(gt=Decimal('1.0')) + + m = ConstrainedExample(constrained_decimal=Decimal('1.1')) + print(repr(m)) + #> ConstrainedExample(constrained_decimal=Decimal('1.1')) + + try: + ConstrainedExample(constrained_decimal=Decimal('0.9')) + except ValidationError as e: + print(e.errors()) + ''' + [ + { + 'type': 'greater_than', + 'loc': ('constrained_decimal',), + 'msg': 'Input should be greater than 1.0', + 'input': Decimal('0.9'), + 'ctx': {'gt': Decimal('1.0')}, + 'url': 'https://errors.pydantic.dev/2/v/greater_than', + } + ] + ''' + ``` + """ # noqa: D212 + return Annotated[ + Decimal, + Strict(strict) if strict is not None else None, + annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le), + annotated_types.MultipleOf(multiple_of) if multiple_of is not None else None, + _fields.pydantic_general_metadata(max_digits=max_digits, decimal_places=decimal_places), + AllowInfNan(allow_inf_nan) if allow_inf_nan is not None else None, + ] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ UUID TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -if TYPE_CHECKING: - UUID1 = UUID - UUID3 = UUID - UUID4 = UUID - UUID5 = UUID -else: - class UUID1(UUID): - _required_version = 1 +@_dataclasses.dataclass(**_internal_dataclass.slots_true) +class UuidVersion: + """A field metadata class to indicate a [UUID](https://docs.python.org/3/library/uuid.html) version.""" - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type='string', format=f'uuid{cls._required_version}') + uuid_version: Literal[1, 3, 4, 5] - class UUID3(UUID1): - _required_version = 3 + def __get_pydantic_json_schema__( + self, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + field_schema = handler(core_schema) + field_schema.pop('anyOf', None) # remove the bytes/str union + field_schema.update(type='string', format=f'uuid{self.uuid_version}') + return field_schema - class UUID4(UUID1): - _required_version = 4 + def __get_pydantic_core_schema__(self, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + if isinstance(self, source): + # used directly as a type + return core_schema.uuid_schema(version=self.uuid_version) + else: + # update existing schema with self.uuid_version + schema = handler(source) + _check_annotated_type(schema['type'], 'uuid', self.__class__.__name__) + schema['version'] = self.uuid_version # type: ignore + return schema - class UUID5(UUID1): - _required_version = 5 + def __hash__(self) -> int: + return hash(type(self.uuid_version)) + + +UUID1 = Annotated[UUID, UuidVersion(1)] +"""A [UUID](https://docs.python.org/3/library/uuid.html) that must be version 1. + +```py +import uuid + +from pydantic import UUID1, BaseModel + +class Model(BaseModel): + uuid1: UUID1 + +Model(uuid1=uuid.uuid1()) +``` +""" +UUID3 = Annotated[UUID, UuidVersion(3)] +"""A [UUID](https://docs.python.org/3/library/uuid.html) that must be version 3. + +```py +import uuid + +from pydantic import UUID3, BaseModel + +class Model(BaseModel): + uuid3: UUID3 + +Model(uuid3=uuid.uuid3(uuid.NAMESPACE_DNS, 'pydantic.org')) +``` +""" +UUID4 = Annotated[UUID, UuidVersion(4)] +"""A [UUID](https://docs.python.org/3/library/uuid.html) that must be version 4. + +```py +import uuid + +from pydantic import UUID4, BaseModel + +class Model(BaseModel): + uuid4: UUID4 + +Model(uuid4=uuid.uuid4()) +``` +""" +UUID5 = Annotated[UUID, UuidVersion(5)] +"""A [UUID](https://docs.python.org/3/library/uuid.html) that must be version 5. + +```py +import uuid + +from pydantic import UUID5, BaseModel + +class Model(BaseModel): + uuid5: UUID5 + +Model(uuid5=uuid.uuid5(uuid.NAMESPACE_DNS, 'pydantic.org')) +``` +""" # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PATH TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -if TYPE_CHECKING: - FilePath = Path - DirectoryPath = Path -else: - class FilePath(Path): - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(format='file-path') +@_dataclasses.dataclass +class PathType: + path_type: Literal['file', 'dir', 'new'] - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield path_validator - yield path_exists_validator - yield cls.validate + def __get_pydantic_json_schema__( + self, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + field_schema = handler(core_schema) + format_conversion = {'file': 'file-path', 'dir': 'directory-path'} + field_schema.update(format=format_conversion.get(self.path_type, 'path'), type='string') + return field_schema - @classmethod - def validate(cls, value: Path) -> Path: - if not value.is_file(): - raise errors.PathNotAFileError(path=value) + def __get_pydantic_core_schema__(self, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + function_lookup = { + 'file': cast(core_schema.WithInfoValidatorFunction, self.validate_file), + 'dir': cast(core_schema.WithInfoValidatorFunction, self.validate_directory), + 'new': cast(core_schema.WithInfoValidatorFunction, self.validate_new), + } - return value + return core_schema.with_info_after_validator_function( + function_lookup[self.path_type], + handler(source), + ) - class DirectoryPath(Path): - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(format='directory-path') + @staticmethod + def validate_file(path: Path, _: core_schema.ValidationInfo) -> Path: + if path.is_file(): + return path + else: + raise PydanticCustomError('path_not_file', 'Path does not point to a file') - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield path_validator - yield path_exists_validator - yield cls.validate + @staticmethod + def validate_directory(path: Path, _: core_schema.ValidationInfo) -> Path: + if path.is_dir(): + return path + else: + raise PydanticCustomError('path_not_directory', 'Path does not point to a directory') - @classmethod - def validate(cls, value: Path) -> Path: - if not value.is_dir(): - raise errors.PathNotADirectoryError(path=value) + @staticmethod + def validate_new(path: Path, _: core_schema.ValidationInfo) -> Path: + if path.exists(): + raise PydanticCustomError('path_exists', 'Path already exists') + elif not path.parent.exists(): + raise PydanticCustomError('parent_does_not_exist', 'Parent directory does not exist') + else: + return path - return value + def __hash__(self) -> int: + return hash(type(self.path_type)) + + +FilePath = Annotated[Path, PathType('file')] +"""A path that must point to a file. + +```py +from pathlib import Path + +from pydantic import BaseModel, FilePath, ValidationError + +class Model(BaseModel): + f: FilePath + +path = Path('text.txt') +path.touch() +m = Model(f='text.txt') +print(m.model_dump()) +#> {'f': PosixPath('text.txt')} +path.unlink() + +path = Path('directory') +path.mkdir(exist_ok=True) +try: + Model(f='directory') # directory +except ValidationError as e: + print(e) + ''' + 1 validation error for Model + f + Path does not point to a file [type=path_not_file, input_value='directory', input_type=str] + ''' +path.rmdir() + +try: + Model(f='not-exists-file') +except ValidationError as e: + print(e) + ''' + 1 validation error for Model + f + Path does not point to a file [type=path_not_file, input_value='not-exists-file', input_type=str] + ''' +``` +""" +DirectoryPath = Annotated[Path, PathType('dir')] +"""A path that must point to a directory. + +```py +from pathlib import Path + +from pydantic import BaseModel, DirectoryPath, ValidationError + +class Model(BaseModel): + f: DirectoryPath + +path = Path('directory/') +path.mkdir() +m = Model(f='directory/') +print(m.model_dump()) +#> {'f': PosixPath('directory')} +path.rmdir() + +path = Path('file.txt') +path.touch() +try: + Model(f='file.txt') # file +except ValidationError as e: + print(e) + ''' + 1 validation error for Model + f + Path does not point to a directory [type=path_not_directory, input_value='file.txt', input_type=str] + ''' +path.unlink() + +try: + Model(f='not-exists-directory') +except ValidationError as e: + print(e) + ''' + 1 validation error for Model + f + Path does not point to a directory [type=path_not_directory, input_value='not-exists-directory', input_type=str] + ''' +``` +""" +NewPath = Annotated[Path, PathType('new')] +"""A path for a new file or directory that must not already exist.""" # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ JSON TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -class JsonWrapper: - pass - - -class JsonMeta(type): - def __getitem__(self, t: Type[Any]) -> Type[JsonWrapper]: - if t is Any: - return Json # allow Json[Any] to replecate plain Json - return _registered(type('JsonWrapperValue', (JsonWrapper,), {'inner_type': t})) - - if TYPE_CHECKING: - Json = Annotated[T, ...] # Json[list[str]] will be recognized by type checkers as list[str] + Json = Annotated[AnyType, ...] # Json[list[str]] will be recognized by type checkers as list[str] else: - class Json(metaclass=JsonMeta): + class Json: + """A special type wrapper which loads JSON before parsing. + + You can use the `Json` data type to make Pydantic first load a raw JSON string before + validating the loaded data into the parametrized type: + + ```py + from typing import Any, List + + from pydantic import BaseModel, Json, ValidationError + + + class AnyJsonModel(BaseModel): + json_obj: Json[Any] + + + class ConstrainedJsonModel(BaseModel): + json_obj: Json[List[int]] + + + print(AnyJsonModel(json_obj='{"b": 1}')) + #> json_obj={'b': 1} + print(ConstrainedJsonModel(json_obj='[1, 2, 3]')) + #> json_obj=[1, 2, 3] + + try: + ConstrainedJsonModel(json_obj=12) + except ValidationError as e: + print(e) + ''' + 1 validation error for ConstrainedJsonModel + json_obj + JSON input should be string, bytes or bytearray [type=json_type, input_value=12, input_type=int] + ''' + + try: + ConstrainedJsonModel(json_obj='[a, b]') + except ValidationError as e: + print(e) + ''' + 1 validation error for ConstrainedJsonModel + json_obj + Invalid JSON: expected value at line 1 column 2 [type=json_invalid, input_value='[a, b]', input_type=str] + ''' + + try: + ConstrainedJsonModel(json_obj='["a", "b"]') + except ValidationError as e: + print(e) + ''' + 2 validation errors for ConstrainedJsonModel + json_obj.0 + Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='a', input_type=str] + json_obj.1 + Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='b', input_type=str] + ''' + ``` + + When you dump the model using `model_dump` or `model_dump_json`, the dumped value will be the result of validation, + not the original JSON string. However, you can use the argument `round_trip=True` to get the original JSON string back: + + ```py + from typing import List + + from pydantic import BaseModel, Json + + + class ConstrainedJsonModel(BaseModel): + json_obj: Json[List[int]] + + + print(ConstrainedJsonModel(json_obj='[1, 2, 3]').model_dump_json()) + #> {"json_obj":[1,2,3]} + print( + ConstrainedJsonModel(json_obj='[1, 2, 3]').model_dump_json(round_trip=True) + ) + #> {"json_obj":"[1,2,3]"} + ``` + """ + @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - field_schema.update(type='string', format='json-string') + def __class_getitem__(cls, item: AnyType) -> AnyType: + return Annotated[item, cls()] + + @classmethod + def __get_pydantic_core_schema__(cls, source: Any, handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + if cls is source: + return core_schema.json_schema(None) + else: + return core_schema.json_schema(handler(source)) + + def __repr__(self) -> str: + return 'Json' + + def __hash__(self) -> int: + return hash(type(self)) + + def __eq__(self, other: Any) -> bool: + return type(other) == type(self) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SECRET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +SecretType = TypeVar('SecretType', str, bytes) -class SecretField(abc.ABC): - """ - Note: this should be implemented as a generic like `SecretField(ABC, Generic[T])`, - the `__init__()` should be part of the abstract class and the - `get_secret_value()` method should use the generic `T` type. - However Cython doesn't support very well generics at the moment and - the generated code fails to be imported (see - https://github.com/cython/cython/issues/2753). - """ +class _SecretField(Generic[SecretType]): + def __init__(self, secret_value: SecretType) -> None: + self._secret_value: SecretType = secret_value + + def get_secret_value(self) -> SecretType: + """Get the secret value. + + Returns: + The secret value. + """ + return self._secret_value def __eq__(self, other: Any) -> bool: return isinstance(other, self.__class__) and self.get_secret_value() == other.get_secret_value() - def __str__(self) -> str: - return '**********' if self.get_secret_value() else '' - def __hash__(self) -> int: return hash(self.get_secret_value()) - @abc.abstractmethod - def get_secret_value(self) -> Any: # pragma: no cover - ... - - -class SecretStr(SecretField): - min_length: OptionalInt = None - max_length: OptionalInt = None - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none( - field_schema, - type='string', - writeOnly=True, - format='password', - minLength=cls.min_length, - maxLength=cls.max_length, - ) - - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield cls.validate - yield constr_length_validator - - @classmethod - def validate(cls, value: Any) -> 'SecretStr': - if isinstance(value, cls): - return value - value = str_validator(value) - return cls(value) - - def __init__(self, value: str): - self._secret_value = value - - def __repr__(self) -> str: - return f"SecretStr('{self}')" - def __len__(self) -> int: return len(self._secret_value) - def display(self) -> str: - warnings.warn('`secret_str.display()` is deprecated, use `str(secret_str)` instead', DeprecationWarning) - return str(self) - - def get_secret_value(self) -> str: - return self._secret_value - - -class SecretBytes(SecretField): - min_length: OptionalInt = None - max_length: OptionalInt = None - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none( - field_schema, - type='string', - writeOnly=True, - format='password', - minLength=cls.min_length, - maxLength=cls.max_length, - ) - - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield cls.validate - yield constr_length_validator - - @classmethod - def validate(cls, value: Any) -> 'SecretBytes': - if isinstance(value, cls): - return value - value = bytes_validator(value) - return cls(value) - - def __init__(self, value: bytes): - self._secret_value = value + def __str__(self) -> str: + return str(self._display()) def __repr__(self) -> str: - return f"SecretBytes(b'{self}')" + return f'{self.__class__.__name__}({self._display()!r})' - def __len__(self) -> int: - return len(self._secret_value) + def _display(self) -> SecretType: + raise NotImplementedError - def display(self) -> str: - warnings.warn('`secret_bytes.display()` is deprecated, use `str(secret_bytes)` instead', DeprecationWarning) - return str(self) + @classmethod + def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + if issubclass(source, SecretStr): + field_type = str + inner_schema = core_schema.str_schema() + else: + assert issubclass(source, SecretBytes) + field_type = bytes + inner_schema = core_schema.bytes_schema() + error_kind = 'string_type' if field_type is str else 'bytes_type' - def get_secret_value(self) -> bytes: - return self._secret_value + def serialize( + value: _SecretField[SecretType], info: core_schema.SerializationInfo + ) -> str | _SecretField[SecretType]: + if info.mode == 'json': + # we want the output to always be string without the `b'` prefix for bytes, + # hence we just use `secret_display` + return _secret_display(value.get_secret_value()) + else: + return value + + def get_json_schema(_core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: + json_schema = handler(inner_schema) + _utils.update_not_none( + json_schema, + type='string', + writeOnly=True, + format='password', + ) + return json_schema + + json_schema = core_schema.no_info_after_validator_function( + source, # construct the type + inner_schema, + ) + s = core_schema.json_or_python_schema( + python_schema=core_schema.union_schema( + [ + core_schema.is_instance_schema(source), + json_schema, + ], + strict=True, + custom_error_type=error_kind, + ), + json_schema=json_schema, + serialization=core_schema.plain_serializer_function_ser_schema( + serialize, + info_arg=True, + return_schema=core_schema.str_schema(), + when_used='json', + ), + ) + s.setdefault('metadata', {}).setdefault('pydantic_js_functions', []).append(get_json_schema) + return s + + +def _secret_display(value: str | bytes) -> str: + return '**********' if value else '' + + +class SecretStr(_SecretField[str]): + """A string used for storing sensitive information that you do not want to be visible in logging or tracebacks. + + When the secret value is nonempty, it is displayed as `'**********'` instead of the underlying value in + calls to `repr()` and `str()`. If the value _is_ empty, it is displayed as `''`. + + ```py + from pydantic import BaseModel, SecretStr + + class User(BaseModel): + username: str + password: SecretStr + + user = User(username='scolvin', password='password1') + + print(user) + #> username='scolvin' password=SecretStr('**********') + print(user.password.get_secret_value()) + #> password1 + print((SecretStr('password'), SecretStr(''))) + #> (SecretStr('**********'), SecretStr('')) + ``` + """ + + def _display(self) -> str: + return _secret_display(self.get_secret_value()) + + +class SecretBytes(_SecretField[bytes]): + """A bytes used for storing sensitive information that you do not want to be visible in logging or tracebacks. + + It displays `b'**********'` instead of the string value on `repr()` and `str()` calls. + When the secret value is nonempty, it is displayed as `b'**********'` instead of the underlying value in + calls to `repr()` and `str()`. If the value _is_ empty, it is displayed as `b''`. + + ```py + from pydantic import BaseModel, SecretBytes + + class User(BaseModel): + username: str + password: SecretBytes + + user = User(username='scolvin', password=b'password1') + #> username='scolvin' password=SecretBytes(b'**********') + print(user.password.get_secret_value()) + #> b'password1' + print((SecretBytes(b'password'), SecretBytes(b''))) + #> (SecretBytes(b'**********'), SecretBytes(b'')) + ``` + """ + + def _display(self) -> bytes: + return _secret_display(self.get_secret_value()).encode() # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PAYMENT CARD TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class PaymentCardBrand(str, Enum): - # If you add another card type, please also add it to the - # Hypothesis strategy in `pydantic._hypothesis_plugin`. amex = 'American Express' mastercard = 'Mastercard' visa = 'Visa' @@ -954,10 +1600,13 @@ class PaymentCardBrand(str, Enum): return self.value +@deprecated( + 'The `PaymentCardNumber` class is deprecated, use `pydantic_extra_types` instead. ' + 'See https://docs.pydantic.dev/latest/api/pydantic_extra_types_payment/#pydantic_extra_types.payment.PaymentCardNumber.', + category=PydanticDeprecatedSince20, +) class PaymentCardNumber(str): - """ - Based on: https://en.wikipedia.org/wiki/Payment_card_number - """ + """Based on: https://en.wikipedia.org/wiki/Payment_card_number.""" strip_whitespace: ClassVar[bool] = True min_length: ClassVar[int] = 12 @@ -967,36 +1616,47 @@ class PaymentCardNumber(str): brand: PaymentCardBrand def __init__(self, card_number: str): + self.validate_digits(card_number) + + card_number = self.validate_luhn_check_digit(card_number) + self.bin = card_number[:6] self.last4 = card_number[-4:] - self.brand = self._get_brand(card_number) + self.brand = self.validate_brand(card_number) @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield str_validator - yield constr_strip_whitespace - yield constr_length_validator - yield cls.validate_digits - yield cls.validate_luhn_check_digit - yield cls - yield cls.validate_length_for_brand + def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + return core_schema.with_info_after_validator_function( + cls.validate, + core_schema.str_schema( + min_length=cls.min_length, max_length=cls.max_length, strip_whitespace=cls.strip_whitespace + ), + ) + + @classmethod + def validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> PaymentCardNumber: + """Validate the card number and return a `PaymentCardNumber` instance.""" + return cls(__input_value) @property def masked(self) -> str: + """Mask all but the last 4 digits of the card number. + + Returns: + A masked card number string. + """ num_masked = len(self) - 10 # len(bin) + len(last4) == 10 return f'{self.bin}{"*" * num_masked}{self.last4}' @classmethod - def validate_digits(cls, card_number: str) -> str: + def validate_digits(cls, card_number: str) -> None: + """Validate that the card number is all digits.""" if not card_number.isdigit(): - raise errors.NotDigitError - return card_number + raise PydanticCustomError('payment_card_number_digits', 'Card number is not all digits') @classmethod def validate_luhn_check_digit(cls, card_number: str) -> str: - """ - Based on: https://en.wikipedia.org/wiki/Luhn_algorithm - """ + """Based on: https://en.wikipedia.org/wiki/Luhn_algorithm.""" sum_ = int(card_number[-1]) length = len(card_number) parity = length % 2 @@ -1009,33 +1669,14 @@ class PaymentCardNumber(str): sum_ += digit valid = sum_ % 10 == 0 if not valid: - raise errors.LuhnValidationError - return card_number - - @classmethod - def validate_length_for_brand(cls, card_number: 'PaymentCardNumber') -> 'PaymentCardNumber': - """ - Validate length based on BIN for major brands: - https://en.wikipedia.org/wiki/Payment_card_number#Issuer_identification_number_(IIN) - """ - required_length: Union[None, int, str] = None - if card_number.brand in PaymentCardBrand.mastercard: - required_length = 16 - valid = len(card_number) == required_length - elif card_number.brand == PaymentCardBrand.visa: - required_length = '13, 16 or 19' - valid = len(card_number) in {13, 16, 19} - elif card_number.brand == PaymentCardBrand.amex: - required_length = 15 - valid = len(card_number) == required_length - else: - valid = True - if not valid: - raise errors.InvalidLengthForBrand(brand=card_number.brand, required_length=required_length) + raise PydanticCustomError('payment_card_number_luhn', 'Card number is not luhn valid') return card_number @staticmethod - def _get_brand(card_number: str) -> PaymentCardBrand: + def validate_brand(card_number: str) -> PaymentCardBrand: + """Validate length based on BIN for major brands: + https://en.wikipedia.org/wiki/Payment_card_number#Issuer_identification_number_(IIN). + """ if card_number[0] == '4': brand = PaymentCardBrand.visa elif 51 <= int(card_number[:2]) <= 55: @@ -1044,144 +1685,1195 @@ class PaymentCardNumber(str): brand = PaymentCardBrand.amex else: brand = PaymentCardBrand.other + + required_length: None | int | str = None + if brand in PaymentCardBrand.mastercard: + required_length = 16 + valid = len(card_number) == required_length + elif brand == PaymentCardBrand.visa: + required_length = '13, 16 or 19' + valid = len(card_number) in {13, 16, 19} + elif brand == PaymentCardBrand.amex: + required_length = 15 + valid = len(card_number) == required_length + else: + valid = True + + if not valid: + raise PydanticCustomError( + 'payment_card_number_brand', + 'Length for a {brand} card must be {required_length}', + {'brand': brand, 'required_length': required_length}, + ) return brand # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTE SIZE TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -BYTE_SIZES = { - 'b': 1, - 'kb': 10**3, - 'mb': 10**6, - 'gb': 10**9, - 'tb': 10**12, - 'pb': 10**15, - 'eb': 10**18, - 'kib': 2**10, - 'mib': 2**20, - 'gib': 2**30, - 'tib': 2**40, - 'pib': 2**50, - 'eib': 2**60, -} -BYTE_SIZES.update({k.lower()[0]: v for k, v in BYTE_SIZES.items() if 'i' not in k}) -byte_string_re = re.compile(r'^\s*(\d*\.?\d+)\s*(\w+)?', re.IGNORECASE) - class ByteSize(int): - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield cls.validate + """Converts a string representing a number of bytes with units (such as `'1KB'` or `'11.5MiB'`) into an integer. + + You can use the `ByteSize` data type to (case-insensitively) convert a string representation of a number of bytes into + an integer, and also to print out human-readable strings representing a number of bytes. + + In conformance with [IEC 80000-13 Standard](https://en.wikipedia.org/wiki/ISO/IEC_80000) we interpret `'1KB'` to mean 1000 bytes, + and `'1KiB'` to mean 1024 bytes. In general, including a middle `'i'` will cause the unit to be interpreted as a power of 2, + rather than a power of 10 (so, for example, `'1 MB'` is treated as `1_000_000` bytes, whereas `'1 MiB'` is treated as `1_048_576` bytes). + + !!! info + Note that `1b` will be parsed as "1 byte" and not "1 bit". + + ```py + from pydantic import BaseModel, ByteSize + + class MyModel(BaseModel): + size: ByteSize + + print(MyModel(size=52000).size) + #> 52000 + print(MyModel(size='3000 KiB').size) + #> 3072000 + + m = MyModel(size='50 PB') + print(m.size.human_readable()) + #> 44.4PiB + print(m.size.human_readable(decimal=True)) + #> 50.0PB + + print(m.size.to('TiB')) + #> 45474.73508864641 + ``` + """ + + byte_sizes = { + 'b': 1, + 'kb': 10**3, + 'mb': 10**6, + 'gb': 10**9, + 'tb': 10**12, + 'pb': 10**15, + 'eb': 10**18, + 'kib': 2**10, + 'mib': 2**20, + 'gib': 2**30, + 'tib': 2**40, + 'pib': 2**50, + 'eib': 2**60, + 'bit': 1 / 8, + 'kbit': 10**3 / 8, + 'mbit': 10**6 / 8, + 'gbit': 10**9 / 8, + 'tbit': 10**12 / 8, + 'pbit': 10**15 / 8, + 'ebit': 10**18 / 8, + 'kibit': 2**10 / 8, + 'mibit': 2**20 / 8, + 'gibit': 2**30 / 8, + 'tibit': 2**40 / 8, + 'pibit': 2**50 / 8, + 'eibit': 2**60 / 8, + } + byte_sizes.update({k.lower()[0]: v for k, v in byte_sizes.items() if 'i' not in k}) + + byte_string_pattern = r'^\s*(\d*\.?\d+)\s*(\w+)?' + byte_string_re = re.compile(byte_string_pattern, re.IGNORECASE) @classmethod - def validate(cls, v: StrIntFloat) -> 'ByteSize': + def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + return core_schema.with_info_after_validator_function( + function=cls._validate, + schema=core_schema.union_schema( + [ + core_schema.str_schema(pattern=cls.byte_string_pattern), + core_schema.int_schema(ge=0), + ], + custom_error_type='byte_size', + custom_error_message='could not parse value and unit from byte string', + ), + serialization=core_schema.plain_serializer_function_ser_schema( + int, return_schema=core_schema.int_schema(ge=0) + ), + ) + @classmethod + def _validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> ByteSize: try: - return cls(int(v)) + return cls(int(__input_value)) except ValueError: pass - str_match = byte_string_re.match(str(v)) + str_match = cls.byte_string_re.match(str(__input_value)) if str_match is None: - raise errors.InvalidByteSize() + raise PydanticCustomError('byte_size', 'could not parse value and unit from byte string') scalar, unit = str_match.groups() if unit is None: unit = 'b' try: - unit_mult = BYTE_SIZES[unit.lower()] + unit_mult = cls.byte_sizes[unit.lower()] except KeyError: - raise errors.InvalidByteSizeUnit(unit=unit) + raise PydanticCustomError('byte_size_unit', 'could not interpret byte unit: {unit}', {'unit': unit}) return cls(int(float(scalar) * unit_mult)) def human_readable(self, decimal: bool = False) -> str: + """Converts a byte size to a human readable string. + Args: + decimal: If True, use decimal units (e.g. 1000 bytes per KB). If False, use binary units + (e.g. 1024 bytes per KiB). + + Returns: + A human readable string representation of the byte size. + """ if decimal: divisor = 1000 - units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB'] + units = 'B', 'KB', 'MB', 'GB', 'TB', 'PB' final_unit = 'EB' else: divisor = 1024 - units = ['B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB'] + units = 'B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB' final_unit = 'EiB' num = float(self) for unit in units: if abs(num) < divisor: - return f'{num:0.1f}{unit}' + if unit == 'B': + return f'{num:0.0f}{unit}' + else: + return f'{num:0.1f}{unit}' num /= divisor return f'{num:0.1f}{final_unit}' def to(self, unit: str) -> float: + """Converts a byte size to another unit, including both byte and bit units. + Args: + unit: The unit to convert to. Must be one of the following: B, KB, MB, GB, TB, PB, EB, + KiB, MiB, GiB, TiB, PiB, EiB (byte units) and + bit, kbit, mbit, gbit, tbit, pbit, ebit, + kibit, mibit, gibit, tibit, pibit, eibit (bit units). + + Returns: + The byte size in the new unit. + """ try: - unit_div = BYTE_SIZES[unit.lower()] + unit_div = self.byte_sizes[unit.lower()] except KeyError: - raise errors.InvalidByteSizeUnit(unit=unit) + raise PydanticCustomError('byte_size_unit', 'Could not interpret byte unit: {unit}', {'unit': unit}) return self / unit_div # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DATE TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +def _check_annotated_type(annotated_type: str, expected_type: str, annotation: str) -> None: + if annotated_type != expected_type: + raise PydanticUserError(f"'{annotation}' cannot annotate '{annotated_type}'.", code='invalid_annotated_type') + + if TYPE_CHECKING: - PastDate = date - FutureDate = date + PastDate = Annotated[date, ...] + FutureDate = Annotated[date, ...] else: - class PastDate(date): - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield parse_date - yield cls.validate + class PastDate: + """A date in the past.""" @classmethod - def validate(cls, value: date) -> date: - if value >= date.today(): - raise errors.DateNotInThePastError() + def __get_pydantic_core_schema__( + cls, source: type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + if cls is source: + # used directly as a type + return core_schema.date_schema(now_op='past') + else: + schema = handler(source) + _check_annotated_type(schema['type'], 'date', cls.__name__) + schema['now_op'] = 'past' + return schema - return value + def __repr__(self) -> str: + return 'PastDate' - class FutureDate(date): - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield parse_date - yield cls.validate + class FutureDate: + """A date in the future.""" @classmethod - def validate(cls, value: date) -> date: - if value <= date.today(): - raise errors.DateNotInTheFutureError() + def __get_pydantic_core_schema__( + cls, source: type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + if cls is source: + # used directly as a type + return core_schema.date_schema(now_op='future') + else: + schema = handler(source) + _check_annotated_type(schema['type'], 'date', cls.__name__) + schema['now_op'] = 'future' + return schema - return value - - -class ConstrainedDate(date, metaclass=ConstrainedNumberMeta): - gt: OptionalDate = None - ge: OptionalDate = None - lt: OptionalDate = None - le: OptionalDate = None - - @classmethod - def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: - update_not_none(field_schema, exclusiveMinimum=cls.gt, exclusiveMaximum=cls.lt, minimum=cls.ge, maximum=cls.le) - - @classmethod - def __get_validators__(cls) -> 'CallableGenerator': - yield parse_date - yield number_size_validator + def __repr__(self) -> str: + return 'FutureDate' def condate( *, - gt: date = None, - ge: date = None, - lt: date = None, - le: date = None, -) -> Type[date]: - # use kwargs then define conf in a dict to aid with IDE type hinting - namespace = dict(gt=gt, ge=ge, lt=lt, le=le) - return type('ConstrainedDateValue', (ConstrainedDate,), namespace) + strict: bool | None = None, + gt: date | None = None, + ge: date | None = None, + lt: date | None = None, + le: date | None = None, +) -> type[date]: + """A wrapper for date that adds constraints. + + Args: + strict: Whether to validate the date value in strict mode. Defaults to `None`. + gt: The value must be greater than this. Defaults to `None`. + ge: The value must be greater than or equal to this. Defaults to `None`. + lt: The value must be less than this. Defaults to `None`. + le: The value must be less than or equal to this. Defaults to `None`. + + Returns: + A date type with the specified constraints. + """ + return Annotated[ + date, + Strict(strict) if strict is not None else None, + annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le), + ] + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DATETIME TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +if TYPE_CHECKING: + AwareDatetime = Annotated[datetime, ...] + NaiveDatetime = Annotated[datetime, ...] + PastDatetime = Annotated[datetime, ...] + FutureDatetime = Annotated[datetime, ...] + +else: + + class AwareDatetime: + """A datetime that requires timezone info.""" + + @classmethod + def __get_pydantic_core_schema__( + cls, source: type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + if cls is source: + # used directly as a type + return core_schema.datetime_schema(tz_constraint='aware') + else: + schema = handler(source) + _check_annotated_type(schema['type'], 'datetime', cls.__name__) + schema['tz_constraint'] = 'aware' + return schema + + def __repr__(self) -> str: + return 'AwareDatetime' + + class NaiveDatetime: + """A datetime that doesn't require timezone info.""" + + @classmethod + def __get_pydantic_core_schema__( + cls, source: type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + if cls is source: + # used directly as a type + return core_schema.datetime_schema(tz_constraint='naive') + else: + schema = handler(source) + _check_annotated_type(schema['type'], 'datetime', cls.__name__) + schema['tz_constraint'] = 'naive' + return schema + + def __repr__(self) -> str: + return 'NaiveDatetime' + + class PastDatetime: + """A datetime that must be in the past.""" + + @classmethod + def __get_pydantic_core_schema__( + cls, source: type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + if cls is source: + # used directly as a type + return core_schema.datetime_schema(now_op='past') + else: + schema = handler(source) + _check_annotated_type(schema['type'], 'datetime', cls.__name__) + schema['now_op'] = 'past' + return schema + + def __repr__(self) -> str: + return 'PastDatetime' + + class FutureDatetime: + """A datetime that must be in the future.""" + + @classmethod + def __get_pydantic_core_schema__( + cls, source: type[Any], handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + if cls is source: + # used directly as a type + return core_schema.datetime_schema(now_op='future') + else: + schema = handler(source) + _check_annotated_type(schema['type'], 'datetime', cls.__name__) + schema['now_op'] = 'future' + return schema + + def __repr__(self) -> str: + return 'FutureDatetime' + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Encoded TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class EncoderProtocol(Protocol): + """Protocol for encoding and decoding data to and from bytes.""" + + @classmethod + def decode(cls, data: bytes) -> bytes: + """Decode the data using the encoder. + + Args: + data: The data to decode. + + Returns: + The decoded data. + """ + ... + + @classmethod + def encode(cls, value: bytes) -> bytes: + """Encode the data using the encoder. + + Args: + value: The data to encode. + + Returns: + The encoded data. + """ + ... + + @classmethod + def get_json_format(cls) -> str: + """Get the JSON format for the encoded data. + + Returns: + The JSON format for the encoded data. + """ + ... + + +class Base64Encoder(EncoderProtocol): + """Standard (non-URL-safe) Base64 encoder.""" + + @classmethod + def decode(cls, data: bytes) -> bytes: + """Decode the data from base64 encoded bytes to original bytes data. + + Args: + data: The data to decode. + + Returns: + The decoded data. + """ + try: + return base64.decodebytes(data) + except ValueError as e: + raise PydanticCustomError('base64_decode', "Base64 decoding error: '{error}'", {'error': str(e)}) + + @classmethod + def encode(cls, value: bytes) -> bytes: + """Encode the data from bytes to a base64 encoded bytes. + + Args: + value: The data to encode. + + Returns: + The encoded data. + """ + return base64.encodebytes(value) + + @classmethod + def get_json_format(cls) -> Literal['base64']: + """Get the JSON format for the encoded data. + + Returns: + The JSON format for the encoded data. + """ + return 'base64' + + +class Base64UrlEncoder(EncoderProtocol): + """URL-safe Base64 encoder.""" + + @classmethod + def decode(cls, data: bytes) -> bytes: + """Decode the data from base64 encoded bytes to original bytes data. + + Args: + data: The data to decode. + + Returns: + The decoded data. + """ + try: + return base64.urlsafe_b64decode(data) + except ValueError as e: + raise PydanticCustomError('base64_decode', "Base64 decoding error: '{error}'", {'error': str(e)}) + + @classmethod + def encode(cls, value: bytes) -> bytes: + """Encode the data from bytes to a base64 encoded bytes. + + Args: + value: The data to encode. + + Returns: + The encoded data. + """ + return base64.urlsafe_b64encode(value) + + @classmethod + def get_json_format(cls) -> Literal['base64url']: + """Get the JSON format for the encoded data. + + Returns: + The JSON format for the encoded data. + """ + return 'base64url' + + +@_dataclasses.dataclass(**_internal_dataclass.slots_true) +class EncodedBytes: + """A bytes type that is encoded and decoded using the specified encoder. + + `EncodedBytes` needs an encoder that implements `EncoderProtocol` to operate. + + ```py + from typing_extensions import Annotated + + from pydantic import BaseModel, EncodedBytes, EncoderProtocol, ValidationError + + class MyEncoder(EncoderProtocol): + @classmethod + def decode(cls, data: bytes) -> bytes: + if data == b'**undecodable**': + raise ValueError('Cannot decode data') + return data[13:] + + @classmethod + def encode(cls, value: bytes) -> bytes: + return b'**encoded**: ' + value + + @classmethod + def get_json_format(cls) -> str: + return 'my-encoder' + + MyEncodedBytes = Annotated[bytes, EncodedBytes(encoder=MyEncoder)] + + class Model(BaseModel): + my_encoded_bytes: MyEncodedBytes + + # Initialize the model with encoded data + m = Model(my_encoded_bytes=b'**encoded**: some bytes') + + # Access decoded value + print(m.my_encoded_bytes) + #> b'some bytes' + + # Serialize into the encoded form + print(m.model_dump()) + #> {'my_encoded_bytes': b'**encoded**: some bytes'} + + # Validate encoded data + try: + Model(my_encoded_bytes=b'**undecodable**') + except ValidationError as e: + print(e) + ''' + 1 validation error for Model + my_encoded_bytes + Value error, Cannot decode data [type=value_error, input_value=b'**undecodable**', input_type=bytes] + ''' + ``` + """ + + encoder: type[EncoderProtocol] + + def __get_pydantic_json_schema__( + self, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + field_schema = handler(core_schema) + field_schema.update(type='string', format=self.encoder.get_json_format()) + return field_schema + + def __get_pydantic_core_schema__(self, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + return core_schema.with_info_after_validator_function( + function=self.decode, + schema=core_schema.bytes_schema(), + serialization=core_schema.plain_serializer_function_ser_schema(function=self.encode), + ) + + def decode(self, data: bytes, _: core_schema.ValidationInfo) -> bytes: + """Decode the data using the specified encoder. + + Args: + data: The data to decode. + + Returns: + The decoded data. + """ + return self.encoder.decode(data) + + def encode(self, value: bytes) -> bytes: + """Encode the data using the specified encoder. + + Args: + value: The data to encode. + + Returns: + The encoded data. + """ + return self.encoder.encode(value) + + def __hash__(self) -> int: + return hash(self.encoder) + + +@_dataclasses.dataclass(**_internal_dataclass.slots_true) +class EncodedStr(EncodedBytes): + """A str type that is encoded and decoded using the specified encoder. + + `EncodedStr` needs an encoder that implements `EncoderProtocol` to operate. + + ```py + from typing_extensions import Annotated + + from pydantic import BaseModel, EncodedStr, EncoderProtocol, ValidationError + + class MyEncoder(EncoderProtocol): + @classmethod + def decode(cls, data: bytes) -> bytes: + if data == b'**undecodable**': + raise ValueError('Cannot decode data') + return data[13:] + + @classmethod + def encode(cls, value: bytes) -> bytes: + return b'**encoded**: ' + value + + @classmethod + def get_json_format(cls) -> str: + return 'my-encoder' + + MyEncodedStr = Annotated[str, EncodedStr(encoder=MyEncoder)] + + class Model(BaseModel): + my_encoded_str: MyEncodedStr + + # Initialize the model with encoded data + m = Model(my_encoded_str='**encoded**: some str') + + # Access decoded value + print(m.my_encoded_str) + #> some str + + # Serialize into the encoded form + print(m.model_dump()) + #> {'my_encoded_str': '**encoded**: some str'} + + # Validate encoded data + try: + Model(my_encoded_str='**undecodable**') + except ValidationError as e: + print(e) + ''' + 1 validation error for Model + my_encoded_str + Value error, Cannot decode data [type=value_error, input_value='**undecodable**', input_type=str] + ''' + ``` + """ + + def __get_pydantic_core_schema__(self, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + return core_schema.with_info_after_validator_function( + function=self.decode_str, + schema=super(EncodedStr, self).__get_pydantic_core_schema__(source=source, handler=handler), # noqa: UP008 + serialization=core_schema.plain_serializer_function_ser_schema(function=self.encode_str), + ) + + def decode_str(self, data: bytes, _: core_schema.ValidationInfo) -> str: + """Decode the data using the specified encoder. + + Args: + data: The data to decode. + + Returns: + The decoded data. + """ + return data.decode() + + def encode_str(self, value: str) -> str: + """Encode the data using the specified encoder. + + Args: + value: The data to encode. + + Returns: + The encoded data. + """ + return super(EncodedStr, self).encode(value=value.encode()).decode() # noqa: UP008 + + def __hash__(self) -> int: + return hash(self.encoder) + + +Base64Bytes = Annotated[bytes, EncodedBytes(encoder=Base64Encoder)] +"""A bytes type that is encoded and decoded using the standard (non-URL-safe) base64 encoder. + +Note: + Under the hood, `Base64Bytes` use standard library `base64.encodebytes` and `base64.decodebytes` functions. + + As a result, attempting to decode url-safe base64 data using the `Base64Bytes` type may fail or produce an incorrect + decoding. + +```py +from pydantic import Base64Bytes, BaseModel, ValidationError + +class Model(BaseModel): + base64_bytes: Base64Bytes + +# Initialize the model with base64 data +m = Model(base64_bytes=b'VGhpcyBpcyB0aGUgd2F5') + +# Access decoded value +print(m.base64_bytes) +#> b'This is the way' + +# Serialize into the base64 form +print(m.model_dump()) +#> {'base64_bytes': b'VGhpcyBpcyB0aGUgd2F5\n'} + +# Validate base64 data +try: + print(Model(base64_bytes=b'undecodable').base64_bytes) +except ValidationError as e: + print(e) + ''' + 1 validation error for Model + base64_bytes + Base64 decoding error: 'Incorrect padding' [type=base64_decode, input_value=b'undecodable', input_type=bytes] + ''' +``` +""" +Base64Str = Annotated[str, EncodedStr(encoder=Base64Encoder)] +"""A str type that is encoded and decoded using the standard (non-URL-safe) base64 encoder. + +Note: + Under the hood, `Base64Bytes` use standard library `base64.encodebytes` and `base64.decodebytes` functions. + + As a result, attempting to decode url-safe base64 data using the `Base64Str` type may fail or produce an incorrect + decoding. + +```py +from pydantic import Base64Str, BaseModel, ValidationError + +class Model(BaseModel): + base64_str: Base64Str + +# Initialize the model with base64 data +m = Model(base64_str='VGhlc2UgYXJlbid0IHRoZSBkcm9pZHMgeW91J3JlIGxvb2tpbmcgZm9y') + +# Access decoded value +print(m.base64_str) +#> These aren't the droids you're looking for + +# Serialize into the base64 form +print(m.model_dump()) +#> {'base64_str': 'VGhlc2UgYXJlbid0IHRoZSBkcm9pZHMgeW91J3JlIGxvb2tpbmcgZm9y\n'} + +# Validate base64 data +try: + print(Model(base64_str='undecodable').base64_str) +except ValidationError as e: + print(e) + ''' + 1 validation error for Model + base64_str + Base64 decoding error: 'Incorrect padding' [type=base64_decode, input_value='undecodable', input_type=str] + ''' +``` +""" +Base64UrlBytes = Annotated[bytes, EncodedBytes(encoder=Base64UrlEncoder)] +"""A bytes type that is encoded and decoded using the URL-safe base64 encoder. + +Note: + Under the hood, `Base64UrlBytes` use standard library `base64.urlsafe_b64encode` and `base64.urlsafe_b64decode` + functions. + + As a result, the `Base64UrlBytes` type can be used to faithfully decode "vanilla" base64 data + (using `'+'` and `'/'`). + +```py +from pydantic import Base64UrlBytes, BaseModel + +class Model(BaseModel): + base64url_bytes: Base64UrlBytes + +# Initialize the model with base64 data +m = Model(base64url_bytes=b'SHc_dHc-TXc==') +print(m) +#> base64url_bytes=b'Hw?tw>Mw' +``` +""" +Base64UrlStr = Annotated[str, EncodedStr(encoder=Base64UrlEncoder)] +"""A str type that is encoded and decoded using the URL-safe base64 encoder. + +Note: + Under the hood, `Base64UrlStr` use standard library `base64.urlsafe_b64encode` and `base64.urlsafe_b64decode` + functions. + + As a result, the `Base64UrlStr` type can be used to faithfully decode "vanilla" base64 data (using `'+'` and `'/'`). + +```py +from pydantic import Base64UrlStr, BaseModel + +class Model(BaseModel): + base64url_str: Base64UrlStr + +# Initialize the model with base64 data +m = Model(base64url_str='SHc_dHc-TXc==') +print(m) +#> base64url_str='Hw?tw>Mw' +``` +""" + + +__getattr__ = getattr_migration(__name__) + + +@_dataclasses.dataclass(**_internal_dataclass.slots_true) +class GetPydanticSchema: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/types/#using-getpydanticschema-to-reduce-boilerplate + + A convenience class for creating an annotation that provides pydantic custom type hooks. + + This class is intended to eliminate the need to create a custom "marker" which defines the + `__get_pydantic_core_schema__` and `__get_pydantic_json_schema__` custom hook methods. + + For example, to have a field treated by type checkers as `int`, but by pydantic as `Any`, you can do: + ```python + from typing import Any + + from typing_extensions import Annotated + + from pydantic import BaseModel, GetPydanticSchema + + HandleAsAny = GetPydanticSchema(lambda _s, h: h(Any)) + + class Model(BaseModel): + x: Annotated[int, HandleAsAny] # pydantic sees `x: Any` + + print(repr(Model(x='abc').x)) + #> 'abc' + ``` + """ + + get_pydantic_core_schema: Callable[[Any, GetCoreSchemaHandler], CoreSchema] | None = None + get_pydantic_json_schema: Callable[[Any, GetJsonSchemaHandler], JsonSchemaValue] | None = None + + # Note: we may want to consider adding a convenience staticmethod `def for_type(type_: Any) -> GetPydanticSchema:` + # which returns `GetPydanticSchema(lambda _s, h: h(type_))` + + if not TYPE_CHECKING: + # We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access + + def __getattr__(self, item: str) -> Any: + """Use this rather than defining `__get_pydantic_core_schema__` etc. to reduce the number of nested calls.""" + if item == '__get_pydantic_core_schema__' and self.get_pydantic_core_schema: + return self.get_pydantic_core_schema + elif item == '__get_pydantic_json_schema__' and self.get_pydantic_json_schema: + return self.get_pydantic_json_schema + else: + return object.__getattribute__(self, item) + + __hash__ = object.__hash__ + + +@_dataclasses.dataclass(**_internal_dataclass.slots_true, frozen=True) +class Tag: + """Provides a way to specify the expected tag to use for a case of a (callable) discriminated union. + + Also provides a way to label a union case in error messages. + + When using a callable `Discriminator`, attach a `Tag` to each case in the `Union` to specify the tag that + should be used to identify that case. For example, in the below example, the `Tag` is used to specify that + if `get_discriminator_value` returns `'apple'`, the input should be validated as an `ApplePie`, and if it + returns `'pumpkin'`, the input should be validated as a `PumpkinPie`. + + The primary role of the `Tag` here is to map the return value from the callable `Discriminator` function to + the appropriate member of the `Union` in question. + + ```py + from typing import Any, Union + + from typing_extensions import Annotated, Literal + + from pydantic import BaseModel, Discriminator, Tag + + class Pie(BaseModel): + time_to_cook: int + num_ingredients: int + + class ApplePie(Pie): + fruit: Literal['apple'] = 'apple' + + class PumpkinPie(Pie): + filling: Literal['pumpkin'] = 'pumpkin' + + def get_discriminator_value(v: Any) -> str: + if isinstance(v, dict): + return v.get('fruit', v.get('filling')) + return getattr(v, 'fruit', getattr(v, 'filling', None)) + + class ThanksgivingDinner(BaseModel): + dessert: Annotated[ + Union[ + Annotated[ApplePie, Tag('apple')], + Annotated[PumpkinPie, Tag('pumpkin')], + ], + Discriminator(get_discriminator_value), + ] + + apple_variation = ThanksgivingDinner.model_validate( + {'dessert': {'fruit': 'apple', 'time_to_cook': 60, 'num_ingredients': 8}} + ) + print(repr(apple_variation)) + ''' + ThanksgivingDinner(dessert=ApplePie(time_to_cook=60, num_ingredients=8, fruit='apple')) + ''' + + pumpkin_variation = ThanksgivingDinner.model_validate( + { + 'dessert': { + 'filling': 'pumpkin', + 'time_to_cook': 40, + 'num_ingredients': 6, + } + } + ) + print(repr(pumpkin_variation)) + ''' + ThanksgivingDinner(dessert=PumpkinPie(time_to_cook=40, num_ingredients=6, filling='pumpkin')) + ''' + ``` + + !!! note + You must specify a `Tag` for every case in a `Tag` that is associated with a + callable `Discriminator`. Failing to do so will result in a `PydanticUserError` with code + [`callable-discriminator-no-tag`](../errors/usage_errors.md#callable-discriminator-no-tag). + + See the [Discriminated Unions] concepts docs for more details on how to use `Tag`s. + + [Discriminated Unions]: ../concepts/unions.md#discriminated-unions + """ + + tag: str + + def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: + schema = handler(source_type) + metadata = schema.setdefault('metadata', {}) + assert isinstance(metadata, dict) + metadata[_core_utils.TAGGED_UNION_TAG_KEY] = self.tag + return schema + + +@_dataclasses.dataclass(**_internal_dataclass.slots_true, frozen=True) +class Discriminator: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/unions/#discriminated-unions-with-callable-discriminator + + Provides a way to use a custom callable as the way to extract the value of a union discriminator. + + This allows you to get validation behavior like you'd get from `Field(discriminator=)`, + but without needing to have a single shared field across all the union choices. This also makes it + possible to handle unions of models and primitive types with discriminated-union-style validation errors. + Finally, this allows you to use a custom callable as the way to identify which member of a union a value + belongs to, while still seeing all the performance benefits of a discriminated union. + + Consider this example, which is much more performant with the use of `Discriminator` and thus a `TaggedUnion` + than it would be as a normal `Union`. + + ```py + from typing import Any, Union + + from typing_extensions import Annotated, Literal + + from pydantic import BaseModel, Discriminator, Tag + + class Pie(BaseModel): + time_to_cook: int + num_ingredients: int + + class ApplePie(Pie): + fruit: Literal['apple'] = 'apple' + + class PumpkinPie(Pie): + filling: Literal['pumpkin'] = 'pumpkin' + + def get_discriminator_value(v: Any) -> str: + if isinstance(v, dict): + return v.get('fruit', v.get('filling')) + return getattr(v, 'fruit', getattr(v, 'filling', None)) + + class ThanksgivingDinner(BaseModel): + dessert: Annotated[ + Union[ + Annotated[ApplePie, Tag('apple')], + Annotated[PumpkinPie, Tag('pumpkin')], + ], + Discriminator(get_discriminator_value), + ] + + apple_variation = ThanksgivingDinner.model_validate( + {'dessert': {'fruit': 'apple', 'time_to_cook': 60, 'num_ingredients': 8}} + ) + print(repr(apple_variation)) + ''' + ThanksgivingDinner(dessert=ApplePie(time_to_cook=60, num_ingredients=8, fruit='apple')) + ''' + + pumpkin_variation = ThanksgivingDinner.model_validate( + { + 'dessert': { + 'filling': 'pumpkin', + 'time_to_cook': 40, + 'num_ingredients': 6, + } + } + ) + print(repr(pumpkin_variation)) + ''' + ThanksgivingDinner(dessert=PumpkinPie(time_to_cook=40, num_ingredients=6, filling='pumpkin')) + ''' + ``` + + See the [Discriminated Unions] concepts docs for more details on how to use `Discriminator`s. + + [Discriminated Unions]: ../concepts/unions.md#discriminated-unions + """ + + discriminator: str | Callable[[Any], Hashable] + """The callable or field name for discriminating the type in a tagged union. + + A `Callable` discriminator must extract the value of the discriminator from the input. + A `str` discriminator must be the name of a field to discriminate against. + """ + custom_error_type: str | None = None + """Type to use in [custom errors](../errors/errors.md#custom-errors) replacing the standard discriminated union + validation errors. + """ + custom_error_message: str | None = None + """Message to use in custom errors.""" + custom_error_context: dict[str, int | str | float] | None = None + """Context to use in custom errors.""" + + def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: + origin = _typing_extra.get_origin(source_type) + if not origin or not _typing_extra.origin_is_union(origin): + raise TypeError(f'{type(self).__name__} must be used with a Union type, not {source_type}') + + if isinstance(self.discriminator, str): + from pydantic import Field + + return handler(Annotated[source_type, Field(discriminator=self.discriminator)]) + else: + original_schema = handler(source_type) + return self._convert_schema(original_schema) + + def _convert_schema(self, original_schema: core_schema.CoreSchema) -> core_schema.TaggedUnionSchema: + if original_schema['type'] != 'union': + # This likely indicates that the schema was a single-item union that was simplified. + # In this case, we do the same thing we do in + # `pydantic._internal._discriminated_union._ApplyInferredDiscriminator._apply_to_root`, namely, + # package the generated schema back into a single-item union. + original_schema = core_schema.union_schema([original_schema]) + + tagged_union_choices = {} + for i, choice in enumerate(original_schema['choices']): + tag = None + if isinstance(choice, tuple): + choice, tag = choice + metadata = choice.get('metadata') + if metadata is not None: + metadata_tag = metadata.get(_core_utils.TAGGED_UNION_TAG_KEY) + if metadata_tag is not None: + tag = metadata_tag + if tag is None: + raise PydanticUserError( + f'`Tag` not provided for choice {choice} used with `Discriminator`', + code='callable-discriminator-no-tag', + ) + tagged_union_choices[tag] = choice + + # Have to do these verbose checks to ensure falsy values ('' and {}) don't get ignored + custom_error_type = self.custom_error_type + if custom_error_type is None: + custom_error_type = original_schema.get('custom_error_type') + + custom_error_message = self.custom_error_message + if custom_error_message is None: + custom_error_message = original_schema.get('custom_error_message') + + custom_error_context = self.custom_error_context + if custom_error_context is None: + custom_error_context = original_schema.get('custom_error_context') + + custom_error_type = original_schema.get('custom_error_type') if custom_error_type is None else custom_error_type + return core_schema.tagged_union_schema( + tagged_union_choices, + self.discriminator, + custom_error_type=custom_error_type, + custom_error_message=custom_error_message, + custom_error_context=custom_error_context, + strict=original_schema.get('strict'), + ref=original_schema.get('ref'), + metadata=original_schema.get('metadata'), + serialization=original_schema.get('serialization'), + ) + + +_JSON_TYPES = {int, float, str, bool, list, dict, type(None)} + + +def _get_type_name(x: Any) -> str: + type_ = type(x) + if type_ in _JSON_TYPES: + return type_.__name__ + + # Handle proper subclasses; note we don't need to handle None or bool here + if isinstance(x, int): + return 'int' + if isinstance(x, float): + return 'float' + if isinstance(x, str): + return 'str' + if isinstance(x, list): + return 'list' + if isinstance(x, dict): + return 'dict' + + # Fail by returning the type's actual name + return getattr(type_, '__name__', '') + + +class _AllowAnyJson: + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: + python_schema = handler(source_type) + return core_schema.json_or_python_schema(json_schema=core_schema.any_schema(), python_schema=python_schema) + + +if TYPE_CHECKING: + # This seems to only be necessary for mypy + JsonValue: TypeAlias = Union[ + List['JsonValue'], + Dict[str, 'JsonValue'], + str, + bool, + int, + float, + None, + ] + """A `JsonValue` is used to represent a value that can be serialized to JSON. + + It may be one of: + + * `List['JsonValue']` + * `Dict[str, 'JsonValue']` + * `str` + * `bool` + * `int` + * `float` + * `None` + + The following example demonstrates how to use `JsonValue` to validate JSON data, + and what kind of errors to expect when input data is not json serializable. + + ```py + import json + + from pydantic import BaseModel, JsonValue, ValidationError + + class Model(BaseModel): + j: JsonValue + + valid_json_data = {'j': {'a': {'b': {'c': 1, 'd': [2, None]}}}} + invalid_json_data = {'j': {'a': {'b': ...}}} + + print(repr(Model.model_validate(valid_json_data))) + #> Model(j={'a': {'b': {'c': 1, 'd': [2, None]}}}) + print(repr(Model.model_validate_json(json.dumps(valid_json_data)))) + #> Model(j={'a': {'b': {'c': 1, 'd': [2, None]}}}) + + try: + Model.model_validate(invalid_json_data) + except ValidationError as e: + print(e) + ''' + 1 validation error for Model + j.dict.a.dict.b + input was not a valid JSON value [type=invalid-json-value, input_value=Ellipsis, input_type=ellipsis] + ''' + ``` + """ + +else: + JsonValue = TypeAliasType( + 'JsonValue', + Annotated[ + Union[ + Annotated[List['JsonValue'], Tag('list')], + Annotated[Dict[str, 'JsonValue'], Tag('dict')], + Annotated[str, Tag('str')], + Annotated[bool, Tag('bool')], + Annotated[int, Tag('int')], + Annotated[float, Tag('float')], + Annotated[None, Tag('NoneType')], + ], + Discriminator( + _get_type_name, + custom_error_type='invalid-json-value', + custom_error_message='input was not a valid JSON value', + ), + _AllowAnyJson, + ], + ) + + +class _OnErrorOmit: + @classmethod + def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: + # there is no actual default value here but we use with_default_schema since it already has the on_error + # behavior implemented and it would be no more efficient to implement it on every other validator + # or as a standalone validator + return core_schema.with_default_schema(schema=handler(source_type), on_error='omit') + + +OnErrorOmit = Annotated[T, _OnErrorOmit] +""" +When used as an item in a list, the key type in a dict, optional values of a TypedDict, etc. +this annotation omits the item from the iteration if there is any error validating it. +That is, instead of a [`ValidationError`][pydantic_core.ValidationError] being propagated up and the entire iterable being discarded +any invalid items are discarded and the valid ones are returned. +""" diff --git a/lib/pydantic/typing.py b/lib/pydantic/typing.py index 5ccf266c..f1b32ba2 100644 --- a/lib/pydantic/typing.py +++ b/lib/pydantic/typing.py @@ -1,602 +1,4 @@ -import sys -from collections.abc import Callable -from os import PathLike -from typing import ( # type: ignore - TYPE_CHECKING, - AbstractSet, - Any, - Callable as TypingCallable, - ClassVar, - Dict, - ForwardRef, - Generator, - Iterable, - List, - Mapping, - NewType, - Optional, - Sequence, - Set, - Tuple, - Type, - TypeVar, - Union, - _eval_type, - cast, - get_type_hints, -) +"""`typing` module is a backport module from V1.""" +from ._migration import getattr_migration -from typing_extensions import ( - Annotated, - Final, - Literal, - NotRequired as TypedDictNotRequired, - Required as TypedDictRequired, -) - -try: - from typing import _TypingBase as typing_base # type: ignore -except ImportError: - from typing import _Final as typing_base # type: ignore - -try: - from typing import GenericAlias as TypingGenericAlias # type: ignore -except ImportError: - # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on) - TypingGenericAlias = () - -try: - from types import UnionType as TypesUnionType # type: ignore -except ImportError: - # python < 3.10 does not have UnionType (str | int, byte | bool and so on) - TypesUnionType = () - - -if sys.version_info < (3, 9): - - def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: - return type_._evaluate(globalns, localns) - -else: - - def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: - # Even though it is the right signature for python 3.9, mypy complains with - # `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast... - return cast(Any, type_)._evaluate(globalns, localns, set()) - - -if sys.version_info < (3, 9): - # Ensure we always get all the whole `Annotated` hint, not just the annotated type. - # For 3.7 to 3.8, `get_type_hints` doesn't recognize `typing_extensions.Annotated`, - # so it already returns the full annotation - get_all_type_hints = get_type_hints - -else: - - def get_all_type_hints(obj: Any, globalns: Any = None, localns: Any = None) -> Any: - return get_type_hints(obj, globalns, localns, include_extras=True) - - -_T = TypeVar('_T') - -AnyCallable = TypingCallable[..., Any] -NoArgAnyCallable = TypingCallable[[], Any] - -# workaround for https://github.com/python/mypy/issues/9496 -AnyArgTCallable = TypingCallable[..., _T] - - -# Annotated[...] is implemented by returning an instance of one of these classes, depending on -# python/typing_extensions version. -AnnotatedTypeNames = {'AnnotatedMeta', '_AnnotatedAlias'} - - -if sys.version_info < (3, 8): - - def get_origin(t: Type[Any]) -> Optional[Type[Any]]: - if type(t).__name__ in AnnotatedTypeNames: - # weirdly this is a runtime requirement, as well as for mypy - return cast(Type[Any], Annotated) - return getattr(t, '__origin__', None) - -else: - from typing import get_origin as _typing_get_origin - - def get_origin(tp: Type[Any]) -> Optional[Type[Any]]: - """ - We can't directly use `typing.get_origin` since we need a fallback to support - custom generic classes like `ConstrainedList` - It should be useless once https://github.com/cython/cython/issues/3537 is - solved and https://github.com/pydantic/pydantic/pull/1753 is merged. - """ - if type(tp).__name__ in AnnotatedTypeNames: - return cast(Type[Any], Annotated) # mypy complains about _SpecialForm - return _typing_get_origin(tp) or getattr(tp, '__origin__', None) - - -if sys.version_info < (3, 8): - from typing import _GenericAlias - - def get_args(t: Type[Any]) -> Tuple[Any, ...]: - """Compatibility version of get_args for python 3.7. - - Mostly compatible with the python 3.8 `typing` module version - and able to handle almost all use cases. - """ - if type(t).__name__ in AnnotatedTypeNames: - return t.__args__ + t.__metadata__ - if isinstance(t, _GenericAlias): - res = t.__args__ - if t.__origin__ is Callable and res and res[0] is not Ellipsis: - res = (list(res[:-1]), res[-1]) - return res - return getattr(t, '__args__', ()) - -else: - from typing import get_args as _typing_get_args - - def _generic_get_args(tp: Type[Any]) -> Tuple[Any, ...]: - """ - In python 3.9, `typing.Dict`, `typing.List`, ... - do have an empty `__args__` by default (instead of the generic ~T for example). - In order to still support `Dict` for example and consider it as `Dict[Any, Any]`, - we retrieve the `_nparams` value that tells us how many parameters it needs. - """ - if hasattr(tp, '_nparams'): - return (Any,) * tp._nparams - # Special case for `tuple[()]`, which used to return ((),) with `typing.Tuple` - # in python 3.10- but now returns () for `tuple` and `Tuple`. - # This will probably be clarified in pydantic v2 - try: - if tp == Tuple[()] or sys.version_info >= (3, 9) and tp == tuple[()]: # type: ignore[misc] - return ((),) - # there is a TypeError when compiled with cython - except TypeError: # pragma: no cover - pass - return () - - def get_args(tp: Type[Any]) -> Tuple[Any, ...]: - """Get type arguments with all substitutions performed. - - For unions, basic simplifications used by Union constructor are performed. - Examples:: - get_args(Dict[str, int]) == (str, int) - get_args(int) == () - get_args(Union[int, Union[T, int], str][int]) == (int, str) - get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) - get_args(Callable[[], T][int]) == ([], int) - """ - if type(tp).__name__ in AnnotatedTypeNames: - return tp.__args__ + tp.__metadata__ - # the fallback is needed for the same reasons as `get_origin` (see above) - return _typing_get_args(tp) or getattr(tp, '__args__', ()) or _generic_get_args(tp) - - -if sys.version_info < (3, 9): - - def convert_generics(tp: Type[Any]) -> Type[Any]: - """Python 3.9 and older only supports generics from `typing` module. - They convert strings to ForwardRef automatically. - - Examples:: - typing.List['Hero'] == typing.List[ForwardRef('Hero')] - """ - return tp - -else: - from typing import _UnionGenericAlias # type: ignore - - from typing_extensions import _AnnotatedAlias - - def convert_generics(tp: Type[Any]) -> Type[Any]: - """ - Recursively searches for `str` type hints and replaces them with ForwardRef. - - Examples:: - convert_generics(list['Hero']) == list[ForwardRef('Hero')] - convert_generics(dict['Hero', 'Team']) == dict[ForwardRef('Hero'), ForwardRef('Team')] - convert_generics(typing.Dict['Hero', 'Team']) == typing.Dict[ForwardRef('Hero'), ForwardRef('Team')] - convert_generics(list[str | 'Hero'] | int) == list[str | ForwardRef('Hero')] | int - """ - origin = get_origin(tp) - if not origin or not hasattr(tp, '__args__'): - return tp - - args = get_args(tp) - - # typing.Annotated needs special treatment - if origin is Annotated: - return _AnnotatedAlias(convert_generics(args[0]), args[1:]) - - # recursively replace `str` instances inside of `GenericAlias` with `ForwardRef(arg)` - converted = tuple( - ForwardRef(arg) if isinstance(arg, str) and isinstance(tp, TypingGenericAlias) else convert_generics(arg) - for arg in args - ) - - if converted == args: - return tp - elif isinstance(tp, TypingGenericAlias): - return TypingGenericAlias(origin, converted) - elif isinstance(tp, TypesUnionType): - # recreate types.UnionType (PEP604, Python >= 3.10) - return _UnionGenericAlias(origin, converted) - else: - try: - setattr(tp, '__args__', converted) - except AttributeError: - pass - return tp - - -if sys.version_info < (3, 10): - - def is_union(tp: Optional[Type[Any]]) -> bool: - return tp is Union - - WithArgsTypes = (TypingGenericAlias,) - -else: - import types - import typing - - def is_union(tp: Optional[Type[Any]]) -> bool: - return tp is Union or tp is types.UnionType # noqa: E721 - - WithArgsTypes = (typing._GenericAlias, types.GenericAlias, types.UnionType) - - -if sys.version_info < (3, 9): - StrPath = Union[str, PathLike] -else: - StrPath = Union[str, PathLike] - # TODO: Once we switch to Cython 3 to handle generics properly - # (https://github.com/cython/cython/issues/2753), use following lines instead - # of the one above - # # os.PathLike only becomes subscriptable from Python 3.9 onwards - # StrPath = Union[str, PathLike[str]] - - -if TYPE_CHECKING: - from .fields import ModelField - - TupleGenerator = Generator[Tuple[str, Any], None, None] - DictStrAny = Dict[str, Any] - DictAny = Dict[Any, Any] - SetStr = Set[str] - ListStr = List[str] - IntStr = Union[int, str] - AbstractSetIntStr = AbstractSet[IntStr] - DictIntStrAny = Dict[IntStr, Any] - MappingIntStrAny = Mapping[IntStr, Any] - CallableGenerator = Generator[AnyCallable, None, None] - ReprArgs = Sequence[Tuple[Optional[str], Any]] - AnyClassMethod = classmethod[Any] - -__all__ = ( - 'AnyCallable', - 'NoArgAnyCallable', - 'NoneType', - 'is_none_type', - 'display_as_type', - 'resolve_annotations', - 'is_callable_type', - 'is_literal_type', - 'all_literal_values', - 'is_namedtuple', - 'is_typeddict', - 'is_typeddict_special', - 'is_new_type', - 'new_type_supertype', - 'is_classvar', - 'is_finalvar', - 'update_field_forward_refs', - 'update_model_forward_refs', - 'TupleGenerator', - 'DictStrAny', - 'DictAny', - 'SetStr', - 'ListStr', - 'IntStr', - 'AbstractSetIntStr', - 'DictIntStrAny', - 'CallableGenerator', - 'ReprArgs', - 'AnyClassMethod', - 'CallableGenerator', - 'WithArgsTypes', - 'get_args', - 'get_origin', - 'get_sub_types', - 'typing_base', - 'get_all_type_hints', - 'is_union', - 'StrPath', - 'MappingIntStrAny', -) - - -NoneType = None.__class__ - - -NONE_TYPES: Tuple[Any, Any, Any] = (None, NoneType, Literal[None]) - - -if sys.version_info < (3, 8): - # Even though this implementation is slower, we need it for python 3.7: - # In python 3.7 "Literal" is not a builtin type and uses a different - # mechanism. - # for this reason `Literal[None] is Literal[None]` evaluates to `False`, - # breaking the faster implementation used for the other python versions. - - def is_none_type(type_: Any) -> bool: - return type_ in NONE_TYPES - -elif sys.version_info[:2] == (3, 8): - - def is_none_type(type_: Any) -> bool: - for none_type in NONE_TYPES: - if type_ is none_type: - return True - # With python 3.8, specifically 3.8.10, Literal "is" check sare very flakey - # can change on very subtle changes like use of types in other modules, - # hopefully this check avoids that issue. - if is_literal_type(type_): # pragma: no cover - return all_literal_values(type_) == (None,) - return False - -else: - - def is_none_type(type_: Any) -> bool: - for none_type in NONE_TYPES: - if type_ is none_type: - return True - return False - - -def display_as_type(v: Type[Any]) -> str: - if not isinstance(v, typing_base) and not isinstance(v, WithArgsTypes) and not isinstance(v, type): - v = v.__class__ - - if is_union(get_origin(v)): - return f'Union[{", ".join(map(display_as_type, get_args(v)))}]' - - if isinstance(v, WithArgsTypes): - # Generic alias are constructs like `list[int]` - return str(v).replace('typing.', '') - - try: - return v.__name__ - except AttributeError: - # happens with typing objects - return str(v).replace('typing.', '') - - -def resolve_annotations(raw_annotations: Dict[str, Type[Any]], module_name: Optional[str]) -> Dict[str, Type[Any]]: - """ - Partially taken from typing.get_type_hints. - - Resolve string or ForwardRef annotations into type objects if possible. - """ - base_globals: Optional[Dict[str, Any]] = None - if module_name: - try: - module = sys.modules[module_name] - except KeyError: - # happens occasionally, see https://github.com/pydantic/pydantic/issues/2363 - pass - else: - base_globals = module.__dict__ - - annotations = {} - for name, value in raw_annotations.items(): - if isinstance(value, str): - if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1): - value = ForwardRef(value, is_argument=False, is_class=True) - else: - value = ForwardRef(value, is_argument=False) - try: - value = _eval_type(value, base_globals, None) - except NameError: - # this is ok, it can be fixed with update_forward_refs - pass - annotations[name] = value - return annotations - - -def is_callable_type(type_: Type[Any]) -> bool: - return type_ is Callable or get_origin(type_) is Callable - - -def is_literal_type(type_: Type[Any]) -> bool: - return Literal is not None and get_origin(type_) is Literal - - -def literal_values(type_: Type[Any]) -> Tuple[Any, ...]: - return get_args(type_) - - -def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]: - """ - This method is used to retrieve all Literal values as - Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586) - e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]` - """ - if not is_literal_type(type_): - return (type_,) - - values = literal_values(type_) - return tuple(x for value in values for x in all_literal_values(value)) - - -def is_namedtuple(type_: Type[Any]) -> bool: - """ - Check if a given class is a named tuple. - It can be either a `typing.NamedTuple` or `collections.namedtuple` - """ - from .utils import lenient_issubclass - - return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields') - - -def is_typeddict(type_: Type[Any]) -> bool: - """ - Check if a given class is a typed dict (from `typing` or `typing_extensions`) - In 3.10, there will be a public method (https://docs.python.org/3.10/library/typing.html#typing.is_typeddict) - """ - from .utils import lenient_issubclass - - return lenient_issubclass(type_, dict) and hasattr(type_, '__total__') - - -def _check_typeddict_special(type_: Any) -> bool: - return type_ is TypedDictRequired or type_ is TypedDictNotRequired - - -def is_typeddict_special(type_: Any) -> bool: - """ - Check if type is a TypedDict special form (Required or NotRequired). - """ - return _check_typeddict_special(type_) or _check_typeddict_special(get_origin(type_)) - - -test_type = NewType('test_type', str) - - -def is_new_type(type_: Type[Any]) -> bool: - """ - Check whether type_ was created using typing.NewType - """ - return isinstance(type_, test_type.__class__) and hasattr(type_, '__supertype__') # type: ignore - - -def new_type_supertype(type_: Type[Any]) -> Type[Any]: - while hasattr(type_, '__supertype__'): - type_ = type_.__supertype__ - return type_ - - -def _check_classvar(v: Optional[Type[Any]]) -> bool: - if v is None: - return False - - return v.__class__ == ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar' - - -def _check_finalvar(v: Optional[Type[Any]]) -> bool: - """ - Check if a given type is a `typing.Final` type. - """ - if v is None: - return False - - return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final') - - -def is_classvar(ann_type: Type[Any]) -> bool: - if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)): - return True - - # this is an ugly workaround for class vars that contain forward references and are therefore themselves - # forward references, see #3679 - if ann_type.__class__ == ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['): - return True - - return False - - -def is_finalvar(ann_type: Type[Any]) -> bool: - return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type)) - - -def update_field_forward_refs(field: 'ModelField', globalns: Any, localns: Any) -> None: - """ - Try to update ForwardRefs on fields based on this ModelField, globalns and localns. - """ - prepare = False - if field.type_.__class__ == ForwardRef: - prepare = True - field.type_ = evaluate_forwardref(field.type_, globalns, localns or None) - if field.outer_type_.__class__ == ForwardRef: - prepare = True - field.outer_type_ = evaluate_forwardref(field.outer_type_, globalns, localns or None) - if prepare: - field.prepare() - - if field.sub_fields: - for sub_f in field.sub_fields: - update_field_forward_refs(sub_f, globalns=globalns, localns=localns) - - if field.discriminator_key is not None: - field.prepare_discriminated_union_sub_fields() - - -def update_model_forward_refs( - model: Type[Any], - fields: Iterable['ModelField'], - json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable], - localns: 'DictStrAny', - exc_to_suppress: Tuple[Type[BaseException], ...] = (), -) -> None: - """ - Try to update model fields ForwardRefs based on model and localns. - """ - if model.__module__ in sys.modules: - globalns = sys.modules[model.__module__].__dict__.copy() - else: - globalns = {} - - globalns.setdefault(model.__name__, model) - - for f in fields: - try: - update_field_forward_refs(f, globalns=globalns, localns=localns) - except exc_to_suppress: - pass - - for key in set(json_encoders.keys()): - if isinstance(key, str): - fr: ForwardRef = ForwardRef(key) - elif isinstance(key, ForwardRef): - fr = key - else: - continue - - try: - new_key = evaluate_forwardref(fr, globalns, localns or None) - except exc_to_suppress: # pragma: no cover - continue - - json_encoders[new_key] = json_encoders.pop(key) - - -def get_class(type_: Type[Any]) -> Union[None, bool, Type[Any]]: - """ - Tries to get the class of a Type[T] annotation. Returns True if Type is used - without brackets. Otherwise returns None. - """ - if type_ is type: - return True - - if get_origin(type_) is None: - return None - - args = get_args(type_) - if not args or not isinstance(args[0], type): - return True - else: - return args[0] - - -def get_sub_types(tp: Any) -> List[Any]: - """ - Return all the types that are allowed by type `tp` - `tp` can be a `Union` of allowed types or an `Annotated` type - """ - origin = get_origin(tp) - if origin is Annotated: - return get_sub_types(get_args(tp)[0]) - elif is_union(origin): - return [x for t in get_args(tp) for x in get_sub_types(t)] - else: - return [tp] +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/utils.py b/lib/pydantic/utils.py index 1d016c0e..1619d1db 100644 --- a/lib/pydantic/utils.py +++ b/lib/pydantic/utils.py @@ -1,841 +1,4 @@ -import keyword -import warnings -import weakref -from collections import OrderedDict, defaultdict, deque -from copy import deepcopy -from itertools import islice, zip_longest -from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType -from typing import ( - TYPE_CHECKING, - AbstractSet, - Any, - Callable, - Collection, - Dict, - Generator, - Iterable, - Iterator, - List, - Mapping, - MutableMapping, - NoReturn, - Optional, - Set, - Tuple, - Type, - TypeVar, - Union, -) +"""The `utils` module is a backport module from V1.""" +from ._migration import getattr_migration -from typing_extensions import Annotated - -from .errors import ConfigError -from .typing import ( - NoneType, - WithArgsTypes, - all_literal_values, - display_as_type, - get_args, - get_origin, - is_literal_type, - is_union, -) -from .version import version_info - -if TYPE_CHECKING: - from inspect import Signature - from pathlib import Path - - from .config import BaseConfig - from .dataclasses import Dataclass - from .fields import ModelField - from .main import BaseModel - from .typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs - - RichReprResult = Iterable[Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]] - -__all__ = ( - 'import_string', - 'sequence_like', - 'validate_field_name', - 'lenient_isinstance', - 'lenient_issubclass', - 'in_ipython', - 'is_valid_identifier', - 'deep_update', - 'update_not_none', - 'almost_equal_floats', - 'get_model', - 'to_camel', - 'is_valid_field', - 'smart_deepcopy', - 'PyObjectStr', - 'Representation', - 'GetterDict', - 'ValueItems', - 'version_info', # required here to match behaviour in v1.3 - 'ClassAttribute', - 'path_type', - 'ROOT_KEY', - 'get_unique_discriminator_alias', - 'get_discriminator_alias_and_values', - 'DUNDER_ATTRIBUTES', - 'LimitedDict', -) - -ROOT_KEY = '__root__' -# these are types that are returned unchanged by deepcopy -IMMUTABLE_NON_COLLECTIONS_TYPES: Set[Type[Any]] = { - int, - float, - complex, - str, - bool, - bytes, - type, - NoneType, - FunctionType, - BuiltinFunctionType, - LambdaType, - weakref.ref, - CodeType, - # note: including ModuleType will differ from behaviour of deepcopy by not producing error. - # It might be not a good idea in general, but considering that this function used only internally - # against default values of fields, this will allow to actually have a field with module as default value - ModuleType, - NotImplemented.__class__, - Ellipsis.__class__, -} - -# these are types that if empty, might be copied with simple copy() instead of deepcopy() -BUILTIN_COLLECTIONS: Set[Type[Any]] = { - list, - set, - tuple, - frozenset, - dict, - OrderedDict, - defaultdict, - deque, -} - - -def import_string(dotted_path: str) -> Any: - """ - Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the - last name in the path. Raise ImportError if the import fails. - """ - from importlib import import_module - - try: - module_path, class_name = dotted_path.strip(' ').rsplit('.', 1) - except ValueError as e: - raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e - - module = import_module(module_path) - try: - return getattr(module, class_name) - except AttributeError as e: - raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e - - -def truncate(v: Union[str], *, max_len: int = 80) -> str: - """ - Truncate a value and add a unicode ellipsis (three dots) to the end if it was too long - """ - warnings.warn('`truncate` is no-longer used by pydantic and is deprecated', DeprecationWarning) - if isinstance(v, str) and len(v) > (max_len - 2): - # -3 so quote + string + … + quote has correct length - return (v[: (max_len - 3)] + '…').__repr__() - try: - v = v.__repr__() - except TypeError: - v = v.__class__.__repr__(v) # in case v is a type - if len(v) > max_len: - v = v[: max_len - 1] + '…' - return v - - -def sequence_like(v: Any) -> bool: - return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque)) - - -def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None: - """ - Ensure that the field's name does not shadow an existing attribute of the model. - """ - for base in bases: - if getattr(base, field_name, None): - raise NameError( - f'Field name "{field_name}" shadows a BaseModel attribute; ' - f'use a different field name with "alias=\'{field_name}\'".' - ) - - -def lenient_isinstance(o: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool: - try: - return isinstance(o, class_or_tuple) # type: ignore[arg-type] - except TypeError: - return False - - -def lenient_issubclass(cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool: - try: - return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type] - except TypeError: - if isinstance(cls, WithArgsTypes): - return False - raise # pragma: no cover - - -def in_ipython() -> bool: - """ - Check whether we're in an ipython environment, including jupyter notebooks. - """ - try: - eval('__IPYTHON__') - except NameError: - return False - else: # pragma: no cover - return True - - -def is_valid_identifier(identifier: str) -> bool: - """ - Checks that a string is a valid identifier and not a Python keyword. - :param identifier: The identifier to test. - :return: True if the identifier is valid. - """ - return identifier.isidentifier() and not keyword.iskeyword(identifier) - - -KeyType = TypeVar('KeyType') - - -def deep_update(mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]) -> Dict[KeyType, Any]: - updated_mapping = mapping.copy() - for updating_mapping in updating_mappings: - for k, v in updating_mapping.items(): - if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict): - updated_mapping[k] = deep_update(updated_mapping[k], v) - else: - updated_mapping[k] = v - return updated_mapping - - -def update_not_none(mapping: Dict[Any, Any], **update: Any) -> None: - mapping.update({k: v for k, v in update.items() if v is not None}) - - -def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool: - """ - Return True if two floats are almost equal - """ - return abs(value_1 - value_2) <= delta - - -def generate_model_signature( - init: Callable[..., None], fields: Dict[str, 'ModelField'], config: Type['BaseConfig'] -) -> 'Signature': - """ - Generate signature for model based on its fields - """ - from inspect import Parameter, Signature, signature - - from .config import Extra - - present_params = signature(init).parameters.values() - merged_params: Dict[str, Parameter] = {} - var_kw = None - use_var_kw = False - - for param in islice(present_params, 1, None): # skip self arg - if param.kind is param.VAR_KEYWORD: - var_kw = param - continue - merged_params[param.name] = param - - if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through - allow_names = config.allow_population_by_field_name - for field_name, field in fields.items(): - param_name = field.alias - if field_name in merged_params or param_name in merged_params: - continue - elif not is_valid_identifier(param_name): - if allow_names and is_valid_identifier(field_name): - param_name = field_name - else: - use_var_kw = True - continue - - # TODO: replace annotation with actual expected types once #1055 solved - kwargs = {'default': field.default} if not field.required else {} - merged_params[param_name] = Parameter( - param_name, Parameter.KEYWORD_ONLY, annotation=field.annotation, **kwargs - ) - - if config.extra is Extra.allow: - use_var_kw = True - - if var_kw and use_var_kw: - # Make sure the parameter for extra kwargs - # does not have the same name as a field - default_model_signature = [ - ('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD), - ('data', Parameter.VAR_KEYWORD), - ] - if [(p.name, p.kind) for p in present_params] == default_model_signature: - # if this is the standard model signature, use extra_data as the extra args name - var_kw_name = 'extra_data' - else: - # else start from var_kw - var_kw_name = var_kw.name - - # generate a name that's definitely unique - while var_kw_name in fields: - var_kw_name += '_' - merged_params[var_kw_name] = var_kw.replace(name=var_kw_name) - - return Signature(parameters=list(merged_params.values()), return_annotation=None) - - -def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']: - from .main import BaseModel - - try: - model_cls = obj.__pydantic_model__ # type: ignore - except AttributeError: - model_cls = obj - - if not issubclass(model_cls, BaseModel): - raise TypeError('Unsupported type, must be either BaseModel or dataclass') - return model_cls - - -def to_camel(string: str) -> str: - return ''.join(word.capitalize() for word in string.split('_')) - - -def to_lower_camel(string: str) -> str: - if len(string) >= 1: - pascal_string = to_camel(string) - return pascal_string[0].lower() + pascal_string[1:] - return string.lower() - - -T = TypeVar('T') - - -def unique_list( - input_list: Union[List[T], Tuple[T, ...]], - *, - name_factory: Callable[[T], str] = str, -) -> List[T]: - """ - Make a list unique while maintaining order. - We update the list if another one with the same name is set - (e.g. root validator overridden in subclass) - """ - result: List[T] = [] - result_names: List[str] = [] - for v in input_list: - v_name = name_factory(v) - if v_name not in result_names: - result_names.append(v_name) - result.append(v) - else: - result[result_names.index(v_name)] = v - - return result - - -class PyObjectStr(str): - """ - String class where repr doesn't include quotes. Useful with Representation when you want to return a string - representation of something that valid (or pseudo-valid) python. - """ - - def __repr__(self) -> str: - return str(self) - - -class Representation: - """ - Mixin to provide __str__, __repr__, and __pretty__ methods. See #884 for more details. - - __pretty__ is used by [devtools](https://python-devtools.helpmanual.io/) to provide human readable representations - of objects. - """ - - __slots__: Tuple[str, ...] = tuple() - - def __repr_args__(self) -> 'ReprArgs': - """ - Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden. - - Can either return: - * name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]` - * or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]` - """ - attrs = ((s, getattr(self, s)) for s in self.__slots__) - return [(a, v) for a, v in attrs if v is not None] - - def __repr_name__(self) -> str: - """ - Name of the instance's class, used in __repr__. - """ - return self.__class__.__name__ - - def __repr_str__(self, join_str: str) -> str: - return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__()) - - def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]: - """ - Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects - """ - yield self.__repr_name__() + '(' - yield 1 - for name, value in self.__repr_args__(): - if name is not None: - yield name + '=' - yield fmt(value) - yield ',' - yield 0 - yield -1 - yield ')' - - def __str__(self) -> str: - return self.__repr_str__(' ') - - def __repr__(self) -> str: - return f'{self.__repr_name__()}({self.__repr_str__(", ")})' - - def __rich_repr__(self) -> 'RichReprResult': - """Get fields for Rich library""" - for name, field_repr in self.__repr_args__(): - if name is None: - yield field_repr - else: - yield name, field_repr - - -class GetterDict(Representation): - """ - Hack to make object's smell just enough like dicts for validate_model. - - We can't inherit from Mapping[str, Any] because it upsets cython so we have to implement all methods ourselves. - """ - - __slots__ = ('_obj',) - - def __init__(self, obj: Any): - self._obj = obj - - def __getitem__(self, key: str) -> Any: - try: - return getattr(self._obj, key) - except AttributeError as e: - raise KeyError(key) from e - - def get(self, key: Any, default: Any = None) -> Any: - return getattr(self._obj, key, default) - - def extra_keys(self) -> Set[Any]: - """ - We don't want to get any other attributes of obj if the model didn't explicitly ask for them - """ - return set() - - def keys(self) -> List[Any]: - """ - Keys of the pseudo dictionary, uses a list not set so order information can be maintained like python - dictionaries. - """ - return list(self) - - def values(self) -> List[Any]: - return [self[k] for k in self] - - def items(self) -> Iterator[Tuple[str, Any]]: - for k in self: - yield k, self.get(k) - - def __iter__(self) -> Iterator[str]: - for name in dir(self._obj): - if not name.startswith('_'): - yield name - - def __len__(self) -> int: - return sum(1 for _ in self) - - def __contains__(self, item: Any) -> bool: - return item in self.keys() - - def __eq__(self, other: Any) -> bool: - return dict(self) == dict(other.items()) - - def __repr_args__(self) -> 'ReprArgs': - return [(None, dict(self))] - - def __repr_name__(self) -> str: - return f'GetterDict[{display_as_type(self._obj)}]' - - -class ValueItems(Representation): - """ - Class for more convenient calculation of excluded or included fields on values. - """ - - __slots__ = ('_items', '_type') - - def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> None: - items = self._coerce_items(items) - - if isinstance(value, (list, tuple)): - items = self._normalize_indexes(items, len(value)) - - self._items: 'MappingIntStrAny' = items - - def is_excluded(self, item: Any) -> bool: - """ - Check if item is fully excluded. - - :param item: key or index of a value - """ - return self.is_true(self._items.get(item)) - - def is_included(self, item: Any) -> bool: - """ - Check if value is contained in self._items - - :param item: key or index of value - """ - return item in self._items - - def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]: - """ - :param e: key or index of element on value - :return: raw values for element if self._items is dict and contain needed element - """ - - item = self._items.get(e) - return item if not self.is_true(item) else None - - def _normalize_indexes(self, items: 'MappingIntStrAny', v_length: int) -> 'DictIntStrAny': - """ - :param items: dict or set of indexes which will be normalized - :param v_length: length of sequence indexes of which will be - - >>> self._normalize_indexes({0: True, -2: True, -1: True}, 4) - {0: True, 2: True, 3: True} - >>> self._normalize_indexes({'__all__': True}, 4) - {0: True, 1: True, 2: True, 3: True} - """ - - normalized_items: 'DictIntStrAny' = {} - all_items = None - for i, v in items.items(): - if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or self.is_true(v)): - raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}') - if i == '__all__': - all_items = self._coerce_value(v) - continue - if not isinstance(i, int): - raise TypeError( - 'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: ' - 'expected integer keys or keyword "__all__"' - ) - normalized_i = v_length + i if i < 0 else i - normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i)) - - if not all_items: - return normalized_items - if self.is_true(all_items): - for i in range(v_length): - normalized_items.setdefault(i, ...) - return normalized_items - for i in range(v_length): - normalized_item = normalized_items.setdefault(i, {}) - if not self.is_true(normalized_item): - normalized_items[i] = self.merge(all_items, normalized_item) - return normalized_items - - @classmethod - def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any: - """ - Merge a ``base`` item with an ``override`` item. - - Both ``base`` and ``override`` are converted to dictionaries if possible. - Sets are converted to dictionaries with the sets entries as keys and - Ellipsis as values. - - Each key-value pair existing in ``base`` is merged with ``override``, - while the rest of the key-value pairs are updated recursively with this function. - - Merging takes place based on the "union" of keys if ``intersect`` is - set to ``False`` (default) and on the intersection of keys if - ``intersect`` is set to ``True``. - """ - override = cls._coerce_value(override) - base = cls._coerce_value(base) - if override is None: - return base - if cls.is_true(base) or base is None: - return override - if cls.is_true(override): - return base if intersect else override - - # intersection or union of keys while preserving ordering: - if intersect: - merge_keys = [k for k in base if k in override] + [k for k in override if k in base] - else: - merge_keys = list(base) + [k for k in override if k not in base] - - merged: 'DictIntStrAny' = {} - for k in merge_keys: - merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect) - if merged_item is not None: - merged[k] = merged_item - - return merged - - @staticmethod - def _coerce_items(items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> 'MappingIntStrAny': - if isinstance(items, Mapping): - pass - elif isinstance(items, AbstractSet): - items = dict.fromkeys(items, ...) - else: - class_name = getattr(items, '__class__', '???') - assert_never( - items, - f'Unexpected type of exclude value {class_name}', - ) - return items - - @classmethod - def _coerce_value(cls, value: Any) -> Any: - if value is None or cls.is_true(value): - return value - return cls._coerce_items(value) - - @staticmethod - def is_true(v: Any) -> bool: - return v is True or v is ... - - def __repr_args__(self) -> 'ReprArgs': - return [(None, self._items)] - - -class ClassAttribute: - """ - Hide class attribute from its instances - """ - - __slots__ = ( - 'name', - 'value', - ) - - def __init__(self, name: str, value: Any) -> None: - self.name = name - self.value = value - - def __get__(self, instance: Any, owner: Type[Any]) -> None: - if instance is None: - return self.value - raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only') - - -path_types = { - 'is_dir': 'directory', - 'is_file': 'file', - 'is_mount': 'mount point', - 'is_symlink': 'symlink', - 'is_block_device': 'block device', - 'is_char_device': 'char device', - 'is_fifo': 'FIFO', - 'is_socket': 'socket', -} - - -def path_type(p: 'Path') -> str: - """ - Find out what sort of thing a path is. - """ - assert p.exists(), 'path does not exist' - for method, name in path_types.items(): - if getattr(p, method)(): - return name - - return 'unknown' - - -Obj = TypeVar('Obj') - - -def smart_deepcopy(obj: Obj) -> Obj: - """ - Return type as is for immutable built-in types - Use obj.copy() for built-in empty collections - Use copy.deepcopy() for non-empty collections and unknown objects - """ - - obj_type = obj.__class__ - if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES: - return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway - try: - if not obj and obj_type in BUILTIN_COLLECTIONS: - # faster way for empty collections, no need to copy its members - return obj if obj_type is tuple else obj.copy() # type: ignore # tuple doesn't have copy method - except (TypeError, ValueError, RuntimeError): - # do we really dare to catch ALL errors? Seems a bit risky - pass - - return deepcopy(obj) # slowest way when we actually might need a deepcopy - - -def is_valid_field(name: str) -> bool: - if not name.startswith('_'): - return True - return ROOT_KEY == name - - -DUNDER_ATTRIBUTES = { - '__annotations__', - '__classcell__', - '__doc__', - '__module__', - '__orig_bases__', - '__orig_class__', - '__qualname__', -} - - -def is_valid_private_name(name: str) -> bool: - return not is_valid_field(name) and name not in DUNDER_ATTRIBUTES - - -_EMPTY = object() - - -def all_identical(left: Iterable[Any], right: Iterable[Any]) -> bool: - """ - Check that the items of `left` are the same objects as those in `right`. - - >>> a, b = object(), object() - >>> all_identical([a, b, a], [a, b, a]) - True - >>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical" - False - """ - for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY): - if left_item is not right_item: - return False - return True - - -def assert_never(obj: NoReturn, msg: str) -> NoReturn: - """ - Helper to make sure that we have covered all possible types. - - This is mostly useful for ``mypy``, docs: - https://mypy.readthedocs.io/en/latest/literal_types.html#exhaustive-checks - """ - raise TypeError(msg) - - -def get_unique_discriminator_alias(all_aliases: Collection[str], discriminator_key: str) -> str: - """Validate that all aliases are the same and if that's the case return the alias""" - unique_aliases = set(all_aliases) - if len(unique_aliases) > 1: - raise ConfigError( - f'Aliases for discriminator {discriminator_key!r} must be the same (got {", ".join(sorted(all_aliases))})' - ) - return unique_aliases.pop() - - -def get_discriminator_alias_and_values(tp: Any, discriminator_key: str) -> Tuple[str, Tuple[str, ...]]: - """ - Get alias and all valid values in the `Literal` type of the discriminator field - `tp` can be a `BaseModel` class or directly an `Annotated` `Union` of many. - """ - is_root_model = getattr(tp, '__custom_root_type__', False) - - if get_origin(tp) is Annotated: - tp = get_args(tp)[0] - - if hasattr(tp, '__pydantic_model__'): - tp = tp.__pydantic_model__ - - if is_union(get_origin(tp)): - alias, all_values = _get_union_alias_and_all_values(tp, discriminator_key) - return alias, tuple(v for values in all_values for v in values) - elif is_root_model: - union_type = tp.__fields__[ROOT_KEY].type_ - alias, all_values = _get_union_alias_and_all_values(union_type, discriminator_key) - - if len(set(all_values)) > 1: - raise ConfigError( - f'Field {discriminator_key!r} is not the same for all submodels of {display_as_type(tp)!r}' - ) - - return alias, all_values[0] - - else: - try: - t_discriminator_type = tp.__fields__[discriminator_key].type_ - except AttributeError as e: - raise TypeError(f'Type {tp.__name__!r} is not a valid `BaseModel` or `dataclass`') from e - except KeyError as e: - raise ConfigError(f'Model {tp.__name__!r} needs a discriminator field for key {discriminator_key!r}') from e - - if not is_literal_type(t_discriminator_type): - raise ConfigError(f'Field {discriminator_key!r} of model {tp.__name__!r} needs to be a `Literal`') - - return tp.__fields__[discriminator_key].alias, all_literal_values(t_discriminator_type) - - -def _get_union_alias_and_all_values( - union_type: Type[Any], discriminator_key: str -) -> Tuple[str, Tuple[Tuple[str, ...], ...]]: - zipped_aliases_values = [get_discriminator_alias_and_values(t, discriminator_key) for t in get_args(union_type)] - # unzip: [('alias_a',('v1', 'v2)), ('alias_b', ('v3',))] => [('alias_a', 'alias_b'), (('v1', 'v2'), ('v3',))] - all_aliases, all_values = zip(*zipped_aliases_values) - return get_unique_discriminator_alias(all_aliases, discriminator_key), all_values - - -KT = TypeVar('KT') -VT = TypeVar('VT') -if TYPE_CHECKING: - # Annoying inheriting from `MutableMapping` and `dict` breaks cython, hence this work around - class LimitedDict(dict, MutableMapping[KT, VT]): # type: ignore[type-arg] - def __init__(self, size_limit: int = 1000): - ... - -else: - - class LimitedDict(dict): - """ - Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage. - - Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache. - - Annoying inheriting from `MutableMapping` breaks cython. - """ - - def __init__(self, size_limit: int = 1000): - self.size_limit = size_limit - super().__init__() - - def __setitem__(self, __key: Any, __value: Any) -> None: - super().__setitem__(__key, __value) - if len(self) > self.size_limit: - excess = len(self) - self.size_limit + self.size_limit // 10 - to_remove = list(self.keys())[:excess] - for key in to_remove: - del self[key] - - def __class_getitem__(cls, *args: Any) -> Any: - # to avoid errors with 3.7 - pass +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/v1/__init__.py b/lib/pydantic/v1/__init__.py new file mode 100644 index 00000000..3bf1418f --- /dev/null +++ b/lib/pydantic/v1/__init__.py @@ -0,0 +1,131 @@ +# flake8: noqa +from . import dataclasses +from .annotated_types import create_model_from_namedtuple, create_model_from_typeddict +from .class_validators import root_validator, validator +from .config import BaseConfig, ConfigDict, Extra +from .decorator import validate_arguments +from .env_settings import BaseSettings +from .error_wrappers import ValidationError +from .errors import * +from .fields import Field, PrivateAttr, Required +from .main import * +from .networks import * +from .parse import Protocol +from .tools import * +from .types import * +from .version import VERSION, compiled + +__version__ = VERSION + +# WARNING __all__ from .errors is not included here, it will be removed as an export here in v2 +# please use "from pydantic.errors import ..." instead +__all__ = [ + # annotated types utils + 'create_model_from_namedtuple', + 'create_model_from_typeddict', + # dataclasses + 'dataclasses', + # class_validators + 'root_validator', + 'validator', + # config + 'BaseConfig', + 'ConfigDict', + 'Extra', + # decorator + 'validate_arguments', + # env_settings + 'BaseSettings', + # error_wrappers + 'ValidationError', + # fields + 'Field', + 'Required', + # main + 'BaseModel', + 'create_model', + 'validate_model', + # network + 'AnyUrl', + 'AnyHttpUrl', + 'FileUrl', + 'HttpUrl', + 'stricturl', + 'EmailStr', + 'NameEmail', + 'IPvAnyAddress', + 'IPvAnyInterface', + 'IPvAnyNetwork', + 'PostgresDsn', + 'CockroachDsn', + 'AmqpDsn', + 'RedisDsn', + 'MongoDsn', + 'KafkaDsn', + 'validate_email', + # parse + 'Protocol', + # tools + 'parse_file_as', + 'parse_obj_as', + 'parse_raw_as', + 'schema_of', + 'schema_json_of', + # types + 'NoneStr', + 'NoneBytes', + 'StrBytes', + 'NoneStrBytes', + 'StrictStr', + 'ConstrainedBytes', + 'conbytes', + 'ConstrainedList', + 'conlist', + 'ConstrainedSet', + 'conset', + 'ConstrainedFrozenSet', + 'confrozenset', + 'ConstrainedStr', + 'constr', + 'PyObject', + 'ConstrainedInt', + 'conint', + 'PositiveInt', + 'NegativeInt', + 'NonNegativeInt', + 'NonPositiveInt', + 'ConstrainedFloat', + 'confloat', + 'PositiveFloat', + 'NegativeFloat', + 'NonNegativeFloat', + 'NonPositiveFloat', + 'FiniteFloat', + 'ConstrainedDecimal', + 'condecimal', + 'ConstrainedDate', + 'condate', + 'UUID1', + 'UUID3', + 'UUID4', + 'UUID5', + 'FilePath', + 'DirectoryPath', + 'Json', + 'JsonWrapper', + 'SecretField', + 'SecretStr', + 'SecretBytes', + 'StrictBool', + 'StrictBytes', + 'StrictInt', + 'StrictFloat', + 'PaymentCardNumber', + 'PrivateAttr', + 'ByteSize', + 'PastDate', + 'FutureDate', + # version + 'compiled', + 'VERSION', +] diff --git a/lib/pydantic/_hypothesis_plugin.py b/lib/pydantic/v1/_hypothesis_plugin.py similarity index 97% rename from lib/pydantic/_hypothesis_plugin.py rename to lib/pydantic/v1/_hypothesis_plugin.py index a56d2b98..0c529620 100644 --- a/lib/pydantic/_hypothesis_plugin.py +++ b/lib/pydantic/v1/_hypothesis_plugin.py @@ -10,7 +10,7 @@ Pydantic is installed. See also: https://hypothesis.readthedocs.io/en/latest/strategies.html#registering-strategies-via-setuptools-entry-points https://hypothesis.readthedocs.io/en/latest/data.html#hypothesis.strategies.register_type_strategy https://hypothesis.readthedocs.io/en/latest/strategies.html#interaction-with-pytest-cov -https://pydantic-docs.helpmanual.io/usage/types/#pydantic-types +https://docs.pydantic.dev/usage/types/#pydantic-types Note that because our motivation is to *improve user experience*, the strategies are always sound (never generate invalid data) but sacrifice completeness for @@ -46,7 +46,7 @@ from pydantic.utils import lenient_issubclass # # conlist() and conset() are unsupported for now, because the workarounds for # Cython and Hypothesis to handle parametrized generic types are incompatible. -# Once Cython can support 'normal' generics we'll revisit this. +# We are rethinking Hypothesis compatibility in Pydantic v2. # Emails try: @@ -168,6 +168,11 @@ st.register_type_strategy(pydantic.StrictBool, st.booleans()) st.register_type_strategy(pydantic.StrictStr, st.text()) +# FutureDate, PastDate +st.register_type_strategy(pydantic.FutureDate, st.dates(min_value=datetime.date.today() + datetime.timedelta(days=1))) +st.register_type_strategy(pydantic.PastDate, st.dates(max_value=datetime.date.today() - datetime.timedelta(days=1))) + + # Constrained-type resolver functions # # For these ones, we actually want to inspect the type in order to work out a diff --git a/lib/pydantic/annotated_types.py b/lib/pydantic/v1/annotated_types.py similarity index 100% rename from lib/pydantic/annotated_types.py rename to lib/pydantic/v1/annotated_types.py diff --git a/lib/pydantic/v1/class_validators.py b/lib/pydantic/v1/class_validators.py new file mode 100644 index 00000000..71e66509 --- /dev/null +++ b/lib/pydantic/v1/class_validators.py @@ -0,0 +1,361 @@ +import warnings +from collections import ChainMap +from functools import partial, partialmethod, wraps +from itertools import chain +from types import FunctionType +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, overload + +from .errors import ConfigError +from .typing import AnyCallable +from .utils import ROOT_KEY, in_ipython + +if TYPE_CHECKING: + from .typing import AnyClassMethod + + +class Validator: + __slots__ = 'func', 'pre', 'each_item', 'always', 'check_fields', 'skip_on_failure' + + def __init__( + self, + func: AnyCallable, + pre: bool = False, + each_item: bool = False, + always: bool = False, + check_fields: bool = False, + skip_on_failure: bool = False, + ): + self.func = func + self.pre = pre + self.each_item = each_item + self.always = always + self.check_fields = check_fields + self.skip_on_failure = skip_on_failure + + +if TYPE_CHECKING: + from inspect import Signature + + from .config import BaseConfig + from .fields import ModelField + from .types import ModelOrDc + + ValidatorCallable = Callable[[Optional[ModelOrDc], Any, Dict[str, Any], ModelField, Type[BaseConfig]], Any] + ValidatorsList = List[ValidatorCallable] + ValidatorListDict = Dict[str, List[Validator]] + +_FUNCS: Set[str] = set() +VALIDATOR_CONFIG_KEY = '__validator_config__' +ROOT_VALIDATOR_CONFIG_KEY = '__root_validator_config__' + + +def validator( + *fields: str, + pre: bool = False, + each_item: bool = False, + always: bool = False, + check_fields: bool = True, + whole: Optional[bool] = None, + allow_reuse: bool = False, +) -> Callable[[AnyCallable], 'AnyClassMethod']: + """ + Decorate methods on the class indicating that they should be used to validate fields + :param fields: which field(s) the method should be called on + :param pre: whether or not this validator should be called before the standard validators (else after) + :param each_item: for complex objects (sets, lists etc.) whether to validate individual elements rather than the + whole object + :param always: whether this method and other validators should be called even if the value is missing + :param check_fields: whether to check that the fields actually exist on the model + :param allow_reuse: whether to track and raise an error if another validator refers to the decorated function + """ + if not fields: + raise ConfigError('validator with no fields specified') + elif isinstance(fields[0], FunctionType): + raise ConfigError( + "validators should be used with fields and keyword arguments, not bare. " # noqa: Q000 + "E.g. usage should be `@validator('', ...)`" + ) + elif not all(isinstance(field, str) for field in fields): + raise ConfigError( + "validator fields should be passed as separate string args. " # noqa: Q000 + "E.g. usage should be `@validator('', '', ...)`" + ) + + if whole is not None: + warnings.warn( + 'The "whole" keyword argument is deprecated, use "each_item" (inverse meaning, default False) instead', + DeprecationWarning, + ) + assert each_item is False, '"each_item" and "whole" conflict, remove "whole"' + each_item = not whole + + def dec(f: AnyCallable) -> 'AnyClassMethod': + f_cls = _prepare_validator(f, allow_reuse) + setattr( + f_cls, + VALIDATOR_CONFIG_KEY, + ( + fields, + Validator(func=f_cls.__func__, pre=pre, each_item=each_item, always=always, check_fields=check_fields), + ), + ) + return f_cls + + return dec + + +@overload +def root_validator(_func: AnyCallable) -> 'AnyClassMethod': + ... + + +@overload +def root_validator( + *, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False +) -> Callable[[AnyCallable], 'AnyClassMethod']: + ... + + +def root_validator( + _func: Optional[AnyCallable] = None, *, pre: bool = False, allow_reuse: bool = False, skip_on_failure: bool = False +) -> Union['AnyClassMethod', Callable[[AnyCallable], 'AnyClassMethod']]: + """ + Decorate methods on a model indicating that they should be used to validate (and perhaps modify) data either + before or after standard model parsing/validation is performed. + """ + if _func: + f_cls = _prepare_validator(_func, allow_reuse) + setattr( + f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure) + ) + return f_cls + + def dec(f: AnyCallable) -> 'AnyClassMethod': + f_cls = _prepare_validator(f, allow_reuse) + setattr( + f_cls, ROOT_VALIDATOR_CONFIG_KEY, Validator(func=f_cls.__func__, pre=pre, skip_on_failure=skip_on_failure) + ) + return f_cls + + return dec + + +def _prepare_validator(function: AnyCallable, allow_reuse: bool) -> 'AnyClassMethod': + """ + Avoid validators with duplicated names since without this, validators can be overwritten silently + which generally isn't the intended behaviour, don't run in ipython (see #312) or if allow_reuse is False. + """ + f_cls = function if isinstance(function, classmethod) else classmethod(function) + if not in_ipython() and not allow_reuse: + ref = ( + getattr(f_cls.__func__, '__module__', '') + + '.' + + getattr(f_cls.__func__, '__qualname__', f'') + ) + if ref in _FUNCS: + raise ConfigError(f'duplicate validator function "{ref}"; if this is intended, set `allow_reuse=True`') + _FUNCS.add(ref) + return f_cls + + +class ValidatorGroup: + def __init__(self, validators: 'ValidatorListDict') -> None: + self.validators = validators + self.used_validators = {'*'} + + def get_validators(self, name: str) -> Optional[Dict[str, Validator]]: + self.used_validators.add(name) + validators = self.validators.get(name, []) + if name != ROOT_KEY: + validators += self.validators.get('*', []) + if validators: + return {getattr(v.func, '__name__', f''): v for v in validators} + else: + return None + + def check_for_unused(self) -> None: + unused_validators = set( + chain.from_iterable( + ( + getattr(v.func, '__name__', f'') + for v in self.validators[f] + if v.check_fields + ) + for f in (self.validators.keys() - self.used_validators) + ) + ) + if unused_validators: + fn = ', '.join(unused_validators) + raise ConfigError( + f"Validators defined with incorrect fields: {fn} " # noqa: Q000 + f"(use check_fields=False if you're inheriting from the model and intended this)" + ) + + +def extract_validators(namespace: Dict[str, Any]) -> Dict[str, List[Validator]]: + validators: Dict[str, List[Validator]] = {} + for var_name, value in namespace.items(): + validator_config = getattr(value, VALIDATOR_CONFIG_KEY, None) + if validator_config: + fields, v = validator_config + for field in fields: + if field in validators: + validators[field].append(v) + else: + validators[field] = [v] + return validators + + +def extract_root_validators(namespace: Dict[str, Any]) -> Tuple[List[AnyCallable], List[Tuple[bool, AnyCallable]]]: + from inspect import signature + + pre_validators: List[AnyCallable] = [] + post_validators: List[Tuple[bool, AnyCallable]] = [] + for name, value in namespace.items(): + validator_config: Optional[Validator] = getattr(value, ROOT_VALIDATOR_CONFIG_KEY, None) + if validator_config: + sig = signature(validator_config.func) + args = list(sig.parameters.keys()) + if args[0] == 'self': + raise ConfigError( + f'Invalid signature for root validator {name}: {sig}, "self" not permitted as first argument, ' + f'should be: (cls, values).' + ) + if len(args) != 2: + raise ConfigError(f'Invalid signature for root validator {name}: {sig}, should be: (cls, values).') + # check function signature + if validator_config.pre: + pre_validators.append(validator_config.func) + else: + post_validators.append((validator_config.skip_on_failure, validator_config.func)) + return pre_validators, post_validators + + +def inherit_validators(base_validators: 'ValidatorListDict', validators: 'ValidatorListDict') -> 'ValidatorListDict': + for field, field_validators in base_validators.items(): + if field not in validators: + validators[field] = [] + validators[field] += field_validators + return validators + + +def make_generic_validator(validator: AnyCallable) -> 'ValidatorCallable': + """ + Make a generic function which calls a validator with the right arguments. + + Unfortunately other approaches (eg. return a partial of a function that builds the arguments) is slow, + hence this laborious way of doing things. + + It's done like this so validators don't all need **kwargs in their signature, eg. any combination of + the arguments "values", "fields" and/or "config" are permitted. + """ + from inspect import signature + + if not isinstance(validator, (partial, partialmethod)): + # This should be the default case, so overhead is reduced + sig = signature(validator) + args = list(sig.parameters.keys()) + else: + # Fix the generated argument lists of partial methods + sig = signature(validator.func) + args = [ + k + for k in signature(validator.func).parameters.keys() + if k not in validator.args | validator.keywords.keys() + ] + + first_arg = args.pop(0) + if first_arg == 'self': + raise ConfigError( + f'Invalid signature for validator {validator}: {sig}, "self" not permitted as first argument, ' + f'should be: (cls, value, values, config, field), "values", "config" and "field" are all optional.' + ) + elif first_arg == 'cls': + # assume the second argument is value + return wraps(validator)(_generic_validator_cls(validator, sig, set(args[1:]))) + else: + # assume the first argument was value which has already been removed + return wraps(validator)(_generic_validator_basic(validator, sig, set(args))) + + +def prep_validators(v_funcs: Iterable[AnyCallable]) -> 'ValidatorsList': + return [make_generic_validator(f) for f in v_funcs if f] + + +all_kwargs = {'values', 'field', 'config'} + + +def _generic_validator_cls(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable': + # assume the first argument is value + has_kwargs = False + if 'kwargs' in args: + has_kwargs = True + args -= {'kwargs'} + + if not args.issubset(all_kwargs): + raise ConfigError( + f'Invalid signature for validator {validator}: {sig}, should be: ' + f'(cls, value, values, config, field), "values", "config" and "field" are all optional.' + ) + + if has_kwargs: + return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config) + elif args == set(): + return lambda cls, v, values, field, config: validator(cls, v) + elif args == {'values'}: + return lambda cls, v, values, field, config: validator(cls, v, values=values) + elif args == {'field'}: + return lambda cls, v, values, field, config: validator(cls, v, field=field) + elif args == {'config'}: + return lambda cls, v, values, field, config: validator(cls, v, config=config) + elif args == {'values', 'field'}: + return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field) + elif args == {'values', 'config'}: + return lambda cls, v, values, field, config: validator(cls, v, values=values, config=config) + elif args == {'field', 'config'}: + return lambda cls, v, values, field, config: validator(cls, v, field=field, config=config) + else: + # args == {'values', 'field', 'config'} + return lambda cls, v, values, field, config: validator(cls, v, values=values, field=field, config=config) + + +def _generic_validator_basic(validator: AnyCallable, sig: 'Signature', args: Set[str]) -> 'ValidatorCallable': + has_kwargs = False + if 'kwargs' in args: + has_kwargs = True + args -= {'kwargs'} + + if not args.issubset(all_kwargs): + raise ConfigError( + f'Invalid signature for validator {validator}: {sig}, should be: ' + f'(value, values, config, field), "values", "config" and "field" are all optional.' + ) + + if has_kwargs: + return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config) + elif args == set(): + return lambda cls, v, values, field, config: validator(v) + elif args == {'values'}: + return lambda cls, v, values, field, config: validator(v, values=values) + elif args == {'field'}: + return lambda cls, v, values, field, config: validator(v, field=field) + elif args == {'config'}: + return lambda cls, v, values, field, config: validator(v, config=config) + elif args == {'values', 'field'}: + return lambda cls, v, values, field, config: validator(v, values=values, field=field) + elif args == {'values', 'config'}: + return lambda cls, v, values, field, config: validator(v, values=values, config=config) + elif args == {'field', 'config'}: + return lambda cls, v, values, field, config: validator(v, field=field, config=config) + else: + # args == {'values', 'field', 'config'} + return lambda cls, v, values, field, config: validator(v, values=values, field=field, config=config) + + +def gather_all_validators(type_: 'ModelOrDc') -> Dict[str, 'AnyClassMethod']: + all_attributes = ChainMap(*[cls.__dict__ for cls in type_.__mro__]) # type: ignore[arg-type,var-annotated] + return { + k: v + for k, v in all_attributes.items() + if hasattr(v, VALIDATOR_CONFIG_KEY) or hasattr(v, ROOT_VALIDATOR_CONFIG_KEY) + } diff --git a/lib/pydantic/v1/color.py b/lib/pydantic/v1/color.py new file mode 100644 index 00000000..6fdc9fb1 --- /dev/null +++ b/lib/pydantic/v1/color.py @@ -0,0 +1,494 @@ +""" +Color definitions are used as per CSS3 specification: +http://www.w3.org/TR/css3-color/#svg-color + +A few colors have multiple names referring to the sames colors, eg. `grey` and `gray` or `aqua` and `cyan`. + +In these cases the LAST color when sorted alphabetically takes preferences, +eg. Color((0, 255, 255)).as_named() == 'cyan' because "cyan" comes after "aqua". +""" +import math +import re +from colorsys import hls_to_rgb, rgb_to_hls +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast + +from .errors import ColorError +from .utils import Representation, almost_equal_floats + +if TYPE_CHECKING: + from .typing import CallableGenerator, ReprArgs + +ColorTuple = Union[Tuple[int, int, int], Tuple[int, int, int, float]] +ColorType = Union[ColorTuple, str] +HslColorTuple = Union[Tuple[float, float, float], Tuple[float, float, float, float]] + + +class RGBA: + """ + Internal use only as a representation of a color. + """ + + __slots__ = 'r', 'g', 'b', 'alpha', '_tuple' + + def __init__(self, r: float, g: float, b: float, alpha: Optional[float]): + self.r = r + self.g = g + self.b = b + self.alpha = alpha + + self._tuple: Tuple[float, float, float, Optional[float]] = (r, g, b, alpha) + + def __getitem__(self, item: Any) -> Any: + return self._tuple[item] + + +# these are not compiled here to avoid import slowdown, they'll be compiled the first time they're used, then cached +r_hex_short = r'\s*(?:#|0x)?([0-9a-f])([0-9a-f])([0-9a-f])([0-9a-f])?\s*' +r_hex_long = r'\s*(?:#|0x)?([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})([0-9a-f]{2})?\s*' +_r_255 = r'(\d{1,3}(?:\.\d+)?)' +_r_comma = r'\s*,\s*' +r_rgb = fr'\s*rgb\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}\)\s*' +_r_alpha = r'(\d(?:\.\d+)?|\.\d+|\d{1,2}%)' +r_rgba = fr'\s*rgba\(\s*{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_255}{_r_comma}{_r_alpha}\s*\)\s*' +_r_h = r'(-?\d+(?:\.\d+)?|-?\.\d+)(deg|rad|turn)?' +_r_sl = r'(\d{1,3}(?:\.\d+)?)%' +r_hsl = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}\s*\)\s*' +r_hsla = fr'\s*hsl\(\s*{_r_h}{_r_comma}{_r_sl}{_r_comma}{_r_sl}{_r_comma}{_r_alpha}\s*\)\s*' + +# colors where the two hex characters are the same, if all colors match this the short version of hex colors can be used +repeat_colors = {int(c * 2, 16) for c in '0123456789abcdef'} +rads = 2 * math.pi + + +class Color(Representation): + __slots__ = '_original', '_rgba' + + def __init__(self, value: ColorType) -> None: + self._rgba: RGBA + self._original: ColorType + if isinstance(value, (tuple, list)): + self._rgba = parse_tuple(value) + elif isinstance(value, str): + self._rgba = parse_str(value) + elif isinstance(value, Color): + self._rgba = value._rgba + value = value._original + else: + raise ColorError(reason='value must be a tuple, list or string') + + # if we've got here value must be a valid color + self._original = value + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format='color') + + def original(self) -> ColorType: + """ + Original value passed to Color + """ + return self._original + + def as_named(self, *, fallback: bool = False) -> str: + if self._rgba.alpha is None: + rgb = cast(Tuple[int, int, int], self.as_rgb_tuple()) + try: + return COLORS_BY_VALUE[rgb] + except KeyError as e: + if fallback: + return self.as_hex() + else: + raise ValueError('no named color found, use fallback=True, as_hex() or as_rgb()') from e + else: + return self.as_hex() + + def as_hex(self) -> str: + """ + Hex string representing the color can be 3, 4, 6 or 8 characters depending on whether the string + a "short" representation of the color is possible and whether there's an alpha channel. + """ + values = [float_to_255(c) for c in self._rgba[:3]] + if self._rgba.alpha is not None: + values.append(float_to_255(self._rgba.alpha)) + + as_hex = ''.join(f'{v:02x}' for v in values) + if all(c in repeat_colors for c in values): + as_hex = ''.join(as_hex[c] for c in range(0, len(as_hex), 2)) + return '#' + as_hex + + def as_rgb(self) -> str: + """ + Color as an rgb(, , ) or rgba(, , , ) string. + """ + if self._rgba.alpha is None: + return f'rgb({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)})' + else: + return ( + f'rgba({float_to_255(self._rgba.r)}, {float_to_255(self._rgba.g)}, {float_to_255(self._rgba.b)}, ' + f'{round(self._alpha_float(), 2)})' + ) + + def as_rgb_tuple(self, *, alpha: Optional[bool] = None) -> ColorTuple: + """ + Color as an RGB or RGBA tuple; red, green and blue are in the range 0 to 255, alpha if included is + in the range 0 to 1. + + :param alpha: whether to include the alpha channel, options are + None - (default) include alpha only if it's set (e.g. not None) + True - always include alpha, + False - always omit alpha, + """ + r, g, b = (float_to_255(c) for c in self._rgba[:3]) + if alpha is None: + if self._rgba.alpha is None: + return r, g, b + else: + return r, g, b, self._alpha_float() + elif alpha: + return r, g, b, self._alpha_float() + else: + # alpha is False + return r, g, b + + def as_hsl(self) -> str: + """ + Color as an hsl(, , ) or hsl(, , , ) string. + """ + if self._rgba.alpha is None: + h, s, li = self.as_hsl_tuple(alpha=False) # type: ignore + return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%})' + else: + h, s, li, a = self.as_hsl_tuple(alpha=True) # type: ignore + return f'hsl({h * 360:0.0f}, {s:0.0%}, {li:0.0%}, {round(a, 2)})' + + def as_hsl_tuple(self, *, alpha: Optional[bool] = None) -> HslColorTuple: + """ + Color as an HSL or HSLA tuple, e.g. hue, saturation, lightness and optionally alpha; all elements are in + the range 0 to 1. + + NOTE: this is HSL as used in HTML and most other places, not HLS as used in python's colorsys. + + :param alpha: whether to include the alpha channel, options are + None - (default) include alpha only if it's set (e.g. not None) + True - always include alpha, + False - always omit alpha, + """ + h, l, s = rgb_to_hls(self._rgba.r, self._rgba.g, self._rgba.b) + if alpha is None: + if self._rgba.alpha is None: + return h, s, l + else: + return h, s, l, self._alpha_float() + if alpha: + return h, s, l, self._alpha_float() + else: + # alpha is False + return h, s, l + + def _alpha_float(self) -> float: + return 1 if self._rgba.alpha is None else self._rgba.alpha + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls + + def __str__(self) -> str: + return self.as_named(fallback=True) + + def __repr_args__(self) -> 'ReprArgs': + return [(None, self.as_named(fallback=True))] + [('rgb', self.as_rgb_tuple())] # type: ignore + + def __eq__(self, other: Any) -> bool: + return isinstance(other, Color) and self.as_rgb_tuple() == other.as_rgb_tuple() + + def __hash__(self) -> int: + return hash(self.as_rgb_tuple()) + + +def parse_tuple(value: Tuple[Any, ...]) -> RGBA: + """ + Parse a tuple or list as a color. + """ + if len(value) == 3: + r, g, b = (parse_color_value(v) for v in value) + return RGBA(r, g, b, None) + elif len(value) == 4: + r, g, b = (parse_color_value(v) for v in value[:3]) + return RGBA(r, g, b, parse_float_alpha(value[3])) + else: + raise ColorError(reason='tuples must have length 3 or 4') + + +def parse_str(value: str) -> RGBA: + """ + Parse a string to an RGBA tuple, trying the following formats (in this order): + * named color, see COLORS_BY_NAME below + * hex short eg. `fff` (prefix can be `#`, `0x` or nothing) + * hex long eg. `ffffff` (prefix can be `#`, `0x` or nothing) + * `rgb(, , ) ` + * `rgba(, , , )` + """ + value_lower = value.lower() + try: + r, g, b = COLORS_BY_NAME[value_lower] + except KeyError: + pass + else: + return ints_to_rgba(r, g, b, None) + + m = re.fullmatch(r_hex_short, value_lower) + if m: + *rgb, a = m.groups() + r, g, b = (int(v * 2, 16) for v in rgb) + if a: + alpha: Optional[float] = int(a * 2, 16) / 255 + else: + alpha = None + return ints_to_rgba(r, g, b, alpha) + + m = re.fullmatch(r_hex_long, value_lower) + if m: + *rgb, a = m.groups() + r, g, b = (int(v, 16) for v in rgb) + if a: + alpha = int(a, 16) / 255 + else: + alpha = None + return ints_to_rgba(r, g, b, alpha) + + m = re.fullmatch(r_rgb, value_lower) + if m: + return ints_to_rgba(*m.groups(), None) # type: ignore + + m = re.fullmatch(r_rgba, value_lower) + if m: + return ints_to_rgba(*m.groups()) # type: ignore + + m = re.fullmatch(r_hsl, value_lower) + if m: + h, h_units, s, l_ = m.groups() + return parse_hsl(h, h_units, s, l_) + + m = re.fullmatch(r_hsla, value_lower) + if m: + h, h_units, s, l_, a = m.groups() + return parse_hsl(h, h_units, s, l_, parse_float_alpha(a)) + + raise ColorError(reason='string not recognised as a valid color') + + +def ints_to_rgba(r: Union[int, str], g: Union[int, str], b: Union[int, str], alpha: Optional[float]) -> RGBA: + return RGBA(parse_color_value(r), parse_color_value(g), parse_color_value(b), parse_float_alpha(alpha)) + + +def parse_color_value(value: Union[int, str], max_val: int = 255) -> float: + """ + Parse a value checking it's a valid int in the range 0 to max_val and divide by max_val to give a number + in the range 0 to 1 + """ + try: + color = float(value) + except ValueError: + raise ColorError(reason='color values must be a valid number') + if 0 <= color <= max_val: + return color / max_val + else: + raise ColorError(reason=f'color values must be in the range 0 to {max_val}') + + +def parse_float_alpha(value: Union[None, str, float, int]) -> Optional[float]: + """ + Parse a value checking it's a valid float in the range 0 to 1 + """ + if value is None: + return None + try: + if isinstance(value, str) and value.endswith('%'): + alpha = float(value[:-1]) / 100 + else: + alpha = float(value) + except ValueError: + raise ColorError(reason='alpha values must be a valid float') + + if almost_equal_floats(alpha, 1): + return None + elif 0 <= alpha <= 1: + return alpha + else: + raise ColorError(reason='alpha values must be in the range 0 to 1') + + +def parse_hsl(h: str, h_units: str, sat: str, light: str, alpha: Optional[float] = None) -> RGBA: + """ + Parse raw hue, saturation, lightness and alpha values and convert to RGBA. + """ + s_value, l_value = parse_color_value(sat, 100), parse_color_value(light, 100) + + h_value = float(h) + if h_units in {None, 'deg'}: + h_value = h_value % 360 / 360 + elif h_units == 'rad': + h_value = h_value % rads / rads + else: + # turns + h_value = h_value % 1 + + r, g, b = hls_to_rgb(h_value, l_value, s_value) + return RGBA(r, g, b, alpha) + + +def float_to_255(c: float) -> int: + return int(round(c * 255)) + + +COLORS_BY_NAME = { + 'aliceblue': (240, 248, 255), + 'antiquewhite': (250, 235, 215), + 'aqua': (0, 255, 255), + 'aquamarine': (127, 255, 212), + 'azure': (240, 255, 255), + 'beige': (245, 245, 220), + 'bisque': (255, 228, 196), + 'black': (0, 0, 0), + 'blanchedalmond': (255, 235, 205), + 'blue': (0, 0, 255), + 'blueviolet': (138, 43, 226), + 'brown': (165, 42, 42), + 'burlywood': (222, 184, 135), + 'cadetblue': (95, 158, 160), + 'chartreuse': (127, 255, 0), + 'chocolate': (210, 105, 30), + 'coral': (255, 127, 80), + 'cornflowerblue': (100, 149, 237), + 'cornsilk': (255, 248, 220), + 'crimson': (220, 20, 60), + 'cyan': (0, 255, 255), + 'darkblue': (0, 0, 139), + 'darkcyan': (0, 139, 139), + 'darkgoldenrod': (184, 134, 11), + 'darkgray': (169, 169, 169), + 'darkgreen': (0, 100, 0), + 'darkgrey': (169, 169, 169), + 'darkkhaki': (189, 183, 107), + 'darkmagenta': (139, 0, 139), + 'darkolivegreen': (85, 107, 47), + 'darkorange': (255, 140, 0), + 'darkorchid': (153, 50, 204), + 'darkred': (139, 0, 0), + 'darksalmon': (233, 150, 122), + 'darkseagreen': (143, 188, 143), + 'darkslateblue': (72, 61, 139), + 'darkslategray': (47, 79, 79), + 'darkslategrey': (47, 79, 79), + 'darkturquoise': (0, 206, 209), + 'darkviolet': (148, 0, 211), + 'deeppink': (255, 20, 147), + 'deepskyblue': (0, 191, 255), + 'dimgray': (105, 105, 105), + 'dimgrey': (105, 105, 105), + 'dodgerblue': (30, 144, 255), + 'firebrick': (178, 34, 34), + 'floralwhite': (255, 250, 240), + 'forestgreen': (34, 139, 34), + 'fuchsia': (255, 0, 255), + 'gainsboro': (220, 220, 220), + 'ghostwhite': (248, 248, 255), + 'gold': (255, 215, 0), + 'goldenrod': (218, 165, 32), + 'gray': (128, 128, 128), + 'green': (0, 128, 0), + 'greenyellow': (173, 255, 47), + 'grey': (128, 128, 128), + 'honeydew': (240, 255, 240), + 'hotpink': (255, 105, 180), + 'indianred': (205, 92, 92), + 'indigo': (75, 0, 130), + 'ivory': (255, 255, 240), + 'khaki': (240, 230, 140), + 'lavender': (230, 230, 250), + 'lavenderblush': (255, 240, 245), + 'lawngreen': (124, 252, 0), + 'lemonchiffon': (255, 250, 205), + 'lightblue': (173, 216, 230), + 'lightcoral': (240, 128, 128), + 'lightcyan': (224, 255, 255), + 'lightgoldenrodyellow': (250, 250, 210), + 'lightgray': (211, 211, 211), + 'lightgreen': (144, 238, 144), + 'lightgrey': (211, 211, 211), + 'lightpink': (255, 182, 193), + 'lightsalmon': (255, 160, 122), + 'lightseagreen': (32, 178, 170), + 'lightskyblue': (135, 206, 250), + 'lightslategray': (119, 136, 153), + 'lightslategrey': (119, 136, 153), + 'lightsteelblue': (176, 196, 222), + 'lightyellow': (255, 255, 224), + 'lime': (0, 255, 0), + 'limegreen': (50, 205, 50), + 'linen': (250, 240, 230), + 'magenta': (255, 0, 255), + 'maroon': (128, 0, 0), + 'mediumaquamarine': (102, 205, 170), + 'mediumblue': (0, 0, 205), + 'mediumorchid': (186, 85, 211), + 'mediumpurple': (147, 112, 219), + 'mediumseagreen': (60, 179, 113), + 'mediumslateblue': (123, 104, 238), + 'mediumspringgreen': (0, 250, 154), + 'mediumturquoise': (72, 209, 204), + 'mediumvioletred': (199, 21, 133), + 'midnightblue': (25, 25, 112), + 'mintcream': (245, 255, 250), + 'mistyrose': (255, 228, 225), + 'moccasin': (255, 228, 181), + 'navajowhite': (255, 222, 173), + 'navy': (0, 0, 128), + 'oldlace': (253, 245, 230), + 'olive': (128, 128, 0), + 'olivedrab': (107, 142, 35), + 'orange': (255, 165, 0), + 'orangered': (255, 69, 0), + 'orchid': (218, 112, 214), + 'palegoldenrod': (238, 232, 170), + 'palegreen': (152, 251, 152), + 'paleturquoise': (175, 238, 238), + 'palevioletred': (219, 112, 147), + 'papayawhip': (255, 239, 213), + 'peachpuff': (255, 218, 185), + 'peru': (205, 133, 63), + 'pink': (255, 192, 203), + 'plum': (221, 160, 221), + 'powderblue': (176, 224, 230), + 'purple': (128, 0, 128), + 'red': (255, 0, 0), + 'rosybrown': (188, 143, 143), + 'royalblue': (65, 105, 225), + 'saddlebrown': (139, 69, 19), + 'salmon': (250, 128, 114), + 'sandybrown': (244, 164, 96), + 'seagreen': (46, 139, 87), + 'seashell': (255, 245, 238), + 'sienna': (160, 82, 45), + 'silver': (192, 192, 192), + 'skyblue': (135, 206, 235), + 'slateblue': (106, 90, 205), + 'slategray': (112, 128, 144), + 'slategrey': (112, 128, 144), + 'snow': (255, 250, 250), + 'springgreen': (0, 255, 127), + 'steelblue': (70, 130, 180), + 'tan': (210, 180, 140), + 'teal': (0, 128, 128), + 'thistle': (216, 191, 216), + 'tomato': (255, 99, 71), + 'turquoise': (64, 224, 208), + 'violet': (238, 130, 238), + 'wheat': (245, 222, 179), + 'white': (255, 255, 255), + 'whitesmoke': (245, 245, 245), + 'yellow': (255, 255, 0), + 'yellowgreen': (154, 205, 50), +} + +COLORS_BY_VALUE = {v: k for k, v in COLORS_BY_NAME.items()} diff --git a/lib/pydantic/v1/config.py b/lib/pydantic/v1/config.py new file mode 100644 index 00000000..a25973af --- /dev/null +++ b/lib/pydantic/v1/config.py @@ -0,0 +1,191 @@ +import json +from enum import Enum +from typing import TYPE_CHECKING, Any, Callable, Dict, ForwardRef, Optional, Tuple, Type, Union + +from typing_extensions import Literal, Protocol + +from .typing import AnyArgTCallable, AnyCallable +from .utils import GetterDict +from .version import compiled + +if TYPE_CHECKING: + from typing import overload + + from .fields import ModelField + from .main import BaseModel + + ConfigType = Type['BaseConfig'] + + class SchemaExtraCallable(Protocol): + @overload + def __call__(self, schema: Dict[str, Any]) -> None: + pass + + @overload + def __call__(self, schema: Dict[str, Any], model_class: Type[BaseModel]) -> None: + pass + +else: + SchemaExtraCallable = Callable[..., None] + +__all__ = 'BaseConfig', 'ConfigDict', 'get_config', 'Extra', 'inherit_config', 'prepare_config' + + +class Extra(str, Enum): + allow = 'allow' + ignore = 'ignore' + forbid = 'forbid' + + +# https://github.com/cython/cython/issues/4003 +# Fixed in Cython 3 and Pydantic v1 won't support Cython 3. +# Pydantic v2 doesn't depend on Cython at all. +if not compiled: + from typing_extensions import TypedDict + + class ConfigDict(TypedDict, total=False): + title: Optional[str] + anystr_lower: bool + anystr_strip_whitespace: bool + min_anystr_length: int + max_anystr_length: Optional[int] + validate_all: bool + extra: Extra + allow_mutation: bool + frozen: bool + allow_population_by_field_name: bool + use_enum_values: bool + fields: Dict[str, Union[str, Dict[str, str]]] + validate_assignment: bool + error_msg_templates: Dict[str, str] + arbitrary_types_allowed: bool + orm_mode: bool + getter_dict: Type[GetterDict] + alias_generator: Optional[Callable[[str], str]] + keep_untouched: Tuple[type, ...] + schema_extra: Union[Dict[str, object], 'SchemaExtraCallable'] + json_loads: Callable[[str], object] + json_dumps: AnyArgTCallable[str] + json_encoders: Dict[Type[object], AnyCallable] + underscore_attrs_are_private: bool + allow_inf_nan: bool + copy_on_model_validation: Literal['none', 'deep', 'shallow'] + # whether dataclass `__post_init__` should be run after validation + post_init_call: Literal['before_validation', 'after_validation'] + +else: + ConfigDict = dict # type: ignore + + +class BaseConfig: + title: Optional[str] = None + anystr_lower: bool = False + anystr_upper: bool = False + anystr_strip_whitespace: bool = False + min_anystr_length: int = 0 + max_anystr_length: Optional[int] = None + validate_all: bool = False + extra: Extra = Extra.ignore + allow_mutation: bool = True + frozen: bool = False + allow_population_by_field_name: bool = False + use_enum_values: bool = False + fields: Dict[str, Union[str, Dict[str, str]]] = {} + validate_assignment: bool = False + error_msg_templates: Dict[str, str] = {} + arbitrary_types_allowed: bool = False + orm_mode: bool = False + getter_dict: Type[GetterDict] = GetterDict + alias_generator: Optional[Callable[[str], str]] = None + keep_untouched: Tuple[type, ...] = () + schema_extra: Union[Dict[str, Any], 'SchemaExtraCallable'] = {} + json_loads: Callable[[str], Any] = json.loads + json_dumps: Callable[..., str] = json.dumps + json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable] = {} + underscore_attrs_are_private: bool = False + allow_inf_nan: bool = True + + # whether inherited models as fields should be reconstructed as base model, + # and whether such a copy should be shallow or deep + copy_on_model_validation: Literal['none', 'deep', 'shallow'] = 'shallow' + + # whether `Union` should check all allowed types before even trying to coerce + smart_union: bool = False + # whether dataclass `__post_init__` should be run before or after validation + post_init_call: Literal['before_validation', 'after_validation'] = 'before_validation' + + @classmethod + def get_field_info(cls, name: str) -> Dict[str, Any]: + """ + Get properties of FieldInfo from the `fields` property of the config class. + """ + + fields_value = cls.fields.get(name) + + if isinstance(fields_value, str): + field_info: Dict[str, Any] = {'alias': fields_value} + elif isinstance(fields_value, dict): + field_info = fields_value + else: + field_info = {} + + if 'alias' in field_info: + field_info.setdefault('alias_priority', 2) + + if field_info.get('alias_priority', 0) <= 1 and cls.alias_generator: + alias = cls.alias_generator(name) + if not isinstance(alias, str): + raise TypeError(f'Config.alias_generator must return str, not {alias.__class__}') + field_info.update(alias=alias, alias_priority=1) + return field_info + + @classmethod + def prepare_field(cls, field: 'ModelField') -> None: + """ + Optional hook to check or modify fields during model creation. + """ + pass + + +def get_config(config: Union[ConfigDict, Type[object], None]) -> Type[BaseConfig]: + if config is None: + return BaseConfig + + else: + config_dict = ( + config + if isinstance(config, dict) + else {k: getattr(config, k) for k in dir(config) if not k.startswith('__')} + ) + + class Config(BaseConfig): + ... + + for k, v in config_dict.items(): + setattr(Config, k, v) + return Config + + +def inherit_config(self_config: 'ConfigType', parent_config: 'ConfigType', **namespace: Any) -> 'ConfigType': + if not self_config: + base_classes: Tuple['ConfigType', ...] = (parent_config,) + elif self_config == parent_config: + base_classes = (self_config,) + else: + base_classes = self_config, parent_config + + namespace['json_encoders'] = { + **getattr(parent_config, 'json_encoders', {}), + **getattr(self_config, 'json_encoders', {}), + **namespace.get('json_encoders', {}), + } + + return type('Config', base_classes, namespace) + + +def prepare_config(config: Type[BaseConfig], cls_name: str) -> None: + if not isinstance(config.extra, Extra): + try: + config.extra = Extra(config.extra) + except ValueError: + raise ValueError(f'"{cls_name}": {config.extra} is not a valid value for "extra"') diff --git a/lib/pydantic/v1/dataclasses.py b/lib/pydantic/v1/dataclasses.py new file mode 100644 index 00000000..2df3987a --- /dev/null +++ b/lib/pydantic/v1/dataclasses.py @@ -0,0 +1,500 @@ +""" +The main purpose is to enhance stdlib dataclasses by adding validation +A pydantic dataclass can be generated from scratch or from a stdlib one. + +Behind the scene, a pydantic dataclass is just like a regular one on which we attach +a `BaseModel` and magic methods to trigger the validation of the data. +`__init__` and `__post_init__` are hence overridden and have extra logic to be +able to validate input data. + +When a pydantic dataclass is generated from scratch, it's just a plain dataclass +with validation triggered at initialization + +The tricky part if for stdlib dataclasses that are converted after into pydantic ones e.g. + +```py +@dataclasses.dataclass +class M: + x: int + +ValidatedM = pydantic.dataclasses.dataclass(M) +``` + +We indeed still want to support equality, hashing, repr, ... as if it was the stdlib one! + +```py +assert isinstance(ValidatedM(x=1), M) +assert ValidatedM(x=1) == M(x=1) +``` + +This means we **don't want to create a new dataclass that inherits from it** +The trick is to create a wrapper around `M` that will act as a proxy to trigger +validation without altering default `M` behaviour. +""" +import copy +import dataclasses +import sys +from contextlib import contextmanager +from functools import wraps + +try: + from functools import cached_property +except ImportError: + # cached_property available only for python3.8+ + pass + +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, Optional, Type, TypeVar, Union, overload + +from typing_extensions import dataclass_transform + +from .class_validators import gather_all_validators +from .config import BaseConfig, ConfigDict, Extra, get_config +from .error_wrappers import ValidationError +from .errors import DataclassTypeError +from .fields import Field, FieldInfo, Required, Undefined +from .main import create_model, validate_model +from .utils import ClassAttribute + +if TYPE_CHECKING: + from .main import BaseModel + from .typing import CallableGenerator, NoArgAnyCallable + + DataclassT = TypeVar('DataclassT', bound='Dataclass') + + DataclassClassOrWrapper = Union[Type['Dataclass'], 'DataclassProxy'] + + class Dataclass: + # stdlib attributes + __dataclass_fields__: ClassVar[Dict[str, Any]] + __dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams` + __post_init__: ClassVar[Callable[..., None]] + + # Added by pydantic + __pydantic_run_validation__: ClassVar[bool] + __post_init_post_parse__: ClassVar[Callable[..., None]] + __pydantic_initialised__: ClassVar[bool] + __pydantic_model__: ClassVar[Type[BaseModel]] + __pydantic_validate_values__: ClassVar[Callable[['Dataclass'], None]] + __pydantic_has_field_info_default__: ClassVar[bool] # whether a `pydantic.Field` is used as default value + + def __init__(self, *args: object, **kwargs: object) -> None: + pass + + @classmethod + def __get_validators__(cls: Type['Dataclass']) -> 'CallableGenerator': + pass + + @classmethod + def __validate__(cls: Type['DataclassT'], v: Any) -> 'DataclassT': + pass + + +__all__ = [ + 'dataclass', + 'set_validation', + 'create_pydantic_model_from_dataclass', + 'is_builtin_dataclass', + 'make_dataclass_validator', +] + +_T = TypeVar('_T') + +if sys.version_info >= (3, 10): + + @dataclass_transform(field_specifiers=(dataclasses.field, Field)) + @overload + def dataclass( + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Union[ConfigDict, Type[object], None] = None, + validate_on_init: Optional[bool] = None, + use_proxy: Optional[bool] = None, + kw_only: bool = ..., + ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']: + ... + + @dataclass_transform(field_specifiers=(dataclasses.field, Field)) + @overload + def dataclass( + _cls: Type[_T], + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Union[ConfigDict, Type[object], None] = None, + validate_on_init: Optional[bool] = None, + use_proxy: Optional[bool] = None, + kw_only: bool = ..., + ) -> 'DataclassClassOrWrapper': + ... + +else: + + @dataclass_transform(field_specifiers=(dataclasses.field, Field)) + @overload + def dataclass( + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Union[ConfigDict, Type[object], None] = None, + validate_on_init: Optional[bool] = None, + use_proxy: Optional[bool] = None, + ) -> Callable[[Type[_T]], 'DataclassClassOrWrapper']: + ... + + @dataclass_transform(field_specifiers=(dataclasses.field, Field)) + @overload + def dataclass( + _cls: Type[_T], + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Union[ConfigDict, Type[object], None] = None, + validate_on_init: Optional[bool] = None, + use_proxy: Optional[bool] = None, + ) -> 'DataclassClassOrWrapper': + ... + + +@dataclass_transform(field_specifiers=(dataclasses.field, Field)) +def dataclass( + _cls: Optional[Type[_T]] = None, + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, + config: Union[ConfigDict, Type[object], None] = None, + validate_on_init: Optional[bool] = None, + use_proxy: Optional[bool] = None, + kw_only: bool = False, +) -> Union[Callable[[Type[_T]], 'DataclassClassOrWrapper'], 'DataclassClassOrWrapper']: + """ + Like the python standard lib dataclasses but with type validation. + The result is either a pydantic dataclass that will validate input data + or a wrapper that will trigger validation around a stdlib dataclass + to avoid modifying it directly + """ + the_config = get_config(config) + + def wrap(cls: Type[Any]) -> 'DataclassClassOrWrapper': + should_use_proxy = ( + use_proxy + if use_proxy is not None + else ( + is_builtin_dataclass(cls) + and (cls.__bases__[0] is object or set(dir(cls)) == set(dir(cls.__bases__[0]))) + ) + ) + if should_use_proxy: + dc_cls_doc = '' + dc_cls = DataclassProxy(cls) + default_validate_on_init = False + else: + dc_cls_doc = cls.__doc__ or '' # needs to be done before generating dataclass + if sys.version_info >= (3, 10): + dc_cls = dataclasses.dataclass( + cls, + init=init, + repr=repr, + eq=eq, + order=order, + unsafe_hash=unsafe_hash, + frozen=frozen, + kw_only=kw_only, + ) + else: + dc_cls = dataclasses.dataclass( # type: ignore + cls, init=init, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen + ) + default_validate_on_init = True + + should_validate_on_init = default_validate_on_init if validate_on_init is None else validate_on_init + _add_pydantic_validation_attributes(cls, the_config, should_validate_on_init, dc_cls_doc) + dc_cls.__pydantic_model__.__try_update_forward_refs__(**{cls.__name__: cls}) + return dc_cls + + if _cls is None: + return wrap + + return wrap(_cls) + + +@contextmanager +def set_validation(cls: Type['DataclassT'], value: bool) -> Generator[Type['DataclassT'], None, None]: + original_run_validation = cls.__pydantic_run_validation__ + try: + cls.__pydantic_run_validation__ = value + yield cls + finally: + cls.__pydantic_run_validation__ = original_run_validation + + +class DataclassProxy: + __slots__ = '__dataclass__' + + def __init__(self, dc_cls: Type['Dataclass']) -> None: + object.__setattr__(self, '__dataclass__', dc_cls) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + with set_validation(self.__dataclass__, True): + return self.__dataclass__(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: + return getattr(self.__dataclass__, name) + + def __setattr__(self, __name: str, __value: Any) -> None: + return setattr(self.__dataclass__, __name, __value) + + def __instancecheck__(self, instance: Any) -> bool: + return isinstance(instance, self.__dataclass__) + + def __copy__(self) -> 'DataclassProxy': + return DataclassProxy(copy.copy(self.__dataclass__)) + + def __deepcopy__(self, memo: Any) -> 'DataclassProxy': + return DataclassProxy(copy.deepcopy(self.__dataclass__, memo)) + + +def _add_pydantic_validation_attributes( # noqa: C901 (ignore complexity) + dc_cls: Type['Dataclass'], + config: Type[BaseConfig], + validate_on_init: bool, + dc_cls_doc: str, +) -> None: + """ + We need to replace the right method. If no `__post_init__` has been set in the stdlib dataclass + it won't even exist (code is generated on the fly by `dataclasses`) + By default, we run validation after `__init__` or `__post_init__` if defined + """ + init = dc_cls.__init__ + + @wraps(init) + def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: + if config.extra == Extra.ignore: + init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__}) + + elif config.extra == Extra.allow: + for k, v in kwargs.items(): + self.__dict__.setdefault(k, v) + init(self, *args, **{k: v for k, v in kwargs.items() if k in self.__dataclass_fields__}) + + else: + init(self, *args, **kwargs) + + if hasattr(dc_cls, '__post_init__'): + try: + post_init = dc_cls.__post_init__.__wrapped__ # type: ignore[attr-defined] + except AttributeError: + post_init = dc_cls.__post_init__ + + @wraps(post_init) + def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: + if config.post_init_call == 'before_validation': + post_init(self, *args, **kwargs) + + if self.__class__.__pydantic_run_validation__: + self.__pydantic_validate_values__() + if hasattr(self, '__post_init_post_parse__'): + self.__post_init_post_parse__(*args, **kwargs) + + if config.post_init_call == 'after_validation': + post_init(self, *args, **kwargs) + + setattr(dc_cls, '__init__', handle_extra_init) + setattr(dc_cls, '__post_init__', new_post_init) + + else: + + @wraps(init) + def new_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: + handle_extra_init(self, *args, **kwargs) + + if self.__class__.__pydantic_run_validation__: + self.__pydantic_validate_values__() + + if hasattr(self, '__post_init_post_parse__'): + # We need to find again the initvars. To do that we use `__dataclass_fields__` instead of + # public method `dataclasses.fields` + + # get all initvars and their default values + initvars_and_values: Dict[str, Any] = {} + for i, f in enumerate(self.__class__.__dataclass_fields__.values()): + if f._field_type is dataclasses._FIELD_INITVAR: # type: ignore[attr-defined] + try: + # set arg value by default + initvars_and_values[f.name] = args[i] + except IndexError: + initvars_and_values[f.name] = kwargs.get(f.name, f.default) + + self.__post_init_post_parse__(**initvars_and_values) + + setattr(dc_cls, '__init__', new_init) + + setattr(dc_cls, '__pydantic_run_validation__', ClassAttribute('__pydantic_run_validation__', validate_on_init)) + setattr(dc_cls, '__pydantic_initialised__', False) + setattr(dc_cls, '__pydantic_model__', create_pydantic_model_from_dataclass(dc_cls, config, dc_cls_doc)) + setattr(dc_cls, '__pydantic_validate_values__', _dataclass_validate_values) + setattr(dc_cls, '__validate__', classmethod(_validate_dataclass)) + setattr(dc_cls, '__get_validators__', classmethod(_get_validators)) + + if dc_cls.__pydantic_model__.__config__.validate_assignment and not dc_cls.__dataclass_params__.frozen: + setattr(dc_cls, '__setattr__', _dataclass_validate_assignment_setattr) + + +def _get_validators(cls: 'DataclassClassOrWrapper') -> 'CallableGenerator': + yield cls.__validate__ + + +def _validate_dataclass(cls: Type['DataclassT'], v: Any) -> 'DataclassT': + with set_validation(cls, True): + if isinstance(v, cls): + v.__pydantic_validate_values__() + return v + elif isinstance(v, (list, tuple)): + return cls(*v) + elif isinstance(v, dict): + return cls(**v) + else: + raise DataclassTypeError(class_name=cls.__name__) + + +def create_pydantic_model_from_dataclass( + dc_cls: Type['Dataclass'], + config: Type[Any] = BaseConfig, + dc_cls_doc: Optional[str] = None, +) -> Type['BaseModel']: + field_definitions: Dict[str, Any] = {} + for field in dataclasses.fields(dc_cls): + default: Any = Undefined + default_factory: Optional['NoArgAnyCallable'] = None + field_info: FieldInfo + + if field.default is not dataclasses.MISSING: + default = field.default + elif field.default_factory is not dataclasses.MISSING: + default_factory = field.default_factory + else: + default = Required + + if isinstance(default, FieldInfo): + field_info = default + dc_cls.__pydantic_has_field_info_default__ = True + else: + field_info = Field(default=default, default_factory=default_factory, **field.metadata) + + field_definitions[field.name] = (field.type, field_info) + + validators = gather_all_validators(dc_cls) + model: Type['BaseModel'] = create_model( + dc_cls.__name__, + __config__=config, + __module__=dc_cls.__module__, + __validators__=validators, + __cls_kwargs__={'__resolve_forward_refs__': False}, + **field_definitions, + ) + model.__doc__ = dc_cls_doc if dc_cls_doc is not None else dc_cls.__doc__ or '' + return model + + +if sys.version_info >= (3, 8): + + def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool: + return isinstance(getattr(type(obj), k, None), cached_property) + +else: + + def _is_field_cached_property(obj: 'Dataclass', k: str) -> bool: + return False + + +def _dataclass_validate_values(self: 'Dataclass') -> None: + # validation errors can occur if this function is called twice on an already initialised dataclass. + # for example if Extra.forbid is enabled, it would consider __pydantic_initialised__ an invalid extra property + if getattr(self, '__pydantic_initialised__'): + return + if getattr(self, '__pydantic_has_field_info_default__', False): + # We need to remove `FieldInfo` values since they are not valid as input + # It's ok to do that because they are obviously the default values! + input_data = { + k: v + for k, v in self.__dict__.items() + if not (isinstance(v, FieldInfo) or _is_field_cached_property(self, k)) + } + else: + input_data = {k: v for k, v in self.__dict__.items() if not _is_field_cached_property(self, k)} + d, _, validation_error = validate_model(self.__pydantic_model__, input_data, cls=self.__class__) + if validation_error: + raise validation_error + self.__dict__.update(d) + object.__setattr__(self, '__pydantic_initialised__', True) + + +def _dataclass_validate_assignment_setattr(self: 'Dataclass', name: str, value: Any) -> None: + if self.__pydantic_initialised__: + d = dict(self.__dict__) + d.pop(name, None) + known_field = self.__pydantic_model__.__fields__.get(name, None) + if known_field: + value, error_ = known_field.validate(value, d, loc=name, cls=self.__class__) + if error_: + raise ValidationError([error_], self.__class__) + + object.__setattr__(self, name, value) + + +def is_builtin_dataclass(_cls: Type[Any]) -> bool: + """ + Whether a class is a stdlib dataclass + (useful to discriminated a pydantic dataclass that is actually a wrapper around a stdlib dataclass) + + we check that + - `_cls` is a dataclass + - `_cls` is not a processed pydantic dataclass (with a basemodel attached) + - `_cls` is not a pydantic dataclass inheriting directly from a stdlib dataclass + e.g. + ``` + @dataclasses.dataclass + class A: + x: int + + @pydantic.dataclasses.dataclass + class B(A): + y: int + ``` + In this case, when we first check `B`, we make an extra check and look at the annotations ('y'), + which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x') + """ + return ( + dataclasses.is_dataclass(_cls) + and not hasattr(_cls, '__pydantic_model__') + and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {}))) + ) + + +def make_dataclass_validator(dc_cls: Type['Dataclass'], config: Type[BaseConfig]) -> 'CallableGenerator': + """ + Create a pydantic.dataclass from a builtin dataclass to add type validation + and yield the validators + It retrieves the parameters of the dataclass and forwards them to the newly created dataclass + """ + yield from _get_validators(dataclass(dc_cls, config=config, use_proxy=True)) diff --git a/lib/pydantic/v1/datetime_parse.py b/lib/pydantic/v1/datetime_parse.py new file mode 100644 index 00000000..cfd54593 --- /dev/null +++ b/lib/pydantic/v1/datetime_parse.py @@ -0,0 +1,248 @@ +""" +Functions to parse datetime objects. + +We're using regular expressions rather than time.strptime because: +- They provide both validation and parsing. +- They're more flexible for datetimes. +- The date/datetime/time constructors produce friendlier error messages. + +Stolen from https://raw.githubusercontent.com/django/django/main/django/utils/dateparse.py at +9718fa2e8abe430c3526a9278dd976443d4ae3c6 + +Changed to: +* use standard python datetime types not django.utils.timezone +* raise ValueError when regex doesn't match rather than returning None +* support parsing unix timestamps for dates and datetimes +""" +import re +from datetime import date, datetime, time, timedelta, timezone +from typing import Dict, Optional, Type, Union + +from . import errors + +date_expr = r'(?P\d{4})-(?P\d{1,2})-(?P\d{1,2})' +time_expr = ( + r'(?P\d{1,2}):(?P\d{1,2})' + r'(?::(?P\d{1,2})(?:\.(?P\d{1,6})\d{0,6})?)?' + r'(?PZ|[+-]\d{2}(?::?\d{2})?)?$' +) + +date_re = re.compile(f'{date_expr}$') +time_re = re.compile(time_expr) +datetime_re = re.compile(f'{date_expr}[T ]{time_expr}') + +standard_duration_re = re.compile( + r'^' + r'(?:(?P-?\d+) (days?, )?)?' + r'((?:(?P-?\d+):)(?=\d+:\d+))?' + r'(?:(?P-?\d+):)?' + r'(?P-?\d+)' + r'(?:\.(?P\d{1,6})\d{0,6})?' + r'$' +) + +# Support the sections of ISO 8601 date representation that are accepted by timedelta +iso8601_duration_re = re.compile( + r'^(?P[-+]?)' + r'P' + r'(?:(?P\d+(.\d+)?)D)?' + r'(?:T' + r'(?:(?P\d+(.\d+)?)H)?' + r'(?:(?P\d+(.\d+)?)M)?' + r'(?:(?P\d+(.\d+)?)S)?' + r')?' + r'$' +) + +EPOCH = datetime(1970, 1, 1) +# if greater than this, the number is in ms, if less than or equal it's in seconds +# (in seconds this is 11th October 2603, in ms it's 20th August 1970) +MS_WATERSHED = int(2e10) +# slightly more than datetime.max in ns - (datetime.max - EPOCH).total_seconds() * 1e9 +MAX_NUMBER = int(3e20) +StrBytesIntFloat = Union[str, bytes, int, float] + + +def get_numeric(value: StrBytesIntFloat, native_expected_type: str) -> Union[None, int, float]: + if isinstance(value, (int, float)): + return value + try: + return float(value) + except ValueError: + return None + except TypeError: + raise TypeError(f'invalid type; expected {native_expected_type}, string, bytes, int or float') + + +def from_unix_seconds(seconds: Union[int, float]) -> datetime: + if seconds > MAX_NUMBER: + return datetime.max + elif seconds < -MAX_NUMBER: + return datetime.min + + while abs(seconds) > MS_WATERSHED: + seconds /= 1000 + dt = EPOCH + timedelta(seconds=seconds) + return dt.replace(tzinfo=timezone.utc) + + +def _parse_timezone(value: Optional[str], error: Type[Exception]) -> Union[None, int, timezone]: + if value == 'Z': + return timezone.utc + elif value is not None: + offset_mins = int(value[-2:]) if len(value) > 3 else 0 + offset = 60 * int(value[1:3]) + offset_mins + if value[0] == '-': + offset = -offset + try: + return timezone(timedelta(minutes=offset)) + except ValueError: + raise error() + else: + return None + + +def parse_date(value: Union[date, StrBytesIntFloat]) -> date: + """ + Parse a date/int/float/string and return a datetime.date. + + Raise ValueError if the input is well formatted but not a valid date. + Raise ValueError if the input isn't well formatted. + """ + if isinstance(value, date): + if isinstance(value, datetime): + return value.date() + else: + return value + + number = get_numeric(value, 'date') + if number is not None: + return from_unix_seconds(number).date() + + if isinstance(value, bytes): + value = value.decode() + + match = date_re.match(value) # type: ignore + if match is None: + raise errors.DateError() + + kw = {k: int(v) for k, v in match.groupdict().items()} + + try: + return date(**kw) + except ValueError: + raise errors.DateError() + + +def parse_time(value: Union[time, StrBytesIntFloat]) -> time: + """ + Parse a time/string and return a datetime.time. + + Raise ValueError if the input is well formatted but not a valid time. + Raise ValueError if the input isn't well formatted, in particular if it contains an offset. + """ + if isinstance(value, time): + return value + + number = get_numeric(value, 'time') + if number is not None: + if number >= 86400: + # doesn't make sense since the time time loop back around to 0 + raise errors.TimeError() + return (datetime.min + timedelta(seconds=number)).time() + + if isinstance(value, bytes): + value = value.decode() + + match = time_re.match(value) # type: ignore + if match is None: + raise errors.TimeError() + + kw = match.groupdict() + if kw['microsecond']: + kw['microsecond'] = kw['microsecond'].ljust(6, '0') + + tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.TimeError) + kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None} + kw_['tzinfo'] = tzinfo + + try: + return time(**kw_) # type: ignore + except ValueError: + raise errors.TimeError() + + +def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: + """ + Parse a datetime/int/float/string and return a datetime.datetime. + + This function supports time zone offsets. When the input contains one, + the output uses a timezone with a fixed offset from UTC. + + Raise ValueError if the input is well formatted but not a valid datetime. + Raise ValueError if the input isn't well formatted. + """ + if isinstance(value, datetime): + return value + + number = get_numeric(value, 'datetime') + if number is not None: + return from_unix_seconds(number) + + if isinstance(value, bytes): + value = value.decode() + + match = datetime_re.match(value) # type: ignore + if match is None: + raise errors.DateTimeError() + + kw = match.groupdict() + if kw['microsecond']: + kw['microsecond'] = kw['microsecond'].ljust(6, '0') + + tzinfo = _parse_timezone(kw.pop('tzinfo'), errors.DateTimeError) + kw_: Dict[str, Union[None, int, timezone]] = {k: int(v) for k, v in kw.items() if v is not None} + kw_['tzinfo'] = tzinfo + + try: + return datetime(**kw_) # type: ignore + except ValueError: + raise errors.DateTimeError() + + +def parse_duration(value: StrBytesIntFloat) -> timedelta: + """ + Parse a duration int/float/string and return a datetime.timedelta. + + The preferred format for durations in Django is '%d %H:%M:%S.%f'. + + Also supports ISO 8601 representation. + """ + if isinstance(value, timedelta): + return value + + if isinstance(value, (int, float)): + # below code requires a string + value = f'{value:f}' + elif isinstance(value, bytes): + value = value.decode() + + try: + match = standard_duration_re.match(value) or iso8601_duration_re.match(value) + except TypeError: + raise TypeError('invalid type; expected timedelta, string, bytes, int or float') + + if not match: + raise errors.DurationError() + + kw = match.groupdict() + sign = -1 if kw.pop('sign', '+') == '-' else 1 + if kw.get('microseconds'): + kw['microseconds'] = kw['microseconds'].ljust(6, '0') + + if kw.get('seconds') and kw.get('microseconds') and kw['seconds'].startswith('-'): + kw['microseconds'] = '-' + kw['microseconds'] + + kw_ = {k: float(v) for k, v in kw.items() if v is not None} + + return sign * timedelta(**kw_) diff --git a/lib/pydantic/v1/decorator.py b/lib/pydantic/v1/decorator.py new file mode 100644 index 00000000..089aab65 --- /dev/null +++ b/lib/pydantic/v1/decorator.py @@ -0,0 +1,264 @@ +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload + +from . import validator +from .config import Extra +from .errors import ConfigError +from .main import BaseModel, create_model +from .typing import get_all_type_hints +from .utils import to_camel + +__all__ = ('validate_arguments',) + +if TYPE_CHECKING: + from .typing import AnyCallable + + AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable) + ConfigType = Union[None, Type[Any], Dict[str, Any]] + + +@overload +def validate_arguments(func: None = None, *, config: 'ConfigType' = None) -> Callable[['AnyCallableT'], 'AnyCallableT']: + ... + + +@overload +def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT': + ... + + +def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any: + """ + Decorator to validate the arguments passed to a function. + """ + + def validate(_func: 'AnyCallable') -> 'AnyCallable': + vd = ValidatedFunction(_func, config) + + @wraps(_func) + def wrapper_function(*args: Any, **kwargs: Any) -> Any: + return vd.call(*args, **kwargs) + + wrapper_function.vd = vd # type: ignore + wrapper_function.validate = vd.init_model_instance # type: ignore + wrapper_function.raw_function = vd.raw_function # type: ignore + wrapper_function.model = vd.model # type: ignore + return wrapper_function + + if func: + return validate(func) + else: + return validate + + +ALT_V_ARGS = 'v__args' +ALT_V_KWARGS = 'v__kwargs' +V_POSITIONAL_ONLY_NAME = 'v__positional_only' +V_DUPLICATE_KWARGS = 'v__duplicate_kwargs' + + +class ValidatedFunction: + def __init__(self, function: 'AnyCallableT', config: 'ConfigType'): # noqa C901 + from inspect import Parameter, signature + + parameters: Mapping[str, Parameter] = signature(function).parameters + + if parameters.keys() & {ALT_V_ARGS, ALT_V_KWARGS, V_POSITIONAL_ONLY_NAME, V_DUPLICATE_KWARGS}: + raise ConfigError( + f'"{ALT_V_ARGS}", "{ALT_V_KWARGS}", "{V_POSITIONAL_ONLY_NAME}" and "{V_DUPLICATE_KWARGS}" ' + f'are not permitted as argument names when using the "{validate_arguments.__name__}" decorator' + ) + + self.raw_function = function + self.arg_mapping: Dict[int, str] = {} + self.positional_only_args = set() + self.v_args_name = 'args' + self.v_kwargs_name = 'kwargs' + + type_hints = get_all_type_hints(function) + takes_args = False + takes_kwargs = False + fields: Dict[str, Tuple[Any, Any]] = {} + for i, (name, p) in enumerate(parameters.items()): + if p.annotation is p.empty: + annotation = Any + else: + annotation = type_hints[name] + + default = ... if p.default is p.empty else p.default + if p.kind == Parameter.POSITIONAL_ONLY: + self.arg_mapping[i] = name + fields[name] = annotation, default + fields[V_POSITIONAL_ONLY_NAME] = List[str], None + self.positional_only_args.add(name) + elif p.kind == Parameter.POSITIONAL_OR_KEYWORD: + self.arg_mapping[i] = name + fields[name] = annotation, default + fields[V_DUPLICATE_KWARGS] = List[str], None + elif p.kind == Parameter.KEYWORD_ONLY: + fields[name] = annotation, default + elif p.kind == Parameter.VAR_POSITIONAL: + self.v_args_name = name + fields[name] = Tuple[annotation, ...], None + takes_args = True + else: + assert p.kind == Parameter.VAR_KEYWORD, p.kind + self.v_kwargs_name = name + fields[name] = Dict[str, annotation], None # type: ignore + takes_kwargs = True + + # these checks avoid a clash between "args" and a field with that name + if not takes_args and self.v_args_name in fields: + self.v_args_name = ALT_V_ARGS + + # same with "kwargs" + if not takes_kwargs and self.v_kwargs_name in fields: + self.v_kwargs_name = ALT_V_KWARGS + + if not takes_args: + # we add the field so validation below can raise the correct exception + fields[self.v_args_name] = List[Any], None + + if not takes_kwargs: + # same with kwargs + fields[self.v_kwargs_name] = Dict[Any, Any], None + + self.create_model(fields, takes_args, takes_kwargs, config) + + def init_model_instance(self, *args: Any, **kwargs: Any) -> BaseModel: + values = self.build_values(args, kwargs) + return self.model(**values) + + def call(self, *args: Any, **kwargs: Any) -> Any: + m = self.init_model_instance(*args, **kwargs) + return self.execute(m) + + def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]: + values: Dict[str, Any] = {} + if args: + arg_iter = enumerate(args) + while True: + try: + i, a = next(arg_iter) + except StopIteration: + break + arg_name = self.arg_mapping.get(i) + if arg_name is not None: + values[arg_name] = a + else: + values[self.v_args_name] = [a] + [a for _, a in arg_iter] + break + + var_kwargs: Dict[str, Any] = {} + wrong_positional_args = [] + duplicate_kwargs = [] + fields_alias = [ + field.alias + for name, field in self.model.__fields__.items() + if name not in (self.v_args_name, self.v_kwargs_name) + ] + non_var_fields = set(self.model.__fields__) - {self.v_args_name, self.v_kwargs_name} + for k, v in kwargs.items(): + if k in non_var_fields or k in fields_alias: + if k in self.positional_only_args: + wrong_positional_args.append(k) + if k in values: + duplicate_kwargs.append(k) + values[k] = v + else: + var_kwargs[k] = v + + if var_kwargs: + values[self.v_kwargs_name] = var_kwargs + if wrong_positional_args: + values[V_POSITIONAL_ONLY_NAME] = wrong_positional_args + if duplicate_kwargs: + values[V_DUPLICATE_KWARGS] = duplicate_kwargs + return values + + def execute(self, m: BaseModel) -> Any: + d = {k: v for k, v in m._iter() if k in m.__fields_set__ or m.__fields__[k].default_factory} + var_kwargs = d.pop(self.v_kwargs_name, {}) + + if self.v_args_name in d: + args_: List[Any] = [] + in_kwargs = False + kwargs = {} + for name, value in d.items(): + if in_kwargs: + kwargs[name] = value + elif name == self.v_args_name: + args_ += value + in_kwargs = True + else: + args_.append(value) + return self.raw_function(*args_, **kwargs, **var_kwargs) + elif self.positional_only_args: + args_ = [] + kwargs = {} + for name, value in d.items(): + if name in self.positional_only_args: + args_.append(value) + else: + kwargs[name] = value + return self.raw_function(*args_, **kwargs, **var_kwargs) + else: + return self.raw_function(**d, **var_kwargs) + + def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None: + pos_args = len(self.arg_mapping) + + class CustomConfig: + pass + + if not TYPE_CHECKING: # pragma: no branch + if isinstance(config, dict): + CustomConfig = type('Config', (), config) # noqa: F811 + elif config is not None: + CustomConfig = config # noqa: F811 + + if hasattr(CustomConfig, 'fields') or hasattr(CustomConfig, 'alias_generator'): + raise ConfigError( + 'Setting the "fields" and "alias_generator" property on custom Config for ' + '@validate_arguments is not yet supported, please remove.' + ) + + class DecoratorBaseModel(BaseModel): + @validator(self.v_args_name, check_fields=False, allow_reuse=True) + def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]: + if takes_args or v is None: + return v + + raise TypeError(f'{pos_args} positional arguments expected but {pos_args + len(v)} given') + + @validator(self.v_kwargs_name, check_fields=False, allow_reuse=True) + def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + if takes_kwargs or v is None: + return v + + plural = '' if len(v) == 1 else 's' + keys = ', '.join(map(repr, v.keys())) + raise TypeError(f'unexpected keyword argument{plural}: {keys}') + + @validator(V_POSITIONAL_ONLY_NAME, check_fields=False, allow_reuse=True) + def check_positional_only(cls, v: Optional[List[str]]) -> None: + if v is None: + return + + plural = '' if len(v) == 1 else 's' + keys = ', '.join(map(repr, v)) + raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}') + + @validator(V_DUPLICATE_KWARGS, check_fields=False, allow_reuse=True) + def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None: + if v is None: + return + + plural = '' if len(v) == 1 else 's' + keys = ', '.join(map(repr, v)) + raise TypeError(f'multiple values for argument{plural}: {keys}') + + class Config(CustomConfig): + extra = getattr(CustomConfig, 'extra', Extra.forbid) + + self.model = create_model(to_camel(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields) diff --git a/lib/pydantic/v1/env_settings.py b/lib/pydantic/v1/env_settings.py new file mode 100644 index 00000000..6c446e51 --- /dev/null +++ b/lib/pydantic/v1/env_settings.py @@ -0,0 +1,350 @@ +import os +import warnings +from pathlib import Path +from typing import AbstractSet, Any, Callable, ClassVar, Dict, List, Mapping, Optional, Tuple, Type, Union + +from .config import BaseConfig, Extra +from .fields import ModelField +from .main import BaseModel +from .types import JsonWrapper +from .typing import StrPath, display_as_type, get_origin, is_union +from .utils import deep_update, lenient_issubclass, path_type, sequence_like + +env_file_sentinel = str(object()) + +SettingsSourceCallable = Callable[['BaseSettings'], Dict[str, Any]] +DotenvType = Union[StrPath, List[StrPath], Tuple[StrPath, ...]] + + +class SettingsError(ValueError): + pass + + +class BaseSettings(BaseModel): + """ + Base class for settings, allowing values to be overridden by environment variables. + + This is useful in production for secrets you do not wish to save in code, it plays nicely with docker(-compose), + Heroku and any 12 factor app design. + """ + + def __init__( + __pydantic_self__, + _env_file: Optional[DotenvType] = env_file_sentinel, + _env_file_encoding: Optional[str] = None, + _env_nested_delimiter: Optional[str] = None, + _secrets_dir: Optional[StrPath] = None, + **values: Any, + ) -> None: + # Uses something other than `self` the first arg to allow "self" as a settable attribute + super().__init__( + **__pydantic_self__._build_values( + values, + _env_file=_env_file, + _env_file_encoding=_env_file_encoding, + _env_nested_delimiter=_env_nested_delimiter, + _secrets_dir=_secrets_dir, + ) + ) + + def _build_values( + self, + init_kwargs: Dict[str, Any], + _env_file: Optional[DotenvType] = None, + _env_file_encoding: Optional[str] = None, + _env_nested_delimiter: Optional[str] = None, + _secrets_dir: Optional[StrPath] = None, + ) -> Dict[str, Any]: + # Configure built-in sources + init_settings = InitSettingsSource(init_kwargs=init_kwargs) + env_settings = EnvSettingsSource( + env_file=(_env_file if _env_file != env_file_sentinel else self.__config__.env_file), + env_file_encoding=( + _env_file_encoding if _env_file_encoding is not None else self.__config__.env_file_encoding + ), + env_nested_delimiter=( + _env_nested_delimiter if _env_nested_delimiter is not None else self.__config__.env_nested_delimiter + ), + env_prefix_len=len(self.__config__.env_prefix), + ) + file_secret_settings = SecretsSettingsSource(secrets_dir=_secrets_dir or self.__config__.secrets_dir) + # Provide a hook to set built-in sources priority and add / remove sources + sources = self.__config__.customise_sources( + init_settings=init_settings, env_settings=env_settings, file_secret_settings=file_secret_settings + ) + if sources: + return deep_update(*reversed([source(self) for source in sources])) + else: + # no one should mean to do this, but I think returning an empty dict is marginally preferable + # to an informative error and much better than a confusing error + return {} + + class Config(BaseConfig): + env_prefix: str = '' + env_file: Optional[DotenvType] = None + env_file_encoding: Optional[str] = None + env_nested_delimiter: Optional[str] = None + secrets_dir: Optional[StrPath] = None + validate_all: bool = True + extra: Extra = Extra.forbid + arbitrary_types_allowed: bool = True + case_sensitive: bool = False + + @classmethod + def prepare_field(cls, field: ModelField) -> None: + env_names: Union[List[str], AbstractSet[str]] + field_info_from_config = cls.get_field_info(field.name) + + env = field_info_from_config.get('env') or field.field_info.extra.get('env') + if env is None: + if field.has_alias: + warnings.warn( + 'aliases are no longer used by BaseSettings to define which environment variables to read. ' + 'Instead use the "env" field setting. ' + 'See https://pydantic-docs.helpmanual.io/usage/settings/#environment-variable-names', + FutureWarning, + ) + env_names = {cls.env_prefix + field.name} + elif isinstance(env, str): + env_names = {env} + elif isinstance(env, (set, frozenset)): + env_names = env + elif sequence_like(env): + env_names = list(env) + else: + raise TypeError(f'invalid field env: {env!r} ({display_as_type(env)}); should be string, list or set') + + if not cls.case_sensitive: + env_names = env_names.__class__(n.lower() for n in env_names) + field.field_info.extra['env_names'] = env_names + + @classmethod + def customise_sources( + cls, + init_settings: SettingsSourceCallable, + env_settings: SettingsSourceCallable, + file_secret_settings: SettingsSourceCallable, + ) -> Tuple[SettingsSourceCallable, ...]: + return init_settings, env_settings, file_secret_settings + + @classmethod + def parse_env_var(cls, field_name: str, raw_val: str) -> Any: + return cls.json_loads(raw_val) + + # populated by the metaclass using the Config class defined above, annotated here to help IDEs only + __config__: ClassVar[Type[Config]] + + +class InitSettingsSource: + __slots__ = ('init_kwargs',) + + def __init__(self, init_kwargs: Dict[str, Any]): + self.init_kwargs = init_kwargs + + def __call__(self, settings: BaseSettings) -> Dict[str, Any]: + return self.init_kwargs + + def __repr__(self) -> str: + return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})' + + +class EnvSettingsSource: + __slots__ = ('env_file', 'env_file_encoding', 'env_nested_delimiter', 'env_prefix_len') + + def __init__( + self, + env_file: Optional[DotenvType], + env_file_encoding: Optional[str], + env_nested_delimiter: Optional[str] = None, + env_prefix_len: int = 0, + ): + self.env_file: Optional[DotenvType] = env_file + self.env_file_encoding: Optional[str] = env_file_encoding + self.env_nested_delimiter: Optional[str] = env_nested_delimiter + self.env_prefix_len: int = env_prefix_len + + def __call__(self, settings: BaseSettings) -> Dict[str, Any]: # noqa C901 + """ + Build environment variables suitable for passing to the Model. + """ + d: Dict[str, Any] = {} + + if settings.__config__.case_sensitive: + env_vars: Mapping[str, Optional[str]] = os.environ + else: + env_vars = {k.lower(): v for k, v in os.environ.items()} + + dotenv_vars = self._read_env_files(settings.__config__.case_sensitive) + if dotenv_vars: + env_vars = {**dotenv_vars, **env_vars} + + for field in settings.__fields__.values(): + env_val: Optional[str] = None + for env_name in field.field_info.extra['env_names']: + env_val = env_vars.get(env_name) + if env_val is not None: + break + + is_complex, allow_parse_failure = self.field_is_complex(field) + if is_complex: + if env_val is None: + # field is complex but no value found so far, try explode_env_vars + env_val_built = self.explode_env_vars(field, env_vars) + if env_val_built: + d[field.alias] = env_val_built + else: + # field is complex and there's a value, decode that as JSON, then add explode_env_vars + try: + env_val = settings.__config__.parse_env_var(field.name, env_val) + except ValueError as e: + if not allow_parse_failure: + raise SettingsError(f'error parsing env var "{env_name}"') from e + + if isinstance(env_val, dict): + d[field.alias] = deep_update(env_val, self.explode_env_vars(field, env_vars)) + else: + d[field.alias] = env_val + elif env_val is not None: + # simplest case, field is not complex, we only need to add the value if it was found + d[field.alias] = env_val + + return d + + def _read_env_files(self, case_sensitive: bool) -> Dict[str, Optional[str]]: + env_files = self.env_file + if env_files is None: + return {} + + if isinstance(env_files, (str, os.PathLike)): + env_files = [env_files] + + dotenv_vars = {} + for env_file in env_files: + env_path = Path(env_file).expanduser() + if env_path.is_file(): + dotenv_vars.update( + read_env_file(env_path, encoding=self.env_file_encoding, case_sensitive=case_sensitive) + ) + + return dotenv_vars + + def field_is_complex(self, field: ModelField) -> Tuple[bool, bool]: + """ + Find out if a field is complex, and if so whether JSON errors should be ignored + """ + if lenient_issubclass(field.annotation, JsonWrapper): + return False, False + + if field.is_complex(): + allow_parse_failure = False + elif is_union(get_origin(field.type_)) and field.sub_fields and any(f.is_complex() for f in field.sub_fields): + allow_parse_failure = True + else: + return False, False + + return True, allow_parse_failure + + def explode_env_vars(self, field: ModelField, env_vars: Mapping[str, Optional[str]]) -> Dict[str, Any]: + """ + Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries. + + This is applied to a single field, hence filtering by env_var prefix. + """ + prefixes = [f'{env_name}{self.env_nested_delimiter}' for env_name in field.field_info.extra['env_names']] + result: Dict[str, Any] = {} + for env_name, env_val in env_vars.items(): + if not any(env_name.startswith(prefix) for prefix in prefixes): + continue + # we remove the prefix before splitting in case the prefix has characters in common with the delimiter + env_name_without_prefix = env_name[self.env_prefix_len :] + _, *keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter) + env_var = result + for key in keys: + env_var = env_var.setdefault(key, {}) + env_var[last_key] = env_val + + return result + + def __repr__(self) -> str: + return ( + f'EnvSettingsSource(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, ' + f'env_nested_delimiter={self.env_nested_delimiter!r})' + ) + + +class SecretsSettingsSource: + __slots__ = ('secrets_dir',) + + def __init__(self, secrets_dir: Optional[StrPath]): + self.secrets_dir: Optional[StrPath] = secrets_dir + + def __call__(self, settings: BaseSettings) -> Dict[str, Any]: + """ + Build fields from "secrets" files. + """ + secrets: Dict[str, Optional[str]] = {} + + if self.secrets_dir is None: + return secrets + + secrets_path = Path(self.secrets_dir).expanduser() + + if not secrets_path.exists(): + warnings.warn(f'directory "{secrets_path}" does not exist') + return secrets + + if not secrets_path.is_dir(): + raise SettingsError(f'secrets_dir must reference a directory, not a {path_type(secrets_path)}') + + for field in settings.__fields__.values(): + for env_name in field.field_info.extra['env_names']: + path = find_case_path(secrets_path, env_name, settings.__config__.case_sensitive) + if not path: + # path does not exist, we currently don't return a warning for this + continue + + if path.is_file(): + secret_value = path.read_text().strip() + if field.is_complex(): + try: + secret_value = settings.__config__.parse_env_var(field.name, secret_value) + except ValueError as e: + raise SettingsError(f'error parsing env var "{env_name}"') from e + + secrets[field.alias] = secret_value + else: + warnings.warn( + f'attempted to load secret file "{path}" but found a {path_type(path)} instead.', + stacklevel=4, + ) + return secrets + + def __repr__(self) -> str: + return f'SecretsSettingsSource(secrets_dir={self.secrets_dir!r})' + + +def read_env_file( + file_path: StrPath, *, encoding: str = None, case_sensitive: bool = False +) -> Dict[str, Optional[str]]: + try: + from dotenv import dotenv_values + except ImportError as e: + raise ImportError('python-dotenv is not installed, run `pip install pydantic[dotenv]`') from e + + file_vars: Dict[str, Optional[str]] = dotenv_values(file_path, encoding=encoding or 'utf8') + if not case_sensitive: + return {k.lower(): v for k, v in file_vars.items()} + else: + return file_vars + + +def find_case_path(dir_path: Path, file_name: str, case_sensitive: bool) -> Optional[Path]: + """ + Find a file within path's directory matching filename, optionally ignoring case. + """ + for f in dir_path.iterdir(): + if f.name == file_name: + return f + elif not case_sensitive and f.name.lower() == file_name.lower(): + return f + return None diff --git a/lib/pydantic/v1/error_wrappers.py b/lib/pydantic/v1/error_wrappers.py new file mode 100644 index 00000000..d89a500c --- /dev/null +++ b/lib/pydantic/v1/error_wrappers.py @@ -0,0 +1,161 @@ +import json +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple, Type, Union + +from .json import pydantic_encoder +from .utils import Representation + +if TYPE_CHECKING: + from typing_extensions import TypedDict + + from .config import BaseConfig + from .types import ModelOrDc + from .typing import ReprArgs + + Loc = Tuple[Union[int, str], ...] + + class _ErrorDictRequired(TypedDict): + loc: Loc + msg: str + type: str + + class ErrorDict(_ErrorDictRequired, total=False): + ctx: Dict[str, Any] + + +__all__ = 'ErrorWrapper', 'ValidationError' + + +class ErrorWrapper(Representation): + __slots__ = 'exc', '_loc' + + def __init__(self, exc: Exception, loc: Union[str, 'Loc']) -> None: + self.exc = exc + self._loc = loc + + def loc_tuple(self) -> 'Loc': + if isinstance(self._loc, tuple): + return self._loc + else: + return (self._loc,) + + def __repr_args__(self) -> 'ReprArgs': + return [('exc', self.exc), ('loc', self.loc_tuple())] + + +# ErrorList is something like Union[List[Union[List[ErrorWrapper], ErrorWrapper]], ErrorWrapper] +# but recursive, therefore just use: +ErrorList = Union[Sequence[Any], ErrorWrapper] + + +class ValidationError(Representation, ValueError): + __slots__ = 'raw_errors', 'model', '_error_cache' + + def __init__(self, errors: Sequence[ErrorList], model: 'ModelOrDc') -> None: + self.raw_errors = errors + self.model = model + self._error_cache: Optional[List['ErrorDict']] = None + + def errors(self) -> List['ErrorDict']: + if self._error_cache is None: + try: + config = self.model.__config__ # type: ignore + except AttributeError: + config = self.model.__pydantic_model__.__config__ # type: ignore + self._error_cache = list(flatten_errors(self.raw_errors, config)) + return self._error_cache + + def json(self, *, indent: Union[None, int, str] = 2) -> str: + return json.dumps(self.errors(), indent=indent, default=pydantic_encoder) + + def __str__(self) -> str: + errors = self.errors() + no_errors = len(errors) + return ( + f'{no_errors} validation error{"" if no_errors == 1 else "s"} for {self.model.__name__}\n' + f'{display_errors(errors)}' + ) + + def __repr_args__(self) -> 'ReprArgs': + return [('model', self.model.__name__), ('errors', self.errors())] + + +def display_errors(errors: List['ErrorDict']) -> str: + return '\n'.join(f'{_display_error_loc(e)}\n {e["msg"]} ({_display_error_type_and_ctx(e)})' for e in errors) + + +def _display_error_loc(error: 'ErrorDict') -> str: + return ' -> '.join(str(e) for e in error['loc']) + + +def _display_error_type_and_ctx(error: 'ErrorDict') -> str: + t = 'type=' + error['type'] + ctx = error.get('ctx') + if ctx: + return t + ''.join(f'; {k}={v}' for k, v in ctx.items()) + else: + return t + + +def flatten_errors( + errors: Sequence[Any], config: Type['BaseConfig'], loc: Optional['Loc'] = None +) -> Generator['ErrorDict', None, None]: + for error in errors: + if isinstance(error, ErrorWrapper): + if loc: + error_loc = loc + error.loc_tuple() + else: + error_loc = error.loc_tuple() + + if isinstance(error.exc, ValidationError): + yield from flatten_errors(error.exc.raw_errors, config, error_loc) + else: + yield error_dict(error.exc, config, error_loc) + elif isinstance(error, list): + yield from flatten_errors(error, config, loc=loc) + else: + raise RuntimeError(f'Unknown error object: {error}') + + +def error_dict(exc: Exception, config: Type['BaseConfig'], loc: 'Loc') -> 'ErrorDict': + type_ = get_exc_type(exc.__class__) + msg_template = config.error_msg_templates.get(type_) or getattr(exc, 'msg_template', None) + ctx = exc.__dict__ + if msg_template: + msg = msg_template.format(**ctx) + else: + msg = str(exc) + + d: 'ErrorDict' = {'loc': loc, 'msg': msg, 'type': type_} + + if ctx: + d['ctx'] = ctx + + return d + + +_EXC_TYPE_CACHE: Dict[Type[Exception], str] = {} + + +def get_exc_type(cls: Type[Exception]) -> str: + # slightly more efficient than using lru_cache since we don't need to worry about the cache filling up + try: + return _EXC_TYPE_CACHE[cls] + except KeyError: + r = _get_exc_type(cls) + _EXC_TYPE_CACHE[cls] = r + return r + + +def _get_exc_type(cls: Type[Exception]) -> str: + if issubclass(cls, AssertionError): + return 'assertion_error' + + base_name = 'type_error' if issubclass(cls, TypeError) else 'value_error' + if cls in (TypeError, ValueError): + # just TypeError or ValueError, no extra code + return base_name + + # if it's not a TypeError or ValueError, we just take the lowercase of the exception name + # no chaining or snake case logic, use "code" for more complex error types. + code = getattr(cls, 'code', None) or cls.__name__.replace('Error', '').lower() + return base_name + '.' + code diff --git a/lib/pydantic/v1/errors.py b/lib/pydantic/v1/errors.py new file mode 100644 index 00000000..7bdafdd1 --- /dev/null +++ b/lib/pydantic/v1/errors.py @@ -0,0 +1,646 @@ +from decimal import Decimal +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Sequence, Set, Tuple, Type, Union + +from .typing import display_as_type + +if TYPE_CHECKING: + from .typing import DictStrAny + +# explicitly state exports to avoid "from .errors import *" also importing Decimal, Path etc. +__all__ = ( + 'PydanticTypeError', + 'PydanticValueError', + 'ConfigError', + 'MissingError', + 'ExtraError', + 'NoneIsNotAllowedError', + 'NoneIsAllowedError', + 'WrongConstantError', + 'NotNoneError', + 'BoolError', + 'BytesError', + 'DictError', + 'EmailError', + 'UrlError', + 'UrlSchemeError', + 'UrlSchemePermittedError', + 'UrlUserInfoError', + 'UrlHostError', + 'UrlHostTldError', + 'UrlPortError', + 'UrlExtraError', + 'EnumError', + 'IntEnumError', + 'EnumMemberError', + 'IntegerError', + 'FloatError', + 'PathError', + 'PathNotExistsError', + 'PathNotAFileError', + 'PathNotADirectoryError', + 'PyObjectError', + 'SequenceError', + 'ListError', + 'SetError', + 'FrozenSetError', + 'TupleError', + 'TupleLengthError', + 'ListMinLengthError', + 'ListMaxLengthError', + 'ListUniqueItemsError', + 'SetMinLengthError', + 'SetMaxLengthError', + 'FrozenSetMinLengthError', + 'FrozenSetMaxLengthError', + 'AnyStrMinLengthError', + 'AnyStrMaxLengthError', + 'StrError', + 'StrRegexError', + 'NumberNotGtError', + 'NumberNotGeError', + 'NumberNotLtError', + 'NumberNotLeError', + 'NumberNotMultipleError', + 'DecimalError', + 'DecimalIsNotFiniteError', + 'DecimalMaxDigitsError', + 'DecimalMaxPlacesError', + 'DecimalWholeDigitsError', + 'DateTimeError', + 'DateError', + 'DateNotInThePastError', + 'DateNotInTheFutureError', + 'TimeError', + 'DurationError', + 'HashableError', + 'UUIDError', + 'UUIDVersionError', + 'ArbitraryTypeError', + 'ClassError', + 'SubclassError', + 'JsonError', + 'JsonTypeError', + 'PatternError', + 'DataclassTypeError', + 'CallableError', + 'IPvAnyAddressError', + 'IPvAnyInterfaceError', + 'IPvAnyNetworkError', + 'IPv4AddressError', + 'IPv6AddressError', + 'IPv4NetworkError', + 'IPv6NetworkError', + 'IPv4InterfaceError', + 'IPv6InterfaceError', + 'ColorError', + 'StrictBoolError', + 'NotDigitError', + 'LuhnValidationError', + 'InvalidLengthForBrand', + 'InvalidByteSize', + 'InvalidByteSizeUnit', + 'MissingDiscriminator', + 'InvalidDiscriminator', +) + + +def cls_kwargs(cls: Type['PydanticErrorMixin'], ctx: 'DictStrAny') -> 'PydanticErrorMixin': + """ + For built-in exceptions like ValueError or TypeError, we need to implement + __reduce__ to override the default behaviour (instead of __getstate__/__setstate__) + By default pickle protocol 2 calls `cls.__new__(cls, *args)`. + Since we only use kwargs, we need a little constructor to change that. + Note: the callable can't be a lambda as pickle looks in the namespace to find it + """ + return cls(**ctx) + + +class PydanticErrorMixin: + code: str + msg_template: str + + def __init__(self, **ctx: Any) -> None: + self.__dict__ = ctx + + def __str__(self) -> str: + return self.msg_template.format(**self.__dict__) + + def __reduce__(self) -> Tuple[Callable[..., 'PydanticErrorMixin'], Tuple[Type['PydanticErrorMixin'], 'DictStrAny']]: + return cls_kwargs, (self.__class__, self.__dict__) + + +class PydanticTypeError(PydanticErrorMixin, TypeError): + pass + + +class PydanticValueError(PydanticErrorMixin, ValueError): + pass + + +class ConfigError(RuntimeError): + pass + + +class MissingError(PydanticValueError): + msg_template = 'field required' + + +class ExtraError(PydanticValueError): + msg_template = 'extra fields not permitted' + + +class NoneIsNotAllowedError(PydanticTypeError): + code = 'none.not_allowed' + msg_template = 'none is not an allowed value' + + +class NoneIsAllowedError(PydanticTypeError): + code = 'none.allowed' + msg_template = 'value is not none' + + +class WrongConstantError(PydanticValueError): + code = 'const' + + def __str__(self) -> str: + permitted = ', '.join(repr(v) for v in self.permitted) # type: ignore + return f'unexpected value; permitted: {permitted}' + + +class NotNoneError(PydanticTypeError): + code = 'not_none' + msg_template = 'value is not None' + + +class BoolError(PydanticTypeError): + msg_template = 'value could not be parsed to a boolean' + + +class BytesError(PydanticTypeError): + msg_template = 'byte type expected' + + +class DictError(PydanticTypeError): + msg_template = 'value is not a valid dict' + + +class EmailError(PydanticValueError): + msg_template = 'value is not a valid email address' + + +class UrlError(PydanticValueError): + code = 'url' + + +class UrlSchemeError(UrlError): + code = 'url.scheme' + msg_template = 'invalid or missing URL scheme' + + +class UrlSchemePermittedError(UrlError): + code = 'url.scheme' + msg_template = 'URL scheme not permitted' + + def __init__(self, allowed_schemes: Set[str]): + super().__init__(allowed_schemes=allowed_schemes) + + +class UrlUserInfoError(UrlError): + code = 'url.userinfo' + msg_template = 'userinfo required in URL but missing' + + +class UrlHostError(UrlError): + code = 'url.host' + msg_template = 'URL host invalid' + + +class UrlHostTldError(UrlError): + code = 'url.host' + msg_template = 'URL host invalid, top level domain required' + + +class UrlPortError(UrlError): + code = 'url.port' + msg_template = 'URL port invalid, port cannot exceed 65535' + + +class UrlExtraError(UrlError): + code = 'url.extra' + msg_template = 'URL invalid, extra characters found after valid URL: {extra!r}' + + +class EnumMemberError(PydanticTypeError): + code = 'enum' + + def __str__(self) -> str: + permitted = ', '.join(repr(v.value) for v in self.enum_values) # type: ignore + return f'value is not a valid enumeration member; permitted: {permitted}' + + +class IntegerError(PydanticTypeError): + msg_template = 'value is not a valid integer' + + +class FloatError(PydanticTypeError): + msg_template = 'value is not a valid float' + + +class PathError(PydanticTypeError): + msg_template = 'value is not a valid path' + + +class _PathValueError(PydanticValueError): + def __init__(self, *, path: Path) -> None: + super().__init__(path=str(path)) + + +class PathNotExistsError(_PathValueError): + code = 'path.not_exists' + msg_template = 'file or directory at path "{path}" does not exist' + + +class PathNotAFileError(_PathValueError): + code = 'path.not_a_file' + msg_template = 'path "{path}" does not point to a file' + + +class PathNotADirectoryError(_PathValueError): + code = 'path.not_a_directory' + msg_template = 'path "{path}" does not point to a directory' + + +class PyObjectError(PydanticTypeError): + msg_template = 'ensure this value contains valid import path or valid callable: {error_message}' + + +class SequenceError(PydanticTypeError): + msg_template = 'value is not a valid sequence' + + +class IterableError(PydanticTypeError): + msg_template = 'value is not a valid iterable' + + +class ListError(PydanticTypeError): + msg_template = 'value is not a valid list' + + +class SetError(PydanticTypeError): + msg_template = 'value is not a valid set' + + +class FrozenSetError(PydanticTypeError): + msg_template = 'value is not a valid frozenset' + + +class DequeError(PydanticTypeError): + msg_template = 'value is not a valid deque' + + +class TupleError(PydanticTypeError): + msg_template = 'value is not a valid tuple' + + +class TupleLengthError(PydanticValueError): + code = 'tuple.length' + msg_template = 'wrong tuple length {actual_length}, expected {expected_length}' + + def __init__(self, *, actual_length: int, expected_length: int) -> None: + super().__init__(actual_length=actual_length, expected_length=expected_length) + + +class ListMinLengthError(PydanticValueError): + code = 'list.min_items' + msg_template = 'ensure this value has at least {limit_value} items' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class ListMaxLengthError(PydanticValueError): + code = 'list.max_items' + msg_template = 'ensure this value has at most {limit_value} items' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class ListUniqueItemsError(PydanticValueError): + code = 'list.unique_items' + msg_template = 'the list has duplicated items' + + +class SetMinLengthError(PydanticValueError): + code = 'set.min_items' + msg_template = 'ensure this value has at least {limit_value} items' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class SetMaxLengthError(PydanticValueError): + code = 'set.max_items' + msg_template = 'ensure this value has at most {limit_value} items' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class FrozenSetMinLengthError(PydanticValueError): + code = 'frozenset.min_items' + msg_template = 'ensure this value has at least {limit_value} items' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class FrozenSetMaxLengthError(PydanticValueError): + code = 'frozenset.max_items' + msg_template = 'ensure this value has at most {limit_value} items' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class AnyStrMinLengthError(PydanticValueError): + code = 'any_str.min_length' + msg_template = 'ensure this value has at least {limit_value} characters' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class AnyStrMaxLengthError(PydanticValueError): + code = 'any_str.max_length' + msg_template = 'ensure this value has at most {limit_value} characters' + + def __init__(self, *, limit_value: int) -> None: + super().__init__(limit_value=limit_value) + + +class StrError(PydanticTypeError): + msg_template = 'str type expected' + + +class StrRegexError(PydanticValueError): + code = 'str.regex' + msg_template = 'string does not match regex "{pattern}"' + + def __init__(self, *, pattern: str) -> None: + super().__init__(pattern=pattern) + + +class _NumberBoundError(PydanticValueError): + def __init__(self, *, limit_value: Union[int, float, Decimal]) -> None: + super().__init__(limit_value=limit_value) + + +class NumberNotGtError(_NumberBoundError): + code = 'number.not_gt' + msg_template = 'ensure this value is greater than {limit_value}' + + +class NumberNotGeError(_NumberBoundError): + code = 'number.not_ge' + msg_template = 'ensure this value is greater than or equal to {limit_value}' + + +class NumberNotLtError(_NumberBoundError): + code = 'number.not_lt' + msg_template = 'ensure this value is less than {limit_value}' + + +class NumberNotLeError(_NumberBoundError): + code = 'number.not_le' + msg_template = 'ensure this value is less than or equal to {limit_value}' + + +class NumberNotFiniteError(PydanticValueError): + code = 'number.not_finite_number' + msg_template = 'ensure this value is a finite number' + + +class NumberNotMultipleError(PydanticValueError): + code = 'number.not_multiple' + msg_template = 'ensure this value is a multiple of {multiple_of}' + + def __init__(self, *, multiple_of: Union[int, float, Decimal]) -> None: + super().__init__(multiple_of=multiple_of) + + +class DecimalError(PydanticTypeError): + msg_template = 'value is not a valid decimal' + + +class DecimalIsNotFiniteError(PydanticValueError): + code = 'decimal.not_finite' + msg_template = 'value is not a valid decimal' + + +class DecimalMaxDigitsError(PydanticValueError): + code = 'decimal.max_digits' + msg_template = 'ensure that there are no more than {max_digits} digits in total' + + def __init__(self, *, max_digits: int) -> None: + super().__init__(max_digits=max_digits) + + +class DecimalMaxPlacesError(PydanticValueError): + code = 'decimal.max_places' + msg_template = 'ensure that there are no more than {decimal_places} decimal places' + + def __init__(self, *, decimal_places: int) -> None: + super().__init__(decimal_places=decimal_places) + + +class DecimalWholeDigitsError(PydanticValueError): + code = 'decimal.whole_digits' + msg_template = 'ensure that there are no more than {whole_digits} digits before the decimal point' + + def __init__(self, *, whole_digits: int) -> None: + super().__init__(whole_digits=whole_digits) + + +class DateTimeError(PydanticValueError): + msg_template = 'invalid datetime format' + + +class DateError(PydanticValueError): + msg_template = 'invalid date format' + + +class DateNotInThePastError(PydanticValueError): + code = 'date.not_in_the_past' + msg_template = 'date is not in the past' + + +class DateNotInTheFutureError(PydanticValueError): + code = 'date.not_in_the_future' + msg_template = 'date is not in the future' + + +class TimeError(PydanticValueError): + msg_template = 'invalid time format' + + +class DurationError(PydanticValueError): + msg_template = 'invalid duration format' + + +class HashableError(PydanticTypeError): + msg_template = 'value is not a valid hashable' + + +class UUIDError(PydanticTypeError): + msg_template = 'value is not a valid uuid' + + +class UUIDVersionError(PydanticValueError): + code = 'uuid.version' + msg_template = 'uuid version {required_version} expected' + + def __init__(self, *, required_version: int) -> None: + super().__init__(required_version=required_version) + + +class ArbitraryTypeError(PydanticTypeError): + code = 'arbitrary_type' + msg_template = 'instance of {expected_arbitrary_type} expected' + + def __init__(self, *, expected_arbitrary_type: Type[Any]) -> None: + super().__init__(expected_arbitrary_type=display_as_type(expected_arbitrary_type)) + + +class ClassError(PydanticTypeError): + code = 'class' + msg_template = 'a class is expected' + + +class SubclassError(PydanticTypeError): + code = 'subclass' + msg_template = 'subclass of {expected_class} expected' + + def __init__(self, *, expected_class: Type[Any]) -> None: + super().__init__(expected_class=display_as_type(expected_class)) + + +class JsonError(PydanticValueError): + msg_template = 'Invalid JSON' + + +class JsonTypeError(PydanticTypeError): + code = 'json' + msg_template = 'JSON object must be str, bytes or bytearray' + + +class PatternError(PydanticValueError): + code = 'regex_pattern' + msg_template = 'Invalid regular expression' + + +class DataclassTypeError(PydanticTypeError): + code = 'dataclass' + msg_template = 'instance of {class_name}, tuple or dict expected' + + +class CallableError(PydanticTypeError): + msg_template = '{value} is not callable' + + +class EnumError(PydanticTypeError): + code = 'enum_instance' + msg_template = '{value} is not a valid Enum instance' + + +class IntEnumError(PydanticTypeError): + code = 'int_enum_instance' + msg_template = '{value} is not a valid IntEnum instance' + + +class IPvAnyAddressError(PydanticValueError): + msg_template = 'value is not a valid IPv4 or IPv6 address' + + +class IPvAnyInterfaceError(PydanticValueError): + msg_template = 'value is not a valid IPv4 or IPv6 interface' + + +class IPvAnyNetworkError(PydanticValueError): + msg_template = 'value is not a valid IPv4 or IPv6 network' + + +class IPv4AddressError(PydanticValueError): + msg_template = 'value is not a valid IPv4 address' + + +class IPv6AddressError(PydanticValueError): + msg_template = 'value is not a valid IPv6 address' + + +class IPv4NetworkError(PydanticValueError): + msg_template = 'value is not a valid IPv4 network' + + +class IPv6NetworkError(PydanticValueError): + msg_template = 'value is not a valid IPv6 network' + + +class IPv4InterfaceError(PydanticValueError): + msg_template = 'value is not a valid IPv4 interface' + + +class IPv6InterfaceError(PydanticValueError): + msg_template = 'value is not a valid IPv6 interface' + + +class ColorError(PydanticValueError): + msg_template = 'value is not a valid color: {reason}' + + +class StrictBoolError(PydanticValueError): + msg_template = 'value is not a valid boolean' + + +class NotDigitError(PydanticValueError): + code = 'payment_card_number.digits' + msg_template = 'card number is not all digits' + + +class LuhnValidationError(PydanticValueError): + code = 'payment_card_number.luhn_check' + msg_template = 'card number is not luhn valid' + + +class InvalidLengthForBrand(PydanticValueError): + code = 'payment_card_number.invalid_length_for_brand' + msg_template = 'Length for a {brand} card must be {required_length}' + + +class InvalidByteSize(PydanticValueError): + msg_template = 'could not parse value and unit from byte string' + + +class InvalidByteSizeUnit(PydanticValueError): + msg_template = 'could not interpret byte unit: {unit}' + + +class MissingDiscriminator(PydanticValueError): + code = 'discriminated_union.missing_discriminator' + msg_template = 'Discriminator {discriminator_key!r} is missing in value' + + +class InvalidDiscriminator(PydanticValueError): + code = 'discriminated_union.invalid_discriminator' + msg_template = ( + 'No match for discriminator {discriminator_key!r} and value {discriminator_value!r} ' + '(allowed values: {allowed_values})' + ) + + def __init__(self, *, discriminator_key: str, discriminator_value: Any, allowed_values: Sequence[Any]) -> None: + super().__init__( + discriminator_key=discriminator_key, + discriminator_value=discriminator_value, + allowed_values=', '.join(map(repr, allowed_values)), + ) diff --git a/lib/pydantic/v1/fields.py b/lib/pydantic/v1/fields.py new file mode 100644 index 00000000..60d260e9 --- /dev/null +++ b/lib/pydantic/v1/fields.py @@ -0,0 +1,1253 @@ +import copy +import re +from collections import Counter as CollectionCounter, defaultdict, deque +from collections.abc import Callable, Hashable as CollectionsHashable, Iterable as CollectionsIterable +from typing import ( + TYPE_CHECKING, + Any, + Counter, + DefaultDict, + Deque, + Dict, + ForwardRef, + FrozenSet, + Generator, + Iterable, + Iterator, + List, + Mapping, + Optional, + Pattern, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +from typing_extensions import Annotated, Final + +from . import errors as errors_ +from .class_validators import Validator, make_generic_validator, prep_validators +from .error_wrappers import ErrorWrapper +from .errors import ConfigError, InvalidDiscriminator, MissingDiscriminator, NoneIsNotAllowedError +from .types import Json, JsonWrapper +from .typing import ( + NoArgAnyCallable, + convert_generics, + display_as_type, + get_args, + get_origin, + is_finalvar, + is_literal_type, + is_new_type, + is_none_type, + is_typeddict, + is_typeddict_special, + is_union, + new_type_supertype, +) +from .utils import ( + PyObjectStr, + Representation, + ValueItems, + get_discriminator_alias_and_values, + get_unique_discriminator_alias, + lenient_isinstance, + lenient_issubclass, + sequence_like, + smart_deepcopy, +) +from .validators import constant_validator, dict_validator, find_validators, validate_json + +Required: Any = Ellipsis + +T = TypeVar('T') + + +class UndefinedType: + def __repr__(self) -> str: + return 'PydanticUndefined' + + def __copy__(self: T) -> T: + return self + + def __reduce__(self) -> str: + return 'Undefined' + + def __deepcopy__(self: T, _: Any) -> T: + return self + + +Undefined = UndefinedType() + +if TYPE_CHECKING: + from .class_validators import ValidatorsList + from .config import BaseConfig + from .error_wrappers import ErrorList + from .types import ModelOrDc + from .typing import AbstractSetIntStr, MappingIntStrAny, ReprArgs + + ValidateReturn = Tuple[Optional[Any], Optional[ErrorList]] + LocStr = Union[Tuple[Union[int, str], ...], str] + BoolUndefined = Union[bool, UndefinedType] + + +class FieldInfo(Representation): + """ + Captures extra information about a field. + """ + + __slots__ = ( + 'default', + 'default_factory', + 'alias', + 'alias_priority', + 'title', + 'description', + 'exclude', + 'include', + 'const', + 'gt', + 'ge', + 'lt', + 'le', + 'multiple_of', + 'allow_inf_nan', + 'max_digits', + 'decimal_places', + 'min_items', + 'max_items', + 'unique_items', + 'min_length', + 'max_length', + 'allow_mutation', + 'repr', + 'regex', + 'discriminator', + 'extra', + ) + + # field constraints with the default value, it's also used in update_from_config below + __field_constraints__ = { + 'min_length': None, + 'max_length': None, + 'regex': None, + 'gt': None, + 'lt': None, + 'ge': None, + 'le': None, + 'multiple_of': None, + 'allow_inf_nan': None, + 'max_digits': None, + 'decimal_places': None, + 'min_items': None, + 'max_items': None, + 'unique_items': None, + 'allow_mutation': True, + } + + def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: + self.default = default + self.default_factory = kwargs.pop('default_factory', None) + self.alias = kwargs.pop('alias', None) + self.alias_priority = kwargs.pop('alias_priority', 2 if self.alias is not None else None) + self.title = kwargs.pop('title', None) + self.description = kwargs.pop('description', None) + self.exclude = kwargs.pop('exclude', None) + self.include = kwargs.pop('include', None) + self.const = kwargs.pop('const', None) + self.gt = kwargs.pop('gt', None) + self.ge = kwargs.pop('ge', None) + self.lt = kwargs.pop('lt', None) + self.le = kwargs.pop('le', None) + self.multiple_of = kwargs.pop('multiple_of', None) + self.allow_inf_nan = kwargs.pop('allow_inf_nan', None) + self.max_digits = kwargs.pop('max_digits', None) + self.decimal_places = kwargs.pop('decimal_places', None) + self.min_items = kwargs.pop('min_items', None) + self.max_items = kwargs.pop('max_items', None) + self.unique_items = kwargs.pop('unique_items', None) + self.min_length = kwargs.pop('min_length', None) + self.max_length = kwargs.pop('max_length', None) + self.allow_mutation = kwargs.pop('allow_mutation', True) + self.regex = kwargs.pop('regex', None) + self.discriminator = kwargs.pop('discriminator', None) + self.repr = kwargs.pop('repr', True) + self.extra = kwargs + + def __repr_args__(self) -> 'ReprArgs': + field_defaults_to_hide: Dict[str, Any] = { + 'repr': True, + **self.__field_constraints__, + } + + attrs = ((s, getattr(self, s)) for s in self.__slots__) + return [(a, v) for a, v in attrs if v != field_defaults_to_hide.get(a, None)] + + def get_constraints(self) -> Set[str]: + """ + Gets the constraints set on the field by comparing the constraint value with its default value + + :return: the constraints set on field_info + """ + return {attr for attr, default in self.__field_constraints__.items() if getattr(self, attr) != default} + + def update_from_config(self, from_config: Dict[str, Any]) -> None: + """ + Update this FieldInfo based on a dict from get_field_info, only fields which have not been set are dated. + """ + for attr_name, value in from_config.items(): + try: + current_value = getattr(self, attr_name) + except AttributeError: + # attr_name is not an attribute of FieldInfo, it should therefore be added to extra + # (except if extra already has this value!) + self.extra.setdefault(attr_name, value) + else: + if current_value is self.__field_constraints__.get(attr_name, None): + setattr(self, attr_name, value) + elif attr_name == 'exclude': + self.exclude = ValueItems.merge(value, current_value) + elif attr_name == 'include': + self.include = ValueItems.merge(value, current_value, intersect=True) + + def _validate(self) -> None: + if self.default is not Undefined and self.default_factory is not None: + raise ValueError('cannot specify both default and default_factory') + + +def Field( + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny', Any]] = None, + include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny', Any]] = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + allow_inf_nan: Optional[bool] = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + unique_items: Optional[bool] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + discriminator: Optional[str] = None, + repr: bool = True, + **extra: Any, +) -> Any: + """ + Used to provide extra information about a field, either for the model schema or complex validation. Some arguments + apply only to number fields (``int``, ``float``, ``Decimal``) and some apply only to ``str``. + + :param default: since this is replacing the field’s default, its first argument is used + to set the default, use ellipsis (``...``) to indicate the field is required + :param default_factory: callable that will be called when a default value is needed for this field + If both `default` and `default_factory` are set, an error is raised. + :param alias: the public name of the field + :param title: can be any string, used in the schema + :param description: can be any string, used in the schema + :param exclude: exclude this field while dumping. + Takes same values as the ``include`` and ``exclude`` arguments on the ``.dict`` method. + :param include: include this field while dumping. + Takes same values as the ``include`` and ``exclude`` arguments on the ``.dict`` method. + :param const: this field is required and *must* take it's default value + :param gt: only applies to numbers, requires the field to be "greater than". The schema + will have an ``exclusiveMinimum`` validation keyword + :param ge: only applies to numbers, requires the field to be "greater than or equal to". The + schema will have a ``minimum`` validation keyword + :param lt: only applies to numbers, requires the field to be "less than". The schema + will have an ``exclusiveMaximum`` validation keyword + :param le: only applies to numbers, requires the field to be "less than or equal to". The + schema will have a ``maximum`` validation keyword + :param multiple_of: only applies to numbers, requires the field to be "a multiple of". The + schema will have a ``multipleOf`` validation keyword + :param allow_inf_nan: only applies to numbers, allows the field to be NaN or infinity (+inf or -inf), + which is a valid Python float. Default True, set to False for compatibility with JSON. + :param max_digits: only applies to Decimals, requires the field to have a maximum number + of digits within the decimal. It does not include a zero before the decimal point or trailing decimal zeroes. + :param decimal_places: only applies to Decimals, requires the field to have at most a number of decimal places + allowed. It does not include trailing decimal zeroes. + :param min_items: only applies to lists, requires the field to have a minimum number of + elements. The schema will have a ``minItems`` validation keyword + :param max_items: only applies to lists, requires the field to have a maximum number of + elements. The schema will have a ``maxItems`` validation keyword + :param unique_items: only applies to lists, requires the field not to have duplicated + elements. The schema will have a ``uniqueItems`` validation keyword + :param min_length: only applies to strings, requires the field to have a minimum length. The + schema will have a ``minLength`` validation keyword + :param max_length: only applies to strings, requires the field to have a maximum length. The + schema will have a ``maxLength`` validation keyword + :param allow_mutation: a boolean which defaults to True. When False, the field raises a TypeError if the field is + assigned on an instance. The BaseModel Config must set validate_assignment to True + :param regex: only applies to strings, requires the field match against a regular expression + pattern string. The schema will have a ``pattern`` validation keyword + :param discriminator: only useful with a (discriminated a.k.a. tagged) `Union` of sub models with a common field. + The `discriminator` is the name of this common field to shorten validation and improve generated schema + :param repr: show this field in the representation + :param **extra: any additional keyword arguments will be added as is to the schema + """ + field_info = FieldInfo( + default, + default_factory=default_factory, + alias=alias, + title=title, + description=description, + exclude=exclude, + include=include, + const=const, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + min_items=min_items, + max_items=max_items, + unique_items=unique_items, + min_length=min_length, + max_length=max_length, + allow_mutation=allow_mutation, + regex=regex, + discriminator=discriminator, + repr=repr, + **extra, + ) + field_info._validate() + return field_info + + +# used to be an enum but changed to int's for small performance improvement as less access overhead +SHAPE_SINGLETON = 1 +SHAPE_LIST = 2 +SHAPE_SET = 3 +SHAPE_MAPPING = 4 +SHAPE_TUPLE = 5 +SHAPE_TUPLE_ELLIPSIS = 6 +SHAPE_SEQUENCE = 7 +SHAPE_FROZENSET = 8 +SHAPE_ITERABLE = 9 +SHAPE_GENERIC = 10 +SHAPE_DEQUE = 11 +SHAPE_DICT = 12 +SHAPE_DEFAULTDICT = 13 +SHAPE_COUNTER = 14 +SHAPE_NAME_LOOKUP = { + SHAPE_LIST: 'List[{}]', + SHAPE_SET: 'Set[{}]', + SHAPE_TUPLE_ELLIPSIS: 'Tuple[{}, ...]', + SHAPE_SEQUENCE: 'Sequence[{}]', + SHAPE_FROZENSET: 'FrozenSet[{}]', + SHAPE_ITERABLE: 'Iterable[{}]', + SHAPE_DEQUE: 'Deque[{}]', + SHAPE_DICT: 'Dict[{}]', + SHAPE_DEFAULTDICT: 'DefaultDict[{}]', + SHAPE_COUNTER: 'Counter[{}]', +} + +MAPPING_LIKE_SHAPES: Set[int] = {SHAPE_DEFAULTDICT, SHAPE_DICT, SHAPE_MAPPING, SHAPE_COUNTER} + + +class ModelField(Representation): + __slots__ = ( + 'type_', + 'outer_type_', + 'annotation', + 'sub_fields', + 'sub_fields_mapping', + 'key_field', + 'validators', + 'pre_validators', + 'post_validators', + 'default', + 'default_factory', + 'required', + 'final', + 'model_config', + 'name', + 'alias', + 'has_alias', + 'field_info', + 'discriminator_key', + 'discriminator_alias', + 'validate_always', + 'allow_none', + 'shape', + 'class_validators', + 'parse_json', + ) + + def __init__( + self, + *, + name: str, + type_: Type[Any], + class_validators: Optional[Dict[str, Validator]], + model_config: Type['BaseConfig'], + default: Any = None, + default_factory: Optional[NoArgAnyCallable] = None, + required: 'BoolUndefined' = Undefined, + final: bool = False, + alias: Optional[str] = None, + field_info: Optional[FieldInfo] = None, + ) -> None: + self.name: str = name + self.has_alias: bool = alias is not None + self.alias: str = alias if alias is not None else name + self.annotation = type_ + self.type_: Any = convert_generics(type_) + self.outer_type_: Any = type_ + self.class_validators = class_validators or {} + self.default: Any = default + self.default_factory: Optional[NoArgAnyCallable] = default_factory + self.required: 'BoolUndefined' = required + self.final: bool = final + self.model_config = model_config + self.field_info: FieldInfo = field_info or FieldInfo(default) + self.discriminator_key: Optional[str] = self.field_info.discriminator + self.discriminator_alias: Optional[str] = self.discriminator_key + + self.allow_none: bool = False + self.validate_always: bool = False + self.sub_fields: Optional[List[ModelField]] = None + self.sub_fields_mapping: Optional[Dict[str, 'ModelField']] = None # used for discriminated union + self.key_field: Optional[ModelField] = None + self.validators: 'ValidatorsList' = [] + self.pre_validators: Optional['ValidatorsList'] = None + self.post_validators: Optional['ValidatorsList'] = None + self.parse_json: bool = False + self.shape: int = SHAPE_SINGLETON + self.model_config.prepare_field(self) + self.prepare() + + def get_default(self) -> Any: + return smart_deepcopy(self.default) if self.default_factory is None else self.default_factory() + + @staticmethod + def _get_field_info( + field_name: str, annotation: Any, value: Any, config: Type['BaseConfig'] + ) -> Tuple[FieldInfo, Any]: + """ + Get a FieldInfo from a root typing.Annotated annotation, value, or config default. + + The FieldInfo may be set in typing.Annotated or the value, but not both. If neither contain + a FieldInfo, a new one will be created using the config. + + :param field_name: name of the field for use in error messages + :param annotation: a type hint such as `str` or `Annotated[str, Field(..., min_length=5)]` + :param value: the field's assigned value + :param config: the model's config object + :return: the FieldInfo contained in the `annotation`, the value, or a new one from the config. + """ + field_info_from_config = config.get_field_info(field_name) + + field_info = None + if get_origin(annotation) is Annotated: + field_infos = [arg for arg in get_args(annotation)[1:] if isinstance(arg, FieldInfo)] + if len(field_infos) > 1: + raise ValueError(f'cannot specify multiple `Annotated` `Field`s for {field_name!r}') + field_info = next(iter(field_infos), None) + if field_info is not None: + field_info = copy.copy(field_info) + field_info.update_from_config(field_info_from_config) + if field_info.default not in (Undefined, Required): + raise ValueError(f'`Field` default cannot be set in `Annotated` for {field_name!r}') + if value is not Undefined and value is not Required: + # check also `Required` because of `validate_arguments` that sets `...` as default value + field_info.default = value + + if isinstance(value, FieldInfo): + if field_info is not None: + raise ValueError(f'cannot specify `Annotated` and value `Field`s together for {field_name!r}') + field_info = value + field_info.update_from_config(field_info_from_config) + elif field_info is None: + field_info = FieldInfo(value, **field_info_from_config) + value = None if field_info.default_factory is not None else field_info.default + field_info._validate() + return field_info, value + + @classmethod + def infer( + cls, + *, + name: str, + value: Any, + annotation: Any, + class_validators: Optional[Dict[str, Validator]], + config: Type['BaseConfig'], + ) -> 'ModelField': + from .schema import get_annotation_from_field_info + + field_info, value = cls._get_field_info(name, annotation, value, config) + required: 'BoolUndefined' = Undefined + if value is Required: + required = True + value = None + elif value is not Undefined: + required = False + annotation = get_annotation_from_field_info(annotation, field_info, name, config.validate_assignment) + + return cls( + name=name, + type_=annotation, + alias=field_info.alias, + class_validators=class_validators, + default=value, + default_factory=field_info.default_factory, + required=required, + model_config=config, + field_info=field_info, + ) + + def set_config(self, config: Type['BaseConfig']) -> None: + self.model_config = config + info_from_config = config.get_field_info(self.name) + config.prepare_field(self) + new_alias = info_from_config.get('alias') + new_alias_priority = info_from_config.get('alias_priority') or 0 + if new_alias and new_alias_priority >= (self.field_info.alias_priority or 0): + self.field_info.alias = new_alias + self.field_info.alias_priority = new_alias_priority + self.alias = new_alias + new_exclude = info_from_config.get('exclude') + if new_exclude is not None: + self.field_info.exclude = ValueItems.merge(self.field_info.exclude, new_exclude) + new_include = info_from_config.get('include') + if new_include is not None: + self.field_info.include = ValueItems.merge(self.field_info.include, new_include, intersect=True) + + @property + def alt_alias(self) -> bool: + return self.name != self.alias + + def prepare(self) -> None: + """ + Prepare the field but inspecting self.default, self.type_ etc. + + Note: this method is **not** idempotent (because _type_analysis is not idempotent), + e.g. calling it it multiple times may modify the field and configure it incorrectly. + """ + self._set_default_and_type() + if self.type_.__class__ is ForwardRef or self.type_.__class__ is DeferredType: + # self.type_ is currently a ForwardRef and there's nothing we can do now, + # user will need to call model.update_forward_refs() + return + + self._type_analysis() + if self.required is Undefined: + self.required = True + if self.default is Undefined and self.default_factory is None: + self.default = None + self.populate_validators() + + def _set_default_and_type(self) -> None: + """ + Set the default value, infer the type if needed and check if `None` value is valid. + """ + if self.default_factory is not None: + if self.type_ is Undefined: + raise errors_.ConfigError( + f'you need to set the type of field {self.name!r} when using `default_factory`' + ) + return + + default_value = self.get_default() + + if default_value is not None and self.type_ is Undefined: + self.type_ = default_value.__class__ + self.outer_type_ = self.type_ + self.annotation = self.type_ + + if self.type_ is Undefined: + raise errors_.ConfigError(f'unable to infer type for attribute "{self.name}"') + + if self.required is False and default_value is None: + self.allow_none = True + + def _type_analysis(self) -> None: # noqa: C901 (ignore complexity) + # typing interface is horrible, we have to do some ugly checks + if lenient_issubclass(self.type_, JsonWrapper): + self.type_ = self.type_.inner_type + self.parse_json = True + elif lenient_issubclass(self.type_, Json): + self.type_ = Any + self.parse_json = True + elif isinstance(self.type_, TypeVar): + if self.type_.__bound__: + self.type_ = self.type_.__bound__ + elif self.type_.__constraints__: + self.type_ = Union[self.type_.__constraints__] + else: + self.type_ = Any + elif is_new_type(self.type_): + self.type_ = new_type_supertype(self.type_) + + if self.type_ is Any or self.type_ is object: + if self.required is Undefined: + self.required = False + self.allow_none = True + return + elif self.type_ is Pattern or self.type_ is re.Pattern: + # python 3.7 only, Pattern is a typing object but without sub fields + return + elif is_literal_type(self.type_): + return + elif is_typeddict(self.type_): + return + + if is_finalvar(self.type_): + self.final = True + + if self.type_ is Final: + self.type_ = Any + else: + self.type_ = get_args(self.type_)[0] + + self._type_analysis() + return + + origin = get_origin(self.type_) + + if origin is Annotated or is_typeddict_special(origin): + self.type_ = get_args(self.type_)[0] + self._type_analysis() + return + + if self.discriminator_key is not None and not is_union(origin): + raise TypeError('`discriminator` can only be used with `Union` type with more than one variant') + + # add extra check for `collections.abc.Hashable` for python 3.10+ where origin is not `None` + if origin is None or origin is CollectionsHashable: + # field is not "typing" object eg. Union, Dict, List etc. + # allow None for virtual superclasses of NoneType, e.g. Hashable + if isinstance(self.type_, type) and isinstance(None, self.type_): + self.allow_none = True + return + elif origin is Callable: + return + elif is_union(origin): + types_ = [] + for type_ in get_args(self.type_): + if is_none_type(type_) or type_ is Any or type_ is object: + if self.required is Undefined: + self.required = False + self.allow_none = True + if is_none_type(type_): + continue + types_.append(type_) + + if len(types_) == 1: + # Optional[] + self.type_ = types_[0] + # this is the one case where the "outer type" isn't just the original type + self.outer_type_ = self.type_ + # re-run to correctly interpret the new self.type_ + self._type_analysis() + else: + self.sub_fields = [self._create_sub_type(t, f'{self.name}_{display_as_type(t)}') for t in types_] + + if self.discriminator_key is not None: + self.prepare_discriminated_union_sub_fields() + return + elif issubclass(origin, Tuple): # type: ignore + # origin == Tuple without item type + args = get_args(self.type_) + if not args: # plain tuple + self.type_ = Any + self.shape = SHAPE_TUPLE_ELLIPSIS + elif len(args) == 2 and args[1] is Ellipsis: # e.g. Tuple[int, ...] + self.type_ = args[0] + self.shape = SHAPE_TUPLE_ELLIPSIS + self.sub_fields = [self._create_sub_type(args[0], f'{self.name}_0')] + elif args == ((),): # Tuple[()] means empty tuple + self.shape = SHAPE_TUPLE + self.type_ = Any + self.sub_fields = [] + else: + self.shape = SHAPE_TUPLE + self.sub_fields = [self._create_sub_type(t, f'{self.name}_{i}') for i, t in enumerate(args)] + return + elif issubclass(origin, List): + # Create self validators + get_validators = getattr(self.type_, '__get_validators__', None) + if get_validators: + self.class_validators.update( + {f'list_{i}': Validator(validator, pre=True) for i, validator in enumerate(get_validators())} + ) + + self.type_ = get_args(self.type_)[0] + self.shape = SHAPE_LIST + elif issubclass(origin, Set): + # Create self validators + get_validators = getattr(self.type_, '__get_validators__', None) + if get_validators: + self.class_validators.update( + {f'set_{i}': Validator(validator, pre=True) for i, validator in enumerate(get_validators())} + ) + + self.type_ = get_args(self.type_)[0] + self.shape = SHAPE_SET + elif issubclass(origin, FrozenSet): + # Create self validators + get_validators = getattr(self.type_, '__get_validators__', None) + if get_validators: + self.class_validators.update( + {f'frozenset_{i}': Validator(validator, pre=True) for i, validator in enumerate(get_validators())} + ) + + self.type_ = get_args(self.type_)[0] + self.shape = SHAPE_FROZENSET + elif issubclass(origin, Deque): + self.type_ = get_args(self.type_)[0] + self.shape = SHAPE_DEQUE + elif issubclass(origin, Sequence): + self.type_ = get_args(self.type_)[0] + self.shape = SHAPE_SEQUENCE + # priority to most common mapping: dict + elif origin is dict or origin is Dict: + self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True) + self.type_ = get_args(self.type_)[1] + self.shape = SHAPE_DICT + elif issubclass(origin, DefaultDict): + self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True) + self.type_ = get_args(self.type_)[1] + self.shape = SHAPE_DEFAULTDICT + elif issubclass(origin, Counter): + self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True) + self.type_ = int + self.shape = SHAPE_COUNTER + elif issubclass(origin, Mapping): + self.key_field = self._create_sub_type(get_args(self.type_)[0], 'key_' + self.name, for_keys=True) + self.type_ = get_args(self.type_)[1] + self.shape = SHAPE_MAPPING + # Equality check as almost everything inherits form Iterable, including str + # check for Iterable and CollectionsIterable, as it could receive one even when declared with the other + elif origin in {Iterable, CollectionsIterable}: + self.type_ = get_args(self.type_)[0] + self.shape = SHAPE_ITERABLE + self.sub_fields = [self._create_sub_type(self.type_, f'{self.name}_type')] + elif issubclass(origin, Type): # type: ignore + return + elif hasattr(origin, '__get_validators__') or self.model_config.arbitrary_types_allowed: + # Is a Pydantic-compatible generic that handles itself + # or we have arbitrary_types_allowed = True + self.shape = SHAPE_GENERIC + self.sub_fields = [self._create_sub_type(t, f'{self.name}_{i}') for i, t in enumerate(get_args(self.type_))] + self.type_ = origin + return + else: + raise TypeError(f'Fields of type "{origin}" are not supported.') + + # type_ has been refined eg. as the type of a List and sub_fields needs to be populated + self.sub_fields = [self._create_sub_type(self.type_, '_' + self.name)] + + def prepare_discriminated_union_sub_fields(self) -> None: + """ + Prepare the mapping -> and update `sub_fields` + Note that this process can be aborted if a `ForwardRef` is encountered + """ + assert self.discriminator_key is not None + + if self.type_.__class__ is DeferredType: + return + + assert self.sub_fields is not None + sub_fields_mapping: Dict[str, 'ModelField'] = {} + all_aliases: Set[str] = set() + + for sub_field in self.sub_fields: + t = sub_field.type_ + if t.__class__ is ForwardRef: + # Stopping everything...will need to call `update_forward_refs` + return + + alias, discriminator_values = get_discriminator_alias_and_values(t, self.discriminator_key) + all_aliases.add(alias) + for discriminator_value in discriminator_values: + sub_fields_mapping[discriminator_value] = sub_field + + self.sub_fields_mapping = sub_fields_mapping + self.discriminator_alias = get_unique_discriminator_alias(all_aliases, self.discriminator_key) + + def _create_sub_type(self, type_: Type[Any], name: str, *, for_keys: bool = False) -> 'ModelField': + if for_keys: + class_validators = None + else: + # validators for sub items should not have `each_item` as we want to check only the first sublevel + class_validators = { + k: Validator( + func=v.func, + pre=v.pre, + each_item=False, + always=v.always, + check_fields=v.check_fields, + skip_on_failure=v.skip_on_failure, + ) + for k, v in self.class_validators.items() + if v.each_item + } + + field_info, _ = self._get_field_info(name, type_, None, self.model_config) + + return self.__class__( + type_=type_, + name=name, + class_validators=class_validators, + model_config=self.model_config, + field_info=field_info, + ) + + def populate_validators(self) -> None: + """ + Prepare self.pre_validators, self.validators, and self.post_validators based on self.type_'s __get_validators__ + and class validators. This method should be idempotent, e.g. it should be safe to call multiple times + without mis-configuring the field. + """ + self.validate_always = getattr(self.type_, 'validate_always', False) or any( + v.always for v in self.class_validators.values() + ) + + class_validators_ = self.class_validators.values() + if not self.sub_fields or self.shape == SHAPE_GENERIC: + get_validators = getattr(self.type_, '__get_validators__', None) + v_funcs = ( + *[v.func for v in class_validators_ if v.each_item and v.pre], + *(get_validators() if get_validators else list(find_validators(self.type_, self.model_config))), + *[v.func for v in class_validators_ if v.each_item and not v.pre], + ) + self.validators = prep_validators(v_funcs) + + self.pre_validators = [] + self.post_validators = [] + + if self.field_info and self.field_info.const: + self.post_validators.append(make_generic_validator(constant_validator)) + + if class_validators_: + self.pre_validators += prep_validators(v.func for v in class_validators_ if not v.each_item and v.pre) + self.post_validators += prep_validators(v.func for v in class_validators_ if not v.each_item and not v.pre) + + if self.parse_json: + self.pre_validators.append(make_generic_validator(validate_json)) + + self.pre_validators = self.pre_validators or None + self.post_validators = self.post_validators or None + + def validate( + self, v: Any, values: Dict[str, Any], *, loc: 'LocStr', cls: Optional['ModelOrDc'] = None + ) -> 'ValidateReturn': + assert self.type_.__class__ is not DeferredType + + if self.type_.__class__ is ForwardRef: + assert cls is not None + raise ConfigError( + f'field "{self.name}" not yet prepared so type is still a ForwardRef, ' + f'you might need to call {cls.__name__}.update_forward_refs().' + ) + + errors: Optional['ErrorList'] + if self.pre_validators: + v, errors = self._apply_validators(v, values, loc, cls, self.pre_validators) + if errors: + return v, errors + + if v is None: + if is_none_type(self.type_): + # keep validating + pass + elif self.allow_none: + if self.post_validators: + return self._apply_validators(v, values, loc, cls, self.post_validators) + else: + return None, None + else: + return v, ErrorWrapper(NoneIsNotAllowedError(), loc) + + if self.shape == SHAPE_SINGLETON: + v, errors = self._validate_singleton(v, values, loc, cls) + elif self.shape in MAPPING_LIKE_SHAPES: + v, errors = self._validate_mapping_like(v, values, loc, cls) + elif self.shape == SHAPE_TUPLE: + v, errors = self._validate_tuple(v, values, loc, cls) + elif self.shape == SHAPE_ITERABLE: + v, errors = self._validate_iterable(v, values, loc, cls) + elif self.shape == SHAPE_GENERIC: + v, errors = self._apply_validators(v, values, loc, cls, self.validators) + else: + # sequence, list, set, generator, tuple with ellipsis, frozen set + v, errors = self._validate_sequence_like(v, values, loc, cls) + + if not errors and self.post_validators: + v, errors = self._apply_validators(v, values, loc, cls, self.post_validators) + return v, errors + + def _validate_sequence_like( # noqa: C901 (ignore complexity) + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] + ) -> 'ValidateReturn': + """ + Validate sequence-like containers: lists, tuples, sets and generators + Note that large if-else blocks are necessary to enable Cython + optimization, which is why we disable the complexity check above. + """ + if not sequence_like(v): + e: errors_.PydanticTypeError + if self.shape == SHAPE_LIST: + e = errors_.ListError() + elif self.shape in (SHAPE_TUPLE, SHAPE_TUPLE_ELLIPSIS): + e = errors_.TupleError() + elif self.shape == SHAPE_SET: + e = errors_.SetError() + elif self.shape == SHAPE_FROZENSET: + e = errors_.FrozenSetError() + else: + e = errors_.SequenceError() + return v, ErrorWrapper(e, loc) + + loc = loc if isinstance(loc, tuple) else (loc,) + result = [] + errors: List[ErrorList] = [] + for i, v_ in enumerate(v): + v_loc = *loc, i + r, ee = self._validate_singleton(v_, values, v_loc, cls) + if ee: + errors.append(ee) + else: + result.append(r) + + if errors: + return v, errors + + converted: Union[List[Any], Set[Any], FrozenSet[Any], Tuple[Any, ...], Iterator[Any], Deque[Any]] = result + + if self.shape == SHAPE_SET: + converted = set(result) + elif self.shape == SHAPE_FROZENSET: + converted = frozenset(result) + elif self.shape == SHAPE_TUPLE_ELLIPSIS: + converted = tuple(result) + elif self.shape == SHAPE_DEQUE: + converted = deque(result, maxlen=getattr(v, 'maxlen', None)) + elif self.shape == SHAPE_SEQUENCE: + if isinstance(v, tuple): + converted = tuple(result) + elif isinstance(v, set): + converted = set(result) + elif isinstance(v, Generator): + converted = iter(result) + elif isinstance(v, deque): + converted = deque(result, maxlen=getattr(v, 'maxlen', None)) + return converted, None + + def _validate_iterable( + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] + ) -> 'ValidateReturn': + """ + Validate Iterables. + + This intentionally doesn't validate values to allow infinite generators. + """ + + try: + iterable = iter(v) + except TypeError: + return v, ErrorWrapper(errors_.IterableError(), loc) + return iterable, None + + def _validate_tuple( + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] + ) -> 'ValidateReturn': + e: Optional[Exception] = None + if not sequence_like(v): + e = errors_.TupleError() + else: + actual_length, expected_length = len(v), len(self.sub_fields) # type: ignore + if actual_length != expected_length: + e = errors_.TupleLengthError(actual_length=actual_length, expected_length=expected_length) + + if e: + return v, ErrorWrapper(e, loc) + + loc = loc if isinstance(loc, tuple) else (loc,) + result = [] + errors: List[ErrorList] = [] + for i, (v_, field) in enumerate(zip(v, self.sub_fields)): # type: ignore + v_loc = *loc, i + r, ee = field.validate(v_, values, loc=v_loc, cls=cls) + if ee: + errors.append(ee) + else: + result.append(r) + + if errors: + return v, errors + else: + return tuple(result), None + + def _validate_mapping_like( + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] + ) -> 'ValidateReturn': + try: + v_iter = dict_validator(v) + except TypeError as exc: + return v, ErrorWrapper(exc, loc) + + loc = loc if isinstance(loc, tuple) else (loc,) + result, errors = {}, [] + for k, v_ in v_iter.items(): + v_loc = *loc, '__key__' + key_result, key_errors = self.key_field.validate(k, values, loc=v_loc, cls=cls) # type: ignore + if key_errors: + errors.append(key_errors) + continue + + v_loc = *loc, k + value_result, value_errors = self._validate_singleton(v_, values, v_loc, cls) + if value_errors: + errors.append(value_errors) + continue + + result[key_result] = value_result + if errors: + return v, errors + elif self.shape == SHAPE_DICT: + return result, None + elif self.shape == SHAPE_DEFAULTDICT: + return defaultdict(self.type_, result), None + elif self.shape == SHAPE_COUNTER: + return CollectionCounter(result), None + else: + return self._get_mapping_value(v, result), None + + def _get_mapping_value(self, original: T, converted: Dict[Any, Any]) -> Union[T, Dict[Any, Any]]: + """ + When type is `Mapping[KT, KV]` (or another unsupported mapping), we try to avoid + coercing to `dict` unwillingly. + """ + original_cls = original.__class__ + + if original_cls == dict or original_cls == Dict: + return converted + elif original_cls in {defaultdict, DefaultDict}: + return defaultdict(self.type_, converted) + else: + try: + # Counter, OrderedDict, UserDict, ... + return original_cls(converted) # type: ignore + except TypeError: + raise RuntimeError(f'Could not convert dictionary to {original_cls.__name__!r}') from None + + def _validate_singleton( + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] + ) -> 'ValidateReturn': + if self.sub_fields: + if self.discriminator_key is not None: + return self._validate_discriminated_union(v, values, loc, cls) + + errors = [] + + if self.model_config.smart_union and is_union(get_origin(self.type_)): + # 1st pass: check if the value is an exact instance of one of the Union types + # (e.g. to avoid coercing a bool into an int) + for field in self.sub_fields: + if v.__class__ is field.outer_type_: + return v, None + + # 2nd pass: check if the value is an instance of any subclass of the Union types + for field in self.sub_fields: + # This whole logic will be improved later on to support more complex `isinstance` checks + # It will probably be done once a strict mode is added and be something like: + # ``` + # value, error = field.validate(v, values, strict=True) + # if error is None: + # return value, None + # ``` + try: + if isinstance(v, field.outer_type_): + return v, None + except TypeError: + # compound type + if lenient_isinstance(v, get_origin(field.outer_type_)): + value, error = field.validate(v, values, loc=loc, cls=cls) + if not error: + return value, None + + # 1st pass by default or 3rd pass with `smart_union` enabled: + # check if the value can be coerced into one of the Union types + for field in self.sub_fields: + value, error = field.validate(v, values, loc=loc, cls=cls) + if error: + errors.append(error) + else: + return value, None + return v, errors + else: + return self._apply_validators(v, values, loc, cls, self.validators) + + def _validate_discriminated_union( + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'] + ) -> 'ValidateReturn': + assert self.discriminator_key is not None + assert self.discriminator_alias is not None + + try: + try: + discriminator_value = v[self.discriminator_alias] + except KeyError: + if self.model_config.allow_population_by_field_name: + discriminator_value = v[self.discriminator_key] + else: + raise + except KeyError: + return v, ErrorWrapper(MissingDiscriminator(discriminator_key=self.discriminator_key), loc) + except TypeError: + try: + # BaseModel or dataclass + discriminator_value = getattr(v, self.discriminator_key) + except (AttributeError, TypeError): + return v, ErrorWrapper(MissingDiscriminator(discriminator_key=self.discriminator_key), loc) + + if self.sub_fields_mapping is None: + assert cls is not None + raise ConfigError( + f'field "{self.name}" not yet prepared so type is still a ForwardRef, ' + f'you might need to call {cls.__name__}.update_forward_refs().' + ) + + try: + sub_field = self.sub_fields_mapping[discriminator_value] + except (KeyError, TypeError): + # KeyError: `discriminator_value` is not in the dictionary. + # TypeError: `discriminator_value` is unhashable. + assert self.sub_fields_mapping is not None + return v, ErrorWrapper( + InvalidDiscriminator( + discriminator_key=self.discriminator_key, + discriminator_value=discriminator_value, + allowed_values=list(self.sub_fields_mapping), + ), + loc, + ) + else: + if not isinstance(loc, tuple): + loc = (loc,) + return sub_field.validate(v, values, loc=(*loc, display_as_type(sub_field.type_)), cls=cls) + + def _apply_validators( + self, v: Any, values: Dict[str, Any], loc: 'LocStr', cls: Optional['ModelOrDc'], validators: 'ValidatorsList' + ) -> 'ValidateReturn': + for validator in validators: + try: + v = validator(cls, v, values, self, self.model_config) + except (ValueError, TypeError, AssertionError) as exc: + return v, ErrorWrapper(exc, loc) + return v, None + + def is_complex(self) -> bool: + """ + Whether the field is "complex" eg. env variables should be parsed as JSON. + """ + from .main import BaseModel + + return ( + self.shape != SHAPE_SINGLETON + or hasattr(self.type_, '__pydantic_model__') + or lenient_issubclass(self.type_, (BaseModel, list, set, frozenset, dict)) + ) + + def _type_display(self) -> PyObjectStr: + t = display_as_type(self.type_) + + if self.shape in MAPPING_LIKE_SHAPES: + t = f'Mapping[{display_as_type(self.key_field.type_)}, {t}]' # type: ignore + elif self.shape == SHAPE_TUPLE: + t = 'Tuple[{}]'.format(', '.join(display_as_type(f.type_) for f in self.sub_fields)) # type: ignore + elif self.shape == SHAPE_GENERIC: + assert self.sub_fields + t = '{}[{}]'.format( + display_as_type(self.type_), ', '.join(display_as_type(f.type_) for f in self.sub_fields) + ) + elif self.shape != SHAPE_SINGLETON: + t = SHAPE_NAME_LOOKUP[self.shape].format(t) + + if self.allow_none and (self.shape != SHAPE_SINGLETON or not self.sub_fields): + t = f'Optional[{t}]' + return PyObjectStr(t) + + def __repr_args__(self) -> 'ReprArgs': + args = [('name', self.name), ('type', self._type_display()), ('required', self.required)] + + if not self.required: + if self.default_factory is not None: + args.append(('default_factory', f'')) + else: + args.append(('default', self.default)) + + if self.alt_alias: + args.append(('alias', self.alias)) + return args + + +class ModelPrivateAttr(Representation): + __slots__ = ('default', 'default_factory') + + def __init__(self, default: Any = Undefined, *, default_factory: Optional[NoArgAnyCallable] = None) -> None: + self.default = default + self.default_factory = default_factory + + def get_default(self) -> Any: + return smart_deepcopy(self.default) if self.default_factory is None else self.default_factory() + + def __eq__(self, other: Any) -> bool: + return isinstance(other, self.__class__) and (self.default, self.default_factory) == ( + other.default, + other.default_factory, + ) + + +def PrivateAttr( + default: Any = Undefined, + *, + default_factory: Optional[NoArgAnyCallable] = None, +) -> Any: + """ + Indicates that attribute is only used internally and never mixed with regular fields. + + Types or values of private attrs are not checked by pydantic and it's up to you to keep them relevant. + + Private attrs are stored in model __slots__. + + :param default: the attribute’s default value + :param default_factory: callable that will be called when a default value is needed for this attribute + If both `default` and `default_factory` are set, an error is raised. + """ + if default is not Undefined and default_factory is not None: + raise ValueError('cannot specify both default and default_factory') + + return ModelPrivateAttr( + default, + default_factory=default_factory, + ) + + +class DeferredType: + """ + Used to postpone field preparation, while creating recursive generic models. + """ + + +def is_finalvar_with_default_val(type_: Type[Any], val: Any) -> bool: + return is_finalvar(type_) and val is not Undefined and not isinstance(val, FieldInfo) diff --git a/lib/pydantic/v1/generics.py b/lib/pydantic/v1/generics.py new file mode 100644 index 00000000..a75b6b98 --- /dev/null +++ b/lib/pydantic/v1/generics.py @@ -0,0 +1,400 @@ +import sys +import types +import typing +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + ForwardRef, + Generic, + Iterator, + List, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, +) +from weakref import WeakKeyDictionary, WeakValueDictionary + +from typing_extensions import Annotated, Literal as ExtLiteral + +from .class_validators import gather_all_validators +from .fields import DeferredType +from .main import BaseModel, create_model +from .types import JsonWrapper +from .typing import display_as_type, get_all_type_hints, get_args, get_origin, typing_base +from .utils import all_identical, lenient_issubclass + +if sys.version_info >= (3, 10): + from typing import _UnionGenericAlias +if sys.version_info >= (3, 8): + from typing import Literal + +GenericModelT = TypeVar('GenericModelT', bound='GenericModel') +TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type + +CacheKey = Tuple[Type[Any], Any, Tuple[Any, ...]] +Parametrization = Mapping[TypeVarType, Type[Any]] + +# weak dictionaries allow the dynamically created parametrized versions of generic models to get collected +# once they are no longer referenced by the caller. +if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9 + GenericTypesCache = WeakValueDictionary[CacheKey, Type[BaseModel]] + AssignedParameters = WeakKeyDictionary[Type[BaseModel], Parametrization] +else: + GenericTypesCache = WeakValueDictionary + AssignedParameters = WeakKeyDictionary + +# _generic_types_cache is a Mapping from __class_getitem__ arguments to the parametrized version of generic models. +# This ensures multiple calls of e.g. A[B] return always the same class. +_generic_types_cache = GenericTypesCache() + +# _assigned_parameters is a Mapping from parametrized version of generic models to assigned types of parametrizations +# as captured during construction of the class (not instances). +# E.g., for generic model `Model[A, B]`, when parametrized model `Model[int, str]` is created, +# `Model[int, str]`: {A: int, B: str}` will be stored in `_assigned_parameters`. +# (This information is only otherwise available after creation from the class name string). +_assigned_parameters = AssignedParameters() + + +class GenericModel(BaseModel): + __slots__ = () + __concrete__: ClassVar[bool] = False + + if TYPE_CHECKING: + # Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with + # `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of + # `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below. + __parameters__: ClassVar[Tuple[TypeVarType, ...]] + + # Setting the return type as Type[Any] instead of Type[BaseModel] prevents PyCharm warnings + def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[Type[Any], ...]]) -> Type[Any]: + """Instantiates a new class from a generic class `cls` and type variables `params`. + + :param params: Tuple of types the class . Given a generic class + `Model` with 2 type variables and a concrete model `Model[str, int]`, + the value `(str, int)` would be passed to `params`. + :return: New model class inheriting from `cls` with instantiated + types described by `params`. If no parameters are given, `cls` is + returned as is. + + """ + + def _cache_key(_params: Any) -> CacheKey: + args = get_args(_params) + # python returns a list for Callables, which is not hashable + if len(args) == 2 and isinstance(args[0], list): + args = (tuple(args[0]), args[1]) + return cls, _params, args + + cached = _generic_types_cache.get(_cache_key(params)) + if cached is not None: + return cached + if cls.__concrete__ and Generic not in cls.__bases__: + raise TypeError('Cannot parameterize a concrete instantiation of a generic model') + if not isinstance(params, tuple): + params = (params,) + if cls is GenericModel and any(isinstance(param, TypeVar) for param in params): + raise TypeError('Type parameters should be placed on typing.Generic, not GenericModel') + if not hasattr(cls, '__parameters__'): + raise TypeError(f'Type {cls.__name__} must inherit from typing.Generic before being parameterized') + + check_parameters_count(cls, params) + # Build map from generic typevars to passed params + typevars_map: Dict[TypeVarType, Type[Any]] = dict(zip(cls.__parameters__, params)) + if all_identical(typevars_map.keys(), typevars_map.values()) and typevars_map: + return cls # if arguments are equal to parameters it's the same object + + # Create new model with original model as parent inserting fields with DeferredType. + model_name = cls.__concrete_name__(params) + validators = gather_all_validators(cls) + + type_hints = get_all_type_hints(cls).items() + instance_type_hints = {k: v for k, v in type_hints if get_origin(v) is not ClassVar} + + fields = {k: (DeferredType(), cls.__fields__[k].field_info) for k in instance_type_hints if k in cls.__fields__} + + model_module, called_globally = get_caller_frame_info() + created_model = cast( + Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes + create_model( + model_name, + __module__=model_module or cls.__module__, + __base__=(cls,) + tuple(cls.__parameterized_bases__(typevars_map)), + __config__=None, + __validators__=validators, + __cls_kwargs__=None, + **fields, + ), + ) + + _assigned_parameters[created_model] = typevars_map + + if called_globally: # create global reference and therefore allow pickling + object_by_reference = None + reference_name = model_name + reference_module_globals = sys.modules[created_model.__module__].__dict__ + while object_by_reference is not created_model: + object_by_reference = reference_module_globals.setdefault(reference_name, created_model) + reference_name += '_' + + created_model.Config = cls.Config + + # Find any typevars that are still present in the model. + # If none are left, the model is fully "concrete", otherwise the new + # class is a generic class as well taking the found typevars as + # parameters. + new_params = tuple( + {param: None for param in iter_contained_typevars(typevars_map.values())} + ) # use dict as ordered set + created_model.__concrete__ = not new_params + if new_params: + created_model.__parameters__ = new_params + + # Save created model in cache so we don't end up creating duplicate + # models that should be identical. + _generic_types_cache[_cache_key(params)] = created_model + if len(params) == 1: + _generic_types_cache[_cache_key(params[0])] = created_model + + # Recursively walk class type hints and replace generic typevars + # with concrete types that were passed. + _prepare_model_fields(created_model, fields, instance_type_hints, typevars_map) + + return created_model + + @classmethod + def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str: + """Compute class name for child classes. + + :param params: Tuple of types the class . Given a generic class + `Model` with 2 type variables and a concrete model `Model[str, int]`, + the value `(str, int)` would be passed to `params`. + :return: String representing a the new class where `params` are + passed to `cls` as type variables. + + This method can be overridden to achieve a custom naming scheme for GenericModels. + """ + param_names = [display_as_type(param) for param in params] + params_component = ', '.join(param_names) + return f'{cls.__name__}[{params_component}]' + + @classmethod + def __parameterized_bases__(cls, typevars_map: Parametrization) -> Iterator[Type[Any]]: + """ + Returns unbound bases of cls parameterised to given type variables + + :param typevars_map: Dictionary of type applications for binding subclasses. + Given a generic class `Model` with 2 type variables [S, T] + and a concrete model `Model[str, int]`, + the value `{S: str, T: int}` would be passed to `typevars_map`. + :return: an iterator of generic sub classes, parameterised by `typevars_map` + and other assigned parameters of `cls` + + e.g.: + ``` + class A(GenericModel, Generic[T]): + ... + + class B(A[V], Generic[V]): + ... + + assert A[int] in B.__parameterized_bases__({V: int}) + ``` + """ + + def build_base_model( + base_model: Type[GenericModel], mapped_types: Parametrization + ) -> Iterator[Type[GenericModel]]: + base_parameters = tuple(mapped_types[param] for param in base_model.__parameters__) + parameterized_base = base_model.__class_getitem__(base_parameters) + if parameterized_base is base_model or parameterized_base is cls: + # Avoid duplication in MRO + return + yield parameterized_base + + for base_model in cls.__bases__: + if not issubclass(base_model, GenericModel): + # not a class that can be meaningfully parameterized + continue + elif not getattr(base_model, '__parameters__', None): + # base_model is "GenericModel" (and has no __parameters__) + # or + # base_model is already concrete, and will be included transitively via cls. + continue + elif cls in _assigned_parameters: + if base_model in _assigned_parameters: + # cls is partially parameterised but not from base_model + # e.g. cls = B[S], base_model = A[S] + # B[S][int] should subclass A[int], (and will be transitively via B[int]) + # but it's not viable to consistently subclass types with arbitrary construction + # So don't attempt to include A[S][int] + continue + else: # base_model not in _assigned_parameters: + # cls is partially parameterized, base_model is original generic + # e.g. cls = B[str, T], base_model = B[S, T] + # Need to determine the mapping for the base_model parameters + mapped_types: Parametrization = { + key: typevars_map.get(value, value) for key, value in _assigned_parameters[cls].items() + } + yield from build_base_model(base_model, mapped_types) + else: + # cls is base generic, so base_class has a distinct base + # can construct the Parameterised base model using typevars_map directly + yield from build_base_model(base_model, typevars_map) + + +def replace_types(type_: Any, type_map: Mapping[Any, Any]) -> Any: + """Return type with all occurrences of `type_map` keys recursively replaced with their values. + + :param type_: Any type, class or generic alias + :param type_map: Mapping from `TypeVar` instance to concrete types. + :return: New type representing the basic structure of `type_` with all + `typevar_map` keys recursively replaced. + + >>> replace_types(Tuple[str, Union[List[str], float]], {str: int}) + Tuple[int, Union[List[int], float]] + + """ + if not type_map: + return type_ + + type_args = get_args(type_) + origin_type = get_origin(type_) + + if origin_type is Annotated: + annotated_type, *annotations = type_args + return Annotated[replace_types(annotated_type, type_map), tuple(annotations)] + + if (origin_type is ExtLiteral) or (sys.version_info >= (3, 8) and origin_type is Literal): + return type_map.get(type_, type_) + # Having type args is a good indicator that this is a typing module + # class instantiation or a generic alias of some sort. + if type_args: + resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args) + if all_identical(type_args, resolved_type_args): + # If all arguments are the same, there is no need to modify the + # type or create a new object at all + return type_ + if ( + origin_type is not None + and isinstance(type_, typing_base) + and not isinstance(origin_type, typing_base) + and getattr(type_, '_name', None) is not None + ): + # In python < 3.9 generic aliases don't exist so any of these like `list`, + # `type` or `collections.abc.Callable` need to be translated. + # See: https://www.python.org/dev/peps/pep-0585 + origin_type = getattr(typing, type_._name) + assert origin_type is not None + # PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__. + # We also cannot use isinstance() since we have to compare types. + if sys.version_info >= (3, 10) and origin_type is types.UnionType: # noqa: E721 + return _UnionGenericAlias(origin_type, resolved_type_args) + return origin_type[resolved_type_args] + + # We handle pydantic generic models separately as they don't have the same + # semantics as "typing" classes or generic aliases + if not origin_type and lenient_issubclass(type_, GenericModel) and not type_.__concrete__: + type_args = type_.__parameters__ + resolved_type_args = tuple(replace_types(t, type_map) for t in type_args) + if all_identical(type_args, resolved_type_args): + return type_ + return type_[resolved_type_args] + + # Handle special case for typehints that can have lists as arguments. + # `typing.Callable[[int, str], int]` is an example for this. + if isinstance(type_, (List, list)): + resolved_list = list(replace_types(element, type_map) for element in type_) + if all_identical(type_, resolved_list): + return type_ + return resolved_list + + # For JsonWrapperValue, need to handle its inner type to allow correct parsing + # of generic Json arguments like Json[T] + if not origin_type and lenient_issubclass(type_, JsonWrapper): + type_.inner_type = replace_types(type_.inner_type, type_map) + return type_ + + # If all else fails, we try to resolve the type directly and otherwise just + # return the input with no modifications. + new_type = type_map.get(type_, type_) + # Convert string to ForwardRef + if isinstance(new_type, str): + return ForwardRef(new_type) + else: + return new_type + + +def check_parameters_count(cls: Type[GenericModel], parameters: Tuple[Any, ...]) -> None: + actual = len(parameters) + expected = len(cls.__parameters__) + if actual != expected: + description = 'many' if actual > expected else 'few' + raise TypeError(f'Too {description} parameters for {cls.__name__}; actual {actual}, expected {expected}') + + +DictValues: Type[Any] = {}.values().__class__ + + +def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]: + """Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found.""" + if isinstance(v, TypeVar): + yield v + elif hasattr(v, '__parameters__') and not get_origin(v) and lenient_issubclass(v, GenericModel): + yield from v.__parameters__ + elif isinstance(v, (DictValues, list)): + for var in v: + yield from iter_contained_typevars(var) + else: + args = get_args(v) + for arg in args: + yield from iter_contained_typevars(arg) + + +def get_caller_frame_info() -> Tuple[Optional[str], bool]: + """ + Used inside a function to check whether it was called globally + + Will only work against non-compiled code, therefore used only in pydantic.generics + + :returns Tuple[module_name, called_globally] + """ + try: + previous_caller_frame = sys._getframe(2) + except ValueError as e: + raise RuntimeError('This function must be used inside another function') from e + except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it + return None, False + frame_globals = previous_caller_frame.f_globals + return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals + + +def _prepare_model_fields( + created_model: Type[GenericModel], + fields: Mapping[str, Any], + instance_type_hints: Mapping[str, type], + typevars_map: Mapping[Any, type], +) -> None: + """ + Replace DeferredType fields with concrete type hints and prepare them. + """ + + for key, field in created_model.__fields__.items(): + if key not in fields: + assert field.type_.__class__ is not DeferredType + # https://github.com/nedbat/coveragepy/issues/198 + continue # pragma: no cover + + assert field.type_.__class__ is DeferredType, field.type_.__class__ + + field_type_hint = instance_type_hints[key] + concrete_type = replace_types(field_type_hint, typevars_map) + field.type_ = concrete_type + field.outer_type_ = concrete_type + field.prepare() + created_model.__annotations__[key] = concrete_type diff --git a/lib/pydantic/v1/json.py b/lib/pydantic/v1/json.py new file mode 100644 index 00000000..b358b850 --- /dev/null +++ b/lib/pydantic/v1/json.py @@ -0,0 +1,112 @@ +import datetime +from collections import deque +from decimal import Decimal +from enum import Enum +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from pathlib import Path +from re import Pattern +from types import GeneratorType +from typing import Any, Callable, Dict, Type, Union +from uuid import UUID + +from .color import Color +from .networks import NameEmail +from .types import SecretBytes, SecretStr + +__all__ = 'pydantic_encoder', 'custom_pydantic_encoder', 'timedelta_isoformat' + + +def isoformat(o: Union[datetime.date, datetime.time]) -> str: + return o.isoformat() + + +def decimal_encoder(dec_value: Decimal) -> Union[int, float]: + """ + Encodes a Decimal as int of there's no exponent, otherwise float + + This is useful when we use ConstrainedDecimal to represent Numeric(x,0) + where a integer (but not int typed) is used. Encoding this as a float + results in failed round-tripping between encode and parse. + Our Id type is a prime example of this. + + >>> decimal_encoder(Decimal("1.0")) + 1.0 + + >>> decimal_encoder(Decimal("1")) + 1 + """ + if dec_value.as_tuple().exponent >= 0: + return int(dec_value) + else: + return float(dec_value) + + +ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { + bytes: lambda o: o.decode(), + Color: str, + datetime.date: isoformat, + datetime.datetime: isoformat, + datetime.time: isoformat, + datetime.timedelta: lambda td: td.total_seconds(), + Decimal: decimal_encoder, + Enum: lambda o: o.value, + frozenset: list, + deque: list, + GeneratorType: list, + IPv4Address: str, + IPv4Interface: str, + IPv4Network: str, + IPv6Address: str, + IPv6Interface: str, + IPv6Network: str, + NameEmail: str, + Path: str, + Pattern: lambda o: o.pattern, + SecretBytes: str, + SecretStr: str, + set: list, + UUID: str, +} + + +def pydantic_encoder(obj: Any) -> Any: + from dataclasses import asdict, is_dataclass + + from .main import BaseModel + + if isinstance(obj, BaseModel): + return obj.dict() + elif is_dataclass(obj): + return asdict(obj) + + # Check the class type and its superclasses for a matching encoder + for base in obj.__class__.__mro__[:-1]: + try: + encoder = ENCODERS_BY_TYPE[base] + except KeyError: + continue + return encoder(obj) + else: # We have exited the for loop without finding a suitable encoder + raise TypeError(f"Object of type '{obj.__class__.__name__}' is not JSON serializable") + + +def custom_pydantic_encoder(type_encoders: Dict[Any, Callable[[Type[Any]], Any]], obj: Any) -> Any: + # Check the class type and its superclasses for a matching encoder + for base in obj.__class__.__mro__[:-1]: + try: + encoder = type_encoders[base] + except KeyError: + continue + + return encoder(obj) + else: # We have exited the for loop without finding a suitable encoder + return pydantic_encoder(obj) + + +def timedelta_isoformat(td: datetime.timedelta) -> str: + """ + ISO 8601 encoding for Python timedelta object. + """ + minutes, seconds = divmod(td.seconds, 60) + hours, minutes = divmod(minutes, 60) + return f'{"-" if td.days < 0 else ""}P{abs(td.days)}DT{hours:d}H{minutes:d}M{seconds:d}.{td.microseconds:06d}S' diff --git a/lib/pydantic/v1/main.py b/lib/pydantic/v1/main.py new file mode 100644 index 00000000..08b8af55 --- /dev/null +++ b/lib/pydantic/v1/main.py @@ -0,0 +1,1107 @@ +import warnings +from abc import ABCMeta +from copy import deepcopy +from enum import Enum +from functools import partial +from pathlib import Path +from types import FunctionType, prepare_class, resolve_bases +from typing import ( + TYPE_CHECKING, + AbstractSet, + Any, + Callable, + ClassVar, + Dict, + List, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, + no_type_check, + overload, +) + +from typing_extensions import dataclass_transform + +from .class_validators import ValidatorGroup, extract_root_validators, extract_validators, inherit_validators +from .config import BaseConfig, Extra, inherit_config, prepare_config +from .error_wrappers import ErrorWrapper, ValidationError +from .errors import ConfigError, DictError, ExtraError, MissingError +from .fields import ( + MAPPING_LIKE_SHAPES, + Field, + ModelField, + ModelPrivateAttr, + PrivateAttr, + Undefined, + is_finalvar_with_default_val, +) +from .json import custom_pydantic_encoder, pydantic_encoder +from .parse import Protocol, load_file, load_str_bytes +from .schema import default_ref_template, model_schema +from .types import PyObject, StrBytes +from .typing import ( + AnyCallable, + get_args, + get_origin, + is_classvar, + is_namedtuple, + is_union, + resolve_annotations, + update_model_forward_refs, +) +from .utils import ( + DUNDER_ATTRIBUTES, + ROOT_KEY, + ClassAttribute, + GetterDict, + Representation, + ValueItems, + generate_model_signature, + is_valid_field, + is_valid_private_name, + lenient_issubclass, + sequence_like, + smart_deepcopy, + unique_list, + validate_field_name, +) + +if TYPE_CHECKING: + from inspect import Signature + + from .class_validators import ValidatorListDict + from .types import ModelOrDc + from .typing import ( + AbstractSetIntStr, + AnyClassMethod, + CallableGenerator, + DictAny, + DictStrAny, + MappingIntStrAny, + ReprArgs, + SetStr, + TupleGenerator, + ) + + Model = TypeVar('Model', bound='BaseModel') + +__all__ = 'BaseModel', 'create_model', 'validate_model' + +_T = TypeVar('_T') + + +def validate_custom_root_type(fields: Dict[str, ModelField]) -> None: + if len(fields) > 1: + raise ValueError(f'{ROOT_KEY} cannot be mixed with other fields') + + +def generate_hash_function(frozen: bool) -> Optional[Callable[[Any], int]]: + def hash_function(self_: Any) -> int: + return hash(self_.__class__) + hash(tuple(self_.__dict__.values())) + + return hash_function if frozen else None + + +# If a field is of type `Callable`, its default value should be a function and cannot to ignored. +ANNOTATED_FIELD_UNTOUCHED_TYPES: Tuple[Any, ...] = (property, type, classmethod, staticmethod) +# When creating a `BaseModel` instance, we bypass all the methods, properties... added to the model +UNTOUCHED_TYPES: Tuple[Any, ...] = (FunctionType,) + ANNOTATED_FIELD_UNTOUCHED_TYPES +# Note `ModelMetaclass` refers to `BaseModel`, but is also used to *create* `BaseModel`, so we need to add this extra +# (somewhat hacky) boolean to keep track of whether we've created the `BaseModel` class yet, and therefore whether it's +# safe to refer to it. If it *hasn't* been created, we assume that the `__new__` call we're in the middle of is for +# the `BaseModel` class, since that's defined immediately after the metaclass. +_is_base_model_class_defined = False + + +@dataclass_transform(kw_only_default=True, field_specifiers=(Field,)) +class ModelMetaclass(ABCMeta): + @no_type_check # noqa C901 + def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901 + fields: Dict[str, ModelField] = {} + config = BaseConfig + validators: 'ValidatorListDict' = {} + + pre_root_validators, post_root_validators = [], [] + private_attributes: Dict[str, ModelPrivateAttr] = {} + base_private_attributes: Dict[str, ModelPrivateAttr] = {} + slots: SetStr = namespace.get('__slots__', ()) + slots = {slots} if isinstance(slots, str) else set(slots) + class_vars: SetStr = set() + hash_func: Optional[Callable[[Any], int]] = None + + for base in reversed(bases): + if _is_base_model_class_defined and issubclass(base, BaseModel) and base != BaseModel: + fields.update(smart_deepcopy(base.__fields__)) + config = inherit_config(base.__config__, config) + validators = inherit_validators(base.__validators__, validators) + pre_root_validators += base.__pre_root_validators__ + post_root_validators += base.__post_root_validators__ + base_private_attributes.update(base.__private_attributes__) + class_vars.update(base.__class_vars__) + hash_func = base.__hash__ + + resolve_forward_refs = kwargs.pop('__resolve_forward_refs__', True) + allowed_config_kwargs: SetStr = { + key + for key in dir(config) + if not (key.startswith('__') and key.endswith('__')) # skip dunder methods and attributes + } + config_kwargs = {key: kwargs.pop(key) for key in kwargs.keys() & allowed_config_kwargs} + config_from_namespace = namespace.get('Config') + if config_kwargs and config_from_namespace: + raise TypeError('Specifying config in two places is ambiguous, use either Config attribute or class kwargs') + config = inherit_config(config_from_namespace, config, **config_kwargs) + + validators = inherit_validators(extract_validators(namespace), validators) + vg = ValidatorGroup(validators) + + for f in fields.values(): + f.set_config(config) + extra_validators = vg.get_validators(f.name) + if extra_validators: + f.class_validators.update(extra_validators) + # re-run prepare to add extra validators + f.populate_validators() + + prepare_config(config, name) + + untouched_types = ANNOTATED_FIELD_UNTOUCHED_TYPES + + def is_untouched(v: Any) -> bool: + return isinstance(v, untouched_types) or v.__class__.__name__ == 'cython_function_or_method' + + if (namespace.get('__module__'), namespace.get('__qualname__')) != ('pydantic.main', 'BaseModel'): + annotations = resolve_annotations(namespace.get('__annotations__', {}), namespace.get('__module__', None)) + # annotation only fields need to come first in fields + for ann_name, ann_type in annotations.items(): + if is_classvar(ann_type): + class_vars.add(ann_name) + elif is_finalvar_with_default_val(ann_type, namespace.get(ann_name, Undefined)): + class_vars.add(ann_name) + elif is_valid_field(ann_name): + validate_field_name(bases, ann_name) + value = namespace.get(ann_name, Undefined) + allowed_types = get_args(ann_type) if is_union(get_origin(ann_type)) else (ann_type,) + if ( + is_untouched(value) + and ann_type != PyObject + and not any( + lenient_issubclass(get_origin(allowed_type), Type) for allowed_type in allowed_types + ) + ): + continue + fields[ann_name] = ModelField.infer( + name=ann_name, + value=value, + annotation=ann_type, + class_validators=vg.get_validators(ann_name), + config=config, + ) + elif ann_name not in namespace and config.underscore_attrs_are_private: + private_attributes[ann_name] = PrivateAttr() + + untouched_types = UNTOUCHED_TYPES + config.keep_untouched + for var_name, value in namespace.items(): + can_be_changed = var_name not in class_vars and not is_untouched(value) + if isinstance(value, ModelPrivateAttr): + if not is_valid_private_name(var_name): + raise NameError( + f'Private attributes "{var_name}" must not be a valid field name; ' + f'Use sunder or dunder names, e. g. "_{var_name}" or "__{var_name}__"' + ) + private_attributes[var_name] = value + elif config.underscore_attrs_are_private and is_valid_private_name(var_name) and can_be_changed: + private_attributes[var_name] = PrivateAttr(default=value) + elif is_valid_field(var_name) and var_name not in annotations and can_be_changed: + validate_field_name(bases, var_name) + inferred = ModelField.infer( + name=var_name, + value=value, + annotation=annotations.get(var_name, Undefined), + class_validators=vg.get_validators(var_name), + config=config, + ) + if var_name in fields: + if lenient_issubclass(inferred.type_, fields[var_name].type_): + inferred.type_ = fields[var_name].type_ + else: + raise TypeError( + f'The type of {name}.{var_name} differs from the new default value; ' + f'if you wish to change the type of this field, please use a type annotation' + ) + fields[var_name] = inferred + + _custom_root_type = ROOT_KEY in fields + if _custom_root_type: + validate_custom_root_type(fields) + vg.check_for_unused() + if config.json_encoders: + json_encoder = partial(custom_pydantic_encoder, config.json_encoders) + else: + json_encoder = pydantic_encoder + pre_rv_new, post_rv_new = extract_root_validators(namespace) + + if hash_func is None: + hash_func = generate_hash_function(config.frozen) + + exclude_from_namespace = fields | private_attributes.keys() | {'__slots__'} + new_namespace = { + '__config__': config, + '__fields__': fields, + '__exclude_fields__': { + name: field.field_info.exclude for name, field in fields.items() if field.field_info.exclude is not None + } + or None, + '__include_fields__': { + name: field.field_info.include for name, field in fields.items() if field.field_info.include is not None + } + or None, + '__validators__': vg.validators, + '__pre_root_validators__': unique_list( + pre_root_validators + pre_rv_new, + name_factory=lambda v: v.__name__, + ), + '__post_root_validators__': unique_list( + post_root_validators + post_rv_new, + name_factory=lambda skip_on_failure_and_v: skip_on_failure_and_v[1].__name__, + ), + '__schema_cache__': {}, + '__json_encoder__': staticmethod(json_encoder), + '__custom_root_type__': _custom_root_type, + '__private_attributes__': {**base_private_attributes, **private_attributes}, + '__slots__': slots | private_attributes.keys(), + '__hash__': hash_func, + '__class_vars__': class_vars, + **{n: v for n, v in namespace.items() if n not in exclude_from_namespace}, + } + + cls = super().__new__(mcs, name, bases, new_namespace, **kwargs) + # set __signature__ attr only for model class, but not for its instances + cls.__signature__ = ClassAttribute('__signature__', generate_model_signature(cls.__init__, fields, config)) + if resolve_forward_refs: + cls.__try_update_forward_refs__() + + # preserve `__set_name__` protocol defined in https://peps.python.org/pep-0487 + # for attributes not in `new_namespace` (e.g. private attributes) + for name, obj in namespace.items(): + if name not in new_namespace: + set_name = getattr(obj, '__set_name__', None) + if callable(set_name): + set_name(cls, name) + + return cls + + def __instancecheck__(self, instance: Any) -> bool: + """ + Avoid calling ABC _abc_subclasscheck unless we're pretty sure. + + See #3829 and python/cpython#92810 + """ + return hasattr(instance, '__fields__') and super().__instancecheck__(instance) + + +object_setattr = object.__setattr__ + + +class BaseModel(Representation, metaclass=ModelMetaclass): + if TYPE_CHECKING: + # populated by the metaclass, defined here to help IDEs only + __fields__: ClassVar[Dict[str, ModelField]] = {} + __include_fields__: ClassVar[Optional[Mapping[str, Any]]] = None + __exclude_fields__: ClassVar[Optional[Mapping[str, Any]]] = None + __validators__: ClassVar[Dict[str, AnyCallable]] = {} + __pre_root_validators__: ClassVar[List[AnyCallable]] + __post_root_validators__: ClassVar[List[Tuple[bool, AnyCallable]]] + __config__: ClassVar[Type[BaseConfig]] = BaseConfig + __json_encoder__: ClassVar[Callable[[Any], Any]] = lambda x: x + __schema_cache__: ClassVar['DictAny'] = {} + __custom_root_type__: ClassVar[bool] = False + __signature__: ClassVar['Signature'] + __private_attributes__: ClassVar[Dict[str, ModelPrivateAttr]] + __class_vars__: ClassVar[SetStr] + __fields_set__: ClassVar[SetStr] = set() + + Config = BaseConfig + __slots__ = ('__dict__', '__fields_set__') + __doc__ = '' # Null out the Representation docstring + + def __init__(__pydantic_self__, **data: Any) -> None: + """ + Create a new model by parsing and validating input data from keyword arguments. + + Raises ValidationError if the input data cannot be parsed to form a valid model. + """ + # Uses something other than `self` the first arg to allow "self" as a settable attribute + values, fields_set, validation_error = validate_model(__pydantic_self__.__class__, data) + if validation_error: + raise validation_error + try: + object_setattr(__pydantic_self__, '__dict__', values) + except TypeError as e: + raise TypeError( + 'Model values must be a dict; you may not have returned a dictionary from a root validator' + ) from e + object_setattr(__pydantic_self__, '__fields_set__', fields_set) + __pydantic_self__._init_private_attributes() + + @no_type_check + def __setattr__(self, name, value): # noqa: C901 (ignore complexity) + if name in self.__private_attributes__ or name in DUNDER_ATTRIBUTES: + return object_setattr(self, name, value) + + if self.__config__.extra is not Extra.allow and name not in self.__fields__: + raise ValueError(f'"{self.__class__.__name__}" object has no field "{name}"') + elif not self.__config__.allow_mutation or self.__config__.frozen: + raise TypeError(f'"{self.__class__.__name__}" is immutable and does not support item assignment') + elif name in self.__fields__ and self.__fields__[name].final: + raise TypeError( + f'"{self.__class__.__name__}" object "{name}" field is final and does not support reassignment' + ) + elif self.__config__.validate_assignment: + new_values = {**self.__dict__, name: value} + + for validator in self.__pre_root_validators__: + try: + new_values = validator(self.__class__, new_values) + except (ValueError, TypeError, AssertionError) as exc: + raise ValidationError([ErrorWrapper(exc, loc=ROOT_KEY)], self.__class__) + + known_field = self.__fields__.get(name, None) + if known_field: + # We want to + # - make sure validators are called without the current value for this field inside `values` + # - keep other values (e.g. submodels) untouched (using `BaseModel.dict()` will change them into dicts) + # - keep the order of the fields + if not known_field.field_info.allow_mutation: + raise TypeError(f'"{known_field.name}" has allow_mutation set to False and cannot be assigned') + dict_without_original_value = {k: v for k, v in self.__dict__.items() if k != name} + value, error_ = known_field.validate(value, dict_without_original_value, loc=name, cls=self.__class__) + if error_: + raise ValidationError([error_], self.__class__) + else: + new_values[name] = value + + errors = [] + for skip_on_failure, validator in self.__post_root_validators__: + if skip_on_failure and errors: + continue + try: + new_values = validator(self.__class__, new_values) + except (ValueError, TypeError, AssertionError) as exc: + errors.append(ErrorWrapper(exc, loc=ROOT_KEY)) + if errors: + raise ValidationError(errors, self.__class__) + + # update the whole __dict__ as other values than just `value` + # may be changed (e.g. with `root_validator`) + object_setattr(self, '__dict__', new_values) + else: + self.__dict__[name] = value + + self.__fields_set__.add(name) + + def __getstate__(self) -> 'DictAny': + private_attrs = ((k, getattr(self, k, Undefined)) for k in self.__private_attributes__) + return { + '__dict__': self.__dict__, + '__fields_set__': self.__fields_set__, + '__private_attribute_values__': {k: v for k, v in private_attrs if v is not Undefined}, + } + + def __setstate__(self, state: 'DictAny') -> None: + object_setattr(self, '__dict__', state['__dict__']) + object_setattr(self, '__fields_set__', state['__fields_set__']) + for name, value in state.get('__private_attribute_values__', {}).items(): + object_setattr(self, name, value) + + def _init_private_attributes(self) -> None: + for name, private_attr in self.__private_attributes__.items(): + default = private_attr.get_default() + if default is not Undefined: + object_setattr(self, name, default) + + def dict( + self, + *, + include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> 'DictStrAny': + """ + Generate a dictionary representation of the model, optionally specifying which fields to include or exclude. + + """ + if skip_defaults is not None: + warnings.warn( + f'{self.__class__.__name__}.dict(): "skip_defaults" is deprecated and replaced by "exclude_unset"', + DeprecationWarning, + ) + exclude_unset = skip_defaults + + return dict( + self._iter( + to_dict=True, + by_alias=by_alias, + include=include, + exclude=exclude, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + ) + + def json( + self, + *, + include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + by_alias: bool = False, + skip_defaults: Optional[bool] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + encoder: Optional[Callable[[Any], Any]] = None, + models_as_dict: bool = True, + **dumps_kwargs: Any, + ) -> str: + """ + Generate a JSON representation of the model, `include` and `exclude` arguments as per `dict()`. + + `encoder` is an optional function to supply as `default` to json.dumps(), other arguments as per `json.dumps()`. + """ + if skip_defaults is not None: + warnings.warn( + f'{self.__class__.__name__}.json(): "skip_defaults" is deprecated and replaced by "exclude_unset"', + DeprecationWarning, + ) + exclude_unset = skip_defaults + encoder = cast(Callable[[Any], Any], encoder or self.__json_encoder__) + + # We don't directly call `self.dict()`, which does exactly this with `to_dict=True` + # because we want to be able to keep raw `BaseModel` instances and not as `dict`. + # This allows users to write custom JSON encoders for given `BaseModel` classes. + data = dict( + self._iter( + to_dict=models_as_dict, + by_alias=by_alias, + include=include, + exclude=exclude, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + ) + if self.__custom_root_type__: + data = data[ROOT_KEY] + return self.__config__.json_dumps(data, default=encoder, **dumps_kwargs) + + @classmethod + def _enforce_dict_if_root(cls, obj: Any) -> Any: + if cls.__custom_root_type__ and ( + not (isinstance(obj, dict) and obj.keys() == {ROOT_KEY}) + and not (isinstance(obj, BaseModel) and obj.__fields__.keys() == {ROOT_KEY}) + or cls.__fields__[ROOT_KEY].shape in MAPPING_LIKE_SHAPES + ): + return {ROOT_KEY: obj} + else: + return obj + + @classmethod + def parse_obj(cls: Type['Model'], obj: Any) -> 'Model': + obj = cls._enforce_dict_if_root(obj) + if not isinstance(obj, dict): + try: + obj = dict(obj) + except (TypeError, ValueError) as e: + exc = TypeError(f'{cls.__name__} expected dict not {obj.__class__.__name__}') + raise ValidationError([ErrorWrapper(exc, loc=ROOT_KEY)], cls) from e + return cls(**obj) + + @classmethod + def parse_raw( + cls: Type['Model'], + b: StrBytes, + *, + content_type: str = None, + encoding: str = 'utf8', + proto: Protocol = None, + allow_pickle: bool = False, + ) -> 'Model': + try: + obj = load_str_bytes( + b, + proto=proto, + content_type=content_type, + encoding=encoding, + allow_pickle=allow_pickle, + json_loads=cls.__config__.json_loads, + ) + except (ValueError, TypeError, UnicodeDecodeError) as e: + raise ValidationError([ErrorWrapper(e, loc=ROOT_KEY)], cls) + return cls.parse_obj(obj) + + @classmethod + def parse_file( + cls: Type['Model'], + path: Union[str, Path], + *, + content_type: str = None, + encoding: str = 'utf8', + proto: Protocol = None, + allow_pickle: bool = False, + ) -> 'Model': + obj = load_file( + path, + proto=proto, + content_type=content_type, + encoding=encoding, + allow_pickle=allow_pickle, + json_loads=cls.__config__.json_loads, + ) + return cls.parse_obj(obj) + + @classmethod + def from_orm(cls: Type['Model'], obj: Any) -> 'Model': + if not cls.__config__.orm_mode: + raise ConfigError('You must have the config attribute orm_mode=True to use from_orm') + obj = {ROOT_KEY: obj} if cls.__custom_root_type__ else cls._decompose_class(obj) + m = cls.__new__(cls) + values, fields_set, validation_error = validate_model(cls, obj) + if validation_error: + raise validation_error + object_setattr(m, '__dict__', values) + object_setattr(m, '__fields_set__', fields_set) + m._init_private_attributes() + return m + + @classmethod + def construct(cls: Type['Model'], _fields_set: Optional['SetStr'] = None, **values: Any) -> 'Model': + """ + Creates a new model setting __dict__ and __fields_set__ from trusted or pre-validated data. + Default values are respected, but no other validation is performed. + Behaves as if `Config.extra = 'allow'` was set since it adds all passed values + """ + m = cls.__new__(cls) + fields_values: Dict[str, Any] = {} + for name, field in cls.__fields__.items(): + if field.alt_alias and field.alias in values: + fields_values[name] = values[field.alias] + elif name in values: + fields_values[name] = values[name] + elif not field.required: + fields_values[name] = field.get_default() + fields_values.update(values) + object_setattr(m, '__dict__', fields_values) + if _fields_set is None: + _fields_set = set(values.keys()) + object_setattr(m, '__fields_set__', _fields_set) + m._init_private_attributes() + return m + + def _copy_and_set_values(self: 'Model', values: 'DictStrAny', fields_set: 'SetStr', *, deep: bool) -> 'Model': + if deep: + # chances of having empty dict here are quite low for using smart_deepcopy + values = deepcopy(values) + + cls = self.__class__ + m = cls.__new__(cls) + object_setattr(m, '__dict__', values) + object_setattr(m, '__fields_set__', fields_set) + for name in self.__private_attributes__: + value = getattr(self, name, Undefined) + if value is not Undefined: + if deep: + value = deepcopy(value) + object_setattr(m, name, value) + + return m + + def copy( + self: 'Model', + *, + include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + update: Optional['DictStrAny'] = None, + deep: bool = False, + ) -> 'Model': + """ + Duplicate a model, optionally choose which fields to include, exclude and change. + + :param include: fields to include in new model + :param exclude: fields to exclude from new model, as with values this takes precedence over include + :param update: values to change/add in the new model. Note: the data is not validated before creating + the new model: you should trust this data + :param deep: set to `True` to make a deep copy of the model + :return: new model instance + """ + + values = dict( + self._iter(to_dict=False, by_alias=False, include=include, exclude=exclude, exclude_unset=False), + **(update or {}), + ) + + # new `__fields_set__` can have unset optional fields with a set value in `update` kwarg + if update: + fields_set = self.__fields_set__ | update.keys() + else: + fields_set = set(self.__fields_set__) + + return self._copy_and_set_values(values, fields_set, deep=deep) + + @classmethod + def schema(cls, by_alias: bool = True, ref_template: str = default_ref_template) -> 'DictStrAny': + cached = cls.__schema_cache__.get((by_alias, ref_template)) + if cached is not None: + return cached + s = model_schema(cls, by_alias=by_alias, ref_template=ref_template) + cls.__schema_cache__[(by_alias, ref_template)] = s + return s + + @classmethod + def schema_json( + cls, *, by_alias: bool = True, ref_template: str = default_ref_template, **dumps_kwargs: Any + ) -> str: + from .json import pydantic_encoder + + return cls.__config__.json_dumps( + cls.schema(by_alias=by_alias, ref_template=ref_template), default=pydantic_encoder, **dumps_kwargs + ) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls: Type['Model'], value: Any) -> 'Model': + if isinstance(value, cls): + copy_on_model_validation = cls.__config__.copy_on_model_validation + # whether to deep or shallow copy the model on validation, None means do not copy + deep_copy: Optional[bool] = None + if copy_on_model_validation not in {'deep', 'shallow', 'none'}: + # Warn about deprecated behavior + warnings.warn( + "`copy_on_model_validation` should be a string: 'deep', 'shallow' or 'none'", DeprecationWarning + ) + if copy_on_model_validation: + deep_copy = False + + if copy_on_model_validation == 'shallow': + # shallow copy + deep_copy = False + elif copy_on_model_validation == 'deep': + # deep copy + deep_copy = True + + if deep_copy is None: + return value + else: + return value._copy_and_set_values(value.__dict__, value.__fields_set__, deep=deep_copy) + + value = cls._enforce_dict_if_root(value) + + if isinstance(value, dict): + return cls(**value) + elif cls.__config__.orm_mode: + return cls.from_orm(value) + else: + try: + value_as_dict = dict(value) + except (TypeError, ValueError) as e: + raise DictError() from e + return cls(**value_as_dict) + + @classmethod + def _decompose_class(cls: Type['Model'], obj: Any) -> GetterDict: + if isinstance(obj, GetterDict): + return obj + return cls.__config__.getter_dict(obj) + + @classmethod + @no_type_check + def _get_value( + cls, + v: Any, + to_dict: bool, + by_alias: bool, + include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']], + exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']], + exclude_unset: bool, + exclude_defaults: bool, + exclude_none: bool, + ) -> Any: + if isinstance(v, BaseModel): + if to_dict: + v_dict = v.dict( + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + include=include, + exclude=exclude, + exclude_none=exclude_none, + ) + if ROOT_KEY in v_dict: + return v_dict[ROOT_KEY] + return v_dict + else: + return v.copy(include=include, exclude=exclude) + + value_exclude = ValueItems(v, exclude) if exclude else None + value_include = ValueItems(v, include) if include else None + + if isinstance(v, dict): + return { + k_: cls._get_value( + v_, + to_dict=to_dict, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + include=value_include and value_include.for_element(k_), + exclude=value_exclude and value_exclude.for_element(k_), + exclude_none=exclude_none, + ) + for k_, v_ in v.items() + if (not value_exclude or not value_exclude.is_excluded(k_)) + and (not value_include or value_include.is_included(k_)) + } + + elif sequence_like(v): + seq_args = ( + cls._get_value( + v_, + to_dict=to_dict, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + include=value_include and value_include.for_element(i), + exclude=value_exclude and value_exclude.for_element(i), + exclude_none=exclude_none, + ) + for i, v_ in enumerate(v) + if (not value_exclude or not value_exclude.is_excluded(i)) + and (not value_include or value_include.is_included(i)) + ) + + return v.__class__(*seq_args) if is_namedtuple(v.__class__) else v.__class__(seq_args) + + elif isinstance(v, Enum) and getattr(cls.Config, 'use_enum_values', False): + return v.value + + else: + return v + + @classmethod + def __try_update_forward_refs__(cls, **localns: Any) -> None: + """ + Same as update_forward_refs but will not raise exception + when forward references are not defined. + """ + update_model_forward_refs(cls, cls.__fields__.values(), cls.__config__.json_encoders, localns, (NameError,)) + + @classmethod + def update_forward_refs(cls, **localns: Any) -> None: + """ + Try to update ForwardRefs on fields based on this Model, globalns and localns. + """ + update_model_forward_refs(cls, cls.__fields__.values(), cls.__config__.json_encoders, localns) + + def __iter__(self) -> 'TupleGenerator': + """ + so `dict(model)` works + """ + yield from self.__dict__.items() + + def _iter( + self, + to_dict: bool = False, + by_alias: bool = False, + include: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + exclude: Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']] = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + ) -> 'TupleGenerator': + # Merge field set excludes with explicit exclude parameter with explicit overriding field set options. + # The extra "is not None" guards are not logically necessary but optimizes performance for the simple case. + if exclude is not None or self.__exclude_fields__ is not None: + exclude = ValueItems.merge(self.__exclude_fields__, exclude) + + if include is not None or self.__include_fields__ is not None: + include = ValueItems.merge(self.__include_fields__, include, intersect=True) + + allowed_keys = self._calculate_keys( + include=include, exclude=exclude, exclude_unset=exclude_unset # type: ignore + ) + if allowed_keys is None and not (to_dict or by_alias or exclude_unset or exclude_defaults or exclude_none): + # huge boost for plain _iter() + yield from self.__dict__.items() + return + + value_exclude = ValueItems(self, exclude) if exclude is not None else None + value_include = ValueItems(self, include) if include is not None else None + + for field_key, v in self.__dict__.items(): + if (allowed_keys is not None and field_key not in allowed_keys) or (exclude_none and v is None): + continue + + if exclude_defaults: + model_field = self.__fields__.get(field_key) + if not getattr(model_field, 'required', True) and getattr(model_field, 'default', _missing) == v: + continue + + if by_alias and field_key in self.__fields__: + dict_key = self.__fields__[field_key].alias + else: + dict_key = field_key + + if to_dict or value_include or value_exclude: + v = self._get_value( + v, + to_dict=to_dict, + by_alias=by_alias, + include=value_include and value_include.for_element(field_key), + exclude=value_exclude and value_exclude.for_element(field_key), + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + ) + yield dict_key, v + + def _calculate_keys( + self, + include: Optional['MappingIntStrAny'], + exclude: Optional['MappingIntStrAny'], + exclude_unset: bool, + update: Optional['DictStrAny'] = None, + ) -> Optional[AbstractSet[str]]: + if include is None and exclude is None and exclude_unset is False: + return None + + keys: AbstractSet[str] + if exclude_unset: + keys = self.__fields_set__.copy() + else: + keys = self.__dict__.keys() + + if include is not None: + keys &= include.keys() + + if update: + keys -= update.keys() + + if exclude: + keys -= {k for k, v in exclude.items() if ValueItems.is_true(v)} + + return keys + + def __eq__(self, other: Any) -> bool: + if isinstance(other, BaseModel): + return self.dict() == other.dict() + else: + return self.dict() == other + + def __repr_args__(self) -> 'ReprArgs': + return [ + (k, v) + for k, v in self.__dict__.items() + if k not in DUNDER_ATTRIBUTES and (k not in self.__fields__ or self.__fields__[k].field_info.repr) + ] + + +_is_base_model_class_defined = True + + +@overload +def create_model( + __model_name: str, + *, + __config__: Optional[Type[BaseConfig]] = None, + __base__: None = None, + __module__: str = __name__, + __validators__: Dict[str, 'AnyClassMethod'] = None, + __cls_kwargs__: Dict[str, Any] = None, + **field_definitions: Any, +) -> Type['BaseModel']: + ... + + +@overload +def create_model( + __model_name: str, + *, + __config__: Optional[Type[BaseConfig]] = None, + __base__: Union[Type['Model'], Tuple[Type['Model'], ...]], + __module__: str = __name__, + __validators__: Dict[str, 'AnyClassMethod'] = None, + __cls_kwargs__: Dict[str, Any] = None, + **field_definitions: Any, +) -> Type['Model']: + ... + + +def create_model( + __model_name: str, + *, + __config__: Optional[Type[BaseConfig]] = None, + __base__: Union[None, Type['Model'], Tuple[Type['Model'], ...]] = None, + __module__: str = __name__, + __validators__: Dict[str, 'AnyClassMethod'] = None, + __cls_kwargs__: Dict[str, Any] = None, + __slots__: Optional[Tuple[str, ...]] = None, + **field_definitions: Any, +) -> Type['Model']: + """ + Dynamically create a model. + :param __model_name: name of the created model + :param __config__: config class to use for the new model + :param __base__: base class for the new model to inherit from + :param __module__: module of the created model + :param __validators__: a dict of method names and @validator class methods + :param __cls_kwargs__: a dict for class creation + :param __slots__: Deprecated, `__slots__` should not be passed to `create_model` + :param field_definitions: fields of the model (or extra fields if a base is supplied) + in the format `=(, )` or `=, e.g. + `foobar=(str, ...)` or `foobar=123`, or, for complex use-cases, in the format + `=` or `=(, )`, e.g. + `foo=Field(datetime, default_factory=datetime.utcnow, alias='bar')` or + `foo=(str, FieldInfo(title='Foo'))` + """ + if __slots__ is not None: + # __slots__ will be ignored from here on + warnings.warn('__slots__ should not be passed to create_model', RuntimeWarning) + + if __base__ is not None: + if __config__ is not None: + raise ConfigError('to avoid confusion __config__ and __base__ cannot be used together') + if not isinstance(__base__, tuple): + __base__ = (__base__,) + else: + __base__ = (cast(Type['Model'], BaseModel),) + + __cls_kwargs__ = __cls_kwargs__ or {} + + fields = {} + annotations = {} + + for f_name, f_def in field_definitions.items(): + if not is_valid_field(f_name): + warnings.warn(f'fields may not start with an underscore, ignoring "{f_name}"', RuntimeWarning) + if isinstance(f_def, tuple): + try: + f_annotation, f_value = f_def + except ValueError as e: + raise ConfigError( + 'field definitions should either be a tuple of (, ) or just a ' + 'default value, unfortunately this means tuples as ' + 'default values are not allowed' + ) from e + else: + f_annotation, f_value = None, f_def + + if f_annotation: + annotations[f_name] = f_annotation + fields[f_name] = f_value + + namespace: 'DictStrAny' = {'__annotations__': annotations, '__module__': __module__} + if __validators__: + namespace.update(__validators__) + namespace.update(fields) + if __config__: + namespace['Config'] = inherit_config(__config__, BaseConfig) + resolved_bases = resolve_bases(__base__) + meta, ns, kwds = prepare_class(__model_name, resolved_bases, kwds=__cls_kwargs__) + if resolved_bases is not __base__: + ns['__orig_bases__'] = __base__ + namespace.update(ns) + return meta(__model_name, resolved_bases, namespace, **kwds) + + +_missing = object() + + +def validate_model( # noqa: C901 (ignore complexity) + model: Type[BaseModel], input_data: 'DictStrAny', cls: 'ModelOrDc' = None +) -> Tuple['DictStrAny', 'SetStr', Optional[ValidationError]]: + """ + validate data against a model. + """ + values = {} + errors = [] + # input_data names, possibly alias + names_used = set() + # field names, never aliases + fields_set = set() + config = model.__config__ + check_extra = config.extra is not Extra.ignore + cls_ = cls or model + + for validator in model.__pre_root_validators__: + try: + input_data = validator(cls_, input_data) + except (ValueError, TypeError, AssertionError) as exc: + return {}, set(), ValidationError([ErrorWrapper(exc, loc=ROOT_KEY)], cls_) + + for name, field in model.__fields__.items(): + value = input_data.get(field.alias, _missing) + using_name = False + if value is _missing and config.allow_population_by_field_name and field.alt_alias: + value = input_data.get(field.name, _missing) + using_name = True + + if value is _missing: + if field.required: + errors.append(ErrorWrapper(MissingError(), loc=field.alias)) + continue + + value = field.get_default() + + if not config.validate_all and not field.validate_always: + values[name] = value + continue + else: + fields_set.add(name) + if check_extra: + names_used.add(field.name if using_name else field.alias) + + v_, errors_ = field.validate(value, values, loc=field.alias, cls=cls_) + if isinstance(errors_, ErrorWrapper): + errors.append(errors_) + elif isinstance(errors_, list): + errors.extend(errors_) + else: + values[name] = v_ + + if check_extra: + if isinstance(input_data, GetterDict): + extra = input_data.extra_keys() - names_used + else: + extra = input_data.keys() - names_used + if extra: + fields_set |= extra + if config.extra is Extra.allow: + for f in extra: + values[f] = input_data[f] + else: + for f in sorted(extra): + errors.append(ErrorWrapper(ExtraError(), loc=f)) + + for skip_on_failure, validator in model.__post_root_validators__: + if skip_on_failure and errors: + continue + try: + values = validator(cls_, values) + except (ValueError, TypeError, AssertionError) as exc: + errors.append(ErrorWrapper(exc, loc=ROOT_KEY)) + + if errors: + return values, fields_set, ValidationError(errors, cls_) + else: + return values, fields_set, None diff --git a/lib/pydantic/v1/mypy.py b/lib/pydantic/v1/mypy.py new file mode 100644 index 00000000..1d6d5ae2 --- /dev/null +++ b/lib/pydantic/v1/mypy.py @@ -0,0 +1,944 @@ +import sys +from configparser import ConfigParser +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type as TypingType, Union + +from mypy.errorcodes import ErrorCode +from mypy.nodes import ( + ARG_NAMED, + ARG_NAMED_OPT, + ARG_OPT, + ARG_POS, + ARG_STAR2, + MDEF, + Argument, + AssignmentStmt, + Block, + CallExpr, + ClassDef, + Context, + Decorator, + EllipsisExpr, + FuncBase, + FuncDef, + JsonDict, + MemberExpr, + NameExpr, + PassStmt, + PlaceholderNode, + RefExpr, + StrExpr, + SymbolNode, + SymbolTableNode, + TempNode, + TypeInfo, + TypeVarExpr, + Var, +) +from mypy.options import Options +from mypy.plugin import ( + CheckerPluginInterface, + ClassDefContext, + FunctionContext, + MethodContext, + Plugin, + ReportConfigContext, + SemanticAnalyzerPluginInterface, +) +from mypy.plugins import dataclasses +from mypy.semanal import set_callable_name # type: ignore +from mypy.server.trigger import make_wildcard_trigger +from mypy.types import ( + AnyType, + CallableType, + Instance, + NoneType, + Overloaded, + ProperType, + Type, + TypeOfAny, + TypeType, + TypeVarType, + UnionType, + get_proper_type, +) +from mypy.typevars import fill_typevars +from mypy.util import get_unique_redefinition_name +from mypy.version import __version__ as mypy_version + +from pydantic.utils import is_valid_field + +try: + from mypy.types import TypeVarDef # type: ignore[attr-defined] +except ImportError: # pragma: no cover + # Backward-compatible with TypeVarDef from Mypy 0.910. + from mypy.types import TypeVarType as TypeVarDef + +CONFIGFILE_KEY = 'pydantic-mypy' +METADATA_KEY = 'pydantic-mypy-metadata' +_NAMESPACE = __name__[:-5] # 'pydantic' in 1.10.X, 'pydantic.v1' in v2.X +BASEMODEL_FULLNAME = f'{_NAMESPACE}.main.BaseModel' +BASESETTINGS_FULLNAME = f'{_NAMESPACE}.env_settings.BaseSettings' +MODEL_METACLASS_FULLNAME = f'{_NAMESPACE}.main.ModelMetaclass' +FIELD_FULLNAME = f'{_NAMESPACE}.fields.Field' +DATACLASS_FULLNAME = f'{_NAMESPACE}.dataclasses.dataclass' + + +def parse_mypy_version(version: str) -> Tuple[int, ...]: + return tuple(map(int, version.partition('+')[0].split('.'))) + + +MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version) +BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__' + +# Increment version if plugin changes and mypy caches should be invalidated +__version__ = 2 + + +def plugin(version: str) -> 'TypingType[Plugin]': + """ + `version` is the mypy version string + + We might want to use this to print a warning if the mypy version being used is + newer, or especially older, than we expect (or need). + """ + return PydanticPlugin + + +class PydanticPlugin(Plugin): + def __init__(self, options: Options) -> None: + self.plugin_config = PydanticPluginConfig(options) + self._plugin_data = self.plugin_config.to_data() + super().__init__(options) + + def get_base_class_hook(self, fullname: str) -> 'Optional[Callable[[ClassDefContext], None]]': + sym = self.lookup_fully_qualified(fullname) + if sym and isinstance(sym.node, TypeInfo): # pragma: no branch + # No branching may occur if the mypy cache has not been cleared + if any(get_fullname(base) == BASEMODEL_FULLNAME for base in sym.node.mro): + return self._pydantic_model_class_maker_callback + return None + + def get_metaclass_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: + if fullname == MODEL_METACLASS_FULLNAME: + return self._pydantic_model_metaclass_marker_callback + return None + + def get_function_hook(self, fullname: str) -> 'Optional[Callable[[FunctionContext], Type]]': + sym = self.lookup_fully_qualified(fullname) + if sym and sym.fullname == FIELD_FULLNAME: + return self._pydantic_field_callback + return None + + def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], Type]]: + if fullname.endswith('.from_orm'): + return from_orm_callback + return None + + def get_class_decorator_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: + """Mark pydantic.dataclasses as dataclass. + + Mypy version 1.1.1 added support for `@dataclass_transform` decorator. + """ + if fullname == DATACLASS_FULLNAME and MYPY_VERSION_TUPLE < (1, 1): + return dataclasses.dataclass_class_maker_callback # type: ignore[return-value] + return None + + def report_config_data(self, ctx: ReportConfigContext) -> Dict[str, Any]: + """Return all plugin config data. + + Used by mypy to determine if cache needs to be discarded. + """ + return self._plugin_data + + def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> None: + transformer = PydanticModelTransformer(ctx, self.plugin_config) + transformer.transform() + + def _pydantic_model_metaclass_marker_callback(self, ctx: ClassDefContext) -> None: + """Reset dataclass_transform_spec attribute of ModelMetaclass. + + Let the plugin handle it. This behavior can be disabled + if 'debug_dataclass_transform' is set to True', for testing purposes. + """ + if self.plugin_config.debug_dataclass_transform: + return + info_metaclass = ctx.cls.info.declared_metaclass + assert info_metaclass, "callback not passed from 'get_metaclass_hook'" + if getattr(info_metaclass.type, 'dataclass_transform_spec', None): + info_metaclass.type.dataclass_transform_spec = None # type: ignore[attr-defined] + + def _pydantic_field_callback(self, ctx: FunctionContext) -> 'Type': + """ + Extract the type of the `default` argument from the Field function, and use it as the return type. + + In particular: + * Check whether the default and default_factory argument is specified. + * Output an error if both are specified. + * Retrieve the type of the argument which is specified, and use it as return type for the function. + """ + default_any_type = ctx.default_return_type + + assert ctx.callee_arg_names[0] == 'default', '"default" is no longer first argument in Field()' + assert ctx.callee_arg_names[1] == 'default_factory', '"default_factory" is no longer second argument in Field()' + default_args = ctx.args[0] + default_factory_args = ctx.args[1] + + if default_args and default_factory_args: + error_default_and_default_factory_specified(ctx.api, ctx.context) + return default_any_type + + if default_args: + default_type = ctx.arg_types[0][0] + default_arg = default_args[0] + + # Fallback to default Any type if the field is required + if not isinstance(default_arg, EllipsisExpr): + return default_type + + elif default_factory_args: + default_factory_type = ctx.arg_types[1][0] + + # Functions which use `ParamSpec` can be overloaded, exposing the callable's types as a parameter + # Pydantic calls the default factory without any argument, so we retrieve the first item + if isinstance(default_factory_type, Overloaded): + if MYPY_VERSION_TUPLE > (0, 910): + default_factory_type = default_factory_type.items[0] + else: + # Mypy0.910 exposes the items of overloaded types in a function + default_factory_type = default_factory_type.items()[0] # type: ignore[operator] + + if isinstance(default_factory_type, CallableType): + ret_type = default_factory_type.ret_type + # mypy doesn't think `ret_type` has `args`, you'd think mypy should know, + # add this check in case it varies by version + args = getattr(ret_type, 'args', None) + if args: + if all(isinstance(arg, TypeVarType) for arg in args): + # Looks like the default factory is a type like `list` or `dict`, replace all args with `Any` + ret_type.args = tuple(default_any_type for _ in args) # type: ignore[attr-defined] + return ret_type + + return default_any_type + + +class PydanticPluginConfig: + __slots__ = ( + 'init_forbid_extra', + 'init_typed', + 'warn_required_dynamic_aliases', + 'warn_untyped_fields', + 'debug_dataclass_transform', + ) + init_forbid_extra: bool + init_typed: bool + warn_required_dynamic_aliases: bool + warn_untyped_fields: bool + debug_dataclass_transform: bool # undocumented + + def __init__(self, options: Options) -> None: + if options.config_file is None: # pragma: no cover + return + + toml_config = parse_toml(options.config_file) + if toml_config is not None: + config = toml_config.get('tool', {}).get('pydantic-mypy', {}) + for key in self.__slots__: + setting = config.get(key, False) + if not isinstance(setting, bool): + raise ValueError(f'Configuration value must be a boolean for key: {key}') + setattr(self, key, setting) + else: + plugin_config = ConfigParser() + plugin_config.read(options.config_file) + for key in self.__slots__: + setting = plugin_config.getboolean(CONFIGFILE_KEY, key, fallback=False) + setattr(self, key, setting) + + def to_data(self) -> Dict[str, Any]: + return {key: getattr(self, key) for key in self.__slots__} + + +def from_orm_callback(ctx: MethodContext) -> Type: + """ + Raise an error if orm_mode is not enabled + """ + model_type: Instance + ctx_type = ctx.type + if isinstance(ctx_type, TypeType): + ctx_type = ctx_type.item + if isinstance(ctx_type, CallableType) and isinstance(ctx_type.ret_type, Instance): + model_type = ctx_type.ret_type # called on the class + elif isinstance(ctx_type, Instance): + model_type = ctx_type # called on an instance (unusual, but still valid) + else: # pragma: no cover + detail = f'ctx.type: {ctx_type} (of type {ctx_type.__class__.__name__})' + error_unexpected_behavior(detail, ctx.api, ctx.context) + return ctx.default_return_type + pydantic_metadata = model_type.type.metadata.get(METADATA_KEY) + if pydantic_metadata is None: + return ctx.default_return_type + orm_mode = pydantic_metadata.get('config', {}).get('orm_mode') + if orm_mode is not True: + error_from_orm(get_name(model_type.type), ctx.api, ctx.context) + return ctx.default_return_type + + +class PydanticModelTransformer: + tracked_config_fields: Set[str] = { + 'extra', + 'allow_mutation', + 'frozen', + 'orm_mode', + 'allow_population_by_field_name', + 'alias_generator', + } + + def __init__(self, ctx: ClassDefContext, plugin_config: PydanticPluginConfig) -> None: + self._ctx = ctx + self.plugin_config = plugin_config + + def transform(self) -> None: + """ + Configures the BaseModel subclass according to the plugin settings. + + In particular: + * determines the model config and fields, + * adds a fields-aware signature for the initializer and construct methods + * freezes the class if allow_mutation = False or frozen = True + * stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses + """ + ctx = self._ctx + info = ctx.cls.info + + self.adjust_validator_signatures() + config = self.collect_config() + fields = self.collect_fields(config) + is_settings = any(get_fullname(base) == BASESETTINGS_FULLNAME for base in info.mro[:-1]) + self.add_initializer(fields, config, is_settings) + self.add_construct_method(fields) + self.set_frozen(fields, frozen=config.allow_mutation is False or config.frozen is True) + info.metadata[METADATA_KEY] = { + 'fields': {field.name: field.serialize() for field in fields}, + 'config': config.set_values_dict(), + } + + def adjust_validator_signatures(self) -> None: + """When we decorate a function `f` with `pydantic.validator(...), mypy sees + `f` as a regular method taking a `self` instance, even though pydantic + internally wraps `f` with `classmethod` if necessary. + + Teach mypy this by marking any function whose outermost decorator is a + `validator()` call as a classmethod. + """ + for name, sym in self._ctx.cls.info.names.items(): + if isinstance(sym.node, Decorator): + first_dec = sym.node.original_decorators[0] + if ( + isinstance(first_dec, CallExpr) + and isinstance(first_dec.callee, NameExpr) + and first_dec.callee.fullname == f'{_NAMESPACE}.class_validators.validator' + ): + sym.node.func.is_class = True + + def collect_config(self) -> 'ModelConfigData': + """ + Collects the values of the config attributes that are used by the plugin, accounting for parent classes. + """ + ctx = self._ctx + cls = ctx.cls + config = ModelConfigData() + for stmt in cls.defs.body: + if not isinstance(stmt, ClassDef): + continue + if stmt.name == 'Config': + for substmt in stmt.defs.body: + if not isinstance(substmt, AssignmentStmt): + continue + config.update(self.get_config_update(substmt)) + if ( + config.has_alias_generator + and not config.allow_population_by_field_name + and self.plugin_config.warn_required_dynamic_aliases + ): + error_required_dynamic_aliases(ctx.api, stmt) + for info in cls.info.mro[1:]: # 0 is the current class + if METADATA_KEY not in info.metadata: + continue + + # Each class depends on the set of fields in its ancestors + ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info))) + for name, value in info.metadata[METADATA_KEY]['config'].items(): + config.setdefault(name, value) + return config + + def collect_fields(self, model_config: 'ModelConfigData') -> List['PydanticModelField']: + """ + Collects the fields for the model, accounting for parent classes + """ + # First, collect fields belonging to the current class. + ctx = self._ctx + cls = self._ctx.cls + fields = [] # type: List[PydanticModelField] + known_fields = set() # type: Set[str] + for stmt in cls.defs.body: + if not isinstance(stmt, AssignmentStmt): # `and stmt.new_syntax` to require annotation + continue + + lhs = stmt.lvalues[0] + if not isinstance(lhs, NameExpr) or not is_valid_field(lhs.name): + continue + + if not stmt.new_syntax and self.plugin_config.warn_untyped_fields: + error_untyped_fields(ctx.api, stmt) + + # if lhs.name == '__config__': # BaseConfig not well handled; I'm not sure why yet + # continue + + sym = cls.info.names.get(lhs.name) + if sym is None: # pragma: no cover + # This is likely due to a star import (see the dataclasses plugin for a more detailed explanation) + # This is the same logic used in the dataclasses plugin + continue + + node = sym.node + if isinstance(node, PlaceholderNode): # pragma: no cover + # See the PlaceholderNode docstring for more detail about how this can occur + # Basically, it is an edge case when dealing with complex import logic + # This is the same logic used in the dataclasses plugin + continue + if not isinstance(node, Var): # pragma: no cover + # Don't know if this edge case still happens with the `is_valid_field` check above + # but better safe than sorry + continue + + # x: ClassVar[int] is ignored by dataclasses. + if node.is_classvar: + continue + + is_required = self.get_is_required(cls, stmt, lhs) + alias, has_dynamic_alias = self.get_alias_info(stmt) + if ( + has_dynamic_alias + and not model_config.allow_population_by_field_name + and self.plugin_config.warn_required_dynamic_aliases + ): + error_required_dynamic_aliases(ctx.api, stmt) + fields.append( + PydanticModelField( + name=lhs.name, + is_required=is_required, + alias=alias, + has_dynamic_alias=has_dynamic_alias, + line=stmt.line, + column=stmt.column, + ) + ) + known_fields.add(lhs.name) + all_fields = fields.copy() + for info in cls.info.mro[1:]: # 0 is the current class, -2 is BaseModel, -1 is object + if METADATA_KEY not in info.metadata: + continue + + superclass_fields = [] + # Each class depends on the set of fields in its ancestors + ctx.api.add_plugin_dependency(make_wildcard_trigger(get_fullname(info))) + + for name, data in info.metadata[METADATA_KEY]['fields'].items(): + if name not in known_fields: + field = PydanticModelField.deserialize(info, data) + known_fields.add(name) + superclass_fields.append(field) + else: + (field,) = (a for a in all_fields if a.name == name) + all_fields.remove(field) + superclass_fields.append(field) + all_fields = superclass_fields + all_fields + return all_fields + + def add_initializer(self, fields: List['PydanticModelField'], config: 'ModelConfigData', is_settings: bool) -> None: + """ + Adds a fields-aware `__init__` method to the class. + + The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings. + """ + ctx = self._ctx + typed = self.plugin_config.init_typed + use_alias = config.allow_population_by_field_name is not True + force_all_optional = is_settings or bool( + config.has_alias_generator and not config.allow_population_by_field_name + ) + init_arguments = self.get_field_arguments( + fields, typed=typed, force_all_optional=force_all_optional, use_alias=use_alias + ) + if not self.should_init_forbid_extra(fields, config): + var = Var('kwargs') + init_arguments.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2)) + + if '__init__' not in ctx.cls.info.names: + add_method(ctx, '__init__', init_arguments, NoneType()) + + def add_construct_method(self, fields: List['PydanticModelField']) -> None: + """ + Adds a fully typed `construct` classmethod to the class. + + Similar to the fields-aware __init__ method, but always uses the field names (not aliases), + and does not treat settings fields as optional. + """ + ctx = self._ctx + set_str = ctx.api.named_type(f'{BUILTINS_NAME}.set', [ctx.api.named_type(f'{BUILTINS_NAME}.str')]) + optional_set_str = UnionType([set_str, NoneType()]) + fields_set_argument = Argument(Var('_fields_set', optional_set_str), optional_set_str, None, ARG_OPT) + construct_arguments = self.get_field_arguments(fields, typed=True, force_all_optional=False, use_alias=False) + construct_arguments = [fields_set_argument] + construct_arguments + + obj_type = ctx.api.named_type(f'{BUILTINS_NAME}.object') + self_tvar_name = '_PydanticBaseModel' # Make sure it does not conflict with other names in the class + tvar_fullname = ctx.cls.fullname + '.' + self_tvar_name + if MYPY_VERSION_TUPLE >= (1, 4): + tvd = TypeVarType( + self_tvar_name, + tvar_fullname, + -1, + [], + obj_type, + AnyType(TypeOfAny.from_omitted_generics), # type: ignore[arg-type] + ) + self_tvar_expr = TypeVarExpr( + self_tvar_name, + tvar_fullname, + [], + obj_type, + AnyType(TypeOfAny.from_omitted_generics), # type: ignore[arg-type] + ) + else: + tvd = TypeVarDef(self_tvar_name, tvar_fullname, -1, [], obj_type) + self_tvar_expr = TypeVarExpr(self_tvar_name, tvar_fullname, [], obj_type) + ctx.cls.info.names[self_tvar_name] = SymbolTableNode(MDEF, self_tvar_expr) + + # Backward-compatible with TypeVarDef from Mypy 0.910. + if isinstance(tvd, TypeVarType): + self_type = tvd + else: + self_type = TypeVarType(tvd) + + add_method( + ctx, + 'construct', + construct_arguments, + return_type=self_type, + self_type=self_type, + tvar_def=tvd, + is_classmethod=True, + ) + + def set_frozen(self, fields: List['PydanticModelField'], frozen: bool) -> None: + """ + Marks all fields as properties so that attempts to set them trigger mypy errors. + + This is the same approach used by the attrs and dataclasses plugins. + """ + ctx = self._ctx + info = ctx.cls.info + for field in fields: + sym_node = info.names.get(field.name) + if sym_node is not None: + var = sym_node.node + if isinstance(var, Var): + var.is_property = frozen + elif isinstance(var, PlaceholderNode) and not ctx.api.final_iteration: + # See https://github.com/pydantic/pydantic/issues/5191 to hit this branch for test coverage + ctx.api.defer() + else: # pragma: no cover + # I don't know whether it's possible to hit this branch, but I've added it for safety + try: + var_str = str(var) + except TypeError: + # This happens for PlaceholderNode; perhaps it will happen for other types in the future.. + var_str = repr(var) + detail = f'sym_node.node: {var_str} (of type {var.__class__})' + error_unexpected_behavior(detail, ctx.api, ctx.cls) + else: + var = field.to_var(info, use_alias=False) + var.info = info + var.is_property = frozen + var._fullname = get_fullname(info) + '.' + get_name(var) + info.names[get_name(var)] = SymbolTableNode(MDEF, var) + + def get_config_update(self, substmt: AssignmentStmt) -> Optional['ModelConfigData']: + """ + Determines the config update due to a single statement in the Config class definition. + + Warns if a tracked config attribute is set to a value the plugin doesn't know how to interpret (e.g., an int) + """ + lhs = substmt.lvalues[0] + if not (isinstance(lhs, NameExpr) and lhs.name in self.tracked_config_fields): + return None + if lhs.name == 'extra': + if isinstance(substmt.rvalue, StrExpr): + forbid_extra = substmt.rvalue.value == 'forbid' + elif isinstance(substmt.rvalue, MemberExpr): + forbid_extra = substmt.rvalue.name == 'forbid' + else: + error_invalid_config_value(lhs.name, self._ctx.api, substmt) + return None + return ModelConfigData(forbid_extra=forbid_extra) + if lhs.name == 'alias_generator': + has_alias_generator = True + if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname == 'builtins.None': + has_alias_generator = False + return ModelConfigData(has_alias_generator=has_alias_generator) + if isinstance(substmt.rvalue, NameExpr) and substmt.rvalue.fullname in ('builtins.True', 'builtins.False'): + return ModelConfigData(**{lhs.name: substmt.rvalue.fullname == 'builtins.True'}) + error_invalid_config_value(lhs.name, self._ctx.api, substmt) + return None + + @staticmethod + def get_is_required(cls: ClassDef, stmt: AssignmentStmt, lhs: NameExpr) -> bool: + """ + Returns a boolean indicating whether the field defined in `stmt` is a required field. + """ + expr = stmt.rvalue + if isinstance(expr, TempNode): + # TempNode means annotation-only, so only non-required if Optional + value_type = get_proper_type(cls.info[lhs.name].type) + return not PydanticModelTransformer.type_has_implicit_default(value_type) + if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME: + # The "default value" is a call to `Field`; at this point, the field is + # only required if default is Ellipsis (i.e., `field_name: Annotation = Field(...)`) or if default_factory + # is specified. + for arg, name in zip(expr.args, expr.arg_names): + # If name is None, then this arg is the default because it is the only positional argument. + if name is None or name == 'default': + return arg.__class__ is EllipsisExpr + if name == 'default_factory': + return False + # In this case, default and default_factory are not specified, so we need to look at the annotation + value_type = get_proper_type(cls.info[lhs.name].type) + return not PydanticModelTransformer.type_has_implicit_default(value_type) + # Only required if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`) + return isinstance(expr, EllipsisExpr) + + @staticmethod + def type_has_implicit_default(type_: Optional[ProperType]) -> bool: + """ + Returns True if the passed type will be given an implicit default value. + + In pydantic v1, this is the case for Optional types and Any (with default value None). + """ + if isinstance(type_, AnyType): + # Annotated as Any + return True + if isinstance(type_, UnionType) and any( + isinstance(item, NoneType) or isinstance(item, AnyType) for item in type_.items + ): + # Annotated as Optional, or otherwise having NoneType or AnyType in the union + return True + return False + + @staticmethod + def get_alias_info(stmt: AssignmentStmt) -> Tuple[Optional[str], bool]: + """ + Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`. + + `has_dynamic_alias` is True if and only if an alias is provided, but not as a string literal. + If `has_dynamic_alias` is True, `alias` will be None. + """ + expr = stmt.rvalue + if isinstance(expr, TempNode): + # TempNode means annotation-only + return None, False + + if not ( + isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME + ): + # Assigned value is not a call to pydantic.fields.Field + return None, False + + for i, arg_name in enumerate(expr.arg_names): + if arg_name != 'alias': + continue + arg = expr.args[i] + if isinstance(arg, StrExpr): + return arg.value, False + else: + return None, True + return None, False + + def get_field_arguments( + self, fields: List['PydanticModelField'], typed: bool, force_all_optional: bool, use_alias: bool + ) -> List[Argument]: + """ + Helper function used during the construction of the `__init__` and `construct` method signatures. + + Returns a list of mypy Argument instances for use in the generated signatures. + """ + info = self._ctx.cls.info + arguments = [ + field.to_argument(info, typed=typed, force_optional=force_all_optional, use_alias=use_alias) + for field in fields + if not (use_alias and field.has_dynamic_alias) + ] + return arguments + + def should_init_forbid_extra(self, fields: List['PydanticModelField'], config: 'ModelConfigData') -> bool: + """ + Indicates whether the generated `__init__` should get a `**kwargs` at the end of its signature + + We disallow arbitrary kwargs if the extra config setting is "forbid", or if the plugin config says to, + *unless* a required dynamic alias is present (since then we can't determine a valid signature). + """ + if not config.allow_population_by_field_name: + if self.is_dynamic_alias_present(fields, bool(config.has_alias_generator)): + return False + if config.forbid_extra: + return True + return self.plugin_config.init_forbid_extra + + @staticmethod + def is_dynamic_alias_present(fields: List['PydanticModelField'], has_alias_generator: bool) -> bool: + """ + Returns whether any fields on the model have a "dynamic alias", i.e., an alias that cannot be + determined during static analysis. + """ + for field in fields: + if field.has_dynamic_alias: + return True + if has_alias_generator: + for field in fields: + if field.alias is None: + return True + return False + + +class PydanticModelField: + def __init__( + self, name: str, is_required: bool, alias: Optional[str], has_dynamic_alias: bool, line: int, column: int + ): + self.name = name + self.is_required = is_required + self.alias = alias + self.has_dynamic_alias = has_dynamic_alias + self.line = line + self.column = column + + def to_var(self, info: TypeInfo, use_alias: bool) -> Var: + name = self.name + if use_alias and self.alias is not None: + name = self.alias + return Var(name, info[self.name].type) + + def to_argument(self, info: TypeInfo, typed: bool, force_optional: bool, use_alias: bool) -> Argument: + if typed and info[self.name].type is not None: + type_annotation = info[self.name].type + else: + type_annotation = AnyType(TypeOfAny.explicit) + return Argument( + variable=self.to_var(info, use_alias), + type_annotation=type_annotation, + initializer=None, + kind=ARG_NAMED_OPT if force_optional or not self.is_required else ARG_NAMED, + ) + + def serialize(self) -> JsonDict: + return self.__dict__ + + @classmethod + def deserialize(cls, info: TypeInfo, data: JsonDict) -> 'PydanticModelField': + return cls(**data) + + +class ModelConfigData: + def __init__( + self, + forbid_extra: Optional[bool] = None, + allow_mutation: Optional[bool] = None, + frozen: Optional[bool] = None, + orm_mode: Optional[bool] = None, + allow_population_by_field_name: Optional[bool] = None, + has_alias_generator: Optional[bool] = None, + ): + self.forbid_extra = forbid_extra + self.allow_mutation = allow_mutation + self.frozen = frozen + self.orm_mode = orm_mode + self.allow_population_by_field_name = allow_population_by_field_name + self.has_alias_generator = has_alias_generator + + def set_values_dict(self) -> Dict[str, Any]: + return {k: v for k, v in self.__dict__.items() if v is not None} + + def update(self, config: Optional['ModelConfigData']) -> None: + if config is None: + return + for k, v in config.set_values_dict().items(): + setattr(self, k, v) + + def setdefault(self, key: str, value: Any) -> None: + if getattr(self, key) is None: + setattr(self, key, value) + + +ERROR_ORM = ErrorCode('pydantic-orm', 'Invalid from_orm call', 'Pydantic') +ERROR_CONFIG = ErrorCode('pydantic-config', 'Invalid config value', 'Pydantic') +ERROR_ALIAS = ErrorCode('pydantic-alias', 'Dynamic alias disallowed', 'Pydantic') +ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic') +ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic') +ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic') + + +def error_from_orm(model_name: str, api: CheckerPluginInterface, context: Context) -> None: + api.fail(f'"{model_name}" does not have orm_mode=True', context, code=ERROR_ORM) + + +def error_invalid_config_value(name: str, api: SemanticAnalyzerPluginInterface, context: Context) -> None: + api.fail(f'Invalid value for "Config.{name}"', context, code=ERROR_CONFIG) + + +def error_required_dynamic_aliases(api: SemanticAnalyzerPluginInterface, context: Context) -> None: + api.fail('Required dynamic aliases disallowed', context, code=ERROR_ALIAS) + + +def error_unexpected_behavior( + detail: str, api: Union[CheckerPluginInterface, SemanticAnalyzerPluginInterface], context: Context +) -> None: # pragma: no cover + # Can't think of a good way to test this, but I confirmed it renders as desired by adding to a non-error path + link = 'https://github.com/pydantic/pydantic/issues/new/choose' + full_message = f'The pydantic mypy plugin ran into unexpected behavior: {detail}\n' + full_message += f'Please consider reporting this bug at {link} so we can try to fix it!' + api.fail(full_message, context, code=ERROR_UNEXPECTED) + + +def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context) -> None: + api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED) + + +def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None: + api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS) + + +def add_method( + ctx: ClassDefContext, + name: str, + args: List[Argument], + return_type: Type, + self_type: Optional[Type] = None, + tvar_def: Optional[TypeVarDef] = None, + is_classmethod: bool = False, + is_new: bool = False, + # is_staticmethod: bool = False, +) -> None: + """ + Adds a new method to a class. + + This can be dropped if/when https://github.com/python/mypy/issues/7301 is merged + """ + info = ctx.cls.info + + # First remove any previously generated methods with the same name + # to avoid clashes and problems in the semantic analyzer. + if name in info.names: + sym = info.names[name] + if sym.plugin_generated and isinstance(sym.node, FuncDef): + ctx.cls.defs.body.remove(sym.node) # pragma: no cover + + self_type = self_type or fill_typevars(info) + if is_classmethod or is_new: + first = [Argument(Var('_cls'), TypeType.make_normalized(self_type), None, ARG_POS)] + # elif is_staticmethod: + # first = [] + else: + self_type = self_type or fill_typevars(info) + first = [Argument(Var('__pydantic_self__'), self_type, None, ARG_POS)] + args = first + args + arg_types, arg_names, arg_kinds = [], [], [] + for arg in args: + assert arg.type_annotation, 'All arguments must be fully typed.' + arg_types.append(arg.type_annotation) + arg_names.append(get_name(arg.variable)) + arg_kinds.append(arg.kind) + + function_type = ctx.api.named_type(f'{BUILTINS_NAME}.function') + signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type) + if tvar_def: + signature.variables = [tvar_def] + + func = FuncDef(name, args, Block([PassStmt()])) + func.info = info + func.type = set_callable_name(signature, func) + func.is_class = is_classmethod + # func.is_static = is_staticmethod + func._fullname = get_fullname(info) + '.' + name + func.line = info.line + + # NOTE: we would like the plugin generated node to dominate, but we still + # need to keep any existing definitions so they get semantically analyzed. + if name in info.names: + # Get a nice unique name instead. + r_name = get_unique_redefinition_name(name, info.names) + info.names[r_name] = info.names[name] + + if is_classmethod: # or is_staticmethod: + func.is_decorated = True + v = Var(name, func.type) + v.info = info + v._fullname = func._fullname + # if is_classmethod: + v.is_classmethod = True + dec = Decorator(func, [NameExpr('classmethod')], v) + # else: + # v.is_staticmethod = True + # dec = Decorator(func, [NameExpr('staticmethod')], v) + + dec.line = info.line + sym = SymbolTableNode(MDEF, dec) + else: + sym = SymbolTableNode(MDEF, func) + sym.plugin_generated = True + + info.names[name] = sym + info.defn.defs.body.append(func) + + +def get_fullname(x: Union[FuncBase, SymbolNode]) -> str: + """ + Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped. + """ + fn = x.fullname + if callable(fn): # pragma: no cover + return fn() + return fn + + +def get_name(x: Union[FuncBase, SymbolNode]) -> str: + """ + Used for compatibility with mypy 0.740; can be dropped once support for 0.740 is dropped. + """ + fn = x.name + if callable(fn): # pragma: no cover + return fn() + return fn + + +def parse_toml(config_file: str) -> Optional[Dict[str, Any]]: + if not config_file.endswith('.toml'): + return None + + read_mode = 'rb' + if sys.version_info >= (3, 11): + import tomllib as toml_ + else: + try: + import tomli as toml_ + except ImportError: + # older versions of mypy have toml as a dependency, not tomli + read_mode = 'r' + try: + import toml as toml_ # type: ignore[no-redef] + except ImportError: # pragma: no cover + import warnings + + warnings.warn('No TOML parser installed, cannot read configuration from `pyproject.toml`.') + return None + + with open(config_file, read_mode) as rf: + return toml_.load(rf) # type: ignore[arg-type] diff --git a/lib/pydantic/v1/networks.py b/lib/pydantic/v1/networks.py new file mode 100644 index 00000000..cfebe588 --- /dev/null +++ b/lib/pydantic/v1/networks.py @@ -0,0 +1,747 @@ +import re +from ipaddress import ( + IPv4Address, + IPv4Interface, + IPv4Network, + IPv6Address, + IPv6Interface, + IPv6Network, + _BaseAddress, + _BaseNetwork, +) +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Generator, + List, + Match, + Optional, + Pattern, + Set, + Tuple, + Type, + Union, + cast, + no_type_check, +) + +from . import errors +from .utils import Representation, update_not_none +from .validators import constr_length_validator, str_validator + +if TYPE_CHECKING: + import email_validator + from typing_extensions import TypedDict + + from .config import BaseConfig + from .fields import ModelField + from .typing import AnyCallable + + CallableGenerator = Generator[AnyCallable, None, None] + + class Parts(TypedDict, total=False): + scheme: str + user: Optional[str] + password: Optional[str] + ipv4: Optional[str] + ipv6: Optional[str] + domain: Optional[str] + port: Optional[str] + path: Optional[str] + query: Optional[str] + fragment: Optional[str] + + class HostParts(TypedDict, total=False): + host: str + tld: Optional[str] + host_type: Optional[str] + port: Optional[str] + rebuild: bool + +else: + email_validator = None + + class Parts(dict): + pass + + +NetworkType = Union[str, bytes, int, Tuple[Union[str, bytes, int], Union[str, int]]] + +__all__ = [ + 'AnyUrl', + 'AnyHttpUrl', + 'FileUrl', + 'HttpUrl', + 'stricturl', + 'EmailStr', + 'NameEmail', + 'IPvAnyAddress', + 'IPvAnyInterface', + 'IPvAnyNetwork', + 'PostgresDsn', + 'CockroachDsn', + 'AmqpDsn', + 'RedisDsn', + 'MongoDsn', + 'KafkaDsn', + 'validate_email', +] + +_url_regex_cache = None +_multi_host_url_regex_cache = None +_ascii_domain_regex_cache = None +_int_domain_regex_cache = None +_host_regex_cache = None + +_host_regex = ( + r'(?:' + r'(?P(?:\d{1,3}\.){3}\d{1,3})(?=$|[/:#?])|' # ipv4 + r'(?P\[[A-F0-9]*:[A-F0-9:]+\])(?=$|[/:#?])|' # ipv6 + r'(?P[^\s/:?#]+)' # domain, validation occurs later + r')?' + r'(?::(?P\d+))?' # port +) +_scheme_regex = r'(?:(?P[a-z][a-z0-9+\-.]+)://)?' # scheme https://tools.ietf.org/html/rfc3986#appendix-A +_user_info_regex = r'(?:(?P[^\s:/]*)(?::(?P[^\s/]*))?@)?' +_path_regex = r'(?P/[^\s?#]*)?' +_query_regex = r'(?:\?(?P[^\s#]*))?' +_fragment_regex = r'(?:#(?P[^\s#]*))?' + + +def url_regex() -> Pattern[str]: + global _url_regex_cache + if _url_regex_cache is None: + _url_regex_cache = re.compile( + rf'{_scheme_regex}{_user_info_regex}{_host_regex}{_path_regex}{_query_regex}{_fragment_regex}', + re.IGNORECASE, + ) + return _url_regex_cache + + +def multi_host_url_regex() -> Pattern[str]: + """ + Compiled multi host url regex. + + Additionally to `url_regex` it allows to match multiple hosts. + E.g. host1.db.net,host2.db.net + """ + global _multi_host_url_regex_cache + if _multi_host_url_regex_cache is None: + _multi_host_url_regex_cache = re.compile( + rf'{_scheme_regex}{_user_info_regex}' + r'(?P([^/]*))' # validation occurs later + rf'{_path_regex}{_query_regex}{_fragment_regex}', + re.IGNORECASE, + ) + return _multi_host_url_regex_cache + + +def ascii_domain_regex() -> Pattern[str]: + global _ascii_domain_regex_cache + if _ascii_domain_regex_cache is None: + ascii_chunk = r'[_0-9a-z](?:[-_0-9a-z]{0,61}[_0-9a-z])?' + ascii_domain_ending = r'(?P\.[a-z]{2,63})?\.?' + _ascii_domain_regex_cache = re.compile( + fr'(?:{ascii_chunk}\.)*?{ascii_chunk}{ascii_domain_ending}', re.IGNORECASE + ) + return _ascii_domain_regex_cache + + +def int_domain_regex() -> Pattern[str]: + global _int_domain_regex_cache + if _int_domain_regex_cache is None: + int_chunk = r'[_0-9a-\U00040000](?:[-_0-9a-\U00040000]{0,61}[_0-9a-\U00040000])?' + int_domain_ending = r'(?P(\.[^\W\d_]{2,63})|(\.(?:xn--)[_0-9a-z-]{2,63}))?\.?' + _int_domain_regex_cache = re.compile(fr'(?:{int_chunk}\.)*?{int_chunk}{int_domain_ending}', re.IGNORECASE) + return _int_domain_regex_cache + + +def host_regex() -> Pattern[str]: + global _host_regex_cache + if _host_regex_cache is None: + _host_regex_cache = re.compile( + _host_regex, + re.IGNORECASE, + ) + return _host_regex_cache + + +class AnyUrl(str): + strip_whitespace = True + min_length = 1 + max_length = 2**16 + allowed_schemes: Optional[Collection[str]] = None + tld_required: bool = False + user_required: bool = False + host_required: bool = True + hidden_parts: Set[str] = set() + + __slots__ = ('scheme', 'user', 'password', 'host', 'tld', 'host_type', 'port', 'path', 'query', 'fragment') + + @no_type_check + def __new__(cls, url: Optional[str], **kwargs) -> object: + return str.__new__(cls, cls.build(**kwargs) if url is None else url) + + def __init__( + self, + url: str, + *, + scheme: str, + user: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + tld: Optional[str] = None, + host_type: str = 'domain', + port: Optional[str] = None, + path: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, + ) -> None: + str.__init__(url) + self.scheme = scheme + self.user = user + self.password = password + self.host = host + self.tld = tld + self.host_type = host_type + self.port = port + self.path = path + self.query = query + self.fragment = fragment + + @classmethod + def build( + cls, + *, + scheme: str, + user: Optional[str] = None, + password: Optional[str] = None, + host: str, + port: Optional[str] = None, + path: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, + **_kwargs: str, + ) -> str: + parts = Parts( + scheme=scheme, + user=user, + password=password, + host=host, + port=port, + path=path, + query=query, + fragment=fragment, + **_kwargs, # type: ignore[misc] + ) + + url = scheme + '://' + if user: + url += user + if password: + url += ':' + password + if user or password: + url += '@' + url += host + if port and ('port' not in cls.hidden_parts or cls.get_default_parts(parts).get('port') != port): + url += ':' + port + if path: + url += path + if query: + url += '?' + query + if fragment: + url += '#' + fragment + return url + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none(field_schema, minLength=cls.min_length, maxLength=cls.max_length, format='uri') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls, value: Any, field: 'ModelField', config: 'BaseConfig') -> 'AnyUrl': + if value.__class__ == cls: + return value + value = str_validator(value) + if cls.strip_whitespace: + value = value.strip() + url: str = cast(str, constr_length_validator(value, field, config)) + + m = cls._match_url(url) + # the regex should always match, if it doesn't please report with details of the URL tried + assert m, 'URL regex failed unexpectedly' + + original_parts = cast('Parts', m.groupdict()) + parts = cls.apply_default_parts(original_parts) + parts = cls.validate_parts(parts) + + if m.end() != len(url): + raise errors.UrlExtraError(extra=url[m.end() :]) + + return cls._build_url(m, url, parts) + + @classmethod + def _build_url(cls, m: Match[str], url: str, parts: 'Parts') -> 'AnyUrl': + """ + Validate hosts and build the AnyUrl object. Split from `validate` so this method + can be altered in `MultiHostDsn`. + """ + host, tld, host_type, rebuild = cls.validate_host(parts) + + return cls( + None if rebuild else url, + scheme=parts['scheme'], + user=parts['user'], + password=parts['password'], + host=host, + tld=tld, + host_type=host_type, + port=parts['port'], + path=parts['path'], + query=parts['query'], + fragment=parts['fragment'], + ) + + @staticmethod + def _match_url(url: str) -> Optional[Match[str]]: + return url_regex().match(url) + + @staticmethod + def _validate_port(port: Optional[str]) -> None: + if port is not None and int(port) > 65_535: + raise errors.UrlPortError() + + @classmethod + def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts': + """ + A method used to validate parts of a URL. + Could be overridden to set default values for parts if missing + """ + scheme = parts['scheme'] + if scheme is None: + raise errors.UrlSchemeError() + + if cls.allowed_schemes and scheme.lower() not in cls.allowed_schemes: + raise errors.UrlSchemePermittedError(set(cls.allowed_schemes)) + + if validate_port: + cls._validate_port(parts['port']) + + user = parts['user'] + if cls.user_required and user is None: + raise errors.UrlUserInfoError() + + return parts + + @classmethod + def validate_host(cls, parts: 'Parts') -> Tuple[str, Optional[str], str, bool]: + tld, host_type, rebuild = None, None, False + for f in ('domain', 'ipv4', 'ipv6'): + host = parts[f] # type: ignore[literal-required] + if host: + host_type = f + break + + if host is None: + if cls.host_required: + raise errors.UrlHostError() + elif host_type == 'domain': + is_international = False + d = ascii_domain_regex().fullmatch(host) + if d is None: + d = int_domain_regex().fullmatch(host) + if d is None: + raise errors.UrlHostError() + is_international = True + + tld = d.group('tld') + if tld is None and not is_international: + d = int_domain_regex().fullmatch(host) + assert d is not None + tld = d.group('tld') + is_international = True + + if tld is not None: + tld = tld[1:] + elif cls.tld_required: + raise errors.UrlHostTldError() + + if is_international: + host_type = 'int_domain' + rebuild = True + host = host.encode('idna').decode('ascii') + if tld is not None: + tld = tld.encode('idna').decode('ascii') + + return host, tld, host_type, rebuild # type: ignore + + @staticmethod + def get_default_parts(parts: 'Parts') -> 'Parts': + return {} + + @classmethod + def apply_default_parts(cls, parts: 'Parts') -> 'Parts': + for key, value in cls.get_default_parts(parts).items(): + if not parts[key]: # type: ignore[literal-required] + parts[key] = value # type: ignore[literal-required] + return parts + + def __repr__(self) -> str: + extra = ', '.join(f'{n}={getattr(self, n)!r}' for n in self.__slots__ if getattr(self, n) is not None) + return f'{self.__class__.__name__}({super().__repr__()}, {extra})' + + +class AnyHttpUrl(AnyUrl): + allowed_schemes = {'http', 'https'} + + __slots__ = () + + +class HttpUrl(AnyHttpUrl): + tld_required = True + # https://stackoverflow.com/questions/417142/what-is-the-maximum-length-of-a-url-in-different-browsers + max_length = 2083 + hidden_parts = {'port'} + + @staticmethod + def get_default_parts(parts: 'Parts') -> 'Parts': + return {'port': '80' if parts['scheme'] == 'http' else '443'} + + +class FileUrl(AnyUrl): + allowed_schemes = {'file'} + host_required = False + + __slots__ = () + + +class MultiHostDsn(AnyUrl): + __slots__ = AnyUrl.__slots__ + ('hosts',) + + def __init__(self, *args: Any, hosts: Optional[List['HostParts']] = None, **kwargs: Any): + super().__init__(*args, **kwargs) + self.hosts = hosts + + @staticmethod + def _match_url(url: str) -> Optional[Match[str]]: + return multi_host_url_regex().match(url) + + @classmethod + def validate_parts(cls, parts: 'Parts', validate_port: bool = True) -> 'Parts': + return super().validate_parts(parts, validate_port=False) + + @classmethod + def _build_url(cls, m: Match[str], url: str, parts: 'Parts') -> 'MultiHostDsn': + hosts_parts: List['HostParts'] = [] + host_re = host_regex() + for host in m.groupdict()['hosts'].split(','): + d: Parts = host_re.match(host).groupdict() # type: ignore + host, tld, host_type, rebuild = cls.validate_host(d) + port = d.get('port') + cls._validate_port(port) + hosts_parts.append( + { + 'host': host, + 'host_type': host_type, + 'tld': tld, + 'rebuild': rebuild, + 'port': port, + } + ) + + if len(hosts_parts) > 1: + return cls( + None if any([hp['rebuild'] for hp in hosts_parts]) else url, + scheme=parts['scheme'], + user=parts['user'], + password=parts['password'], + path=parts['path'], + query=parts['query'], + fragment=parts['fragment'], + host_type=None, + hosts=hosts_parts, + ) + else: + # backwards compatibility with single host + host_part = hosts_parts[0] + return cls( + None if host_part['rebuild'] else url, + scheme=parts['scheme'], + user=parts['user'], + password=parts['password'], + host=host_part['host'], + tld=host_part['tld'], + host_type=host_part['host_type'], + port=host_part.get('port'), + path=parts['path'], + query=parts['query'], + fragment=parts['fragment'], + ) + + +class PostgresDsn(MultiHostDsn): + allowed_schemes = { + 'postgres', + 'postgresql', + 'postgresql+asyncpg', + 'postgresql+pg8000', + 'postgresql+psycopg', + 'postgresql+psycopg2', + 'postgresql+psycopg2cffi', + 'postgresql+py-postgresql', + 'postgresql+pygresql', + } + user_required = True + + __slots__ = () + + +class CockroachDsn(AnyUrl): + allowed_schemes = { + 'cockroachdb', + 'cockroachdb+psycopg2', + 'cockroachdb+asyncpg', + } + user_required = True + + +class AmqpDsn(AnyUrl): + allowed_schemes = {'amqp', 'amqps'} + host_required = False + + +class RedisDsn(AnyUrl): + __slots__ = () + allowed_schemes = {'redis', 'rediss'} + host_required = False + + @staticmethod + def get_default_parts(parts: 'Parts') -> 'Parts': + return { + 'domain': 'localhost' if not (parts['ipv4'] or parts['ipv6']) else '', + 'port': '6379', + 'path': '/0', + } + + +class MongoDsn(AnyUrl): + allowed_schemes = {'mongodb'} + + # TODO: Needed to generic "Parts" for "Replica Set", "Sharded Cluster", and other mongodb deployment modes + @staticmethod + def get_default_parts(parts: 'Parts') -> 'Parts': + return { + 'port': '27017', + } + + +class KafkaDsn(AnyUrl): + allowed_schemes = {'kafka'} + + @staticmethod + def get_default_parts(parts: 'Parts') -> 'Parts': + return { + 'domain': 'localhost', + 'port': '9092', + } + + +def stricturl( + *, + strip_whitespace: bool = True, + min_length: int = 1, + max_length: int = 2**16, + tld_required: bool = True, + host_required: bool = True, + allowed_schemes: Optional[Collection[str]] = None, +) -> Type[AnyUrl]: + # use kwargs then define conf in a dict to aid with IDE type hinting + namespace = dict( + strip_whitespace=strip_whitespace, + min_length=min_length, + max_length=max_length, + tld_required=tld_required, + host_required=host_required, + allowed_schemes=allowed_schemes, + ) + return type('UrlValue', (AnyUrl,), namespace) + + +def import_email_validator() -> None: + global email_validator + try: + import email_validator + except ImportError as e: + raise ImportError('email-validator is not installed, run `pip install pydantic[email]`') from e + + +class EmailStr(str): + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format='email') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + # included here and below so the error happens straight away + import_email_validator() + + yield str_validator + yield cls.validate + + @classmethod + def validate(cls, value: Union[str]) -> str: + return validate_email(value)[1] + + +class NameEmail(Representation): + __slots__ = 'name', 'email' + + def __init__(self, name: str, email: str): + self.name = name + self.email = email + + def __eq__(self, other: Any) -> bool: + return isinstance(other, NameEmail) and (self.name, self.email) == (other.name, other.email) + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format='name-email') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + import_email_validator() + + yield cls.validate + + @classmethod + def validate(cls, value: Any) -> 'NameEmail': + if value.__class__ == cls: + return value + value = str_validator(value) + return cls(*validate_email(value)) + + def __str__(self) -> str: + return f'{self.name} <{self.email}>' + + +class IPvAnyAddress(_BaseAddress): + __slots__ = () + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format='ipvanyaddress') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls, value: Union[str, bytes, int]) -> Union[IPv4Address, IPv6Address]: + try: + return IPv4Address(value) + except ValueError: + pass + + try: + return IPv6Address(value) + except ValueError: + raise errors.IPvAnyAddressError() + + +class IPvAnyInterface(_BaseAddress): + __slots__ = () + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format='ipvanyinterface') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls, value: NetworkType) -> Union[IPv4Interface, IPv6Interface]: + try: + return IPv4Interface(value) + except ValueError: + pass + + try: + return IPv6Interface(value) + except ValueError: + raise errors.IPvAnyInterfaceError() + + +class IPvAnyNetwork(_BaseNetwork): # type: ignore + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format='ipvanynetwork') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls, value: NetworkType) -> Union[IPv4Network, IPv6Network]: + # Assume IP Network is defined with a default value for ``strict`` argument. + # Define your own class if you want to specify network address check strictness. + try: + return IPv4Network(value) + except ValueError: + pass + + try: + return IPv6Network(value) + except ValueError: + raise errors.IPvAnyNetworkError() + + +pretty_email_regex = re.compile(r'([\w ]*?) *<(.*)> *') +MAX_EMAIL_LENGTH = 2048 +"""Maximum length for an email. +A somewhat arbitrary but very generous number compared to what is allowed by most implementations. +""" + + +def validate_email(value: Union[str]) -> Tuple[str, str]: + """ + Email address validation using https://pypi.org/project/email-validator/ + Notes: + * raw ip address (literal) domain parts are not allowed. + * "John Doe " style "pretty" email addresses are processed + * spaces are striped from the beginning and end of addresses but no error is raised + """ + if email_validator is None: + import_email_validator() + + if len(value) > MAX_EMAIL_LENGTH: + raise errors.EmailError() + + m = pretty_email_regex.fullmatch(value) + name: Union[str, None] = None + if m: + name, value = m.groups() + email = value.strip() + try: + parts = email_validator.validate_email(email, check_deliverability=False) + except email_validator.EmailNotValidError as e: + raise errors.EmailError from e + + if hasattr(parts, 'normalized'): + # email-validator >= 2 + email = parts.normalized + assert email is not None + name = name or parts.local_part + return name, email + else: + # email-validator >1, <2 + at_index = email.index('@') + local_part = email[:at_index] # RFC 5321, local part must be case-sensitive. + global_part = email[at_index:].lower() + + return name or local_part, local_part + global_part diff --git a/lib/pydantic/v1/parse.py b/lib/pydantic/v1/parse.py new file mode 100644 index 00000000..7ac330ca --- /dev/null +++ b/lib/pydantic/v1/parse.py @@ -0,0 +1,66 @@ +import json +import pickle +from enum import Enum +from pathlib import Path +from typing import Any, Callable, Union + +from .types import StrBytes + + +class Protocol(str, Enum): + json = 'json' + pickle = 'pickle' + + +def load_str_bytes( + b: StrBytes, + *, + content_type: str = None, + encoding: str = 'utf8', + proto: Protocol = None, + allow_pickle: bool = False, + json_loads: Callable[[str], Any] = json.loads, +) -> Any: + if proto is None and content_type: + if content_type.endswith(('json', 'javascript')): + pass + elif allow_pickle and content_type.endswith('pickle'): + proto = Protocol.pickle + else: + raise TypeError(f'Unknown content-type: {content_type}') + + proto = proto or Protocol.json + + if proto == Protocol.json: + if isinstance(b, bytes): + b = b.decode(encoding) + return json_loads(b) + elif proto == Protocol.pickle: + if not allow_pickle: + raise RuntimeError('Trying to decode with pickle with allow_pickle=False') + bb = b if isinstance(b, bytes) else b.encode() + return pickle.loads(bb) + else: + raise TypeError(f'Unknown protocol: {proto}') + + +def load_file( + path: Union[str, Path], + *, + content_type: str = None, + encoding: str = 'utf8', + proto: Protocol = None, + allow_pickle: bool = False, + json_loads: Callable[[str], Any] = json.loads, +) -> Any: + path = Path(path) + b = path.read_bytes() + if content_type is None: + if path.suffix in ('.js', '.json'): + proto = Protocol.json + elif path.suffix == '.pkl': + proto = Protocol.pickle + + return load_str_bytes( + b, proto=proto, content_type=content_type, encoding=encoding, allow_pickle=allow_pickle, json_loads=json_loads + ) diff --git a/lib/pydantic/v1/py.typed b/lib/pydantic/v1/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/lib/pydantic/v1/schema.py b/lib/pydantic/v1/schema.py new file mode 100644 index 00000000..ea16a72a --- /dev/null +++ b/lib/pydantic/v1/schema.py @@ -0,0 +1,1163 @@ +import re +import warnings +from collections import defaultdict +from dataclasses import is_dataclass +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from enum import Enum +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + ForwardRef, + FrozenSet, + Generic, + Iterable, + List, + Optional, + Pattern, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) +from uuid import UUID + +from typing_extensions import Annotated, Literal + +from .fields import ( + MAPPING_LIKE_SHAPES, + SHAPE_DEQUE, + SHAPE_FROZENSET, + SHAPE_GENERIC, + SHAPE_ITERABLE, + SHAPE_LIST, + SHAPE_SEQUENCE, + SHAPE_SET, + SHAPE_SINGLETON, + SHAPE_TUPLE, + SHAPE_TUPLE_ELLIPSIS, + FieldInfo, + ModelField, +) +from .json import pydantic_encoder +from .networks import AnyUrl, EmailStr +from .types import ( + ConstrainedDecimal, + ConstrainedFloat, + ConstrainedFrozenSet, + ConstrainedInt, + ConstrainedList, + ConstrainedSet, + ConstrainedStr, + SecretBytes, + SecretStr, + StrictBytes, + StrictStr, + conbytes, + condecimal, + confloat, + confrozenset, + conint, + conlist, + conset, + constr, +) +from .typing import ( + all_literal_values, + get_args, + get_origin, + get_sub_types, + is_callable_type, + is_literal_type, + is_namedtuple, + is_none_type, + is_union, +) +from .utils import ROOT_KEY, get_model, lenient_issubclass + +if TYPE_CHECKING: + from .dataclasses import Dataclass + from .main import BaseModel + +default_prefix = '#/definitions/' +default_ref_template = '#/definitions/{model}' + +TypeModelOrEnum = Union[Type['BaseModel'], Type[Enum]] +TypeModelSet = Set[TypeModelOrEnum] + + +def _apply_modify_schema( + modify_schema: Callable[..., None], field: Optional[ModelField], field_schema: Dict[str, Any] +) -> None: + from inspect import signature + + sig = signature(modify_schema) + args = set(sig.parameters.keys()) + if 'field' in args or 'kwargs' in args: + modify_schema(field_schema, field=field) + else: + modify_schema(field_schema) + + +def schema( + models: Sequence[Union[Type['BaseModel'], Type['Dataclass']]], + *, + by_alias: bool = True, + title: Optional[str] = None, + description: Optional[str] = None, + ref_prefix: Optional[str] = None, + ref_template: str = default_ref_template, +) -> Dict[str, Any]: + """ + Process a list of models and generate a single JSON Schema with all of them defined in the ``definitions`` + top-level JSON key, including their sub-models. + + :param models: a list of models to include in the generated JSON Schema + :param by_alias: generate the schemas using the aliases defined, if any + :param title: title for the generated schema that includes the definitions + :param description: description for the generated schema + :param ref_prefix: the JSON Pointer prefix for schema references with ``$ref``, if None, will be set to the + default of ``#/definitions/``. Update it if you want the schemas to reference the definitions somewhere + else, e.g. for OpenAPI use ``#/components/schemas/``. The resulting generated schemas will still be at the + top-level key ``definitions``, so you can extract them from there. But all the references will have the set + prefix. + :param ref_template: Use a ``string.format()`` template for ``$ref`` instead of a prefix. This can be useful + for references that cannot be represented by ``ref_prefix`` such as a definition stored in another file. For + a sibling json file in a ``/schemas`` directory use ``"/schemas/${model}.json#"``. + :return: dict with the JSON Schema with a ``definitions`` top-level key including the schema definitions for + the models and sub-models passed in ``models``. + """ + clean_models = [get_model(model) for model in models] + flat_models = get_flat_models_from_models(clean_models) + model_name_map = get_model_name_map(flat_models) + definitions = {} + output_schema: Dict[str, Any] = {} + if title: + output_schema['title'] = title + if description: + output_schema['description'] = description + for model in clean_models: + m_schema, m_definitions, m_nested_models = model_process_schema( + model, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + ) + definitions.update(m_definitions) + model_name = model_name_map[model] + definitions[model_name] = m_schema + if definitions: + output_schema['definitions'] = definitions + return output_schema + + +def model_schema( + model: Union[Type['BaseModel'], Type['Dataclass']], + by_alias: bool = True, + ref_prefix: Optional[str] = None, + ref_template: str = default_ref_template, +) -> Dict[str, Any]: + """ + Generate a JSON Schema for one model. With all the sub-models defined in the ``definitions`` top-level + JSON key. + + :param model: a Pydantic model (a class that inherits from BaseModel) + :param by_alias: generate the schemas using the aliases defined, if any + :param ref_prefix: the JSON Pointer prefix for schema references with ``$ref``, if None, will be set to the + default of ``#/definitions/``. Update it if you want the schemas to reference the definitions somewhere + else, e.g. for OpenAPI use ``#/components/schemas/``. The resulting generated schemas will still be at the + top-level key ``definitions``, so you can extract them from there. But all the references will have the set + prefix. + :param ref_template: Use a ``string.format()`` template for ``$ref`` instead of a prefix. This can be useful for + references that cannot be represented by ``ref_prefix`` such as a definition stored in another file. For a + sibling json file in a ``/schemas`` directory use ``"/schemas/${model}.json#"``. + :return: dict with the JSON Schema for the passed ``model`` + """ + model = get_model(model) + flat_models = get_flat_models_from_model(model) + model_name_map = get_model_name_map(flat_models) + model_name = model_name_map[model] + m_schema, m_definitions, nested_models = model_process_schema( + model, by_alias=by_alias, model_name_map=model_name_map, ref_prefix=ref_prefix, ref_template=ref_template + ) + if model_name in nested_models: + # model_name is in Nested models, it has circular references + m_definitions[model_name] = m_schema + m_schema = get_schema_ref(model_name, ref_prefix, ref_template, False) + if m_definitions: + m_schema.update({'definitions': m_definitions}) + return m_schema + + +def get_field_info_schema(field: ModelField, schema_overrides: bool = False) -> Tuple[Dict[str, Any], bool]: + # If no title is explicitly set, we don't set title in the schema for enums. + # The behaviour is the same as `BaseModel` reference, where the default title + # is in the definitions part of the schema. + schema_: Dict[str, Any] = {} + if field.field_info.title or not lenient_issubclass(field.type_, Enum): + schema_['title'] = field.field_info.title or field.alias.title().replace('_', ' ') + + if field.field_info.title: + schema_overrides = True + + if field.field_info.description: + schema_['description'] = field.field_info.description + schema_overrides = True + + if not field.required and field.default is not None and not is_callable_type(field.outer_type_): + schema_['default'] = encode_default(field.default) + schema_overrides = True + + return schema_, schema_overrides + + +def field_schema( + field: ModelField, + *, + by_alias: bool = True, + model_name_map: Dict[TypeModelOrEnum, str], + ref_prefix: Optional[str] = None, + ref_template: str = default_ref_template, + known_models: Optional[TypeModelSet] = None, +) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: + """ + Process a Pydantic field and return a tuple with a JSON Schema for it as the first item. + Also return a dictionary of definitions with models as keys and their schemas as values. If the passed field + is a model and has sub-models, and those sub-models don't have overrides (as ``title``, ``default``, etc), they + will be included in the definitions and referenced in the schema instead of included recursively. + + :param field: a Pydantic ``ModelField`` + :param by_alias: use the defined alias (if any) in the returned schema + :param model_name_map: used to generate the JSON Schema references to other models included in the definitions + :param ref_prefix: the JSON Pointer prefix to use for references to other schemas, if None, the default of + #/definitions/ will be used + :param ref_template: Use a ``string.format()`` template for ``$ref`` instead of a prefix. This can be useful for + references that cannot be represented by ``ref_prefix`` such as a definition stored in another file. For a + sibling json file in a ``/schemas`` directory use ``"/schemas/${model}.json#"``. + :param known_models: used to solve circular references + :return: tuple of the schema for this field and additional definitions + """ + s, schema_overrides = get_field_info_schema(field) + + validation_schema = get_field_schema_validations(field) + if validation_schema: + s.update(validation_schema) + schema_overrides = True + + f_schema, f_definitions, f_nested_models = field_type_schema( + field, + by_alias=by_alias, + model_name_map=model_name_map, + schema_overrides=schema_overrides, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models or set(), + ) + + # $ref will only be returned when there are no schema_overrides + if '$ref' in f_schema: + return f_schema, f_definitions, f_nested_models + else: + s.update(f_schema) + return s, f_definitions, f_nested_models + + +numeric_types = (int, float, Decimal) +_str_types_attrs: Tuple[Tuple[str, Union[type, Tuple[type, ...]], str], ...] = ( + ('max_length', numeric_types, 'maxLength'), + ('min_length', numeric_types, 'minLength'), + ('regex', str, 'pattern'), +) + +_numeric_types_attrs: Tuple[Tuple[str, Union[type, Tuple[type, ...]], str], ...] = ( + ('gt', numeric_types, 'exclusiveMinimum'), + ('lt', numeric_types, 'exclusiveMaximum'), + ('ge', numeric_types, 'minimum'), + ('le', numeric_types, 'maximum'), + ('multiple_of', numeric_types, 'multipleOf'), +) + + +def get_field_schema_validations(field: ModelField) -> Dict[str, Any]: + """ + Get the JSON Schema validation keywords for a ``field`` with an annotation of + a Pydantic ``FieldInfo`` with validation arguments. + """ + f_schema: Dict[str, Any] = {} + + if lenient_issubclass(field.type_, Enum): + # schema is already updated by `enum_process_schema`; just update with field extra + if field.field_info.extra: + f_schema.update(field.field_info.extra) + return f_schema + + if lenient_issubclass(field.type_, (str, bytes)): + for attr_name, t, keyword in _str_types_attrs: + attr = getattr(field.field_info, attr_name, None) + if isinstance(attr, t): + f_schema[keyword] = attr + if lenient_issubclass(field.type_, numeric_types) and not issubclass(field.type_, bool): + for attr_name, t, keyword in _numeric_types_attrs: + attr = getattr(field.field_info, attr_name, None) + if isinstance(attr, t): + f_schema[keyword] = attr + if field.field_info is not None and field.field_info.const: + f_schema['const'] = field.default + if field.field_info.extra: + f_schema.update(field.field_info.extra) + modify_schema = getattr(field.outer_type_, '__modify_schema__', None) + if modify_schema: + _apply_modify_schema(modify_schema, field, f_schema) + return f_schema + + +def get_model_name_map(unique_models: TypeModelSet) -> Dict[TypeModelOrEnum, str]: + """ + Process a set of models and generate unique names for them to be used as keys in the JSON Schema + definitions. By default the names are the same as the class name. But if two models in different Python + modules have the same name (e.g. "users.Model" and "items.Model"), the generated names will be + based on the Python module path for those conflicting models to prevent name collisions. + + :param unique_models: a Python set of models + :return: dict mapping models to names + """ + name_model_map = {} + conflicting_names: Set[str] = set() + for model in unique_models: + model_name = normalize_name(model.__name__) + if model_name in conflicting_names: + model_name = get_long_model_name(model) + name_model_map[model_name] = model + elif model_name in name_model_map: + conflicting_names.add(model_name) + conflicting_model = name_model_map.pop(model_name) + name_model_map[get_long_model_name(conflicting_model)] = conflicting_model + name_model_map[get_long_model_name(model)] = model + else: + name_model_map[model_name] = model + return {v: k for k, v in name_model_map.items()} + + +def get_flat_models_from_model(model: Type['BaseModel'], known_models: Optional[TypeModelSet] = None) -> TypeModelSet: + """ + Take a single ``model`` and generate a set with itself and all the sub-models in the tree. I.e. if you pass + model ``Foo`` (subclass of Pydantic ``BaseModel``) as ``model``, and it has a field of type ``Bar`` (also + subclass of ``BaseModel``) and that model ``Bar`` has a field of type ``Baz`` (also subclass of ``BaseModel``), + the return value will be ``set([Foo, Bar, Baz])``. + + :param model: a Pydantic ``BaseModel`` subclass + :param known_models: used to solve circular references + :return: a set with the initial model and all its sub-models + """ + known_models = known_models or set() + flat_models: TypeModelSet = set() + flat_models.add(model) + known_models |= flat_models + fields = cast(Sequence[ModelField], model.__fields__.values()) + flat_models |= get_flat_models_from_fields(fields, known_models=known_models) + return flat_models + + +def get_flat_models_from_field(field: ModelField, known_models: TypeModelSet) -> TypeModelSet: + """ + Take a single Pydantic ``ModelField`` (from a model) that could have been declared as a subclass of BaseModel + (so, it could be a submodel), and generate a set with its model and all the sub-models in the tree. + I.e. if you pass a field that was declared to be of type ``Foo`` (subclass of BaseModel) as ``field``, and that + model ``Foo`` has a field of type ``Bar`` (also subclass of ``BaseModel``) and that model ``Bar`` has a field of + type ``Baz`` (also subclass of ``BaseModel``), the return value will be ``set([Foo, Bar, Baz])``. + + :param field: a Pydantic ``ModelField`` + :param known_models: used to solve circular references + :return: a set with the model used in the declaration for this field, if any, and all its sub-models + """ + from .main import BaseModel + + flat_models: TypeModelSet = set() + + field_type = field.type_ + if lenient_issubclass(getattr(field_type, '__pydantic_model__', None), BaseModel): + field_type = field_type.__pydantic_model__ + + if field.sub_fields and not lenient_issubclass(field_type, BaseModel): + flat_models |= get_flat_models_from_fields(field.sub_fields, known_models=known_models) + elif lenient_issubclass(field_type, BaseModel) and field_type not in known_models: + flat_models |= get_flat_models_from_model(field_type, known_models=known_models) + elif lenient_issubclass(field_type, Enum): + flat_models.add(field_type) + return flat_models + + +def get_flat_models_from_fields(fields: Sequence[ModelField], known_models: TypeModelSet) -> TypeModelSet: + """ + Take a list of Pydantic ``ModelField``s (from a model) that could have been declared as subclasses of ``BaseModel`` + (so, any of them could be a submodel), and generate a set with their models and all the sub-models in the tree. + I.e. if you pass a the fields of a model ``Foo`` (subclass of ``BaseModel``) as ``fields``, and on of them has a + field of type ``Bar`` (also subclass of ``BaseModel``) and that model ``Bar`` has a field of type ``Baz`` (also + subclass of ``BaseModel``), the return value will be ``set([Foo, Bar, Baz])``. + + :param fields: a list of Pydantic ``ModelField``s + :param known_models: used to solve circular references + :return: a set with any model declared in the fields, and all their sub-models + """ + flat_models: TypeModelSet = set() + for field in fields: + flat_models |= get_flat_models_from_field(field, known_models=known_models) + return flat_models + + +def get_flat_models_from_models(models: Sequence[Type['BaseModel']]) -> TypeModelSet: + """ + Take a list of ``models`` and generate a set with them and all their sub-models in their trees. I.e. if you pass + a list of two models, ``Foo`` and ``Bar``, both subclasses of Pydantic ``BaseModel`` as models, and ``Bar`` has + a field of type ``Baz`` (also subclass of ``BaseModel``), the return value will be ``set([Foo, Bar, Baz])``. + """ + flat_models: TypeModelSet = set() + for model in models: + flat_models |= get_flat_models_from_model(model) + return flat_models + + +def get_long_model_name(model: TypeModelOrEnum) -> str: + return f'{model.__module__}__{model.__qualname__}'.replace('.', '__') + + +def field_type_schema( + field: ModelField, + *, + by_alias: bool, + model_name_map: Dict[TypeModelOrEnum, str], + ref_template: str, + schema_overrides: bool = False, + ref_prefix: Optional[str] = None, + known_models: TypeModelSet, +) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: + """ + Used by ``field_schema()``, you probably should be using that function. + + Take a single ``field`` and generate the schema for its type only, not including additional + information as title, etc. Also return additional schema definitions, from sub-models. + """ + from .main import BaseModel # noqa: F811 + + definitions = {} + nested_models: Set[str] = set() + f_schema: Dict[str, Any] + if field.shape in { + SHAPE_LIST, + SHAPE_TUPLE_ELLIPSIS, + SHAPE_SEQUENCE, + SHAPE_SET, + SHAPE_FROZENSET, + SHAPE_ITERABLE, + SHAPE_DEQUE, + }: + items_schema, f_definitions, f_nested_models = field_singleton_schema( + field, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + definitions.update(f_definitions) + nested_models.update(f_nested_models) + f_schema = {'type': 'array', 'items': items_schema} + if field.shape in {SHAPE_SET, SHAPE_FROZENSET}: + f_schema['uniqueItems'] = True + + elif field.shape in MAPPING_LIKE_SHAPES: + f_schema = {'type': 'object'} + key_field = cast(ModelField, field.key_field) + regex = getattr(key_field.type_, 'regex', None) + items_schema, f_definitions, f_nested_models = field_singleton_schema( + field, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + definitions.update(f_definitions) + nested_models.update(f_nested_models) + if regex: + # Dict keys have a regex pattern + # items_schema might be a schema or empty dict, add it either way + f_schema['patternProperties'] = {ConstrainedStr._get_pattern(regex): items_schema} + if items_schema: + # The dict values are not simply Any, so they need a schema + f_schema['additionalProperties'] = items_schema + elif field.shape == SHAPE_TUPLE or (field.shape == SHAPE_GENERIC and not issubclass(field.type_, BaseModel)): + sub_schema = [] + sub_fields = cast(List[ModelField], field.sub_fields) + for sf in sub_fields: + sf_schema, sf_definitions, sf_nested_models = field_type_schema( + sf, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + definitions.update(sf_definitions) + nested_models.update(sf_nested_models) + sub_schema.append(sf_schema) + + sub_fields_len = len(sub_fields) + if field.shape == SHAPE_GENERIC: + all_of_schemas = sub_schema[0] if sub_fields_len == 1 else {'type': 'array', 'items': sub_schema} + f_schema = {'allOf': [all_of_schemas]} + else: + f_schema = { + 'type': 'array', + 'minItems': sub_fields_len, + 'maxItems': sub_fields_len, + } + if sub_fields_len >= 1: + f_schema['items'] = sub_schema + else: + assert field.shape in {SHAPE_SINGLETON, SHAPE_GENERIC}, field.shape + f_schema, f_definitions, f_nested_models = field_singleton_schema( + field, + by_alias=by_alias, + model_name_map=model_name_map, + schema_overrides=schema_overrides, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + definitions.update(f_definitions) + nested_models.update(f_nested_models) + + # check field type to avoid repeated calls to the same __modify_schema__ method + if field.type_ != field.outer_type_: + if field.shape == SHAPE_GENERIC: + field_type = field.type_ + else: + field_type = field.outer_type_ + modify_schema = getattr(field_type, '__modify_schema__', None) + if modify_schema: + _apply_modify_schema(modify_schema, field, f_schema) + return f_schema, definitions, nested_models + + +def model_process_schema( + model: TypeModelOrEnum, + *, + by_alias: bool = True, + model_name_map: Dict[TypeModelOrEnum, str], + ref_prefix: Optional[str] = None, + ref_template: str = default_ref_template, + known_models: Optional[TypeModelSet] = None, + field: Optional[ModelField] = None, +) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: + """ + Used by ``model_schema()``, you probably should be using that function. + + Take a single ``model`` and generate its schema. Also return additional schema definitions, from sub-models. The + sub-models of the returned schema will be referenced, but their definitions will not be included in the schema. All + the definitions are returned as the second value. + """ + from inspect import getdoc, signature + + known_models = known_models or set() + if lenient_issubclass(model, Enum): + model = cast(Type[Enum], model) + s = enum_process_schema(model, field=field) + return s, {}, set() + model = cast(Type['BaseModel'], model) + s = {'title': model.__config__.title or model.__name__} + doc = getdoc(model) + if doc: + s['description'] = doc + known_models.add(model) + m_schema, m_definitions, nested_models = model_type_schema( + model, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + s.update(m_schema) + schema_extra = model.__config__.schema_extra + if callable(schema_extra): + if len(signature(schema_extra).parameters) == 1: + schema_extra(s) + else: + schema_extra(s, model) + else: + s.update(schema_extra) + return s, m_definitions, nested_models + + +def model_type_schema( + model: Type['BaseModel'], + *, + by_alias: bool, + model_name_map: Dict[TypeModelOrEnum, str], + ref_template: str, + ref_prefix: Optional[str] = None, + known_models: TypeModelSet, +) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: + """ + You probably should be using ``model_schema()``, this function is indirectly used by that function. + + Take a single ``model`` and generate the schema for its type only, not including additional + information as title, etc. Also return additional schema definitions, from sub-models. + """ + properties = {} + required = [] + definitions: Dict[str, Any] = {} + nested_models: Set[str] = set() + for k, f in model.__fields__.items(): + try: + f_schema, f_definitions, f_nested_models = field_schema( + f, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + except SkipField as skip: + warnings.warn(skip.message, UserWarning) + continue + definitions.update(f_definitions) + nested_models.update(f_nested_models) + if by_alias: + properties[f.alias] = f_schema + if f.required: + required.append(f.alias) + else: + properties[k] = f_schema + if f.required: + required.append(k) + if ROOT_KEY in properties: + out_schema = properties[ROOT_KEY] + out_schema['title'] = model.__config__.title or model.__name__ + else: + out_schema = {'type': 'object', 'properties': properties} + if required: + out_schema['required'] = required + if model.__config__.extra == 'forbid': + out_schema['additionalProperties'] = False + return out_schema, definitions, nested_models + + +def enum_process_schema(enum: Type[Enum], *, field: Optional[ModelField] = None) -> Dict[str, Any]: + """ + Take a single `enum` and generate its schema. + + This is similar to the `model_process_schema` function, but applies to ``Enum`` objects. + """ + import inspect + + schema_: Dict[str, Any] = { + 'title': enum.__name__, + # Python assigns all enums a default docstring value of 'An enumeration', so + # all enums will have a description field even if not explicitly provided. + 'description': inspect.cleandoc(enum.__doc__ or 'An enumeration.'), + # Add enum values and the enum field type to the schema. + 'enum': [item.value for item in cast(Iterable[Enum], enum)], + } + + add_field_type_to_schema(enum, schema_) + + modify_schema = getattr(enum, '__modify_schema__', None) + if modify_schema: + _apply_modify_schema(modify_schema, field, schema_) + + return schema_ + + +def field_singleton_sub_fields_schema( + field: ModelField, + *, + by_alias: bool, + model_name_map: Dict[TypeModelOrEnum, str], + ref_template: str, + schema_overrides: bool = False, + ref_prefix: Optional[str] = None, + known_models: TypeModelSet, +) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: + """ + This function is indirectly used by ``field_schema()``, you probably should be using that function. + + Take a list of Pydantic ``ModelField`` from the declaration of a type with parameters, and generate their + schema. I.e., fields used as "type parameters", like ``str`` and ``int`` in ``Tuple[str, int]``. + """ + sub_fields = cast(List[ModelField], field.sub_fields) + definitions = {} + nested_models: Set[str] = set() + if len(sub_fields) == 1: + return field_type_schema( + sub_fields[0], + by_alias=by_alias, + model_name_map=model_name_map, + schema_overrides=schema_overrides, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + else: + s: Dict[str, Any] = {} + # https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#discriminator-object + field_has_discriminator: bool = field.discriminator_key is not None + if field_has_discriminator: + assert field.sub_fields_mapping is not None + + discriminator_models_refs: Dict[str, Union[str, Dict[str, Any]]] = {} + + for discriminator_value, sub_field in field.sub_fields_mapping.items(): + if isinstance(discriminator_value, Enum): + discriminator_value = str(discriminator_value.value) + # sub_field is either a `BaseModel` or directly an `Annotated` `Union` of many + if is_union(get_origin(sub_field.type_)): + sub_models = get_sub_types(sub_field.type_) + discriminator_models_refs[discriminator_value] = { + model_name_map[sub_model]: get_schema_ref( + model_name_map[sub_model], ref_prefix, ref_template, False + ) + for sub_model in sub_models + } + else: + sub_field_type = sub_field.type_ + if hasattr(sub_field_type, '__pydantic_model__'): + sub_field_type = sub_field_type.__pydantic_model__ + + discriminator_model_name = model_name_map[sub_field_type] + discriminator_model_ref = get_schema_ref(discriminator_model_name, ref_prefix, ref_template, False) + discriminator_models_refs[discriminator_value] = discriminator_model_ref['$ref'] + + s['discriminator'] = { + 'propertyName': field.discriminator_alias, + 'mapping': discriminator_models_refs, + } + + sub_field_schemas = [] + for sf in sub_fields: + sub_schema, sub_definitions, sub_nested_models = field_type_schema( + sf, + by_alias=by_alias, + model_name_map=model_name_map, + schema_overrides=schema_overrides, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + definitions.update(sub_definitions) + if schema_overrides and 'allOf' in sub_schema: + # if the sub_field is a referenced schema we only need the referenced + # object. Otherwise we will end up with several allOf inside anyOf/oneOf. + # See https://github.com/pydantic/pydantic/issues/1209 + sub_schema = sub_schema['allOf'][0] + + if sub_schema.keys() == {'discriminator', 'oneOf'}: + # we don't want discriminator information inside oneOf choices, this is dealt with elsewhere + sub_schema.pop('discriminator') + sub_field_schemas.append(sub_schema) + nested_models.update(sub_nested_models) + s['oneOf' if field_has_discriminator else 'anyOf'] = sub_field_schemas + return s, definitions, nested_models + + +# Order is important, e.g. subclasses of str must go before str +# this is used only for standard library types, custom types should use __modify_schema__ instead +field_class_to_schema: Tuple[Tuple[Any, Dict[str, Any]], ...] = ( + (Path, {'type': 'string', 'format': 'path'}), + (datetime, {'type': 'string', 'format': 'date-time'}), + (date, {'type': 'string', 'format': 'date'}), + (time, {'type': 'string', 'format': 'time'}), + (timedelta, {'type': 'number', 'format': 'time-delta'}), + (IPv4Network, {'type': 'string', 'format': 'ipv4network'}), + (IPv6Network, {'type': 'string', 'format': 'ipv6network'}), + (IPv4Interface, {'type': 'string', 'format': 'ipv4interface'}), + (IPv6Interface, {'type': 'string', 'format': 'ipv6interface'}), + (IPv4Address, {'type': 'string', 'format': 'ipv4'}), + (IPv6Address, {'type': 'string', 'format': 'ipv6'}), + (Pattern, {'type': 'string', 'format': 'regex'}), + (str, {'type': 'string'}), + (bytes, {'type': 'string', 'format': 'binary'}), + (bool, {'type': 'boolean'}), + (int, {'type': 'integer'}), + (float, {'type': 'number'}), + (Decimal, {'type': 'number'}), + (UUID, {'type': 'string', 'format': 'uuid'}), + (dict, {'type': 'object'}), + (list, {'type': 'array', 'items': {}}), + (tuple, {'type': 'array', 'items': {}}), + (set, {'type': 'array', 'items': {}, 'uniqueItems': True}), + (frozenset, {'type': 'array', 'items': {}, 'uniqueItems': True}), +) + +json_scheme = {'type': 'string', 'format': 'json-string'} + + +def add_field_type_to_schema(field_type: Any, schema_: Dict[str, Any]) -> None: + """ + Update the given `schema` with the type-specific metadata for the given `field_type`. + + This function looks through `field_class_to_schema` for a class that matches the given `field_type`, + and then modifies the given `schema` with the information from that type. + """ + for type_, t_schema in field_class_to_schema: + # Fallback for `typing.Pattern` and `re.Pattern` as they are not a valid class + if lenient_issubclass(field_type, type_) or field_type is type_ is Pattern: + schema_.update(t_schema) + break + + +def get_schema_ref(name: str, ref_prefix: Optional[str], ref_template: str, schema_overrides: bool) -> Dict[str, Any]: + if ref_prefix: + schema_ref = {'$ref': ref_prefix + name} + else: + schema_ref = {'$ref': ref_template.format(model=name)} + return {'allOf': [schema_ref]} if schema_overrides else schema_ref + + +def field_singleton_schema( # noqa: C901 (ignore complexity) + field: ModelField, + *, + by_alias: bool, + model_name_map: Dict[TypeModelOrEnum, str], + ref_template: str, + schema_overrides: bool = False, + ref_prefix: Optional[str] = None, + known_models: TypeModelSet, +) -> Tuple[Dict[str, Any], Dict[str, Any], Set[str]]: + """ + This function is indirectly used by ``field_schema()``, you should probably be using that function. + + Take a single Pydantic ``ModelField``, and return its schema and any additional definitions from sub-models. + """ + from .main import BaseModel + + definitions: Dict[str, Any] = {} + nested_models: Set[str] = set() + field_type = field.type_ + + # Recurse into this field if it contains sub_fields and is NOT a + # BaseModel OR that BaseModel is a const + if field.sub_fields and ( + (field.field_info and field.field_info.const) or not lenient_issubclass(field_type, BaseModel) + ): + return field_singleton_sub_fields_schema( + field, + by_alias=by_alias, + model_name_map=model_name_map, + schema_overrides=schema_overrides, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + if field_type is Any or field_type is object or field_type.__class__ == TypeVar or get_origin(field_type) is type: + return {}, definitions, nested_models # no restrictions + if is_none_type(field_type): + return {'type': 'null'}, definitions, nested_models + if is_callable_type(field_type): + raise SkipField(f'Callable {field.name} was excluded from schema since JSON schema has no equivalent type.') + f_schema: Dict[str, Any] = {} + if field.field_info is not None and field.field_info.const: + f_schema['const'] = field.default + + if is_literal_type(field_type): + values = tuple(x.value if isinstance(x, Enum) else x for x in all_literal_values(field_type)) + + if len({v.__class__ for v in values}) > 1: + return field_schema( + multitypes_literal_field_for_schema(values, field), + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + ) + + # All values have the same type + field_type = values[0].__class__ + f_schema['enum'] = list(values) + add_field_type_to_schema(field_type, f_schema) + elif lenient_issubclass(field_type, Enum): + enum_name = model_name_map[field_type] + f_schema, schema_overrides = get_field_info_schema(field, schema_overrides) + f_schema.update(get_schema_ref(enum_name, ref_prefix, ref_template, schema_overrides)) + definitions[enum_name] = enum_process_schema(field_type, field=field) + elif is_namedtuple(field_type): + sub_schema, *_ = model_process_schema( + field_type.__pydantic_model__, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + field=field, + ) + items_schemas = list(sub_schema['properties'].values()) + f_schema.update( + { + 'type': 'array', + 'items': items_schemas, + 'minItems': len(items_schemas), + 'maxItems': len(items_schemas), + } + ) + elif not hasattr(field_type, '__pydantic_model__'): + add_field_type_to_schema(field_type, f_schema) + + modify_schema = getattr(field_type, '__modify_schema__', None) + if modify_schema: + _apply_modify_schema(modify_schema, field, f_schema) + + if f_schema: + return f_schema, definitions, nested_models + + # Handle dataclass-based models + if lenient_issubclass(getattr(field_type, '__pydantic_model__', None), BaseModel): + field_type = field_type.__pydantic_model__ + + if issubclass(field_type, BaseModel): + model_name = model_name_map[field_type] + if field_type not in known_models: + sub_schema, sub_definitions, sub_nested_models = model_process_schema( + field_type, + by_alias=by_alias, + model_name_map=model_name_map, + ref_prefix=ref_prefix, + ref_template=ref_template, + known_models=known_models, + field=field, + ) + definitions.update(sub_definitions) + definitions[model_name] = sub_schema + nested_models.update(sub_nested_models) + else: + nested_models.add(model_name) + schema_ref = get_schema_ref(model_name, ref_prefix, ref_template, schema_overrides) + return schema_ref, definitions, nested_models + + # For generics with no args + args = get_args(field_type) + if args is not None and not args and Generic in field_type.__bases__: + return f_schema, definitions, nested_models + + raise ValueError(f'Value not declarable with JSON Schema, field: {field}') + + +def multitypes_literal_field_for_schema(values: Tuple[Any, ...], field: ModelField) -> ModelField: + """ + To support `Literal` with values of different types, we split it into multiple `Literal` with same type + e.g. `Literal['qwe', 'asd', 1, 2]` becomes `Union[Literal['qwe', 'asd'], Literal[1, 2]]` + """ + literal_distinct_types = defaultdict(list) + for v in values: + literal_distinct_types[v.__class__].append(v) + distinct_literals = (Literal[tuple(same_type_values)] for same_type_values in literal_distinct_types.values()) + + return ModelField( + name=field.name, + type_=Union[tuple(distinct_literals)], # type: ignore + class_validators=field.class_validators, + model_config=field.model_config, + default=field.default, + required=field.required, + alias=field.alias, + field_info=field.field_info, + ) + + +def encode_default(dft: Any) -> Any: + from .main import BaseModel + + if isinstance(dft, BaseModel) or is_dataclass(dft): + dft = cast('dict[str, Any]', pydantic_encoder(dft)) + + if isinstance(dft, dict): + return {encode_default(k): encode_default(v) for k, v in dft.items()} + elif isinstance(dft, Enum): + return dft.value + elif isinstance(dft, (int, float, str)): + return dft + elif isinstance(dft, (list, tuple)): + t = dft.__class__ + seq_args = (encode_default(v) for v in dft) + return t(*seq_args) if is_namedtuple(t) else t(seq_args) + elif dft is None: + return None + else: + return pydantic_encoder(dft) + + +_map_types_constraint: Dict[Any, Callable[..., type]] = {int: conint, float: confloat, Decimal: condecimal} + + +def get_annotation_from_field_info( + annotation: Any, field_info: FieldInfo, field_name: str, validate_assignment: bool = False +) -> Type[Any]: + """ + Get an annotation with validation implemented for numbers and strings based on the field_info. + :param annotation: an annotation from a field specification, as ``str``, ``ConstrainedStr`` + :param field_info: an instance of FieldInfo, possibly with declarations for validations and JSON Schema + :param field_name: name of the field for use in error messages + :param validate_assignment: default False, flag for BaseModel Config value of validate_assignment + :return: the same ``annotation`` if unmodified or a new annotation with validation in place + """ + constraints = field_info.get_constraints() + used_constraints: Set[str] = set() + if constraints: + annotation, used_constraints = get_annotation_with_constraints(annotation, field_info) + if validate_assignment: + used_constraints.add('allow_mutation') + + unused_constraints = constraints - used_constraints + if unused_constraints: + raise ValueError( + f'On field "{field_name}" the following field constraints are set but not enforced: ' + f'{", ".join(unused_constraints)}. ' + f'\nFor more details see https://docs.pydantic.dev/usage/schema/#unenforced-field-constraints' + ) + + return annotation + + +def get_annotation_with_constraints(annotation: Any, field_info: FieldInfo) -> Tuple[Type[Any], Set[str]]: # noqa: C901 + """ + Get an annotation with used constraints implemented for numbers and strings based on the field_info. + + :param annotation: an annotation from a field specification, as ``str``, ``ConstrainedStr`` + :param field_info: an instance of FieldInfo, possibly with declarations for validations and JSON Schema + :return: the same ``annotation`` if unmodified or a new annotation along with the used constraints. + """ + used_constraints: Set[str] = set() + + def go(type_: Any) -> Type[Any]: + if ( + is_literal_type(type_) + or isinstance(type_, ForwardRef) + or lenient_issubclass(type_, (ConstrainedList, ConstrainedSet, ConstrainedFrozenSet)) + ): + return type_ + origin = get_origin(type_) + if origin is not None: + args: Tuple[Any, ...] = get_args(type_) + if any(isinstance(a, ForwardRef) for a in args): + # forward refs cause infinite recursion below + return type_ + + if origin is Annotated: + return go(args[0]) + if is_union(origin): + return Union[tuple(go(a) for a in args)] # type: ignore + + if issubclass(origin, List) and ( + field_info.min_items is not None + or field_info.max_items is not None + or field_info.unique_items is not None + ): + used_constraints.update({'min_items', 'max_items', 'unique_items'}) + return conlist( + go(args[0]), + min_items=field_info.min_items, + max_items=field_info.max_items, + unique_items=field_info.unique_items, + ) + + if issubclass(origin, Set) and (field_info.min_items is not None or field_info.max_items is not None): + used_constraints.update({'min_items', 'max_items'}) + return conset(go(args[0]), min_items=field_info.min_items, max_items=field_info.max_items) + + if issubclass(origin, FrozenSet) and (field_info.min_items is not None or field_info.max_items is not None): + used_constraints.update({'min_items', 'max_items'}) + return confrozenset(go(args[0]), min_items=field_info.min_items, max_items=field_info.max_items) + + for t in (Tuple, List, Set, FrozenSet, Sequence): + if issubclass(origin, t): # type: ignore + return t[tuple(go(a) for a in args)] # type: ignore + + if issubclass(origin, Dict): + return Dict[args[0], go(args[1])] # type: ignore + + attrs: Optional[Tuple[str, ...]] = None + constraint_func: Optional[Callable[..., type]] = None + if isinstance(type_, type): + if issubclass(type_, (SecretStr, SecretBytes)): + attrs = ('max_length', 'min_length') + + def constraint_func(**kw: Any) -> Type[Any]: + return type(type_.__name__, (type_,), kw) + + elif issubclass(type_, str) and not issubclass(type_, (EmailStr, AnyUrl)): + attrs = ('max_length', 'min_length', 'regex') + if issubclass(type_, StrictStr): + + def constraint_func(**kw: Any) -> Type[Any]: + return type(type_.__name__, (type_,), kw) + + else: + constraint_func = constr + elif issubclass(type_, bytes): + attrs = ('max_length', 'min_length', 'regex') + if issubclass(type_, StrictBytes): + + def constraint_func(**kw: Any) -> Type[Any]: + return type(type_.__name__, (type_,), kw) + + else: + constraint_func = conbytes + elif issubclass(type_, numeric_types) and not issubclass( + type_, + ( + ConstrainedInt, + ConstrainedFloat, + ConstrainedDecimal, + ConstrainedList, + ConstrainedSet, + ConstrainedFrozenSet, + bool, + ), + ): + # Is numeric type + attrs = ('gt', 'lt', 'ge', 'le', 'multiple_of') + if issubclass(type_, float): + attrs += ('allow_inf_nan',) + if issubclass(type_, Decimal): + attrs += ('max_digits', 'decimal_places') + numeric_type = next(t for t in numeric_types if issubclass(type_, t)) # pragma: no branch + constraint_func = _map_types_constraint[numeric_type] + + if attrs: + used_constraints.update(set(attrs)) + kwargs = { + attr_name: attr + for attr_name, attr in ((attr_name, getattr(field_info, attr_name)) for attr_name in attrs) + if attr is not None + } + if kwargs: + constraint_func = cast(Callable[..., type], constraint_func) + return constraint_func(**kwargs) + return type_ + + return go(annotation), used_constraints + + +def normalize_name(name: str) -> str: + """ + Normalizes the given name. This can be applied to either a model *or* enum. + """ + return re.sub(r'[^a-zA-Z0-9.\-_]', '_', name) + + +class SkipField(Exception): + """ + Utility exception used to exclude fields from schema. + """ + + def __init__(self, message: str) -> None: + self.message = message diff --git a/lib/pydantic/v1/tools.py b/lib/pydantic/v1/tools.py new file mode 100644 index 00000000..45be2770 --- /dev/null +++ b/lib/pydantic/v1/tools.py @@ -0,0 +1,92 @@ +import json +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeVar, Union + +from .parse import Protocol, load_file, load_str_bytes +from .types import StrBytes +from .typing import display_as_type + +__all__ = ('parse_file_as', 'parse_obj_as', 'parse_raw_as', 'schema_of', 'schema_json_of') + +NameFactory = Union[str, Callable[[Type[Any]], str]] + +if TYPE_CHECKING: + from .typing import DictStrAny + + +def _generate_parsing_type_name(type_: Any) -> str: + return f'ParsingModel[{display_as_type(type_)}]' + + +@lru_cache(maxsize=2048) +def _get_parsing_type(type_: Any, *, type_name: Optional[NameFactory] = None) -> Any: + from .main import create_model + + if type_name is None: + type_name = _generate_parsing_type_name + if not isinstance(type_name, str): + type_name = type_name(type_) + return create_model(type_name, __root__=(type_, ...)) + + +T = TypeVar('T') + + +def parse_obj_as(type_: Type[T], obj: Any, *, type_name: Optional[NameFactory] = None) -> T: + model_type = _get_parsing_type(type_, type_name=type_name) # type: ignore[arg-type] + return model_type(__root__=obj).__root__ + + +def parse_file_as( + type_: Type[T], + path: Union[str, Path], + *, + content_type: str = None, + encoding: str = 'utf8', + proto: Protocol = None, + allow_pickle: bool = False, + json_loads: Callable[[str], Any] = json.loads, + type_name: Optional[NameFactory] = None, +) -> T: + obj = load_file( + path, + proto=proto, + content_type=content_type, + encoding=encoding, + allow_pickle=allow_pickle, + json_loads=json_loads, + ) + return parse_obj_as(type_, obj, type_name=type_name) + + +def parse_raw_as( + type_: Type[T], + b: StrBytes, + *, + content_type: str = None, + encoding: str = 'utf8', + proto: Protocol = None, + allow_pickle: bool = False, + json_loads: Callable[[str], Any] = json.loads, + type_name: Optional[NameFactory] = None, +) -> T: + obj = load_str_bytes( + b, + proto=proto, + content_type=content_type, + encoding=encoding, + allow_pickle=allow_pickle, + json_loads=json_loads, + ) + return parse_obj_as(type_, obj, type_name=type_name) + + +def schema_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_kwargs: Any) -> 'DictStrAny': + """Generate a JSON schema (as dict) for the passed model or dynamically generated one""" + return _get_parsing_type(type_, type_name=title).schema(**schema_kwargs) + + +def schema_json_of(type_: Any, *, title: Optional[NameFactory] = None, **schema_json_kwargs: Any) -> str: + """Generate a JSON schema (as JSON) for the passed model or dynamically generated one""" + return _get_parsing_type(type_, type_name=title).schema_json(**schema_json_kwargs) diff --git a/lib/pydantic/v1/types.py b/lib/pydantic/v1/types.py new file mode 100644 index 00000000..754e58ff --- /dev/null +++ b/lib/pydantic/v1/types.py @@ -0,0 +1,1205 @@ +import abc +import math +import re +import warnings +from datetime import date +from decimal import Decimal, InvalidOperation +from enum import Enum +from pathlib import Path +from types import new_class +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + FrozenSet, + List, + Optional, + Pattern, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) +from uuid import UUID +from weakref import WeakSet + +from . import errors +from .datetime_parse import parse_date +from .utils import import_string, update_not_none +from .validators import ( + bytes_validator, + constr_length_validator, + constr_lower, + constr_strip_whitespace, + constr_upper, + decimal_validator, + float_finite_validator, + float_validator, + frozenset_validator, + int_validator, + list_validator, + number_multiple_validator, + number_size_validator, + path_exists_validator, + path_validator, + set_validator, + str_validator, + strict_bytes_validator, + strict_float_validator, + strict_int_validator, + strict_str_validator, +) + +__all__ = [ + 'NoneStr', + 'NoneBytes', + 'StrBytes', + 'NoneStrBytes', + 'StrictStr', + 'ConstrainedBytes', + 'conbytes', + 'ConstrainedList', + 'conlist', + 'ConstrainedSet', + 'conset', + 'ConstrainedFrozenSet', + 'confrozenset', + 'ConstrainedStr', + 'constr', + 'PyObject', + 'ConstrainedInt', + 'conint', + 'PositiveInt', + 'NegativeInt', + 'NonNegativeInt', + 'NonPositiveInt', + 'ConstrainedFloat', + 'confloat', + 'PositiveFloat', + 'NegativeFloat', + 'NonNegativeFloat', + 'NonPositiveFloat', + 'FiniteFloat', + 'ConstrainedDecimal', + 'condecimal', + 'UUID1', + 'UUID3', + 'UUID4', + 'UUID5', + 'FilePath', + 'DirectoryPath', + 'Json', + 'JsonWrapper', + 'SecretField', + 'SecretStr', + 'SecretBytes', + 'StrictBool', + 'StrictBytes', + 'StrictInt', + 'StrictFloat', + 'PaymentCardNumber', + 'ByteSize', + 'PastDate', + 'FutureDate', + 'ConstrainedDate', + 'condate', +] + +NoneStr = Optional[str] +NoneBytes = Optional[bytes] +StrBytes = Union[str, bytes] +NoneStrBytes = Optional[StrBytes] +OptionalInt = Optional[int] +OptionalIntFloat = Union[OptionalInt, float] +OptionalIntFloatDecimal = Union[OptionalIntFloat, Decimal] +OptionalDate = Optional[date] +StrIntFloat = Union[str, int, float] + +if TYPE_CHECKING: + from typing_extensions import Annotated + + from .dataclasses import Dataclass + from .main import BaseModel + from .typing import CallableGenerator + + ModelOrDc = Type[Union[BaseModel, Dataclass]] + +T = TypeVar('T') +_DEFINED_TYPES: 'WeakSet[type]' = WeakSet() + + +@overload +def _registered(typ: Type[T]) -> Type[T]: + pass + + +@overload +def _registered(typ: 'ConstrainedNumberMeta') -> 'ConstrainedNumberMeta': + pass + + +def _registered(typ: Union[Type[T], 'ConstrainedNumberMeta']) -> Union[Type[T], 'ConstrainedNumberMeta']: + # In order to generate valid examples of constrained types, Hypothesis needs + # to inspect the type object - so we keep a weakref to each contype object + # until it can be registered. When (or if) our Hypothesis plugin is loaded, + # it monkeypatches this function. + # If Hypothesis is never used, the total effect is to keep a weak reference + # which has minimal memory usage and doesn't even affect garbage collection. + _DEFINED_TYPES.add(typ) + return typ + + +class ConstrainedNumberMeta(type): + def __new__(cls, name: str, bases: Any, dct: Dict[str, Any]) -> 'ConstrainedInt': # type: ignore + new_cls = cast('ConstrainedInt', type.__new__(cls, name, bases, dct)) + + if new_cls.gt is not None and new_cls.ge is not None: + raise errors.ConfigError('bounds gt and ge cannot be specified at the same time') + if new_cls.lt is not None and new_cls.le is not None: + raise errors.ConfigError('bounds lt and le cannot be specified at the same time') + + return _registered(new_cls) # type: ignore + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BOOLEAN TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +if TYPE_CHECKING: + StrictBool = bool +else: + + class StrictBool(int): + """ + StrictBool to allow for bools which are not type-coerced. + """ + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='boolean') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls, value: Any) -> bool: + """ + Ensure that we only allow bools. + """ + if isinstance(value, bool): + return value + + raise errors.StrictBoolError() + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTEGER TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class ConstrainedInt(int, metaclass=ConstrainedNumberMeta): + strict: bool = False + gt: OptionalInt = None + ge: OptionalInt = None + lt: OptionalInt = None + le: OptionalInt = None + multiple_of: OptionalInt = None + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none( + field_schema, + exclusiveMinimum=cls.gt, + exclusiveMaximum=cls.lt, + minimum=cls.ge, + maximum=cls.le, + multipleOf=cls.multiple_of, + ) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield strict_int_validator if cls.strict else int_validator + yield number_size_validator + yield number_multiple_validator + + +def conint( + *, + strict: bool = False, + gt: Optional[int] = None, + ge: Optional[int] = None, + lt: Optional[int] = None, + le: Optional[int] = None, + multiple_of: Optional[int] = None, +) -> Type[int]: + # use kwargs then define conf in a dict to aid with IDE type hinting + namespace = dict(strict=strict, gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of) + return type('ConstrainedIntValue', (ConstrainedInt,), namespace) + + +if TYPE_CHECKING: + PositiveInt = int + NegativeInt = int + NonPositiveInt = int + NonNegativeInt = int + StrictInt = int +else: + + class PositiveInt(ConstrainedInt): + gt = 0 + + class NegativeInt(ConstrainedInt): + lt = 0 + + class NonPositiveInt(ConstrainedInt): + le = 0 + + class NonNegativeInt(ConstrainedInt): + ge = 0 + + class StrictInt(ConstrainedInt): + strict = True + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLOAT TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class ConstrainedFloat(float, metaclass=ConstrainedNumberMeta): + strict: bool = False + gt: OptionalIntFloat = None + ge: OptionalIntFloat = None + lt: OptionalIntFloat = None + le: OptionalIntFloat = None + multiple_of: OptionalIntFloat = None + allow_inf_nan: Optional[bool] = None + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none( + field_schema, + exclusiveMinimum=cls.gt, + exclusiveMaximum=cls.lt, + minimum=cls.ge, + maximum=cls.le, + multipleOf=cls.multiple_of, + ) + # Modify constraints to account for differences between IEEE floats and JSON + if field_schema.get('exclusiveMinimum') == -math.inf: + del field_schema['exclusiveMinimum'] + if field_schema.get('minimum') == -math.inf: + del field_schema['minimum'] + if field_schema.get('exclusiveMaximum') == math.inf: + del field_schema['exclusiveMaximum'] + if field_schema.get('maximum') == math.inf: + del field_schema['maximum'] + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield strict_float_validator if cls.strict else float_validator + yield number_size_validator + yield number_multiple_validator + yield float_finite_validator + + +def confloat( + *, + strict: bool = False, + gt: float = None, + ge: float = None, + lt: float = None, + le: float = None, + multiple_of: float = None, + allow_inf_nan: Optional[bool] = None, +) -> Type[float]: + # use kwargs then define conf in a dict to aid with IDE type hinting + namespace = dict(strict=strict, gt=gt, ge=ge, lt=lt, le=le, multiple_of=multiple_of, allow_inf_nan=allow_inf_nan) + return type('ConstrainedFloatValue', (ConstrainedFloat,), namespace) + + +if TYPE_CHECKING: + PositiveFloat = float + NegativeFloat = float + NonPositiveFloat = float + NonNegativeFloat = float + StrictFloat = float + FiniteFloat = float +else: + + class PositiveFloat(ConstrainedFloat): + gt = 0 + + class NegativeFloat(ConstrainedFloat): + lt = 0 + + class NonPositiveFloat(ConstrainedFloat): + le = 0 + + class NonNegativeFloat(ConstrainedFloat): + ge = 0 + + class StrictFloat(ConstrainedFloat): + strict = True + + class FiniteFloat(ConstrainedFloat): + allow_inf_nan = False + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTES TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class ConstrainedBytes(bytes): + strip_whitespace = False + to_upper = False + to_lower = False + min_length: OptionalInt = None + max_length: OptionalInt = None + strict: bool = False + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none(field_schema, minLength=cls.min_length, maxLength=cls.max_length) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield strict_bytes_validator if cls.strict else bytes_validator + yield constr_strip_whitespace + yield constr_upper + yield constr_lower + yield constr_length_validator + + +def conbytes( + *, + strip_whitespace: bool = False, + to_upper: bool = False, + to_lower: bool = False, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + strict: bool = False, +) -> Type[bytes]: + # use kwargs then define conf in a dict to aid with IDE type hinting + namespace = dict( + strip_whitespace=strip_whitespace, + to_upper=to_upper, + to_lower=to_lower, + min_length=min_length, + max_length=max_length, + strict=strict, + ) + return _registered(type('ConstrainedBytesValue', (ConstrainedBytes,), namespace)) + + +if TYPE_CHECKING: + StrictBytes = bytes +else: + + class StrictBytes(ConstrainedBytes): + strict = True + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ STRING TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class ConstrainedStr(str): + strip_whitespace = False + to_upper = False + to_lower = False + min_length: OptionalInt = None + max_length: OptionalInt = None + curtail_length: OptionalInt = None + regex: Optional[Union[str, Pattern[str]]] = None + strict = False + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none( + field_schema, + minLength=cls.min_length, + maxLength=cls.max_length, + pattern=cls.regex and cls._get_pattern(cls.regex), + ) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield strict_str_validator if cls.strict else str_validator + yield constr_strip_whitespace + yield constr_upper + yield constr_lower + yield constr_length_validator + yield cls.validate + + @classmethod + def validate(cls, value: Union[str]) -> Union[str]: + if cls.curtail_length and len(value) > cls.curtail_length: + value = value[: cls.curtail_length] + + if cls.regex: + if not re.match(cls.regex, value): + raise errors.StrRegexError(pattern=cls._get_pattern(cls.regex)) + + return value + + @staticmethod + def _get_pattern(regex: Union[str, Pattern[str]]) -> str: + return regex if isinstance(regex, str) else regex.pattern + + +def constr( + *, + strip_whitespace: bool = False, + to_upper: bool = False, + to_lower: bool = False, + strict: bool = False, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + curtail_length: Optional[int] = None, + regex: Optional[str] = None, +) -> Type[str]: + # use kwargs then define conf in a dict to aid with IDE type hinting + namespace = dict( + strip_whitespace=strip_whitespace, + to_upper=to_upper, + to_lower=to_lower, + strict=strict, + min_length=min_length, + max_length=max_length, + curtail_length=curtail_length, + regex=regex and re.compile(regex), + ) + return _registered(type('ConstrainedStrValue', (ConstrainedStr,), namespace)) + + +if TYPE_CHECKING: + StrictStr = str +else: + + class StrictStr(ConstrainedStr): + strict = True + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +# This types superclass should be Set[T], but cython chokes on that... +class ConstrainedSet(set): # type: ignore + # Needed for pydantic to detect that this is a set + __origin__ = set + __args__: Set[Type[T]] # type: ignore + + min_items: Optional[int] = None + max_items: Optional[int] = None + item_type: Type[T] # type: ignore + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.set_length_validator + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none(field_schema, minItems=cls.min_items, maxItems=cls.max_items) + + @classmethod + def set_length_validator(cls, v: 'Optional[Set[T]]') -> 'Optional[Set[T]]': + if v is None: + return None + + v = set_validator(v) + v_len = len(v) + + if cls.min_items is not None and v_len < cls.min_items: + raise errors.SetMinLengthError(limit_value=cls.min_items) + + if cls.max_items is not None and v_len > cls.max_items: + raise errors.SetMaxLengthError(limit_value=cls.max_items) + + return v + + +def conset(item_type: Type[T], *, min_items: Optional[int] = None, max_items: Optional[int] = None) -> Type[Set[T]]: + # __args__ is needed to conform to typing generics api + namespace = {'min_items': min_items, 'max_items': max_items, 'item_type': item_type, '__args__': [item_type]} + # We use new_class to be able to deal with Generic types + return new_class('ConstrainedSetValue', (ConstrainedSet,), {}, lambda ns: ns.update(namespace)) + + +# This types superclass should be FrozenSet[T], but cython chokes on that... +class ConstrainedFrozenSet(frozenset): # type: ignore + # Needed for pydantic to detect that this is a set + __origin__ = frozenset + __args__: FrozenSet[Type[T]] # type: ignore + + min_items: Optional[int] = None + max_items: Optional[int] = None + item_type: Type[T] # type: ignore + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.frozenset_length_validator + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none(field_schema, minItems=cls.min_items, maxItems=cls.max_items) + + @classmethod + def frozenset_length_validator(cls, v: 'Optional[FrozenSet[T]]') -> 'Optional[FrozenSet[T]]': + if v is None: + return None + + v = frozenset_validator(v) + v_len = len(v) + + if cls.min_items is not None and v_len < cls.min_items: + raise errors.FrozenSetMinLengthError(limit_value=cls.min_items) + + if cls.max_items is not None and v_len > cls.max_items: + raise errors.FrozenSetMaxLengthError(limit_value=cls.max_items) + + return v + + +def confrozenset( + item_type: Type[T], *, min_items: Optional[int] = None, max_items: Optional[int] = None +) -> Type[FrozenSet[T]]: + # __args__ is needed to conform to typing generics api + namespace = {'min_items': min_items, 'max_items': max_items, 'item_type': item_type, '__args__': [item_type]} + # We use new_class to be able to deal with Generic types + return new_class('ConstrainedFrozenSetValue', (ConstrainedFrozenSet,), {}, lambda ns: ns.update(namespace)) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LIST TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +# This types superclass should be List[T], but cython chokes on that... +class ConstrainedList(list): # type: ignore + # Needed for pydantic to detect that this is a list + __origin__ = list + __args__: Tuple[Type[T], ...] # type: ignore + + min_items: Optional[int] = None + max_items: Optional[int] = None + unique_items: Optional[bool] = None + item_type: Type[T] # type: ignore + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.list_length_validator + if cls.unique_items: + yield cls.unique_items_validator + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none(field_schema, minItems=cls.min_items, maxItems=cls.max_items, uniqueItems=cls.unique_items) + + @classmethod + def list_length_validator(cls, v: 'Optional[List[T]]') -> 'Optional[List[T]]': + if v is None: + return None + + v = list_validator(v) + v_len = len(v) + + if cls.min_items is not None and v_len < cls.min_items: + raise errors.ListMinLengthError(limit_value=cls.min_items) + + if cls.max_items is not None and v_len > cls.max_items: + raise errors.ListMaxLengthError(limit_value=cls.max_items) + + return v + + @classmethod + def unique_items_validator(cls, v: 'Optional[List[T]]') -> 'Optional[List[T]]': + if v is None: + return None + + for i, value in enumerate(v, start=1): + if value in v[i:]: + raise errors.ListUniqueItemsError() + + return v + + +def conlist( + item_type: Type[T], *, min_items: Optional[int] = None, max_items: Optional[int] = None, unique_items: bool = None +) -> Type[List[T]]: + # __args__ is needed to conform to typing generics api + namespace = dict( + min_items=min_items, max_items=max_items, unique_items=unique_items, item_type=item_type, __args__=(item_type,) + ) + # We use new_class to be able to deal with Generic types + return new_class('ConstrainedListValue', (ConstrainedList,), {}, lambda ns: ns.update(namespace)) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PYOBJECT TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +if TYPE_CHECKING: + PyObject = Callable[..., Any] +else: + + class PyObject: + validate_always = True + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls, value: Any) -> Any: + if isinstance(value, Callable): + return value + + try: + value = str_validator(value) + except errors.StrError: + raise errors.PyObjectError(error_message='value is neither a valid import path not a valid callable') + + try: + return import_string(value) + except ImportError as e: + raise errors.PyObjectError(error_message=str(e)) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DECIMAL TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class ConstrainedDecimal(Decimal, metaclass=ConstrainedNumberMeta): + gt: OptionalIntFloatDecimal = None + ge: OptionalIntFloatDecimal = None + lt: OptionalIntFloatDecimal = None + le: OptionalIntFloatDecimal = None + max_digits: OptionalInt = None + decimal_places: OptionalInt = None + multiple_of: OptionalIntFloatDecimal = None + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none( + field_schema, + exclusiveMinimum=cls.gt, + exclusiveMaximum=cls.lt, + minimum=cls.ge, + maximum=cls.le, + multipleOf=cls.multiple_of, + ) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield decimal_validator + yield number_size_validator + yield number_multiple_validator + yield cls.validate + + @classmethod + def validate(cls, value: Decimal) -> Decimal: + try: + normalized_value = value.normalize() + except InvalidOperation: + normalized_value = value + digit_tuple, exponent = normalized_value.as_tuple()[1:] + if exponent in {'F', 'n', 'N'}: + raise errors.DecimalIsNotFiniteError() + + if exponent >= 0: + # A positive exponent adds that many trailing zeros. + digits = len(digit_tuple) + exponent + decimals = 0 + else: + # If the absolute value of the negative exponent is larger than the + # number of digits, then it's the same as the number of digits, + # because it'll consume all of the digits in digit_tuple and then + # add abs(exponent) - len(digit_tuple) leading zeros after the + # decimal point. + if abs(exponent) > len(digit_tuple): + digits = decimals = abs(exponent) + else: + digits = len(digit_tuple) + decimals = abs(exponent) + whole_digits = digits - decimals + + if cls.max_digits is not None and digits > cls.max_digits: + raise errors.DecimalMaxDigitsError(max_digits=cls.max_digits) + + if cls.decimal_places is not None and decimals > cls.decimal_places: + raise errors.DecimalMaxPlacesError(decimal_places=cls.decimal_places) + + if cls.max_digits is not None and cls.decimal_places is not None: + expected = cls.max_digits - cls.decimal_places + if whole_digits > expected: + raise errors.DecimalWholeDigitsError(whole_digits=expected) + + return value + + +def condecimal( + *, + gt: Decimal = None, + ge: Decimal = None, + lt: Decimal = None, + le: Decimal = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, + multiple_of: Decimal = None, +) -> Type[Decimal]: + # use kwargs then define conf in a dict to aid with IDE type hinting + namespace = dict( + gt=gt, ge=ge, lt=lt, le=le, max_digits=max_digits, decimal_places=decimal_places, multiple_of=multiple_of + ) + return type('ConstrainedDecimalValue', (ConstrainedDecimal,), namespace) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ UUID TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +if TYPE_CHECKING: + UUID1 = UUID + UUID3 = UUID + UUID4 = UUID + UUID5 = UUID +else: + + class UUID1(UUID): + _required_version = 1 + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format=f'uuid{cls._required_version}') + + class UUID3(UUID1): + _required_version = 3 + + class UUID4(UUID1): + _required_version = 4 + + class UUID5(UUID1): + _required_version = 5 + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PATH TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +if TYPE_CHECKING: + FilePath = Path + DirectoryPath = Path +else: + + class FilePath(Path): + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(format='file-path') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield path_validator + yield path_exists_validator + yield cls.validate + + @classmethod + def validate(cls, value: Path) -> Path: + if not value.is_file(): + raise errors.PathNotAFileError(path=value) + + return value + + class DirectoryPath(Path): + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(format='directory-path') + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield path_validator + yield path_exists_validator + yield cls.validate + + @classmethod + def validate(cls, value: Path) -> Path: + if not value.is_dir(): + raise errors.PathNotADirectoryError(path=value) + + return value + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ JSON TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class JsonWrapper: + pass + + +class JsonMeta(type): + def __getitem__(self, t: Type[Any]) -> Type[JsonWrapper]: + if t is Any: + return Json # allow Json[Any] to replecate plain Json + return _registered(type('JsonWrapperValue', (JsonWrapper,), {'inner_type': t})) + + +if TYPE_CHECKING: + Json = Annotated[T, ...] # Json[list[str]] will be recognized by type checkers as list[str] + +else: + + class Json(metaclass=JsonMeta): + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + field_schema.update(type='string', format='json-string') + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SECRET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class SecretField(abc.ABC): + """ + Note: this should be implemented as a generic like `SecretField(ABC, Generic[T])`, + the `__init__()` should be part of the abstract class and the + `get_secret_value()` method should use the generic `T` type. + + However Cython doesn't support very well generics at the moment and + the generated code fails to be imported (see + https://github.com/cython/cython/issues/2753). + """ + + def __eq__(self, other: Any) -> bool: + return isinstance(other, self.__class__) and self.get_secret_value() == other.get_secret_value() + + def __str__(self) -> str: + return '**********' if self.get_secret_value() else '' + + def __hash__(self) -> int: + return hash(self.get_secret_value()) + + @abc.abstractmethod + def get_secret_value(self) -> Any: # pragma: no cover + ... + + +class SecretStr(SecretField): + min_length: OptionalInt = None + max_length: OptionalInt = None + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none( + field_schema, + type='string', + writeOnly=True, + format='password', + minLength=cls.min_length, + maxLength=cls.max_length, + ) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + yield constr_length_validator + + @classmethod + def validate(cls, value: Any) -> 'SecretStr': + if isinstance(value, cls): + return value + value = str_validator(value) + return cls(value) + + def __init__(self, value: str): + self._secret_value = value + + def __repr__(self) -> str: + return f"SecretStr('{self}')" + + def __len__(self) -> int: + return len(self._secret_value) + + def display(self) -> str: + warnings.warn('`secret_str.display()` is deprecated, use `str(secret_str)` instead', DeprecationWarning) + return str(self) + + def get_secret_value(self) -> str: + return self._secret_value + + +class SecretBytes(SecretField): + min_length: OptionalInt = None + max_length: OptionalInt = None + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none( + field_schema, + type='string', + writeOnly=True, + format='password', + minLength=cls.min_length, + maxLength=cls.max_length, + ) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + yield constr_length_validator + + @classmethod + def validate(cls, value: Any) -> 'SecretBytes': + if isinstance(value, cls): + return value + value = bytes_validator(value) + return cls(value) + + def __init__(self, value: bytes): + self._secret_value = value + + def __repr__(self) -> str: + return f"SecretBytes(b'{self}')" + + def __len__(self) -> int: + return len(self._secret_value) + + def display(self) -> str: + warnings.warn('`secret_bytes.display()` is deprecated, use `str(secret_bytes)` instead', DeprecationWarning) + return str(self) + + def get_secret_value(self) -> bytes: + return self._secret_value + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PAYMENT CARD TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + +class PaymentCardBrand(str, Enum): + # If you add another card type, please also add it to the + # Hypothesis strategy in `pydantic._hypothesis_plugin`. + amex = 'American Express' + mastercard = 'Mastercard' + visa = 'Visa' + other = 'other' + + def __str__(self) -> str: + return self.value + + +class PaymentCardNumber(str): + """ + Based on: https://en.wikipedia.org/wiki/Payment_card_number + """ + + strip_whitespace: ClassVar[bool] = True + min_length: ClassVar[int] = 12 + max_length: ClassVar[int] = 19 + bin: str + last4: str + brand: PaymentCardBrand + + def __init__(self, card_number: str): + self.bin = card_number[:6] + self.last4 = card_number[-4:] + self.brand = self._get_brand(card_number) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield str_validator + yield constr_strip_whitespace + yield constr_length_validator + yield cls.validate_digits + yield cls.validate_luhn_check_digit + yield cls + yield cls.validate_length_for_brand + + @property + def masked(self) -> str: + num_masked = len(self) - 10 # len(bin) + len(last4) == 10 + return f'{self.bin}{"*" * num_masked}{self.last4}' + + @classmethod + def validate_digits(cls, card_number: str) -> str: + if not card_number.isdigit(): + raise errors.NotDigitError + return card_number + + @classmethod + def validate_luhn_check_digit(cls, card_number: str) -> str: + """ + Based on: https://en.wikipedia.org/wiki/Luhn_algorithm + """ + sum_ = int(card_number[-1]) + length = len(card_number) + parity = length % 2 + for i in range(length - 1): + digit = int(card_number[i]) + if i % 2 == parity: + digit *= 2 + if digit > 9: + digit -= 9 + sum_ += digit + valid = sum_ % 10 == 0 + if not valid: + raise errors.LuhnValidationError + return card_number + + @classmethod + def validate_length_for_brand(cls, card_number: 'PaymentCardNumber') -> 'PaymentCardNumber': + """ + Validate length based on BIN for major brands: + https://en.wikipedia.org/wiki/Payment_card_number#Issuer_identification_number_(IIN) + """ + required_length: Union[None, int, str] = None + if card_number.brand in PaymentCardBrand.mastercard: + required_length = 16 + valid = len(card_number) == required_length + elif card_number.brand == PaymentCardBrand.visa: + required_length = '13, 16 or 19' + valid = len(card_number) in {13, 16, 19} + elif card_number.brand == PaymentCardBrand.amex: + required_length = 15 + valid = len(card_number) == required_length + else: + valid = True + if not valid: + raise errors.InvalidLengthForBrand(brand=card_number.brand, required_length=required_length) + return card_number + + @staticmethod + def _get_brand(card_number: str) -> PaymentCardBrand: + if card_number[0] == '4': + brand = PaymentCardBrand.visa + elif 51 <= int(card_number[:2]) <= 55: + brand = PaymentCardBrand.mastercard + elif card_number[:2] in {'34', '37'}: + brand = PaymentCardBrand.amex + else: + brand = PaymentCardBrand.other + return brand + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTE SIZE TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +BYTE_SIZES = { + 'b': 1, + 'kb': 10**3, + 'mb': 10**6, + 'gb': 10**9, + 'tb': 10**12, + 'pb': 10**15, + 'eb': 10**18, + 'kib': 2**10, + 'mib': 2**20, + 'gib': 2**30, + 'tib': 2**40, + 'pib': 2**50, + 'eib': 2**60, +} +BYTE_SIZES.update({k.lower()[0]: v for k, v in BYTE_SIZES.items() if 'i' not in k}) +byte_string_re = re.compile(r'^\s*(\d*\.?\d+)\s*(\w+)?', re.IGNORECASE) + + +class ByteSize(int): + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield cls.validate + + @classmethod + def validate(cls, v: StrIntFloat) -> 'ByteSize': + try: + return cls(int(v)) + except ValueError: + pass + + str_match = byte_string_re.match(str(v)) + if str_match is None: + raise errors.InvalidByteSize() + + scalar, unit = str_match.groups() + if unit is None: + unit = 'b' + + try: + unit_mult = BYTE_SIZES[unit.lower()] + except KeyError: + raise errors.InvalidByteSizeUnit(unit=unit) + + return cls(int(float(scalar) * unit_mult)) + + def human_readable(self, decimal: bool = False) -> str: + if decimal: + divisor = 1000 + units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB'] + final_unit = 'EB' + else: + divisor = 1024 + units = ['B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB'] + final_unit = 'EiB' + + num = float(self) + for unit in units: + if abs(num) < divisor: + return f'{num:0.1f}{unit}' + num /= divisor + + return f'{num:0.1f}{final_unit}' + + def to(self, unit: str) -> float: + try: + unit_div = BYTE_SIZES[unit.lower()] + except KeyError: + raise errors.InvalidByteSizeUnit(unit=unit) + + return self / unit_div + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DATE TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +if TYPE_CHECKING: + PastDate = date + FutureDate = date +else: + + class PastDate(date): + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield parse_date + yield cls.validate + + @classmethod + def validate(cls, value: date) -> date: + if value >= date.today(): + raise errors.DateNotInThePastError() + + return value + + class FutureDate(date): + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield parse_date + yield cls.validate + + @classmethod + def validate(cls, value: date) -> date: + if value <= date.today(): + raise errors.DateNotInTheFutureError() + + return value + + +class ConstrainedDate(date, metaclass=ConstrainedNumberMeta): + gt: OptionalDate = None + ge: OptionalDate = None + lt: OptionalDate = None + le: OptionalDate = None + + @classmethod + def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None: + update_not_none(field_schema, exclusiveMinimum=cls.gt, exclusiveMaximum=cls.lt, minimum=cls.ge, maximum=cls.le) + + @classmethod + def __get_validators__(cls) -> 'CallableGenerator': + yield parse_date + yield number_size_validator + + +def condate( + *, + gt: date = None, + ge: date = None, + lt: date = None, + le: date = None, +) -> Type[date]: + # use kwargs then define conf in a dict to aid with IDE type hinting + namespace = dict(gt=gt, ge=ge, lt=lt, le=le) + return type('ConstrainedDateValue', (ConstrainedDate,), namespace) diff --git a/lib/pydantic/v1/typing.py b/lib/pydantic/v1/typing.py new file mode 100644 index 00000000..a690a053 --- /dev/null +++ b/lib/pydantic/v1/typing.py @@ -0,0 +1,603 @@ +import sys +import typing +from collections.abc import Callable +from os import PathLike +from typing import ( # type: ignore + TYPE_CHECKING, + AbstractSet, + Any, + Callable as TypingCallable, + ClassVar, + Dict, + ForwardRef, + Generator, + Iterable, + List, + Mapping, + NewType, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + _eval_type, + cast, + get_type_hints, +) + +from typing_extensions import ( + Annotated, + Final, + Literal, + NotRequired as TypedDictNotRequired, + Required as TypedDictRequired, +) + +try: + from typing import _TypingBase as typing_base # type: ignore +except ImportError: + from typing import _Final as typing_base # type: ignore + +try: + from typing import GenericAlias as TypingGenericAlias # type: ignore +except ImportError: + # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on) + TypingGenericAlias = () + +try: + from types import UnionType as TypesUnionType # type: ignore +except ImportError: + # python < 3.10 does not have UnionType (str | int, byte | bool and so on) + TypesUnionType = () + + +if sys.version_info < (3, 9): + + def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: + return type_._evaluate(globalns, localns) + +else: + + def evaluate_forwardref(type_: ForwardRef, globalns: Any, localns: Any) -> Any: + # Even though it is the right signature for python 3.9, mypy complains with + # `error: Too many arguments for "_evaluate" of "ForwardRef"` hence the cast... + return cast(Any, type_)._evaluate(globalns, localns, set()) + + +if sys.version_info < (3, 9): + # Ensure we always get all the whole `Annotated` hint, not just the annotated type. + # For 3.7 to 3.8, `get_type_hints` doesn't recognize `typing_extensions.Annotated`, + # so it already returns the full annotation + get_all_type_hints = get_type_hints + +else: + + def get_all_type_hints(obj: Any, globalns: Any = None, localns: Any = None) -> Any: + return get_type_hints(obj, globalns, localns, include_extras=True) + + +_T = TypeVar('_T') + +AnyCallable = TypingCallable[..., Any] +NoArgAnyCallable = TypingCallable[[], Any] + +# workaround for https://github.com/python/mypy/issues/9496 +AnyArgTCallable = TypingCallable[..., _T] + + +# Annotated[...] is implemented by returning an instance of one of these classes, depending on +# python/typing_extensions version. +AnnotatedTypeNames = {'AnnotatedMeta', '_AnnotatedAlias'} + + +LITERAL_TYPES: Set[Any] = {Literal} +if hasattr(typing, 'Literal'): + LITERAL_TYPES.add(typing.Literal) + + +if sys.version_info < (3, 8): + + def get_origin(t: Type[Any]) -> Optional[Type[Any]]: + if type(t).__name__ in AnnotatedTypeNames: + # weirdly this is a runtime requirement, as well as for mypy + return cast(Type[Any], Annotated) + return getattr(t, '__origin__', None) + +else: + from typing import get_origin as _typing_get_origin + + def get_origin(tp: Type[Any]) -> Optional[Type[Any]]: + """ + We can't directly use `typing.get_origin` since we need a fallback to support + custom generic classes like `ConstrainedList` + It should be useless once https://github.com/cython/cython/issues/3537 is + solved and https://github.com/pydantic/pydantic/pull/1753 is merged. + """ + if type(tp).__name__ in AnnotatedTypeNames: + return cast(Type[Any], Annotated) # mypy complains about _SpecialForm + return _typing_get_origin(tp) or getattr(tp, '__origin__', None) + + +if sys.version_info < (3, 8): + from typing import _GenericAlias + + def get_args(t: Type[Any]) -> Tuple[Any, ...]: + """Compatibility version of get_args for python 3.7. + + Mostly compatible with the python 3.8 `typing` module version + and able to handle almost all use cases. + """ + if type(t).__name__ in AnnotatedTypeNames: + return t.__args__ + t.__metadata__ + if isinstance(t, _GenericAlias): + res = t.__args__ + if t.__origin__ is Callable and res and res[0] is not Ellipsis: + res = (list(res[:-1]), res[-1]) + return res + return getattr(t, '__args__', ()) + +else: + from typing import get_args as _typing_get_args + + def _generic_get_args(tp: Type[Any]) -> Tuple[Any, ...]: + """ + In python 3.9, `typing.Dict`, `typing.List`, ... + do have an empty `__args__` by default (instead of the generic ~T for example). + In order to still support `Dict` for example and consider it as `Dict[Any, Any]`, + we retrieve the `_nparams` value that tells us how many parameters it needs. + """ + if hasattr(tp, '_nparams'): + return (Any,) * tp._nparams + # Special case for `tuple[()]`, which used to return ((),) with `typing.Tuple` + # in python 3.10- but now returns () for `tuple` and `Tuple`. + # This will probably be clarified in pydantic v2 + try: + if tp == Tuple[()] or sys.version_info >= (3, 9) and tp == tuple[()]: # type: ignore[misc] + return ((),) + # there is a TypeError when compiled with cython + except TypeError: # pragma: no cover + pass + return () + + def get_args(tp: Type[Any]) -> Tuple[Any, ...]: + """Get type arguments with all substitutions performed. + + For unions, basic simplifications used by Union constructor are performed. + Examples:: + get_args(Dict[str, int]) == (str, int) + get_args(int) == () + get_args(Union[int, Union[T, int], str][int]) == (int, str) + get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int]) + get_args(Callable[[], T][int]) == ([], int) + """ + if type(tp).__name__ in AnnotatedTypeNames: + return tp.__args__ + tp.__metadata__ + # the fallback is needed for the same reasons as `get_origin` (see above) + return _typing_get_args(tp) or getattr(tp, '__args__', ()) or _generic_get_args(tp) + + +if sys.version_info < (3, 9): + + def convert_generics(tp: Type[Any]) -> Type[Any]: + """Python 3.9 and older only supports generics from `typing` module. + They convert strings to ForwardRef automatically. + + Examples:: + typing.List['Hero'] == typing.List[ForwardRef('Hero')] + """ + return tp + +else: + from typing import _UnionGenericAlias # type: ignore + + from typing_extensions import _AnnotatedAlias + + def convert_generics(tp: Type[Any]) -> Type[Any]: + """ + Recursively searches for `str` type hints and replaces them with ForwardRef. + + Examples:: + convert_generics(list['Hero']) == list[ForwardRef('Hero')] + convert_generics(dict['Hero', 'Team']) == dict[ForwardRef('Hero'), ForwardRef('Team')] + convert_generics(typing.Dict['Hero', 'Team']) == typing.Dict[ForwardRef('Hero'), ForwardRef('Team')] + convert_generics(list[str | 'Hero'] | int) == list[str | ForwardRef('Hero')] | int + """ + origin = get_origin(tp) + if not origin or not hasattr(tp, '__args__'): + return tp + + args = get_args(tp) + + # typing.Annotated needs special treatment + if origin is Annotated: + return _AnnotatedAlias(convert_generics(args[0]), args[1:]) + + # recursively replace `str` instances inside of `GenericAlias` with `ForwardRef(arg)` + converted = tuple( + ForwardRef(arg) if isinstance(arg, str) and isinstance(tp, TypingGenericAlias) else convert_generics(arg) + for arg in args + ) + + if converted == args: + return tp + elif isinstance(tp, TypingGenericAlias): + return TypingGenericAlias(origin, converted) + elif isinstance(tp, TypesUnionType): + # recreate types.UnionType (PEP604, Python >= 3.10) + return _UnionGenericAlias(origin, converted) + else: + try: + setattr(tp, '__args__', converted) + except AttributeError: + pass + return tp + + +if sys.version_info < (3, 10): + + def is_union(tp: Optional[Type[Any]]) -> bool: + return tp is Union + + WithArgsTypes = (TypingGenericAlias,) + +else: + import types + import typing + + def is_union(tp: Optional[Type[Any]]) -> bool: + return tp is Union or tp is types.UnionType # noqa: E721 + + WithArgsTypes = (typing._GenericAlias, types.GenericAlias, types.UnionType) + + +StrPath = Union[str, PathLike] + + +if TYPE_CHECKING: + from .fields import ModelField + + TupleGenerator = Generator[Tuple[str, Any], None, None] + DictStrAny = Dict[str, Any] + DictAny = Dict[Any, Any] + SetStr = Set[str] + ListStr = List[str] + IntStr = Union[int, str] + AbstractSetIntStr = AbstractSet[IntStr] + DictIntStrAny = Dict[IntStr, Any] + MappingIntStrAny = Mapping[IntStr, Any] + CallableGenerator = Generator[AnyCallable, None, None] + ReprArgs = Sequence[Tuple[Optional[str], Any]] + + MYPY = False + if MYPY: + AnyClassMethod = classmethod[Any] + else: + # classmethod[TargetType, CallableParamSpecType, CallableReturnType] + AnyClassMethod = classmethod[Any, Any, Any] + +__all__ = ( + 'AnyCallable', + 'NoArgAnyCallable', + 'NoneType', + 'is_none_type', + 'display_as_type', + 'resolve_annotations', + 'is_callable_type', + 'is_literal_type', + 'all_literal_values', + 'is_namedtuple', + 'is_typeddict', + 'is_typeddict_special', + 'is_new_type', + 'new_type_supertype', + 'is_classvar', + 'is_finalvar', + 'update_field_forward_refs', + 'update_model_forward_refs', + 'TupleGenerator', + 'DictStrAny', + 'DictAny', + 'SetStr', + 'ListStr', + 'IntStr', + 'AbstractSetIntStr', + 'DictIntStrAny', + 'CallableGenerator', + 'ReprArgs', + 'AnyClassMethod', + 'CallableGenerator', + 'WithArgsTypes', + 'get_args', + 'get_origin', + 'get_sub_types', + 'typing_base', + 'get_all_type_hints', + 'is_union', + 'StrPath', + 'MappingIntStrAny', +) + + +NoneType = None.__class__ + + +NONE_TYPES: Tuple[Any, Any, Any] = (None, NoneType, Literal[None]) + + +if sys.version_info < (3, 8): + # Even though this implementation is slower, we need it for python 3.7: + # In python 3.7 "Literal" is not a builtin type and uses a different + # mechanism. + # for this reason `Literal[None] is Literal[None]` evaluates to `False`, + # breaking the faster implementation used for the other python versions. + + def is_none_type(type_: Any) -> bool: + return type_ in NONE_TYPES + +elif sys.version_info[:2] == (3, 8): + + def is_none_type(type_: Any) -> bool: + for none_type in NONE_TYPES: + if type_ is none_type: + return True + # With python 3.8, specifically 3.8.10, Literal "is" check sare very flakey + # can change on very subtle changes like use of types in other modules, + # hopefully this check avoids that issue. + if is_literal_type(type_): # pragma: no cover + return all_literal_values(type_) == (None,) + return False + +else: + + def is_none_type(type_: Any) -> bool: + return type_ in NONE_TYPES + + +def display_as_type(v: Type[Any]) -> str: + if not isinstance(v, typing_base) and not isinstance(v, WithArgsTypes) and not isinstance(v, type): + v = v.__class__ + + if is_union(get_origin(v)): + return f'Union[{", ".join(map(display_as_type, get_args(v)))}]' + + if isinstance(v, WithArgsTypes): + # Generic alias are constructs like `list[int]` + return str(v).replace('typing.', '') + + try: + return v.__name__ + except AttributeError: + # happens with typing objects + return str(v).replace('typing.', '') + + +def resolve_annotations(raw_annotations: Dict[str, Type[Any]], module_name: Optional[str]) -> Dict[str, Type[Any]]: + """ + Partially taken from typing.get_type_hints. + + Resolve string or ForwardRef annotations into type objects if possible. + """ + base_globals: Optional[Dict[str, Any]] = None + if module_name: + try: + module = sys.modules[module_name] + except KeyError: + # happens occasionally, see https://github.com/pydantic/pydantic/issues/2363 + pass + else: + base_globals = module.__dict__ + + annotations = {} + for name, value in raw_annotations.items(): + if isinstance(value, str): + if (3, 10) > sys.version_info >= (3, 9, 8) or sys.version_info >= (3, 10, 1): + value = ForwardRef(value, is_argument=False, is_class=True) + else: + value = ForwardRef(value, is_argument=False) + try: + value = _eval_type(value, base_globals, None) + except NameError: + # this is ok, it can be fixed with update_forward_refs + pass + annotations[name] = value + return annotations + + +def is_callable_type(type_: Type[Any]) -> bool: + return type_ is Callable or get_origin(type_) is Callable + + +def is_literal_type(type_: Type[Any]) -> bool: + return Literal is not None and get_origin(type_) in LITERAL_TYPES + + +def literal_values(type_: Type[Any]) -> Tuple[Any, ...]: + return get_args(type_) + + +def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]: + """ + This method is used to retrieve all Literal values as + Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586) + e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]` + """ + if not is_literal_type(type_): + return (type_,) + + values = literal_values(type_) + return tuple(x for value in values for x in all_literal_values(value)) + + +def is_namedtuple(type_: Type[Any]) -> bool: + """ + Check if a given class is a named tuple. + It can be either a `typing.NamedTuple` or `collections.namedtuple` + """ + from .utils import lenient_issubclass + + return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields') + + +def is_typeddict(type_: Type[Any]) -> bool: + """ + Check if a given class is a typed dict (from `typing` or `typing_extensions`) + In 3.10, there will be a public method (https://docs.python.org/3.10/library/typing.html#typing.is_typeddict) + """ + from .utils import lenient_issubclass + + return lenient_issubclass(type_, dict) and hasattr(type_, '__total__') + + +def _check_typeddict_special(type_: Any) -> bool: + return type_ is TypedDictRequired or type_ is TypedDictNotRequired + + +def is_typeddict_special(type_: Any) -> bool: + """ + Check if type is a TypedDict special form (Required or NotRequired). + """ + return _check_typeddict_special(type_) or _check_typeddict_special(get_origin(type_)) + + +test_type = NewType('test_type', str) + + +def is_new_type(type_: Type[Any]) -> bool: + """ + Check whether type_ was created using typing.NewType + """ + return isinstance(type_, test_type.__class__) and hasattr(type_, '__supertype__') # type: ignore + + +def new_type_supertype(type_: Type[Any]) -> Type[Any]: + while hasattr(type_, '__supertype__'): + type_ = type_.__supertype__ + return type_ + + +def _check_classvar(v: Optional[Type[Any]]) -> bool: + if v is None: + return False + + return v.__class__ == ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar' + + +def _check_finalvar(v: Optional[Type[Any]]) -> bool: + """ + Check if a given type is a `typing.Final` type. + """ + if v is None: + return False + + return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final') + + +def is_classvar(ann_type: Type[Any]) -> bool: + if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)): + return True + + # this is an ugly workaround for class vars that contain forward references and are therefore themselves + # forward references, see #3679 + if ann_type.__class__ == ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['): + return True + + return False + + +def is_finalvar(ann_type: Type[Any]) -> bool: + return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type)) + + +def update_field_forward_refs(field: 'ModelField', globalns: Any, localns: Any) -> None: + """ + Try to update ForwardRefs on fields based on this ModelField, globalns and localns. + """ + prepare = False + if field.type_.__class__ == ForwardRef: + prepare = True + field.type_ = evaluate_forwardref(field.type_, globalns, localns or None) + if field.outer_type_.__class__ == ForwardRef: + prepare = True + field.outer_type_ = evaluate_forwardref(field.outer_type_, globalns, localns or None) + if prepare: + field.prepare() + + if field.sub_fields: + for sub_f in field.sub_fields: + update_field_forward_refs(sub_f, globalns=globalns, localns=localns) + + if field.discriminator_key is not None: + field.prepare_discriminated_union_sub_fields() + + +def update_model_forward_refs( + model: Type[Any], + fields: Iterable['ModelField'], + json_encoders: Dict[Union[Type[Any], str, ForwardRef], AnyCallable], + localns: 'DictStrAny', + exc_to_suppress: Tuple[Type[BaseException], ...] = (), +) -> None: + """ + Try to update model fields ForwardRefs based on model and localns. + """ + if model.__module__ in sys.modules: + globalns = sys.modules[model.__module__].__dict__.copy() + else: + globalns = {} + + globalns.setdefault(model.__name__, model) + + for f in fields: + try: + update_field_forward_refs(f, globalns=globalns, localns=localns) + except exc_to_suppress: + pass + + for key in set(json_encoders.keys()): + if isinstance(key, str): + fr: ForwardRef = ForwardRef(key) + elif isinstance(key, ForwardRef): + fr = key + else: + continue + + try: + new_key = evaluate_forwardref(fr, globalns, localns or None) + except exc_to_suppress: # pragma: no cover + continue + + json_encoders[new_key] = json_encoders.pop(key) + + +def get_class(type_: Type[Any]) -> Union[None, bool, Type[Any]]: + """ + Tries to get the class of a Type[T] annotation. Returns True if Type is used + without brackets. Otherwise returns None. + """ + if type_ is type: + return True + + if get_origin(type_) is None: + return None + + args = get_args(type_) + if not args or not isinstance(args[0], type): + return True + else: + return args[0] + + +def get_sub_types(tp: Any) -> List[Any]: + """ + Return all the types that are allowed by type `tp` + `tp` can be a `Union` of allowed types or an `Annotated` type + """ + origin = get_origin(tp) + if origin is Annotated: + return get_sub_types(get_args(tp)[0]) + elif is_union(origin): + return [x for t in get_args(tp) for x in get_sub_types(t)] + else: + return [tp] diff --git a/lib/pydantic/v1/utils.py b/lib/pydantic/v1/utils.py new file mode 100644 index 00000000..4d0f68ed --- /dev/null +++ b/lib/pydantic/v1/utils.py @@ -0,0 +1,803 @@ +import keyword +import warnings +import weakref +from collections import OrderedDict, defaultdict, deque +from copy import deepcopy +from itertools import islice, zip_longest +from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType +from typing import ( + TYPE_CHECKING, + AbstractSet, + Any, + Callable, + Collection, + Dict, + Generator, + Iterable, + Iterator, + List, + Mapping, + NoReturn, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +from typing_extensions import Annotated + +from .errors import ConfigError +from .typing import ( + NoneType, + WithArgsTypes, + all_literal_values, + display_as_type, + get_args, + get_origin, + is_literal_type, + is_union, +) +from .version import version_info + +if TYPE_CHECKING: + from inspect import Signature + from pathlib import Path + + from .config import BaseConfig + from .dataclasses import Dataclass + from .fields import ModelField + from .main import BaseModel + from .typing import AbstractSetIntStr, DictIntStrAny, IntStr, MappingIntStrAny, ReprArgs + + RichReprResult = Iterable[Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]] + +__all__ = ( + 'import_string', + 'sequence_like', + 'validate_field_name', + 'lenient_isinstance', + 'lenient_issubclass', + 'in_ipython', + 'is_valid_identifier', + 'deep_update', + 'update_not_none', + 'almost_equal_floats', + 'get_model', + 'to_camel', + 'is_valid_field', + 'smart_deepcopy', + 'PyObjectStr', + 'Representation', + 'GetterDict', + 'ValueItems', + 'version_info', # required here to match behaviour in v1.3 + 'ClassAttribute', + 'path_type', + 'ROOT_KEY', + 'get_unique_discriminator_alias', + 'get_discriminator_alias_and_values', + 'DUNDER_ATTRIBUTES', +) + +ROOT_KEY = '__root__' +# these are types that are returned unchanged by deepcopy +IMMUTABLE_NON_COLLECTIONS_TYPES: Set[Type[Any]] = { + int, + float, + complex, + str, + bool, + bytes, + type, + NoneType, + FunctionType, + BuiltinFunctionType, + LambdaType, + weakref.ref, + CodeType, + # note: including ModuleType will differ from behaviour of deepcopy by not producing error. + # It might be not a good idea in general, but considering that this function used only internally + # against default values of fields, this will allow to actually have a field with module as default value + ModuleType, + NotImplemented.__class__, + Ellipsis.__class__, +} + +# these are types that if empty, might be copied with simple copy() instead of deepcopy() +BUILTIN_COLLECTIONS: Set[Type[Any]] = { + list, + set, + tuple, + frozenset, + dict, + OrderedDict, + defaultdict, + deque, +} + + +def import_string(dotted_path: str) -> Any: + """ + Stolen approximately from django. Import a dotted module path and return the attribute/class designated by the + last name in the path. Raise ImportError if the import fails. + """ + from importlib import import_module + + try: + module_path, class_name = dotted_path.strip(' ').rsplit('.', 1) + except ValueError as e: + raise ImportError(f'"{dotted_path}" doesn\'t look like a module path') from e + + module = import_module(module_path) + try: + return getattr(module, class_name) + except AttributeError as e: + raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e + + +def truncate(v: Union[str], *, max_len: int = 80) -> str: + """ + Truncate a value and add a unicode ellipsis (three dots) to the end if it was too long + """ + warnings.warn('`truncate` is no-longer used by pydantic and is deprecated', DeprecationWarning) + if isinstance(v, str) and len(v) > (max_len - 2): + # -3 so quote + string + … + quote has correct length + return (v[: (max_len - 3)] + '…').__repr__() + try: + v = v.__repr__() + except TypeError: + v = v.__class__.__repr__(v) # in case v is a type + if len(v) > max_len: + v = v[: max_len - 1] + '…' + return v + + +def sequence_like(v: Any) -> bool: + return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque)) + + +def validate_field_name(bases: List[Type['BaseModel']], field_name: str) -> None: + """ + Ensure that the field's name does not shadow an existing attribute of the model. + """ + for base in bases: + if getattr(base, field_name, None): + raise NameError( + f'Field name "{field_name}" shadows a BaseModel attribute; ' + f'use a different field name with "alias=\'{field_name}\'".' + ) + + +def lenient_isinstance(o: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool: + try: + return isinstance(o, class_or_tuple) # type: ignore[arg-type] + except TypeError: + return False + + +def lenient_issubclass(cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]) -> bool: + try: + return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type] + except TypeError: + if isinstance(cls, WithArgsTypes): + return False + raise # pragma: no cover + + +def in_ipython() -> bool: + """ + Check whether we're in an ipython environment, including jupyter notebooks. + """ + try: + eval('__IPYTHON__') + except NameError: + return False + else: # pragma: no cover + return True + + +def is_valid_identifier(identifier: str) -> bool: + """ + Checks that a string is a valid identifier and not a Python keyword. + :param identifier: The identifier to test. + :return: True if the identifier is valid. + """ + return identifier.isidentifier() and not keyword.iskeyword(identifier) + + +KeyType = TypeVar('KeyType') + + +def deep_update(mapping: Dict[KeyType, Any], *updating_mappings: Dict[KeyType, Any]) -> Dict[KeyType, Any]: + updated_mapping = mapping.copy() + for updating_mapping in updating_mappings: + for k, v in updating_mapping.items(): + if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict): + updated_mapping[k] = deep_update(updated_mapping[k], v) + else: + updated_mapping[k] = v + return updated_mapping + + +def update_not_none(mapping: Dict[Any, Any], **update: Any) -> None: + mapping.update({k: v for k, v in update.items() if v is not None}) + + +def almost_equal_floats(value_1: float, value_2: float, *, delta: float = 1e-8) -> bool: + """ + Return True if two floats are almost equal + """ + return abs(value_1 - value_2) <= delta + + +def generate_model_signature( + init: Callable[..., None], fields: Dict[str, 'ModelField'], config: Type['BaseConfig'] +) -> 'Signature': + """ + Generate signature for model based on its fields + """ + from inspect import Parameter, Signature, signature + + from .config import Extra + + present_params = signature(init).parameters.values() + merged_params: Dict[str, Parameter] = {} + var_kw = None + use_var_kw = False + + for param in islice(present_params, 1, None): # skip self arg + if param.kind is param.VAR_KEYWORD: + var_kw = param + continue + merged_params[param.name] = param + + if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through + allow_names = config.allow_population_by_field_name + for field_name, field in fields.items(): + param_name = field.alias + if field_name in merged_params or param_name in merged_params: + continue + elif not is_valid_identifier(param_name): + if allow_names and is_valid_identifier(field_name): + param_name = field_name + else: + use_var_kw = True + continue + + # TODO: replace annotation with actual expected types once #1055 solved + kwargs = {'default': field.default} if not field.required else {} + merged_params[param_name] = Parameter( + param_name, Parameter.KEYWORD_ONLY, annotation=field.annotation, **kwargs + ) + + if config.extra is Extra.allow: + use_var_kw = True + + if var_kw and use_var_kw: + # Make sure the parameter for extra kwargs + # does not have the same name as a field + default_model_signature = [ + ('__pydantic_self__', Parameter.POSITIONAL_OR_KEYWORD), + ('data', Parameter.VAR_KEYWORD), + ] + if [(p.name, p.kind) for p in present_params] == default_model_signature: + # if this is the standard model signature, use extra_data as the extra args name + var_kw_name = 'extra_data' + else: + # else start from var_kw + var_kw_name = var_kw.name + + # generate a name that's definitely unique + while var_kw_name in fields: + var_kw_name += '_' + merged_params[var_kw_name] = var_kw.replace(name=var_kw_name) + + return Signature(parameters=list(merged_params.values()), return_annotation=None) + + +def get_model(obj: Union[Type['BaseModel'], Type['Dataclass']]) -> Type['BaseModel']: + from .main import BaseModel + + try: + model_cls = obj.__pydantic_model__ # type: ignore + except AttributeError: + model_cls = obj + + if not issubclass(model_cls, BaseModel): + raise TypeError('Unsupported type, must be either BaseModel or dataclass') + return model_cls + + +def to_camel(string: str) -> str: + return ''.join(word.capitalize() for word in string.split('_')) + + +def to_lower_camel(string: str) -> str: + if len(string) >= 1: + pascal_string = to_camel(string) + return pascal_string[0].lower() + pascal_string[1:] + return string.lower() + + +T = TypeVar('T') + + +def unique_list( + input_list: Union[List[T], Tuple[T, ...]], + *, + name_factory: Callable[[T], str] = str, +) -> List[T]: + """ + Make a list unique while maintaining order. + We update the list if another one with the same name is set + (e.g. root validator overridden in subclass) + """ + result: List[T] = [] + result_names: List[str] = [] + for v in input_list: + v_name = name_factory(v) + if v_name not in result_names: + result_names.append(v_name) + result.append(v) + else: + result[result_names.index(v_name)] = v + + return result + + +class PyObjectStr(str): + """ + String class where repr doesn't include quotes. Useful with Representation when you want to return a string + representation of something that valid (or pseudo-valid) python. + """ + + def __repr__(self) -> str: + return str(self) + + +class Representation: + """ + Mixin to provide __str__, __repr__, and __pretty__ methods. See #884 for more details. + + __pretty__ is used by [devtools](https://python-devtools.helpmanual.io/) to provide human readable representations + of objects. + """ + + __slots__: Tuple[str, ...] = tuple() + + def __repr_args__(self) -> 'ReprArgs': + """ + Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden. + + Can either return: + * name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]` + * or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]` + """ + attrs = ((s, getattr(self, s)) for s in self.__slots__) + return [(a, v) for a, v in attrs if v is not None] + + def __repr_name__(self) -> str: + """ + Name of the instance's class, used in __repr__. + """ + return self.__class__.__name__ + + def __repr_str__(self, join_str: str) -> str: + return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__()) + + def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]: + """ + Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects + """ + yield self.__repr_name__() + '(' + yield 1 + for name, value in self.__repr_args__(): + if name is not None: + yield name + '=' + yield fmt(value) + yield ',' + yield 0 + yield -1 + yield ')' + + def __str__(self) -> str: + return self.__repr_str__(' ') + + def __repr__(self) -> str: + return f'{self.__repr_name__()}({self.__repr_str__(", ")})' + + def __rich_repr__(self) -> 'RichReprResult': + """Get fields for Rich library""" + for name, field_repr in self.__repr_args__(): + if name is None: + yield field_repr + else: + yield name, field_repr + + +class GetterDict(Representation): + """ + Hack to make object's smell just enough like dicts for validate_model. + + We can't inherit from Mapping[str, Any] because it upsets cython so we have to implement all methods ourselves. + """ + + __slots__ = ('_obj',) + + def __init__(self, obj: Any): + self._obj = obj + + def __getitem__(self, key: str) -> Any: + try: + return getattr(self._obj, key) + except AttributeError as e: + raise KeyError(key) from e + + def get(self, key: Any, default: Any = None) -> Any: + return getattr(self._obj, key, default) + + def extra_keys(self) -> Set[Any]: + """ + We don't want to get any other attributes of obj if the model didn't explicitly ask for them + """ + return set() + + def keys(self) -> List[Any]: + """ + Keys of the pseudo dictionary, uses a list not set so order information can be maintained like python + dictionaries. + """ + return list(self) + + def values(self) -> List[Any]: + return [self[k] for k in self] + + def items(self) -> Iterator[Tuple[str, Any]]: + for k in self: + yield k, self.get(k) + + def __iter__(self) -> Iterator[str]: + for name in dir(self._obj): + if not name.startswith('_'): + yield name + + def __len__(self) -> int: + return sum(1 for _ in self) + + def __contains__(self, item: Any) -> bool: + return item in self.keys() + + def __eq__(self, other: Any) -> bool: + return dict(self) == dict(other.items()) + + def __repr_args__(self) -> 'ReprArgs': + return [(None, dict(self))] + + def __repr_name__(self) -> str: + return f'GetterDict[{display_as_type(self._obj)}]' + + +class ValueItems(Representation): + """ + Class for more convenient calculation of excluded or included fields on values. + """ + + __slots__ = ('_items', '_type') + + def __init__(self, value: Any, items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> None: + items = self._coerce_items(items) + + if isinstance(value, (list, tuple)): + items = self._normalize_indexes(items, len(value)) + + self._items: 'MappingIntStrAny' = items + + def is_excluded(self, item: Any) -> bool: + """ + Check if item is fully excluded. + + :param item: key or index of a value + """ + return self.is_true(self._items.get(item)) + + def is_included(self, item: Any) -> bool: + """ + Check if value is contained in self._items + + :param item: key or index of value + """ + return item in self._items + + def for_element(self, e: 'IntStr') -> Optional[Union['AbstractSetIntStr', 'MappingIntStrAny']]: + """ + :param e: key or index of element on value + :return: raw values for element if self._items is dict and contain needed element + """ + + item = self._items.get(e) + return item if not self.is_true(item) else None + + def _normalize_indexes(self, items: 'MappingIntStrAny', v_length: int) -> 'DictIntStrAny': + """ + :param items: dict or set of indexes which will be normalized + :param v_length: length of sequence indexes of which will be + + >>> self._normalize_indexes({0: True, -2: True, -1: True}, 4) + {0: True, 2: True, 3: True} + >>> self._normalize_indexes({'__all__': True}, 4) + {0: True, 1: True, 2: True, 3: True} + """ + + normalized_items: 'DictIntStrAny' = {} + all_items = None + for i, v in items.items(): + if not (isinstance(v, Mapping) or isinstance(v, AbstractSet) or self.is_true(v)): + raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}') + if i == '__all__': + all_items = self._coerce_value(v) + continue + if not isinstance(i, int): + raise TypeError( + 'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: ' + 'expected integer keys or keyword "__all__"' + ) + normalized_i = v_length + i if i < 0 else i + normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i)) + + if not all_items: + return normalized_items + if self.is_true(all_items): + for i in range(v_length): + normalized_items.setdefault(i, ...) + return normalized_items + for i in range(v_length): + normalized_item = normalized_items.setdefault(i, {}) + if not self.is_true(normalized_item): + normalized_items[i] = self.merge(all_items, normalized_item) + return normalized_items + + @classmethod + def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any: + """ + Merge a ``base`` item with an ``override`` item. + + Both ``base`` and ``override`` are converted to dictionaries if possible. + Sets are converted to dictionaries with the sets entries as keys and + Ellipsis as values. + + Each key-value pair existing in ``base`` is merged with ``override``, + while the rest of the key-value pairs are updated recursively with this function. + + Merging takes place based on the "union" of keys if ``intersect`` is + set to ``False`` (default) and on the intersection of keys if + ``intersect`` is set to ``True``. + """ + override = cls._coerce_value(override) + base = cls._coerce_value(base) + if override is None: + return base + if cls.is_true(base) or base is None: + return override + if cls.is_true(override): + return base if intersect else override + + # intersection or union of keys while preserving ordering: + if intersect: + merge_keys = [k for k in base if k in override] + [k for k in override if k in base] + else: + merge_keys = list(base) + [k for k in override if k not in base] + + merged: 'DictIntStrAny' = {} + for k in merge_keys: + merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect) + if merged_item is not None: + merged[k] = merged_item + + return merged + + @staticmethod + def _coerce_items(items: Union['AbstractSetIntStr', 'MappingIntStrAny']) -> 'MappingIntStrAny': + if isinstance(items, Mapping): + pass + elif isinstance(items, AbstractSet): + items = dict.fromkeys(items, ...) + else: + class_name = getattr(items, '__class__', '???') + assert_never( + items, + f'Unexpected type of exclude value {class_name}', + ) + return items + + @classmethod + def _coerce_value(cls, value: Any) -> Any: + if value is None or cls.is_true(value): + return value + return cls._coerce_items(value) + + @staticmethod + def is_true(v: Any) -> bool: + return v is True or v is ... + + def __repr_args__(self) -> 'ReprArgs': + return [(None, self._items)] + + +class ClassAttribute: + """ + Hide class attribute from its instances + """ + + __slots__ = ( + 'name', + 'value', + ) + + def __init__(self, name: str, value: Any) -> None: + self.name = name + self.value = value + + def __get__(self, instance: Any, owner: Type[Any]) -> None: + if instance is None: + return self.value + raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only') + + +path_types = { + 'is_dir': 'directory', + 'is_file': 'file', + 'is_mount': 'mount point', + 'is_symlink': 'symlink', + 'is_block_device': 'block device', + 'is_char_device': 'char device', + 'is_fifo': 'FIFO', + 'is_socket': 'socket', +} + + +def path_type(p: 'Path') -> str: + """ + Find out what sort of thing a path is. + """ + assert p.exists(), 'path does not exist' + for method, name in path_types.items(): + if getattr(p, method)(): + return name + + return 'unknown' + + +Obj = TypeVar('Obj') + + +def smart_deepcopy(obj: Obj) -> Obj: + """ + Return type as is for immutable built-in types + Use obj.copy() for built-in empty collections + Use copy.deepcopy() for non-empty collections and unknown objects + """ + + obj_type = obj.__class__ + if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES: + return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway + try: + if not obj and obj_type in BUILTIN_COLLECTIONS: + # faster way for empty collections, no need to copy its members + return obj if obj_type is tuple else obj.copy() # type: ignore # tuple doesn't have copy method + except (TypeError, ValueError, RuntimeError): + # do we really dare to catch ALL errors? Seems a bit risky + pass + + return deepcopy(obj) # slowest way when we actually might need a deepcopy + + +def is_valid_field(name: str) -> bool: + if not name.startswith('_'): + return True + return ROOT_KEY == name + + +DUNDER_ATTRIBUTES = { + '__annotations__', + '__classcell__', + '__doc__', + '__module__', + '__orig_bases__', + '__orig_class__', + '__qualname__', +} + + +def is_valid_private_name(name: str) -> bool: + return not is_valid_field(name) and name not in DUNDER_ATTRIBUTES + + +_EMPTY = object() + + +def all_identical(left: Iterable[Any], right: Iterable[Any]) -> bool: + """ + Check that the items of `left` are the same objects as those in `right`. + + >>> a, b = object(), object() + >>> all_identical([a, b, a], [a, b, a]) + True + >>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical" + False + """ + for left_item, right_item in zip_longest(left, right, fillvalue=_EMPTY): + if left_item is not right_item: + return False + return True + + +def assert_never(obj: NoReturn, msg: str) -> NoReturn: + """ + Helper to make sure that we have covered all possible types. + + This is mostly useful for ``mypy``, docs: + https://mypy.readthedocs.io/en/latest/literal_types.html#exhaustive-checks + """ + raise TypeError(msg) + + +def get_unique_discriminator_alias(all_aliases: Collection[str], discriminator_key: str) -> str: + """Validate that all aliases are the same and if that's the case return the alias""" + unique_aliases = set(all_aliases) + if len(unique_aliases) > 1: + raise ConfigError( + f'Aliases for discriminator {discriminator_key!r} must be the same (got {", ".join(sorted(all_aliases))})' + ) + return unique_aliases.pop() + + +def get_discriminator_alias_and_values(tp: Any, discriminator_key: str) -> Tuple[str, Tuple[str, ...]]: + """ + Get alias and all valid values in the `Literal` type of the discriminator field + `tp` can be a `BaseModel` class or directly an `Annotated` `Union` of many. + """ + is_root_model = getattr(tp, '__custom_root_type__', False) + + if get_origin(tp) is Annotated: + tp = get_args(tp)[0] + + if hasattr(tp, '__pydantic_model__'): + tp = tp.__pydantic_model__ + + if is_union(get_origin(tp)): + alias, all_values = _get_union_alias_and_all_values(tp, discriminator_key) + return alias, tuple(v for values in all_values for v in values) + elif is_root_model: + union_type = tp.__fields__[ROOT_KEY].type_ + alias, all_values = _get_union_alias_and_all_values(union_type, discriminator_key) + + if len(set(all_values)) > 1: + raise ConfigError( + f'Field {discriminator_key!r} is not the same for all submodels of {display_as_type(tp)!r}' + ) + + return alias, all_values[0] + + else: + try: + t_discriminator_type = tp.__fields__[discriminator_key].type_ + except AttributeError as e: + raise TypeError(f'Type {tp.__name__!r} is not a valid `BaseModel` or `dataclass`') from e + except KeyError as e: + raise ConfigError(f'Model {tp.__name__!r} needs a discriminator field for key {discriminator_key!r}') from e + + if not is_literal_type(t_discriminator_type): + raise ConfigError(f'Field {discriminator_key!r} of model {tp.__name__!r} needs to be a `Literal`') + + return tp.__fields__[discriminator_key].alias, all_literal_values(t_discriminator_type) + + +def _get_union_alias_and_all_values( + union_type: Type[Any], discriminator_key: str +) -> Tuple[str, Tuple[Tuple[str, ...], ...]]: + zipped_aliases_values = [get_discriminator_alias_and_values(t, discriminator_key) for t in get_args(union_type)] + # unzip: [('alias_a',('v1', 'v2)), ('alias_b', ('v3',))] => [('alias_a', 'alias_b'), (('v1', 'v2'), ('v3',))] + all_aliases, all_values = zip(*zipped_aliases_values) + return get_unique_discriminator_alias(all_aliases, discriminator_key), all_values diff --git a/lib/pydantic/v1/validators.py b/lib/pydantic/v1/validators.py new file mode 100644 index 00000000..549a235e --- /dev/null +++ b/lib/pydantic/v1/validators.py @@ -0,0 +1,765 @@ +import math +import re +from collections import OrderedDict, deque +from collections.abc import Hashable as CollectionsHashable +from datetime import date, datetime, time, timedelta +from decimal import Decimal, DecimalException +from enum import Enum, IntEnum +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Deque, + Dict, + ForwardRef, + FrozenSet, + Generator, + Hashable, + List, + NamedTuple, + Pattern, + Set, + Tuple, + Type, + TypeVar, + Union, +) +from uuid import UUID + +from . import errors +from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time +from .typing import ( + AnyCallable, + all_literal_values, + display_as_type, + get_class, + is_callable_type, + is_literal_type, + is_namedtuple, + is_none_type, + is_typeddict, +) +from .utils import almost_equal_floats, lenient_issubclass, sequence_like + +if TYPE_CHECKING: + from typing_extensions import Literal, TypedDict + + from .config import BaseConfig + from .fields import ModelField + from .types import ConstrainedDecimal, ConstrainedFloat, ConstrainedInt + + ConstrainedNumber = Union[ConstrainedDecimal, ConstrainedFloat, ConstrainedInt] + AnyOrderedDict = OrderedDict[Any, Any] + Number = Union[int, float, Decimal] + StrBytes = Union[str, bytes] + + +def str_validator(v: Any) -> Union[str]: + if isinstance(v, str): + if isinstance(v, Enum): + return v.value + else: + return v + elif isinstance(v, (float, int, Decimal)): + # is there anything else we want to add here? If you think so, create an issue. + return str(v) + elif isinstance(v, (bytes, bytearray)): + return v.decode() + else: + raise errors.StrError() + + +def strict_str_validator(v: Any) -> Union[str]: + if isinstance(v, str) and not isinstance(v, Enum): + return v + raise errors.StrError() + + +def bytes_validator(v: Any) -> Union[bytes]: + if isinstance(v, bytes): + return v + elif isinstance(v, bytearray): + return bytes(v) + elif isinstance(v, str): + return v.encode() + elif isinstance(v, (float, int, Decimal)): + return str(v).encode() + else: + raise errors.BytesError() + + +def strict_bytes_validator(v: Any) -> Union[bytes]: + if isinstance(v, bytes): + return v + elif isinstance(v, bytearray): + return bytes(v) + else: + raise errors.BytesError() + + +BOOL_FALSE = {0, '0', 'off', 'f', 'false', 'n', 'no'} +BOOL_TRUE = {1, '1', 'on', 't', 'true', 'y', 'yes'} + + +def bool_validator(v: Any) -> bool: + if v is True or v is False: + return v + if isinstance(v, bytes): + v = v.decode() + if isinstance(v, str): + v = v.lower() + try: + if v in BOOL_TRUE: + return True + if v in BOOL_FALSE: + return False + except TypeError: + raise errors.BoolError() + raise errors.BoolError() + + +# matches the default limit cpython, see https://github.com/python/cpython/pull/96500 +max_str_int = 4_300 + + +def int_validator(v: Any) -> int: + if isinstance(v, int) and not (v is True or v is False): + return v + + # see https://github.com/pydantic/pydantic/issues/1477 and in turn, https://github.com/python/cpython/issues/95778 + # this check should be unnecessary once patch releases are out for 3.7, 3.8, 3.9 and 3.10 + # but better to check here until then. + # NOTICE: this does not fully protect user from the DOS risk since the standard library JSON implementation + # (and other std lib modules like xml) use `int()` and are likely called before this, the best workaround is to + # 1. update to the latest patch release of python once released, 2. use a different JSON library like ujson + if isinstance(v, (str, bytes, bytearray)) and len(v) > max_str_int: + raise errors.IntegerError() + + try: + return int(v) + except (TypeError, ValueError, OverflowError): + raise errors.IntegerError() + + +def strict_int_validator(v: Any) -> int: + if isinstance(v, int) and not (v is True or v is False): + return v + raise errors.IntegerError() + + +def float_validator(v: Any) -> float: + if isinstance(v, float): + return v + + try: + return float(v) + except (TypeError, ValueError): + raise errors.FloatError() + + +def strict_float_validator(v: Any) -> float: + if isinstance(v, float): + return v + raise errors.FloatError() + + +def float_finite_validator(v: 'Number', field: 'ModelField', config: 'BaseConfig') -> 'Number': + allow_inf_nan = getattr(field.type_, 'allow_inf_nan', None) + if allow_inf_nan is None: + allow_inf_nan = config.allow_inf_nan + + if allow_inf_nan is False and (math.isnan(v) or math.isinf(v)): + raise errors.NumberNotFiniteError() + return v + + +def number_multiple_validator(v: 'Number', field: 'ModelField') -> 'Number': + field_type: ConstrainedNumber = field.type_ + if field_type.multiple_of is not None: + mod = float(v) / float(field_type.multiple_of) % 1 + if not almost_equal_floats(mod, 0.0) and not almost_equal_floats(mod, 1.0): + raise errors.NumberNotMultipleError(multiple_of=field_type.multiple_of) + return v + + +def number_size_validator(v: 'Number', field: 'ModelField') -> 'Number': + field_type: ConstrainedNumber = field.type_ + if field_type.gt is not None and not v > field_type.gt: + raise errors.NumberNotGtError(limit_value=field_type.gt) + elif field_type.ge is not None and not v >= field_type.ge: + raise errors.NumberNotGeError(limit_value=field_type.ge) + + if field_type.lt is not None and not v < field_type.lt: + raise errors.NumberNotLtError(limit_value=field_type.lt) + if field_type.le is not None and not v <= field_type.le: + raise errors.NumberNotLeError(limit_value=field_type.le) + + return v + + +def constant_validator(v: 'Any', field: 'ModelField') -> 'Any': + """Validate ``const`` fields. + + The value provided for a ``const`` field must be equal to the default value + of the field. This is to support the keyword of the same name in JSON + Schema. + """ + if v != field.default: + raise errors.WrongConstantError(given=v, permitted=[field.default]) + + return v + + +def anystr_length_validator(v: 'StrBytes', config: 'BaseConfig') -> 'StrBytes': + v_len = len(v) + + min_length = config.min_anystr_length + if v_len < min_length: + raise errors.AnyStrMinLengthError(limit_value=min_length) + + max_length = config.max_anystr_length + if max_length is not None and v_len > max_length: + raise errors.AnyStrMaxLengthError(limit_value=max_length) + + return v + + +def anystr_strip_whitespace(v: 'StrBytes') -> 'StrBytes': + return v.strip() + + +def anystr_upper(v: 'StrBytes') -> 'StrBytes': + return v.upper() + + +def anystr_lower(v: 'StrBytes') -> 'StrBytes': + return v.lower() + + +def ordered_dict_validator(v: Any) -> 'AnyOrderedDict': + if isinstance(v, OrderedDict): + return v + + try: + return OrderedDict(v) + except (TypeError, ValueError): + raise errors.DictError() + + +def dict_validator(v: Any) -> Dict[Any, Any]: + if isinstance(v, dict): + return v + + try: + return dict(v) + except (TypeError, ValueError): + raise errors.DictError() + + +def list_validator(v: Any) -> List[Any]: + if isinstance(v, list): + return v + elif sequence_like(v): + return list(v) + else: + raise errors.ListError() + + +def tuple_validator(v: Any) -> Tuple[Any, ...]: + if isinstance(v, tuple): + return v + elif sequence_like(v): + return tuple(v) + else: + raise errors.TupleError() + + +def set_validator(v: Any) -> Set[Any]: + if isinstance(v, set): + return v + elif sequence_like(v): + return set(v) + else: + raise errors.SetError() + + +def frozenset_validator(v: Any) -> FrozenSet[Any]: + if isinstance(v, frozenset): + return v + elif sequence_like(v): + return frozenset(v) + else: + raise errors.FrozenSetError() + + +def deque_validator(v: Any) -> Deque[Any]: + if isinstance(v, deque): + return v + elif sequence_like(v): + return deque(v) + else: + raise errors.DequeError() + + +def enum_member_validator(v: Any, field: 'ModelField', config: 'BaseConfig') -> Enum: + try: + enum_v = field.type_(v) + except ValueError: + # field.type_ should be an enum, so will be iterable + raise errors.EnumMemberError(enum_values=list(field.type_)) + return enum_v.value if config.use_enum_values else enum_v + + +def uuid_validator(v: Any, field: 'ModelField') -> UUID: + try: + if isinstance(v, str): + v = UUID(v) + elif isinstance(v, (bytes, bytearray)): + try: + v = UUID(v.decode()) + except ValueError: + # 16 bytes in big-endian order as the bytes argument fail + # the above check + v = UUID(bytes=v) + except ValueError: + raise errors.UUIDError() + + if not isinstance(v, UUID): + raise errors.UUIDError() + + required_version = getattr(field.type_, '_required_version', None) + if required_version and v.version != required_version: + raise errors.UUIDVersionError(required_version=required_version) + + return v + + +def decimal_validator(v: Any) -> Decimal: + if isinstance(v, Decimal): + return v + elif isinstance(v, (bytes, bytearray)): + v = v.decode() + + v = str(v).strip() + + try: + v = Decimal(v) + except DecimalException: + raise errors.DecimalError() + + if not v.is_finite(): + raise errors.DecimalIsNotFiniteError() + + return v + + +def hashable_validator(v: Any) -> Hashable: + if isinstance(v, Hashable): + return v + + raise errors.HashableError() + + +def ip_v4_address_validator(v: Any) -> IPv4Address: + if isinstance(v, IPv4Address): + return v + + try: + return IPv4Address(v) + except ValueError: + raise errors.IPv4AddressError() + + +def ip_v6_address_validator(v: Any) -> IPv6Address: + if isinstance(v, IPv6Address): + return v + + try: + return IPv6Address(v) + except ValueError: + raise errors.IPv6AddressError() + + +def ip_v4_network_validator(v: Any) -> IPv4Network: + """ + Assume IPv4Network initialised with a default ``strict`` argument + + See more: + https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network + """ + if isinstance(v, IPv4Network): + return v + + try: + return IPv4Network(v) + except ValueError: + raise errors.IPv4NetworkError() + + +def ip_v6_network_validator(v: Any) -> IPv6Network: + """ + Assume IPv6Network initialised with a default ``strict`` argument + + See more: + https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network + """ + if isinstance(v, IPv6Network): + return v + + try: + return IPv6Network(v) + except ValueError: + raise errors.IPv6NetworkError() + + +def ip_v4_interface_validator(v: Any) -> IPv4Interface: + if isinstance(v, IPv4Interface): + return v + + try: + return IPv4Interface(v) + except ValueError: + raise errors.IPv4InterfaceError() + + +def ip_v6_interface_validator(v: Any) -> IPv6Interface: + if isinstance(v, IPv6Interface): + return v + + try: + return IPv6Interface(v) + except ValueError: + raise errors.IPv6InterfaceError() + + +def path_validator(v: Any) -> Path: + if isinstance(v, Path): + return v + + try: + return Path(v) + except TypeError: + raise errors.PathError() + + +def path_exists_validator(v: Any) -> Path: + if not v.exists(): + raise errors.PathNotExistsError(path=v) + + return v + + +def callable_validator(v: Any) -> AnyCallable: + """ + Perform a simple check if the value is callable. + + Note: complete matching of argument type hints and return types is not performed + """ + if callable(v): + return v + + raise errors.CallableError(value=v) + + +def enum_validator(v: Any) -> Enum: + if isinstance(v, Enum): + return v + + raise errors.EnumError(value=v) + + +def int_enum_validator(v: Any) -> IntEnum: + if isinstance(v, IntEnum): + return v + + raise errors.IntEnumError(value=v) + + +def make_literal_validator(type_: Any) -> Callable[[Any], Any]: + permitted_choices = all_literal_values(type_) + + # To have a O(1) complexity and still return one of the values set inside the `Literal`, + # we create a dict with the set values (a set causes some problems with the way intersection works). + # In some cases the set value and checked value can indeed be different (see `test_literal_validator_str_enum`) + allowed_choices = {v: v for v in permitted_choices} + + def literal_validator(v: Any) -> Any: + try: + return allowed_choices[v] + except (KeyError, TypeError): + raise errors.WrongConstantError(given=v, permitted=permitted_choices) + + return literal_validator + + +def constr_length_validator(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes': + v_len = len(v) + + min_length = field.type_.min_length if field.type_.min_length is not None else config.min_anystr_length + if v_len < min_length: + raise errors.AnyStrMinLengthError(limit_value=min_length) + + max_length = field.type_.max_length if field.type_.max_length is not None else config.max_anystr_length + if max_length is not None and v_len > max_length: + raise errors.AnyStrMaxLengthError(limit_value=max_length) + + return v + + +def constr_strip_whitespace(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes': + strip_whitespace = field.type_.strip_whitespace or config.anystr_strip_whitespace + if strip_whitespace: + v = v.strip() + + return v + + +def constr_upper(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes': + upper = field.type_.to_upper or config.anystr_upper + if upper: + v = v.upper() + + return v + + +def constr_lower(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes': + lower = field.type_.to_lower or config.anystr_lower + if lower: + v = v.lower() + return v + + +def validate_json(v: Any, config: 'BaseConfig') -> Any: + if v is None: + # pass None through to other validators + return v + try: + return config.json_loads(v) # type: ignore + except ValueError: + raise errors.JsonError() + except TypeError: + raise errors.JsonTypeError() + + +T = TypeVar('T') + + +def make_arbitrary_type_validator(type_: Type[T]) -> Callable[[T], T]: + def arbitrary_type_validator(v: Any) -> T: + if isinstance(v, type_): + return v + raise errors.ArbitraryTypeError(expected_arbitrary_type=type_) + + return arbitrary_type_validator + + +def make_class_validator(type_: Type[T]) -> Callable[[Any], Type[T]]: + def class_validator(v: Any) -> Type[T]: + if lenient_issubclass(v, type_): + return v + raise errors.SubclassError(expected_class=type_) + + return class_validator + + +def any_class_validator(v: Any) -> Type[T]: + if isinstance(v, type): + return v + raise errors.ClassError() + + +def none_validator(v: Any) -> 'Literal[None]': + if v is None: + return v + raise errors.NotNoneError() + + +def pattern_validator(v: Any) -> Pattern[str]: + if isinstance(v, Pattern): + return v + + str_value = str_validator(v) + + try: + return re.compile(str_value) + except re.error: + raise errors.PatternError() + + +NamedTupleT = TypeVar('NamedTupleT', bound=NamedTuple) + + +def make_namedtuple_validator( + namedtuple_cls: Type[NamedTupleT], config: Type['BaseConfig'] +) -> Callable[[Tuple[Any, ...]], NamedTupleT]: + from .annotated_types import create_model_from_namedtuple + + NamedTupleModel = create_model_from_namedtuple( + namedtuple_cls, + __config__=config, + __module__=namedtuple_cls.__module__, + ) + namedtuple_cls.__pydantic_model__ = NamedTupleModel # type: ignore[attr-defined] + + def namedtuple_validator(values: Tuple[Any, ...]) -> NamedTupleT: + annotations = NamedTupleModel.__annotations__ + + if len(values) > len(annotations): + raise errors.ListMaxLengthError(limit_value=len(annotations)) + + dict_values: Dict[str, Any] = dict(zip(annotations, values)) + validated_dict_values: Dict[str, Any] = dict(NamedTupleModel(**dict_values)) + return namedtuple_cls(**validated_dict_values) + + return namedtuple_validator + + +def make_typeddict_validator( + typeddict_cls: Type['TypedDict'], config: Type['BaseConfig'] # type: ignore[valid-type] +) -> Callable[[Any], Dict[str, Any]]: + from .annotated_types import create_model_from_typeddict + + TypedDictModel = create_model_from_typeddict( + typeddict_cls, + __config__=config, + __module__=typeddict_cls.__module__, + ) + typeddict_cls.__pydantic_model__ = TypedDictModel # type: ignore[attr-defined] + + def typeddict_validator(values: 'TypedDict') -> Dict[str, Any]: # type: ignore[valid-type] + return TypedDictModel.parse_obj(values).dict(exclude_unset=True) + + return typeddict_validator + + +class IfConfig: + def __init__(self, validator: AnyCallable, *config_attr_names: str, ignored_value: Any = False) -> None: + self.validator = validator + self.config_attr_names = config_attr_names + self.ignored_value = ignored_value + + def check(self, config: Type['BaseConfig']) -> bool: + return any(getattr(config, name) not in {None, self.ignored_value} for name in self.config_attr_names) + + +# order is important here, for example: bool is a subclass of int so has to come first, datetime before date same, +# IPv4Interface before IPv4Address, etc +_VALIDATORS: List[Tuple[Type[Any], List[Any]]] = [ + (IntEnum, [int_validator, enum_member_validator]), + (Enum, [enum_member_validator]), + ( + str, + [ + str_validator, + IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'), + IfConfig(anystr_upper, 'anystr_upper'), + IfConfig(anystr_lower, 'anystr_lower'), + IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'), + ], + ), + ( + bytes, + [ + bytes_validator, + IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'), + IfConfig(anystr_upper, 'anystr_upper'), + IfConfig(anystr_lower, 'anystr_lower'), + IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'), + ], + ), + (bool, [bool_validator]), + (int, [int_validator]), + (float, [float_validator, IfConfig(float_finite_validator, 'allow_inf_nan', ignored_value=True)]), + (Path, [path_validator]), + (datetime, [parse_datetime]), + (date, [parse_date]), + (time, [parse_time]), + (timedelta, [parse_duration]), + (OrderedDict, [ordered_dict_validator]), + (dict, [dict_validator]), + (list, [list_validator]), + (tuple, [tuple_validator]), + (set, [set_validator]), + (frozenset, [frozenset_validator]), + (deque, [deque_validator]), + (UUID, [uuid_validator]), + (Decimal, [decimal_validator]), + (IPv4Interface, [ip_v4_interface_validator]), + (IPv6Interface, [ip_v6_interface_validator]), + (IPv4Address, [ip_v4_address_validator]), + (IPv6Address, [ip_v6_address_validator]), + (IPv4Network, [ip_v4_network_validator]), + (IPv6Network, [ip_v6_network_validator]), +] + + +def find_validators( # noqa: C901 (ignore complexity) + type_: Type[Any], config: Type['BaseConfig'] +) -> Generator[AnyCallable, None, None]: + from .dataclasses import is_builtin_dataclass, make_dataclass_validator + + if type_ is Any or type_ is object: + return + type_type = type_.__class__ + if type_type == ForwardRef or type_type == TypeVar: + return + + if is_none_type(type_): + yield none_validator + return + if type_ is Pattern or type_ is re.Pattern: + yield pattern_validator + return + if type_ is Hashable or type_ is CollectionsHashable: + yield hashable_validator + return + if is_callable_type(type_): + yield callable_validator + return + if is_literal_type(type_): + yield make_literal_validator(type_) + return + if is_builtin_dataclass(type_): + yield from make_dataclass_validator(type_, config) + return + if type_ is Enum: + yield enum_validator + return + if type_ is IntEnum: + yield int_enum_validator + return + if is_namedtuple(type_): + yield tuple_validator + yield make_namedtuple_validator(type_, config) + return + if is_typeddict(type_): + yield make_typeddict_validator(type_, config) + return + + class_ = get_class(type_) + if class_ is not None: + if class_ is not Any and isinstance(class_, type): + yield make_class_validator(class_) + else: + yield any_class_validator + return + + for val_type, validators in _VALIDATORS: + try: + if issubclass(type_, val_type): + for v in validators: + if isinstance(v, IfConfig): + if v.check(config): + yield v.validator + else: + yield v + return + except TypeError: + raise RuntimeError(f'error checking inheritance of {type_!r} (type: {display_as_type(type_)})') + + if config.arbitrary_types_allowed: + yield make_arbitrary_type_validator(type_) + else: + raise RuntimeError(f'no validator found for {type_}, see `arbitrary_types_allowed` in Config') diff --git a/lib/pydantic/v1/version.py b/lib/pydantic/v1/version.py new file mode 100644 index 00000000..ec982ba7 --- /dev/null +++ b/lib/pydantic/v1/version.py @@ -0,0 +1,38 @@ +__all__ = 'compiled', 'VERSION', 'version_info' + +VERSION = '1.10.14' + +try: + import cython # type: ignore +except ImportError: + compiled: bool = False +else: # pragma: no cover + try: + compiled = cython.compiled + except AttributeError: + compiled = False + + +def version_info() -> str: + import platform + import sys + from importlib import import_module + from pathlib import Path + + optional_deps = [] + for p in ('devtools', 'dotenv', 'email-validator', 'typing-extensions'): + try: + import_module(p.replace('-', '_')) + except ImportError: + continue + optional_deps.append(p) + + info = { + 'pydantic version': VERSION, + 'pydantic compiled': compiled, + 'install path': Path(__file__).resolve().parent, + 'python version': sys.version, + 'platform': platform.platform(), + 'optional deps. installed': optional_deps, + } + return '\n'.join('{:>30} {}'.format(k + ':', str(v).replace('\n', ' ')) for k, v in info.items()) diff --git a/lib/pydantic/validate_call_decorator.py b/lib/pydantic/validate_call_decorator.py new file mode 100644 index 00000000..b95fa3b6 --- /dev/null +++ b/lib/pydantic/validate_call_decorator.py @@ -0,0 +1,67 @@ +"""Decorator for validating function calls.""" +from __future__ import annotations as _annotations + +import functools +from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload + +from ._internal import _validate_call + +__all__ = ('validate_call',) + +if TYPE_CHECKING: + from .config import ConfigDict + + AnyCallableT = TypeVar('AnyCallableT', bound=Callable[..., Any]) + + +@overload +def validate_call( + *, config: ConfigDict | None = None, validate_return: bool = False +) -> Callable[[AnyCallableT], AnyCallableT]: + ... + + +@overload +def validate_call(__func: AnyCallableT) -> AnyCallableT: + ... + + +def validate_call( + __func: AnyCallableT | None = None, + *, + config: ConfigDict | None = None, + validate_return: bool = False, +) -> AnyCallableT | Callable[[AnyCallableT], AnyCallableT]: + """Usage docs: https://docs.pydantic.dev/2.6/concepts/validation_decorator/ + + Returns a decorated wrapper around the function that validates the arguments and, optionally, the return value. + + Usage may be either as a plain decorator `@validate_call` or with arguments `@validate_call(...)`. + + Args: + __func: The function to be decorated. + config: The configuration dictionary. + validate_return: Whether to validate the return value. + + Returns: + The decorated function. + """ + + def validate(function: AnyCallableT) -> AnyCallableT: + if isinstance(function, (classmethod, staticmethod)): + name = type(function).__name__ + raise TypeError(f'The `@{name}` decorator should be applied after `@validate_call` (put `@{name}` on top)') + validate_call_wrapper = _validate_call.ValidateCallWrapper(function, config, validate_return) + + @functools.wraps(function) + def wrapper_function(*args, **kwargs): + return validate_call_wrapper(*args, **kwargs) + + wrapper_function.raw_function = function # type: ignore + + return wrapper_function # type: ignore + + if __func: + return validate(__func) + else: + return validate diff --git a/lib/pydantic/validators.py b/lib/pydantic/validators.py index fb6d0418..55b0339e 100644 --- a/lib/pydantic/validators.py +++ b/lib/pydantic/validators.py @@ -1,765 +1,4 @@ -import math -import re -from collections import OrderedDict, deque -from collections.abc import Hashable as CollectionsHashable -from datetime import date, datetime, time, timedelta -from decimal import Decimal, DecimalException -from enum import Enum, IntEnum -from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network -from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Deque, - Dict, - ForwardRef, - FrozenSet, - Generator, - Hashable, - List, - NamedTuple, - Pattern, - Set, - Tuple, - Type, - TypeVar, - Union, -) -from uuid import UUID +"""The `validators` module is a backport module from V1.""" +from ._migration import getattr_migration -from . import errors -from .datetime_parse import parse_date, parse_datetime, parse_duration, parse_time -from .typing import ( - AnyCallable, - all_literal_values, - display_as_type, - get_class, - is_callable_type, - is_literal_type, - is_namedtuple, - is_none_type, - is_typeddict, -) -from .utils import almost_equal_floats, lenient_issubclass, sequence_like - -if TYPE_CHECKING: - from typing_extensions import Literal, TypedDict - - from .config import BaseConfig - from .fields import ModelField - from .types import ConstrainedDecimal, ConstrainedFloat, ConstrainedInt - - ConstrainedNumber = Union[ConstrainedDecimal, ConstrainedFloat, ConstrainedInt] - AnyOrderedDict = OrderedDict[Any, Any] - Number = Union[int, float, Decimal] - StrBytes = Union[str, bytes] - - -def str_validator(v: Any) -> Union[str]: - if isinstance(v, str): - if isinstance(v, Enum): - return v.value - else: - return v - elif isinstance(v, (float, int, Decimal)): - # is there anything else we want to add here? If you think so, create an issue. - return str(v) - elif isinstance(v, (bytes, bytearray)): - return v.decode() - else: - raise errors.StrError() - - -def strict_str_validator(v: Any) -> Union[str]: - if isinstance(v, str) and not isinstance(v, Enum): - return v - raise errors.StrError() - - -def bytes_validator(v: Any) -> Union[bytes]: - if isinstance(v, bytes): - return v - elif isinstance(v, bytearray): - return bytes(v) - elif isinstance(v, str): - return v.encode() - elif isinstance(v, (float, int, Decimal)): - return str(v).encode() - else: - raise errors.BytesError() - - -def strict_bytes_validator(v: Any) -> Union[bytes]: - if isinstance(v, bytes): - return v - elif isinstance(v, bytearray): - return bytes(v) - else: - raise errors.BytesError() - - -BOOL_FALSE = {0, '0', 'off', 'f', 'false', 'n', 'no'} -BOOL_TRUE = {1, '1', 'on', 't', 'true', 'y', 'yes'} - - -def bool_validator(v: Any) -> bool: - if v is True or v is False: - return v - if isinstance(v, bytes): - v = v.decode() - if isinstance(v, str): - v = v.lower() - try: - if v in BOOL_TRUE: - return True - if v in BOOL_FALSE: - return False - except TypeError: - raise errors.BoolError() - raise errors.BoolError() - - -# matches the default limit cpython, see https://github.com/python/cpython/pull/96500 -max_str_int = 4_300 - - -def int_validator(v: Any) -> int: - if isinstance(v, int) and not (v is True or v is False): - return v - - # see https://github.com/pydantic/pydantic/issues/1477 and in turn, https://github.com/python/cpython/issues/95778 - # this check should be unnecessary once patch releases are out for 3.7, 3.8, 3.9 and 3.10 - # but better to check here until then. - # NOTICE: this does not fully protect user from the DOS risk since the standard library JSON implementation - # (and other std lib modules like xml) use `int()` and are likely called before this, the best workaround is to - # 1. update to the latest patch release of python once released, 2. use a different JSON library like ujson - if isinstance(v, (str, bytes, bytearray)) and len(v) > max_str_int: - raise errors.IntegerError() - - try: - return int(v) - except (TypeError, ValueError, OverflowError): - raise errors.IntegerError() - - -def strict_int_validator(v: Any) -> int: - if isinstance(v, int) and not (v is True or v is False): - return v - raise errors.IntegerError() - - -def float_validator(v: Any) -> float: - if isinstance(v, float): - return v - - try: - return float(v) - except (TypeError, ValueError): - raise errors.FloatError() - - -def strict_float_validator(v: Any) -> float: - if isinstance(v, float): - return v - raise errors.FloatError() - - -def float_finite_validator(v: 'Number', field: 'ModelField', config: 'BaseConfig') -> 'Number': - allow_inf_nan = getattr(field.type_, 'allow_inf_nan', None) - if allow_inf_nan is None: - allow_inf_nan = config.allow_inf_nan - - if allow_inf_nan is False and (math.isnan(v) or math.isinf(v)): - raise errors.NumberNotFiniteError() - return v - - -def number_multiple_validator(v: 'Number', field: 'ModelField') -> 'Number': - field_type: ConstrainedNumber = field.type_ - if field_type.multiple_of is not None: - mod = float(v) / float(field_type.multiple_of) % 1 - if not almost_equal_floats(mod, 0.0) and not almost_equal_floats(mod, 1.0): - raise errors.NumberNotMultipleError(multiple_of=field_type.multiple_of) - return v - - -def number_size_validator(v: 'Number', field: 'ModelField') -> 'Number': - field_type: ConstrainedNumber = field.type_ - if field_type.gt is not None and not v > field_type.gt: - raise errors.NumberNotGtError(limit_value=field_type.gt) - elif field_type.ge is not None and not v >= field_type.ge: - raise errors.NumberNotGeError(limit_value=field_type.ge) - - if field_type.lt is not None and not v < field_type.lt: - raise errors.NumberNotLtError(limit_value=field_type.lt) - if field_type.le is not None and not v <= field_type.le: - raise errors.NumberNotLeError(limit_value=field_type.le) - - return v - - -def constant_validator(v: 'Any', field: 'ModelField') -> 'Any': - """Validate ``const`` fields. - - The value provided for a ``const`` field must be equal to the default value - of the field. This is to support the keyword of the same name in JSON - Schema. - """ - if v != field.default: - raise errors.WrongConstantError(given=v, permitted=[field.default]) - - return v - - -def anystr_length_validator(v: 'StrBytes', config: 'BaseConfig') -> 'StrBytes': - v_len = len(v) - - min_length = config.min_anystr_length - if v_len < min_length: - raise errors.AnyStrMinLengthError(limit_value=min_length) - - max_length = config.max_anystr_length - if max_length is not None and v_len > max_length: - raise errors.AnyStrMaxLengthError(limit_value=max_length) - - return v - - -def anystr_strip_whitespace(v: 'StrBytes') -> 'StrBytes': - return v.strip() - - -def anystr_upper(v: 'StrBytes') -> 'StrBytes': - return v.upper() - - -def anystr_lower(v: 'StrBytes') -> 'StrBytes': - return v.lower() - - -def ordered_dict_validator(v: Any) -> 'AnyOrderedDict': - if isinstance(v, OrderedDict): - return v - - try: - return OrderedDict(v) - except (TypeError, ValueError): - raise errors.DictError() - - -def dict_validator(v: Any) -> Dict[Any, Any]: - if isinstance(v, dict): - return v - - try: - return dict(v) - except (TypeError, ValueError): - raise errors.DictError() - - -def list_validator(v: Any) -> List[Any]: - if isinstance(v, list): - return v - elif sequence_like(v): - return list(v) - else: - raise errors.ListError() - - -def tuple_validator(v: Any) -> Tuple[Any, ...]: - if isinstance(v, tuple): - return v - elif sequence_like(v): - return tuple(v) - else: - raise errors.TupleError() - - -def set_validator(v: Any) -> Set[Any]: - if isinstance(v, set): - return v - elif sequence_like(v): - return set(v) - else: - raise errors.SetError() - - -def frozenset_validator(v: Any) -> FrozenSet[Any]: - if isinstance(v, frozenset): - return v - elif sequence_like(v): - return frozenset(v) - else: - raise errors.FrozenSetError() - - -def deque_validator(v: Any) -> Deque[Any]: - if isinstance(v, deque): - return v - elif sequence_like(v): - return deque(v) - else: - raise errors.DequeError() - - -def enum_member_validator(v: Any, field: 'ModelField', config: 'BaseConfig') -> Enum: - try: - enum_v = field.type_(v) - except ValueError: - # field.type_ should be an enum, so will be iterable - raise errors.EnumMemberError(enum_values=list(field.type_)) - return enum_v.value if config.use_enum_values else enum_v - - -def uuid_validator(v: Any, field: 'ModelField') -> UUID: - try: - if isinstance(v, str): - v = UUID(v) - elif isinstance(v, (bytes, bytearray)): - try: - v = UUID(v.decode()) - except ValueError: - # 16 bytes in big-endian order as the bytes argument fail - # the above check - v = UUID(bytes=v) - except ValueError: - raise errors.UUIDError() - - if not isinstance(v, UUID): - raise errors.UUIDError() - - required_version = getattr(field.type_, '_required_version', None) - if required_version and v.version != required_version: - raise errors.UUIDVersionError(required_version=required_version) - - return v - - -def decimal_validator(v: Any) -> Decimal: - if isinstance(v, Decimal): - return v - elif isinstance(v, (bytes, bytearray)): - v = v.decode() - - v = str(v).strip() - - try: - v = Decimal(v) - except DecimalException: - raise errors.DecimalError() - - if not v.is_finite(): - raise errors.DecimalIsNotFiniteError() - - return v - - -def hashable_validator(v: Any) -> Hashable: - if isinstance(v, Hashable): - return v - - raise errors.HashableError() - - -def ip_v4_address_validator(v: Any) -> IPv4Address: - if isinstance(v, IPv4Address): - return v - - try: - return IPv4Address(v) - except ValueError: - raise errors.IPv4AddressError() - - -def ip_v6_address_validator(v: Any) -> IPv6Address: - if isinstance(v, IPv6Address): - return v - - try: - return IPv6Address(v) - except ValueError: - raise errors.IPv6AddressError() - - -def ip_v4_network_validator(v: Any) -> IPv4Network: - """ - Assume IPv4Network initialised with a default ``strict`` argument - - See more: - https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network - """ - if isinstance(v, IPv4Network): - return v - - try: - return IPv4Network(v) - except ValueError: - raise errors.IPv4NetworkError() - - -def ip_v6_network_validator(v: Any) -> IPv6Network: - """ - Assume IPv6Network initialised with a default ``strict`` argument - - See more: - https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network - """ - if isinstance(v, IPv6Network): - return v - - try: - return IPv6Network(v) - except ValueError: - raise errors.IPv6NetworkError() - - -def ip_v4_interface_validator(v: Any) -> IPv4Interface: - if isinstance(v, IPv4Interface): - return v - - try: - return IPv4Interface(v) - except ValueError: - raise errors.IPv4InterfaceError() - - -def ip_v6_interface_validator(v: Any) -> IPv6Interface: - if isinstance(v, IPv6Interface): - return v - - try: - return IPv6Interface(v) - except ValueError: - raise errors.IPv6InterfaceError() - - -def path_validator(v: Any) -> Path: - if isinstance(v, Path): - return v - - try: - return Path(v) - except TypeError: - raise errors.PathError() - - -def path_exists_validator(v: Any) -> Path: - if not v.exists(): - raise errors.PathNotExistsError(path=v) - - return v - - -def callable_validator(v: Any) -> AnyCallable: - """ - Perform a simple check if the value is callable. - - Note: complete matching of argument type hints and return types is not performed - """ - if callable(v): - return v - - raise errors.CallableError(value=v) - - -def enum_validator(v: Any) -> Enum: - if isinstance(v, Enum): - return v - - raise errors.EnumError(value=v) - - -def int_enum_validator(v: Any) -> IntEnum: - if isinstance(v, IntEnum): - return v - - raise errors.IntEnumError(value=v) - - -def make_literal_validator(type_: Any) -> Callable[[Any], Any]: - permitted_choices = all_literal_values(type_) - - # To have a O(1) complexity and still return one of the values set inside the `Literal`, - # we create a dict with the set values (a set causes some problems with the way intersection works). - # In some cases the set value and checked value can indeed be different (see `test_literal_validator_str_enum`) - allowed_choices = {v: v for v in permitted_choices} - - def literal_validator(v: Any) -> Any: - try: - return allowed_choices[v] - except KeyError: - raise errors.WrongConstantError(given=v, permitted=permitted_choices) - - return literal_validator - - -def constr_length_validator(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes': - v_len = len(v) - - min_length = field.type_.min_length if field.type_.min_length is not None else config.min_anystr_length - if v_len < min_length: - raise errors.AnyStrMinLengthError(limit_value=min_length) - - max_length = field.type_.max_length if field.type_.max_length is not None else config.max_anystr_length - if max_length is not None and v_len > max_length: - raise errors.AnyStrMaxLengthError(limit_value=max_length) - - return v - - -def constr_strip_whitespace(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes': - strip_whitespace = field.type_.strip_whitespace or config.anystr_strip_whitespace - if strip_whitespace: - v = v.strip() - - return v - - -def constr_upper(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes': - upper = field.type_.to_upper or config.anystr_upper - if upper: - v = v.upper() - - return v - - -def constr_lower(v: 'StrBytes', field: 'ModelField', config: 'BaseConfig') -> 'StrBytes': - lower = field.type_.to_lower or config.anystr_lower - if lower: - v = v.lower() - return v - - -def validate_json(v: Any, config: 'BaseConfig') -> Any: - if v is None: - # pass None through to other validators - return v - try: - return config.json_loads(v) # type: ignore - except ValueError: - raise errors.JsonError() - except TypeError: - raise errors.JsonTypeError() - - -T = TypeVar('T') - - -def make_arbitrary_type_validator(type_: Type[T]) -> Callable[[T], T]: - def arbitrary_type_validator(v: Any) -> T: - if isinstance(v, type_): - return v - raise errors.ArbitraryTypeError(expected_arbitrary_type=type_) - - return arbitrary_type_validator - - -def make_class_validator(type_: Type[T]) -> Callable[[Any], Type[T]]: - def class_validator(v: Any) -> Type[T]: - if lenient_issubclass(v, type_): - return v - raise errors.SubclassError(expected_class=type_) - - return class_validator - - -def any_class_validator(v: Any) -> Type[T]: - if isinstance(v, type): - return v - raise errors.ClassError() - - -def none_validator(v: Any) -> 'Literal[None]': - if v is None: - return v - raise errors.NotNoneError() - - -def pattern_validator(v: Any) -> Pattern[str]: - if isinstance(v, Pattern): - return v - - str_value = str_validator(v) - - try: - return re.compile(str_value) - except re.error: - raise errors.PatternError() - - -NamedTupleT = TypeVar('NamedTupleT', bound=NamedTuple) - - -def make_namedtuple_validator( - namedtuple_cls: Type[NamedTupleT], config: Type['BaseConfig'] -) -> Callable[[Tuple[Any, ...]], NamedTupleT]: - from .annotated_types import create_model_from_namedtuple - - NamedTupleModel = create_model_from_namedtuple( - namedtuple_cls, - __config__=config, - __module__=namedtuple_cls.__module__, - ) - namedtuple_cls.__pydantic_model__ = NamedTupleModel # type: ignore[attr-defined] - - def namedtuple_validator(values: Tuple[Any, ...]) -> NamedTupleT: - annotations = NamedTupleModel.__annotations__ - - if len(values) > len(annotations): - raise errors.ListMaxLengthError(limit_value=len(annotations)) - - dict_values: Dict[str, Any] = dict(zip(annotations, values)) - validated_dict_values: Dict[str, Any] = dict(NamedTupleModel(**dict_values)) - return namedtuple_cls(**validated_dict_values) - - return namedtuple_validator - - -def make_typeddict_validator( - typeddict_cls: Type['TypedDict'], config: Type['BaseConfig'] # type: ignore[valid-type] -) -> Callable[[Any], Dict[str, Any]]: - from .annotated_types import create_model_from_typeddict - - TypedDictModel = create_model_from_typeddict( - typeddict_cls, - __config__=config, - __module__=typeddict_cls.__module__, - ) - typeddict_cls.__pydantic_model__ = TypedDictModel # type: ignore[attr-defined] - - def typeddict_validator(values: 'TypedDict') -> Dict[str, Any]: # type: ignore[valid-type] - return TypedDictModel.parse_obj(values).dict(exclude_unset=True) - - return typeddict_validator - - -class IfConfig: - def __init__(self, validator: AnyCallable, *config_attr_names: str, ignored_value: Any = False) -> None: - self.validator = validator - self.config_attr_names = config_attr_names - self.ignored_value = ignored_value - - def check(self, config: Type['BaseConfig']) -> bool: - return any(getattr(config, name) not in {None, self.ignored_value} for name in self.config_attr_names) - - -# order is important here, for example: bool is a subclass of int so has to come first, datetime before date same, -# IPv4Interface before IPv4Address, etc -_VALIDATORS: List[Tuple[Type[Any], List[Any]]] = [ - (IntEnum, [int_validator, enum_member_validator]), - (Enum, [enum_member_validator]), - ( - str, - [ - str_validator, - IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'), - IfConfig(anystr_upper, 'anystr_upper'), - IfConfig(anystr_lower, 'anystr_lower'), - IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'), - ], - ), - ( - bytes, - [ - bytes_validator, - IfConfig(anystr_strip_whitespace, 'anystr_strip_whitespace'), - IfConfig(anystr_upper, 'anystr_upper'), - IfConfig(anystr_lower, 'anystr_lower'), - IfConfig(anystr_length_validator, 'min_anystr_length', 'max_anystr_length'), - ], - ), - (bool, [bool_validator]), - (int, [int_validator]), - (float, [float_validator, IfConfig(float_finite_validator, 'allow_inf_nan', ignored_value=True)]), - (Path, [path_validator]), - (datetime, [parse_datetime]), - (date, [parse_date]), - (time, [parse_time]), - (timedelta, [parse_duration]), - (OrderedDict, [ordered_dict_validator]), - (dict, [dict_validator]), - (list, [list_validator]), - (tuple, [tuple_validator]), - (set, [set_validator]), - (frozenset, [frozenset_validator]), - (deque, [deque_validator]), - (UUID, [uuid_validator]), - (Decimal, [decimal_validator]), - (IPv4Interface, [ip_v4_interface_validator]), - (IPv6Interface, [ip_v6_interface_validator]), - (IPv4Address, [ip_v4_address_validator]), - (IPv6Address, [ip_v6_address_validator]), - (IPv4Network, [ip_v4_network_validator]), - (IPv6Network, [ip_v6_network_validator]), -] - - -def find_validators( # noqa: C901 (ignore complexity) - type_: Type[Any], config: Type['BaseConfig'] -) -> Generator[AnyCallable, None, None]: - from .dataclasses import is_builtin_dataclass, make_dataclass_validator - - if type_ is Any or type_ is object: - return - type_type = type_.__class__ - if type_type == ForwardRef or type_type == TypeVar: - return - - if is_none_type(type_): - yield none_validator - return - if type_ is Pattern or type_ is re.Pattern: - yield pattern_validator - return - if type_ is Hashable or type_ is CollectionsHashable: - yield hashable_validator - return - if is_callable_type(type_): - yield callable_validator - return - if is_literal_type(type_): - yield make_literal_validator(type_) - return - if is_builtin_dataclass(type_): - yield from make_dataclass_validator(type_, config) - return - if type_ is Enum: - yield enum_validator - return - if type_ is IntEnum: - yield int_enum_validator - return - if is_namedtuple(type_): - yield tuple_validator - yield make_namedtuple_validator(type_, config) - return - if is_typeddict(type_): - yield make_typeddict_validator(type_, config) - return - - class_ = get_class(type_) - if class_ is not None: - if class_ is not Any and isinstance(class_, type): - yield make_class_validator(class_) - else: - yield any_class_validator - return - - for val_type, validators in _VALIDATORS: - try: - if issubclass(type_, val_type): - for v in validators: - if isinstance(v, IfConfig): - if v.check(config): - yield v.validator - else: - yield v - return - except TypeError: - raise RuntimeError(f'error checking inheritance of {type_!r} (type: {display_as_type(type_)})') - - if config.arbitrary_types_allowed: - yield make_arbitrary_type_validator(type_) - else: - raise RuntimeError(f'no validator found for {type_}, see `arbitrary_types_allowed` in Config') +__getattr__ = getattr_migration(__name__) diff --git a/lib/pydantic/version.py b/lib/pydantic/version.py index 32c61633..3e233771 100644 --- a/lib/pydantic/version.py +++ b/lib/pydantic/version.py @@ -1,38 +1,80 @@ -__all__ = 'compiled', 'VERSION', 'version_info' +"""The `version` module holds the version information for Pydantic.""" +from __future__ import annotations as _annotations -VERSION = '1.10.2' +__all__ = 'VERSION', 'version_info' -try: - import cython # type: ignore -except ImportError: - compiled: bool = False -else: # pragma: no cover - try: - compiled = cython.compiled - except AttributeError: - compiled = False +VERSION = '2.6.4' +"""The version of Pydantic.""" + + +def version_short() -> str: + """Return the `major.minor` part of Pydantic version. + + It returns '2.1' if Pydantic version is '2.1.1'. + """ + return '.'.join(VERSION.split('.')[:2]) def version_info() -> str: + """Return complete version information for Pydantic and its dependencies.""" + import importlib.metadata as importlib_metadata + import os import platform import sys - from importlib import import_module from pathlib import Path - optional_deps = [] - for p in ('devtools', 'dotenv', 'email-validator', 'typing-extensions'): - try: - import_module(p.replace('-', '_')) - except ImportError: - continue - optional_deps.append(p) + import pydantic_core._pydantic_core as pdc + + from ._internal import _git as git + + # get data about packages that are closely related to pydantic, use pydantic or often conflict with pydantic + package_names = { + 'email-validator', + 'fastapi', + 'mypy', + 'pydantic-extra-types', + 'pydantic-settings', + 'pyright', + 'typing_extensions', + } + related_packages = [] + + for dist in importlib_metadata.distributions(): + name = dist.metadata['Name'] + if name in package_names: + related_packages.append(f'{name}-{dist.version}') + + pydantic_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + most_recent_commit = ( + git.git_revision(pydantic_dir) if git.is_git_repo(pydantic_dir) and git.have_git() else 'unknown' + ) info = { 'pydantic version': VERSION, - 'pydantic compiled': compiled, + 'pydantic-core version': pdc.__version__, + 'pydantic-core build': getattr(pdc, 'build_info', None) or pdc.build_profile, 'install path': Path(__file__).resolve().parent, 'python version': sys.version, 'platform': platform.platform(), - 'optional deps. installed': optional_deps, + 'related packages': ' '.join(related_packages), + 'commit': most_recent_commit, } return '\n'.join('{:>30} {}'.format(k + ':', str(v).replace('\n', ' ')) for k, v in info.items()) + + +def parse_mypy_version(version: str) -> tuple[int, ...]: + """Parse mypy string version to tuple of ints. + + This function is included here rather than the mypy plugin file because the mypy plugin file cannot be imported + outside a mypy run. + + It parses normal version like `0.930` and dev version + like `0.940+dev.04cac4b5d911c4f9529e6ce86a27b44f28846f5d.dirty`. + + Args: + version: The mypy version string. + + Returns: + A tuple of ints. e.g. (0, 930). + """ + return tuple(map(int, version.partition('+')[0].split('.'))) diff --git a/lib/pydantic/warnings.py b/lib/pydantic/warnings.py new file mode 100644 index 00000000..aedd4fba --- /dev/null +++ b/lib/pydantic/warnings.py @@ -0,0 +1,58 @@ +"""Pydantic-specific warnings.""" +from __future__ import annotations as _annotations + +from .version import version_short + +__all__ = 'PydanticDeprecatedSince20', 'PydanticDeprecationWarning' + + +class PydanticDeprecationWarning(DeprecationWarning): + """A Pydantic specific deprecation warning. + + This warning is raised when using deprecated functionality in Pydantic. It provides information on when the + deprecation was introduced and the expected version in which the corresponding functionality will be removed. + + Attributes: + message: Description of the warning. + since: Pydantic version in what the deprecation was introduced. + expected_removal: Pydantic version in what the corresponding functionality expected to be removed. + """ + + message: str + since: tuple[int, int] + expected_removal: tuple[int, int] + + def __init__( + self, message: str, *args: object, since: tuple[int, int], expected_removal: tuple[int, int] | None = None + ) -> None: + super().__init__(message, *args) + self.message = message.rstrip('.') + self.since = since + self.expected_removal = expected_removal if expected_removal is not None else (since[0] + 1, 0) + + def __str__(self) -> str: + message = ( + f'{self.message}. Deprecated in Pydantic V{self.since[0]}.{self.since[1]}' + f' to be removed in V{self.expected_removal[0]}.{self.expected_removal[1]}.' + ) + if self.since == (2, 0): + message += f' See Pydantic V2 Migration Guide at https://errors.pydantic.dev/{version_short()}/migration/' + return message + + +class PydanticDeprecatedSince20(PydanticDeprecationWarning): + """A specific `PydanticDeprecationWarning` subclass defining functionality deprecated since Pydantic 2.0.""" + + def __init__(self, message: str, *args: object) -> None: + super().__init__(message, *args, since=(2, 0), expected_removal=(3, 0)) + + +class PydanticDeprecatedSince26(PydanticDeprecationWarning): + """A specific `PydanticDeprecationWarning` subclass defining functionality deprecated since Pydantic 2.6.""" + + def __init__(self, message: str, *args: object) -> None: + super().__init__(message, *args, since=(2, 0), expected_removal=(3, 0)) + + +class GenericBeforeBaseModelWarning(Warning): + pass diff --git a/lib/pydantic_core/__init__.py b/lib/pydantic_core/__init__.py new file mode 100644 index 00000000..5b2655c9 --- /dev/null +++ b/lib/pydantic_core/__init__.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import sys as _sys +from typing import Any as _Any + +from ._pydantic_core import ( + ArgsKwargs, + MultiHostUrl, + PydanticCustomError, + PydanticKnownError, + PydanticOmit, + PydanticSerializationError, + PydanticSerializationUnexpectedValue, + PydanticUndefined, + PydanticUndefinedType, + PydanticUseDefault, + SchemaError, + SchemaSerializer, + SchemaValidator, + Some, + TzInfo, + Url, + ValidationError, + __version__, + from_json, + to_json, + to_jsonable_python, + validate_core_schema, +) +from .core_schema import CoreConfig, CoreSchema, CoreSchemaType, ErrorType + +if _sys.version_info < (3, 11): + from typing_extensions import NotRequired as _NotRequired +else: + from typing import NotRequired as _NotRequired + +if _sys.version_info < (3, 9): + from typing_extensions import TypedDict as _TypedDict +else: + from typing import TypedDict as _TypedDict + +__all__ = [ + '__version__', + 'CoreConfig', + 'CoreSchema', + 'CoreSchemaType', + 'SchemaValidator', + 'SchemaSerializer', + 'Some', + 'Url', + 'MultiHostUrl', + 'ArgsKwargs', + 'PydanticUndefined', + 'PydanticUndefinedType', + 'SchemaError', + 'ErrorDetails', + 'InitErrorDetails', + 'ValidationError', + 'PydanticCustomError', + 'PydanticKnownError', + 'PydanticOmit', + 'PydanticUseDefault', + 'PydanticSerializationError', + 'PydanticSerializationUnexpectedValue', + 'TzInfo', + 'to_json', + 'from_json', + 'to_jsonable_python', + 'validate_core_schema', +] + + +class ErrorDetails(_TypedDict): + type: str + """ + The type of error that occurred, this is an identifier designed for + programmatic use that will change rarely or never. + + `type` is unique for each error message, and can hence be used as an identifier to build custom error messages. + """ + loc: tuple[int | str, ...] + """Tuple of strings and ints identifying where in the schema the error occurred.""" + msg: str + """A human readable error message.""" + input: _Any + """The input data at this `loc` that caused the error.""" + ctx: _NotRequired[dict[str, _Any]] + """ + Values which are required to render the error message, and could hence be useful in rendering custom error messages. + Also useful for passing custom error data forward. + """ + + +class InitErrorDetails(_TypedDict): + type: str | PydanticCustomError + """The type of error that occurred, this should a "slug" identifier that changes rarely or never.""" + loc: _NotRequired[tuple[int | str, ...]] + """Tuple of strings and ints identifying where in the schema the error occurred.""" + input: _Any + """The input data at this `loc` that caused the error.""" + ctx: _NotRequired[dict[str, _Any]] + """ + Values which are required to render the error message, and could hence be useful in rendering custom error messages. + Also useful for passing custom error data forward. + """ + + +class ErrorTypeInfo(_TypedDict): + """ + Gives information about errors. + """ + + type: ErrorType + """The type of error that occurred, this should a "slug" identifier that changes rarely or never.""" + message_template_python: str + """String template to render a human readable error message from using context, when the input is Python.""" + example_message_python: str + """Example of a human readable error message, when the input is Python.""" + message_template_json: _NotRequired[str] + """String template to render a human readable error message from using context, when the input is JSON data.""" + example_message_json: _NotRequired[str] + """Example of a human readable error message, when the input is JSON data.""" + example_context: dict[str, _Any] | None + """Example of context values.""" + + +class MultiHostHost(_TypedDict): + """ + A host part of a multi-host URL. + """ + + username: str | None + """The username part of this host, or `None`.""" + password: str | None + """The password part of this host, or `None`.""" + host: str | None + """The host part of this host, or `None`.""" + port: int | None + """The port part of this host, or `None`.""" diff --git a/lib/pydantic_core/_pydantic_core.pyi b/lib/pydantic_core/_pydantic_core.pyi new file mode 100644 index 00000000..a7b727f8 --- /dev/null +++ b/lib/pydantic_core/_pydantic_core.pyi @@ -0,0 +1,882 @@ +from __future__ import annotations + +import datetime +import sys +from typing import Any, Callable, Generic, Optional, Type, TypeVar + +from pydantic_core import ErrorDetails, ErrorTypeInfo, InitErrorDetails, MultiHostHost +from pydantic_core.core_schema import CoreConfig, CoreSchema, ErrorType + +if sys.version_info < (3, 8): + from typing_extensions import final +else: + from typing import final + +if sys.version_info < (3, 11): + from typing_extensions import Literal, LiteralString, Self, TypeAlias +else: + from typing import Literal, LiteralString, Self, TypeAlias + +from _typeshed import SupportsAllComparisons + +__all__ = [ + '__version__', + 'build_profile', + 'build_info', + '_recursion_limit', + 'ArgsKwargs', + 'SchemaValidator', + 'SchemaSerializer', + 'Url', + 'MultiHostUrl', + 'SchemaError', + 'ValidationError', + 'PydanticCustomError', + 'PydanticKnownError', + 'PydanticOmit', + 'PydanticUseDefault', + 'PydanticSerializationError', + 'PydanticSerializationUnexpectedValue', + 'PydanticUndefined', + 'PydanticUndefinedType', + 'Some', + 'to_json', + 'from_json', + 'to_jsonable_python', + 'list_all_errors', + 'TzInfo', + 'validate_core_schema', +] +__version__: str +build_profile: str +build_info: str +_recursion_limit: int + +_T = TypeVar('_T', default=Any, covariant=True) + +_StringInput: TypeAlias = 'dict[str, _StringInput]' + +@final +class Some(Generic[_T]): + """ + Similar to Rust's [`Option::Some`](https://doc.rust-lang.org/std/option/enum.Option.html) type, this + identifies a value as being present, and provides a way to access it. + + Generally used in a union with `None` to different between "some value which could be None" and no value. + """ + + __match_args__ = ('value',) + + @property + def value(self) -> _T: + """ + Returns the value wrapped by `Some`. + """ + @classmethod + def __class_getitem__(cls, __item: Any) -> Type[Self]: ... + +@final +class SchemaValidator: + """ + `SchemaValidator` is the Python wrapper for `pydantic-core`'s Rust validation logic, internally it owns one + `CombinedValidator` which may in turn own more `CombinedValidator`s which make up the full schema validator. + """ + + def __new__(cls, schema: CoreSchema, config: CoreConfig | None = None) -> Self: + """ + Create a new SchemaValidator. + + Arguments: + schema: The [`CoreSchema`][pydantic_core.core_schema.CoreSchema] to use for validation. + config: Optionally a [`CoreConfig`][pydantic_core.core_schema.CoreConfig] to configure validation. + """ + @property + def title(self) -> str: + """ + The title of the schema, as used in the heading of [`ValidationError.__str__()`][pydantic_core.ValidationError]. + """ + def validate_python( + self, + input: Any, + *, + strict: bool | None = None, + from_attributes: bool | None = None, + context: 'dict[str, Any] | None' = None, + self_instance: Any | None = None, + ) -> Any: + """ + Validate a Python object against the schema and return the validated object. + + Arguments: + input: The Python object to validate. + strict: Whether to validate the object in strict mode. + If `None`, the value of [`CoreConfig.strict`][pydantic_core.core_schema.CoreConfig] is used. + from_attributes: Whether to validate objects as inputs to models by extracting attributes. + If `None`, the value of [`CoreConfig.from_attributes`][pydantic_core.core_schema.CoreConfig] is used. + context: The context to use for validation, this is passed to functional validators as + [`info.context`][pydantic_core.core_schema.ValidationInfo.context]. + self_instance: An instance of a model set attributes on from validation, this is used when running + validation from the `__init__` method of a model. + + Raises: + ValidationError: If validation fails. + Exception: Other error types maybe raised if internal errors occur. + + Returns: + The validated object. + """ + def isinstance_python( + self, + input: Any, + *, + strict: bool | None = None, + from_attributes: bool | None = None, + context: 'dict[str, Any] | None' = None, + self_instance: Any | None = None, + ) -> bool: + """ + Similar to [`validate_python()`][pydantic_core.SchemaValidator.validate_python] but returns a boolean. + + Arguments match `validate_python()`. This method will not raise `ValidationError`s but will raise internal + errors. + + Returns: + `True` if validation succeeds, `False` if validation fails. + """ + def validate_json( + self, + input: str | bytes | bytearray, + *, + strict: bool | None = None, + context: 'dict[str, Any] | None' = None, + self_instance: Any | None = None, + ) -> Any: + """ + Validate JSON data directly against the schema and return the validated Python object. + + This method should be significantly faster than `validate_python(json.loads(json_data))` as it avoids the + need to create intermediate Python objects + + It also handles constructing the correct Python type even in strict mode, where + `validate_python(json.loads(json_data))` would fail validation. + + Arguments: + input: The JSON data to validate. + strict: Whether to validate the object in strict mode. + If `None`, the value of [`CoreConfig.strict`][pydantic_core.core_schema.CoreConfig] is used. + context: The context to use for validation, this is passed to functional validators as + [`info.context`][pydantic_core.core_schema.ValidationInfo.context]. + self_instance: An instance of a model set attributes on from validation. + + Raises: + ValidationError: If validation fails or if the JSON data is invalid. + Exception: Other error types maybe raised if internal errors occur. + + Returns: + The validated Python object. + """ + def validate_strings( + self, input: _StringInput, *, strict: bool | None = None, context: 'dict[str, Any] | None' = None + ) -> Any: + """ + Validate a string against the schema and return the validated Python object. + + This is similar to `validate_json` but applies to scenarios where the input will be a string but not + JSON data, e.g. URL fragments, query parameters, etc. + + Arguments: + input: The input as a string, or bytes/bytearray if `strict=False`. + strict: Whether to validate the object in strict mode. + If `None`, the value of [`CoreConfig.strict`][pydantic_core.core_schema.CoreConfig] is used. + context: The context to use for validation, this is passed to functional validators as + [`info.context`][pydantic_core.core_schema.ValidationInfo.context]. + + Raises: + ValidationError: If validation fails or if the JSON data is invalid. + Exception: Other error types maybe raised if internal errors occur. + + Returns: + The validated Python object. + """ + def validate_assignment( + self, + obj: Any, + field_name: str, + field_value: Any, + *, + strict: bool | None = None, + from_attributes: bool | None = None, + context: 'dict[str, Any] | None' = None, + ) -> dict[str, Any] | tuple[dict[str, Any], dict[str, Any] | None, set[str]]: + """ + Validate an assignment to a field on a model. + + Arguments: + obj: The model instance being assigned to. + field_name: The name of the field to validate assignment for. + field_value: The value to assign to the field. + strict: Whether to validate the object in strict mode. + If `None`, the value of [`CoreConfig.strict`][pydantic_core.core_schema.CoreConfig] is used. + from_attributes: Whether to validate objects as inputs to models by extracting attributes. + If `None`, the value of [`CoreConfig.from_attributes`][pydantic_core.core_schema.CoreConfig] is used. + context: The context to use for validation, this is passed to functional validators as + [`info.context`][pydantic_core.core_schema.ValidationInfo.context]. + + Raises: + ValidationError: If validation fails. + Exception: Other error types maybe raised if internal errors occur. + + Returns: + Either the model dict or a tuple of `(model_data, model_extra, fields_set)` + """ + def get_default_value(self, *, strict: bool | None = None, context: Any = None) -> Some | None: + """ + Get the default value for the schema, including running default value validation. + + Arguments: + strict: Whether to validate the default value in strict mode. + If `None`, the value of [`CoreConfig.strict`][pydantic_core.core_schema.CoreConfig] is used. + context: The context to use for validation, this is passed to functional validators as + [`info.context`][pydantic_core.core_schema.ValidationInfo.context]. + + Raises: + ValidationError: If validation fails. + Exception: Other error types maybe raised if internal errors occur. + + Returns: + `None` if the schema has no default value, otherwise a [`Some`][pydantic_core.Some] containing the default. + """ + +_IncEx: TypeAlias = set[int] | set[str] | dict[int, _IncEx] | dict[str, _IncEx] | None + +@final +class SchemaSerializer: + """ + `SchemaSerializer` is the Python wrapper for `pydantic-core`'s Rust serialization logic, internally it owns one + `CombinedSerializer` which may in turn own more `CombinedSerializer`s which make up the full schema serializer. + """ + + def __new__(cls, schema: CoreSchema, config: CoreConfig | None = None) -> Self: + """ + Create a new SchemaSerializer. + + Arguments: + schema: The [`CoreSchema`][pydantic_core.core_schema.CoreSchema] to use for serialization. + config: Optionally a [`CoreConfig`][pydantic_core.core_schema.CoreConfig] to to configure serialization. + """ + def to_python( + self, + value: Any, + *, + mode: str | None = None, + include: _IncEx = None, + exclude: _IncEx = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + round_trip: bool = False, + warnings: bool = True, + fallback: Callable[[Any], Any] | None = None, + ) -> Any: + """ + Serialize/marshal a Python object to a Python object including transforming and filtering data. + + Arguments: + value: The Python object to serialize. + mode: The serialization mode to use, either `'python'` or `'json'`, defaults to `'python'`. In JSON mode, + all values are converted to JSON compatible types, e.g. `None`, `int`, `float`, `str`, `list`, `dict`. + include: A set of fields to include, if `None` all fields are included. + exclude: A set of fields to exclude, if `None` no fields are excluded. + by_alias: Whether to use the alias names of fields. + exclude_unset: Whether to exclude fields that are not set, + e.g. are not included in `__pydantic_fields_set__`. + exclude_defaults: Whether to exclude fields that are equal to their default value. + exclude_none: Whether to exclude fields that have a value of `None`. + round_trip: Whether to enable serialization and validation round-trip support. + warnings: Whether to log warnings when invalid fields are encountered. + fallback: A function to call when an unknown value is encountered, + if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. + + Raises: + PydanticSerializationError: If serialization fails and no `fallback` function is provided. + + Returns: + The serialized Python object. + """ + def to_json( + self, + value: Any, + *, + indent: int | None = None, + include: _IncEx = None, + exclude: _IncEx = None, + by_alias: bool = True, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + round_trip: bool = False, + warnings: bool = True, + fallback: Callable[[Any], Any] | None = None, + ) -> bytes: + """ + Serialize a Python object to JSON including transforming and filtering data. + + Arguments: + value: The Python object to serialize. + indent: If `None`, the JSON will be compact, otherwise it will be pretty-printed with the indent provided. + include: A set of fields to include, if `None` all fields are included. + exclude: A set of fields to exclude, if `None` no fields are excluded. + by_alias: Whether to use the alias names of fields. + exclude_unset: Whether to exclude fields that are not set, + e.g. are not included in `__pydantic_fields_set__`. + exclude_defaults: Whether to exclude fields that are equal to their default value. + exclude_none: Whether to exclude fields that have a value of `None`. + round_trip: Whether to enable serialization and validation round-trip support. + warnings: Whether to log warnings when invalid fields are encountered. + fallback: A function to call when an unknown value is encountered, + if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. + + Raises: + PydanticSerializationError: If serialization fails and no `fallback` function is provided. + + Returns: + JSON bytes. + """ + +def to_json( + value: Any, + *, + indent: int | None = None, + include: _IncEx = None, + exclude: _IncEx = None, + by_alias: bool = True, + exclude_none: bool = False, + round_trip: bool = False, + timedelta_mode: Literal['iso8601', 'float'] = 'iso8601', + bytes_mode: Literal['utf8', 'base64'] = 'utf8', + inf_nan_mode: Literal['null', 'constants'] = 'constants', + serialize_unknown: bool = False, + fallback: Callable[[Any], Any] | None = None, +) -> bytes: + """ + Serialize a Python object to JSON including transforming and filtering data. + + This is effectively a standalone version of [`SchemaSerializer.to_json`][pydantic_core.SchemaSerializer.to_json]. + + Arguments: + value: The Python object to serialize. + indent: If `None`, the JSON will be compact, otherwise it will be pretty-printed with the indent provided. + include: A set of fields to include, if `None` all fields are included. + exclude: A set of fields to exclude, if `None` no fields are excluded. + by_alias: Whether to use the alias names of fields. + exclude_none: Whether to exclude fields that have a value of `None`. + round_trip: Whether to enable serialization and validation round-trip support. + timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`. + bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`. + inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'` or `'constants'`. + serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails + `""` will be used. + fallback: A function to call when an unknown value is encountered, + if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. + + Raises: + PydanticSerializationError: If serialization fails and no `fallback` function is provided. + + Returns: + JSON bytes. + """ + +def from_json(data: str | bytes | bytearray, *, allow_inf_nan: bool = True, cache_strings: bool = True) -> Any: + """ + Deserialize JSON data to a Python object. + + This is effectively a faster version of `json.loads()`. + + Arguments: + data: The JSON data to deserialize. + allow_inf_nan: Whether to allow `Infinity`, `-Infinity` and `NaN` values as `json.loads()` does by default. + cache_strings: Whether to cache strings to avoid constructing new Python objects, + this should have a significant impact on performance while increasing memory usage slightly. + + Raises: + ValueError: If deserialization fails. + + Returns: + The deserialized Python object. + """ + +def to_jsonable_python( + value: Any, + *, + include: _IncEx = None, + exclude: _IncEx = None, + by_alias: bool = True, + exclude_none: bool = False, + round_trip: bool = False, + timedelta_mode: Literal['iso8601', 'float'] = 'iso8601', + bytes_mode: Literal['utf8', 'base64'] = 'utf8', + inf_nan_mode: Literal['null', 'constants'] = 'constants', + serialize_unknown: bool = False, + fallback: Callable[[Any], Any] | None = None, +) -> Any: + """ + Serialize/marshal a Python object to a JSON-serializable Python object including transforming and filtering data. + + This is effectively a standalone version of + [`SchemaSerializer.to_python(mode='json')`][pydantic_core.SchemaSerializer.to_python]. + + Args: + value: The Python object to serialize. + include: A set of fields to include, if `None` all fields are included. + exclude: A set of fields to exclude, if `None` no fields are excluded. + by_alias: Whether to use the alias names of fields. + exclude_none: Whether to exclude fields that have a value of `None`. + round_trip: Whether to enable serialization and validation round-trip support. + timedelta_mode: How to serialize `timedelta` objects, either `'iso8601'` or `'float'`. + bytes_mode: How to serialize `bytes` objects, either `'utf8'` or `'base64'`. + inf_nan_mode: How to serialize `Infinity`, `-Infinity` and `NaN` values, either `'null'` or `'constants'`. + serialize_unknown: Attempt to serialize unknown types, `str(value)` will be used, if that fails + `""` will be used. + fallback: A function to call when an unknown value is encountered, + if `None` a [`PydanticSerializationError`][pydantic_core.PydanticSerializationError] error is raised. + + Raises: + PydanticSerializationError: If serialization fails and no `fallback` function is provided. + + Returns: + The serialized Python object. + """ + +class Url(SupportsAllComparisons): + """ + A URL type, internal logic uses the [url rust crate](https://docs.rs/url/latest/url/) originally developed + by Mozilla. + """ + + def __new__(cls, url: str) -> Self: + """ + Create a new `Url` instance. + + Args: + url: String representation of a URL. + + Returns: + A new `Url` instance. + + Raises: + ValidationError: If the URL is invalid. + """ + @property + def scheme(self) -> str: + """ + The scheme part of the URL. + + e.g. `https` in `https://user:pass@host:port/path?query#fragment` + """ + @property + def username(self) -> str | None: + """ + The username part of the URL, or `None`. + + e.g. `user` in `https://user:pass@host:port/path?query#fragment` + """ + @property + def password(self) -> str | None: + """ + The password part of the URL, or `None`. + + e.g. `pass` in `https://user:pass@host:port/path?query#fragment` + """ + @property + def host(self) -> str | None: + """ + The host part of the URL, or `None`. + + If the URL must be punycode encoded, this is the encoded host, e.g if the input URL is `https://£££.com`, + `host` will be `xn--9aaa.com` + """ + def unicode_host(self) -> str | None: + """ + The host part of the URL as a unicode string, or `None`. + + e.g. `host` in `https://user:pass@host:port/path?query#fragment` + + If the URL must be punycode encoded, this is the decoded host, e.g if the input URL is `https://£££.com`, + `unicode_host()` will be `£££.com` + """ + @property + def port(self) -> int | None: + """ + The port part of the URL, or `None`. + + e.g. `port` in `https://user:pass@host:port/path?query#fragment` + """ + @property + def path(self) -> str | None: + """ + The path part of the URL, or `None`. + + e.g. `/path` in `https://user:pass@host:port/path?query#fragment` + """ + @property + def query(self) -> str | None: + """ + The query part of the URL, or `None`. + + e.g. `query` in `https://user:pass@host:port/path?query#fragment` + """ + def query_params(self) -> list[tuple[str, str]]: + """ + The query part of the URL as a list of key-value pairs. + + e.g. `[('foo', 'bar')]` in `https://user:pass@host:port/path?foo=bar#fragment` + """ + @property + def fragment(self) -> str | None: + """ + The fragment part of the URL, or `None`. + + e.g. `fragment` in `https://user:pass@host:port/path?query#fragment` + """ + def unicode_string(self) -> str: + """ + The URL as a unicode string, unlike `__str__()` this will not punycode encode the host. + + If the URL must be punycode encoded, this is the decoded string, e.g if the input URL is `https://£££.com`, + `unicode_string()` will be `https://£££.com` + """ + def __repr__(self) -> str: ... + def __str__(self) -> str: + """ + The URL as a string, this will punycode encode the host if required. + """ + def __deepcopy__(self, memo: dict) -> str: ... + @classmethod + def build( + cls, + *, + scheme: str, + username: Optional[str] = None, + password: Optional[str] = None, + host: str, + port: Optional[int] = None, + path: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, + ) -> Self: + """ + Build a new `Url` instance from its component parts. + + Args: + scheme: The scheme part of the URL. + username: The username part of the URL, or omit for no username. + password: The password part of the URL, or omit for no password. + host: The host part of the URL. + port: The port part of the URL, or omit for no port. + path: The path part of the URL, or omit for no path. + query: The query part of the URL, or omit for no query. + fragment: The fragment part of the URL, or omit for no fragment. + + Returns: + An instance of URL + """ + +class MultiHostUrl(SupportsAllComparisons): + """ + A URL type with support for multiple hosts, as used by some databases for DSNs, e.g. `https://foo.com,bar.com/path`. + + Internal URL logic uses the [url rust crate](https://docs.rs/url/latest/url/) originally developed + by Mozilla. + """ + + def __new__(cls, url: str) -> Self: + """ + Create a new `MultiHostUrl` instance. + + Args: + url: String representation of a URL. + + Returns: + A new `MultiHostUrl` instance. + + Raises: + ValidationError: If the URL is invalid. + """ + @property + def scheme(self) -> str: + """ + The scheme part of the URL. + + e.g. `https` in `https://foo.com,bar.com/path?query#fragment` + """ + @property + def path(self) -> str | None: + """ + The path part of the URL, or `None`. + + e.g. `/path` in `https://foo.com,bar.com/path?query#fragment` + """ + @property + def query(self) -> str | None: + """ + The query part of the URL, or `None`. + + e.g. `query` in `https://foo.com,bar.com/path?query#fragment` + """ + def query_params(self) -> list[tuple[str, str]]: + """ + The query part of the URL as a list of key-value pairs. + + e.g. `[('foo', 'bar')]` in `https://foo.com,bar.com/path?query#fragment` + """ + @property + def fragment(self) -> str | None: + """ + The fragment part of the URL, or `None`. + + e.g. `fragment` in `https://foo.com,bar.com/path?query#fragment` + """ + def hosts(self) -> list[MultiHostHost]: + ''' + + The hosts of the `MultiHostUrl` as [`MultiHostHost`][pydantic_core.MultiHostHost] typed dicts. + + ```py + from pydantic_core import MultiHostUrl + + mhu = MultiHostUrl('https://foo.com:123,foo:bar@bar.com/path') + print(mhu.hosts()) + """ + [ + {'username': None, 'password': None, 'host': 'foo.com', 'port': 123}, + {'username': 'foo', 'password': 'bar', 'host': 'bar.com', 'port': 443} + ] + ``` + Returns: + A list of dicts, each representing a host. + ''' + def unicode_string(self) -> str: + """ + The URL as a unicode string, unlike `__str__()` this will not punycode encode the hosts. + """ + def __repr__(self) -> str: ... + def __str__(self) -> str: + """ + The URL as a string, this will punycode encode the hosts if required. + """ + def __deepcopy__(self, memo: dict) -> Self: ... + @classmethod + def build( + cls, + *, + scheme: str, + hosts: Optional[list[MultiHostHost]] = None, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + path: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, + ) -> Self: + """ + Build a new `MultiHostUrl` instance from its component parts. + + This method takes either `hosts` - a list of `MultiHostHost` typed dicts, or the individual components + `username`, `password`, `host` and `port`. + + Args: + scheme: The scheme part of the URL. + hosts: Multiple hosts to build the URL from. + username: The username part of the URL. + password: The password part of the URL. + host: The host part of the URL. + port: The port part of the URL. + path: The path part of the URL. + query: The query part of the URL, or omit for no query. + fragment: The fragment part of the URL, or omit for no fragment. + + Returns: + An instance of `MultiHostUrl` + """ + +@final +class SchemaError(Exception): + """ + Information about errors that occur while building a [`SchemaValidator`][pydantic_core.SchemaValidator] + or [`SchemaSerializer`][pydantic_core.SchemaSerializer]. + """ + + def error_count(self) -> int: + """ + Returns: + The number of errors in the schema. + """ + def errors(self) -> list[ErrorDetails]: + """ + Returns: + A list of [`ErrorDetails`][pydantic_core.ErrorDetails] for each error in the schema. + """ + +@final +class ValidationError(ValueError): + """ + `ValidationError` is the exception raised by `pydantic-core` when validation fails, it contains a list of errors + which detail why validation failed. + """ + + @staticmethod + def from_exception_data( + title: str, + line_errors: list[InitErrorDetails], + input_type: Literal['python', 'json'] = 'python', + hide_input: bool = False, + ) -> ValidationError: + """ + Python constructor for a Validation Error. + + The API for constructing validation errors will probably change in the future, + hence the static method rather than `__init__`. + + Arguments: + title: The title of the error, as used in the heading of `str(validation_error)` + line_errors: A list of [`InitErrorDetails`][pydantic_core.InitErrorDetails] which contain information + about errors that occurred during validation. + input_type: Whether the error is for a Python object or JSON. + hide_input: Whether to hide the input value in the error message. + """ + @property + def title(self) -> str: + """ + The title of the error, as used in the heading of `str(validation_error)`. + """ + def error_count(self) -> int: + """ + Returns: + The number of errors in the validation error. + """ + def errors( + self, *, include_url: bool = True, include_context: bool = True, include_input: bool = True + ) -> list[ErrorDetails]: + """ + Details about each error in the validation error. + + Args: + include_url: Whether to include a URL to documentation on the error each error. + include_context: Whether to include the context of each error. + include_input: Whether to include the input value of each error. + + Returns: + A list of [`ErrorDetails`][pydantic_core.ErrorDetails] for each error in the validation error. + """ + def json( + self, + *, + indent: int | None = None, + include_url: bool = True, + include_context: bool = True, + include_input: bool = True, + ) -> str: + """ + Same as [`errors()`][pydantic_core.ValidationError.errors] but returns a JSON string. + + Args: + indent: The number of spaces to indent the JSON by, or `None` for no indentation - compact JSON. + include_url: Whether to include a URL to documentation on the error each error. + include_context: Whether to include the context of each error. + include_input: Whether to include the input value of each error. + + Returns: + a JSON string. + """ + + def __repr__(self) -> str: + """ + A string representation of the validation error. + + Whether or not documentation URLs are included in the repr is controlled by the + environment variable `PYDANTIC_ERRORS_INCLUDE_URL` being set to `1` or + `true`; by default, URLs are shown. + + Due to implementation details, this environment variable can only be set once, + before the first validation error is created. + """ + +@final +class PydanticCustomError(ValueError): + def __new__( + cls, error_type: LiteralString, message_template: LiteralString, context: dict[str, Any] | None = None + ) -> Self: ... + @property + def context(self) -> dict[str, Any] | None: ... + @property + def type(self) -> str: ... + @property + def message_template(self) -> str: ... + def message(self) -> str: ... + +@final +class PydanticKnownError(ValueError): + def __new__(cls, error_type: ErrorType, context: dict[str, Any] | None = None) -> Self: ... + @property + def context(self) -> dict[str, Any] | None: ... + @property + def type(self) -> ErrorType: ... + @property + def message_template(self) -> str: ... + def message(self) -> str: ... + +@final +class PydanticOmit(Exception): + def __new__(cls) -> Self: ... + +@final +class PydanticUseDefault(Exception): + def __new__(cls) -> Self: ... + +@final +class PydanticSerializationError(ValueError): + def __new__(cls, message: str) -> Self: ... + +@final +class PydanticSerializationUnexpectedValue(ValueError): + def __new__(cls, message: str | None = None) -> Self: ... + +@final +class ArgsKwargs: + def __new__(cls, args: tuple[Any, ...], kwargs: dict[str, Any] | None = None) -> Self: ... + @property + def args(self) -> tuple[Any, ...]: ... + @property + def kwargs(self) -> dict[str, Any] | None: ... + +@final +class PydanticUndefinedType: + def __copy__(self) -> Self: ... + def __deepcopy__(self, memo: Any) -> Self: ... + +PydanticUndefined: PydanticUndefinedType + +def list_all_errors() -> list[ErrorTypeInfo]: + """ + Get information about all built-in errors. + + Returns: + A list of `ErrorTypeInfo` typed dicts. + """ +@final +class TzInfo(datetime.tzinfo): + def tzname(self, _dt: datetime.datetime | None) -> str | None: ... + def utcoffset(self, _dt: datetime.datetime | None) -> datetime.timedelta: ... + def dst(self, _dt: datetime.datetime | None) -> datetime.timedelta: ... + def fromutc(self, dt: datetime.datetime) -> datetime.datetime: ... + def __deepcopy__(self, _memo: dict[Any, Any]) -> 'TzInfo': ... + +def validate_core_schema(schema: CoreSchema, *, strict: bool | None = None) -> CoreSchema: + """Validate a CoreSchema + This currently uses lax mode for validation (i.e. will coerce strings to dates and such) + but may use strict mode in the future. + We may also remove this function altogether, do not rely on it being present if you are + using pydantic-core directly. + """ diff --git a/lib/pydantic_core/core_schema.py b/lib/pydantic_core/core_schema.py new file mode 100644 index 00000000..31bf4878 --- /dev/null +++ b/lib/pydantic_core/core_schema.py @@ -0,0 +1,3980 @@ +""" +This module contains definitions to build schemas which `pydantic_core` can +validate and serialize. +""" + +from __future__ import annotations as _annotations + +import sys +import warnings +from collections.abc import Mapping +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Set, Tuple, Type, Union + +from typing_extensions import deprecated + +if sys.version_info < (3, 12): + from typing_extensions import TypedDict +else: + from typing import TypedDict + +if sys.version_info < (3, 11): + from typing_extensions import Protocol, Required, TypeAlias +else: + from typing import Protocol, Required, TypeAlias + +if sys.version_info < (3, 9): + from typing_extensions import Literal +else: + from typing import Literal + +if TYPE_CHECKING: + from pydantic_core import PydanticUndefined +else: + # The initial build of pydantic_core requires PydanticUndefined to generate + # the core schema; so we need to conditionally skip it. mypy doesn't like + # this at all, hence the TYPE_CHECKING branch above. + try: + from pydantic_core import PydanticUndefined + except ImportError: + PydanticUndefined = object() + + +ExtraBehavior = Literal['allow', 'forbid', 'ignore'] + + +class CoreConfig(TypedDict, total=False): + """ + Base class for schema configuration options. + + Attributes: + title: The name of the configuration. + strict: Whether the configuration should strictly adhere to specified rules. + extra_fields_behavior: The behavior for handling extra fields. + typed_dict_total: Whether the TypedDict should be considered total. Default is `True`. + from_attributes: Whether to use attributes for models, dataclasses, and tagged union keys. + loc_by_alias: Whether to use the used alias (or first alias for "field required" errors) instead of + `field_names` to construct error `loc`s. Default is `True`. + revalidate_instances: Whether instances of models and dataclasses should re-validate. Default is 'never'. + validate_default: Whether to validate default values during validation. Default is `False`. + populate_by_name: Whether an aliased field may be populated by its name as given by the model attribute, + as well as the alias. (Replaces 'allow_population_by_field_name' in Pydantic v1.) Default is `False`. + str_max_length: The maximum length for string fields. + str_min_length: The minimum length for string fields. + str_strip_whitespace: Whether to strip whitespace from string fields. + str_to_lower: Whether to convert string fields to lowercase. + str_to_upper: Whether to convert string fields to uppercase. + allow_inf_nan: Whether to allow infinity and NaN values for float fields. Default is `True`. + ser_json_timedelta: The serialization option for `timedelta` values. Default is 'iso8601'. + ser_json_bytes: The serialization option for `bytes` values. Default is 'utf8'. + ser_json_inf_nan: The serialization option for infinity and NaN values + in float fields. Default is 'null'. + hide_input_in_errors: Whether to hide input data from `ValidationError` representation. + validation_error_cause: Whether to add user-python excs to the __cause__ of a ValidationError. + Requires exceptiongroup backport pre Python 3.11. + coerce_numbers_to_str: Whether to enable coercion of any `Number` type to `str` (not applicable in `strict` mode). + regex_engine: The regex engine to use for regex pattern validation. Default is 'rust-regex'. See `StringSchema`. + """ + + title: str + strict: bool + # settings related to typed dicts, model fields, dataclass fields + extra_fields_behavior: ExtraBehavior + typed_dict_total: bool # default: True + # used for models, dataclasses, and tagged union keys + from_attributes: bool + # whether to use the used alias (or first alias for "field required" errors) instead of field_names + # to construct error `loc`s, default True + loc_by_alias: bool + # whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never' + revalidate_instances: Literal['always', 'never', 'subclass-instances'] + # whether to validate default values during validation, default False + validate_default: bool + # used on typed-dicts and arguments + populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1 + # fields related to string fields only + str_max_length: int + str_min_length: int + str_strip_whitespace: bool + str_to_lower: bool + str_to_upper: bool + # fields related to float fields only + allow_inf_nan: bool # default: True + # the config options are used to customise serialization to JSON + ser_json_timedelta: Literal['iso8601', 'float'] # default: 'iso8601' + ser_json_bytes: Literal['utf8', 'base64', 'hex'] # default: 'utf8' + ser_json_inf_nan: Literal['null', 'constants'] # default: 'null' + # used to hide input data from ValidationError repr + hide_input_in_errors: bool + validation_error_cause: bool # default: False + coerce_numbers_to_str: bool # default: False + regex_engine: Literal['rust-regex', 'python-re'] # default: 'rust-regex' + + +IncExCall: TypeAlias = 'set[int | str] | dict[int | str, IncExCall] | None' + + +class SerializationInfo(Protocol): + @property + def include(self) -> IncExCall: + ... + + @property + def exclude(self) -> IncExCall: + ... + + @property + def mode(self) -> str: + ... + + @property + def by_alias(self) -> bool: + ... + + @property + def exclude_unset(self) -> bool: + ... + + @property + def exclude_defaults(self) -> bool: + ... + + @property + def exclude_none(self) -> bool: + ... + + @property + def round_trip(self) -> bool: + ... + + def mode_is_json(self) -> bool: + ... + + def __str__(self) -> str: + ... + + def __repr__(self) -> str: + ... + + +class FieldSerializationInfo(SerializationInfo, Protocol): + @property + def field_name(self) -> str: + ... + + +class ValidationInfo(Protocol): + """ + Argument passed to validation functions. + """ + + @property + def context(self) -> Any | None: + """Current validation context.""" + ... + + @property + def config(self) -> CoreConfig | None: + """The CoreConfig that applies to this validation.""" + ... + + @property + def mode(self) -> Literal['python', 'json']: + """The type of input data we are currently validating""" + ... + + @property + def data(self) -> Dict[str, Any]: + """The data being validated for this model.""" + ... + + @property + def field_name(self) -> str | None: + """ + The name of the current field being validated if this validator is + attached to a model field. + """ + ... + + +ExpectedSerializationTypes = Literal[ + 'none', + 'int', + 'bool', + 'float', + 'str', + 'bytes', + 'bytearray', + 'list', + 'tuple', + 'set', + 'frozenset', + 'generator', + 'dict', + 'datetime', + 'date', + 'time', + 'timedelta', + 'url', + 'multi-host-url', + 'json', + 'uuid', +] + + +class SimpleSerSchema(TypedDict, total=False): + type: Required[ExpectedSerializationTypes] + + +def simple_ser_schema(type: ExpectedSerializationTypes) -> SimpleSerSchema: + """ + Returns a schema for serialization with a custom type. + + Args: + type: The type to use for serialization + """ + return SimpleSerSchema(type=type) + + +# (__input_value: Any) -> Any +GeneralPlainNoInfoSerializerFunction = Callable[[Any], Any] +# (__input_value: Any, __info: FieldSerializationInfo) -> Any +GeneralPlainInfoSerializerFunction = Callable[[Any, SerializationInfo], Any] +# (__model: Any, __input_value: Any) -> Any +FieldPlainNoInfoSerializerFunction = Callable[[Any, Any], Any] +# (__model: Any, __input_value: Any, __info: FieldSerializationInfo) -> Any +FieldPlainInfoSerializerFunction = Callable[[Any, Any, FieldSerializationInfo], Any] +SerializerFunction = Union[ + GeneralPlainNoInfoSerializerFunction, + GeneralPlainInfoSerializerFunction, + FieldPlainNoInfoSerializerFunction, + FieldPlainInfoSerializerFunction, +] + +WhenUsed = Literal['always', 'unless-none', 'json', 'json-unless-none'] +""" +Values have the following meanings: + +* `'always'` means always use +* `'unless-none'` means use unless the value is `None` +* `'json'` means use when serializing to JSON +* `'json-unless-none'` means use when serializing to JSON and the value is not `None` +""" + + +class PlainSerializerFunctionSerSchema(TypedDict, total=False): + type: Required[Literal['function-plain']] + function: Required[SerializerFunction] + is_field_serializer: bool # default False + info_arg: bool # default False + return_schema: CoreSchema # if omitted, AnySchema is used + when_used: WhenUsed # default: 'always' + + +def plain_serializer_function_ser_schema( + function: SerializerFunction, + *, + is_field_serializer: bool | None = None, + info_arg: bool | None = None, + return_schema: CoreSchema | None = None, + when_used: WhenUsed = 'always', +) -> PlainSerializerFunctionSerSchema: + """ + Returns a schema for serialization with a function, can be either a "general" or "field" function. + + Args: + function: The function to use for serialization + is_field_serializer: Whether the serializer is for a field, e.g. takes `model` as the first argument, + and `info` includes `field_name` + info_arg: Whether the function takes an `__info` argument + return_schema: Schema to use for serializing return value + when_used: When the function should be called + """ + if when_used == 'always': + # just to avoid extra elements in schema, and to use the actual default defined in rust + when_used = None # type: ignore + return _dict_not_none( + type='function-plain', + function=function, + is_field_serializer=is_field_serializer, + info_arg=info_arg, + return_schema=return_schema, + when_used=when_used, + ) + + +class SerializerFunctionWrapHandler(Protocol): # pragma: no cover + def __call__(self, __input_value: Any, __index_key: int | str | None = None) -> Any: + ... + + +# (__input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any +GeneralWrapNoInfoSerializerFunction = Callable[[Any, SerializerFunctionWrapHandler], Any] +# (__input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: SerializationInfo) -> Any +GeneralWrapInfoSerializerFunction = Callable[[Any, SerializerFunctionWrapHandler, SerializationInfo], Any] +# (__model: Any, __input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any +FieldWrapNoInfoSerializerFunction = Callable[[Any, Any, SerializerFunctionWrapHandler], Any] +# (__model: Any, __input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: FieldSerializationInfo) -> Any +FieldWrapInfoSerializerFunction = Callable[[Any, Any, SerializerFunctionWrapHandler, FieldSerializationInfo], Any] +WrapSerializerFunction = Union[ + GeneralWrapNoInfoSerializerFunction, + GeneralWrapInfoSerializerFunction, + FieldWrapNoInfoSerializerFunction, + FieldWrapInfoSerializerFunction, +] + + +class WrapSerializerFunctionSerSchema(TypedDict, total=False): + type: Required[Literal['function-wrap']] + function: Required[WrapSerializerFunction] + is_field_serializer: bool # default False + info_arg: bool # default False + schema: CoreSchema # if omitted, the schema on which this serializer is defined is used + return_schema: CoreSchema # if omitted, AnySchema is used + when_used: WhenUsed # default: 'always' + + +def wrap_serializer_function_ser_schema( + function: WrapSerializerFunction, + *, + is_field_serializer: bool | None = None, + info_arg: bool | None = None, + schema: CoreSchema | None = None, + return_schema: CoreSchema | None = None, + when_used: WhenUsed = 'always', +) -> WrapSerializerFunctionSerSchema: + """ + Returns a schema for serialization with a wrap function, can be either a "general" or "field" function. + + Args: + function: The function to use for serialization + is_field_serializer: Whether the serializer is for a field, e.g. takes `model` as the first argument, + and `info` includes `field_name` + info_arg: Whether the function takes an `__info` argument + schema: The schema to use for the inner serialization + return_schema: Schema to use for serializing return value + when_used: When the function should be called + """ + if when_used == 'always': + # just to avoid extra elements in schema, and to use the actual default defined in rust + when_used = None # type: ignore + return _dict_not_none( + type='function-wrap', + function=function, + is_field_serializer=is_field_serializer, + info_arg=info_arg, + schema=schema, + return_schema=return_schema, + when_used=when_used, + ) + + +class FormatSerSchema(TypedDict, total=False): + type: Required[Literal['format']] + formatting_string: Required[str] + when_used: WhenUsed # default: 'json-unless-none' + + +def format_ser_schema(formatting_string: str, *, when_used: WhenUsed = 'json-unless-none') -> FormatSerSchema: + """ + Returns a schema for serialization using python's `format` method. + + Args: + formatting_string: String defining the format to use + when_used: Same meaning as for [general_function_plain_ser_schema], but with a different default + """ + if when_used == 'json-unless-none': + # just to avoid extra elements in schema, and to use the actual default defined in rust + when_used = None # type: ignore + return _dict_not_none(type='format', formatting_string=formatting_string, when_used=when_used) + + +class ToStringSerSchema(TypedDict, total=False): + type: Required[Literal['to-string']] + when_used: WhenUsed # default: 'json-unless-none' + + +def to_string_ser_schema(*, when_used: WhenUsed = 'json-unless-none') -> ToStringSerSchema: + """ + Returns a schema for serialization using python's `str()` / `__str__` method. + + Args: + when_used: Same meaning as for [general_function_plain_ser_schema], but with a different default + """ + s = dict(type='to-string') + if when_used != 'json-unless-none': + # just to avoid extra elements in schema, and to use the actual default defined in rust + s['when_used'] = when_used + return s # type: ignore + + +class ModelSerSchema(TypedDict, total=False): + type: Required[Literal['model']] + cls: Required[Type[Any]] + schema: Required[CoreSchema] + + +def model_ser_schema(cls: Type[Any], schema: CoreSchema) -> ModelSerSchema: + """ + Returns a schema for serialization using a model. + + Args: + cls: The expected class type, used to generate warnings if the wrong type is passed + schema: Internal schema to use to serialize the model dict + """ + return ModelSerSchema(type='model', cls=cls, schema=schema) + + +SerSchema = Union[ + SimpleSerSchema, + PlainSerializerFunctionSerSchema, + WrapSerializerFunctionSerSchema, + FormatSerSchema, + ToStringSerSchema, + ModelSerSchema, +] + + +class ComputedField(TypedDict, total=False): + type: Required[Literal['computed-field']] + property_name: Required[str] + return_schema: Required[CoreSchema] + alias: str + metadata: Any + + +def computed_field( + property_name: str, return_schema: CoreSchema, *, alias: str | None = None, metadata: Any = None +) -> ComputedField: + """ + ComputedFields are properties of a model or dataclass that are included in serialization. + + Args: + property_name: The name of the property on the model or dataclass + return_schema: The schema used for the type returned by the computed field + alias: The name to use in the serialized output + metadata: Any other information you want to include with the schema, not used by pydantic-core + """ + return _dict_not_none( + type='computed-field', property_name=property_name, return_schema=return_schema, alias=alias, metadata=metadata + ) + + +class AnySchema(TypedDict, total=False): + type: Required[Literal['any']] + ref: str + metadata: Any + serialization: SerSchema + + +def any_schema(*, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None) -> AnySchema: + """ + Returns a schema that matches any value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.any_schema() + v = SchemaValidator(schema) + assert v.validate_python(1) == 1 + ``` + + Args: + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none(type='any', ref=ref, metadata=metadata, serialization=serialization) + + +class NoneSchema(TypedDict, total=False): + type: Required[Literal['none']] + ref: str + metadata: Any + serialization: SerSchema + + +def none_schema(*, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None) -> NoneSchema: + """ + Returns a schema that matches a None value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.none_schema() + v = SchemaValidator(schema) + assert v.validate_python(None) is None + ``` + + Args: + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none(type='none', ref=ref, metadata=metadata, serialization=serialization) + + +class BoolSchema(TypedDict, total=False): + type: Required[Literal['bool']] + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def bool_schema( + strict: bool | None = None, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None +) -> BoolSchema: + """ + Returns a schema that matches a bool value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.bool_schema() + v = SchemaValidator(schema) + assert v.validate_python('True') is True + ``` + + Args: + strict: Whether the value should be a bool or a value that can be converted to a bool + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none(type='bool', strict=strict, ref=ref, metadata=metadata, serialization=serialization) + + +class IntSchema(TypedDict, total=False): + type: Required[Literal['int']] + multiple_of: int + le: int + ge: int + lt: int + gt: int + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def int_schema( + *, + multiple_of: int | None = None, + le: int | None = None, + ge: int | None = None, + lt: int | None = None, + gt: int | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> IntSchema: + """ + Returns a schema that matches a int value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.int_schema(multiple_of=2, le=6, ge=2) + v = SchemaValidator(schema) + assert v.validate_python('4') == 4 + ``` + + Args: + multiple_of: The value must be a multiple of this number + le: The value must be less than or equal to this number + ge: The value must be greater than or equal to this number + lt: The value must be strictly less than this number + gt: The value must be strictly greater than this number + strict: Whether the value should be a int or a value that can be converted to a int + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='int', + multiple_of=multiple_of, + le=le, + ge=ge, + lt=lt, + gt=gt, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class FloatSchema(TypedDict, total=False): + type: Required[Literal['float']] + allow_inf_nan: bool # whether 'NaN', '+inf', '-inf' should be forbidden. default: True + multiple_of: float + le: float + ge: float + lt: float + gt: float + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def float_schema( + *, + allow_inf_nan: bool | None = None, + multiple_of: float | None = None, + le: float | None = None, + ge: float | None = None, + lt: float | None = None, + gt: float | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> FloatSchema: + """ + Returns a schema that matches a float value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.float_schema(le=0.8, ge=0.2) + v = SchemaValidator(schema) + assert v.validate_python('0.5') == 0.5 + ``` + + Args: + allow_inf_nan: Whether to allow inf and nan values + multiple_of: The value must be a multiple of this number + le: The value must be less than or equal to this number + ge: The value must be greater than or equal to this number + lt: The value must be strictly less than this number + gt: The value must be strictly greater than this number + strict: Whether the value should be a float or a value that can be converted to a float + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='float', + allow_inf_nan=allow_inf_nan, + multiple_of=multiple_of, + le=le, + ge=ge, + lt=lt, + gt=gt, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class DecimalSchema(TypedDict, total=False): + type: Required[Literal['decimal']] + allow_inf_nan: bool # whether 'NaN', '+inf', '-inf' should be forbidden. default: False + multiple_of: Decimal + le: Decimal + ge: Decimal + lt: Decimal + gt: Decimal + max_digits: int + decimal_places: int + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def decimal_schema( + *, + allow_inf_nan: bool = None, + multiple_of: Decimal | None = None, + le: Decimal | None = None, + ge: Decimal | None = None, + lt: Decimal | None = None, + gt: Decimal | None = None, + max_digits: int | None = None, + decimal_places: int | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> DecimalSchema: + """ + Returns a schema that matches a decimal value, e.g.: + + ```py + from decimal import Decimal + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.decimal_schema(le=0.8, ge=0.2) + v = SchemaValidator(schema) + assert v.validate_python('0.5') == Decimal('0.5') + ``` + + Args: + allow_inf_nan: Whether to allow inf and nan values + multiple_of: The value must be a multiple of this number + le: The value must be less than or equal to this number + ge: The value must be greater than or equal to this number + lt: The value must be strictly less than this number + gt: The value must be strictly greater than this number + max_digits: The maximum number of decimal digits allowed + decimal_places: The maximum number of decimal places allowed + strict: Whether the value should be a float or a value that can be converted to a float + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='decimal', + gt=gt, + ge=ge, + lt=lt, + le=le, + max_digits=max_digits, + decimal_places=decimal_places, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class StringSchema(TypedDict, total=False): + type: Required[Literal['str']] + pattern: str + max_length: int + min_length: int + strip_whitespace: bool + to_lower: bool + to_upper: bool + regex_engine: Literal['rust-regex', 'python-re'] # default: 'rust-regex' + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def str_schema( + *, + pattern: str | None = None, + max_length: int | None = None, + min_length: int | None = None, + strip_whitespace: bool | None = None, + to_lower: bool | None = None, + to_upper: bool | None = None, + regex_engine: Literal['rust-regex', 'python-re'] | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> StringSchema: + """ + Returns a schema that matches a string value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.str_schema(max_length=10, min_length=2) + v = SchemaValidator(schema) + assert v.validate_python('hello') == 'hello' + ``` + + Args: + pattern: A regex pattern that the value must match + max_length: The value must be at most this length + min_length: The value must be at least this length + strip_whitespace: Whether to strip whitespace from the value + to_lower: Whether to convert the value to lowercase + to_upper: Whether to convert the value to uppercase + regex_engine: The regex engine to use for pattern validation. Default is 'rust-regex'. + - `rust-regex` uses the [`regex`](https://docs.rs/regex) Rust + crate, which is non-backtracking and therefore more DDoS + resistant, but does not support all regex features. + - `python-re` use the [`re`](https://docs.python.org/3/library/re.html) module, + which supports all regex features, but may be slower. + strict: Whether the value should be a string or a value that can be converted to a string + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='str', + pattern=pattern, + max_length=max_length, + min_length=min_length, + strip_whitespace=strip_whitespace, + to_lower=to_lower, + to_upper=to_upper, + regex_engine=regex_engine, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class BytesSchema(TypedDict, total=False): + type: Required[Literal['bytes']] + max_length: int + min_length: int + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def bytes_schema( + *, + max_length: int | None = None, + min_length: int | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> BytesSchema: + """ + Returns a schema that matches a bytes value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.bytes_schema(max_length=10, min_length=2) + v = SchemaValidator(schema) + assert v.validate_python(b'hello') == b'hello' + ``` + + Args: + max_length: The value must be at most this length + min_length: The value must be at least this length + strict: Whether the value should be a bytes or a value that can be converted to a bytes + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='bytes', + max_length=max_length, + min_length=min_length, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class DateSchema(TypedDict, total=False): + type: Required[Literal['date']] + strict: bool + le: date + ge: date + lt: date + gt: date + now_op: Literal['past', 'future'] + # defaults to current local utc offset from `time.localtime().tm_gmtoff` + # value is restricted to -86_400 < offset < 86_400 by bounds in generate_self_schema.py + now_utc_offset: int + ref: str + metadata: Any + serialization: SerSchema + + +def date_schema( + *, + strict: bool | None = None, + le: date | None = None, + ge: date | None = None, + lt: date | None = None, + gt: date | None = None, + now_op: Literal['past', 'future'] | None = None, + now_utc_offset: int | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> DateSchema: + """ + Returns a schema that matches a date value, e.g.: + + ```py + from datetime import date + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.date_schema(le=date(2020, 1, 1), ge=date(2019, 1, 1)) + v = SchemaValidator(schema) + assert v.validate_python(date(2019, 6, 1)) == date(2019, 6, 1) + ``` + + Args: + strict: Whether the value should be a date or a value that can be converted to a date + le: The value must be less than or equal to this date + ge: The value must be greater than or equal to this date + lt: The value must be strictly less than this date + gt: The value must be strictly greater than this date + now_op: The value must be in the past or future relative to the current date + now_utc_offset: The value must be in the past or future relative to the current date with this utc offset + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='date', + strict=strict, + le=le, + ge=ge, + lt=lt, + gt=gt, + now_op=now_op, + now_utc_offset=now_utc_offset, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class TimeSchema(TypedDict, total=False): + type: Required[Literal['time']] + strict: bool + le: time + ge: time + lt: time + gt: time + tz_constraint: Union[Literal['aware', 'naive'], int] + microseconds_precision: Literal['truncate', 'error'] + ref: str + metadata: Any + serialization: SerSchema + + +def time_schema( + *, + strict: bool | None = None, + le: time | None = None, + ge: time | None = None, + lt: time | None = None, + gt: time | None = None, + tz_constraint: Literal['aware', 'naive'] | int | None = None, + microseconds_precision: Literal['truncate', 'error'] = 'truncate', + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> TimeSchema: + """ + Returns a schema that matches a time value, e.g.: + + ```py + from datetime import time + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.time_schema(le=time(12, 0, 0), ge=time(6, 0, 0)) + v = SchemaValidator(schema) + assert v.validate_python(time(9, 0, 0)) == time(9, 0, 0) + ``` + + Args: + strict: Whether the value should be a time or a value that can be converted to a time + le: The value must be less than or equal to this time + ge: The value must be greater than or equal to this time + lt: The value must be strictly less than this time + gt: The value must be strictly greater than this time + tz_constraint: The value must be timezone aware or naive, or an int to indicate required tz offset + microseconds_precision: The behavior when seconds have more than 6 digits or microseconds is too large + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='time', + strict=strict, + le=le, + ge=ge, + lt=lt, + gt=gt, + tz_constraint=tz_constraint, + microseconds_precision=microseconds_precision, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class DatetimeSchema(TypedDict, total=False): + type: Required[Literal['datetime']] + strict: bool + le: datetime + ge: datetime + lt: datetime + gt: datetime + now_op: Literal['past', 'future'] + tz_constraint: Union[Literal['aware', 'naive'], int] + # defaults to current local utc offset from `time.localtime().tm_gmtoff` + # value is restricted to -86_400 < offset < 86_400 by bounds in generate_self_schema.py + now_utc_offset: int + microseconds_precision: Literal['truncate', 'error'] # default: 'truncate' + ref: str + metadata: Any + serialization: SerSchema + + +def datetime_schema( + *, + strict: bool | None = None, + le: datetime | None = None, + ge: datetime | None = None, + lt: datetime | None = None, + gt: datetime | None = None, + now_op: Literal['past', 'future'] | None = None, + tz_constraint: Literal['aware', 'naive'] | int | None = None, + now_utc_offset: int | None = None, + microseconds_precision: Literal['truncate', 'error'] = 'truncate', + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> DatetimeSchema: + """ + Returns a schema that matches a datetime value, e.g.: + + ```py + from datetime import datetime + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.datetime_schema() + v = SchemaValidator(schema) + now = datetime.now() + assert v.validate_python(str(now)) == now + ``` + + Args: + strict: Whether the value should be a datetime or a value that can be converted to a datetime + le: The value must be less than or equal to this datetime + ge: The value must be greater than or equal to this datetime + lt: The value must be strictly less than this datetime + gt: The value must be strictly greater than this datetime + now_op: The value must be in the past or future relative to the current datetime + tz_constraint: The value must be timezone aware or naive, or an int to indicate required tz offset + TODO: use of a tzinfo where offset changes based on the datetime is not yet supported + now_utc_offset: The value must be in the past or future relative to the current datetime with this utc offset + microseconds_precision: The behavior when seconds have more than 6 digits or microseconds is too large + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='datetime', + strict=strict, + le=le, + ge=ge, + lt=lt, + gt=gt, + now_op=now_op, + tz_constraint=tz_constraint, + now_utc_offset=now_utc_offset, + microseconds_precision=microseconds_precision, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class TimedeltaSchema(TypedDict, total=False): + type: Required[Literal['timedelta']] + strict: bool + le: timedelta + ge: timedelta + lt: timedelta + gt: timedelta + microseconds_precision: Literal['truncate', 'error'] + ref: str + metadata: Any + serialization: SerSchema + + +def timedelta_schema( + *, + strict: bool | None = None, + le: timedelta | None = None, + ge: timedelta | None = None, + lt: timedelta | None = None, + gt: timedelta | None = None, + microseconds_precision: Literal['truncate', 'error'] = 'truncate', + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> TimedeltaSchema: + """ + Returns a schema that matches a timedelta value, e.g.: + + ```py + from datetime import timedelta + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.timedelta_schema(le=timedelta(days=1), ge=timedelta(days=0)) + v = SchemaValidator(schema) + assert v.validate_python(timedelta(hours=12)) == timedelta(hours=12) + ``` + + Args: + strict: Whether the value should be a timedelta or a value that can be converted to a timedelta + le: The value must be less than or equal to this timedelta + ge: The value must be greater than or equal to this timedelta + lt: The value must be strictly less than this timedelta + gt: The value must be strictly greater than this timedelta + microseconds_precision: The behavior when seconds have more than 6 digits or microseconds is too large + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='timedelta', + strict=strict, + le=le, + ge=ge, + lt=lt, + gt=gt, + microseconds_precision=microseconds_precision, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class LiteralSchema(TypedDict, total=False): + type: Required[Literal['literal']] + expected: Required[List[Any]] + ref: str + metadata: Any + serialization: SerSchema + + +def literal_schema( + expected: list[Any], *, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None +) -> LiteralSchema: + """ + Returns a schema that matches a literal value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.literal_schema(['hello', 'world']) + v = SchemaValidator(schema) + assert v.validate_python('hello') == 'hello' + ``` + + Args: + expected: The value must be one of these values + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none(type='literal', expected=expected, ref=ref, metadata=metadata, serialization=serialization) + + +# must match input/parse_json.rs::JsonType::try_from +JsonType = Literal['null', 'bool', 'int', 'float', 'str', 'list', 'dict'] + + +class IsInstanceSchema(TypedDict, total=False): + type: Required[Literal['is-instance']] + cls: Required[Any] + cls_repr: str + ref: str + metadata: Any + serialization: SerSchema + + +def is_instance_schema( + cls: Any, + *, + cls_repr: str | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> IsInstanceSchema: + """ + Returns a schema that checks if a value is an instance of a class, equivalent to python's `isinstance` method, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + class A: + pass + + schema = core_schema.is_instance_schema(cls=A) + v = SchemaValidator(schema) + v.validate_python(A()) + ``` + + Args: + cls: The value must be an instance of this class + cls_repr: If provided this string is used in the validator name instead of `repr(cls)` + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='is-instance', cls=cls, cls_repr=cls_repr, ref=ref, metadata=metadata, serialization=serialization + ) + + +class IsSubclassSchema(TypedDict, total=False): + type: Required[Literal['is-subclass']] + cls: Required[Type[Any]] + cls_repr: str + ref: str + metadata: Any + serialization: SerSchema + + +def is_subclass_schema( + cls: Type[Any], + *, + cls_repr: str | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> IsInstanceSchema: + """ + Returns a schema that checks if a value is a subtype of a class, equivalent to python's `issubclass` method, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + class A: + pass + + class B(A): + pass + + schema = core_schema.is_subclass_schema(cls=A) + v = SchemaValidator(schema) + v.validate_python(B) + ``` + + Args: + cls: The value must be a subclass of this class + cls_repr: If provided this string is used in the validator name instead of `repr(cls)` + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='is-subclass', cls=cls, cls_repr=cls_repr, ref=ref, metadata=metadata, serialization=serialization + ) + + +class CallableSchema(TypedDict, total=False): + type: Required[Literal['callable']] + ref: str + metadata: Any + serialization: SerSchema + + +def callable_schema( + *, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None +) -> CallableSchema: + """ + Returns a schema that checks if a value is callable, equivalent to python's `callable` method, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.callable_schema() + v = SchemaValidator(schema) + v.validate_python(min) + ``` + + Args: + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none(type='callable', ref=ref, metadata=metadata, serialization=serialization) + + +class UuidSchema(TypedDict, total=False): + type: Required[Literal['uuid']] + version: Literal[1, 3, 4, 5] + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def uuid_schema( + *, + version: Literal[1, 3, 4, 5] | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> UuidSchema: + return _dict_not_none( + type='uuid', version=version, strict=strict, ref=ref, metadata=metadata, serialization=serialization + ) + + +class IncExSeqSerSchema(TypedDict, total=False): + type: Required[Literal['include-exclude-sequence']] + include: Set[int] + exclude: Set[int] + + +def filter_seq_schema(*, include: Set[int] | None = None, exclude: Set[int] | None = None) -> IncExSeqSerSchema: + return _dict_not_none(type='include-exclude-sequence', include=include, exclude=exclude) + + +IncExSeqOrElseSerSchema = Union[IncExSeqSerSchema, SerSchema] + + +class ListSchema(TypedDict, total=False): + type: Required[Literal['list']] + items_schema: CoreSchema + min_length: int + max_length: int + strict: bool + ref: str + metadata: Any + serialization: IncExSeqOrElseSerSchema + + +def list_schema( + items_schema: CoreSchema | None = None, + *, + min_length: int | None = None, + max_length: int | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: IncExSeqOrElseSerSchema | None = None, +) -> ListSchema: + """ + Returns a schema that matches a list value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.list_schema(core_schema.int_schema(), min_length=0, max_length=10) + v = SchemaValidator(schema) + assert v.validate_python(['4']) == [4] + ``` + + Args: + items_schema: The value must be a list of items that match this schema + min_length: The value must be a list with at least this many items + max_length: The value must be a list with at most this many items + strict: The value must be a list with exactly this many items + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='list', + items_schema=items_schema, + min_length=min_length, + max_length=max_length, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +# @deprecated('tuple_positional_schema is deprecated. Use pydantic_core.core_schema.tuple_schema instead.') +def tuple_positional_schema( + items_schema: list[CoreSchema], + *, + extras_schema: CoreSchema | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: IncExSeqOrElseSerSchema | None = None, +) -> TupleSchema: + """ + Returns a schema that matches a tuple of schemas, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.tuple_positional_schema( + [core_schema.int_schema(), core_schema.str_schema()] + ) + v = SchemaValidator(schema) + assert v.validate_python((1, 'hello')) == (1, 'hello') + ``` + + Args: + items_schema: The value must be a tuple with items that match these schemas + extras_schema: The value must be a tuple with items that match this schema + This was inspired by JSON schema's `prefixItems` and `items` fields. + In python's `typing.Tuple`, you can't specify a type for "extra" items -- they must all be the same type + if the length is variable. So this field won't be set from a `typing.Tuple` annotation on a pydantic model. + strict: The value must be a tuple with exactly this many items + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + if extras_schema is not None: + variadic_item_index = len(items_schema) + items_schema = items_schema + [extras_schema] + else: + variadic_item_index = None + return tuple_schema( + items_schema=items_schema, + variadic_item_index=variadic_item_index, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +# @deprecated('tuple_variable_schema is deprecated. Use pydantic_core.core_schema.tuple_schema instead.') +def tuple_variable_schema( + items_schema: CoreSchema | None = None, + *, + min_length: int | None = None, + max_length: int | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: IncExSeqOrElseSerSchema | None = None, +) -> TupleSchema: + """ + Returns a schema that matches a tuple of a given schema, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.tuple_variable_schema( + items_schema=core_schema.int_schema(), min_length=0, max_length=10 + ) + v = SchemaValidator(schema) + assert v.validate_python(('1', 2, 3)) == (1, 2, 3) + ``` + + Args: + items_schema: The value must be a tuple with items that match this schema + min_length: The value must be a tuple with at least this many items + max_length: The value must be a tuple with at most this many items + strict: The value must be a tuple with exactly this many items + ref: Optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return tuple_schema( + items_schema=[items_schema or any_schema()], + variadic_item_index=0, + min_length=min_length, + max_length=max_length, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class TupleSchema(TypedDict, total=False): + type: Required[Literal['tuple']] + items_schema: Required[List[CoreSchema]] + variadic_item_index: int + min_length: int + max_length: int + strict: bool + ref: str + metadata: Any + serialization: IncExSeqOrElseSerSchema + + +def tuple_schema( + items_schema: list[CoreSchema], + *, + variadic_item_index: int | None = None, + min_length: int | None = None, + max_length: int | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: IncExSeqOrElseSerSchema | None = None, +) -> TupleSchema: + """ + Returns a schema that matches a tuple of schemas, with an optional variadic item, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.tuple_schema( + [core_schema.int_schema(), core_schema.str_schema(), core_schema.float_schema()], + variadic_item_index=1, + ) + v = SchemaValidator(schema) + assert v.validate_python((1, 'hello', 'world', 1.5)) == (1, 'hello', 'world', 1.5) + ``` + + Args: + items_schema: The value must be a tuple with items that match these schemas + variadic_item_index: The index of the schema in `items_schema` to be treated as variadic (following PEP 646) + min_length: The value must be a tuple with at least this many items + max_length: The value must be a tuple with at most this many items + strict: The value must be a tuple with exactly this many items + ref: Optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='tuple', + items_schema=items_schema, + variadic_item_index=variadic_item_index, + min_length=min_length, + max_length=max_length, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class SetSchema(TypedDict, total=False): + type: Required[Literal['set']] + items_schema: CoreSchema + min_length: int + max_length: int + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def set_schema( + items_schema: CoreSchema | None = None, + *, + min_length: int | None = None, + max_length: int | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> SetSchema: + """ + Returns a schema that matches a set of a given schema, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.set_schema( + items_schema=core_schema.int_schema(), min_length=0, max_length=10 + ) + v = SchemaValidator(schema) + assert v.validate_python({1, '2', 3}) == {1, 2, 3} + ``` + + Args: + items_schema: The value must be a set with items that match this schema + min_length: The value must be a set with at least this many items + max_length: The value must be a set with at most this many items + strict: The value must be a set with exactly this many items + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='set', + items_schema=items_schema, + min_length=min_length, + max_length=max_length, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class FrozenSetSchema(TypedDict, total=False): + type: Required[Literal['frozenset']] + items_schema: CoreSchema + min_length: int + max_length: int + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def frozenset_schema( + items_schema: CoreSchema | None = None, + *, + min_length: int | None = None, + max_length: int | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> FrozenSetSchema: + """ + Returns a schema that matches a frozenset of a given schema, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.frozenset_schema( + items_schema=core_schema.int_schema(), min_length=0, max_length=10 + ) + v = SchemaValidator(schema) + assert v.validate_python(frozenset(range(3))) == frozenset({0, 1, 2}) + ``` + + Args: + items_schema: The value must be a frozenset with items that match this schema + min_length: The value must be a frozenset with at least this many items + max_length: The value must be a frozenset with at most this many items + strict: The value must be a frozenset with exactly this many items + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='frozenset', + items_schema=items_schema, + min_length=min_length, + max_length=max_length, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class GeneratorSchema(TypedDict, total=False): + type: Required[Literal['generator']] + items_schema: CoreSchema + min_length: int + max_length: int + ref: str + metadata: Any + serialization: IncExSeqOrElseSerSchema + + +def generator_schema( + items_schema: CoreSchema | None = None, + *, + min_length: int | None = None, + max_length: int | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: IncExSeqOrElseSerSchema | None = None, +) -> GeneratorSchema: + """ + Returns a schema that matches a generator value, e.g.: + + ```py + from typing import Iterator + from pydantic_core import SchemaValidator, core_schema + + def gen() -> Iterator[int]: + yield 1 + + schema = core_schema.generator_schema(items_schema=core_schema.int_schema()) + v = SchemaValidator(schema) + v.validate_python(gen()) + ``` + + Unlike other types, validated generators do not raise ValidationErrors eagerly, + but instead will raise a ValidationError when a violating value is actually read from the generator. + This is to ensure that "validated" generators retain the benefit of lazy evaluation. + + Args: + items_schema: The value must be a generator with items that match this schema + min_length: The value must be a generator that yields at least this many items + max_length: The value must be a generator that yields at most this many items + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='generator', + items_schema=items_schema, + min_length=min_length, + max_length=max_length, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +IncExDict = Set[Union[int, str]] + + +class IncExDictSerSchema(TypedDict, total=False): + type: Required[Literal['include-exclude-dict']] + include: IncExDict + exclude: IncExDict + + +def filter_dict_schema(*, include: IncExDict | None = None, exclude: IncExDict | None = None) -> IncExDictSerSchema: + return _dict_not_none(type='include-exclude-dict', include=include, exclude=exclude) + + +IncExDictOrElseSerSchema = Union[IncExDictSerSchema, SerSchema] + + +class DictSchema(TypedDict, total=False): + type: Required[Literal['dict']] + keys_schema: CoreSchema # default: AnySchema + values_schema: CoreSchema # default: AnySchema + min_length: int + max_length: int + strict: bool + ref: str + metadata: Any + serialization: IncExDictOrElseSerSchema + + +def dict_schema( + keys_schema: CoreSchema | None = None, + values_schema: CoreSchema | None = None, + *, + min_length: int | None = None, + max_length: int | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> DictSchema: + """ + Returns a schema that matches a dict value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.dict_schema( + keys_schema=core_schema.str_schema(), values_schema=core_schema.int_schema() + ) + v = SchemaValidator(schema) + assert v.validate_python({'a': '1', 'b': 2}) == {'a': 1, 'b': 2} + ``` + + Args: + keys_schema: The value must be a dict with keys that match this schema + values_schema: The value must be a dict with values that match this schema + min_length: The value must be a dict with at least this many items + max_length: The value must be a dict with at most this many items + strict: Whether the keys and values should be validated with strict mode + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='dict', + keys_schema=keys_schema, + values_schema=values_schema, + min_length=min_length, + max_length=max_length, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +# (__input_value: Any) -> Any +NoInfoValidatorFunction = Callable[[Any], Any] + + +class NoInfoValidatorFunctionSchema(TypedDict): + type: Literal['no-info'] + function: NoInfoValidatorFunction + + +# (__input_value: Any, __info: ValidationInfo) -> Any +WithInfoValidatorFunction = Callable[[Any, ValidationInfo], Any] + + +class WithInfoValidatorFunctionSchema(TypedDict, total=False): + type: Required[Literal['with-info']] + function: Required[WithInfoValidatorFunction] + field_name: str + + +ValidationFunction = Union[NoInfoValidatorFunctionSchema, WithInfoValidatorFunctionSchema] + + +class _ValidatorFunctionSchema(TypedDict, total=False): + function: Required[ValidationFunction] + schema: Required[CoreSchema] + ref: str + metadata: Any + serialization: SerSchema + + +class BeforeValidatorFunctionSchema(_ValidatorFunctionSchema, total=False): + type: Required[Literal['function-before']] + + +def no_info_before_validator_function( + function: NoInfoValidatorFunction, + schema: CoreSchema, + *, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> BeforeValidatorFunctionSchema: + """ + Returns a schema that calls a validator function before validating, no `info` argument is provided, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + def fn(v: bytes) -> str: + return v.decode() + 'world' + + func_schema = core_schema.no_info_before_validator_function( + function=fn, schema=core_schema.str_schema() + ) + schema = core_schema.typed_dict_schema({'a': core_schema.typed_dict_field(func_schema)}) + + v = SchemaValidator(schema) + assert v.validate_python({'a': b'hello '}) == {'a': 'hello world'} + ``` + + Args: + function: The validator function to call + schema: The schema to validate the output of the validator function + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='function-before', + function={'type': 'no-info', 'function': function}, + schema=schema, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +def with_info_before_validator_function( + function: WithInfoValidatorFunction, + schema: CoreSchema, + *, + field_name: str | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> BeforeValidatorFunctionSchema: + """ + Returns a schema that calls a validator function before validation, the function is called with + an `info` argument, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + def fn(v: bytes, info: core_schema.ValidationInfo) -> str: + assert info.data is not None + assert info.field_name is not None + return v.decode() + 'world' + + func_schema = core_schema.with_info_before_validator_function( + function=fn, schema=core_schema.str_schema(), field_name='a' + ) + schema = core_schema.typed_dict_schema({'a': core_schema.typed_dict_field(func_schema)}) + + v = SchemaValidator(schema) + assert v.validate_python({'a': b'hello '}) == {'a': 'hello world'} + ``` + + Args: + function: The validator function to call + field_name: The name of the field + schema: The schema to validate the output of the validator function + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='function-before', + function=_dict_not_none(type='with-info', function=function, field_name=field_name), + schema=schema, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class AfterValidatorFunctionSchema(_ValidatorFunctionSchema, total=False): + type: Required[Literal['function-after']] + + +def no_info_after_validator_function( + function: NoInfoValidatorFunction, + schema: CoreSchema, + *, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> AfterValidatorFunctionSchema: + """ + Returns a schema that calls a validator function after validating, no `info` argument is provided, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + def fn(v: str) -> str: + return v + 'world' + + func_schema = core_schema.no_info_after_validator_function(fn, core_schema.str_schema()) + schema = core_schema.typed_dict_schema({'a': core_schema.typed_dict_field(func_schema)}) + + v = SchemaValidator(schema) + assert v.validate_python({'a': b'hello '}) == {'a': 'hello world'} + ``` + + Args: + function: The validator function to call after the schema is validated + schema: The schema to validate before the validator function + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='function-after', + function={'type': 'no-info', 'function': function}, + schema=schema, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +def with_info_after_validator_function( + function: WithInfoValidatorFunction, + schema: CoreSchema, + *, + field_name: str | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> AfterValidatorFunctionSchema: + """ + Returns a schema that calls a validator function after validation, the function is called with + an `info` argument, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + def fn(v: str, info: core_schema.ValidationInfo) -> str: + assert info.data is not None + assert info.field_name is not None + return v + 'world' + + func_schema = core_schema.with_info_after_validator_function( + function=fn, schema=core_schema.str_schema(), field_name='a' + ) + schema = core_schema.typed_dict_schema({'a': core_schema.typed_dict_field(func_schema)}) + + v = SchemaValidator(schema) + assert v.validate_python({'a': b'hello '}) == {'a': 'hello world'} + ``` + + Args: + function: The validator function to call after the schema is validated + schema: The schema to validate before the validator function + field_name: The name of the field this validators is applied to, if any + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='function-after', + function=_dict_not_none(type='with-info', function=function, field_name=field_name), + schema=schema, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class ValidatorFunctionWrapHandler(Protocol): + def __call__(self, input_value: Any, outer_location: str | int | None = None) -> Any: # pragma: no cover + ... + + +# (__input_value: Any, __validator: ValidatorFunctionWrapHandler) -> Any +NoInfoWrapValidatorFunction = Callable[[Any, ValidatorFunctionWrapHandler], Any] + + +class NoInfoWrapValidatorFunctionSchema(TypedDict): + type: Literal['no-info'] + function: NoInfoWrapValidatorFunction + + +# (__input_value: Any, __validator: ValidatorFunctionWrapHandler, __info: ValidationInfo) -> Any +WithInfoWrapValidatorFunction = Callable[[Any, ValidatorFunctionWrapHandler, ValidationInfo], Any] + + +class WithInfoWrapValidatorFunctionSchema(TypedDict, total=False): + type: Required[Literal['with-info']] + function: Required[WithInfoWrapValidatorFunction] + field_name: str + + +WrapValidatorFunction = Union[NoInfoWrapValidatorFunctionSchema, WithInfoWrapValidatorFunctionSchema] + + +class WrapValidatorFunctionSchema(TypedDict, total=False): + type: Required[Literal['function-wrap']] + function: Required[WrapValidatorFunction] + schema: Required[CoreSchema] + ref: str + metadata: Any + serialization: SerSchema + + +def no_info_wrap_validator_function( + function: NoInfoWrapValidatorFunction, + schema: CoreSchema, + *, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> WrapValidatorFunctionSchema: + """ + Returns a schema which calls a function with a `validator` callable argument which can + optionally be used to call inner validation with the function logic, this is much like the + "onion" implementation of middleware in many popular web frameworks, no `info` argument is passed, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + def fn( + v: str, + validator: core_schema.ValidatorFunctionWrapHandler, + ) -> str: + return validator(input_value=v) + 'world' + + schema = core_schema.no_info_wrap_validator_function( + function=fn, schema=core_schema.str_schema() + ) + v = SchemaValidator(schema) + assert v.validate_python('hello ') == 'hello world' + ``` + + Args: + function: The validator function to call + schema: The schema to validate the output of the validator function + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='function-wrap', + function={'type': 'no-info', 'function': function}, + schema=schema, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +def with_info_wrap_validator_function( + function: WithInfoWrapValidatorFunction, + schema: CoreSchema, + *, + field_name: str | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> WrapValidatorFunctionSchema: + """ + Returns a schema which calls a function with a `validator` callable argument which can + optionally be used to call inner validation with the function logic, this is much like the + "onion" implementation of middleware in many popular web frameworks, an `info` argument is also passed, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + def fn( + v: str, + validator: core_schema.ValidatorFunctionWrapHandler, + info: core_schema.ValidationInfo, + ) -> str: + return validator(input_value=v) + 'world' + + schema = core_schema.with_info_wrap_validator_function( + function=fn, schema=core_schema.str_schema() + ) + v = SchemaValidator(schema) + assert v.validate_python('hello ') == 'hello world' + ``` + + Args: + function: The validator function to call + schema: The schema to validate the output of the validator function + field_name: The name of the field this validators is applied to, if any + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='function-wrap', + function=_dict_not_none(type='with-info', function=function, field_name=field_name), + schema=schema, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class PlainValidatorFunctionSchema(TypedDict, total=False): + type: Required[Literal['function-plain']] + function: Required[ValidationFunction] + ref: str + metadata: Any + serialization: SerSchema + + +def no_info_plain_validator_function( + function: NoInfoValidatorFunction, + *, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> PlainValidatorFunctionSchema: + """ + Returns a schema that uses the provided function for validation, no `info` argument is passed, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + def fn(v: str) -> str: + assert 'hello' in v + return v + 'world' + + schema = core_schema.no_info_plain_validator_function(function=fn) + v = SchemaValidator(schema) + assert v.validate_python('hello ') == 'hello world' + ``` + + Args: + function: The validator function to call + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='function-plain', + function={'type': 'no-info', 'function': function}, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +def with_info_plain_validator_function( + function: WithInfoValidatorFunction, + *, + field_name: str | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> PlainValidatorFunctionSchema: + """ + Returns a schema that uses the provided function for validation, an `info` argument is passed, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + def fn(v: str, info: core_schema.ValidationInfo) -> str: + assert 'hello' in v + return v + 'world' + + schema = core_schema.with_info_plain_validator_function(function=fn) + v = SchemaValidator(schema) + assert v.validate_python('hello ') == 'hello world' + ``` + + Args: + function: The validator function to call + field_name: The name of the field this validators is applied to, if any + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='function-plain', + function=_dict_not_none(type='with-info', function=function, field_name=field_name), + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class WithDefaultSchema(TypedDict, total=False): + type: Required[Literal['default']] + schema: Required[CoreSchema] + default: Any + default_factory: Callable[[], Any] + on_error: Literal['raise', 'omit', 'default'] # default: 'raise' + validate_default: bool # default: False + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def with_default_schema( + schema: CoreSchema, + *, + default: Any = PydanticUndefined, + default_factory: Callable[[], Any] | None = None, + on_error: Literal['raise', 'omit', 'default'] | None = None, + validate_default: bool | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> WithDefaultSchema: + """ + Returns a schema that adds a default value to the given schema, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.with_default_schema(core_schema.str_schema(), default='hello') + wrapper_schema = core_schema.typed_dict_schema( + {'a': core_schema.typed_dict_field(schema)} + ) + v = SchemaValidator(wrapper_schema) + assert v.validate_python({}) == v.validate_python({'a': 'hello'}) + ``` + + Args: + schema: The schema to add a default value to + default: The default value to use + default_factory: A function that returns the default value to use + on_error: What to do if the schema validation fails. One of 'raise', 'omit', 'default' + validate_default: Whether the default value should be validated + strict: Whether the underlying schema should be validated with strict mode + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + s = _dict_not_none( + type='default', + schema=schema, + default_factory=default_factory, + on_error=on_error, + validate_default=validate_default, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + if default is not PydanticUndefined: + s['default'] = default + return s + + +class NullableSchema(TypedDict, total=False): + type: Required[Literal['nullable']] + schema: Required[CoreSchema] + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def nullable_schema( + schema: CoreSchema, + *, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> NullableSchema: + """ + Returns a schema that matches a nullable value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.nullable_schema(core_schema.str_schema()) + v = SchemaValidator(schema) + assert v.validate_python(None) is None + ``` + + Args: + schema: The schema to wrap + strict: Whether the underlying schema should be validated with strict mode + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='nullable', schema=schema, strict=strict, ref=ref, metadata=metadata, serialization=serialization + ) + + +class UnionSchema(TypedDict, total=False): + type: Required[Literal['union']] + choices: Required[List[Union[CoreSchema, Tuple[CoreSchema, str]]]] + # default true, whether to automatically collapse unions with one element to the inner validator + auto_collapse: bool + custom_error_type: str + custom_error_message: str + custom_error_context: Dict[str, Union[str, int, float]] + mode: Literal['smart', 'left_to_right'] # default: 'smart' + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def union_schema( + choices: list[CoreSchema | tuple[CoreSchema, str]], + *, + auto_collapse: bool | None = None, + custom_error_type: str | None = None, + custom_error_message: str | None = None, + custom_error_context: dict[str, str | int] | None = None, + mode: Literal['smart', 'left_to_right'] | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> UnionSchema: + """ + Returns a schema that matches a union value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.union_schema([core_schema.str_schema(), core_schema.int_schema()]) + v = SchemaValidator(schema) + assert v.validate_python('hello') == 'hello' + assert v.validate_python(1) == 1 + ``` + + Args: + choices: The schemas to match. If a tuple, the second item is used as the label for the case. + auto_collapse: whether to automatically collapse unions with one element to the inner validator, default true + custom_error_type: The custom error type to use if the validation fails + custom_error_message: The custom error message to use if the validation fails + custom_error_context: The custom error context to use if the validation fails + mode: How to select which choice to return + * `smart` (default) will try to return the choice which is the closest match to the input value + * `left_to_right` will return the first choice in `choices` which succeeds validation + strict: Whether the underlying schemas should be validated with strict mode + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='union', + choices=choices, + auto_collapse=auto_collapse, + custom_error_type=custom_error_type, + custom_error_message=custom_error_message, + custom_error_context=custom_error_context, + mode=mode, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class TaggedUnionSchema(TypedDict, total=False): + type: Required[Literal['tagged-union']] + choices: Required[Dict[Hashable, CoreSchema]] + discriminator: Required[Union[str, List[Union[str, int]], List[List[Union[str, int]]], Callable[[Any], Hashable]]] + custom_error_type: str + custom_error_message: str + custom_error_context: Dict[str, Union[str, int, float]] + strict: bool + from_attributes: bool # default: True + ref: str + metadata: Any + serialization: SerSchema + + +def tagged_union_schema( + choices: Dict[Hashable, CoreSchema], + discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], Hashable], + *, + custom_error_type: str | None = None, + custom_error_message: str | None = None, + custom_error_context: dict[str, int | str | float] | None = None, + strict: bool | None = None, + from_attributes: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> TaggedUnionSchema: + """ + Returns a schema that matches a tagged union value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + apple_schema = core_schema.typed_dict_schema( + { + 'foo': core_schema.typed_dict_field(core_schema.str_schema()), + 'bar': core_schema.typed_dict_field(core_schema.int_schema()), + } + ) + banana_schema = core_schema.typed_dict_schema( + { + 'foo': core_schema.typed_dict_field(core_schema.str_schema()), + 'spam': core_schema.typed_dict_field( + core_schema.list_schema(items_schema=core_schema.int_schema()) + ), + } + ) + schema = core_schema.tagged_union_schema( + choices={ + 'apple': apple_schema, + 'banana': banana_schema, + }, + discriminator='foo', + ) + v = SchemaValidator(schema) + assert v.validate_python({'foo': 'apple', 'bar': '123'}) == {'foo': 'apple', 'bar': 123} + assert v.validate_python({'foo': 'banana', 'spam': [1, 2, 3]}) == { + 'foo': 'banana', + 'spam': [1, 2, 3], + } + ``` + + Args: + choices: The schemas to match + When retrieving a schema from `choices` using the discriminator value, if the value is a str, + it should be fed back into the `choices` map until a schema is obtained + (This approach is to prevent multiple ownership of a single schema in Rust) + discriminator: The discriminator to use to determine the schema to use + * If `discriminator` is a str, it is the name of the attribute to use as the discriminator + * If `discriminator` is a list of int/str, it should be used as a "path" to access the discriminator + * If `discriminator` is a list of lists, each inner list is a path, and the first path that exists is used + * If `discriminator` is a callable, it should return the discriminator when called on the value to validate; + the callable can return `None` to indicate that there is no matching discriminator present on the input + custom_error_type: The custom error type to use if the validation fails + custom_error_message: The custom error message to use if the validation fails + custom_error_context: The custom error context to use if the validation fails + strict: Whether the underlying schemas should be validated with strict mode + from_attributes: Whether to use the attributes of the object to retrieve the discriminator value + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='tagged-union', + choices=choices, + discriminator=discriminator, + custom_error_type=custom_error_type, + custom_error_message=custom_error_message, + custom_error_context=custom_error_context, + strict=strict, + from_attributes=from_attributes, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class ChainSchema(TypedDict, total=False): + type: Required[Literal['chain']] + steps: Required[List[CoreSchema]] + ref: str + metadata: Any + serialization: SerSchema + + +def chain_schema( + steps: list[CoreSchema], *, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None +) -> ChainSchema: + """ + Returns a schema that chains the provided validation schemas, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + def fn(v: str, info: core_schema.ValidationInfo) -> str: + assert 'hello' in v + return v + ' world' + + fn_schema = core_schema.with_info_plain_validator_function(function=fn) + schema = core_schema.chain_schema( + [fn_schema, fn_schema, fn_schema, core_schema.str_schema()] + ) + v = SchemaValidator(schema) + assert v.validate_python('hello') == 'hello world world world' + ``` + + Args: + steps: The schemas to chain + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none(type='chain', steps=steps, ref=ref, metadata=metadata, serialization=serialization) + + +class LaxOrStrictSchema(TypedDict, total=False): + type: Required[Literal['lax-or-strict']] + lax_schema: Required[CoreSchema] + strict_schema: Required[CoreSchema] + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def lax_or_strict_schema( + lax_schema: CoreSchema, + strict_schema: CoreSchema, + *, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> LaxOrStrictSchema: + """ + Returns a schema that uses the lax or strict schema, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + def fn(v: str, info: core_schema.ValidationInfo) -> str: + assert 'hello' in v + return v + ' world' + + lax_schema = core_schema.int_schema(strict=False) + strict_schema = core_schema.int_schema(strict=True) + + schema = core_schema.lax_or_strict_schema( + lax_schema=lax_schema, strict_schema=strict_schema, strict=True + ) + v = SchemaValidator(schema) + assert v.validate_python(123) == 123 + + schema = core_schema.lax_or_strict_schema( + lax_schema=lax_schema, strict_schema=strict_schema, strict=False + ) + v = SchemaValidator(schema) + assert v.validate_python('123') == 123 + ``` + + Args: + lax_schema: The lax schema to use + strict_schema: The strict schema to use + strict: Whether the strict schema should be used + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='lax-or-strict', + lax_schema=lax_schema, + strict_schema=strict_schema, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class JsonOrPythonSchema(TypedDict, total=False): + type: Required[Literal['json-or-python']] + json_schema: Required[CoreSchema] + python_schema: Required[CoreSchema] + ref: str + metadata: Any + serialization: SerSchema + + +def json_or_python_schema( + json_schema: CoreSchema, + python_schema: CoreSchema, + *, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> JsonOrPythonSchema: + """ + Returns a schema that uses the Json or Python schema depending on the input: + + ```py + from pydantic_core import SchemaValidator, ValidationError, core_schema + + v = SchemaValidator( + core_schema.json_or_python_schema( + json_schema=core_schema.int_schema(), + python_schema=core_schema.int_schema(strict=True), + ) + ) + + assert v.validate_json('"123"') == 123 + + try: + v.validate_python('123') + except ValidationError: + pass + else: + raise AssertionError('Validation should have failed') + ``` + + Args: + json_schema: The schema to use for Json inputs + python_schema: The schema to use for Python inputs + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='json-or-python', + json_schema=json_schema, + python_schema=python_schema, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class TypedDictField(TypedDict, total=False): + type: Required[Literal['typed-dict-field']] + schema: Required[CoreSchema] + required: bool + validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]] + serialization_alias: str + serialization_exclude: bool # default: False + metadata: Any + + +def typed_dict_field( + schema: CoreSchema, + *, + required: bool | None = None, + validation_alias: str | list[str | int] | list[list[str | int]] | None = None, + serialization_alias: str | None = None, + serialization_exclude: bool | None = None, + metadata: Any = None, +) -> TypedDictField: + """ + Returns a schema that matches a typed dict field, e.g.: + + ```py + from pydantic_core import core_schema + + field = core_schema.typed_dict_field(schema=core_schema.int_schema(), required=True) + ``` + + Args: + schema: The schema to use for the field + required: Whether the field is required + validation_alias: The alias(es) to use to find the field in the validation data + serialization_alias: The alias to use as a key when serializing + serialization_exclude: Whether to exclude the field when serializing + metadata: Any other information you want to include with the schema, not used by pydantic-core + """ + return _dict_not_none( + type='typed-dict-field', + schema=schema, + required=required, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + serialization_exclude=serialization_exclude, + metadata=metadata, + ) + + +class TypedDictSchema(TypedDict, total=False): + type: Required[Literal['typed-dict']] + fields: Required[Dict[str, TypedDictField]] + computed_fields: List[ComputedField] + strict: bool + extras_schema: CoreSchema + # all these values can be set via config, equivalent fields have `typed_dict_` prefix + extra_behavior: ExtraBehavior + total: bool # default: True + populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1 + ref: str + metadata: Any + serialization: SerSchema + config: CoreConfig + + +def typed_dict_schema( + fields: Dict[str, TypedDictField], + *, + computed_fields: list[ComputedField] | None = None, + strict: bool | None = None, + extras_schema: CoreSchema | None = None, + extra_behavior: ExtraBehavior | None = None, + total: bool | None = None, + populate_by_name: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, + config: CoreConfig | None = None, +) -> TypedDictSchema: + """ + Returns a schema that matches a typed dict, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + wrapper_schema = core_schema.typed_dict_schema( + {'a': core_schema.typed_dict_field(core_schema.str_schema())} + ) + v = SchemaValidator(wrapper_schema) + assert v.validate_python({'a': 'hello'}) == {'a': 'hello'} + ``` + + Args: + fields: The fields to use for the typed dict + computed_fields: Computed fields to use when serializing the model, only applies when directly inside a model + strict: Whether the typed dict is strict + extras_schema: The extra validator to use for the typed dict + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + extra_behavior: The extra behavior to use for the typed dict + total: Whether the typed dict is total + populate_by_name: Whether the typed dict should populate by name + serialization: Custom serialization schema + """ + return _dict_not_none( + type='typed-dict', + fields=fields, + computed_fields=computed_fields, + strict=strict, + extras_schema=extras_schema, + extra_behavior=extra_behavior, + total=total, + populate_by_name=populate_by_name, + ref=ref, + metadata=metadata, + serialization=serialization, + config=config, + ) + + +class ModelField(TypedDict, total=False): + type: Required[Literal['model-field']] + schema: Required[CoreSchema] + validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]] + serialization_alias: str + serialization_exclude: bool # default: False + frozen: bool + metadata: Any + + +def model_field( + schema: CoreSchema, + *, + validation_alias: str | list[str | int] | list[list[str | int]] | None = None, + serialization_alias: str | None = None, + serialization_exclude: bool | None = None, + frozen: bool | None = None, + metadata: Any = None, +) -> ModelField: + """ + Returns a schema for a model field, e.g.: + + ```py + from pydantic_core import core_schema + + field = core_schema.model_field(schema=core_schema.int_schema()) + ``` + + Args: + schema: The schema to use for the field + validation_alias: The alias(es) to use to find the field in the validation data + serialization_alias: The alias to use as a key when serializing + serialization_exclude: Whether to exclude the field when serializing + frozen: Whether the field is frozen + metadata: Any other information you want to include with the schema, not used by pydantic-core + """ + return _dict_not_none( + type='model-field', + schema=schema, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + serialization_exclude=serialization_exclude, + frozen=frozen, + metadata=metadata, + ) + + +class ModelFieldsSchema(TypedDict, total=False): + type: Required[Literal['model-fields']] + fields: Required[Dict[str, ModelField]] + model_name: str + computed_fields: List[ComputedField] + strict: bool + extras_schema: CoreSchema + # all these values can be set via config, equivalent fields have `typed_dict_` prefix + extra_behavior: ExtraBehavior + populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1 + from_attributes: bool + ref: str + metadata: Any + serialization: SerSchema + + +def model_fields_schema( + fields: Dict[str, ModelField], + *, + model_name: str | None = None, + computed_fields: list[ComputedField] | None = None, + strict: bool | None = None, + extras_schema: CoreSchema | None = None, + extra_behavior: ExtraBehavior | None = None, + populate_by_name: bool | None = None, + from_attributes: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> ModelFieldsSchema: + """ + Returns a schema that matches a typed dict, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + wrapper_schema = core_schema.model_fields_schema( + {'a': core_schema.model_field(core_schema.str_schema())} + ) + v = SchemaValidator(wrapper_schema) + print(v.validate_python({'a': 'hello'})) + #> ({'a': 'hello'}, None, {'a'}) + ``` + + Args: + fields: The fields to use for the typed dict + model_name: The name of the model, used for error messages, defaults to "Model" + computed_fields: Computed fields to use when serializing the model, only applies when directly inside a model + strict: Whether the typed dict is strict + extras_schema: The extra validator to use for the typed dict + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + extra_behavior: The extra behavior to use for the typed dict + populate_by_name: Whether the typed dict should populate by name + from_attributes: Whether the typed dict should be populated from attributes + serialization: Custom serialization schema + """ + return _dict_not_none( + type='model-fields', + fields=fields, + model_name=model_name, + computed_fields=computed_fields, + strict=strict, + extras_schema=extras_schema, + extra_behavior=extra_behavior, + populate_by_name=populate_by_name, + from_attributes=from_attributes, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class ModelSchema(TypedDict, total=False): + type: Required[Literal['model']] + cls: Required[Type[Any]] + schema: Required[CoreSchema] + custom_init: bool + root_model: bool + post_init: str + revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never' + strict: bool + frozen: bool + extra_behavior: ExtraBehavior + config: CoreConfig + ref: str + metadata: Any + serialization: SerSchema + + +def model_schema( + cls: Type[Any], + schema: CoreSchema, + *, + custom_init: bool | None = None, + root_model: bool | None = None, + post_init: str | None = None, + revalidate_instances: Literal['always', 'never', 'subclass-instances'] | None = None, + strict: bool | None = None, + frozen: bool | None = None, + extra_behavior: ExtraBehavior | None = None, + config: CoreConfig | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> ModelSchema: + """ + A model schema generally contains a typed-dict schema. + It will run the typed dict validator, then create a new class + and set the dict and fields set returned from the typed dict validator + to `__dict__` and `__pydantic_fields_set__` respectively. + + Example: + + ```py + from pydantic_core import CoreConfig, SchemaValidator, core_schema + + class MyModel: + __slots__ = ( + '__dict__', + '__pydantic_fields_set__', + '__pydantic_extra__', + '__pydantic_private__', + ) + + schema = core_schema.model_schema( + cls=MyModel, + config=CoreConfig(str_max_length=5), + schema=core_schema.model_fields_schema( + fields={'a': core_schema.model_field(core_schema.str_schema())}, + ), + ) + v = SchemaValidator(schema) + assert v.isinstance_python({'a': 'hello'}) is True + assert v.isinstance_python({'a': 'too long'}) is False + ``` + + Args: + cls: The class to use for the model + schema: The schema to use for the model + custom_init: Whether the model has a custom init method + root_model: Whether the model is a `RootModel` + post_init: The call after init to use for the model + revalidate_instances: whether instances of models and dataclasses (including subclass instances) + should re-validate defaults to config.revalidate_instances, else 'never' + strict: Whether the model is strict + frozen: Whether the model is frozen + extra_behavior: The extra behavior to use for the model, used in serialization + config: The config to use for the model + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='model', + cls=cls, + schema=schema, + custom_init=custom_init, + root_model=root_model, + post_init=post_init, + revalidate_instances=revalidate_instances, + strict=strict, + frozen=frozen, + extra_behavior=extra_behavior, + config=config, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class DataclassField(TypedDict, total=False): + type: Required[Literal['dataclass-field']] + name: Required[str] + schema: Required[CoreSchema] + kw_only: bool # default: True + init: bool # default: True + init_only: bool # default: False + frozen: bool # default: False + validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]] + serialization_alias: str + serialization_exclude: bool # default: False + metadata: Any + + +def dataclass_field( + name: str, + schema: CoreSchema, + *, + kw_only: bool | None = None, + init: bool | None = None, + init_only: bool | None = None, + validation_alias: str | list[str | int] | list[list[str | int]] | None = None, + serialization_alias: str | None = None, + serialization_exclude: bool | None = None, + metadata: Any = None, + frozen: bool | None = None, +) -> DataclassField: + """ + Returns a schema for a dataclass field, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + field = core_schema.dataclass_field( + name='a', schema=core_schema.str_schema(), kw_only=False + ) + schema = core_schema.dataclass_args_schema('Foobar', [field]) + v = SchemaValidator(schema) + assert v.validate_python({'a': 'hello'}) == ({'a': 'hello'}, None) + ``` + + Args: + name: The name to use for the argument parameter + schema: The schema to use for the argument parameter + kw_only: Whether the field can be set with a positional argument as well as a keyword argument + init: Whether the field should be validated during initialization + init_only: Whether the field should be omitted from `__dict__` and passed to `__post_init__` + validation_alias: The alias(es) to use to find the field in the validation data + serialization_alias: The alias to use as a key when serializing + serialization_exclude: Whether to exclude the field when serializing + metadata: Any other information you want to include with the schema, not used by pydantic-core + frozen: Whether the field is frozen + """ + return _dict_not_none( + type='dataclass-field', + name=name, + schema=schema, + kw_only=kw_only, + init=init, + init_only=init_only, + validation_alias=validation_alias, + serialization_alias=serialization_alias, + serialization_exclude=serialization_exclude, + metadata=metadata, + frozen=frozen, + ) + + +class DataclassArgsSchema(TypedDict, total=False): + type: Required[Literal['dataclass-args']] + dataclass_name: Required[str] + fields: Required[List[DataclassField]] + computed_fields: List[ComputedField] + populate_by_name: bool # default: False + collect_init_only: bool # default: False + ref: str + metadata: Any + serialization: SerSchema + extra_behavior: ExtraBehavior + + +def dataclass_args_schema( + dataclass_name: str, + fields: list[DataclassField], + *, + computed_fields: List[ComputedField] | None = None, + populate_by_name: bool | None = None, + collect_init_only: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, + extra_behavior: ExtraBehavior | None = None, +) -> DataclassArgsSchema: + """ + Returns a schema for validating dataclass arguments, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + field_a = core_schema.dataclass_field( + name='a', schema=core_schema.str_schema(), kw_only=False + ) + field_b = core_schema.dataclass_field( + name='b', schema=core_schema.bool_schema(), kw_only=False + ) + schema = core_schema.dataclass_args_schema('Foobar', [field_a, field_b]) + v = SchemaValidator(schema) + assert v.validate_python({'a': 'hello', 'b': True}) == ({'a': 'hello', 'b': True}, None) + ``` + + Args: + dataclass_name: The name of the dataclass being validated + fields: The fields to use for the dataclass + computed_fields: Computed fields to use when serializing the dataclass + populate_by_name: Whether to populate by name + collect_init_only: Whether to collect init only fields into a dict to pass to `__post_init__` + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + extra_behavior: How to handle extra fields + """ + return _dict_not_none( + type='dataclass-args', + dataclass_name=dataclass_name, + fields=fields, + computed_fields=computed_fields, + populate_by_name=populate_by_name, + collect_init_only=collect_init_only, + ref=ref, + metadata=metadata, + serialization=serialization, + extra_behavior=extra_behavior, + ) + + +class DataclassSchema(TypedDict, total=False): + type: Required[Literal['dataclass']] + cls: Required[Type[Any]] + schema: Required[CoreSchema] + fields: Required[List[str]] + cls_name: str + post_init: bool # default: False + revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never' + strict: bool # default: False + frozen: bool # default False + ref: str + metadata: Any + serialization: SerSchema + slots: bool + config: CoreConfig + + +def dataclass_schema( + cls: Type[Any], + schema: CoreSchema, + fields: List[str], + *, + cls_name: str | None = None, + post_init: bool | None = None, + revalidate_instances: Literal['always', 'never', 'subclass-instances'] | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, + frozen: bool | None = None, + slots: bool | None = None, + config: CoreConfig | None = None, +) -> DataclassSchema: + """ + Returns a schema for a dataclass. As with `ModelSchema`, this schema can only be used as a field within + another schema, not as the root type. + + Args: + cls: The dataclass type, used to perform subclass checks + schema: The schema to use for the dataclass fields + fields: Fields of the dataclass, this is used in serialization and in validation during re-validation + and while validating assignment + cls_name: The name to use in error locs, etc; this is useful for generics (default: `cls.__name__`) + post_init: Whether to call `__post_init__` after validation + revalidate_instances: whether instances of models and dataclasses (including subclass instances) + should re-validate defaults to config.revalidate_instances, else 'never' + strict: Whether to require an exact instance of `cls` + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + frozen: Whether the dataclass is frozen + slots: Whether `slots=True` on the dataclass, means each field is assigned independently, rather than + simply setting `__dict__`, default false + """ + return _dict_not_none( + type='dataclass', + cls=cls, + fields=fields, + cls_name=cls_name, + schema=schema, + post_init=post_init, + revalidate_instances=revalidate_instances, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + frozen=frozen, + slots=slots, + config=config, + ) + + +class ArgumentsParameter(TypedDict, total=False): + name: Required[str] + schema: Required[CoreSchema] + mode: Literal['positional_only', 'positional_or_keyword', 'keyword_only'] # default positional_or_keyword + alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]] + + +def arguments_parameter( + name: str, + schema: CoreSchema, + *, + mode: Literal['positional_only', 'positional_or_keyword', 'keyword_only'] | None = None, + alias: str | list[str | int] | list[list[str | int]] | None = None, +) -> ArgumentsParameter: + """ + Returns a schema that matches an argument parameter, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + param = core_schema.arguments_parameter( + name='a', schema=core_schema.str_schema(), mode='positional_only' + ) + schema = core_schema.arguments_schema([param]) + v = SchemaValidator(schema) + assert v.validate_python(('hello',)) == (('hello',), {}) + ``` + + Args: + name: The name to use for the argument parameter + schema: The schema to use for the argument parameter + mode: The mode to use for the argument parameter + alias: The alias to use for the argument parameter + """ + return _dict_not_none(name=name, schema=schema, mode=mode, alias=alias) + + +class ArgumentsSchema(TypedDict, total=False): + type: Required[Literal['arguments']] + arguments_schema: Required[List[ArgumentsParameter]] + populate_by_name: bool + var_args_schema: CoreSchema + var_kwargs_schema: CoreSchema + ref: str + metadata: Any + serialization: SerSchema + + +def arguments_schema( + arguments: list[ArgumentsParameter], + *, + populate_by_name: bool | None = None, + var_args_schema: CoreSchema | None = None, + var_kwargs_schema: CoreSchema | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> ArgumentsSchema: + """ + Returns a schema that matches an arguments schema, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + param_a = core_schema.arguments_parameter( + name='a', schema=core_schema.str_schema(), mode='positional_only' + ) + param_b = core_schema.arguments_parameter( + name='b', schema=core_schema.bool_schema(), mode='positional_only' + ) + schema = core_schema.arguments_schema([param_a, param_b]) + v = SchemaValidator(schema) + assert v.validate_python(('hello', True)) == (('hello', True), {}) + ``` + + Args: + arguments: The arguments to use for the arguments schema + populate_by_name: Whether to populate by name + var_args_schema: The variable args schema to use for the arguments schema + var_kwargs_schema: The variable kwargs schema to use for the arguments schema + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='arguments', + arguments_schema=arguments, + populate_by_name=populate_by_name, + var_args_schema=var_args_schema, + var_kwargs_schema=var_kwargs_schema, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class CallSchema(TypedDict, total=False): + type: Required[Literal['call']] + arguments_schema: Required[CoreSchema] + function: Required[Callable[..., Any]] + function_name: str # default function.__name__ + return_schema: CoreSchema + ref: str + metadata: Any + serialization: SerSchema + + +def call_schema( + arguments: CoreSchema, + function: Callable[..., Any], + *, + function_name: str | None = None, + return_schema: CoreSchema | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> CallSchema: + """ + Returns a schema that matches an arguments schema, then calls a function, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + param_a = core_schema.arguments_parameter( + name='a', schema=core_schema.str_schema(), mode='positional_only' + ) + param_b = core_schema.arguments_parameter( + name='b', schema=core_schema.bool_schema(), mode='positional_only' + ) + args_schema = core_schema.arguments_schema([param_a, param_b]) + + schema = core_schema.call_schema( + arguments=args_schema, + function=lambda a, b: a + str(not b), + return_schema=core_schema.str_schema(), + ) + v = SchemaValidator(schema) + assert v.validate_python((('hello', True))) == 'helloFalse' + ``` + + Args: + arguments: The arguments to use for the arguments schema + function: The function to use for the call schema + function_name: The function name to use for the call schema, if not provided `function.__name__` is used + return_schema: The return schema to use for the call schema + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='call', + arguments_schema=arguments, + function=function, + function_name=function_name, + return_schema=return_schema, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class CustomErrorSchema(TypedDict, total=False): + type: Required[Literal['custom-error']] + schema: Required[CoreSchema] + custom_error_type: Required[str] + custom_error_message: str + custom_error_context: Dict[str, Union[str, int, float]] + ref: str + metadata: Any + serialization: SerSchema + + +def custom_error_schema( + schema: CoreSchema, + custom_error_type: str, + *, + custom_error_message: str | None = None, + custom_error_context: dict[str, Any] | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> CustomErrorSchema: + """ + Returns a schema that matches a custom error value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.custom_error_schema( + schema=core_schema.int_schema(), + custom_error_type='MyError', + custom_error_message='Error msg', + ) + v = SchemaValidator(schema) + v.validate_python(1) + ``` + + Args: + schema: The schema to use for the custom error schema + custom_error_type: The custom error type to use for the custom error schema + custom_error_message: The custom error message to use for the custom error schema + custom_error_context: The custom error context to use for the custom error schema + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='custom-error', + schema=schema, + custom_error_type=custom_error_type, + custom_error_message=custom_error_message, + custom_error_context=custom_error_context, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class JsonSchema(TypedDict, total=False): + type: Required[Literal['json']] + schema: CoreSchema + ref: str + metadata: Any + serialization: SerSchema + + +def json_schema( + schema: CoreSchema | None = None, + *, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> JsonSchema: + """ + Returns a schema that matches a JSON value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + dict_schema = core_schema.model_fields_schema( + { + 'field_a': core_schema.model_field(core_schema.str_schema()), + 'field_b': core_schema.model_field(core_schema.bool_schema()), + }, + ) + + class MyModel: + __slots__ = ( + '__dict__', + '__pydantic_fields_set__', + '__pydantic_extra__', + '__pydantic_private__', + ) + field_a: str + field_b: bool + + json_schema = core_schema.json_schema(schema=dict_schema) + schema = core_schema.model_schema(cls=MyModel, schema=json_schema) + v = SchemaValidator(schema) + m = v.validate_python('{"field_a": "hello", "field_b": true}') + assert isinstance(m, MyModel) + ``` + + Args: + schema: The schema to use for the JSON schema + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none(type='json', schema=schema, ref=ref, metadata=metadata, serialization=serialization) + + +class UrlSchema(TypedDict, total=False): + type: Required[Literal['url']] + max_length: int + allowed_schemes: List[str] + host_required: bool # default False + default_host: str + default_port: int + default_path: str + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def url_schema( + *, + max_length: int | None = None, + allowed_schemes: list[str] | None = None, + host_required: bool | None = None, + default_host: str | None = None, + default_port: int | None = None, + default_path: str | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> UrlSchema: + """ + Returns a schema that matches a URL value, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.url_schema() + v = SchemaValidator(schema) + print(v.validate_python('https://example.com')) + #> https://example.com/ + ``` + + Args: + max_length: The maximum length of the URL + allowed_schemes: The allowed URL schemes + host_required: Whether the URL must have a host + default_host: The default host to use if the URL does not have a host + default_port: The default port to use if the URL does not have a port + default_path: The default path to use if the URL does not have a path + strict: Whether to use strict URL parsing + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='url', + max_length=max_length, + allowed_schemes=allowed_schemes, + host_required=host_required, + default_host=default_host, + default_port=default_port, + default_path=default_path, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class MultiHostUrlSchema(TypedDict, total=False): + type: Required[Literal['multi-host-url']] + max_length: int + allowed_schemes: List[str] + host_required: bool # default False + default_host: str + default_port: int + default_path: str + strict: bool + ref: str + metadata: Any + serialization: SerSchema + + +def multi_host_url_schema( + *, + max_length: int | None = None, + allowed_schemes: list[str] | None = None, + host_required: bool | None = None, + default_host: str | None = None, + default_port: int | None = None, + default_path: str | None = None, + strict: bool | None = None, + ref: str | None = None, + metadata: Any = None, + serialization: SerSchema | None = None, +) -> MultiHostUrlSchema: + """ + Returns a schema that matches a URL value with possibly multiple hosts, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.multi_host_url_schema() + v = SchemaValidator(schema) + print(v.validate_python('redis://localhost,0.0.0.0,127.0.0.1')) + #> redis://localhost,0.0.0.0,127.0.0.1 + ``` + + Args: + max_length: The maximum length of the URL + allowed_schemes: The allowed URL schemes + host_required: Whether the URL must have a host + default_host: The default host to use if the URL does not have a host + default_port: The default port to use if the URL does not have a port + default_path: The default path to use if the URL does not have a path + strict: Whether to use strict URL parsing + ref: optional unique identifier of the schema, used to reference the schema in other places + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='multi-host-url', + max_length=max_length, + allowed_schemes=allowed_schemes, + host_required=host_required, + default_host=default_host, + default_port=default_port, + default_path=default_path, + strict=strict, + ref=ref, + metadata=metadata, + serialization=serialization, + ) + + +class DefinitionsSchema(TypedDict, total=False): + type: Required[Literal['definitions']] + schema: Required[CoreSchema] + definitions: Required[List[CoreSchema]] + metadata: Any + serialization: SerSchema + + +def definitions_schema(schema: CoreSchema, definitions: list[CoreSchema]) -> DefinitionsSchema: + """ + Build a schema that contains both an inner schema and a list of definitions which can be used + within the inner schema. + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema = core_schema.definitions_schema( + core_schema.list_schema(core_schema.definition_reference_schema('foobar')), + [core_schema.int_schema(ref='foobar')], + ) + v = SchemaValidator(schema) + assert v.validate_python([1, 2, '3']) == [1, 2, 3] + ``` + + Args: + schema: The inner schema + definitions: List of definitions which can be referenced within inner schema + """ + return DefinitionsSchema(type='definitions', schema=schema, definitions=definitions) + + +class DefinitionReferenceSchema(TypedDict, total=False): + type: Required[Literal['definition-ref']] + schema_ref: Required[str] + ref: str + metadata: Any + serialization: SerSchema + + +def definition_reference_schema( + schema_ref: str, ref: str | None = None, metadata: Any = None, serialization: SerSchema | None = None +) -> DefinitionReferenceSchema: + """ + Returns a schema that points to a schema stored in "definitions", this is useful for nested recursive + models and also when you want to define validators separately from the main schema, e.g.: + + ```py + from pydantic_core import SchemaValidator, core_schema + + schema_definition = core_schema.definition_reference_schema('list-schema') + schema = core_schema.definitions_schema( + schema=schema_definition, + definitions=[ + core_schema.list_schema(items_schema=schema_definition, ref='list-schema'), + ], + ) + v = SchemaValidator(schema) + assert v.validate_python([()]) == [[]] + ``` + + Args: + schema_ref: The schema ref to use for the definition reference schema + metadata: Any other information you want to include with the schema, not used by pydantic-core + serialization: Custom serialization schema + """ + return _dict_not_none( + type='definition-ref', schema_ref=schema_ref, ref=ref, metadata=metadata, serialization=serialization + ) + + +MYPY = False +# See https://github.com/python/mypy/issues/14034 for details, in summary mypy is extremely slow to process this +# union which kills performance not just for pydantic, but even for code using pydantic +if not MYPY: + CoreSchema = Union[ + AnySchema, + NoneSchema, + BoolSchema, + IntSchema, + FloatSchema, + DecimalSchema, + StringSchema, + BytesSchema, + DateSchema, + TimeSchema, + DatetimeSchema, + TimedeltaSchema, + LiteralSchema, + IsInstanceSchema, + IsSubclassSchema, + CallableSchema, + ListSchema, + TupleSchema, + SetSchema, + FrozenSetSchema, + GeneratorSchema, + DictSchema, + AfterValidatorFunctionSchema, + BeforeValidatorFunctionSchema, + WrapValidatorFunctionSchema, + PlainValidatorFunctionSchema, + WithDefaultSchema, + NullableSchema, + UnionSchema, + TaggedUnionSchema, + ChainSchema, + LaxOrStrictSchema, + JsonOrPythonSchema, + TypedDictSchema, + ModelFieldsSchema, + ModelSchema, + DataclassArgsSchema, + DataclassSchema, + ArgumentsSchema, + CallSchema, + CustomErrorSchema, + JsonSchema, + UrlSchema, + MultiHostUrlSchema, + DefinitionsSchema, + DefinitionReferenceSchema, + UuidSchema, + ] +elif False: + CoreSchema: TypeAlias = Mapping[str, Any] + + +# to update this, call `pytest -k test_core_schema_type_literal` and copy the output +CoreSchemaType = Literal[ + 'any', + 'none', + 'bool', + 'int', + 'float', + 'decimal', + 'str', + 'bytes', + 'date', + 'time', + 'datetime', + 'timedelta', + 'literal', + 'is-instance', + 'is-subclass', + 'callable', + 'list', + 'tuple', + 'set', + 'frozenset', + 'generator', + 'dict', + 'function-after', + 'function-before', + 'function-wrap', + 'function-plain', + 'default', + 'nullable', + 'union', + 'tagged-union', + 'chain', + 'lax-or-strict', + 'json-or-python', + 'typed-dict', + 'model-fields', + 'model', + 'dataclass-args', + 'dataclass', + 'arguments', + 'call', + 'custom-error', + 'json', + 'url', + 'multi-host-url', + 'definitions', + 'definition-ref', + 'uuid', +] + +CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field'] + + +# used in _pydantic_core.pyi::PydanticKnownError +# to update this, call `pytest -k test_all_errors` and copy the output +ErrorType = Literal[ + 'no_such_attribute', + 'json_invalid', + 'json_type', + 'recursion_loop', + 'missing', + 'frozen_field', + 'frozen_instance', + 'extra_forbidden', + 'invalid_key', + 'get_attribute_error', + 'model_type', + 'model_attributes_type', + 'dataclass_type', + 'dataclass_exact_type', + 'none_required', + 'greater_than', + 'greater_than_equal', + 'less_than', + 'less_than_equal', + 'multiple_of', + 'finite_number', + 'too_short', + 'too_long', + 'iterable_type', + 'iteration_error', + 'string_type', + 'string_sub_type', + 'string_unicode', + 'string_too_short', + 'string_too_long', + 'string_pattern_mismatch', + 'enum', + 'dict_type', + 'mapping_type', + 'list_type', + 'tuple_type', + 'set_type', + 'bool_type', + 'bool_parsing', + 'int_type', + 'int_parsing', + 'int_parsing_size', + 'int_from_float', + 'float_type', + 'float_parsing', + 'bytes_type', + 'bytes_too_short', + 'bytes_too_long', + 'value_error', + 'assertion_error', + 'literal_error', + 'date_type', + 'date_parsing', + 'date_from_datetime_parsing', + 'date_from_datetime_inexact', + 'date_past', + 'date_future', + 'time_type', + 'time_parsing', + 'datetime_type', + 'datetime_parsing', + 'datetime_object_invalid', + 'datetime_from_date_parsing', + 'datetime_past', + 'datetime_future', + 'timezone_naive', + 'timezone_aware', + 'timezone_offset', + 'time_delta_type', + 'time_delta_parsing', + 'frozen_set_type', + 'is_instance_of', + 'is_subclass_of', + 'callable_type', + 'union_tag_invalid', + 'union_tag_not_found', + 'arguments_type', + 'missing_argument', + 'unexpected_keyword_argument', + 'missing_keyword_only_argument', + 'unexpected_positional_argument', + 'missing_positional_only_argument', + 'multiple_argument_values', + 'url_type', + 'url_parsing', + 'url_syntax_violation', + 'url_too_long', + 'url_scheme', + 'uuid_type', + 'uuid_parsing', + 'uuid_version', + 'decimal_type', + 'decimal_parsing', + 'decimal_max_digits', + 'decimal_max_places', + 'decimal_whole_digits', +] + + +def _dict_not_none(**kwargs: Any) -> Any: + return {k: v for k, v in kwargs.items() if v is not None} + + +############################################################################### +# All this stuff is deprecated by #980 and will be removed eventually +# They're kept because some code external code will be using them + + +@deprecated('`field_before_validator_function` is deprecated, use `with_info_before_validator_function` instead.') +def field_before_validator_function(function: WithInfoValidatorFunction, field_name: str, schema: CoreSchema, **kwargs): + warnings.warn( + '`field_before_validator_function` is deprecated, use `with_info_before_validator_function` instead.', + DeprecationWarning, + ) + return with_info_before_validator_function(function, schema, field_name=field_name, **kwargs) + + +@deprecated('`general_before_validator_function` is deprecated, use `with_info_before_validator_function` instead.') +def general_before_validator_function(*args, **kwargs): + warnings.warn( + '`general_before_validator_function` is deprecated, use `with_info_before_validator_function` instead.', + DeprecationWarning, + ) + return with_info_before_validator_function(*args, **kwargs) + + +@deprecated('`field_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.') +def field_after_validator_function(function: WithInfoValidatorFunction, field_name: str, schema: CoreSchema, **kwargs): + warnings.warn( + '`field_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.', + DeprecationWarning, + ) + return with_info_after_validator_function(function, schema, field_name=field_name, **kwargs) + + +@deprecated('`general_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.') +def general_after_validator_function(*args, **kwargs): + warnings.warn( + '`general_after_validator_function` is deprecated, use `with_info_after_validator_function` instead.', + DeprecationWarning, + ) + return with_info_after_validator_function(*args, **kwargs) + + +@deprecated('`field_wrap_validator_function` is deprecated, use `with_info_wrap_validator_function` instead.') +def field_wrap_validator_function( + function: WithInfoWrapValidatorFunction, field_name: str, schema: CoreSchema, **kwargs +): + warnings.warn( + '`field_wrap_validator_function` is deprecated, use `with_info_wrap_validator_function` instead.', + DeprecationWarning, + ) + return with_info_wrap_validator_function(function, schema, field_name=field_name, **kwargs) + + +@deprecated('`general_wrap_validator_function` is deprecated, use `with_info_wrap_validator_function` instead.') +def general_wrap_validator_function(*args, **kwargs): + warnings.warn( + '`general_wrap_validator_function` is deprecated, use `with_info_wrap_validator_function` instead.', + DeprecationWarning, + ) + return with_info_wrap_validator_function(*args, **kwargs) + + +@deprecated('`field_plain_validator_function` is deprecated, use `with_info_plain_validator_function` instead.') +def field_plain_validator_function(function: WithInfoValidatorFunction, field_name: str, **kwargs): + warnings.warn( + '`field_plain_validator_function` is deprecated, use `with_info_plain_validator_function` instead.', + DeprecationWarning, + ) + return with_info_plain_validator_function(function, field_name=field_name, **kwargs) + + +@deprecated('`general_plain_validator_function` is deprecated, use `with_info_plain_validator_function` instead.') +def general_plain_validator_function(*args, **kwargs): + warnings.warn( + '`general_plain_validator_function` is deprecated, use `with_info_plain_validator_function` instead.', + DeprecationWarning, + ) + return with_info_plain_validator_function(*args, **kwargs) + + +_deprecated_import_lookup = { + 'FieldValidationInfo': ValidationInfo, + 'FieldValidatorFunction': WithInfoValidatorFunction, + 'GeneralValidatorFunction': WithInfoValidatorFunction, + 'FieldWrapValidatorFunction': WithInfoWrapValidatorFunction, +} + +if TYPE_CHECKING: + FieldValidationInfo = ValidationInfo + + +def __getattr__(attr_name: str) -> object: + new_attr = _deprecated_import_lookup.get(attr_name) + if new_attr is None: + raise AttributeError(f"module 'pydantic_core' has no attribute '{attr_name}'") + else: + import warnings + + msg = f'`{attr_name}` is deprecated, use `{new_attr.__name__}` instead.' + warnings.warn(msg, DeprecationWarning, stacklevel=1) + return new_attr diff --git a/lib/pydantic_core/py.typed b/lib/pydantic_core/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/lib/tempora/timing.py b/lib/tempora/timing.py index aed0d336..e74b8962 100644 --- a/lib/tempora/timing.py +++ b/lib/tempora/timing.py @@ -1,22 +1,21 @@ +import collections.abc +import contextlib import datetime import functools import numbers import time -import collections.abc -import contextlib import jaraco.functools class Stopwatch: """ - A simple stopwatch which starts automatically. + A simple stopwatch that starts automatically. >>> w = Stopwatch() >>> _1_sec = datetime.timedelta(seconds=1) >>> w.split() < _1_sec True - >>> import time >>> time.sleep(1.0) >>> w.split() >= _1_sec True @@ -27,13 +26,13 @@ class Stopwatch: >>> w.split() < _1_sec True - It should be possible to launch the Stopwatch in a context: + Launch the Stopwatch in a context: >>> with Stopwatch() as watch: ... assert isinstance(watch.split(), datetime.timedelta) - In that case, the watch is stopped when the context is exited, - so to read the elapsed time: + After exiting the context, the watch is stopped; read the + elapsed time directly: >>> watch.elapsed datetime.timedelta(...) diff --git a/lib/typing_extensions.py b/lib/typing_extensions.py index ef42417c..f3132ea4 100644 --- a/lib/typing_extensions.py +++ b/lib/typing_extensions.py @@ -2,11 +2,12 @@ import abc import collections import collections.abc import functools +import inspect import operator import sys import types as _types import typing - +import warnings __all__ = [ # Super-special typing primitives. @@ -31,6 +32,7 @@ __all__ = [ 'Coroutine', 'AsyncGenerator', 'AsyncContextManager', + 'Buffer', 'ChainMap', # Concrete collection types. @@ -43,7 +45,13 @@ __all__ = [ 'TypedDict', # Structural checks, a.k.a. protocols. + 'SupportsAbs', + 'SupportsBytes', + 'SupportsComplex', + 'SupportsFloat', 'SupportsIndex', + 'SupportsInt', + 'SupportsRound', # One-off things. 'Annotated', @@ -51,12 +59,17 @@ __all__ = [ 'assert_type', 'clear_overloads', 'dataclass_transform', + 'deprecated', + 'Doc', 'get_overloads', 'final', 'get_args', 'get_origin', + 'get_original_bases', + 'get_protocol_members', 'get_type_hints', 'IntVar', + 'is_protocol', 'is_typeddict', 'Literal', 'NewType', @@ -68,12 +81,54 @@ __all__ = [ 'runtime_checkable', 'Text', 'TypeAlias', + 'TypeAliasType', 'TypeGuard', + 'TypeIs', 'TYPE_CHECKING', 'Never', 'NoReturn', + 'ReadOnly', 'Required', 'NotRequired', + + # Pure aliases, have always been in typing + 'AbstractSet', + 'AnyStr', + 'BinaryIO', + 'Callable', + 'Collection', + 'Container', + 'Dict', + 'ForwardRef', + 'FrozenSet', + 'Generator', + 'Generic', + 'Hashable', + 'IO', + 'ItemsView', + 'Iterable', + 'Iterator', + 'KeysView', + 'List', + 'Mapping', + 'MappingView', + 'Match', + 'MutableMapping', + 'MutableSequence', + 'MutableSet', + 'Optional', + 'Pattern', + 'Reversible', + 'Sequence', + 'Set', + 'Sized', + 'TextIO', + 'Tuple', + 'Union', + 'ValuesView', + 'cast', + 'no_type_check', + 'no_type_check_decorator', ] # for backward compatibility @@ -83,7 +138,13 @@ GenericMeta = type # The functions below are modified copies of typing internal helpers. # They are needed by _ProtocolMeta and they provide support for PEP 646. -_marker = object() + +class _Sentinel: + def __repr__(self): + return "" + + +_marker = _Sentinel() def _check_generic(cls, parameters, elen=_marker): @@ -184,36 +245,13 @@ else: ClassVar = typing.ClassVar -# On older versions of typing there is an internal class named "Final". -# 3.8+ -if hasattr(typing, 'Final') and sys.version_info[:2] >= (3, 7): - Final = typing.Final -# 3.7 -else: - class _FinalForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name +class _ExtensionsSpecialForm(typing._SpecialForm, _root=True): + def __repr__(self): + return 'typing_extensions.' + self._name - def __getitem__(self, parameters): - item = typing._type_check(parameters, - f'{self._name} accepts only a single type.') - return typing._GenericAlias(self, (item,)) - Final = _FinalForm('Final', - doc="""A special typing construct to indicate that a name - cannot be re-assigned or overridden in a subclass. - For example: - - MAX_SIZE: Final = 9000 - MAX_SIZE += 1 # Error reported by type checker - - class Connection: - TIMEOUT: Final[int] = 10 - class FastConnector(Connection): - TIMEOUT = 1 # Error reported by type checker - - There is no runtime checking of these properties.""") +Final = typing.Final if sys.version_info >= (3, 11): final = typing.final @@ -257,21 +295,67 @@ def IntVar(name): return typing.TypeVar(name) -# 3.8+: -if hasattr(typing, 'Literal'): +# A Literal bug was fixed in 3.11.0, 3.10.1 and 3.9.8 +if sys.version_info >= (3, 10, 1): Literal = typing.Literal -# 3.7: else: - class _LiteralForm(typing._SpecialForm, _root=True): + def _flatten_literal_params(parameters): + """An internal helper for Literal creation: flatten Literals among parameters""" + params = [] + for p in parameters: + if isinstance(p, _LiteralGenericAlias): + params.extend(p.__args__) + else: + params.append(p) + return tuple(params) - def __repr__(self): - return 'typing_extensions.' + self._name + def _value_and_type_iter(params): + for p in params: + yield p, type(p) + + class _LiteralGenericAlias(typing._GenericAlias, _root=True): + def __eq__(self, other): + if not isinstance(other, _LiteralGenericAlias): + return NotImplemented + these_args_deduped = set(_value_and_type_iter(self.__args__)) + other_args_deduped = set(_value_and_type_iter(other.__args__)) + return these_args_deduped == other_args_deduped + + def __hash__(self): + return hash(frozenset(_value_and_type_iter(self.__args__))) + + class _LiteralForm(_ExtensionsSpecialForm, _root=True): + def __init__(self, doc: str): + self._name = 'Literal' + self._doc = self.__doc__ = doc def __getitem__(self, parameters): - return typing._GenericAlias(self, parameters) + if not isinstance(parameters, tuple): + parameters = (parameters,) - Literal = _LiteralForm('Literal', - doc="""A type that can be used to indicate to type checkers + parameters = _flatten_literal_params(parameters) + + val_type_pairs = list(_value_and_type_iter(parameters)) + try: + deduped_pairs = set(val_type_pairs) + except TypeError: + # unhashable parameters + pass + else: + # similar logic to typing._deduplicate on Python 3.9+ + if len(deduped_pairs) < len(val_type_pairs): + new_parameters = [] + for pair in val_type_pairs: + if pair in deduped_pairs: + new_parameters.append(pair[0]) + deduped_pairs.remove(pair) + assert not deduped_pairs, deduped_pairs + parameters = tuple(new_parameters) + + return _LiteralGenericAlias(self, parameters) + + Literal = _LiteralForm(doc="""\ + A type that can be used to indicate to type checkers that the corresponding value has a value literally equivalent to the provided parameter. For example: @@ -285,7 +369,7 @@ else: instead of a type.""") -_overload_dummy = typing._overload_dummy # noqa +_overload_dummy = typing._overload_dummy if hasattr(typing, "get_overloads"): # 3.11+ @@ -359,8 +443,6 @@ Type = typing.Type # Various ABCs mimicking those in collections.abc. # A few are simply re-exported for completeness. - - Awaitable = typing.Awaitable Coroutine = typing.Coroutine AsyncIterable = typing.AsyncIterable @@ -369,278 +451,343 @@ Deque = typing.Deque ContextManager = typing.ContextManager AsyncContextManager = typing.AsyncContextManager DefaultDict = typing.DefaultDict - -# 3.7.2+ -if hasattr(typing, 'OrderedDict'): - OrderedDict = typing.OrderedDict -# 3.7.0-3.7.2 -else: - OrderedDict = typing._alias(collections.OrderedDict, (KT, VT)) - +OrderedDict = typing.OrderedDict Counter = typing.Counter ChainMap = typing.ChainMap AsyncGenerator = typing.AsyncGenerator -NewType = typing.NewType Text = typing.Text TYPE_CHECKING = typing.TYPE_CHECKING -_PROTO_WHITELIST = ['Callable', 'Awaitable', - 'Iterable', 'Iterator', 'AsyncIterable', 'AsyncIterator', - 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', - 'ContextManager', 'AsyncContextManager'] +_PROTO_ALLOWLIST = { + 'collections.abc': [ + 'Callable', 'Awaitable', 'Iterable', 'Iterator', 'AsyncIterable', + 'Hashable', 'Sized', 'Container', 'Collection', 'Reversible', 'Buffer', + ], + 'contextlib': ['AbstractContextManager', 'AbstractAsyncContextManager'], + 'typing_extensions': ['Buffer'], +} + + +_EXCLUDED_ATTRS = { + "__abstractmethods__", "__annotations__", "__weakref__", "_is_protocol", + "_is_runtime_protocol", "__dict__", "__slots__", "__parameters__", + "__orig_bases__", "__module__", "_MutableMapping__marker", "__doc__", + "__subclasshook__", "__orig_class__", "__init__", "__new__", + "__protocol_attrs__", "__non_callable_proto_members__", + "__match_args__", +} + +if sys.version_info >= (3, 9): + _EXCLUDED_ATTRS.add("__class_getitem__") + +if sys.version_info >= (3, 12): + _EXCLUDED_ATTRS.add("__type_params__") + +_EXCLUDED_ATTRS = frozenset(_EXCLUDED_ATTRS) def _get_protocol_attrs(cls): attrs = set() for base in cls.__mro__[:-1]: # without object - if base.__name__ in ('Protocol', 'Generic'): + if base.__name__ in {'Protocol', 'Generic'}: continue annotations = getattr(base, '__annotations__', {}) - for attr in list(base.__dict__.keys()) + list(annotations.keys()): - if (not attr.startswith('_abc_') and attr not in ( - '__abstractmethods__', '__annotations__', '__weakref__', - '_is_protocol', '_is_runtime_protocol', '__dict__', - '__args__', '__slots__', - '__next_in_mro__', '__parameters__', '__origin__', - '__orig_bases__', '__extra__', '__tree_hash__', - '__doc__', '__subclasshook__', '__init__', '__new__', - '__module__', '_MutableMapping__marker', '_gorg')): + for attr in (*base.__dict__, *annotations): + if (not attr.startswith('_abc_') and attr not in _EXCLUDED_ATTRS): attrs.add(attr) return attrs -def _is_callable_members_only(cls): - return all(callable(getattr(cls, attr, None)) for attr in _get_protocol_attrs(cls)) +def _caller(depth=2): + try: + return sys._getframe(depth).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): # For platforms without _getframe() + return None -def _maybe_adjust_parameters(cls): - """Helper function used in Protocol.__init_subclass__ and _TypedDictMeta.__new__. - - The contents of this function are very similar - to logic found in typing.Generic.__init_subclass__ - on the CPython main branch. - """ - tvars = [] - if '__orig_bases__' in cls.__dict__: - tvars = typing._collect_type_vars(cls.__orig_bases__) - # Look for Generic[T1, ..., Tn] or Protocol[T1, ..., Tn]. - # If found, tvars must be a subset of it. - # If not found, tvars is it. - # Also check for and reject plain Generic, - # and reject multiple Generic[...] and/or Protocol[...]. - gvars = None - for base in cls.__orig_bases__: - if (isinstance(base, typing._GenericAlias) and - base.__origin__ in (typing.Generic, Protocol)): - # for error messages - the_base = base.__origin__.__name__ - if gvars is not None: - raise TypeError( - "Cannot inherit from Generic[...]" - " and/or Protocol[...] multiple types.") - gvars = base.__parameters__ - if gvars is None: - gvars = tvars - else: - tvarset = set(tvars) - gvarset = set(gvars) - if not tvarset <= gvarset: - s_vars = ', '.join(str(t) for t in tvars if t not in gvarset) - s_args = ', '.join(str(g) for g in gvars) - raise TypeError(f"Some type variables ({s_vars}) are" - f" not listed in {the_base}[{s_args}]") - tvars = gvars - cls.__parameters__ = tuple(tvars) - - -# 3.8+ -if hasattr(typing, 'Protocol'): +# `__match_args__` attribute was removed from protocol members in 3.13, +# we want to backport this change to older Python versions. +if sys.version_info >= (3, 13): Protocol = typing.Protocol -# 3.7 else: + def _allow_reckless_class_checks(depth=3): + """Allow instance and class checks for special stdlib modules. + The abc and functools modules indiscriminately call isinstance() and + issubclass() on the whole MRO of a user class, which may contain protocols. + """ + return _caller(depth) in {'abc', 'functools', None} def _no_init(self, *args, **kwargs): if type(self)._is_protocol: raise TypeError('Protocols cannot be instantiated') - class _ProtocolMeta(abc.ABCMeta): # noqa: B024 - # This metaclass is a bit unfortunate and exists only because of the lack - # of __instancehook__. + def _type_check_issubclass_arg_1(arg): + """Raise TypeError if `arg` is not an instance of `type` + in `issubclass(arg, )`. + + In most cases, this is verified by type.__subclasscheck__. + Checking it again unnecessarily would slow down issubclass() checks, + so, we don't perform this check unless we absolutely have to. + + For various error paths, however, + we want to ensure that *this* error message is shown to the user + where relevant, rather than a typing.py-specific error message. + """ + if not isinstance(arg, type): + # Same error message as for issubclass(1, int). + raise TypeError('issubclass() arg 1 must be a class') + + # Inheriting from typing._ProtocolMeta isn't actually desirable, + # but is necessary to allow typing.Protocol and typing_extensions.Protocol + # to mix without getting TypeErrors about "metaclass conflict" + class _ProtocolMeta(type(typing.Protocol)): + # This metaclass is somewhat unfortunate, + # but is necessary for several reasons... + # + # NOTE: DO NOT call super() in any methods in this class + # That would call the methods on typing._ProtocolMeta on Python 3.8-3.11 + # and those are slow + def __new__(mcls, name, bases, namespace, **kwargs): + if name == "Protocol" and len(bases) < 2: + pass + elif {Protocol, typing.Protocol} & set(bases): + for base in bases: + if not ( + base in {object, typing.Generic, Protocol, typing.Protocol} + or base.__name__ in _PROTO_ALLOWLIST.get(base.__module__, []) + or is_protocol(base) + ): + raise TypeError( + f"Protocols can only inherit from other protocols, " + f"got {base!r}" + ) + return abc.ABCMeta.__new__(mcls, name, bases, namespace, **kwargs) + + def __init__(cls, *args, **kwargs): + abc.ABCMeta.__init__(cls, *args, **kwargs) + if getattr(cls, "_is_protocol", False): + cls.__protocol_attrs__ = _get_protocol_attrs(cls) + + def __subclasscheck__(cls, other): + if cls is Protocol: + return type.__subclasscheck__(cls, other) + if ( + getattr(cls, '_is_protocol', False) + and not _allow_reckless_class_checks() + ): + if not getattr(cls, '_is_runtime_protocol', False): + _type_check_issubclass_arg_1(other) + raise TypeError( + "Instance and class checks can only be used with " + "@runtime_checkable protocols" + ) + if ( + # this attribute is set by @runtime_checkable: + cls.__non_callable_proto_members__ + and cls.__dict__.get("__subclasshook__") is _proto_hook + ): + _type_check_issubclass_arg_1(other) + non_method_attrs = sorted(cls.__non_callable_proto_members__) + raise TypeError( + "Protocols with non-method members don't support issubclass()." + f" Non-method members: {str(non_method_attrs)[1:-1]}." + ) + return abc.ABCMeta.__subclasscheck__(cls, other) + def __instancecheck__(cls, instance): # We need this method for situations where attributes are # assigned in __init__. - if ((not getattr(cls, '_is_protocol', False) or - _is_callable_members_only(cls)) and - issubclass(instance.__class__, cls)): + if cls is Protocol: + return type.__instancecheck__(cls, instance) + if not getattr(cls, "_is_protocol", False): + # i.e., it's a concrete subclass of a protocol + return abc.ABCMeta.__instancecheck__(cls, instance) + + if ( + not getattr(cls, '_is_runtime_protocol', False) and + not _allow_reckless_class_checks() + ): + raise TypeError("Instance and class checks can only be used with" + " @runtime_checkable protocols") + + if abc.ABCMeta.__instancecheck__(cls, instance): return True - if cls._is_protocol: - if all(hasattr(instance, attr) and - (not callable(getattr(cls, attr, None)) or - getattr(instance, attr) is not None) - for attr in _get_protocol_attrs(cls)): - return True - return super().__instancecheck__(instance) - class Protocol(metaclass=_ProtocolMeta): - # There is quite a lot of overlapping code with typing.Generic. - # Unfortunately it is hard to avoid this while these live in two different - # modules. The duplicated code will be removed when Protocol is moved to typing. - """Base class for protocol classes. Protocol classes are defined as:: + for attr in cls.__protocol_attrs__: + try: + val = inspect.getattr_static(instance, attr) + except AttributeError: + break + # this attribute is set by @runtime_checkable: + if val is None and attr not in cls.__non_callable_proto_members__: + break + else: + return True - class Proto(Protocol): - def meth(self) -> int: - ... + return False - Such classes are primarily used with static type checkers that recognize - structural subtyping (static duck-typing), for example:: + def __eq__(cls, other): + # Hack so that typing.Generic.__class_getitem__ + # treats typing_extensions.Protocol + # as equivalent to typing.Protocol + if abc.ABCMeta.__eq__(cls, other) is True: + return True + return cls is Protocol and other is typing.Protocol - class C: - def meth(self) -> int: - return 0 + # This has to be defined, or the abc-module cache + # complains about classes with this metaclass being unhashable, + # if we define only __eq__! + def __hash__(cls) -> int: + return type.__hash__(cls) - def func(x: Proto) -> int: - return x.meth() + @classmethod + def _proto_hook(cls, other): + if not cls.__dict__.get('_is_protocol', False): + return NotImplemented - func(C()) # Passes static type check + for attr in cls.__protocol_attrs__: + for base in other.__mro__: + # Check if the members appears in the class dictionary... + if attr in base.__dict__: + if base.__dict__[attr] is None: + return NotImplemented + break - See PEP 544 for details. Protocol classes decorated with - @typing_extensions.runtime act as simple-minded runtime protocol that checks - only the presence of given attributes, ignoring their type signatures. + # ...or in annotations, if it is a sub-protocol. + annotations = getattr(base, '__annotations__', {}) + if ( + isinstance(annotations, collections.abc.Mapping) + and attr in annotations + and is_protocol(other) + ): + break + else: + return NotImplemented + return True - Protocol classes can be generic, they are defined as:: - - class GenProto(Protocol[T]): - def meth(self) -> T: - ... - """ + class Protocol(typing.Generic, metaclass=_ProtocolMeta): + __doc__ = typing.Protocol.__doc__ __slots__ = () _is_protocol = True - - def __new__(cls, *args, **kwds): - if cls is Protocol: - raise TypeError("Type Protocol cannot be instantiated; " - "it can only be used as a base class") - return super().__new__(cls) - - @typing._tp_cache - def __class_getitem__(cls, params): - if not isinstance(params, tuple): - params = (params,) - if not params and cls is not typing.Tuple: - raise TypeError( - f"Parameter list to {cls.__qualname__}[...] cannot be empty") - msg = "Parameters to generic types must be types." - params = tuple(typing._type_check(p, msg) for p in params) # noqa - if cls is Protocol: - # Generic can only be subscripted with unique type variables. - if not all(isinstance(p, typing.TypeVar) for p in params): - i = 0 - while isinstance(params[i], typing.TypeVar): - i += 1 - raise TypeError( - "Parameters to Protocol[...] must all be type variables." - f" Parameter {i + 1} is {params[i]}") - if len(set(params)) != len(params): - raise TypeError( - "Parameters to Protocol[...] must all be unique") - else: - # Subscripting a regular Generic subclass. - _check_generic(cls, params, len(cls.__parameters__)) - return typing._GenericAlias(cls, params) + _is_runtime_protocol = False def __init_subclass__(cls, *args, **kwargs): - if '__orig_bases__' in cls.__dict__: - error = typing.Generic in cls.__orig_bases__ - else: - error = typing.Generic in cls.__bases__ - if error: - raise TypeError("Cannot inherit from plain Generic") - _maybe_adjust_parameters(cls) + super().__init_subclass__(*args, **kwargs) # Determine if this is a protocol or a concrete subclass. - if not cls.__dict__.get('_is_protocol', None): + if not cls.__dict__.get('_is_protocol', False): cls._is_protocol = any(b is Protocol for b in cls.__bases__) # Set (or override) the protocol subclass hook. - def _proto_hook(other): - if not cls.__dict__.get('_is_protocol', None): - return NotImplemented - if not getattr(cls, '_is_runtime_protocol', False): - if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']: - return NotImplemented - raise TypeError("Instance and class checks can only be used with" - " @runtime protocols") - if not _is_callable_members_only(cls): - if sys._getframe(2).f_globals['__name__'] in ['abc', 'functools']: - return NotImplemented - raise TypeError("Protocols with non-method members" - " don't support issubclass()") - if not isinstance(other, type): - # Same error as for issubclass(1, int) - raise TypeError('issubclass() arg 1 must be a class') - for attr in _get_protocol_attrs(cls): - for base in other.__mro__: - if attr in base.__dict__: - if base.__dict__[attr] is None: - return NotImplemented - break - annotations = getattr(base, '__annotations__', {}) - if (isinstance(annotations, typing.Mapping) and - attr in annotations and - isinstance(other, _ProtocolMeta) and - other._is_protocol): - break - else: - return NotImplemented - return True if '__subclasshook__' not in cls.__dict__: cls.__subclasshook__ = _proto_hook - # We have nothing more to do for non-protocols. - if not cls._is_protocol: - return - - # Check consistency of bases. - for base in cls.__bases__: - if not (base in (object, typing.Generic) or - base.__module__ == 'collections.abc' and - base.__name__ in _PROTO_WHITELIST or - isinstance(base, _ProtocolMeta) and base._is_protocol): - raise TypeError('Protocols can only inherit from other' - f' protocols, got {repr(base)}') - cls.__init__ = _no_init + # Prohibit instantiation for protocol classes + if cls._is_protocol and cls.__init__ is Protocol.__init__: + cls.__init__ = _no_init -# 3.8+ -if hasattr(typing, 'runtime_checkable'): +if sys.version_info >= (3, 13): runtime_checkable = typing.runtime_checkable -# 3.7 else: def runtime_checkable(cls): - """Mark a protocol class as a runtime protocol, so that it - can be used with isinstance() and issubclass(). Raise TypeError - if applied to a non-protocol class. + """Mark a protocol class as a runtime protocol. - This allows a simple-minded structural check very similar to the - one-offs in collections.abc such as Hashable. + Such protocol can be used with isinstance() and issubclass(). + Raise TypeError if applied to a non-protocol class. + This allows a simple-minded structural check very similar to + one trick ponies in collections.abc such as Iterable. + + For example:: + + @runtime_checkable + class Closable(Protocol): + def close(self): ... + + assert isinstance(open('/some/file'), Closable) + + Warning: this will check only the presence of the required methods, + not their type signatures! """ - if not isinstance(cls, _ProtocolMeta) or not cls._is_protocol: + if not issubclass(cls, typing.Generic) or not getattr(cls, '_is_protocol', False): raise TypeError('@runtime_checkable can be only applied to protocol classes,' - f' got {cls!r}') + ' got %r' % cls) cls._is_runtime_protocol = True + + # Only execute the following block if it's a typing_extensions.Protocol class. + # typing.Protocol classes don't need it. + if isinstance(cls, _ProtocolMeta): + # PEP 544 prohibits using issubclass() + # with protocols that have non-method members. + # See gh-113320 for why we compute this attribute here, + # rather than in `_ProtocolMeta.__init__` + cls.__non_callable_proto_members__ = set() + for attr in cls.__protocol_attrs__: + try: + is_callable = callable(getattr(cls, attr, None)) + except Exception as e: + raise TypeError( + f"Failed to determine whether protocol member {attr!r} " + "is a method member" + ) from e + else: + if not is_callable: + cls.__non_callable_proto_members__.add(attr) + return cls -# Exists for backwards compatibility. +# The "runtime" alias exists for backwards compatibility. runtime = runtime_checkable -# 3.8+ -if hasattr(typing, 'SupportsIndex'): +# Our version of runtime-checkable protocols is faster on Python 3.8-3.11 +if sys.version_info >= (3, 12): + SupportsInt = typing.SupportsInt + SupportsFloat = typing.SupportsFloat + SupportsComplex = typing.SupportsComplex + SupportsBytes = typing.SupportsBytes SupportsIndex = typing.SupportsIndex -# 3.7 + SupportsAbs = typing.SupportsAbs + SupportsRound = typing.SupportsRound else: + @runtime_checkable + class SupportsInt(Protocol): + """An ABC with one abstract method __int__.""" + __slots__ = () + + @abc.abstractmethod + def __int__(self) -> int: + pass + + @runtime_checkable + class SupportsFloat(Protocol): + """An ABC with one abstract method __float__.""" + __slots__ = () + + @abc.abstractmethod + def __float__(self) -> float: + pass + + @runtime_checkable + class SupportsComplex(Protocol): + """An ABC with one abstract method __complex__.""" + __slots__ = () + + @abc.abstractmethod + def __complex__(self) -> complex: + pass + + @runtime_checkable + class SupportsBytes(Protocol): + """An ABC with one abstract method __bytes__.""" + __slots__ = () + + @abc.abstractmethod + def __bytes__(self) -> bytes: + pass + @runtime_checkable class SupportsIndex(Protocol): __slots__ = () @@ -649,8 +796,45 @@ else: def __index__(self) -> int: pass + @runtime_checkable + class SupportsAbs(Protocol[T_co]): + """ + An ABC with one abstract method __abs__ that is covariant in its return type. + """ + __slots__ = () -if hasattr(typing, "Required"): + @abc.abstractmethod + def __abs__(self) -> T_co: + pass + + @runtime_checkable + class SupportsRound(Protocol[T_co]): + """ + An ABC with one abstract method __round__ that is covariant in its return type. + """ + __slots__ = () + + @abc.abstractmethod + def __round__(self, ndigits: int = 0) -> T_co: + pass + + +def _ensure_subclassable(mro_entries): + def inner(func): + if sys.implementation.name == "pypy" and sys.version_info < (3, 9): + cls_dict = { + "__call__": staticmethod(func), + "__mro_entries__": staticmethod(mro_entries) + } + t = type(func.__name__, (), cls_dict) + return functools.update_wrapper(t(), func) + else: + func.__mro_entries__ = mro_entries + return func + return inner + + +if hasattr(typing, "ReadOnly"): # The standard library TypedDict in Python 3.8 does not store runtime information # about which (if any) keys are optional. See https://bugs.python.org/issue38834 # The standard library TypedDict in Python 3.9.0/1 does not honour the "total" @@ -658,148 +842,164 @@ if hasattr(typing, "Required"): # The standard library TypedDict below Python 3.11 does not store runtime # information about optional and required keys when using Required or NotRequired. # Generic TypedDicts are also impossible using typing.TypedDict on Python <3.11. + # Aaaand on 3.12 we add __orig_bases__ to TypedDict + # to enable better runtime introspection. + # On 3.13 we deprecate some odd ways of creating TypedDicts. + # PEP 705 proposes adding the ReadOnly[] qualifier. TypedDict = typing.TypedDict _TypedDictMeta = typing._TypedDictMeta is_typeddict = typing.is_typeddict else: - def _check_fails(cls, other): - try: - if sys._getframe(1).f_globals['__name__'] not in ['abc', - 'functools', - 'typing']: - # Typed dicts are only for static structural subtyping. - raise TypeError('TypedDict does not support instance and class checks') - except (AttributeError, ValueError): - pass - return False + # 3.10.0 and later + _TAKES_MODULE = "module" in inspect.signature(typing._type_check).parameters - def _dict_new(*args, **kwargs): - if not args: - raise TypeError('TypedDict.__new__(): not enough arguments') - _, args = args[0], args[1:] # allow the "cls" keyword be passed - return dict(*args, **kwargs) - - _dict_new.__text_signature__ = '($cls, _typename, _fields=None, /, **kwargs)' - - def _typeddict_new(*args, total=True, **kwargs): - if not args: - raise TypeError('TypedDict.__new__(): not enough arguments') - _, args = args[0], args[1:] # allow the "cls" keyword be passed - if args: - typename, args = args[0], args[1:] # allow the "_typename" keyword be passed - elif '_typename' in kwargs: - typename = kwargs.pop('_typename') - import warnings - warnings.warn("Passing '_typename' as keyword argument is deprecated", - DeprecationWarning, stacklevel=2) - else: - raise TypeError("TypedDict.__new__() missing 1 required positional " - "argument: '_typename'") - if args: - try: - fields, = args # allow the "_fields" keyword be passed - except ValueError: - raise TypeError('TypedDict.__new__() takes from 2 to 3 ' - f'positional arguments but {len(args) + 2} ' - 'were given') - elif '_fields' in kwargs and len(kwargs) == 1: - fields = kwargs.pop('_fields') - import warnings - warnings.warn("Passing '_fields' as keyword argument is deprecated", - DeprecationWarning, stacklevel=2) - else: - fields = None - - if fields is None: - fields = kwargs - elif kwargs: - raise TypeError("TypedDict takes either a dict or keyword arguments," - " but not both") - - ns = {'__annotations__': dict(fields)} - try: - # Setting correct module is necessary to make typed dict classes pickleable. - ns['__module__'] = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - pass - - return _TypedDictMeta(typename, (), ns, total=total) - - _typeddict_new.__text_signature__ = ('($cls, _typename, _fields=None,' - ' /, *, total=True, **kwargs)') + def _get_typeddict_qualifiers(annotation_type): + while True: + annotation_origin = get_origin(annotation_type) + if annotation_origin is Annotated: + annotation_args = get_args(annotation_type) + if annotation_args: + annotation_type = annotation_args[0] + else: + break + elif annotation_origin is Required: + yield Required + annotation_type, = get_args(annotation_type) + elif annotation_origin is NotRequired: + yield NotRequired + annotation_type, = get_args(annotation_type) + elif annotation_origin is ReadOnly: + yield ReadOnly + annotation_type, = get_args(annotation_type) + else: + break class _TypedDictMeta(type): - def __init__(cls, name, bases, ns, total=True): - super().__init__(name, bases, ns) + def __new__(cls, name, bases, ns, *, total=True, closed=False): + """Create new typed dict class object. - def __new__(cls, name, bases, ns, total=True): - # Create new typed dict class object. - # This method is called directly when TypedDict is subclassed, - # or via _typeddict_new when TypedDict is instantiated. This way - # TypedDict supports all three syntaxes described in its docstring. - # Subclasses and instances of TypedDict return actual dictionaries - # via _dict_new. - ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new - # Don't insert typing.Generic into __bases__ here, - # or Generic.__init_subclass__ will raise TypeError - # in the super().__new__() call. - # Instead, monkey-patch __bases__ onto the class after it's been created. - tp_dict = super().__new__(cls, name, (dict,), ns) + This method is called when TypedDict is subclassed, + or when TypedDict is instantiated. This way + TypedDict supports all three syntax forms described in its docstring. + Subclasses and instances of TypedDict return actual dictionaries. + """ + for base in bases: + if type(base) is not _TypedDictMeta and base is not typing.Generic: + raise TypeError('cannot inherit from both a TypedDict type ' + 'and a non-TypedDict base class') - if any(issubclass(base, typing.Generic) for base in bases): - tp_dict.__bases__ = (typing.Generic, dict) - _maybe_adjust_parameters(tp_dict) + if any(issubclass(b, typing.Generic) for b in bases): + generic_base = (typing.Generic,) + else: + generic_base = () + + # typing.py generally doesn't let you inherit from plain Generic, unless + # the name of the class happens to be "Protocol" + tp_dict = type.__new__(_TypedDictMeta, "Protocol", (*generic_base, dict), ns) + tp_dict.__name__ = name + if tp_dict.__qualname__ == "Protocol": + tp_dict.__qualname__ = name + + if not hasattr(tp_dict, '__orig_bases__'): + tp_dict.__orig_bases__ = bases annotations = {} own_annotations = ns.get('__annotations__', {}) msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type" - own_annotations = { - n: typing._type_check(tp, msg) for n, tp in own_annotations.items() - } + if _TAKES_MODULE: + own_annotations = { + n: typing._type_check(tp, msg, module=tp_dict.__module__) + for n, tp in own_annotations.items() + } + else: + own_annotations = { + n: typing._type_check(tp, msg) + for n, tp in own_annotations.items() + } required_keys = set() optional_keys = set() + readonly_keys = set() + mutable_keys = set() + extra_items_type = None for base in bases: - annotations.update(base.__dict__.get('__annotations__', {})) - required_keys.update(base.__dict__.get('__required_keys__', ())) - optional_keys.update(base.__dict__.get('__optional_keys__', ())) + base_dict = base.__dict__ + + annotations.update(base_dict.get('__annotations__', {})) + required_keys.update(base_dict.get('__required_keys__', ())) + optional_keys.update(base_dict.get('__optional_keys__', ())) + readonly_keys.update(base_dict.get('__readonly_keys__', ())) + mutable_keys.update(base_dict.get('__mutable_keys__', ())) + base_extra_items_type = base_dict.get('__extra_items__', None) + if base_extra_items_type is not None: + extra_items_type = base_extra_items_type + + if closed and extra_items_type is None: + extra_items_type = Never + if closed and "__extra_items__" in own_annotations: + annotation_type = own_annotations.pop("__extra_items__") + qualifiers = set(_get_typeddict_qualifiers(annotation_type)) + if Required in qualifiers: + raise TypeError( + "Special key __extra_items__ does not support " + "Required" + ) + if NotRequired in qualifiers: + raise TypeError( + "Special key __extra_items__ does not support " + "NotRequired" + ) + extra_items_type = annotation_type annotations.update(own_annotations) for annotation_key, annotation_type in own_annotations.items(): - annotation_origin = get_origin(annotation_type) - if annotation_origin is Annotated: - annotation_args = get_args(annotation_type) - if annotation_args: - annotation_type = annotation_args[0] - annotation_origin = get_origin(annotation_type) + qualifiers = set(_get_typeddict_qualifiers(annotation_type)) - if annotation_origin is Required: + if Required in qualifiers: required_keys.add(annotation_key) - elif annotation_origin is NotRequired: + elif NotRequired in qualifiers: optional_keys.add(annotation_key) elif total: required_keys.add(annotation_key) else: optional_keys.add(annotation_key) + if ReadOnly in qualifiers: + mutable_keys.discard(annotation_key) + readonly_keys.add(annotation_key) + else: + mutable_keys.add(annotation_key) + readonly_keys.discard(annotation_key) tp_dict.__annotations__ = annotations tp_dict.__required_keys__ = frozenset(required_keys) tp_dict.__optional_keys__ = frozenset(optional_keys) + tp_dict.__readonly_keys__ = frozenset(readonly_keys) + tp_dict.__mutable_keys__ = frozenset(mutable_keys) if not hasattr(tp_dict, '__total__'): tp_dict.__total__ = total + tp_dict.__closed__ = closed + tp_dict.__extra_items__ = extra_items_type return tp_dict - __instancecheck__ = __subclasscheck__ = _check_fails + __call__ = dict # static method - TypedDict = _TypedDictMeta('TypedDict', (dict,), {}) - TypedDict.__module__ = __name__ - TypedDict.__doc__ = \ - """A simple typed name space. At runtime it is equivalent to a plain dict. + def __subclasscheck__(cls, other): + # Typed dicts are only for static structural subtyping. + raise TypeError('TypedDict does not support instance and class checks') - TypedDict creates a dictionary type that expects all of its - instances to have a certain set of keys, with each key + __instancecheck__ = __subclasscheck__ + + _TypedDict = type.__new__(_TypedDictMeta, 'TypedDict', (), {}) + + @_ensure_subclassable(lambda bases: (_TypedDict,)) + def TypedDict(typename, fields=_marker, /, *, total=True, closed=False, **kwargs): + """A simple typed namespace. At runtime it is equivalent to a plain dict. + + TypedDict creates a dictionary type such that a type checker will expect all + instances to have a certain set of keys, where each key is associated with a value of a consistent type. This expectation - is not checked at runtime but is only enforced by type checkers. + is not checked at runtime. + Usage:: class Point2D(TypedDict): @@ -814,14 +1014,71 @@ else: The type info can be accessed via the Point2D.__annotations__ dict, and the Point2D.__required_keys__ and Point2D.__optional_keys__ frozensets. - TypedDict supports two additional equivalent forms:: + TypedDict supports an additional equivalent form:: - Point2D = TypedDict('Point2D', x=int, y=int, label=str) Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str}) - The class syntax is only supported in Python 3.6+, while two other - syntax forms work for Python 2.7 and 3.2+ + By default, all keys must be present in a TypedDict. It is possible + to override this by specifying totality:: + + class Point2D(TypedDict, total=False): + x: int + y: int + + This means that a Point2D TypedDict can have any of the keys omitted. A type + checker is only expected to support a literal False or True as the value of + the total argument. True is the default, and makes all items defined in the + class body be required. + + The Required and NotRequired special forms can also be used to mark + individual keys as being required or not required:: + + class Point2D(TypedDict): + x: int # the "x" key must always be present (Required is the default) + y: NotRequired[int] # the "y" key can be omitted + + See PEP 655 for more details on Required and NotRequired. """ + if fields is _marker or fields is None: + if fields is _marker: + deprecated_thing = "Failing to pass a value for the 'fields' parameter" + else: + deprecated_thing = "Passing `None` as the 'fields' parameter" + + example = f"`{typename} = TypedDict({typename!r}, {{}})`" + deprecation_msg = ( + f"{deprecated_thing} is deprecated and will be disallowed in " + "Python 3.15. To create a TypedDict class with 0 fields " + "using the functional syntax, pass an empty dictionary, e.g. " + ) + example + "." + warnings.warn(deprecation_msg, DeprecationWarning, stacklevel=2) + if closed is not False and closed is not True: + kwargs["closed"] = closed + closed = False + fields = kwargs + elif kwargs: + raise TypeError("TypedDict takes either a dict or keyword arguments," + " but not both") + if kwargs: + if sys.version_info >= (3, 13): + raise TypeError("TypedDict takes no keyword arguments") + warnings.warn( + "The kwargs-based syntax for TypedDict definitions is deprecated " + "in Python 3.11, will be removed in Python 3.13, and may not be " + "understood by third-party type checkers.", + DeprecationWarning, + stacklevel=2, + ) + + ns = {'__annotations__': dict(fields)} + module = _caller() + if module is not None: + # Setting correct module is necessary to make typed dict classes pickleable. + ns['__module__'] = module + + td = _TypedDictMeta(typename, (), ns, total=total, closed=closed) + td.__orig_bases__ = (TypedDict,) + return td if hasattr(typing, "_TypedDictMeta"): _TYPEDDICT_TYPES = (typing._TypedDictMeta, _TypedDictMeta) @@ -839,14 +1096,17 @@ else: is_typeddict(Film) # => True is_typeddict(Union[list, str]) # => False """ - return isinstance(tp, tuple(_TYPEDDICT_TYPES)) + # On 3.8, this would otherwise return True + if hasattr(typing, "TypedDict") and tp is typing.TypedDict: + return False + return isinstance(tp, _TYPEDDICT_TYPES) if hasattr(typing, "assert_type"): assert_type = typing.assert_type else: - def assert_type(__val, __typ): + def assert_type(val, typ, /): """Assert (to the type checker) that the value is of the given type. When the type checker encounters a call to assert_type(), it @@ -859,15 +1119,12 @@ else: At runtime this returns the first argument unchanged and otherwise does nothing. """ - return __val + return val -if hasattr(typing, "Required"): +if hasattr(typing, "Required"): # 3.11+ get_type_hints = typing.get_type_hints -else: - import functools - import types - +else: # <=3.10 # replaces _strip_annotations() def _strip_extras(t): """Strips Annotated, Required and NotRequired from a given type.""" @@ -880,12 +1137,12 @@ else: if stripped_args == t.__args__: return t return t.copy_with(stripped_args) - if hasattr(types, "GenericAlias") and isinstance(t, types.GenericAlias): + if hasattr(_types, "GenericAlias") and isinstance(t, _types.GenericAlias): stripped_args = tuple(_strip_extras(a) for a in t.__args__) if stripped_args == t.__args__: return t - return types.GenericAlias(t.__origin__, stripped_args) - if hasattr(types, "UnionType") and isinstance(t, types.UnionType): + return _types.GenericAlias(t.__origin__, stripped_args) + if hasattr(_types, "UnionType") and isinstance(t, _types.UnionType): stripped_args = tuple(_strip_extras(a) for a in t.__args__) if stripped_args == t.__args__: return t @@ -925,11 +1182,11 @@ else: - If two dict arguments are passed, they specify globals and locals, respectively. """ - if hasattr(typing, "Annotated"): + if hasattr(typing, "Annotated"): # 3.9+ hint = typing.get_type_hints( obj, globalns=globalns, localns=localns, include_extras=True ) - else: + else: # 3.8 hint = typing.get_type_hints(obj, globalns=globalns, localns=localns) if include_extras: return hint @@ -942,7 +1199,7 @@ if hasattr(typing, 'Annotated'): # Not exported and not a public API, but needed for get_origin() and get_args() # to work. _AnnotatedAlias = typing._AnnotatedAlias -# 3.7-3.8 +# 3.8 else: class _AnnotatedAlias(typing._GenericAlias, _root=True): """Runtime representation of an annotated type. @@ -1047,7 +1304,7 @@ else: if sys.version_info[:2] >= (3, 10): get_origin = typing.get_origin get_args = typing.get_args -# 3.7-3.9 +# 3.8-3.9 else: try: # 3.9+ @@ -1112,11 +1369,7 @@ if hasattr(typing, 'TypeAlias'): TypeAlias = typing.TypeAlias # 3.9 elif sys.version_info[:2] >= (3, 9): - class _TypeAliasForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name - - @_TypeAliasForm + @_ExtensionsSpecialForm def TypeAlias(self, parameters): """Special marker indicating that an assignment should be recognized as a proper type alias definition by type @@ -1129,68 +1382,89 @@ elif sys.version_info[:2] >= (3, 9): It's invalid when used anywhere except as in the example above. """ raise TypeError(f"{self} is not subscriptable") -# 3.7-3.8 +# 3.8 else: - class _TypeAliasForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name + TypeAlias = _ExtensionsSpecialForm( + 'TypeAlias', + doc="""Special marker indicating that an assignment should + be recognized as a proper type alias definition by type + checkers. - TypeAlias = _TypeAliasForm('TypeAlias', - doc="""Special marker indicating that an assignment should - be recognized as a proper type alias definition by type - checkers. + For example:: - For example:: + Predicate: TypeAlias = Callable[..., bool] - Predicate: TypeAlias = Callable[..., bool] + It's invalid when used anywhere except as in the example + above.""" + ) - It's invalid when used anywhere except as in the example - above.""") + +def _set_default(type_param, default): + if isinstance(default, (tuple, list)): + type_param.__default__ = tuple((typing._type_check(d, "Default must be a type") + for d in default)) + elif default != _marker: + if isinstance(type_param, ParamSpec) and default is ...: # ... not valid <3.11 + type_param.__default__ = default + else: + type_param.__default__ = typing._type_check(default, "Default must be a type") + else: + type_param.__default__ = None + + +def _set_module(typevarlike): + # for pickling: + def_mod = _caller(depth=3) + if def_mod != 'typing_extensions': + typevarlike.__module__ = def_mod class _DefaultMixin: """Mixin for TypeVarLike defaults.""" __slots__ = () + __init__ = _set_default - def __init__(self, default): - if isinstance(default, (tuple, list)): - self.__default__ = tuple((typing._type_check(d, "Default must be a type") - for d in default)) - elif default: - self.__default__ = typing._type_check(default, "Default must be a type") - else: - self.__default__ = None + +# Classes using this metaclass must provide a _backported_typevarlike ClassVar +class _TypeVarLikeMeta(type): + def __instancecheck__(cls, __instance: Any) -> bool: + return isinstance(__instance, cls._backported_typevarlike) # Add default and infer_variance parameters from PEP 696 and 695 -class TypeVar(typing.TypeVar, _DefaultMixin, _root=True): +class TypeVar(metaclass=_TypeVarLikeMeta): """Type variable.""" - __module__ = 'typing' + _backported_typevarlike = typing.TypeVar - def __init__(self, name, *constraints, bound=None, - covariant=False, contravariant=False, - default=None, infer_variance=False): - super().__init__(name, *constraints, bound=bound, covariant=covariant, - contravariant=contravariant) - _DefaultMixin.__init__(self, default) - self.__infer_variance__ = infer_variance + def __new__(cls, name, *constraints, bound=None, + covariant=False, contravariant=False, + default=_marker, infer_variance=False): + if hasattr(typing, "TypeAliasType"): + # PEP 695 implemented (3.12+), can pass infer_variance to typing.TypeVar + typevar = typing.TypeVar(name, *constraints, bound=bound, + covariant=covariant, contravariant=contravariant, + infer_variance=infer_variance) + else: + typevar = typing.TypeVar(name, *constraints, bound=bound, + covariant=covariant, contravariant=contravariant) + if infer_variance and (covariant or contravariant): + raise ValueError("Variance cannot be specified with infer_variance.") + typevar.__infer_variance__ = infer_variance + _set_default(typevar, default) + _set_module(typevar) + return typevar - # for pickling: - try: - def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - def_mod = None - if def_mod != 'typing_extensions': - self.__module__ = def_mod + def __init_subclass__(cls) -> None: + raise TypeError(f"type '{__name__}.TypeVar' is not an acceptable base type") # Python 3.10+ has PEP 612 if hasattr(typing, 'ParamSpecArgs'): ParamSpecArgs = typing.ParamSpecArgs ParamSpecKwargs = typing.ParamSpecKwargs -# 3.7-3.9 +# 3.8-3.9 else: class _Immutable: """Mixin to indicate that object should not be copied.""" @@ -1251,27 +1525,35 @@ else: # 3.10+ if hasattr(typing, 'ParamSpec'): - # Add default Parameter - PEP 696 - class ParamSpec(typing.ParamSpec, _DefaultMixin, _root=True): - """Parameter specification variable.""" + # Add default parameter - PEP 696 + class ParamSpec(metaclass=_TypeVarLikeMeta): + """Parameter specification.""" - __module__ = 'typing' + _backported_typevarlike = typing.ParamSpec - def __init__(self, name, *, bound=None, covariant=False, contravariant=False, - default=None): - super().__init__(name, bound=bound, covariant=covariant, - contravariant=contravariant) - _DefaultMixin.__init__(self, default) + def __new__(cls, name, *, bound=None, + covariant=False, contravariant=False, + infer_variance=False, default=_marker): + if hasattr(typing, "TypeAliasType"): + # PEP 695 implemented, can pass infer_variance to typing.TypeVar + paramspec = typing.ParamSpec(name, bound=bound, + covariant=covariant, + contravariant=contravariant, + infer_variance=infer_variance) + else: + paramspec = typing.ParamSpec(name, bound=bound, + covariant=covariant, + contravariant=contravariant) + paramspec.__infer_variance__ = infer_variance - # for pickling: - try: - def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - def_mod = None - if def_mod != 'typing_extensions': - self.__module__ = def_mod + _set_default(paramspec, default) + _set_module(paramspec) + return paramspec -# 3.7-3.9 + def __init_subclass__(cls) -> None: + raise TypeError(f"type '{__name__}.ParamSpec' is not an acceptable base type") + +# 3.8-3.9 else: # Inherits from list as a workaround for Callable checks in Python < 3.9.2. @@ -1334,11 +1616,12 @@ else: return ParamSpecKwargs(self) def __init__(self, name, *, bound=None, covariant=False, contravariant=False, - default=None): + infer_variance=False, default=_marker): super().__init__([self]) self.__name__ = name self.__covariant__ = bool(covariant) self.__contravariant__ = bool(contravariant) + self.__infer_variance__ = bool(infer_variance) if bound: self.__bound__ = typing._type_check(bound, 'Bound must be a type.') else: @@ -1346,15 +1629,14 @@ else: _DefaultMixin.__init__(self, default) # for pickling: - try: - def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - def_mod = None + def_mod = _caller() if def_mod != 'typing_extensions': self.__module__ = def_mod def __repr__(self): - if self.__covariant__: + if self.__infer_variance__: + prefix = '' + elif self.__covariant__: prefix = '+' elif self.__contravariant__: prefix = '-' @@ -1376,7 +1658,7 @@ else: pass -# 3.7-3.9 +# 3.8-3.9 if not hasattr(typing, 'Concatenate'): # Inherits from list as a workaround for Callable checks in Python < 3.9.2. class _ConcatenateGenericAlias(list): @@ -1411,7 +1693,7 @@ if not hasattr(typing, 'Concatenate'): ) -# 3.7-3.9 +# 3.8-3.9 @typing._tp_cache def _concatenate_getitem(self, parameters): if parameters == (): @@ -1429,10 +1711,10 @@ def _concatenate_getitem(self, parameters): # 3.10+ if hasattr(typing, 'Concatenate'): Concatenate = typing.Concatenate - _ConcatenateGenericAlias = typing._ConcatenateGenericAlias # noqa + _ConcatenateGenericAlias = typing._ConcatenateGenericAlias # noqa: F811 # 3.9 elif sys.version_info[:2] >= (3, 9): - @_TypeAliasForm + @_ExtensionsSpecialForm def Concatenate(self, parameters): """Used in conjunction with ``ParamSpec`` and ``Callable`` to represent a higher order function which adds, removes or transforms parameters of a @@ -1445,12 +1727,9 @@ elif sys.version_info[:2] >= (3, 9): See PEP 612 for detailed information. """ return _concatenate_getitem(self, parameters) -# 3.7-8 +# 3.8 else: - class _ConcatenateForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name - + class _ConcatenateForm(_ExtensionsSpecialForm, _root=True): def __getitem__(self, parameters): return _concatenate_getitem(self, parameters) @@ -1472,11 +1751,7 @@ if hasattr(typing, 'TypeGuard'): TypeGuard = typing.TypeGuard # 3.9 elif sys.version_info[:2] >= (3, 9): - class _TypeGuardForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name - - @_TypeGuardForm + @_ExtensionsSpecialForm def TypeGuard(self, parameters): """Special typing form used to annotate the return type of a user-defined type guard function. ``TypeGuard`` only accepts a single type argument. @@ -1522,13 +1797,9 @@ elif sys.version_info[:2] >= (3, 9): """ item = typing._type_check(parameters, f'{self} accepts only a single type.') return typing._GenericAlias(self, (item,)) -# 3.7-3.8 +# 3.8 else: - class _TypeGuardForm(typing._SpecialForm, _root=True): - - def __repr__(self): - return 'typing_extensions.' + self._name - + class _TypeGuardForm(_ExtensionsSpecialForm, _root=True): def __getitem__(self, parameters): item = typing._type_check(parameters, f'{self._name} accepts only a single type') @@ -1579,6 +1850,98 @@ else: PEP 647 (User-Defined Type Guards). """) +# 3.13+ +if hasattr(typing, 'TypeIs'): + TypeIs = typing.TypeIs +# 3.9 +elif sys.version_info[:2] >= (3, 9): + @_ExtensionsSpecialForm + def TypeIs(self, parameters): + """Special typing form used to annotate the return type of a user-defined + type narrower function. ``TypeIs`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean. + + ``TypeIs`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type guard". + + Sometimes it would be convenient to use a user-defined boolean function + as a type guard. Such a function should use ``TypeIs[...]`` as its + return type to alert static type checkers to this intention. + + Using ``-> TypeIs`` tells the static type checker that for a given + function: + + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the intersection of the type inside ``TypeGuard`` and the argument's + previously known type. + + For example:: + + def is_awaitable(val: object) -> TypeIs[Awaitable[Any]]: + return hasattr(val, '__await__') + + def f(val: Union[int, Awaitable[int]]) -> int: + if is_awaitable(val): + assert_type(val, Awaitable[int]) + else: + assert_type(val, int) + + ``TypeIs`` also works with type variables. For more information, see + PEP 742 (Narrowing types with TypeIs). + """ + item = typing._type_check(parameters, f'{self} accepts only a single type.') + return typing._GenericAlias(self, (item,)) +# 3.8 +else: + class _TypeIsForm(_ExtensionsSpecialForm, _root=True): + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type') + return typing._GenericAlias(self, (item,)) + + TypeIs = _TypeIsForm( + 'TypeIs', + doc="""Special typing form used to annotate the return type of a user-defined + type narrower function. ``TypeIs`` only accepts a single type argument. + At runtime, functions marked this way should return a boolean. + + ``TypeIs`` aims to benefit *type narrowing* -- a technique used by static + type checkers to determine a more precise type of an expression within a + program's code flow. Usually type narrowing is done by analyzing + conditional code flow and applying the narrowing to a block of code. The + conditional expression here is sometimes referred to as a "type guard". + + Sometimes it would be convenient to use a user-defined boolean function + as a type guard. Such a function should use ``TypeIs[...]`` as its + return type to alert static type checkers to this intention. + + Using ``-> TypeIs`` tells the static type checker that for a given + function: + + 1. The return value is a boolean. + 2. If the return value is ``True``, the type of its argument + is the intersection of the type inside ``TypeGuard`` and the argument's + previously known type. + + For example:: + + def is_awaitable(val: object) -> TypeIs[Awaitable[Any]]: + return hasattr(val, '__await__') + + def f(val: Union[int, Awaitable[int]]) -> int: + if is_awaitable(val): + assert_type(val, Awaitable[int]) + else: + assert_type(val, int) + + ``TypeIs`` also works with type variables. For more information, see + PEP 742 (Narrowing types with TypeIs). + """) + # Vendored from cpython typing._SpecialFrom class _SpecialForm(typing._Final, _root=True): @@ -1624,7 +1987,7 @@ class _SpecialForm(typing._Final, _root=True): return self._getitem(self, parameters) -if hasattr(typing, "LiteralString"): +if hasattr(typing, "LiteralString"): # 3.11+ LiteralString = typing.LiteralString else: @_SpecialForm @@ -1647,7 +2010,7 @@ else: raise TypeError(f"{self} is not subscriptable") -if hasattr(typing, "Self"): +if hasattr(typing, "Self"): # 3.11+ Self = typing.Self else: @_SpecialForm @@ -1668,7 +2031,7 @@ else: raise TypeError(f"{self} is not subscriptable") -if hasattr(typing, "Never"): +if hasattr(typing, "Never"): # 3.11+ Never = typing.Never else: @_SpecialForm @@ -1698,14 +2061,10 @@ else: raise TypeError(f"{self} is not subscriptable") -if hasattr(typing, 'Required'): +if hasattr(typing, 'Required'): # 3.11+ Required = typing.Required NotRequired = typing.NotRequired -elif sys.version_info[:2] >= (3, 9): - class _ExtensionsSpecialForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name - +elif sys.version_info[:2] >= (3, 9): # 3.9-3.10 @_ExtensionsSpecialForm def Required(self, parameters): """A special typing construct to mark a key of a total=False TypedDict @@ -1743,11 +2102,8 @@ elif sys.version_info[:2] >= (3, 9): item = typing._type_check(parameters, f'{self._name} accepts only a single type.') return typing._GenericAlias(self, (item,)) -else: - class _RequiredForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name - +else: # 3.8 + class _RequiredForm(_ExtensionsSpecialForm, _root=True): def __getitem__(self, parameters): item = typing._type_check(parameters, f'{self._name} accepts only a single type.') @@ -1786,59 +2142,129 @@ else: """) -if hasattr(typing, "Unpack"): # 3.11+ +if hasattr(typing, 'ReadOnly'): + ReadOnly = typing.ReadOnly +elif sys.version_info[:2] >= (3, 9): # 3.9-3.12 + @_ExtensionsSpecialForm + def ReadOnly(self, parameters): + """A special typing construct to mark an item of a TypedDict as read-only. + + For example: + + class Movie(TypedDict): + title: ReadOnly[str] + year: int + + def mutate_movie(m: Movie) -> None: + m["year"] = 1992 # allowed + m["title"] = "The Matrix" # typechecker error + + There is no runtime checking for this property. + """ + item = typing._type_check(parameters, f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + +else: # 3.8 + class _ReadOnlyForm(_ExtensionsSpecialForm, _root=True): + def __getitem__(self, parameters): + item = typing._type_check(parameters, + f'{self._name} accepts only a single type.') + return typing._GenericAlias(self, (item,)) + + ReadOnly = _ReadOnlyForm( + 'ReadOnly', + doc="""A special typing construct to mark a key of a TypedDict as read-only. + + For example: + + class Movie(TypedDict): + title: ReadOnly[str] + year: int + + def mutate_movie(m: Movie) -> None: + m["year"] = 1992 # allowed + m["title"] = "The Matrix" # typechecker error + + There is no runtime checking for this propery. + """) + + +_UNPACK_DOC = """\ +Type unpack operator. + +The type unpack operator takes the child types from some container type, +such as `tuple[int, str]` or a `TypeVarTuple`, and 'pulls them out'. For +example: + + # For some generic class `Foo`: + Foo[Unpack[tuple[int, str]]] # Equivalent to Foo[int, str] + + Ts = TypeVarTuple('Ts') + # Specifies that `Bar` is generic in an arbitrary number of types. + # (Think of `Ts` as a tuple of an arbitrary number of individual + # `TypeVar`s, which the `Unpack` is 'pulling out' directly into the + # `Generic[]`.) + class Bar(Generic[Unpack[Ts]]): ... + Bar[int] # Valid + Bar[int, str] # Also valid + +From Python 3.11, this can also be done using the `*` operator: + + Foo[*tuple[int, str]] + class Bar(Generic[*Ts]): ... + +The operator can also be used along with a `TypedDict` to annotate +`**kwargs` in a function signature. For instance: + + class Movie(TypedDict): + name: str + year: int + + # This function expects two keyword arguments - *name* of type `str` and + # *year* of type `int`. + def foo(**kwargs: Unpack[Movie]): ... + +Note that there is only some runtime checking of this operator. Not +everything the runtime allows may be accepted by static type checkers. + +For more information, see PEP 646 and PEP 692. +""" + + +if sys.version_info >= (3, 12): # PEP 692 changed the repr of Unpack[] Unpack = typing.Unpack -elif sys.version_info[:2] >= (3, 9): - class _UnpackSpecialForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name + + def _is_unpack(obj): + return get_origin(obj) is Unpack + +elif sys.version_info[:2] >= (3, 9): # 3.9+ + class _UnpackSpecialForm(_ExtensionsSpecialForm, _root=True): + def __init__(self, getitem): + super().__init__(getitem) + self.__doc__ = _UNPACK_DOC class _UnpackAlias(typing._GenericAlias, _root=True): __class__ = typing.TypeVar @_UnpackSpecialForm def Unpack(self, parameters): - """A special typing construct to unpack a variadic type. For example: - - Shape = TypeVarTuple('Shape') - Batch = NewType('Batch', int) - - def add_batch_axis( - x: Array[Unpack[Shape]] - ) -> Array[Batch, Unpack[Shape]]: ... - - """ item = typing._type_check(parameters, f'{self._name} accepts only a single type.') return _UnpackAlias(self, (item,)) def _is_unpack(obj): return isinstance(obj, _UnpackAlias) -else: +else: # 3.8 class _UnpackAlias(typing._GenericAlias, _root=True): __class__ = typing.TypeVar - class _UnpackForm(typing._SpecialForm, _root=True): - def __repr__(self): - return 'typing_extensions.' + self._name - + class _UnpackForm(_ExtensionsSpecialForm, _root=True): def __getitem__(self, parameters): item = typing._type_check(parameters, f'{self._name} accepts only a single type.') return _UnpackAlias(self, (item,)) - Unpack = _UnpackForm( - 'Unpack', - doc="""A special typing construct to unpack a variadic type. For example: - - Shape = TypeVarTuple('Shape') - Batch = NewType('Batch', int) - - def add_batch_axis( - x: Array[Unpack[Shape]] - ) -> Array[Batch, Unpack[Shape]]: ... - - """) + Unpack = _UnpackForm('Unpack', doc=_UNPACK_DOC) def _is_unpack(obj): return isinstance(obj, _UnpackAlias) @@ -1846,23 +2272,22 @@ else: if hasattr(typing, "TypeVarTuple"): # 3.11+ - # Add default Parameter - PEP 696 - class TypeVarTuple(typing.TypeVarTuple, _DefaultMixin, _root=True): + # Add default parameter - PEP 696 + class TypeVarTuple(metaclass=_TypeVarLikeMeta): """Type variable tuple.""" - def __init__(self, name, *, default=None): - super().__init__(name) - _DefaultMixin.__init__(self, default) + _backported_typevarlike = typing.TypeVarTuple - # for pickling: - try: - def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - def_mod = None - if def_mod != 'typing_extensions': - self.__module__ = def_mod + def __new__(cls, name, *, default=_marker): + tvt = typing.TypeVarTuple(name) + _set_default(tvt, default) + _set_module(tvt) + return tvt -else: + def __init_subclass__(self, *args, **kwds): + raise TypeError("Cannot subclass special typing classes") + +else: # <=3.10 class TypeVarTuple(_DefaultMixin): """Type variable tuple. @@ -1913,15 +2338,12 @@ else: def __iter__(self): yield self.__unpacked__ - def __init__(self, name, *, default=None): + def __init__(self, name, *, default=_marker): self.__name__ = name _DefaultMixin.__init__(self, default) # for pickling: - try: - def_mod = sys._getframe(1).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): - def_mod = None + def_mod = _caller() if def_mod != 'typing_extensions': self.__module__ = def_mod @@ -1944,10 +2366,10 @@ else: raise TypeError("Cannot subclass special typing classes") -if hasattr(typing, "reveal_type"): +if hasattr(typing, "reveal_type"): # 3.11+ reveal_type = typing.reveal_type -else: - def reveal_type(__obj: T) -> T: +else: # <=3.10 + def reveal_type(obj: T, /) -> T: """Reveal the inferred type of a variable. When a static type checker encounters a call to ``reveal_type()``, @@ -1963,14 +2385,14 @@ else: argument and returns it unchanged. """ - print(f"Runtime type is {type(__obj).__name__!r}", file=sys.stderr) - return __obj + print(f"Runtime type is {type(obj).__name__!r}", file=sys.stderr) + return obj -if hasattr(typing, "assert_never"): +if hasattr(typing, "assert_never"): # 3.11+ assert_never = typing.assert_never -else: - def assert_never(__arg: Never) -> Never: +else: # <=3.10 + def assert_never(arg: Never, /) -> Never: """Assert to the type checker that a line of code is unreachable. Example:: @@ -1993,14 +2415,16 @@ else: raise AssertionError("Expected code to be unreachable") -if hasattr(typing, 'dataclass_transform'): +if sys.version_info >= (3, 12): # 3.12+ + # dataclass_transform exists in 3.11 but lacks the frozen_default parameter dataclass_transform = typing.dataclass_transform -else: +else: # <=3.11 def dataclass_transform( *, eq_default: bool = True, order_default: bool = False, kw_only_default: bool = False, + frozen_default: bool = False, field_specifiers: typing.Tuple[ typing.Union[typing.Type[typing.Any], typing.Callable[..., typing.Any]], ... @@ -2057,6 +2481,8 @@ else: assumed to be True or False if it is omitted by the caller. - ``kw_only_default`` indicates whether the ``kw_only`` parameter is assumed to be True or False if it is omitted by the caller. + - ``frozen_default`` indicates whether the ``frozen`` parameter is + assumed to be True or False if it is omitted by the caller. - ``field_specifiers`` specifies a static list of supported classes or functions that describe fields, similar to ``dataclasses.field()``. @@ -2071,6 +2497,7 @@ else: "eq_default": eq_default, "order_default": order_default, "kw_only_default": kw_only_default, + "frozen_default": frozen_default, "field_specifiers": field_specifiers, "kwargs": kwargs, } @@ -2078,18 +2505,18 @@ else: return decorator -if hasattr(typing, "override"): +if hasattr(typing, "override"): # 3.12+ override = typing.override -else: +else: # <=3.11 _F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any]) - def override(__arg: _F) -> _F: + def override(arg: _F, /) -> _F: """Indicate that a method is intended to override a method in a base class. Usage: class Base: - def method(self) -> None: ... + def method(self) -> None: pass class Child(Base): @@ -2102,10 +2529,156 @@ else: This helps prevent bugs that may occur when a base class is changed without an equivalent change to a child class. + There is no runtime checking of these properties. The decorator + sets the ``__override__`` attribute to ``True`` on the decorated object + to allow runtime introspection. + See PEP 698 for details. """ - return __arg + try: + arg.__override__ = True + except (AttributeError, TypeError): + # Skip the attribute silently if it is not writable. + # AttributeError happens if the object has __slots__ or a + # read-only property, TypeError if it's a builtin class. + pass + return arg + + +if hasattr(warnings, "deprecated"): + deprecated = warnings.deprecated +else: + _T = typing.TypeVar("_T") + + class deprecated: + """Indicate that a class, function or overload is deprecated. + + When this decorator is applied to an object, the type checker + will generate a diagnostic on usage of the deprecated object. + + Usage: + + @deprecated("Use B instead") + class A: + pass + + @deprecated("Use g instead") + def f(): + pass + + @overload + @deprecated("int support is deprecated") + def g(x: int) -> int: ... + @overload + def g(x: str) -> int: ... + + The warning specified by *category* will be emitted at runtime + on use of deprecated objects. For functions, that happens on calls; + for classes, on instantiation and on creation of subclasses. + If the *category* is ``None``, no warning is emitted at runtime. + The *stacklevel* determines where the + warning is emitted. If it is ``1`` (the default), the warning + is emitted at the direct caller of the deprecated object; if it + is higher, it is emitted further up the stack. + Static type checker behavior is not affected by the *category* + and *stacklevel* arguments. + + The deprecation message passed to the decorator is saved in the + ``__deprecated__`` attribute on the decorated object. + If applied to an overload, the decorator + must be after the ``@overload`` decorator for the attribute to + exist on the overload as returned by ``get_overloads()``. + + See PEP 702 for details. + + """ + def __init__( + self, + message: str, + /, + *, + category: typing.Optional[typing.Type[Warning]] = DeprecationWarning, + stacklevel: int = 1, + ) -> None: + if not isinstance(message, str): + raise TypeError( + "Expected an object of type str for 'message', not " + f"{type(message).__name__!r}" + ) + self.message = message + self.category = category + self.stacklevel = stacklevel + + def __call__(self, arg: _T, /) -> _T: + # Make sure the inner functions created below don't + # retain a reference to self. + msg = self.message + category = self.category + stacklevel = self.stacklevel + if category is None: + arg.__deprecated__ = msg + return arg + elif isinstance(arg, type): + import functools + from types import MethodType + + original_new = arg.__new__ + + @functools.wraps(original_new) + def __new__(cls, *args, **kwargs): + if cls is arg: + warnings.warn(msg, category=category, stacklevel=stacklevel + 1) + if original_new is not object.__new__: + return original_new(cls, *args, **kwargs) + # Mirrors a similar check in object.__new__. + elif cls.__init__ is object.__init__ and (args or kwargs): + raise TypeError(f"{cls.__name__}() takes no arguments") + else: + return original_new(cls) + + arg.__new__ = staticmethod(__new__) + + original_init_subclass = arg.__init_subclass__ + # We need slightly different behavior if __init_subclass__ + # is a bound method (likely if it was implemented in Python) + if isinstance(original_init_subclass, MethodType): + original_init_subclass = original_init_subclass.__func__ + + @functools.wraps(original_init_subclass) + def __init_subclass__(*args, **kwargs): + warnings.warn(msg, category=category, stacklevel=stacklevel + 1) + return original_init_subclass(*args, **kwargs) + + arg.__init_subclass__ = classmethod(__init_subclass__) + # Or otherwise, which likely means it's a builtin such as + # object's implementation of __init_subclass__. + else: + @functools.wraps(original_init_subclass) + def __init_subclass__(*args, **kwargs): + warnings.warn(msg, category=category, stacklevel=stacklevel + 1) + return original_init_subclass(*args, **kwargs) + + arg.__init_subclass__ = __init_subclass__ + + arg.__deprecated__ = __new__.__deprecated__ = msg + __init_subclass__.__deprecated__ = msg + return arg + elif callable(arg): + import functools + + @functools.wraps(arg) + def wrapper(*args, **kwargs): + warnings.warn(msg, category=category, stacklevel=stacklevel + 1) + return arg(*args, **kwargs) + + arg.__deprecated__ = wrapper.__deprecated__ = msg + return wrapper + else: + raise TypeError( + "@deprecated decorator with non-None category must be applied to " + f"a class or callable, not {arg!r}" + ) # We have to do some monkey patching to deal with the dual nature of @@ -2120,18 +2693,14 @@ if not hasattr(typing, "TypeVarTuple"): typing._check_generic = _check_generic -# Backport typing.NamedTuple as it exists in Python 3.11. +# Backport typing.NamedTuple as it exists in Python 3.13. # In 3.11, the ability to define generic `NamedTuple`s was supported. # This was explicitly disallowed in 3.9-3.10, and only half-worked in <=3.8. -if sys.version_info >= (3, 11): +# On 3.12, we added __orig_bases__ to call-based NamedTuples +# On 3.13, we deprecated kwargs-based NamedTuples +if sys.version_info >= (3, 13): NamedTuple = typing.NamedTuple else: - def _caller(): - try: - return sys._getframe(2).f_globals.get('__name__', '__main__') - except (AttributeError, ValueError): # For platforms without _getframe() - return None - def _make_nmtuple(name, types, module, defaults=()): fields = [n for n, t in types] annotations = {n: typing._type_check(t, f"field {n} annotation must be a type") @@ -2173,37 +2742,486 @@ else: ) nm_tpl.__bases__ = bases if typing.Generic in bases: - class_getitem = typing.Generic.__class_getitem__.__func__ - nm_tpl.__class_getitem__ = classmethod(class_getitem) + if hasattr(typing, '_generic_class_getitem'): # 3.12+ + nm_tpl.__class_getitem__ = classmethod(typing._generic_class_getitem) + else: + class_getitem = typing.Generic.__class_getitem__.__func__ + nm_tpl.__class_getitem__ = classmethod(class_getitem) # update from user namespace without overriding special namedtuple attributes - for key in ns: + for key, val in ns.items(): if key in _prohibited_namedtuple_fields: raise AttributeError("Cannot overwrite NamedTuple attribute " + key) - elif key not in _special_namedtuple_fields and key not in nm_tpl._fields: - setattr(nm_tpl, key, ns[key]) + elif key not in _special_namedtuple_fields: + if key not in nm_tpl._fields: + setattr(nm_tpl, key, ns[key]) + try: + set_name = type(val).__set_name__ + except AttributeError: + pass + else: + try: + set_name(val, nm_tpl, key) + except BaseException as e: + msg = ( + f"Error calling __set_name__ on {type(val).__name__!r} " + f"instance {key!r} in {typename!r}" + ) + # BaseException.add_note() existed on py311, + # but the __set_name__ machinery didn't start + # using add_note() until py312. + # Making sure exceptions are raised in the same way + # as in "normal" classes seems most important here. + if sys.version_info >= (3, 12): + e.add_note(msg) + raise + else: + raise RuntimeError(msg) from e + if typing.Generic in bases: nm_tpl.__init_subclass__() return nm_tpl - def NamedTuple(__typename, __fields=None, **kwargs): - if __fields is None: - __fields = kwargs.items() - elif kwargs: - raise TypeError("Either list of fields or keywords" - " can be provided to NamedTuple, not both") - return _make_nmtuple(__typename, __fields, module=_caller()) - - NamedTuple.__doc__ = typing.NamedTuple.__doc__ _NamedTuple = type.__new__(_NamedTupleMeta, 'NamedTuple', (), {}) - # On 3.8+, alter the signature so that it matches typing.NamedTuple. - # The signature of typing.NamedTuple on >=3.8 is invalid syntax in Python 3.7, - # so just leave the signature as it is on 3.7. - if sys.version_info >= (3, 8): - NamedTuple.__text_signature__ = '(typename, fields=None, /, **kwargs)' - def _namedtuple_mro_entries(bases): assert NamedTuple in bases return (_NamedTuple,) - NamedTuple.__mro_entries__ = _namedtuple_mro_entries + @_ensure_subclassable(_namedtuple_mro_entries) + def NamedTuple(typename, fields=_marker, /, **kwargs): + """Typed version of namedtuple. + + Usage:: + + class Employee(NamedTuple): + name: str + id: int + + This is equivalent to:: + + Employee = collections.namedtuple('Employee', ['name', 'id']) + + The resulting class has an extra __annotations__ attribute, giving a + dict that maps field names to types. (The field names are also in + the _fields attribute, which is part of the namedtuple API.) + An alternative equivalent functional syntax is also accepted:: + + Employee = NamedTuple('Employee', [('name', str), ('id', int)]) + """ + if fields is _marker: + if kwargs: + deprecated_thing = "Creating NamedTuple classes using keyword arguments" + deprecation_msg = ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "Use the class-based or functional syntax instead." + ) + else: + deprecated_thing = "Failing to pass a value for the 'fields' parameter" + example = f"`{typename} = NamedTuple({typename!r}, [])`" + deprecation_msg = ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "To create a NamedTuple class with 0 fields " + "using the functional syntax, " + "pass an empty list, e.g. " + ) + example + "." + elif fields is None: + if kwargs: + raise TypeError( + "Cannot pass `None` as the 'fields' parameter " + "and also specify fields using keyword arguments" + ) + else: + deprecated_thing = "Passing `None` as the 'fields' parameter" + example = f"`{typename} = NamedTuple({typename!r}, [])`" + deprecation_msg = ( + "{name} is deprecated and will be disallowed in Python {remove}. " + "To create a NamedTuple class with 0 fields " + "using the functional syntax, " + "pass an empty list, e.g. " + ) + example + "." + elif kwargs: + raise TypeError("Either list of fields or keywords" + " can be provided to NamedTuple, not both") + if fields is _marker or fields is None: + warnings.warn( + deprecation_msg.format(name=deprecated_thing, remove="3.15"), + DeprecationWarning, + stacklevel=2, + ) + fields = kwargs.items() + nt = _make_nmtuple(typename, fields, module=_caller()) + nt.__orig_bases__ = (NamedTuple,) + return nt + + +if hasattr(collections.abc, "Buffer"): + Buffer = collections.abc.Buffer +else: + class Buffer(abc.ABC): + """Base class for classes that implement the buffer protocol. + + The buffer protocol allows Python objects to expose a low-level + memory buffer interface. Before Python 3.12, it is not possible + to implement the buffer protocol in pure Python code, or even + to check whether a class implements the buffer protocol. In + Python 3.12 and higher, the ``__buffer__`` method allows access + to the buffer protocol from Python code, and the + ``collections.abc.Buffer`` ABC allows checking whether a class + implements the buffer protocol. + + To indicate support for the buffer protocol in earlier versions, + inherit from this ABC, either in a stub file or at runtime, + or use ABC registration. This ABC provides no methods, because + there is no Python-accessible methods shared by pre-3.12 buffer + classes. It is useful primarily for static checks. + + """ + + # As a courtesy, register the most common stdlib buffer classes. + Buffer.register(memoryview) + Buffer.register(bytearray) + Buffer.register(bytes) + + +# Backport of types.get_original_bases, available on 3.12+ in CPython +if hasattr(_types, "get_original_bases"): + get_original_bases = _types.get_original_bases +else: + def get_original_bases(cls, /): + """Return the class's "original" bases prior to modification by `__mro_entries__`. + + Examples:: + + from typing import TypeVar, Generic + from typing_extensions import NamedTuple, TypedDict + + T = TypeVar("T") + class Foo(Generic[T]): ... + class Bar(Foo[int], float): ... + class Baz(list[str]): ... + Eggs = NamedTuple("Eggs", [("a", int), ("b", str)]) + Spam = TypedDict("Spam", {"a": int, "b": str}) + + assert get_original_bases(Bar) == (Foo[int], float) + assert get_original_bases(Baz) == (list[str],) + assert get_original_bases(Eggs) == (NamedTuple,) + assert get_original_bases(Spam) == (TypedDict,) + assert get_original_bases(int) == (object,) + """ + try: + return cls.__dict__.get("__orig_bases__", cls.__bases__) + except AttributeError: + raise TypeError( + f'Expected an instance of type, not {type(cls).__name__!r}' + ) from None + + +# NewType is a class on Python 3.10+, making it pickleable +# The error message for subclassing instances of NewType was improved on 3.11+ +if sys.version_info >= (3, 11): + NewType = typing.NewType +else: + class NewType: + """NewType creates simple unique types with almost zero + runtime overhead. NewType(name, tp) is considered a subtype of tp + by static type checkers. At runtime, NewType(name, tp) returns + a dummy callable that simply returns its argument. Usage:: + UserId = NewType('UserId', int) + def name_by_id(user_id: UserId) -> str: + ... + UserId('user') # Fails type check + name_by_id(42) # Fails type check + name_by_id(UserId(42)) # OK + num = UserId(5) + 1 # type: int + """ + + def __call__(self, obj, /): + return obj + + def __init__(self, name, tp): + self.__qualname__ = name + if '.' in name: + name = name.rpartition('.')[-1] + self.__name__ = name + self.__supertype__ = tp + def_mod = _caller() + if def_mod != 'typing_extensions': + self.__module__ = def_mod + + def __mro_entries__(self, bases): + # We defined __mro_entries__ to get a better error message + # if a user attempts to subclass a NewType instance. bpo-46170 + supercls_name = self.__name__ + + class Dummy: + def __init_subclass__(cls): + subcls_name = cls.__name__ + raise TypeError( + f"Cannot subclass an instance of NewType. " + f"Perhaps you were looking for: " + f"`{subcls_name} = NewType({subcls_name!r}, {supercls_name})`" + ) + + return (Dummy,) + + def __repr__(self): + return f'{self.__module__}.{self.__qualname__}' + + def __reduce__(self): + return self.__qualname__ + + if sys.version_info >= (3, 10): + # PEP 604 methods + # It doesn't make sense to have these methods on Python <3.10 + + def __or__(self, other): + return typing.Union[self, other] + + def __ror__(self, other): + return typing.Union[other, self] + + +if hasattr(typing, "TypeAliasType"): + TypeAliasType = typing.TypeAliasType +else: + def _is_unionable(obj): + """Corresponds to is_unionable() in unionobject.c in CPython.""" + return obj is None or isinstance(obj, ( + type, + _types.GenericAlias, + _types.UnionType, + TypeAliasType, + )) + + class TypeAliasType: + """Create named, parameterized type aliases. + + This provides a backport of the new `type` statement in Python 3.12: + + type ListOrSet[T] = list[T] | set[T] + + is equivalent to: + + T = TypeVar("T") + ListOrSet = TypeAliasType("ListOrSet", list[T] | set[T], type_params=(T,)) + + The name ListOrSet can then be used as an alias for the type it refers to. + + The type_params argument should contain all the type parameters used + in the value of the type alias. If the alias is not generic, this + argument is omitted. + + Static type checkers should only support type aliases declared using + TypeAliasType that follow these rules: + + - The first argument (the name) must be a string literal. + - The TypeAliasType instance must be immediately assigned to a variable + of the same name. (For example, 'X = TypeAliasType("Y", int)' is invalid, + as is 'X, Y = TypeAliasType("X", int), TypeAliasType("Y", int)'). + + """ + + def __init__(self, name: str, value, *, type_params=()): + if not isinstance(name, str): + raise TypeError("TypeAliasType name must be a string") + self.__value__ = value + self.__type_params__ = type_params + + parameters = [] + for type_param in type_params: + if isinstance(type_param, TypeVarTuple): + parameters.extend(type_param) + else: + parameters.append(type_param) + self.__parameters__ = tuple(parameters) + def_mod = _caller() + if def_mod != 'typing_extensions': + self.__module__ = def_mod + # Setting this attribute closes the TypeAliasType from further modification + self.__name__ = name + + def __setattr__(self, name: str, value: object, /) -> None: + if hasattr(self, "__name__"): + self._raise_attribute_error(name) + super().__setattr__(name, value) + + def __delattr__(self, name: str, /) -> Never: + self._raise_attribute_error(name) + + def _raise_attribute_error(self, name: str) -> Never: + # Match the Python 3.12 error messages exactly + if name == "__name__": + raise AttributeError("readonly attribute") + elif name in {"__value__", "__type_params__", "__parameters__", "__module__"}: + raise AttributeError( + f"attribute '{name}' of 'typing.TypeAliasType' objects " + "is not writable" + ) + else: + raise AttributeError( + f"'typing.TypeAliasType' object has no attribute '{name}'" + ) + + def __repr__(self) -> str: + return self.__name__ + + def __getitem__(self, parameters): + if not isinstance(parameters, tuple): + parameters = (parameters,) + parameters = [ + typing._type_check( + item, f'Subscripting {self.__name__} requires a type.' + ) + for item in parameters + ] + return typing._GenericAlias(self, tuple(parameters)) + + def __reduce__(self): + return self.__name__ + + def __init_subclass__(cls, *args, **kwargs): + raise TypeError( + "type 'typing_extensions.TypeAliasType' is not an acceptable base type" + ) + + # The presence of this method convinces typing._type_check + # that TypeAliasTypes are types. + def __call__(self): + raise TypeError("Type alias is not callable") + + if sys.version_info >= (3, 10): + def __or__(self, right): + # For forward compatibility with 3.12, reject Unions + # that are not accepted by the built-in Union. + if not _is_unionable(right): + return NotImplemented + return typing.Union[self, right] + + def __ror__(self, left): + if not _is_unionable(left): + return NotImplemented + return typing.Union[left, self] + + +if hasattr(typing, "is_protocol"): + is_protocol = typing.is_protocol + get_protocol_members = typing.get_protocol_members +else: + def is_protocol(tp: type, /) -> bool: + """Return True if the given type is a Protocol. + + Example:: + + >>> from typing_extensions import Protocol, is_protocol + >>> class P(Protocol): + ... def a(self) -> str: ... + ... b: int + >>> is_protocol(P) + True + >>> is_protocol(int) + False + """ + return ( + isinstance(tp, type) + and getattr(tp, '_is_protocol', False) + and tp is not Protocol + and tp is not typing.Protocol + ) + + def get_protocol_members(tp: type, /) -> typing.FrozenSet[str]: + """Return the set of members defined in a Protocol. + + Example:: + + >>> from typing_extensions import Protocol, get_protocol_members + >>> class P(Protocol): + ... def a(self) -> str: ... + ... b: int + >>> get_protocol_members(P) + frozenset({'a', 'b'}) + + Raise a TypeError for arguments that are not Protocols. + """ + if not is_protocol(tp): + raise TypeError(f'{tp!r} is not a Protocol') + if hasattr(tp, '__protocol_attrs__'): + return frozenset(tp.__protocol_attrs__) + return frozenset(_get_protocol_attrs(tp)) + + +if hasattr(typing, "Doc"): + Doc = typing.Doc +else: + class Doc: + """Define the documentation of a type annotation using ``Annotated``, to be + used in class attributes, function and method parameters, return values, + and variables. + + The value should be a positional-only string literal to allow static tools + like editors and documentation generators to use it. + + This complements docstrings. + + The string value passed is available in the attribute ``documentation``. + + Example:: + + >>> from typing_extensions import Annotated, Doc + >>> def hi(to: Annotated[str, Doc("Who to say hi to")]) -> None: ... + """ + def __init__(self, documentation: str, /) -> None: + self.documentation = documentation + + def __repr__(self) -> str: + return f"Doc({self.documentation!r})" + + def __hash__(self) -> int: + return hash(self.documentation) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Doc): + return NotImplemented + return self.documentation == other.documentation + + +# Aliases for items that have always been in typing. +# Explicitly assign these (rather than using `from typing import *` at the top), +# so that we get a CI error if one of these is deleted from typing.py +# in a future version of Python +AbstractSet = typing.AbstractSet +AnyStr = typing.AnyStr +BinaryIO = typing.BinaryIO +Callable = typing.Callable +Collection = typing.Collection +Container = typing.Container +Dict = typing.Dict +ForwardRef = typing.ForwardRef +FrozenSet = typing.FrozenSet +Generator = typing.Generator +Generic = typing.Generic +Hashable = typing.Hashable +IO = typing.IO +ItemsView = typing.ItemsView +Iterable = typing.Iterable +Iterator = typing.Iterator +KeysView = typing.KeysView +List = typing.List +Mapping = typing.Mapping +MappingView = typing.MappingView +Match = typing.Match +MutableMapping = typing.MutableMapping +MutableSequence = typing.MutableSequence +MutableSet = typing.MutableSet +Optional = typing.Optional +Pattern = typing.Pattern +Reversible = typing.Reversible +Sequence = typing.Sequence +Set = typing.Set +Sized = typing.Sized +TextIO = typing.TextIO +Tuple = typing.Tuple +Union = typing.Union +ValuesView = typing.ValuesView +cast = typing.cast +no_type_check = typing.no_type_check +no_type_check_decorator = typing.no_type_check_decorator diff --git a/lib/zc/lockfile/__init__.py b/lib/zc/lockfile/__init__.py index b541fa2d..f93917b0 100644 --- a/lib/zc/lockfile/__init__.py +++ b/lib/zc/lockfile/__init__.py @@ -11,18 +11,18 @@ # FOR A PARTICULAR PURPOSE # ############################################################################## - -import os -import errno import logging +import os + + logger = logging.getLogger("zc.lockfile") -__metaclass__ = type class LockError(Exception): """Couldn't get a lock """ + try: import fcntl except ImportError: @@ -31,6 +31,7 @@ except ImportError: except ImportError: def _lock_file(file): raise TypeError('No file-locking support on this platform') + def _unlock_file(file): raise TypeError('No file-locking support on this platform') @@ -40,14 +41,14 @@ except ImportError: # Lock just the first byte try: msvcrt.locking(file.fileno(), msvcrt.LK_NBLCK, 1) - except IOError: + except OSError: raise LockError("Couldn't lock %r" % file.name) def _unlock_file(file): try: file.seek(0) msvcrt.locking(file.fileno(), msvcrt.LK_UNLCK, 1) - except IOError: + except OSError: raise LockError("Couldn't unlock %r" % file.name) else: @@ -57,14 +58,16 @@ else: def _lock_file(file): try: fcntl.flock(file.fileno(), _flags) - except IOError: + except OSError: raise LockError("Couldn't lock %r" % file.name) def _unlock_file(file): fcntl.flock(file.fileno(), fcntl.LOCK_UN) + class LazyHostName: """Avoid importing socket and calling gethostname() unnecessarily""" + def __str__(self): import socket return socket.gethostname() @@ -79,7 +82,7 @@ class SimpleLockFile: try: # Try to open for writing without truncation: fp = open(path, 'r+') - except IOError: + except OSError: # If the file doesn't exist, we'll get an IO error, try a+ # Note that there may be a race here. Multiple processes # could fail on the r+ open and open the file a+, but only @@ -89,7 +92,7 @@ class SimpleLockFile: try: _lock_file(fp) self._fp = fp - except: + except BaseException: fp.close() raise @@ -114,7 +117,7 @@ class LockFile(SimpleLockFile): def __init__(self, path, content_template='{pid}'): self._content_template = content_template - super(LockFile, self).__init__(path) + super().__init__(path) def _on_lock(self): content = self._content_template.format( diff --git a/lib/zc/lockfile/tests.py b/lib/zc/lockfile/tests.py index 4c890539..ae9ffca5 100644 --- a/lib/zc/lockfile/tests.py +++ b/lib/zc/lockfile/tests.py @@ -11,23 +11,22 @@ # FOR A PARTICULAR PURPOSE. # ############################################################################## -import os, re, sys, unittest, doctest -import zc.lockfile, time, threading -from zope.testing import renormalizing, setupstack +import doctest +import os import tempfile -try: - from unittest.mock import Mock, patch -except ImportError: - from mock import Mock, patch +import threading +import time +import unittest +from unittest.mock import Mock +from unittest.mock import patch + +from zope.testing import setupstack + +import zc.lockfile -checker = renormalizing.RENormalizing([ - # Python 3 adds module path to error class name. - (re.compile("zc\.lockfile\.LockError:"), - r"LockError:"), - ]) def inc(): - while 1: + while True: try: lock = zc.lockfile.LockFile('f.lock') except zc.lockfile.LockError: @@ -43,6 +42,7 @@ def inc(): f.close() lock.close() + def many_threads_read_and_write(): r""" >>> with open('f', 'w+b') as file: @@ -72,6 +72,7 @@ def many_threads_read_and_write(): """ + def pid_in_lockfile(): r""" >>> import os, zc.lockfile @@ -88,7 +89,7 @@ def pid_in_lockfile(): >>> lock = zc.lockfile.LockFile("f.lock") Traceback (most recent call last): ... - LockError: Couldn't lock 'f.lock' + zc.lockfile.LockError: Couldn't lock 'f.lock' >>> f = open("f.lock") >>> _ = f.seek(1) @@ -107,7 +108,8 @@ def hostname_in_lockfile(): >>> import zc.lockfile >>> with patch('socket.gethostname', Mock(return_value='myhostname')): - ... lock = zc.lockfile.LockFile("f.lock", content_template='{hostname}') + ... lock = zc.lockfile.LockFile( + ... "f.lock", content_template='{hostname}') >>> f = open("f.lock") >>> _ = f.seek(1) >>> f.read().rstrip() @@ -119,7 +121,7 @@ def hostname_in_lockfile(): >>> lock = zc.lockfile.LockFile("f.lock", content_template='{hostname}') Traceback (most recent call last): ... - LockError: Couldn't lock 'f.lock' + zc.lockfile.LockError: Couldn't lock 'f.lock' >>> f = open("f.lock") >>> _ = f.seek(1) @@ -131,7 +133,7 @@ def hostname_in_lockfile(): """ -class TestLogger(object): +class TestLogger: def __init__(self): self.log_entries = [] @@ -141,6 +143,7 @@ class TestLogger(object): class LockFileLogEntryTestCase(unittest.TestCase): """Tests for logging in case of lock failure""" + def setUp(self): self.here = os.getcwd() self.tmp = tempfile.mkdtemp(prefix='zc.lockfile-test-') @@ -154,8 +157,8 @@ class LockFileLogEntryTestCase(unittest.TestCase): # PID and hostname are parsed and logged from lock file on failure with patch('os.getpid', Mock(return_value=123)): with patch('socket.gethostname', Mock(return_value='myhostname')): - lock = zc.lockfile.LockFile('f.lock', - content_template='{pid}/{hostname}') + lock = zc.lockfile.LockFile( + 'f.lock', content_template='{pid}/{hostname}') with open('f.lock') as f: self.assertEqual(' 123/myhostname\n', f.read()) @@ -191,11 +194,10 @@ class LockFileLogEntryTestCase(unittest.TestCase): def test_suite(): suite = unittest.TestSuite() suite.addTest(doctest.DocFileSuite( - 'README.txt', checker=checker, + 'README.txt', setUp=setupstack.setUpDirectory, tearDown=setupstack.tearDown)) suite.addTest(doctest.DocTestSuite( - setUp=setupstack.setUpDirectory, tearDown=setupstack.tearDown, - checker=checker)) + setUp=setupstack.setUpDirectory, tearDown=setupstack.tearDown)) # Add unittest test cases from this module suite.addTest(unittest.defaultTestLoader.loadTestsFromName(__name__)) return suite