Revert "Bump cherrypy from 18.8.0 to 18.9.0 (#2266)"

This reverts commit faef9a94c4.
This commit is contained in:
JonnyWong16 2024-03-24 17:40:56 -07:00
parent fcd8ef11f4
commit 2fc618c01f
No known key found for this signature in database
GPG key ID: B1F1F9807184697A
673 changed files with 11579 additions and 159846 deletions

Binary file not shown.

View file

@ -1,74 +0,0 @@
"""adodbapi - A python DB API 2.0 (PEP 249) interface to Microsoft ADO
Copyright (C) 2002 Henrik Ekelund, version 2.1 by Vernon Cole
* http://sourceforge.net/projects/adodbapi
"""
import sys
import time
from .adodbapi import Connection, Cursor, __version__, connect, dateconverter
from .apibase import (
BINARY,
DATETIME,
NUMBER,
ROWID,
STRING,
DatabaseError,
DataError,
Error,
FetchFailedError,
IntegrityError,
InterfaceError,
InternalError,
NotSupportedError,
OperationalError,
ProgrammingError,
Warning,
apilevel,
paramstyle,
threadsafety,
)
def Binary(aString):
"""This function constructs an object capable of holding a binary (long) string value."""
return bytes(aString)
def Date(year, month, day):
"This function constructs an object holding a date value."
return dateconverter.Date(year, month, day)
def Time(hour, minute, second):
"This function constructs an object holding a time value."
return dateconverter.Time(hour, minute, second)
def Timestamp(year, month, day, hour, minute, second):
"This function constructs an object holding a time stamp value."
return dateconverter.Timestamp(year, month, day, hour, minute, second)
def DateFromTicks(ticks):
"""This function constructs an object holding a date value from the given ticks value
(number of seconds since the epoch; see the documentation of the standard Python time module for details).
"""
return Date(*time.gmtime(ticks)[:3])
def TimeFromTicks(ticks):
"""This function constructs an object holding a time value from the given ticks value
(number of seconds since the epoch; see the documentation of the standard Python time module for details).
"""
return Time(*time.gmtime(ticks)[3:6])
def TimestampFromTicks(ticks):
"""This function constructs an object holding a time stamp value from the given
ticks value (number of seconds since the epoch;
see the documentation of the standard Python time module for details)."""
return Timestamp(*time.gmtime(ticks)[:6])
version = "adodbapi v" + __version__

View file

@ -1,281 +0,0 @@
# ADO enumerated constants documented on MSDN:
# http://msdn.microsoft.com/en-us/library/ms678353(VS.85).aspx
# IsolationLevelEnum
adXactUnspecified = -1
adXactBrowse = 0x100
adXactChaos = 0x10
adXactCursorStability = 0x1000
adXactIsolated = 0x100000
adXactReadCommitted = 0x1000
adXactReadUncommitted = 0x100
adXactRepeatableRead = 0x10000
adXactSerializable = 0x100000
# CursorLocationEnum
adUseClient = 3
adUseServer = 2
# CursorTypeEnum
adOpenDynamic = 2
adOpenForwardOnly = 0
adOpenKeyset = 1
adOpenStatic = 3
adOpenUnspecified = -1
# CommandTypeEnum
adCmdText = 1
adCmdStoredProc = 4
adSchemaTables = 20
# ParameterDirectionEnum
adParamInput = 1
adParamInputOutput = 3
adParamOutput = 2
adParamReturnValue = 4
adParamUnknown = 0
directions = {
0: "Unknown",
1: "Input",
2: "Output",
3: "InputOutput",
4: "Return",
}
def ado_direction_name(ado_dir):
try:
return "adParam" + directions[ado_dir]
except:
return "unknown direction (" + str(ado_dir) + ")"
# ObjectStateEnum
adStateClosed = 0
adStateOpen = 1
adStateConnecting = 2
adStateExecuting = 4
adStateFetching = 8
# FieldAttributeEnum
adFldMayBeNull = 0x40
# ConnectModeEnum
adModeUnknown = 0
adModeRead = 1
adModeWrite = 2
adModeReadWrite = 3
adModeShareDenyRead = 4
adModeShareDenyWrite = 8
adModeShareExclusive = 12
adModeShareDenyNone = 16
adModeRecursive = 0x400000
# XactAttributeEnum
adXactCommitRetaining = 131072
adXactAbortRetaining = 262144
ado_error_TIMEOUT = -2147217871
# DataTypeEnum - ADO Data types documented at:
# http://msdn2.microsoft.com/en-us/library/ms675318.aspx
adArray = 0x2000
adEmpty = 0x0
adBSTR = 0x8
adBigInt = 0x14
adBinary = 0x80
adBoolean = 0xB
adChapter = 0x88
adChar = 0x81
adCurrency = 0x6
adDBDate = 0x85
adDBTime = 0x86
adDBTimeStamp = 0x87
adDate = 0x7
adDecimal = 0xE
adDouble = 0x5
adError = 0xA
adFileTime = 0x40
adGUID = 0x48
adIDispatch = 0x9
adIUnknown = 0xD
adInteger = 0x3
adLongVarBinary = 0xCD
adLongVarChar = 0xC9
adLongVarWChar = 0xCB
adNumeric = 0x83
adPropVariant = 0x8A
adSingle = 0x4
adSmallInt = 0x2
adTinyInt = 0x10
adUnsignedBigInt = 0x15
adUnsignedInt = 0x13
adUnsignedSmallInt = 0x12
adUnsignedTinyInt = 0x11
adUserDefined = 0x84
adVarBinary = 0xCC
adVarChar = 0xC8
adVarNumeric = 0x8B
adVarWChar = 0xCA
adVariant = 0xC
adWChar = 0x82
# Additional constants used by introspection but not ADO itself
AUTO_FIELD_MARKER = -1000
adTypeNames = {
adBSTR: "adBSTR",
adBigInt: "adBigInt",
adBinary: "adBinary",
adBoolean: "adBoolean",
adChapter: "adChapter",
adChar: "adChar",
adCurrency: "adCurrency",
adDBDate: "adDBDate",
adDBTime: "adDBTime",
adDBTimeStamp: "adDBTimeStamp",
adDate: "adDate",
adDecimal: "adDecimal",
adDouble: "adDouble",
adEmpty: "adEmpty",
adError: "adError",
adFileTime: "adFileTime",
adGUID: "adGUID",
adIDispatch: "adIDispatch",
adIUnknown: "adIUnknown",
adInteger: "adInteger",
adLongVarBinary: "adLongVarBinary",
adLongVarChar: "adLongVarChar",
adLongVarWChar: "adLongVarWChar",
adNumeric: "adNumeric",
adPropVariant: "adPropVariant",
adSingle: "adSingle",
adSmallInt: "adSmallInt",
adTinyInt: "adTinyInt",
adUnsignedBigInt: "adUnsignedBigInt",
adUnsignedInt: "adUnsignedInt",
adUnsignedSmallInt: "adUnsignedSmallInt",
adUnsignedTinyInt: "adUnsignedTinyInt",
adUserDefined: "adUserDefined",
adVarBinary: "adVarBinary",
adVarChar: "adVarChar",
adVarNumeric: "adVarNumeric",
adVarWChar: "adVarWChar",
adVariant: "adVariant",
adWChar: "adWChar",
}
def ado_type_name(ado_type):
return adTypeNames.get(ado_type, "unknown type (" + str(ado_type) + ")")
# here in decimal, sorted by value
# adEmpty 0 Specifies no value (DBTYPE_EMPTY).
# adSmallInt 2 Indicates a two-byte signed integer (DBTYPE_I2).
# adInteger 3 Indicates a four-byte signed integer (DBTYPE_I4).
# adSingle 4 Indicates a single-precision floating-point value (DBTYPE_R4).
# adDouble 5 Indicates a double-precision floating-point value (DBTYPE_R8).
# adCurrency 6 Indicates a currency value (DBTYPE_CY). Currency is a fixed-point number
# with four digits to the right of the decimal point. It is stored in an eight-byte signed integer scaled by 10,000.
# adDate 7 Indicates a date value (DBTYPE_DATE). A date is stored as a double, the whole part of which is
# the number of days since December 30, 1899, and the fractional part of which is the fraction of a day.
# adBSTR 8 Indicates a null-terminated character string (Unicode) (DBTYPE_BSTR).
# adIDispatch 9 Indicates a pointer to an IDispatch interface on a COM object (DBTYPE_IDISPATCH).
# adError 10 Indicates a 32-bit error code (DBTYPE_ERROR).
# adBoolean 11 Indicates a boolean value (DBTYPE_BOOL).
# adVariant 12 Indicates an Automation Variant (DBTYPE_VARIANT).
# adIUnknown 13 Indicates a pointer to an IUnknown interface on a COM object (DBTYPE_IUNKNOWN).
# adDecimal 14 Indicates an exact numeric value with a fixed precision and scale (DBTYPE_DECIMAL).
# adTinyInt 16 Indicates a one-byte signed integer (DBTYPE_I1).
# adUnsignedTinyInt 17 Indicates a one-byte unsigned integer (DBTYPE_UI1).
# adUnsignedSmallInt 18 Indicates a two-byte unsigned integer (DBTYPE_UI2).
# adUnsignedInt 19 Indicates a four-byte unsigned integer (DBTYPE_UI4).
# adBigInt 20 Indicates an eight-byte signed integer (DBTYPE_I8).
# adUnsignedBigInt 21 Indicates an eight-byte unsigned integer (DBTYPE_UI8).
# adFileTime 64 Indicates a 64-bit value representing the number of 100-nanosecond intervals since
# January 1, 1601 (DBTYPE_FILETIME).
# adGUID 72 Indicates a globally unique identifier (GUID) (DBTYPE_GUID).
# adBinary 128 Indicates a binary value (DBTYPE_BYTES).
# adChar 129 Indicates a string value (DBTYPE_STR).
# adWChar 130 Indicates a null-terminated Unicode character string (DBTYPE_WSTR).
# adNumeric 131 Indicates an exact numeric value with a fixed precision and scale (DBTYPE_NUMERIC).
# adUserDefined 132 Indicates a user-defined variable (DBTYPE_UDT).
# adUserDefined 132 Indicates a user-defined variable (DBTYPE_UDT).
# adDBDate 133 Indicates a date value (yyyymmdd) (DBTYPE_DBDATE).
# adDBTime 134 Indicates a time value (hhmmss) (DBTYPE_DBTIME).
# adDBTimeStamp 135 Indicates a date/time stamp (yyyymmddhhmmss plus a fraction in billionths) (DBTYPE_DBTIMESTAMP).
# adChapter 136 Indicates a four-byte chapter value that identifies rows in a child rowset (DBTYPE_HCHAPTER).
# adPropVariant 138 Indicates an Automation PROPVARIANT (DBTYPE_PROP_VARIANT).
# adVarNumeric 139 Indicates a numeric value (Parameter object only).
# adVarChar 200 Indicates a string value (Parameter object only).
# adLongVarChar 201 Indicates a long string value (Parameter object only).
# adVarWChar 202 Indicates a null-terminated Unicode character string (Parameter object only).
# adLongVarWChar 203 Indicates a long null-terminated Unicode string value (Parameter object only).
# adVarBinary 204 Indicates a binary value (Parameter object only).
# adLongVarBinary 205 Indicates a long binary value (Parameter object only).
# adArray (Does not apply to ADOX.) 0x2000 A flag value, always combined with another data type constant,
# that indicates an array of that other data type.
# Error codes to names
adoErrors = {
0xE7B: "adErrBoundToCommand",
0xE94: "adErrCannotComplete",
0xEA4: "adErrCantChangeConnection",
0xC94: "adErrCantChangeProvider",
0xE8C: "adErrCantConvertvalue",
0xE8D: "adErrCantCreate",
0xEA3: "adErrCatalogNotSet",
0xE8E: "adErrColumnNotOnThisRow",
0xD5D: "adErrDataConversion",
0xE89: "adErrDataOverflow",
0xE9A: "adErrDelResOutOfScope",
0xEA6: "adErrDenyNotSupported",
0xEA7: "adErrDenyTypeNotSupported",
0xCB3: "adErrFeatureNotAvailable",
0xEA5: "adErrFieldsUpdateFailed",
0xC93: "adErrIllegalOperation",
0xCAE: "adErrInTransaction",
0xE87: "adErrIntegrityViolation",
0xBB9: "adErrInvalidArgument",
0xE7D: "adErrInvalidConnection",
0xE7C: "adErrInvalidParamInfo",
0xE82: "adErrInvalidTransaction",
0xE91: "adErrInvalidURL",
0xCC1: "adErrItemNotFound",
0xBCD: "adErrNoCurrentRecord",
0xE83: "adErrNotExecuting",
0xE7E: "adErrNotReentrant",
0xE78: "adErrObjectClosed",
0xD27: "adErrObjectInCollection",
0xD5C: "adErrObjectNotSet",
0xE79: "adErrObjectOpen",
0xBBA: "adErrOpeningFile",
0xE80: "adErrOperationCancelled",
0xE96: "adErrOutOfSpace",
0xE88: "adErrPermissionDenied",
0xE9E: "adErrPropConflicting",
0xE9B: "adErrPropInvalidColumn",
0xE9C: "adErrPropInvalidOption",
0xE9D: "adErrPropInvalidValue",
0xE9F: "adErrPropNotAllSettable",
0xEA0: "adErrPropNotSet",
0xEA1: "adErrPropNotSettable",
0xEA2: "adErrPropNotSupported",
0xBB8: "adErrProviderFailed",
0xE7A: "adErrProviderNotFound",
0xBBB: "adErrReadFile",
0xE93: "adErrResourceExists",
0xE92: "adErrResourceLocked",
0xE97: "adErrResourceOutOfScope",
0xE8A: "adErrSchemaViolation",
0xE8B: "adErrSignMismatch",
0xE81: "adErrStillConnecting",
0xE7F: "adErrStillExecuting",
0xE90: "adErrTreePermissionDenied",
0xE8F: "adErrURLDoesNotExist",
0xE99: "adErrURLNamedRowDoesNotExist",
0xE98: "adErrUnavailable",
0xE84: "adErrUnsafeOperation",
0xE95: "adErrVolumeNotFound",
0xBBC: "adErrWriteFile",
}

File diff suppressed because it is too large Load diff

View file

@ -1,794 +0,0 @@
"""adodbapi.apibase - A python DB API 2.0 (PEP 249) interface to Microsoft ADO
Copyright (C) 2002 Henrik Ekelund, version 2.1 by Vernon Cole
* http://sourceforge.net/projects/pywin32
* http://sourceforge.net/projects/adodbapi
"""
import datetime
import decimal
import numbers
import sys
import time
# noinspection PyUnresolvedReferences
from . import ado_consts as adc
verbose = False # debugging flag
onIronPython = sys.platform == "cli"
if onIronPython: # we need type definitions for odd data we may need to convert
# noinspection PyUnresolvedReferences
from System import DateTime, DBNull
NullTypes = (type(None), DBNull)
else:
DateTime = type(NotImplemented) # should never be seen on win32
NullTypes = type(None)
# --- define objects to smooth out Python3 <-> Python 2.x differences
unicodeType = str
longType = int
StringTypes = str
makeByteBuffer = bytes
memoryViewType = memoryview
_BaseException = Exception
try: # jdhardy -- handle bytes under IronPython & Py3
bytes
except NameError:
bytes = str # define it for old Pythons
# ------- Error handlers ------
def standardErrorHandler(connection, cursor, errorclass, errorvalue):
err = (errorclass, errorvalue)
try:
connection.messages.append(err)
except:
pass
if cursor is not None:
try:
cursor.messages.append(err)
except:
pass
raise errorclass(errorvalue)
# Note: _BaseException is defined differently between Python 2.x and 3.x
class Error(_BaseException):
pass # Exception that is the base class of all other error
# exceptions. You can use this to catch all errors with one
# single 'except' statement. Warnings are not considered
# errors and thus should not use this class as base. It must
# be a subclass of the Python StandardError (defined in the
# module exceptions).
class Warning(_BaseException):
pass
class InterfaceError(Error):
pass
class DatabaseError(Error):
pass
class InternalError(DatabaseError):
pass
class OperationalError(DatabaseError):
pass
class ProgrammingError(DatabaseError):
pass
class IntegrityError(DatabaseError):
pass
class DataError(DatabaseError):
pass
class NotSupportedError(DatabaseError):
pass
class FetchFailedError(OperationalError):
"""
Error is used by RawStoredProcedureQuerySet to determine when a fetch
failed due to a connection being closed or there is no record set
returned. (Non-standard, added especially for django)
"""
pass
# # # # # ----- Type Objects and Constructors ----- # # # # #
# Many databases need to have the input in a particular format for binding to an operation's input parameters.
# For example, if an input is destined for a DATE column, then it must be bound to the database in a particular
# string format. Similar problems exist for "Row ID" columns or large binary items (e.g. blobs or RAW columns).
# This presents problems for Python since the parameters to the executeXXX() method are untyped.
# When the database module sees a Python string object, it doesn't know if it should be bound as a simple CHAR
# column, as a raw BINARY item, or as a DATE.
#
# To overcome this problem, a module must provide the constructors defined below to create objects that can
# hold special values. When passed to the cursor methods, the module can then detect the proper type of
# the input parameter and bind it accordingly.
# A Cursor Object's description attribute returns information about each of the result columns of a query.
# The type_code must compare equal to one of Type Objects defined below. Type Objects may be equal to more than
# one type code (e.g. DATETIME could be equal to the type codes for date, time and timestamp columns;
# see the Implementation Hints below for details).
# SQL NULL values are represented by the Python None singleton on input and output.
# Note: Usage of Unix ticks for database interfacing can cause troubles because of the limited date range they cover.
# def Date(year,month,day):
# "This function constructs an object holding a date value. "
# return dateconverter.date(year,month,day) #dateconverter.Date(year,month,day)
#
# def Time(hour,minute,second):
# "This function constructs an object holding a time value. "
# return dateconverter.time(hour, minute, second) # dateconverter.Time(hour,minute,second)
#
# def Timestamp(year,month,day,hour,minute,second):
# "This function constructs an object holding a time stamp value. "
# return dateconverter.datetime(year,month,day,hour,minute,second)
#
# def DateFromTicks(ticks):
# """This function constructs an object holding a date value from the given ticks value
# (number of seconds since the epoch; see the documentation of the standard Python time module for details). """
# return Date(*time.gmtime(ticks)[:3])
#
# def TimeFromTicks(ticks):
# """This function constructs an object holding a time value from the given ticks value
# (number of seconds since the epoch; see the documentation of the standard Python time module for details). """
# return Time(*time.gmtime(ticks)[3:6])
#
# def TimestampFromTicks(ticks):
# """This function constructs an object holding a time stamp value from the given
# ticks value (number of seconds since the epoch;
# see the documentation of the standard Python time module for details). """
# return Timestamp(*time.gmtime(ticks)[:6])
#
# def Binary(aString):
# """This function constructs an object capable of holding a binary (long) string value. """
# b = makeByteBuffer(aString)
# return b
# ----- Time converters ----------------------------------------------
class TimeConverter(object): # this is a generic time converter skeleton
def __init__(self): # the details will be filled in by instances
self._ordinal_1899_12_31 = datetime.date(1899, 12, 31).toordinal() - 1
# Use cls.types to compare if an input parameter is a datetime
self.types = {
type(self.Date(2000, 1, 1)),
type(self.Time(12, 1, 1)),
type(self.Timestamp(2000, 1, 1, 12, 1, 1)),
datetime.datetime,
datetime.time,
datetime.date,
}
def COMDate(self, obj):
"""Returns a ComDate from a date-time"""
try: # most likely a datetime
tt = obj.timetuple()
try:
ms = obj.microsecond
except:
ms = 0
return self.ComDateFromTuple(tt, ms)
except: # might be a tuple
try:
return self.ComDateFromTuple(obj)
except: # try an mxdate
try:
return obj.COMDate()
except:
raise ValueError('Cannot convert "%s" to COMdate.' % repr(obj))
def ComDateFromTuple(self, t, microseconds=0):
d = datetime.date(t[0], t[1], t[2])
integerPart = d.toordinal() - self._ordinal_1899_12_31
ms = (t[3] * 3600 + t[4] * 60 + t[5]) * 1000000 + microseconds
fractPart = float(ms) / 86400000000.0
return integerPart + fractPart
def DateObjectFromCOMDate(self, comDate):
"Returns an object of the wanted type from a ComDate"
raise NotImplementedError # "Abstract class"
def Date(self, year, month, day):
"This function constructs an object holding a date value."
raise NotImplementedError # "Abstract class"
def Time(self, hour, minute, second):
"This function constructs an object holding a time value."
raise NotImplementedError # "Abstract class"
def Timestamp(self, year, month, day, hour, minute, second):
"This function constructs an object holding a time stamp value."
raise NotImplementedError # "Abstract class"
# all purpose date to ISO format converter
def DateObjectToIsoFormatString(self, obj):
"This function should return a string in the format 'YYYY-MM-dd HH:MM:SS:ms' (ms optional)"
try: # most likely, a datetime.datetime
s = obj.isoformat(" ")
except (TypeError, AttributeError):
if isinstance(obj, datetime.date):
s = obj.isoformat() + " 00:00:00" # return exact midnight
else:
try: # maybe it has a strftime method, like mx
s = obj.strftime("%Y-%m-%d %H:%M:%S")
except AttributeError:
try: # but may be time.struct_time
s = time.strftime("%Y-%m-%d %H:%M:%S", obj)
except:
raise ValueError('Cannot convert "%s" to isoformat' % repr(obj))
return s
# -- Optional: if mx extensions are installed you may use mxDateTime ----
try:
import mx.DateTime
mxDateTime = True
except:
mxDateTime = False
if mxDateTime:
class mxDateTimeConverter(TimeConverter): # used optionally if installed
def __init__(self):
TimeConverter.__init__(self)
self.types.add(type(mx.DateTime))
def DateObjectFromCOMDate(self, comDate):
return mx.DateTime.DateTimeFromCOMDate(comDate)
def Date(self, year, month, day):
return mx.DateTime.Date(year, month, day)
def Time(self, hour, minute, second):
return mx.DateTime.Time(hour, minute, second)
def Timestamp(self, year, month, day, hour, minute, second):
return mx.DateTime.Timestamp(year, month, day, hour, minute, second)
else:
class mxDateTimeConverter(TimeConverter):
pass # if no mx is installed
class pythonDateTimeConverter(TimeConverter): # standard since Python 2.3
def __init__(self):
TimeConverter.__init__(self)
def DateObjectFromCOMDate(self, comDate):
if isinstance(comDate, datetime.datetime):
odn = comDate.toordinal()
tim = comDate.time()
new = datetime.datetime.combine(datetime.datetime.fromordinal(odn), tim)
return new
# return comDate.replace(tzinfo=None) # make non aware
elif isinstance(comDate, DateTime):
fComDate = comDate.ToOADate() # ironPython clr Date/Time
else:
fComDate = float(comDate) # ComDate is number of days since 1899-12-31
integerPart = int(fComDate)
floatpart = fComDate - integerPart
##if floatpart == 0.0:
## return datetime.date.fromordinal(integerPart + self._ordinal_1899_12_31)
dte = datetime.datetime.fromordinal(
integerPart + self._ordinal_1899_12_31
) + datetime.timedelta(milliseconds=floatpart * 86400000)
# millisecondsperday=86400000 # 24*60*60*1000
return dte
def Date(self, year, month, day):
return datetime.date(year, month, day)
def Time(self, hour, minute, second):
return datetime.time(hour, minute, second)
def Timestamp(self, year, month, day, hour, minute, second):
return datetime.datetime(year, month, day, hour, minute, second)
class pythonTimeConverter(TimeConverter): # the old, ?nix type date and time
def __init__(self): # caution: this Class gets confised by timezones and DST
TimeConverter.__init__(self)
self.types.add(time.struct_time)
def DateObjectFromCOMDate(self, comDate):
"Returns ticks since 1970"
if isinstance(comDate, datetime.datetime):
return comDate.timetuple()
elif isinstance(comDate, DateTime): # ironPython clr date/time
fcomDate = comDate.ToOADate()
else:
fcomDate = float(comDate)
secondsperday = 86400 # 24*60*60
# ComDate is number of days since 1899-12-31, gmtime epoch is 1970-1-1 = 25569 days
t = time.gmtime(secondsperday * (fcomDate - 25569.0))
return t # year,month,day,hour,minute,second,weekday,julianday,daylightsaving=t
def Date(self, year, month, day):
return self.Timestamp(year, month, day, 0, 0, 0)
def Time(self, hour, minute, second):
return time.gmtime((hour * 60 + minute) * 60 + second)
def Timestamp(self, year, month, day, hour, minute, second):
return time.localtime(
time.mktime((year, month, day, hour, minute, second, 0, 0, -1))
)
base_dateconverter = pythonDateTimeConverter()
# ------ DB API required module attributes ---------------------
threadsafety = 1 # TODO -- find out whether this module is actually BETTER than 1.
apilevel = "2.0" # String constant stating the supported DB API level.
paramstyle = "qmark" # the default parameter style
# ------ control for an extension which may become part of DB API 3.0 ---
accepted_paramstyles = ("qmark", "named", "format", "pyformat", "dynamic")
# ------------------------------------------------------------------------------------------
# define similar types for generic conversion routines
adoIntegerTypes = (
adc.adInteger,
adc.adSmallInt,
adc.adTinyInt,
adc.adUnsignedInt,
adc.adUnsignedSmallInt,
adc.adUnsignedTinyInt,
adc.adBoolean,
adc.adError,
) # max 32 bits
adoRowIdTypes = (adc.adChapter,) # v2.1 Rose
adoLongTypes = (adc.adBigInt, adc.adFileTime, adc.adUnsignedBigInt)
adoExactNumericTypes = (
adc.adDecimal,
adc.adNumeric,
adc.adVarNumeric,
adc.adCurrency,
) # v2.3 Cole
adoApproximateNumericTypes = (adc.adDouble, adc.adSingle) # v2.1 Cole
adoStringTypes = (
adc.adBSTR,
adc.adChar,
adc.adLongVarChar,
adc.adLongVarWChar,
adc.adVarChar,
adc.adVarWChar,
adc.adWChar,
)
adoBinaryTypes = (adc.adBinary, adc.adLongVarBinary, adc.adVarBinary)
adoDateTimeTypes = (adc.adDBTime, adc.adDBTimeStamp, adc.adDate, adc.adDBDate)
adoRemainingTypes = (
adc.adEmpty,
adc.adIDispatch,
adc.adIUnknown,
adc.adPropVariant,
adc.adArray,
adc.adUserDefined,
adc.adVariant,
adc.adGUID,
)
# this class is a trick to determine whether a type is a member of a related group of types. see PEP notes
class DBAPITypeObject(object):
def __init__(self, valuesTuple):
self.values = frozenset(valuesTuple)
def __eq__(self, other):
return other in self.values
def __ne__(self, other):
return other not in self.values
"""This type object is used to describe columns in a database that are string-based (e.g. CHAR). """
STRING = DBAPITypeObject(adoStringTypes)
"""This type object is used to describe (long) binary columns in a database (e.g. LONG, RAW, BLOBs). """
BINARY = DBAPITypeObject(adoBinaryTypes)
"""This type object is used to describe numeric columns in a database. """
NUMBER = DBAPITypeObject(
adoIntegerTypes + adoLongTypes + adoExactNumericTypes + adoApproximateNumericTypes
)
"""This type object is used to describe date/time columns in a database. """
DATETIME = DBAPITypeObject(adoDateTimeTypes)
"""This type object is used to describe the "Row ID" column in a database. """
ROWID = DBAPITypeObject(adoRowIdTypes)
OTHER = DBAPITypeObject(adoRemainingTypes)
# ------- utilities for translating python data types to ADO data types ---------------------------------
typeMap = {
memoryViewType: adc.adVarBinary,
float: adc.adDouble,
type(None): adc.adEmpty,
str: adc.adBSTR,
bool: adc.adBoolean, # v2.1 Cole
decimal.Decimal: adc.adDecimal,
int: adc.adBigInt,
bytes: adc.adVarBinary,
}
def pyTypeToADOType(d):
tp = type(d)
try:
return typeMap[tp]
except KeyError: # The type was not defined in the pre-computed Type table
from . import dateconverter
if (
tp in dateconverter.types
): # maybe it is one of our supported Date/Time types
return adc.adDate
# otherwise, attempt to discern the type by probing the data object itself -- to handle duck typing
if isinstance(d, StringTypes):
return adc.adBSTR
if isinstance(d, numbers.Integral):
return adc.adBigInt
if isinstance(d, numbers.Real):
return adc.adDouble
raise DataError('cannot convert "%s" (type=%s) to ADO' % (repr(d), tp))
# # # # # # # # # # # # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# functions to convert database values to Python objects
# ------------------------------------------------------------------------
# variant type : function converting variant to Python value
def variantConvertDate(v):
from . import dateconverter # this function only called when adodbapi is running
return dateconverter.DateObjectFromCOMDate(v)
def cvtString(variant): # use to get old action of adodbapi v1 if desired
if onIronPython:
try:
return variant.ToString()
except:
pass
return str(variant)
def cvtDecimal(variant): # better name
return _convertNumberWithCulture(variant, decimal.Decimal)
def cvtNumeric(variant): # older name - don't break old code
return cvtDecimal(variant)
def cvtFloat(variant):
return _convertNumberWithCulture(variant, float)
def _convertNumberWithCulture(variant, f):
try:
return f(variant)
except (ValueError, TypeError, decimal.InvalidOperation):
try:
europeVsUS = str(variant).replace(",", ".")
return f(europeVsUS)
except (ValueError, TypeError, decimal.InvalidOperation):
pass
def cvtInt(variant):
return int(variant)
def cvtLong(variant): # only important in old versions where long and int differ
return int(variant)
def cvtBuffer(variant):
return bytes(variant)
def cvtUnicode(variant):
return str(variant)
def identity(x):
return x
def cvtUnusual(variant):
if verbose > 1:
sys.stderr.write("Conversion called for Unusual data=%s\n" % repr(variant))
if isinstance(variant, DateTime): # COMdate or System.Date
from .adodbapi import ( # this will only be called when adodbapi is in use, and very rarely
dateconverter,
)
return dateconverter.DateObjectFromCOMDate(variant)
return variant # cannot find conversion function -- just give the data to the user
def convert_to_python(variant, func): # convert DB value into Python value
if isinstance(variant, NullTypes): # IronPython Null or None
return None
return func(variant) # call the appropriate conversion function
class MultiMap(dict): # builds a dictionary from {(sequence,of,keys) : function}
"""A dictionary of ado.type : function -- but you can set multiple items by passing a sequence of keys"""
# useful for defining conversion functions for groups of similar data types.
def __init__(self, aDict):
for k, v in list(aDict.items()):
self[k] = v # we must call __setitem__
def __setitem__(self, adoType, cvtFn):
"set a single item, or a whole sequence of items"
try: # user passed us a sequence, set them individually
for type in adoType:
dict.__setitem__(self, type, cvtFn)
except TypeError: # a single value fails attempt to iterate
dict.__setitem__(self, adoType, cvtFn)
# initialize variantConversions dictionary used to convert SQL to Python
# this is the dictionary of default conversion functions, built by the class above.
# this becomes a class attribute for the Connection, and that attribute is used
# to build the list of column conversion functions for the Cursor
variantConversions = MultiMap(
{
adoDateTimeTypes: variantConvertDate,
adoApproximateNumericTypes: cvtFloat,
adoExactNumericTypes: cvtDecimal, # use to force decimal rather than unicode
adoLongTypes: cvtLong,
adoIntegerTypes: cvtInt,
adoRowIdTypes: cvtInt,
adoStringTypes: identity,
adoBinaryTypes: cvtBuffer,
adoRemainingTypes: cvtUnusual,
}
)
# # # # # classes to emulate the result of cursor.fetchxxx() as a sequence of sequences # # # # #
# "an ENUM of how my low level records are laid out"
RS_WIN_32, RS_ARRAY, RS_REMOTE = list(range(1, 4))
class SQLrow(object): # a single database row
# class to emulate a sequence, so that a column may be retrieved by either number or name
def __init__(self, rows, index): # "rows" is an _SQLrows object, index is which row
self.rows = rows # parent 'fetch' container object
self.index = index # my row number within parent
def __getattr__(self, name): # used for row.columnName type of value access
try:
return self._getValue(self.rows.columnNames[name.lower()])
except KeyError:
raise AttributeError('Unknown column name "{}"'.format(name))
def _getValue(self, key): # key must be an integer
if (
self.rows.recordset_format == RS_ARRAY
): # retrieve from two-dimensional array
v = self.rows.ado_results[key, self.index]
elif self.rows.recordset_format == RS_REMOTE:
v = self.rows.ado_results[self.index][key]
else: # pywin32 - retrieve from tuple of tuples
v = self.rows.ado_results[key][self.index]
if self.rows.converters is NotImplemented:
return v
return convert_to_python(v, self.rows.converters[key])
def __len__(self):
return self.rows.numberOfColumns
def __getitem__(self, key): # used for row[key] type of value access
if isinstance(key, int): # normal row[1] designation
try:
return self._getValue(key)
except IndexError:
raise
if isinstance(key, slice):
indices = key.indices(self.rows.numberOfColumns)
vl = [self._getValue(i) for i in range(*indices)]
return tuple(vl)
try:
return self._getValue(
self.rows.columnNames[key.lower()]
) # extension row[columnName] designation
except (KeyError, TypeError):
er, st, tr = sys.exc_info()
raise er(
'No such key as "%s" in %s' % (repr(key), self.__repr__())
).with_traceback(tr)
def __iter__(self):
return iter(self.__next__())
def __next__(self):
for n in range(self.rows.numberOfColumns):
yield self._getValue(n)
def __repr__(self): # create a human readable representation
taglist = sorted(list(self.rows.columnNames.items()), key=lambda x: x[1])
s = "<SQLrow={"
for name, i in taglist:
s += name + ":" + repr(self._getValue(i)) + ", "
return s[:-2] + "}>"
def __str__(self): # create a pretty human readable representation
return str(
tuple(str(self._getValue(i)) for i in range(self.rows.numberOfColumns))
)
# TO-DO implement pickling an SQLrow directly
# def __getstate__(self): return self.__dict__
# def __setstate__(self, d): self.__dict__.update(d)
# which basically tell pickle to treat your class just like a normal one,
# taking self.__dict__ as representing the whole of the instance state,
# despite the existence of the __getattr__.
# # # #
class SQLrows(object):
# class to emulate a sequence for multiple rows using a container object
def __init__(self, ado_results, numberOfRows, cursor):
self.ado_results = ado_results # raw result of SQL get
try:
self.recordset_format = cursor.recordset_format
self.numberOfColumns = cursor.numberOfColumns
self.converters = cursor.converters
self.columnNames = cursor.columnNames
except AttributeError:
self.recordset_format = RS_ARRAY
self.numberOfColumns = 0
self.converters = []
self.columnNames = {}
self.numberOfRows = numberOfRows
def __len__(self):
return self.numberOfRows
def __getitem__(self, item): # used for row or row,column access
if not self.ado_results:
return []
if isinstance(item, slice): # will return a list of row objects
indices = item.indices(self.numberOfRows)
return [SQLrow(self, k) for k in range(*indices)]
elif isinstance(item, tuple) and len(item) == 2:
# d = some_rowsObject[i,j] will return a datum from a two-dimension address
i, j = item
if not isinstance(j, int):
try:
j = self.columnNames[j.lower()] # convert named column to numeric
except KeyError:
raise KeyError('adodbapi: no such column name as "%s"' % repr(j))
if self.recordset_format == RS_ARRAY: # retrieve from two-dimensional array
v = self.ado_results[j, i]
elif self.recordset_format == RS_REMOTE:
v = self.ado_results[i][j]
else: # pywin32 - retrieve from tuple of tuples
v = self.ado_results[j][i]
if self.converters is NotImplemented:
return v
return convert_to_python(v, self.converters[j])
else:
row = SQLrow(self, item) # new row descriptor
return row
def __iter__(self):
return iter(self.__next__())
def __next__(self):
for n in range(self.numberOfRows):
row = SQLrow(self, n)
yield row
# # # # #
# # # # # functions to re-format SQL requests to other paramstyle requirements # # # # # # # # # #
def changeNamedToQmark(
op,
): # convert from 'named' paramstyle to ADO required '?'mark parameters
outOp = ""
outparms = []
chunks = op.split(
"'"
) # quote all literals -- odd numbered list results are literals.
inQuotes = False
for chunk in chunks:
if inQuotes: # this is inside a quote
if chunk == "": # double apostrophe to quote one apostrophe
outOp = outOp[:-1] # so take one away
else:
outOp += "'" + chunk + "'" # else pass the quoted string as is.
else: # is SQL code -- look for a :namedParameter
while chunk: # some SQL string remains
sp = chunk.split(":", 1)
outOp += sp[0] # concat the part up to the :
s = ""
try:
chunk = sp[1]
except IndexError:
chunk = None
if chunk: # there was a parameter - parse it out
i = 0
c = chunk[0]
while c.isalnum() or c == "_":
i += 1
try:
c = chunk[i]
except IndexError:
break
s = chunk[:i]
chunk = chunk[i:]
if s:
outparms.append(s) # list the parameters in order
outOp += "?" # put in the Qmark
inQuotes = not inQuotes
return outOp, outparms
def changeFormatToQmark(
op,
): # convert from 'format' paramstyle to ADO required '?'mark parameters
outOp = ""
outparams = []
chunks = op.split(
"'"
) # quote all literals -- odd numbered list results are literals.
inQuotes = False
for chunk in chunks:
if inQuotes:
if (
outOp != "" and chunk == ""
): # he used a double apostrophe to quote one apostrophe
outOp = outOp[:-1] # so take one away
else:
outOp += "'" + chunk + "'" # else pass the quoted string as is.
else: # is SQL code -- look for a %s parameter
if "%(" in chunk: # ugh! pyformat!
while chunk: # some SQL string remains
sp = chunk.split("%(", 1)
outOp += sp[0] # concat the part up to the %
if len(sp) > 1:
try:
s, chunk = sp[1].split(")s", 1) # find the ')s'
except ValueError:
raise ProgrammingError(
'Pyformat SQL has incorrect format near "%s"' % chunk
)
outparams.append(s)
outOp += "?" # put in the Qmark
else:
chunk = None
else: # proper '%s' format
sp = chunk.split("%s") # make each %s
outOp += "?".join(sp) # into ?
inQuotes = not inQuotes # every other chunk is a quoted string
return outOp, outparams

View file

@ -1,72 +0,0 @@
""" db_print.py -- a simple demo for ADO database reads."""
import sys
import adodbapi.ado_consts as adc
cmd_args = ("filename", "table_name")
if "help" in sys.argv:
print("possible settings keywords are:", cmd_args)
sys.exit()
kw_args = {} # pick up filename and proxy address from command line (optionally)
for arg in sys.argv:
s = arg.split("=")
if len(s) > 1:
if s[0] in cmd_args:
kw_args[s[0]] = s[1]
kw_args.setdefault(
"filename", "test.mdb"
) # assumes server is running from examples folder
kw_args.setdefault("table_name", "Products") # the name of the demo table
# the server needs to select the provider based on his Python installation
provider_switch = ["provider", "Microsoft.ACE.OLEDB.12.0", "Microsoft.Jet.OLEDB.4.0"]
# ------------------------ START HERE -------------------------------------
# create the connection
constr = "Provider=%(provider)s;Data Source=%(filename)s"
import adodbapi as db
con = db.connect(constr, kw_args, macro_is64bit=provider_switch)
if kw_args["table_name"] == "?":
print("The tables in your database are:")
for name in con.get_table_names():
print(name)
else:
# make a cursor on the connection
with con.cursor() as c:
# run an SQL statement on the cursor
sql = "select * from %s" % kw_args["table_name"]
print('performing query="%s"' % sql)
c.execute(sql)
# check the results
print(
'result rowcount shows as= %d. (Note: -1 means "not known")' % (c.rowcount,)
)
print("")
print("result data description is:")
print(" NAME Type DispSize IntrnlSz Prec Scale Null?")
for d in c.description:
print(
("%16s %-12s %8s %8d %4d %5d %s")
% (d[0], adc.adTypeNames[d[1]], d[2], d[3], d[4], d[5], bool(d[6]))
)
print("")
print("str() of first five records are...")
# get the results
db = c.fetchmany(5)
# print them
for rec in db:
print(rec)
print("")
print("repr() of next row is...")
print(repr(c.fetchone()))
print("")
con.close()

View file

@ -1,20 +0,0 @@
""" db_table_names.py -- a simple demo for ADO database table listing."""
import sys
import adodbapi
try:
databasename = sys.argv[1]
except IndexError:
databasename = "test.mdb"
provider = ["prv", "Microsoft.ACE.OLEDB.12.0", "Microsoft.Jet.OLEDB.4.0"]
constr = "Provider=%(prv)s;Data Source=%(db)s"
# create the connection
con = adodbapi.connect(constr, db=databasename, macro_is64bit=provider)
print("Table names in= %s" % databasename)
for table in con.get_table_names():
print(table)

View file

@ -1,41 +0,0 @@
import sys
import adodbapi
try:
import adodbapi.is64bit as is64bit
is64 = is64bit.Python()
except ImportError:
is64 = False
if is64:
driver = "Microsoft.ACE.OLEDB.12.0"
else:
driver = "Microsoft.Jet.OLEDB.4.0"
extended = 'Extended Properties="Excel 8.0;HDR=Yes;IMEX=1;"'
try: # first command line argument will be xls file name -- default to the one written by xls_write.py
filename = sys.argv[1]
except IndexError:
filename = "xx.xls"
constr = "Provider=%s;Data Source=%s;%s" % (driver, filename, extended)
conn = adodbapi.connect(constr)
try: # second command line argument will be worksheet name -- default to first worksheet
sheet = sys.argv[2]
except IndexError:
# use ADO feature to get the name of the first worksheet
sheet = conn.get_table_names()[0]
print("Shreadsheet=%s Worksheet=%s" % (filename, sheet))
print("------------------------------------------------------------")
crsr = conn.cursor()
sql = "SELECT * from [%s]" % sheet
crsr.execute(sql)
for row in crsr.fetchmany(10):
print(repr(row))
crsr.close()
conn.close()

View file

@ -1,41 +0,0 @@
import datetime
import adodbapi
try:
import adodbapi.is64bit as is64bit
is64 = is64bit.Python()
except ImportError:
is64 = False # in case the user has an old version of adodbapi
if is64:
driver = "Microsoft.ACE.OLEDB.12.0"
else:
driver = "Microsoft.Jet.OLEDB.4.0"
filename = "xx.xls" # file will be created if it does not exist
extended = 'Extended Properties="Excel 8.0;Readonly=False;"'
constr = "Provider=%s;Data Source=%s;%s" % (driver, filename, extended)
conn = adodbapi.connect(constr)
with conn: # will auto commit if no errors
with conn.cursor() as crsr:
try:
crsr.execute("drop table SheetOne")
except:
pass # just is case there is one already there
# create the sheet and the header row and set the types for the columns
crsr.execute(
"create table SheetOne (Name varchar, Rank varchar, SrvcNum integer, Weight float, Birth date)"
)
sql = "INSERT INTO SheetOne (name, rank , srvcnum, weight, birth) values (?,?,?,?,?)"
data = ("Mike Murphy", "SSG", 123456789, 167.8, datetime.date(1922, 12, 27))
crsr.execute(sql, data) # write the first row of data
crsr.execute(
sql, ["John Jones", "Pvt", 987654321, 140.0, datetime.date(1921, 7, 4)]
) # another row of data
conn.close()
print("Created spreadsheet=%s worksheet=%s" % (filename, "SheetOne"))

View file

@ -1,41 +0,0 @@
"""is64bit.Python() --> boolean value of detected Python word size. is64bit.os() --> os build version"""
import sys
def Python():
if sys.platform == "cli": # IronPython
import System
return System.IntPtr.Size == 8
else:
try:
return sys.maxsize > 2147483647
except AttributeError:
return sys.maxint > 2147483647
def os():
import platform
pm = platform.machine()
if pm != ".." and pm.endswith("64"): # recent Python (not Iron)
return True
else:
import os
if "PROCESSOR_ARCHITEW6432" in os.environ:
return True # 32 bit program running on 64 bit Windows
try:
return os.environ["PROCESSOR_ARCHITECTURE"].endswith(
"64"
) # 64 bit Windows 64 bit program
except (IndexError, KeyError):
pass # not Windows
try:
return "64" in platform.architecture()[0] # this often works in Linux
except:
return False # is an older version of Python, assume also an older os (best we can guess)
if __name__ == "__main__":
print("is64bit.Python() =", Python(), "is64bit.os() =", os())

View file

@ -1,506 +0,0 @@
GNU LESSER GENERAL PUBLIC LICENSE
Version 2.1, February 1999
Copyright (C) 1991, 1999 Free Software Foundation, Inc.
59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
[This is the first released version of the Lesser GPL. It also counts
as the successor of the GNU Library Public License, version 2, hence
the version number 2.1.]
Preamble
The licenses for most software are designed to take away your
freedom to share and change it. By contrast, the GNU General Public
Licenses are intended to guarantee your freedom to share and change
free software--to make sure the software is free for all its users.
This license, the Lesser General Public License, applies to some
specially designated software packages--typically libraries--of the
Free Software Foundation and other authors who decide to use it. You
can use it too, but we suggest you first think carefully about whether
this license or the ordinary General Public License is the better
strategy to use in any particular case, based on the explanations below.
When we speak of free software, we are referring to freedom of use,
not price. Our General Public Licenses are designed to make sure that
you have the freedom to distribute copies of free software (and charge
for this service if you wish); that you receive source code or can get
it if you want it; that you can change the software and use pieces of
it in new free programs; and that you are informed that you can do
these things.
To protect your rights, we need to make restrictions that forbid
distributors to deny you these rights or to ask you to surrender these
rights. These restrictions translate to certain responsibilities for
you if you distribute copies of the library or if you modify it.
For example, if you distribute copies of the library, whether gratis
or for a fee, you must give the recipients all the rights that we gave
you. You must make sure that they, too, receive or can get the source
code. If you link other code with the library, you must provide
complete object files to the recipients, so that they can relink them
with the library after making changes to the library and recompiling
it. And you must show them these terms so they know their rights.
We protect your rights with a two-step method: (1) we copyright the
library, and (2) we offer you this license, which gives you legal
permission to copy, distribute and/or modify the library.
To protect each distributor, we want to make it very clear that
there is no warranty for the free library. Also, if the library is
modified by someone else and passed on, the recipients should know
that what they have is not the original version, so that the original
author's reputation will not be affected by problems that might be
introduced by others.
Finally, software patents pose a constant threat to the existence of
any free program. We wish to make sure that a company cannot
effectively restrict the users of a free program by obtaining a
restrictive license from a patent holder. Therefore, we insist that
any patent license obtained for a version of the library must be
consistent with the full freedom of use specified in this license.
Most GNU software, including some libraries, is covered by the
ordinary GNU General Public License. This license, the GNU Lesser
General Public License, applies to certain designated libraries, and
is quite different from the ordinary General Public License. We use
this license for certain libraries in order to permit linking those
libraries into non-free programs.
When a program is linked with a library, whether statically or using
a shared library, the combination of the two is legally speaking a
combined work, a derivative of the original library. The ordinary
General Public License therefore permits such linking only if the
entire combination fits its criteria of freedom. The Lesser General
Public License permits more lax criteria for linking other code with
the library.
We call this license the "Lesser" General Public License because it
does Less to protect the user's freedom than the ordinary General
Public License. It also provides other free software developers Less
of an advantage over competing non-free programs. These disadvantages
are the reason we use the ordinary General Public License for many
libraries. However, the Lesser license provides advantages in certain
special circumstances.
For example, on rare occasions, there may be a special need to
encourage the widest possible use of a certain library, so that it becomes
a de-facto standard. To achieve this, non-free programs must be
allowed to use the library. A more frequent case is that a free
library does the same job as widely used non-free libraries. In this
case, there is little to gain by limiting the free library to free
software only, so we use the Lesser General Public License.
In other cases, permission to use a particular library in non-free
programs enables a greater number of people to use a large body of
free software. For example, permission to use the GNU C Library in
non-free programs enables many more people to use the whole GNU
operating system, as well as its variant, the GNU/Linux operating
system.
Although the Lesser General Public License is Less protective of the
users' freedom, it does ensure that the user of a program that is
linked with the Library has the freedom and the wherewithal to run
that program using a modified version of the Library.
The precise terms and conditions for copying, distribution and
modification follow. Pay close attention to the difference between a
"work based on the library" and a "work that uses the library". The
former contains code derived from the library, whereas the latter must
be combined with the library in order to run.
GNU LESSER GENERAL PUBLIC LICENSE
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
0. This License Agreement applies to any software library or other
program which contains a notice placed by the copyright holder or
other authorized party saying it may be distributed under the terms of
this Lesser General Public License (also called "this License").
Each licensee is addressed as "you".
A "library" means a collection of software functions and/or data
prepared so as to be conveniently linked with application programs
(which use some of those functions and data) to form executables.
The "Library", below, refers to any such software library or work
which has been distributed under these terms. A "work based on the
Library" means either the Library or any derivative work under
copyright law: that is to say, a work containing the Library or a
portion of it, either verbatim or with modifications and/or translated
straightforwardly into another language. (Hereinafter, translation is
included without limitation in the term "modification".)
"Source code" for a work means the preferred form of the work for
making modifications to it. For a library, complete source code means
all the source code for all modules it contains, plus any associated
interface definition files, plus the scripts used to control compilation
and installation of the library.
Activities other than copying, distribution and modification are not
covered by this License; they are outside its scope. The act of
running a program using the Library is not restricted, and output from
such a program is covered only if its contents constitute a work based
on the Library (independent of the use of the Library in a tool for
writing it). Whether that is true depends on what the Library does
and what the program that uses the Library does.
1. You may copy and distribute verbatim copies of the Library's
complete source code as you receive it, in any medium, provided that
you conspicuously and appropriately publish on each copy an
appropriate copyright notice and disclaimer of warranty; keep intact
all the notices that refer to this License and to the absence of any
warranty; and distribute a copy of this License along with the
Library.
You may charge a fee for the physical act of transferring a copy,
and you may at your option offer warranty protection in exchange for a
fee.
2. You may modify your copy or copies of the Library or any portion
of it, thus forming a work based on the Library, and copy and
distribute such modifications or work under the terms of Section 1
above, provided that you also meet all of these conditions:
a) The modified work must itself be a software library.
b) You must cause the files modified to carry prominent notices
stating that you changed the files and the date of any change.
c) You must cause the whole of the work to be licensed at no
charge to all third parties under the terms of this License.
d) If a facility in the modified Library refers to a function or a
table of data to be supplied by an application program that uses
the facility, other than as an argument passed when the facility
is invoked, then you must make a good faith effort to ensure that,
in the event an application does not supply such function or
table, the facility still operates, and performs whatever part of
its purpose remains meaningful.
(For example, a function in a library to compute square roots has
a purpose that is entirely well-defined independent of the
application. Therefore, Subsection 2d requires that any
application-supplied function or table used by this function must
be optional: if the application does not supply it, the square
root function must still compute square roots.)
These requirements apply to the modified work as a whole. If
identifiable sections of that work are not derived from the Library,
and can be reasonably considered independent and separate works in
themselves, then this License, and its terms, do not apply to those
sections when you distribute them as separate works. But when you
distribute the same sections as part of a whole which is a work based
on the Library, the distribution of the whole must be on the terms of
this License, whose permissions for other licensees extend to the
entire whole, and thus to each and every part regardless of who wrote
it.
Thus, it is not the intent of this section to claim rights or contest
your rights to work written entirely by you; rather, the intent is to
exercise the right to control the distribution of derivative or
collective works based on the Library.
In addition, mere aggregation of another work not based on the Library
with the Library (or with a work based on the Library) on a volume of
a storage or distribution medium does not bring the other work under
the scope of this License.
3. You may opt to apply the terms of the ordinary GNU General Public
License instead of this License to a given copy of the Library. To do
this, you must alter all the notices that refer to this License, so
that they refer to the ordinary GNU General Public License, version 2,
instead of to this License. (If a newer version than version 2 of the
ordinary GNU General Public License has appeared, then you can specify
that version instead if you wish.) Do not make any other change in
these notices.
Once this change is made in a given copy, it is irreversible for
that copy, so the ordinary GNU General Public License applies to all
subsequent copies and derivative works made from that copy.
This option is useful when you wish to copy part of the code of
the Library into a program that is not a library.
4. You may copy and distribute the Library (or a portion or
derivative of it, under Section 2) in object code or executable form
under the terms of Sections 1 and 2 above provided that you accompany
it with the complete corresponding machine-readable source code, which
must be distributed under the terms of Sections 1 and 2 above on a
medium customarily used for software interchange.
If distribution of object code is made by offering access to copy
from a designated place, then offering equivalent access to copy the
source code from the same place satisfies the requirement to
distribute the source code, even though third parties are not
compelled to copy the source along with the object code.
5. A program that contains no derivative of any portion of the
Library, but is designed to work with the Library by being compiled or
linked with it, is called a "work that uses the Library". Such a
work, in isolation, is not a derivative work of the Library, and
therefore falls outside the scope of this License.
However, linking a "work that uses the Library" with the Library
creates an executable that is a derivative of the Library (because it
contains portions of the Library), rather than a "work that uses the
library". The executable is therefore covered by this License.
Section 6 states terms for distribution of such executables.
When a "work that uses the Library" uses material from a header file
that is part of the Library, the object code for the work may be a
derivative work of the Library even though the source code is not.
Whether this is true is especially significant if the work can be
linked without the Library, or if the work is itself a library. The
threshold for this to be true is not precisely defined by law.
If such an object file uses only numerical parameters, data
structure layouts and accessors, and small macros and small inline
functions (ten lines or less in length), then the use of the object
file is unrestricted, regardless of whether it is legally a derivative
work. (Executables containing this object code plus portions of the
Library will still fall under Section 6.)
Otherwise, if the work is a derivative of the Library, you may
distribute the object code for the work under the terms of Section 6.
Any executables containing that work also fall under Section 6,
whether or not they are linked directly with the Library itself.
6. As an exception to the Sections above, you may also combine or
link a "work that uses the Library" with the Library to produce a
work containing portions of the Library, and distribute that work
under terms of your choice, provided that the terms permit
modification of the work for the customer's own use and reverse
engineering for debugging such modifications.
You must give prominent notice with each copy of the work that the
Library is used in it and that the Library and its use are covered by
this License. You must supply a copy of this License. If the work
during execution displays copyright notices, you must include the
copyright notice for the Library among them, as well as a reference
directing the user to the copy of this License. Also, you must do one
of these things:
a) Accompany the work with the complete corresponding
machine-readable source code for the Library including whatever
changes were used in the work (which must be distributed under
Sections 1 and 2 above); and, if the work is an executable linked
with the Library, with the complete machine-readable "work that
uses the Library", as object code and/or source code, so that the
user can modify the Library and then relink to produce a modified
executable containing the modified Library. (It is understood
that the user who changes the contents of definitions files in the
Library will not necessarily be able to recompile the application
to use the modified definitions.)
b) Use a suitable shared library mechanism for linking with the
Library. A suitable mechanism is one that (1) uses at run time a
copy of the library already present on the user's computer system,
rather than copying library functions into the executable, and (2)
will operate properly with a modified version of the library, if
the user installs one, as long as the modified version is
interface-compatible with the version that the work was made with.
c) Accompany the work with a written offer, valid for at
least three years, to give the same user the materials
specified in Subsection 6a, above, for a charge no more
than the cost of performing this distribution.
d) If distribution of the work is made by offering access to copy
from a designated place, offer equivalent access to copy the above
specified materials from the same place.
e) Verify that the user has already received a copy of these
materials or that you have already sent this user a copy.
For an executable, the required form of the "work that uses the
Library" must include any data and utility programs needed for
reproducing the executable from it. However, as a special exception,
the materials to be distributed need not include anything that is
normally distributed (in either source or binary form) with the major
components (compiler, kernel, and so on) of the operating system on
which the executable runs, unless that component itself accompanies
the executable.
It may happen that this requirement contradicts the license
restrictions of other proprietary libraries that do not normally
accompany the operating system. Such a contradiction means you cannot
use both them and the Library together in an executable that you
distribute.
7. You may place library facilities that are a work based on the
Library side-by-side in a single library together with other library
facilities not covered by this License, and distribute such a combined
library, provided that the separate distribution of the work based on
the Library and of the other library facilities is otherwise
permitted, and provided that you do these two things:
a) Accompany the combined library with a copy of the same work
based on the Library, uncombined with any other library
facilities. This must be distributed under the terms of the
Sections above.
b) Give prominent notice with the combined library of the fact
that part of it is a work based on the Library, and explaining
where to find the accompanying uncombined form of the same work.
8. You may not copy, modify, sublicense, link with, or distribute
the Library except as expressly provided under this License. Any
attempt otherwise to copy, modify, sublicense, link with, or
distribute the Library is void, and will automatically terminate your
rights under this License. However, parties who have received copies,
or rights, from you under this License will not have their licenses
terminated so long as such parties remain in full compliance.
9. You are not required to accept this License, since you have not
signed it. However, nothing else grants you permission to modify or
distribute the Library or its derivative works. These actions are
prohibited by law if you do not accept this License. Therefore, by
modifying or distributing the Library (or any work based on the
Library), you indicate your acceptance of this License to do so, and
all its terms and conditions for copying, distributing or modifying
the Library or works based on it.
10. Each time you redistribute the Library (or any work based on the
Library), the recipient automatically receives a license from the
original licensor to copy, distribute, link with or modify the Library
subject to these terms and conditions. You may not impose any further
restrictions on the recipients' exercise of the rights granted herein.
You are not responsible for enforcing compliance by third parties with
this License.
11. If, as a consequence of a court judgment or allegation of patent
infringement or for any other reason (not limited to patent issues),
conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot
distribute so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you
may not distribute the Library at all. For example, if a patent
license would not permit royalty-free redistribution of the Library by
all those who receive copies directly or indirectly through you, then
the only way you could satisfy both it and this License would be to
refrain entirely from distribution of the Library.
If any portion of this section is held invalid or unenforceable under any
particular circumstance, the balance of the section is intended to apply,
and the section as a whole is intended to apply in other circumstances.
It is not the purpose of this section to induce you to infringe any
patents or other property right claims or to contest validity of any
such claims; this section has the sole purpose of protecting the
integrity of the free software distribution system which is
implemented by public license practices. Many people have made
generous contributions to the wide range of software distributed
through that system in reliance on consistent application of that
system; it is up to the author/donor to decide if he or she is willing
to distribute software through any other system and a licensee cannot
impose that choice.
This section is intended to make thoroughly clear what is believed to
be a consequence of the rest of this License.
12. If the distribution and/or use of the Library is restricted in
certain countries either by patents or by copyrighted interfaces, the
original copyright holder who places the Library under this License may add
an explicit geographical distribution limitation excluding those countries,
so that distribution is permitted only in or among countries not thus
excluded. In such case, this License incorporates the limitation as if
written in the body of this License.
13. The Free Software Foundation may publish revised and/or new
versions of the Lesser General Public License from time to time.
Such new versions will be similar in spirit to the present version,
but may differ in detail to address new problems or concerns.
Each version is given a distinguishing version number. If the Library
specifies a version number of this License which applies to it and
"any later version", you have the option of following the terms and
conditions either of that version or of any later version published by
the Free Software Foundation. If the Library does not specify a
license version number, you may choose any version ever published by
the Free Software Foundation.
14. If you wish to incorporate parts of the Library into other free
programs whose distribution conditions are incompatible with these,
write to the author to ask for permission. For software which is
copyrighted by the Free Software Foundation, write to the Free
Software Foundation; we sometimes make exceptions for this. Our
decision will be guided by the two goals of preserving the free status
of all derivatives of our free software and of promoting the sharing
and reuse of software generally.
NO WARRANTY
15. BECAUSE THE LIBRARY IS LICENSED FREE OF CHARGE, THERE IS NO
WARRANTY FOR THE LIBRARY, TO THE EXTENT PERMITTED BY APPLICABLE LAW.
EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR
OTHER PARTIES PROVIDE THE LIBRARY "AS IS" WITHOUT WARRANTY OF ANY
KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE
LIBRARY IS WITH YOU. SHOULD THE LIBRARY PROVE DEFECTIVE, YOU ASSUME
THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN
WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY
AND/OR REDISTRIBUTE THE LIBRARY AS PERMITTED ABOVE, BE LIABLE TO YOU
FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR
CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE
LIBRARY (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING
RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A
FAILURE OF THE LIBRARY TO OPERATE WITH ANY OTHER SOFTWARE), EVEN IF
SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
DAMAGES.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Libraries
If you develop a new library, and you want it to be of the greatest
possible use to the public, we recommend making it free software that
everyone can redistribute and change. You can do so by permitting
redistribution under these terms (or, alternatively, under the terms of the
ordinary General Public License).
To apply these terms, attach the following notices to the library. It is
safest to attach them to the start of each source file to most effectively
convey the exclusion of warranty; and each file should have at least the
"copyright" line and a pointer to where the full notice is found.
<one line to give the library's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
Also add information on how to contact you by electronic and paper mail.
You should also get your employer (if you work as a programmer) or your
school, if any, to sign a "copyright disclaimer" for the library, if
necessary. Here is a sample; alter the names:
Yoyodyne, Inc., hereby disclaims all copyright interest in the
library `Frob' (a library for tweaking knobs) written by James Random Hacker.
<signature of Ty Coon>, 1 April 1990
Ty Coon, President of Vice
That's all there is to it!

View file

@ -1,144 +0,0 @@
""" a clumsy attempt at a macro language to let the programmer execute code on the server (ex: determine 64bit)"""
from . import is64bit as is64bit
def macro_call(macro_name, args, kwargs):
"""allow the programmer to perform limited processing on the server by passing macro names and args
:new_key - the key name the macro will create
:args[0] - macro name
:args[1:] - any arguments
:code - the value of the keyword item
:kwargs - the connection keyword dictionary. ??key has been removed
--> the value to put in for kwargs['name'] = value
"""
if isinstance(args, (str, str)):
args = [
args
] # the user forgot to pass a sequence, so make a string into args[0]
new_key = args[0]
try:
if macro_name == "is64bit":
if is64bit.Python(): # if on 64 bit Python
return new_key, args[1] # return first argument
else:
try:
return new_key, args[2] # else return second argument (if defined)
except IndexError:
return new_key, "" # else return blank
elif (
macro_name == "getuser"
): # get the name of the user the server is logged in under
if not new_key in kwargs:
import getpass
return new_key, getpass.getuser()
elif macro_name == "getnode": # get the name of the computer running the server
import platform
try:
return new_key, args[1] % platform.node()
except IndexError:
return new_key, platform.node()
elif macro_name == "getenv": # expand the server's environment variable args[1]
try:
dflt = args[2] # if not found, default from args[2]
except IndexError: # or blank
dflt = ""
return new_key, os.environ.get(args[1], dflt)
elif macro_name == "auto_security":
if (
not "user" in kwargs or not kwargs["user"]
): # missing, blank, or Null username
return new_key, "Integrated Security=SSPI"
return new_key, "User ID=%(user)s; Password=%(password)s" % kwargs
elif (
macro_name == "find_temp_test_path"
): # helper function for testing ado operation -- undocumented
import os
import tempfile
return new_key, os.path.join(
tempfile.gettempdir(), "adodbapi_test", args[1]
)
raise ValueError("Unknown connect string macro=%s" % macro_name)
except:
raise ValueError("Error in macro processing %s %s" % (macro_name, repr(args)))
def process(
args, kwargs, expand_macros=False
): # --> connection string with keyword arguments processed.
"""attempts to inject arguments into a connection string using Python "%" operator for strings
co: adodbapi connection object
args: positional parameters from the .connect() call
kvargs: keyword arguments from the .connect() call
"""
try:
dsn = args[0]
except IndexError:
dsn = None
if isinstance(
dsn, dict
): # as a convenience the first argument may be django settings
kwargs.update(dsn)
elif (
dsn
): # the connection string is passed to the connection as part of the keyword dictionary
kwargs["connection_string"] = dsn
try:
a1 = args[1]
except IndexError:
a1 = None
# historically, the second positional argument might be a timeout value
if isinstance(a1, int):
kwargs["timeout"] = a1
# if the second positional argument is a string, then it is user
elif isinstance(a1, str):
kwargs["user"] = a1
# if the second positional argument is a dictionary, use it as keyword arguments, too
elif isinstance(a1, dict):
kwargs.update(a1)
try:
kwargs["password"] = args[2] # the third positional argument is password
kwargs["host"] = args[3] # the fourth positional argument is host name
kwargs["database"] = args[4] # the fifth positional argument is database name
except IndexError:
pass
# make sure connection string is defined somehow
if not "connection_string" in kwargs:
try: # perhaps 'dsn' was defined
kwargs["connection_string"] = kwargs["dsn"]
except KeyError:
try: # as a last effort, use the "host" keyword
kwargs["connection_string"] = kwargs["host"]
except KeyError:
raise TypeError("Must define 'connection_string' for ado connections")
if expand_macros:
for kwarg in list(kwargs.keys()):
if kwarg.startswith("macro_"): # If a key defines a macro
macro_name = kwarg[6:] # name without the "macro_"
macro_code = kwargs.pop(
kwarg
) # we remove the macro_key and get the code to execute
new_key, rslt = macro_call(
macro_name, macro_code, kwargs
) # run the code in the local context
kwargs[new_key] = rslt # put the result back in the keywords dict
# special processing for PyRO IPv6 host address
try:
s = kwargs["proxy_host"]
if ":" in s: # it is an IPv6 address
if s[0] != "[": # is not surrounded by brackets
kwargs["proxy_host"] = s.join(("[", "]")) # put it in brackets
except KeyError:
pass
return kwargs

View file

@ -1,92 +0,0 @@
Project
-------
adodbapi
A Python DB-API 2.0 (PEP-249) module that makes it easy to use Microsoft ADO
for connecting with databases and other data sources
using either CPython or IronPython.
Home page: <http://sourceforge.net/projects/adodbapi>
Features:
* 100% DB-API 2.0 (PEP-249) compliant (including most extensions and recommendations).
* Includes pyunit testcases that describe how to use the module.
* Fully implemented in Python. -- runs in Python 2.5+ Python 3.0+ and IronPython 2.6+
* Licensed under the LGPL license, which means that it can be used freely even in commercial programs subject to certain restrictions.
* The user can choose between paramstyles: 'qmark' 'named' 'format' 'pyformat' 'dynamic'
* Supports data retrieval by column name e.g.:
for row in myCurser.execute("select name,age from students"):
print("Student", row.name, "is", row.age, "years old.")
* Supports user-definable system-to-Python data conversion functions (selected by ADO data type, or by column)
Prerequisites:
* C Python 2.7 or 3.5 or higher
and pywin32 (Mark Hammond's python for windows extensions.)
or
Iron Python 2.7 or higher. (works in IPy2.0 for all data types except BUFFER)
Installation:
* (C-Python on Windows): Install pywin32 ("pip install pywin32") which includes adodbapi.
* (IronPython on Windows): Download adodbapi from http://sf.net/projects/adodbapi. Unpack the zip.
Open a command window as an administrator. CD to the folder containing the unzipped files.
Run "setup.py install" using the IronPython of your choice.
NOTE: ...........
If you do not like the new default operation of returning Numeric columns as decimal.Decimal,
you can select other options by the user defined conversion feature.
Try:
adodbapi.apibase.variantConversions[adodbapi.ado_consts.adNumeric] = adodbapi.apibase.cvtString
or:
adodbapi.apibase.variantConversions[adodbapi.ado_consts.adNumeric] = adodbapi.apibase.cvtFloat
or:
adodbapi.apibase.variantConversions[adodbapi.ado_consts.adNumeric] = write_your_own_convertion_function
............
notes for 2.6.2:
The definitive source has been moved to https://github.com/mhammond/pywin32/tree/master/adodbapi.
Remote has proven too hard to configure and test with Pyro4. I am moving it to unsupported status
until I can change to a different connection method.
whats new in version 2.6
A cursor.prepare() method and support for prepared SQL statements.
Lots of refactoring, especially of the Remote and Server modules (still to be treated as Beta code).
The quick start document 'quick_reference.odt' will export as a nice-looking pdf.
Added paramstyles 'pyformat' and 'dynamic'. If your 'paramstyle' is 'named' you _must_ pass a dictionary of
parameters to your .execute() method. If your 'paramstyle' is 'format' 'pyformat' or 'dynamic', you _may_
pass a dictionary of parameters -- provided your SQL operation string is formatted correctly.
whats new in version 2.5
Remote module: (works on Linux!) allows a Windows computer to serve ADO databases via PyRO
Server module: PyRO server for ADO. Run using a command like= C:>python -m adodbapi.server
(server has simple connection string macros: is64bit, getuser, sql_provider, auto_security)
Brief documentation included. See adodbapi/examples folder adodbapi.rtf
New connection method conn.get_table_names() --> list of names of tables in database
Vastly refactored. Data conversion things have been moved to the new adodbapi.apibase module.
Many former module-level attributes are now class attributes. (Should be more thread-safe)
Connection objects are now context managers for transactions and will commit or rollback.
Cursor objects are context managers and will automatically close themselves.
Autocommit can be switched on and off.
Keyword and positional arguments on the connect() method work as documented in PEP 249.
Keyword arguments from the connect call can be formatted into the connection string.
New keyword arguments defined, such as: autocommit, paramstyle, remote_proxy, remote_port.
*** Breaking change: variantConversion lookups are simplified: the following will raise KeyError:
oldconverter=adodbapi.variantConversions[adodbapi.adoStringTypes]
Refactor as: oldconverter=adodbapi.variantConversions[adodbapi.adoStringTypes[0]]
License
-------
LGPL, see http://www.opensource.org/licenses/lgpl-license.php
Documentation
-------------
Look at adodbapi/quick_reference.md
http://www.python.org/topics/database/DatabaseAPI-2.0.html
read the examples in adodbapi/examples
and look at the test cases in adodbapi/test directory.
Mailing lists
-------------
The adodbapi mailing lists have been deactivated. Submit comments to the
pywin32 or IronPython mailing lists.
-- the bug tracker on sourceforge.net/projects/adodbapi may be checked, (infrequently).
-- please use: https://github.com/mhammond/pywin32/issues

View file

@ -1,634 +0,0 @@
"""adodbapi.remote - A python DB API 2.0 (PEP 249) interface to Microsoft ADO
Copyright (C) 2002 Henrik Ekelund, version 2.1 by Vernon Cole
* http://sourceforge.net/projects/pywin32
* http://sourceforge.net/projects/adodbapi
This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.
This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
django adaptations and refactoring thanks to Adam Vandenberg
DB-API 2.0 specification: http://www.python.org/dev/peps/pep-0249/
This module source should run correctly in CPython versions 2.5 and later,
or IronPython version 2.7 and later,
or, after running through 2to3.py, CPython 3.0 or later.
"""
__version__ = "2.6.0.4"
version = "adodbapi.remote v" + __version__
import array
import datetime
import os
import sys
import time
# Pyro4 is required for server and remote operation --> https://pypi.python.org/pypi/Pyro4/
try:
import Pyro4
except ImportError:
print('* * * Sorry, server operation requires Pyro4. Please "pip import" it.')
exit(11)
import adodbapi
import adodbapi.apibase as api
import adodbapi.process_connect_string
from adodbapi.apibase import ProgrammingError
_BaseException = api._BaseException
sys.excepthook = Pyro4.util.excepthook
Pyro4.config.PREFER_IP_VERSION = 0 # allow system to prefer IPv6
Pyro4.config.COMMTIMEOUT = 40.0 # a bit longer than the default SQL server Gtimeout
Pyro4.config.SERIALIZER = "pickle"
try:
verbose = int(os.environ["ADODBAPI_VERBOSE"])
except:
verbose = False
if verbose:
print(version)
# --- define objects to smooth out Python3 <-> Python 2.x differences
unicodeType = str # this line will be altered by 2to3.py to '= str'
longType = int # this line will be altered by 2to3.py to '= int'
StringTypes = str
makeByteBuffer = bytes
memoryViewType = memoryview
# -----------------------------------------------------------
# conversion functions mandated by PEP 249
Binary = makeByteBuffer # override the function from apibase.py
def Date(year, month, day):
return datetime.date(year, month, day) # dateconverter.Date(year,month,day)
def Time(hour, minute, second):
return datetime.time(hour, minute, second) # dateconverter.Time(hour,minute,second)
def Timestamp(year, month, day, hour, minute, second):
return datetime.datetime(year, month, day, hour, minute, second)
def DateFromTicks(ticks):
return Date(*time.gmtime(ticks)[:3])
def TimeFromTicks(ticks):
return Time(*time.gmtime(ticks)[3:6])
def TimestampFromTicks(ticks):
return Timestamp(*time.gmtime(ticks)[:6])
def connect(*args, **kwargs): # --> a remote db-api connection object
"""Create and open a remote db-api database connection object"""
# process the argument list the programmer gave us
kwargs = adodbapi.process_connect_string.process(args, kwargs)
# the "proxy_xxx" keys tell us where to find the PyRO proxy server
kwargs.setdefault(
"pyro_connection", "PYRO:ado.connection@%(proxy_host)s:%(proxy_port)s"
)
if not "proxy_port" in kwargs:
try:
pport = os.environ["PROXY_PORT"]
except KeyError:
pport = 9099
kwargs["proxy_port"] = pport
if not "proxy_host" in kwargs or not kwargs["proxy_host"]:
try:
phost = os.environ["PROXY_HOST"]
except KeyError:
phost = "[::1]" # '127.0.0.1'
kwargs["proxy_host"] = phost
ado_uri = kwargs["pyro_connection"] % kwargs
# ask PyRO make us a remote connection object
auto_retry = 3
while auto_retry:
try:
dispatcher = Pyro4.Proxy(ado_uri)
if "comm_timeout" in kwargs:
dispatcher._pyroTimeout = float(kwargs["comm_timeout"])
uri = dispatcher.make_connection()
break
except Pyro4.core.errors.PyroError:
auto_retry -= 1
if auto_retry:
time.sleep(1)
else:
raise api.DatabaseError("Cannot create connection to=%s" % ado_uri)
conn_uri = fix_uri(uri, kwargs) # get a host connection from the proxy server
while auto_retry:
try:
host_conn = Pyro4.Proxy(
conn_uri
) # bring up an exclusive Pyro connection for my ADO connection
break
except Pyro4.core.errors.PyroError:
auto_retry -= 1
if auto_retry:
time.sleep(1)
else:
raise api.DatabaseError(
"Cannot create ADO connection object using=%s" % conn_uri
)
if "comm_timeout" in kwargs:
host_conn._pyroTimeout = float(kwargs["comm_timeout"])
# make a local clone
myConn = Connection()
while auto_retry:
try:
myConn.connect(
kwargs, host_conn
) # call my connect method -- hand him the host connection
break
except Pyro4.core.errors.PyroError:
auto_retry -= 1
if auto_retry:
time.sleep(1)
else:
raise api.DatabaseError(
"Pyro error creating connection to/thru=%s" % repr(kwargs)
)
except _BaseException as e:
raise api.DatabaseError(
"Error creating remote connection to=%s, e=%s, %s"
% (repr(kwargs), repr(e), sys.exc_info()[2])
)
return myConn
def fix_uri(uri, kwargs):
"""convert a generic pyro uri with '0.0.0.0' into the address we actually called"""
u = uri.asString()
s = u.split("[::0]") # IPv6 generic address
if len(s) == 1: # did not find one
s = u.split("0.0.0.0") # IPv4 generic address
if len(s) > 1: # found a generic
return kwargs["proxy_host"].join(s) # fill in our address for the host
return uri
# # # # # ----- the Class that defines a connection ----- # # # # #
class Connection(object):
# include connection attributes required by api definition.
Warning = api.Warning
Error = api.Error
InterfaceError = api.InterfaceError
DataError = api.DataError
DatabaseError = api.DatabaseError
OperationalError = api.OperationalError
IntegrityError = api.IntegrityError
InternalError = api.InternalError
NotSupportedError = api.NotSupportedError
ProgrammingError = api.ProgrammingError
# set up some class attributes
paramstyle = api.paramstyle
@property
def dbapi(self): # a proposed db-api version 3 extension.
"Return a reference to the DBAPI module for this Connection."
return api
def __init__(self):
self.proxy = None
self.kwargs = {}
self.errorhandler = None
self.supportsTransactions = False
self.paramstyle = api.paramstyle
self.timeout = 30
self.cursors = {}
def connect(self, kwargs, connection_maker):
self.kwargs = kwargs
if verbose:
print('%s attempting: "%s"' % (version, repr(kwargs)))
self.proxy = connection_maker
##try:
ret = self.proxy.connect(kwargs) # ask the server to hook us up
##except ImportError, e: # Pyro is trying to import pywinTypes.comerrer
## self._raiseConnectionError(api.DatabaseError, 'Proxy cannot connect using=%s' % repr(kwargs))
if ret is not True:
self._raiseConnectionError(
api.OperationalError, "Proxy returns error message=%s" % repr(ret)
)
self.supportsTransactions = self.getIndexedValue("supportsTransactions")
self.paramstyle = self.getIndexedValue("paramstyle")
self.timeout = self.getIndexedValue("timeout")
if verbose:
print("adodbapi.remote New connection at %X" % id(self))
def _raiseConnectionError(self, errorclass, errorvalue):
eh = self.errorhandler
if eh is None:
eh = api.standardErrorHandler
eh(self, None, errorclass, errorvalue)
def close(self):
"""Close the connection now (rather than whenever __del__ is called).
The connection will be unusable from this point forward;
an Error (or subclass) exception will be raised if any operation is attempted with the connection.
The same applies to all cursor objects trying to use the connection.
"""
for crsr in list(self.cursors.values())[
:
]: # copy the list, then close each one
crsr.close()
try:
"""close the underlying remote Connection object"""
self.proxy.close()
if verbose:
print("adodbapi.remote Closed connection at %X" % id(self))
object.__delattr__(
self, "proxy"
) # future attempts to use closed cursor will be caught by __getattr__
except Exception:
pass
def __del__(self):
try:
self.proxy.close()
except:
pass
def commit(self):
"""Commit any pending transaction to the database.
Note that if the database supports an auto-commit feature,
this must be initially off. An interface method may be provided to turn it back on.
Database modules that do not support transactions should implement this method with void functionality.
"""
if not self.supportsTransactions:
return
result = self.proxy.commit()
if result:
self._raiseConnectionError(
api.OperationalError, "Error during commit: %s" % result
)
def _rollback(self):
"""In case a database does provide transactions this method causes the the database to roll back to
the start of any pending transaction. Closing a connection without committing the changes first will
cause an implicit rollback to be performed.
"""
result = self.proxy.rollback()
if result:
self._raiseConnectionError(
api.OperationalError, "Error during rollback: %s" % result
)
def __setattr__(self, name, value):
if name in ("paramstyle", "timeout", "autocommit"):
if self.proxy:
self.proxy.send_attribute_to_host(name, value)
object.__setattr__(self, name, value) # store attribute locally (too)
def __getattr__(self, item):
if (
item == "rollback"
): # the rollback method only appears if the database supports transactions
if self.supportsTransactions:
return (
self._rollback
) # return the rollback method so the caller can execute it.
else:
raise self.ProgrammingError(
"this data provider does not support Rollback"
)
elif item in (
"dbms_name",
"dbms_version",
"connection_string",
"autocommit",
): # 'messages' ):
return self.getIndexedValue(item)
elif item == "proxy":
raise self.ProgrammingError("Attempting to use closed connection")
else:
raise self.ProgrammingError('No remote access for attribute="%s"' % item)
def getIndexedValue(self, index):
r = self.proxy.get_attribute_for_remote(index)
return r
def cursor(self):
"Return a new Cursor Object using the connection."
myCursor = Cursor(self)
return myCursor
def _i_am_here(self, crsr):
"message from a new cursor proclaiming its existence"
self.cursors[crsr.id] = crsr
def _i_am_closing(self, crsr):
"message from a cursor giving connection a chance to clean up"
try:
del self.cursors[crsr.id]
except:
pass
def __enter__(self): # Connections are context managers
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type:
self._rollback() # automatic rollback on errors
else:
self.commit()
def get_table_names(self):
return self.proxy.get_table_names()
def fixpickle(x):
"""pickle barfs on buffer(x) so we pass as array.array(x) then restore to original form for .execute()"""
if x is None:
return None
if isinstance(x, dict):
# for 'named' paramstyle user will pass a mapping
newargs = {}
for arg, val in list(x.items()):
if isinstance(val, memoryViewType):
newval = array.array("B")
newval.fromstring(val)
newargs[arg] = newval
else:
newargs[arg] = val
return newargs
# if not a mapping, then a sequence
newargs = []
for arg in x:
if isinstance(arg, memoryViewType):
newarg = array.array("B")
newarg.fromstring(arg)
newargs.append(newarg)
else:
newargs.append(arg)
return newargs
class Cursor(object):
def __init__(self, connection):
self.command = None
self.errorhandler = None ## was: connection.errorhandler
self.connection = connection
self.proxy = self.connection.proxy
self.rs = None # the fetchable data for this cursor
self.converters = NotImplemented
self.id = connection.proxy.build_cursor()
connection._i_am_here(self)
self.recordset_format = api.RS_REMOTE
if verbose:
print(
"%s New cursor at %X on conn %X"
% (version, id(self), id(self.connection))
)
def prepare(self, operation):
self.command = operation
try:
del self.description
except AttributeError:
pass
self.proxy.crsr_prepare(self.id, operation)
def __iter__(self): # [2.1 Zamarev]
return iter(self.fetchone, None) # [2.1 Zamarev]
def __next__(self):
r = self.fetchone()
if r:
return r
raise StopIteration
def __enter__(self):
"Allow database cursors to be used with context managers."
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"Allow database cursors to be used with context managers."
self.close()
def __getattr__(self, key):
if key == "numberOfColumns":
try:
return len(self.rs[0])
except:
return 0
if key == "description":
try:
self.description = self.proxy.crsr_get_description(self.id)[:]
return self.description
except TypeError:
return None
if key == "columnNames":
try:
r = dict(
self.proxy.crsr_get_columnNames(self.id)
) # copy the remote columns
except TypeError:
r = {}
self.columnNames = r
return r
if key == "remote_cursor":
raise api.OperationalError
try:
return self.proxy.crsr_get_attribute_for_remote(self.id, key)
except AttributeError:
raise api.InternalError(
'Failure getting attribute "%s" from proxy cursor.' % key
)
def __setattr__(self, key, value):
if key == "arraysize":
self.proxy.crsr_set_arraysize(self.id, value)
if key == "paramstyle":
if value in api.accepted_paramstyles:
self.proxy.crsr_set_paramstyle(self.id, value)
else:
self._raiseCursorError(
api.ProgrammingError, 'invalid paramstyle ="%s"' % value
)
object.__setattr__(self, key, value)
def _raiseCursorError(self, errorclass, errorvalue):
eh = self.errorhandler
if eh is None:
eh = api.standardErrorHandler
eh(self.connection, self, errorclass, errorvalue)
def execute(self, operation, parameters=None):
if self.connection is None:
self._raiseCursorError(
ProgrammingError, "Attempted operation on closed cursor"
)
self.command = operation
try:
del self.description
except AttributeError:
pass
try:
del self.columnNames
except AttributeError:
pass
fp = fixpickle(parameters)
if verbose > 2:
print(
(
'%s executing "%s" with params=%s'
% (version, operation, repr(parameters))
)
)
result = self.proxy.crsr_execute(self.id, operation, fp)
if result: # an exception was triggered
self._raiseCursorError(result[0], result[1])
def executemany(self, operation, seq_of_parameters):
if self.connection is None:
self._raiseCursorError(
ProgrammingError, "Attempted operation on closed cursor"
)
self.command = operation
try:
del self.description
except AttributeError:
pass
try:
del self.columnNames
except AttributeError:
pass
sq = [fixpickle(x) for x in seq_of_parameters]
if verbose > 2:
print(
(
'%s executemany "%s" with params=%s'
% (version, operation, repr(seq_of_parameters))
)
)
self.proxy.crsr_executemany(self.id, operation, sq)
def nextset(self):
try:
del self.description
except AttributeError:
pass
try:
del self.columnNames
except AttributeError:
pass
if verbose > 2:
print(("%s nextset" % version))
return self.proxy.crsr_nextset(self.id)
def callproc(self, procname, parameters=None):
if self.connection is None:
self._raiseCursorError(
ProgrammingError, "Attempted operation on closed cursor"
)
self.command = procname
try:
del self.description
except AttributeError:
pass
try:
del self.columnNames
except AttributeError:
pass
fp = fixpickle(parameters)
if verbose > 2:
print(
(
'%s callproc "%s" with params=%s'
% (version, procname, repr(parameters))
)
)
return self.proxy.crsr_callproc(self.id, procname, fp)
def fetchone(self):
try:
f1 = self.proxy.crsr_fetchone(self.id)
except _BaseException as e:
self._raiseCursorError(api.DatabaseError, e)
else:
if f1 is None:
return None
self.rs = [f1]
return api.SQLrows(self.rs, 1, self)[
0
] # new object to hold the results of the fetch
def fetchmany(self, size=None):
try:
self.rs = self.proxy.crsr_fetchmany(self.id, size)
if not self.rs:
return []
r = api.SQLrows(self.rs, len(self.rs), self)
return r
except Exception as e:
self._raiseCursorError(api.DatabaseError, e)
def fetchall(self):
try:
self.rs = self.proxy.crsr_fetchall(self.id)
if not self.rs:
return []
return api.SQLrows(self.rs, len(self.rs), self)
except Exception as e:
self._raiseCursorError(api.DatabaseError, e)
def close(self):
if self.connection is None:
return
self.connection._i_am_closing(self) # take me off the connection's cursors list
try:
self.proxy.crsr_close(self.id)
except:
pass
try:
del self.description
except:
pass
try:
del self.rs # let go of the recordset
except:
pass
self.connection = (
None # this will make all future method calls on me throw an exception
)
self.proxy = None
if verbose:
print("adodbapi.remote Closed cursor at %X" % id(self))
def __del__(self):
try:
self.close()
except:
pass
def setinputsizes(self, sizes):
pass
def setoutputsize(self, size, column=None):
pass

View file

@ -1,15 +0,0 @@
"""call using an open ADO connection --> list of table names"""
from . import adodbapi
def names(connection_object):
ado = connection_object.adoConn
schema = ado.OpenSchema(20) # constant = adSchemaTables
tables = []
while not schema.EOF:
name = adodbapi.getIndexedValue(schema.Fields, "TABLE_NAME").Value
tables.append(name)
schema.MoveNext()
del schema
return tables

View file

@ -1,70 +0,0 @@
"""adodbapi -- a pure Python PEP 249 DB-API package using Microsoft ADO
Adodbapi can be run on CPython 3.5 and later.
or IronPython version 2.6 and later (in theory, possibly no longer in practice!)
"""
CLASSIFIERS = """\
Development Status :: 5 - Production/Stable
Intended Audience :: Developers
License :: OSI Approved :: GNU Library or Lesser General Public License (LGPL)
Operating System :: Microsoft :: Windows
Operating System :: POSIX :: Linux
Programming Language :: Python
Programming Language :: Python :: 3
Programming Language :: SQL
Topic :: Software Development
Topic :: Software Development :: Libraries :: Python Modules
Topic :: Database
"""
NAME = "adodbapi"
MAINTAINER = "Vernon Cole"
MAINTAINER_EMAIL = "vernondcole@gmail.com"
DESCRIPTION = (
"""A pure Python package implementing PEP 249 DB-API using Microsoft ADO."""
)
URL = "http://sourceforge.net/projects/adodbapi"
LICENSE = "LGPL"
CLASSIFIERS = filter(None, CLASSIFIERS.split("\n"))
AUTHOR = "Henrik Ekelund, Vernon Cole, et.al."
AUTHOR_EMAIL = "vernondcole@gmail.com"
PLATFORMS = ["Windows", "Linux"]
VERSION = None # in case searching for version fails
a = open("adodbapi.py") # find the version string in the source code
for line in a:
if "__version__" in line:
VERSION = line.split("'")[1]
print('adodbapi version="%s"' % VERSION)
break
a.close()
def setup_package():
from distutils.command.build_py import build_py
from distutils.core import setup
setup(
cmdclass={"build_py": build_py},
name=NAME,
maintainer=MAINTAINER,
maintainer_email=MAINTAINER_EMAIL,
description=DESCRIPTION,
url=URL,
keywords="database ado odbc dbapi db-api Microsoft SQL",
## download_url=DOWNLOAD_URL,
long_description=open("README.txt").read(),
license=LICENSE,
classifiers=CLASSIFIERS,
author=AUTHOR,
author_email=AUTHOR_EMAIL,
platforms=PLATFORMS,
version=VERSION,
package_dir={"adodbapi": ""},
packages=["adodbapi"],
)
return
if __name__ == "__main__":
setup_package()

File diff suppressed because it is too large Load diff

View file

@ -1,221 +0,0 @@
# Configure this to _YOUR_ environment in order to run the testcases.
"testADOdbapiConfig.py v 2.6.2.B00"
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# #
# # TESTERS:
# #
# # You will need to make numerous modifications to this file
# # to adapt it to your own testing environment.
# #
# # Skip down to the next "# #" line --
# # -- the things you need to change are below it.
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
import platform
import random
import sys
import is64bit
import setuptestframework
import tryconnection
print("\nPython", sys.version)
node = platform.node()
try:
print(
"node=%s, is64bit.os()= %s, is64bit.Python()= %s"
% (node, is64bit.os(), is64bit.Python())
)
except:
pass
if "--help" in sys.argv:
print(
"""Valid command-line switches are:
--package - create a temporary test package, run 2to3 if needed.
--all - run all possible tests
--time - loop over time format tests (including mxdatetime if present)
--nojet - do not test against an ACCESS database file
--mssql - test against Microsoft SQL server
--pg - test against PostgreSQL
--mysql - test against MariaDB
--remote= - test unsing remote server at= (experimental)
"""
)
exit()
try:
onWindows = bool(sys.getwindowsversion()) # seems to work on all versions of Python
except:
onWindows = False
# create a random name for temporary table names
_alphabet = (
"PYFGCRLAOEUIDHTNSQJKXBMWVZ" # why, yes, I do happen to use a dvorak keyboard
)
tmp = "".join([random.choice(_alphabet) for x in range(9)])
mdb_name = "xx_" + tmp + ".mdb" # generate a non-colliding name for the temporary .mdb
testfolder = setuptestframework.maketemp()
if "--package" in sys.argv:
# create a new adodbapi module -- running 2to3 if needed.
pth = setuptestframework.makeadopackage(testfolder)
else:
# use the adodbapi module in which this file appears
pth = setuptestframework.find_ado_path()
if pth not in sys.path:
# look here _first_ to find modules
sys.path.insert(1, pth)
proxy_host = None
for arg in sys.argv:
if arg.startswith("--remote="):
proxy_host = arg.split("=")[1]
import adodbapi.remote as remote
break
# function to clean up the temporary folder -- calling program must run this function before exit.
cleanup = setuptestframework.getcleanupfunction()
try:
import adodbapi # will (hopefully) be imported using the "pth" discovered above
except SyntaxError:
print(
'\n* * * Are you trying to run Python2 code using Python3? Re-run this test using the "--package" switch.'
)
sys.exit(11)
try:
print(adodbapi.version) # show version
except:
print('"adodbapi.version" not present or not working.')
print(__doc__)
verbose = False
for a in sys.argv:
if a.startswith("--verbose"):
arg = True
try:
arg = int(a.split("=")[1])
except IndexError:
pass
adodbapi.adodbapi.verbose = arg
verbose = arg
doAllTests = "--all" in sys.argv
doAccessTest = not ("--nojet" in sys.argv)
doSqlServerTest = "--mssql" in sys.argv or doAllTests
doMySqlTest = "--mysql" in sys.argv or doAllTests
doPostgresTest = "--pg" in sys.argv or doAllTests
iterateOverTimeTests = ("--time" in sys.argv or doAllTests) and onWindows
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# # start your environment setup here v v v
SQL_HOST_NODE = "testsql.2txt.us,1430"
try: # If mx extensions are installed, use mxDateTime
import mx.DateTime
doMxDateTimeTest = True
except:
doMxDateTimeTest = False # Requires eGenixMXExtensions
doTimeTest = True # obsolete python time format
if doAccessTest:
if proxy_host: # determine the (probably remote) database file folder
c = {"macro_find_temp_test_path": ["mdb", mdb_name], "proxy_host": proxy_host}
else:
c = {"mdb": setuptestframework.makemdb(testfolder, mdb_name)}
# macro definition for keyword "provider" using macro "is64bit" -- see documentation
# is64bit will return true for 64 bit versions of Python, so the macro will select the ACE provider
# (If running a remote ADO service, this will test the 64-bitedness of the ADO server.)
c["macro_is64bit"] = [
"provider",
"Microsoft.ACE.OLEDB.12.0", # 64 bit provider
"Microsoft.Jet.OLEDB.4.0",
] # 32 bit provider
connStrAccess = "Provider=%(provider)s;Data Source=%(mdb)s" # ;Mode=ReadWrite;Persist Security Info=False;Jet OLEDB:Bypass UserInfo Validation=True"
print(
" ...Testing ACCESS connection to {} file...".format(
c.get("mdb", "remote .mdb")
)
)
doAccessTest, connStrAccess, dbAccessconnect = tryconnection.try_connection(
verbose, connStrAccess, 10, **c
)
if doSqlServerTest:
c = {
"host": SQL_HOST_NODE, # name of computer with SQL Server
"database": "adotest",
"user": "adotestuser", # None implies Windows security
"password": "Sq1234567",
# macro definition for keyword "security" using macro "auto_security"
"macro_auto_security": "security",
"provider": "MSOLEDBSQL; MARS Connection=True",
}
if proxy_host:
c["proxy_host"] = proxy_host
connStr = "Provider=%(provider)s; Initial Catalog=%(database)s; Data Source=%(host)s; %(security)s;"
print(" ...Testing MS-SQL login to {}...".format(c["host"]))
(
doSqlServerTest,
connStrSQLServer,
dbSqlServerconnect,
) = tryconnection.try_connection(verbose, connStr, 30, **c)
if doMySqlTest:
c = {
"host": "testmysql.2txt.us",
"database": "adodbapitest",
"user": "adotest",
"password": "12345678",
"port": "3330", # note the nonstandard port for obfuscation
"driver": "MySQL ODBC 5.1 Driver",
} # or _driver="MySQL ODBC 3.51 Driver
if proxy_host:
c["proxy_host"] = proxy_host
c["macro_is64bit"] = [
"provider",
"Provider=MSDASQL;",
] # turn on the 64 bit ODBC adapter only if needed
cs = (
"%(provider)sDriver={%(driver)s};Server=%(host)s;Port=3330;"
+ "Database=%(database)s;user=%(user)s;password=%(password)s;Option=3;"
)
print(" ...Testing MySql login to {}...".format(c["host"]))
doMySqlTest, connStrMySql, dbMySqlconnect = tryconnection.try_connection(
verbose, cs, 5, **c
)
if doPostgresTest:
_computername = "testpg.2txt.us"
_databasename = "adotest"
_username = "adotestuser"
_password = "12345678"
kws = {"timeout": 4}
kws["macro_is64bit"] = [
"prov_drv",
"Provider=MSDASQL;Driver={PostgreSQL Unicode(x64)}",
"Driver=PostgreSQL Unicode",
]
# get driver from http://www.postgresql.org/ftp/odbc/versions/
# test using positional and keyword arguments (bad example for real code)
if proxy_host:
kws["proxy_host"] = proxy_host
print(" ...Testing PostgreSQL login to {}...".format(_computername))
doPostgresTest, connStrPostgres, dbPostgresConnect = tryconnection.try_connection(
verbose,
"%(prov_drv)s;Server=%(host)s;Database=%(database)s;uid=%(user)s;pwd=%(password)s;port=5430;", # note nonstandard port
_username,
_password,
_computername,
_databasename,
**kws
)
assert (
doAccessTest or doSqlServerTest or doMySqlTest or doPostgresTest
), "No database engine found for testing"

View file

@ -1,939 +0,0 @@
#!/usr/bin/env python
""" Python DB API 2.0 driver compliance unit test suite.
This software is Public Domain and may be used without restrictions.
"Now we have booze and barflies entering the discussion, plus rumours of
DBAs on drugs... and I won't tell you what flashes through my mind each
time I read the subject line with 'Anal Compliance' in it. All around
this is turning out to be a thoroughly unwholesome unit test."
-- Ian Bicking
"""
__version__ = "$Revision: 1.15.0 $"[11:-2]
__author__ = "Stuart Bishop <stuart@stuartbishop.net>"
import sys
import time
import unittest
if sys.version[0] >= "3": # python 3.x
_BaseException = Exception
def _failUnless(self, expr, msg=None):
self.assertTrue(expr, msg)
else: # python 2.x
from exceptions import Exception as _BaseException
def _failUnless(self, expr, msg=None):
self.failUnless(expr, msg) ## deprecated since Python 2.6
# set this to "True" to follow API 2.0 to the letter
TEST_FOR_NON_IDEMPOTENT_CLOSE = False
# Revision 1.15 2019/11/22 00:50:00 kf7xm
# Make Turn off IDEMPOTENT_CLOSE a proper skipTest
# Revision 1.14 2013/05/20 11:02:05 kf7xm
# Add a literal string to the format insertion test to catch trivial re-format algorithms
# Revision 1.13 2013/05/08 14:31:50 kf7xm
# Quick switch to Turn off IDEMPOTENT_CLOSE test. Also: Silence teardown failure
# Revision 1.12 2009/02/06 03:35:11 kf7xm
# Tested okay with Python 3.0, includes last minute patches from Mark H.
#
# Revision 1.1.1.1.2.1 2008/09/20 19:54:59 rupole
# Include latest changes from main branch
# Updates for py3k
#
# Revision 1.11 2005/01/02 02:41:01 zenzen
# Update author email address
#
# Revision 1.10 2003/10/09 03:14:14 zenzen
# Add test for DB API 2.0 optional extension, where database exceptions
# are exposed as attributes on the Connection object.
#
# Revision 1.9 2003/08/13 01:16:36 zenzen
# Minor tweak from Stefan Fleiter
#
# Revision 1.8 2003/04/10 00:13:25 zenzen
# Changes, as per suggestions by M.-A. Lemburg
# - Add a table prefix, to ensure namespace collisions can always be avoided
#
# Revision 1.7 2003/02/26 23:33:37 zenzen
# Break out DDL into helper functions, as per request by David Rushby
#
# Revision 1.6 2003/02/21 03:04:33 zenzen
# Stuff from Henrik Ekelund:
# added test_None
# added test_nextset & hooks
#
# Revision 1.5 2003/02/17 22:08:43 zenzen
# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize
# defaults to 1 & generic cursor.callproc test added
#
# Revision 1.4 2003/02/15 00:16:33 zenzen
# Changes, as per suggestions and bug reports by M.-A. Lemburg,
# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar
# - Class renamed
# - Now a subclass of TestCase, to avoid requiring the driver stub
# to use multiple inheritance
# - Reversed the polarity of buggy test in test_description
# - Test exception heirarchy correctly
# - self.populate is now self._populate(), so if a driver stub
# overrides self.ddl1 this change propogates
# - VARCHAR columns now have a width, which will hopefully make the
# DDL even more portible (this will be reversed if it causes more problems)
# - cursor.rowcount being checked after various execute and fetchXXX methods
# - Check for fetchall and fetchmany returning empty lists after results
# are exhausted (already checking for empty lists if select retrieved
# nothing
# - Fix bugs in test_setoutputsize_basic and test_setinputsizes
#
def str2bytes(sval):
if sys.version_info < (3, 0) and isinstance(sval, str):
sval = sval.decode("latin1")
return sval.encode("latin1") # python 3 make unicode into bytes
class DatabaseAPI20Test(unittest.TestCase):
"""Test a database self.driver for DB API 2.0 compatibility.
This implementation tests Gadfly, but the TestCase
is structured so that other self.drivers can subclass this
test case to ensure compiliance with the DB-API. It is
expected that this TestCase may be expanded in the future
if ambiguities or edge conditions are discovered.
The 'Optional Extensions' are not yet being tested.
self.drivers should subclass this test, overriding setUp, tearDown,
self.driver, connect_args and connect_kw_args. Class specification
should be as follows:
import dbapi20
class mytest(dbapi20.DatabaseAPI20Test):
[...]
Don't 'import DatabaseAPI20Test from dbapi20', or you will
confuse the unit tester - just 'import dbapi20'.
"""
# The self.driver module. This should be the module where the 'connect'
# method is to be found
driver = None
connect_args = () # List of arguments to pass to connect
connect_kw_args = {} # Keyword arguments for connect
table_prefix = "dbapi20test_" # If you need to specify a prefix for tables
ddl1 = "create table %sbooze (name varchar(20))" % table_prefix
ddl2 = "create table %sbarflys (name varchar(20), drink varchar(30))" % table_prefix
xddl1 = "drop table %sbooze" % table_prefix
xddl2 = "drop table %sbarflys" % table_prefix
lowerfunc = "lower" # Name of stored procedure to convert string->lowercase
# Some drivers may need to override these helpers, for example adding
# a 'commit' after the execute.
def executeDDL1(self, cursor):
cursor.execute(self.ddl1)
def executeDDL2(self, cursor):
cursor.execute(self.ddl2)
def setUp(self):
"""self.drivers should override this method to perform required setup
if any is necessary, such as creating the database.
"""
pass
def tearDown(self):
"""self.drivers should override this method to perform required cleanup
if any is necessary, such as deleting the test database.
The default drops the tables that may be created.
"""
try:
con = self._connect()
try:
cur = con.cursor()
for ddl in (self.xddl1, self.xddl2):
try:
cur.execute(ddl)
con.commit()
except self.driver.Error:
# Assume table didn't exist. Other tests will check if
# execute is busted.
pass
finally:
con.close()
except _BaseException:
pass
def _connect(self):
try:
r = self.driver.connect(*self.connect_args, **self.connect_kw_args)
except AttributeError:
self.fail("No connect method found in self.driver module")
return r
def test_connect(self):
con = self._connect()
con.close()
def test_apilevel(self):
try:
# Must exist
apilevel = self.driver.apilevel
# Must equal 2.0
self.assertEqual(apilevel, "2.0")
except AttributeError:
self.fail("Driver doesn't define apilevel")
def test_threadsafety(self):
try:
# Must exist
threadsafety = self.driver.threadsafety
# Must be a valid value
_failUnless(self, threadsafety in (0, 1, 2, 3))
except AttributeError:
self.fail("Driver doesn't define threadsafety")
def test_paramstyle(self):
try:
# Must exist
paramstyle = self.driver.paramstyle
# Must be a valid value
_failUnless(
self, paramstyle in ("qmark", "numeric", "named", "format", "pyformat")
)
except AttributeError:
self.fail("Driver doesn't define paramstyle")
def test_Exceptions(self):
# Make sure required exceptions exist, and are in the
# defined heirarchy.
if sys.version[0] == "3": # under Python 3 StardardError no longer exists
self.assertTrue(issubclass(self.driver.Warning, Exception))
self.assertTrue(issubclass(self.driver.Error, Exception))
else:
self.failUnless(issubclass(self.driver.Warning, Exception))
self.failUnless(issubclass(self.driver.Error, Exception))
_failUnless(self, issubclass(self.driver.InterfaceError, self.driver.Error))
_failUnless(self, issubclass(self.driver.DatabaseError, self.driver.Error))
_failUnless(self, issubclass(self.driver.OperationalError, self.driver.Error))
_failUnless(self, issubclass(self.driver.IntegrityError, self.driver.Error))
_failUnless(self, issubclass(self.driver.InternalError, self.driver.Error))
_failUnless(self, issubclass(self.driver.ProgrammingError, self.driver.Error))
_failUnless(self, issubclass(self.driver.NotSupportedError, self.driver.Error))
def test_ExceptionsAsConnectionAttributes(self):
# OPTIONAL EXTENSION
# Test for the optional DB API 2.0 extension, where the exceptions
# are exposed as attributes on the Connection object
# I figure this optional extension will be implemented by any
# driver author who is using this test suite, so it is enabled
# by default.
con = self._connect()
drv = self.driver
_failUnless(self, con.Warning is drv.Warning)
_failUnless(self, con.Error is drv.Error)
_failUnless(self, con.InterfaceError is drv.InterfaceError)
_failUnless(self, con.DatabaseError is drv.DatabaseError)
_failUnless(self, con.OperationalError is drv.OperationalError)
_failUnless(self, con.IntegrityError is drv.IntegrityError)
_failUnless(self, con.InternalError is drv.InternalError)
_failUnless(self, con.ProgrammingError is drv.ProgrammingError)
_failUnless(self, con.NotSupportedError is drv.NotSupportedError)
def test_commit(self):
con = self._connect()
try:
# Commit must work, even if it doesn't do anything
con.commit()
finally:
con.close()
def test_rollback(self):
con = self._connect()
# If rollback is defined, it should either work or throw
# the documented exception
if hasattr(con, "rollback"):
try:
con.rollback()
except self.driver.NotSupportedError:
pass
def test_cursor(self):
con = self._connect()
try:
cur = con.cursor()
finally:
con.close()
def test_cursor_isolation(self):
con = self._connect()
try:
# Make sure cursors created from the same connection have
# the documented transaction isolation level
cur1 = con.cursor()
cur2 = con.cursor()
self.executeDDL1(cur1)
cur1.execute(
"insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix)
)
cur2.execute("select name from %sbooze" % self.table_prefix)
booze = cur2.fetchall()
self.assertEqual(len(booze), 1)
self.assertEqual(len(booze[0]), 1)
self.assertEqual(booze[0][0], "Victoria Bitter")
finally:
con.close()
def test_description(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
self.assertEqual(
cur.description,
None,
"cursor.description should be none after executing a "
"statement that can return no rows (such as DDL)",
)
cur.execute("select name from %sbooze" % self.table_prefix)
self.assertEqual(
len(cur.description), 1, "cursor.description describes too many columns"
)
self.assertEqual(
len(cur.description[0]),
7,
"cursor.description[x] tuples must have 7 elements",
)
self.assertEqual(
cur.description[0][0].lower(),
"name",
"cursor.description[x][0] must return column name",
)
self.assertEqual(
cur.description[0][1],
self.driver.STRING,
"cursor.description[x][1] must return column type. Got %r"
% cur.description[0][1],
)
# Make sure self.description gets reset
self.executeDDL2(cur)
self.assertEqual(
cur.description,
None,
"cursor.description not being set to None when executing "
"no-result statements (eg. DDL)",
)
finally:
con.close()
def test_rowcount(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
_failUnless(
self,
cur.rowcount in (-1, 0), # Bug #543885
"cursor.rowcount should be -1 or 0 after executing no-result "
"statements",
)
cur.execute(
"insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix)
)
_failUnless(
self,
cur.rowcount in (-1, 1),
"cursor.rowcount should == number or rows inserted, or "
"set to -1 after executing an insert statement",
)
cur.execute("select name from %sbooze" % self.table_prefix)
_failUnless(
self,
cur.rowcount in (-1, 1),
"cursor.rowcount should == number of rows returned, or "
"set to -1 after executing a select statement",
)
self.executeDDL2(cur)
self.assertEqual(
cur.rowcount,
-1,
"cursor.rowcount not being reset to -1 after executing "
"no-result statements",
)
finally:
con.close()
lower_func = "lower"
def test_callproc(self):
con = self._connect()
try:
cur = con.cursor()
if self.lower_func and hasattr(cur, "callproc"):
r = cur.callproc(self.lower_func, ("FOO",))
self.assertEqual(len(r), 1)
self.assertEqual(r[0], "FOO")
r = cur.fetchall()
self.assertEqual(len(r), 1, "callproc produced no result set")
self.assertEqual(len(r[0]), 1, "callproc produced invalid result set")
self.assertEqual(r[0][0], "foo", "callproc produced invalid results")
finally:
con.close()
def test_close(self):
con = self._connect()
try:
cur = con.cursor()
finally:
con.close()
# cursor.execute should raise an Error if called after connection
# closed
self.assertRaises(self.driver.Error, self.executeDDL1, cur)
# connection.commit should raise an Error if called after connection'
# closed.'
self.assertRaises(self.driver.Error, con.commit)
# connection.close should raise an Error if called more than once
#!!! reasonable persons differ about the usefulness of this test and this feature !!!
if TEST_FOR_NON_IDEMPOTENT_CLOSE:
self.assertRaises(self.driver.Error, con.close)
else:
self.skipTest(
"Non-idempotent close is considered a bad thing by some people."
)
def test_execute(self):
con = self._connect()
try:
cur = con.cursor()
self._paraminsert(cur)
finally:
con.close()
def _paraminsert(self, cur):
self.executeDDL2(cur)
cur.execute(
"insert into %sbarflys values ('Victoria Bitter', 'thi%%s :may ca%%(u)se? troub:1e')"
% (self.table_prefix)
)
_failUnless(self, cur.rowcount in (-1, 1))
if self.driver.paramstyle == "qmark":
cur.execute(
"insert into %sbarflys values (?, 'thi%%s :may ca%%(u)se? troub:1e')"
% self.table_prefix,
("Cooper's",),
)
elif self.driver.paramstyle == "numeric":
cur.execute(
"insert into %sbarflys values (:1, 'thi%%s :may ca%%(u)se? troub:1e')"
% self.table_prefix,
("Cooper's",),
)
elif self.driver.paramstyle == "named":
cur.execute(
"insert into %sbarflys values (:beer, 'thi%%s :may ca%%(u)se? troub:1e')"
% self.table_prefix,
{"beer": "Cooper's"},
)
elif self.driver.paramstyle == "format":
cur.execute(
"insert into %sbarflys values (%%s, 'thi%%s :may ca%%(u)se? troub:1e')"
% self.table_prefix,
("Cooper's",),
)
elif self.driver.paramstyle == "pyformat":
cur.execute(
"insert into %sbarflys values (%%(beer)s, 'thi%%s :may ca%%(u)se? troub:1e')"
% self.table_prefix,
{"beer": "Cooper's"},
)
else:
self.fail("Invalid paramstyle")
_failUnless(self, cur.rowcount in (-1, 1))
cur.execute("select name, drink from %sbarflys" % self.table_prefix)
res = cur.fetchall()
self.assertEqual(len(res), 2, "cursor.fetchall returned too few rows")
beers = [res[0][0], res[1][0]]
beers.sort()
self.assertEqual(
beers[0],
"Cooper's",
"cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly",
)
self.assertEqual(
beers[1],
"Victoria Bitter",
"cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly",
)
trouble = "thi%s :may ca%(u)se? troub:1e"
self.assertEqual(
res[0][1],
trouble,
"cursor.fetchall retrieved incorrect data, or data inserted "
"incorrectly. Got=%s, Expected=%s" % (repr(res[0][1]), repr(trouble)),
)
self.assertEqual(
res[1][1],
trouble,
"cursor.fetchall retrieved incorrect data, or data inserted "
"incorrectly. Got=%s, Expected=%s" % (repr(res[1][1]), repr(trouble)),
)
def test_executemany(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
largs = [("Cooper's",), ("Boag's",)]
margs = [{"beer": "Cooper's"}, {"beer": "Boag's"}]
if self.driver.paramstyle == "qmark":
cur.executemany(
"insert into %sbooze values (?)" % self.table_prefix, largs
)
elif self.driver.paramstyle == "numeric":
cur.executemany(
"insert into %sbooze values (:1)" % self.table_prefix, largs
)
elif self.driver.paramstyle == "named":
cur.executemany(
"insert into %sbooze values (:beer)" % self.table_prefix, margs
)
elif self.driver.paramstyle == "format":
cur.executemany(
"insert into %sbooze values (%%s)" % self.table_prefix, largs
)
elif self.driver.paramstyle == "pyformat":
cur.executemany(
"insert into %sbooze values (%%(beer)s)" % (self.table_prefix),
margs,
)
else:
self.fail("Unknown paramstyle")
_failUnless(
self,
cur.rowcount in (-1, 2),
"insert using cursor.executemany set cursor.rowcount to "
"incorrect value %r" % cur.rowcount,
)
cur.execute("select name from %sbooze" % self.table_prefix)
res = cur.fetchall()
self.assertEqual(
len(res), 2, "cursor.fetchall retrieved incorrect number of rows"
)
beers = [res[0][0], res[1][0]]
beers.sort()
self.assertEqual(
beers[0], "Boag's", 'incorrect data "%s" retrieved' % beers[0]
)
self.assertEqual(beers[1], "Cooper's", "incorrect data retrieved")
finally:
con.close()
def test_fetchone(self):
con = self._connect()
try:
cur = con.cursor()
# cursor.fetchone should raise an Error if called before
# executing a select-type query
self.assertRaises(self.driver.Error, cur.fetchone)
# cursor.fetchone should raise an Error if called after
# executing a query that cannnot return rows
self.executeDDL1(cur)
self.assertRaises(self.driver.Error, cur.fetchone)
cur.execute("select name from %sbooze" % self.table_prefix)
self.assertEqual(
cur.fetchone(),
None,
"cursor.fetchone should return None if a query retrieves " "no rows",
)
_failUnless(self, cur.rowcount in (-1, 0))
# cursor.fetchone should raise an Error if called after
# executing a query that cannnot return rows
cur.execute(
"insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix)
)
self.assertRaises(self.driver.Error, cur.fetchone)
cur.execute("select name from %sbooze" % self.table_prefix)
r = cur.fetchone()
self.assertEqual(
len(r), 1, "cursor.fetchone should have retrieved a single row"
)
self.assertEqual(
r[0], "Victoria Bitter", "cursor.fetchone retrieved incorrect data"
)
self.assertEqual(
cur.fetchone(),
None,
"cursor.fetchone should return None if no more rows available",
)
_failUnless(self, cur.rowcount in (-1, 1))
finally:
con.close()
samples = [
"Carlton Cold",
"Carlton Draft",
"Mountain Goat",
"Redback",
"Victoria Bitter",
"XXXX",
]
def _populate(self):
"""Return a list of sql commands to setup the DB for the fetch
tests.
"""
populate = [
"insert into %sbooze values ('%s')" % (self.table_prefix, s)
for s in self.samples
]
return populate
def test_fetchmany(self):
con = self._connect()
try:
cur = con.cursor()
# cursor.fetchmany should raise an Error if called without
# issuing a query
self.assertRaises(self.driver.Error, cur.fetchmany, 4)
self.executeDDL1(cur)
for sql in self._populate():
cur.execute(sql)
cur.execute("select name from %sbooze" % self.table_prefix)
r = cur.fetchmany()
self.assertEqual(
len(r),
1,
"cursor.fetchmany retrieved incorrect number of rows, "
"default of arraysize is one.",
)
cur.arraysize = 10
r = cur.fetchmany(3) # Should get 3 rows
self.assertEqual(
len(r), 3, "cursor.fetchmany retrieved incorrect number of rows"
)
r = cur.fetchmany(4) # Should get 2 more
self.assertEqual(
len(r), 2, "cursor.fetchmany retrieved incorrect number of rows"
)
r = cur.fetchmany(4) # Should be an empty sequence
self.assertEqual(
len(r),
0,
"cursor.fetchmany should return an empty sequence after "
"results are exhausted",
)
_failUnless(self, cur.rowcount in (-1, 6))
# Same as above, using cursor.arraysize
cur.arraysize = 4
cur.execute("select name from %sbooze" % self.table_prefix)
r = cur.fetchmany() # Should get 4 rows
self.assertEqual(
len(r), 4, "cursor.arraysize not being honoured by fetchmany"
)
r = cur.fetchmany() # Should get 2 more
self.assertEqual(len(r), 2)
r = cur.fetchmany() # Should be an empty sequence
self.assertEqual(len(r), 0)
_failUnless(self, cur.rowcount in (-1, 6))
cur.arraysize = 6
cur.execute("select name from %sbooze" % self.table_prefix)
rows = cur.fetchmany() # Should get all rows
_failUnless(self, cur.rowcount in (-1, 6))
self.assertEqual(len(rows), 6)
self.assertEqual(len(rows), 6)
rows = [r[0] for r in rows]
rows.sort()
# Make sure we get the right data back out
for i in range(0, 6):
self.assertEqual(
rows[i],
self.samples[i],
"incorrect data retrieved by cursor.fetchmany",
)
rows = cur.fetchmany() # Should return an empty list
self.assertEqual(
len(rows),
0,
"cursor.fetchmany should return an empty sequence if "
"called after the whole result set has been fetched",
)
_failUnless(self, cur.rowcount in (-1, 6))
self.executeDDL2(cur)
cur.execute("select name from %sbarflys" % self.table_prefix)
r = cur.fetchmany() # Should get empty sequence
self.assertEqual(
len(r),
0,
"cursor.fetchmany should return an empty sequence if "
"query retrieved no rows",
)
_failUnless(self, cur.rowcount in (-1, 0))
finally:
con.close()
def test_fetchall(self):
con = self._connect()
try:
cur = con.cursor()
# cursor.fetchall should raise an Error if called
# without executing a query that may return rows (such
# as a select)
self.assertRaises(self.driver.Error, cur.fetchall)
self.executeDDL1(cur)
for sql in self._populate():
cur.execute(sql)
# cursor.fetchall should raise an Error if called
# after executing a a statement that cannot return rows
self.assertRaises(self.driver.Error, cur.fetchall)
cur.execute("select name from %sbooze" % self.table_prefix)
rows = cur.fetchall()
_failUnless(self, cur.rowcount in (-1, len(self.samples)))
self.assertEqual(
len(rows),
len(self.samples),
"cursor.fetchall did not retrieve all rows",
)
rows = [r[0] for r in rows]
rows.sort()
for i in range(0, len(self.samples)):
self.assertEqual(
rows[i], self.samples[i], "cursor.fetchall retrieved incorrect rows"
)
rows = cur.fetchall()
self.assertEqual(
len(rows),
0,
"cursor.fetchall should return an empty list if called "
"after the whole result set has been fetched",
)
_failUnless(self, cur.rowcount in (-1, len(self.samples)))
self.executeDDL2(cur)
cur.execute("select name from %sbarflys" % self.table_prefix)
rows = cur.fetchall()
_failUnless(self, cur.rowcount in (-1, 0))
self.assertEqual(
len(rows),
0,
"cursor.fetchall should return an empty list if "
"a select query returns no rows",
)
finally:
con.close()
def test_mixedfetch(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
for sql in self._populate():
cur.execute(sql)
cur.execute("select name from %sbooze" % self.table_prefix)
rows1 = cur.fetchone()
rows23 = cur.fetchmany(2)
rows4 = cur.fetchone()
rows56 = cur.fetchall()
_failUnless(self, cur.rowcount in (-1, 6))
self.assertEqual(
len(rows23), 2, "fetchmany returned incorrect number of rows"
)
self.assertEqual(
len(rows56), 2, "fetchall returned incorrect number of rows"
)
rows = [rows1[0]]
rows.extend([rows23[0][0], rows23[1][0]])
rows.append(rows4[0])
rows.extend([rows56[0][0], rows56[1][0]])
rows.sort()
for i in range(0, len(self.samples)):
self.assertEqual(
rows[i], self.samples[i], "incorrect data retrieved or inserted"
)
finally:
con.close()
def help_nextset_setUp(self, cur):
"""Should create a procedure called deleteme
that returns two result sets, first the
number of rows in booze then "name from booze"
"""
raise NotImplementedError("Helper not implemented")
# sql="""
# create procedure deleteme as
# begin
# select count(*) from booze
# select name from booze
# end
# """
# cur.execute(sql)
def help_nextset_tearDown(self, cur):
"If cleaning up is needed after nextSetTest"
raise NotImplementedError("Helper not implemented")
# cur.execute("drop procedure deleteme")
def test_nextset(self):
con = self._connect()
try:
cur = con.cursor()
if not hasattr(cur, "nextset"):
return
try:
self.executeDDL1(cur)
sql = self._populate()
for sql in self._populate():
cur.execute(sql)
self.help_nextset_setUp(cur)
cur.callproc("deleteme")
numberofrows = cur.fetchone()
assert numberofrows[0] == len(self.samples)
assert cur.nextset()
names = cur.fetchall()
assert len(names) == len(self.samples)
s = cur.nextset()
assert s == None, "No more return sets, should return None"
finally:
self.help_nextset_tearDown(cur)
finally:
con.close()
def test_nextset(self):
raise NotImplementedError("Drivers need to override this test")
def test_arraysize(self):
# Not much here - rest of the tests for this are in test_fetchmany
con = self._connect()
try:
cur = con.cursor()
_failUnless(
self, hasattr(cur, "arraysize"), "cursor.arraysize must be defined"
)
finally:
con.close()
def test_setinputsizes(self):
con = self._connect()
try:
cur = con.cursor()
cur.setinputsizes((25,))
self._paraminsert(cur) # Make sure cursor still works
finally:
con.close()
def test_setoutputsize_basic(self):
# Basic test is to make sure setoutputsize doesn't blow up
con = self._connect()
try:
cur = con.cursor()
cur.setoutputsize(1000)
cur.setoutputsize(2000, 0)
self._paraminsert(cur) # Make sure the cursor still works
finally:
con.close()
def test_setoutputsize(self):
# Real test for setoutputsize is driver dependant
raise NotImplementedError("Driver needed to override this test")
def test_None(self):
con = self._connect()
try:
cur = con.cursor()
self.executeDDL1(cur)
cur.execute("insert into %sbooze values (NULL)" % self.table_prefix)
cur.execute("select name from %sbooze" % self.table_prefix)
r = cur.fetchall()
self.assertEqual(len(r), 1)
self.assertEqual(len(r[0]), 1)
self.assertEqual(r[0][0], None, "NULL value not returned as None")
finally:
con.close()
def test_Date(self):
d1 = self.driver.Date(2002, 12, 25)
d2 = self.driver.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0)))
# Can we assume this? API doesn't specify, but it seems implied
# self.assertEqual(str(d1),str(d2))
def test_Time(self):
t1 = self.driver.Time(13, 45, 30)
t2 = self.driver.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0)))
# Can we assume this? API doesn't specify, but it seems implied
# self.assertEqual(str(t1),str(t2))
def test_Timestamp(self):
t1 = self.driver.Timestamp(2002, 12, 25, 13, 45, 30)
t2 = self.driver.TimestampFromTicks(
time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0))
)
# Can we assume this? API doesn't specify, but it seems implied
# self.assertEqual(str(t1),str(t2))
def test_Binary(self):
b = self.driver.Binary(str2bytes("Something"))
b = self.driver.Binary(str2bytes(""))
def test_STRING(self):
_failUnless(
self, hasattr(self.driver, "STRING"), "module.STRING must be defined"
)
def test_BINARY(self):
_failUnless(
self, hasattr(self.driver, "BINARY"), "module.BINARY must be defined."
)
def test_NUMBER(self):
_failUnless(
self, hasattr(self.driver, "NUMBER"), "module.NUMBER must be defined."
)
def test_DATETIME(self):
_failUnless(
self, hasattr(self.driver, "DATETIME"), "module.DATETIME must be defined."
)
def test_ROWID(self):
_failUnless(
self, hasattr(self.driver, "ROWID"), "module.ROWID must be defined."
)

View file

@ -1,41 +0,0 @@
"""is64bit.Python() --> boolean value of detected Python word size. is64bit.os() --> os build version"""
import sys
def Python():
if sys.platform == "cli": # IronPython
import System
return System.IntPtr.Size == 8
else:
try:
return sys.maxsize > 2147483647
except AttributeError:
return sys.maxint > 2147483647
def os():
import platform
pm = platform.machine()
if pm != ".." and pm.endswith("64"): # recent Python (not Iron)
return True
else:
import os
if "PROCESSOR_ARCHITEW6432" in os.environ:
return True # 32 bit program running on 64 bit Windows
try:
return os.environ["PROCESSOR_ARCHITECTURE"].endswith(
"64"
) # 64 bit Windows 64 bit program
except IndexError:
pass # not Windows
try:
return "64" in platform.architecture()[0] # this often works in Linux
except:
return False # is an older version of Python, assume also an older os (best we can guess)
if __name__ == "__main__":
print("is64bit.Python() =", Python(), "is64bit.os() =", os())

View file

@ -1,134 +0,0 @@
#!/usr/bin/python2
# Configure this in order to run the testcases.
"setuptestframework.py v 2.6.0.8"
import os
import shutil
import sys
import tempfile
try:
OSErrors = (WindowsError, OSError)
except NameError: # not running on Windows
OSErrors = OSError
def maketemp():
temphome = tempfile.gettempdir()
tempdir = os.path.join(temphome, "adodbapi_test")
try:
os.mkdir(tempdir)
except:
pass
return tempdir
def _cleanup_function(testfolder, mdb_name):
try:
os.unlink(os.path.join(testfolder, mdb_name))
except:
pass # mdb database not present
try:
shutil.rmtree(testfolder)
print(" cleaned up folder", testfolder)
except:
pass # test package not present
def getcleanupfunction():
return _cleanup_function
def find_ado_path():
adoName = os.path.normpath(os.getcwd() + "/../../adodbapi.py")
adoPackage = os.path.dirname(adoName)
return adoPackage
# make a new package directory for the test copy of ado
def makeadopackage(testfolder):
adoName = os.path.normpath(os.getcwd() + "/../adodbapi.py")
adoPath = os.path.dirname(adoName)
if os.path.exists(adoName):
newpackage = os.path.join(testfolder, "adodbapi")
try:
os.mkdir(newpackage)
except OSErrors:
print(
"*Note: temporary adodbapi package already exists: may be two versions running?"
)
for f in os.listdir(adoPath):
if f.endswith(".py"):
shutil.copy(os.path.join(adoPath, f), newpackage)
if sys.version_info >= (3, 0): # only when running Py3.n
save = sys.stdout
sys.stdout = None
from lib2to3.main import main # use 2to3 to make test package
main("lib2to3.fixes", args=["-n", "-w", newpackage])
sys.stdout = save
return testfolder
else:
raise EnvironmentError("Connot find source of adodbapi to test.")
def makemdb(testfolder, mdb_name):
# following setup code borrowed from pywin32 odbc test suite
# kindly contributed by Frank Millman.
import os
_accessdatasource = os.path.join(testfolder, mdb_name)
if os.path.isfile(_accessdatasource):
print("using JET database=", _accessdatasource)
else:
try:
from win32com.client import constants
from win32com.client.gencache import EnsureDispatch
win32 = True
except ImportError: # perhaps we are running IronPython
win32 = False # iron Python
try:
from System import Activator, Type
except:
pass
# Create a brand-new database - what is the story with these?
dbe = None
for suffix in (".36", ".35", ".30"):
try:
if win32:
dbe = EnsureDispatch("DAO.DBEngine" + suffix)
else:
type = Type.GetTypeFromProgID("DAO.DBEngine" + suffix)
dbe = Activator.CreateInstance(type)
break
except:
pass
if dbe:
print(" ...Creating ACCESS db at " + _accessdatasource)
if win32:
workspace = dbe.Workspaces(0)
newdb = workspace.CreateDatabase(
_accessdatasource, constants.dbLangGeneral, constants.dbVersion40
)
else:
newdb = dbe.CreateDatabase(
_accessdatasource, ";LANGID=0x0409;CP=1252;COUNTRY=0"
)
newdb.Close()
else:
print(" ...copying test ACCESS db to " + _accessdatasource)
mdbName = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "examples", "test.mdb")
)
import shutil
shutil.copy(mdbName, _accessdatasource)
return _accessdatasource
if __name__ == "__main__":
print("Setting up a Jet database for server to use for remote testing...")
temp = maketemp()
makemdb(temp, "server_test.mdb")

View file

@ -1,200 +0,0 @@
print("This module depends on the dbapi20 compliance tests created by Stuart Bishop")
print("(see db-sig mailing list history for info)")
import platform
import sys
import unittest
import dbapi20
import setuptestframework
testfolder = setuptestframework.maketemp()
if "--package" in sys.argv:
pth = setuptestframework.makeadopackage(testfolder)
sys.argv.remove("--package")
else:
pth = setuptestframework.find_ado_path()
if pth not in sys.path:
sys.path.insert(1, pth)
# function to clean up the temporary folder -- calling program must run this function before exit.
cleanup = setuptestframework.getcleanupfunction()
import adodbapi
import adodbapi.is64bit as is64bit
db = adodbapi
if "--verbose" in sys.argv:
db.adodbapi.verbose = 3
print(adodbapi.version)
print("Tested with dbapi20 %s" % dbapi20.__version__)
try:
onWindows = bool(sys.getwindowsversion()) # seems to work on all versions of Python
except:
onWindows = False
node = platform.node()
conn_kws = {}
host = "testsql.2txt.us,1430" # if None, will use macro to fill in node name
instance = r"%s\SQLEXPRESS"
conn_kws["name"] = "adotest"
conn_kws["user"] = "adotestuser" # None implies Windows security
conn_kws["password"] = "Sq1234567"
# macro definition for keyword "security" using macro "auto_security"
conn_kws["macro_auto_security"] = "security"
if host is None:
conn_kws["macro_getnode"] = ["host", instance]
else:
conn_kws["host"] = host
conn_kws[
"provider"
] = "Provider=MSOLEDBSQL;DataTypeCompatibility=80;MARS Connection=True;"
connStr = "%(provider)s; %(security)s; Initial Catalog=%(name)s;Data Source=%(host)s"
if onWindows and node != "z-PC":
pass # default should make a local SQL Server connection
elif node == "xxx": # try Postgres database
_computername = "25.223.161.222"
_databasename = "adotest"
_username = "adotestuser"
_password = "12345678"
_driver = "PostgreSQL Unicode"
_provider = ""
connStr = "%sDriver={%s};Server=%s;Database=%s;uid=%s;pwd=%s;" % (
_provider,
_driver,
_computername,
_databasename,
_username,
_password,
)
elif node == "yyy": # ACCESS data base is known to fail some tests.
if is64bit.Python():
driver = "Microsoft.ACE.OLEDB.12.0"
else:
driver = "Microsoft.Jet.OLEDB.4.0"
testmdb = setuptestframework.makemdb(testfolder)
connStr = r"Provider=%s;Data Source=%s" % (driver, testmdb)
else: # try a remote connection to an SQL server
conn_kws["proxy_host"] = "25.44.77.176"
import adodbapi.remote
db = adodbapi.remote
print("Using Connection String like=%s" % connStr)
print("Keywords=%s" % repr(conn_kws))
class test_adodbapi(dbapi20.DatabaseAPI20Test):
driver = db
connect_args = (connStr,)
connect_kw_args = conn_kws
def __init__(self, arg):
dbapi20.DatabaseAPI20Test.__init__(self, arg)
def getTestMethodName(self):
return self.id().split(".")[-1]
def setUp(self):
# Call superclass setUp In case this does something in the
# future
dbapi20.DatabaseAPI20Test.setUp(self)
if self.getTestMethodName() == "test_callproc":
con = self._connect()
engine = con.dbms_name
## print('Using database Engine=%s' % engine) ##
if engine != "MS Jet":
sql = """
create procedure templower
@theData varchar(50)
as
select lower(@theData)
"""
else: # Jet
sql = """
create procedure templower
(theData varchar(50))
as
select lower(theData);
"""
cur = con.cursor()
try:
cur.execute(sql)
con.commit()
except:
pass
cur.close()
con.close()
self.lower_func = "templower"
def tearDown(self):
if self.getTestMethodName() == "test_callproc":
con = self._connect()
cur = con.cursor()
try:
cur.execute("drop procedure templower")
except:
pass
con.commit()
dbapi20.DatabaseAPI20Test.tearDown(self)
def help_nextset_setUp(self, cur):
"Should create a procedure called deleteme"
'that returns two result sets, first the number of rows in booze then "name from booze"'
sql = """
create procedure deleteme as
begin
select count(*) from %sbooze
select name from %sbooze
end
""" % (
self.table_prefix,
self.table_prefix,
)
cur.execute(sql)
def help_nextset_tearDown(self, cur):
"If cleaning up is needed after nextSetTest"
try:
cur.execute("drop procedure deleteme")
except:
pass
def test_nextset(self):
con = self._connect()
try:
cur = con.cursor()
stmts = [self.ddl1] + self._populate()
for sql in stmts:
cur.execute(sql)
self.help_nextset_setUp(cur)
cur.callproc("deleteme")
numberofrows = cur.fetchone()
assert numberofrows[0] == 6
assert cur.nextset()
names = cur.fetchall()
assert len(names) == len(self.samples)
s = cur.nextset()
assert s == None, "No more return sets, should return None"
finally:
try:
self.help_nextset_tearDown(cur)
finally:
con.close()
def test_setoutputsize(self):
pass
if __name__ == "__main__":
unittest.main()
cleanup(testfolder, None)

View file

@ -1,33 +0,0 @@
remote = False # automatic testing of remote access has been removed here
def try_connection(verbose, *args, **kwargs):
import adodbapi
dbconnect = adodbapi.connect
try:
s = dbconnect(*args, **kwargs) # connect to server
if verbose:
print("Connected to:", s.connection_string)
print("which has tables:", s.get_table_names())
s.close() # thanks, it worked, goodbye
except adodbapi.DatabaseError as inst:
print(inst.args[0]) # should be the error message
print("***Failed getting connection using=", repr(args), repr(kwargs))
return False, (args, kwargs), None
print(" (successful)")
return True, (args, kwargs, remote), dbconnect
def try_operation_with_expected_exception(
expected_exception_list, some_function, *args, **kwargs
):
try:
some_function(*args, **kwargs)
except expected_exception_list as e:
return True, e
except:
raise # an exception other than the expected occurred
return False, "The expected exception did not occur"

View file

@ -1,396 +0,0 @@
import math
import sys
from dataclasses import dataclass
from datetime import timezone
from typing import TYPE_CHECKING, Any, Callable, Iterator, Optional, SupportsFloat, SupportsIndex, TypeVar, Union
if sys.version_info < (3, 8):
from typing_extensions import Protocol, runtime_checkable
else:
from typing import Protocol, runtime_checkable
if sys.version_info < (3, 9):
from typing_extensions import Annotated, Literal
else:
from typing import Annotated, Literal
if sys.version_info < (3, 10):
EllipsisType = type(Ellipsis)
KW_ONLY = {}
SLOTS = {}
else:
from types import EllipsisType
KW_ONLY = {"kw_only": True}
SLOTS = {"slots": True}
__all__ = (
'BaseMetadata',
'GroupedMetadata',
'Gt',
'Ge',
'Lt',
'Le',
'Interval',
'MultipleOf',
'MinLen',
'MaxLen',
'Len',
'Timezone',
'Predicate',
'LowerCase',
'UpperCase',
'IsDigits',
'IsFinite',
'IsNotFinite',
'IsNan',
'IsNotNan',
'IsInfinite',
'IsNotInfinite',
'doc',
'DocInfo',
'__version__',
)
__version__ = '0.6.0'
T = TypeVar('T')
# arguments that start with __ are considered
# positional only
# see https://peps.python.org/pep-0484/#positional-only-arguments
class SupportsGt(Protocol):
def __gt__(self: T, __other: T) -> bool:
...
class SupportsGe(Protocol):
def __ge__(self: T, __other: T) -> bool:
...
class SupportsLt(Protocol):
def __lt__(self: T, __other: T) -> bool:
...
class SupportsLe(Protocol):
def __le__(self: T, __other: T) -> bool:
...
class SupportsMod(Protocol):
def __mod__(self: T, __other: T) -> T:
...
class SupportsDiv(Protocol):
def __div__(self: T, __other: T) -> T:
...
class BaseMetadata:
"""Base class for all metadata.
This exists mainly so that implementers
can do `isinstance(..., BaseMetadata)` while traversing field annotations.
"""
__slots__ = ()
@dataclass(frozen=True, **SLOTS)
class Gt(BaseMetadata):
"""Gt(gt=x) implies that the value must be greater than x.
It can be used with any type that supports the ``>`` operator,
including numbers, dates and times, strings, sets, and so on.
"""
gt: SupportsGt
@dataclass(frozen=True, **SLOTS)
class Ge(BaseMetadata):
"""Ge(ge=x) implies that the value must be greater than or equal to x.
It can be used with any type that supports the ``>=`` operator,
including numbers, dates and times, strings, sets, and so on.
"""
ge: SupportsGe
@dataclass(frozen=True, **SLOTS)
class Lt(BaseMetadata):
"""Lt(lt=x) implies that the value must be less than x.
It can be used with any type that supports the ``<`` operator,
including numbers, dates and times, strings, sets, and so on.
"""
lt: SupportsLt
@dataclass(frozen=True, **SLOTS)
class Le(BaseMetadata):
"""Le(le=x) implies that the value must be less than or equal to x.
It can be used with any type that supports the ``<=`` operator,
including numbers, dates and times, strings, sets, and so on.
"""
le: SupportsLe
@runtime_checkable
class GroupedMetadata(Protocol):
"""A grouping of multiple BaseMetadata objects.
`GroupedMetadata` on its own is not metadata and has no meaning.
All it the the constraint and metadata should be fully expressable
in terms of the `BaseMetadata`'s returned by `GroupedMetadata.__iter__()`.
Concrete implementations should override `GroupedMetadata.__iter__()`
to add their own metadata.
For example:
>>> @dataclass
>>> class Field(GroupedMetadata):
>>> gt: float | None = None
>>> description: str | None = None
...
>>> def __iter__(self) -> Iterable[BaseMetadata]:
>>> if self.gt is not None:
>>> yield Gt(self.gt)
>>> if self.description is not None:
>>> yield Description(self.gt)
Also see the implementation of `Interval` below for an example.
Parsers should recognize this and unpack it so that it can be used
both with and without unpacking:
- `Annotated[int, Field(...)]` (parser must unpack Field)
- `Annotated[int, *Field(...)]` (PEP-646)
""" # noqa: trailing-whitespace
@property
def __is_annotated_types_grouped_metadata__(self) -> Literal[True]:
return True
def __iter__(self) -> Iterator[BaseMetadata]:
...
if not TYPE_CHECKING:
__slots__ = () # allow subclasses to use slots
def __init_subclass__(cls, *args: Any, **kwargs: Any) -> None:
# Basic ABC like functionality without the complexity of an ABC
super().__init_subclass__(*args, **kwargs)
if cls.__iter__ is GroupedMetadata.__iter__:
raise TypeError("Can't subclass GroupedMetadata without implementing __iter__")
def __iter__(self) -> Iterator[BaseMetadata]: # noqa: F811
raise NotImplementedError # more helpful than "None has no attribute..." type errors
@dataclass(frozen=True, **KW_ONLY, **SLOTS)
class Interval(GroupedMetadata):
"""Interval can express inclusive or exclusive bounds with a single object.
It accepts keyword arguments ``gt``, ``ge``, ``lt``, and/or ``le``, which
are interpreted the same way as the single-bound constraints.
"""
gt: Union[SupportsGt, None] = None
ge: Union[SupportsGe, None] = None
lt: Union[SupportsLt, None] = None
le: Union[SupportsLe, None] = None
def __iter__(self) -> Iterator[BaseMetadata]:
"""Unpack an Interval into zero or more single-bounds."""
if self.gt is not None:
yield Gt(self.gt)
if self.ge is not None:
yield Ge(self.ge)
if self.lt is not None:
yield Lt(self.lt)
if self.le is not None:
yield Le(self.le)
@dataclass(frozen=True, **SLOTS)
class MultipleOf(BaseMetadata):
"""MultipleOf(multiple_of=x) might be interpreted in two ways:
1. Python semantics, implying ``value % multiple_of == 0``, or
2. JSONschema semantics, where ``int(value / multiple_of) == value / multiple_of``
We encourage users to be aware of these two common interpretations,
and libraries to carefully document which they implement.
"""
multiple_of: Union[SupportsDiv, SupportsMod]
@dataclass(frozen=True, **SLOTS)
class MinLen(BaseMetadata):
"""
MinLen() implies minimum inclusive length,
e.g. ``len(value) >= min_length``.
"""
min_length: Annotated[int, Ge(0)]
@dataclass(frozen=True, **SLOTS)
class MaxLen(BaseMetadata):
"""
MaxLen() implies maximum inclusive length,
e.g. ``len(value) <= max_length``.
"""
max_length: Annotated[int, Ge(0)]
@dataclass(frozen=True, **SLOTS)
class Len(GroupedMetadata):
"""
Len() implies that ``min_length <= len(value) <= max_length``.
Upper bound may be omitted or ``None`` to indicate no upper length bound.
"""
min_length: Annotated[int, Ge(0)] = 0
max_length: Optional[Annotated[int, Ge(0)]] = None
def __iter__(self) -> Iterator[BaseMetadata]:
"""Unpack a Len into zone or more single-bounds."""
if self.min_length > 0:
yield MinLen(self.min_length)
if self.max_length is not None:
yield MaxLen(self.max_length)
@dataclass(frozen=True, **SLOTS)
class Timezone(BaseMetadata):
"""Timezone(tz=...) requires a datetime to be aware (or ``tz=None``, naive).
``Annotated[datetime, Timezone(None)]`` must be a naive datetime.
``Timezone[...]`` (the ellipsis literal) expresses that the datetime must be
tz-aware but any timezone is allowed.
You may also pass a specific timezone string or timezone object such as
``Timezone(timezone.utc)`` or ``Timezone("Africa/Abidjan")`` to express that
you only allow a specific timezone, though we note that this is often
a symptom of poor design.
"""
tz: Union[str, timezone, EllipsisType, None]
@dataclass(frozen=True, **SLOTS)
class Predicate(BaseMetadata):
"""``Predicate(func: Callable)`` implies `func(value)` is truthy for valid values.
Users should prefer statically inspectable metadata, but if you need the full
power and flexibility of arbitrary runtime predicates... here it is.
We provide a few predefined predicates for common string constraints:
``IsLower = Predicate(str.islower)``, ``IsUpper = Predicate(str.isupper)``, and
``IsDigit = Predicate(str.isdigit)``. Users are encouraged to use methods which
can be given special handling, and avoid indirection like ``lambda s: s.lower()``.
Some libraries might have special logic to handle certain predicates, e.g. by
checking for `str.isdigit` and using its presence to both call custom logic to
enforce digit-only strings, and customise some generated external schema.
We do not specify what behaviour should be expected for predicates that raise
an exception. For example `Annotated[int, Predicate(str.isdigit)]` might silently
skip invalid constraints, or statically raise an error; or it might try calling it
and then propogate or discard the resulting exception.
"""
func: Callable[[Any], bool]
@dataclass
class Not:
func: Callable[[Any], bool]
def __call__(self, __v: Any) -> bool:
return not self.func(__v)
_StrType = TypeVar("_StrType", bound=str)
LowerCase = Annotated[_StrType, Predicate(str.islower)]
"""
Return True if the string is a lowercase string, False otherwise.
A string is lowercase if all cased characters in the string are lowercase and there is at least one cased character in the string.
""" # noqa: E501
UpperCase = Annotated[_StrType, Predicate(str.isupper)]
"""
Return True if the string is an uppercase string, False otherwise.
A string is uppercase if all cased characters in the string are uppercase and there is at least one cased character in the string.
""" # noqa: E501
IsDigits = Annotated[_StrType, Predicate(str.isdigit)]
"""
Return True if the string is a digit string, False otherwise.
A string is a digit string if all characters in the string are digits and there is at least one character in the string.
""" # noqa: E501
IsAscii = Annotated[_StrType, Predicate(str.isascii)]
"""
Return True if all characters in the string are ASCII, False otherwise.
ASCII characters have code points in the range U+0000-U+007F. Empty string is ASCII too.
"""
_NumericType = TypeVar('_NumericType', bound=Union[SupportsFloat, SupportsIndex])
IsFinite = Annotated[_NumericType, Predicate(math.isfinite)]
"""Return True if x is neither an infinity nor a NaN, and False otherwise."""
IsNotFinite = Annotated[_NumericType, Predicate(Not(math.isfinite))]
"""Return True if x is one of infinity or NaN, and False otherwise"""
IsNan = Annotated[_NumericType, Predicate(math.isnan)]
"""Return True if x is a NaN (not a number), and False otherwise."""
IsNotNan = Annotated[_NumericType, Predicate(Not(math.isnan))]
"""Return True if x is anything but NaN (not a number), and False otherwise."""
IsInfinite = Annotated[_NumericType, Predicate(math.isinf)]
"""Return True if x is a positive or negative infinity, and False otherwise."""
IsNotInfinite = Annotated[_NumericType, Predicate(Not(math.isinf))]
"""Return True if x is neither a positive or negative infinity, and False otherwise."""
try:
from typing_extensions import DocInfo, doc # type: ignore [attr-defined]
except ImportError:
@dataclass(frozen=True, **SLOTS)
class DocInfo: # type: ignore [no-redef]
""" "
The return value of doc(), mainly to be used by tools that want to extract the
Annotated documentation at runtime.
"""
documentation: str
"""The documentation string passed to doc()."""
def doc(
documentation: str,
) -> DocInfo:
"""
Add documentation to a type annotation inside of Annotated.
For example:
>>> def hi(name: Annotated[int, doc("The name of the user")]) -> None: ...
"""
return DocInfo(documentation)

View file

@ -1,147 +0,0 @@
import math
import sys
from datetime import date, datetime, timedelta, timezone
from decimal import Decimal
from typing import Any, Dict, Iterable, Iterator, List, NamedTuple, Set, Tuple
if sys.version_info < (3, 9):
from typing_extensions import Annotated
else:
from typing import Annotated
import annotated_types as at
class Case(NamedTuple):
"""
A test case for `annotated_types`.
"""
annotation: Any
valid_cases: Iterable[Any]
invalid_cases: Iterable[Any]
def cases() -> Iterable[Case]:
# Gt, Ge, Lt, Le
yield Case(Annotated[int, at.Gt(4)], (5, 6, 1000), (4, 0, -1))
yield Case(Annotated[float, at.Gt(0.5)], (0.6, 0.7, 0.8, 0.9), (0.5, 0.0, -0.1))
yield Case(
Annotated[datetime, at.Gt(datetime(2000, 1, 1))],
[datetime(2000, 1, 2), datetime(2000, 1, 3)],
[datetime(2000, 1, 1), datetime(1999, 12, 31)],
)
yield Case(
Annotated[datetime, at.Gt(date(2000, 1, 1))],
[date(2000, 1, 2), date(2000, 1, 3)],
[date(2000, 1, 1), date(1999, 12, 31)],
)
yield Case(
Annotated[datetime, at.Gt(Decimal('1.123'))],
[Decimal('1.1231'), Decimal('123')],
[Decimal('1.123'), Decimal('0')],
)
yield Case(Annotated[int, at.Ge(4)], (4, 5, 6, 1000, 4), (0, -1))
yield Case(Annotated[float, at.Ge(0.5)], (0.5, 0.6, 0.7, 0.8, 0.9), (0.4, 0.0, -0.1))
yield Case(
Annotated[datetime, at.Ge(datetime(2000, 1, 1))],
[datetime(2000, 1, 2), datetime(2000, 1, 3)],
[datetime(1998, 1, 1), datetime(1999, 12, 31)],
)
yield Case(Annotated[int, at.Lt(4)], (0, -1), (4, 5, 6, 1000, 4))
yield Case(Annotated[float, at.Lt(0.5)], (0.4, 0.0, -0.1), (0.5, 0.6, 0.7, 0.8, 0.9))
yield Case(
Annotated[datetime, at.Lt(datetime(2000, 1, 1))],
[datetime(1999, 12, 31), datetime(1999, 12, 31)],
[datetime(2000, 1, 2), datetime(2000, 1, 3)],
)
yield Case(Annotated[int, at.Le(4)], (4, 0, -1), (5, 6, 1000))
yield Case(Annotated[float, at.Le(0.5)], (0.5, 0.0, -0.1), (0.6, 0.7, 0.8, 0.9))
yield Case(
Annotated[datetime, at.Le(datetime(2000, 1, 1))],
[datetime(2000, 1, 1), datetime(1999, 12, 31)],
[datetime(2000, 1, 2), datetime(2000, 1, 3)],
)
# Interval
yield Case(Annotated[int, at.Interval(gt=4)], (5, 6, 1000), (4, 0, -1))
yield Case(Annotated[int, at.Interval(gt=4, lt=10)], (5, 6), (4, 10, 1000, 0, -1))
yield Case(Annotated[float, at.Interval(ge=0.5, le=1)], (0.5, 0.9, 1), (0.49, 1.1))
yield Case(
Annotated[datetime, at.Interval(gt=datetime(2000, 1, 1), le=datetime(2000, 1, 3))],
[datetime(2000, 1, 2), datetime(2000, 1, 3)],
[datetime(2000, 1, 1), datetime(2000, 1, 4)],
)
yield Case(Annotated[int, at.MultipleOf(multiple_of=3)], (0, 3, 9), (1, 2, 4))
yield Case(Annotated[float, at.MultipleOf(multiple_of=0.5)], (0, 0.5, 1, 1.5), (0.4, 1.1))
# lengths
yield Case(Annotated[str, at.MinLen(3)], ('123', '1234', 'x' * 10), ('', '1', '12'))
yield Case(Annotated[str, at.Len(3)], ('123', '1234', 'x' * 10), ('', '1', '12'))
yield Case(Annotated[List[int], at.MinLen(3)], ([1, 2, 3], [1, 2, 3, 4], [1] * 10), ([], [1], [1, 2]))
yield Case(Annotated[List[int], at.Len(3)], ([1, 2, 3], [1, 2, 3, 4], [1] * 10), ([], [1], [1, 2]))
yield Case(Annotated[str, at.MaxLen(4)], ('', '1234'), ('12345', 'x' * 10))
yield Case(Annotated[str, at.Len(0, 4)], ('', '1234'), ('12345', 'x' * 10))
yield Case(Annotated[List[str], at.MaxLen(4)], ([], ['a', 'bcdef'], ['a', 'b', 'c']), (['a'] * 5, ['b'] * 10))
yield Case(Annotated[List[str], at.Len(0, 4)], ([], ['a', 'bcdef'], ['a', 'b', 'c']), (['a'] * 5, ['b'] * 10))
yield Case(Annotated[str, at.Len(3, 5)], ('123', '12345'), ('', '1', '12', '123456', 'x' * 10))
yield Case(Annotated[str, at.Len(3, 3)], ('123',), ('12', '1234'))
yield Case(Annotated[Dict[int, int], at.Len(2, 3)], [{1: 1, 2: 2}], [{}, {1: 1}, {1: 1, 2: 2, 3: 3, 4: 4}])
yield Case(Annotated[Set[int], at.Len(2, 3)], ({1, 2}, {1, 2, 3}), (set(), {1}, {1, 2, 3, 4}))
yield Case(Annotated[Tuple[int, ...], at.Len(2, 3)], ((1, 2), (1, 2, 3)), ((), (1,), (1, 2, 3, 4)))
# Timezone
yield Case(
Annotated[datetime, at.Timezone(None)], [datetime(2000, 1, 1)], [datetime(2000, 1, 1, tzinfo=timezone.utc)]
)
yield Case(
Annotated[datetime, at.Timezone(...)], [datetime(2000, 1, 1, tzinfo=timezone.utc)], [datetime(2000, 1, 1)]
)
yield Case(
Annotated[datetime, at.Timezone(timezone.utc)],
[datetime(2000, 1, 1, tzinfo=timezone.utc)],
[datetime(2000, 1, 1), datetime(2000, 1, 1, tzinfo=timezone(timedelta(hours=6)))],
)
yield Case(
Annotated[datetime, at.Timezone('Europe/London')],
[datetime(2000, 1, 1, tzinfo=timezone(timedelta(0), name='Europe/London'))],
[datetime(2000, 1, 1), datetime(2000, 1, 1, tzinfo=timezone(timedelta(hours=6)))],
)
# predicate types
yield Case(at.LowerCase[str], ['abc', 'foobar'], ['', 'A', 'Boom'])
yield Case(at.UpperCase[str], ['ABC', 'DEFO'], ['', 'a', 'abc', 'AbC'])
yield Case(at.IsDigits[str], ['123'], ['', 'ab', 'a1b2'])
yield Case(at.IsAscii[str], ['123', 'foo bar'], ['£100', '😊', 'whatever 👀'])
yield Case(Annotated[int, at.Predicate(lambda x: x % 2 == 0)], [0, 2, 4], [1, 3, 5])
yield Case(at.IsFinite[float], [1.23], [math.nan, math.inf, -math.inf])
yield Case(at.IsNotFinite[float], [math.nan, math.inf], [1.23])
yield Case(at.IsNan[float], [math.nan], [1.23, math.inf])
yield Case(at.IsNotNan[float], [1.23, math.inf], [math.nan])
yield Case(at.IsInfinite[float], [math.inf], [math.nan, 1.23])
yield Case(at.IsNotInfinite[float], [math.nan, 1.23], [math.inf])
# check stacked predicates
yield Case(at.IsInfinite[Annotated[float, at.Predicate(lambda x: x > 0)]], [math.inf], [-math.inf, 1.23, math.nan])
# doc
yield Case(Annotated[int, at.doc("A number")], [1, 2], [])
# custom GroupedMetadata
class MyCustomGroupedMetadata(at.GroupedMetadata):
def __iter__(self) -> Iterator[at.Predicate]:
yield at.Predicate(lambda x: float(x).is_integer())
yield Case(Annotated[float, MyCustomGroupedMetadata()], [0, 2.0], [0.01, 1.5])

View file

@ -20,7 +20,7 @@ from functools import wraps
from inspect import signature
async def _run_forever_coro(coro, args, kwargs, loop):
def _launch_forever_coro(coro, args, kwargs, loop):
'''
This helper function launches an async main function that was tagged with
forever=True. There are two possibilities:
@ -48,7 +48,7 @@ async def _run_forever_coro(coro, args, kwargs, loop):
# forever=True feature from autoasync at some point in the future.
thing = coro(*args, **kwargs)
if iscoroutine(thing):
await thing
loop.create_task(thing)
def autoasync(coro=None, *, loop=None, forever=False, pass_loop=False):
@ -127,9 +127,7 @@ def autoasync(coro=None, *, loop=None, forever=False, pass_loop=False):
args, kwargs = bound_args.args, bound_args.kwargs
if forever:
local_loop.create_task(_run_forever_coro(
coro, args, kwargs, local_loop
))
_launch_forever_coro(coro, args, kwargs, local_loop)
local_loop.run_forever()
else:
return local_loop.run_until_complete(coro(*args, **kwargs))

View file

@ -452,6 +452,6 @@ class WSGIErrorHandler(logging.Handler):
class LazyRfc3339UtcTime(object):
def __str__(self):
"""Return utcnow() in RFC3339 UTC Format."""
iso_formatted_now = datetime.datetime.utcnow().isoformat('T')
return f'{iso_formatted_now!s}Z'
"""Return now() in RFC3339 UTC Format."""
now = datetime.datetime.now()
return now.isoformat('T') + 'Z'

View file

@ -622,15 +622,13 @@ def autovary(ignore=None, debug=False):
def convert_params(exception=ValueError, error=400):
"""Convert request params based on function annotations.
"""Convert request params based on function annotations, with error handling.
This function also processes errors that are subclasses of ``exception``.
exception
Exception class to catch.
:param BaseException exception: Exception class to catch.
:type exception: BaseException
:param error: The HTTP status code to return to the client on failure.
:type error: int
status
The HTTP error code to return to the client on failure.
"""
request = cherrypy.serving.request
types = request.handler.callable.__annotations__

View file

@ -47,9 +47,7 @@ try:
import pstats
def new_func_strip_path(func_name):
"""Add ``__init__`` modules' parents.
This makes the profiler output more readable.
"""Make profiler output more readable by adding `__init__` modules' parents
"""
filename, line, name = func_name
if filename.endswith('__init__.py'):

View file

@ -188,7 +188,7 @@ class Parser(configparser.ConfigParser):
def dict_from_file(self, file):
if hasattr(file, 'read'):
self.read_file(file)
self.readfp(file)
else:
self.read(file)
return self.as_dict()

View file

@ -1,18 +1,19 @@
"""Module with helpers for serving static files."""
import mimetypes
import os
import platform
import re
import stat
import unicodedata
import mimetypes
import urllib.parse
import unicodedata
from email.generator import _make_boundary as make_boundary
from io import UnsupportedOperation
import cherrypy
from cherrypy._cpcompat import ntob
from cherrypy.lib import cptools, file_generator_limited, httputil
from cherrypy.lib import cptools, httputil, file_generator_limited
def _setup_mimetypes():
@ -184,10 +185,7 @@ def serve_fileobj(fileobj, content_type=None, disposition=None, name=None,
def _serve_fileobj(fileobj, content_type, content_length, debug=False):
"""Set ``response.body`` to the given file object, perhaps ranged.
Internal helper.
"""
"""Internal. Set response.body to the given file object, perhaps ranged."""
response = cherrypy.serving.response
# HTTP/1.0 didn't have Range/Accept-Ranges headers, or the 206 code

View file

@ -494,7 +494,7 @@ class Bus(object):
"Cannot reconstruct command from '-c'. "
'Ref: https://github.com/cherrypy/cherrypy/issues/1545')
except AttributeError:
"""It looks Py_GetArgcArgv's completely absent in some environments
"""It looks Py_GetArgcArgv is completely absent in some environments
It is known, that there's no Py_GetArgcArgv in MS Windows and
``ctypes`` module is completely absent in Google AppEngine

View file

@ -136,9 +136,6 @@ class HTTPTests(helper.CPWebCase):
self.assertStatus(200)
self.assertBody(b'Hello world!')
response.close()
c.close()
# Now send a message that has no Content-Length, but does send a body.
# Verify that CP times out the socket and responds
# with 411 Length Required.
@ -162,9 +159,6 @@ class HTTPTests(helper.CPWebCase):
self.status = str(response.status)
self.assertStatus(411)
response.close()
c.close()
def test_post_multipart(self):
alphabet = 'abcdefghijklmnopqrstuvwxyz'
# generate file contents for a large post
@ -190,9 +184,6 @@ class HTTPTests(helper.CPWebCase):
parts = ['%s * 65536' % ch for ch in alphabet]
self.assertBody(', '.join(parts))
response.close()
c.close()
def test_post_filename_with_special_characters(self):
"""Testing that we can handle filenames with special characters.
@ -226,9 +217,6 @@ class HTTPTests(helper.CPWebCase):
self.assertStatus(200)
self.assertBody(fname)
response.close()
c.close()
def test_malformed_request_line(self):
if getattr(cherrypy.server, 'using_apache', False):
return self.skip('skipped due to known Apache differences...')
@ -276,9 +264,6 @@ class HTTPTests(helper.CPWebCase):
self.body = response.fp.read(20)
self.assertBody('Illegal header line.')
response.close()
c.close()
def test_http_over_https(self):
if self.scheme != 'https':
return self.skip('skipped (not running HTTPS)... ')

View file

@ -150,8 +150,6 @@ class IteratorTest(helper.CPWebCase):
self.assertStatus(200)
self.assertBody('0')
itr_conn.close()
# Now we do the same check with streaming - some classes will
# be automatically closed, while others cannot.
stream_counts = {}

View file

@ -1,6 +1,5 @@
"""Basic tests for the CherryPy core: request handling."""
import datetime
import logging
from cheroot.test import webtest
@ -198,33 +197,6 @@ def test_custom_log_format(log_tracker, monkeypatch, server):
)
def test_utc_in_timez(monkeypatch):
"""Test that ``LazyRfc3339UtcTime`` is rendered as ``str`` using UTC timestamp."""
utcoffset8_local_time_in_naive_utc = (
datetime.datetime(
year=2020,
month=1,
day=1,
hour=1,
minute=23,
second=45,
tzinfo=datetime.timezone(datetime.timedelta(hours=8)),
)
.astimezone(datetime.timezone.utc)
.replace(tzinfo=None)
)
class mock_datetime:
@classmethod
def utcnow(cls):
return utcoffset8_local_time_in_naive_utc
monkeypatch.setattr('datetime.datetime', mock_datetime)
rfc3339_utc_time = str(cherrypy._cplogging.LazyRfc3339UtcTime())
expected_time = '2019-12-31T17:23:45Z'
assert rfc3339_utc_time == expected_time
def test_timez_log_format(log_tracker, monkeypatch, server):
"""Test a customized access_log_format string, which is a
feature of _cplogging.LogManager.access()."""

View file

@ -3,6 +3,8 @@ inflect: english language inflection
- correctly generate plurals, ordinals, indefinite articles
- convert numbers to words
Copyright (C) 2010 Paul Dyson
Based upon the Perl module
`Lingua::EN::Inflect <https://metacpan.org/pod/Lingua::EN::Inflect>`_.
@ -68,16 +70,11 @@ from typing import (
cast,
Any,
)
from typing_extensions import Literal
from numbers import Number
from pydantic import Field
from typing_extensions import Annotated
from .compat.pydantic1 import validate_call
from .compat.pydantic import same_method
from pydantic import Field, validate_arguments
from pydantic.typing import Annotated
class UnknownClassicalModeError(Exception):
@ -108,6 +105,14 @@ class BadGenderError(Exception):
pass
STDOUT_ON = False
def print3(txt: str) -> None:
if STDOUT_ON:
print(txt)
def enclose(s: str) -> str:
return f"(?:{s})"
@ -1722,44 +1727,66 @@ plverb_irregular_pres = {
"is": "are",
"was": "were",
"were": "were",
"was": "were",
"have": "have",
"have": "have",
"has": "have",
"do": "do",
"do": "do",
"does": "do",
}
plverb_ambiguous_pres = {
"act": "act",
"act": "act",
"acts": "act",
"blame": "blame",
"blame": "blame",
"blames": "blame",
"can": "can",
"can": "can",
"can": "can",
"must": "must",
"must": "must",
"must": "must",
"fly": "fly",
"fly": "fly",
"flies": "fly",
"copy": "copy",
"copy": "copy",
"copies": "copy",
"drink": "drink",
"drink": "drink",
"drinks": "drink",
"fight": "fight",
"fight": "fight",
"fights": "fight",
"fire": "fire",
"fire": "fire",
"fires": "fire",
"like": "like",
"like": "like",
"likes": "like",
"look": "look",
"look": "look",
"looks": "look",
"make": "make",
"make": "make",
"makes": "make",
"reach": "reach",
"reach": "reach",
"reaches": "reach",
"run": "run",
"run": "run",
"runs": "run",
"sink": "sink",
"sink": "sink",
"sinks": "sink",
"sleep": "sleep",
"sleep": "sleep",
"sleeps": "sleep",
"view": "view",
"view": "view",
"views": "view",
}
@ -1827,7 +1854,7 @@ pl_adj_poss_keys = re.compile(fr"^({enclose('|'.join(pl_adj_poss))})$", re.IGNOR
A_abbrev = re.compile(
r"""
^(?! FJO | [HLMNS]Y. | RY[EO] | SQU
(?! FJO | [HLMNS]Y. | RY[EO] | SQU
| ( F[LR]? | [HL] | MN? | N | RH? | S[CHKLMNPTVW]? | X(YL)?) [AEIOU])
[FHLMNRSX][A-Z]
""",
@ -2026,14 +2053,15 @@ Falsish = Any # ideally, falsish would only validate on bool(value) is False
class engine:
def __init__(self) -> None:
self.classical_dict = def_classical.copy()
self.persistent_count: Optional[int] = None
self.mill_count = 0
self.pl_sb_user_defined: List[Optional[Word]] = []
self.pl_v_user_defined: List[Optional[Word]] = []
self.pl_adj_user_defined: List[Optional[Word]] = []
self.si_sb_user_defined: List[Optional[Word]] = []
self.A_a_user_defined: List[Optional[Word]] = []
self.pl_sb_user_defined: List[str] = []
self.pl_v_user_defined: List[str] = []
self.pl_adj_user_defined: List[str] = []
self.si_sb_user_defined: List[str] = []
self.A_a_user_defined: List[str] = []
self.thegender = "neuter"
self.__number_args: Optional[Dict[str, str]] = None
@ -2045,8 +2073,28 @@ class engine:
def _number_args(self, val):
self.__number_args = val
@validate_call
def defnoun(self, singular: Optional[Word], plural: Optional[Word]) -> int:
deprecated_methods = dict(
pl="plural",
plnoun="plural_noun",
plverb="plural_verb",
pladj="plural_adj",
sinoun="single_noun",
prespart="present_participle",
numwords="number_to_words",
plequal="compare",
plnounequal="compare_nouns",
plverbequal="compare_verbs",
pladjequal="compare_adjs",
wordlist="join",
)
def __getattr__(self, meth):
if meth in self.deprecated_methods:
print3(f"{meth}() deprecated, use {self.deprecated_methods[meth]}()")
raise DeprecationWarning
raise AttributeError
def defnoun(self, singular: str, plural: str) -> int:
"""
Set the noun plural of singular to plural.
@ -2057,16 +2105,7 @@ class engine:
self.si_sb_user_defined.extend((plural, singular))
return 1
@validate_call
def defverb(
self,
s1: Optional[Word],
p1: Optional[Word],
s2: Optional[Word],
p2: Optional[Word],
s3: Optional[Word],
p3: Optional[Word],
) -> int:
def defverb(self, s1: str, p1: str, s2: str, p2: str, s3: str, p3: str) -> int:
"""
Set the verb plurals for s1, s2 and s3 to p1, p2 and p3 respectively.
@ -2082,8 +2121,7 @@ class engine:
self.pl_v_user_defined.extend((s1, p1, s2, p2, s3, p3))
return 1
@validate_call
def defadj(self, singular: Optional[Word], plural: Optional[Word]) -> int:
def defadj(self, singular: str, plural: str) -> int:
"""
Set the adjective plural of singular to plural.
@ -2093,8 +2131,7 @@ class engine:
self.pl_adj_user_defined.extend((singular, plural))
return 1
@validate_call
def defa(self, pattern: Optional[Word]) -> int:
def defa(self, pattern: str) -> int:
"""
Define the indefinite article as 'a' for words matching pattern.
@ -2103,8 +2140,7 @@ class engine:
self.A_a_user_defined.extend((pattern, "a"))
return 1
@validate_call
def defan(self, pattern: Optional[Word]) -> int:
def defan(self, pattern: str) -> int:
"""
Define the indefinite article as 'an' for words matching pattern.
@ -2113,7 +2149,7 @@ class engine:
self.A_a_user_defined.extend((pattern, "an"))
return 1
def checkpat(self, pattern: Optional[Word]) -> None:
def checkpat(self, pattern: Optional[str]) -> None:
"""
check for errors in a regex pattern
"""
@ -2122,15 +2158,16 @@ class engine:
try:
re.match(pattern, "")
except re.error:
raise BadUserDefinedPatternError(pattern)
print3(f"\nBad user-defined singular pattern:\n\t{pattern}\n")
raise BadUserDefinedPatternError
def checkpatplural(self, pattern: Optional[Word]) -> None:
def checkpatplural(self, pattern: str) -> None:
"""
check for errors in a regex replace pattern
"""
return
@validate_call
@validate_arguments
def ud_match(self, word: Word, wordlist: Sequence[Optional[Word]]) -> Optional[str]:
for i in range(len(wordlist) - 2, -2, -2): # backwards through even elements
mo = re.search(fr"^{wordlist[i]}$", word, re.IGNORECASE)
@ -2270,7 +2307,7 @@ class engine:
# 0. PERFORM GENERAL INFLECTIONS IN A STRING
@validate_call
@validate_arguments
def inflect(self, text: Word) -> str:
"""
Perform inflections in a string.
@ -2347,7 +2384,7 @@ class engine:
else:
return "", "", ""
@validate_call
@validate_arguments
def plural(self, text: Word, count: Optional[Union[str, int, Any]] = None) -> str:
"""
Return the plural of text.
@ -2371,7 +2408,7 @@ class engine:
)
return f"{pre}{plural}{post}"
@validate_call
@validate_arguments
def plural_noun(
self, text: Word, count: Optional[Union[str, int, Any]] = None
) -> str:
@ -2392,7 +2429,7 @@ class engine:
plural = self.postprocess(word, self._plnoun(word, count))
return f"{pre}{plural}{post}"
@validate_call
@validate_arguments
def plural_verb(
self, text: Word, count: Optional[Union[str, int, Any]] = None
) -> str:
@ -2416,7 +2453,7 @@ class engine:
)
return f"{pre}{plural}{post}"
@validate_call
@validate_arguments
def plural_adj(
self, text: Word, count: Optional[Union[str, int, Any]] = None
) -> str:
@ -2437,7 +2474,7 @@ class engine:
plural = self.postprocess(word, self._pl_special_adjective(word, count) or word)
return f"{pre}{plural}{post}"
@validate_call
@validate_arguments
def compare(self, word1: Word, word2: Word) -> Union[str, bool]:
"""
compare word1 and word2 for equality regardless of plurality
@ -2460,15 +2497,15 @@ class engine:
>>> compare('egg', '')
Traceback (most recent call last):
...
pydantic...ValidationError: ...
...
...at least 1 characters...
pydantic.error_wrappers.ValidationError: 1 validation error for Compare
word2
ensure this value has at least 1 characters...
"""
norms = self.plural_noun, self.plural_verb, self.plural_adj
results = (self._plequal(word1, word2, norm) for norm in norms)
return next(filter(None, results), False)
@validate_call
@validate_arguments
def compare_nouns(self, word1: Word, word2: Word) -> Union[str, bool]:
"""
compare word1 and word2 for equality regardless of plurality
@ -2484,7 +2521,7 @@ class engine:
"""
return self._plequal(word1, word2, self.plural_noun)
@validate_call
@validate_arguments
def compare_verbs(self, word1: Word, word2: Word) -> Union[str, bool]:
"""
compare word1 and word2 for equality regardless of plurality
@ -2500,7 +2537,7 @@ class engine:
"""
return self._plequal(word1, word2, self.plural_verb)
@validate_call
@validate_arguments
def compare_adjs(self, word1: Word, word2: Word) -> Union[str, bool]:
"""
compare word1 and word2 for equality regardless of plurality
@ -2516,13 +2553,13 @@ class engine:
"""
return self._plequal(word1, word2, self.plural_adj)
@validate_call
@validate_arguments
def singular_noun(
self,
text: Word,
count: Optional[Union[int, str, Any]] = None,
gender: Optional[str] = None,
) -> Union[str, Literal[False]]:
) -> Union[str, bool]:
"""
Return the singular of text, where text is a plural noun.
@ -2574,12 +2611,12 @@ class engine:
return "s:p"
self.classical_dict = classval.copy()
if same_method(pl, self.plural) or same_method(pl, self.plural_noun):
if pl == self.plural or pl == self.plural_noun:
if self._pl_check_plurals_N(word1, word2):
return "p:p"
if self._pl_check_plurals_N(word2, word1):
return "p:p"
if same_method(pl, self.plural) or same_method(pl, self.plural_adj):
if pl == self.plural or pl == self.plural_adj:
if self._pl_check_plurals_adj(word1, word2):
return "p:p"
return False
@ -3229,11 +3266,11 @@ class engine:
if words.last in si_sb_irregular_caps:
llen = len(words.last)
return f"{word[:-llen]}{si_sb_irregular_caps[words.last]}"
return "{}{}".format(word[:-llen], si_sb_irregular_caps[words.last])
if words.last.lower() in si_sb_irregular:
llen = len(words.last.lower())
return f"{word[:-llen]}{si_sb_irregular[words.last.lower()]}"
return "{}{}".format(word[:-llen], si_sb_irregular[words.last.lower()])
dash_split = words.lowered.split("-")
if (" ".join(dash_split[-2:])).lower() in si_sb_irregular_compound:
@ -3304,6 +3341,7 @@ class engine:
# HANDLE INCOMPLETELY ASSIMILATED IMPORTS
if self.classical_dict["ancient"]:
if words.lowered[-6:] == "trices":
return word[:-3] + "x"
if words.lowered[-4:] in ("eaux", "ieux"):
@ -3421,6 +3459,7 @@ class engine:
# HANDLE ...o
if words.lowered[-2:] == "os":
if words.last.lower() in si_sb_U_o_os_complete:
return word[:-1]
@ -3450,7 +3489,7 @@ class engine:
# ADJECTIVES
@validate_call
@validate_arguments
def a(self, text: Word, count: Optional[Union[int, str, Any]] = 1) -> str:
"""
Return the appropriate indefinite article followed by text.
@ -3531,7 +3570,7 @@ class engine:
# 2. TRANSLATE ZERO-QUANTIFIED $word TO "no plural($word)"
@validate_call
@validate_arguments
def no(self, text: Word, count: Optional[Union[int, str]] = None) -> str:
"""
If count is 0, no, zero or nil, return 'no' followed by the plural
@ -3569,7 +3608,7 @@ class engine:
# PARTICIPLES
@validate_call
@validate_arguments
def present_participle(self, word: Word) -> str:
"""
Return the present participle for word.
@ -3588,31 +3627,31 @@ class engine:
# NUMERICAL INFLECTIONS
@validate_call(config=dict(arbitrary_types_allowed=True))
def ordinal(self, num: Union[Number, Word]) -> str:
@validate_arguments
def ordinal(self, num: Union[int, Word]) -> str: # noqa: C901
"""
Return the ordinal of num.
>>> ordinal = engine().ordinal
>>> ordinal(1)
'1st'
>>> ordinal('one')
'first'
num can be an integer or text
e.g. ordinal(1) returns '1st'
ordinal('one') returns 'first'
"""
if DIGIT.match(str(num)):
if isinstance(num, (float, int)) and int(num) == num:
if isinstance(num, (int, float)):
n = int(num)
else:
if "." in str(num):
try:
# numbers after decimal,
# so only need last one for ordinal
n = int(str(num)[-1])
n = int(num[-1])
except ValueError: # ends with '.', so need to use whole string
n = int(str(num)[:-1])
n = int(num[:-1])
else:
n = int(num) # type: ignore
n = int(num)
try:
post = nth[n % 100]
except KeyError:
@ -3621,7 +3660,7 @@ class engine:
else:
# Mad props to Damian Conway (?) whose ordinal()
# algorithm is type-bendy enough to foil MyPy
str_num: str = num # type: ignore[assignment]
str_num: str = num # type: ignore[assignment]
mo = ordinal_suff.search(str_num)
if mo:
post = ordinal[mo.group(1)]
@ -3632,6 +3671,7 @@ class engine:
def millfn(self, ind: int = 0) -> str:
if ind > len(mill) - 1:
print3("number out of range")
raise NumOutOfRangeError
return mill[ind]
@ -3747,7 +3787,7 @@ class engine:
num = ONE_DIGIT_WORD.sub(self.unitsub, num, 1)
return num
@validate_call(config=dict(arbitrary_types_allowed=True)) # noqa: C901
@validate_arguments(config=dict(arbitrary_types_allowed=True)) # noqa: C901
def number_to_words( # noqa: C901
self,
num: Union[Number, Word],
@ -3899,7 +3939,7 @@ class engine:
# Join words with commas and a trailing 'and' (when appropriate)...
@validate_call
@validate_arguments
def join(
self,
words: Optional[Sequence[Word]],

View file

@ -1,19 +0,0 @@
class ValidateCallWrapperWrapper:
def __init__(self, wrapped):
self.orig = wrapped
def __eq__(self, other):
return self.raw_function == other.raw_function
@property
def raw_function(self):
return getattr(self.orig, 'raw_function') or self.orig
def same_method(m1, m2) -> bool:
"""
Return whether m1 and m2 are the same method.
Workaround for pydantic/pydantic#6390.
"""
return ValidateCallWrapperWrapper(m1) == ValidateCallWrapperWrapper(m2)

View file

@ -1,8 +0,0 @@
try:
from pydantic import validate_call # type: ignore
except ImportError:
# Pydantic 1
from pydantic import validate_arguments as validate_call # type: ignore
__all__ = ['validate_call']

View file

@ -1,7 +0,0 @@
A Python ISAPI extension. Contributed by Phillip Frantz, and is
Copyright 2002-2003 by Blackdog Software Pty Ltd.
See the 'samples' directory, and particularly samples\README.txt
You can find documentation in the PyWin32.chm file that comes with pywin32 -
you can open this from Pythonwin->Help, or from the start menu.

View file

@ -1,39 +0,0 @@
# The Python ISAPI package.
# Exceptions thrown by the DLL framework.
class ISAPIError(Exception):
def __init__(self, errno, strerror=None, funcname=None):
# named attributes match IOError etc.
self.errno = errno
self.strerror = strerror
self.funcname = funcname
Exception.__init__(self, errno, strerror, funcname)
def __str__(self):
if self.strerror is None:
try:
import win32api
self.strerror = win32api.FormatMessage(self.errno).strip()
except:
self.strerror = "no error message is available"
# str() looks like a win32api error.
return str((self.errno, self.strerror, self.funcname))
class FilterError(ISAPIError):
pass
class ExtensionError(ISAPIError):
pass
# A little development aid - a filter or extension callback function can
# raise one of these exceptions, and the handler module will be reloaded.
# This means you can change your code without restarting IIS.
# After a reload, your filter/extension will have the GetFilterVersion/
# GetExtensionVersion function called, but with None as the first arg.
class InternalReloadException(Exception):
pass

View file

@ -1,92 +0,0 @@
<!-- NOTE: This HTML is displayed inside the CHM file - hence some hrefs
will only work in that environment
-->
<HTML>
<BODY>
<TITLE>Introduction to Python ISAPI support</TITLE>
<h2>Introduction to Python ISAPI support</h2>
<h3>See also</h3>
<ul>
<li><a href="/isapi_modules.html">The isapi related modules</a>
</li>
<li><a href="/isapi_objects.html">The isapi related objects</a>
</li>
</ul>
<p><i>Note: if you are viewing this documentation directly from disk,
most links in this document will fail - you can also find this document in the
CHM file that comes with pywin32, where the links will work</i>
<h3>Introduction</h3>
This documents Python support for hosting ISAPI exensions and filters inside
Microsoft Internet Information Server (IIS). It assumes a basic understanding
of the ISAPI filter and extension mechanism.
<p>
In summary, to implement a filter or extension, you provide a Python module
which defines a Filter and/or Extension class. Once your class has been
loaded, IIS/ISAPI will, via an extension DLL, call methods on your class.
<p>
A filter and a class instance need only provide 3 methods - for filters they
are called <code>GetFilterVersion</code>, <code>HttpFilterProc</code> and
<code>TerminateFilter</code>. For extensions they
are named <code>GetExtensionVersion</code>, <code>HttpExtensionProc</code> and
<code>TerminateExtension</code>. If you are familiar with writing ISAPI
extensions in C/C++, these names and their purpose will be familiar.
<p>
Most of the work is done in the <code>HttpFilterProc</code> and
<code>HttpExtensionProc</code> methods. These both take a single
parameter - an <a href="/HTTP_FILTER_CONTEXT.html">HTTP_FILTER_CONTEXT</a> and
<a href="/EXTENSION_CONTROL_BLOCK.html">EXTENSION_CONTROL_BLOCK</a>
object respectively.
<p>
In addition to these components, there is an 'isapi' package, containing
support facilities (base-classes, exceptions, etc) which can be leveraged
by the extension.
<h4>Base classes</h4>
There are a number of base classes provided to make writing extensions a little
simpler. Of particular note is <code>isapi.threaded_extension.ThreadPoolExtension</code>.
This implements a thread-pool and informs IIS that the request is progressing
in the background. Your sub-class need only provide a <code>Dispatch</code>
method, which is called on one of the worker threads rather than the thread
that the request came in on.
<p>
There is base-class for a filter in <code>isapi.simple</code>, but there is no
equivilent threaded filter - filters work under a different model, where
background processing is not possible.
<h4>Samples</h4>
Please see the <code>isapi/samples</code> directory for some sample filters
and extensions.
<H3>Implementation</H3>
A Python ISAPI filter extension consists of 2 main components:
<UL>
<LI>A DLL used by ISAPI to interface with Python.</LI>
<LI>A Python script used by that DLL to implement the filter or extension
functionality</LI>
</UL>
<h4>Extension DLL</h4>
The DLL is usually managed automatically by the isapi.install module. As the
Python script for the extension is installed, a generic DLL provided with
the isapi package is installed next to the script, and IIS configured to
use this DLL.
<p>
The name of the DLL always has the same base name as the Python script, but
with a leading underscore (_), and an extension of .dll. For example, the
sample "redirector.py" will, when installed, have "_redirector.dll" created
in the same directory.
<p/>
The Python script may provide 2 entry points - methods named __FilterFactory__
and __ExtensionFactory__, both taking no arguments and returning a filter or
extension object.
<h3>Using py2exe and the isapi package</h3>
You can instruct py2exe to create a 'frozen' Python ISAPI filter/extension.
In this case, py2exe will create a package with everything you need in one
directory, and the Python source file embedded in the .zip file.
<p>
In general, you will want to build a seperate installation executable along
with the ISAPI extension. This executable will be built from the same script.
See the ISAPI sample in the py2exe distribution.

View file

@ -1,815 +0,0 @@
"""Installation utilities for Python ISAPI filters and extensions."""
# this code adapted from "Tomcat JK2 ISAPI redirector", part of Apache
# Created July 2004, Mark Hammond.
import imp
import os
import shutil
import stat
import sys
import traceback
import pythoncom
import win32api
import winerror
from win32com.client import Dispatch, GetObject
from win32com.client.gencache import EnsureDispatch, EnsureModule
_APP_INPROC = 0
_APP_OUTPROC = 1
_APP_POOLED = 2
_IIS_OBJECT = "IIS://LocalHost/W3SVC"
_IIS_SERVER = "IIsWebServer"
_IIS_WEBDIR = "IIsWebDirectory"
_IIS_WEBVIRTUALDIR = "IIsWebVirtualDir"
_IIS_FILTERS = "IIsFilters"
_IIS_FILTER = "IIsFilter"
_DEFAULT_SERVER_NAME = "Default Web Site"
_DEFAULT_HEADERS = "X-Powered-By: Python"
_DEFAULT_PROTECTION = _APP_POOLED
# Default is for 'execute' only access - ie, only the extension
# can be used. This can be overridden via your install script.
_DEFAULT_ACCESS_EXECUTE = True
_DEFAULT_ACCESS_READ = False
_DEFAULT_ACCESS_WRITE = False
_DEFAULT_ACCESS_SCRIPT = False
_DEFAULT_CONTENT_INDEXED = False
_DEFAULT_ENABLE_DIR_BROWSING = False
_DEFAULT_ENABLE_DEFAULT_DOC = False
_extensions = [ext for ext, _, _ in imp.get_suffixes()]
is_debug_build = "_d.pyd" in _extensions
this_dir = os.path.abspath(os.path.dirname(__file__))
class FilterParameters:
Name = None
Description = None
Path = None
Server = None
# Params that control if/how AddExtensionFile is called.
AddExtensionFile = True
AddExtensionFile_Enabled = True
AddExtensionFile_GroupID = None # defaults to Name
AddExtensionFile_CanDelete = True
AddExtensionFile_Description = None # defaults to Description.
def __init__(self, **kw):
self.__dict__.update(kw)
class VirtualDirParameters:
Name = None # Must be provided.
Description = None # defaults to Name
AppProtection = _DEFAULT_PROTECTION
Headers = _DEFAULT_HEADERS
Path = None # defaults to WWW root.
Type = _IIS_WEBVIRTUALDIR
AccessExecute = _DEFAULT_ACCESS_EXECUTE
AccessRead = _DEFAULT_ACCESS_READ
AccessWrite = _DEFAULT_ACCESS_WRITE
AccessScript = _DEFAULT_ACCESS_SCRIPT
ContentIndexed = _DEFAULT_CONTENT_INDEXED
EnableDirBrowsing = _DEFAULT_ENABLE_DIR_BROWSING
EnableDefaultDoc = _DEFAULT_ENABLE_DEFAULT_DOC
DefaultDoc = None # Only set in IIS if not None
ScriptMaps = []
ScriptMapUpdate = "end" # can be 'start', 'end', 'replace'
Server = None
def __init__(self, **kw):
self.__dict__.update(kw)
def is_root(self):
"This virtual directory is a root directory if parent and name are blank"
parent, name = self.split_path()
return not parent and not name
def split_path(self):
return split_path(self.Name)
class ScriptMapParams:
Extension = None
Module = None
Flags = 5
Verbs = ""
# Params that control if/how AddExtensionFile is called.
AddExtensionFile = True
AddExtensionFile_Enabled = True
AddExtensionFile_GroupID = None # defaults to Name
AddExtensionFile_CanDelete = True
AddExtensionFile_Description = None # defaults to Description.
def __init__(self, **kw):
self.__dict__.update(kw)
def __str__(self):
"Format this parameter suitable for IIS"
items = [self.Extension, self.Module, self.Flags]
# IIS gets upset if there is a trailing verb comma, but no verbs
if self.Verbs:
items.append(self.Verbs)
items = [str(item) for item in items]
return ",".join(items)
class ISAPIParameters:
ServerName = _DEFAULT_SERVER_NAME
# Description = None
Filters = []
VirtualDirs = []
def __init__(self, **kw):
self.__dict__.update(kw)
verbose = 1 # The level - 0 is quiet.
def log(level, what):
if verbose >= level:
print(what)
# Convert an ADSI COM exception to the Win32 error code embedded in it.
def _GetWin32ErrorCode(com_exc):
hr = com_exc.hresult
# If we have more details in the 'excepinfo' struct, use it.
if com_exc.excepinfo:
hr = com_exc.excepinfo[-1]
if winerror.HRESULT_FACILITY(hr) != winerror.FACILITY_WIN32:
raise
return winerror.SCODE_CODE(hr)
class InstallationError(Exception):
pass
class ItemNotFound(InstallationError):
pass
class ConfigurationError(InstallationError):
pass
def FindPath(options, server, name):
if name.lower().startswith("iis://"):
return name
else:
if name and name[0] != "/":
name = "/" + name
return FindWebServer(options, server) + "/ROOT" + name
def LocateWebServerPath(description):
"""
Find an IIS web server whose name or comment matches the provided
description (case-insensitive).
>>> LocateWebServerPath('Default Web Site') # doctest: +SKIP
or
>>> LocateWebServerPath('1') #doctest: +SKIP
"""
assert len(description) >= 1, "Server name or comment is required"
iis = GetObject(_IIS_OBJECT)
description = description.lower().strip()
for site in iis:
# Name is generally a number, but no need to assume that.
site_attributes = [
getattr(site, attr, "").lower().strip()
for attr in ("Name", "ServerComment")
]
if description in site_attributes:
return site.AdsPath
msg = "No web sites match the description '%s'" % description
raise ItemNotFound(msg)
def GetWebServer(description=None):
"""
Load the web server instance (COM object) for a given instance
or description.
If None is specified, the default website is retrieved (indicated
by the identifier 1.
"""
description = description or "1"
path = LocateWebServerPath(description)
server = LoadWebServer(path)
return server
def LoadWebServer(path):
try:
server = GetObject(path)
except pythoncom.com_error as details:
msg = details.strerror
if exc.excepinfo and exc.excepinfo[2]:
msg = exc.excepinfo[2]
msg = "WebServer %s: %s" % (path, msg)
raise ItemNotFound(msg)
return server
def FindWebServer(options, server_desc):
"""
Legacy function to allow options to define a .server property
to override the other parameter. Use GetWebServer instead.
"""
# options takes precedence
server_desc = options.server or server_desc
# make sure server_desc is unicode (could be mbcs if passed in
# sys.argv).
if server_desc and not isinstance(server_desc, str):
server_desc = server_desc.decode("mbcs")
# get the server (if server_desc is None, the default site is acquired)
server = GetWebServer(server_desc)
return server.adsPath
def split_path(path):
"""
Get the parent path and basename.
>>> split_path('/')
['', '']
>>> split_path('')
['', '']
>>> split_path('foo')
['', 'foo']
>>> split_path('/foo')
['', 'foo']
>>> split_path('/foo/bar')
['/foo', 'bar']
>>> split_path('foo/bar')
['/foo', 'bar']
"""
if not path.startswith("/"):
path = "/" + path
return path.rsplit("/", 1)
def _CreateDirectory(iis_dir, name, params):
# We used to go to lengths to keep an existing virtual directory
# in place. However, in some cases the existing directories got
# into a bad state, and an update failed to get them working.
# So we nuke it first. If this is a problem, we could consider adding
# a --keep-existing option.
try:
# Also seen the Class change to a generic IISObject - so nuke
# *any* existing object, regardless of Class
assert name.strip("/"), "mustn't delete the root!"
iis_dir.Delete("", name)
log(2, "Deleted old directory '%s'" % (name,))
except pythoncom.com_error:
pass
newDir = iis_dir.Create(params.Type, name)
log(2, "Creating new directory '%s' in %s..." % (name, iis_dir.Name))
friendly = params.Description or params.Name
newDir.AppFriendlyName = friendly
# Note that the new directory won't be visible in the IIS UI
# unless the directory exists on the filesystem.
try:
path = params.Path or iis_dir.Path
newDir.Path = path
except AttributeError:
# If params.Type is IIS_WEBDIRECTORY, an exception is thrown
pass
newDir.AppCreate2(params.AppProtection)
# XXX - note that these Headers only work in IIS6 and earlier. IIS7
# only supports them on the w3svc node - not even on individial sites,
# let alone individual extensions in the site!
if params.Headers:
newDir.HttpCustomHeaders = params.Headers
log(2, "Setting directory options...")
newDir.AccessExecute = params.AccessExecute
newDir.AccessRead = params.AccessRead
newDir.AccessWrite = params.AccessWrite
newDir.AccessScript = params.AccessScript
newDir.ContentIndexed = params.ContentIndexed
newDir.EnableDirBrowsing = params.EnableDirBrowsing
newDir.EnableDefaultDoc = params.EnableDefaultDoc
if params.DefaultDoc is not None:
newDir.DefaultDoc = params.DefaultDoc
newDir.SetInfo()
return newDir
def CreateDirectory(params, options):
_CallHook(params, "PreInstall", options)
if not params.Name:
raise ConfigurationError("No Name param")
parent, name = params.split_path()
target_dir = GetObject(FindPath(options, params.Server, parent))
if not params.is_root():
target_dir = _CreateDirectory(target_dir, name, params)
AssignScriptMaps(params.ScriptMaps, target_dir, params.ScriptMapUpdate)
_CallHook(params, "PostInstall", options, target_dir)
log(1, "Configured Virtual Directory: %s" % (params.Name,))
return target_dir
def AssignScriptMaps(script_maps, target, update="replace"):
"""Updates IIS with the supplied script map information.
script_maps is a list of ScriptMapParameter objects
target is an IIS Virtual Directory to assign the script maps to
update is a string indicating how to update the maps, one of ('start',
'end', or 'replace')
"""
# determine which function to use to assign script maps
script_map_func = "_AssignScriptMaps" + update.capitalize()
try:
script_map_func = eval(script_map_func)
except NameError:
msg = "Unknown ScriptMapUpdate option '%s'" % update
raise ConfigurationError(msg)
# use the str method to format the script maps for IIS
script_maps = [str(s) for s in script_maps]
# call the correct function
script_map_func(target, script_maps)
target.SetInfo()
def get_unique_items(sequence, reference):
"Return items in sequence that can't be found in reference."
return tuple([item for item in sequence if item not in reference])
def _AssignScriptMapsReplace(target, script_maps):
target.ScriptMaps = script_maps
def _AssignScriptMapsEnd(target, script_maps):
unique_new_maps = get_unique_items(script_maps, target.ScriptMaps)
target.ScriptMaps = target.ScriptMaps + unique_new_maps
def _AssignScriptMapsStart(target, script_maps):
unique_new_maps = get_unique_items(script_maps, target.ScriptMaps)
target.ScriptMaps = unique_new_maps + target.ScriptMaps
def CreateISAPIFilter(filterParams, options):
server = FindWebServer(options, filterParams.Server)
_CallHook(filterParams, "PreInstall", options)
try:
filters = GetObject(server + "/Filters")
except pythoncom.com_error as exc:
# Brand new sites don't have the '/Filters' collection - create it.
# Any errors other than 'not found' we shouldn't ignore.
if (
winerror.HRESULT_FACILITY(exc.hresult) != winerror.FACILITY_WIN32
or winerror.HRESULT_CODE(exc.hresult) != winerror.ERROR_PATH_NOT_FOUND
):
raise
server_ob = GetObject(server)
filters = server_ob.Create(_IIS_FILTERS, "Filters")
filters.FilterLoadOrder = ""
filters.SetInfo()
# As for VirtualDir, delete an existing one.
assert filterParams.Name.strip("/"), "mustn't delete the root!"
try:
filters.Delete(_IIS_FILTER, filterParams.Name)
log(2, "Deleted old filter '%s'" % (filterParams.Name,))
except pythoncom.com_error:
pass
newFilter = filters.Create(_IIS_FILTER, filterParams.Name)
log(2, "Created new ISAPI filter...")
assert os.path.isfile(filterParams.Path)
newFilter.FilterPath = filterParams.Path
newFilter.FilterDescription = filterParams.Description
newFilter.SetInfo()
load_order = [b.strip() for b in filters.FilterLoadOrder.split(",") if b]
if filterParams.Name not in load_order:
load_order.append(filterParams.Name)
filters.FilterLoadOrder = ",".join(load_order)
filters.SetInfo()
_CallHook(filterParams, "PostInstall", options, newFilter)
log(1, "Configured Filter: %s" % (filterParams.Name,))
return newFilter
def DeleteISAPIFilter(filterParams, options):
_CallHook(filterParams, "PreRemove", options)
server = FindWebServer(options, filterParams.Server)
ob_path = server + "/Filters"
try:
filters = GetObject(ob_path)
except pythoncom.com_error as details:
# failure to open the filters just means a totally clean IIS install
# (IIS5 at least has no 'Filters' key when freshly installed).
log(2, "ISAPI filter path '%s' did not exist." % (ob_path,))
return
try:
assert filterParams.Name.strip("/"), "mustn't delete the root!"
filters.Delete(_IIS_FILTER, filterParams.Name)
log(2, "Deleted ISAPI filter '%s'" % (filterParams.Name,))
except pythoncom.com_error as details:
rc = _GetWin32ErrorCode(details)
if rc != winerror.ERROR_PATH_NOT_FOUND:
raise
log(2, "ISAPI filter '%s' did not exist." % (filterParams.Name,))
# Remove from the load order
load_order = [b.strip() for b in filters.FilterLoadOrder.split(",") if b]
if filterParams.Name in load_order:
load_order.remove(filterParams.Name)
filters.FilterLoadOrder = ",".join(load_order)
filters.SetInfo()
_CallHook(filterParams, "PostRemove", options)
log(1, "Deleted Filter: %s" % (filterParams.Name,))
def _AddExtensionFile(module, def_groupid, def_desc, params, options):
group_id = params.AddExtensionFile_GroupID or def_groupid
desc = params.AddExtensionFile_Description or def_desc
try:
ob = GetObject(_IIS_OBJECT)
ob.AddExtensionFile(
module,
params.AddExtensionFile_Enabled,
group_id,
params.AddExtensionFile_CanDelete,
desc,
)
log(2, "Added extension file '%s' (%s)" % (module, desc))
except (pythoncom.com_error, AttributeError) as details:
# IIS5 always fails. Probably should upgrade this to
# complain more loudly if IIS6 fails.
log(2, "Failed to add extension file '%s': %s" % (module, details))
def AddExtensionFiles(params, options):
"""Register the modules used by the filters/extensions as a trusted
'extension module' - required by the default IIS6 security settings."""
# Add each module only once.
added = {}
for vd in params.VirtualDirs:
for smp in vd.ScriptMaps:
if smp.Module not in added and smp.AddExtensionFile:
_AddExtensionFile(smp.Module, vd.Name, vd.Description, smp, options)
added[smp.Module] = True
for fd in params.Filters:
if fd.Path not in added and fd.AddExtensionFile:
_AddExtensionFile(fd.Path, fd.Name, fd.Description, fd, options)
added[fd.Path] = True
def _DeleteExtensionFileRecord(module, options):
try:
ob = GetObject(_IIS_OBJECT)
ob.DeleteExtensionFileRecord(module)
log(2, "Deleted extension file record for '%s'" % module)
except (pythoncom.com_error, AttributeError) as details:
log(2, "Failed to remove extension file '%s': %s" % (module, details))
def DeleteExtensionFileRecords(params, options):
deleted = {} # only remove each .dll once.
for vd in params.VirtualDirs:
for smp in vd.ScriptMaps:
if smp.Module not in deleted and smp.AddExtensionFile:
_DeleteExtensionFileRecord(smp.Module, options)
deleted[smp.Module] = True
for filter_def in params.Filters:
if filter_def.Path not in deleted and filter_def.AddExtensionFile:
_DeleteExtensionFileRecord(filter_def.Path, options)
deleted[filter_def.Path] = True
def CheckLoaderModule(dll_name):
suffix = ""
if is_debug_build:
suffix = "_d"
template = os.path.join(this_dir, "PyISAPI_loader" + suffix + ".dll")
if not os.path.isfile(template):
raise ConfigurationError("Template loader '%s' does not exist" % (template,))
# We can't do a simple "is newer" check, as the DLL is specific to the
# Python version. So we check the date-time and size are identical,
# and skip the copy in that case.
src_stat = os.stat(template)
try:
dest_stat = os.stat(dll_name)
except os.error:
same = 0
else:
same = (
src_stat[stat.ST_SIZE] == dest_stat[stat.ST_SIZE]
and src_stat[stat.ST_MTIME] == dest_stat[stat.ST_MTIME]
)
if not same:
log(2, "Updating %s->%s" % (template, dll_name))
shutil.copyfile(template, dll_name)
shutil.copystat(template, dll_name)
else:
log(2, "%s is up to date." % (dll_name,))
def _CallHook(ob, hook_name, options, *extra_args):
func = getattr(ob, hook_name, None)
if func is not None:
args = (ob, options) + extra_args
func(*args)
def Install(params, options):
_CallHook(params, "PreInstall", options)
for vd in params.VirtualDirs:
CreateDirectory(vd, options)
for filter_def in params.Filters:
CreateISAPIFilter(filter_def, options)
AddExtensionFiles(params, options)
_CallHook(params, "PostInstall", options)
def RemoveDirectory(params, options):
if params.is_root():
return
try:
directory = GetObject(FindPath(options, params.Server, params.Name))
except pythoncom.com_error as details:
rc = _GetWin32ErrorCode(details)
if rc != winerror.ERROR_PATH_NOT_FOUND:
raise
log(2, "VirtualDirectory '%s' did not exist" % params.Name)
directory = None
if directory is not None:
# Be robust should IIS get upset about unloading.
try:
directory.AppUnLoad()
except:
exc_val = sys.exc_info()[1]
log(2, "AppUnLoad() for %s failed: %s" % (params.Name, exc_val))
# Continue trying to delete it.
try:
parent = GetObject(directory.Parent)
parent.Delete(directory.Class, directory.Name)
log(1, "Deleted Virtual Directory: %s" % (params.Name,))
except:
exc_val = sys.exc_info()[1]
log(1, "Failed to remove directory %s: %s" % (params.Name, exc_val))
def RemoveScriptMaps(vd_params, options):
"Remove script maps from the already installed virtual directory"
parent, name = vd_params.split_path()
target_dir = GetObject(FindPath(options, vd_params.Server, parent))
installed_maps = list(target_dir.ScriptMaps)
for _map in map(str, vd_params.ScriptMaps):
if _map in installed_maps:
installed_maps.remove(_map)
target_dir.ScriptMaps = installed_maps
target_dir.SetInfo()
def Uninstall(params, options):
_CallHook(params, "PreRemove", options)
DeleteExtensionFileRecords(params, options)
for vd in params.VirtualDirs:
_CallHook(vd, "PreRemove", options)
RemoveDirectory(vd, options)
if vd.is_root():
# if this is installed to the root virtual directory, we can't delete it
# so remove the script maps.
RemoveScriptMaps(vd, options)
_CallHook(vd, "PostRemove", options)
for filter_def in params.Filters:
DeleteISAPIFilter(filter_def, options)
_CallHook(params, "PostRemove", options)
# Patch up any missing module names in the params, replacing them with
# the DLL name that hosts this extension/filter.
def _PatchParamsModule(params, dll_name, file_must_exist=True):
if file_must_exist:
if not os.path.isfile(dll_name):
raise ConfigurationError("%s does not exist" % (dll_name,))
# Patch up all references to the DLL.
for f in params.Filters:
if f.Path is None:
f.Path = dll_name
for d in params.VirtualDirs:
for sm in d.ScriptMaps:
if sm.Module is None:
sm.Module = dll_name
def GetLoaderModuleName(mod_name, check_module=None):
# find the name of the DLL hosting us.
# By default, this is "_{module_base_name}.dll"
if hasattr(sys, "frozen"):
# What to do? The .dll knows its name, but this is likely to be
# executed via a .exe, which does not know.
base, ext = os.path.splitext(mod_name)
path, base = os.path.split(base)
# handle the common case of 'foo.exe'/'foow.exe'
if base.endswith("w"):
base = base[:-1]
# For py2exe, we have '_foo.dll' as the standard pyisapi loader - but
# 'foo.dll' is what we use (it just delegates).
# So no leading '_' on the installed name.
dll_name = os.path.abspath(os.path.join(path, base + ".dll"))
else:
base, ext = os.path.splitext(mod_name)
path, base = os.path.split(base)
dll_name = os.path.abspath(os.path.join(path, "_" + base + ".dll"))
# Check we actually have it.
if check_module is None:
check_module = not hasattr(sys, "frozen")
if check_module:
CheckLoaderModule(dll_name)
return dll_name
# Note the 'log' params to these 'builtin' args - old versions of pywin32
# didn't log at all in this function (by intent; anyone calling this was
# responsible). So existing code that calls this function with the old
# signature (ie, without a 'log' param) still gets the same behaviour as
# before...
def InstallModule(conf_module_name, params, options, log=lambda *args: None):
"Install the extension"
if not hasattr(sys, "frozen"):
conf_module_name = os.path.abspath(conf_module_name)
if not os.path.isfile(conf_module_name):
raise ConfigurationError("%s does not exist" % (conf_module_name,))
loader_dll = GetLoaderModuleName(conf_module_name)
_PatchParamsModule(params, loader_dll)
Install(params, options)
log(1, "Installation complete.")
def UninstallModule(conf_module_name, params, options, log=lambda *args: None):
"Remove the extension"
loader_dll = GetLoaderModuleName(conf_module_name, False)
_PatchParamsModule(params, loader_dll, False)
Uninstall(params, options)
log(1, "Uninstallation complete.")
standard_arguments = {
"install": InstallModule,
"remove": UninstallModule,
}
def build_usage(handler_map):
docstrings = [handler.__doc__ for handler in handler_map.values()]
all_args = dict(zip(iter(handler_map.keys()), docstrings))
arg_names = "|".join(iter(all_args.keys()))
usage_string = "%prog [options] [" + arg_names + "]\n"
usage_string += "commands:\n"
for arg, desc in all_args.items():
usage_string += " %-10s: %s" % (arg, desc) + "\n"
return usage_string[:-1]
def MergeStandardOptions(options, params):
"""
Take an options object generated by the command line and merge
the values into the IISParameters object.
"""
pass
# We support 2 ways of extending our command-line/install support.
# * Many of the installation items allow you to specify "PreInstall",
# "PostInstall", "PreRemove" and "PostRemove" hooks
# All hooks are called with the 'params' object being operated on, and
# the 'optparser' options for this session (ie, the command-line options)
# PostInstall for VirtualDirectories and Filters both have an additional
# param - the ADSI object just created.
# * You can pass your own option parser for us to use, and/or define a map
# with your own custom arg handlers. It is a map of 'arg'->function.
# The function is called with (options, log_fn, arg). The function's
# docstring is used in the usage output.
def HandleCommandLine(
params,
argv=None,
conf_module_name=None,
default_arg="install",
opt_parser=None,
custom_arg_handlers={},
):
"""Perform installation or removal of an ISAPI filter or extension.
This module handles standard command-line options and configuration
information, and installs, removes or updates the configuration of an
ISAPI filter or extension.
You must pass your configuration information in params - all other
arguments are optional, and allow you to configure the installation
process.
"""
global verbose
from optparse import OptionParser
argv = argv or sys.argv
if not conf_module_name:
conf_module_name = sys.argv[0]
# convert to a long name so that if we were somehow registered with
# the "short" version but unregistered with the "long" version we
# still work (that will depend on exactly how the installer was
# started)
try:
conf_module_name = win32api.GetLongPathName(conf_module_name)
except win32api.error as exc:
log(
2,
"Couldn't determine the long name for %r: %s" % (conf_module_name, exc),
)
if opt_parser is None:
# Build our own parser.
parser = OptionParser(usage="")
else:
# The caller is providing their own filter, presumably with their
# own options all setup.
parser = opt_parser
# build a usage string if we don't have one.
if not parser.get_usage():
all_handlers = standard_arguments.copy()
all_handlers.update(custom_arg_handlers)
parser.set_usage(build_usage(all_handlers))
# allow the user to use uninstall as a synonym for remove if it wasn't
# defined by the custom arg handlers.
all_handlers.setdefault("uninstall", all_handlers["remove"])
parser.add_option(
"-q",
"--quiet",
action="store_false",
dest="verbose",
default=True,
help="don't print status messages to stdout",
)
parser.add_option(
"-v",
"--verbosity",
action="count",
dest="verbose",
default=1,
help="increase the verbosity of status messages",
)
parser.add_option(
"",
"--server",
action="store",
help="Specifies the IIS server to install/uninstall on."
" Default is '%s/1'" % (_IIS_OBJECT,),
)
(options, args) = parser.parse_args(argv[1:])
MergeStandardOptions(options, params)
verbose = options.verbose
if not args:
args = [default_arg]
try:
for arg in args:
handler = all_handlers[arg]
handler(conf_module_name, params, options, log)
except (ItemNotFound, InstallationError) as details:
if options.verbose > 1:
traceback.print_exc()
print("%s: %s" % (details.__class__.__name__, details))
except KeyError:
parser.error("Invalid arg '%s'" % arg)

View file

@ -1,120 +0,0 @@
"""Constants needed by ISAPI filters and extensions."""
# ======================================================================
# Copyright 2002-2003 by Blackdog Software Pty Ltd.
#
# All Rights Reserved
#
# Permission to use, copy, modify, and distribute this software and
# its documentation for any purpose and without fee is hereby
# granted, provided that the above copyright notice appear in all
# copies and that both that copyright notice and this permission
# notice appear in supporting documentation, and that the name of
# Blackdog Software not be used in advertising or publicity pertaining to
# distribution of the software without specific, written prior
# permission.
#
# BLACKDOG SOFTWARE DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE,
# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN
# NO EVENT SHALL BLACKDOG SOFTWARE BE LIABLE FOR ANY SPECIAL, INDIRECT OR
# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
# ======================================================================
# HTTP reply codes
HTTP_CONTINUE = 100
HTTP_SWITCHING_PROTOCOLS = 101
HTTP_PROCESSING = 102
HTTP_OK = 200
HTTP_CREATED = 201
HTTP_ACCEPTED = 202
HTTP_NON_AUTHORITATIVE = 203
HTTP_NO_CONTENT = 204
HTTP_RESET_CONTENT = 205
HTTP_PARTIAL_CONTENT = 206
HTTP_MULTI_STATUS = 207
HTTP_MULTIPLE_CHOICES = 300
HTTP_MOVED_PERMANENTLY = 301
HTTP_MOVED_TEMPORARILY = 302
HTTP_SEE_OTHER = 303
HTTP_NOT_MODIFIED = 304
HTTP_USE_PROXY = 305
HTTP_TEMPORARY_REDIRECT = 307
HTTP_BAD_REQUEST = 400
HTTP_UNAUTHORIZED = 401
HTTP_PAYMENT_REQUIRED = 402
HTTP_FORBIDDEN = 403
HTTP_NOT_FOUND = 404
HTTP_METHOD_NOT_ALLOWED = 405
HTTP_NOT_ACCEPTABLE = 406
HTTP_PROXY_AUTHENTICATION_REQUIRED = 407
HTTP_REQUEST_TIME_OUT = 408
HTTP_CONFLICT = 409
HTTP_GONE = 410
HTTP_LENGTH_REQUIRED = 411
HTTP_PRECONDITION_FAILED = 412
HTTP_REQUEST_ENTITY_TOO_LARGE = 413
HTTP_REQUEST_URI_TOO_LARGE = 414
HTTP_UNSUPPORTED_MEDIA_TYPE = 415
HTTP_RANGE_NOT_SATISFIABLE = 416
HTTP_EXPECTATION_FAILED = 417
HTTP_UNPROCESSABLE_ENTITY = 422
HTTP_INTERNAL_SERVER_ERROR = 500
HTTP_NOT_IMPLEMENTED = 501
HTTP_BAD_GATEWAY = 502
HTTP_SERVICE_UNAVAILABLE = 503
HTTP_GATEWAY_TIME_OUT = 504
HTTP_VERSION_NOT_SUPPORTED = 505
HTTP_VARIANT_ALSO_VARIES = 506
HSE_STATUS_SUCCESS = 1
HSE_STATUS_SUCCESS_AND_KEEP_CONN = 2
HSE_STATUS_PENDING = 3
HSE_STATUS_ERROR = 4
SF_NOTIFY_SECURE_PORT = 0x00000001
SF_NOTIFY_NONSECURE_PORT = 0x00000002
SF_NOTIFY_READ_RAW_DATA = 0x00008000
SF_NOTIFY_PREPROC_HEADERS = 0x00004000
SF_NOTIFY_AUTHENTICATION = 0x00002000
SF_NOTIFY_URL_MAP = 0x00001000
SF_NOTIFY_ACCESS_DENIED = 0x00000800
SF_NOTIFY_SEND_RESPONSE = 0x00000040
SF_NOTIFY_SEND_RAW_DATA = 0x00000400
SF_NOTIFY_LOG = 0x00000200
SF_NOTIFY_END_OF_REQUEST = 0x00000080
SF_NOTIFY_END_OF_NET_SESSION = 0x00000100
SF_NOTIFY_ORDER_HIGH = 0x00080000
SF_NOTIFY_ORDER_MEDIUM = 0x00040000
SF_NOTIFY_ORDER_LOW = 0x00020000
SF_NOTIFY_ORDER_DEFAULT = SF_NOTIFY_ORDER_LOW
SF_NOTIFY_ORDER_MASK = (
SF_NOTIFY_ORDER_HIGH | SF_NOTIFY_ORDER_MEDIUM | SF_NOTIFY_ORDER_LOW
)
SF_STATUS_REQ_FINISHED = 134217728 # 0x8000000
SF_STATUS_REQ_FINISHED_KEEP_CONN = 134217728 + 1
SF_STATUS_REQ_NEXT_NOTIFICATION = 134217728 + 2
SF_STATUS_REQ_HANDLED_NOTIFICATION = 134217728 + 3
SF_STATUS_REQ_ERROR = 134217728 + 4
SF_STATUS_REQ_READ_NEXT = 134217728 + 5
HSE_IO_SYNC = 0x00000001 # for WriteClient
HSE_IO_ASYNC = 0x00000002 # for WriteClient/TF/EU
HSE_IO_DISCONNECT_AFTER_SEND = 0x00000004 # for TF
HSE_IO_SEND_HEADERS = 0x00000008 # for TF
HSE_IO_NODELAY = 0x00001000 # turn off nagling
# These two are only used by VectorSend
HSE_IO_FINAL_SEND = 0x00000010
HSE_IO_CACHE_RESPONSE = 0x00000020
HSE_EXEC_URL_NO_HEADERS = 0x02
HSE_EXEC_URL_IGNORE_CURRENT_INTERCEPTOR = 0x04
HSE_EXEC_URL_IGNORE_VALIDATION_AND_RANGE = 0x10
HSE_EXEC_URL_DISABLE_CUSTOM_ERROR = 0x20
HSE_EXEC_URL_SSI_CMD = 0x40
HSE_EXEC_URL_HTTP_CACHE_ELIGIBLE = 0x80

View file

@ -1,20 +0,0 @@
In this directory you will find examples of ISAPI filters and extensions.
The filter loading mechanism works like this:
* IIS loads the special Python "loader" DLL. This DLL will generally have a
leading underscore as part of its name.
* This loader DLL looks for a Python module, by removing the first letter of
the DLL base name.
This means that an ISAPI extension module consists of 2 key files - the loader
DLL (eg, "_MyIISModule.dll", and a Python module (which for this example
would be "MyIISModule.py")
When you install an ISAPI extension, the installation code checks to see if
there is a loader DLL for your implementation file - if one does not exist,
or the standard loader is different, it is copied and renamed accordingly.
We use this mechanism to provide the maximum separation between different
Python extensions installed on the same server - otherwise filter order and
other tricky IIS semantics would need to be replicated. Also, each filter
gets its own thread-pool, etc.

View file

@ -1,218 +0,0 @@
# This extension demonstrates some advanced features of the Python ISAPI
# framework.
# We demonstrate:
# * Reloading your Python module without shutting down IIS (eg, when your
# .py implementation file changes.)
# * Custom command-line handling - both additional options and commands.
# * Using a query string - any part of the URL after a '?' is assumed to
# be "variable names" separated by '&' - we will print the values of
# these server variables.
# * If the tail portion of the URL is "ReportUnhealthy", IIS will be
# notified we are unhealthy via a HSE_REQ_REPORT_UNHEALTHY request.
# Whether this is acted upon depends on if the IIS health-checking
# tools are installed, but you should always see the reason written
# to the Windows event log - see the IIS documentation for more.
import os
import stat
import sys
from isapi import isapicon
from isapi.simple import SimpleExtension
if hasattr(sys, "isapidllhandle"):
import win32traceutil
# Notes on reloading
# If your HttpFilterProc or HttpExtensionProc functions raises
# 'isapi.InternalReloadException', the framework will not treat it
# as an error but instead will terminate your extension, reload your
# extension module, re-initialize the instance, and re-issue the request.
# The Initialize functions are called with None as their param. The
# return code from the terminate function is ignored.
#
# This is all the framework does to help you. It is up to your code
# when you raise this exception. This sample uses a Win32 "find
# notification". Whenever windows tells us one of the files in the
# directory has changed, we check if the time of our source-file has
# changed, and set a flag. Next imcoming request, we check the flag and
# raise the special exception if set.
#
# The end result is that the module is automatically reloaded whenever
# the source-file changes - you need take no further action to see your
# changes reflected in the running server.
# The framework only reloads your module - if you have libraries you
# depend on and also want reloaded, you must arrange for this yourself.
# One way of doing this would be to special case the import of these
# modules. Eg:
# --
# try:
# my_module = reload(my_module) # module already imported - reload it
# except NameError:
# import my_module # first time around - import it.
# --
# When your module is imported for the first time, the NameError will
# be raised, and the module imported. When the ISAPI framework reloads
# your module, the existing module will avoid the NameError, and allow
# you to reload that module.
import threading
import win32con
import win32event
import win32file
import winerror
from isapi import InternalReloadException
try:
reload_counter += 1
except NameError:
reload_counter = 0
# A watcher thread that checks for __file__ changing.
# When it detects it, it simply sets "change_detected" to true.
class ReloadWatcherThread(threading.Thread):
def __init__(self):
self.change_detected = False
self.filename = __file__
if self.filename.endswith("c") or self.filename.endswith("o"):
self.filename = self.filename[:-1]
self.handle = win32file.FindFirstChangeNotification(
os.path.dirname(self.filename),
False, # watch tree?
win32con.FILE_NOTIFY_CHANGE_LAST_WRITE,
)
threading.Thread.__init__(self)
def run(self):
last_time = os.stat(self.filename)[stat.ST_MTIME]
while 1:
try:
rc = win32event.WaitForSingleObject(self.handle, win32event.INFINITE)
win32file.FindNextChangeNotification(self.handle)
except win32event.error as details:
# handle closed - thread should terminate.
if details.winerror != winerror.ERROR_INVALID_HANDLE:
raise
break
this_time = os.stat(self.filename)[stat.ST_MTIME]
if this_time != last_time:
print("Detected file change - flagging for reload.")
self.change_detected = True
last_time = this_time
def stop(self):
win32file.FindCloseChangeNotification(self.handle)
# The ISAPI extension - handles requests in our virtual dir, and sends the
# response to the client.
class Extension(SimpleExtension):
"Python advanced sample Extension"
def __init__(self):
self.reload_watcher = ReloadWatcherThread()
self.reload_watcher.start()
def HttpExtensionProc(self, ecb):
# NOTE: If you use a ThreadPoolExtension, you must still perform
# this check in HttpExtensionProc - raising the exception from
# The "Dispatch" method will just cause the exception to be
# rendered to the browser.
if self.reload_watcher.change_detected:
print("Doing reload")
raise InternalReloadException
url = ecb.GetServerVariable("UNICODE_URL")
if url.endswith("ReportUnhealthy"):
ecb.ReportUnhealthy("I'm a little sick")
ecb.SendResponseHeaders("200 OK", "Content-Type: text/html\r\n\r\n", 0)
print("<HTML><BODY>", file=ecb)
qs = ecb.GetServerVariable("QUERY_STRING")
if qs:
queries = qs.split("&")
print("<PRE>", file=ecb)
for q in queries:
val = ecb.GetServerVariable(q, "&lt;no such variable&gt;")
print("%s=%r" % (q, val), file=ecb)
print("</PRE><P/>", file=ecb)
print("This module has been imported", file=ecb)
print("%d times" % (reload_counter,), file=ecb)
print("</BODY></HTML>", file=ecb)
ecb.close()
return isapicon.HSE_STATUS_SUCCESS
def TerminateExtension(self, status):
self.reload_watcher.stop()
# The entry points for the ISAPI extension.
def __ExtensionFactory__():
return Extension()
# Our special command line customization.
# Pre-install hook for our virtual directory.
def PreInstallDirectory(params, options):
# If the user used our special '--description' option,
# then we override our default.
if options.description:
params.Description = options.description
# Post install hook for our entire script
def PostInstall(params, options):
print()
print("The sample has been installed.")
print("Point your browser to /AdvancedPythonSample")
print("If you modify the source file and reload the page,")
print("you should see the reload counter increment")
# Handler for our custom 'status' argument.
def status_handler(options, log, arg):
"Query the status of something"
print("Everything seems to be fine!")
custom_arg_handlers = {"status": status_handler}
if __name__ == "__main__":
# If run from the command-line, install ourselves.
from isapi.install import *
params = ISAPIParameters(PostInstall=PostInstall)
# Setup the virtual directories - this is a list of directories our
# extension uses - in this case only 1.
# Each extension has a "script map" - this is the mapping of ISAPI
# extensions.
sm = [ScriptMapParams(Extension="*", Flags=0)]
vd = VirtualDirParameters(
Name="AdvancedPythonSample",
Description=Extension.__doc__,
ScriptMaps=sm,
ScriptMapUpdate="replace",
# specify the pre-install hook.
PreInstall=PreInstallDirectory,
)
params.VirtualDirs = [vd]
# Setup our custom option parser.
from optparse import OptionParser
parser = OptionParser("") # blank usage, so isapi sets it.
parser.add_option(
"",
"--description",
action="store",
help="custom description to use for the virtual directory",
)
HandleCommandLine(
params, opt_parser=parser, custom_arg_handlers=custom_arg_handlers
)

View file

@ -1,125 +0,0 @@
# This is a sample ISAPI extension written in Python.
#
# Please see README.txt in this directory, and specifically the
# information about the "loader" DLL - installing this sample will create
# "_redirector.dll" in the current directory. The readme explains this.
# Executing this script (or any server config script) will install the extension
# into your web server. As the server executes, the PyISAPI framework will load
# this module and create your Extension and Filter objects.
# This is the simplest possible redirector (or proxy) we can write. The
# extension installs with a mask of '*' in the root of the site.
# As an added bonus though, we optionally show how, on IIS6 and later, we
# can use HSE_ERQ_EXEC_URL to ignore certain requests - in IIS5 and earlier
# we can only do this with an ISAPI filter - see redirector_with_filter for
# an example. If this sample is run on IIS5 or earlier it simply ignores
# any excludes.
import sys
from isapi import isapicon, threaded_extension
try:
from urllib.request import urlopen
except ImportError:
# py3k spelling...
from urllib.request import urlopen
import win32api
# sys.isapidllhandle will exist when we are loaded by the IIS framework.
# In this case we redirect our output to the win32traceutil collector.
if hasattr(sys, "isapidllhandle"):
import win32traceutil
# The site we are proxying.
proxy = "http://www.python.org"
# Urls we exclude (ie, allow IIS to handle itself) - all are lowered,
# and these entries exist by default on Vista...
excludes = ["/iisstart.htm", "/welcome.png"]
# An "io completion" function, called when ecb.ExecURL completes...
def io_callback(ecb, url, cbIO, errcode):
# Get the status of our ExecURL
httpstatus, substatus, win32 = ecb.GetExecURLStatus()
print(
"ExecURL of %r finished with http status %d.%d, win32 status %d (%s)"
% (url, httpstatus, substatus, win32, win32api.FormatMessage(win32).strip())
)
# nothing more to do!
ecb.DoneWithSession()
# The ISAPI extension - handles all requests in the site.
class Extension(threaded_extension.ThreadPoolExtension):
"Python sample Extension"
def Dispatch(self, ecb):
# Note that our ThreadPoolExtension base class will catch exceptions
# in our Dispatch method, and write the traceback to the client.
# That is perfect for this sample, so we don't catch our own.
# print 'IIS dispatching "%s"' % (ecb.GetServerVariable("URL"),)
url = ecb.GetServerVariable("URL").decode("ascii")
for exclude in excludes:
if url.lower().startswith(exclude):
print("excluding %s" % url)
if ecb.Version < 0x60000:
print("(but this is IIS5 or earlier - can't do 'excludes')")
else:
ecb.IOCompletion(io_callback, url)
ecb.ExecURL(
None,
None,
None,
None,
None,
isapicon.HSE_EXEC_URL_IGNORE_CURRENT_INTERCEPTOR,
)
return isapicon.HSE_STATUS_PENDING
new_url = proxy + url
print("Opening %s" % new_url)
fp = urlopen(new_url)
headers = fp.info()
# subtle py3k breakage: in py3k, str(headers) has normalized \r\n
# back to \n and also stuck an extra \n term. py2k leaves the
# \r\n from the server in tact and finishes with a single term.
if sys.version_info < (3, 0):
header_text = str(headers) + "\r\n"
else:
# take *all* trailing \n off, replace remaining with
# \r\n, then add the 2 trailing \r\n.
header_text = str(headers).rstrip("\n").replace("\n", "\r\n") + "\r\n\r\n"
ecb.SendResponseHeaders("200 OK", header_text, False)
ecb.WriteClient(fp.read())
ecb.DoneWithSession()
print("Returned data from '%s'" % (new_url,))
return isapicon.HSE_STATUS_SUCCESS
# The entry points for the ISAPI extension.
def __ExtensionFactory__():
return Extension()
if __name__ == "__main__":
# If run from the command-line, install ourselves.
from isapi.install import *
params = ISAPIParameters()
# Setup the virtual directories - this is a list of directories our
# extension uses - in this case only 1.
# Each extension has a "script map" - this is the mapping of ISAPI
# extensions.
sm = [ScriptMapParams(Extension="*", Flags=0)]
vd = VirtualDirParameters(
Name="/",
Description=Extension.__doc__,
ScriptMaps=sm,
ScriptMapUpdate="replace",
)
params.VirtualDirs = [vd]
HandleCommandLine(params)

View file

@ -1,85 +0,0 @@
# This is a sample ISAPI extension written in Python.
# This is like the other 'redirector' samples, but uses asnch IO when writing
# back to the client (it does *not* use asynch io talking to the remote
# server!)
import sys
import urllib.error
import urllib.parse
import urllib.request
from isapi import isapicon, threaded_extension
# sys.isapidllhandle will exist when we are loaded by the IIS framework.
# In this case we redirect our output to the win32traceutil collector.
if hasattr(sys, "isapidllhandle"):
import win32traceutil
# The site we are proxying.
proxy = "http://www.python.org"
# We synchronously read chunks of this size then asynchronously write them.
CHUNK_SIZE = 8192
# The callback made when IIS completes the asynch write.
def io_callback(ecb, fp, cbIO, errcode):
print("IO callback", ecb, fp, cbIO, errcode)
chunk = fp.read(CHUNK_SIZE)
if chunk:
ecb.WriteClient(chunk, isapicon.HSE_IO_ASYNC)
# and wait for the next callback to say this chunk is done.
else:
# eof - say we are complete.
fp.close()
ecb.DoneWithSession()
# The ISAPI extension - handles all requests in the site.
class Extension(threaded_extension.ThreadPoolExtension):
"Python sample proxy server - asynch version."
def Dispatch(self, ecb):
print('IIS dispatching "%s"' % (ecb.GetServerVariable("URL"),))
url = ecb.GetServerVariable("URL")
new_url = proxy + url
print("Opening %s" % new_url)
fp = urllib.request.urlopen(new_url)
headers = fp.info()
ecb.SendResponseHeaders("200 OK", str(headers) + "\r\n", False)
# now send the first chunk asynchronously
ecb.ReqIOCompletion(io_callback, fp)
chunk = fp.read(CHUNK_SIZE)
if chunk:
ecb.WriteClient(chunk, isapicon.HSE_IO_ASYNC)
return isapicon.HSE_STATUS_PENDING
# no data - just close things now.
ecb.DoneWithSession()
return isapicon.HSE_STATUS_SUCCESS
# The entry points for the ISAPI extension.
def __ExtensionFactory__():
return Extension()
if __name__ == "__main__":
# If run from the command-line, install ourselves.
from isapi.install import *
params = ISAPIParameters()
# Setup the virtual directories - this is a list of directories our
# extension uses - in this case only 1.
# Each extension has a "script map" - this is the mapping of ISAPI
# extensions.
sm = [ScriptMapParams(Extension="*", Flags=0)]
vd = VirtualDirParameters(
Name="/",
Description=Extension.__doc__,
ScriptMaps=sm,
ScriptMapUpdate="replace",
)
params.VirtualDirs = [vd]
HandleCommandLine(params)

View file

@ -1,161 +0,0 @@
# This is a sample configuration file for an ISAPI filter and extension
# written in Python.
#
# Please see README.txt in this directory, and specifically the
# information about the "loader" DLL - installing this sample will create
# "_redirector_with_filter.dll" in the current directory. The readme explains
# this.
# Executing this script (or any server config script) will install the extension
# into your web server. As the server executes, the PyISAPI framework will load
# this module and create your Extension and Filter objects.
# This sample provides sample redirector:
# It is implemented by a filter and an extension, so that some requests can
# be ignored. Compare with 'redirector_simple' which avoids the filter, but
# is unable to selectively ignore certain requests.
# The process is sample uses is:
# * The filter is installed globally, as all filters are.
# * A Virtual Directory named "python" is setup. This dir has our ISAPI
# extension as the only application, mapped to file-extension '*'. Thus, our
# extension handles *all* requests in this directory.
# The basic process is that the filter does URL rewriting, redirecting every
# URL to our Virtual Directory. Our extension then handles this request,
# forwarding the data from the proxied site.
# For example:
# * URL of "index.html" comes in.
# * Filter rewrites this to "/python/index.html"
# * Our extension sees the full "/python/index.html", removes the leading
# portion, and opens and forwards the remote URL.
# This sample is very small - it avoid most error handling, etc. It is for
# demonstration purposes only.
import sys
import urllib.error
import urllib.parse
import urllib.request
from isapi import isapicon, threaded_extension
from isapi.simple import SimpleFilter
# sys.isapidllhandle will exist when we are loaded by the IIS framework.
# In this case we redirect our output to the win32traceutil collector.
if hasattr(sys, "isapidllhandle"):
import win32traceutil
# The site we are proxying.
proxy = "http://www.python.org"
# The name of the virtual directory we install in, and redirect from.
virtualdir = "/python"
# The key feature of this redirector over the simple redirector is that it
# can choose to ignore certain responses by having the filter not rewrite them
# to our virtual dir. For this sample, we just exclude the IIS help directory.
# The ISAPI extension - handles requests in our virtual dir, and sends the
# response to the client.
class Extension(threaded_extension.ThreadPoolExtension):
"Python sample Extension"
def Dispatch(self, ecb):
# Note that our ThreadPoolExtension base class will catch exceptions
# in our Dispatch method, and write the traceback to the client.
# That is perfect for this sample, so we don't catch our own.
# print 'IIS dispatching "%s"' % (ecb.GetServerVariable("URL"),)
url = ecb.GetServerVariable("URL")
if url.startswith(virtualdir):
new_url = proxy + url[len(virtualdir) :]
print("Opening", new_url)
fp = urllib.request.urlopen(new_url)
headers = fp.info()
ecb.SendResponseHeaders("200 OK", str(headers) + "\r\n", False)
ecb.WriteClient(fp.read())
ecb.DoneWithSession()
print("Returned data from '%s'!" % (new_url,))
else:
# this should never happen - we should only see requests that
# start with our virtual directory name.
print("Not proxying '%s'" % (url,))
# The ISAPI filter.
class Filter(SimpleFilter):
"Sample Python Redirector"
filter_flags = isapicon.SF_NOTIFY_PREPROC_HEADERS | isapicon.SF_NOTIFY_ORDER_DEFAULT
def HttpFilterProc(self, fc):
# print "Filter Dispatch"
nt = fc.NotificationType
if nt != isapicon.SF_NOTIFY_PREPROC_HEADERS:
return isapicon.SF_STATUS_REQ_NEXT_NOTIFICATION
pp = fc.GetData()
url = pp.GetHeader("url")
# print "URL is '%s'" % (url,)
prefix = virtualdir
if not url.startswith(prefix):
new_url = prefix + url
print("New proxied URL is '%s'" % (new_url,))
pp.SetHeader("url", new_url)
# For the sake of demonstration, show how the FilterContext
# attribute is used. It always starts out life as None, and
# any assignments made are automatically decref'd by the
# framework during a SF_NOTIFY_END_OF_NET_SESSION notification.
if fc.FilterContext is None:
fc.FilterContext = 0
fc.FilterContext += 1
print("This is request number %d on this connection" % fc.FilterContext)
return isapicon.SF_STATUS_REQ_HANDLED_NOTIFICATION
else:
print("Filter ignoring URL '%s'" % (url,))
# Some older code that handled SF_NOTIFY_URL_MAP.
# ~ print "Have URL_MAP notify"
# ~ urlmap = fc.GetData()
# ~ print "URI is", urlmap.URL
# ~ print "Path is", urlmap.PhysicalPath
# ~ if urlmap.URL.startswith("/UC/"):
# ~ # Find the /UC/ in the physical path, and nuke it (except
# ~ # as the path is physical, it is \)
# ~ p = urlmap.PhysicalPath
# ~ pos = p.index("\\UC\\")
# ~ p = p[:pos] + p[pos+3:]
# ~ p = r"E:\src\pyisapi\webroot\PyTest\formTest.htm"
# ~ print "New path is", p
# ~ urlmap.PhysicalPath = p
# The entry points for the ISAPI extension.
def __FilterFactory__():
return Filter()
def __ExtensionFactory__():
return Extension()
if __name__ == "__main__":
# If run from the command-line, install ourselves.
from isapi.install import *
params = ISAPIParameters()
# Setup all filters - these are global to the site.
params.Filters = [
FilterParameters(Name="PythonRedirector", Description=Filter.__doc__),
]
# Setup the virtual directories - this is a list of directories our
# extension uses - in this case only 1.
# Each extension has a "script map" - this is the mapping of ISAPI
# extensions.
sm = [ScriptMapParams(Extension="*", Flags=0)]
vd = VirtualDirParameters(
Name=virtualdir[1:],
Description=Extension.__doc__,
ScriptMaps=sm,
ScriptMapUpdate="replace",
)
params.VirtualDirs = [vd]
HandleCommandLine(params)

View file

@ -1,195 +0,0 @@
# This extension is used mainly for testing purposes - it is not
# designed to be a simple sample, but instead is a hotch-potch of things
# that attempts to exercise the framework.
import os
import stat
import sys
from isapi import isapicon
from isapi.simple import SimpleExtension
if hasattr(sys, "isapidllhandle"):
import win32traceutil
# We use the same reload support as 'advanced.py' demonstrates.
import threading
import win32con
import win32event
import win32file
import winerror
from isapi import InternalReloadException
# A watcher thread that checks for __file__ changing.
# When it detects it, it simply sets "change_detected" to true.
class ReloadWatcherThread(threading.Thread):
def __init__(self):
self.change_detected = False
self.filename = __file__
if self.filename.endswith("c") or self.filename.endswith("o"):
self.filename = self.filename[:-1]
self.handle = win32file.FindFirstChangeNotification(
os.path.dirname(self.filename),
False, # watch tree?
win32con.FILE_NOTIFY_CHANGE_LAST_WRITE,
)
threading.Thread.__init__(self)
def run(self):
last_time = os.stat(self.filename)[stat.ST_MTIME]
while 1:
try:
rc = win32event.WaitForSingleObject(self.handle, win32event.INFINITE)
win32file.FindNextChangeNotification(self.handle)
except win32event.error as details:
# handle closed - thread should terminate.
if details.winerror != winerror.ERROR_INVALID_HANDLE:
raise
break
this_time = os.stat(self.filename)[stat.ST_MTIME]
if this_time != last_time:
print("Detected file change - flagging for reload.")
self.change_detected = True
last_time = this_time
def stop(self):
win32file.FindCloseChangeNotification(self.handle)
def TransmitFileCallback(ecb, hFile, cbIO, errCode):
print("Transmit complete!")
ecb.close()
# The ISAPI extension - handles requests in our virtual dir, and sends the
# response to the client.
class Extension(SimpleExtension):
"Python test Extension"
def __init__(self):
self.reload_watcher = ReloadWatcherThread()
self.reload_watcher.start()
def HttpExtensionProc(self, ecb):
# NOTE: If you use a ThreadPoolExtension, you must still perform
# this check in HttpExtensionProc - raising the exception from
# The "Dispatch" method will just cause the exception to be
# rendered to the browser.
if self.reload_watcher.change_detected:
print("Doing reload")
raise InternalReloadException
if ecb.GetServerVariable("UNICODE_URL").endswith("test.py"):
file_flags = (
win32con.FILE_FLAG_SEQUENTIAL_SCAN | win32con.FILE_FLAG_OVERLAPPED
)
hfile = win32file.CreateFile(
__file__,
win32con.GENERIC_READ,
0,
None,
win32con.OPEN_EXISTING,
file_flags,
None,
)
flags = (
isapicon.HSE_IO_ASYNC
| isapicon.HSE_IO_DISCONNECT_AFTER_SEND
| isapicon.HSE_IO_SEND_HEADERS
)
# We pass hFile to the callback simply as a way of keeping it alive
# for the duration of the transmission
try:
ecb.TransmitFile(
TransmitFileCallback,
hfile,
int(hfile),
"200 OK",
0,
0,
None,
None,
flags,
)
except:
# Errors keep this source file open!
hfile.Close()
raise
else:
# default response
ecb.SendResponseHeaders("200 OK", "Content-Type: text/html\r\n\r\n", 0)
print("<HTML><BODY>", file=ecb)
print("The root of this site is at", ecb.MapURLToPath("/"), file=ecb)
print("</BODY></HTML>", file=ecb)
ecb.close()
return isapicon.HSE_STATUS_SUCCESS
def TerminateExtension(self, status):
self.reload_watcher.stop()
# The entry points for the ISAPI extension.
def __ExtensionFactory__():
return Extension()
# Our special command line customization.
# Pre-install hook for our virtual directory.
def PreInstallDirectory(params, options):
# If the user used our special '--description' option,
# then we override our default.
if options.description:
params.Description = options.description
# Post install hook for our entire script
def PostInstall(params, options):
print()
print("The sample has been installed.")
print("Point your browser to /PyISAPITest")
# Handler for our custom 'status' argument.
def status_handler(options, log, arg):
"Query the status of something"
print("Everything seems to be fine!")
custom_arg_handlers = {"status": status_handler}
if __name__ == "__main__":
# If run from the command-line, install ourselves.
from isapi.install import *
params = ISAPIParameters(PostInstall=PostInstall)
# Setup the virtual directories - this is a list of directories our
# extension uses - in this case only 1.
# Each extension has a "script map" - this is the mapping of ISAPI
# extensions.
sm = [ScriptMapParams(Extension="*", Flags=0)]
vd = VirtualDirParameters(
Name="PyISAPITest",
Description=Extension.__doc__,
ScriptMaps=sm,
ScriptMapUpdate="replace",
# specify the pre-install hook.
PreInstall=PreInstallDirectory,
)
params.VirtualDirs = [vd]
# Setup our custom option parser.
from optparse import OptionParser
parser = OptionParser("") # blank usage, so isapi sets it.
parser.add_option(
"",
"--description",
action="store",
help="custom description to use for the virtual directory",
)
HandleCommandLine(
params, opt_parser=parser, custom_arg_handlers=custom_arg_handlers
)

View file

@ -1,70 +0,0 @@
"""Simple base-classes for extensions and filters.
None of the filter and extension functions are considered 'optional' by the
framework. These base-classes provide simple implementations for the
Initialize and Terminate functions, allowing you to omit them,
It is not necessary to use these base-classes - but if you don't, you
must ensure each of the required methods are implemented.
"""
class SimpleExtension:
"Base class for a simple ISAPI extension"
def __init__(self):
pass
def GetExtensionVersion(self, vi):
"""Called by the ISAPI framework to get the extension version
The default implementation uses the classes docstring to
set the extension description."""
# nod to our reload capability - vi is None when we are reloaded.
if vi is not None:
vi.ExtensionDesc = self.__doc__
def HttpExtensionProc(self, control_block):
"""Called by the ISAPI framework for each extension request.
sub-classes must provide an implementation for this method.
"""
raise NotImplementedError("sub-classes should override HttpExtensionProc")
def TerminateExtension(self, status):
"""Called by the ISAPI framework as the extension terminates."""
pass
class SimpleFilter:
"Base class for a a simple ISAPI filter"
filter_flags = None
def __init__(self):
pass
def GetFilterVersion(self, fv):
"""Called by the ISAPI framework to get the extension version
The default implementation uses the classes docstring to
set the extension description, and uses the classes
filter_flags attribute to set the ISAPI filter flags - you
must specify filter_flags in your class.
"""
if self.filter_flags is None:
raise RuntimeError("You must specify the filter flags")
# nod to our reload capability - fv is None when we are reloaded.
if fv is not None:
fv.Flags = self.filter_flags
fv.FilterDesc = self.__doc__
def HttpFilterProc(self, fc):
"""Called by the ISAPI framework for each filter request.
sub-classes must provide an implementation for this method.
"""
raise NotImplementedError("sub-classes should override HttpExtensionProc")
def TerminateFilter(self, status):
"""Called by the ISAPI framework as the filter terminates."""
pass

View file

@ -1,3 +0,0 @@
This is a directory for tests of the PyISAPI framework.
For demos, please see the pyisapi 'samples' directory.

View file

@ -1,119 +0,0 @@
# This is an ISAPI extension purely for testing purposes. It is NOT
# a 'demo' (even though it may be useful!)
#
# Install this extension, then point your browser to:
# "http://localhost/pyisapi_test/test1"
# This will execute the method 'test1' below. See below for the list of
# test methods that are acceptable.
import urllib.error
import urllib.parse
import urllib.request
# If we have no console (eg, am running from inside IIS), redirect output
# somewhere useful - in this case, the standard win32 trace collector.
import win32api
import winerror
from isapi import ExtensionError, isapicon, threaded_extension
from isapi.simple import SimpleFilter
try:
win32api.GetConsoleTitle()
except win32api.error:
# No console - redirect
import win32traceutil
# The ISAPI extension - handles requests in our virtual dir, and sends the
# response to the client.
class Extension(threaded_extension.ThreadPoolExtension):
"Python ISAPI Tester"
def Dispatch(self, ecb):
print('Tester dispatching "%s"' % (ecb.GetServerVariable("URL"),))
url = ecb.GetServerVariable("URL")
test_name = url.split("/")[-1]
meth = getattr(self, test_name, None)
if meth is None:
raise AttributeError("No test named '%s'" % (test_name,))
result = meth(ecb)
if result is None:
# This means the test finalized everything
return
ecb.SendResponseHeaders("200 OK", "Content-type: text/html\r\n\r\n", False)
print("<HTML><BODY>Finished running test <i>", test_name, "</i>", file=ecb)
print("<pre>", file=ecb)
print(result, file=ecb)
print("</pre>", file=ecb)
print("</BODY></HTML>", file=ecb)
ecb.DoneWithSession()
def test1(self, ecb):
try:
ecb.GetServerVariable("foo bar")
raise RuntimeError("should have failed!")
except ExtensionError as err:
assert err.errno == winerror.ERROR_INVALID_INDEX, err
return "worked!"
def test_long_vars(self, ecb):
qs = ecb.GetServerVariable("QUERY_STRING")
# Our implementation has a default buffer size of 8k - so we test
# the code that handles an overflow by ensuring there are more
# than 8k worth of chars in the URL.
expected_query = "x" * 8500
if len(qs) == 0:
# Just the URL with no query part - redirect to myself, but with
# a huge query portion.
me = ecb.GetServerVariable("URL")
headers = "Location: " + me + "?" + expected_query + "\r\n\r\n"
ecb.SendResponseHeaders("301 Moved", headers)
ecb.DoneWithSession()
return None
if qs == expected_query:
return "Total length of variable is %d - test worked!" % (len(qs),)
else:
return "Unexpected query portion! Got %d chars, expected %d" % (
len(qs),
len(expected_query),
)
def test_unicode_vars(self, ecb):
# We need to check that we are running IIS6! This seems the only
# effective way from an extension.
ver = float(ecb.GetServerVariable("SERVER_SOFTWARE").split("/")[1])
if ver < 6.0:
return "This is IIS version %g - unicode only works in IIS6 and later" % ver
us = ecb.GetServerVariable("UNICODE_SERVER_NAME")
if not isinstance(us, str):
raise RuntimeError("unexpected type!")
if us != str(ecb.GetServerVariable("SERVER_NAME")):
raise RuntimeError("Unicode and non-unicode values were not the same")
return "worked!"
# The entry points for the ISAPI extension.
def __ExtensionFactory__():
return Extension()
if __name__ == "__main__":
# If run from the command-line, install ourselves.
from isapi.install import *
params = ISAPIParameters()
# Setup the virtual directories - this is a list of directories our
# extension uses - in this case only 1.
# Each extension has a "script map" - this is the mapping of ISAPI
# extensions.
sm = [ScriptMapParams(Extension="*", Flags=0)]
vd = VirtualDirParameters(
Name="pyisapi_test",
Description=Extension.__doc__,
ScriptMaps=sm,
ScriptMapUpdate="replace",
)
params.VirtualDirs = [vd]
HandleCommandLine(params)

View file

@ -1,189 +0,0 @@
"""An ISAPI extension base class implemented using a thread-pool."""
# $Id$
import sys
import threading
import time
import traceback
from pywintypes import OVERLAPPED
from win32event import INFINITE
from win32file import (
CloseHandle,
CreateIoCompletionPort,
GetQueuedCompletionStatus,
PostQueuedCompletionStatus,
)
from win32security import SetThreadToken
import isapi.simple
from isapi import ExtensionError, isapicon
ISAPI_REQUEST = 1
ISAPI_SHUTDOWN = 2
class WorkerThread(threading.Thread):
def __init__(self, extension, io_req_port):
self.running = False
self.io_req_port = io_req_port
self.extension = extension
threading.Thread.__init__(self)
# We wait 15 seconds for a thread to terminate, but if it fails to,
# we don't want the process to hang at exit waiting for it...
self.setDaemon(True)
def run(self):
self.running = True
while self.running:
errCode, bytes, key, overlapped = GetQueuedCompletionStatus(
self.io_req_port, INFINITE
)
if key == ISAPI_SHUTDOWN and overlapped is None:
break
# Let the parent extension handle the command.
dispatcher = self.extension.dispatch_map.get(key)
if dispatcher is None:
raise RuntimeError("Bad request '%s'" % (key,))
dispatcher(errCode, bytes, key, overlapped)
def call_handler(self, cblock):
self.extension.Dispatch(cblock)
# A generic thread-pool based extension, using IO Completion Ports.
# Sub-classes can override one method to implement a simple extension, or
# may leverage the CompletionPort to queue their own requests, and implement a
# fully asynch extension.
class ThreadPoolExtension(isapi.simple.SimpleExtension):
"Base class for an ISAPI extension based around a thread-pool"
max_workers = 20
worker_shutdown_wait = 15000 # 15 seconds for workers to quit...
def __init__(self):
self.workers = []
# extensible dispatch map, for sub-classes that need to post their
# own requests to the completion port.
# Each of these functions is called with the result of
# GetQueuedCompletionStatus for our port.
self.dispatch_map = {
ISAPI_REQUEST: self.DispatchConnection,
}
def GetExtensionVersion(self, vi):
isapi.simple.SimpleExtension.GetExtensionVersion(self, vi)
# As per Q192800, the CompletionPort should be created with the number
# of processors, even if the number of worker threads is much larger.
# Passing 0 means the system picks the number.
self.io_req_port = CreateIoCompletionPort(-1, None, 0, 0)
# start up the workers
self.workers = []
for i in range(self.max_workers):
worker = WorkerThread(self, self.io_req_port)
worker.start()
self.workers.append(worker)
def HttpExtensionProc(self, control_block):
overlapped = OVERLAPPED()
overlapped.object = control_block
PostQueuedCompletionStatus(self.io_req_port, 0, ISAPI_REQUEST, overlapped)
return isapicon.HSE_STATUS_PENDING
def TerminateExtension(self, status):
for worker in self.workers:
worker.running = False
for worker in self.workers:
PostQueuedCompletionStatus(self.io_req_port, 0, ISAPI_SHUTDOWN, None)
# wait for them to terminate - pity we aren't using 'native' threads
# as then we could do a smart wait - but now we need to poll....
end_time = time.time() + self.worker_shutdown_wait / 1000
alive = self.workers
while alive:
if time.time() > end_time:
# xxx - might be nice to log something here.
break
time.sleep(0.2)
alive = [w for w in alive if w.is_alive()]
self.dispatch_map = {} # break circles
CloseHandle(self.io_req_port)
# This is the one operation the base class supports - a simple
# Connection request. We setup the thread-token, and dispatch to the
# sub-class's 'Dispatch' method.
def DispatchConnection(self, errCode, bytes, key, overlapped):
control_block = overlapped.object
# setup the correct user for this request
hRequestToken = control_block.GetImpersonationToken()
SetThreadToken(None, hRequestToken)
try:
try:
self.Dispatch(control_block)
except:
self.HandleDispatchError(control_block)
finally:
# reset the security context
SetThreadToken(None, None)
def Dispatch(self, ecb):
"""Overridden by the sub-class to handle connection requests.
This class creates a thread-pool using a Windows completion port,
and dispatches requests via this port. Sub-classes can generally
implement each connection request using blocking reads and writes, and
the thread-pool will still provide decent response to the end user.
The sub-class can set a max_workers attribute (default is 20). Note
that this generally does *not* mean 20 threads will all be concurrently
running, via the magic of Windows completion ports.
There is no default implementation - sub-classes must implement this.
"""
raise NotImplementedError("sub-classes should override Dispatch")
def HandleDispatchError(self, ecb):
"""Handles errors in the Dispatch method.
When a Dispatch method call fails, this method is called to handle
the exception. The default implementation formats the traceback
in the browser.
"""
ecb.HttpStatusCode = isapicon.HSE_STATUS_ERROR
# control_block.LogData = "we failed!"
exc_typ, exc_val, exc_tb = sys.exc_info()
limit = None
try:
try:
import cgi
ecb.SendResponseHeaders(
"200 OK", "Content-type: text/html\r\n\r\n", False
)
print(file=ecb)
print("<H3>Traceback (most recent call last):</H3>", file=ecb)
list = traceback.format_tb(
exc_tb, limit
) + traceback.format_exception_only(exc_typ, exc_val)
print(
"<PRE>%s<B>%s</B></PRE>"
% (
cgi.escape("".join(list[:-1])),
cgi.escape(list[-1]),
),
file=ecb,
)
except ExtensionError:
# The client disconnected without reading the error body -
# its probably not a real browser at the other end, ignore it.
pass
except:
print("FAILED to render the error message!")
traceback.print_exc()
print("ORIGINAL extension error:")
traceback.print_exception(exc_typ, exc_val, exc_tb)
finally:
# holding tracebacks in a local of a frame that may itself be
# part of a traceback used to be evil and cause leaks!
exc_tb = None
ecb.DoneWithSession()

View file

@ -5,49 +5,23 @@ import itertools
import copy
import functools
import random
from collections.abc import Container, Iterable, Mapping
from typing import Callable, Union
from jaraco.classes.properties import NonDataProperty
import jaraco.text
_Matchable = Union[Callable, Container, Iterable, re.Pattern]
def _dispatch(obj: _Matchable) -> Callable:
# can't rely on singledispatch for Union[Container, Iterable]
# due to ambiguity
# (https://peps.python.org/pep-0443/#abstract-base-classes).
if isinstance(obj, re.Pattern):
return obj.fullmatch
if not isinstance(obj, Callable): # type: ignore
if not isinstance(obj, Container):
obj = set(obj) # type: ignore
obj = obj.__contains__
return obj # type: ignore
class Projection(collections.abc.Mapping):
"""
Project a set of keys over a mapping
>>> sample = {'a': 1, 'b': 2, 'c': 3}
>>> prj = Projection(['a', 'c', 'd'], sample)
>>> dict(prj)
{'a': 1, 'c': 3}
Projection also accepts an iterable or callable or pattern.
>>> iter_prj = Projection(iter('acd'), sample)
>>> call_prj = Projection(lambda k: ord(k) in (97, 99, 100), sample)
>>> pat_prj = Projection(re.compile(r'[acd]'), sample)
>>> prj == iter_prj == call_prj == pat_prj
>>> prj == {'a': 1, 'c': 3}
True
Keys should only appear if they were specified and exist in the space.
Order is retained.
>>> list(prj)
>>> sorted(list(prj.keys()))
['a', 'c']
Attempting to access a key not in the projection
@ -62,58 +36,119 @@ class Projection(collections.abc.Mapping):
>>> target = {'a': 2, 'b': 2}
>>> target.update(prj)
>>> target
{'a': 1, 'b': 2, 'c': 3}
>>> target == {'a': 1, 'b': 2, 'c': 3}
True
Projection keeps a reference to the original dict, so
modifying the original dict may modify the Projection.
Also note that Projection keeps a reference to the original dict, so
if you modify the original dict, that could modify the Projection.
>>> del sample['a']
>>> dict(prj)
{'c': 3}
"""
def __init__(self, keys: _Matchable, space: Mapping):
self._match = _dispatch(keys)
def __init__(self, keys, space):
self._keys = tuple(keys)
self._space = space
def __getitem__(self, key):
if not self._match(key):
if key not in self._keys:
raise KeyError(key)
return self._space[key]
def _keys_resolved(self):
return filter(self._match, self._space)
def __iter__(self):
return self._keys_resolved()
return iter(set(self._keys).intersection(self._space))
def __len__(self):
return len(tuple(self._keys_resolved()))
return len(tuple(iter(self)))
class Mask(Projection):
class DictFilter(collections.abc.Mapping):
"""
The inverse of a :class:`Projection`, masking out keys.
Takes a dict, and simulates a sub-dict based on the keys.
>>> sample = {'a': 1, 'b': 2, 'c': 3}
>>> msk = Mask(['a', 'c', 'd'], sample)
>>> dict(msk)
>>> filtered = DictFilter(sample, ['a', 'c'])
>>> filtered == {'a': 1, 'c': 3}
True
>>> set(filtered.values()) == {1, 3}
True
>>> set(filtered.items()) == {('a', 1), ('c', 3)}
True
One can also filter by a regular expression pattern
>>> sample['d'] = 4
>>> sample['ef'] = 5
Here we filter for only single-character keys
>>> filtered = DictFilter(sample, include_pattern='.$')
>>> filtered == {'a': 1, 'b': 2, 'c': 3, 'd': 4}
True
>>> filtered['e']
Traceback (most recent call last):
...
KeyError: 'e'
>>> 'e' in filtered
False
Pattern is useful for excluding keys with a prefix.
>>> filtered = DictFilter(sample, include_pattern=r'(?![ace])')
>>> dict(filtered)
{'b': 2, 'd': 4}
Also note that DictFilter keeps a reference to the original dict, so
if you modify the original dict, that could modify the filtered dict.
>>> del sample['d']
>>> dict(filtered)
{'b': 2}
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# self._match = compose(operator.not_, self._match)
self._match = lambda key, orig=self._match: not orig(key)
def __init__(self, dict, include_keys=[], include_pattern=None):
self.dict = dict
self.specified_keys = set(include_keys)
if include_pattern is not None:
self.include_pattern = re.compile(include_pattern)
else:
# for performance, replace the pattern_keys property
self.pattern_keys = set()
def get_pattern_keys(self):
keys = filter(self.include_pattern.match, self.dict.keys())
return set(keys)
pattern_keys = NonDataProperty(get_pattern_keys)
@property
def include_keys(self):
return self.specified_keys | self.pattern_keys
def __getitem__(self, i):
if i not in self.include_keys:
raise KeyError(i)
return self.dict[i]
def __iter__(self):
return filter(self.include_keys.__contains__, self.dict.keys())
def __len__(self):
return len(list(self))
def dict_map(function, dictionary):
"""
Return a new dict with function applied to values of dictionary.
dict_map is much like the built-in function map. It takes a dictionary
and applys a function to the values of that dictionary, returning a
new dictionary with the mapped values in the original keys.
>>> dict_map(lambda x: x+1, dict(a=1, b=2))
{'a': 2, 'b': 3}
>>> d = dict_map(lambda x:x+1, dict(a=1, b=2))
>>> d == dict(a=2,b=3)
True
"""
return dict((key, function(value)) for key, value in dictionary.items())
@ -129,7 +164,7 @@ class RangeMap(dict):
One may supply keyword parameters to be passed to the sort function used
to sort keys (i.e. key, reverse) as sort_params.
Create a map that maps 1-3 -> 'a', 4-6 -> 'b'
Let's create a map that maps 1-3 -> 'a', 4-6 -> 'b'
>>> r = RangeMap({3: 'a', 6: 'b'}) # boy, that was easy
>>> r[1], r[2], r[3], r[4], r[5], r[6]
@ -141,7 +176,7 @@ class RangeMap(dict):
>>> r[4.5]
'b'
Notice that the way rangemap is defined, it must be open-ended
But you'll notice that the way rangemap is defined, it must be open-ended
on one side.
>>> r[0]
@ -244,7 +279,7 @@ class RangeMap(dict):
return (sorted_keys[RangeMap.first_item], sorted_keys[RangeMap.last_item])
# some special values for the RangeMap
undefined_value = type('RangeValueUndefined', (), {})()
undefined_value = type(str('RangeValueUndefined'), (), {})()
class Item(int):
"RangeMap Item"
@ -259,7 +294,7 @@ def __identity(x):
def sorted_items(d, key=__identity, reverse=False):
"""
Return the items of the dictionary sorted by the keys.
Return the items of the dictionary sorted by the keys
>>> sample = dict(foo=20, bar=42, baz=10)
>>> tuple(sorted_items(sample))
@ -272,7 +307,6 @@ def sorted_items(d, key=__identity, reverse=False):
>>> tuple(sorted_items(sample, reverse=True))
(('foo', 20), ('baz', 10), ('bar', 42))
"""
# wrap the key func so it operates on the first element of each item
def pairkey_key(item):
return key(item[0])
@ -441,7 +475,7 @@ class ItemsAsAttributes:
Mix-in class to enable a mapping object to provide items as
attributes.
>>> C = type('C', (dict, ItemsAsAttributes), dict())
>>> C = type(str('C'), (dict, ItemsAsAttributes), dict())
>>> i = C()
>>> i['foo'] = 'bar'
>>> i.foo
@ -470,7 +504,7 @@ class ItemsAsAttributes:
>>> missing_func = lambda self, key: 'missing item'
>>> C = type(
... 'C',
... str('C'),
... (dict, ItemsAsAttributes),
... dict(__missing__ = missing_func),
... )

View file

@ -5,18 +5,10 @@ import functools
import tempfile
import shutil
import operator
import warnings
@contextlib.contextmanager
def pushd(dir):
"""
>>> tmp_path = getfixture('tmp_path')
>>> with pushd(tmp_path):
... assert os.getcwd() == os.fspath(tmp_path)
>>> assert os.getcwd() != os.fspath(tmp_path)
"""
orig = os.getcwd()
os.chdir(dir)
try:
@ -37,8 +29,6 @@ def tarball_context(url, target_dir=None, runner=None, pushd=pushd):
target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '')
if runner is None:
runner = functools.partial(subprocess.check_call, shell=True)
else:
warnings.warn("runner parameter is deprecated", DeprecationWarning)
# In the tar command, use --strip-components=1 to strip the first path and
# then
# use -C to cause the files to be extracted to {target_dir}. This ensures
@ -58,15 +48,6 @@ def tarball_context(url, target_dir=None, runner=None, pushd=pushd):
def infer_compression(url):
"""
Given a URL or filename, infer the compression code for tar.
>>> infer_compression('http://foo/bar.tar.gz')
'z'
>>> infer_compression('http://foo/bar.tgz')
'z'
>>> infer_compression('file.bz')
'j'
>>> infer_compression('file.xz')
'J'
"""
# cheat and just assume it's the last two characters
compression_indicator = url[-2:]
@ -80,12 +61,6 @@ def temp_dir(remover=shutil.rmtree):
"""
Create a temporary directory context. Pass a custom remover
to override the removal behavior.
>>> import pathlib
>>> with temp_dir() as the_dir:
... assert os.path.isdir(the_dir)
... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents')
>>> assert not os.path.exists(the_dir)
"""
temp_dir = tempfile.mkdtemp()
try:
@ -115,12 +90,6 @@ def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir):
@contextlib.contextmanager
def null():
"""
A null context suitable to stand in for a meaningful context.
>>> with null() as value:
... assert value is None
"""
yield
@ -143,10 +112,6 @@ class ExceptionTrap:
... raise ValueError("1 + 1 is not 3")
>>> bool(trap)
True
>>> trap.value
ValueError('1 + 1 is not 3')
>>> trap.tb
<traceback object at ...>
>>> with ExceptionTrap(ValueError) as trap:
... raise Exception()
@ -246,43 +211,3 @@ class suppress(contextlib.suppress, contextlib.ContextDecorator):
... {}['']
>>> key_error()
"""
class on_interrupt(contextlib.ContextDecorator):
"""
Replace a KeyboardInterrupt with SystemExit(1)
>>> def do_interrupt():
... raise KeyboardInterrupt()
>>> on_interrupt('error')(do_interrupt)()
Traceback (most recent call last):
...
SystemExit: 1
>>> on_interrupt('error', code=255)(do_interrupt)()
Traceback (most recent call last):
...
SystemExit: 255
>>> on_interrupt('suppress')(do_interrupt)()
>>> with __import__('pytest').raises(KeyboardInterrupt):
... on_interrupt('ignore')(do_interrupt)()
"""
def __init__(
self,
action='error',
# py3.7 compat
# /,
code=1,
):
self.action = action
self.code = code
def __enter__(self):
return self
def __exit__(self, exctype, excinst, exctb):
if exctype is not KeyboardInterrupt or self.action == 'ignore':
return
elif self.action == 'error':
raise SystemExit(self.code) from excinst
return self.action == 'suppress'

View file

@ -1,4 +1,4 @@
import collections.abc
import collections
import functools
import inspect
import itertools
@ -9,6 +9,11 @@ import warnings
import more_itertools
from typing import Callable, TypeVar
CallableT = TypeVar("CallableT", bound=Callable[..., object])
def compose(*funcs):
"""
@ -34,6 +39,24 @@ def compose(*funcs):
return functools.reduce(compose_two, funcs)
def method_caller(method_name, *args, **kwargs):
"""
Return a function that will call a named method on the
target object with optional positional and keyword
arguments.
>>> lower = method_caller('lower')
>>> lower('MyString')
'mystring'
"""
def call_method(target):
func = getattr(target, method_name)
return func(*args, **kwargs)
return call_method
def once(func):
"""
Decorate func so it's only ever called the first time.
@ -76,7 +99,12 @@ def once(func):
return wrapper
def method_cache(method, cache_wrapper=functools.lru_cache()):
def method_cache(
method: CallableT,
cache_wrapper: Callable[
[CallableT], CallableT
] = functools.lru_cache(), # type: ignore[assignment]
) -> CallableT:
"""
Wrap lru_cache to support storing the cache data in the object instances.
@ -144,17 +172,22 @@ def method_cache(method, cache_wrapper=functools.lru_cache()):
for another implementation and additional justification.
"""
def wrapper(self, *args, **kwargs):
def wrapper(self: object, *args: object, **kwargs: object) -> object:
# it's the first call, replace the method with a cached, bound method
bound_method = types.MethodType(method, self)
bound_method: CallableT = types.MethodType( # type: ignore[assignment]
method, self
)
cached_method = cache_wrapper(bound_method)
setattr(self, method.__name__, cached_method)
return cached_method(*args, **kwargs)
# Support cache clear even before cache has been created.
wrapper.cache_clear = lambda: None
wrapper.cache_clear = lambda: None # type: ignore[attr-defined]
return _special_method_cache(method, cache_wrapper) or wrapper
return (
_special_method_cache(method, cache_wrapper) # type: ignore[return-value]
or wrapper
)
def _special_method_cache(method, cache_wrapper):
@ -170,13 +203,12 @@ def _special_method_cache(method, cache_wrapper):
"""
name = method.__name__
special_names = '__getattr__', '__getitem__'
if name not in special_names:
return None
return
wrapper_name = '__cached' + name
def proxy(self, /, *args, **kwargs):
def proxy(self, *args, **kwargs):
if wrapper_name not in vars(self):
bound = types.MethodType(method, self)
cache = cache_wrapper(bound)
@ -213,7 +245,7 @@ def result_invoke(action):
r"""
Decorate a function with an action function that is
invoked on the results returned from the decorated
function (for its side effect), then return the original
function (for its side-effect), then return the original
result.
>>> @result_invoke(print)
@ -237,7 +269,7 @@ def result_invoke(action):
return wrap
def invoke(f, /, *args, **kwargs):
def invoke(f, *args, **kwargs):
"""
Call a function for its side effect after initialization.
@ -272,15 +304,25 @@ def invoke(f, /, *args, **kwargs):
Use functools.partial to pass parameters to the initial call
>>> @functools.partial(invoke, name='bingo')
... def func(name): print('called with', name)
... def func(name): print("called with", name)
called with bingo
"""
f(*args, **kwargs)
return f
def call_aside(*args, **kwargs):
"""
Deprecated name for invoke.
"""
warnings.warn("call_aside is deprecated, use invoke", DeprecationWarning)
return invoke(*args, **kwargs)
class Throttler:
"""Rate-limit a function (or other callable)."""
"""
Rate-limit a function (or other callable)
"""
def __init__(self, func, max_rate=float('Inf')):
if isinstance(func, Throttler):
@ -297,20 +339,20 @@ class Throttler:
return self.func(*args, **kwargs)
def _wait(self):
"""Ensure at least 1/max_rate seconds from last call."""
"ensure at least 1/max_rate seconds from last call"
elapsed = time.time() - self.last_called
must_wait = 1 / self.max_rate - elapsed
time.sleep(max(0, must_wait))
self.last_called = time.time()
def __get__(self, obj, owner=None):
def __get__(self, obj, type=None):
return first_invoke(self._wait, functools.partial(self.func, obj))
def first_invoke(func1, func2):
"""
Return a function that when invoked will invoke func1 without
any parameters (for its side effect) and then invoke func2
any parameters (for its side-effect) and then invoke func2
with whatever parameters were passed, returning its result.
"""
@ -321,17 +363,6 @@ def first_invoke(func1, func2):
return wrapper
method_caller = first_invoke(
lambda: warnings.warn(
'`jaraco.functools.method_caller` is deprecated, '
'use `operator.methodcaller` instead',
DeprecationWarning,
stacklevel=3,
),
operator.methodcaller,
)
def retry_call(func, cleanup=lambda: None, retries=0, trap=()):
"""
Given a callable func, trap the indicated exceptions
@ -340,7 +371,7 @@ def retry_call(func, cleanup=lambda: None, retries=0, trap=()):
to propagate.
"""
attempts = itertools.count() if retries == float('inf') else range(retries)
for _ in attempts:
for attempt in attempts:
try:
return func()
except trap:
@ -377,7 +408,7 @@ def retry(*r_args, **r_kwargs):
def print_yielded(func):
"""
Convert a generator into a function that prints all yielded elements.
Convert a generator into a function that prints all yielded elements
>>> @print_yielded
... def x():
@ -393,7 +424,7 @@ def print_yielded(func):
def pass_none(func):
"""
Wrap func so it's not called if its first param is None.
Wrap func so it's not called if its first param is None
>>> print_text = pass_none(print)
>>> print_text('text')
@ -402,10 +433,9 @@ def pass_none(func):
"""
@functools.wraps(func)
def wrapper(param, /, *args, **kwargs):
def wrapper(param, *args, **kwargs):
if param is not None:
return func(param, *args, **kwargs)
return None
return wrapper
@ -479,7 +509,7 @@ def save_method_args(method):
args_and_kwargs = collections.namedtuple('args_and_kwargs', 'args kwargs')
@functools.wraps(method)
def wrapper(self, /, *args, **kwargs):
def wrapper(self, *args, **kwargs):
attr_name = '_saved_' + method.__name__
attr = args_and_kwargs(args, kwargs)
setattr(self, attr_name, attr)
@ -529,13 +559,6 @@ def except_(*exceptions, replace=None, use=None):
def identity(x):
"""
Return the argument.
>>> o = object()
>>> identity(o) is o
True
"""
return x
@ -557,7 +580,7 @@ def bypass_when(check, *, _op=identity):
def decorate(func):
@functools.wraps(func)
def wrapper(param, /):
def wrapper(param):
return param if _op(check) else func(param)
return wrapper
@ -581,53 +604,3 @@ def bypass_unless(check):
2
"""
return bypass_when(check, _op=operator.not_)
@functools.singledispatch
def _splat_inner(args, func):
"""Splat args to func."""
return func(*args)
@_splat_inner.register
def _(args: collections.abc.Mapping, func):
"""Splat kargs to func as kwargs."""
return func(**args)
def splat(func):
"""
Wrap func to expect its parameters to be passed positionally in a tuple.
Has a similar effect to that of ``itertools.starmap`` over
simple ``map``.
>>> pairs = [(-1, 1), (0, 2)]
>>> more_itertools.consume(itertools.starmap(print, pairs))
-1 1
0 2
>>> more_itertools.consume(map(splat(print), pairs))
-1 1
0 2
The approach generalizes to other iterators that don't have a "star"
equivalent, such as a "starfilter".
>>> list(filter(splat(operator.add), pairs))
[(0, 2)]
Splat also accepts a mapping argument.
>>> def is_nice(msg, code):
... return "smile" in msg or code == 0
>>> msgs = [
... dict(msg='smile!', code=20),
... dict(msg='error :(', code=1),
... dict(msg='unknown', code=0),
... ]
>>> for msg in filter(splat(is_nice), msgs):
... print(msg)
{'msg': 'smile!', 'code': 20}
{'msg': 'unknown', 'code': 0}
"""
return functools.wraps(func)(functools.partial(_splat_inner, func=func))

View file

@ -1,128 +0,0 @@
from collections.abc import Callable, Hashable, Iterator
from functools import partial
from operator import methodcaller
import sys
from typing import (
Any,
Generic,
Protocol,
TypeVar,
overload,
)
if sys.version_info >= (3, 10):
from typing import Concatenate, ParamSpec
else:
from typing_extensions import Concatenate, ParamSpec
_P = ParamSpec('_P')
_R = TypeVar('_R')
_T = TypeVar('_T')
_R1 = TypeVar('_R1')
_R2 = TypeVar('_R2')
_V = TypeVar('_V')
_S = TypeVar('_S')
_R_co = TypeVar('_R_co', covariant=True)
class _OnceCallable(Protocol[_P, _R]):
saved_result: _R
reset: Callable[[], None]
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
class _ProxyMethodCacheWrapper(Protocol[_R_co]):
cache_clear: Callable[[], None]
def __call__(self, *args: Hashable, **kwargs: Hashable) -> _R_co: ...
class _MethodCacheWrapper(Protocol[_R_co]):
def cache_clear(self) -> None: ...
def __call__(self, *args: Hashable, **kwargs: Hashable) -> _R_co: ...
# `compose()` overloads below will cover most use cases.
@overload
def compose(
__func1: Callable[[_R], _T],
__func2: Callable[_P, _R],
/,
) -> Callable[_P, _T]: ...
@overload
def compose(
__func1: Callable[[_R], _T],
__func2: Callable[[_R1], _R],
__func3: Callable[_P, _R1],
/,
) -> Callable[_P, _T]: ...
@overload
def compose(
__func1: Callable[[_R], _T],
__func2: Callable[[_R2], _R],
__func3: Callable[[_R1], _R2],
__func4: Callable[_P, _R1],
/,
) -> Callable[_P, _T]: ...
def once(func: Callable[_P, _R]) -> _OnceCallable[_P, _R]: ...
def method_cache(
method: Callable[..., _R],
cache_wrapper: Callable[[Callable[..., _R]], _MethodCacheWrapper[_R]] = ...,
) -> _MethodCacheWrapper[_R] | _ProxyMethodCacheWrapper[_R]: ...
def apply(
transform: Callable[[_R], _T]
) -> Callable[[Callable[_P, _R]], Callable[_P, _T]]: ...
def result_invoke(
action: Callable[[_R], Any]
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: ...
def invoke(
f: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs
) -> Callable[_P, _R]: ...
def call_aside(
f: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
) -> Callable[_P, _R]: ...
class Throttler(Generic[_R]):
last_called: float
func: Callable[..., _R]
max_rate: float
def __init__(
self, func: Callable[..., _R] | Throttler[_R], max_rate: float = ...
) -> None: ...
def reset(self) -> None: ...
def __call__(self, *args: Any, **kwargs: Any) -> _R: ...
def __get__(self, obj: Any, owner: type[Any] | None = ...) -> Callable[..., _R]: ...
def first_invoke(
func1: Callable[..., Any], func2: Callable[_P, _R]
) -> Callable[_P, _R]: ...
method_caller: Callable[..., methodcaller]
def retry_call(
func: Callable[..., _R],
cleanup: Callable[..., None] = ...,
retries: int | float = ...,
trap: type[BaseException] | tuple[type[BaseException], ...] = ...,
) -> _R: ...
def retry(
cleanup: Callable[..., None] = ...,
retries: int | float = ...,
trap: type[BaseException] | tuple[type[BaseException], ...] = ...,
) -> Callable[[Callable[..., _R]], Callable[..., _R]]: ...
def print_yielded(func: Callable[_P, Iterator[Any]]) -> Callable[_P, None]: ...
def pass_none(
func: Callable[Concatenate[_T, _P], _R]
) -> Callable[Concatenate[_T, _P], _R]: ...
def assign_params(
func: Callable[..., _R], namespace: dict[str, Any]
) -> partial[_R]: ...
def save_method_args(
method: Callable[Concatenate[_S, _P], _R]
) -> Callable[Concatenate[_S, _P], _R]: ...
def except_(
*exceptions: type[BaseException], replace: Any = ..., use: Any = ...
) -> Callable[[Callable[_P, Any]], Callable[_P, Any]]: ...
def identity(x: _T) -> _T: ...
def bypass_when(
check: _V, *, _op: Callable[[_V], Any] = ...
) -> Callable[[Callable[[_T], _R]], Callable[[_T], _T | _R]]: ...
def bypass_unless(
check: Any,
) -> Callable[[Callable[[_T], _R]], Callable[[_T], _T | _R]]: ...

View file

@ -227,12 +227,10 @@ def unwrap(s):
return '\n'.join(cleaned)
lorem_ipsum: str = (
files(__name__).joinpath('Lorem ipsum.txt').read_text(encoding='utf-8')
)
lorem_ipsum: str = files(__name__).joinpath('Lorem ipsum.txt').read_text()
class Splitter:
class Splitter(object):
"""object that will split a string with the given arguments for each call
>>> s = Splitter(',')
@ -369,7 +367,7 @@ class WordSet(tuple):
return self.trim_left(item).trim_right(item)
def __getitem__(self, item):
result = super().__getitem__(item)
result = super(WordSet, self).__getitem__(item)
if isinstance(item, slice):
result = WordSet(result)
return result
@ -584,7 +582,7 @@ def join_continuation(lines):
['foobarbaz']
Not sure why, but...
The character preceding the backslash is also elided.
The character preceeding the backslash is also elided.
>>> list(join_continuation(['goo\\', 'dly']))
['godly']
@ -609,16 +607,16 @@ def read_newlines(filename, limit=1024):
r"""
>>> tmp_path = getfixture('tmp_path')
>>> filename = tmp_path / 'out.txt'
>>> _ = filename.write_text('foo\n', newline='', encoding='utf-8')
>>> _ = filename.write_text('foo\n', newline='')
>>> read_newlines(filename)
'\n'
>>> _ = filename.write_text('foo\r\n', newline='', encoding='utf-8')
>>> _ = filename.write_text('foo\r\n', newline='')
>>> read_newlines(filename)
'\r\n'
>>> _ = filename.write_text('foo\r\nbar\nbing\r', newline='', encoding='utf-8')
>>> _ = filename.write_text('foo\r\nbar\nbing\r', newline='')
>>> read_newlines(filename)
('\r', '\n', '\r\n')
"""
with open(filename, encoding='utf-8') as fp:
with open(filename) as fp:
fp.read(limit)
return fp.newlines

View file

@ -12,11 +12,11 @@ def report_newlines(filename):
>>> tmp_path = getfixture('tmp_path')
>>> filename = tmp_path / 'out.txt'
>>> _ = filename.write_text('foo\nbar\n', newline='', encoding='utf-8')
>>> _ = filename.write_text('foo\nbar\n', newline='')
>>> report_newlines(filename)
newline is '\n'
>>> filename = tmp_path / 'out.txt'
>>> _ = filename.write_text('foo\nbar\r\n', newline='', encoding='utf-8')
>>> _ = filename.write_text('foo\nbar\r\n', newline='')
>>> report_newlines(filename)
newlines are ('\n', '\r\n')
"""

View file

@ -1,21 +0,0 @@
import sys
import autocommand
from jaraco.text import Stripper
def strip_prefix():
r"""
Strip any common prefix from stdin.
>>> import io, pytest
>>> getfixture('monkeypatch').setattr('sys.stdin', io.StringIO('abcdef\nabc123'))
>>> strip_prefix()
def
123
"""
sys.stdout.writelines(Stripper.strip_prefix(sys.stdin).lines)
autocommand.autocommand(__name__)(strip_prefix)

View file

@ -3,4 +3,4 @@
from .more import * # noqa
from .recipes import * # noqa
__version__ = '10.2.0'
__version__ = '10.1.0'

View file

@ -19,7 +19,7 @@ from itertools import (
zip_longest,
product,
)
from math import exp, factorial, floor, log, perm, comb
from math import exp, factorial, floor, log
from queue import Empty, Queue
from random import random, randrange, uniform
from operator import itemgetter, mul, sub, gt, lt, ge, le
@ -68,10 +68,8 @@ __all__ = [
'divide',
'duplicates_everseen',
'duplicates_justseen',
'classify_unique',
'exactly_n',
'filter_except',
'filter_map',
'first',
'gray_product',
'groupby_transform',
@ -85,7 +83,6 @@ __all__ = [
'is_sorted',
'islice_extended',
'iterate',
'iter_suppress',
'last',
'locate',
'longest_common_prefix',
@ -201,14 +198,15 @@ def first(iterable, default=_marker):
``next(iter(iterable), default)``.
"""
for item in iterable:
return item
if default is _marker:
raise ValueError(
'first() was called on an empty iterable, and no '
'default value was provided.'
)
return default
try:
return next(iter(iterable))
except StopIteration as e:
if default is _marker:
raise ValueError(
'first() was called on an empty iterable, and no '
'default value was provided.'
) from e
return default
def last(iterable, default=_marker):
@ -584,9 +582,6 @@ def strictly_n(iterable, n, too_short=None, too_long=None):
>>> list(strictly_n(iterable, n))
['a', 'b', 'c', 'd']
Note that the returned iterable must be consumed in order for the check to
be made.
By default, *too_short* and *too_long* are functions that raise
``ValueError``.
@ -924,7 +919,7 @@ def substrings_indexes(seq, reverse=False):
class bucket:
"""Wrap *iterable* and return an object that buckets the iterable into
"""Wrap *iterable* and return an object that buckets it iterable into
child iterables based on a *key* function.
>>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
@ -3227,8 +3222,6 @@ class time_limited:
stops if the time elapsed is greater than *limit_seconds*. If your time
limit is 1 second, but it takes 2 seconds to generate the first item from
the iterable, the function will run for 2 seconds and not yield anything.
As a special case, when *limit_seconds* is zero, the iterator never
returns anything.
"""
@ -3244,9 +3237,6 @@ class time_limited:
return self
def __next__(self):
if self.limit_seconds == 0:
self.timed_out = True
raise StopIteration
item = next(self._iterable)
if monotonic() - self._start_time > self.limit_seconds:
self.timed_out = True
@ -3366,7 +3356,7 @@ def iequals(*iterables):
>>> iequals("abc", "acb")
False
Not to be confused with :func:`all_equal`, which checks whether all
Not to be confused with :func:`all_equals`, which checks whether all
elements of iterable are equal to each other.
"""
@ -3863,7 +3853,7 @@ def nth_permutation(iterable, r, index):
elif not 0 <= r < n:
raise ValueError
else:
c = perm(n, r)
c = factorial(n) // factorial(n - r)
if index < 0:
index += c
@ -3908,7 +3898,7 @@ def nth_combination_with_replacement(iterable, r, index):
if (r < 0) or (r > n):
raise ValueError
c = comb(n + r - 1, r)
c = factorial(n + r - 1) // (factorial(r) * factorial(n - 1))
if index < 0:
index += c
@ -3921,7 +3911,9 @@ def nth_combination_with_replacement(iterable, r, index):
while r:
r -= 1
while n >= 0:
num_combs = comb(n + r - 1, r)
num_combs = factorial(n + r - 1) // (
factorial(r) * factorial(n - 1)
)
if index < num_combs:
break
n -= 1
@ -4023,9 +4015,9 @@ def combination_index(element, iterable):
for i, j in enumerate(reversed(indexes), start=1):
j = n - j
if i <= j:
index += comb(j, i)
index += factorial(j) // (factorial(i) * factorial(j - i))
return comb(n + 1, k + 1) - index
return factorial(n + 1) // (factorial(k + 1) * factorial(n - k)) - index
def combination_with_replacement_index(element, iterable):
@ -4065,7 +4057,7 @@ def combination_with_replacement_index(element, iterable):
break
else:
raise ValueError(
'element is not a combination with replacement of iterable'
'element is not a combination with replacment of iterable'
)
n = len(pool)
@ -4074,13 +4066,11 @@ def combination_with_replacement_index(element, iterable):
occupations[p] += 1
index = 0
cumulative_sum = 0
for k in range(1, n):
cumulative_sum += occupations[k - 1]
j = l + n - 1 - k - cumulative_sum
j = l + n - 1 - k - sum(occupations[:k])
i = n - k
if i <= j:
index += comb(j, i)
index += factorial(j) // (factorial(i) * factorial(j - i))
return index
@ -4306,7 +4296,7 @@ def duplicates_everseen(iterable, key=None):
>>> list(duplicates_everseen('AaaBbbCccAaa', str.lower))
['a', 'a', 'b', 'b', 'c', 'c', 'A', 'a', 'a']
This function is analogous to :func:`unique_everseen` and is subject to
This function is analagous to :func:`unique_everseen` and is subject to
the same performance considerations.
"""
@ -4336,54 +4326,12 @@ def duplicates_justseen(iterable, key=None):
>>> list(duplicates_justseen('AaaBbbCccAaa', str.lower))
['a', 'a', 'b', 'b', 'c', 'c', 'a', 'a']
This function is analogous to :func:`unique_justseen`.
This function is analagous to :func:`unique_justseen`.
"""
return flatten(g for _, g in groupby(iterable, key) for _ in g)
def classify_unique(iterable, key=None):
"""Classify each element in terms of its uniqueness.
For each element in the input iterable, return a 3-tuple consisting of:
1. The element itself
2. ``False`` if the element is equal to the one preceding it in the input,
``True`` otherwise (i.e. the equivalent of :func:`unique_justseen`)
3. ``False`` if this element has been seen anywhere in the input before,
``True`` otherwise (i.e. the equivalent of :func:`unique_everseen`)
>>> list(classify_unique('otto')) # doctest: +NORMALIZE_WHITESPACE
[('o', True, True),
('t', True, True),
('t', False, False),
('o', True, False)]
This function is analogous to :func:`unique_everseen` and is subject to
the same performance considerations.
"""
seen_set = set()
seen_list = []
use_key = key is not None
previous = None
for i, element in enumerate(iterable):
k = key(element) if use_key else element
is_unique_justseen = not i or previous != k
previous = k
is_unique_everseen = False
try:
if k not in seen_set:
seen_set.add(k)
is_unique_everseen = True
except TypeError:
if k not in seen_list:
seen_list.append(k)
is_unique_everseen = True
yield element, is_unique_justseen, is_unique_everseen
def minmax(iterable_or_value, *others, key=None, default=_marker):
"""Returns both the smallest and largest items in an iterable
or the largest of two or more arguments.
@ -4581,8 +4529,10 @@ def takewhile_inclusive(predicate, iterable):
:func:`takewhile` would return ``[1, 4]``.
"""
for x in iterable:
yield x
if not predicate(x):
if predicate(x):
yield x
else:
yield x
break
@ -4617,40 +4567,3 @@ def outer_product(func, xs, ys, *args, **kwargs):
starmap(lambda x, y: func(x, y, *args, **kwargs), product(xs, ys)),
n=len(ys),
)
def iter_suppress(iterable, *exceptions):
"""Yield each of the items from *iterable*. If the iteration raises one of
the specified *exceptions*, that exception will be suppressed and iteration
will stop.
>>> from itertools import chain
>>> def breaks_at_five(x):
... while True:
... if x >= 5:
... raise RuntimeError
... yield x
... x += 1
>>> it_1 = iter_suppress(breaks_at_five(1), RuntimeError)
>>> it_2 = iter_suppress(breaks_at_five(2), RuntimeError)
>>> list(chain(it_1, it_2))
[1, 2, 3, 4, 2, 3, 4]
"""
try:
yield from iterable
except exceptions:
return
def filter_map(func, iterable):
"""Apply *func* to every element of *iterable*, yielding only those which
are not ``None``.
>>> elems = ['1', 'a', '2', 'b', '3']
>>> list(filter_map(lambda s: int(s) if s.isnumeric() else None, elems))
[1, 2, 3]
"""
for x in iterable:
y = func(x)
if y is not None:
yield y

View file

@ -29,7 +29,7 @@ _U = TypeVar('_U')
_V = TypeVar('_V')
_W = TypeVar('_W')
_T_co = TypeVar('_T_co', covariant=True)
_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[Any]])
_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[object]])
_Raisable = BaseException | Type[BaseException]
@type_check_only
@ -74,7 +74,7 @@ class peekable(Generic[_T], Iterator[_T]):
def __getitem__(self, index: slice) -> list[_T]: ...
def consumer(func: _GenFn) -> _GenFn: ...
def ilen(iterable: Iterable[_T]) -> int: ...
def ilen(iterable: Iterable[object]) -> int: ...
def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ...
def with_iter(
context_manager: ContextManager[Iterable[_T]],
@ -116,7 +116,7 @@ class bucket(Generic[_T, _U], Container[_U]):
self,
iterable: Iterable[_T],
key: Callable[[_T], _U],
validator: Callable[[_U], object] | None = ...,
validator: Callable[[object], object] | None = ...,
) -> None: ...
def __contains__(self, value: object) -> bool: ...
def __iter__(self) -> Iterator[_U]: ...
@ -383,7 +383,7 @@ def mark_ends(
iterable: Iterable[_T],
) -> Iterable[tuple[bool, bool, _T]]: ...
def locate(
iterable: Iterable[_T],
iterable: Iterable[object],
pred: Callable[..., Any] = ...,
window_size: int | None = ...,
) -> Iterator[int]: ...
@ -618,9 +618,6 @@ def duplicates_everseen(
def duplicates_justseen(
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
) -> Iterator[_T]: ...
def classify_unique(
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
) -> Iterator[tuple[_T, bool, bool]]: ...
class _SupportsLessThan(Protocol):
def __lt__(self, __other: Any) -> bool: ...
@ -665,9 +662,9 @@ def minmax(
def longest_common_prefix(
iterables: Iterable[Iterable[_T]],
) -> Iterator[_T]: ...
def iequals(*iterables: Iterable[Any]) -> bool: ...
def iequals(*iterables: Iterable[object]) -> bool: ...
def constrained_batches(
iterable: Iterable[_T],
iterable: Iterable[object],
max_size: int,
max_count: int | None = ...,
get_len: Callable[[_T], object] = ...,
@ -685,11 +682,3 @@ def outer_product(
*args: Any,
**kwargs: Any,
) -> Iterator[tuple[_V, ...]]: ...
def iter_suppress(
iterable: Iterable[_T],
*exceptions: Type[BaseException],
) -> Iterator[_T]: ...
def filter_map(
func: Callable[[_T], _V | None],
iterable: Iterable[_T],
) -> Iterator[_V]: ...

View file

@ -28,7 +28,6 @@ from itertools import (
zip_longest,
)
from random import randrange, sample, choice
from sys import hexversion
__all__ = [
'all_equal',
@ -57,7 +56,6 @@ __all__ = [
'powerset',
'prepend',
'quantify',
'reshape',
'random_combination_with_replacement',
'random_combination',
'random_permutation',
@ -71,7 +69,6 @@ __all__ = [
'tabulate',
'tail',
'take',
'totient',
'transpose',
'triplewise',
'unique_everseen',
@ -495,7 +492,7 @@ def unique_everseen(iterable, key=None):
>>> list(unique_everseen(iterable, key=tuple)) # Faster
[[1, 2], [2, 3]]
Similarly, you may want to convert unhashable ``set`` objects with
Similary, you may want to convert unhashable ``set`` objects with
``key=frozenset``. For ``dict`` objects,
``key=lambda x: frozenset(x.items())`` can be used.
@ -527,9 +524,6 @@ def unique_justseen(iterable, key=None):
['A', 'B', 'C', 'A', 'D']
"""
if key is None:
return map(operator.itemgetter(0), groupby(iterable))
return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
@ -823,34 +817,35 @@ def polynomial_from_roots(roots):
return list(reduce(convolve, factors, [1]))
def iter_index(iterable, value, start=0, stop=None):
def iter_index(iterable, value, start=0):
"""Yield the index of each place in *iterable* that *value* occurs,
beginning with index *start* and ending before index *stop*.
beginning with index *start*.
See :func:`locate` for a more general means of finding the indexes
associated with particular values.
>>> list(iter_index('AABCADEAF', 'A'))
[0, 1, 4, 7]
>>> list(iter_index('AABCADEAF', 'A', 1)) # start index is inclusive
[1, 4, 7]
>>> list(iter_index('AABCADEAF', 'A', 1, 7)) # stop index is not inclusive
[1, 4]
"""
seq_index = getattr(iterable, 'index', None)
if seq_index is None:
try:
seq_index = iterable.index
except AttributeError:
# Slow path for general iterables
it = islice(iterable, start, stop)
for i, element in enumerate(it, start):
if element is value or element == value:
yield i
else:
# Fast path for sequences
stop = len(iterable) if stop is None else stop
it = islice(iterable, start, None)
i = start - 1
try:
while True:
yield (i := seq_index(value, i + 1, stop))
i = i + operator.indexOf(it, value) + 1
yield i
except ValueError:
pass
else:
# Fast path for sequences
i = start - 1
try:
while True:
i = seq_index(value, i + 1)
yield i
except ValueError:
pass
@ -861,52 +856,47 @@ def sieve(n):
>>> list(sieve(30))
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
"""
if n > 2:
yield 2
start = 3
data = bytearray((0, 1)) * (n // 2)
data[:3] = 0, 0, 0
limit = math.isqrt(n) + 1
for p in iter_index(data, 1, start, limit):
yield from iter_index(data, 1, start, p * p)
for p in compress(range(limit), data):
data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
start = p * p
yield from iter_index(data, 1, start)
data[2] = 1
return iter_index(data, 1) if n > 2 else iter([])
def _batched(iterable, n, *, strict=False):
"""Batch data into tuples of length *n*. If the number of items in
*iterable* is not divisible by *n*:
* The last batch will be shorter if *strict* is ``False``.
* :exc:`ValueError` will be raised if *strict* is ``True``.
def _batched(iterable, n):
"""Batch data into lists of length *n*. The last batch may be shorter.
>>> list(batched('ABCDEFG', 3))
[('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
On Python 3.13 and above, this is an alias for :func:`itertools.batched`.
On Python 3.12 and above, this is an alias for :func:`itertools.batched`.
"""
if n < 1:
raise ValueError('n must be at least one')
it = iter(iterable)
while batch := tuple(islice(it, n)):
if strict and len(batch) != n:
raise ValueError('batched(): incomplete batch')
while True:
batch = tuple(islice(it, n))
if not batch:
break
yield batch
if hexversion >= 0x30D00A2:
try:
from itertools import batched as itertools_batched
def batched(iterable, n, *, strict=False):
return itertools_batched(iterable, n, strict=strict)
else:
except ImportError:
batched = _batched
else:
def batched(iterable, n):
return itertools_batched(iterable, n)
batched.__doc__ = _batched.__doc__
def transpose(it):
"""Swap the rows and columns of the input matrix.
"""Swap the rows and columns of the input.
>>> list(transpose([(1, 2, 3), (11, 22, 33)]))
[(1, 11), (2, 22), (3, 33)]
@ -917,20 +907,8 @@ def transpose(it):
return _zip_strict(*it)
def reshape(matrix, cols):
"""Reshape the 2-D input *matrix* to have a column count given by *cols*.
>>> matrix = [(0, 1), (2, 3), (4, 5)]
>>> cols = 3
>>> list(reshape(matrix, cols))
[(0, 1, 2), (3, 4, 5)]
"""
return batched(chain.from_iterable(matrix), cols)
def matmul(m1, m2):
"""Multiply two matrices.
>>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
[(49, 80), (41, 60)]
@ -943,12 +921,13 @@ def matmul(m1, m2):
def factor(n):
"""Yield the prime factors of n.
>>> list(factor(360))
[2, 2, 2, 3, 3, 5]
"""
for prime in sieve(math.isqrt(n) + 1):
while not n % prime:
while True:
if n % prime:
break
yield prime
n //= prime
if n == 1:
@ -996,17 +975,3 @@ def polynomial_derivative(coefficients):
n = len(coefficients)
powers = reversed(range(1, n))
return list(map(operator.mul, coefficients, powers))
def totient(n):
"""Return the count of natural numbers up to *n* that are coprime with *n*.
>>> totient(9)
6
>>> totient(12)
4
"""
for p in unique_justseen(factor(n)):
n = n // p * (p - 1)
return n

View file

@ -14,8 +14,6 @@ from typing import (
# Type and type variable definitions
_T = TypeVar('_T')
_T1 = TypeVar('_T1')
_T2 = TypeVar('_T2')
_U = TypeVar('_U')
def take(n: int, iterable: Iterable[_T]) -> list[_T]: ...
@ -28,14 +26,14 @@ def consume(iterator: Iterable[_T], n: int | None = ...) -> None: ...
def nth(iterable: Iterable[_T], n: int) -> _T | None: ...
@overload
def nth(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ...
def all_equal(iterable: Iterable[_T]) -> bool: ...
def all_equal(iterable: Iterable[object]) -> bool: ...
def quantify(
iterable: Iterable[_T], pred: Callable[[_T], bool] = ...
) -> int: ...
def pad_none(iterable: Iterable[_T]) -> Iterator[_T | None]: ...
def padnone(iterable: Iterable[_T]) -> Iterator[_T | None]: ...
def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: ...
def dotproduct(vec1: Iterable[_T1], vec2: Iterable[_T2]) -> Any: ...
def dotproduct(vec1: Iterable[object], vec2: Iterable[object]) -> object: ...
def flatten(listOfLists: Iterable[Iterable[_T]]) -> Iterator[_T]: ...
def repeatfunc(
func: Callable[..., _U], times: int | None = ..., *args: Any
@ -105,24 +103,20 @@ def sliding_window(
def subslices(iterable: Iterable[_T]) -> Iterator[list[_T]]: ...
def polynomial_from_roots(roots: Sequence[_T]) -> list[_T]: ...
def iter_index(
iterable: Iterable[_T],
iterable: Iterable[object],
value: Any,
start: int | None = ...,
stop: int | None = ...,
) -> Iterator[int]: ...
def sieve(n: int) -> Iterator[int]: ...
def batched(
iterable: Iterable[_T], n: int, *, strict: bool = False
iterable: Iterable[_T],
n: int,
) -> Iterator[tuple[_T]]: ...
def transpose(
it: Iterable[Iterable[_T]],
) -> Iterator[tuple[_T, ...]]: ...
def reshape(
matrix: Iterable[Iterable[_T]], cols: int
) -> Iterator[tuple[_T, ...]]: ...
def matmul(m1: Sequence[_T], m2: Sequence[_T]) -> Iterator[tuple[_T]]: ...
def factor(n: int) -> Iterator[int]: ...
def polynomial_eval(coefficients: Sequence[_T], x: _U) -> _U: ...
def sum_of_squares(it: Iterable[_T]) -> _T: ...
def polynomial_derivative(coefficients: Sequence[_T]) -> list[_T]: ...
def totient(n: int) -> int: ...

View file

@ -1,114 +1,56 @@
import typing
from ._migration import getattr_migration
from .version import VERSION
if typing.TYPE_CHECKING:
# import of virtually everything is supported via `__getattr__` below,
# but we need them here for type checking and IDE support
import pydantic_core
from pydantic_core.core_schema import (
FieldSerializationInfo,
SerializationInfo,
SerializerFunctionWrapHandler,
ValidationInfo,
ValidatorFunctionWrapHandler,
)
from . import dataclasses
from ._internal._generate_schema import GenerateSchema as GenerateSchema
from .aliases import AliasChoices, AliasGenerator, AliasPath
from .annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
from .config import ConfigDict
from .errors import *
from .fields import Field, PrivateAttr, computed_field
from .functional_serializers import (
PlainSerializer,
SerializeAsAny,
WrapSerializer,
field_serializer,
model_serializer,
)
from .functional_validators import (
AfterValidator,
BeforeValidator,
InstanceOf,
PlainValidator,
SkipValidation,
WrapValidator,
field_validator,
model_validator,
)
from .json_schema import WithJsonSchema
from .main import *
from .networks import *
from .type_adapter import TypeAdapter
from .types import *
from .validate_call_decorator import validate_call
from .warnings import PydanticDeprecatedSince20, PydanticDeprecatedSince26, PydanticDeprecationWarning
# this encourages pycharm to import `ValidationError` from here, not pydantic_core
ValidationError = pydantic_core.ValidationError
from .deprecated.class_validators import root_validator, validator
from .deprecated.config import BaseConfig, Extra
from .deprecated.tools import *
from .root_model import RootModel
# flake8: noqa
from . import dataclasses
from .annotated_types import create_model_from_namedtuple, create_model_from_typeddict
from .class_validators import root_validator, validator
from .config import BaseConfig, ConfigDict, Extra
from .decorator import validate_arguments
from .env_settings import BaseSettings
from .error_wrappers import ValidationError
from .errors import *
from .fields import Field, PrivateAttr, Required
from .main import *
from .networks import *
from .parse import Protocol
from .tools import *
from .types import *
from .version import VERSION, compiled
__version__ = VERSION
__all__ = (
# WARNING __all__ from .errors is not included here, it will be removed as an export here in v2
# please use "from pydantic.errors import ..." instead
__all__ = [
# annotated types utils
'create_model_from_namedtuple',
'create_model_from_typeddict',
# dataclasses
'dataclasses',
# functional validators
'field_validator',
'model_validator',
'AfterValidator',
'BeforeValidator',
'PlainValidator',
'WrapValidator',
'SkipValidation',
'InstanceOf',
# JSON Schema
'WithJsonSchema',
# deprecated V1 functional validators, these are imported via `__getattr__` below
# class_validators
'root_validator',
'validator',
# functional serializers
'field_serializer',
'model_serializer',
'PlainSerializer',
'SerializeAsAny',
'WrapSerializer',
# config
'ConfigDict',
# deprecated V1 config, these are imported via `__getattr__` below
'BaseConfig',
'ConfigDict',
'Extra',
# validate_call
'validate_call',
# errors
'PydanticErrorCodes',
'PydanticUserError',
'PydanticSchemaGenerationError',
'PydanticImportError',
'PydanticUndefinedAnnotation',
'PydanticInvalidForJsonSchema',
# decorator
'validate_arguments',
# env_settings
'BaseSettings',
# error_wrappers
'ValidationError',
# fields
'Field',
'computed_field',
'PrivateAttr',
# alias
'AliasChoices',
'AliasGenerator',
'AliasPath',
'Required',
# main
'BaseModel',
'create_model',
'validate_model',
# network
'AnyUrl',
'AnyHttpUrl',
'FileUrl',
'HttpUrl',
'UrlConstraints',
'stricturl',
'EmailStr',
'NameEmail',
'IPvAnyAddress',
@ -120,38 +62,48 @@ __all__ = (
'RedisDsn',
'MongoDsn',
'KafkaDsn',
'NatsDsn',
'MySQLDsn',
'MariaDBDsn',
'validate_email',
# root_model
'RootModel',
# deprecated tools, these are imported via `__getattr__` below
# parse
'Protocol',
# tools
'parse_file_as',
'parse_obj_as',
'parse_raw_as',
'schema_of',
'schema_json_of',
# types
'Strict',
'NoneStr',
'NoneBytes',
'StrBytes',
'NoneStrBytes',
'StrictStr',
'ConstrainedBytes',
'conbytes',
'ConstrainedList',
'conlist',
'ConstrainedSet',
'conset',
'ConstrainedFrozenSet',
'confrozenset',
'ConstrainedStr',
'constr',
'StringConstraints',
'ImportString',
'PyObject',
'ConstrainedInt',
'conint',
'PositiveInt',
'NegativeInt',
'NonNegativeInt',
'NonPositiveInt',
'ConstrainedFloat',
'confloat',
'PositiveFloat',
'NegativeFloat',
'NonNegativeFloat',
'NonPositiveFloat',
'FiniteFloat',
'ConstrainedDecimal',
'condecimal',
'ConstrainedDate',
'condate',
'UUID1',
'UUID3',
@ -159,8 +111,9 @@ __all__ = (
'UUID5',
'FilePath',
'DirectoryPath',
'NewPath',
'Json',
'JsonWrapper',
'SecretField',
'SecretStr',
'SecretBytes',
'StrictBool',
@ -168,221 +121,11 @@ __all__ = (
'StrictInt',
'StrictFloat',
'PaymentCardNumber',
'PrivateAttr',
'ByteSize',
'PastDate',
'FutureDate',
'PastDatetime',
'FutureDatetime',
'AwareDatetime',
'NaiveDatetime',
'AllowInfNan',
'EncoderProtocol',
'EncodedBytes',
'EncodedStr',
'Base64Encoder',
'Base64Bytes',
'Base64Str',
'Base64UrlBytes',
'Base64UrlStr',
'GetPydanticSchema',
'Tag',
'Discriminator',
'JsonValue',
# type_adapter
'TypeAdapter',
# version
'__version__',
'compiled',
'VERSION',
# warnings
'PydanticDeprecatedSince20',
'PydanticDeprecatedSince26',
'PydanticDeprecationWarning',
# annotated handlers
'GetCoreSchemaHandler',
'GetJsonSchemaHandler',
# generate schema from ._internal
'GenerateSchema',
# pydantic_core
'ValidationError',
'ValidationInfo',
'SerializationInfo',
'ValidatorFunctionWrapHandler',
'FieldSerializationInfo',
'SerializerFunctionWrapHandler',
'OnErrorOmit',
)
# A mapping of {<member name>: (package, <module name>)} defining dynamic imports
_dynamic_imports: 'dict[str, tuple[str, str]]' = {
'dataclasses': (__package__, '__module__'),
# functional validators
'field_validator': (__package__, '.functional_validators'),
'model_validator': (__package__, '.functional_validators'),
'AfterValidator': (__package__, '.functional_validators'),
'BeforeValidator': (__package__, '.functional_validators'),
'PlainValidator': (__package__, '.functional_validators'),
'WrapValidator': (__package__, '.functional_validators'),
'SkipValidation': (__package__, '.functional_validators'),
'InstanceOf': (__package__, '.functional_validators'),
# JSON Schema
'WithJsonSchema': (__package__, '.json_schema'),
# functional serializers
'field_serializer': (__package__, '.functional_serializers'),
'model_serializer': (__package__, '.functional_serializers'),
'PlainSerializer': (__package__, '.functional_serializers'),
'SerializeAsAny': (__package__, '.functional_serializers'),
'WrapSerializer': (__package__, '.functional_serializers'),
# config
'ConfigDict': (__package__, '.config'),
# validate call
'validate_call': (__package__, '.validate_call_decorator'),
# errors
'PydanticErrorCodes': (__package__, '.errors'),
'PydanticUserError': (__package__, '.errors'),
'PydanticSchemaGenerationError': (__package__, '.errors'),
'PydanticImportError': (__package__, '.errors'),
'PydanticUndefinedAnnotation': (__package__, '.errors'),
'PydanticInvalidForJsonSchema': (__package__, '.errors'),
# fields
'Field': (__package__, '.fields'),
'computed_field': (__package__, '.fields'),
'PrivateAttr': (__package__, '.fields'),
# alias
'AliasChoices': (__package__, '.aliases'),
'AliasGenerator': (__package__, '.aliases'),
'AliasPath': (__package__, '.aliases'),
# main
'BaseModel': (__package__, '.main'),
'create_model': (__package__, '.main'),
# network
'AnyUrl': (__package__, '.networks'),
'AnyHttpUrl': (__package__, '.networks'),
'FileUrl': (__package__, '.networks'),
'HttpUrl': (__package__, '.networks'),
'UrlConstraints': (__package__, '.networks'),
'EmailStr': (__package__, '.networks'),
'NameEmail': (__package__, '.networks'),
'IPvAnyAddress': (__package__, '.networks'),
'IPvAnyInterface': (__package__, '.networks'),
'IPvAnyNetwork': (__package__, '.networks'),
'PostgresDsn': (__package__, '.networks'),
'CockroachDsn': (__package__, '.networks'),
'AmqpDsn': (__package__, '.networks'),
'RedisDsn': (__package__, '.networks'),
'MongoDsn': (__package__, '.networks'),
'KafkaDsn': (__package__, '.networks'),
'NatsDsn': (__package__, '.networks'),
'MySQLDsn': (__package__, '.networks'),
'MariaDBDsn': (__package__, '.networks'),
'validate_email': (__package__, '.networks'),
# root_model
'RootModel': (__package__, '.root_model'),
# types
'Strict': (__package__, '.types'),
'StrictStr': (__package__, '.types'),
'conbytes': (__package__, '.types'),
'conlist': (__package__, '.types'),
'conset': (__package__, '.types'),
'confrozenset': (__package__, '.types'),
'constr': (__package__, '.types'),
'StringConstraints': (__package__, '.types'),
'ImportString': (__package__, '.types'),
'conint': (__package__, '.types'),
'PositiveInt': (__package__, '.types'),
'NegativeInt': (__package__, '.types'),
'NonNegativeInt': (__package__, '.types'),
'NonPositiveInt': (__package__, '.types'),
'confloat': (__package__, '.types'),
'PositiveFloat': (__package__, '.types'),
'NegativeFloat': (__package__, '.types'),
'NonNegativeFloat': (__package__, '.types'),
'NonPositiveFloat': (__package__, '.types'),
'FiniteFloat': (__package__, '.types'),
'condecimal': (__package__, '.types'),
'condate': (__package__, '.types'),
'UUID1': (__package__, '.types'),
'UUID3': (__package__, '.types'),
'UUID4': (__package__, '.types'),
'UUID5': (__package__, '.types'),
'FilePath': (__package__, '.types'),
'DirectoryPath': (__package__, '.types'),
'NewPath': (__package__, '.types'),
'Json': (__package__, '.types'),
'SecretStr': (__package__, '.types'),
'SecretBytes': (__package__, '.types'),
'StrictBool': (__package__, '.types'),
'StrictBytes': (__package__, '.types'),
'StrictInt': (__package__, '.types'),
'StrictFloat': (__package__, '.types'),
'PaymentCardNumber': (__package__, '.types'),
'ByteSize': (__package__, '.types'),
'PastDate': (__package__, '.types'),
'FutureDate': (__package__, '.types'),
'PastDatetime': (__package__, '.types'),
'FutureDatetime': (__package__, '.types'),
'AwareDatetime': (__package__, '.types'),
'NaiveDatetime': (__package__, '.types'),
'AllowInfNan': (__package__, '.types'),
'EncoderProtocol': (__package__, '.types'),
'EncodedBytes': (__package__, '.types'),
'EncodedStr': (__package__, '.types'),
'Base64Encoder': (__package__, '.types'),
'Base64Bytes': (__package__, '.types'),
'Base64Str': (__package__, '.types'),
'Base64UrlBytes': (__package__, '.types'),
'Base64UrlStr': (__package__, '.types'),
'GetPydanticSchema': (__package__, '.types'),
'Tag': (__package__, '.types'),
'Discriminator': (__package__, '.types'),
'JsonValue': (__package__, '.types'),
'OnErrorOmit': (__package__, '.types'),
# type_adapter
'TypeAdapter': (__package__, '.type_adapter'),
# warnings
'PydanticDeprecatedSince20': (__package__, '.warnings'),
'PydanticDeprecatedSince26': (__package__, '.warnings'),
'PydanticDeprecationWarning': (__package__, '.warnings'),
# annotated handlers
'GetCoreSchemaHandler': (__package__, '.annotated_handlers'),
'GetJsonSchemaHandler': (__package__, '.annotated_handlers'),
# generate schema from ._internal
'GenerateSchema': (__package__, '._internal._generate_schema'),
# pydantic_core stuff
'ValidationError': ('pydantic_core', '.'),
'ValidationInfo': ('pydantic_core', '.core_schema'),
'SerializationInfo': ('pydantic_core', '.core_schema'),
'ValidatorFunctionWrapHandler': ('pydantic_core', '.core_schema'),
'FieldSerializationInfo': ('pydantic_core', '.core_schema'),
'SerializerFunctionWrapHandler': ('pydantic_core', '.core_schema'),
# deprecated, mostly not included in __all__
'root_validator': (__package__, '.deprecated.class_validators'),
'validator': (__package__, '.deprecated.class_validators'),
'BaseConfig': (__package__, '.deprecated.config'),
'Extra': (__package__, '.deprecated.config'),
'parse_obj_as': (__package__, '.deprecated.tools'),
'schema_of': (__package__, '.deprecated.tools'),
'schema_json_of': (__package__, '.deprecated.tools'),
'FieldValidationInfo': ('pydantic_core', '.core_schema'),
}
_getattr_migration = getattr_migration(__name__)
def __getattr__(attr_name: str) -> object:
dynamic_attr = _dynamic_imports.get(attr_name)
if dynamic_attr is None:
return _getattr_migration(attr_name)
package, module_name = dynamic_attr
from importlib import import_module
if module_name == '__module__':
return import_module(f'.{attr_name}', package=package)
else:
module = import_module(module_name, package=package)
return getattr(module, attr_name)
def __dir__() -> 'list[str]':
return list(__all__)
]

View file

@ -10,7 +10,7 @@ Pydantic is installed. See also:
https://hypothesis.readthedocs.io/en/latest/strategies.html#registering-strategies-via-setuptools-entry-points
https://hypothesis.readthedocs.io/en/latest/data.html#hypothesis.strategies.register_type_strategy
https://hypothesis.readthedocs.io/en/latest/strategies.html#interaction-with-pytest-cov
https://docs.pydantic.dev/usage/types/#pydantic-types
https://pydantic-docs.helpmanual.io/usage/types/#pydantic-types
Note that because our motivation is to *improve user experience*, the strategies
are always sound (never generate invalid data) but sacrifice completeness for
@ -46,7 +46,7 @@ from pydantic.utils import lenient_issubclass
#
# conlist() and conset() are unsupported for now, because the workarounds for
# Cython and Hypothesis to handle parametrized generic types are incompatible.
# We are rethinking Hypothesis compatibility in Pydantic v2.
# Once Cython can support 'normal' generics we'll revisit this.
# Emails
try:
@ -168,11 +168,6 @@ st.register_type_strategy(pydantic.StrictBool, st.booleans())
st.register_type_strategy(pydantic.StrictStr, st.text())
# FutureDate, PastDate
st.register_type_strategy(pydantic.FutureDate, st.dates(min_value=datetime.date.today() + datetime.timedelta(days=1)))
st.register_type_strategy(pydantic.PastDate, st.dates(max_value=datetime.date.today() - datetime.timedelta(days=1)))
# Constrained-type resolver functions
#
# For these ones, we actually want to inspect the type in order to work out a

View file

@ -1,322 +0,0 @@
from __future__ import annotations as _annotations
import warnings
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Any,
Callable,
cast,
)
from pydantic_core import core_schema
from typing_extensions import (
Literal,
Self,
)
from ..aliases import AliasGenerator
from ..config import ConfigDict, ExtraValues, JsonDict, JsonEncoder, JsonSchemaExtraCallable
from ..errors import PydanticUserError
from ..warnings import PydanticDeprecatedSince20
if not TYPE_CHECKING:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
if TYPE_CHECKING:
from .._internal._schema_generation_shared import GenerateSchema
DEPRECATION_MESSAGE = 'Support for class-based `config` is deprecated, use ConfigDict instead.'
class ConfigWrapper:
"""Internal wrapper for Config which exposes ConfigDict items as attributes."""
__slots__ = ('config_dict',)
config_dict: ConfigDict
# all annotations are copied directly from ConfigDict, and should be kept up to date, a test will fail if they
# stop matching
title: str | None
str_to_lower: bool
str_to_upper: bool
str_strip_whitespace: bool
str_min_length: int
str_max_length: int | None
extra: ExtraValues | None
frozen: bool
populate_by_name: bool
use_enum_values: bool
validate_assignment: bool
arbitrary_types_allowed: bool
from_attributes: bool
# whether to use the actual key provided in the data (e.g. alias or first alias for "field required" errors) instead of field_names
# to construct error `loc`s, default `True`
loc_by_alias: bool
alias_generator: Callable[[str], str] | AliasGenerator | None
ignored_types: tuple[type, ...]
allow_inf_nan: bool
json_schema_extra: JsonDict | JsonSchemaExtraCallable | None
json_encoders: dict[type[object], JsonEncoder] | None
# new in V2
strict: bool
# whether instances of models and dataclasses (including subclass instances) should re-validate, default 'never'
revalidate_instances: Literal['always', 'never', 'subclass-instances']
ser_json_timedelta: Literal['iso8601', 'float']
ser_json_bytes: Literal['utf8', 'base64']
ser_json_inf_nan: Literal['null', 'constants']
# whether to validate default values during validation, default False
validate_default: bool
validate_return: bool
protected_namespaces: tuple[str, ...]
hide_input_in_errors: bool
defer_build: bool
plugin_settings: dict[str, object] | None
schema_generator: type[GenerateSchema] | None
json_schema_serialization_defaults_required: bool
json_schema_mode_override: Literal['validation', 'serialization', None]
coerce_numbers_to_str: bool
regex_engine: Literal['rust-regex', 'python-re']
validation_error_cause: bool
def __init__(self, config: ConfigDict | dict[str, Any] | type[Any] | None, *, check: bool = True):
if check:
self.config_dict = prepare_config(config)
else:
self.config_dict = cast(ConfigDict, config)
@classmethod
def for_model(cls, bases: tuple[type[Any], ...], namespace: dict[str, Any], kwargs: dict[str, Any]) -> Self:
"""Build a new `ConfigWrapper` instance for a `BaseModel`.
The config wrapper built based on (in descending order of priority):
- options from `kwargs`
- options from the `namespace`
- options from the base classes (`bases`)
Args:
bases: A tuple of base classes.
namespace: The namespace of the class being created.
kwargs: The kwargs passed to the class being created.
Returns:
A `ConfigWrapper` instance for `BaseModel`.
"""
config_new = ConfigDict()
for base in bases:
config = getattr(base, 'model_config', None)
if config:
config_new.update(config.copy())
config_class_from_namespace = namespace.get('Config')
config_dict_from_namespace = namespace.get('model_config')
if config_class_from_namespace and config_dict_from_namespace:
raise PydanticUserError('"Config" and "model_config" cannot be used together', code='config-both')
config_from_namespace = config_dict_from_namespace or prepare_config(config_class_from_namespace)
config_new.update(config_from_namespace)
for k in list(kwargs.keys()):
if k in config_keys:
config_new[k] = kwargs.pop(k)
return cls(config_new)
# we don't show `__getattr__` to type checkers so missing attributes cause errors
if not TYPE_CHECKING: # pragma: no branch
def __getattr__(self, name: str) -> Any:
try:
return self.config_dict[name]
except KeyError:
try:
return config_defaults[name]
except KeyError:
raise AttributeError(f'Config has no attribute {name!r}') from None
def core_config(self, obj: Any) -> core_schema.CoreConfig:
"""Create a pydantic-core config, `obj` is just used to populate `title` if not set in config.
Pass `obj=None` if you do not want to attempt to infer the `title`.
We don't use getattr here since we don't want to populate with defaults.
Args:
obj: An object used to populate `title` if not set in config.
Returns:
A `CoreConfig` object created from config.
"""
def dict_not_none(**kwargs: Any) -> Any:
return {k: v for k, v in kwargs.items() if v is not None}
core_config = core_schema.CoreConfig(
**dict_not_none(
title=self.config_dict.get('title') or (obj and obj.__name__),
extra_fields_behavior=self.config_dict.get('extra'),
allow_inf_nan=self.config_dict.get('allow_inf_nan'),
populate_by_name=self.config_dict.get('populate_by_name'),
str_strip_whitespace=self.config_dict.get('str_strip_whitespace'),
str_to_lower=self.config_dict.get('str_to_lower'),
str_to_upper=self.config_dict.get('str_to_upper'),
strict=self.config_dict.get('strict'),
ser_json_timedelta=self.config_dict.get('ser_json_timedelta'),
ser_json_bytes=self.config_dict.get('ser_json_bytes'),
ser_json_inf_nan=self.config_dict.get('ser_json_inf_nan'),
from_attributes=self.config_dict.get('from_attributes'),
loc_by_alias=self.config_dict.get('loc_by_alias'),
revalidate_instances=self.config_dict.get('revalidate_instances'),
validate_default=self.config_dict.get('validate_default'),
str_max_length=self.config_dict.get('str_max_length'),
str_min_length=self.config_dict.get('str_min_length'),
hide_input_in_errors=self.config_dict.get('hide_input_in_errors'),
coerce_numbers_to_str=self.config_dict.get('coerce_numbers_to_str'),
regex_engine=self.config_dict.get('regex_engine'),
validation_error_cause=self.config_dict.get('validation_error_cause'),
)
)
return core_config
def __repr__(self):
c = ', '.join(f'{k}={v!r}' for k, v in self.config_dict.items())
return f'ConfigWrapper({c})'
class ConfigWrapperStack:
"""A stack of `ConfigWrapper` instances."""
def __init__(self, config_wrapper: ConfigWrapper):
self._config_wrapper_stack: list[ConfigWrapper] = [config_wrapper]
@property
def tail(self) -> ConfigWrapper:
return self._config_wrapper_stack[-1]
@contextmanager
def push(self, config_wrapper: ConfigWrapper | ConfigDict | None):
if config_wrapper is None:
yield
return
if not isinstance(config_wrapper, ConfigWrapper):
config_wrapper = ConfigWrapper(config_wrapper, check=False)
self._config_wrapper_stack.append(config_wrapper)
try:
yield
finally:
self._config_wrapper_stack.pop()
config_defaults = ConfigDict(
title=None,
str_to_lower=False,
str_to_upper=False,
str_strip_whitespace=False,
str_min_length=0,
str_max_length=None,
# let the model / dataclass decide how to handle it
extra=None,
frozen=False,
populate_by_name=False,
use_enum_values=False,
validate_assignment=False,
arbitrary_types_allowed=False,
from_attributes=False,
loc_by_alias=True,
alias_generator=None,
ignored_types=(),
allow_inf_nan=True,
json_schema_extra=None,
strict=False,
revalidate_instances='never',
ser_json_timedelta='iso8601',
ser_json_bytes='utf8',
ser_json_inf_nan='null',
validate_default=False,
validate_return=False,
protected_namespaces=('model_',),
hide_input_in_errors=False,
json_encoders=None,
defer_build=False,
plugin_settings=None,
schema_generator=None,
json_schema_serialization_defaults_required=False,
json_schema_mode_override=None,
coerce_numbers_to_str=False,
regex_engine='rust-regex',
validation_error_cause=False,
)
def prepare_config(config: ConfigDict | dict[str, Any] | type[Any] | None) -> ConfigDict:
"""Create a `ConfigDict` instance from an existing dict, a class (e.g. old class-based config) or None.
Args:
config: The input config.
Returns:
A ConfigDict object created from config.
"""
if config is None:
return ConfigDict()
if not isinstance(config, dict):
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning)
config = {k: getattr(config, k) for k in dir(config) if not k.startswith('__')}
config_dict = cast(ConfigDict, config)
check_deprecated(config_dict)
return config_dict
config_keys = set(ConfigDict.__annotations__.keys())
V2_REMOVED_KEYS = {
'allow_mutation',
'error_msg_templates',
'fields',
'getter_dict',
'smart_union',
'underscore_attrs_are_private',
'json_loads',
'json_dumps',
'copy_on_model_validation',
'post_init_call',
}
V2_RENAMED_KEYS = {
'allow_population_by_field_name': 'populate_by_name',
'anystr_lower': 'str_to_lower',
'anystr_strip_whitespace': 'str_strip_whitespace',
'anystr_upper': 'str_to_upper',
'keep_untouched': 'ignored_types',
'max_anystr_length': 'str_max_length',
'min_anystr_length': 'str_min_length',
'orm_mode': 'from_attributes',
'schema_extra': 'json_schema_extra',
'validate_all': 'validate_default',
}
def check_deprecated(config_dict: ConfigDict) -> None:
"""Check for deprecated config keys and warn the user.
Args:
config_dict: The input config.
"""
deprecated_removed_keys = V2_REMOVED_KEYS & config_dict.keys()
deprecated_renamed_keys = V2_RENAMED_KEYS.keys() & config_dict.keys()
if deprecated_removed_keys or deprecated_renamed_keys:
renamings = {k: V2_RENAMED_KEYS[k] for k in sorted(deprecated_renamed_keys)}
renamed_bullets = [f'* {k!r} has been renamed to {v!r}' for k, v in renamings.items()]
removed_bullets = [f'* {k!r} has been removed' for k in sorted(deprecated_removed_keys)]
message = '\n'.join(['Valid config keys have changed in V2:'] + renamed_bullets + removed_bullets)
warnings.warn(message, UserWarning)

View file

@ -1,92 +0,0 @@
from __future__ import annotations as _annotations
import typing
from typing import Any
import typing_extensions
if typing.TYPE_CHECKING:
from ._schema_generation_shared import (
CoreSchemaOrField as CoreSchemaOrField,
)
from ._schema_generation_shared import (
GetJsonSchemaFunction,
)
class CoreMetadata(typing_extensions.TypedDict, total=False):
"""A `TypedDict` for holding the metadata dict of the schema.
Attributes:
pydantic_js_functions: List of JSON schema functions.
pydantic_js_prefer_positional_arguments: Whether JSON schema generator will
prefer positional over keyword arguments for an 'arguments' schema.
"""
pydantic_js_functions: list[GetJsonSchemaFunction]
pydantic_js_annotation_functions: list[GetJsonSchemaFunction]
# If `pydantic_js_prefer_positional_arguments` is True, the JSON schema generator will
# prefer positional over keyword arguments for an 'arguments' schema.
pydantic_js_prefer_positional_arguments: bool | None
pydantic_typed_dict_cls: type[Any] | None # TODO: Consider moving this into the pydantic-core TypedDictSchema
class CoreMetadataHandler:
"""Because the metadata field in pydantic_core is of type `Any`, we can't assume much about its contents.
This class is used to interact with the metadata field on a CoreSchema object in a consistent
way throughout pydantic.
"""
__slots__ = ('_schema',)
def __init__(self, schema: CoreSchemaOrField):
self._schema = schema
metadata = schema.get('metadata')
if metadata is None:
schema['metadata'] = CoreMetadata()
elif not isinstance(metadata, dict):
raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.')
@property
def metadata(self) -> CoreMetadata:
"""Retrieves the metadata dict from the schema, initializing it to a dict if it is None
and raises an error if it is not a dict.
"""
metadata = self._schema.get('metadata')
if metadata is None:
self._schema['metadata'] = metadata = CoreMetadata()
if not isinstance(metadata, dict):
raise TypeError(f'CoreSchema metadata should be a dict; got {metadata!r}.')
return metadata
def build_metadata_dict(
*, # force keyword arguments to make it easier to modify this signature in a backwards-compatible way
js_functions: list[GetJsonSchemaFunction] | None = None,
js_annotation_functions: list[GetJsonSchemaFunction] | None = None,
js_prefer_positional_arguments: bool | None = None,
typed_dict_cls: type[Any] | None = None,
initial_metadata: Any | None = None,
) -> Any:
"""Builds a dict to use as the metadata field of a CoreSchema object in a manner that is consistent
with the CoreMetadataHandler class.
"""
if initial_metadata is not None and not isinstance(initial_metadata, dict):
raise TypeError(f'CoreSchema metadata should be a dict; got {initial_metadata!r}.')
metadata = CoreMetadata(
pydantic_js_functions=js_functions or [],
pydantic_js_annotation_functions=js_annotation_functions or [],
pydantic_js_prefer_positional_arguments=js_prefer_positional_arguments,
pydantic_typed_dict_cls=typed_dict_cls,
)
metadata = {k: v for k, v in metadata.items() if v is not None}
if initial_metadata is not None:
metadata = {**initial_metadata, **metadata}
return metadata

View file

@ -1,570 +0,0 @@
from __future__ import annotations
import os
from collections import defaultdict
from typing import (
Any,
Callable,
Hashable,
TypeVar,
Union,
)
from pydantic_core import CoreSchema, core_schema
from pydantic_core import validate_core_schema as _validate_core_schema
from typing_extensions import TypeAliasType, TypeGuard, get_args, get_origin
from . import _repr
from ._typing_extra import is_generic_alias
AnyFunctionSchema = Union[
core_schema.AfterValidatorFunctionSchema,
core_schema.BeforeValidatorFunctionSchema,
core_schema.WrapValidatorFunctionSchema,
core_schema.PlainValidatorFunctionSchema,
]
FunctionSchemaWithInnerSchema = Union[
core_schema.AfterValidatorFunctionSchema,
core_schema.BeforeValidatorFunctionSchema,
core_schema.WrapValidatorFunctionSchema,
]
CoreSchemaField = Union[
core_schema.ModelField, core_schema.DataclassField, core_schema.TypedDictField, core_schema.ComputedField
]
CoreSchemaOrField = Union[core_schema.CoreSchema, CoreSchemaField]
_CORE_SCHEMA_FIELD_TYPES = {'typed-dict-field', 'dataclass-field', 'model-field', 'computed-field'}
_FUNCTION_WITH_INNER_SCHEMA_TYPES = {'function-before', 'function-after', 'function-wrap'}
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'set', 'frozenset'}
_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache'
TAGGED_UNION_TAG_KEY = 'pydantic.internal.tagged_union_tag'
"""
Used in a `Tag` schema to specify the tag used for a discriminated union.
"""
HAS_INVALID_SCHEMAS_METADATA_KEY = 'pydantic.internal.invalid'
"""Used to mark a schema that is invalid because it refers to a definition that was not yet defined when the
schema was first encountered.
"""
def is_core_schema(
schema: CoreSchemaOrField,
) -> TypeGuard[CoreSchema]:
return schema['type'] not in _CORE_SCHEMA_FIELD_TYPES
def is_core_schema_field(
schema: CoreSchemaOrField,
) -> TypeGuard[CoreSchemaField]:
return schema['type'] in _CORE_SCHEMA_FIELD_TYPES
def is_function_with_inner_schema(
schema: CoreSchemaOrField,
) -> TypeGuard[FunctionSchemaWithInnerSchema]:
return schema['type'] in _FUNCTION_WITH_INNER_SCHEMA_TYPES
def is_list_like_schema_with_items_schema(
schema: CoreSchema,
) -> TypeGuard[core_schema.ListSchema | core_schema.SetSchema | core_schema.FrozenSetSchema]:
return schema['type'] in _LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES
def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None = None) -> str:
"""Produces the ref to be used for this type by pydantic_core's core schemas.
This `args_override` argument was added for the purpose of creating valid recursive references
when creating generic models without needing to create a concrete class.
"""
origin = get_origin(type_) or type_
args = get_args(type_) if is_generic_alias(type_) else (args_override or ())
generic_metadata = getattr(type_, '__pydantic_generic_metadata__', None)
if generic_metadata:
origin = generic_metadata['origin'] or origin
args = generic_metadata['args'] or args
module_name = getattr(origin, '__module__', '<No __module__>')
if isinstance(origin, TypeAliasType):
type_ref = f'{module_name}.{origin.__name__}:{id(origin)}'
else:
try:
qualname = getattr(origin, '__qualname__', f'<No __qualname__: {origin}>')
except Exception:
qualname = getattr(origin, '__qualname__', '<No __qualname__>')
type_ref = f'{module_name}.{qualname}:{id(origin)}'
arg_refs: list[str] = []
for arg in args:
if isinstance(arg, str):
# Handle string literals as a special case; we may be able to remove this special handling if we
# wrap them in a ForwardRef at some point.
arg_ref = f'{arg}:str-{id(arg)}'
else:
arg_ref = f'{_repr.display_as_type(arg)}:{id(arg)}'
arg_refs.append(arg_ref)
if arg_refs:
type_ref = f'{type_ref}[{",".join(arg_refs)}]'
return type_ref
def get_ref(s: core_schema.CoreSchema) -> None | str:
"""Get the ref from the schema if it has one.
This exists just for type checking to work correctly.
"""
return s.get('ref', None)
def collect_definitions(schema: core_schema.CoreSchema) -> dict[str, core_schema.CoreSchema]:
defs: dict[str, CoreSchema] = {}
def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
ref = get_ref(s)
if ref:
defs[ref] = s
return recurse(s, _record_valid_refs)
walk_core_schema(schema, _record_valid_refs)
return defs
def define_expected_missing_refs(
schema: core_schema.CoreSchema, allowed_missing_refs: set[str]
) -> core_schema.CoreSchema | None:
if not allowed_missing_refs:
# in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema
# this is a common case (will be hit for all non-generic models), so it's worth optimizing for
return None
refs = collect_definitions(schema).keys()
expected_missing_refs = allowed_missing_refs.difference(refs)
if expected_missing_refs:
definitions: list[core_schema.CoreSchema] = [
# TODO: Replace this with a (new) CoreSchema that, if present at any level, makes validation fail
# Issue: https://github.com/pydantic/pydantic-core/issues/619
core_schema.none_schema(ref=ref, metadata={HAS_INVALID_SCHEMAS_METADATA_KEY: True})
for ref in expected_missing_refs
]
return core_schema.definitions_schema(schema, definitions)
return None
def collect_invalid_schemas(schema: core_schema.CoreSchema) -> bool:
invalid = False
def _is_schema_valid(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
nonlocal invalid
if 'metadata' in s:
metadata = s['metadata']
if HAS_INVALID_SCHEMAS_METADATA_KEY in metadata:
invalid = metadata[HAS_INVALID_SCHEMAS_METADATA_KEY]
return s
return recurse(s, _is_schema_valid)
walk_core_schema(schema, _is_schema_valid)
return invalid
T = TypeVar('T')
Recurse = Callable[[core_schema.CoreSchema, 'Walk'], core_schema.CoreSchema]
Walk = Callable[[core_schema.CoreSchema, Recurse], core_schema.CoreSchema]
# TODO: Should we move _WalkCoreSchema into pydantic_core proper?
# Issue: https://github.com/pydantic/pydantic-core/issues/615
class _WalkCoreSchema:
def __init__(self):
self._schema_type_to_method = self._build_schema_type_to_method()
def _build_schema_type_to_method(self) -> dict[core_schema.CoreSchemaType, Recurse]:
mapping: dict[core_schema.CoreSchemaType, Recurse] = {}
key: core_schema.CoreSchemaType
for key in get_args(core_schema.CoreSchemaType):
method_name = f"handle_{key.replace('-', '_')}_schema"
mapping[key] = getattr(self, method_name, self._handle_other_schemas)
return mapping
def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
return f(schema, self._walk)
def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
schema = self._schema_type_to_method[schema['type']](schema.copy(), f)
ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore
if ser_schema:
schema['serialization'] = self._handle_ser_schemas(ser_schema, f)
return schema
def _handle_other_schemas(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
sub_schema = schema.get('schema', None)
if sub_schema is not None:
schema['schema'] = self.walk(sub_schema, f) # type: ignore
return schema
def _handle_ser_schemas(self, ser_schema: core_schema.SerSchema, f: Walk) -> core_schema.SerSchema:
schema: core_schema.CoreSchema | None = ser_schema.get('schema', None)
if schema is not None:
ser_schema['schema'] = self.walk(schema, f) # type: ignore
return_schema: core_schema.CoreSchema | None = ser_schema.get('return_schema', None)
if return_schema is not None:
ser_schema['return_schema'] = self.walk(return_schema, f) # type: ignore
return ser_schema
def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema, f: Walk) -> core_schema.CoreSchema:
new_definitions: list[core_schema.CoreSchema] = []
for definition in schema['definitions']:
if 'schema_ref' in definition and 'ref' in definition:
# This indicates a purposely indirect reference
# We want to keep such references around for implications related to JSON schema, etc.:
new_definitions.append(definition)
# However, we still need to walk the referenced definition:
self.walk(definition, f)
continue
updated_definition = self.walk(definition, f)
if 'ref' in updated_definition:
# If the updated definition schema doesn't have a 'ref', it shouldn't go in the definitions
# This is most likely to happen due to replacing something with a definition reference, in
# which case it should certainly not go in the definitions list
new_definitions.append(updated_definition)
new_inner_schema = self.walk(schema['schema'], f)
if not new_definitions and len(schema) == 3:
# This means we'd be returning a "trivial" definitions schema that just wrapped the inner schema
return new_inner_schema
new_schema = schema.copy()
new_schema['schema'] = new_inner_schema
new_schema['definitions'] = new_definitions
return new_schema
def handle_list_schema(self, schema: core_schema.ListSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_set_schema(self, schema: core_schema.SetSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_generator_schema(self, schema: core_schema.GeneratorSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_tuple_schema(self, schema: core_schema.TupleSchema, f: Walk) -> core_schema.CoreSchema:
schema['items_schema'] = [self.walk(v, f) for v in schema['items_schema']]
return schema
def handle_dict_schema(self, schema: core_schema.DictSchema, f: Walk) -> core_schema.CoreSchema:
keys_schema = schema.get('keys_schema')
if keys_schema is not None:
schema['keys_schema'] = self.walk(keys_schema, f)
values_schema = schema.get('values_schema')
if values_schema:
schema['values_schema'] = self.walk(values_schema, f)
return schema
def handle_function_schema(self, schema: AnyFunctionSchema, f: Walk) -> core_schema.CoreSchema:
if not is_function_with_inner_schema(schema):
return schema
schema['schema'] = self.walk(schema['schema'], f)
return schema
def handle_union_schema(self, schema: core_schema.UnionSchema, f: Walk) -> core_schema.CoreSchema:
new_choices: list[CoreSchema | tuple[CoreSchema, str]] = []
for v in schema['choices']:
if isinstance(v, tuple):
new_choices.append((self.walk(v[0], f), v[1]))
else:
new_choices.append(self.walk(v, f))
schema['choices'] = new_choices
return schema
def handle_tagged_union_schema(self, schema: core_schema.TaggedUnionSchema, f: Walk) -> core_schema.CoreSchema:
new_choices: dict[Hashable, core_schema.CoreSchema] = {}
for k, v in schema['choices'].items():
new_choices[k] = v if isinstance(v, (str, int)) else self.walk(v, f)
schema['choices'] = new_choices
return schema
def handle_chain_schema(self, schema: core_schema.ChainSchema, f: Walk) -> core_schema.CoreSchema:
schema['steps'] = [self.walk(v, f) for v in schema['steps']]
return schema
def handle_lax_or_strict_schema(self, schema: core_schema.LaxOrStrictSchema, f: Walk) -> core_schema.CoreSchema:
schema['lax_schema'] = self.walk(schema['lax_schema'], f)
schema['strict_schema'] = self.walk(schema['strict_schema'], f)
return schema
def handle_json_or_python_schema(self, schema: core_schema.JsonOrPythonSchema, f: Walk) -> core_schema.CoreSchema:
schema['json_schema'] = self.walk(schema['json_schema'], f)
schema['python_schema'] = self.walk(schema['python_schema'], f)
return schema
def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: Walk) -> core_schema.CoreSchema:
extras_schema = schema.get('extras_schema')
if extras_schema is not None:
schema['extras_schema'] = self.walk(extras_schema, f)
replaced_fields: dict[str, core_schema.ModelField] = {}
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', ()):
replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
if replaced_computed_fields:
schema['computed_fields'] = replaced_computed_fields
for k, v in schema['fields'].items():
replaced_field = v.copy()
replaced_field['schema'] = self.walk(v['schema'], f)
replaced_fields[k] = replaced_field
schema['fields'] = replaced_fields
return schema
def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk) -> core_schema.CoreSchema:
extras_schema = schema.get('extras_schema')
if extras_schema is not None:
schema['extras_schema'] = self.walk(extras_schema, f)
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', ()):
replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
if replaced_computed_fields:
schema['computed_fields'] = replaced_computed_fields
replaced_fields: dict[str, core_schema.TypedDictField] = {}
for k, v in schema['fields'].items():
replaced_field = v.copy()
replaced_field['schema'] = self.walk(v['schema'], f)
replaced_fields[k] = replaced_field
schema['fields'] = replaced_fields
return schema
def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema, f: Walk) -> core_schema.CoreSchema:
replaced_fields: list[core_schema.DataclassField] = []
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', ()):
replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
if replaced_computed_fields:
schema['computed_fields'] = replaced_computed_fields
for field in schema['fields']:
replaced_field = field.copy()
replaced_field['schema'] = self.walk(field['schema'], f)
replaced_fields.append(replaced_field)
schema['fields'] = replaced_fields
return schema
def handle_arguments_schema(self, schema: core_schema.ArgumentsSchema, f: Walk) -> core_schema.CoreSchema:
replaced_arguments_schema: list[core_schema.ArgumentsParameter] = []
for param in schema['arguments_schema']:
replaced_param = param.copy()
replaced_param['schema'] = self.walk(param['schema'], f)
replaced_arguments_schema.append(replaced_param)
schema['arguments_schema'] = replaced_arguments_schema
if 'var_args_schema' in schema:
schema['var_args_schema'] = self.walk(schema['var_args_schema'], f)
if 'var_kwargs_schema' in schema:
schema['var_kwargs_schema'] = self.walk(schema['var_kwargs_schema'], f)
return schema
def handle_call_schema(self, schema: core_schema.CallSchema, f: Walk) -> core_schema.CoreSchema:
schema['arguments_schema'] = self.walk(schema['arguments_schema'], f)
if 'return_schema' in schema:
schema['return_schema'] = self.walk(schema['return_schema'], f)
return schema
_dispatch = _WalkCoreSchema().walk
def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
"""Recursively traverse a CoreSchema.
Args:
schema (core_schema.CoreSchema): The CoreSchema to process, it will not be modified.
f (Walk): A function to apply. This function takes two arguments:
1. The current CoreSchema that is being processed
(not the same one you passed into this function, one level down).
2. The "next" `f` to call. This lets you for example use `f=functools.partial(some_method, some_context)`
to pass data down the recursive calls without using globals or other mutable state.
Returns:
core_schema.CoreSchema: A processed CoreSchema.
"""
return f(schema.copy(), _dispatch)
def simplify_schema_references(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: # noqa: C901
definitions: dict[str, core_schema.CoreSchema] = {}
ref_counts: dict[str, int] = defaultdict(int)
involved_in_recursion: dict[str, bool] = {}
current_recursion_ref_count: dict[str, int] = defaultdict(int)
def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] == 'definitions':
for definition in s['definitions']:
ref = get_ref(definition)
assert ref is not None
if ref not in definitions:
definitions[ref] = definition
recurse(definition, collect_refs)
return recurse(s['schema'], collect_refs)
else:
ref = get_ref(s)
if ref is not None:
new = recurse(s, collect_refs)
new_ref = get_ref(new)
if new_ref:
definitions[new_ref] = new
return core_schema.definition_reference_schema(schema_ref=ref)
else:
return recurse(s, collect_refs)
schema = walk_core_schema(schema, collect_refs)
def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] != 'definition-ref':
return recurse(s, count_refs)
ref = s['schema_ref']
ref_counts[ref] += 1
if ref_counts[ref] >= 2:
# If this model is involved in a recursion this should be detected
# on its second encounter, we can safely stop the walk here.
if current_recursion_ref_count[ref] != 0:
involved_in_recursion[ref] = True
return s
current_recursion_ref_count[ref] += 1
recurse(definitions[ref], count_refs)
current_recursion_ref_count[ref] -= 1
return s
schema = walk_core_schema(schema, count_refs)
assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it'
def can_be_inlined(s: core_schema.DefinitionReferenceSchema, ref: str) -> bool:
if ref_counts[ref] > 1:
return False
if involved_in_recursion.get(ref, False):
return False
if 'serialization' in s:
return False
if 'metadata' in s:
metadata = s['metadata']
for k in (
'pydantic_js_functions',
'pydantic_js_annotation_functions',
'pydantic.internal.union_discriminator',
):
if k in metadata:
# we need to keep this as a ref
return False
return True
def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] == 'definition-ref':
ref = s['schema_ref']
# Check if the reference is only used once, not involved in recursion and does not have
# any extra keys (like 'serialization')
if can_be_inlined(s, ref):
# Inline the reference by replacing the reference with the actual schema
new = definitions.pop(ref)
ref_counts[ref] -= 1 # because we just replaced it!
# put all other keys that were on the def-ref schema into the inlined version
# in particular this is needed for `serialization`
if 'serialization' in s:
new['serialization'] = s['serialization']
s = recurse(new, inline_refs)
return s
else:
return recurse(s, inline_refs)
else:
return recurse(s, inline_refs)
schema = walk_core_schema(schema, inline_refs)
def_values = [v for v in definitions.values() if ref_counts[v['ref']] > 0] # type: ignore
if def_values:
schema = core_schema.definitions_schema(schema=schema, definitions=def_values)
return schema
def _strip_metadata(schema: CoreSchema) -> CoreSchema:
def strip_metadata(s: CoreSchema, recurse: Recurse) -> CoreSchema:
s = s.copy()
s.pop('metadata', None)
if s['type'] == 'model-fields':
s = s.copy()
s['fields'] = {k: v.copy() for k, v in s['fields'].items()}
for field_name, field_schema in s['fields'].items():
field_schema.pop('metadata', None)
s['fields'][field_name] = field_schema
computed_fields = s.get('computed_fields', None)
if computed_fields:
s['computed_fields'] = [cf.copy() for cf in computed_fields]
for cf in computed_fields:
cf.pop('metadata', None)
else:
s.pop('computed_fields', None)
elif s['type'] == 'model':
# remove some defaults
if s.get('custom_init', True) is False:
s.pop('custom_init')
if s.get('root_model', True) is False:
s.pop('root_model')
if {'title'}.issuperset(s.get('config', {}).keys()):
s.pop('config', None)
return recurse(s, strip_metadata)
return walk_core_schema(schema, strip_metadata)
def pretty_print_core_schema(
schema: CoreSchema,
include_metadata: bool = False,
) -> None:
"""Pretty print a CoreSchema using rich.
This is intended for debugging purposes.
Args:
schema: The CoreSchema to print.
include_metadata: Whether to include metadata in the output. Defaults to `False`.
"""
from rich import print # type: ignore # install it manually in your dev env
if not include_metadata:
schema = _strip_metadata(schema)
return print(schema)
def validate_core_schema(schema: CoreSchema) -> CoreSchema:
if 'PYDANTIC_SKIP_VALIDATING_CORE_SCHEMAS' in os.environ:
return schema
return _validate_core_schema(schema)

View file

@ -1,225 +0,0 @@
"""Private logic for creating pydantic dataclasses."""
from __future__ import annotations as _annotations
import dataclasses
import typing
import warnings
from functools import partial, wraps
from typing import Any, Callable, ClassVar
from pydantic_core import (
ArgsKwargs,
SchemaSerializer,
SchemaValidator,
core_schema,
)
from typing_extensions import TypeGuard
from ..errors import PydanticUndefinedAnnotation
from ..fields import FieldInfo
from ..plugin._schema_validator import create_schema_validator
from ..warnings import PydanticDeprecatedSince20
from . import _config, _decorators, _typing_extra
from ._fields import collect_dataclass_fields
from ._generate_schema import GenerateSchema
from ._generics import get_standard_typevars_map
from ._mock_val_ser import set_dataclass_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._signature import generate_pydantic_signature
if typing.TYPE_CHECKING:
from ..config import ConfigDict
class StandardDataclass(typing.Protocol):
__dataclass_fields__: ClassVar[dict[str, Any]]
__dataclass_params__: ClassVar[Any] # in reality `dataclasses._DataclassParams`
__post_init__: ClassVar[Callable[..., None]]
def __init__(self, *args: object, **kwargs: object) -> None:
pass
class PydanticDataclass(StandardDataclass, typing.Protocol):
"""A protocol containing attributes only available once a class has been decorated as a Pydantic dataclass.
Attributes:
__pydantic_config__: Pydantic-specific configuration settings for the dataclass.
__pydantic_complete__: Whether dataclass building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The pydantic-core schema used to build the SchemaValidator and SchemaSerializer.
__pydantic_decorators__: Metadata containing the decorators defined on the dataclass.
__pydantic_fields__: Metadata about the fields defined on the dataclass.
__pydantic_serializer__: The pydantic-core SchemaSerializer used to dump instances of the dataclass.
__pydantic_validator__: The pydantic-core SchemaValidator used to validate instances of the dataclass.
"""
__pydantic_config__: ClassVar[ConfigDict]
__pydantic_complete__: ClassVar[bool]
__pydantic_core_schema__: ClassVar[core_schema.CoreSchema]
__pydantic_decorators__: ClassVar[_decorators.DecoratorInfos]
__pydantic_fields__: ClassVar[dict[str, FieldInfo]]
__pydantic_serializer__: ClassVar[SchemaSerializer]
__pydantic_validator__: ClassVar[SchemaValidator]
else:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
def set_dataclass_fields(cls: type[StandardDataclass], types_namespace: dict[str, Any] | None = None) -> None:
"""Collect and set `cls.__pydantic_fields__`.
Args:
cls: The class.
types_namespace: The types namespace, defaults to `None`.
"""
typevars_map = get_standard_typevars_map(cls)
fields = collect_dataclass_fields(cls, types_namespace, typevars_map=typevars_map)
cls.__pydantic_fields__ = fields # type: ignore
def complete_dataclass(
cls: type[Any],
config_wrapper: _config.ConfigWrapper,
*,
raise_errors: bool = True,
types_namespace: dict[str, Any] | None,
) -> bool:
"""Finish building a pydantic dataclass.
This logic is called on a class which has already been wrapped in `dataclasses.dataclass()`.
This is somewhat analogous to `pydantic._internal._model_construction.complete_model_class`.
Args:
cls: The class.
config_wrapper: The config wrapper instance.
raise_errors: Whether to raise errors, defaults to `True`.
types_namespace: The types namespace.
Returns:
`True` if building a pydantic dataclass is successfully completed, `False` otherwise.
Raises:
PydanticUndefinedAnnotation: If `raise_error` is `True` and there is an undefined annotations.
"""
if hasattr(cls, '__post_init_post_parse__'):
warnings.warn(
'Support for `__post_init_post_parse__` has been dropped, the method will not be called', DeprecationWarning
)
if types_namespace is None:
types_namespace = _typing_extra.get_cls_types_namespace(cls)
set_dataclass_fields(cls, types_namespace)
typevars_map = get_standard_typevars_map(cls)
gen_schema = GenerateSchema(
config_wrapper,
types_namespace,
typevars_map,
)
# This needs to be called before we change the __init__
sig = generate_pydantic_signature(
init=cls.__init__,
fields=cls.__pydantic_fields__, # type: ignore
config_wrapper=config_wrapper,
is_dataclass=True,
)
# dataclass.__init__ must be defined here so its `__qualname__` can be changed since functions can't be copied.
def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) -> None:
__tracebackhide__ = True
s = __dataclass_self__
s.__pydantic_validator__.validate_python(ArgsKwargs(args, kwargs), self_instance=s)
__init__.__qualname__ = f'{cls.__qualname__}.__init__'
cls.__init__ = __init__ # type: ignore
cls.__pydantic_config__ = config_wrapper.config_dict # type: ignore
cls.__signature__ = sig # type: ignore
get_core_schema = getattr(cls, '__get_pydantic_core_schema__', None)
try:
if get_core_schema:
schema = get_core_schema(
cls,
CallbackGetCoreSchemaHandler(
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
gen_schema,
ref_mode='unpack',
),
)
else:
schema = gen_schema.generate_schema(cls, from_dunder_get_core_schema=False)
except PydanticUndefinedAnnotation as e:
if raise_errors:
raise
set_dataclass_mocks(cls, cls.__name__, f'`{e.name}`')
return False
core_config = config_wrapper.core_config(cls)
try:
schema = gen_schema.clean_schema(schema)
except gen_schema.CollectedInvalid:
set_dataclass_mocks(cls, cls.__name__, 'all referenced types')
return False
# We are about to set all the remaining required properties expected for this cast;
# __pydantic_decorators__ and __pydantic_fields__ should already be set
cls = typing.cast('type[PydanticDataclass]', cls)
# debug(schema)
cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = validator = create_schema_validator(
schema, cls, cls.__module__, cls.__qualname__, 'dataclass', core_config, config_wrapper.plugin_settings
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
if config_wrapper.validate_assignment:
@wraps(cls.__setattr__)
def validated_setattr(instance: Any, __field: str, __value: str) -> None:
validator.validate_assignment(instance, __field, __value)
cls.__setattr__ = validated_setattr.__get__(None, cls) # type: ignore
return True
def is_builtin_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
"""Returns True if a class is a stdlib dataclass and *not* a pydantic dataclass.
We check that
- `_cls` is a dataclass
- `_cls` does not inherit from a processed pydantic dataclass (and thus have a `__pydantic_validator__`)
- `_cls` does not have any annotations that are not dataclass fields
e.g.
```py
import dataclasses
import pydantic.dataclasses
@dataclasses.dataclass
class A:
x: int
@pydantic.dataclasses.dataclass
class B(A):
y: int
```
In this case, when we first check `B`, we make an extra check and look at the annotations ('y'),
which won't be a superset of all the dataclass fields (only the stdlib fields i.e. 'x')
Args:
cls: The class.
Returns:
`True` if the class is a stdlib dataclass, `False` otherwise.
"""
return (
dataclasses.is_dataclass(_cls)
and not hasattr(_cls, '__pydantic_validator__')
and set(_cls.__dataclass_fields__).issuperset(set(getattr(_cls, '__annotations__', {})))
)

View file

@ -1,791 +0,0 @@
"""Logic related to validators applied to models etc. via the `@field_validator` and `@model_validator` decorators."""
from __future__ import annotations as _annotations
from collections import deque
from dataclasses import dataclass, field
from functools import cached_property, partial, partialmethod
from inspect import Parameter, Signature, isdatadescriptor, ismethoddescriptor, signature
from itertools import islice
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Iterable, TypeVar, Union
from pydantic_core import PydanticUndefined, core_schema
from typing_extensions import Literal, TypeAlias, is_typeddict
from ..errors import PydanticUserError
from ._core_utils import get_type_ref
from ._internal_dataclass import slots_true
from ._typing_extra import get_function_type_hints
if TYPE_CHECKING:
from ..fields import ComputedFieldInfo
from ..functional_validators import FieldValidatorModes
@dataclass(**slots_true)
class ValidatorDecoratorInfo:
"""A container for data from `@validator` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@validator'.
fields: A tuple of field names the validator should be called on.
mode: The proposed validator mode.
each_item: For complex objects (sets, lists etc.) whether to validate individual
elements rather than the whole object.
always: Whether this method and other validators should be called even if the value is missing.
check_fields: Whether to check that the fields actually exist on the model.
"""
decorator_repr: ClassVar[str] = '@validator'
fields: tuple[str, ...]
mode: Literal['before', 'after']
each_item: bool
always: bool
check_fields: bool | None
@dataclass(**slots_true)
class FieldValidatorDecoratorInfo:
"""A container for data from `@field_validator` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@field_validator'.
fields: A tuple of field names the validator should be called on.
mode: The proposed validator mode.
check_fields: Whether to check that the fields actually exist on the model.
"""
decorator_repr: ClassVar[str] = '@field_validator'
fields: tuple[str, ...]
mode: FieldValidatorModes
check_fields: bool | None
@dataclass(**slots_true)
class RootValidatorDecoratorInfo:
"""A container for data from `@root_validator` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@root_validator'.
mode: The proposed validator mode.
"""
decorator_repr: ClassVar[str] = '@root_validator'
mode: Literal['before', 'after']
@dataclass(**slots_true)
class FieldSerializerDecoratorInfo:
"""A container for data from `@field_serializer` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@field_serializer'.
fields: A tuple of field names the serializer should be called on.
mode: The proposed serializer mode.
return_type: The type of the serializer's return value.
when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
and `'json-unless-none'`.
check_fields: Whether to check that the fields actually exist on the model.
"""
decorator_repr: ClassVar[str] = '@field_serializer'
fields: tuple[str, ...]
mode: Literal['plain', 'wrap']
return_type: Any
when_used: core_schema.WhenUsed
check_fields: bool | None
@dataclass(**slots_true)
class ModelSerializerDecoratorInfo:
"""A container for data from `@model_serializer` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@model_serializer'.
mode: The proposed serializer mode.
return_type: The type of the serializer's return value.
when_used: The serialization condition. Accepts a string with values `'always'`, `'unless-none'`, `'json'`,
and `'json-unless-none'`.
"""
decorator_repr: ClassVar[str] = '@model_serializer'
mode: Literal['plain', 'wrap']
return_type: Any
when_used: core_schema.WhenUsed
@dataclass(**slots_true)
class ModelValidatorDecoratorInfo:
"""A container for data from `@model_validator` so that we can access it
while building the pydantic-core schema.
Attributes:
decorator_repr: A class variable representing the decorator string, '@model_serializer'.
mode: The proposed serializer mode.
"""
decorator_repr: ClassVar[str] = '@model_validator'
mode: Literal['wrap', 'before', 'after']
DecoratorInfo: TypeAlias = """Union[
ValidatorDecoratorInfo,
FieldValidatorDecoratorInfo,
RootValidatorDecoratorInfo,
FieldSerializerDecoratorInfo,
ModelSerializerDecoratorInfo,
ModelValidatorDecoratorInfo,
ComputedFieldInfo,
]"""
ReturnType = TypeVar('ReturnType')
DecoratedType: TypeAlias = (
'Union[classmethod[Any, Any, ReturnType], staticmethod[Any, ReturnType], Callable[..., ReturnType], property]'
)
@dataclass # can't use slots here since we set attributes on `__post_init__`
class PydanticDescriptorProxy(Generic[ReturnType]):
"""Wrap a classmethod, staticmethod, property or unbound function
and act as a descriptor that allows us to detect decorated items
from the class' attributes.
This class' __get__ returns the wrapped item's __get__ result,
which makes it transparent for classmethods and staticmethods.
Attributes:
wrapped: The decorator that has to be wrapped.
decorator_info: The decorator info.
shim: A wrapper function to wrap V1 style function.
"""
wrapped: DecoratedType[ReturnType]
decorator_info: DecoratorInfo
shim: Callable[[Callable[..., Any]], Callable[..., Any]] | None = None
def __post_init__(self):
for attr in 'setter', 'deleter':
if hasattr(self.wrapped, attr):
f = partial(self._call_wrapped_attr, name=attr)
setattr(self, attr, f)
def _call_wrapped_attr(self, func: Callable[[Any], None], *, name: str) -> PydanticDescriptorProxy[ReturnType]:
self.wrapped = getattr(self.wrapped, name)(func)
return self
def __get__(self, obj: object | None, obj_type: type[object] | None = None) -> PydanticDescriptorProxy[ReturnType]:
try:
return self.wrapped.__get__(obj, obj_type)
except AttributeError:
# not a descriptor, e.g. a partial object
return self.wrapped # type: ignore[return-value]
def __set_name__(self, instance: Any, name: str) -> None:
if hasattr(self.wrapped, '__set_name__'):
self.wrapped.__set_name__(instance, name) # pyright: ignore[reportFunctionMemberAccess]
def __getattr__(self, __name: str) -> Any:
"""Forward checks for __isabstractmethod__ and such."""
return getattr(self.wrapped, __name)
DecoratorInfoType = TypeVar('DecoratorInfoType', bound=DecoratorInfo)
@dataclass(**slots_true)
class Decorator(Generic[DecoratorInfoType]):
"""A generic container class to join together the decorator metadata
(metadata from decorator itself, which we have when the
decorator is called but not when we are building the core-schema)
and the bound function (which we have after the class itself is created).
Attributes:
cls_ref: The class ref.
cls_var_name: The decorated function name.
func: The decorated function.
shim: A wrapper function to wrap V1 style function.
info: The decorator info.
"""
cls_ref: str
cls_var_name: str
func: Callable[..., Any]
shim: Callable[[Any], Any] | None
info: DecoratorInfoType
@staticmethod
def build(
cls_: Any,
*,
cls_var_name: str,
shim: Callable[[Any], Any] | None,
info: DecoratorInfoType,
) -> Decorator[DecoratorInfoType]:
"""Build a new decorator.
Args:
cls_: The class.
cls_var_name: The decorated function name.
shim: A wrapper function to wrap V1 style function.
info: The decorator info.
Returns:
The new decorator instance.
"""
func = get_attribute_from_bases(cls_, cls_var_name)
if shim is not None:
func = shim(func)
func = unwrap_wrapped_function(func, unwrap_partial=False)
if not callable(func):
# This branch will get hit for classmethod properties
attribute = get_attribute_from_base_dicts(cls_, cls_var_name) # prevents the binding call to `__get__`
if isinstance(attribute, PydanticDescriptorProxy):
func = unwrap_wrapped_function(attribute.wrapped)
return Decorator(
cls_ref=get_type_ref(cls_),
cls_var_name=cls_var_name,
func=func,
shim=shim,
info=info,
)
def bind_to_cls(self, cls: Any) -> Decorator[DecoratorInfoType]:
"""Bind the decorator to a class.
Args:
cls: the class.
Returns:
The new decorator instance.
"""
return self.build(
cls,
cls_var_name=self.cls_var_name,
shim=self.shim,
info=self.info,
)
def get_bases(tp: type[Any]) -> tuple[type[Any], ...]:
"""Get the base classes of a class or typeddict.
Args:
tp: The type or class to get the bases.
Returns:
The base classes.
"""
if is_typeddict(tp):
return tp.__orig_bases__ # type: ignore
try:
return tp.__bases__
except AttributeError:
return ()
def mro(tp: type[Any]) -> tuple[type[Any], ...]:
"""Calculate the Method Resolution Order of bases using the C3 algorithm.
See https://www.python.org/download/releases/2.3/mro/
"""
# try to use the existing mro, for performance mainly
# but also because it helps verify the implementation below
if not is_typeddict(tp):
try:
return tp.__mro__
except AttributeError:
# GenericAlias and some other cases
pass
bases = get_bases(tp)
return (tp,) + mro_for_bases(bases)
def mro_for_bases(bases: tuple[type[Any], ...]) -> tuple[type[Any], ...]:
def merge_seqs(seqs: list[deque[type[Any]]]) -> Iterable[type[Any]]:
while True:
non_empty = [seq for seq in seqs if seq]
if not non_empty:
# Nothing left to process, we're done.
return
candidate: type[Any] | None = None
for seq in non_empty: # Find merge candidates among seq heads.
candidate = seq[0]
not_head = [s for s in non_empty if candidate in islice(s, 1, None)]
if not_head:
# Reject the candidate.
candidate = None
else:
break
if not candidate:
raise TypeError('Inconsistent hierarchy, no C3 MRO is possible')
yield candidate
for seq in non_empty:
# Remove candidate.
if seq[0] == candidate:
seq.popleft()
seqs = [deque(mro(base)) for base in bases] + [deque(bases)]
return tuple(merge_seqs(seqs))
_sentinel = object()
def get_attribute_from_bases(tp: type[Any] | tuple[type[Any], ...], name: str) -> Any:
"""Get the attribute from the next class in the MRO that has it,
aiming to simulate calling the method on the actual class.
The reason for iterating over the mro instead of just getting
the attribute (which would do that for us) is to support TypedDict,
which lacks a real __mro__, but can have a virtual one constructed
from its bases (as done here).
Args:
tp: The type or class to search for the attribute. If a tuple, this is treated as a set of base classes.
name: The name of the attribute to retrieve.
Returns:
Any: The attribute value, if found.
Raises:
AttributeError: If the attribute is not found in any class in the MRO.
"""
if isinstance(tp, tuple):
for base in mro_for_bases(tp):
attribute = base.__dict__.get(name, _sentinel)
if attribute is not _sentinel:
attribute_get = getattr(attribute, '__get__', None)
if attribute_get is not None:
return attribute_get(None, tp)
return attribute
raise AttributeError(f'{name} not found in {tp}')
else:
try:
return getattr(tp, name)
except AttributeError:
return get_attribute_from_bases(mro(tp), name)
def get_attribute_from_base_dicts(tp: type[Any], name: str) -> Any:
"""Get an attribute out of the `__dict__` following the MRO.
This prevents the call to `__get__` on the descriptor, and allows
us to get the original function for classmethod properties.
Args:
tp: The type or class to search for the attribute.
name: The name of the attribute to retrieve.
Returns:
Any: The attribute value, if found.
Raises:
KeyError: If the attribute is not found in any class's `__dict__` in the MRO.
"""
for base in reversed(mro(tp)):
if name in base.__dict__:
return base.__dict__[name]
return tp.__dict__[name] # raise the error
@dataclass(**slots_true)
class DecoratorInfos:
"""Mapping of name in the class namespace to decorator info.
note that the name in the class namespace is the function or attribute name
not the field name!
"""
validators: dict[str, Decorator[ValidatorDecoratorInfo]] = field(default_factory=dict)
field_validators: dict[str, Decorator[FieldValidatorDecoratorInfo]] = field(default_factory=dict)
root_validators: dict[str, Decorator[RootValidatorDecoratorInfo]] = field(default_factory=dict)
field_serializers: dict[str, Decorator[FieldSerializerDecoratorInfo]] = field(default_factory=dict)
model_serializers: dict[str, Decorator[ModelSerializerDecoratorInfo]] = field(default_factory=dict)
model_validators: dict[str, Decorator[ModelValidatorDecoratorInfo]] = field(default_factory=dict)
computed_fields: dict[str, Decorator[ComputedFieldInfo]] = field(default_factory=dict)
@staticmethod
def build(model_dc: type[Any]) -> DecoratorInfos: # noqa: C901 (ignore complexity)
"""We want to collect all DecFunc instances that exist as
attributes in the namespace of the class (a BaseModel or dataclass)
that called us
But we want to collect these in the order of the bases
So instead of getting them all from the leaf class (the class that called us),
we traverse the bases from root (the oldest ancestor class) to leaf
and collect all of the instances as we go, taking care to replace
any duplicate ones with the last one we see to mimic how function overriding
works with inheritance.
If we do replace any functions we put the replacement into the position
the replaced function was in; that is, we maintain the order.
"""
# reminder: dicts are ordered and replacement does not alter the order
res = DecoratorInfos()
for base in reversed(mro(model_dc)[1:]):
existing: DecoratorInfos | None = base.__dict__.get('__pydantic_decorators__')
if existing is None:
existing = DecoratorInfos.build(base)
res.validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.validators.items()})
res.field_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_validators.items()})
res.root_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.root_validators.items()})
res.field_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.field_serializers.items()})
res.model_serializers.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_serializers.items()})
res.model_validators.update({k: v.bind_to_cls(model_dc) for k, v in existing.model_validators.items()})
res.computed_fields.update({k: v.bind_to_cls(model_dc) for k, v in existing.computed_fields.items()})
to_replace: list[tuple[str, Any]] = []
for var_name, var_value in vars(model_dc).items():
if isinstance(var_value, PydanticDescriptorProxy):
info = var_value.decorator_info
if isinstance(info, ValidatorDecoratorInfo):
res.validators[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
elif isinstance(info, FieldValidatorDecoratorInfo):
res.field_validators[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
elif isinstance(info, RootValidatorDecoratorInfo):
res.root_validators[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
elif isinstance(info, FieldSerializerDecoratorInfo):
# check whether a serializer function is already registered for fields
for field_serializer_decorator in res.field_serializers.values():
# check that each field has at most one serializer function.
# serializer functions for the same field in subclasses are allowed,
# and are treated as overrides
if field_serializer_decorator.cls_var_name == var_name:
continue
for f in info.fields:
if f in field_serializer_decorator.info.fields:
raise PydanticUserError(
'Multiple field serializer functions were defined '
f'for field {f!r}, this is not allowed.',
code='multiple-field-serializers',
)
res.field_serializers[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
elif isinstance(info, ModelValidatorDecoratorInfo):
res.model_validators[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
elif isinstance(info, ModelSerializerDecoratorInfo):
res.model_serializers[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=var_value.shim, info=info
)
else:
from ..fields import ComputedFieldInfo
isinstance(var_value, ComputedFieldInfo)
res.computed_fields[var_name] = Decorator.build(
model_dc, cls_var_name=var_name, shim=None, info=info
)
to_replace.append((var_name, var_value.wrapped))
if to_replace:
# If we can save `__pydantic_decorators__` on the class we'll be able to check for it above
# so then we don't need to re-process the type, which means we can discard our descriptor wrappers
# and replace them with the thing they are wrapping (see the other setattr call below)
# which allows validator class methods to also function as regular class methods
setattr(model_dc, '__pydantic_decorators__', res)
for name, value in to_replace:
setattr(model_dc, name, value)
return res
def inspect_validator(validator: Callable[..., Any], mode: FieldValidatorModes) -> bool:
"""Look at a field or model validator function and determine whether it takes an info argument.
An error is raised if the function has an invalid signature.
Args:
validator: The validator function to inspect.
mode: The proposed validator mode.
Returns:
Whether the validator takes an info argument.
"""
try:
sig = signature(validator)
except ValueError:
# builtins and some C extensions don't have signatures
# assume that they don't take an info argument and only take a single argument
# e.g. `str.strip` or `datetime.datetime`
return False
n_positional = count_positional_params(sig)
if mode == 'wrap':
if n_positional == 3:
return True
elif n_positional == 2:
return False
else:
assert mode in {'before', 'after', 'plain'}, f"invalid mode: {mode!r}, expected 'before', 'after' or 'plain"
if n_positional == 2:
return True
elif n_positional == 1:
return False
raise PydanticUserError(
f'Unrecognized field_validator function signature for {validator} with `mode={mode}`:{sig}',
code='validator-signature',
)
def inspect_field_serializer(
serializer: Callable[..., Any], mode: Literal['plain', 'wrap'], computed_field: bool = False
) -> tuple[bool, bool]:
"""Look at a field serializer function and determine if it is a field serializer,
and whether it takes an info argument.
An error is raised if the function has an invalid signature.
Args:
serializer: The serializer function to inspect.
mode: The serializer mode, either 'plain' or 'wrap'.
computed_field: When serializer is applied on computed_field. It doesn't require
info signature.
Returns:
Tuple of (is_field_serializer, info_arg).
"""
sig = signature(serializer)
first = next(iter(sig.parameters.values()), None)
is_field_serializer = first is not None and first.name == 'self'
n_positional = count_positional_params(sig)
if is_field_serializer:
# -1 to correct for self parameter
info_arg = _serializer_info_arg(mode, n_positional - 1)
else:
info_arg = _serializer_info_arg(mode, n_positional)
if info_arg is None:
raise PydanticUserError(
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
code='field-serializer-signature',
)
if info_arg and computed_field:
raise PydanticUserError(
'field_serializer on computed_field does not use info signature', code='field-serializer-signature'
)
else:
return is_field_serializer, info_arg
def inspect_annotated_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
"""Look at a serializer function used via `Annotated` and determine whether it takes an info argument.
An error is raised if the function has an invalid signature.
Args:
serializer: The serializer function to check.
mode: The serializer mode, either 'plain' or 'wrap'.
Returns:
info_arg
"""
sig = signature(serializer)
info_arg = _serializer_info_arg(mode, count_positional_params(sig))
if info_arg is None:
raise PydanticUserError(
f'Unrecognized field_serializer function signature for {serializer} with `mode={mode}`:{sig}',
code='field-serializer-signature',
)
else:
return info_arg
def inspect_model_serializer(serializer: Callable[..., Any], mode: Literal['plain', 'wrap']) -> bool:
"""Look at a model serializer function and determine whether it takes an info argument.
An error is raised if the function has an invalid signature.
Args:
serializer: The serializer function to check.
mode: The serializer mode, either 'plain' or 'wrap'.
Returns:
`info_arg` - whether the function expects an info argument.
"""
if isinstance(serializer, (staticmethod, classmethod)) or not is_instance_method_from_sig(serializer):
raise PydanticUserError(
'`@model_serializer` must be applied to instance methods', code='model-serializer-instance-method'
)
sig = signature(serializer)
info_arg = _serializer_info_arg(mode, count_positional_params(sig))
if info_arg is None:
raise PydanticUserError(
f'Unrecognized model_serializer function signature for {serializer} with `mode={mode}`:{sig}',
code='model-serializer-signature',
)
else:
return info_arg
def _serializer_info_arg(mode: Literal['plain', 'wrap'], n_positional: int) -> bool | None:
if mode == 'plain':
if n_positional == 1:
# (__input_value: Any) -> Any
return False
elif n_positional == 2:
# (__model: Any, __input_value: Any) -> Any
return True
else:
assert mode == 'wrap', f"invalid mode: {mode!r}, expected 'plain' or 'wrap'"
if n_positional == 2:
# (__input_value: Any, __serializer: SerializerFunctionWrapHandler) -> Any
return False
elif n_positional == 3:
# (__input_value: Any, __serializer: SerializerFunctionWrapHandler, __info: SerializationInfo) -> Any
return True
return None
AnyDecoratorCallable: TypeAlias = (
'Union[classmethod[Any, Any, Any], staticmethod[Any, Any], partialmethod[Any], Callable[..., Any]]'
)
def is_instance_method_from_sig(function: AnyDecoratorCallable) -> bool:
"""Whether the function is an instance method.
It will consider a function as instance method if the first parameter of
function is `self`.
Args:
function: The function to check.
Returns:
`True` if the function is an instance method, `False` otherwise.
"""
sig = signature(unwrap_wrapped_function(function))
first = next(iter(sig.parameters.values()), None)
if first and first.name == 'self':
return True
return False
def ensure_classmethod_based_on_signature(function: AnyDecoratorCallable) -> Any:
"""Apply the `@classmethod` decorator on the function.
Args:
function: The function to apply the decorator on.
Return:
The `@classmethod` decorator applied function.
"""
if not isinstance(
unwrap_wrapped_function(function, unwrap_class_static_method=False), classmethod
) and _is_classmethod_from_sig(function):
return classmethod(function) # type: ignore[arg-type]
return function
def _is_classmethod_from_sig(function: AnyDecoratorCallable) -> bool:
sig = signature(unwrap_wrapped_function(function))
first = next(iter(sig.parameters.values()), None)
if first and first.name == 'cls':
return True
return False
def unwrap_wrapped_function(
func: Any,
*,
unwrap_partial: bool = True,
unwrap_class_static_method: bool = True,
) -> Any:
"""Recursively unwraps a wrapped function until the underlying function is reached.
This handles property, functools.partial, functools.partialmethod, staticmethod and classmethod.
Args:
func: The function to unwrap.
unwrap_partial: If True (default), unwrap partial and partialmethod decorators, otherwise don't.
decorators.
unwrap_class_static_method: If True (default), also unwrap classmethod and staticmethod
decorators. If False, only unwrap partial and partialmethod decorators.
Returns:
The underlying function of the wrapped function.
"""
all: set[Any] = {property, cached_property}
if unwrap_partial:
all.update({partial, partialmethod})
if unwrap_class_static_method:
all.update({staticmethod, classmethod})
while isinstance(func, tuple(all)):
if unwrap_class_static_method and isinstance(func, (classmethod, staticmethod)):
func = func.__func__
elif isinstance(func, (partial, partialmethod)):
func = func.func
elif isinstance(func, property):
func = func.fget # arbitrary choice, convenient for computed fields
else:
# Make coverage happy as it can only get here in the last possible case
assert isinstance(func, cached_property)
func = func.func # type: ignore
return func
def get_function_return_type(
func: Any, explicit_return_type: Any, types_namespace: dict[str, Any] | None = None
) -> Any:
"""Get the function return type.
It gets the return type from the type annotation if `explicit_return_type` is `None`.
Otherwise, it returns `explicit_return_type`.
Args:
func: The function to get its return type.
explicit_return_type: The explicit return type.
types_namespace: The types namespace, defaults to `None`.
Returns:
The function return type.
"""
if explicit_return_type is PydanticUndefined:
# try to get it from the type annotation
hints = get_function_type_hints(
unwrap_wrapped_function(func), include_keys={'return'}, types_namespace=types_namespace
)
return hints.get('return', PydanticUndefined)
else:
return explicit_return_type
def count_positional_params(sig: Signature) -> int:
return sum(1 for param in sig.parameters.values() if can_be_positional(param))
def can_be_positional(param: Parameter) -> bool:
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
def ensure_property(f: Any) -> Any:
"""Ensure that a function is a `property` or `cached_property`, or is a valid descriptor.
Args:
f: The function to check.
Returns:
The function, or a `property` or `cached_property` instance wrapping the function.
"""
if ismethoddescriptor(f) or isdatadescriptor(f):
return f
else:
return property(f)

View file

@ -1,181 +0,0 @@
"""Logic for V1 validators, e.g. `@validator` and `@root_validator`."""
from __future__ import annotations as _annotations
from inspect import Parameter, signature
from typing import Any, Dict, Tuple, Union, cast
from pydantic_core import core_schema
from typing_extensions import Protocol
from ..errors import PydanticUserError
from ._decorators import can_be_positional
class V1OnlyValueValidator(Protocol):
"""A simple validator, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any) -> Any:
...
class V1ValidatorWithValues(Protocol):
"""A validator with `values` argument, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, values: dict[str, Any]) -> Any:
...
class V1ValidatorWithValuesKwOnly(Protocol):
"""A validator with keyword only `values` argument, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, *, values: dict[str, Any]) -> Any:
...
class V1ValidatorWithKwargs(Protocol):
"""A validator with `kwargs` argument, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, **kwargs: Any) -> Any:
...
class V1ValidatorWithValuesAndKwargs(Protocol):
"""A validator with `values` and `kwargs` arguments, supported for V1 validators and V2 validators."""
def __call__(self, __value: Any, values: dict[str, Any], **kwargs: Any) -> Any:
...
V1Validator = Union[
V1ValidatorWithValues, V1ValidatorWithValuesKwOnly, V1ValidatorWithKwargs, V1ValidatorWithValuesAndKwargs
]
def can_be_keyword(param: Parameter) -> bool:
return param.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY)
def make_generic_v1_field_validator(validator: V1Validator) -> core_schema.WithInfoValidatorFunction:
"""Wrap a V1 style field validator for V2 compatibility.
Args:
validator: The V1 style field validator.
Returns:
A wrapped V2 style field validator.
Raises:
PydanticUserError: If the signature is not supported or the parameters are
not available in Pydantic V2.
"""
sig = signature(validator)
needs_values_kw = False
for param_num, (param_name, parameter) in enumerate(sig.parameters.items()):
if can_be_keyword(parameter) and param_name in ('field', 'config'):
raise PydanticUserError(
'The `field` and `config` parameters are not available in Pydantic V2, '
'please use the `info` parameter instead.',
code='validator-field-config-info',
)
if parameter.kind is Parameter.VAR_KEYWORD:
needs_values_kw = True
elif can_be_keyword(parameter) and param_name == 'values':
needs_values_kw = True
elif can_be_positional(parameter) and param_num == 0:
# value
continue
elif parameter.default is Parameter.empty: # ignore params with defaults e.g. bound by functools.partial
raise PydanticUserError(
f'Unsupported signature for V1 style validator {validator}: {sig} is not supported.',
code='validator-v1-signature',
)
if needs_values_kw:
# (v, **kwargs), (v, values, **kwargs), (v, *, values, **kwargs) or (v, *, values)
val1 = cast(V1ValidatorWithValues, validator)
def wrapper1(value: Any, info: core_schema.ValidationInfo) -> Any:
return val1(value, values=info.data)
return wrapper1
else:
val2 = cast(V1OnlyValueValidator, validator)
def wrapper2(value: Any, _: core_schema.ValidationInfo) -> Any:
return val2(value)
return wrapper2
RootValidatorValues = Dict[str, Any]
# technically tuple[model_dict, model_extra, fields_set] | tuple[dataclass_dict, init_vars]
RootValidatorFieldsTuple = Tuple[Any, ...]
class V1RootValidatorFunction(Protocol):
"""A simple root validator, supported for V1 validators and V2 validators."""
def __call__(self, __values: RootValidatorValues) -> RootValidatorValues:
...
class V2CoreBeforeRootValidator(Protocol):
"""V2 validator with mode='before'."""
def __call__(self, __values: RootValidatorValues, __info: core_schema.ValidationInfo) -> RootValidatorValues:
...
class V2CoreAfterRootValidator(Protocol):
"""V2 validator with mode='after'."""
def __call__(
self, __fields_tuple: RootValidatorFieldsTuple, __info: core_schema.ValidationInfo
) -> RootValidatorFieldsTuple:
...
def make_v1_generic_root_validator(
validator: V1RootValidatorFunction, pre: bool
) -> V2CoreBeforeRootValidator | V2CoreAfterRootValidator:
"""Wrap a V1 style root validator for V2 compatibility.
Args:
validator: The V1 style field validator.
pre: Whether the validator is a pre validator.
Returns:
A wrapped V2 style validator.
"""
if pre is True:
# mode='before' for pydantic-core
def _wrapper1(values: RootValidatorValues, _: core_schema.ValidationInfo) -> RootValidatorValues:
return validator(values)
return _wrapper1
# mode='after' for pydantic-core
def _wrapper2(fields_tuple: RootValidatorFieldsTuple, _: core_schema.ValidationInfo) -> RootValidatorFieldsTuple:
if len(fields_tuple) == 2:
# dataclass, this is easy
values, init_vars = fields_tuple
values = validator(values)
return values, init_vars
else:
# ugly hack: to match v1 behaviour, we merge values and model_extra, then split them up based on fields
# afterwards
model_dict, model_extra, fields_set = fields_tuple
if model_extra:
fields = set(model_dict.keys())
model_dict.update(model_extra)
model_dict_new = validator(model_dict)
for k in list(model_dict_new.keys()):
if k not in fields:
model_extra[k] = model_dict_new.pop(k)
else:
model_dict_new = validator(model_dict)
return model_dict_new, model_extra, fields_set
return _wrapper2

View file

@ -1,506 +0,0 @@
from __future__ import annotations as _annotations
from typing import TYPE_CHECKING, Any, Hashable, Sequence
from pydantic_core import CoreSchema, core_schema
from ..errors import PydanticUserError
from . import _core_utils
from ._core_utils import (
CoreSchemaField,
collect_definitions,
simplify_schema_references,
)
if TYPE_CHECKING:
from ..types import Discriminator
CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'
class MissingDefinitionForUnionRef(Exception):
"""Raised when applying a discriminated union discriminator to a schema
requires a definition that is not yet defined
"""
def __init__(self, ref: str) -> None:
self.ref = ref
super().__init__(f'Missing definition for ref {self.ref!r}')
def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None:
schema.setdefault('metadata', {})
metadata = schema.get('metadata')
assert metadata is not None
metadata[CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY] = discriminator
def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
definitions: dict[str, CoreSchema] | None = None
def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema:
nonlocal definitions
s = recurse(s, inner)
if s['type'] == 'tagged-union':
return s
metadata = s.get('metadata', {})
discriminator = metadata.pop(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
if discriminator is not None:
if definitions is None:
definitions = collect_definitions(schema)
s = apply_discriminator(s, discriminator, definitions)
return s
return simplify_schema_references(_core_utils.walk_core_schema(schema, inner))
def apply_discriminator(
schema: core_schema.CoreSchema,
discriminator: str | Discriminator,
definitions: dict[str, core_schema.CoreSchema] | None = None,
) -> core_schema.CoreSchema:
"""Applies the discriminator and returns a new core schema.
Args:
schema: The input schema.
discriminator: The name of the field which will serve as the discriminator.
definitions: A mapping of schema ref to schema.
Returns:
The new core schema.
Raises:
TypeError:
- If `discriminator` is used with invalid union variant.
- If `discriminator` is used with `Union` type with one variant.
- If `discriminator` value mapped to multiple choices.
MissingDefinitionForUnionRef:
If the definition for ref is missing.
PydanticUserError:
- If a model in union doesn't have a discriminator field.
- If discriminator field has a non-string alias.
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
from ..types import Discriminator
if isinstance(discriminator, Discriminator):
if isinstance(discriminator.discriminator, str):
discriminator = discriminator.discriminator
else:
return discriminator._convert_schema(schema)
return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)
class _ApplyInferredDiscriminator:
"""This class is used to convert an input schema containing a union schema into one where that union is
replaced with a tagged-union, with all the associated debugging and performance benefits.
This is done by:
* Validating that the input schema is compatible with the provided discriminator
* Introspecting the schema to determine which discriminator values should map to which union choices
* Handling various edge cases such as 'definitions', 'default', 'nullable' schemas, and more
I have chosen to implement the conversion algorithm in this class, rather than a function,
to make it easier to maintain state while recursively walking the provided CoreSchema.
"""
def __init__(self, discriminator: str, definitions: dict[str, core_schema.CoreSchema]):
# `discriminator` should be the name of the field which will serve as the discriminator.
# It must be the python name of the field, and *not* the field's alias. Note that as of now,
# all members of a discriminated union _must_ use a field with the same name as the discriminator.
# This may change if/when we expose a way to manually specify the TaggedUnionSchema's choices.
self.discriminator = discriminator
# `definitions` should contain a mapping of schema ref to schema for all schemas which might
# be referenced by some choice
self.definitions = definitions
# `_discriminator_alias` will hold the value, if present, of the alias for the discriminator
#
# Note: following the v1 implementation, we currently disallow the use of different aliases
# for different choices. This is not a limitation of pydantic_core, but if we try to handle
# this, the inference logic gets complicated very quickly, and could result in confusing
# debugging challenges for users making subtle mistakes.
#
# Rather than trying to do the most powerful inference possible, I think we should eventually
# expose a way to more-manually control the way the TaggedUnionSchema is constructed through
# the use of a new type which would be placed as an Annotation on the Union type. This would
# provide the full flexibility/power of pydantic_core's TaggedUnionSchema where necessary for
# more complex cases, without over-complicating the inference logic for the common cases.
self._discriminator_alias: str | None = None
# `_should_be_nullable` indicates whether the converted union has `None` as an allowed value.
# If `None` is an acceptable value of the (possibly-wrapped) union, we ignore it while
# constructing the TaggedUnionSchema, but set the `_should_be_nullable` attribute to True.
# Once we have constructed the TaggedUnionSchema, if `_should_be_nullable` is True, we ensure
# that the final schema gets wrapped as a NullableSchema. This has the same semantics on the
# python side, but resolves the issue that `None` cannot correspond to any discriminator values.
self._should_be_nullable = False
# `_is_nullable` is used to track if the final produced schema will definitely be nullable;
# we set it to True if the input schema is wrapped in a nullable schema that we know will be preserved
# as an indication that, even if None is discovered as one of the union choices, we will not need to wrap
# the final value in another nullable schema.
#
# This is more complicated than just checking for the final outermost schema having type 'nullable' thanks
# to the possible presence of other wrapper schemas such as DefinitionsSchema, WithDefaultSchema, etc.
self._is_nullable = False
# `_choices_to_handle` serves as a stack of choices to add to the tagged union. Initially, choices
# from the union in the wrapped schema will be appended to this list, and the recursive choice-handling
# algorithm may add more choices to this stack as (nested) unions are encountered.
self._choices_to_handle: list[core_schema.CoreSchema] = []
# `_tagged_union_choices` is built during the call to `apply`, and will hold the choices to be included
# in the output TaggedUnionSchema that will replace the union from the input schema
self._tagged_union_choices: dict[Hashable, core_schema.CoreSchema] = {}
# `_used` is changed to True after applying the discriminator to prevent accidental re-use
self._used = False
def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""Return a new CoreSchema based on `schema` that uses a tagged-union with the discriminator provided
to this class.
Args:
schema: The input schema.
Returns:
The new core schema.
Raises:
TypeError:
- If `discriminator` is used with invalid union variant.
- If `discriminator` is used with `Union` type with one variant.
- If `discriminator` value mapped to multiple choices.
ValueError:
If the definition for ref is missing.
PydanticUserError:
- If a model in union doesn't have a discriminator field.
- If discriminator field has a non-string alias.
- If discriminator fields have different aliases.
- If discriminator field not of type `Literal`.
"""
self.definitions.update(collect_definitions(schema))
assert not self._used
schema = self._apply_to_root(schema)
if self._should_be_nullable and not self._is_nullable:
schema = core_schema.nullable_schema(schema)
self._used = True
new_defs = collect_definitions(schema)
missing_defs = self.definitions.keys() - new_defs.keys()
if missing_defs:
schema = core_schema.definitions_schema(schema, [self.definitions[ref] for ref in missing_defs])
return schema
def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""This method handles the outer-most stage of recursion over the input schema:
unwrapping nullable or definitions schemas, and calling the `_handle_choice`
method iteratively on the choices extracted (recursively) from the possibly-wrapped union.
"""
if schema['type'] == 'nullable':
self._is_nullable = True
wrapped = self._apply_to_root(schema['schema'])
nullable_wrapper = schema.copy()
nullable_wrapper['schema'] = wrapped
return nullable_wrapper
if schema['type'] == 'definitions':
wrapped = self._apply_to_root(schema['schema'])
definitions_wrapper = schema.copy()
definitions_wrapper['schema'] = wrapped
return definitions_wrapper
if schema['type'] != 'union':
# If the schema is not a union, it probably means it just had a single member and
# was flattened by pydantic_core.
# However, it still may make sense to apply the discriminator to this schema,
# as a way to get discriminated-union-style error messages, so we allow this here.
schema = core_schema.union_schema([schema])
# Reverse the choices list before extending the stack so that they get handled in the order they occur
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in schema['choices'][::-1]]
self._choices_to_handle.extend(choices_schemas)
while self._choices_to_handle:
choice = self._choices_to_handle.pop()
self._handle_choice(choice)
if self._discriminator_alias is not None and self._discriminator_alias != self.discriminator:
# * We need to annotate `discriminator` as a union here to handle both branches of this conditional
# * We need to annotate `discriminator` as list[list[str | int]] and not list[list[str]] due to the
# invariance of list, and because list[list[str | int]] is the type of the discriminator argument
# to tagged_union_schema below
# * See the docstring of pydantic_core.core_schema.tagged_union_schema for more details about how to
# interpret the value of the discriminator argument to tagged_union_schema. (The list[list[str]] here
# is the appropriate way to provide a list of fallback attributes to check for a discriminator value.)
discriminator: str | list[list[str | int]] = [[self.discriminator], [self._discriminator_alias]]
else:
discriminator = self.discriminator
return core_schema.tagged_union_schema(
choices=self._tagged_union_choices,
discriminator=discriminator,
custom_error_type=schema.get('custom_error_type'),
custom_error_message=schema.get('custom_error_message'),
custom_error_context=schema.get('custom_error_context'),
strict=False,
from_attributes=True,
ref=schema.get('ref'),
metadata=schema.get('metadata'),
serialization=schema.get('serialization'),
)
def _handle_choice(self, choice: core_schema.CoreSchema) -> None:
"""This method handles the "middle" stage of recursion over the input schema.
Specifically, it is responsible for handling each choice of the outermost union
(and any "coalesced" choices obtained from inner unions).
Here, "handling" entails:
* Coalescing nested unions and compatible tagged-unions
* Tracking the presence of 'none' and 'nullable' schemas occurring as choices
* Validating that each allowed discriminator value maps to a unique choice
* Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
"""
if choice['type'] == 'definition-ref':
if choice['schema_ref'] not in self.definitions:
raise MissingDefinitionForUnionRef(choice['schema_ref'])
if choice['type'] == 'none':
self._should_be_nullable = True
elif choice['type'] == 'definitions':
self._handle_choice(choice['schema'])
elif choice['type'] == 'nullable':
self._should_be_nullable = True
self._handle_choice(choice['schema']) # unwrap the nullable schema
elif choice['type'] == 'union':
# Reverse the choices list before extending the stack so that they get handled in the order they occur
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]]
self._choices_to_handle.extend(choices_schemas)
elif choice['type'] not in {
'model',
'typed-dict',
'tagged-union',
'lax-or-strict',
'dataclass',
'dataclass-args',
'definition-ref',
} and not _core_utils.is_function_with_inner_schema(choice):
# We should eventually handle 'definition-ref' as well
raise TypeError(
f'{choice["type"]!r} is not a valid discriminated union variant;'
' should be a `BaseModel` or `dataclass`'
)
else:
if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice):
# In this case, this inner tagged-union is compatible with the outer tagged-union,
# and its choices can be coalesced into the outer TaggedUnionSchema.
subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
# Reverse the choices list before extending the stack so that they get handled in the order they occur
self._choices_to_handle.extend(subchoices[::-1])
return
inferred_discriminator_values = self._infer_discriminator_values_for_choice(choice, source_name=None)
self._set_unique_choice_for_values(choice, inferred_discriminator_values)
def _is_discriminator_shared(self, choice: core_schema.TaggedUnionSchema) -> bool:
"""This method returns a boolean indicating whether the discriminator for the `choice`
is the same as that being used for the outermost tagged union. This is used to
determine whether this TaggedUnionSchema choice should be "coalesced" into the top level,
or whether it should be treated as a separate (nested) choice.
"""
inner_discriminator = choice['discriminator']
return inner_discriminator == self.discriminator or (
isinstance(inner_discriminator, list)
and (self.discriminator in inner_discriminator or [self.discriminator] in inner_discriminator)
)
def _infer_discriminator_values_for_choice( # noqa C901
self, choice: core_schema.CoreSchema, source_name: str | None
) -> list[str | int]:
"""This function recurses over `choice`, extracting all discriminator values that should map to this choice.
`model_name` is accepted for the purpose of producing useful error messages.
"""
if choice['type'] == 'definitions':
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
elif choice['type'] == 'function-plain':
raise TypeError(
f'{choice["type"]!r} is not a valid discriminated union variant;'
' should be a `BaseModel` or `dataclass`'
)
elif _core_utils.is_function_with_inner_schema(choice):
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name)
elif choice['type'] == 'lax-or-strict':
return sorted(
set(
self._infer_discriminator_values_for_choice(choice['lax_schema'], source_name=None)
+ self._infer_discriminator_values_for_choice(choice['strict_schema'], source_name=None)
)
)
elif choice['type'] == 'tagged-union':
values: list[str | int] = []
# Ignore str/int "choices" since these are just references to other choices
subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))]
for subchoice in subchoices:
subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None)
values.extend(subchoice_values)
return values
elif choice['type'] == 'union':
values = []
for subchoice in choice['choices']:
subchoice_schema = subchoice[0] if isinstance(subchoice, tuple) else subchoice
subchoice_values = self._infer_discriminator_values_for_choice(subchoice_schema, source_name=None)
values.extend(subchoice_values)
return values
elif choice['type'] == 'nullable':
self._should_be_nullable = True
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=None)
elif choice['type'] == 'model':
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
elif choice['type'] == 'dataclass':
return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__)
elif choice['type'] == 'model-fields':
return self._infer_discriminator_values_for_model_choice(choice, source_name=source_name)
elif choice['type'] == 'dataclass-args':
return self._infer_discriminator_values_for_dataclass_choice(choice, source_name=source_name)
elif choice['type'] == 'typed-dict':
return self._infer_discriminator_values_for_typed_dict_choice(choice, source_name=source_name)
elif choice['type'] == 'definition-ref':
schema_ref = choice['schema_ref']
if schema_ref not in self.definitions:
raise MissingDefinitionForUnionRef(schema_ref)
return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name)
else:
raise TypeError(
f'{choice["type"]!r} is not a valid discriminated union variant;'
' should be a `BaseModel` or `dataclass`'
)
def _infer_discriminator_values_for_typed_dict_choice(
self, choice: core_schema.TypedDictSchema, source_name: str | None = None
) -> list[str | int]:
"""This method just extracts the _infer_discriminator_values_for_choice logic specific to TypedDictSchema
for the sake of readability.
"""
source = 'TypedDict' if source_name is None else f'TypedDict {source_name!r}'
field = choice['fields'].get(self.discriminator)
if field is None:
raise PydanticUserError(
f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
)
return self._infer_discriminator_values_for_field(field, source)
def _infer_discriminator_values_for_model_choice(
self, choice: core_schema.ModelFieldsSchema, source_name: str | None = None
) -> list[str | int]:
source = 'ModelFields' if source_name is None else f'Model {source_name!r}'
field = choice['fields'].get(self.discriminator)
if field is None:
raise PydanticUserError(
f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
)
return self._infer_discriminator_values_for_field(field, source)
def _infer_discriminator_values_for_dataclass_choice(
self, choice: core_schema.DataclassArgsSchema, source_name: str | None = None
) -> list[str | int]:
source = 'DataclassArgs' if source_name is None else f'Dataclass {source_name!r}'
for field in choice['fields']:
if field['name'] == self.discriminator:
break
else:
raise PydanticUserError(
f'{source} needs a discriminator field for key {self.discriminator!r}', code='discriminator-no-field'
)
return self._infer_discriminator_values_for_field(field, source)
def _infer_discriminator_values_for_field(self, field: CoreSchemaField, source: str) -> list[str | int]:
if field['type'] == 'computed-field':
# This should never occur as a discriminator, as it is only relevant to serialization
return []
alias = field.get('validation_alias', self.discriminator)
if not isinstance(alias, str):
raise PydanticUserError(
f'Alias {alias!r} is not supported in a discriminated union', code='discriminator-alias-type'
)
if self._discriminator_alias is None:
self._discriminator_alias = alias
elif self._discriminator_alias != alias:
raise PydanticUserError(
f'Aliases for discriminator {self.discriminator!r} must be the same '
f'(got {alias}, {self._discriminator_alias})',
code='discriminator-alias',
)
return self._infer_discriminator_values_for_inner_schema(field['schema'], source)
def _infer_discriminator_values_for_inner_schema(
self, schema: core_schema.CoreSchema, source: str
) -> list[str | int]:
"""When inferring discriminator values for a field, we typically extract the expected values from a literal
schema. This function does that, but also handles nested unions and defaults.
"""
if schema['type'] == 'literal':
return schema['expected']
elif schema['type'] == 'union':
# Generally when multiple values are allowed they should be placed in a single `Literal`, but
# we add this case to handle the situation where a field is annotated as a `Union` of `Literal`s.
# For example, this lets us handle `Union[Literal['key'], Union[Literal['Key'], Literal['KEY']]]`
values: list[Any] = []
for choice in schema['choices']:
choice_schema = choice[0] if isinstance(choice, tuple) else choice
choice_values = self._infer_discriminator_values_for_inner_schema(choice_schema, source)
values.extend(choice_values)
return values
elif schema['type'] == 'default':
# This will happen if the field has a default value; we ignore it while extracting the discriminator values
return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
elif schema['type'] == 'function-after':
# After validators don't affect the discriminator values
return self._infer_discriminator_values_for_inner_schema(schema['schema'], source)
elif schema['type'] in {'function-before', 'function-wrap', 'function-plain'}:
validator_type = repr(schema['type'].split('-')[1])
raise PydanticUserError(
f'Cannot use a mode={validator_type} validator in the'
f' discriminator field {self.discriminator!r} of {source}',
code='discriminator-validator',
)
else:
raise PydanticUserError(
f'{source} needs field {self.discriminator!r} to be of type `Literal`',
code='discriminator-needs-literal',
)
def _set_unique_choice_for_values(self, choice: core_schema.CoreSchema, values: Sequence[str | int]) -> None:
"""This method updates `self.tagged_union_choices` so that all provided (discriminator) `values` map to the
provided `choice`, validating that none of these values already map to another (different) choice.
"""
for discriminator_value in values:
if discriminator_value in self._tagged_union_choices:
# It is okay if `value` is already in tagged_union_choices as long as it maps to the same value.
# Because tagged_union_choices may map values to other values, we need to walk the choices dict
# until we get to a "real" choice, and confirm that is equal to the one assigned.
existing_choice = self._tagged_union_choices[discriminator_value]
if existing_choice != choice:
raise TypeError(
f'Value {discriminator_value!r} for discriminator '
f'{self.discriminator!r} mapped to multiple choices'
)
else:
self._tagged_union_choices[discriminator_value] = choice

View file

@ -1,319 +0,0 @@
"""Private logic related to fields (the `Field()` function and `FieldInfo` class), and arguments to `Annotated`."""
from __future__ import annotations as _annotations
import dataclasses
import sys
import warnings
from copy import copy
from functools import lru_cache
from typing import TYPE_CHECKING, Any
from pydantic_core import PydanticUndefined
from pydantic.errors import PydanticUserError
from . import _typing_extra
from ._config import ConfigWrapper
from ._repr import Representation
from ._typing_extra import get_cls_type_hints_lenient, get_type_hints, is_classvar, is_finalvar
if TYPE_CHECKING:
from annotated_types import BaseMetadata
from ..fields import FieldInfo
from ..main import BaseModel
from ._dataclasses import StandardDataclass
from ._decorators import DecoratorInfos
def get_type_hints_infer_globalns(
obj: Any,
localns: dict[str, Any] | None = None,
include_extras: bool = False,
) -> dict[str, Any]:
"""Gets type hints for an object by inferring the global namespace.
It uses the `typing.get_type_hints`, The only thing that we do here is fetching
global namespace from `obj.__module__` if it is not `None`.
Args:
obj: The object to get its type hints.
localns: The local namespaces.
include_extras: Whether to recursively include annotation metadata.
Returns:
The object type hints.
"""
module_name = getattr(obj, '__module__', None)
globalns: dict[str, Any] | None = None
if module_name:
try:
globalns = sys.modules[module_name].__dict__
except KeyError:
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
pass
return get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)
class PydanticMetadata(Representation):
"""Base class for annotation markers like `Strict`."""
__slots__ = ()
def pydantic_general_metadata(**metadata: Any) -> BaseMetadata:
"""Create a new `_PydanticGeneralMetadata` class with the given metadata.
Args:
**metadata: The metadata to add.
Returns:
The new `_PydanticGeneralMetadata` class.
"""
return _general_metadata_cls()(metadata) # type: ignore
@lru_cache(maxsize=None)
def _general_metadata_cls() -> type[BaseMetadata]:
"""Do it this way to avoid importing `annotated_types` at import time."""
from annotated_types import BaseMetadata
class _PydanticGeneralMetadata(PydanticMetadata, BaseMetadata):
"""Pydantic general metadata like `max_digits`."""
def __init__(self, metadata: Any):
self.__dict__ = metadata
return _PydanticGeneralMetadata # type: ignore
def collect_model_fields( # noqa: C901
cls: type[BaseModel],
bases: tuple[type[Any], ...],
config_wrapper: ConfigWrapper,
types_namespace: dict[str, Any] | None,
*,
typevars_map: dict[Any, Any] | None = None,
) -> tuple[dict[str, FieldInfo], set[str]]:
"""Collect the fields of a nascent pydantic model.
Also collect the names of any ClassVars present in the type hints.
The returned value is a tuple of two items: the fields dict, and the set of ClassVar names.
Args:
cls: BaseModel or dataclass.
bases: Parents of the class, generally `cls.__bases__`.
config_wrapper: The config wrapper instance.
types_namespace: Optional extra namespace to look for types in.
typevars_map: A dictionary mapping type variables to their concrete types.
Returns:
A tuple contains fields and class variables.
Raises:
NameError:
- If there is a conflict between a field name and protected namespaces.
- If there is a field other than `root` in `RootModel`.
- If a field shadows an attribute in the parent model.
"""
from ..fields import FieldInfo
type_hints = get_cls_type_hints_lenient(cls, types_namespace)
# https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
# annotations is only used for finding fields in parent classes
annotations = cls.__dict__.get('__annotations__', {})
fields: dict[str, FieldInfo] = {}
class_vars: set[str] = set()
for ann_name, ann_type in type_hints.items():
if ann_name == 'model_config':
# We never want to treat `model_config` as a field
# Note: we may need to change this logic if/when we introduce a `BareModel` class with no
# protected namespaces (where `model_config` might be allowed as a field name)
continue
for protected_namespace in config_wrapper.protected_namespaces:
if ann_name.startswith(protected_namespace):
for b in bases:
if hasattr(b, ann_name):
from ..main import BaseModel
if not (issubclass(b, BaseModel) and ann_name in b.model_fields):
raise NameError(
f'Field "{ann_name}" conflicts with member {getattr(b, ann_name)}'
f' of protected namespace "{protected_namespace}".'
)
else:
valid_namespaces = tuple(
x for x in config_wrapper.protected_namespaces if not ann_name.startswith(x)
)
warnings.warn(
f'Field "{ann_name}" has conflict with protected namespace "{protected_namespace}".'
'\n\nYou may be able to resolve this warning by setting'
f" `model_config['protected_namespaces'] = {valid_namespaces}`.",
UserWarning,
)
if is_classvar(ann_type):
class_vars.add(ann_name)
continue
if _is_finalvar_with_default_val(ann_type, getattr(cls, ann_name, PydanticUndefined)):
class_vars.add(ann_name)
continue
if not is_valid_field_name(ann_name):
continue
if cls.__pydantic_root_model__ and ann_name != 'root':
raise NameError(
f"Unexpected field with name {ann_name!r}; only 'root' is allowed as a field of a `RootModel`"
)
# when building a generic model with `MyModel[int]`, the generic_origin check makes sure we don't get
# "... shadows an attribute" errors
generic_origin = getattr(cls, '__pydantic_generic_metadata__', {}).get('origin')
for base in bases:
dataclass_fields = {
field.name for field in (dataclasses.fields(base) if dataclasses.is_dataclass(base) else ())
}
if hasattr(base, ann_name):
if base is generic_origin:
# Don't error when "shadowing" of attributes in parametrized generics
continue
if ann_name in dataclass_fields:
# Don't error when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set
# on the class instance.
continue
warnings.warn(
f'Field name "{ann_name}" shadows an attribute in parent "{base.__qualname__}"; ',
UserWarning,
)
try:
default = getattr(cls, ann_name, PydanticUndefined)
if default is PydanticUndefined:
raise AttributeError
except AttributeError:
if ann_name in annotations:
field_info = FieldInfo.from_annotation(ann_type)
else:
# if field has no default value and is not in __annotations__ this means that it is
# defined in a base class and we can take it from there
model_fields_lookup: dict[str, FieldInfo] = {}
for x in cls.__bases__[::-1]:
model_fields_lookup.update(getattr(x, 'model_fields', {}))
if ann_name in model_fields_lookup:
# The field was present on one of the (possibly multiple) base classes
# copy the field to make sure typevar substitutions don't cause issues with the base classes
field_info = copy(model_fields_lookup[ann_name])
else:
# The field was not found on any base classes; this seems to be caused by fields not getting
# generated thanks to models not being fully defined while initializing recursive models.
# Nothing stops us from just creating a new FieldInfo for this type hint, so we do this.
field_info = FieldInfo.from_annotation(ann_type)
else:
field_info = FieldInfo.from_annotated_attribute(ann_type, default)
# attributes which are fields are removed from the class namespace:
# 1. To match the behaviour of annotation-only fields
# 2. To avoid false positives in the NameError check above
try:
delattr(cls, ann_name)
except AttributeError:
pass # indicates the attribute was on a parent class
# Use cls.__dict__['__pydantic_decorators__'] instead of cls.__pydantic_decorators__
# to make sure the decorators have already been built for this exact class
decorators: DecoratorInfos = cls.__dict__['__pydantic_decorators__']
if ann_name in decorators.computed_fields:
raise ValueError("you can't override a field with a computed field")
fields[ann_name] = field_info
if typevars_map:
for field in fields.values():
field.apply_typevars_map(typevars_map, types_namespace)
return fields, class_vars
def _is_finalvar_with_default_val(type_: type[Any], val: Any) -> bool:
from ..fields import FieldInfo
if not is_finalvar(type_):
return False
elif val is PydanticUndefined:
return False
elif isinstance(val, FieldInfo) and (val.default is PydanticUndefined and val.default_factory is None):
return False
else:
return True
def collect_dataclass_fields(
cls: type[StandardDataclass], types_namespace: dict[str, Any] | None, *, typevars_map: dict[Any, Any] | None = None
) -> dict[str, FieldInfo]:
"""Collect the fields of a dataclass.
Args:
cls: dataclass.
types_namespace: Optional extra namespace to look for types in.
typevars_map: A dictionary mapping type variables to their concrete types.
Returns:
The dataclass fields.
"""
from ..fields import FieldInfo
fields: dict[str, FieldInfo] = {}
dataclass_fields: dict[str, dataclasses.Field] = cls.__dataclass_fields__
cls_localns = dict(vars(cls)) # this matches get_cls_type_hints_lenient, but all tests pass with `= None` instead
source_module = sys.modules.get(cls.__module__)
if source_module is not None:
types_namespace = {**source_module.__dict__, **(types_namespace or {})}
for ann_name, dataclass_field in dataclass_fields.items():
ann_type = _typing_extra.eval_type_lenient(dataclass_field.type, types_namespace, cls_localns)
if is_classvar(ann_type):
continue
if (
not dataclass_field.init
and dataclass_field.default == dataclasses.MISSING
and dataclass_field.default_factory == dataclasses.MISSING
):
# TODO: We should probably do something with this so that validate_assignment behaves properly
# Issue: https://github.com/pydantic/pydantic/issues/5470
continue
if isinstance(dataclass_field.default, FieldInfo):
if dataclass_field.default.init_var:
if dataclass_field.default.init is False:
raise PydanticUserError(
f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.',
code='clashing-init-and-init-var',
)
# TODO: same note as above re validate_assignment
continue
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field.default)
else:
field_info = FieldInfo.from_annotated_attribute(ann_type, dataclass_field)
fields[ann_name] = field_info
if field_info.default is not PydanticUndefined and isinstance(getattr(cls, ann_name, field_info), FieldInfo):
# We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
setattr(cls, ann_name, field_info.default)
if typevars_map:
for field in fields.values():
field.apply_typevars_map(typevars_map, types_namespace)
return fields
def is_valid_field_name(name: str) -> bool:
return not name.startswith('_')
def is_valid_privateattr_name(name: str) -> bool:
return name.startswith('_') and not name.startswith('__')

View file

@ -1,23 +0,0 @@
from __future__ import annotations as _annotations
from dataclasses import dataclass
from typing import Union
@dataclass
class PydanticRecursiveRef:
type_ref: str
__name__ = 'PydanticRecursiveRef'
__hash__ = object.__hash__
def __call__(self) -> None:
"""Defining __call__ is necessary for the `typing` module to let you use an instance of
this class as the result of resolving a standard ForwardRef.
"""
def __or__(self, other):
return Union[self, other] # type: ignore
def __ror__(self, other):
return Union[other, self] # type: ignore

File diff suppressed because it is too large Load diff

View file

@ -1,517 +0,0 @@
from __future__ import annotations
import sys
import types
import typing
from collections import ChainMap
from contextlib import contextmanager
from contextvars import ContextVar
from types import prepare_class
from typing import TYPE_CHECKING, Any, Iterator, List, Mapping, MutableMapping, Tuple, TypeVar
from weakref import WeakValueDictionary
import typing_extensions
from ._core_utils import get_type_ref
from ._forward_ref import PydanticRecursiveRef
from ._typing_extra import TypeVarType, typing_base
from ._utils import all_identical, is_model_class
if sys.version_info >= (3, 10):
from typing import _UnionGenericAlias # type: ignore[attr-defined]
if TYPE_CHECKING:
from ..main import BaseModel
GenericTypesCacheKey = Tuple[Any, Any, Tuple[Any, ...]]
# Note: We want to remove LimitedDict, but to do this, we'd need to improve the handling of generics caching.
# Right now, to handle recursive generics, we some types must remain cached for brief periods without references.
# By chaining the WeakValuesDict with a LimitedDict, we have a way to retain caching for all types with references,
# while also retaining a limited number of types even without references. This is generally enough to build
# specific recursive generic models without losing required items out of the cache.
KT = TypeVar('KT')
VT = TypeVar('VT')
_LIMITED_DICT_SIZE = 100
if TYPE_CHECKING:
class LimitedDict(dict, MutableMapping[KT, VT]):
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
...
else:
class LimitedDict(dict):
"""Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage.
Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache.
"""
def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
self.size_limit = size_limit
super().__init__()
def __setitem__(self, __key: Any, __value: Any) -> None:
super().__setitem__(__key, __value)
if len(self) > self.size_limit:
excess = len(self) - self.size_limit + self.size_limit // 10
to_remove = list(self.keys())[:excess]
for key in to_remove:
del self[key]
# weak dictionaries allow the dynamically created parametrized versions of generic models to get collected
# once they are no longer referenced by the caller.
if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9
GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
else:
GenericTypesCache = WeakValueDictionary
if TYPE_CHECKING:
class DeepChainMap(ChainMap[KT, VT]): # type: ignore
...
else:
class DeepChainMap(ChainMap):
"""Variant of ChainMap that allows direct updates to inner scopes.
Taken from https://docs.python.org/3/library/collections.html#collections.ChainMap,
with some light modifications for this use case.
"""
def clear(self) -> None:
for mapping in self.maps:
mapping.clear()
def __setitem__(self, key: KT, value: VT) -> None:
for mapping in self.maps:
mapping[key] = value
def __delitem__(self, key: KT) -> None:
hit = False
for mapping in self.maps:
if key in mapping:
del mapping[key]
hit = True
if not hit:
raise KeyError(key)
# Despite the fact that LimitedDict _seems_ no longer necessary, I'm very nervous to actually remove it
# and discover later on that we need to re-add all this infrastructure...
# _GENERIC_TYPES_CACHE = DeepChainMap(GenericTypesCache(), LimitedDict())
_GENERIC_TYPES_CACHE = GenericTypesCache()
class PydanticGenericMetadata(typing_extensions.TypedDict):
origin: type[BaseModel] | None # analogous to typing._GenericAlias.__origin__
args: tuple[Any, ...] # analogous to typing._GenericAlias.__args__
parameters: tuple[type[Any], ...] # analogous to typing.Generic.__parameters__
def create_generic_submodel(
model_name: str, origin: type[BaseModel], args: tuple[Any, ...], params: tuple[Any, ...]
) -> type[BaseModel]:
"""Dynamically create a submodel of a provided (generic) BaseModel.
This is used when producing concrete parametrizations of generic models. This function
only *creates* the new subclass; the schema/validators/serialization must be updated to
reflect a concrete parametrization elsewhere.
Args:
model_name: The name of the newly created model.
origin: The base class for the new model to inherit from.
args: A tuple of generic metadata arguments.
params: A tuple of generic metadata parameters.
Returns:
The created submodel.
"""
namespace: dict[str, Any] = {'__module__': origin.__module__}
bases = (origin,)
meta, ns, kwds = prepare_class(model_name, bases)
namespace.update(ns)
created_model = meta(
model_name,
bases,
namespace,
__pydantic_generic_metadata__={
'origin': origin,
'args': args,
'parameters': params,
},
__pydantic_reset_parent_namespace__=False,
**kwds,
)
model_module, called_globally = _get_caller_frame_info(depth=3)
if called_globally: # create global reference and therefore allow pickling
object_by_reference = None
reference_name = model_name
reference_module_globals = sys.modules[created_model.__module__].__dict__
while object_by_reference is not created_model:
object_by_reference = reference_module_globals.setdefault(reference_name, created_model)
reference_name += '_'
return created_model
def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
"""Used inside a function to check whether it was called globally.
Args:
depth: The depth to get the frame.
Returns:
A tuple contains `module_name` and `called_globally`.
Raises:
RuntimeError: If the function is not called inside a function.
"""
try:
previous_caller_frame = sys._getframe(depth)
except ValueError as e:
raise RuntimeError('This function must be used inside another function') from e
except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it
return None, False
frame_globals = previous_caller_frame.f_globals
return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals
DictValues: type[Any] = {}.values().__class__
def iter_contained_typevars(v: Any) -> Iterator[TypeVarType]:
"""Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found.
This is inspired as an alternative to directly accessing the `__parameters__` attribute of a GenericAlias,
since __parameters__ of (nested) generic BaseModel subclasses won't show up in that list.
"""
if isinstance(v, TypeVar):
yield v
elif is_model_class(v):
yield from v.__pydantic_generic_metadata__['parameters']
elif isinstance(v, (DictValues, list)):
for var in v:
yield from iter_contained_typevars(var)
else:
args = get_args(v)
for arg in args:
yield from iter_contained_typevars(arg)
def get_args(v: Any) -> Any:
pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
if pydantic_generic_metadata:
return pydantic_generic_metadata.get('args')
return typing_extensions.get_args(v)
def get_origin(v: Any) -> Any:
pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
if pydantic_generic_metadata:
return pydantic_generic_metadata.get('origin')
return typing_extensions.get_origin(v)
def get_standard_typevars_map(cls: type[Any]) -> dict[TypeVarType, Any] | None:
"""Package a generic type's typevars and parametrization (if present) into a dictionary compatible with the
`replace_types` function. Specifically, this works with standard typing generics and typing._GenericAlias.
"""
origin = get_origin(cls)
if origin is None:
return None
if not hasattr(origin, '__parameters__'):
return None
# In this case, we know that cls is a _GenericAlias, and origin is the generic type
# So it is safe to access cls.__args__ and origin.__parameters__
args: tuple[Any, ...] = cls.__args__ # type: ignore
parameters: tuple[TypeVarType, ...] = origin.__parameters__
return dict(zip(parameters, args))
def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVarType, Any] | None:
"""Package a generic BaseModel's typevars and concrete parametrization (if present) into a dictionary compatible
with the `replace_types` function.
Since BaseModel.__class_getitem__ does not produce a typing._GenericAlias, and the BaseModel generic info is
stored in the __pydantic_generic_metadata__ attribute, we need special handling here.
"""
# TODO: This could be unified with `get_standard_typevars_map` if we stored the generic metadata
# in the __origin__, __args__, and __parameters__ attributes of the model.
generic_metadata = cls.__pydantic_generic_metadata__
origin = generic_metadata['origin']
args = generic_metadata['args']
return dict(zip(iter_contained_typevars(origin), args))
def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
"""Return type with all occurrences of `type_map` keys recursively replaced with their values.
Args:
type_: The class or generic alias.
type_map: Mapping from `TypeVar` instance to concrete types.
Returns:
A new type representing the basic structure of `type_` with all
`typevar_map` keys recursively replaced.
Example:
```py
from typing import List, Tuple, Union
from pydantic._internal._generics import replace_types
replace_types(Tuple[str, Union[List[str], float]], {str: int})
#> Tuple[int, Union[List[int], float]]
```
"""
if not type_map:
return type_
type_args = get_args(type_)
origin_type = get_origin(type_)
if origin_type is typing_extensions.Annotated:
annotated_type, *annotations = type_args
annotated = replace_types(annotated_type, type_map)
for annotation in annotations:
annotated = typing_extensions.Annotated[annotated, annotation]
return annotated
# Having type args is a good indicator that this is a typing module
# class instantiation or a generic alias of some sort.
if type_args:
resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
if all_identical(type_args, resolved_type_args):
# If all arguments are the same, there is no need to modify the
# type or create a new object at all
return type_
if (
origin_type is not None
and isinstance(type_, typing_base)
and not isinstance(origin_type, typing_base)
and getattr(type_, '_name', None) is not None
):
# In python < 3.9 generic aliases don't exist so any of these like `list`,
# `type` or `collections.abc.Callable` need to be translated.
# See: https://www.python.org/dev/peps/pep-0585
origin_type = getattr(typing, type_._name)
assert origin_type is not None
# PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__.
# We also cannot use isinstance() since we have to compare types.
if sys.version_info >= (3, 10) and origin_type is types.UnionType:
return _UnionGenericAlias(origin_type, resolved_type_args)
# NotRequired[T] and Required[T] don't support tuple type resolved_type_args, hence the condition below
return origin_type[resolved_type_args[0] if len(resolved_type_args) == 1 else resolved_type_args]
# We handle pydantic generic models separately as they don't have the same
# semantics as "typing" classes or generic aliases
if not origin_type and is_model_class(type_):
parameters = type_.__pydantic_generic_metadata__['parameters']
if not parameters:
return type_
resolved_type_args = tuple(replace_types(t, type_map) for t in parameters)
if all_identical(parameters, resolved_type_args):
return type_
return type_[resolved_type_args]
# Handle special case for typehints that can have lists as arguments.
# `typing.Callable[[int, str], int]` is an example for this.
if isinstance(type_, (List, list)):
resolved_list = list(replace_types(element, type_map) for element in type_)
if all_identical(type_, resolved_list):
return type_
return resolved_list
# If all else fails, we try to resolve the type directly and otherwise just
# return the input with no modifications.
return type_map.get(type_, type_)
def has_instance_in_type(type_: Any, isinstance_target: Any) -> bool:
"""Checks if the type, or any of its arbitrary nested args, satisfy
`isinstance(<type>, isinstance_target)`.
"""
if isinstance(type_, isinstance_target):
return True
type_args = get_args(type_)
origin_type = get_origin(type_)
if origin_type is typing_extensions.Annotated:
annotated_type, *annotations = type_args
return has_instance_in_type(annotated_type, isinstance_target)
# Having type args is a good indicator that this is a typing module
# class instantiation or a generic alias of some sort.
if any(has_instance_in_type(a, isinstance_target) for a in type_args):
return True
# Handle special case for typehints that can have lists as arguments.
# `typing.Callable[[int, str], int]` is an example for this.
if isinstance(type_, (List, list)) and not isinstance(type_, typing_extensions.ParamSpec):
if any(has_instance_in_type(element, isinstance_target) for element in type_):
return True
return False
def check_parameters_count(cls: type[BaseModel], parameters: tuple[Any, ...]) -> None:
"""Check the generic model parameters count is equal.
Args:
cls: The generic model.
parameters: A tuple of passed parameters to the generic model.
Raises:
TypeError: If the passed parameters count is not equal to generic model parameters count.
"""
actual = len(parameters)
expected = len(cls.__pydantic_generic_metadata__['parameters'])
if actual != expected:
description = 'many' if actual > expected else 'few'
raise TypeError(f'Too {description} parameters for {cls}; actual {actual}, expected {expected}')
_generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None)
@contextmanager
def generic_recursion_self_type(
origin: type[BaseModel], args: tuple[Any, ...]
) -> Iterator[PydanticRecursiveRef | None]:
"""This contextmanager should be placed around the recursive calls used to build a generic type,
and accept as arguments the generic origin type and the type arguments being passed to it.
If the same origin and arguments are observed twice, it implies that a self-reference placeholder
can be used while building the core schema, and will produce a schema_ref that will be valid in the
final parent schema.
"""
previously_seen_type_refs = _generic_recursion_cache.get()
if previously_seen_type_refs is None:
previously_seen_type_refs = set()
token = _generic_recursion_cache.set(previously_seen_type_refs)
else:
token = None
try:
type_ref = get_type_ref(origin, args_override=args)
if type_ref in previously_seen_type_refs:
self_type = PydanticRecursiveRef(type_ref=type_ref)
yield self_type
else:
previously_seen_type_refs.add(type_ref)
yield None
finally:
if token:
_generic_recursion_cache.reset(token)
def recursively_defined_type_refs() -> set[str]:
visited = _generic_recursion_cache.get()
if not visited:
return set() # not in a generic recursion, so there are no types
return visited.copy() # don't allow modifications
def get_cached_generic_type_early(parent: type[BaseModel], typevar_values: Any) -> type[BaseModel] | None:
"""The use of a two-stage cache lookup approach was necessary to have the highest performance possible for
repeated calls to `__class_getitem__` on generic types (which may happen in tighter loops during runtime),
while still ensuring that certain alternative parametrizations ultimately resolve to the same type.
As a concrete example, this approach was necessary to make Model[List[T]][int] equal to Model[List[int]].
The approach could be modified to not use two different cache keys at different points, but the
_early_cache_key is optimized to be as quick to compute as possible (for repeated-access speed), and the
_late_cache_key is optimized to be as "correct" as possible, so that two types that will ultimately be the
same after resolving the type arguments will always produce cache hits.
If we wanted to move to only using a single cache key per type, we would either need to always use the
slower/more computationally intensive logic associated with _late_cache_key, or would need to accept
that Model[List[T]][int] is a different type than Model[List[T]][int]. Because we rely on subclass relationships
during validation, I think it is worthwhile to ensure that types that are functionally equivalent are actually
equal.
"""
return _GENERIC_TYPES_CACHE.get(_early_cache_key(parent, typevar_values))
def get_cached_generic_type_late(
parent: type[BaseModel], typevar_values: Any, origin: type[BaseModel], args: tuple[Any, ...]
) -> type[BaseModel] | None:
"""See the docstring of `get_cached_generic_type_early` for more information about the two-stage cache lookup."""
cached = _GENERIC_TYPES_CACHE.get(_late_cache_key(origin, args, typevar_values))
if cached is not None:
set_cached_generic_type(parent, typevar_values, cached, origin, args)
return cached
def set_cached_generic_type(
parent: type[BaseModel],
typevar_values: tuple[Any, ...],
type_: type[BaseModel],
origin: type[BaseModel] | None = None,
args: tuple[Any, ...] | None = None,
) -> None:
"""See the docstring of `get_cached_generic_type_early` for more information about why items are cached with
two different keys.
"""
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values)] = type_
if len(typevar_values) == 1:
_GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values[0])] = type_
if origin and args:
_GENERIC_TYPES_CACHE[_late_cache_key(origin, args, typevar_values)] = type_
def _union_orderings_key(typevar_values: Any) -> Any:
"""This is intended to help differentiate between Union types with the same arguments in different order.
Thanks to caching internal to the `typing` module, it is not possible to distinguish between
List[Union[int, float]] and List[Union[float, int]] (and similarly for other "parent" origins besides List)
because `typing` considers Union[int, float] to be equal to Union[float, int].
However, you _can_ distinguish between (top-level) Union[int, float] vs. Union[float, int].
Because we parse items as the first Union type that is successful, we get slightly more consistent behavior
if we make an effort to distinguish the ordering of items in a union. It would be best if we could _always_
get the exact-correct order of items in the union, but that would require a change to the `typing` module itself.
(See https://github.com/python/cpython/issues/86483 for reference.)
"""
if isinstance(typevar_values, tuple):
args_data = []
for value in typevar_values:
args_data.append(_union_orderings_key(value))
return tuple(args_data)
elif typing_extensions.get_origin(typevar_values) is typing.Union:
return get_args(typevar_values)
else:
return ()
def _early_cache_key(cls: type[BaseModel], typevar_values: Any) -> GenericTypesCacheKey:
"""This is intended for minimal computational overhead during lookups of cached types.
Note that this is overly simplistic, and it's possible that two different cls/typevar_values
inputs would ultimately result in the same type being created in BaseModel.__class_getitem__.
To handle this, we have a fallback _late_cache_key that is checked later if the _early_cache_key
lookup fails, and should result in a cache hit _precisely_ when the inputs to __class_getitem__
would result in the same type.
"""
return cls, typevar_values, _union_orderings_key(typevar_values)
def _late_cache_key(origin: type[BaseModel], args: tuple[Any, ...], typevar_values: Any) -> GenericTypesCacheKey:
"""This is intended for use later in the process of creating a new type, when we have more information
about the exact args that will be passed. If it turns out that a different set of inputs to
__class_getitem__ resulted in the same inputs to the generic type creation process, we can still
return the cached type, and update the cache with the _early_cache_key as well.
"""
# The _union_orderings_key is placed at the start here to ensure there cannot be a collision with an
# _early_cache_key, as that function will always produce a BaseModel subclass as the first item in the key,
# whereas this function will always produce a tuple as the first item in the key.
return _union_orderings_key(typevar_values), origin, args

View file

@ -1,26 +0,0 @@
"""Git utilities, adopted from mypy's git utilities (https://github.com/python/mypy/blob/master/mypy/git.py)."""
from __future__ import annotations
import os
import subprocess
def is_git_repo(dir: str) -> bool:
"""Is the given directory version-controlled with git?"""
return os.path.exists(os.path.join(dir, '.git'))
def have_git() -> bool:
"""Can we run the git executable?"""
try:
subprocess.check_output(['git', '--help'])
return True
except subprocess.CalledProcessError:
return False
except OSError:
return False
def git_revision(dir: str) -> str:
"""Get the SHA-1 of the HEAD of a git repository."""
return subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD'], cwd=dir).decode('utf-8').strip()

View file

@ -1,10 +0,0 @@
import sys
from typing import Any, Dict
dataclass_kwargs: Dict[str, Any]
# `slots` is available on Python >= 3.10
if sys.version_info >= (3, 10):
slots_true = {'slots': True}
else:
slots_true = {}

View file

@ -1,410 +0,0 @@
from __future__ import annotations
from collections import defaultdict
from copy import copy
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Iterable
from pydantic_core import CoreSchema, PydanticCustomError, to_jsonable_python
from pydantic_core import core_schema as cs
from ._fields import PydanticMetadata
if TYPE_CHECKING:
from ..annotated_handlers import GetJsonSchemaHandler
STRICT = {'strict'}
SEQUENCE_CONSTRAINTS = {'min_length', 'max_length'}
INEQUALITY = {'le', 'ge', 'lt', 'gt'}
NUMERIC_CONSTRAINTS = {'multiple_of', 'allow_inf_nan', *INEQUALITY}
STR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT, 'strip_whitespace', 'to_lower', 'to_upper', 'pattern'}
BYTES_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
LIST_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
TUPLE_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
SET_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
DICT_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
GENERATOR_CONSTRAINTS = {*SEQUENCE_CONSTRAINTS, *STRICT}
FLOAT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
INT_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
BOOL_CONSTRAINTS = STRICT
UUID_CONSTRAINTS = STRICT
DATE_TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
TIMEDELTA_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
TIME_CONSTRAINTS = {*NUMERIC_CONSTRAINTS, *STRICT}
LAX_OR_STRICT_CONSTRAINTS = STRICT
UNION_CONSTRAINTS = {'union_mode'}
URL_CONSTRAINTS = {
'max_length',
'allowed_schemes',
'host_required',
'default_host',
'default_port',
'default_path',
}
TEXT_SCHEMA_TYPES = ('str', 'bytes', 'url', 'multi-host-url')
SEQUENCE_SCHEMA_TYPES = ('list', 'tuple', 'set', 'frozenset', 'generator', *TEXT_SCHEMA_TYPES)
NUMERIC_SCHEMA_TYPES = ('float', 'int', 'date', 'time', 'timedelta', 'datetime')
CONSTRAINTS_TO_ALLOWED_SCHEMAS: dict[str, set[str]] = defaultdict(set)
for constraint in STR_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(TEXT_SCHEMA_TYPES)
for constraint in BYTES_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bytes',))
for constraint in LIST_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('list',))
for constraint in TUPLE_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('tuple',))
for constraint in SET_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('set', 'frozenset'))
for constraint in DICT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('dict',))
for constraint in GENERATOR_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('generator',))
for constraint in FLOAT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('float',))
for constraint in INT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('int',))
for constraint in DATE_TIME_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('date', 'time', 'datetime'))
for constraint in TIMEDELTA_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('timedelta',))
for constraint in TIME_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('time',))
for schema_type in (*TEXT_SCHEMA_TYPES, *SEQUENCE_SCHEMA_TYPES, *NUMERIC_SCHEMA_TYPES, 'typed-dict', 'model'):
CONSTRAINTS_TO_ALLOWED_SCHEMAS['strict'].add(schema_type)
for constraint in UNION_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('union',))
for constraint in URL_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('url', 'multi-host-url'))
for constraint in BOOL_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('bool',))
for constraint in UUID_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('uuid',))
for constraint in LAX_OR_STRICT_CONSTRAINTS:
CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint].update(('lax-or-strict',))
def add_js_update_schema(s: cs.CoreSchema, f: Callable[[], dict[str, Any]]) -> None:
def update_js_schema(s: cs.CoreSchema, handler: GetJsonSchemaHandler) -> dict[str, Any]:
js_schema = handler(s)
js_schema.update(f())
return js_schema
if 'metadata' in s:
metadata = s['metadata']
if 'pydantic_js_functions' in s:
metadata['pydantic_js_functions'].append(update_js_schema)
else:
metadata['pydantic_js_functions'] = [update_js_schema]
else:
s['metadata'] = {'pydantic_js_functions': [update_js_schema]}
def as_jsonable_value(v: Any) -> Any:
if type(v) not in (int, str, float, bytes, bool, type(None)):
return to_jsonable_python(v)
return v
def expand_grouped_metadata(annotations: Iterable[Any]) -> Iterable[Any]:
"""Expand the annotations.
Args:
annotations: An iterable of annotations.
Returns:
An iterable of expanded annotations.
Example:
```py
from annotated_types import Ge, Len
from pydantic._internal._known_annotated_metadata import expand_grouped_metadata
print(list(expand_grouped_metadata([Ge(4), Len(5)])))
#> [Ge(ge=4), MinLen(min_length=5)]
```
"""
import annotated_types as at
from pydantic.fields import FieldInfo # circular import
for annotation in annotations:
if isinstance(annotation, at.GroupedMetadata):
yield from annotation
elif isinstance(annotation, FieldInfo):
yield from annotation.metadata
# this is a bit problematic in that it results in duplicate metadata
# all of our "consumers" can handle it, but it is not ideal
# we probably should split up FieldInfo into:
# - annotated types metadata
# - individual metadata known only to Pydantic
annotation = copy(annotation)
annotation.metadata = []
yield annotation
else:
yield annotation
def apply_known_metadata(annotation: Any, schema: CoreSchema) -> CoreSchema | None: # noqa: C901
"""Apply `annotation` to `schema` if it is an annotation we know about (Gt, Le, etc.).
Otherwise return `None`.
This does not handle all known annotations. If / when it does, it can always
return a CoreSchema and return the unmodified schema if the annotation should be ignored.
Assumes that GroupedMetadata has already been expanded via `expand_grouped_metadata`.
Args:
annotation: The annotation.
schema: The schema.
Returns:
An updated schema with annotation if it is an annotation we know about, `None` otherwise.
Raises:
PydanticCustomError: If `Predicate` fails.
"""
import annotated_types as at
from . import _validators
schema = schema.copy()
schema_update, other_metadata = collect_known_metadata([annotation])
schema_type = schema['type']
for constraint, value in schema_update.items():
if constraint not in CONSTRAINTS_TO_ALLOWED_SCHEMAS:
raise ValueError(f'Unknown constraint {constraint}')
allowed_schemas = CONSTRAINTS_TO_ALLOWED_SCHEMAS[constraint]
if schema_type in allowed_schemas:
if constraint == 'union_mode' and schema_type == 'union':
schema['mode'] = value # type: ignore # schema is UnionSchema
else:
schema[constraint] = value
continue
if constraint == 'allow_inf_nan' and value is False:
return cs.no_info_after_validator_function(
_validators.forbid_inf_nan_check,
schema,
)
elif constraint == 'pattern':
# insert a str schema to make sure the regex engine matches
return cs.chain_schema(
[
schema,
cs.str_schema(pattern=value),
]
)
elif constraint == 'gt':
s = cs.no_info_after_validator_function(
partial(_validators.greater_than_validator, gt=value),
schema,
)
add_js_update_schema(s, lambda: {'gt': as_jsonable_value(value)})
return s
elif constraint == 'ge':
return cs.no_info_after_validator_function(
partial(_validators.greater_than_or_equal_validator, ge=value),
schema,
)
elif constraint == 'lt':
return cs.no_info_after_validator_function(
partial(_validators.less_than_validator, lt=value),
schema,
)
elif constraint == 'le':
return cs.no_info_after_validator_function(
partial(_validators.less_than_or_equal_validator, le=value),
schema,
)
elif constraint == 'multiple_of':
return cs.no_info_after_validator_function(
partial(_validators.multiple_of_validator, multiple_of=value),
schema,
)
elif constraint == 'min_length':
s = cs.no_info_after_validator_function(
partial(_validators.min_length_validator, min_length=value),
schema,
)
add_js_update_schema(s, lambda: {'minLength': (as_jsonable_value(value))})
return s
elif constraint == 'max_length':
s = cs.no_info_after_validator_function(
partial(_validators.max_length_validator, max_length=value),
schema,
)
add_js_update_schema(s, lambda: {'maxLength': (as_jsonable_value(value))})
return s
elif constraint == 'strip_whitespace':
return cs.chain_schema(
[
schema,
cs.str_schema(strip_whitespace=True),
]
)
elif constraint == 'to_lower':
return cs.chain_schema(
[
schema,
cs.str_schema(to_lower=True),
]
)
elif constraint == 'to_upper':
return cs.chain_schema(
[
schema,
cs.str_schema(to_upper=True),
]
)
elif constraint == 'min_length':
return cs.no_info_after_validator_function(
partial(_validators.min_length_validator, min_length=annotation.min_length),
schema,
)
elif constraint == 'max_length':
return cs.no_info_after_validator_function(
partial(_validators.max_length_validator, max_length=annotation.max_length),
schema,
)
else:
raise RuntimeError(f'Unable to apply constraint {constraint} to schema {schema_type}')
for annotation in other_metadata:
if isinstance(annotation, at.Gt):
return cs.no_info_after_validator_function(
partial(_validators.greater_than_validator, gt=annotation.gt),
schema,
)
elif isinstance(annotation, at.Ge):
return cs.no_info_after_validator_function(
partial(_validators.greater_than_or_equal_validator, ge=annotation.ge),
schema,
)
elif isinstance(annotation, at.Lt):
return cs.no_info_after_validator_function(
partial(_validators.less_than_validator, lt=annotation.lt),
schema,
)
elif isinstance(annotation, at.Le):
return cs.no_info_after_validator_function(
partial(_validators.less_than_or_equal_validator, le=annotation.le),
schema,
)
elif isinstance(annotation, at.MultipleOf):
return cs.no_info_after_validator_function(
partial(_validators.multiple_of_validator, multiple_of=annotation.multiple_of),
schema,
)
elif isinstance(annotation, at.MinLen):
return cs.no_info_after_validator_function(
partial(_validators.min_length_validator, min_length=annotation.min_length),
schema,
)
elif isinstance(annotation, at.MaxLen):
return cs.no_info_after_validator_function(
partial(_validators.max_length_validator, max_length=annotation.max_length),
schema,
)
elif isinstance(annotation, at.Predicate):
predicate_name = f'{annotation.func.__qualname__} ' if hasattr(annotation.func, '__qualname__') else ''
def val_func(v: Any) -> Any:
# annotation.func may also raise an exception, let it pass through
if not annotation.func(v):
raise PydanticCustomError(
'predicate_failed',
f'Predicate {predicate_name}failed', # type: ignore
)
return v
return cs.no_info_after_validator_function(val_func, schema)
# ignore any other unknown metadata
return None
return schema
def collect_known_metadata(annotations: Iterable[Any]) -> tuple[dict[str, Any], list[Any]]:
"""Split `annotations` into known metadata and unknown annotations.
Args:
annotations: An iterable of annotations.
Returns:
A tuple contains a dict of known metadata and a list of unknown annotations.
Example:
```py
from annotated_types import Gt, Len
from pydantic._internal._known_annotated_metadata import collect_known_metadata
print(collect_known_metadata([Gt(1), Len(42), ...]))
#> ({'gt': 1, 'min_length': 42}, [Ellipsis])
```
"""
import annotated_types as at
annotations = expand_grouped_metadata(annotations)
res: dict[str, Any] = {}
remaining: list[Any] = []
for annotation in annotations:
# isinstance(annotation, PydanticMetadata) also covers ._fields:_PydanticGeneralMetadata
if isinstance(annotation, PydanticMetadata):
res.update(annotation.__dict__)
# we don't use dataclasses.asdict because that recursively calls asdict on the field values
elif isinstance(annotation, at.MinLen):
res.update({'min_length': annotation.min_length})
elif isinstance(annotation, at.MaxLen):
res.update({'max_length': annotation.max_length})
elif isinstance(annotation, at.Gt):
res.update({'gt': annotation.gt})
elif isinstance(annotation, at.Ge):
res.update({'ge': annotation.ge})
elif isinstance(annotation, at.Lt):
res.update({'lt': annotation.lt})
elif isinstance(annotation, at.Le):
res.update({'le': annotation.le})
elif isinstance(annotation, at.MultipleOf):
res.update({'multiple_of': annotation.multiple_of})
elif isinstance(annotation, type) and issubclass(annotation, PydanticMetadata):
# also support PydanticMetadata classes being used without initialisation,
# e.g. `Annotated[int, Strict]` as well as `Annotated[int, Strict()]`
res.update({k: v for k, v in vars(annotation).items() if not k.startswith('_')})
else:
remaining.append(annotation)
# Nones can sneak in but pydantic-core will reject them
# it'd be nice to clean things up so we don't put in None (we probably don't _need_ to, it was just easier)
# but this is simple enough to kick that can down the road
res = {k: v for k, v in res.items() if v is not None}
return res, remaining
def check_metadata(metadata: dict[str, Any], allowed: Iterable[str], source_type: Any) -> None:
"""A small utility function to validate that the given metadata can be applied to the target.
More than saving lines of code, this gives us a consistent error message for all of our internal implementations.
Args:
metadata: A dict of metadata.
allowed: An iterable of allowed metadata.
source_type: The source type.
Raises:
TypeError: If there is metadatas that can't be applied on source type.
"""
unknown = metadata.keys() - set(allowed)
if unknown:
raise TypeError(
f'The following constraints cannot be applied to {source_type!r}: {", ".join([f"{k!r}" for k in unknown])}'
)

View file

@ -1,140 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Callable, Generic, TypeVar
from pydantic_core import SchemaSerializer, SchemaValidator
from typing_extensions import Literal
from ..errors import PydanticErrorCodes, PydanticUserError
if TYPE_CHECKING:
from ..dataclasses import PydanticDataclass
from ..main import BaseModel
ValSer = TypeVar('ValSer', SchemaValidator, SchemaSerializer)
class MockValSer(Generic[ValSer]):
"""Mocker for `pydantic_core.SchemaValidator` or `pydantic_core.SchemaSerializer` which optionally attempts to
rebuild the thing it's mocking when one of its methods is accessed and raises an error if that fails.
"""
__slots__ = '_error_message', '_code', '_val_or_ser', '_attempt_rebuild'
def __init__(
self,
error_message: str,
*,
code: PydanticErrorCodes,
val_or_ser: Literal['validator', 'serializer'],
attempt_rebuild: Callable[[], ValSer | None] | None = None,
) -> None:
self._error_message = error_message
self._val_or_ser = SchemaValidator if val_or_ser == 'validator' else SchemaSerializer
self._code: PydanticErrorCodes = code
self._attempt_rebuild = attempt_rebuild
def __getattr__(self, item: str) -> None:
__tracebackhide__ = True
if self._attempt_rebuild:
val_ser = self._attempt_rebuild()
if val_ser is not None:
return getattr(val_ser, item)
# raise an AttributeError if `item` doesn't exist
getattr(self._val_or_ser, item)
raise PydanticUserError(self._error_message, code=self._code)
def rebuild(self) -> ValSer | None:
if self._attempt_rebuild:
val_ser = self._attempt_rebuild()
if val_ser is not None:
return val_ser
else:
raise PydanticUserError(self._error_message, code=self._code)
return None
def set_model_mocks(cls: type[BaseModel], cls_name: str, undefined_name: str = 'all referenced types') -> None:
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a model.
Args:
cls: The model class to set the mocks on
cls_name: Name of the model class, used in error messages
undefined_name: Name of the undefined thing, used in error messages
"""
undefined_type_error_message = (
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
f' then call `{cls_name}.model_rebuild()`.'
)
def attempt_rebuild_validator() -> SchemaValidator | None:
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
return cls.__pydantic_validator__
else:
return None
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_validator,
)
def attempt_rebuild_serializer() -> SchemaSerializer | None:
if cls.model_rebuild(raise_errors=False, _parent_namespace_depth=5) is not False:
return cls.__pydantic_serializer__
else:
return None
cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='serializer',
attempt_rebuild=attempt_rebuild_serializer,
)
def set_dataclass_mocks(
cls: type[PydanticDataclass], cls_name: str, undefined_name: str = 'all referenced types'
) -> None:
"""Set `__pydantic_validator__` and `__pydantic_serializer__` to `MockValSer`s on a dataclass.
Args:
cls: The model class to set the mocks on
cls_name: Name of the model class, used in error messages
undefined_name: Name of the undefined thing, used in error messages
"""
from ..dataclasses import rebuild_dataclass
undefined_type_error_message = (
f'`{cls_name}` is not fully defined; you should define {undefined_name},'
f' then call `pydantic.dataclasses.rebuild_dataclass({cls_name})`.'
)
def attempt_rebuild_validator() -> SchemaValidator | None:
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
return cls.__pydantic_validator__
else:
return None
cls.__pydantic_validator__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_validator,
)
def attempt_rebuild_serializer() -> SchemaSerializer | None:
if rebuild_dataclass(cls, raise_errors=False, _parent_namespace_depth=5) is not False:
return cls.__pydantic_serializer__
else:
return None
cls.__pydantic_serializer__ = MockValSer( # type: ignore[assignment]
undefined_type_error_message,
code='class-not-fully-defined',
val_or_ser='validator',
attempt_rebuild=attempt_rebuild_serializer,
)

View file

@ -1,637 +0,0 @@
"""Private logic for creating models."""
from __future__ import annotations as _annotations
import operator
import typing
import warnings
import weakref
from abc import ABCMeta
from functools import partial
from types import FunctionType
from typing import Any, Callable, Generic
import typing_extensions
from pydantic_core import PydanticUndefined, SchemaSerializer
from typing_extensions import dataclass_transform, deprecated
from ..errors import PydanticUndefinedAnnotation, PydanticUserError
from ..plugin._schema_validator import create_schema_validator
from ..warnings import GenericBeforeBaseModelWarning, PydanticDeprecatedSince20
from ._config import ConfigWrapper
from ._decorators import DecoratorInfos, PydanticDescriptorProxy, get_attribute_from_bases
from ._fields import collect_model_fields, is_valid_field_name, is_valid_privateattr_name
from ._generate_schema import GenerateSchema
from ._generics import PydanticGenericMetadata, get_model_typevars_map
from ._mock_val_ser import MockValSer, set_model_mocks
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._signature import generate_pydantic_signature
from ._typing_extra import get_cls_types_namespace, is_annotated, is_classvar, parent_frame_namespace
from ._utils import ClassAttribute, SafeGetItemProxy
from ._validate_call import ValidateCallWrapper
if typing.TYPE_CHECKING:
from ..fields import Field as PydanticModelField
from ..fields import FieldInfo, ModelPrivateAttr
from ..main import BaseModel
else:
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
# and https://youtrack.jetbrains.com/issue/PY-51428
DeprecationWarning = PydanticDeprecatedSince20
PydanticModelField = object()
object_setattr = object.__setattr__
class _ModelNamespaceDict(dict):
"""A dictionary subclass that intercepts attribute setting on model classes and
warns about overriding of decorators.
"""
def __setitem__(self, k: str, v: object) -> None:
existing: Any = self.get(k, None)
if existing and v is not existing and isinstance(existing, PydanticDescriptorProxy):
warnings.warn(f'`{k}` overrides an existing Pydantic `{existing.decorator_info.decorator_repr}` decorator')
return super().__setitem__(k, v)
@dataclass_transform(kw_only_default=True, field_specifiers=(PydanticModelField,))
class ModelMetaclass(ABCMeta):
def __new__(
mcs,
cls_name: str,
bases: tuple[type[Any], ...],
namespace: dict[str, Any],
__pydantic_generic_metadata__: PydanticGenericMetadata | None = None,
__pydantic_reset_parent_namespace__: bool = True,
_create_model_module: str | None = None,
**kwargs: Any,
) -> type:
"""Metaclass for creating Pydantic models.
Args:
cls_name: The name of the class to be created.
bases: The base classes of the class to be created.
namespace: The attribute dictionary of the class to be created.
__pydantic_generic_metadata__: Metadata for generic models.
__pydantic_reset_parent_namespace__: Reset parent namespace.
_create_model_module: The module of the class to be created, if created by `create_model`.
**kwargs: Catch-all for any other keyword arguments.
Returns:
The new class created by the metaclass.
"""
# Note `ModelMetaclass` refers to `BaseModel`, but is also used to *create* `BaseModel`, so we rely on the fact
# that `BaseModel` itself won't have any bases, but any subclass of it will, to determine whether the `__new__`
# call we're in the middle of is for the `BaseModel` class.
if bases:
base_field_names, class_vars, base_private_attributes = mcs._collect_bases_data(bases)
config_wrapper = ConfigWrapper.for_model(bases, namespace, kwargs)
namespace['model_config'] = config_wrapper.config_dict
private_attributes = inspect_namespace(
namespace, config_wrapper.ignored_types, class_vars, base_field_names
)
if private_attributes:
original_model_post_init = get_model_post_init(namespace, bases)
if original_model_post_init is not None:
# if there are private_attributes and a model_post_init function, we handle both
def wrapped_model_post_init(self: BaseModel, __context: Any) -> None:
"""We need to both initialize private attributes and call the user-defined model_post_init
method.
"""
init_private_attributes(self, __context)
original_model_post_init(self, __context)
namespace['model_post_init'] = wrapped_model_post_init
else:
namespace['model_post_init'] = init_private_attributes
namespace['__class_vars__'] = class_vars
namespace['__private_attributes__'] = {**base_private_attributes, **private_attributes}
cls: type[BaseModel] = super().__new__(mcs, cls_name, bases, namespace, **kwargs) # type: ignore
from ..main import BaseModel
mro = cls.__mro__
if Generic in mro and mro.index(Generic) < mro.index(BaseModel):
warnings.warn(
GenericBeforeBaseModelWarning(
'Classes should inherit from `BaseModel` before generic classes (e.g. `typing.Generic[T]`) '
'for pydantic generics to work properly.'
),
stacklevel=2,
)
cls.__pydantic_custom_init__ = not getattr(cls.__init__, '__pydantic_base_init__', False)
cls.__pydantic_post_init__ = None if cls.model_post_init is BaseModel.model_post_init else 'model_post_init'
cls.__pydantic_decorators__ = DecoratorInfos.build(cls)
# Use the getattr below to grab the __parameters__ from the `typing.Generic` parent class
if __pydantic_generic_metadata__:
cls.__pydantic_generic_metadata__ = __pydantic_generic_metadata__
else:
parent_parameters = getattr(cls, '__pydantic_generic_metadata__', {}).get('parameters', ())
parameters = getattr(cls, '__parameters__', None) or parent_parameters
if parameters and parent_parameters and not all(x in parameters for x in parent_parameters):
combined_parameters = parent_parameters + tuple(x for x in parameters if x not in parent_parameters)
parameters_str = ', '.join([str(x) for x in combined_parameters])
generic_type_label = f'typing.Generic[{parameters_str}]'
error_message = (
f'All parameters must be present on typing.Generic;'
f' you should inherit from {generic_type_label}.'
)
if Generic not in bases: # pragma: no cover
# We raise an error here not because it is desirable, but because some cases are mishandled.
# It would be nice to remove this error and still have things behave as expected, it's just
# challenging because we are using a custom `__class_getitem__` to parametrize generic models,
# and not returning a typing._GenericAlias from it.
bases_str = ', '.join([x.__name__ for x in bases] + [generic_type_label])
error_message += (
f' Note: `typing.Generic` must go last: `class {cls.__name__}({bases_str}): ...`)'
)
raise TypeError(error_message)
cls.__pydantic_generic_metadata__ = {
'origin': None,
'args': (),
'parameters': parameters,
}
cls.__pydantic_complete__ = False # Ensure this specific class gets completed
# preserve `__set_name__` protocol defined in https://peps.python.org/pep-0487
# for attributes not in `new_namespace` (e.g. private attributes)
for name, obj in private_attributes.items():
obj.__set_name__(cls, name)
if __pydantic_reset_parent_namespace__:
cls.__pydantic_parent_namespace__ = build_lenient_weakvaluedict(parent_frame_namespace())
parent_namespace = getattr(cls, '__pydantic_parent_namespace__', None)
if isinstance(parent_namespace, dict):
parent_namespace = unpack_lenient_weakvaluedict(parent_namespace)
types_namespace = get_cls_types_namespace(cls, parent_namespace)
set_model_fields(cls, bases, config_wrapper, types_namespace)
if config_wrapper.frozen and '__hash__' not in namespace:
set_default_hash_func(cls, bases)
complete_model_class(
cls,
cls_name,
config_wrapper,
raise_errors=False,
types_namespace=types_namespace,
create_model_module=_create_model_module,
)
# If this is placed before the complete_model_class call above,
# the generic computed fields return type is set to PydanticUndefined
cls.model_computed_fields = {k: v.info for k, v in cls.__pydantic_decorators__.computed_fields.items()}
# using super(cls, cls) on the next line ensures we only call the parent class's __pydantic_init_subclass__
# I believe the `type: ignore` is only necessary because mypy doesn't realize that this code branch is
# only hit for _proper_ subclasses of BaseModel
super(cls, cls).__pydantic_init_subclass__(**kwargs) # type: ignore[misc]
return cls
else:
# this is the BaseModel class itself being created, no logic required
return super().__new__(mcs, cls_name, bases, namespace, **kwargs)
if not typing.TYPE_CHECKING: # pragma: no branch
# We put `__getattr__` in a non-TYPE_CHECKING block because otherwise, mypy allows arbitrary attribute access
def __getattr__(self, item: str) -> Any:
"""This is necessary to keep attribute access working for class attribute access."""
private_attributes = self.__dict__.get('__private_attributes__')
if private_attributes and item in private_attributes:
return private_attributes[item]
if item == '__pydantic_core_schema__':
# This means the class didn't get a schema generated for it, likely because there was an undefined reference
maybe_mock_validator = getattr(self, '__pydantic_validator__', None)
if isinstance(maybe_mock_validator, MockValSer):
rebuilt_validator = maybe_mock_validator.rebuild()
if rebuilt_validator is not None:
# In this case, a validator was built, and so `__pydantic_core_schema__` should now be set
return getattr(self, '__pydantic_core_schema__')
raise AttributeError(item)
@classmethod
def __prepare__(cls, *args: Any, **kwargs: Any) -> dict[str, object]:
return _ModelNamespaceDict()
def __instancecheck__(self, instance: Any) -> bool:
"""Avoid calling ABC _abc_subclasscheck unless we're pretty sure.
See #3829 and python/cpython#92810
"""
return hasattr(instance, '__pydantic_validator__') and super().__instancecheck__(instance)
@staticmethod
def _collect_bases_data(bases: tuple[type[Any], ...]) -> tuple[set[str], set[str], dict[str, ModelPrivateAttr]]:
from ..main import BaseModel
field_names: set[str] = set()
class_vars: set[str] = set()
private_attributes: dict[str, ModelPrivateAttr] = {}
for base in bases:
if issubclass(base, BaseModel) and base is not BaseModel:
# model_fields might not be defined yet in the case of generics, so we use getattr here:
field_names.update(getattr(base, 'model_fields', {}).keys())
class_vars.update(base.__class_vars__)
private_attributes.update(base.__private_attributes__)
return field_names, class_vars, private_attributes
@property
@deprecated('The `__fields__` attribute is deprecated, use `model_fields` instead.', category=None)
def __fields__(self) -> dict[str, FieldInfo]:
warnings.warn(
'The `__fields__` attribute is deprecated, use `model_fields` instead.', PydanticDeprecatedSince20
)
return self.model_fields # type: ignore
def __dir__(self) -> list[str]:
attributes = list(super().__dir__())
if '__fields__' in attributes:
attributes.remove('__fields__')
return attributes
def init_private_attributes(self: BaseModel, __context: Any) -> None:
"""This function is meant to behave like a BaseModel method to initialise private attributes.
It takes context as an argument since that's what pydantic-core passes when calling it.
Args:
self: The BaseModel instance.
__context: The context.
"""
if getattr(self, '__pydantic_private__', None) is None:
pydantic_private = {}
for name, private_attr in self.__private_attributes__.items():
default = private_attr.get_default()
if default is not PydanticUndefined:
pydantic_private[name] = default
object_setattr(self, '__pydantic_private__', pydantic_private)
def get_model_post_init(namespace: dict[str, Any], bases: tuple[type[Any], ...]) -> Callable[..., Any] | None:
"""Get the `model_post_init` method from the namespace or the class bases, or `None` if not defined."""
if 'model_post_init' in namespace:
return namespace['model_post_init']
from ..main import BaseModel
model_post_init = get_attribute_from_bases(bases, 'model_post_init')
if model_post_init is not BaseModel.model_post_init:
return model_post_init
def inspect_namespace( # noqa C901
namespace: dict[str, Any],
ignored_types: tuple[type[Any], ...],
base_class_vars: set[str],
base_class_fields: set[str],
) -> dict[str, ModelPrivateAttr]:
"""Iterate over the namespace and:
* gather private attributes
* check for items which look like fields but are not (e.g. have no annotation) and warn.
Args:
namespace: The attribute dictionary of the class to be created.
ignored_types: A tuple of ignore types.
base_class_vars: A set of base class class variables.
base_class_fields: A set of base class fields.
Returns:
A dict contains private attributes info.
Raises:
TypeError: If there is a `__root__` field in model.
NameError: If private attribute name is invalid.
PydanticUserError:
- If a field does not have a type annotation.
- If a field on base class was overridden by a non-annotated attribute.
"""
from ..fields import FieldInfo, ModelPrivateAttr, PrivateAttr
all_ignored_types = ignored_types + default_ignored_types()
private_attributes: dict[str, ModelPrivateAttr] = {}
raw_annotations = namespace.get('__annotations__', {})
if '__root__' in raw_annotations or '__root__' in namespace:
raise TypeError("To define root models, use `pydantic.RootModel` rather than a field called '__root__'")
ignored_names: set[str] = set()
for var_name, value in list(namespace.items()):
if var_name == 'model_config':
continue
elif (
isinstance(value, type)
and value.__module__ == namespace['__module__']
and value.__qualname__.startswith(namespace['__qualname__'])
):
# `value` is a nested type defined in this namespace; don't error
continue
elif isinstance(value, all_ignored_types) or value.__class__.__module__ == 'functools':
ignored_names.add(var_name)
continue
elif isinstance(value, ModelPrivateAttr):
if var_name.startswith('__'):
raise NameError(
'Private attributes must not use dunder names;'
f' use a single underscore prefix instead of {var_name!r}.'
)
elif is_valid_field_name(var_name):
raise NameError(
'Private attributes must not use valid field names;'
f' use sunder names, e.g. {"_" + var_name!r} instead of {var_name!r}.'
)
private_attributes[var_name] = value
del namespace[var_name]
elif isinstance(value, FieldInfo) and not is_valid_field_name(var_name):
suggested_name = var_name.lstrip('_') or 'my_field' # don't suggest '' for all-underscore name
raise NameError(
f'Fields must not use names with leading underscores;'
f' e.g., use {suggested_name!r} instead of {var_name!r}.'
)
elif var_name.startswith('__'):
continue
elif is_valid_privateattr_name(var_name):
if var_name not in raw_annotations or not is_classvar(raw_annotations[var_name]):
private_attributes[var_name] = PrivateAttr(default=value)
del namespace[var_name]
elif var_name in base_class_vars:
continue
elif var_name not in raw_annotations:
if var_name in base_class_fields:
raise PydanticUserError(
f'Field {var_name!r} defined on a base class was overridden by a non-annotated attribute. '
f'All field definitions, including overrides, require a type annotation.',
code='model-field-overridden',
)
elif isinstance(value, FieldInfo):
raise PydanticUserError(
f'Field {var_name!r} requires a type annotation', code='model-field-missing-annotation'
)
else:
raise PydanticUserError(
f'A non-annotated attribute was detected: `{var_name} = {value!r}`. All model fields require a '
f'type annotation; if `{var_name}` is not meant to be a field, you may be able to resolve this '
f"error by annotating it as a `ClassVar` or updating `model_config['ignored_types']`.",
code='model-field-missing-annotation',
)
for ann_name, ann_type in raw_annotations.items():
if (
is_valid_privateattr_name(ann_name)
and ann_name not in private_attributes
and ann_name not in ignored_names
and not is_classvar(ann_type)
and ann_type not in all_ignored_types
and getattr(ann_type, '__module__', None) != 'functools'
):
if is_annotated(ann_type):
_, *metadata = typing_extensions.get_args(ann_type)
private_attr = next((v for v in metadata if isinstance(v, ModelPrivateAttr)), None)
if private_attr is not None:
private_attributes[ann_name] = private_attr
continue
private_attributes[ann_name] = PrivateAttr()
return private_attributes
def set_default_hash_func(cls: type[BaseModel], bases: tuple[type[Any], ...]) -> None:
base_hash_func = get_attribute_from_bases(bases, '__hash__')
new_hash_func = make_hash_func(cls)
if base_hash_func in {None, object.__hash__} or getattr(base_hash_func, '__code__', None) == new_hash_func.__code__:
# If `__hash__` is some default, we generate a hash function.
# It will be `None` if not overridden from BaseModel.
# It may be `object.__hash__` if there is another
# parent class earlier in the bases which doesn't override `__hash__` (e.g. `typing.Generic`).
# It may be a value set by `set_default_hash_func` if `cls` is a subclass of another frozen model.
# In the last case we still need a new hash function to account for new `model_fields`.
cls.__hash__ = new_hash_func
def make_hash_func(cls: type[BaseModel]) -> Any:
getter = operator.itemgetter(*cls.model_fields.keys()) if cls.model_fields else lambda _: 0
def hash_func(self: Any) -> int:
try:
return hash(getter(self.__dict__))
except KeyError:
# In rare cases (such as when using the deprecated copy method), the __dict__ may not contain
# all model fields, which is how we can get here.
# getter(self.__dict__) is much faster than any 'safe' method that accounts for missing keys,
# and wrapping it in a `try` doesn't slow things down much in the common case.
return hash(getter(SafeGetItemProxy(self.__dict__)))
return hash_func
def set_model_fields(
cls: type[BaseModel], bases: tuple[type[Any], ...], config_wrapper: ConfigWrapper, types_namespace: dict[str, Any]
) -> None:
"""Collect and set `cls.model_fields` and `cls.__class_vars__`.
Args:
cls: BaseModel or dataclass.
bases: Parents of the class, generally `cls.__bases__`.
config_wrapper: The config wrapper instance.
types_namespace: Optional extra namespace to look for types in.
"""
typevars_map = get_model_typevars_map(cls)
fields, class_vars = collect_model_fields(cls, bases, config_wrapper, types_namespace, typevars_map=typevars_map)
cls.model_fields = fields
cls.__class_vars__.update(class_vars)
for k in class_vars:
# Class vars should not be private attributes
# We remove them _here_ and not earlier because we rely on inspecting the class to determine its classvars,
# but private attributes are determined by inspecting the namespace _prior_ to class creation.
# In the case that a classvar with a leading-'_' is defined via a ForwardRef (e.g., when using
# `__future__.annotations`), we want to remove the private attribute which was detected _before_ we knew it
# evaluated to a classvar
value = cls.__private_attributes__.pop(k, None)
if value is not None and value.default is not PydanticUndefined:
setattr(cls, k, value.default)
def complete_model_class(
cls: type[BaseModel],
cls_name: str,
config_wrapper: ConfigWrapper,
*,
raise_errors: bool = True,
types_namespace: dict[str, Any] | None,
create_model_module: str | None = None,
) -> bool:
"""Finish building a model class.
This logic must be called after class has been created since validation functions must be bound
and `get_type_hints` requires a class object.
Args:
cls: BaseModel or dataclass.
cls_name: The model or dataclass name.
config_wrapper: The config wrapper instance.
raise_errors: Whether to raise errors.
types_namespace: Optional extra namespace to look for types in.
create_model_module: The module of the class to be created, if created by `create_model`.
Returns:
`True` if the model is successfully completed, else `False`.
Raises:
PydanticUndefinedAnnotation: If `PydanticUndefinedAnnotation` occurs in`__get_pydantic_core_schema__`
and `raise_errors=True`.
"""
typevars_map = get_model_typevars_map(cls)
gen_schema = GenerateSchema(
config_wrapper,
types_namespace,
typevars_map,
)
handler = CallbackGetCoreSchemaHandler(
partial(gen_schema.generate_schema, from_dunder_get_core_schema=False),
gen_schema,
ref_mode='unpack',
)
if config_wrapper.defer_build:
set_model_mocks(cls, cls_name)
return False
try:
schema = cls.__get_pydantic_core_schema__(cls, handler)
except PydanticUndefinedAnnotation as e:
if raise_errors:
raise
set_model_mocks(cls, cls_name, f'`{e.name}`')
return False
core_config = config_wrapper.core_config(cls)
try:
schema = gen_schema.clean_schema(schema)
except gen_schema.CollectedInvalid:
set_model_mocks(cls, cls_name)
return False
# debug(schema)
cls.__pydantic_core_schema__ = schema
cls.__pydantic_validator__ = create_schema_validator(
schema,
cls,
create_model_module or cls.__module__,
cls.__qualname__,
'create_model' if create_model_module else 'BaseModel',
core_config,
config_wrapper.plugin_settings,
)
cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config)
cls.__pydantic_complete__ = True
# set __signature__ attr only for model class, but not for its instances
cls.__signature__ = ClassAttribute(
'__signature__',
generate_pydantic_signature(init=cls.__init__, fields=cls.model_fields, config_wrapper=config_wrapper),
)
return True
class _PydanticWeakRef:
"""Wrapper for `weakref.ref` that enables `pickle` serialization.
Cloudpickle fails to serialize `weakref.ref` objects due to an arcane error related
to abstract base classes (`abc.ABC`). This class works around the issue by wrapping
`weakref.ref` instead of subclassing it.
See https://github.com/pydantic/pydantic/issues/6763 for context.
Semantics:
- If not pickled, behaves the same as a `weakref.ref`.
- If pickled along with the referenced object, the same `weakref.ref` behavior
will be maintained between them after unpickling.
- If pickled without the referenced object, after unpickling the underlying
reference will be cleared (`__call__` will always return `None`).
"""
def __init__(self, obj: Any):
if obj is None:
# The object will be `None` upon deserialization if the serialized weakref
# had lost its underlying object.
self._wr = None
else:
self._wr = weakref.ref(obj)
def __call__(self) -> Any:
if self._wr is None:
return None
else:
return self._wr()
def __reduce__(self) -> tuple[Callable, tuple[weakref.ReferenceType | None]]:
return _PydanticWeakRef, (self(),)
def build_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None:
"""Takes an input dictionary, and produces a new value that (invertibly) replaces the values with weakrefs.
We can't just use a WeakValueDictionary because many types (including int, str, etc.) can't be stored as values
in a WeakValueDictionary.
The `unpack_lenient_weakvaluedict` function can be used to reverse this operation.
"""
if d is None:
return None
result = {}
for k, v in d.items():
try:
proxy = _PydanticWeakRef(v)
except TypeError:
proxy = v
result[k] = proxy
return result
def unpack_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None:
"""Inverts the transform performed by `build_lenient_weakvaluedict`."""
if d is None:
return None
result = {}
for k, v in d.items():
if isinstance(v, _PydanticWeakRef):
v = v()
if v is not None:
result[k] = v
else:
result[k] = v
return result
def default_ignored_types() -> tuple[type[Any], ...]:
from ..fields import ComputedFieldInfo
return (
FunctionType,
property,
classmethod,
staticmethod,
PydanticDescriptorProxy,
ComputedFieldInfo,
ValidateCallWrapper,
)

View file

@ -1,117 +0,0 @@
"""Tools to provide pretty/human-readable display of objects."""
from __future__ import annotations as _annotations
import types
import typing
from typing import Any
import typing_extensions
from . import _typing_extra
if typing.TYPE_CHECKING:
ReprArgs: typing_extensions.TypeAlias = 'typing.Iterable[tuple[str | None, Any]]'
RichReprResult: typing_extensions.TypeAlias = (
'typing.Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]'
)
class PlainRepr(str):
"""String class where repr doesn't include quotes. Useful with Representation when you want to return a string
representation of something that is valid (or pseudo-valid) python.
"""
def __repr__(self) -> str:
return str(self)
class Representation:
# Mixin to provide `__str__`, `__repr__`, and `__pretty__` and `__rich_repr__` methods.
# `__pretty__` is used by [devtools](https://python-devtools.helpmanual.io/).
# `__rich_repr__` is used by [rich](https://rich.readthedocs.io/en/stable/pretty.html).
# (this is not a docstring to avoid adding a docstring to classes which inherit from Representation)
# we don't want to use a type annotation here as it can break get_type_hints
__slots__ = tuple() # type: typing.Collection[str]
def __repr_args__(self) -> ReprArgs:
"""Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
Can either return:
* name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
* or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
"""
attrs_names = self.__slots__
if not attrs_names and hasattr(self, '__dict__'):
attrs_names = self.__dict__.keys()
attrs = ((s, getattr(self, s)) for s in attrs_names)
return [(a, v) for a, v in attrs if v is not None]
def __repr_name__(self) -> str:
"""Name of the instance's class, used in __repr__."""
return self.__class__.__name__
def __repr_str__(self, join_str: str) -> str:
return join_str.join(repr(v) if a is None else f'{a}={v!r}' for a, v in self.__repr_args__())
def __pretty__(self, fmt: typing.Callable[[Any], Any], **kwargs: Any) -> typing.Generator[Any, None, None]:
"""Used by devtools (https://python-devtools.helpmanual.io/) to pretty print objects."""
yield self.__repr_name__() + '('
yield 1
for name, value in self.__repr_args__():
if name is not None:
yield name + '='
yield fmt(value)
yield ','
yield 0
yield -1
yield ')'
def __rich_repr__(self) -> RichReprResult:
"""Used by Rich (https://rich.readthedocs.io/en/stable/pretty.html) to pretty print objects."""
for name, field_repr in self.__repr_args__():
if name is None:
yield field_repr
else:
yield name, field_repr
def __str__(self) -> str:
return self.__repr_str__(' ')
def __repr__(self) -> str:
return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
def display_as_type(obj: Any) -> str:
"""Pretty representation of a type, should be as close as possible to the original type definition string.
Takes some logic from `typing._type_repr`.
"""
if isinstance(obj, types.FunctionType):
return obj.__name__
elif obj is ...:
return '...'
elif isinstance(obj, Representation):
return repr(obj)
elif isinstance(obj, typing_extensions.TypeAliasType):
return str(obj)
if not isinstance(obj, (_typing_extra.typing_base, _typing_extra.WithArgsTypes, type)):
obj = obj.__class__
if _typing_extra.origin_is_union(typing_extensions.get_origin(obj)):
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
return f'Union[{args}]'
elif isinstance(obj, _typing_extra.WithArgsTypes):
if typing_extensions.get_origin(obj) == typing_extensions.Literal:
args = ', '.join(map(repr, typing_extensions.get_args(obj)))
else:
args = ', '.join(map(display_as_type, typing_extensions.get_args(obj)))
try:
return f'{obj.__qualname__}[{args}]'
except AttributeError:
return str(obj) # handles TypeAliasType in 3.12
elif isinstance(obj, type):
return obj.__qualname__
else:
return repr(obj).replace('typing.', '').replace('typing_extensions.', '')

View file

@ -1,124 +0,0 @@
"""Types and utility functions used by various other internal tools."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable
from pydantic_core import core_schema
from typing_extensions import Literal
from ..annotated_handlers import GetCoreSchemaHandler, GetJsonSchemaHandler
if TYPE_CHECKING:
from ..json_schema import GenerateJsonSchema, JsonSchemaValue
from ._core_utils import CoreSchemaOrField
from ._generate_schema import GenerateSchema
GetJsonSchemaFunction = Callable[[CoreSchemaOrField, GetJsonSchemaHandler], JsonSchemaValue]
HandlerOverride = Callable[[CoreSchemaOrField], JsonSchemaValue]
class GenerateJsonSchemaHandler(GetJsonSchemaHandler):
"""JsonSchemaHandler implementation that doesn't do ref unwrapping by default.
This is used for any Annotated metadata so that we don't end up with conflicting
modifications to the definition schema.
Used internally by Pydantic, please do not rely on this implementation.
See `GetJsonSchemaHandler` for the handler API.
"""
def __init__(self, generate_json_schema: GenerateJsonSchema, handler_override: HandlerOverride | None) -> None:
self.generate_json_schema = generate_json_schema
self.handler = handler_override or generate_json_schema.generate_inner
self.mode = generate_json_schema.mode
def __call__(self, __core_schema: CoreSchemaOrField) -> JsonSchemaValue:
return self.handler(__core_schema)
def resolve_ref_schema(self, maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue:
"""Resolves `$ref` in the json schema.
This returns the input json schema if there is no `$ref` in json schema.
Args:
maybe_ref_json_schema: The input json schema that may contains `$ref`.
Returns:
Resolved json schema.
Raises:
LookupError: If it can't find the definition for `$ref`.
"""
if '$ref' not in maybe_ref_json_schema:
return maybe_ref_json_schema
ref = maybe_ref_json_schema['$ref']
json_schema = self.generate_json_schema.get_schema_from_definitions(ref)
if json_schema is None:
raise LookupError(
f'Could not find a ref for {ref}.'
' Maybe you tried to call resolve_ref_schema from within a recursive model?'
)
return json_schema
class CallbackGetCoreSchemaHandler(GetCoreSchemaHandler):
"""Wrapper to use an arbitrary function as a `GetCoreSchemaHandler`.
Used internally by Pydantic, please do not rely on this implementation.
See `GetCoreSchemaHandler` for the handler API.
"""
def __init__(
self,
handler: Callable[[Any], core_schema.CoreSchema],
generate_schema: GenerateSchema,
ref_mode: Literal['to-def', 'unpack'] = 'to-def',
) -> None:
self._handler = handler
self._generate_schema = generate_schema
self._ref_mode = ref_mode
def __call__(self, __source_type: Any) -> core_schema.CoreSchema:
schema = self._handler(__source_type)
ref = schema.get('ref')
if self._ref_mode == 'to-def':
if ref is not None:
self._generate_schema.defs.definitions[ref] = schema
return core_schema.definition_reference_schema(ref)
return schema
else: # ref_mode = 'unpack
return self.resolve_ref_schema(schema)
def _get_types_namespace(self) -> dict[str, Any] | None:
return self._generate_schema._types_namespace
def generate_schema(self, __source_type: Any) -> core_schema.CoreSchema:
return self._generate_schema.generate_schema(__source_type)
@property
def field_name(self) -> str | None:
return self._generate_schema.field_name_stack.get()
def resolve_ref_schema(self, maybe_ref_schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""Resolves reference in the core schema.
Args:
maybe_ref_schema: The input core schema that may contains reference.
Returns:
Resolved core schema.
Raises:
LookupError: If it can't find the definition for reference.
"""
if maybe_ref_schema['type'] == 'definition-ref':
ref = maybe_ref_schema['schema_ref']
if ref not in self._generate_schema.defs.definitions:
raise LookupError(
f'Could not find a ref for {ref}.'
' Maybe you tried to call resolve_ref_schema from within a recursive model?'
)
return self._generate_schema.defs.definitions[ref]
elif maybe_ref_schema['type'] == 'definitions':
return self.resolve_ref_schema(maybe_ref_schema['schema'])
return maybe_ref_schema

View file

@ -1,164 +0,0 @@
from __future__ import annotations
import dataclasses
from inspect import Parameter, Signature, signature
from typing import TYPE_CHECKING, Any, Callable
from pydantic_core import PydanticUndefined
from ._config import ConfigWrapper
from ._utils import is_valid_identifier
if TYPE_CHECKING:
from ..fields import FieldInfo
def _field_name_for_signature(field_name: str, field_info: FieldInfo) -> str:
"""Extract the correct name to use for the field when generating a signature.
Assuming the field has a valid alias, this will return the alias. Otherwise, it will return the field name.
First priority is given to the validation_alias, then the alias, then the field name.
Args:
field_name: The name of the field
field_info: The corresponding FieldInfo object.
Returns:
The correct name to use when generating a signature.
"""
def _alias_if_valid(x: Any) -> str | None:
"""Return the alias if it is a valid alias and identifier, else None."""
return x if isinstance(x, str) and is_valid_identifier(x) else None
return _alias_if_valid(field_info.alias) or _alias_if_valid(field_info.validation_alias) or field_name
def _process_param_defaults(param: Parameter) -> Parameter:
"""Modify the signature for a parameter in a dataclass where the default value is a FieldInfo instance.
Args:
param (Parameter): The parameter
Returns:
Parameter: The custom processed parameter
"""
from ..fields import FieldInfo
param_default = param.default
if isinstance(param_default, FieldInfo):
annotation = param.annotation
# Replace the annotation if appropriate
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if annotation == 'Any':
annotation = Any
# Replace the field default
default = param_default.default
if default is PydanticUndefined:
if param_default.default_factory is PydanticUndefined:
default = Signature.empty
else:
# this is used by dataclasses to indicate a factory exists:
default = dataclasses._HAS_DEFAULT_FACTORY # type: ignore
return param.replace(
annotation=annotation, name=_field_name_for_signature(param.name, param_default), default=default
)
return param
def _generate_signature_parameters( # noqa: C901 (ignore complexity, could use a refactor)
init: Callable[..., None],
fields: dict[str, FieldInfo],
config_wrapper: ConfigWrapper,
) -> dict[str, Parameter]:
"""Generate a mapping of parameter names to Parameter objects for a pydantic BaseModel or dataclass."""
from itertools import islice
present_params = signature(init).parameters.values()
merged_params: dict[str, Parameter] = {}
var_kw = None
use_var_kw = False
for param in islice(present_params, 1, None): # skip self arg
# inspect does "clever" things to show annotations as strings because we have
# `from __future__ import annotations` in main, we don't want that
if fields.get(param.name):
# exclude params with init=False
if getattr(fields[param.name], 'init', True) is False:
continue
param = param.replace(name=_field_name_for_signature(param.name, fields[param.name]))
if param.annotation == 'Any':
param = param.replace(annotation=Any)
if param.kind is param.VAR_KEYWORD:
var_kw = param
continue
merged_params[param.name] = param
if var_kw: # if custom init has no var_kw, fields which are not declared in it cannot be passed through
allow_names = config_wrapper.populate_by_name
for field_name, field in fields.items():
# when alias is a str it should be used for signature generation
param_name = _field_name_for_signature(field_name, field)
if field_name in merged_params or param_name in merged_params:
continue
if not is_valid_identifier(param_name):
if allow_names:
param_name = field_name
else:
use_var_kw = True
continue
kwargs = {} if field.is_required() else {'default': field.get_default(call_default_factory=False)}
merged_params[param_name] = Parameter(
param_name, Parameter.KEYWORD_ONLY, annotation=field.rebuild_annotation(), **kwargs
)
if config_wrapper.extra == 'allow':
use_var_kw = True
if var_kw and use_var_kw:
# Make sure the parameter for extra kwargs
# does not have the same name as a field
default_model_signature = [
('self', Parameter.POSITIONAL_ONLY),
('data', Parameter.VAR_KEYWORD),
]
if [(p.name, p.kind) for p in present_params] == default_model_signature:
# if this is the standard model signature, use extra_data as the extra args name
var_kw_name = 'extra_data'
else:
# else start from var_kw
var_kw_name = var_kw.name
# generate a name that's definitely unique
while var_kw_name in fields:
var_kw_name += '_'
merged_params[var_kw_name] = var_kw.replace(name=var_kw_name)
return merged_params
def generate_pydantic_signature(
init: Callable[..., None], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper, is_dataclass: bool = False
) -> Signature:
"""Generate signature for a pydantic BaseModel or dataclass.
Args:
init: The class init.
fields: The model fields.
config_wrapper: The config wrapper instance.
is_dataclass: Whether the model is a dataclass.
Returns:
The dataclass/BaseModel subclass signature.
"""
merged_params = _generate_signature_parameters(init, fields, config_wrapper)
if is_dataclass:
merged_params = {k: _process_param_defaults(v) for k, v in merged_params.items()}
return Signature(parameters=list(merged_params.values()), return_annotation=None)

View file

@ -1,714 +0,0 @@
"""Logic for generating pydantic-core schemas for standard library types.
Import of this module is deferred since it contains imports of many standard library modules.
"""
from __future__ import annotations as _annotations
import collections
import collections.abc
import dataclasses
import decimal
import inspect
import os
import typing
from enum import Enum
from functools import partial
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from typing import Any, Callable, Iterable, TypeVar
import typing_extensions
from pydantic_core import (
CoreSchema,
MultiHostUrl,
PydanticCustomError,
PydanticOmit,
Url,
core_schema,
)
from typing_extensions import get_args, get_origin
from pydantic.errors import PydanticSchemaGenerationError
from pydantic.fields import FieldInfo
from pydantic.types import Strict
from ..config import ConfigDict
from ..json_schema import JsonSchemaValue, update_json_schema
from . import _known_annotated_metadata, _typing_extra, _validators
from ._core_utils import get_type_ref
from ._internal_dataclass import slots_true
from ._schema_generation_shared import GetCoreSchemaHandler, GetJsonSchemaHandler
if typing.TYPE_CHECKING:
from ._generate_schema import GenerateSchema
StdSchemaFunction = Callable[[GenerateSchema, type[Any]], core_schema.CoreSchema]
@dataclasses.dataclass(**slots_true)
class SchemaTransformer:
get_core_schema: Callable[[Any, GetCoreSchemaHandler], CoreSchema]
get_json_schema: Callable[[CoreSchema, GetJsonSchemaHandler], JsonSchemaValue]
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
return self.get_core_schema(source_type, handler)
def __get_pydantic_json_schema__(self, schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
return self.get_json_schema(schema, handler)
def get_enum_core_schema(enum_type: type[Enum], config: ConfigDict) -> CoreSchema:
cases: list[Any] = list(enum_type.__members__.values())
enum_ref = get_type_ref(enum_type)
description = None if not enum_type.__doc__ else inspect.cleandoc(enum_type.__doc__)
if description == 'An enumeration.': # This is the default value provided by enum.EnumMeta.__new__; don't use it
description = None
updates = {'title': enum_type.__name__, 'description': description}
updates = {k: v for k, v in updates.items() if v is not None}
def get_json_schema(_, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
json_schema = handler(core_schema.literal_schema([x.value for x in cases], ref=enum_ref))
original_schema = handler.resolve_ref_schema(json_schema)
update_json_schema(original_schema, updates)
return json_schema
if not cases:
# Use an isinstance check for enums with no cases.
# The most important use case for this is creating TypeVar bounds for generics that should
# be restricted to enums. This is more consistent than it might seem at first, since you can only
# subclass enum.Enum (or subclasses of enum.Enum) if all parent classes have no cases.
# We use the get_json_schema function when an Enum subclass has been declared with no cases
# so that we can still generate a valid json schema.
return core_schema.is_instance_schema(enum_type, metadata={'pydantic_js_functions': [get_json_schema]})
use_enum_values = config.get('use_enum_values', False)
if len(cases) == 1:
expected = repr(cases[0].value)
else:
expected = ', '.join([repr(case.value) for case in cases[:-1]]) + f' or {cases[-1].value!r}'
def to_enum(__input_value: Any) -> Enum:
try:
enum_field = enum_type(__input_value)
if use_enum_values:
return enum_field.value
return enum_field
except ValueError:
# The type: ignore on the next line is to ignore the requirement of LiteralString
raise PydanticCustomError('enum', f'Input should be {expected}', {'expected': expected}) # type: ignore
strict_python_schema = core_schema.is_instance_schema(enum_type)
if use_enum_values:
strict_python_schema = core_schema.chain_schema(
[strict_python_schema, core_schema.no_info_plain_validator_function(lambda x: x.value)]
)
to_enum_validator = core_schema.no_info_plain_validator_function(to_enum)
if issubclass(enum_type, int):
# this handles `IntEnum`, and also `Foobar(int, Enum)`
updates['type'] = 'integer'
lax = core_schema.chain_schema([core_schema.int_schema(), to_enum_validator])
# Disallow float from JSON due to strict mode
strict = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.int_schema()),
python_schema=strict_python_schema,
)
elif issubclass(enum_type, str):
# this handles `StrEnum` (3.11 only), and also `Foobar(str, Enum)`
updates['type'] = 'string'
lax = core_schema.chain_schema([core_schema.str_schema(), to_enum_validator])
strict = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.str_schema()),
python_schema=strict_python_schema,
)
elif issubclass(enum_type, float):
updates['type'] = 'numeric'
lax = core_schema.chain_schema([core_schema.float_schema(), to_enum_validator])
strict = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(to_enum, core_schema.float_schema()),
python_schema=strict_python_schema,
)
else:
lax = to_enum_validator
strict = core_schema.json_or_python_schema(json_schema=to_enum_validator, python_schema=strict_python_schema)
return core_schema.lax_or_strict_schema(
lax_schema=lax, strict_schema=strict, ref=enum_ref, metadata={'pydantic_js_functions': [get_json_schema]}
)
@dataclasses.dataclass(**slots_true)
class InnerSchemaValidator:
"""Use a fixed CoreSchema, avoiding interference from outward annotations."""
core_schema: CoreSchema
js_schema: JsonSchemaValue | None = None
js_core_schema: CoreSchema | None = None
js_schema_update: JsonSchemaValue | None = None
def __get_pydantic_json_schema__(self, _schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue:
if self.js_schema is not None:
return self.js_schema
js_schema = handler(self.js_core_schema or self.core_schema)
if self.js_schema_update is not None:
js_schema.update(self.js_schema_update)
return js_schema
def __get_pydantic_core_schema__(self, _source_type: Any, _handler: GetCoreSchemaHandler) -> CoreSchema:
return self.core_schema
def decimal_prepare_pydantic_annotations(
source: Any, annotations: Iterable[Any], config: ConfigDict
) -> tuple[Any, list[Any]] | None:
if source is not decimal.Decimal:
return None
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
config_allow_inf_nan = config.get('allow_inf_nan')
if config_allow_inf_nan is not None:
metadata.setdefault('allow_inf_nan', config_allow_inf_nan)
_known_annotated_metadata.check_metadata(
metadata, {*_known_annotated_metadata.FLOAT_CONSTRAINTS, 'max_digits', 'decimal_places'}, decimal.Decimal
)
return source, [InnerSchemaValidator(core_schema.decimal_schema(**metadata)), *remaining_annotations]
def datetime_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
import datetime
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
if source_type is datetime.date:
sv = InnerSchemaValidator(core_schema.date_schema(**metadata))
elif source_type is datetime.datetime:
sv = InnerSchemaValidator(core_schema.datetime_schema(**metadata))
elif source_type is datetime.time:
sv = InnerSchemaValidator(core_schema.time_schema(**metadata))
elif source_type is datetime.timedelta:
sv = InnerSchemaValidator(core_schema.timedelta_schema(**metadata))
else:
return None
# check now that we know the source type is correct
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.DATE_TIME_CONSTRAINTS, source_type)
return (source_type, [sv, *remaining_annotations])
def uuid_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
# UUIDs have no constraints - they are fixed length, constructing a UUID instance checks the length
from uuid import UUID
if source_type is not UUID:
return None
return (source_type, [InnerSchemaValidator(core_schema.uuid_schema()), *annotations])
def path_schema_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
import pathlib
if source_type not in {
os.PathLike,
pathlib.Path,
pathlib.PurePath,
pathlib.PosixPath,
pathlib.PurePosixPath,
pathlib.PureWindowsPath,
}:
return None
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.STR_CONSTRAINTS, source_type)
construct_path = pathlib.PurePath if source_type is os.PathLike else source_type
def path_validator(input_value: str) -> os.PathLike[Any]:
try:
return construct_path(input_value)
except TypeError as e:
raise PydanticCustomError('path_type', 'Input is not a valid path') from e
constrained_str_schema = core_schema.str_schema(**metadata)
instance_schema = core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(path_validator, constrained_str_schema),
python_schema=core_schema.is_instance_schema(source_type),
)
strict: bool | None = None
for annotation in annotations:
if isinstance(annotation, Strict):
strict = annotation.strict
schema = core_schema.lax_or_strict_schema(
lax_schema=core_schema.union_schema(
[
instance_schema,
core_schema.no_info_after_validator_function(path_validator, constrained_str_schema),
],
custom_error_type='path_type',
custom_error_message='Input is not a valid path',
strict=True,
),
strict_schema=instance_schema,
serialization=core_schema.to_string_ser_schema(),
strict=strict,
)
return (
source_type,
[
InnerSchemaValidator(schema, js_core_schema=constrained_str_schema, js_schema_update={'format': 'path'}),
*remaining_annotations,
],
)
def dequeue_validator(
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, maxlen: None | int
) -> collections.deque[Any]:
if isinstance(input_value, collections.deque):
maxlens = [v for v in (input_value.maxlen, maxlen) if v is not None]
if maxlens:
maxlen = min(maxlens)
return collections.deque(handler(input_value), maxlen=maxlen)
else:
return collections.deque(handler(input_value), maxlen=maxlen)
@dataclasses.dataclass(**slots_true)
class SequenceValidator:
mapped_origin: type[Any]
item_source_type: type[Any]
min_length: int | None = None
max_length: int | None = None
strict: bool = False
def serialize_sequence_via_list(
self, v: Any, handler: core_schema.SerializerFunctionWrapHandler, info: core_schema.SerializationInfo
) -> Any:
items: list[Any] = []
for index, item in enumerate(v):
try:
v = handler(item, index)
except PydanticOmit:
pass
else:
items.append(v)
if info.mode_is_json():
return items
else:
return self.mapped_origin(items)
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
if self.item_source_type is Any:
items_schema = None
else:
items_schema = handler.generate_schema(self.item_source_type)
metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict}
if self.mapped_origin in (list, set, frozenset):
if self.mapped_origin is list:
constrained_schema = core_schema.list_schema(items_schema, **metadata)
elif self.mapped_origin is set:
constrained_schema = core_schema.set_schema(items_schema, **metadata)
else:
assert self.mapped_origin is frozenset # safety check in case we forget to add a case
constrained_schema = core_schema.frozenset_schema(items_schema, **metadata)
schema = constrained_schema
else:
# safety check in case we forget to add a case
assert self.mapped_origin in (collections.deque, collections.Counter)
if self.mapped_origin is collections.deque:
# if we have a MaxLen annotation might as well set that as the default maxlen on the deque
# this lets us re-use existing metadata annotations to let users set the maxlen on a dequeue
# that e.g. comes from JSON
coerce_instance_wrap = partial(
core_schema.no_info_wrap_validator_function,
partial(dequeue_validator, maxlen=metadata.get('max_length', None)),
)
else:
coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin)
constrained_schema = core_schema.list_schema(items_schema, **metadata)
check_instance = core_schema.json_or_python_schema(
json_schema=core_schema.list_schema(),
python_schema=core_schema.is_instance_schema(self.mapped_origin),
)
serialization = core_schema.wrap_serializer_function_ser_schema(
self.serialize_sequence_via_list, schema=items_schema or core_schema.any_schema(), info_arg=True
)
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
if metadata.get('strict', False):
schema = strict
else:
lax = coerce_instance_wrap(constrained_schema)
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
schema['serialization'] = serialization
return schema
SEQUENCE_ORIGIN_MAP: dict[Any, Any] = {
typing.Deque: collections.deque,
collections.deque: collections.deque,
list: list,
typing.List: list,
set: set,
typing.AbstractSet: set,
typing.Set: set,
frozenset: frozenset,
typing.FrozenSet: frozenset,
typing.Sequence: list,
typing.MutableSequence: list,
typing.MutableSet: set,
# this doesn't handle subclasses of these
# parametrized typing.Set creates one of these
collections.abc.MutableSet: set,
collections.abc.Set: frozenset,
}
def identity(s: CoreSchema) -> CoreSchema:
return s
def sequence_like_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
origin: Any = get_origin(source_type)
mapped_origin = SEQUENCE_ORIGIN_MAP.get(origin, None) if origin else SEQUENCE_ORIGIN_MAP.get(source_type, None)
if mapped_origin is None:
return None
args = get_args(source_type)
if not args:
args = (Any,)
elif len(args) != 1:
raise ValueError('Expected sequence to have exactly 1 generic parameter')
item_source_type = args[0]
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
return (source_type, [SequenceValidator(mapped_origin, item_source_type, **metadata), *remaining_annotations])
MAPPING_ORIGIN_MAP: dict[Any, Any] = {
typing.DefaultDict: collections.defaultdict,
collections.defaultdict: collections.defaultdict,
collections.OrderedDict: collections.OrderedDict,
typing_extensions.OrderedDict: collections.OrderedDict,
dict: dict,
typing.Dict: dict,
collections.Counter: collections.Counter,
typing.Counter: collections.Counter,
# this doesn't handle subclasses of these
typing.Mapping: dict,
typing.MutableMapping: dict,
# parametrized typing.{Mutable}Mapping creates one of these
collections.abc.MutableMapping: dict,
collections.abc.Mapping: dict,
}
def defaultdict_validator(
input_value: Any, handler: core_schema.ValidatorFunctionWrapHandler, default_default_factory: Callable[[], Any]
) -> collections.defaultdict[Any, Any]:
if isinstance(input_value, collections.defaultdict):
default_factory = input_value.default_factory
return collections.defaultdict(default_factory, handler(input_value))
else:
return collections.defaultdict(default_default_factory, handler(input_value))
def get_defaultdict_default_default_factory(values_source_type: Any) -> Callable[[], Any]:
def infer_default() -> Callable[[], Any]:
allowed_default_types: dict[Any, Any] = {
typing.Tuple: tuple,
tuple: tuple,
collections.abc.Sequence: tuple,
collections.abc.MutableSequence: list,
typing.List: list,
list: list,
typing.Sequence: list,
typing.Set: set,
set: set,
typing.MutableSet: set,
collections.abc.MutableSet: set,
collections.abc.Set: frozenset,
typing.MutableMapping: dict,
typing.Mapping: dict,
collections.abc.Mapping: dict,
collections.abc.MutableMapping: dict,
float: float,
int: int,
str: str,
bool: bool,
}
values_type_origin = get_origin(values_source_type) or values_source_type
instructions = 'set using `DefaultDict[..., Annotated[..., Field(default_factory=...)]]`'
if isinstance(values_type_origin, TypeVar):
def type_var_default_factory() -> None:
raise RuntimeError(
'Generic defaultdict cannot be used without a concrete value type or an'
' explicit default factory, ' + instructions
)
return type_var_default_factory
elif values_type_origin not in allowed_default_types:
# a somewhat subjective set of types that have reasonable default values
allowed_msg = ', '.join([t.__name__ for t in set(allowed_default_types.values())])
raise PydanticSchemaGenerationError(
f'Unable to infer a default factory for keys of type {values_source_type}.'
f' Only {allowed_msg} are supported, other types require an explicit default factory'
' ' + instructions
)
return allowed_default_types[values_type_origin]
# Assume Annotated[..., Field(...)]
if _typing_extra.is_annotated(values_source_type):
field_info = next((v for v in get_args(values_source_type) if isinstance(v, FieldInfo)), None)
else:
field_info = None
if field_info and field_info.default_factory:
default_default_factory = field_info.default_factory
else:
default_default_factory = infer_default()
return default_default_factory
@dataclasses.dataclass(**slots_true)
class MappingValidator:
mapped_origin: type[Any]
keys_source_type: type[Any]
values_source_type: type[Any]
min_length: int | None = None
max_length: int | None = None
strict: bool = False
def serialize_mapping_via_dict(self, v: Any, handler: core_schema.SerializerFunctionWrapHandler) -> Any:
return handler(v)
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
if self.keys_source_type is Any:
keys_schema = None
else:
keys_schema = handler.generate_schema(self.keys_source_type)
if self.values_source_type is Any:
values_schema = None
else:
values_schema = handler.generate_schema(self.values_source_type)
metadata = {'min_length': self.min_length, 'max_length': self.max_length, 'strict': self.strict}
if self.mapped_origin is dict:
schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
else:
constrained_schema = core_schema.dict_schema(keys_schema, values_schema, **metadata)
check_instance = core_schema.json_or_python_schema(
json_schema=core_schema.dict_schema(),
python_schema=core_schema.is_instance_schema(self.mapped_origin),
)
if self.mapped_origin is collections.defaultdict:
default_default_factory = get_defaultdict_default_default_factory(self.values_source_type)
coerce_instance_wrap = partial(
core_schema.no_info_wrap_validator_function,
partial(defaultdict_validator, default_default_factory=default_default_factory),
)
else:
coerce_instance_wrap = partial(core_schema.no_info_after_validator_function, self.mapped_origin)
serialization = core_schema.wrap_serializer_function_ser_schema(
self.serialize_mapping_via_dict,
schema=core_schema.dict_schema(
keys_schema or core_schema.any_schema(), values_schema or core_schema.any_schema()
),
info_arg=False,
)
strict = core_schema.chain_schema([check_instance, coerce_instance_wrap(constrained_schema)])
if metadata.get('strict', False):
schema = strict
else:
lax = coerce_instance_wrap(constrained_schema)
schema = core_schema.lax_or_strict_schema(lax_schema=lax, strict_schema=strict)
schema['serialization'] = serialization
return schema
def mapping_like_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
origin: Any = get_origin(source_type)
mapped_origin = MAPPING_ORIGIN_MAP.get(origin, None) if origin else MAPPING_ORIGIN_MAP.get(source_type, None)
if mapped_origin is None:
return None
args = get_args(source_type)
if not args:
args = (Any, Any)
elif mapped_origin is collections.Counter:
# a single generic
if len(args) != 1:
raise ValueError('Expected Counter to have exactly 1 generic parameter')
args = (args[0], int) # keys are always an int
elif len(args) != 2:
raise ValueError('Expected mapping to have exactly 2 generic parameters')
keys_source_type, values_source_type = args
metadata, remaining_annotations = _known_annotated_metadata.collect_known_metadata(annotations)
_known_annotated_metadata.check_metadata(metadata, _known_annotated_metadata.SEQUENCE_CONSTRAINTS, source_type)
return (
source_type,
[
MappingValidator(mapped_origin, keys_source_type, values_source_type, **metadata),
*remaining_annotations,
],
)
def ip_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
def make_strict_ip_schema(tp: type[Any]) -> CoreSchema:
return core_schema.json_or_python_schema(
json_schema=core_schema.no_info_after_validator_function(tp, core_schema.str_schema()),
python_schema=core_schema.is_instance_schema(tp),
)
if source_type is IPv4Address:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_address_validator),
strict_schema=make_strict_ip_schema(IPv4Address),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv4'},
),
*annotations,
]
if source_type is IPv4Network:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_network_validator),
strict_schema=make_strict_ip_schema(IPv4Network),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv4network'},
),
*annotations,
]
if source_type is IPv4Interface:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v4_interface_validator),
strict_schema=make_strict_ip_schema(IPv4Interface),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv4interface'},
),
*annotations,
]
if source_type is IPv6Address:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_address_validator),
strict_schema=make_strict_ip_schema(IPv6Address),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv6'},
),
*annotations,
]
if source_type is IPv6Network:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_network_validator),
strict_schema=make_strict_ip_schema(IPv6Network),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv6network'},
),
*annotations,
]
if source_type is IPv6Interface:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.lax_or_strict_schema(
lax_schema=core_schema.no_info_plain_validator_function(_validators.ip_v6_interface_validator),
strict_schema=make_strict_ip_schema(IPv6Interface),
serialization=core_schema.to_string_ser_schema(),
),
lambda _1, _2: {'type': 'string', 'format': 'ipv6interface'},
),
*annotations,
]
return None
def url_prepare_pydantic_annotations(
source_type: Any, annotations: Iterable[Any], _config: ConfigDict
) -> tuple[Any, list[Any]] | None:
if source_type is Url:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.url_schema(),
lambda cs, handler: handler(cs),
),
*annotations,
]
if source_type is MultiHostUrl:
return source_type, [
SchemaTransformer(
lambda _1, _2: core_schema.multi_host_url_schema(),
lambda cs, handler: handler(cs),
),
*annotations,
]
PREPARE_METHODS: tuple[Callable[[Any, Iterable[Any], ConfigDict], tuple[Any, list[Any]] | None], ...] = (
decimal_prepare_pydantic_annotations,
sequence_like_prepare_pydantic_annotations,
datetime_prepare_pydantic_annotations,
uuid_prepare_pydantic_annotations,
path_schema_prepare_pydantic_annotations,
mapping_like_prepare_pydantic_annotations,
ip_prepare_pydantic_annotations,
url_prepare_pydantic_annotations,
)

View file

@ -1,469 +0,0 @@
"""Logic for interacting with type annotations, mostly extensions, shims and hacks to wrap python's typing module."""
from __future__ import annotations as _annotations
import dataclasses
import sys
import types
import typing
from collections.abc import Callable
from functools import partial
from types import GetSetDescriptorType
from typing import TYPE_CHECKING, Any, Final
from typing_extensions import Annotated, Literal, TypeAliasType, TypeGuard, get_args, get_origin
if TYPE_CHECKING:
from ._dataclasses import StandardDataclass
try:
from typing import _TypingBase # type: ignore[attr-defined]
except ImportError:
from typing import _Final as _TypingBase # type: ignore[attr-defined]
typing_base = _TypingBase
if sys.version_info < (3, 9):
# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
TypingGenericAlias = ()
else:
from typing import GenericAlias as TypingGenericAlias # type: ignore
if sys.version_info < (3, 11):
from typing_extensions import NotRequired, Required
else:
from typing import NotRequired, Required # noqa: F401
if sys.version_info < (3, 10):
def origin_is_union(tp: type[Any] | None) -> bool:
return tp is typing.Union
WithArgsTypes = (TypingGenericAlias,)
else:
def origin_is_union(tp: type[Any] | None) -> bool:
return tp is typing.Union or tp is types.UnionType
WithArgsTypes = typing._GenericAlias, types.GenericAlias, types.UnionType # type: ignore[attr-defined]
if sys.version_info < (3, 10):
NoneType = type(None)
EllipsisType = type(Ellipsis)
else:
from types import NoneType as NoneType
LITERAL_TYPES: set[Any] = {Literal}
if hasattr(typing, 'Literal'):
LITERAL_TYPES.add(typing.Literal) # type: ignore
NONE_TYPES: tuple[Any, ...] = (None, NoneType, *(tp[None] for tp in LITERAL_TYPES))
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type
def is_none_type(type_: Any) -> bool:
return type_ in NONE_TYPES
def is_callable_type(type_: type[Any]) -> bool:
return type_ is Callable or get_origin(type_) is Callable
def is_literal_type(type_: type[Any]) -> bool:
return Literal is not None and get_origin(type_) in LITERAL_TYPES
def literal_values(type_: type[Any]) -> tuple[Any, ...]:
return get_args(type_)
def all_literal_values(type_: type[Any]) -> list[Any]:
"""This method is used to retrieve all Literal values as
Literal can be used recursively (see https://www.python.org/dev/peps/pep-0586)
e.g. `Literal[Literal[Literal[1, 2, 3], "foo"], 5, None]`.
"""
if not is_literal_type(type_):
return [type_]
values = literal_values(type_)
return list(x for value in values for x in all_literal_values(value))
def is_annotated(ann_type: Any) -> bool:
from ._utils import lenient_issubclass
origin = get_origin(ann_type)
return origin is not None and lenient_issubclass(origin, Annotated)
def is_namedtuple(type_: type[Any]) -> bool:
"""Check if a given class is a named tuple.
It can be either a `typing.NamedTuple` or `collections.namedtuple`.
"""
from ._utils import lenient_issubclass
return lenient_issubclass(type_, tuple) and hasattr(type_, '_fields')
test_new_type = typing.NewType('test_new_type', str)
def is_new_type(type_: type[Any]) -> bool:
"""Check whether type_ was created using typing.NewType.
Can't use isinstance because it fails <3.10.
"""
return isinstance(type_, test_new_type.__class__) and hasattr(type_, '__supertype__') # type: ignore[arg-type]
def _check_classvar(v: type[Any] | None) -> bool:
if v is None:
return False
return v.__class__ == typing.ClassVar.__class__ and getattr(v, '_name', None) == 'ClassVar'
def is_classvar(ann_type: type[Any]) -> bool:
if _check_classvar(ann_type) or _check_classvar(get_origin(ann_type)):
return True
# this is an ugly workaround for class vars that contain forward references and are therefore themselves
# forward references, see #3679
if ann_type.__class__ == typing.ForwardRef and ann_type.__forward_arg__.startswith('ClassVar['): # type: ignore
return True
return False
def _check_finalvar(v: type[Any] | None) -> bool:
"""Check if a given type is a `typing.Final` type."""
if v is None:
return False
return v.__class__ == Final.__class__ and (sys.version_info < (3, 8) or getattr(v, '_name', None) == 'Final')
def is_finalvar(ann_type: Any) -> bool:
return _check_finalvar(ann_type) or _check_finalvar(get_origin(ann_type))
def parent_frame_namespace(*, parent_depth: int = 2) -> dict[str, Any] | None:
"""We allow use of items in parent namespace to get around the issue with `get_type_hints` only looking in the
global module namespace. See https://github.com/pydantic/pydantic/issues/2678#issuecomment-1008139014 -> Scope
and suggestion at the end of the next comment by @gvanrossum.
WARNING 1: it matters exactly where this is called. By default, this function will build a namespace from the
parent of where it is called.
WARNING 2: this only looks in the parent namespace, not other parents since (AFAIK) there's no way to collect a
dict of exactly what's in scope. Using `f_back` would work sometimes but would be very wrong and confusing in many
other cases. See https://discuss.python.org/t/is-there-a-way-to-access-parent-nested-namespaces/20659.
"""
frame = sys._getframe(parent_depth)
# if f_back is None, it's the global module namespace and we don't need to include it here
if frame.f_back is None:
return None
else:
return frame.f_locals
def add_module_globals(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]:
module_name = getattr(obj, '__module__', None)
if module_name:
try:
module_globalns = sys.modules[module_name].__dict__
except KeyError:
# happens occasionally, see https://github.com/pydantic/pydantic/issues/2363
pass
else:
if globalns:
return {**module_globalns, **globalns}
else:
# copy module globals to make sure it can't be updated later
return module_globalns.copy()
return globalns or {}
def get_cls_types_namespace(cls: type[Any], parent_namespace: dict[str, Any] | None = None) -> dict[str, Any]:
ns = add_module_globals(cls, parent_namespace)
ns[cls.__name__] = cls
return ns
def get_cls_type_hints_lenient(obj: Any, globalns: dict[str, Any] | None = None) -> dict[str, Any]:
"""Collect annotations from a class, including those from parent classes.
Unlike `typing.get_type_hints`, this function will not error if a forward reference is not resolvable.
"""
hints = {}
for base in reversed(obj.__mro__):
ann = base.__dict__.get('__annotations__')
localns = dict(vars(base))
if ann is not None and ann is not GetSetDescriptorType:
for name, value in ann.items():
hints[name] = eval_type_lenient(value, globalns, localns)
return hints
def eval_type_lenient(value: Any, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None) -> Any:
"""Behaves like typing._eval_type, except it won't raise an error if a forward reference can't be resolved."""
if value is None:
value = NoneType
elif isinstance(value, str):
value = _make_forward_ref(value, is_argument=False, is_class=True)
try:
return eval_type_backport(value, globalns, localns)
except NameError:
# the point of this function is to be tolerant to this case
return value
def eval_type_backport(
value: Any, globalns: dict[str, Any] | None = None, localns: dict[str, Any] | None = None
) -> Any:
"""Like `typing._eval_type`, but falls back to the `eval_type_backport` package if it's
installed to let older Python versions use newer typing features.
Specifically, this transforms `X | Y` into `typing.Union[X, Y]`
and `list[X]` into `typing.List[X]` etc. (for all the types made generic in PEP 585)
if the original syntax is not supported in the current Python version.
"""
try:
return typing._eval_type( # type: ignore
value, globalns, localns
)
except TypeError as e:
if not (isinstance(value, typing.ForwardRef) and is_backport_fixable_error(e)):
raise
try:
from eval_type_backport import eval_type_backport
except ImportError:
raise TypeError(
f'You have a type annotation {value.__forward_arg__!r} '
f'which makes use of newer typing features than are supported in your version of Python. '
f'To handle this error, you should either remove the use of new syntax '
f'or install the `eval_type_backport` package.'
) from e
return eval_type_backport(value, globalns, localns, try_default=False)
def is_backport_fixable_error(e: TypeError) -> bool:
msg = str(e)
return msg.startswith('unsupported operand type(s) for |: ') or "' object is not subscriptable" in msg
def get_function_type_hints(
function: Callable[..., Any], *, include_keys: set[str] | None = None, types_namespace: dict[str, Any] | None = None
) -> dict[str, Any]:
"""Like `typing.get_type_hints`, but doesn't convert `X` to `Optional[X]` if the default value is `None`, also
copes with `partial`.
"""
if isinstance(function, partial):
annotations = function.func.__annotations__
else:
annotations = function.__annotations__
globalns = add_module_globals(function)
type_hints = {}
for name, value in annotations.items():
if include_keys is not None and name not in include_keys:
continue
if value is None:
value = NoneType
elif isinstance(value, str):
value = _make_forward_ref(value)
type_hints[name] = eval_type_backport(value, globalns, types_namespace)
return type_hints
if sys.version_info < (3, 9, 8) or (3, 10) <= sys.version_info < (3, 10, 1):
def _make_forward_ref(
arg: Any,
is_argument: bool = True,
*,
is_class: bool = False,
) -> typing.ForwardRef:
"""Wrapper for ForwardRef that accounts for the `is_class` argument missing in older versions.
The `module` argument is omitted as it breaks <3.9.8, =3.10.0 and isn't used in the calls below.
See https://github.com/python/cpython/pull/28560 for some background.
The backport happened on 3.9.8, see:
https://github.com/pydantic/pydantic/discussions/6244#discussioncomment-6275458,
and on 3.10.1 for the 3.10 branch, see:
https://github.com/pydantic/pydantic/issues/6912
Implemented as EAFP with memory.
"""
return typing.ForwardRef(arg, is_argument)
else:
_make_forward_ref = typing.ForwardRef
if sys.version_info >= (3, 10):
get_type_hints = typing.get_type_hints
else:
"""
For older versions of python, we have a custom implementation of `get_type_hints` which is a close as possible to
the implementation in CPython 3.10.8.
"""
@typing.no_type_check
def get_type_hints( # noqa: C901
obj: Any,
globalns: dict[str, Any] | None = None,
localns: dict[str, Any] | None = None,
include_extras: bool = False,
) -> dict[str, Any]: # pragma: no cover
"""Taken verbatim from python 3.10.8 unchanged, except:
* type annotations of the function definition above.
* prefixing `typing.` where appropriate
* Use `_make_forward_ref` instead of `typing.ForwardRef` to handle the `is_class` argument.
https://github.com/python/cpython/blob/aaaf5174241496afca7ce4d4584570190ff972fe/Lib/typing.py#L1773-L1875
DO NOT CHANGE THIS METHOD UNLESS ABSOLUTELY NECESSARY.
======================================================
Return type hints for an object.
This is often the same as obj.__annotations__, but it handles
forward references encoded as string literals, adds Optional[t] if a
default value equal to None is set and recursively replaces all
'Annotated[T, ...]' with 'T' (unless 'include_extras=True').
The argument may be a module, class, method, or function. The annotations
are returned as a dictionary. For classes, annotations include also
inherited members.
TypeError is raised if the argument is not of a type that can contain
annotations, and an empty dictionary is returned if no annotations are
present.
BEWARE -- the behavior of globalns and localns is counterintuitive
(unless you are familiar with how eval() and exec() work). The
search order is locals first, then globals.
- If no dict arguments are passed, an attempt is made to use the
globals from obj (or the respective module's globals for classes),
and these are also used as the locals. If the object does not appear
to have globals, an empty dictionary is used. For classes, the search
order is globals first then locals.
- If one dict argument is passed, it is used for both globals and
locals.
- If two dict arguments are passed, they specify globals and
locals, respectively.
"""
if getattr(obj, '__no_type_check__', None):
return {}
# Classes require a special treatment.
if isinstance(obj, type):
hints = {}
for base in reversed(obj.__mro__):
if globalns is None:
base_globals = getattr(sys.modules.get(base.__module__, None), '__dict__', {})
else:
base_globals = globalns
ann = base.__dict__.get('__annotations__', {})
if isinstance(ann, types.GetSetDescriptorType):
ann = {}
base_locals = dict(vars(base)) if localns is None else localns
if localns is None and globalns is None:
# This is surprising, but required. Before Python 3.10,
# get_type_hints only evaluated the globalns of
# a class. To maintain backwards compatibility, we reverse
# the globalns and localns order so that eval() looks into
# *base_globals* first rather than *base_locals*.
# This only affects ForwardRefs.
base_globals, base_locals = base_locals, base_globals
for name, value in ann.items():
if value is None:
value = type(None)
if isinstance(value, str):
value = _make_forward_ref(value, is_argument=False, is_class=True)
value = eval_type_backport(value, base_globals, base_locals)
hints[name] = value
if not include_extras and hasattr(typing, '_strip_annotations'):
return {
k: typing._strip_annotations(t) # type: ignore
for k, t in hints.items()
}
else:
return hints
if globalns is None:
if isinstance(obj, types.ModuleType):
globalns = obj.__dict__
else:
nsobj = obj
# Find globalns for the unwrapped object.
while hasattr(nsobj, '__wrapped__'):
nsobj = nsobj.__wrapped__
globalns = getattr(nsobj, '__globals__', {})
if localns is None:
localns = globalns
elif localns is None:
localns = globalns
hints = getattr(obj, '__annotations__', None)
if hints is None:
# Return empty annotations for something that _could_ have them.
if isinstance(obj, typing._allowed_types): # type: ignore
return {}
else:
raise TypeError(f'{obj!r} is not a module, class, method, ' 'or function.')
defaults = typing._get_defaults(obj) # type: ignore
hints = dict(hints)
for name, value in hints.items():
if value is None:
value = type(None)
if isinstance(value, str):
# class-level forward refs were handled above, this must be either
# a module-level annotation or a function argument annotation
value = _make_forward_ref(
value,
is_argument=not isinstance(obj, types.ModuleType),
is_class=False,
)
value = eval_type_backport(value, globalns, localns)
if name in defaults and defaults[name] is None:
value = typing.Optional[value]
hints[name] = value
return hints if include_extras else {k: typing._strip_annotations(t) for k, t in hints.items()} # type: ignore
def is_dataclass(_cls: type[Any]) -> TypeGuard[type[StandardDataclass]]:
# The dataclasses.is_dataclass function doesn't seem to provide TypeGuard functionality,
# so I created this convenience function
return dataclasses.is_dataclass(_cls)
def origin_is_type_alias_type(origin: Any) -> TypeGuard[TypeAliasType]:
return isinstance(origin, TypeAliasType)
if sys.version_info >= (3, 10):
def is_generic_alias(type_: type[Any]) -> bool:
return isinstance(type_, (types.GenericAlias, typing._GenericAlias)) # type: ignore[attr-defined]
else:
def is_generic_alias(type_: type[Any]) -> bool:
return isinstance(type_, typing._GenericAlias) # type: ignore

View file

@ -1,362 +0,0 @@
"""Bucket of reusable internal utilities.
This should be reduced as much as possible with functions only used in one place, moved to that place.
"""
from __future__ import annotations as _annotations
import dataclasses
import keyword
import typing
import weakref
from collections import OrderedDict, defaultdict, deque
from copy import deepcopy
from itertools import zip_longest
from types import BuiltinFunctionType, CodeType, FunctionType, GeneratorType, LambdaType, ModuleType
from typing import Any, Mapping, TypeVar
from typing_extensions import TypeAlias, TypeGuard
from . import _repr, _typing_extra
if typing.TYPE_CHECKING:
MappingIntStrAny: TypeAlias = 'typing.Mapping[int, Any] | typing.Mapping[str, Any]'
AbstractSetIntStr: TypeAlias = 'typing.AbstractSet[int] | typing.AbstractSet[str]'
from ..main import BaseModel
# these are types that are returned unchanged by deepcopy
IMMUTABLE_NON_COLLECTIONS_TYPES: set[type[Any]] = {
int,
float,
complex,
str,
bool,
bytes,
type,
_typing_extra.NoneType,
FunctionType,
BuiltinFunctionType,
LambdaType,
weakref.ref,
CodeType,
# note: including ModuleType will differ from behaviour of deepcopy by not producing error.
# It might be not a good idea in general, but considering that this function used only internally
# against default values of fields, this will allow to actually have a field with module as default value
ModuleType,
NotImplemented.__class__,
Ellipsis.__class__,
}
# these are types that if empty, might be copied with simple copy() instead of deepcopy()
BUILTIN_COLLECTIONS: set[type[Any]] = {
list,
set,
tuple,
frozenset,
dict,
OrderedDict,
defaultdict,
deque,
}
def sequence_like(v: Any) -> bool:
return isinstance(v, (list, tuple, set, frozenset, GeneratorType, deque))
def lenient_isinstance(o: Any, class_or_tuple: type[Any] | tuple[type[Any], ...] | None) -> bool: # pragma: no cover
try:
return isinstance(o, class_or_tuple) # type: ignore[arg-type]
except TypeError:
return False
def lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
except TypeError:
if isinstance(cls, _typing_extra.WithArgsTypes):
return False
raise # pragma: no cover
def is_model_class(cls: Any) -> TypeGuard[type[BaseModel]]:
"""Returns true if cls is a _proper_ subclass of BaseModel, and provides proper type-checking,
unlike raw calls to lenient_issubclass.
"""
from ..main import BaseModel
return lenient_issubclass(cls, BaseModel) and cls is not BaseModel
def is_valid_identifier(identifier: str) -> bool:
"""Checks that a string is a valid identifier and not a Python keyword.
:param identifier: The identifier to test.
:return: True if the identifier is valid.
"""
return identifier.isidentifier() and not keyword.iskeyword(identifier)
KeyType = TypeVar('KeyType')
def deep_update(mapping: dict[KeyType, Any], *updating_mappings: dict[KeyType, Any]) -> dict[KeyType, Any]:
updated_mapping = mapping.copy()
for updating_mapping in updating_mappings:
for k, v in updating_mapping.items():
if k in updated_mapping and isinstance(updated_mapping[k], dict) and isinstance(v, dict):
updated_mapping[k] = deep_update(updated_mapping[k], v)
else:
updated_mapping[k] = v
return updated_mapping
def update_not_none(mapping: dict[Any, Any], **update: Any) -> None:
mapping.update({k: v for k, v in update.items() if v is not None})
T = TypeVar('T')
def unique_list(
input_list: list[T] | tuple[T, ...],
*,
name_factory: typing.Callable[[T], str] = str,
) -> list[T]:
"""Make a list unique while maintaining order.
We update the list if another one with the same name is set
(e.g. model validator overridden in subclass).
"""
result: list[T] = []
result_names: list[str] = []
for v in input_list:
v_name = name_factory(v)
if v_name not in result_names:
result_names.append(v_name)
result.append(v)
else:
result[result_names.index(v_name)] = v
return result
class ValueItems(_repr.Representation):
"""Class for more convenient calculation of excluded or included fields on values."""
__slots__ = ('_items', '_type')
def __init__(self, value: Any, items: AbstractSetIntStr | MappingIntStrAny) -> None:
items = self._coerce_items(items)
if isinstance(value, (list, tuple)):
items = self._normalize_indexes(items, len(value)) # type: ignore
self._items: MappingIntStrAny = items # type: ignore
def is_excluded(self, item: Any) -> bool:
"""Check if item is fully excluded.
:param item: key or index of a value
"""
return self.is_true(self._items.get(item))
def is_included(self, item: Any) -> bool:
"""Check if value is contained in self._items.
:param item: key or index of value
"""
return item in self._items
def for_element(self, e: int | str) -> AbstractSetIntStr | MappingIntStrAny | None:
""":param e: key or index of element on value
:return: raw values for element if self._items is dict and contain needed element
"""
item = self._items.get(e) # type: ignore
return item if not self.is_true(item) else None
def _normalize_indexes(self, items: MappingIntStrAny, v_length: int) -> dict[int | str, Any]:
""":param items: dict or set of indexes which will be normalized
:param v_length: length of sequence indexes of which will be
>>> self._normalize_indexes({0: True, -2: True, -1: True}, 4)
{0: True, 2: True, 3: True}
>>> self._normalize_indexes({'__all__': True}, 4)
{0: True, 1: True, 2: True, 3: True}
"""
normalized_items: dict[int | str, Any] = {}
all_items = None
for i, v in items.items():
if not (isinstance(v, typing.Mapping) or isinstance(v, typing.AbstractSet) or self.is_true(v)):
raise TypeError(f'Unexpected type of exclude value for index "{i}" {v.__class__}')
if i == '__all__':
all_items = self._coerce_value(v)
continue
if not isinstance(i, int):
raise TypeError(
'Excluding fields from a sequence of sub-models or dicts must be performed index-wise: '
'expected integer keys or keyword "__all__"'
)
normalized_i = v_length + i if i < 0 else i
normalized_items[normalized_i] = self.merge(v, normalized_items.get(normalized_i))
if not all_items:
return normalized_items
if self.is_true(all_items):
for i in range(v_length):
normalized_items.setdefault(i, ...)
return normalized_items
for i in range(v_length):
normalized_item = normalized_items.setdefault(i, {})
if not self.is_true(normalized_item):
normalized_items[i] = self.merge(all_items, normalized_item)
return normalized_items
@classmethod
def merge(cls, base: Any, override: Any, intersect: bool = False) -> Any:
"""Merge a `base` item with an `override` item.
Both `base` and `override` are converted to dictionaries if possible.
Sets are converted to dictionaries with the sets entries as keys and
Ellipsis as values.
Each key-value pair existing in `base` is merged with `override`,
while the rest of the key-value pairs are updated recursively with this function.
Merging takes place based on the "union" of keys if `intersect` is
set to `False` (default) and on the intersection of keys if
`intersect` is set to `True`.
"""
override = cls._coerce_value(override)
base = cls._coerce_value(base)
if override is None:
return base
if cls.is_true(base) or base is None:
return override
if cls.is_true(override):
return base if intersect else override
# intersection or union of keys while preserving ordering:
if intersect:
merge_keys = [k for k in base if k in override] + [k for k in override if k in base]
else:
merge_keys = list(base) + [k for k in override if k not in base]
merged: dict[int | str, Any] = {}
for k in merge_keys:
merged_item = cls.merge(base.get(k), override.get(k), intersect=intersect)
if merged_item is not None:
merged[k] = merged_item
return merged
@staticmethod
def _coerce_items(items: AbstractSetIntStr | MappingIntStrAny) -> MappingIntStrAny:
if isinstance(items, typing.Mapping):
pass
elif isinstance(items, typing.AbstractSet):
items = dict.fromkeys(items, ...) # type: ignore
else:
class_name = getattr(items, '__class__', '???')
raise TypeError(f'Unexpected type of exclude value {class_name}')
return items # type: ignore
@classmethod
def _coerce_value(cls, value: Any) -> Any:
if value is None or cls.is_true(value):
return value
return cls._coerce_items(value)
@staticmethod
def is_true(v: Any) -> bool:
return v is True or v is ...
def __repr_args__(self) -> _repr.ReprArgs:
return [(None, self._items)]
if typing.TYPE_CHECKING:
def ClassAttribute(name: str, value: T) -> T:
...
else:
class ClassAttribute:
"""Hide class attribute from its instances."""
__slots__ = 'name', 'value'
def __init__(self, name: str, value: Any) -> None:
self.name = name
self.value = value
def __get__(self, instance: Any, owner: type[Any]) -> None:
if instance is None:
return self.value
raise AttributeError(f'{self.name!r} attribute of {owner.__name__!r} is class-only')
Obj = TypeVar('Obj')
def smart_deepcopy(obj: Obj) -> Obj:
"""Return type as is for immutable built-in types
Use obj.copy() for built-in empty collections
Use copy.deepcopy() for non-empty collections and unknown objects.
"""
obj_type = obj.__class__
if obj_type in IMMUTABLE_NON_COLLECTIONS_TYPES:
return obj # fastest case: obj is immutable and not collection therefore will not be copied anyway
try:
if not obj and obj_type in BUILTIN_COLLECTIONS:
# faster way for empty collections, no need to copy its members
return obj if obj_type is tuple else obj.copy() # tuple doesn't have copy method # type: ignore
except (TypeError, ValueError, RuntimeError):
# do we really dare to catch ALL errors? Seems a bit risky
pass
return deepcopy(obj) # slowest way when we actually might need a deepcopy
_SENTINEL = object()
def all_identical(left: typing.Iterable[Any], right: typing.Iterable[Any]) -> bool:
"""Check that the items of `left` are the same objects as those in `right`.
>>> a, b = object(), object()
>>> all_identical([a, b, a], [a, b, a])
True
>>> all_identical([a, b, [a]], [a, b, [a]]) # new list object, while "equal" is not "identical"
False
"""
for left_item, right_item in zip_longest(left, right, fillvalue=_SENTINEL):
if left_item is not right_item:
return False
return True
@dataclasses.dataclass(frozen=True)
class SafeGetItemProxy:
"""Wrapper redirecting `__getitem__` to `get` with a sentinel value as default
This makes is safe to use in `operator.itemgetter` when some keys may be missing
"""
# Define __slots__manually for performances
# @dataclasses.dataclass() only support slots=True in python>=3.10
__slots__ = ('wrapped',)
wrapped: Mapping[str, Any]
def __getitem__(self, __key: str) -> Any:
return self.wrapped.get(__key, _SENTINEL)
# required to pass the object to operator.itemgetter() instances due to a quirk of typeshed
# https://github.com/python/mypy/issues/13713
# https://github.com/python/typeshed/pull/8785
# Since this is typing-only, hide it in a typing.TYPE_CHECKING block
if typing.TYPE_CHECKING:
def __contains__(self, __key: str) -> bool:
return self.wrapped.__contains__(__key)

View file

@ -1,84 +0,0 @@
from __future__ import annotations as _annotations
import inspect
from functools import partial
from typing import Any, Awaitable, Callable
import pydantic_core
from ..config import ConfigDict
from ..plugin._schema_validator import create_schema_validator
from . import _generate_schema, _typing_extra
from ._config import ConfigWrapper
class ValidateCallWrapper:
"""This is a wrapper around a function that validates the arguments passed to it, and optionally the return value."""
__slots__ = (
'__pydantic_validator__',
'__name__',
'__qualname__',
'__annotations__',
'__dict__', # required for __module__
)
def __init__(self, function: Callable[..., Any], config: ConfigDict | None, validate_return: bool):
if isinstance(function, partial):
func = function.func
schema_type = func
self.__name__ = f'partial({func.__name__})'
self.__qualname__ = f'partial({func.__qualname__})'
self.__module__ = func.__module__
else:
schema_type = function
self.__name__ = function.__name__
self.__qualname__ = function.__qualname__
self.__module__ = function.__module__
namespace = _typing_extra.add_module_globals(function, None)
config_wrapper = ConfigWrapper(config)
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
core_config = config_wrapper.core_config(self)
self.__pydantic_validator__ = create_schema_validator(
schema,
schema_type,
self.__module__,
self.__qualname__,
'validate_call',
core_config,
config_wrapper.plugin_settings,
)
if validate_return:
signature = inspect.signature(function)
return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any
gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace)
schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
validator = create_schema_validator(
schema,
schema_type,
self.__module__,
self.__qualname__,
'validate_call',
core_config,
config_wrapper.plugin_settings,
)
if inspect.iscoroutinefunction(function):
async def return_val_wrapper(aw: Awaitable[Any]) -> None:
return validator.validate_python(await aw)
self.__return_pydantic_validator__ = return_val_wrapper
else:
self.__return_pydantic_validator__ = validator.validate_python
else:
self.__return_pydantic_validator__ = None
def __call__(self, *args: Any, **kwargs: Any) -> Any:
res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
if self.__return_pydantic_validator__:
return self.__return_pydantic_validator__(res)
return res

View file

@ -1,278 +0,0 @@
"""Validator functions for standard library types.
Import of this module is deferred since it contains imports of many standard library modules.
"""
from __future__ import annotations as _annotations
import math
import re
import typing
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
from typing import Any
from pydantic_core import PydanticCustomError, core_schema
from pydantic_core._pydantic_core import PydanticKnownError
def sequence_validator(
__input_value: typing.Sequence[Any],
validator: core_schema.ValidatorFunctionWrapHandler,
) -> typing.Sequence[Any]:
"""Validator for `Sequence` types, isinstance(v, Sequence) has already been called."""
value_type = type(__input_value)
# We don't accept any plain string as a sequence
# Relevant issue: https://github.com/pydantic/pydantic/issues/5595
if issubclass(value_type, (str, bytes)):
raise PydanticCustomError(
'sequence_str',
"'{type_name}' instances are not allowed as a Sequence value",
{'type_name': value_type.__name__},
)
v_list = validator(__input_value)
# the rest of the logic is just re-creating the original type from `v_list`
if value_type == list:
return v_list
elif issubclass(value_type, range):
# return the list as we probably can't re-create the range
return v_list
else:
# best guess at how to re-create the original type, more custom construction logic might be required
return value_type(v_list) # type: ignore[call-arg]
def import_string(value: Any) -> Any:
if isinstance(value, str):
try:
return _import_string_logic(value)
except ImportError as e:
raise PydanticCustomError('import_error', 'Invalid python path: {error}', {'error': str(e)}) from e
else:
# otherwise we just return the value and let the next validator do the rest of the work
return value
def _import_string_logic(dotted_path: str) -> Any:
"""Inspired by uvicorn — dotted paths should include a colon before the final item if that item is not a module.
(This is necessary to distinguish between a submodule and an attribute when there is a conflict.).
If the dotted path does not include a colon and the final item is not a valid module, importing as an attribute
rather than a submodule will be attempted automatically.
So, for example, the following values of `dotted_path` result in the following returned values:
* 'collections': <module 'collections'>
* 'collections.abc': <module 'collections.abc'>
* 'collections.abc:Mapping': <class 'collections.abc.Mapping'>
* `collections.abc.Mapping`: <class 'collections.abc.Mapping'> (though this is a bit slower than the previous line)
An error will be raised under any of the following scenarios:
* `dotted_path` contains more than one colon (e.g., 'collections:abc:Mapping')
* the substring of `dotted_path` before the colon is not a valid module in the environment (e.g., '123:Mapping')
* the substring of `dotted_path` after the colon is not an attribute of the module (e.g., 'collections:abc123')
"""
from importlib import import_module
components = dotted_path.strip().split(':')
if len(components) > 2:
raise ImportError(f"Import strings should have at most one ':'; received {dotted_path!r}")
module_path = components[0]
if not module_path:
raise ImportError(f'Import strings should have a nonempty module name; received {dotted_path!r}')
try:
module = import_module(module_path)
except ModuleNotFoundError as e:
if '.' in module_path:
# Check if it would be valid if the final item was separated from its module with a `:`
maybe_module_path, maybe_attribute = dotted_path.strip().rsplit('.', 1)
try:
return _import_string_logic(f'{maybe_module_path}:{maybe_attribute}')
except ImportError:
pass
raise ImportError(f'No module named {module_path!r}') from e
raise e
if len(components) > 1:
attribute = components[1]
try:
return getattr(module, attribute)
except AttributeError as e:
raise ImportError(f'cannot import name {attribute!r} from {module_path!r}') from e
else:
return module
def pattern_either_validator(__input_value: Any) -> typing.Pattern[Any]:
if isinstance(__input_value, typing.Pattern):
return __input_value
elif isinstance(__input_value, (str, bytes)):
# todo strict mode
return compile_pattern(__input_value) # type: ignore
else:
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
def pattern_str_validator(__input_value: Any) -> typing.Pattern[str]:
if isinstance(__input_value, typing.Pattern):
if isinstance(__input_value.pattern, str):
return __input_value
else:
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
elif isinstance(__input_value, str):
return compile_pattern(__input_value)
elif isinstance(__input_value, bytes):
raise PydanticCustomError('pattern_str_type', 'Input should be a string pattern')
else:
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
def pattern_bytes_validator(__input_value: Any) -> typing.Pattern[bytes]:
if isinstance(__input_value, typing.Pattern):
if isinstance(__input_value.pattern, bytes):
return __input_value
else:
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
elif isinstance(__input_value, bytes):
return compile_pattern(__input_value)
elif isinstance(__input_value, str):
raise PydanticCustomError('pattern_bytes_type', 'Input should be a bytes pattern')
else:
raise PydanticCustomError('pattern_type', 'Input should be a valid pattern')
PatternType = typing.TypeVar('PatternType', str, bytes)
def compile_pattern(pattern: PatternType) -> typing.Pattern[PatternType]:
try:
return re.compile(pattern)
except re.error:
raise PydanticCustomError('pattern_regex', 'Input should be a valid regular expression')
def ip_v4_address_validator(__input_value: Any) -> IPv4Address:
if isinstance(__input_value, IPv4Address):
return __input_value
try:
return IPv4Address(__input_value)
except ValueError:
raise PydanticCustomError('ip_v4_address', 'Input is not a valid IPv4 address')
def ip_v6_address_validator(__input_value: Any) -> IPv6Address:
if isinstance(__input_value, IPv6Address):
return __input_value
try:
return IPv6Address(__input_value)
except ValueError:
raise PydanticCustomError('ip_v6_address', 'Input is not a valid IPv6 address')
def ip_v4_network_validator(__input_value: Any) -> IPv4Network:
"""Assume IPv4Network initialised with a default `strict` argument.
See more:
https://docs.python.org/library/ipaddress.html#ipaddress.IPv4Network
"""
if isinstance(__input_value, IPv4Network):
return __input_value
try:
return IPv4Network(__input_value)
except ValueError:
raise PydanticCustomError('ip_v4_network', 'Input is not a valid IPv4 network')
def ip_v6_network_validator(__input_value: Any) -> IPv6Network:
"""Assume IPv6Network initialised with a default `strict` argument.
See more:
https://docs.python.org/library/ipaddress.html#ipaddress.IPv6Network
"""
if isinstance(__input_value, IPv6Network):
return __input_value
try:
return IPv6Network(__input_value)
except ValueError:
raise PydanticCustomError('ip_v6_network', 'Input is not a valid IPv6 network')
def ip_v4_interface_validator(__input_value: Any) -> IPv4Interface:
if isinstance(__input_value, IPv4Interface):
return __input_value
try:
return IPv4Interface(__input_value)
except ValueError:
raise PydanticCustomError('ip_v4_interface', 'Input is not a valid IPv4 interface')
def ip_v6_interface_validator(__input_value: Any) -> IPv6Interface:
if isinstance(__input_value, IPv6Interface):
return __input_value
try:
return IPv6Interface(__input_value)
except ValueError:
raise PydanticCustomError('ip_v6_interface', 'Input is not a valid IPv6 interface')
def greater_than_validator(x: Any, gt: Any) -> Any:
if not (x > gt):
raise PydanticKnownError('greater_than', {'gt': gt})
return x
def greater_than_or_equal_validator(x: Any, ge: Any) -> Any:
if not (x >= ge):
raise PydanticKnownError('greater_than_equal', {'ge': ge})
return x
def less_than_validator(x: Any, lt: Any) -> Any:
if not (x < lt):
raise PydanticKnownError('less_than', {'lt': lt})
return x
def less_than_or_equal_validator(x: Any, le: Any) -> Any:
if not (x <= le):
raise PydanticKnownError('less_than_equal', {'le': le})
return x
def multiple_of_validator(x: Any, multiple_of: Any) -> Any:
if not (x % multiple_of == 0):
raise PydanticKnownError('multiple_of', {'multiple_of': multiple_of})
return x
def min_length_validator(x: Any, min_length: Any) -> Any:
if not (len(x) >= min_length):
raise PydanticKnownError(
'too_short',
{'field_type': 'Value', 'min_length': min_length, 'actual_length': len(x)},
)
return x
def max_length_validator(x: Any, max_length: Any) -> Any:
if len(x) > max_length:
raise PydanticKnownError(
'too_long',
{'field_type': 'Value', 'max_length': max_length, 'actual_length': len(x)},
)
return x
def forbid_inf_nan_check(x: Any) -> Any:
if not math.isfinite(x):
raise PydanticKnownError('finite_number')
return x

View file

@ -1,308 +0,0 @@
import sys
from typing import Any, Callable, Dict
from .version import version_short
MOVED_IN_V2 = {
'pydantic.utils:version_info': 'pydantic.version:version_info',
'pydantic.error_wrappers:ValidationError': 'pydantic:ValidationError',
'pydantic.utils:to_camel': 'pydantic.alias_generators:to_pascal',
'pydantic.utils:to_lower_camel': 'pydantic.alias_generators:to_camel',
'pydantic:PyObject': 'pydantic.types:ImportString',
'pydantic.types:PyObject': 'pydantic.types:ImportString',
'pydantic.generics:GenericModel': 'pydantic.BaseModel',
}
DEPRECATED_MOVED_IN_V2 = {
'pydantic.tools:schema_of': 'pydantic.deprecated.tools:schema_of',
'pydantic.tools:parse_obj_as': 'pydantic.deprecated.tools:parse_obj_as',
'pydantic.tools:schema_json_of': 'pydantic.deprecated.tools:schema_json_of',
'pydantic.json:pydantic_encoder': 'pydantic.deprecated.json:pydantic_encoder',
'pydantic:validate_arguments': 'pydantic.deprecated.decorator:validate_arguments',
'pydantic.json:custom_pydantic_encoder': 'pydantic.deprecated.json:custom_pydantic_encoder',
'pydantic.json:timedelta_isoformat': 'pydantic.deprecated.json:timedelta_isoformat',
'pydantic.decorator:validate_arguments': 'pydantic.deprecated.decorator:validate_arguments',
'pydantic.class_validators:validator': 'pydantic.deprecated.class_validators:validator',
'pydantic.class_validators:root_validator': 'pydantic.deprecated.class_validators:root_validator',
'pydantic.config:BaseConfig': 'pydantic.deprecated.config:BaseConfig',
'pydantic.config:Extra': 'pydantic.deprecated.config:Extra',
}
REDIRECT_TO_V1 = {
f'pydantic.utils:{obj}': f'pydantic.v1.utils:{obj}'
for obj in (
'deep_update',
'GetterDict',
'lenient_issubclass',
'lenient_isinstance',
'is_valid_field',
'update_not_none',
'import_string',
'Representation',
'ROOT_KEY',
'smart_deepcopy',
'sequence_like',
)
}
REMOVED_IN_V2 = {
'pydantic:ConstrainedBytes',
'pydantic:ConstrainedDate',
'pydantic:ConstrainedDecimal',
'pydantic:ConstrainedFloat',
'pydantic:ConstrainedFrozenSet',
'pydantic:ConstrainedInt',
'pydantic:ConstrainedList',
'pydantic:ConstrainedSet',
'pydantic:ConstrainedStr',
'pydantic:JsonWrapper',
'pydantic:NoneBytes',
'pydantic:NoneStr',
'pydantic:NoneStrBytes',
'pydantic:Protocol',
'pydantic:Required',
'pydantic:StrBytes',
'pydantic:compiled',
'pydantic.config:get_config',
'pydantic.config:inherit_config',
'pydantic.config:prepare_config',
'pydantic:create_model_from_namedtuple',
'pydantic:create_model_from_typeddict',
'pydantic.dataclasses:create_pydantic_model_from_dataclass',
'pydantic.dataclasses:make_dataclass_validator',
'pydantic.dataclasses:set_validation',
'pydantic.datetime_parse:parse_date',
'pydantic.datetime_parse:parse_time',
'pydantic.datetime_parse:parse_datetime',
'pydantic.datetime_parse:parse_duration',
'pydantic.error_wrappers:ErrorWrapper',
'pydantic.errors:AnyStrMaxLengthError',
'pydantic.errors:AnyStrMinLengthError',
'pydantic.errors:ArbitraryTypeError',
'pydantic.errors:BoolError',
'pydantic.errors:BytesError',
'pydantic.errors:CallableError',
'pydantic.errors:ClassError',
'pydantic.errors:ColorError',
'pydantic.errors:ConfigError',
'pydantic.errors:DataclassTypeError',
'pydantic.errors:DateError',
'pydantic.errors:DateNotInTheFutureError',
'pydantic.errors:DateNotInThePastError',
'pydantic.errors:DateTimeError',
'pydantic.errors:DecimalError',
'pydantic.errors:DecimalIsNotFiniteError',
'pydantic.errors:DecimalMaxDigitsError',
'pydantic.errors:DecimalMaxPlacesError',
'pydantic.errors:DecimalWholeDigitsError',
'pydantic.errors:DictError',
'pydantic.errors:DurationError',
'pydantic.errors:EmailError',
'pydantic.errors:EnumError',
'pydantic.errors:EnumMemberError',
'pydantic.errors:ExtraError',
'pydantic.errors:FloatError',
'pydantic.errors:FrozenSetError',
'pydantic.errors:FrozenSetMaxLengthError',
'pydantic.errors:FrozenSetMinLengthError',
'pydantic.errors:HashableError',
'pydantic.errors:IPv4AddressError',
'pydantic.errors:IPv4InterfaceError',
'pydantic.errors:IPv4NetworkError',
'pydantic.errors:IPv6AddressError',
'pydantic.errors:IPv6InterfaceError',
'pydantic.errors:IPv6NetworkError',
'pydantic.errors:IPvAnyAddressError',
'pydantic.errors:IPvAnyInterfaceError',
'pydantic.errors:IPvAnyNetworkError',
'pydantic.errors:IntEnumError',
'pydantic.errors:IntegerError',
'pydantic.errors:InvalidByteSize',
'pydantic.errors:InvalidByteSizeUnit',
'pydantic.errors:InvalidDiscriminator',
'pydantic.errors:InvalidLengthForBrand',
'pydantic.errors:JsonError',
'pydantic.errors:JsonTypeError',
'pydantic.errors:ListError',
'pydantic.errors:ListMaxLengthError',
'pydantic.errors:ListMinLengthError',
'pydantic.errors:ListUniqueItemsError',
'pydantic.errors:LuhnValidationError',
'pydantic.errors:MissingDiscriminator',
'pydantic.errors:MissingError',
'pydantic.errors:NoneIsAllowedError',
'pydantic.errors:NoneIsNotAllowedError',
'pydantic.errors:NotDigitError',
'pydantic.errors:NotNoneError',
'pydantic.errors:NumberNotGeError',
'pydantic.errors:NumberNotGtError',
'pydantic.errors:NumberNotLeError',
'pydantic.errors:NumberNotLtError',
'pydantic.errors:NumberNotMultipleError',
'pydantic.errors:PathError',
'pydantic.errors:PathNotADirectoryError',
'pydantic.errors:PathNotAFileError',
'pydantic.errors:PathNotExistsError',
'pydantic.errors:PatternError',
'pydantic.errors:PyObjectError',
'pydantic.errors:PydanticTypeError',
'pydantic.errors:PydanticValueError',
'pydantic.errors:SequenceError',
'pydantic.errors:SetError',
'pydantic.errors:SetMaxLengthError',
'pydantic.errors:SetMinLengthError',
'pydantic.errors:StrError',
'pydantic.errors:StrRegexError',
'pydantic.errors:StrictBoolError',
'pydantic.errors:SubclassError',
'pydantic.errors:TimeError',
'pydantic.errors:TupleError',
'pydantic.errors:TupleLengthError',
'pydantic.errors:UUIDError',
'pydantic.errors:UUIDVersionError',
'pydantic.errors:UrlError',
'pydantic.errors:UrlExtraError',
'pydantic.errors:UrlHostError',
'pydantic.errors:UrlHostTldError',
'pydantic.errors:UrlPortError',
'pydantic.errors:UrlSchemeError',
'pydantic.errors:UrlSchemePermittedError',
'pydantic.errors:UrlUserInfoError',
'pydantic.errors:WrongConstantError',
'pydantic.main:validate_model',
'pydantic.networks:stricturl',
'pydantic:parse_file_as',
'pydantic:parse_raw_as',
'pydantic:stricturl',
'pydantic.tools:parse_file_as',
'pydantic.tools:parse_raw_as',
'pydantic.types:ConstrainedBytes',
'pydantic.types:ConstrainedDate',
'pydantic.types:ConstrainedDecimal',
'pydantic.types:ConstrainedFloat',
'pydantic.types:ConstrainedFrozenSet',
'pydantic.types:ConstrainedInt',
'pydantic.types:ConstrainedList',
'pydantic.types:ConstrainedSet',
'pydantic.types:ConstrainedStr',
'pydantic.types:JsonWrapper',
'pydantic.types:NoneBytes',
'pydantic.types:NoneStr',
'pydantic.types:NoneStrBytes',
'pydantic.types:StrBytes',
'pydantic.typing:evaluate_forwardref',
'pydantic.typing:AbstractSetIntStr',
'pydantic.typing:AnyCallable',
'pydantic.typing:AnyClassMethod',
'pydantic.typing:CallableGenerator',
'pydantic.typing:DictAny',
'pydantic.typing:DictIntStrAny',
'pydantic.typing:DictStrAny',
'pydantic.typing:IntStr',
'pydantic.typing:ListStr',
'pydantic.typing:MappingIntStrAny',
'pydantic.typing:NoArgAnyCallable',
'pydantic.typing:NoneType',
'pydantic.typing:ReprArgs',
'pydantic.typing:SetStr',
'pydantic.typing:StrPath',
'pydantic.typing:TupleGenerator',
'pydantic.typing:WithArgsTypes',
'pydantic.typing:all_literal_values',
'pydantic.typing:display_as_type',
'pydantic.typing:get_all_type_hints',
'pydantic.typing:get_args',
'pydantic.typing:get_origin',
'pydantic.typing:get_sub_types',
'pydantic.typing:is_callable_type',
'pydantic.typing:is_classvar',
'pydantic.typing:is_finalvar',
'pydantic.typing:is_literal_type',
'pydantic.typing:is_namedtuple',
'pydantic.typing:is_new_type',
'pydantic.typing:is_none_type',
'pydantic.typing:is_typeddict',
'pydantic.typing:is_typeddict_special',
'pydantic.typing:is_union',
'pydantic.typing:new_type_supertype',
'pydantic.typing:resolve_annotations',
'pydantic.typing:typing_base',
'pydantic.typing:update_field_forward_refs',
'pydantic.typing:update_model_forward_refs',
'pydantic.utils:ClassAttribute',
'pydantic.utils:DUNDER_ATTRIBUTES',
'pydantic.utils:PyObjectStr',
'pydantic.utils:ValueItems',
'pydantic.utils:almost_equal_floats',
'pydantic.utils:get_discriminator_alias_and_values',
'pydantic.utils:get_model',
'pydantic.utils:get_unique_discriminator_alias',
'pydantic.utils:in_ipython',
'pydantic.utils:is_valid_identifier',
'pydantic.utils:path_type',
'pydantic.utils:validate_field_name',
'pydantic:validate_model',
}
def getattr_migration(module: str) -> Callable[[str], Any]:
"""Implement PEP 562 for objects that were either moved or removed on the migration
to V2.
Args:
module: The module name.
Returns:
A callable that will raise an error if the object is not found.
"""
# This avoids circular import with errors.py.
from .errors import PydanticImportError
def wrapper(name: str) -> object:
"""Raise an error if the object is not found, or warn if it was moved.
In case it was moved, it still returns the object.
Args:
name: The object name.
Returns:
The object.
"""
if name == '__path__':
raise AttributeError(f'module {module!r} has no attribute {name!r}')
import warnings
from ._internal._validators import import_string
import_path = f'{module}:{name}'
if import_path in MOVED_IN_V2.keys():
new_location = MOVED_IN_V2[import_path]
warnings.warn(f'`{import_path}` has been moved to `{new_location}`.')
return import_string(MOVED_IN_V2[import_path])
if import_path in DEPRECATED_MOVED_IN_V2:
# skip the warning here because a deprecation warning will be raised elsewhere
return import_string(DEPRECATED_MOVED_IN_V2[import_path])
if import_path in REDIRECT_TO_V1:
new_location = REDIRECT_TO_V1[import_path]
warnings.warn(
f'`{import_path}` has been removed. We are importing from `{new_location}` instead.'
'See the migration guide for more details: https://docs.pydantic.dev/latest/migration/'
)
return import_string(REDIRECT_TO_V1[import_path])
if import_path == 'pydantic:BaseSettings':
raise PydanticImportError(
'`BaseSettings` has been moved to the `pydantic-settings` package. '
f'See https://docs.pydantic.dev/{version_short()}/migration/#basesettings-has-moved-to-pydantic-settings '
'for more details.'
)
if import_path in REMOVED_IN_V2:
raise PydanticImportError(f'`{import_path}` has been removed in V2.')
globals: Dict[str, Any] = sys.modules[module].__dict__
if name in globals:
return globals[name]
raise AttributeError(f'module {module!r} has no attribute {name!r}')
return wrapper

View file

@ -1,50 +0,0 @@
"""Alias generators for converting between different capitalization conventions."""
import re
__all__ = ('to_pascal', 'to_camel', 'to_snake')
def to_pascal(snake: str) -> str:
"""Convert a snake_case string to PascalCase.
Args:
snake: The string to convert.
Returns:
The PascalCase string.
"""
camel = snake.title()
return re.sub('([0-9A-Za-z])_(?=[0-9A-Z])', lambda m: m.group(1), camel)
def to_camel(snake: str) -> str:
"""Convert a snake_case string to camelCase.
Args:
snake: The string to convert.
Returns:
The converted camelCase string.
"""
camel = to_pascal(snake)
return re.sub('(^_*[A-Z])', lambda m: m.group(1).lower(), camel)
def to_snake(camel: str) -> str:
"""Convert a PascalCase or camelCase string to snake_case.
Args:
camel: The string to convert.
Returns:
The converted string in snake_case.
"""
# Handle the sequence of uppercase letters followed by a lowercase letter
snake = re.sub(r'([A-Z]+)([A-Z][a-z])', lambda m: f'{m.group(1)}_{m.group(2)}', camel)
# Insert an underscore between a lowercase letter and an uppercase letter
snake = re.sub(r'([a-z])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
# Insert an underscore between a digit and an uppercase letter
snake = re.sub(r'([0-9])([A-Z])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
# Insert an underscore between a lowercase letter and a digit
snake = re.sub(r'([a-z])([0-9])', lambda m: f'{m.group(1)}_{m.group(2)}', snake)
return snake.lower()

View file

@ -1,112 +0,0 @@
"""Support for alias configurations."""
from __future__ import annotations
import dataclasses
from typing import Callable, Literal
from ._internal import _internal_dataclass
__all__ = ('AliasGenerator', 'AliasPath', 'AliasChoices')
@dataclasses.dataclass(**_internal_dataclass.slots_true)
class AliasPath:
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/alias#aliaspath-and-aliaschoices
A data class used by `validation_alias` as a convenience to create aliases.
Attributes:
path: A list of string or integer aliases.
"""
path: list[int | str]
def __init__(self, first_arg: str, *args: str | int) -> None:
self.path = [first_arg] + list(args)
def convert_to_aliases(self) -> list[str | int]:
"""Converts arguments to a list of string or integer aliases.
Returns:
The list of aliases.
"""
return self.path
@dataclasses.dataclass(**_internal_dataclass.slots_true)
class AliasChoices:
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/alias#aliaspath-and-aliaschoices
A data class used by `validation_alias` as a convenience to create aliases.
Attributes:
choices: A list containing a string or `AliasPath`.
"""
choices: list[str | AliasPath]
def __init__(self, first_choice: str | AliasPath, *choices: str | AliasPath) -> None:
self.choices = [first_choice] + list(choices)
def convert_to_aliases(self) -> list[list[str | int]]:
"""Converts arguments to a list of lists containing string or integer aliases.
Returns:
The list of aliases.
"""
aliases: list[list[str | int]] = []
for c in self.choices:
if isinstance(c, AliasPath):
aliases.append(c.convert_to_aliases())
else:
aliases.append([c])
return aliases
@dataclasses.dataclass(**_internal_dataclass.slots_true)
class AliasGenerator:
"""Usage docs: https://docs.pydantic.dev/2.6/concepts/alias#using-an-aliasgenerator
A data class used by `alias_generator` as a convenience to create various aliases.
Attributes:
alias: A callable that takes a field name and returns an alias for it.
validation_alias: A callable that takes a field name and returns a validation alias for it.
serialization_alias: A callable that takes a field name and returns a serialization alias for it.
"""
alias: Callable[[str], str] | None = None
validation_alias: Callable[[str], str | AliasPath | AliasChoices] | None = None
serialization_alias: Callable[[str], str] | None = None
def _generate_alias(
self,
alias_kind: Literal['alias', 'validation_alias', 'serialization_alias'],
allowed_types: tuple[type[str] | type[AliasPath] | type[AliasChoices], ...],
field_name: str,
) -> str | AliasPath | AliasChoices | None:
"""Generate an alias of the specified kind. Returns None if the alias generator is None.
Raises:
TypeError: If the alias generator produces an invalid type.
"""
alias = None
if alias_generator := getattr(self, alias_kind):
alias = alias_generator(field_name)
if alias and not isinstance(alias, allowed_types):
raise TypeError(
f'Invalid `{alias_kind}` type. `{alias_kind}` generator must produce one of `{allowed_types}`'
)
return alias
def generate_aliases(self, field_name: str) -> tuple[str | None, str | AliasPath | AliasChoices | None, str | None]:
"""Generate `alias`, `validation_alias`, and `serialization_alias` for a field.
Returns:
A tuple of three aliases - validation, alias, and serialization.
"""
alias = self._generate_alias('alias', (str,), field_name)
validation_alias = self._generate_alias('validation_alias', (str, AliasChoices, AliasPath), field_name)
serialization_alias = self._generate_alias('serialization_alias', (str,), field_name)
return alias, validation_alias, serialization_alias # type: ignore

View file

@ -1,120 +0,0 @@
"""Type annotations to use with `__get_pydantic_core_schema__` and `__get_pydantic_json_schema__`."""
from __future__ import annotations as _annotations
from typing import TYPE_CHECKING, Any, Union
from pydantic_core import core_schema
if TYPE_CHECKING:
from .json_schema import JsonSchemaMode, JsonSchemaValue
CoreSchemaOrField = Union[
core_schema.CoreSchema,
core_schema.ModelField,
core_schema.DataclassField,
core_schema.TypedDictField,
core_schema.ComputedField,
]
__all__ = 'GetJsonSchemaHandler', 'GetCoreSchemaHandler'
class GetJsonSchemaHandler:
"""Handler to call into the next JSON schema generation function.
Attributes:
mode: Json schema mode, can be `validation` or `serialization`.
"""
mode: JsonSchemaMode
def __call__(self, __core_schema: CoreSchemaOrField) -> JsonSchemaValue:
"""Call the inner handler and get the JsonSchemaValue it returns.
This will call the next JSON schema modifying function up until it calls
into `pydantic.json_schema.GenerateJsonSchema`, which will raise a
`pydantic.errors.PydanticInvalidForJsonSchema` error if it cannot generate
a JSON schema.
Args:
__core_schema: A `pydantic_core.core_schema.CoreSchema`.
Returns:
JsonSchemaValue: The JSON schema generated by the inner JSON schema modify
functions.
"""
raise NotImplementedError
def resolve_ref_schema(self, __maybe_ref_json_schema: JsonSchemaValue) -> JsonSchemaValue:
"""Get the real schema for a `{"$ref": ...}` schema.
If the schema given is not a `$ref` schema, it will be returned as is.
This means you don't have to check before calling this function.
Args:
__maybe_ref_json_schema: A JsonSchemaValue which may be a `$ref` schema.
Raises:
LookupError: If the ref is not found.
Returns:
JsonSchemaValue: A JsonSchemaValue that has no `$ref`.
"""
raise NotImplementedError
class GetCoreSchemaHandler:
"""Handler to call into the next CoreSchema schema generation function."""
def __call__(self, __source_type: Any) -> core_schema.CoreSchema:
"""Call the inner handler and get the CoreSchema it returns.
This will call the next CoreSchema modifying function up until it calls
into Pydantic's internal schema generation machinery, which will raise a
`pydantic.errors.PydanticSchemaGenerationError` error if it cannot generate
a CoreSchema for the given source type.
Args:
__source_type: The input type.
Returns:
CoreSchema: The `pydantic-core` CoreSchema generated.
"""
raise NotImplementedError
def generate_schema(self, __source_type: Any) -> core_schema.CoreSchema:
"""Generate a schema unrelated to the current context.
Use this function if e.g. you are handling schema generation for a sequence
and want to generate a schema for its items.
Otherwise, you may end up doing something like applying a `min_length` constraint
that was intended for the sequence itself to its items!
Args:
__source_type: The input type.
Returns:
CoreSchema: The `pydantic-core` CoreSchema generated.
"""
raise NotImplementedError
def resolve_ref_schema(self, __maybe_ref_schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""Get the real schema for a `definition-ref` schema.
If the schema given is not a `definition-ref` schema, it will be returned as is.
This means you don't have to check before calling this function.
Args:
__maybe_ref_schema: A `CoreSchema`, `ref`-based or not.
Raises:
LookupError: If the `ref` is not found.
Returns:
A concrete `CoreSchema`.
"""
raise NotImplementedError
@property
def field_name(self) -> str | None:
"""Get the name of the closest field to this validator."""
raise NotImplementedError
def _get_types_namespace(self) -> dict[str, Any] | None:
"""Internal method used during type resolution for serializer annotations."""
raise NotImplementedError

Some files were not shown because too many files have changed in this diff Show more