diff --git a/lib/backports/__init__.py b/lib/backports/__init__.py index 0d1f7edf..8db66d3d 100644 --- a/lib/backports/__init__.py +++ b/lib/backports/__init__.py @@ -1 +1 @@ -__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore +__path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/lib/backports/csv.py b/lib/backports/csv.py deleted file mode 100644 index 4694a28e..00000000 --- a/lib/backports/csv.py +++ /dev/null @@ -1,979 +0,0 @@ -# -*- coding: utf-8 -*- -"""A port of Python 3's csv module to Python 2. - -The API of the csv module in Python 2 is drastically different from -the csv module in Python 3. This is due, for the most part, to the -difference between str in Python 2 and Python 3. - -The semantics of Python 3's version are more useful because they support -unicode natively, while Python 2's csv does not. -""" -from __future__ import unicode_literals, absolute_import - -__all__ = [ "QUOTE_MINIMAL", "QUOTE_ALL", "QUOTE_NONNUMERIC", "QUOTE_NONE", - "Error", "Dialect", "__doc__", "excel", "excel_tab", - "field_size_limit", "reader", "writer", - "register_dialect", "get_dialect", "list_dialects", "Sniffer", - "unregister_dialect", "__version__", "DictReader", "DictWriter" ] - -import re -import numbers -from io import StringIO -from csv import ( - QUOTE_MINIMAL, QUOTE_ALL, QUOTE_NONNUMERIC, QUOTE_NONE, - __version__, __doc__, Error, field_size_limit, -) - -# Stuff needed from six -import sys -PY3 = sys.version_info[0] == 3 -if PY3: - string_types = str - text_type = str - binary_type = bytes - unichr = chr -else: - string_types = basestring - text_type = unicode - binary_type = str - - -class QuoteStrategy(object): - quoting = None - - def __init__(self, dialect): - if self.quoting is not None: - assert dialect.quoting == self.quoting - self.dialect = dialect - self.setup() - - escape_pattern_quoted = r'({quotechar})'.format( - quotechar=re.escape(self.dialect.quotechar or '"')) - escape_pattern_unquoted = r'([{specialchars}])'.format( - specialchars=re.escape(self.specialchars)) - - self.escape_re_quoted = re.compile(escape_pattern_quoted) - self.escape_re_unquoted = re.compile(escape_pattern_unquoted) - - def setup(self): - """Optional method for strategy-wide optimizations.""" - - def quoted(self, field=None, raw_field=None, only=None): - """Determine whether this field should be quoted.""" - raise NotImplementedError( - 'quoted must be implemented by a subclass') - - @property - def specialchars(self): - """The special characters that need to be escaped.""" - raise NotImplementedError( - 'specialchars must be implemented by a subclass') - - def escape_re(self, quoted=None): - if quoted: - return self.escape_re_quoted - return self.escape_re_unquoted - - def escapechar(self, quoted=None): - if quoted and self.dialect.doublequote: - return self.dialect.quotechar - return self.dialect.escapechar - - def prepare(self, raw_field, only=None): - field = text_type(raw_field if raw_field is not None else '') - quoted = self.quoted(field=field, raw_field=raw_field, only=only) - - escape_re = self.escape_re(quoted=quoted) - escapechar = self.escapechar(quoted=quoted) - - if escape_re.search(field): - escapechar = '\\\\' if escapechar == '\\' else escapechar - if not escapechar: - raise Error('No escapechar is set') - escape_replace = r'{escapechar}\1'.format(escapechar=escapechar) - field = escape_re.sub(escape_replace, field) - - if quoted: - field = '{quotechar}{field}{quotechar}'.format( - quotechar=self.dialect.quotechar, field=field) - - return field - - -class QuoteMinimalStrategy(QuoteStrategy): - quoting = QUOTE_MINIMAL - - def setup(self): - self.quoted_re = re.compile(r'[{specialchars}]'.format( - specialchars=re.escape(self.specialchars))) - - @property - def specialchars(self): - return ( - self.dialect.lineterminator + - self.dialect.quotechar + - self.dialect.delimiter + - (self.dialect.escapechar or '') - ) - - def quoted(self, field, only, **kwargs): - if field == self.dialect.quotechar and not self.dialect.doublequote: - # If the only character in the field is the quotechar, and - # doublequote is false, then just escape without outer quotes. - return False - return field == '' and only or bool(self.quoted_re.search(field)) - - -class QuoteAllStrategy(QuoteStrategy): - quoting = QUOTE_ALL - - @property - def specialchars(self): - return self.dialect.quotechar - - def quoted(self, **kwargs): - return True - - -class QuoteNonnumericStrategy(QuoteStrategy): - quoting = QUOTE_NONNUMERIC - - @property - def specialchars(self): - return ( - self.dialect.lineterminator + - self.dialect.quotechar + - self.dialect.delimiter + - (self.dialect.escapechar or '') - ) - - def quoted(self, raw_field, **kwargs): - return not isinstance(raw_field, numbers.Number) - - -class QuoteNoneStrategy(QuoteStrategy): - quoting = QUOTE_NONE - - @property - def specialchars(self): - return ( - self.dialect.lineterminator + - (self.dialect.quotechar or '') + - self.dialect.delimiter + - (self.dialect.escapechar or '') - ) - - def quoted(self, field, only, **kwargs): - if field == '' and only: - raise Error('single empty field record must be quoted') - return False - - -class writer(object): - def __init__(self, fileobj, dialect='excel', **fmtparams): - if fileobj is None: - raise TypeError('fileobj must be file-like, not None') - - self.fileobj = fileobj - - if isinstance(dialect, text_type): - dialect = get_dialect(dialect) - - try: - self.dialect = Dialect.combine(dialect, fmtparams) - except Error as e: - raise TypeError(*e.args) - - strategies = { - QUOTE_MINIMAL: QuoteMinimalStrategy, - QUOTE_ALL: QuoteAllStrategy, - QUOTE_NONNUMERIC: QuoteNonnumericStrategy, - QUOTE_NONE: QuoteNoneStrategy, - } - self.strategy = strategies[self.dialect.quoting](self.dialect) - - def writerow(self, row): - if row is None: - raise Error('row must be an iterable') - - row = list(row) - only = len(row) == 1 - row = [self.strategy.prepare(field, only=only) for field in row] - - line = self.dialect.delimiter.join(row) + self.dialect.lineterminator - return self.fileobj.write(line) - - def writerows(self, rows): - for row in rows: - self.writerow(row) - - -START_RECORD = 0 -START_FIELD = 1 -ESCAPED_CHAR = 2 -IN_FIELD = 3 -IN_QUOTED_FIELD = 4 -ESCAPE_IN_QUOTED_FIELD = 5 -QUOTE_IN_QUOTED_FIELD = 6 -EAT_CRNL = 7 -AFTER_ESCAPED_CRNL = 8 - - -class reader(object): - def __init__(self, fileobj, dialect='excel', **fmtparams): - self.input_iter = iter(fileobj) - - if isinstance(dialect, text_type): - dialect = get_dialect(dialect) - - try: - self.dialect = Dialect.combine(dialect, fmtparams) - except Error as e: - raise TypeError(*e.args) - - self.fields = None - self.field = None - self.line_num = 0 - - def parse_reset(self): - self.fields = [] - self.field = [] - self.state = START_RECORD - self.numeric_field = False - - def parse_save_field(self): - field = ''.join(self.field) - self.field = [] - if self.numeric_field: - field = float(field) - self.numeric_field = False - self.fields.append(field) - - def parse_add_char(self, c): - if len(self.field) >= field_size_limit(): - raise Error('field size limit exceeded') - self.field.append(c) - - def parse_process_char(self, c): - switch = { - START_RECORD: self._parse_start_record, - START_FIELD: self._parse_start_field, - ESCAPED_CHAR: self._parse_escaped_char, - AFTER_ESCAPED_CRNL: self._parse_after_escaped_crnl, - IN_FIELD: self._parse_in_field, - IN_QUOTED_FIELD: self._parse_in_quoted_field, - ESCAPE_IN_QUOTED_FIELD: self._parse_escape_in_quoted_field, - QUOTE_IN_QUOTED_FIELD: self._parse_quote_in_quoted_field, - EAT_CRNL: self._parse_eat_crnl, - } - return switch[self.state](c) - - def _parse_start_record(self, c): - if c == '\0': - return - elif c == '\n' or c == '\r': - self.state = EAT_CRNL - return - - self.state = START_FIELD - return self._parse_start_field(c) - - def _parse_start_field(self, c): - if c == '\n' or c == '\r' or c == '\0': - self.parse_save_field() - self.state = START_RECORD if c == '\0' else EAT_CRNL - elif (c == self.dialect.quotechar and - self.dialect.quoting != QUOTE_NONE): - self.state = IN_QUOTED_FIELD - elif c == self.dialect.escapechar: - self.state = ESCAPED_CHAR - elif c == ' ' and self.dialect.skipinitialspace: - pass # Ignore space at start of field - elif c == self.dialect.delimiter: - # Save empty field - self.parse_save_field() - else: - # Begin new unquoted field - if self.dialect.quoting == QUOTE_NONNUMERIC: - self.numeric_field = True - self.parse_add_char(c) - self.state = IN_FIELD - - def _parse_escaped_char(self, c): - if c == '\n' or c == '\r': - self.parse_add_char(c) - self.state = AFTER_ESCAPED_CRNL - return - if c == '\0': - c = '\n' - self.parse_add_char(c) - self.state = IN_FIELD - - def _parse_after_escaped_crnl(self, c): - if c == '\0': - return - return self._parse_in_field(c) - - def _parse_in_field(self, c): - # In unquoted field - if c == '\n' or c == '\r' or c == '\0': - # End of line - return [fields] - self.parse_save_field() - self.state = START_RECORD if c == '\0' else EAT_CRNL - elif c == self.dialect.escapechar: - self.state = ESCAPED_CHAR - elif c == self.dialect.delimiter: - self.parse_save_field() - self.state = START_FIELD - else: - # Normal character - save in field - self.parse_add_char(c) - - def _parse_in_quoted_field(self, c): - if c == '\0': - pass - elif c == self.dialect.escapechar: - self.state = ESCAPE_IN_QUOTED_FIELD - elif (c == self.dialect.quotechar and - self.dialect.quoting != QUOTE_NONE): - if self.dialect.doublequote: - self.state = QUOTE_IN_QUOTED_FIELD - else: - self.state = IN_FIELD - else: - self.parse_add_char(c) - - def _parse_escape_in_quoted_field(self, c): - if c == '\0': - c = '\n' - - self.parse_add_char(c) - self.state = IN_QUOTED_FIELD - - def _parse_quote_in_quoted_field(self, c): - if (self.dialect.quoting != QUOTE_NONE and - c == self.dialect.quotechar): - # save "" as " - self.parse_add_char(c) - self.state = IN_QUOTED_FIELD - elif c == self.dialect.delimiter: - self.parse_save_field() - self.state = START_FIELD - elif c == '\n' or c == '\r' or c == '\0': - # End of line = return [fields] - self.parse_save_field() - self.state = START_RECORD if c == '\0' else EAT_CRNL - elif not self.dialect.strict: - self.parse_add_char(c) - self.state = IN_FIELD - else: - # illegal - raise Error("{delimiter}' expected after '{quotechar}".format( - delimiter=self.dialect.delimiter, - quotechar=self.dialect.quotechar, - )) - - def _parse_eat_crnl(self, c): - if c == '\n' or c == '\r': - pass - elif c == '\0': - self.state = START_RECORD - else: - raise Error('new-line character seen in unquoted field - do you ' - 'need to open the file in universal-newline mode?') - - - def __iter__(self): - return self - - def __next__(self): - self.parse_reset() - - while True: - try: - lineobj = next(self.input_iter) - except StopIteration: - if len(self.field) != 0 or self.state == IN_QUOTED_FIELD: - if self.dialect.strict: - raise Error('unexpected end of data') - self.parse_save_field() - if self.fields: - break - raise - - if not isinstance(lineobj, text_type): - typ = type(lineobj) - typ_name = 'bytes' if typ == bytes else typ.__name__ - err_str = ('iterator should return strings, not {0}' - ' (did you open the file in text mode?)') - raise Error(err_str.format(typ_name)) - - self.line_num += 1 - for c in lineobj: - if c == '\0': - raise Error('line contains NULL byte') - self.parse_process_char(c) - - self.parse_process_char('\0') - - if self.state == START_RECORD: - break - - fields = self.fields - self.fields = None - return fields - - next = __next__ - - -_dialect_registry = {} -def register_dialect(name, dialect='excel', **fmtparams): - if not isinstance(name, text_type): - raise TypeError('"name" must be a string') - - dialect = Dialect.extend(dialect, fmtparams) - - try: - Dialect.validate(dialect) - except: - raise TypeError('dialect is invalid') - - assert name not in _dialect_registry - _dialect_registry[name] = dialect - -def unregister_dialect(name): - try: - _dialect_registry.pop(name) - except KeyError: - raise Error('"{name}" not a registered dialect'.format(name=name)) - -def get_dialect(name): - try: - return _dialect_registry[name] - except KeyError: - raise Error('Could not find dialect {0}'.format(name)) - -def list_dialects(): - return list(_dialect_registry) - - -class Dialect(object): - """Describe a CSV dialect. - This must be subclassed (see csv.excel). Valid attributes are: - delimiter, quotechar, escapechar, doublequote, skipinitialspace, - lineterminator, quoting, strict. - """ - _name = "" - _valid = False - # placeholders - delimiter = None - quotechar = None - escapechar = None - doublequote = None - skipinitialspace = None - lineterminator = None - quoting = None - strict = None - - def __init__(self): - self.validate(self) - if self.__class__ != Dialect: - self._valid = True - - @classmethod - def validate(cls, dialect): - dialect = cls.extend(dialect) - - if not isinstance(dialect.quoting, int): - raise Error('"quoting" must be an integer') - - if dialect.delimiter is None: - raise Error('delimiter must be set') - cls.validate_text(dialect, 'delimiter') - - if dialect.lineterminator is None: - raise Error('lineterminator must be set') - if not isinstance(dialect.lineterminator, text_type): - raise Error('"lineterminator" must be a string') - - if dialect.quoting not in [ - QUOTE_NONE, QUOTE_MINIMAL, QUOTE_NONNUMERIC, QUOTE_ALL]: - raise Error('Invalid quoting specified') - - if dialect.quoting != QUOTE_NONE: - if dialect.quotechar is None and dialect.escapechar is None: - raise Error('quotechar must be set if quoting enabled') - if dialect.quotechar is not None: - cls.validate_text(dialect, 'quotechar') - - @staticmethod - def validate_text(dialect, attr): - val = getattr(dialect, attr) - if not isinstance(val, text_type): - if type(val) == bytes: - raise Error('"{0}" must be string, not bytes'.format(attr)) - raise Error('"{0}" must be string, not {1}'.format( - attr, type(val).__name__)) - - if len(val) != 1: - raise Error('"{0}" must be a 1-character string'.format(attr)) - - @staticmethod - def defaults(): - return { - 'delimiter': ',', - 'doublequote': True, - 'escapechar': None, - 'lineterminator': '\r\n', - 'quotechar': '"', - 'quoting': QUOTE_MINIMAL, - 'skipinitialspace': False, - 'strict': False, - } - - @classmethod - def extend(cls, dialect, fmtparams=None): - if isinstance(dialect, string_types): - dialect = get_dialect(dialect) - - if fmtparams is None: - return dialect - - defaults = cls.defaults() - - if any(param not in defaults for param in fmtparams): - raise TypeError('Invalid fmtparam') - - specified = dict( - (attr, getattr(dialect, attr, None)) - for attr in cls.defaults() - ) - - specified.update(fmtparams) - return type(str('ExtendedDialect'), (cls,), specified) - - @classmethod - def combine(cls, dialect, fmtparams): - """Create a new dialect with defaults and added parameters.""" - dialect = cls.extend(dialect, fmtparams) - defaults = cls.defaults() - specified = dict( - (attr, getattr(dialect, attr, None)) - for attr in defaults - if getattr(dialect, attr, None) is not None or - attr in ['quotechar', 'delimiter', 'lineterminator', 'quoting'] - ) - - defaults.update(specified) - dialect = type(str('CombinedDialect'), (cls,), defaults) - cls.validate(dialect) - return dialect() - - def __delattr__(self, attr): - if self._valid: - raise AttributeError('dialect is immutable.') - super(Dialect, self).__delattr__(attr) - - def __setattr__(self, attr, value): - if self._valid: - raise AttributeError('dialect is immutable.') - super(Dialect, self).__setattr__(attr, value) - - -class excel(Dialect): - """Describe the usual properties of Excel-generated CSV files.""" - delimiter = ',' - quotechar = '"' - doublequote = True - skipinitialspace = False - lineterminator = '\r\n' - quoting = QUOTE_MINIMAL -register_dialect("excel", excel) - -class excel_tab(excel): - """Describe the usual properties of Excel-generated TAB-delimited files.""" - delimiter = '\t' -register_dialect("excel-tab", excel_tab) - -class unix_dialect(Dialect): - """Describe the usual properties of Unix-generated CSV files.""" - delimiter = ',' - quotechar = '"' - doublequote = True - skipinitialspace = False - lineterminator = '\n' - quoting = QUOTE_ALL -register_dialect("unix", unix_dialect) - - -class DictReader(object): - def __init__(self, f, fieldnames=None, restkey=None, restval=None, - dialect="excel", *args, **kwds): - self._fieldnames = fieldnames # list of keys for the dict - self.restkey = restkey # key to catch long rows - self.restval = restval # default value for short rows - self.reader = reader(f, dialect, *args, **kwds) - self.dialect = dialect - self.line_num = 0 - - def __iter__(self): - return self - - @property - def fieldnames(self): - if self._fieldnames is None: - try: - self._fieldnames = next(self.reader) - except StopIteration: - pass - self.line_num = self.reader.line_num - return self._fieldnames - - @fieldnames.setter - def fieldnames(self, value): - self._fieldnames = value - - def __next__(self): - if self.line_num == 0: - # Used only for its side effect. - self.fieldnames - row = next(self.reader) - self.line_num = self.reader.line_num - - # unlike the basic reader, we prefer not to return blanks, - # because we will typically wind up with a dict full of None - # values - while row == []: - row = next(self.reader) - d = dict(zip(self.fieldnames, row)) - lf = len(self.fieldnames) - lr = len(row) - if lf < lr: - d[self.restkey] = row[lf:] - elif lf > lr: - for key in self.fieldnames[lr:]: - d[key] = self.restval - return d - - next = __next__ - - -class DictWriter(object): - def __init__(self, f, fieldnames, restval="", extrasaction="raise", - dialect="excel", *args, **kwds): - self.fieldnames = fieldnames # list of keys for the dict - self.restval = restval # for writing short dicts - if extrasaction.lower() not in ("raise", "ignore"): - raise ValueError("extrasaction (%s) must be 'raise' or 'ignore'" - % extrasaction) - self.extrasaction = extrasaction - self.writer = writer(f, dialect, *args, **kwds) - - def writeheader(self): - header = dict(zip(self.fieldnames, self.fieldnames)) - self.writerow(header) - - def _dict_to_list(self, rowdict): - if self.extrasaction == "raise": - wrong_fields = [k for k in rowdict if k not in self.fieldnames] - if wrong_fields: - raise ValueError("dict contains fields not in fieldnames: " - + ", ".join([repr(x) for x in wrong_fields])) - return (rowdict.get(key, self.restval) for key in self.fieldnames) - - def writerow(self, rowdict): - return self.writer.writerow(self._dict_to_list(rowdict)) - - def writerows(self, rowdicts): - return self.writer.writerows(map(self._dict_to_list, rowdicts)) - -# Guard Sniffer's type checking against builds that exclude complex() -try: - complex -except NameError: - complex = float - -class Sniffer(object): - ''' - "Sniffs" the format of a CSV file (i.e. delimiter, quotechar) - Returns a Dialect object. - ''' - def __init__(self): - # in case there is more than one possible delimiter - self.preferred = [',', '\t', ';', ' ', ':'] - - - def sniff(self, sample, delimiters=None): - """ - Returns a dialect (or None) corresponding to the sample - """ - - quotechar, doublequote, delimiter, skipinitialspace = \ - self._guess_quote_and_delimiter(sample, delimiters) - if not delimiter: - delimiter, skipinitialspace = self._guess_delimiter(sample, - delimiters) - - if not delimiter: - raise Error("Could not determine delimiter") - - class dialect(Dialect): - _name = "sniffed" - lineterminator = '\r\n' - quoting = QUOTE_MINIMAL - # escapechar = '' - - dialect.doublequote = doublequote - dialect.delimiter = delimiter - # _csv.reader won't accept a quotechar of '' - dialect.quotechar = quotechar or '"' - dialect.skipinitialspace = skipinitialspace - - return dialect - - - def _guess_quote_and_delimiter(self, data, delimiters): - """ - Looks for text enclosed between two identical quotes - (the probable quotechar) which are preceded and followed - by the same character (the probable delimiter). - For example: - ,'some text', - The quote with the most wins, same with the delimiter. - If there is no quotechar the delimiter can't be determined - this way. - """ - - matches = [] - for restr in ('(?P[^\w\n"\'])(?P ?)(?P["\']).*?(?P=quote)(?P=delim)', # ,".*?", - '(?:^|\n)(?P["\']).*?(?P=quote)(?P[^\w\n"\'])(?P ?)', # ".*?", - '(?P>[^\w\n"\'])(?P ?)(?P["\']).*?(?P=quote)(?:$|\n)', # ,".*?" - '(?:^|\n)(?P["\']).*?(?P=quote)(?:$|\n)'): # ".*?" (no delim, no space) - regexp = re.compile(restr, re.DOTALL | re.MULTILINE) - matches = regexp.findall(data) - if matches: - break - - if not matches: - # (quotechar, doublequote, delimiter, skipinitialspace) - return ('', False, None, 0) - quotes = {} - delims = {} - spaces = 0 - groupindex = regexp.groupindex - for m in matches: - n = groupindex['quote'] - 1 - key = m[n] - if key: - quotes[key] = quotes.get(key, 0) + 1 - try: - n = groupindex['delim'] - 1 - key = m[n] - except KeyError: - continue - if key and (delimiters is None or key in delimiters): - delims[key] = delims.get(key, 0) + 1 - try: - n = groupindex['space'] - 1 - except KeyError: - continue - if m[n]: - spaces += 1 - - quotechar = max(quotes, key=quotes.get) - - if delims: - delim = max(delims, key=delims.get) - skipinitialspace = delims[delim] == spaces - if delim == '\n': # most likely a file with a single column - delim = '' - else: - # there is *no* delimiter, it's a single column of quoted data - delim = '' - skipinitialspace = 0 - - # if we see an extra quote between delimiters, we've got a - # double quoted format - dq_regexp = re.compile( - r"((%(delim)s)|^)\W*%(quote)s[^%(delim)s\n]*%(quote)s[^%(delim)s\n]*%(quote)s\W*((%(delim)s)|$)" % \ - {'delim':re.escape(delim), 'quote':quotechar}, re.MULTILINE) - - - - if dq_regexp.search(data): - doublequote = True - else: - doublequote = False - - return (quotechar, doublequote, delim, skipinitialspace) - - - def _guess_delimiter(self, data, delimiters): - """ - The delimiter /should/ occur the same number of times on - each row. However, due to malformed data, it may not. We don't want - an all or nothing approach, so we allow for small variations in this - number. - 1) build a table of the frequency of each character on every line. - 2) build a table of frequencies of this frequency (meta-frequency?), - e.g. 'x occurred 5 times in 10 rows, 6 times in 1000 rows, - 7 times in 2 rows' - 3) use the mode of the meta-frequency to determine the /expected/ - frequency for that character - 4) find out how often the character actually meets that goal - 5) the character that best meets its goal is the delimiter - For performance reasons, the data is evaluated in chunks, so it can - try and evaluate the smallest portion of the data possible, evaluating - additional chunks as necessary. - """ - - data = list(filter(None, data.split('\n'))) - - ascii = [unichr(c) for c in range(127)] # 7-bit ASCII - - # build frequency tables - chunkLength = min(10, len(data)) - iteration = 0 - charFrequency = {} - modes = {} - delims = {} - start, end = 0, min(chunkLength, len(data)) - while start < len(data): - iteration += 1 - for line in data[start:end]: - for char in ascii: - metaFrequency = charFrequency.get(char, {}) - # must count even if frequency is 0 - freq = line.count(char) - # value is the mode - metaFrequency[freq] = metaFrequency.get(freq, 0) + 1 - charFrequency[char] = metaFrequency - - for char in charFrequency.keys(): - items = list(charFrequency[char].items()) - if len(items) == 1 and items[0][0] == 0: - continue - # get the mode of the frequencies - if len(items) > 1: - modes[char] = max(items, key=lambda x: x[1]) - # adjust the mode - subtract the sum of all - # other frequencies - items.remove(modes[char]) - modes[char] = (modes[char][0], modes[char][1] - - sum(item[1] for item in items)) - else: - modes[char] = items[0] - - # build a list of possible delimiters - modeList = modes.items() - total = float(chunkLength * iteration) - # (rows of consistent data) / (number of rows) = 100% - consistency = 1.0 - # minimum consistency threshold - threshold = 0.9 - while len(delims) == 0 and consistency >= threshold: - for k, v in modeList: - if v[0] > 0 and v[1] > 0: - if ((v[1]/total) >= consistency and - (delimiters is None or k in delimiters)): - delims[k] = v - consistency -= 0.01 - - if len(delims) == 1: - delim = list(delims.keys())[0] - skipinitialspace = (data[0].count(delim) == - data[0].count("%c " % delim)) - return (delim, skipinitialspace) - - # analyze another chunkLength lines - start = end - end += chunkLength - - if not delims: - return ('', 0) - - # if there's more than one, fall back to a 'preferred' list - if len(delims) > 1: - for d in self.preferred: - if d in delims.keys(): - skipinitialspace = (data[0].count(d) == - data[0].count("%c " % d)) - return (d, skipinitialspace) - - # nothing else indicates a preference, pick the character that - # dominates(?) - items = [(v,k) for (k,v) in delims.items()] - items.sort() - delim = items[-1][1] - - skipinitialspace = (data[0].count(delim) == - data[0].count("%c " % delim)) - return (delim, skipinitialspace) - - - def has_header(self, sample): - # Creates a dictionary of types of data in each column. If any - # column is of a single type (say, integers), *except* for the first - # row, then the first row is presumed to be labels. If the type - # can't be determined, it is assumed to be a string in which case - # the length of the string is the determining factor: if all of the - # rows except for the first are the same length, it's a header. - # Finally, a 'vote' is taken at the end for each column, adding or - # subtracting from the likelihood of the first row being a header. - - rdr = reader(StringIO(sample), self.sniff(sample)) - - header = next(rdr) # assume first row is header - - columns = len(header) - columnTypes = {} - for i in range(columns): columnTypes[i] = None - - checked = 0 - for row in rdr: - # arbitrary number of rows to check, to keep it sane - if checked > 20: - break - checked += 1 - - if len(row) != columns: - continue # skip rows that have irregular number of columns - - for col in list(columnTypes.keys()): - - for thisType in [int, float, complex]: - try: - thisType(row[col]) - break - except (ValueError, OverflowError): - pass - else: - # fallback to length of string - thisType = len(row[col]) - - if thisType != columnTypes[col]: - if columnTypes[col] is None: # add new column type - columnTypes[col] = thisType - else: - # type is inconsistent, remove column from - # consideration - del columnTypes[col] - - # finally, compare results against first row and "vote" - # on whether it's a header - hasHeader = 0 - for col, colType in columnTypes.items(): - if type(colType) == type(0): # it's a length - if len(header[col]) != colType: - hasHeader += 1 - else: - hasHeader -= 1 - else: # attempt typecast - try: - colType(header[col]) - except (ValueError, TypeError): - hasHeader += 1 - else: - hasHeader -= 1 - - return hasHeader > 0 diff --git a/lib/backports/functools_lru_cache.py b/lib/backports/functools_lru_cache.py deleted file mode 100644 index e372cff3..00000000 --- a/lib/backports/functools_lru_cache.py +++ /dev/null @@ -1,243 +0,0 @@ -from __future__ import absolute_import - -import functools -from collections import namedtuple -from threading import RLock - -_CacheInfo = namedtuple("_CacheInfo", ["hits", "misses", "maxsize", "currsize"]) - - -@functools.wraps(functools.update_wrapper) -def update_wrapper( - wrapper, - wrapped, - assigned=functools.WRAPPER_ASSIGNMENTS, - updated=functools.WRAPPER_UPDATES, -): - """ - Patch two bugs in functools.update_wrapper. - """ - # workaround for http://bugs.python.org/issue3445 - assigned = tuple(attr for attr in assigned if hasattr(wrapped, attr)) - wrapper = functools.update_wrapper(wrapper, wrapped, assigned, updated) - # workaround for https://bugs.python.org/issue17482 - wrapper.__wrapped__ = wrapped - return wrapper - - -class _HashedSeq(list): - """This class guarantees that hash() will be called no more than once - per element. This is important because the lru_cache() will hash - the key multiple times on a cache miss. - - """ - - __slots__ = 'hashvalue' - - def __init__(self, tup, hash=hash): - self[:] = tup - self.hashvalue = hash(tup) - - def __hash__(self): - return self.hashvalue - - -def _make_key( - args, - kwds, - typed, - kwd_mark=(object(),), - fasttypes={int, str}, - tuple=tuple, - type=type, - len=len, -): - """Make a cache key from optionally typed positional and keyword arguments - - The key is constructed in a way that is flat as possible rather than - as a nested structure that would take more memory. - - If there is only a single argument and its data type is known to cache - its hash value, then that argument is returned without a wrapper. This - saves space and improves lookup speed. - - """ - # All of code below relies on kwds preserving the order input by the user. - # Formerly, we sorted() the kwds before looping. The new way is *much* - # faster; however, it means that f(x=1, y=2) will now be treated as a - # distinct call from f(y=2, x=1) which will be cached separately. - key = args - if kwds: - key += kwd_mark - for item in kwds.items(): - key += item - if typed: - key += tuple(type(v) for v in args) - if kwds: - key += tuple(type(v) for v in kwds.values()) - elif len(key) == 1 and type(key[0]) in fasttypes: - return key[0] - return _HashedSeq(key) - - -def lru_cache(maxsize=128, typed=False): - """Least-recently-used cache decorator. - - If *maxsize* is set to None, the LRU features are disabled and the cache - can grow without bound. - - If *typed* is True, arguments of different types will be cached separately. - For example, f(decimal.Decimal("3.0")) and f(3.0) will be treated as - distinct calls with distinct results. Some types such as str and int may - be cached separately even when typed is false. - - Arguments to the cached function must be hashable. - - View the cache statistics named tuple (hits, misses, maxsize, currsize) - with f.cache_info(). Clear the cache and statistics with f.cache_clear(). - Access the underlying function with f.__wrapped__. - - See: https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU) - - """ - - # Users should only access the lru_cache through its public API: - # cache_info, cache_clear, and f.__wrapped__ - # The internals of the lru_cache are encapsulated for thread safety and - # to allow the implementation to change (including a possible C version). - - if isinstance(maxsize, int): - # Negative maxsize is treated as 0 - if maxsize < 0: - maxsize = 0 - elif callable(maxsize) and isinstance(typed, bool): - # The user_function was passed in directly via the maxsize argument - user_function, maxsize = maxsize, 128 - wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo) - wrapper.cache_parameters = lambda: {'maxsize': maxsize, 'typed': typed} - return update_wrapper(wrapper, user_function) - elif maxsize is not None: - raise TypeError('Expected first argument to be an integer, a callable, or None') - - def decorating_function(user_function): - wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo) - wrapper.cache_parameters = lambda: {'maxsize': maxsize, 'typed': typed} - return update_wrapper(wrapper, user_function) - - return decorating_function - - -def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo): - # Constants shared by all lru cache instances: - sentinel = object() # unique object used to signal cache misses - make_key = _make_key # build a key from the function arguments - PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields - - cache = {} - hits = misses = 0 - full = False - cache_get = cache.get # bound method to lookup a key or return None - cache_len = cache.__len__ # get cache size without calling len() - lock = RLock() # because linkedlist updates aren't threadsafe - root = [] # root of the circular doubly linked list - root[:] = [root, root, None, None] # initialize by pointing to self - - if maxsize == 0: - - def wrapper(*args, **kwds): - # No caching -- just a statistics update - nonlocal misses - misses += 1 - result = user_function(*args, **kwds) - return result - - elif maxsize is None: - - def wrapper(*args, **kwds): - # Simple caching without ordering or size limit - nonlocal hits, misses - key = make_key(args, kwds, typed) - result = cache_get(key, sentinel) - if result is not sentinel: - hits += 1 - return result - misses += 1 - result = user_function(*args, **kwds) - cache[key] = result - return result - - else: - - def wrapper(*args, **kwds): - # Size limited caching that tracks accesses by recency - nonlocal root, hits, misses, full - key = make_key(args, kwds, typed) - with lock: - link = cache_get(key) - if link is not None: - # Move the link to the front of the circular queue - link_prev, link_next, _key, result = link - link_prev[NEXT] = link_next - link_next[PREV] = link_prev - last = root[PREV] - last[NEXT] = root[PREV] = link - link[PREV] = last - link[NEXT] = root - hits += 1 - return result - misses += 1 - result = user_function(*args, **kwds) - with lock: - if key in cache: - # Getting here means that this same key was added to the - # cache while the lock was released. Since the link - # update is already done, we need only return the - # computed result and update the count of misses. - pass - elif full: - # Use the old root to store the new key and result. - oldroot = root - oldroot[KEY] = key - oldroot[RESULT] = result - # Empty the oldest link and make it the new root. - # Keep a reference to the old key and old result to - # prevent their ref counts from going to zero during the - # update. That will prevent potentially arbitrary object - # clean-up code (i.e. __del__) from running while we're - # still adjusting the links. - root = oldroot[NEXT] - oldkey = root[KEY] - root[KEY] = root[RESULT] = None - # Now update the cache dictionary. - del cache[oldkey] - # Save the potentially reentrant cache[key] assignment - # for last, after the root and links have been put in - # a consistent state. - cache[key] = oldroot - else: - # Put result in a new link at the front of the queue. - last = root[PREV] - link = [last, root, key, result] - last[NEXT] = root[PREV] = cache[key] = link - # Use the cache_len bound method instead of the len() function - # which could potentially be wrapped in an lru_cache itself. - full = cache_len() >= maxsize - return result - - def cache_info(): - """Report cache statistics""" - with lock: - return _CacheInfo(hits, misses, maxsize, cache_len()) - - def cache_clear(): - """Clear the cache and cache statistics""" - nonlocal hits, misses, full - with lock: - cache.clear() - root[:] = [root, root, None, None] - hits = misses = 0 - full = False - - wrapper.cache_info = cache_info - wrapper.cache_clear = cache_clear - return wrapper diff --git a/lib/backports/tarfile/__init__.py b/lib/backports/tarfile/__init__.py new file mode 100644 index 00000000..6dd498dc --- /dev/null +++ b/lib/backports/tarfile/__init__.py @@ -0,0 +1,2902 @@ +#!/usr/bin/env python3 +#------------------------------------------------------------------- +# tarfile.py +#------------------------------------------------------------------- +# Copyright (C) 2002 Lars Gustaebel +# All rights reserved. +# +# Permission is hereby granted, free of charge, to any person +# obtaining a copy of this software and associated documentation +# files (the "Software"), to deal in the Software without +# restriction, including without limitation the rights to use, +# copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the +# Software is furnished to do so, subject to the following +# conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +# OTHER DEALINGS IN THE SOFTWARE. +# +"""Read from and write to tar format archives. +""" + +version = "0.9.0" +__author__ = "Lars Gust\u00e4bel (lars@gustaebel.de)" +__credits__ = "Gustavo Niemeyer, Niels Gust\u00e4bel, Richard Townsend." + +#--------- +# Imports +#--------- +from builtins import open as bltn_open +import sys +import os +import io +import shutil +import stat +import time +import struct +import copy +import re +import warnings + +from .compat.py38 import removesuffix + +try: + import pwd +except ImportError: + pwd = None +try: + import grp +except ImportError: + grp = None + +# os.symlink on Windows prior to 6.0 raises NotImplementedError +# OSError (winerror=1314) will be raised if the caller does not hold the +# SeCreateSymbolicLinkPrivilege privilege +symlink_exception = (AttributeError, NotImplementedError, OSError) + +# from tarfile import * +__all__ = ["TarFile", "TarInfo", "is_tarfile", "TarError", "ReadError", + "CompressionError", "StreamError", "ExtractError", "HeaderError", + "ENCODING", "USTAR_FORMAT", "GNU_FORMAT", "PAX_FORMAT", + "DEFAULT_FORMAT", "open","fully_trusted_filter", "data_filter", + "tar_filter", "FilterError", "AbsoluteLinkError", + "OutsideDestinationError", "SpecialFileError", "AbsolutePathError", + "LinkOutsideDestinationError"] + + +#--------------------------------------------------------- +# tar constants +#--------------------------------------------------------- +NUL = b"\0" # the null character +BLOCKSIZE = 512 # length of processing blocks +RECORDSIZE = BLOCKSIZE * 20 # length of records +GNU_MAGIC = b"ustar \0" # magic gnu tar string +POSIX_MAGIC = b"ustar\x0000" # magic posix tar string + +LENGTH_NAME = 100 # maximum length of a filename +LENGTH_LINK = 100 # maximum length of a linkname +LENGTH_PREFIX = 155 # maximum length of the prefix field + +REGTYPE = b"0" # regular file +AREGTYPE = b"\0" # regular file +LNKTYPE = b"1" # link (inside tarfile) +SYMTYPE = b"2" # symbolic link +CHRTYPE = b"3" # character special device +BLKTYPE = b"4" # block special device +DIRTYPE = b"5" # directory +FIFOTYPE = b"6" # fifo special device +CONTTYPE = b"7" # contiguous file + +GNUTYPE_LONGNAME = b"L" # GNU tar longname +GNUTYPE_LONGLINK = b"K" # GNU tar longlink +GNUTYPE_SPARSE = b"S" # GNU tar sparse file + +XHDTYPE = b"x" # POSIX.1-2001 extended header +XGLTYPE = b"g" # POSIX.1-2001 global header +SOLARIS_XHDTYPE = b"X" # Solaris extended header + +USTAR_FORMAT = 0 # POSIX.1-1988 (ustar) format +GNU_FORMAT = 1 # GNU tar format +PAX_FORMAT = 2 # POSIX.1-2001 (pax) format +DEFAULT_FORMAT = PAX_FORMAT + +#--------------------------------------------------------- +# tarfile constants +#--------------------------------------------------------- +# File types that tarfile supports: +SUPPORTED_TYPES = (REGTYPE, AREGTYPE, LNKTYPE, + SYMTYPE, DIRTYPE, FIFOTYPE, + CONTTYPE, CHRTYPE, BLKTYPE, + GNUTYPE_LONGNAME, GNUTYPE_LONGLINK, + GNUTYPE_SPARSE) + +# File types that will be treated as a regular file. +REGULAR_TYPES = (REGTYPE, AREGTYPE, + CONTTYPE, GNUTYPE_SPARSE) + +# File types that are part of the GNU tar format. +GNU_TYPES = (GNUTYPE_LONGNAME, GNUTYPE_LONGLINK, + GNUTYPE_SPARSE) + +# Fields from a pax header that override a TarInfo attribute. +PAX_FIELDS = ("path", "linkpath", "size", "mtime", + "uid", "gid", "uname", "gname") + +# Fields from a pax header that are affected by hdrcharset. +PAX_NAME_FIELDS = {"path", "linkpath", "uname", "gname"} + +# Fields in a pax header that are numbers, all other fields +# are treated as strings. +PAX_NUMBER_FIELDS = { + "atime": float, + "ctime": float, + "mtime": float, + "uid": int, + "gid": int, + "size": int +} + +#--------------------------------------------------------- +# initialization +#--------------------------------------------------------- +if os.name == "nt": + ENCODING = "utf-8" +else: + ENCODING = sys.getfilesystemencoding() + +#--------------------------------------------------------- +# Some useful functions +#--------------------------------------------------------- + +def stn(s, length, encoding, errors): + """Convert a string to a null-terminated bytes object. + """ + if s is None: + raise ValueError("metadata cannot contain None") + s = s.encode(encoding, errors) + return s[:length] + (length - len(s)) * NUL + +def nts(s, encoding, errors): + """Convert a null-terminated bytes object to a string. + """ + p = s.find(b"\0") + if p != -1: + s = s[:p] + return s.decode(encoding, errors) + +def nti(s): + """Convert a number field to a python number. + """ + # There are two possible encodings for a number field, see + # itn() below. + if s[0] in (0o200, 0o377): + n = 0 + for i in range(len(s) - 1): + n <<= 8 + n += s[i + 1] + if s[0] == 0o377: + n = -(256 ** (len(s) - 1) - n) + else: + try: + s = nts(s, "ascii", "strict") + n = int(s.strip() or "0", 8) + except ValueError: + raise InvalidHeaderError("invalid header") + return n + +def itn(n, digits=8, format=DEFAULT_FORMAT): + """Convert a python number to a number field. + """ + # POSIX 1003.1-1988 requires numbers to be encoded as a string of + # octal digits followed by a null-byte, this allows values up to + # (8**(digits-1))-1. GNU tar allows storing numbers greater than + # that if necessary. A leading 0o200 or 0o377 byte indicate this + # particular encoding, the following digits-1 bytes are a big-endian + # base-256 representation. This allows values up to (256**(digits-1))-1. + # A 0o200 byte indicates a positive number, a 0o377 byte a negative + # number. + original_n = n + n = int(n) + if 0 <= n < 8 ** (digits - 1): + s = bytes("%0*o" % (digits - 1, n), "ascii") + NUL + elif format == GNU_FORMAT and -256 ** (digits - 1) <= n < 256 ** (digits - 1): + if n >= 0: + s = bytearray([0o200]) + else: + s = bytearray([0o377]) + n = 256 ** digits + n + + for i in range(digits - 1): + s.insert(1, n & 0o377) + n >>= 8 + else: + raise ValueError("overflow in number field") + + return s + +def calc_chksums(buf): + """Calculate the checksum for a member's header by summing up all + characters except for the chksum field which is treated as if + it was filled with spaces. According to the GNU tar sources, + some tars (Sun and NeXT) calculate chksum with signed char, + which will be different if there are chars in the buffer with + the high bit set. So we calculate two checksums, unsigned and + signed. + """ + unsigned_chksum = 256 + sum(struct.unpack_from("148B8x356B", buf)) + signed_chksum = 256 + sum(struct.unpack_from("148b8x356b", buf)) + return unsigned_chksum, signed_chksum + +def copyfileobj(src, dst, length=None, exception=OSError, bufsize=None): + """Copy length bytes from fileobj src to fileobj dst. + If length is None, copy the entire content. + """ + bufsize = bufsize or 16 * 1024 + if length == 0: + return + if length is None: + shutil.copyfileobj(src, dst, bufsize) + return + + blocks, remainder = divmod(length, bufsize) + for b in range(blocks): + buf = src.read(bufsize) + if len(buf) < bufsize: + raise exception("unexpected end of data") + dst.write(buf) + + if remainder != 0: + buf = src.read(remainder) + if len(buf) < remainder: + raise exception("unexpected end of data") + dst.write(buf) + return + +def _safe_print(s): + encoding = getattr(sys.stdout, 'encoding', None) + if encoding is not None: + s = s.encode(encoding, 'backslashreplace').decode(encoding) + print(s, end=' ') + + +class TarError(Exception): + """Base exception.""" + pass +class ExtractError(TarError): + """General exception for extract errors.""" + pass +class ReadError(TarError): + """Exception for unreadable tar archives.""" + pass +class CompressionError(TarError): + """Exception for unavailable compression methods.""" + pass +class StreamError(TarError): + """Exception for unsupported operations on stream-like TarFiles.""" + pass +class HeaderError(TarError): + """Base exception for header errors.""" + pass +class EmptyHeaderError(HeaderError): + """Exception for empty headers.""" + pass +class TruncatedHeaderError(HeaderError): + """Exception for truncated headers.""" + pass +class EOFHeaderError(HeaderError): + """Exception for end of file headers.""" + pass +class InvalidHeaderError(HeaderError): + """Exception for invalid headers.""" + pass +class SubsequentHeaderError(HeaderError): + """Exception for missing and invalid extended headers.""" + pass + +#--------------------------- +# internal stream interface +#--------------------------- +class _LowLevelFile: + """Low-level file object. Supports reading and writing. + It is used instead of a regular file object for streaming + access. + """ + + def __init__(self, name, mode): + mode = { + "r": os.O_RDONLY, + "w": os.O_WRONLY | os.O_CREAT | os.O_TRUNC, + }[mode] + if hasattr(os, "O_BINARY"): + mode |= os.O_BINARY + self.fd = os.open(name, mode, 0o666) + + def close(self): + os.close(self.fd) + + def read(self, size): + return os.read(self.fd, size) + + def write(self, s): + os.write(self.fd, s) + +class _Stream: + """Class that serves as an adapter between TarFile and + a stream-like object. The stream-like object only + needs to have a read() or write() method that works with bytes, + and the method is accessed blockwise. + Use of gzip or bzip2 compression is possible. + A stream-like object could be for example: sys.stdin.buffer, + sys.stdout.buffer, a socket, a tape device etc. + + _Stream is intended to be used only internally. + """ + + def __init__(self, name, mode, comptype, fileobj, bufsize, + compresslevel): + """Construct a _Stream object. + """ + self._extfileobj = True + if fileobj is None: + fileobj = _LowLevelFile(name, mode) + self._extfileobj = False + + if comptype == '*': + # Enable transparent compression detection for the + # stream interface + fileobj = _StreamProxy(fileobj) + comptype = fileobj.getcomptype() + + self.name = name or "" + self.mode = mode + self.comptype = comptype + self.fileobj = fileobj + self.bufsize = bufsize + self.buf = b"" + self.pos = 0 + self.closed = False + + try: + if comptype == "gz": + try: + import zlib + except ImportError: + raise CompressionError("zlib module is not available") from None + self.zlib = zlib + self.crc = zlib.crc32(b"") + if mode == "r": + self.exception = zlib.error + self._init_read_gz() + else: + self._init_write_gz(compresslevel) + + elif comptype == "bz2": + try: + import bz2 + except ImportError: + raise CompressionError("bz2 module is not available") from None + if mode == "r": + self.dbuf = b"" + self.cmp = bz2.BZ2Decompressor() + self.exception = OSError + else: + self.cmp = bz2.BZ2Compressor(compresslevel) + + elif comptype == "xz": + try: + import lzma + except ImportError: + raise CompressionError("lzma module is not available") from None + if mode == "r": + self.dbuf = b"" + self.cmp = lzma.LZMADecompressor() + self.exception = lzma.LZMAError + else: + self.cmp = lzma.LZMACompressor() + + elif comptype != "tar": + raise CompressionError("unknown compression type %r" % comptype) + + except: + if not self._extfileobj: + self.fileobj.close() + self.closed = True + raise + + def __del__(self): + if hasattr(self, "closed") and not self.closed: + self.close() + + def _init_write_gz(self, compresslevel): + """Initialize for writing with gzip compression. + """ + self.cmp = self.zlib.compressobj(compresslevel, + self.zlib.DEFLATED, + -self.zlib.MAX_WBITS, + self.zlib.DEF_MEM_LEVEL, + 0) + timestamp = struct.pack(" self.bufsize: + self.fileobj.write(self.buf[:self.bufsize]) + self.buf = self.buf[self.bufsize:] + + def close(self): + """Close the _Stream object. No operation should be + done on it afterwards. + """ + if self.closed: + return + + self.closed = True + try: + if self.mode == "w" and self.comptype != "tar": + self.buf += self.cmp.flush() + + if self.mode == "w" and self.buf: + self.fileobj.write(self.buf) + self.buf = b"" + if self.comptype == "gz": + self.fileobj.write(struct.pack("= 0: + blocks, remainder = divmod(pos - self.pos, self.bufsize) + for i in range(blocks): + self.read(self.bufsize) + self.read(remainder) + else: + raise StreamError("seeking backwards is not allowed") + return self.pos + + def read(self, size): + """Return the next size number of bytes from the stream.""" + assert size is not None + buf = self._read(size) + self.pos += len(buf) + return buf + + def _read(self, size): + """Return size bytes from the stream. + """ + if self.comptype == "tar": + return self.__read(size) + + c = len(self.dbuf) + t = [self.dbuf] + while c < size: + # Skip underlying buffer to avoid unaligned double buffering. + if self.buf: + buf = self.buf + self.buf = b"" + else: + buf = self.fileobj.read(self.bufsize) + if not buf: + break + try: + buf = self.cmp.decompress(buf) + except self.exception as e: + raise ReadError("invalid compressed data") from e + t.append(buf) + c += len(buf) + t = b"".join(t) + self.dbuf = t[size:] + return t[:size] + + def __read(self, size): + """Return size bytes from stream. If internal buffer is empty, + read another block from the stream. + """ + c = len(self.buf) + t = [self.buf] + while c < size: + buf = self.fileobj.read(self.bufsize) + if not buf: + break + t.append(buf) + c += len(buf) + t = b"".join(t) + self.buf = t[size:] + return t[:size] +# class _Stream + +class _StreamProxy(object): + """Small proxy class that enables transparent compression + detection for the Stream interface (mode 'r|*'). + """ + + def __init__(self, fileobj): + self.fileobj = fileobj + self.buf = self.fileobj.read(BLOCKSIZE) + + def read(self, size): + self.read = self.fileobj.read + return self.buf + + def getcomptype(self): + if self.buf.startswith(b"\x1f\x8b\x08"): + return "gz" + elif self.buf[0:3] == b"BZh" and self.buf[4:10] == b"1AY&SY": + return "bz2" + elif self.buf.startswith((b"\x5d\x00\x00\x80", b"\xfd7zXZ")): + return "xz" + else: + return "tar" + + def close(self): + self.fileobj.close() +# class StreamProxy + +#------------------------ +# Extraction file object +#------------------------ +class _FileInFile(object): + """A thin wrapper around an existing file object that + provides a part of its data as an individual file + object. + """ + + def __init__(self, fileobj, offset, size, name, blockinfo=None): + self.fileobj = fileobj + self.offset = offset + self.size = size + self.position = 0 + self.name = name + self.closed = False + + if blockinfo is None: + blockinfo = [(0, size)] + + # Construct a map with data and zero blocks. + self.map_index = 0 + self.map = [] + lastpos = 0 + realpos = self.offset + for offset, size in blockinfo: + if offset > lastpos: + self.map.append((False, lastpos, offset, None)) + self.map.append((True, offset, offset + size, realpos)) + realpos += size + lastpos = offset + size + if lastpos < self.size: + self.map.append((False, lastpos, self.size, None)) + + def flush(self): + pass + + def readable(self): + return True + + def writable(self): + return False + + def seekable(self): + return self.fileobj.seekable() + + def tell(self): + """Return the current file position. + """ + return self.position + + def seek(self, position, whence=io.SEEK_SET): + """Seek to a position in the file. + """ + if whence == io.SEEK_SET: + self.position = min(max(position, 0), self.size) + elif whence == io.SEEK_CUR: + if position < 0: + self.position = max(self.position + position, 0) + else: + self.position = min(self.position + position, self.size) + elif whence == io.SEEK_END: + self.position = max(min(self.size + position, self.size), 0) + else: + raise ValueError("Invalid argument") + return self.position + + def read(self, size=None): + """Read data from the file. + """ + if size is None: + size = self.size - self.position + else: + size = min(size, self.size - self.position) + + buf = b"" + while size > 0: + while True: + data, start, stop, offset = self.map[self.map_index] + if start <= self.position < stop: + break + else: + self.map_index += 1 + if self.map_index == len(self.map): + self.map_index = 0 + length = min(size, stop - self.position) + if data: + self.fileobj.seek(offset + (self.position - start)) + b = self.fileobj.read(length) + if len(b) != length: + raise ReadError("unexpected end of data") + buf += b + else: + buf += NUL * length + size -= length + self.position += length + return buf + + def readinto(self, b): + buf = self.read(len(b)) + b[:len(buf)] = buf + return len(buf) + + def close(self): + self.closed = True +#class _FileInFile + +class ExFileObject(io.BufferedReader): + + def __init__(self, tarfile, tarinfo): + fileobj = _FileInFile(tarfile.fileobj, tarinfo.offset_data, + tarinfo.size, tarinfo.name, tarinfo.sparse) + super().__init__(fileobj) +#class ExFileObject + + +#----------------------------- +# extraction filters (PEP 706) +#----------------------------- + +class FilterError(TarError): + pass + +class AbsolutePathError(FilterError): + def __init__(self, tarinfo): + self.tarinfo = tarinfo + super().__init__(f'member {tarinfo.name!r} has an absolute path') + +class OutsideDestinationError(FilterError): + def __init__(self, tarinfo, path): + self.tarinfo = tarinfo + self._path = path + super().__init__(f'{tarinfo.name!r} would be extracted to {path!r}, ' + + 'which is outside the destination') + +class SpecialFileError(FilterError): + def __init__(self, tarinfo): + self.tarinfo = tarinfo + super().__init__(f'{tarinfo.name!r} is a special file') + +class AbsoluteLinkError(FilterError): + def __init__(self, tarinfo): + self.tarinfo = tarinfo + super().__init__(f'{tarinfo.name!r} is a link to an absolute path') + +class LinkOutsideDestinationError(FilterError): + def __init__(self, tarinfo, path): + self.tarinfo = tarinfo + self._path = path + super().__init__(f'{tarinfo.name!r} would link to {path!r}, ' + + 'which is outside the destination') + +def _get_filtered_attrs(member, dest_path, for_data=True): + new_attrs = {} + name = member.name + dest_path = os.path.realpath(dest_path) + # Strip leading / (tar's directory separator) from filenames. + # Include os.sep (target OS directory separator) as well. + if name.startswith(('/', os.sep)): + name = new_attrs['name'] = member.path.lstrip('/' + os.sep) + if os.path.isabs(name): + # Path is absolute even after stripping. + # For example, 'C:/foo' on Windows. + raise AbsolutePathError(member) + # Ensure we stay in the destination + target_path = os.path.realpath(os.path.join(dest_path, name)) + if os.path.commonpath([target_path, dest_path]) != dest_path: + raise OutsideDestinationError(member, target_path) + # Limit permissions (no high bits, and go-w) + mode = member.mode + if mode is not None: + # Strip high bits & group/other write bits + mode = mode & 0o755 + if for_data: + # For data, handle permissions & file types + if member.isreg() or member.islnk(): + if not mode & 0o100: + # Clear executable bits if not executable by user + mode &= ~0o111 + # Ensure owner can read & write + mode |= 0o600 + elif member.isdir() or member.issym(): + # Ignore mode for directories & symlinks + mode = None + else: + # Reject special files + raise SpecialFileError(member) + if mode != member.mode: + new_attrs['mode'] = mode + if for_data: + # Ignore ownership for 'data' + if member.uid is not None: + new_attrs['uid'] = None + if member.gid is not None: + new_attrs['gid'] = None + if member.uname is not None: + new_attrs['uname'] = None + if member.gname is not None: + new_attrs['gname'] = None + # Check link destination for 'data' + if member.islnk() or member.issym(): + if os.path.isabs(member.linkname): + raise AbsoluteLinkError(member) + if member.issym(): + target_path = os.path.join(dest_path, + os.path.dirname(name), + member.linkname) + else: + target_path = os.path.join(dest_path, + member.linkname) + target_path = os.path.realpath(target_path) + if os.path.commonpath([target_path, dest_path]) != dest_path: + raise LinkOutsideDestinationError(member, target_path) + return new_attrs + +def fully_trusted_filter(member, dest_path): + return member + +def tar_filter(member, dest_path): + new_attrs = _get_filtered_attrs(member, dest_path, False) + if new_attrs: + return member.replace(**new_attrs, deep=False) + return member + +def data_filter(member, dest_path): + new_attrs = _get_filtered_attrs(member, dest_path, True) + if new_attrs: + return member.replace(**new_attrs, deep=False) + return member + +_NAMED_FILTERS = { + "fully_trusted": fully_trusted_filter, + "tar": tar_filter, + "data": data_filter, +} + +#------------------ +# Exported Classes +#------------------ + +# Sentinel for replace() defaults, meaning "don't change the attribute" +_KEEP = object() + +class TarInfo(object): + """Informational class which holds the details about an + archive member given by a tar header block. + TarInfo objects are returned by TarFile.getmember(), + TarFile.getmembers() and TarFile.gettarinfo() and are + usually created internally. + """ + + __slots__ = dict( + name = 'Name of the archive member.', + mode = 'Permission bits.', + uid = 'User ID of the user who originally stored this member.', + gid = 'Group ID of the user who originally stored this member.', + size = 'Size in bytes.', + mtime = 'Time of last modification.', + chksum = 'Header checksum.', + type = ('File type. type is usually one of these constants: ' + 'REGTYPE, AREGTYPE, LNKTYPE, SYMTYPE, DIRTYPE, FIFOTYPE, ' + 'CONTTYPE, CHRTYPE, BLKTYPE, GNUTYPE_SPARSE.'), + linkname = ('Name of the target file name, which is only present ' + 'in TarInfo objects of type LNKTYPE and SYMTYPE.'), + uname = 'User name.', + gname = 'Group name.', + devmajor = 'Device major number.', + devminor = 'Device minor number.', + offset = 'The tar header starts here.', + offset_data = "The file's data starts here.", + pax_headers = ('A dictionary containing key-value pairs of an ' + 'associated pax extended header.'), + sparse = 'Sparse member information.', + tarfile = None, + _sparse_structs = None, + _link_target = None, + ) + + def __init__(self, name=""): + """Construct a TarInfo object. name is the optional name + of the member. + """ + self.name = name # member name + self.mode = 0o644 # file permissions + self.uid = 0 # user id + self.gid = 0 # group id + self.size = 0 # file size + self.mtime = 0 # modification time + self.chksum = 0 # header checksum + self.type = REGTYPE # member type + self.linkname = "" # link name + self.uname = "" # user name + self.gname = "" # group name + self.devmajor = 0 # device major number + self.devminor = 0 # device minor number + + self.offset = 0 # the tar header starts here + self.offset_data = 0 # the file's data starts here + + self.sparse = None # sparse member information + self.pax_headers = {} # pax header information + + @property + def path(self): + 'In pax headers, "name" is called "path".' + return self.name + + @path.setter + def path(self, name): + self.name = name + + @property + def linkpath(self): + 'In pax headers, "linkname" is called "linkpath".' + return self.linkname + + @linkpath.setter + def linkpath(self, linkname): + self.linkname = linkname + + def __repr__(self): + return "<%s %r at %#x>" % (self.__class__.__name__,self.name,id(self)) + + def replace(self, *, + name=_KEEP, mtime=_KEEP, mode=_KEEP, linkname=_KEEP, + uid=_KEEP, gid=_KEEP, uname=_KEEP, gname=_KEEP, + deep=True, _KEEP=_KEEP): + """Return a deep copy of self with the given attributes replaced. + """ + if deep: + result = copy.deepcopy(self) + else: + result = copy.copy(self) + if name is not _KEEP: + result.name = name + if mtime is not _KEEP: + result.mtime = mtime + if mode is not _KEEP: + result.mode = mode + if linkname is not _KEEP: + result.linkname = linkname + if uid is not _KEEP: + result.uid = uid + if gid is not _KEEP: + result.gid = gid + if uname is not _KEEP: + result.uname = uname + if gname is not _KEEP: + result.gname = gname + return result + + def get_info(self): + """Return the TarInfo's attributes as a dictionary. + """ + if self.mode is None: + mode = None + else: + mode = self.mode & 0o7777 + info = { + "name": self.name, + "mode": mode, + "uid": self.uid, + "gid": self.gid, + "size": self.size, + "mtime": self.mtime, + "chksum": self.chksum, + "type": self.type, + "linkname": self.linkname, + "uname": self.uname, + "gname": self.gname, + "devmajor": self.devmajor, + "devminor": self.devminor + } + + if info["type"] == DIRTYPE and not info["name"].endswith("/"): + info["name"] += "/" + + return info + + def tobuf(self, format=DEFAULT_FORMAT, encoding=ENCODING, errors="surrogateescape"): + """Return a tar header as a string of 512 byte blocks. + """ + info = self.get_info() + for name, value in info.items(): + if value is None: + raise ValueError("%s may not be None" % name) + + if format == USTAR_FORMAT: + return self.create_ustar_header(info, encoding, errors) + elif format == GNU_FORMAT: + return self.create_gnu_header(info, encoding, errors) + elif format == PAX_FORMAT: + return self.create_pax_header(info, encoding) + else: + raise ValueError("invalid format") + + def create_ustar_header(self, info, encoding, errors): + """Return the object as a ustar header block. + """ + info["magic"] = POSIX_MAGIC + + if len(info["linkname"].encode(encoding, errors)) > LENGTH_LINK: + raise ValueError("linkname is too long") + + if len(info["name"].encode(encoding, errors)) > LENGTH_NAME: + info["prefix"], info["name"] = self._posix_split_name(info["name"], encoding, errors) + + return self._create_header(info, USTAR_FORMAT, encoding, errors) + + def create_gnu_header(self, info, encoding, errors): + """Return the object as a GNU header block sequence. + """ + info["magic"] = GNU_MAGIC + + buf = b"" + if len(info["linkname"].encode(encoding, errors)) > LENGTH_LINK: + buf += self._create_gnu_long_header(info["linkname"], GNUTYPE_LONGLINK, encoding, errors) + + if len(info["name"].encode(encoding, errors)) > LENGTH_NAME: + buf += self._create_gnu_long_header(info["name"], GNUTYPE_LONGNAME, encoding, errors) + + return buf + self._create_header(info, GNU_FORMAT, encoding, errors) + + def create_pax_header(self, info, encoding): + """Return the object as a ustar header block. If it cannot be + represented this way, prepend a pax extended header sequence + with supplement information. + """ + info["magic"] = POSIX_MAGIC + pax_headers = self.pax_headers.copy() + + # Test string fields for values that exceed the field length or cannot + # be represented in ASCII encoding. + for name, hname, length in ( + ("name", "path", LENGTH_NAME), ("linkname", "linkpath", LENGTH_LINK), + ("uname", "uname", 32), ("gname", "gname", 32)): + + if hname in pax_headers: + # The pax header has priority. + continue + + # Try to encode the string as ASCII. + try: + info[name].encode("ascii", "strict") + except UnicodeEncodeError: + pax_headers[hname] = info[name] + continue + + if len(info[name]) > length: + pax_headers[hname] = info[name] + + # Test number fields for values that exceed the field limit or values + # that like to be stored as float. + for name, digits in (("uid", 8), ("gid", 8), ("size", 12), ("mtime", 12)): + needs_pax = False + + val = info[name] + val_is_float = isinstance(val, float) + val_int = round(val) if val_is_float else val + if not 0 <= val_int < 8 ** (digits - 1): + # Avoid overflow. + info[name] = 0 + needs_pax = True + elif val_is_float: + # Put rounded value in ustar header, and full + # precision value in pax header. + info[name] = val_int + needs_pax = True + + # The existing pax header has priority. + if needs_pax and name not in pax_headers: + pax_headers[name] = str(val) + + # Create a pax extended header if necessary. + if pax_headers: + buf = self._create_pax_generic_header(pax_headers, XHDTYPE, encoding) + else: + buf = b"" + + return buf + self._create_header(info, USTAR_FORMAT, "ascii", "replace") + + @classmethod + def create_pax_global_header(cls, pax_headers): + """Return the object as a pax global header block sequence. + """ + return cls._create_pax_generic_header(pax_headers, XGLTYPE, "utf-8") + + def _posix_split_name(self, name, encoding, errors): + """Split a name longer than 100 chars into a prefix + and a name part. + """ + components = name.split("/") + for i in range(1, len(components)): + prefix = "/".join(components[:i]) + name = "/".join(components[i:]) + if len(prefix.encode(encoding, errors)) <= LENGTH_PREFIX and \ + len(name.encode(encoding, errors)) <= LENGTH_NAME: + break + else: + raise ValueError("name is too long") + + return prefix, name + + @staticmethod + def _create_header(info, format, encoding, errors): + """Return a header block. info is a dictionary with file + information, format must be one of the *_FORMAT constants. + """ + has_device_fields = info.get("type") in (CHRTYPE, BLKTYPE) + if has_device_fields: + devmajor = itn(info.get("devmajor", 0), 8, format) + devminor = itn(info.get("devminor", 0), 8, format) + else: + devmajor = stn("", 8, encoding, errors) + devminor = stn("", 8, encoding, errors) + + # None values in metadata should cause ValueError. + # itn()/stn() do this for all fields except type. + filetype = info.get("type", REGTYPE) + if filetype is None: + raise ValueError("TarInfo.type must not be None") + + parts = [ + stn(info.get("name", ""), 100, encoding, errors), + itn(info.get("mode", 0) & 0o7777, 8, format), + itn(info.get("uid", 0), 8, format), + itn(info.get("gid", 0), 8, format), + itn(info.get("size", 0), 12, format), + itn(info.get("mtime", 0), 12, format), + b" ", # checksum field + filetype, + stn(info.get("linkname", ""), 100, encoding, errors), + info.get("magic", POSIX_MAGIC), + stn(info.get("uname", ""), 32, encoding, errors), + stn(info.get("gname", ""), 32, encoding, errors), + devmajor, + devminor, + stn(info.get("prefix", ""), 155, encoding, errors) + ] + + buf = struct.pack("%ds" % BLOCKSIZE, b"".join(parts)) + chksum = calc_chksums(buf[-BLOCKSIZE:])[0] + buf = buf[:-364] + bytes("%06o\0" % chksum, "ascii") + buf[-357:] + return buf + + @staticmethod + def _create_payload(payload): + """Return the string payload filled with zero bytes + up to the next 512 byte border. + """ + blocks, remainder = divmod(len(payload), BLOCKSIZE) + if remainder > 0: + payload += (BLOCKSIZE - remainder) * NUL + return payload + + @classmethod + def _create_gnu_long_header(cls, name, type, encoding, errors): + """Return a GNUTYPE_LONGNAME or GNUTYPE_LONGLINK sequence + for name. + """ + name = name.encode(encoding, errors) + NUL + + info = {} + info["name"] = "././@LongLink" + info["type"] = type + info["size"] = len(name) + info["magic"] = GNU_MAGIC + + # create extended header + name blocks. + return cls._create_header(info, USTAR_FORMAT, encoding, errors) + \ + cls._create_payload(name) + + @classmethod + def _create_pax_generic_header(cls, pax_headers, type, encoding): + """Return a POSIX.1-2008 extended or global header sequence + that contains a list of keyword, value pairs. The values + must be strings. + """ + # Check if one of the fields contains surrogate characters and thereby + # forces hdrcharset=BINARY, see _proc_pax() for more information. + binary = False + for keyword, value in pax_headers.items(): + try: + value.encode("utf-8", "strict") + except UnicodeEncodeError: + binary = True + break + + records = b"" + if binary: + # Put the hdrcharset field at the beginning of the header. + records += b"21 hdrcharset=BINARY\n" + + for keyword, value in pax_headers.items(): + keyword = keyword.encode("utf-8") + if binary: + # Try to restore the original byte representation of `value'. + # Needless to say, that the encoding must match the string. + value = value.encode(encoding, "surrogateescape") + else: + value = value.encode("utf-8") + + l = len(keyword) + len(value) + 3 # ' ' + '=' + '\n' + n = p = 0 + while True: + n = l + len(str(p)) + if n == p: + break + p = n + records += bytes(str(p), "ascii") + b" " + keyword + b"=" + value + b"\n" + + # We use a hardcoded "././@PaxHeader" name like star does + # instead of the one that POSIX recommends. + info = {} + info["name"] = "././@PaxHeader" + info["type"] = type + info["size"] = len(records) + info["magic"] = POSIX_MAGIC + + # Create pax header + record blocks. + return cls._create_header(info, USTAR_FORMAT, "ascii", "replace") + \ + cls._create_payload(records) + + @classmethod + def frombuf(cls, buf, encoding, errors): + """Construct a TarInfo object from a 512 byte bytes object. + """ + if len(buf) == 0: + raise EmptyHeaderError("empty header") + if len(buf) != BLOCKSIZE: + raise TruncatedHeaderError("truncated header") + if buf.count(NUL) == BLOCKSIZE: + raise EOFHeaderError("end of file header") + + chksum = nti(buf[148:156]) + if chksum not in calc_chksums(buf): + raise InvalidHeaderError("bad checksum") + + obj = cls() + obj.name = nts(buf[0:100], encoding, errors) + obj.mode = nti(buf[100:108]) + obj.uid = nti(buf[108:116]) + obj.gid = nti(buf[116:124]) + obj.size = nti(buf[124:136]) + obj.mtime = nti(buf[136:148]) + obj.chksum = chksum + obj.type = buf[156:157] + obj.linkname = nts(buf[157:257], encoding, errors) + obj.uname = nts(buf[265:297], encoding, errors) + obj.gname = nts(buf[297:329], encoding, errors) + obj.devmajor = nti(buf[329:337]) + obj.devminor = nti(buf[337:345]) + prefix = nts(buf[345:500], encoding, errors) + + # Old V7 tar format represents a directory as a regular + # file with a trailing slash. + if obj.type == AREGTYPE and obj.name.endswith("/"): + obj.type = DIRTYPE + + # The old GNU sparse format occupies some of the unused + # space in the buffer for up to 4 sparse structures. + # Save them for later processing in _proc_sparse(). + if obj.type == GNUTYPE_SPARSE: + pos = 386 + structs = [] + for i in range(4): + try: + offset = nti(buf[pos:pos + 12]) + numbytes = nti(buf[pos + 12:pos + 24]) + except ValueError: + break + structs.append((offset, numbytes)) + pos += 24 + isextended = bool(buf[482]) + origsize = nti(buf[483:495]) + obj._sparse_structs = (structs, isextended, origsize) + + # Remove redundant slashes from directories. + if obj.isdir(): + obj.name = obj.name.rstrip("/") + + # Reconstruct a ustar longname. + if prefix and obj.type not in GNU_TYPES: + obj.name = prefix + "/" + obj.name + return obj + + @classmethod + def fromtarfile(cls, tarfile): + """Return the next TarInfo object from TarFile object + tarfile. + """ + buf = tarfile.fileobj.read(BLOCKSIZE) + obj = cls.frombuf(buf, tarfile.encoding, tarfile.errors) + obj.offset = tarfile.fileobj.tell() - BLOCKSIZE + return obj._proc_member(tarfile) + + #-------------------------------------------------------------------------- + # The following are methods that are called depending on the type of a + # member. The entry point is _proc_member() which can be overridden in a + # subclass to add custom _proc_*() methods. A _proc_*() method MUST + # implement the following + # operations: + # 1. Set self.offset_data to the position where the data blocks begin, + # if there is data that follows. + # 2. Set tarfile.offset to the position where the next member's header will + # begin. + # 3. Return self or another valid TarInfo object. + def _proc_member(self, tarfile): + """Choose the right processing method depending on + the type and call it. + """ + if self.type in (GNUTYPE_LONGNAME, GNUTYPE_LONGLINK): + return self._proc_gnulong(tarfile) + elif self.type == GNUTYPE_SPARSE: + return self._proc_sparse(tarfile) + elif self.type in (XHDTYPE, XGLTYPE, SOLARIS_XHDTYPE): + return self._proc_pax(tarfile) + else: + return self._proc_builtin(tarfile) + + def _proc_builtin(self, tarfile): + """Process a builtin type or an unknown type which + will be treated as a regular file. + """ + self.offset_data = tarfile.fileobj.tell() + offset = self.offset_data + if self.isreg() or self.type not in SUPPORTED_TYPES: + # Skip the following data blocks. + offset += self._block(self.size) + tarfile.offset = offset + + # Patch the TarInfo object with saved global + # header information. + self._apply_pax_info(tarfile.pax_headers, tarfile.encoding, tarfile.errors) + + # Remove redundant slashes from directories. This is to be consistent + # with frombuf(). + if self.isdir(): + self.name = self.name.rstrip("/") + + return self + + def _proc_gnulong(self, tarfile): + """Process the blocks that hold a GNU longname + or longlink member. + """ + buf = tarfile.fileobj.read(self._block(self.size)) + + # Fetch the next header and process it. + try: + next = self.fromtarfile(tarfile) + except HeaderError as e: + raise SubsequentHeaderError(str(e)) from None + + # Patch the TarInfo object from the next header with + # the longname information. + next.offset = self.offset + if self.type == GNUTYPE_LONGNAME: + next.name = nts(buf, tarfile.encoding, tarfile.errors) + elif self.type == GNUTYPE_LONGLINK: + next.linkname = nts(buf, tarfile.encoding, tarfile.errors) + + # Remove redundant slashes from directories. This is to be consistent + # with frombuf(). + if next.isdir(): + next.name = removesuffix(next.name, "/") + + return next + + def _proc_sparse(self, tarfile): + """Process a GNU sparse header plus extra headers. + """ + # We already collected some sparse structures in frombuf(). + structs, isextended, origsize = self._sparse_structs + del self._sparse_structs + + # Collect sparse structures from extended header blocks. + while isextended: + buf = tarfile.fileobj.read(BLOCKSIZE) + pos = 0 + for i in range(21): + try: + offset = nti(buf[pos:pos + 12]) + numbytes = nti(buf[pos + 12:pos + 24]) + except ValueError: + break + if offset and numbytes: + structs.append((offset, numbytes)) + pos += 24 + isextended = bool(buf[504]) + self.sparse = structs + + self.offset_data = tarfile.fileobj.tell() + tarfile.offset = self.offset_data + self._block(self.size) + self.size = origsize + return self + + def _proc_pax(self, tarfile): + """Process an extended or global header as described in + POSIX.1-2008. + """ + # Read the header information. + buf = tarfile.fileobj.read(self._block(self.size)) + + # A pax header stores supplemental information for either + # the following file (extended) or all following files + # (global). + if self.type == XGLTYPE: + pax_headers = tarfile.pax_headers + else: + pax_headers = tarfile.pax_headers.copy() + + # Check if the pax header contains a hdrcharset field. This tells us + # the encoding of the path, linkpath, uname and gname fields. Normally, + # these fields are UTF-8 encoded but since POSIX.1-2008 tar + # implementations are allowed to store them as raw binary strings if + # the translation to UTF-8 fails. + match = re.search(br"\d+ hdrcharset=([^\n]+)\n", buf) + if match is not None: + pax_headers["hdrcharset"] = match.group(1).decode("utf-8") + + # For the time being, we don't care about anything other than "BINARY". + # The only other value that is currently allowed by the standard is + # "ISO-IR 10646 2000 UTF-8" in other words UTF-8. + hdrcharset = pax_headers.get("hdrcharset") + if hdrcharset == "BINARY": + encoding = tarfile.encoding + else: + encoding = "utf-8" + + # Parse pax header information. A record looks like that: + # "%d %s=%s\n" % (length, keyword, value). length is the size + # of the complete record including the length field itself and + # the newline. keyword and value are both UTF-8 encoded strings. + regex = re.compile(br"(\d+) ([^=]+)=") + pos = 0 + while match := regex.match(buf, pos): + length, keyword = match.groups() + length = int(length) + if length == 0: + raise InvalidHeaderError("invalid header") + value = buf[match.end(2) + 1:match.start(1) + length - 1] + + # Normally, we could just use "utf-8" as the encoding and "strict" + # as the error handler, but we better not take the risk. For + # example, GNU tar <= 1.23 is known to store filenames it cannot + # translate to UTF-8 as raw strings (unfortunately without a + # hdrcharset=BINARY header). + # We first try the strict standard encoding, and if that fails we + # fall back on the user's encoding and error handler. + keyword = self._decode_pax_field(keyword, "utf-8", "utf-8", + tarfile.errors) + if keyword in PAX_NAME_FIELDS: + value = self._decode_pax_field(value, encoding, tarfile.encoding, + tarfile.errors) + else: + value = self._decode_pax_field(value, "utf-8", "utf-8", + tarfile.errors) + + pax_headers[keyword] = value + pos += length + + # Fetch the next header. + try: + next = self.fromtarfile(tarfile) + except HeaderError as e: + raise SubsequentHeaderError(str(e)) from None + + # Process GNU sparse information. + if "GNU.sparse.map" in pax_headers: + # GNU extended sparse format version 0.1. + self._proc_gnusparse_01(next, pax_headers) + + elif "GNU.sparse.size" in pax_headers: + # GNU extended sparse format version 0.0. + self._proc_gnusparse_00(next, pax_headers, buf) + + elif pax_headers.get("GNU.sparse.major") == "1" and pax_headers.get("GNU.sparse.minor") == "0": + # GNU extended sparse format version 1.0. + self._proc_gnusparse_10(next, pax_headers, tarfile) + + if self.type in (XHDTYPE, SOLARIS_XHDTYPE): + # Patch the TarInfo object with the extended header info. + next._apply_pax_info(pax_headers, tarfile.encoding, tarfile.errors) + next.offset = self.offset + + if "size" in pax_headers: + # If the extended header replaces the size field, + # we need to recalculate the offset where the next + # header starts. + offset = next.offset_data + if next.isreg() or next.type not in SUPPORTED_TYPES: + offset += next._block(next.size) + tarfile.offset = offset + + return next + + def _proc_gnusparse_00(self, next, pax_headers, buf): + """Process a GNU tar extended sparse header, version 0.0. + """ + offsets = [] + for match in re.finditer(br"\d+ GNU.sparse.offset=(\d+)\n", buf): + offsets.append(int(match.group(1))) + numbytes = [] + for match in re.finditer(br"\d+ GNU.sparse.numbytes=(\d+)\n", buf): + numbytes.append(int(match.group(1))) + next.sparse = list(zip(offsets, numbytes)) + + def _proc_gnusparse_01(self, next, pax_headers): + """Process a GNU tar extended sparse header, version 0.1. + """ + sparse = [int(x) for x in pax_headers["GNU.sparse.map"].split(",")] + next.sparse = list(zip(sparse[::2], sparse[1::2])) + + def _proc_gnusparse_10(self, next, pax_headers, tarfile): + """Process a GNU tar extended sparse header, version 1.0. + """ + fields = None + sparse = [] + buf = tarfile.fileobj.read(BLOCKSIZE) + fields, buf = buf.split(b"\n", 1) + fields = int(fields) + while len(sparse) < fields * 2: + if b"\n" not in buf: + buf += tarfile.fileobj.read(BLOCKSIZE) + number, buf = buf.split(b"\n", 1) + sparse.append(int(number)) + next.offset_data = tarfile.fileobj.tell() + next.sparse = list(zip(sparse[::2], sparse[1::2])) + + def _apply_pax_info(self, pax_headers, encoding, errors): + """Replace fields with supplemental information from a previous + pax extended or global header. + """ + for keyword, value in pax_headers.items(): + if keyword == "GNU.sparse.name": + setattr(self, "path", value) + elif keyword == "GNU.sparse.size": + setattr(self, "size", int(value)) + elif keyword == "GNU.sparse.realsize": + setattr(self, "size", int(value)) + elif keyword in PAX_FIELDS: + if keyword in PAX_NUMBER_FIELDS: + try: + value = PAX_NUMBER_FIELDS[keyword](value) + except ValueError: + value = 0 + if keyword == "path": + value = value.rstrip("/") + setattr(self, keyword, value) + + self.pax_headers = pax_headers.copy() + + def _decode_pax_field(self, value, encoding, fallback_encoding, fallback_errors): + """Decode a single field from a pax record. + """ + try: + return value.decode(encoding, "strict") + except UnicodeDecodeError: + return value.decode(fallback_encoding, fallback_errors) + + def _block(self, count): + """Round up a byte count by BLOCKSIZE and return it, + e.g. _block(834) => 1024. + """ + blocks, remainder = divmod(count, BLOCKSIZE) + if remainder: + blocks += 1 + return blocks * BLOCKSIZE + + def isreg(self): + 'Return True if the Tarinfo object is a regular file.' + return self.type in REGULAR_TYPES + + def isfile(self): + 'Return True if the Tarinfo object is a regular file.' + return self.isreg() + + def isdir(self): + 'Return True if it is a directory.' + return self.type == DIRTYPE + + def issym(self): + 'Return True if it is a symbolic link.' + return self.type == SYMTYPE + + def islnk(self): + 'Return True if it is a hard link.' + return self.type == LNKTYPE + + def ischr(self): + 'Return True if it is a character device.' + return self.type == CHRTYPE + + def isblk(self): + 'Return True if it is a block device.' + return self.type == BLKTYPE + + def isfifo(self): + 'Return True if it is a FIFO.' + return self.type == FIFOTYPE + + def issparse(self): + return self.sparse is not None + + def isdev(self): + 'Return True if it is one of character device, block device or FIFO.' + return self.type in (CHRTYPE, BLKTYPE, FIFOTYPE) +# class TarInfo + +class TarFile(object): + """The TarFile Class provides an interface to tar archives. + """ + + debug = 0 # May be set from 0 (no msgs) to 3 (all msgs) + + dereference = False # If true, add content of linked file to the + # tar file, else the link. + + ignore_zeros = False # If true, skips empty or invalid blocks and + # continues processing. + + errorlevel = 1 # If 0, fatal errors only appear in debug + # messages (if debug >= 0). If > 0, errors + # are passed to the caller as exceptions. + + format = DEFAULT_FORMAT # The format to use when creating an archive. + + encoding = ENCODING # Encoding for 8-bit character strings. + + errors = None # Error handler for unicode conversion. + + tarinfo = TarInfo # The default TarInfo class to use. + + fileobject = ExFileObject # The file-object for extractfile(). + + extraction_filter = None # The default filter for extraction. + + def __init__(self, name=None, mode="r", fileobj=None, format=None, + tarinfo=None, dereference=None, ignore_zeros=None, encoding=None, + errors="surrogateescape", pax_headers=None, debug=None, + errorlevel=None, copybufsize=None): + """Open an (uncompressed) tar archive `name'. `mode' is either 'r' to + read from an existing archive, 'a' to append data to an existing + file or 'w' to create a new file overwriting an existing one. `mode' + defaults to 'r'. + If `fileobj' is given, it is used for reading or writing data. If it + can be determined, `mode' is overridden by `fileobj's mode. + `fileobj' is not closed, when TarFile is closed. + """ + modes = {"r": "rb", "a": "r+b", "w": "wb", "x": "xb"} + if mode not in modes: + raise ValueError("mode must be 'r', 'a', 'w' or 'x'") + self.mode = mode + self._mode = modes[mode] + + if not fileobj: + if self.mode == "a" and not os.path.exists(name): + # Create nonexistent files in append mode. + self.mode = "w" + self._mode = "wb" + fileobj = bltn_open(name, self._mode) + self._extfileobj = False + else: + if (name is None and hasattr(fileobj, "name") and + isinstance(fileobj.name, (str, bytes))): + name = fileobj.name + if hasattr(fileobj, "mode"): + self._mode = fileobj.mode + self._extfileobj = True + self.name = os.path.abspath(name) if name else None + self.fileobj = fileobj + + # Init attributes. + if format is not None: + self.format = format + if tarinfo is not None: + self.tarinfo = tarinfo + if dereference is not None: + self.dereference = dereference + if ignore_zeros is not None: + self.ignore_zeros = ignore_zeros + if encoding is not None: + self.encoding = encoding + self.errors = errors + + if pax_headers is not None and self.format == PAX_FORMAT: + self.pax_headers = pax_headers + else: + self.pax_headers = {} + + if debug is not None: + self.debug = debug + if errorlevel is not None: + self.errorlevel = errorlevel + + # Init datastructures. + self.copybufsize = copybufsize + self.closed = False + self.members = [] # list of members as TarInfo objects + self._loaded = False # flag if all members have been read + self.offset = self.fileobj.tell() + # current position in the archive file + self.inodes = {} # dictionary caching the inodes of + # archive members already added + + try: + if self.mode == "r": + self.firstmember = None + self.firstmember = self.next() + + if self.mode == "a": + # Move to the end of the archive, + # before the first empty block. + while True: + self.fileobj.seek(self.offset) + try: + tarinfo = self.tarinfo.fromtarfile(self) + self.members.append(tarinfo) + except EOFHeaderError: + self.fileobj.seek(self.offset) + break + except HeaderError as e: + raise ReadError(str(e)) from None + + if self.mode in ("a", "w", "x"): + self._loaded = True + + if self.pax_headers: + buf = self.tarinfo.create_pax_global_header(self.pax_headers.copy()) + self.fileobj.write(buf) + self.offset += len(buf) + except: + if not self._extfileobj: + self.fileobj.close() + self.closed = True + raise + + #-------------------------------------------------------------------------- + # Below are the classmethods which act as alternate constructors to the + # TarFile class. The open() method is the only one that is needed for + # public use; it is the "super"-constructor and is able to select an + # adequate "sub"-constructor for a particular compression using the mapping + # from OPEN_METH. + # + # This concept allows one to subclass TarFile without losing the comfort of + # the super-constructor. A sub-constructor is registered and made available + # by adding it to the mapping in OPEN_METH. + + @classmethod + def open(cls, name=None, mode="r", fileobj=None, bufsize=RECORDSIZE, **kwargs): + r"""Open a tar archive for reading, writing or appending. Return + an appropriate TarFile class. + + mode: + 'r' or 'r:\*' open for reading with transparent compression + 'r:' open for reading exclusively uncompressed + 'r:gz' open for reading with gzip compression + 'r:bz2' open for reading with bzip2 compression + 'r:xz' open for reading with lzma compression + 'a' or 'a:' open for appending, creating the file if necessary + 'w' or 'w:' open for writing without compression + 'w:gz' open for writing with gzip compression + 'w:bz2' open for writing with bzip2 compression + 'w:xz' open for writing with lzma compression + + 'x' or 'x:' create a tarfile exclusively without compression, raise + an exception if the file is already created + 'x:gz' create a gzip compressed tarfile, raise an exception + if the file is already created + 'x:bz2' create a bzip2 compressed tarfile, raise an exception + if the file is already created + 'x:xz' create an lzma compressed tarfile, raise an exception + if the file is already created + + 'r|\*' open a stream of tar blocks with transparent compression + 'r|' open an uncompressed stream of tar blocks for reading + 'r|gz' open a gzip compressed stream of tar blocks + 'r|bz2' open a bzip2 compressed stream of tar blocks + 'r|xz' open an lzma compressed stream of tar blocks + 'w|' open an uncompressed stream for writing + 'w|gz' open a gzip compressed stream for writing + 'w|bz2' open a bzip2 compressed stream for writing + 'w|xz' open an lzma compressed stream for writing + """ + + if not name and not fileobj: + raise ValueError("nothing to open") + + if mode in ("r", "r:*"): + # Find out which *open() is appropriate for opening the file. + def not_compressed(comptype): + return cls.OPEN_METH[comptype] == 'taropen' + error_msgs = [] + for comptype in sorted(cls.OPEN_METH, key=not_compressed): + func = getattr(cls, cls.OPEN_METH[comptype]) + if fileobj is not None: + saved_pos = fileobj.tell() + try: + return func(name, "r", fileobj, **kwargs) + except (ReadError, CompressionError) as e: + error_msgs.append(f'- method {comptype}: {e!r}') + if fileobj is not None: + fileobj.seek(saved_pos) + continue + error_msgs_summary = '\n'.join(error_msgs) + raise ReadError(f"file could not be opened successfully:\n{error_msgs_summary}") + + elif ":" in mode: + filemode, comptype = mode.split(":", 1) + filemode = filemode or "r" + comptype = comptype or "tar" + + # Select the *open() function according to + # given compression. + if comptype in cls.OPEN_METH: + func = getattr(cls, cls.OPEN_METH[comptype]) + else: + raise CompressionError("unknown compression type %r" % comptype) + return func(name, filemode, fileobj, **kwargs) + + elif "|" in mode: + filemode, comptype = mode.split("|", 1) + filemode = filemode or "r" + comptype = comptype or "tar" + + if filemode not in ("r", "w"): + raise ValueError("mode must be 'r' or 'w'") + + compresslevel = kwargs.pop("compresslevel", 9) + stream = _Stream(name, filemode, comptype, fileobj, bufsize, + compresslevel) + try: + t = cls(name, filemode, stream, **kwargs) + except: + stream.close() + raise + t._extfileobj = False + return t + + elif mode in ("a", "w", "x"): + return cls.taropen(name, mode, fileobj, **kwargs) + + raise ValueError("undiscernible mode") + + @classmethod + def taropen(cls, name, mode="r", fileobj=None, **kwargs): + """Open uncompressed tar archive name for reading or writing. + """ + if mode not in ("r", "a", "w", "x"): + raise ValueError("mode must be 'r', 'a', 'w' or 'x'") + return cls(name, mode, fileobj, **kwargs) + + @classmethod + def gzopen(cls, name, mode="r", fileobj=None, compresslevel=9, **kwargs): + """Open gzip compressed tar archive name for reading or writing. + Appending is not allowed. + """ + if mode not in ("r", "w", "x"): + raise ValueError("mode must be 'r', 'w' or 'x'") + + try: + from gzip import GzipFile + except ImportError: + raise CompressionError("gzip module is not available") from None + + try: + fileobj = GzipFile(name, mode + "b", compresslevel, fileobj) + except OSError as e: + if fileobj is not None and mode == 'r': + raise ReadError("not a gzip file") from e + raise + + try: + t = cls.taropen(name, mode, fileobj, **kwargs) + except OSError as e: + fileobj.close() + if mode == 'r': + raise ReadError("not a gzip file") from e + raise + except: + fileobj.close() + raise + t._extfileobj = False + return t + + @classmethod + def bz2open(cls, name, mode="r", fileobj=None, compresslevel=9, **kwargs): + """Open bzip2 compressed tar archive name for reading or writing. + Appending is not allowed. + """ + if mode not in ("r", "w", "x"): + raise ValueError("mode must be 'r', 'w' or 'x'") + + try: + from bz2 import BZ2File + except ImportError: + raise CompressionError("bz2 module is not available") from None + + fileobj = BZ2File(fileobj or name, mode, compresslevel=compresslevel) + + try: + t = cls.taropen(name, mode, fileobj, **kwargs) + except (OSError, EOFError) as e: + fileobj.close() + if mode == 'r': + raise ReadError("not a bzip2 file") from e + raise + except: + fileobj.close() + raise + t._extfileobj = False + return t + + @classmethod + def xzopen(cls, name, mode="r", fileobj=None, preset=None, **kwargs): + """Open lzma compressed tar archive name for reading or writing. + Appending is not allowed. + """ + if mode not in ("r", "w", "x"): + raise ValueError("mode must be 'r', 'w' or 'x'") + + try: + from lzma import LZMAFile, LZMAError + except ImportError: + raise CompressionError("lzma module is not available") from None + + fileobj = LZMAFile(fileobj or name, mode, preset=preset) + + try: + t = cls.taropen(name, mode, fileobj, **kwargs) + except (LZMAError, EOFError) as e: + fileobj.close() + if mode == 'r': + raise ReadError("not an lzma file") from e + raise + except: + fileobj.close() + raise + t._extfileobj = False + return t + + # All *open() methods are registered here. + OPEN_METH = { + "tar": "taropen", # uncompressed tar + "gz": "gzopen", # gzip compressed tar + "bz2": "bz2open", # bzip2 compressed tar + "xz": "xzopen" # lzma compressed tar + } + + #-------------------------------------------------------------------------- + # The public methods which TarFile provides: + + def close(self): + """Close the TarFile. In write-mode, two finishing zero blocks are + appended to the archive. + """ + if self.closed: + return + + self.closed = True + try: + if self.mode in ("a", "w", "x"): + self.fileobj.write(NUL * (BLOCKSIZE * 2)) + self.offset += (BLOCKSIZE * 2) + # fill up the end with zero-blocks + # (like option -b20 for tar does) + blocks, remainder = divmod(self.offset, RECORDSIZE) + if remainder > 0: + self.fileobj.write(NUL * (RECORDSIZE - remainder)) + finally: + if not self._extfileobj: + self.fileobj.close() + + def getmember(self, name): + """Return a TarInfo object for member ``name``. If ``name`` can not be + found in the archive, KeyError is raised. If a member occurs more + than once in the archive, its last occurrence is assumed to be the + most up-to-date version. + """ + tarinfo = self._getmember(name.rstrip('/')) + if tarinfo is None: + raise KeyError("filename %r not found" % name) + return tarinfo + + def getmembers(self): + """Return the members of the archive as a list of TarInfo objects. The + list has the same order as the members in the archive. + """ + self._check() + if not self._loaded: # if we want to obtain a list of + self._load() # all members, we first have to + # scan the whole archive. + return self.members + + def getnames(self): + """Return the members of the archive as a list of their names. It has + the same order as the list returned by getmembers(). + """ + return [tarinfo.name for tarinfo in self.getmembers()] + + def gettarinfo(self, name=None, arcname=None, fileobj=None): + """Create a TarInfo object from the result of os.stat or equivalent + on an existing file. The file is either named by ``name``, or + specified as a file object ``fileobj`` with a file descriptor. If + given, ``arcname`` specifies an alternative name for the file in the + archive, otherwise, the name is taken from the 'name' attribute of + 'fileobj', or the 'name' argument. The name should be a text + string. + """ + self._check("awx") + + # When fileobj is given, replace name by + # fileobj's real name. + if fileobj is not None: + name = fileobj.name + + # Building the name of the member in the archive. + # Backward slashes are converted to forward slashes, + # Absolute paths are turned to relative paths. + if arcname is None: + arcname = name + drv, arcname = os.path.splitdrive(arcname) + arcname = arcname.replace(os.sep, "/") + arcname = arcname.lstrip("/") + + # Now, fill the TarInfo object with + # information specific for the file. + tarinfo = self.tarinfo() + tarinfo.tarfile = self # Not needed + + # Use os.stat or os.lstat, depending on if symlinks shall be resolved. + if fileobj is None: + if not self.dereference: + statres = os.lstat(name) + else: + statres = os.stat(name) + else: + statres = os.fstat(fileobj.fileno()) + linkname = "" + + stmd = statres.st_mode + if stat.S_ISREG(stmd): + inode = (statres.st_ino, statres.st_dev) + if not self.dereference and statres.st_nlink > 1 and \ + inode in self.inodes and arcname != self.inodes[inode]: + # Is it a hardlink to an already + # archived file? + type = LNKTYPE + linkname = self.inodes[inode] + else: + # The inode is added only if its valid. + # For win32 it is always 0. + type = REGTYPE + if inode[0]: + self.inodes[inode] = arcname + elif stat.S_ISDIR(stmd): + type = DIRTYPE + elif stat.S_ISFIFO(stmd): + type = FIFOTYPE + elif stat.S_ISLNK(stmd): + type = SYMTYPE + linkname = os.readlink(name) + elif stat.S_ISCHR(stmd): + type = CHRTYPE + elif stat.S_ISBLK(stmd): + type = BLKTYPE + else: + return None + + # Fill the TarInfo object with all + # information we can get. + tarinfo.name = arcname + tarinfo.mode = stmd + tarinfo.uid = statres.st_uid + tarinfo.gid = statres.st_gid + if type == REGTYPE: + tarinfo.size = statres.st_size + else: + tarinfo.size = 0 + tarinfo.mtime = statres.st_mtime + tarinfo.type = type + tarinfo.linkname = linkname + if pwd: + try: + tarinfo.uname = pwd.getpwuid(tarinfo.uid)[0] + except KeyError: + pass + if grp: + try: + tarinfo.gname = grp.getgrgid(tarinfo.gid)[0] + except KeyError: + pass + + if type in (CHRTYPE, BLKTYPE): + if hasattr(os, "major") and hasattr(os, "minor"): + tarinfo.devmajor = os.major(statres.st_rdev) + tarinfo.devminor = os.minor(statres.st_rdev) + return tarinfo + + def list(self, verbose=True, *, members=None): + """Print a table of contents to sys.stdout. If ``verbose`` is False, only + the names of the members are printed. If it is True, an `ls -l'-like + output is produced. ``members`` is optional and must be a subset of the + list returned by getmembers(). + """ + self._check() + + if members is None: + members = self + for tarinfo in members: + if verbose: + if tarinfo.mode is None: + _safe_print("??????????") + else: + _safe_print(stat.filemode(tarinfo.mode)) + _safe_print("%s/%s" % (tarinfo.uname or tarinfo.uid, + tarinfo.gname or tarinfo.gid)) + if tarinfo.ischr() or tarinfo.isblk(): + _safe_print("%10s" % + ("%d,%d" % (tarinfo.devmajor, tarinfo.devminor))) + else: + _safe_print("%10d" % tarinfo.size) + if tarinfo.mtime is None: + _safe_print("????-??-?? ??:??:??") + else: + _safe_print("%d-%02d-%02d %02d:%02d:%02d" \ + % time.localtime(tarinfo.mtime)[:6]) + + _safe_print(tarinfo.name + ("/" if tarinfo.isdir() else "")) + + if verbose: + if tarinfo.issym(): + _safe_print("-> " + tarinfo.linkname) + if tarinfo.islnk(): + _safe_print("link to " + tarinfo.linkname) + print() + + def add(self, name, arcname=None, recursive=True, *, filter=None): + """Add the file ``name`` to the archive. ``name`` may be any type of file + (directory, fifo, symbolic link, etc.). If given, ``arcname`` + specifies an alternative name for the file in the archive. + Directories are added recursively by default. This can be avoided by + setting ``recursive`` to False. ``filter`` is a function + that expects a TarInfo object argument and returns the changed + TarInfo object, if it returns None the TarInfo object will be + excluded from the archive. + """ + self._check("awx") + + if arcname is None: + arcname = name + + # Skip if somebody tries to archive the archive... + if self.name is not None and os.path.abspath(name) == self.name: + self._dbg(2, "tarfile: Skipped %r" % name) + return + + self._dbg(1, name) + + # Create a TarInfo object from the file. + tarinfo = self.gettarinfo(name, arcname) + + if tarinfo is None: + self._dbg(1, "tarfile: Unsupported type %r" % name) + return + + # Change or exclude the TarInfo object. + if filter is not None: + tarinfo = filter(tarinfo) + if tarinfo is None: + self._dbg(2, "tarfile: Excluded %r" % name) + return + + # Append the tar header and data to the archive. + if tarinfo.isreg(): + with bltn_open(name, "rb") as f: + self.addfile(tarinfo, f) + + elif tarinfo.isdir(): + self.addfile(tarinfo) + if recursive: + for f in sorted(os.listdir(name)): + self.add(os.path.join(name, f), os.path.join(arcname, f), + recursive, filter=filter) + + else: + self.addfile(tarinfo) + + def addfile(self, tarinfo, fileobj=None): + """Add the TarInfo object ``tarinfo`` to the archive. If ``fileobj`` is + given, it should be a binary file, and tarinfo.size bytes are read + from it and added to the archive. You can create TarInfo objects + directly, or by using gettarinfo(). + """ + self._check("awx") + + tarinfo = copy.copy(tarinfo) + + buf = tarinfo.tobuf(self.format, self.encoding, self.errors) + self.fileobj.write(buf) + self.offset += len(buf) + bufsize=self.copybufsize + # If there's data to follow, append it. + if fileobj is not None: + copyfileobj(fileobj, self.fileobj, tarinfo.size, bufsize=bufsize) + blocks, remainder = divmod(tarinfo.size, BLOCKSIZE) + if remainder > 0: + self.fileobj.write(NUL * (BLOCKSIZE - remainder)) + blocks += 1 + self.offset += blocks * BLOCKSIZE + + self.members.append(tarinfo) + + def _get_filter_function(self, filter): + if filter is None: + filter = self.extraction_filter + if filter is None: + warnings.warn( + 'Python 3.14 will, by default, filter extracted tar ' + + 'archives and reject files or modify their metadata. ' + + 'Use the filter argument to control this behavior.', + DeprecationWarning) + return fully_trusted_filter + if isinstance(filter, str): + raise TypeError( + 'String names are not supported for ' + + 'TarFile.extraction_filter. Use a function such as ' + + 'tarfile.data_filter directly.') + return filter + if callable(filter): + return filter + try: + return _NAMED_FILTERS[filter] + except KeyError: + raise ValueError(f"filter {filter!r} not found") from None + + def extractall(self, path=".", members=None, *, numeric_owner=False, + filter=None): + """Extract all members from the archive to the current working + directory and set owner, modification time and permissions on + directories afterwards. `path' specifies a different directory + to extract to. `members' is optional and must be a subset of the + list returned by getmembers(). If `numeric_owner` is True, only + the numbers for user/group names are used and not the names. + + The `filter` function will be called on each member just + before extraction. + It can return a changed TarInfo or None to skip the member. + String names of common filters are accepted. + """ + directories = [] + + filter_function = self._get_filter_function(filter) + if members is None: + members = self + + for member in members: + tarinfo = self._get_extract_tarinfo(member, filter_function, path) + if tarinfo is None: + continue + if tarinfo.isdir(): + # For directories, delay setting attributes until later, + # since permissions can interfere with extraction and + # extracting contents can reset mtime. + directories.append(tarinfo) + self._extract_one(tarinfo, path, set_attrs=not tarinfo.isdir(), + numeric_owner=numeric_owner) + + # Reverse sort directories. + directories.sort(key=lambda a: a.name, reverse=True) + + # Set correct owner, mtime and filemode on directories. + for tarinfo in directories: + dirpath = os.path.join(path, tarinfo.name) + try: + self.chown(tarinfo, dirpath, numeric_owner=numeric_owner) + self.utime(tarinfo, dirpath) + self.chmod(tarinfo, dirpath) + except ExtractError as e: + self._handle_nonfatal_error(e) + + def extract(self, member, path="", set_attrs=True, *, numeric_owner=False, + filter=None): + """Extract a member from the archive to the current working directory, + using its full name. Its file information is extracted as accurately + as possible. `member' may be a filename or a TarInfo object. You can + specify a different directory using `path'. File attributes (owner, + mtime, mode) are set unless `set_attrs' is False. If `numeric_owner` + is True, only the numbers for user/group names are used and not + the names. + + The `filter` function will be called before extraction. + It can return a changed TarInfo or None to skip the member. + String names of common filters are accepted. + """ + filter_function = self._get_filter_function(filter) + tarinfo = self._get_extract_tarinfo(member, filter_function, path) + if tarinfo is not None: + self._extract_one(tarinfo, path, set_attrs, numeric_owner) + + def _get_extract_tarinfo(self, member, filter_function, path): + """Get filtered TarInfo (or None) from member, which might be a str""" + if isinstance(member, str): + tarinfo = self.getmember(member) + else: + tarinfo = member + + unfiltered = tarinfo + try: + tarinfo = filter_function(tarinfo, path) + except (OSError, FilterError) as e: + self._handle_fatal_error(e) + except ExtractError as e: + self._handle_nonfatal_error(e) + if tarinfo is None: + self._dbg(2, "tarfile: Excluded %r" % unfiltered.name) + return None + # Prepare the link target for makelink(). + if tarinfo.islnk(): + tarinfo = copy.copy(tarinfo) + tarinfo._link_target = os.path.join(path, tarinfo.linkname) + return tarinfo + + def _extract_one(self, tarinfo, path, set_attrs, numeric_owner): + """Extract from filtered tarinfo to disk""" + self._check("r") + + try: + self._extract_member(tarinfo, os.path.join(path, tarinfo.name), + set_attrs=set_attrs, + numeric_owner=numeric_owner) + except OSError as e: + self._handle_fatal_error(e) + except ExtractError as e: + self._handle_nonfatal_error(e) + + def _handle_nonfatal_error(self, e): + """Handle non-fatal error (ExtractError) according to errorlevel""" + if self.errorlevel > 1: + raise + else: + self._dbg(1, "tarfile: %s" % e) + + def _handle_fatal_error(self, e): + """Handle "fatal" error according to self.errorlevel""" + if self.errorlevel > 0: + raise + elif isinstance(e, OSError): + if e.filename is None: + self._dbg(1, "tarfile: %s" % e.strerror) + else: + self._dbg(1, "tarfile: %s %r" % (e.strerror, e.filename)) + else: + self._dbg(1, "tarfile: %s %s" % (type(e).__name__, e)) + + def extractfile(self, member): + """Extract a member from the archive as a file object. ``member`` may be + a filename or a TarInfo object. If ``member`` is a regular file or + a link, an io.BufferedReader object is returned. For all other + existing members, None is returned. If ``member`` does not appear + in the archive, KeyError is raised. + """ + self._check("r") + + if isinstance(member, str): + tarinfo = self.getmember(member) + else: + tarinfo = member + + if tarinfo.isreg() or tarinfo.type not in SUPPORTED_TYPES: + # Members with unknown types are treated as regular files. + return self.fileobject(self, tarinfo) + + elif tarinfo.islnk() or tarinfo.issym(): + if isinstance(self.fileobj, _Stream): + # A small but ugly workaround for the case that someone tries + # to extract a (sym)link as a file-object from a non-seekable + # stream of tar blocks. + raise StreamError("cannot extract (sym)link as file object") + else: + # A (sym)link's file object is its target's file object. + return self.extractfile(self._find_link_target(tarinfo)) + else: + # If there's no data associated with the member (directory, chrdev, + # blkdev, etc.), return None instead of a file object. + return None + + def _extract_member(self, tarinfo, targetpath, set_attrs=True, + numeric_owner=False): + """Extract the TarInfo object tarinfo to a physical + file called targetpath. + """ + # Fetch the TarInfo object for the given name + # and build the destination pathname, replacing + # forward slashes to platform specific separators. + targetpath = targetpath.rstrip("/") + targetpath = targetpath.replace("/", os.sep) + + # Create all upper directories. + upperdirs = os.path.dirname(targetpath) + if upperdirs and not os.path.exists(upperdirs): + # Create directories that are not part of the archive with + # default permissions. + os.makedirs(upperdirs) + + if tarinfo.islnk() or tarinfo.issym(): + self._dbg(1, "%s -> %s" % (tarinfo.name, tarinfo.linkname)) + else: + self._dbg(1, tarinfo.name) + + if tarinfo.isreg(): + self.makefile(tarinfo, targetpath) + elif tarinfo.isdir(): + self.makedir(tarinfo, targetpath) + elif tarinfo.isfifo(): + self.makefifo(tarinfo, targetpath) + elif tarinfo.ischr() or tarinfo.isblk(): + self.makedev(tarinfo, targetpath) + elif tarinfo.islnk() or tarinfo.issym(): + self.makelink(tarinfo, targetpath) + elif tarinfo.type not in SUPPORTED_TYPES: + self.makeunknown(tarinfo, targetpath) + else: + self.makefile(tarinfo, targetpath) + + if set_attrs: + self.chown(tarinfo, targetpath, numeric_owner) + if not tarinfo.issym(): + self.chmod(tarinfo, targetpath) + self.utime(tarinfo, targetpath) + + #-------------------------------------------------------------------------- + # Below are the different file methods. They are called via + # _extract_member() when extract() is called. They can be replaced in a + # subclass to implement other functionality. + + def makedir(self, tarinfo, targetpath): + """Make a directory called targetpath. + """ + try: + if tarinfo.mode is None: + # Use the system's default mode + os.mkdir(targetpath) + else: + # Use a safe mode for the directory, the real mode is set + # later in _extract_member(). + os.mkdir(targetpath, 0o700) + except FileExistsError: + if not os.path.isdir(targetpath): + raise + + def makefile(self, tarinfo, targetpath): + """Make a file called targetpath. + """ + source = self.fileobj + source.seek(tarinfo.offset_data) + bufsize = self.copybufsize + with bltn_open(targetpath, "wb") as target: + if tarinfo.sparse is not None: + for offset, size in tarinfo.sparse: + target.seek(offset) + copyfileobj(source, target, size, ReadError, bufsize) + target.seek(tarinfo.size) + target.truncate() + else: + copyfileobj(source, target, tarinfo.size, ReadError, bufsize) + + def makeunknown(self, tarinfo, targetpath): + """Make a file from a TarInfo object with an unknown type + at targetpath. + """ + self.makefile(tarinfo, targetpath) + self._dbg(1, "tarfile: Unknown file type %r, " \ + "extracted as regular file." % tarinfo.type) + + def makefifo(self, tarinfo, targetpath): + """Make a fifo called targetpath. + """ + if hasattr(os, "mkfifo"): + os.mkfifo(targetpath) + else: + raise ExtractError("fifo not supported by system") + + def makedev(self, tarinfo, targetpath): + """Make a character or block device called targetpath. + """ + if not hasattr(os, "mknod") or not hasattr(os, "makedev"): + raise ExtractError("special devices not supported by system") + + mode = tarinfo.mode + if mode is None: + # Use mknod's default + mode = 0o600 + if tarinfo.isblk(): + mode |= stat.S_IFBLK + else: + mode |= stat.S_IFCHR + + os.mknod(targetpath, mode, + os.makedev(tarinfo.devmajor, tarinfo.devminor)) + + def makelink(self, tarinfo, targetpath): + """Make a (symbolic) link called targetpath. If it cannot be created + (platform limitation), we try to make a copy of the referenced file + instead of a link. + """ + try: + # For systems that support symbolic and hard links. + if tarinfo.issym(): + if os.path.lexists(targetpath): + # Avoid FileExistsError on following os.symlink. + os.unlink(targetpath) + os.symlink(tarinfo.linkname, targetpath) + else: + if os.path.exists(tarinfo._link_target): + os.link(tarinfo._link_target, targetpath) + else: + self._extract_member(self._find_link_target(tarinfo), + targetpath) + except symlink_exception: + try: + self._extract_member(self._find_link_target(tarinfo), + targetpath) + except KeyError: + raise ExtractError("unable to resolve link inside archive") from None + + def chown(self, tarinfo, targetpath, numeric_owner): + """Set owner of targetpath according to tarinfo. If numeric_owner + is True, use .gid/.uid instead of .gname/.uname. If numeric_owner + is False, fall back to .gid/.uid when the search based on name + fails. + """ + if hasattr(os, "geteuid") and os.geteuid() == 0: + # We have to be root to do so. + g = tarinfo.gid + u = tarinfo.uid + if not numeric_owner: + try: + if grp and tarinfo.gname: + g = grp.getgrnam(tarinfo.gname)[2] + except KeyError: + pass + try: + if pwd and tarinfo.uname: + u = pwd.getpwnam(tarinfo.uname)[2] + except KeyError: + pass + if g is None: + g = -1 + if u is None: + u = -1 + try: + if tarinfo.issym() and hasattr(os, "lchown"): + os.lchown(targetpath, u, g) + else: + os.chown(targetpath, u, g) + except OSError as e: + raise ExtractError("could not change owner") from e + + def chmod(self, tarinfo, targetpath): + """Set file permissions of targetpath according to tarinfo. + """ + if tarinfo.mode is None: + return + try: + os.chmod(targetpath, tarinfo.mode) + except OSError as e: + raise ExtractError("could not change mode") from e + + def utime(self, tarinfo, targetpath): + """Set modification time of targetpath according to tarinfo. + """ + mtime = tarinfo.mtime + if mtime is None: + return + if not hasattr(os, 'utime'): + return + try: + os.utime(targetpath, (mtime, mtime)) + except OSError as e: + raise ExtractError("could not change modification time") from e + + #-------------------------------------------------------------------------- + def next(self): + """Return the next member of the archive as a TarInfo object, when + TarFile is opened for reading. Return None if there is no more + available. + """ + self._check("ra") + if self.firstmember is not None: + m = self.firstmember + self.firstmember = None + return m + + # Advance the file pointer. + if self.offset != self.fileobj.tell(): + if self.offset == 0: + return None + self.fileobj.seek(self.offset - 1) + if not self.fileobj.read(1): + raise ReadError("unexpected end of data") + + # Read the next block. + tarinfo = None + while True: + try: + tarinfo = self.tarinfo.fromtarfile(self) + except EOFHeaderError as e: + if self.ignore_zeros: + self._dbg(2, "0x%X: %s" % (self.offset, e)) + self.offset += BLOCKSIZE + continue + except InvalidHeaderError as e: + if self.ignore_zeros: + self._dbg(2, "0x%X: %s" % (self.offset, e)) + self.offset += BLOCKSIZE + continue + elif self.offset == 0: + raise ReadError(str(e)) from None + except EmptyHeaderError: + if self.offset == 0: + raise ReadError("empty file") from None + except TruncatedHeaderError as e: + if self.offset == 0: + raise ReadError(str(e)) from None + except SubsequentHeaderError as e: + raise ReadError(str(e)) from None + except Exception as e: + try: + import zlib + if isinstance(e, zlib.error): + raise ReadError(f'zlib error: {e}') from None + else: + raise e + except ImportError: + raise e + break + + if tarinfo is not None: + self.members.append(tarinfo) + else: + self._loaded = True + + return tarinfo + + #-------------------------------------------------------------------------- + # Little helper methods: + + def _getmember(self, name, tarinfo=None, normalize=False): + """Find an archive member by name from bottom to top. + If tarinfo is given, it is used as the starting point. + """ + # Ensure that all members have been loaded. + members = self.getmembers() + + # Limit the member search list up to tarinfo. + skipping = False + if tarinfo is not None: + try: + index = members.index(tarinfo) + except ValueError: + # The given starting point might be a (modified) copy. + # We'll later skip members until we find an equivalent. + skipping = True + else: + # Happy fast path + members = members[:index] + + if normalize: + name = os.path.normpath(name) + + for member in reversed(members): + if skipping: + if tarinfo.offset == member.offset: + skipping = False + continue + if normalize: + member_name = os.path.normpath(member.name) + else: + member_name = member.name + + if name == member_name: + return member + + if skipping: + # Starting point was not found + raise ValueError(tarinfo) + + def _load(self): + """Read through the entire archive file and look for readable + members. + """ + while self.next() is not None: + pass + self._loaded = True + + def _check(self, mode=None): + """Check if TarFile is still open, and if the operation's mode + corresponds to TarFile's mode. + """ + if self.closed: + raise OSError("%s is closed" % self.__class__.__name__) + if mode is not None and self.mode not in mode: + raise OSError("bad operation for mode %r" % self.mode) + + def _find_link_target(self, tarinfo): + """Find the target member of a symlink or hardlink member in the + archive. + """ + if tarinfo.issym(): + # Always search the entire archive. + linkname = "/".join(filter(None, (os.path.dirname(tarinfo.name), tarinfo.linkname))) + limit = None + else: + # Search the archive before the link, because a hard link is + # just a reference to an already archived file. + linkname = tarinfo.linkname + limit = tarinfo + + member = self._getmember(linkname, tarinfo=limit, normalize=True) + if member is None: + raise KeyError("linkname %r not found" % linkname) + return member + + def __iter__(self): + """Provide an iterator object. + """ + if self._loaded: + yield from self.members + return + + # Yield items using TarFile's next() method. + # When all members have been read, set TarFile as _loaded. + index = 0 + # Fix for SF #1100429: Under rare circumstances it can + # happen that getmembers() is called during iteration, + # which will have already exhausted the next() method. + if self.firstmember is not None: + tarinfo = self.next() + index += 1 + yield tarinfo + + while True: + if index < len(self.members): + tarinfo = self.members[index] + elif not self._loaded: + tarinfo = self.next() + if not tarinfo: + self._loaded = True + return + else: + return + index += 1 + yield tarinfo + + def _dbg(self, level, msg): + """Write debugging output to sys.stderr. + """ + if level <= self.debug: + print(msg, file=sys.stderr) + + def __enter__(self): + self._check() + return self + + def __exit__(self, type, value, traceback): + if type is None: + self.close() + else: + # An exception occurred. We must not call close() because + # it would try to write end-of-archive blocks and padding. + if not self._extfileobj: + self.fileobj.close() + self.closed = True + +#-------------------- +# exported functions +#-------------------- + +def is_tarfile(name): + """Return True if name points to a tar archive that we + are able to handle, else return False. + + 'name' should be a string, file, or file-like object. + """ + try: + if hasattr(name, "read"): + pos = name.tell() + t = open(fileobj=name) + name.seek(pos) + else: + t = open(name) + t.close() + return True + except TarError: + return False + +open = TarFile.open + + +def main(): + import argparse + + description = 'A simple command-line interface for tarfile module.' + parser = argparse.ArgumentParser(description=description) + parser.add_argument('-v', '--verbose', action='store_true', default=False, + help='Verbose output') + parser.add_argument('--filter', metavar='', + choices=_NAMED_FILTERS, + help='Filter for extraction') + + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument('-l', '--list', metavar='', + help='Show listing of a tarfile') + group.add_argument('-e', '--extract', nargs='+', + metavar=('', ''), + help='Extract tarfile into target dir') + group.add_argument('-c', '--create', nargs='+', + metavar=('', ''), + help='Create tarfile from sources') + group.add_argument('-t', '--test', metavar='', + help='Test if a tarfile is valid') + + args = parser.parse_args() + + if args.filter and args.extract is None: + parser.exit(1, '--filter is only valid for extraction\n') + + if args.test is not None: + src = args.test + if is_tarfile(src): + with open(src, 'r') as tar: + tar.getmembers() + print(tar.getmembers(), file=sys.stderr) + if args.verbose: + print('{!r} is a tar archive.'.format(src)) + else: + parser.exit(1, '{!r} is not a tar archive.\n'.format(src)) + + elif args.list is not None: + src = args.list + if is_tarfile(src): + with TarFile.open(src, 'r:*') as tf: + tf.list(verbose=args.verbose) + else: + parser.exit(1, '{!r} is not a tar archive.\n'.format(src)) + + elif args.extract is not None: + if len(args.extract) == 1: + src = args.extract[0] + curdir = os.curdir + elif len(args.extract) == 2: + src, curdir = args.extract + else: + parser.exit(1, parser.format_help()) + + if is_tarfile(src): + with TarFile.open(src, 'r:*') as tf: + tf.extractall(path=curdir, filter=args.filter) + if args.verbose: + if curdir == '.': + msg = '{!r} file is extracted.'.format(src) + else: + msg = ('{!r} file is extracted ' + 'into {!r} directory.').format(src, curdir) + print(msg) + else: + parser.exit(1, '{!r} is not a tar archive.\n'.format(src)) + + elif args.create is not None: + tar_name = args.create.pop(0) + _, ext = os.path.splitext(tar_name) + compressions = { + # gz + '.gz': 'gz', + '.tgz': 'gz', + # xz + '.xz': 'xz', + '.txz': 'xz', + # bz2 + '.bz2': 'bz2', + '.tbz': 'bz2', + '.tbz2': 'bz2', + '.tb2': 'bz2', + } + tar_mode = 'w:' + compressions[ext] if ext in compressions else 'w' + tar_files = args.create + + with TarFile.open(tar_name, tar_mode) as tf: + for file_name in tar_files: + tf.add(file_name) + + if args.verbose: + print('{!r} file created.'.format(tar_name)) + +if __name__ == '__main__': + main() diff --git a/lib/backports/tarfile/__main__.py b/lib/backports/tarfile/__main__.py new file mode 100644 index 00000000..daf55090 --- /dev/null +++ b/lib/backports/tarfile/__main__.py @@ -0,0 +1,5 @@ +from . import main + + +if __name__ == '__main__': + main() diff --git a/lib/inflect/compat/__init__.py b/lib/backports/tarfile/compat/__init__.py similarity index 100% rename from lib/inflect/compat/__init__.py rename to lib/backports/tarfile/compat/__init__.py diff --git a/lib/backports/tarfile/compat/py38.py b/lib/backports/tarfile/compat/py38.py new file mode 100644 index 00000000..20fbbfc1 --- /dev/null +++ b/lib/backports/tarfile/compat/py38.py @@ -0,0 +1,24 @@ +import sys + + +if sys.version_info < (3, 9): + + def removesuffix(self, suffix): + # suffix='' should not call self[:-0]. + if suffix and self.endswith(suffix): + return self[: -len(suffix)] + else: + return self[:] + + def removeprefix(self, prefix): + if self.startswith(prefix): + return self[len(prefix) :] + else: + return self[:] +else: + + def removesuffix(self, suffix): + return self.removesuffix(suffix) + + def removeprefix(self, prefix): + return self.removeprefix(prefix) diff --git a/lib/future/__init__.py b/lib/future/__init__.py index b609299a..b097fd81 100644 --- a/lib/future/__init__.py +++ b/lib/future/__init__.py @@ -52,7 +52,7 @@ Automatic conversion -------------------- An included script called `futurize -`_ aids in converting +`_ aids in converting code (from either Python 2 or Python 3) to code compatible with both platforms. It is similar to ``python-modernize`` but goes further in providing Python 3 compatibility through the use of the backported types @@ -62,21 +62,20 @@ and builtin functions in ``future``. Documentation ------------- -See: http://python-future.org +See: https://python-future.org Credits ------- :Author: Ed Schofield, Jordan M. Adler, et al -:Sponsor: Python Charmers Pty Ltd, Australia, and Python Charmers Pte - Ltd, Singapore. http://pythoncharmers.com -:Others: See docs/credits.rst or http://python-future.org/credits.html +:Sponsor: Python Charmers: https://pythoncharmers.com +:Others: See docs/credits.rst or https://python-future.org/credits.html Licensing --------- -Copyright 2013-2019 Python Charmers Pty Ltd, Australia. +Copyright 2013-2024 Python Charmers, Australia. The software is distributed under an MIT licence. See LICENSE.txt. """ @@ -84,10 +83,10 @@ The software is distributed under an MIT licence. See LICENSE.txt. __title__ = 'future' __author__ = 'Ed Schofield' __license__ = 'MIT' -__copyright__ = 'Copyright 2013-2019 Python Charmers Pty Ltd' -__ver_major__ = 0 -__ver_minor__ = 18 -__ver_patch__ = 3 +__copyright__ = 'Copyright 2013-2024 Python Charmers (https://pythoncharmers.com)' +__ver_major__ = 1 +__ver_minor__ = 0 +__ver_patch__ = 0 __ver_sub__ = '' __version__ = "%d.%d.%d%s" % (__ver_major__, __ver_minor__, __ver_patch__, __ver_sub__) diff --git a/lib/future/backports/datetime.py b/lib/future/backports/datetime.py index 3261014e..8cd62ddf 100644 --- a/lib/future/backports/datetime.py +++ b/lib/future/backports/datetime.py @@ -689,7 +689,7 @@ class date(object): @classmethod def fromordinal(cls, n): - """Contruct a date from a proleptic Gregorian ordinal. + """Construct a date from a proleptic Gregorian ordinal. January 1 of year 1 is day 1. Only the year, month and day are non-zero in the result. diff --git a/lib/future/backports/email/_header_value_parser.py b/lib/future/backports/email/_header_value_parser.py index 43957edc..59b1b318 100644 --- a/lib/future/backports/email/_header_value_parser.py +++ b/lib/future/backports/email/_header_value_parser.py @@ -2867,7 +2867,7 @@ def parse_content_type_header(value): _find_mime_parameters(ctype, value) return ctype ctype.append(token) - # XXX: If we really want to follow the formal grammer we should make + # XXX: If we really want to follow the formal grammar we should make # mantype and subtype specialized TokenLists here. Probably not worth it. if not value or value[0] != '/': ctype.defects.append(errors.InvalidHeaderDefect( diff --git a/lib/future/backports/email/parser.py b/lib/future/backports/email/parser.py index df1c6e28..79f0e5a3 100644 --- a/lib/future/backports/email/parser.py +++ b/lib/future/backports/email/parser.py @@ -26,7 +26,7 @@ class Parser(object): textual representation of the message. The string must be formatted as a block of RFC 2822 headers and header - continuation lines, optionally preceeded by a `Unix-from' header. The + continuation lines, optionally preceded by a `Unix-from' header. The header block is terminated either by the end of the string or by a blank line. @@ -92,7 +92,7 @@ class BytesParser(object): textual representation of the message. The input must be formatted as a block of RFC 2822 headers and header - continuation lines, optionally preceeded by a `Unix-from' header. The + continuation lines, optionally preceded by a `Unix-from' header. The header block is terminated either by the end of the input or by a blank line. diff --git a/lib/future/backports/http/cookiejar.py b/lib/future/backports/http/cookiejar.py index 0ad80a02..a39242c0 100644 --- a/lib/future/backports/http/cookiejar.py +++ b/lib/future/backports/http/cookiejar.py @@ -1851,7 +1851,7 @@ def lwp_cookie_str(cookie): class LWPCookieJar(FileCookieJar): """ The LWPCookieJar saves a sequence of "Set-Cookie3" lines. - "Set-Cookie3" is the format used by the libwww-perl libary, not known + "Set-Cookie3" is the format used by the libwww-perl library, not known to be compatible with any browser, but which is easy to read and doesn't lose information about RFC 2965 cookies. diff --git a/lib/future/backports/test/support.py b/lib/future/backports/test/support.py index 1999e208..6639372b 100644 --- a/lib/future/backports/test/support.py +++ b/lib/future/backports/test/support.py @@ -28,7 +28,6 @@ import importlib # import collections.abc # not present on Py2.7 import re import subprocess -import imp import time try: import sysconfig @@ -341,37 +340,6 @@ def rmtree(path): if error.errno != errno.ENOENT: raise -def make_legacy_pyc(source): - """Move a PEP 3147 pyc/pyo file to its legacy pyc/pyo location. - - The choice of .pyc or .pyo extension is done based on the __debug__ flag - value. - - :param source: The file system path to the source file. The source file - does not need to exist, however the PEP 3147 pyc file must exist. - :return: The file system path to the legacy pyc file. - """ - pyc_file = imp.cache_from_source(source) - up_one = os.path.dirname(os.path.abspath(source)) - legacy_pyc = os.path.join(up_one, source + ('c' if __debug__ else 'o')) - os.rename(pyc_file, legacy_pyc) - return legacy_pyc - -def forget(modname): - """'Forget' a module was ever imported. - - This removes the module from sys.modules and deletes any PEP 3147 or - legacy .pyc and .pyo files. - """ - unload(modname) - for dirname in sys.path: - source = os.path.join(dirname, modname + '.py') - # It doesn't matter if they exist or not, unlink all possible - # combinations of PEP 3147 and legacy pyc and pyo files. - unlink(source + 'c') - unlink(source + 'o') - unlink(imp.cache_from_source(source, debug_override=True)) - unlink(imp.cache_from_source(source, debug_override=False)) # On some platforms, should not run gui test even if it is allowed # in `use_resources'. diff --git a/lib/future/backports/xmlrpc/client.py b/lib/future/backports/xmlrpc/client.py index b78e5bad..0838f61a 100644 --- a/lib/future/backports/xmlrpc/client.py +++ b/lib/future/backports/xmlrpc/client.py @@ -134,10 +134,11 @@ from __future__ import (absolute_import, division, print_function, from future.builtins import bytes, dict, int, range, str import base64 -# Py2.7 compatibility hack -base64.encodebytes = base64.encodestring -base64.decodebytes = base64.decodestring import sys +if sys.version_info < (3, 9): + # Py2.7 compatibility hack + base64.encodebytes = base64.encodestring + base64.decodebytes = base64.decodestring import time from datetime import datetime from future.backports.http import client as http_client @@ -1251,7 +1252,7 @@ class Transport(object): # Send HTTP request. # # @param host Host descriptor (URL or (URL, x509 info) tuple). - # @param handler Targer RPC handler (a path relative to host) + # @param handler Target RPC handler (a path relative to host) # @param request_body The XML-RPC request body # @param debug Enable debugging if debug is true. # @return An HTTPConnection. diff --git a/lib/future/builtins/__init__.py b/lib/future/builtins/__init__.py index 8bc1649d..1734cd45 100644 --- a/lib/future/builtins/__init__.py +++ b/lib/future/builtins/__init__.py @@ -2,7 +2,7 @@ A module that brings in equivalents of the new and modified Python 3 builtins into Py2. Has no effect on Py3. -See the docs `here `_ +See the docs `here `_ (``docs/what-else.rst``) for more information. """ diff --git a/lib/future/moves/_dummy_thread.py b/lib/future/moves/_dummy_thread.py index 688d249b..6633f42e 100644 --- a/lib/future/moves/_dummy_thread.py +++ b/lib/future/moves/_dummy_thread.py @@ -1,8 +1,13 @@ from __future__ import absolute_import -from future.utils import PY3 +from future.utils import PY3, PY39_PLUS -if PY3: - from _dummy_thread import * + +if PY39_PLUS: + # _dummy_thread and dummy_threading modules were both deprecated in + # Python 3.7 and removed in Python 3.9 + from _thread import * +elif PY3: + from _dummy_thread import * else: __future_module__ = True from dummy_thread import * diff --git a/lib/future/moves/multiprocessing.py b/lib/future/moves/multiprocessing.py new file mode 100644 index 00000000..a871b676 --- /dev/null +++ b/lib/future/moves/multiprocessing.py @@ -0,0 +1,7 @@ +from __future__ import absolute_import +from future.utils import PY3 + +from multiprocessing import * +if not PY3: + __future_module__ = True + from multiprocessing.queues import SimpleQueue diff --git a/lib/future/moves/test/support.py b/lib/future/moves/test/support.py index e9aa0f48..f70c9d7d 100644 --- a/lib/future/moves/test/support.py +++ b/lib/future/moves/test/support.py @@ -1,9 +1,18 @@ from __future__ import absolute_import + +import sys + from future.standard_library import suspend_hooks from future.utils import PY3 if PY3: from test.support import * + if sys.version_info[:2] >= (3, 10): + from test.support.os_helper import ( + EnvironmentVarGuard, + TESTFN, + ) + from test.support.warnings_helper import check_warnings else: __future_module__ = True with suspend_hooks(): diff --git a/lib/future/standard_library/__init__.py b/lib/future/standard_library/__init__.py index cff02f95..d467aaf4 100644 --- a/lib/future/standard_library/__init__.py +++ b/lib/future/standard_library/__init__.py @@ -17,7 +17,7 @@ And then these normal Py3 imports work on both Py3 and Py2:: import socketserver import winreg # on Windows only import test.support - import html, html.parser, html.entites + import html, html.parser, html.entities import http, http.client, http.server import http.cookies, http.cookiejar import urllib.parse, urllib.request, urllib.response, urllib.error, urllib.robotparser @@ -33,6 +33,7 @@ And then these normal Py3 imports work on both Py3 and Py2:: from collections import OrderedDict, Counter, ChainMap # even on Py2.6 from subprocess import getoutput, getstatusoutput from subprocess import check_output # even on Py2.6 + from multiprocessing import SimpleQueue (The renamed modules and functions are still available under their old names on Python 2.) @@ -62,9 +63,12 @@ from __future__ import absolute_import, division, print_function import sys import logging -import imp +# imp was deprecated in python 3.6 +if sys.version_info >= (3, 6): + import importlib as imp +else: + import imp import contextlib -import types import copy import os @@ -108,6 +112,7 @@ RENAMES = { 'future.moves.socketserver': 'socketserver', 'ConfigParser': 'configparser', 'repr': 'reprlib', + 'multiprocessing.queues': 'multiprocessing', # 'FileDialog': 'tkinter.filedialog', # 'tkFileDialog': 'tkinter.filedialog', # 'SimpleDialog': 'tkinter.simpledialog', @@ -125,7 +130,7 @@ RENAMES = { # 'Tkinter': 'tkinter', '_winreg': 'winreg', 'thread': '_thread', - 'dummy_thread': '_dummy_thread', + 'dummy_thread': '_dummy_thread' if sys.version_info < (3, 9) else '_thread', # 'anydbm': 'dbm', # causes infinite import loop # 'whichdb': 'dbm', # causes infinite import loop # anydbm and whichdb are handled by fix_imports2 @@ -184,6 +189,7 @@ MOVES = [('collections', 'UserList', 'UserList', 'UserList'), ('itertools', 'filterfalse','itertools', 'ifilterfalse'), ('itertools', 'zip_longest','itertools', 'izip_longest'), ('sys', 'intern','__builtin__', 'intern'), + ('multiprocessing', 'SimpleQueue', 'multiprocessing.queues', 'SimpleQueue'), # The re module has no ASCII flag in Py2, but this is the default. # Set re.ASCII to a zero constant. stat.ST_MODE just happens to be one # (and it exists on Py2.6+). diff --git a/lib/future/types/newint.py b/lib/future/types/newint.py index 04a411e9..ebc5715e 100644 --- a/lib/future/types/newint.py +++ b/lib/future/types/newint.py @@ -223,9 +223,11 @@ class newint(with_metaclass(BaseNewInt, long)): def __rpow__(self, other): value = super(newint, self).__rpow__(other) - if value is NotImplemented: + if isint(value): + return newint(value) + elif value is NotImplemented: return other ** long(self) - return newint(value) + return value def __lshift__(self, other): if not isint(other): @@ -318,7 +320,7 @@ class newint(with_metaclass(BaseNewInt, long)): bits = length * 8 num = (2**bits) + self if num <= 0: - raise OverflowError("int too smal to convert") + raise OverflowError("int too small to convert") else: if self < 0: raise OverflowError("can't convert negative int to unsigned") diff --git a/lib/future/types/newrange.py b/lib/future/types/newrange.py index 6d4ebe2f..dc5eb802 100644 --- a/lib/future/types/newrange.py +++ b/lib/future/types/newrange.py @@ -105,7 +105,7 @@ class newrange(Sequence): raise ValueError('%r is not in range' % value) def count(self, value): - """Return the number of ocurrences of integer `value` + """Return the number of occurrences of integer `value` in the sequence this range represents.""" # a value can occur exactly zero or one times return int(value in self) diff --git a/lib/inflect/__init__.py b/lib/inflect/__init__.py index b638c6b8..d0eded16 100644 --- a/lib/inflect/__init__.py +++ b/lib/inflect/__init__.py @@ -3,6 +3,8 @@ inflect: english language inflection - correctly generate plurals, ordinals, indefinite articles - convert numbers to words +Copyright (C) 2010 Paul Dyson + Based upon the Perl module `Lingua::EN::Inflect `_. @@ -50,34 +52,33 @@ Exceptions: """ +from __future__ import annotations + import ast -import re -import functools import collections import contextlib +import functools +import itertools +import re +from numbers import Number from typing import ( + TYPE_CHECKING, + Any, + Callable, Dict, - Union, - Optional, Iterable, List, Match, - Tuple, - Callable, + Optional, Sequence, + Tuple, + Union, cast, - Any, ) -from typing_extensions import Literal -from numbers import Number - -from pydantic import Field -from typing_extensions import Annotated - - -from .compat.pydantic1 import validate_call -from .compat.pydantic import same_method +from more_itertools import windowed_complete +from typeguard import typechecked +from typing_extensions import Annotated, Literal class UnknownClassicalModeError(Exception): @@ -258,9 +259,9 @@ si_sb_irregular_compound = {v: k for (k, v) in pl_sb_irregular_compound.items()} for k in list(si_sb_irregular_compound): if "|" in k: k1, k2 = k.split("|") - si_sb_irregular_compound[k1] = si_sb_irregular_compound[ - k2 - ] = si_sb_irregular_compound[k] + si_sb_irregular_compound[k1] = si_sb_irregular_compound[k2] = ( + si_sb_irregular_compound[k] + ) del si_sb_irregular_compound[k] # si_sb_irregular_keys = enclose('|'.join(si_sb_irregular.keys())) @@ -1597,7 +1598,7 @@ pl_prep_bysize = bysize(pl_prep_list_da) pl_prep = enclose("|".join(pl_prep_list_da)) -pl_sb_prep_dual_compound = fr"(.*?)((?:-|\s+)(?:{pl_prep})(?:-|\s+))a(?:-|\s+)(.*)" +pl_sb_prep_dual_compound = rf"(.*?)((?:-|\s+)(?:{pl_prep})(?:-|\s+))a(?:-|\s+)(.*)" singular_pronoun_genders = { @@ -1764,7 +1765,7 @@ plverb_ambiguous_pres = { } plverb_ambiguous_pres_keys = re.compile( - fr"^({enclose('|'.join(plverb_ambiguous_pres))})((\s.*)?)$", re.IGNORECASE + rf"^({enclose('|'.join(plverb_ambiguous_pres))})((\s.*)?)$", re.IGNORECASE ) @@ -1804,7 +1805,7 @@ pl_count_one = ("1", "a", "an", "one", "each", "every", "this", "that") pl_adj_special = {"a": "some", "an": "some", "this": "these", "that": "those"} pl_adj_special_keys = re.compile( - fr"^({enclose('|'.join(pl_adj_special))})$", re.IGNORECASE + rf"^({enclose('|'.join(pl_adj_special))})$", re.IGNORECASE ) pl_adj_poss = { @@ -1816,7 +1817,7 @@ pl_adj_poss = { "their": "their", } -pl_adj_poss_keys = re.compile(fr"^({enclose('|'.join(pl_adj_poss))})$", re.IGNORECASE) +pl_adj_poss_keys = re.compile(rf"^({enclose('|'.join(pl_adj_poss))})$", re.IGNORECASE) # 2. INDEFINITE ARTICLES @@ -1883,7 +1884,7 @@ ordinal = dict( twelve="twelfth", ) -ordinal_suff = re.compile(fr"({'|'.join(ordinal)})\Z") +ordinal_suff = re.compile(rf"({'|'.join(ordinal)})\Z") # NUMBERS @@ -1948,13 +1949,13 @@ DOLLAR_DIGITS = re.compile(r"\$(\d+)") FUNCTION_CALL = re.compile(r"((\w+)\([^)]*\)*)", re.IGNORECASE) PARTITION_WORD = re.compile(r"\A(\s*)(.+?)(\s*)\Z") PL_SB_POSTFIX_ADJ_STEMS_RE = re.compile( - fr"^(?:{pl_sb_postfix_adj_stems})$", re.IGNORECASE + rf"^(?:{pl_sb_postfix_adj_stems})$", re.IGNORECASE ) PL_SB_PREP_DUAL_COMPOUND_RE = re.compile( - fr"^(?:{pl_sb_prep_dual_compound})$", re.IGNORECASE + rf"^(?:{pl_sb_prep_dual_compound})$", re.IGNORECASE ) DENOMINATOR = re.compile(r"(?P.+)( (per|a) .+)") -PLVERB_SPECIAL_S_RE = re.compile(fr"^({plverb_special_s})$") +PLVERB_SPECIAL_S_RE = re.compile(rf"^({plverb_special_s})$") WHITESPACE = re.compile(r"\s") ENDS_WITH_S = re.compile(r"^(.*[^s])s$", re.IGNORECASE) ENDS_WITH_APOSTROPHE_S = re.compile(r"^(.*)'s?$") @@ -2020,10 +2021,25 @@ class Words(str): self.last = self.split_[-1] -Word = Annotated[str, Field(min_length=1)] Falsish = Any # ideally, falsish would only validate on bool(value) is False +_STATIC_TYPE_CHECKING = TYPE_CHECKING +# ^-- Workaround for typeguard AST manipulation: +# https://github.com/agronholm/typeguard/issues/353#issuecomment-1556306554 + +if _STATIC_TYPE_CHECKING: # pragma: no cover + Word = Annotated[str, "String with at least 1 character"] +else: + + class _WordMeta(type): # Too dynamic to be supported by mypy... + def __instancecheck__(self, instance: Any) -> bool: + return isinstance(instance, str) and len(instance) >= 1 + + class Word(metaclass=_WordMeta): # type: ignore[no-redef] + """String with at least 1 character""" + + class engine: def __init__(self) -> None: self.classical_dict = def_classical.copy() @@ -2045,7 +2061,7 @@ class engine: def _number_args(self, val): self.__number_args = val - @validate_call + @typechecked def defnoun(self, singular: Optional[Word], plural: Optional[Word]) -> int: """ Set the noun plural of singular to plural. @@ -2057,7 +2073,7 @@ class engine: self.si_sb_user_defined.extend((plural, singular)) return 1 - @validate_call + @typechecked def defverb( self, s1: Optional[Word], @@ -2082,7 +2098,7 @@ class engine: self.pl_v_user_defined.extend((s1, p1, s2, p2, s3, p3)) return 1 - @validate_call + @typechecked def defadj(self, singular: Optional[Word], plural: Optional[Word]) -> int: """ Set the adjective plural of singular to plural. @@ -2093,7 +2109,7 @@ class engine: self.pl_adj_user_defined.extend((singular, plural)) return 1 - @validate_call + @typechecked def defa(self, pattern: Optional[Word]) -> int: """ Define the indefinite article as 'a' for words matching pattern. @@ -2103,7 +2119,7 @@ class engine: self.A_a_user_defined.extend((pattern, "a")) return 1 - @validate_call + @typechecked def defan(self, pattern: Optional[Word]) -> int: """ Define the indefinite article as 'an' for words matching pattern. @@ -2121,8 +2137,8 @@ class engine: return try: re.match(pattern, "") - except re.error: - raise BadUserDefinedPatternError(pattern) + except re.error as err: + raise BadUserDefinedPatternError(pattern) from err def checkpatplural(self, pattern: Optional[Word]) -> None: """ @@ -2130,10 +2146,10 @@ class engine: """ return - @validate_call + @typechecked def ud_match(self, word: Word, wordlist: Sequence[Optional[Word]]) -> Optional[str]: for i in range(len(wordlist) - 2, -2, -2): # backwards through even elements - mo = re.search(fr"^{wordlist[i]}$", word, re.IGNORECASE) + mo = re.search(rf"^{wordlist[i]}$", word, re.IGNORECASE) if mo: if wordlist[i + 1] is None: return None @@ -2191,8 +2207,8 @@ class engine: if count is not None: try: self.persistent_count = int(count) - except ValueError: - raise BadNumValueError + except ValueError as err: + raise BadNumValueError from err if (show is None) or show: return str(count) else: @@ -2270,7 +2286,7 @@ class engine: # 0. PERFORM GENERAL INFLECTIONS IN A STRING - @validate_call + @typechecked def inflect(self, text: Word) -> str: """ Perform inflections in a string. @@ -2347,7 +2363,7 @@ class engine: else: return "", "", "" - @validate_call + @typechecked def plural(self, text: Word, count: Optional[Union[str, int, Any]] = None) -> str: """ Return the plural of text. @@ -2371,7 +2387,7 @@ class engine: ) return f"{pre}{plural}{post}" - @validate_call + @typechecked def plural_noun( self, text: Word, count: Optional[Union[str, int, Any]] = None ) -> str: @@ -2392,7 +2408,7 @@ class engine: plural = self.postprocess(word, self._plnoun(word, count)) return f"{pre}{plural}{post}" - @validate_call + @typechecked def plural_verb( self, text: Word, count: Optional[Union[str, int, Any]] = None ) -> str: @@ -2416,7 +2432,7 @@ class engine: ) return f"{pre}{plural}{post}" - @validate_call + @typechecked def plural_adj( self, text: Word, count: Optional[Union[str, int, Any]] = None ) -> str: @@ -2437,7 +2453,7 @@ class engine: plural = self.postprocess(word, self._pl_special_adjective(word, count) or word) return f"{pre}{plural}{post}" - @validate_call + @typechecked def compare(self, word1: Word, word2: Word) -> Union[str, bool]: """ compare word1 and word2 for equality regardless of plurality @@ -2460,15 +2476,13 @@ class engine: >>> compare('egg', '') Traceback (most recent call last): ... - pydantic...ValidationError: ... - ... - ...at least 1 characters... + typeguard.TypeCheckError:...is not an instance of inflect.Word """ norms = self.plural_noun, self.plural_verb, self.plural_adj results = (self._plequal(word1, word2, norm) for norm in norms) return next(filter(None, results), False) - @validate_call + @typechecked def compare_nouns(self, word1: Word, word2: Word) -> Union[str, bool]: """ compare word1 and word2 for equality regardless of plurality @@ -2484,7 +2498,7 @@ class engine: """ return self._plequal(word1, word2, self.plural_noun) - @validate_call + @typechecked def compare_verbs(self, word1: Word, word2: Word) -> Union[str, bool]: """ compare word1 and word2 for equality regardless of plurality @@ -2500,7 +2514,7 @@ class engine: """ return self._plequal(word1, word2, self.plural_verb) - @validate_call + @typechecked def compare_adjs(self, word1: Word, word2: Word) -> Union[str, bool]: """ compare word1 and word2 for equality regardless of plurality @@ -2516,7 +2530,7 @@ class engine: """ return self._plequal(word1, word2, self.plural_adj) - @validate_call + @typechecked def singular_noun( self, text: Word, @@ -2574,18 +2588,18 @@ class engine: return "s:p" self.classical_dict = classval.copy() - if same_method(pl, self.plural) or same_method(pl, self.plural_noun): + if pl == self.plural or pl == self.plural_noun: if self._pl_check_plurals_N(word1, word2): return "p:p" if self._pl_check_plurals_N(word2, word1): return "p:p" - if same_method(pl, self.plural) or same_method(pl, self.plural_adj): + if pl == self.plural or pl == self.plural_adj: if self._pl_check_plurals_adj(word1, word2): return "p:p" return False def _pl_reg_plurals(self, pair: str, stems: str, end1: str, end2: str) -> bool: - pattern = fr"({stems})({end1}\|\1{end2}|{end2}\|\1{end1})" + pattern = rf"({stems})({end1}\|\1{end2}|{end2}\|\1{end1})" return bool(re.search(pattern, pair)) def _pl_check_plurals_N(self, word1: str, word2: str) -> bool: @@ -2679,6 +2693,8 @@ class engine: word = Words(word) if word.last.lower() in pl_sb_uninflected_complete: + if len(word.split_) >= 3: + return self._handle_long_compounds(word, count=2) or word return word if word in pl_sb_uninflected_caps: @@ -2707,13 +2723,9 @@ class engine: ) if len(word.split_) >= 3: - for numword in range(1, len(word.split_) - 1): - if word.split_[numword] in pl_prep_list_da: - return " ".join( - word.split_[: numword - 1] - + [self._plnoun(word.split_[numword - 1], 2)] - + word.split_[numword:] - ) + handled_words = self._handle_long_compounds(word, count=2) + if handled_words is not None: + return handled_words # only pluralize denominators in units mo = DENOMINATOR.search(word.lowered) @@ -2972,6 +2984,30 @@ class engine: parts[: pivot - 1] + [sep.join([transformed, parts[pivot], ''])] ) + " ".join(parts[(pivot + 1) :]) + def _handle_long_compounds(self, word: Words, count: int) -> Union[str, None]: + """ + Handles the plural and singular for compound `Words` that + have three or more words, based on the given count. + + >>> engine()._handle_long_compounds(Words("pair of scissors"), 2) + 'pairs of scissors' + >>> engine()._handle_long_compounds(Words("men beyond hills"), 1) + 'man beyond hills' + """ + inflection = self._sinoun if count == 1 else self._plnoun + solutions = ( # type: ignore + " ".join( + itertools.chain( + leader, + [inflection(cand, count), prep], # type: ignore + trailer, + ) + ) + for leader, (cand, prep), trailer in windowed_complete(word.split_, 2) + if prep in pl_prep_list_da # type: ignore + ) + return next(solutions, None) + @staticmethod def _find_pivot(words, candidates): pivots = ( @@ -2980,7 +3016,7 @@ class engine: try: return next(pivots) except StopIteration: - raise ValueError("No pivot found") + raise ValueError("No pivot found") from None def _pl_special_verb( # noqa: C901 self, word: str, count: Optional[Union[str, int]] = None @@ -3145,8 +3181,8 @@ class engine: gender = self.thegender elif gender not in singular_pronoun_genders: raise BadGenderError - except (TypeError, IndexError): - raise BadGenderError + except (TypeError, IndexError) as err: + raise BadGenderError from err # HANDLE USER-DEFINED NOUNS @@ -3165,6 +3201,8 @@ class engine: words = Words(word) if words.last.lower() in pl_sb_uninflected_complete: + if len(words.split_) >= 3: + return self._handle_long_compounds(words, count=1) or word return word if word in pl_sb_uninflected_caps: @@ -3450,7 +3488,7 @@ class engine: # ADJECTIVES - @validate_call + @typechecked def a(self, text: Word, count: Optional[Union[int, str, Any]] = 1) -> str: """ Return the appropriate indefinite article followed by text. @@ -3531,7 +3569,7 @@ class engine: # 2. TRANSLATE ZERO-QUANTIFIED $word TO "no plural($word)" - @validate_call + @typechecked def no(self, text: Word, count: Optional[Union[int, str]] = None) -> str: """ If count is 0, no, zero or nil, return 'no' followed by the plural @@ -3569,7 +3607,7 @@ class engine: # PARTICIPLES - @validate_call + @typechecked def present_participle(self, word: Word) -> str: """ Return the present participle for word. @@ -3588,7 +3626,7 @@ class engine: # NUMERICAL INFLECTIONS - @validate_call(config=dict(arbitrary_types_allowed=True)) + @typechecked def ordinal(self, num: Union[Number, Word]) -> str: """ Return the ordinal of num. @@ -3619,16 +3657,7 @@ class engine: post = nth[n % 10] return f"{num}{post}" else: - # Mad props to Damian Conway (?) whose ordinal() - # algorithm is type-bendy enough to foil MyPy - str_num: str = num # type: ignore[assignment] - mo = ordinal_suff.search(str_num) - if mo: - post = ordinal[mo.group(1)] - rval = ordinal_suff.sub(post, str_num) - else: - rval = f"{str_num}th" - return rval + return self._sub_ord(num) def millfn(self, ind: int = 0) -> str: if ind > len(mill) - 1: @@ -3747,7 +3776,36 @@ class engine: num = ONE_DIGIT_WORD.sub(self.unitsub, num, 1) return num - @validate_call(config=dict(arbitrary_types_allowed=True)) # noqa: C901 + @staticmethod + def _sub_ord(val): + new = ordinal_suff.sub(lambda match: ordinal[match.group(1)], val) + return new + "th" * (new == val) + + @classmethod + def _chunk_num(cls, num, decimal, group): + if decimal: + max_split = -1 if group != 0 else 1 + chunks = num.split(".", max_split) + else: + chunks = [num] + return cls._remove_last_blank(chunks) + + @staticmethod + def _remove_last_blank(chunks): + """ + Remove the last item from chunks if it's a blank string. + + Return the resultant chunks and whether the last item was removed. + """ + removed = chunks[-1] == "" + result = chunks[:-1] if removed else chunks + return result, removed + + @staticmethod + def _get_sign(num): + return {'+': 'plus', '-': 'minus'}.get(num.lstrip()[0], '') + + @typechecked def number_to_words( # noqa: C901 self, num: Union[Number, Word], @@ -3794,13 +3852,8 @@ class engine: if group < 0 or group > 3: raise BadChunkingOptionError - nowhite = num.lstrip() - if nowhite[0] == "+": - sign = "plus" - elif nowhite[0] == "-": - sign = "minus" - else: - sign = "" + + sign = self._get_sign(num) if num in nth_suff: num = zero @@ -3808,34 +3861,21 @@ class engine: myord = num[-2:] in nth_suff if myord: num = num[:-2] - finalpoint = False - if decimal: - if group != 0: - chunks = num.split(".") - else: - chunks = num.split(".", 1) - if chunks[-1] == "": # remove blank string if nothing after decimal - chunks = chunks[:-1] - finalpoint = True # add 'point' to end of output - else: - chunks = [num] - first: Union[int, str, bool] = 1 - loopstart = 0 + chunks, finalpoint = self._chunk_num(num, decimal, group) - if chunks[0] == "": - first = 0 - if len(chunks) > 1: - loopstart = 1 + loopstart = chunks[0] == "" + first: bool | None = not loopstart + + def _handle_chunk(chunk): + nonlocal first - for i in range(loopstart, len(chunks)): - chunk = chunks[i] # remove all non numeric \D chunk = NON_DIGIT.sub("", chunk) if chunk == "": chunk = "0" - if group == 0 and (first == 0 or first == ""): + if group == 0 and not first: chunk = self.enword(chunk, 1) else: chunk = self.enword(chunk, group) @@ -3850,20 +3890,17 @@ class engine: # chunk = re.sub(r"(\A\s|\s\Z)", self.blankfn, chunk) chunk = chunk.strip() if first: - first = "" - chunks[i] = chunk + first = None + return chunk + + chunks[loopstart:] = map(_handle_chunk, chunks[loopstart:]) numchunks = [] if first != 0: numchunks = chunks[0].split(f"{comma} ") if myord and numchunks: - # TODO: can this be just one re as it is in perl? - mo = ordinal_suff.search(numchunks[-1]) - if mo: - numchunks[-1] = ordinal_suff.sub(ordinal[mo.group(1)], numchunks[-1]) - else: - numchunks[-1] += "th" + numchunks[-1] = self._sub_ord(numchunks[-1]) for chunk in chunks[1:]: numchunks.append(decimal) @@ -3872,34 +3909,30 @@ class engine: if finalpoint: numchunks.append(decimal) - # wantlist: Perl list context. can explicitly specify in Python if wantlist: - if sign: - numchunks = [sign] + numchunks - return numchunks - elif group: - signout = f"{sign} " if sign else "" - return f"{signout}{', '.join(numchunks)}" - else: - signout = f"{sign} " if sign else "" - num = f"{signout}{numchunks.pop(0)}" - if decimal is None: - first = True - else: - first = not num.endswith(decimal) - for nc in numchunks: - if nc == decimal: - num += f" {nc}" - first = 0 - elif first: - num += f"{comma} {nc}" - else: - num += f" {nc}" - return num + return [sign] * bool(sign) + numchunks - # Join words with commas and a trailing 'and' (when appropriate)... + signout = f"{sign} " if sign else "" + valout = ( + ', '.join(numchunks) + if group + else ''.join(self._render(numchunks, decimal, comma)) + ) + return signout + valout - @validate_call + @staticmethod + def _render(chunks, decimal, comma): + first_item = chunks.pop(0) + yield first_item + first = decimal is None or not first_item.endswith(decimal) + for nc in chunks: + if nc == decimal: + first = False + elif first: + yield comma + yield f" {nc}" + + @typechecked def join( self, words: Optional[Sequence[Word]], diff --git a/lib/inflect/compat/pydantic.py b/lib/inflect/compat/pydantic.py deleted file mode 100644 index d777564a..00000000 --- a/lib/inflect/compat/pydantic.py +++ /dev/null @@ -1,19 +0,0 @@ -class ValidateCallWrapperWrapper: - def __init__(self, wrapped): - self.orig = wrapped - - def __eq__(self, other): - return self.raw_function == other.raw_function - - @property - def raw_function(self): - return getattr(self.orig, 'raw_function') or self.orig - - -def same_method(m1, m2) -> bool: - """ - Return whether m1 and m2 are the same method. - - Workaround for pydantic/pydantic#6390. - """ - return ValidateCallWrapperWrapper(m1) == ValidateCallWrapperWrapper(m2) diff --git a/lib/inflect/compat/pydantic1.py b/lib/inflect/compat/pydantic1.py deleted file mode 100644 index 8262fdcf..00000000 --- a/lib/inflect/compat/pydantic1.py +++ /dev/null @@ -1,8 +0,0 @@ -try: - from pydantic import validate_call # type: ignore -except ImportError: - # Pydantic 1 - from pydantic import validate_arguments as validate_call # type: ignore - - -__all__ = ['validate_call'] diff --git a/lib/jaraco/classes/ancestry.py b/lib/jaraco/classes/ancestry.py deleted file mode 100644 index dd9b2e92..00000000 --- a/lib/jaraco/classes/ancestry.py +++ /dev/null @@ -1,68 +0,0 @@ -""" -Routines for obtaining the class names -of an object and its parent classes. -""" - -from more_itertools import unique_everseen - - -def all_bases(c): - """ - return a tuple of all base classes the class c has as a parent. - >>> object in all_bases(list) - True - """ - return c.mro()[1:] - - -def all_classes(c): - """ - return a tuple of all classes to which c belongs - >>> list in all_classes(list) - True - """ - return c.mro() - - -# borrowed from -# http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/ - - -def iter_subclasses(cls): - """ - Generator over all subclasses of a given class, in depth-first order. - - >>> bool in list(iter_subclasses(int)) - True - >>> class A(object): pass - >>> class B(A): pass - >>> class C(A): pass - >>> class D(B,C): pass - >>> class E(D): pass - >>> - >>> for cls in iter_subclasses(A): - ... print(cls.__name__) - B - D - E - C - >>> # get ALL classes currently defined - >>> res = [cls.__name__ for cls in iter_subclasses(object)] - >>> 'type' in res - True - >>> 'tuple' in res - True - >>> len(res) > 100 - True - """ - return unique_everseen(_iter_all_subclasses(cls)) - - -def _iter_all_subclasses(cls): - try: - subs = cls.__subclasses__() - except TypeError: # fails only when cls is type - subs = cls.__subclasses__(cls) - for sub in subs: - yield sub - yield from iter_subclasses(sub) diff --git a/lib/jaraco/classes/meta.py b/lib/jaraco/classes/meta.py deleted file mode 100644 index bd41a1d9..00000000 --- a/lib/jaraco/classes/meta.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -meta.py - -Some useful metaclasses. -""" - - -class LeafClassesMeta(type): - """ - A metaclass for classes that keeps track of all of them that - aren't base classes. - - >>> Parent = LeafClassesMeta('MyParentClass', (), {}) - >>> Parent in Parent._leaf_classes - True - >>> Child = LeafClassesMeta('MyChildClass', (Parent,), {}) - >>> Child in Parent._leaf_classes - True - >>> Parent in Parent._leaf_classes - False - - >>> Other = LeafClassesMeta('OtherClass', (), {}) - >>> Parent in Other._leaf_classes - False - >>> len(Other._leaf_classes) - 1 - """ - - def __init__(cls, name, bases, attrs): - if not hasattr(cls, '_leaf_classes'): - cls._leaf_classes = set() - leaf_classes = getattr(cls, '_leaf_classes') - leaf_classes.add(cls) - # remove any base classes - leaf_classes -= set(bases) - - -class TagRegistered(type): - """ - As classes of this metaclass are created, they keep a registry in the - base class of all classes by a class attribute, indicated by attr_name. - - >>> FooObject = TagRegistered('FooObject', (), dict(tag='foo')) - >>> FooObject._registry['foo'] is FooObject - True - >>> BarObject = TagRegistered('Barobject', (FooObject,), dict(tag='bar')) - >>> FooObject._registry is BarObject._registry - True - >>> len(FooObject._registry) - 2 - - '...' below should be 'jaraco.classes' but for pytest-dev/pytest#3396 - >>> FooObject._registry['bar'] - - """ - - attr_name = 'tag' - - def __init__(cls, name, bases, namespace): - super(TagRegistered, cls).__init__(name, bases, namespace) - if not hasattr(cls, '_registry'): - cls._registry = {} - meta = cls.__class__ - attr = getattr(cls, meta.attr_name, None) - if attr: - cls._registry[attr] = cls diff --git a/lib/jaraco/classes/properties.py b/lib/jaraco/classes/properties.py deleted file mode 100644 index 62f9e200..00000000 --- a/lib/jaraco/classes/properties.py +++ /dev/null @@ -1,170 +0,0 @@ -class NonDataProperty: - """Much like the property builtin, but only implements __get__, - making it a non-data property, and can be subsequently reset. - - See http://users.rcn.com/python/download/Descriptor.htm for more - information. - - >>> class X(object): - ... @NonDataProperty - ... def foo(self): - ... return 3 - >>> x = X() - >>> x.foo - 3 - >>> x.foo = 4 - >>> x.foo - 4 - - '...' below should be 'jaraco.classes' but for pytest-dev/pytest#3396 - >>> X.foo - <....properties.NonDataProperty object at ...> - """ - - def __init__(self, fget): - assert fget is not None, "fget cannot be none" - assert callable(fget), "fget must be callable" - self.fget = fget - - def __get__(self, obj, objtype=None): - if obj is None: - return self - return self.fget(obj) - - -class classproperty: - """ - Like @property but applies at the class level. - - - >>> class X(metaclass=classproperty.Meta): - ... val = None - ... @classproperty - ... def foo(cls): - ... return cls.val - ... @foo.setter - ... def foo(cls, val): - ... cls.val = val - >>> X.foo - >>> X.foo = 3 - >>> X.foo - 3 - >>> x = X() - >>> x.foo - 3 - >>> X.foo = 4 - >>> x.foo - 4 - - Setting the property on an instance affects the class. - - >>> x.foo = 5 - >>> x.foo - 5 - >>> X.foo - 5 - >>> vars(x) - {} - >>> X().foo - 5 - - Attempting to set an attribute where no setter was defined - results in an AttributeError: - - >>> class GetOnly(metaclass=classproperty.Meta): - ... @classproperty - ... def foo(cls): - ... return 'bar' - >>> GetOnly.foo = 3 - Traceback (most recent call last): - ... - AttributeError: can't set attribute - - It is also possible to wrap a classmethod or staticmethod in - a classproperty. - - >>> class Static(metaclass=classproperty.Meta): - ... @classproperty - ... @classmethod - ... def foo(cls): - ... return 'foo' - ... @classproperty - ... @staticmethod - ... def bar(): - ... return 'bar' - >>> Static.foo - 'foo' - >>> Static.bar - 'bar' - - *Legacy* - - For compatibility, if the metaclass isn't specified, the - legacy behavior will be invoked. - - >>> class X: - ... val = None - ... @classproperty - ... def foo(cls): - ... return cls.val - ... @foo.setter - ... def foo(cls, val): - ... cls.val = val - >>> X.foo - >>> X.foo = 3 - >>> X.foo - 3 - >>> x = X() - >>> x.foo - 3 - >>> X.foo = 4 - >>> x.foo - 4 - - Note, because the metaclass was not specified, setting - a value on an instance does not have the intended effect. - - >>> x.foo = 5 - >>> x.foo - 5 - >>> X.foo # should be 5 - 4 - >>> vars(x) # should be empty - {'foo': 5} - >>> X().foo # should be 5 - 4 - """ - - class Meta(type): - def __setattr__(self, key, value): - obj = self.__dict__.get(key, None) - if type(obj) is classproperty: - return obj.__set__(self, value) - return super().__setattr__(key, value) - - def __init__(self, fget, fset=None): - self.fget = self._ensure_method(fget) - self.fset = fset - fset and self.setter(fset) - - def __get__(self, instance, owner=None): - return self.fget.__get__(None, owner)() - - def __set__(self, owner, value): - if not self.fset: - raise AttributeError("can't set attribute") - if type(owner) is not classproperty.Meta: - owner = type(owner) - return self.fset.__get__(None, owner)(value) - - def setter(self, fset): - self.fset = self._ensure_method(fset) - return self - - @classmethod - def _ensure_method(cls, fn): - """ - Ensure fn is a classmethod or staticmethod. - """ - needs_method = not isinstance(fn, (classmethod, staticmethod)) - return classmethod(fn) if needs_method else fn diff --git a/lib/jaraco/collections/__init__.py b/lib/jaraco/collections/__init__.py index abedf002..5c276da9 100644 --- a/lib/jaraco/collections/__init__.py +++ b/lib/jaraco/collections/__init__.py @@ -1,16 +1,17 @@ -import re -import operator +from __future__ import annotations + import collections.abc -import itertools import copy import functools +import itertools +import operator import random +import re from collections.abc import Container, Iterable, Mapping -from typing import Callable, Union +from typing import Any, Callable, Union import jaraco.text - _Matchable = Union[Callable, Container, Iterable, re.Pattern] @@ -199,7 +200,12 @@ class RangeMap(dict): """ - def __init__(self, source, sort_params={}, key_match_comparator=operator.le): + def __init__( + self, + source, + sort_params: Mapping[str, Any] = {}, + key_match_comparator=operator.le, + ): dict.__init__(self, source) self.sort_params = sort_params self.match = key_match_comparator @@ -291,7 +297,7 @@ class KeyTransformingDict(dict): return key def __init__(self, *args, **kargs): - super(KeyTransformingDict, self).__init__() + super().__init__() # build a dictionary using the default constructs d = dict(*args, **kargs) # build this dictionary using transformed keys. @@ -300,31 +306,31 @@ class KeyTransformingDict(dict): def __setitem__(self, key, val): key = self.transform_key(key) - super(KeyTransformingDict, self).__setitem__(key, val) + super().__setitem__(key, val) def __getitem__(self, key): key = self.transform_key(key) - return super(KeyTransformingDict, self).__getitem__(key) + return super().__getitem__(key) def __contains__(self, key): key = self.transform_key(key) - return super(KeyTransformingDict, self).__contains__(key) + return super().__contains__(key) def __delitem__(self, key): key = self.transform_key(key) - return super(KeyTransformingDict, self).__delitem__(key) + return super().__delitem__(key) def get(self, key, *args, **kwargs): key = self.transform_key(key) - return super(KeyTransformingDict, self).get(key, *args, **kwargs) + return super().get(key, *args, **kwargs) def setdefault(self, key, *args, **kwargs): key = self.transform_key(key) - return super(KeyTransformingDict, self).setdefault(key, *args, **kwargs) + return super().setdefault(key, *args, **kwargs) def pop(self, key, *args, **kwargs): key = self.transform_key(key) - return super(KeyTransformingDict, self).pop(key, *args, **kwargs) + return super().pop(key, *args, **kwargs) def matching_key_for(self, key): """ @@ -333,8 +339,8 @@ class KeyTransformingDict(dict): """ try: return next(e_key for e_key in self.keys() if e_key == key) - except StopIteration: - raise KeyError(key) + except StopIteration as err: + raise KeyError(key) from err class FoldedCaseKeyedDict(KeyTransformingDict): @@ -483,7 +489,7 @@ class ItemsAsAttributes: def __getattr__(self, key): try: - return getattr(super(ItemsAsAttributes, self), key) + return getattr(super(), key) except AttributeError as e: # attempt to get the value from the mapping (return self[key]) # but be careful not to lose the original exception context. @@ -677,7 +683,7 @@ class BijectiveMap(dict): """ def __init__(self, *args, **kwargs): - super(BijectiveMap, self).__init__() + super().__init__() self.update(*args, **kwargs) def __setitem__(self, item, value): @@ -691,19 +697,19 @@ class BijectiveMap(dict): ) if overlap: raise ValueError("Key/Value pairs may not overlap") - super(BijectiveMap, self).__setitem__(item, value) - super(BijectiveMap, self).__setitem__(value, item) + super().__setitem__(item, value) + super().__setitem__(value, item) def __delitem__(self, item): self.pop(item) def __len__(self): - return super(BijectiveMap, self).__len__() // 2 + return super().__len__() // 2 def pop(self, key, *args, **kwargs): mirror = self[key] - super(BijectiveMap, self).__delitem__(mirror) - return super(BijectiveMap, self).pop(key, *args, **kwargs) + super().__delitem__(mirror) + return super().pop(key, *args, **kwargs) def update(self, *args, **kwargs): # build a dictionary using the default constructs @@ -769,7 +775,7 @@ class FrozenDict(collections.abc.Mapping, collections.abc.Hashable): __slots__ = ['__data'] def __new__(cls, *args, **kwargs): - self = super(FrozenDict, cls).__new__(cls) + self = super().__new__(cls) self.__data = dict(*args, **kwargs) return self @@ -844,7 +850,7 @@ class Enumeration(ItemsAsAttributes, BijectiveMap): names = names.split() if codes is None: codes = itertools.count() - super(Enumeration, self).__init__(zip(names, codes)) + super().__init__(zip(names, codes)) @property def names(self): diff --git a/lib/jaraco/context.py b/lib/jaraco/context.py index b0d1ef37..61b27135 100644 --- a/lib/jaraco/context.py +++ b/lib/jaraco/context.py @@ -1,15 +1,26 @@ -import os -import subprocess +from __future__ import annotations + import contextlib import functools -import tempfile -import shutil import operator +import os +import shutil +import subprocess +import sys +import tempfile +import urllib.request import warnings +from typing import Iterator + + +if sys.version_info < (3, 12): + from backports import tarfile +else: + import tarfile @contextlib.contextmanager -def pushd(dir): +def pushd(dir: str | os.PathLike) -> Iterator[str | os.PathLike]: """ >>> tmp_path = getfixture('tmp_path') >>> with pushd(tmp_path): @@ -26,33 +37,88 @@ def pushd(dir): @contextlib.contextmanager -def tarball_context(url, target_dir=None, runner=None, pushd=pushd): +def tarball( + url, target_dir: str | os.PathLike | None = None +) -> Iterator[str | os.PathLike]: """ - Get a tarball, extract it, change to that directory, yield, then - clean up. - `runner` is the function to invoke commands. - `pushd` is a context manager for changing the directory. + Get a tarball, extract it, yield, then clean up. + + >>> import urllib.request + >>> url = getfixture('tarfile_served') + >>> target = getfixture('tmp_path') / 'out' + >>> tb = tarball(url, target_dir=target) + >>> import pathlib + >>> with tb as extracted: + ... contents = pathlib.Path(extracted, 'contents.txt').read_text(encoding='utf-8') + >>> assert not os.path.exists(extracted) """ if target_dir is None: target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '') - if runner is None: - runner = functools.partial(subprocess.check_call, shell=True) - else: - warnings.warn("runner parameter is deprecated", DeprecationWarning) # In the tar command, use --strip-components=1 to strip the first path and # then # use -C to cause the files to be extracted to {target_dir}. This ensures # that we always know where the files were extracted. - runner('mkdir {target_dir}'.format(**vars())) + os.mkdir(target_dir) try: - getter = 'wget {url} -O -' - extract = 'tar x{compression} --strip-components=1 -C {target_dir}' - cmd = ' | '.join((getter, extract)) - runner(cmd.format(compression=infer_compression(url), **vars())) - with pushd(target_dir): - yield target_dir + req = urllib.request.urlopen(url) + with tarfile.open(fileobj=req, mode='r|*') as tf: + tf.extractall(path=target_dir, filter=strip_first_component) + yield target_dir finally: - runner('rm -Rf {target_dir}'.format(**vars())) + shutil.rmtree(target_dir) + + +def strip_first_component( + member: tarfile.TarInfo, + path, +) -> tarfile.TarInfo: + _, member.name = member.name.split('/', 1) + return member + + +def _compose(*cmgrs): + """ + Compose any number of dependent context managers into a single one. + + The last, innermost context manager may take arbitrary arguments, but + each successive context manager should accept the result from the + previous as a single parameter. + + Like :func:`jaraco.functools.compose`, behavior works from right to + left, so the context manager should be indicated from outermost to + innermost. + + Example, to create a context manager to change to a temporary + directory: + + >>> temp_dir_as_cwd = _compose(pushd, temp_dir) + >>> with temp_dir_as_cwd() as dir: + ... assert os.path.samefile(os.getcwd(), dir) + """ + + def compose_two(inner, outer): + def composed(*args, **kwargs): + with inner(*args, **kwargs) as saved, outer(saved) as res: + yield res + + return contextlib.contextmanager(composed) + + return functools.reduce(compose_two, reversed(cmgrs)) + + +tarball_cwd = _compose(pushd, tarball) + + +@contextlib.contextmanager +def tarball_context(*args, **kwargs): + warnings.warn( + "tarball_context is deprecated. Use tarball or tarball_cwd instead.", + DeprecationWarning, + stacklevel=2, + ) + pushd_ctx = kwargs.pop('pushd', pushd) + with tarball(*args, **kwargs) as tball, pushd_ctx(tball) as dir: + yield dir def infer_compression(url): @@ -68,6 +134,11 @@ def infer_compression(url): >>> infer_compression('file.xz') 'J' """ + warnings.warn( + "infer_compression is deprecated with no replacement", + DeprecationWarning, + stacklevel=2, + ) # cheat and just assume it's the last two characters compression_indicator = url[-2:] mapping = dict(gz='z', bz='j', xz='J') @@ -84,7 +155,7 @@ def temp_dir(remover=shutil.rmtree): >>> import pathlib >>> with temp_dir() as the_dir: ... assert os.path.isdir(the_dir) - ... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents') + ... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents', encoding='utf-8') >>> assert not os.path.exists(the_dir) """ temp_dir = tempfile.mkdtemp() @@ -113,15 +184,23 @@ def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir): yield repo_dir -@contextlib.contextmanager def null(): """ A null context suitable to stand in for a meaningful context. >>> with null() as value: ... assert value is None + + This context is most useful when dealing with two or more code + branches but only some need a context. Wrap the others in a null + context to provide symmetry across all options. """ - yield + warnings.warn( + "null is deprecated. Use contextlib.nullcontext", + DeprecationWarning, + stacklevel=2, + ) + return contextlib.nullcontext() class ExceptionTrap: @@ -267,13 +346,7 @@ class on_interrupt(contextlib.ContextDecorator): ... on_interrupt('ignore')(do_interrupt)() """ - def __init__( - self, - action='error', - # py3.7 compat - # /, - code=1, - ): + def __init__(self, action='error', /, code=1): self.action = action self.code = code diff --git a/lib/jaraco/functools/__init__.pyi b/lib/jaraco/functools/__init__.pyi index c2b9ab17..19191bf9 100644 --- a/lib/jaraco/functools/__init__.pyi +++ b/lib/jaraco/functools/__init__.pyi @@ -74,9 +74,6 @@ def result_invoke( def invoke( f: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs ) -> Callable[_P, _R]: ... -def call_aside( - f: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs -) -> Callable[_P, _R]: ... class Throttler(Generic[_R]): last_called: float diff --git a/lib/libfuturize/fixer_util.py b/lib/libfuturize/fixer_util.py index 48e4689d..b5c123f6 100644 --- a/lib/libfuturize/fixer_util.py +++ b/lib/libfuturize/fixer_util.py @@ -9,11 +9,11 @@ python-modernize licence: BSD (from python-modernize/LICENSE) """ from lib2to3.fixer_util import (FromImport, Newline, is_import, - find_root, does_tree_import, Comma) + find_root, does_tree_import, + Call, Name, Comma) from lib2to3.pytree import Leaf, Node -from lib2to3.pygram import python_symbols as syms, python_grammar +from lib2to3.pygram import python_symbols as syms from lib2to3.pygram import token -from lib2to3.fixer_util import (Node, Call, Name, syms, Comma, Number) import re @@ -116,7 +116,7 @@ def suitify(parent): """ for node in parent.children: if node.type == syms.suite: - # already in the prefered format, do nothing + # already in the preferred format, do nothing return # One-liners have no suite node, we have to fake one up @@ -390,6 +390,7 @@ def touch_import_top(package, name_to_import, node): break insert_pos = idx + children_hooks = [] if package is None: import_ = Node(syms.import_name, [ Leaf(token.NAME, u"import"), @@ -413,8 +414,6 @@ def touch_import_top(package, name_to_import, node): ] ) children_hooks = [install_hooks, Newline()] - else: - children_hooks = [] # FromImport(package, [Leaf(token.NAME, name_to_import, prefix=u" ")]) @@ -448,7 +447,6 @@ def check_future_import(node): else: node = node.children[3] # now node is the import_as_name[s] - # print(python_grammar.number2symbol[node.type]) # breaks sometimes if node.type == syms.import_as_names: result = set() for n in node.children: diff --git a/lib/libfuturize/fixes/fix_metaclass.py b/lib/libfuturize/fixes/fix_metaclass.py index 2ac41c97..a7eee40d 100644 --- a/lib/libfuturize/fixes/fix_metaclass.py +++ b/lib/libfuturize/fixes/fix_metaclass.py @@ -37,7 +37,7 @@ from lib2to3.fixer_util import Name, syms, Node, Leaf, touch_import, Call, \ def has_metaclass(parent): """ we have to check the cls_node without changing it. - There are two possiblities: + There are two possibilities: 1) clsdef => suite => simple_stmt => expr_stmt => Leaf('__meta') 2) clsdef => simple_stmt => expr_stmt => Leaf('__meta') """ @@ -63,7 +63,7 @@ def fixup_parse_tree(cls_node): # already in the preferred format, do nothing return - # !%@#! oneliners have no suite node, we have to fake one up + # !%@#! one-liners have no suite node, we have to fake one up for i, node in enumerate(cls_node.children): if node.type == token.COLON: break diff --git a/lib/libpasteurize/fixes/fix_imports.py b/lib/libpasteurize/fixes/fix_imports.py index 2d6718f1..b18ecf3d 100644 --- a/lib/libpasteurize/fixes/fix_imports.py +++ b/lib/libpasteurize/fixes/fix_imports.py @@ -16,6 +16,7 @@ MAPPING = {u"reprlib": u"repr", u"winreg": u"_winreg", u"configparser": u"ConfigParser", u"copyreg": u"copy_reg", + u"multiprocessing.SimpleQueue": u"multiprocessing.queues.SimpleQueue", u"queue": u"Queue", u"socketserver": u"SocketServer", u"_markupbase": u"markupbase", diff --git a/lib/libpasteurize/fixes/fix_unpacking.py b/lib/libpasteurize/fixes/fix_unpacking.py index c2d3207a..6e839e6b 100644 --- a/lib/libpasteurize/fixes/fix_unpacking.py +++ b/lib/libpasteurize/fixes/fix_unpacking.py @@ -18,8 +18,12 @@ def assignment_source(num_pre, num_post, LISTNAME, ITERNAME): Returns a source fit for Assign() from fixer_util """ children = [] - pre = unicode(num_pre) - post = unicode(num_post) + try: + pre = unicode(num_pre) + post = unicode(num_post) + except NameError: + pre = str(num_pre) + post = str(num_post) # This code builds the assignment source from lib2to3 tree primitives. # It's not very readable, but it seems like the most correct way to do it. if num_pre > 0: diff --git a/lib/past/__init__.py b/lib/past/__init__.py index 14713039..54619e0a 100644 --- a/lib/past/__init__.py +++ b/lib/past/__init__.py @@ -75,12 +75,12 @@ Credits ------- :Author: Ed Schofield, Jordan M. Adler, et al -:Sponsor: Python Charmers Pty Ltd, Australia: http://pythoncharmers.com +:Sponsor: Python Charmers: https://pythoncharmers.com Licensing --------- -Copyright 2013-2019 Python Charmers Pty Ltd, Australia. +Copyright 2013-2024 Python Charmers, Australia. The software is distributed under an MIT licence. See LICENSE.txt. """ diff --git a/lib/past/builtins/misc.py b/lib/past/builtins/misc.py index 3600695c..0b8e6a98 100644 --- a/lib/past/builtins/misc.py +++ b/lib/past/builtins/misc.py @@ -1,11 +1,13 @@ from __future__ import unicode_literals import inspect +import sys import math import numbers from future.utils import PY2, PY3, exec_ + if PY2: from collections import Mapping else: @@ -103,13 +105,12 @@ if PY3: return '0' + builtins.oct(number)[2:] raw_input = input - - try: + # imp was deprecated in python 3.6 + if sys.version_info >= (3, 6): from importlib import reload - except ImportError: + else: # for python2, python3 <= 3.4 from imp import reload - unicode = str unichr = chr xrange = range diff --git a/lib/past/translation/__init__.py b/lib/past/translation/__init__.py index 7c678866..ae6c0d90 100644 --- a/lib/past/translation/__init__.py +++ b/lib/past/translation/__init__.py @@ -32,17 +32,31 @@ Author: Ed Schofield. Inspired by and based on ``uprefix`` by Vinay M. Sajip. """ -import imp -import logging -import marshal -import os import sys +# imp was deprecated in python 3.6 +if sys.version_info >= (3, 6): + import importlib as imp +else: + import imp +import logging +import os import copy from lib2to3.pgen2.parse import ParseError from lib2to3.refactor import RefactoringTool from libfuturize import fixes +try: + from importlib.machinery import ( + PathFinder, + SourceFileLoader, + ) +except ImportError: + PathFinder = None + SourceFileLoader = object + +if sys.version_info[:2] < (3, 4): + import imp logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -225,6 +239,81 @@ def detect_python2(source, pathname): return False +def transform(source, pathname): + # This implementation uses lib2to3, + # you can override and use something else + # if that's better for you + + # lib2to3 likes a newline at the end + RTs.setup() + source += '\n' + try: + tree = RTs._rt.refactor_string(source, pathname) + except ParseError as e: + if e.msg != 'bad input' or e.value != '=': + raise + tree = RTs._rtp.refactor_string(source, pathname) + # could optimise a bit for only doing str(tree) if + # getattr(tree, 'was_changed', False) returns True + return str(tree)[:-1] # remove added newline + + +class PastSourceFileLoader(SourceFileLoader): + exclude_paths = [] + include_paths = [] + + def _convert_needed(self): + fullname = self.name + if any(fullname.startswith(path) for path in self.exclude_paths): + convert = False + elif any(fullname.startswith(path) for path in self.include_paths): + convert = True + else: + convert = False + return convert + + def _exec_transformed_module(self, module): + source = self.get_source(self.name) + pathname = self.path + if detect_python2(source, pathname): + source = transform(source, pathname) + code = compile(source, pathname, "exec") + exec(code, module.__dict__) + + # For Python 3.3 + def load_module(self, fullname): + logger.debug("Running load_module for %s", fullname) + if fullname in sys.modules: + mod = sys.modules[fullname] + else: + if self._convert_needed(): + logger.debug("Autoconverting %s", fullname) + mod = imp.new_module(fullname) + sys.modules[fullname] = mod + + # required by PEP 302 + mod.__file__ = self.path + mod.__loader__ = self + if self.is_package(fullname): + mod.__path__ = [] + mod.__package__ = fullname + else: + mod.__package__ = fullname.rpartition('.')[0] + self._exec_transformed_module(mod) + else: + mod = super().load_module(fullname) + return mod + + # For Python >=3.4 + def exec_module(self, module): + logger.debug("Running exec_module for %s", module) + if self._convert_needed(): + logger.debug("Autoconverting %s", self.name) + self._exec_transformed_module(module) + else: + super().exec_module(module) + + class Py2Fixer(object): """ An import hook class that uses lib2to3 for source-to-source translation of @@ -258,151 +347,30 @@ class Py2Fixer(object): """ self.exclude_paths += paths + # For Python 3.3 def find_module(self, fullname, path=None): - logger.debug('Running find_module: {0}...'.format(fullname)) - if '.' in fullname: - parent, child = fullname.rsplit('.', 1) - if path is None: - loader = self.find_module(parent, path) - mod = loader.load_module(parent) - path = mod.__path__ - fullname = child - - # Perhaps we should try using the new importlib functionality in Python - # 3.3: something like this? - # thing = importlib.machinery.PathFinder.find_module(fullname, path) - try: - self.found = imp.find_module(fullname, path) - except Exception as e: - logger.debug('Py2Fixer could not find {0}') - logger.debug('Exception was: {0})'.format(fullname, e)) + logger.debug("Running find_module: (%s, %s)", fullname, path) + loader = PathFinder.find_module(fullname, path) + if not loader: + logger.debug("Py2Fixer could not find %s", fullname) return None - self.kind = self.found[-1][-1] - if self.kind == imp.PKG_DIRECTORY: - self.pathname = os.path.join(self.found[1], '__init__.py') - elif self.kind == imp.PY_SOURCE: - self.pathname = self.found[1] - return self + loader.__class__ = PastSourceFileLoader + loader.exclude_paths = self.exclude_paths + loader.include_paths = self.include_paths + return loader - def transform(self, source): - # This implementation uses lib2to3, - # you can override and use something else - # if that's better for you + # For Python >=3.4 + def find_spec(self, fullname, path=None, target=None): + logger.debug("Running find_spec: (%s, %s, %s)", fullname, path, target) + spec = PathFinder.find_spec(fullname, path, target) + if not spec: + logger.debug("Py2Fixer could not find %s", fullname) + return None + spec.loader.__class__ = PastSourceFileLoader + spec.loader.exclude_paths = self.exclude_paths + spec.loader.include_paths = self.include_paths + return spec - # lib2to3 likes a newline at the end - RTs.setup() - source += '\n' - try: - tree = RTs._rt.refactor_string(source, self.pathname) - except ParseError as e: - if e.msg != 'bad input' or e.value != '=': - raise - tree = RTs._rtp.refactor_string(source, self.pathname) - # could optimise a bit for only doing str(tree) if - # getattr(tree, 'was_changed', False) returns True - return str(tree)[:-1] # remove added newline - - def load_module(self, fullname): - logger.debug('Running load_module for {0}...'.format(fullname)) - if fullname in sys.modules: - mod = sys.modules[fullname] - else: - if self.kind in (imp.PY_COMPILED, imp.C_EXTENSION, imp.C_BUILTIN, - imp.PY_FROZEN): - convert = False - # elif (self.pathname.startswith(_stdlibprefix) - # and 'site-packages' not in self.pathname): - # # We assume it's a stdlib package in this case. Is this too brittle? - # # Please file a bug report at https://github.com/PythonCharmers/python-future - # # if so. - # convert = False - # in theory, other paths could be configured to be excluded here too - elif any([fullname.startswith(path) for path in self.exclude_paths]): - convert = False - elif any([fullname.startswith(path) for path in self.include_paths]): - convert = True - else: - convert = False - if not convert: - logger.debug('Excluded {0} from translation'.format(fullname)) - mod = imp.load_module(fullname, *self.found) - else: - logger.debug('Autoconverting {0} ...'.format(fullname)) - mod = imp.new_module(fullname) - sys.modules[fullname] = mod - - # required by PEP 302 - mod.__file__ = self.pathname - mod.__name__ = fullname - mod.__loader__ = self - - # This: - # mod.__package__ = '.'.join(fullname.split('.')[:-1]) - # seems to result in "SystemError: Parent module '' not loaded, - # cannot perform relative import" for a package's __init__.py - # file. We use the approach below. Another option to try is the - # minimal load_module pattern from the PEP 302 text instead. - - # Is the test in the next line more or less robust than the - # following one? Presumably less ... - # ispkg = self.pathname.endswith('__init__.py') - - if self.kind == imp.PKG_DIRECTORY: - mod.__path__ = [ os.path.dirname(self.pathname) ] - mod.__package__ = fullname - else: - #else, regular module - mod.__path__ = [] - mod.__package__ = fullname.rpartition('.')[0] - - try: - cachename = imp.cache_from_source(self.pathname) - if not os.path.exists(cachename): - update_cache = True - else: - sourcetime = os.stat(self.pathname).st_mtime - cachetime = os.stat(cachename).st_mtime - update_cache = cachetime < sourcetime - # # Force update_cache to work around a problem with it being treated as Py3 code??? - # update_cache = True - if not update_cache: - with open(cachename, 'rb') as f: - data = f.read() - try: - code = marshal.loads(data) - except Exception: - # pyc could be corrupt. Regenerate it - update_cache = True - if update_cache: - if self.found[0]: - source = self.found[0].read() - elif self.kind == imp.PKG_DIRECTORY: - with open(self.pathname) as f: - source = f.read() - - if detect_python2(source, self.pathname): - source = self.transform(source) - - code = compile(source, self.pathname, 'exec') - - dirname = os.path.dirname(cachename) - try: - if not os.path.exists(dirname): - os.makedirs(dirname) - with open(cachename, 'wb') as f: - data = marshal.dumps(code) - f.write(data) - except Exception: # could be write-protected - pass - exec(code, mod.__dict__) - except Exception as e: - # must remove module from sys.modules - del sys.modules[fullname] - raise # keep it simple - - if self.found[0]: - self.found[0].close() - return mod _hook = Py2Fixer() diff --git a/lib/typeguard/__init__.py b/lib/typeguard/__init__.py new file mode 100644 index 00000000..6781cad0 --- /dev/null +++ b/lib/typeguard/__init__.py @@ -0,0 +1,48 @@ +import os +from typing import Any + +from ._checkers import TypeCheckerCallable as TypeCheckerCallable +from ._checkers import TypeCheckLookupCallback as TypeCheckLookupCallback +from ._checkers import check_type_internal as check_type_internal +from ._checkers import checker_lookup_functions as checker_lookup_functions +from ._checkers import load_plugins as load_plugins +from ._config import CollectionCheckStrategy as CollectionCheckStrategy +from ._config import ForwardRefPolicy as ForwardRefPolicy +from ._config import TypeCheckConfiguration as TypeCheckConfiguration +from ._decorators import typechecked as typechecked +from ._decorators import typeguard_ignore as typeguard_ignore +from ._exceptions import InstrumentationWarning as InstrumentationWarning +from ._exceptions import TypeCheckError as TypeCheckError +from ._exceptions import TypeCheckWarning as TypeCheckWarning +from ._exceptions import TypeHintWarning as TypeHintWarning +from ._functions import TypeCheckFailCallback as TypeCheckFailCallback +from ._functions import check_type as check_type +from ._functions import warn_on_error as warn_on_error +from ._importhook import ImportHookManager as ImportHookManager +from ._importhook import TypeguardFinder as TypeguardFinder +from ._importhook import install_import_hook as install_import_hook +from ._memo import TypeCheckMemo as TypeCheckMemo +from ._suppression import suppress_type_checks as suppress_type_checks +from ._utils import Unset as Unset + +# Re-export imports so they look like they live directly in this package +for value in list(locals().values()): + if getattr(value, "__module__", "").startswith(f"{__name__}."): + value.__module__ = __name__ + + +config: TypeCheckConfiguration + + +def __getattr__(name: str) -> Any: + if name == "config": + from ._config import global_config + + return global_config + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +# Automatically load checker lookup functions unless explicitly disabled +if "TYPEGUARD_DISABLE_PLUGIN_AUTOLOAD" not in os.environ: + load_plugins() diff --git a/lib/typeguard/_checkers.py b/lib/typeguard/_checkers.py new file mode 100644 index 00000000..2f8de6f3 --- /dev/null +++ b/lib/typeguard/_checkers.py @@ -0,0 +1,910 @@ +from __future__ import annotations + +import collections.abc +import inspect +import sys +import types +import typing +import warnings +from enum import Enum +from inspect import Parameter, isclass, isfunction +from io import BufferedIOBase, IOBase, RawIOBase, TextIOBase +from textwrap import indent +from typing import ( + IO, + AbstractSet, + Any, + BinaryIO, + Callable, + Dict, + ForwardRef, + List, + Mapping, + MutableMapping, + NewType, + Optional, + Sequence, + Set, + TextIO, + Tuple, + Type, + TypeVar, + Union, +) +from unittest.mock import Mock + +try: + import typing_extensions +except ImportError: + typing_extensions = None # type: ignore[assignment] + +from ._config import ForwardRefPolicy +from ._exceptions import TypeCheckError, TypeHintWarning +from ._memo import TypeCheckMemo +from ._utils import evaluate_forwardref, get_stacklevel, get_type_name, qualified_name + +if sys.version_info >= (3, 13): + from typing import is_typeddict +else: + from typing_extensions import is_typeddict + +if sys.version_info >= (3, 11): + from typing import ( + Annotated, + NotRequired, + TypeAlias, + get_args, + get_origin, + ) + + SubclassableAny = Any +else: + from typing_extensions import ( + Annotated, + NotRequired, + TypeAlias, + get_args, + get_origin, + ) + from typing_extensions import Any as SubclassableAny + +if sys.version_info >= (3, 10): + from importlib.metadata import entry_points + from typing import ParamSpec +else: + from importlib_metadata import entry_points + from typing_extensions import ParamSpec + +TypeCheckerCallable: TypeAlias = Callable[ + [Any, Any, Tuple[Any, ...], TypeCheckMemo], Any +] +TypeCheckLookupCallback: TypeAlias = Callable[ + [Any, Tuple[Any, ...], Tuple[Any, ...]], Optional[TypeCheckerCallable] +] + +checker_lookup_functions: list[TypeCheckLookupCallback] = [] +generic_alias_types: tuple[type, ...] = (type(List), type(List[Any])) +if sys.version_info >= (3, 9): + generic_alias_types += (types.GenericAlias,) + + +# Sentinel +_missing = object() + +# Lifted from mypy.sharedparse +BINARY_MAGIC_METHODS = { + "__add__", + "__and__", + "__cmp__", + "__divmod__", + "__div__", + "__eq__", + "__floordiv__", + "__ge__", + "__gt__", + "__iadd__", + "__iand__", + "__idiv__", + "__ifloordiv__", + "__ilshift__", + "__imatmul__", + "__imod__", + "__imul__", + "__ior__", + "__ipow__", + "__irshift__", + "__isub__", + "__itruediv__", + "__ixor__", + "__le__", + "__lshift__", + "__lt__", + "__matmul__", + "__mod__", + "__mul__", + "__ne__", + "__or__", + "__pow__", + "__radd__", + "__rand__", + "__rdiv__", + "__rfloordiv__", + "__rlshift__", + "__rmatmul__", + "__rmod__", + "__rmul__", + "__ror__", + "__rpow__", + "__rrshift__", + "__rshift__", + "__rsub__", + "__rtruediv__", + "__rxor__", + "__sub__", + "__truediv__", + "__xor__", +} + + +def check_callable( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not callable(value): + raise TypeCheckError("is not callable") + + if args: + try: + signature = inspect.signature(value) + except (TypeError, ValueError): + return + + argument_types = args[0] + if isinstance(argument_types, list) and not any( + type(item) is ParamSpec for item in argument_types + ): + # The callable must not have keyword-only arguments without defaults + unfulfilled_kwonlyargs = [ + param.name + for param in signature.parameters.values() + if param.kind == Parameter.KEYWORD_ONLY + and param.default == Parameter.empty + ] + if unfulfilled_kwonlyargs: + raise TypeCheckError( + f"has mandatory keyword-only arguments in its declaration: " + f'{", ".join(unfulfilled_kwonlyargs)}' + ) + + num_positional_args = num_mandatory_pos_args = 0 + has_varargs = False + for param in signature.parameters.values(): + if param.kind in ( + Parameter.POSITIONAL_ONLY, + Parameter.POSITIONAL_OR_KEYWORD, + ): + num_positional_args += 1 + if param.default is Parameter.empty: + num_mandatory_pos_args += 1 + elif param.kind == Parameter.VAR_POSITIONAL: + has_varargs = True + + if num_mandatory_pos_args > len(argument_types): + raise TypeCheckError( + f"has too many mandatory positional arguments in its declaration; " + f"expected {len(argument_types)} but {num_mandatory_pos_args} " + f"mandatory positional argument(s) declared" + ) + elif not has_varargs and num_positional_args < len(argument_types): + raise TypeCheckError( + f"has too few arguments in its declaration; expected " + f"{len(argument_types)} but {num_positional_args} argument(s) " + f"declared" + ) + + +def check_mapping( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if origin_type is Dict or origin_type is dict: + if not isinstance(value, dict): + raise TypeCheckError("is not a dict") + if origin_type is MutableMapping or origin_type is collections.abc.MutableMapping: + if not isinstance(value, collections.abc.MutableMapping): + raise TypeCheckError("is not a mutable mapping") + elif not isinstance(value, collections.abc.Mapping): + raise TypeCheckError("is not a mapping") + + if args: + key_type, value_type = args + if key_type is not Any or value_type is not Any: + samples = memo.config.collection_check_strategy.iterate_samples( + value.items() + ) + for k, v in samples: + try: + check_type_internal(k, key_type, memo) + except TypeCheckError as exc: + exc.append_path_element(f"key {k!r}") + raise + + try: + check_type_internal(v, value_type, memo) + except TypeCheckError as exc: + exc.append_path_element(f"value of key {k!r}") + raise + + +def check_typed_dict( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not isinstance(value, dict): + raise TypeCheckError("is not a dict") + + declared_keys = frozenset(origin_type.__annotations__) + if hasattr(origin_type, "__required_keys__"): + required_keys = set(origin_type.__required_keys__) + else: # py3.8 and lower + required_keys = set(declared_keys) if origin_type.__total__ else set() + + existing_keys = set(value) + extra_keys = existing_keys - declared_keys + if extra_keys: + keys_formatted = ", ".join(f'"{key}"' for key in sorted(extra_keys, key=repr)) + raise TypeCheckError(f"has unexpected extra key(s): {keys_formatted}") + + # Detect NotRequired fields which are hidden by get_type_hints() + type_hints: dict[str, type] = {} + for key, annotation in origin_type.__annotations__.items(): + if isinstance(annotation, ForwardRef): + annotation = evaluate_forwardref(annotation, memo) + if get_origin(annotation) is NotRequired: + required_keys.discard(key) + annotation = get_args(annotation)[0] + + type_hints[key] = annotation + + missing_keys = required_keys - existing_keys + if missing_keys: + keys_formatted = ", ".join(f'"{key}"' for key in sorted(missing_keys, key=repr)) + raise TypeCheckError(f"is missing required key(s): {keys_formatted}") + + for key, argtype in type_hints.items(): + argvalue = value.get(key, _missing) + if argvalue is not _missing: + try: + check_type_internal(argvalue, argtype, memo) + except TypeCheckError as exc: + exc.append_path_element(f"value of key {key!r}") + raise + + +def check_list( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not isinstance(value, list): + raise TypeCheckError("is not a list") + + if args and args != (Any,): + samples = memo.config.collection_check_strategy.iterate_samples(value) + for i, v in enumerate(samples): + try: + check_type_internal(v, args[0], memo) + except TypeCheckError as exc: + exc.append_path_element(f"item {i}") + raise + + +def check_sequence( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not isinstance(value, collections.abc.Sequence): + raise TypeCheckError("is not a sequence") + + if args and args != (Any,): + samples = memo.config.collection_check_strategy.iterate_samples(value) + for i, v in enumerate(samples): + try: + check_type_internal(v, args[0], memo) + except TypeCheckError as exc: + exc.append_path_element(f"item {i}") + raise + + +def check_set( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if origin_type is frozenset: + if not isinstance(value, frozenset): + raise TypeCheckError("is not a frozenset") + elif not isinstance(value, AbstractSet): + raise TypeCheckError("is not a set") + + if args and args != (Any,): + samples = memo.config.collection_check_strategy.iterate_samples(value) + for v in samples: + try: + check_type_internal(v, args[0], memo) + except TypeCheckError as exc: + exc.append_path_element(f"[{v}]") + raise + + +def check_tuple( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + # Specialized check for NamedTuples + if field_types := getattr(origin_type, "__annotations__", None): + if not isinstance(value, origin_type): + raise TypeCheckError( + f"is not a named tuple of type {qualified_name(origin_type)}" + ) + + for name, field_type in field_types.items(): + try: + check_type_internal(getattr(value, name), field_type, memo) + except TypeCheckError as exc: + exc.append_path_element(f"attribute {name!r}") + raise + + return + elif not isinstance(value, tuple): + raise TypeCheckError("is not a tuple") + + if args: + use_ellipsis = args[-1] is Ellipsis + tuple_params = args[: -1 if use_ellipsis else None] + else: + # Unparametrized Tuple or plain tuple + return + + if use_ellipsis: + element_type = tuple_params[0] + samples = memo.config.collection_check_strategy.iterate_samples(value) + for i, element in enumerate(samples): + try: + check_type_internal(element, element_type, memo) + except TypeCheckError as exc: + exc.append_path_element(f"item {i}") + raise + elif tuple_params == ((),): + if value != (): + raise TypeCheckError("is not an empty tuple") + else: + if len(value) != len(tuple_params): + raise TypeCheckError( + f"has wrong number of elements (expected {len(tuple_params)}, got " + f"{len(value)} instead)" + ) + + for i, (element, element_type) in enumerate(zip(value, tuple_params)): + try: + check_type_internal(element, element_type, memo) + except TypeCheckError as exc: + exc.append_path_element(f"item {i}") + raise + + +def check_union( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + errors: dict[str, TypeCheckError] = {} + try: + for type_ in args: + try: + check_type_internal(value, type_, memo) + return + except TypeCheckError as exc: + errors[get_type_name(type_)] = exc + + formatted_errors = indent( + "\n".join(f"{key}: {error}" for key, error in errors.items()), " " + ) + finally: + del errors # avoid creating ref cycle + raise TypeCheckError(f"did not match any element in the union:\n{formatted_errors}") + + +def check_uniontype( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + errors: dict[str, TypeCheckError] = {} + for type_ in args: + try: + check_type_internal(value, type_, memo) + return + except TypeCheckError as exc: + errors[get_type_name(type_)] = exc + + formatted_errors = indent( + "\n".join(f"{key}: {error}" for key, error in errors.items()), " " + ) + raise TypeCheckError(f"did not match any element in the union:\n{formatted_errors}") + + +def check_class( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not isclass(value) and not isinstance(value, generic_alias_types): + raise TypeCheckError("is not a class") + + if not args: + return + + if isinstance(args[0], ForwardRef): + expected_class = evaluate_forwardref(args[0], memo) + else: + expected_class = args[0] + + if expected_class is Any: + return + elif getattr(expected_class, "_is_protocol", False): + check_protocol(value, expected_class, (), memo) + elif isinstance(expected_class, TypeVar): + check_typevar(value, expected_class, (), memo, subclass_check=True) + elif get_origin(expected_class) is Union: + errors: dict[str, TypeCheckError] = {} + for arg in get_args(expected_class): + if arg is Any: + return + + try: + check_class(value, type, (arg,), memo) + return + except TypeCheckError as exc: + errors[get_type_name(arg)] = exc + else: + formatted_errors = indent( + "\n".join(f"{key}: {error}" for key, error in errors.items()), " " + ) + raise TypeCheckError( + f"did not match any element in the union:\n{formatted_errors}" + ) + elif not issubclass(value, expected_class): # type: ignore[arg-type] + raise TypeCheckError(f"is not a subclass of {qualified_name(expected_class)}") + + +def check_newtype( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + check_type_internal(value, origin_type.__supertype__, memo) + + +def check_instance( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not isinstance(value, origin_type): + raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}") + + +def check_typevar( + value: Any, + origin_type: TypeVar, + args: tuple[Any, ...], + memo: TypeCheckMemo, + *, + subclass_check: bool = False, +) -> None: + if origin_type.__bound__ is not None: + annotation = ( + Type[origin_type.__bound__] if subclass_check else origin_type.__bound__ + ) + check_type_internal(value, annotation, memo) + elif origin_type.__constraints__: + for constraint in origin_type.__constraints__: + annotation = Type[constraint] if subclass_check else constraint + try: + check_type_internal(value, annotation, memo) + except TypeCheckError: + pass + else: + break + else: + formatted_constraints = ", ".join( + get_type_name(constraint) for constraint in origin_type.__constraints__ + ) + raise TypeCheckError( + f"does not match any of the constraints " f"({formatted_constraints})" + ) + + +if typing_extensions is None: + + def _is_literal_type(typ: object) -> bool: + return typ is typing.Literal + +else: + + def _is_literal_type(typ: object) -> bool: + return typ is typing.Literal or typ is typing_extensions.Literal + + +def check_literal( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + def get_literal_args(literal_args: tuple[Any, ...]) -> tuple[Any, ...]: + retval: list[Any] = [] + for arg in literal_args: + if _is_literal_type(get_origin(arg)): + retval.extend(get_literal_args(arg.__args__)) + elif arg is None or isinstance(arg, (int, str, bytes, bool, Enum)): + retval.append(arg) + else: + raise TypeError( + f"Illegal literal value: {arg}" + ) # TypeError here is deliberate + + return tuple(retval) + + final_args = tuple(get_literal_args(args)) + try: + index = final_args.index(value) + except ValueError: + pass + else: + if type(final_args[index]) is type(value): + return + + formatted_args = ", ".join(repr(arg) for arg in final_args) + raise TypeCheckError(f"is not any of ({formatted_args})") from None + + +def check_literal_string( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + check_type_internal(value, str, memo) + + +def check_typeguard( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + check_type_internal(value, bool, memo) + + +def check_none( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if value is not None: + raise TypeCheckError("is not None") + + +def check_number( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if origin_type is complex and not isinstance(value, (complex, float, int)): + raise TypeCheckError("is neither complex, float or int") + elif origin_type is float and not isinstance(value, (float, int)): + raise TypeCheckError("is neither float or int") + + +def check_io( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if origin_type is TextIO or (origin_type is IO and args == (str,)): + if not isinstance(value, TextIOBase): + raise TypeCheckError("is not a text based I/O object") + elif origin_type is BinaryIO or (origin_type is IO and args == (bytes,)): + if not isinstance(value, (RawIOBase, BufferedIOBase)): + raise TypeCheckError("is not a binary I/O object") + elif not isinstance(value, IOBase): + raise TypeCheckError("is not an I/O object") + + +def check_protocol( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + # TODO: implement proper compatibility checking and support non-runtime protocols + if getattr(origin_type, "_is_runtime_protocol", False): + if not isinstance(value, origin_type): + raise TypeCheckError( + f"is not compatible with the {origin_type.__qualname__} protocol" + ) + else: + warnings.warn( + f"Typeguard cannot check the {origin_type.__qualname__} protocol because " + f"it is a non-runtime protocol. If you would like to type check this " + f"protocol, please use @typing.runtime_checkable", + stacklevel=get_stacklevel(), + ) + + +def check_byteslike( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not isinstance(value, (bytearray, bytes, memoryview)): + raise TypeCheckError("is not bytes-like") + + +def check_self( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if memo.self_type is None: + raise TypeCheckError("cannot be checked against Self outside of a method call") + + if isclass(value): + if not issubclass(value, memo.self_type): + raise TypeCheckError( + f"is not an instance of the self type " + f"({qualified_name(memo.self_type)})" + ) + elif not isinstance(value, memo.self_type): + raise TypeCheckError( + f"is not an instance of the self type ({qualified_name(memo.self_type)})" + ) + + +def check_paramspec( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + pass # No-op for now + + +def check_instanceof( + value: Any, + origin_type: Any, + args: tuple[Any, ...], + memo: TypeCheckMemo, +) -> None: + if not isinstance(value, origin_type): + raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}") + + +def check_type_internal( + value: Any, + annotation: Any, + memo: TypeCheckMemo, +) -> None: + """ + Check that the given object is compatible with the given type annotation. + + This function should only be used by type checker callables. Applications should use + :func:`~.check_type` instead. + + :param value: the value to check + :param annotation: the type annotation to check against + :param memo: a memo object containing configuration and information necessary for + looking up forward references + """ + + if isinstance(annotation, ForwardRef): + try: + annotation = evaluate_forwardref(annotation, memo) + except NameError: + if memo.config.forward_ref_policy is ForwardRefPolicy.ERROR: + raise + elif memo.config.forward_ref_policy is ForwardRefPolicy.WARN: + warnings.warn( + f"Cannot resolve forward reference {annotation.__forward_arg__!r}", + TypeHintWarning, + stacklevel=get_stacklevel(), + ) + + return + + if annotation is Any or annotation is SubclassableAny or isinstance(value, Mock): + return + + # Skip type checks if value is an instance of a class that inherits from Any + if not isclass(value) and SubclassableAny in type(value).__bases__: + return + + extras: tuple[Any, ...] + origin_type = get_origin(annotation) + if origin_type is Annotated: + annotation, *extras_ = get_args(annotation) + extras = tuple(extras_) + origin_type = get_origin(annotation) + else: + extras = () + + if origin_type is not None: + args = get_args(annotation) + + # Compatibility hack to distinguish between unparametrized and empty tuple + # (tuple[()]), necessary due to https://github.com/python/cpython/issues/91137 + if origin_type in (tuple, Tuple) and annotation is not Tuple and not args: + args = ((),) + else: + origin_type = annotation + args = () + + for lookup_func in checker_lookup_functions: + checker = lookup_func(origin_type, args, extras) + if checker: + checker(value, origin_type, args, memo) + return + + if isclass(origin_type): + if not isinstance(value, origin_type): + raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}") + elif type(origin_type) is str: # noqa: E721 + warnings.warn( + f"Skipping type check against {origin_type!r}; this looks like a " + f"string-form forward reference imported from another module", + TypeHintWarning, + stacklevel=get_stacklevel(), + ) + + +# Equality checks are applied to these +origin_type_checkers = { + bytes: check_byteslike, + AbstractSet: check_set, + BinaryIO: check_io, + Callable: check_callable, + collections.abc.Callable: check_callable, + complex: check_number, + dict: check_mapping, + Dict: check_mapping, + float: check_number, + frozenset: check_set, + IO: check_io, + list: check_list, + List: check_list, + typing.Literal: check_literal, + Mapping: check_mapping, + MutableMapping: check_mapping, + None: check_none, + collections.abc.Mapping: check_mapping, + collections.abc.MutableMapping: check_mapping, + Sequence: check_sequence, + collections.abc.Sequence: check_sequence, + collections.abc.Set: check_set, + set: check_set, + Set: check_set, + TextIO: check_io, + tuple: check_tuple, + Tuple: check_tuple, + type: check_class, + Type: check_class, + Union: check_union, +} +if sys.version_info >= (3, 10): + origin_type_checkers[types.UnionType] = check_uniontype + origin_type_checkers[typing.TypeGuard] = check_typeguard +if sys.version_info >= (3, 11): + origin_type_checkers.update( + {typing.LiteralString: check_literal_string, typing.Self: check_self} + ) +if typing_extensions is not None: + # On some Python versions, these may simply be re-exports from typing, + # but exactly which Python versions is subject to change, + # so it's best to err on the safe side + # and update the dictionary on all Python versions + # if typing_extensions is installed + origin_type_checkers[typing_extensions.Literal] = check_literal + origin_type_checkers[typing_extensions.LiteralString] = check_literal_string + origin_type_checkers[typing_extensions.Self] = check_self + origin_type_checkers[typing_extensions.TypeGuard] = check_typeguard + + +def builtin_checker_lookup( + origin_type: Any, args: tuple[Any, ...], extras: tuple[Any, ...] +) -> TypeCheckerCallable | None: + checker = origin_type_checkers.get(origin_type) + if checker is not None: + return checker + elif is_typeddict(origin_type): + return check_typed_dict + elif isclass(origin_type) and issubclass( + origin_type, Tuple # type: ignore[arg-type] + ): + # NamedTuple + return check_tuple + elif getattr(origin_type, "_is_protocol", False): + return check_protocol + elif isinstance(origin_type, ParamSpec): + return check_paramspec + elif isinstance(origin_type, TypeVar): + return check_typevar + elif origin_type.__class__ is NewType: + # typing.NewType on Python 3.10+ + return check_newtype + elif ( + isfunction(origin_type) + and getattr(origin_type, "__module__", None) == "typing" + and getattr(origin_type, "__qualname__", "").startswith("NewType.") + and hasattr(origin_type, "__supertype__") + ): + # typing.NewType on Python 3.9 and below + return check_newtype + + return None + + +checker_lookup_functions.append(builtin_checker_lookup) + + +def load_plugins() -> None: + """ + Load all type checker lookup functions from entry points. + + All entry points from the ``typeguard.checker_lookup`` group are loaded, and the + returned lookup functions are added to :data:`typeguard.checker_lookup_functions`. + + .. note:: This function is called implicitly on import, unless the + ``TYPEGUARD_DISABLE_PLUGIN_AUTOLOAD`` environment variable is present. + """ + + for ep in entry_points(group="typeguard.checker_lookup"): + try: + plugin = ep.load() + except Exception as exc: + warnings.warn( + f"Failed to load plugin {ep.name!r}: " f"{qualified_name(exc)}: {exc}", + stacklevel=2, + ) + continue + + if not callable(plugin): + warnings.warn( + f"Plugin {ep} returned a non-callable object: {plugin!r}", stacklevel=2 + ) + continue + + checker_lookup_functions.insert(0, plugin) diff --git a/lib/typeguard/_config.py b/lib/typeguard/_config.py new file mode 100644 index 00000000..36efad53 --- /dev/null +++ b/lib/typeguard/_config.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from collections.abc import Iterable +from dataclasses import dataclass +from enum import Enum, auto +from typing import TYPE_CHECKING, TypeVar + +if TYPE_CHECKING: + from ._functions import TypeCheckFailCallback + +T = TypeVar("T") + + +class ForwardRefPolicy(Enum): + """ + Defines how unresolved forward references are handled. + + Members: + + * ``ERROR``: propagate the :exc:`NameError` when the forward reference lookup fails + * ``WARN``: emit a :class:`~.TypeHintWarning` if the forward reference lookup fails + * ``IGNORE``: silently skip checks for unresolveable forward references + """ + + ERROR = auto() + WARN = auto() + IGNORE = auto() + + +class CollectionCheckStrategy(Enum): + """ + Specifies how thoroughly the contents of collections are type checked. + + This has an effect on the following built-in checkers: + + * ``AbstractSet`` + * ``Dict`` + * ``List`` + * ``Mapping`` + * ``Set`` + * ``Tuple[, ...]`` (arbitrarily sized tuples) + + Members: + + * ``FIRST_ITEM``: check only the first item + * ``ALL_ITEMS``: check all items + """ + + FIRST_ITEM = auto() + ALL_ITEMS = auto() + + def iterate_samples(self, collection: Iterable[T]) -> Iterable[T]: + if self is CollectionCheckStrategy.FIRST_ITEM: + try: + return [next(iter(collection))] + except StopIteration: + return () + else: + return collection + + +@dataclass +class TypeCheckConfiguration: + """ + You can change Typeguard's behavior with these settings. + + .. attribute:: typecheck_fail_callback + :type: Callable[[TypeCheckError, TypeCheckMemo], Any] + + Callable that is called when type checking fails. + + Default: ``None`` (the :exc:`~.TypeCheckError` is raised directly) + + .. attribute:: forward_ref_policy + :type: ForwardRefPolicy + + Specifies what to do when a forward reference fails to resolve. + + Default: ``WARN`` + + .. attribute:: collection_check_strategy + :type: CollectionCheckStrategy + + Specifies how thoroughly the contents of collections (list, dict, etc.) are + type checked. + + Default: ``FIRST_ITEM`` + + .. attribute:: debug_instrumentation + :type: bool + + If set to ``True``, the code of modules or functions instrumented by typeguard + is printed to ``sys.stderr`` after the instrumentation is done + + Requires Python 3.9 or newer. + + Default: ``False`` + """ + + forward_ref_policy: ForwardRefPolicy = ForwardRefPolicy.WARN + typecheck_fail_callback: TypeCheckFailCallback | None = None + collection_check_strategy: CollectionCheckStrategy = ( + CollectionCheckStrategy.FIRST_ITEM + ) + debug_instrumentation: bool = False + + +global_config = TypeCheckConfiguration() diff --git a/lib/typeguard/_decorators.py b/lib/typeguard/_decorators.py new file mode 100644 index 00000000..cf325335 --- /dev/null +++ b/lib/typeguard/_decorators.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +import ast +import inspect +import sys +from collections.abc import Sequence +from functools import partial +from inspect import isclass, isfunction +from types import CodeType, FrameType, FunctionType +from typing import TYPE_CHECKING, Any, Callable, ForwardRef, TypeVar, cast, overload +from warnings import warn + +from ._config import CollectionCheckStrategy, ForwardRefPolicy, global_config +from ._exceptions import InstrumentationWarning +from ._functions import TypeCheckFailCallback +from ._transformer import TypeguardTransformer +from ._utils import Unset, function_name, get_stacklevel, is_method_of, unset + +if TYPE_CHECKING: + from typeshed.stdlib.types import _Cell + + _F = TypeVar("_F") + + def typeguard_ignore(f: _F) -> _F: + """This decorator is a noop during static type-checking.""" + return f + +else: + from typing import no_type_check as typeguard_ignore # noqa: F401 + +T_CallableOrType = TypeVar("T_CallableOrType", bound=Callable[..., Any]) + + +def make_cell(value: object) -> _Cell: + return (lambda: value).__closure__[0] # type: ignore[index] + + +def find_target_function( + new_code: CodeType, target_path: Sequence[str], firstlineno: int +) -> CodeType | None: + target_name = target_path[0] + for const in new_code.co_consts: + if isinstance(const, CodeType): + if const.co_name == target_name: + if const.co_firstlineno == firstlineno: + return const + elif len(target_path) > 1: + target_code = find_target_function( + const, target_path[1:], firstlineno + ) + if target_code: + return target_code + + return None + + +def instrument(f: T_CallableOrType) -> FunctionType | str: + if not getattr(f, "__code__", None): + return "no code associated" + elif not getattr(f, "__module__", None): + return "__module__ attribute is not set" + elif f.__code__.co_filename == "": + return "cannot instrument functions defined in a REPL" + elif hasattr(f, "__wrapped__"): + return ( + "@typechecked only supports instrumenting functions wrapped with " + "@classmethod, @staticmethod or @property" + ) + + target_path = [item for item in f.__qualname__.split(".") if item != ""] + module_source = inspect.getsource(sys.modules[f.__module__]) + module_ast = ast.parse(module_source) + instrumentor = TypeguardTransformer(target_path, f.__code__.co_firstlineno) + instrumentor.visit(module_ast) + + if not instrumentor.target_node or instrumentor.target_lineno is None: + return "instrumentor did not find the target function" + + module_code = compile(module_ast, f.__code__.co_filename, "exec", dont_inherit=True) + new_code = find_target_function( + module_code, target_path, instrumentor.target_lineno + ) + if not new_code: + return "cannot find the target function in the AST" + + if global_config.debug_instrumentation and sys.version_info >= (3, 9): + # Find the matching AST node, then unparse it to source and print to stdout + print( + f"Source code of {f.__qualname__}() after instrumentation:" + "\n----------------------------------------------", + file=sys.stderr, + ) + print(ast.unparse(instrumentor.target_node), file=sys.stderr) + print( + "----------------------------------------------", + file=sys.stderr, + ) + + closure = f.__closure__ + if new_code.co_freevars != f.__code__.co_freevars: + # Create a new closure and find values for the new free variables + frame = cast(FrameType, inspect.currentframe()) + frame = cast(FrameType, frame.f_back) + frame_locals = cast(FrameType, frame.f_back).f_locals + cells: list[_Cell] = [] + for key in new_code.co_freevars: + if key in instrumentor.names_used_in_annotations: + # Find the value and make a new cell from it + value = frame_locals.get(key) or ForwardRef(key) + cells.append(make_cell(value)) + else: + # Reuse the cell from the existing closure + assert f.__closure__ + cells.append(f.__closure__[f.__code__.co_freevars.index(key)]) + + closure = tuple(cells) + + new_function = FunctionType(new_code, f.__globals__, f.__name__, closure=closure) + new_function.__module__ = f.__module__ + new_function.__name__ = f.__name__ + new_function.__qualname__ = f.__qualname__ + new_function.__annotations__ = f.__annotations__ + new_function.__doc__ = f.__doc__ + new_function.__defaults__ = f.__defaults__ + new_function.__kwdefaults__ = f.__kwdefaults__ + return new_function + + +@overload +def typechecked( + *, + forward_ref_policy: ForwardRefPolicy | Unset = unset, + typecheck_fail_callback: TypeCheckFailCallback | Unset = unset, + collection_check_strategy: CollectionCheckStrategy | Unset = unset, + debug_instrumentation: bool | Unset = unset, +) -> Callable[[T_CallableOrType], T_CallableOrType]: ... + + +@overload +def typechecked(target: T_CallableOrType) -> T_CallableOrType: ... + + +def typechecked( + target: T_CallableOrType | None = None, + *, + forward_ref_policy: ForwardRefPolicy | Unset = unset, + typecheck_fail_callback: TypeCheckFailCallback | Unset = unset, + collection_check_strategy: CollectionCheckStrategy | Unset = unset, + debug_instrumentation: bool | Unset = unset, +) -> Any: + """ + Instrument the target function to perform run-time type checking. + + This decorator recompiles the target function, injecting code to type check + arguments, return values, yield values (excluding ``yield from``) and assignments to + annotated local variables. + + This can also be used as a class decorator. This will instrument all type annotated + methods, including :func:`@classmethod `, + :func:`@staticmethod `, and :class:`@property ` decorated + methods in the class. + + .. note:: When Python is run in optimized mode (``-O`` or ``-OO``, this decorator + is a no-op). This is a feature meant for selectively introducing type checking + into a code base where the checks aren't meant to be run in production. + + :param target: the function or class to enable type checking for + :param forward_ref_policy: override for + :attr:`.TypeCheckConfiguration.forward_ref_policy` + :param typecheck_fail_callback: override for + :attr:`.TypeCheckConfiguration.typecheck_fail_callback` + :param collection_check_strategy: override for + :attr:`.TypeCheckConfiguration.collection_check_strategy` + :param debug_instrumentation: override for + :attr:`.TypeCheckConfiguration.debug_instrumentation` + + """ + if target is None: + return partial( + typechecked, + forward_ref_policy=forward_ref_policy, + typecheck_fail_callback=typecheck_fail_callback, + collection_check_strategy=collection_check_strategy, + debug_instrumentation=debug_instrumentation, + ) + + if not __debug__: + return target + + if isclass(target): + for key, attr in target.__dict__.items(): + if is_method_of(attr, target): + retval = instrument(attr) + if isfunction(retval): + setattr(target, key, retval) + elif isinstance(attr, (classmethod, staticmethod)): + if is_method_of(attr.__func__, target): + retval = instrument(attr.__func__) + if isfunction(retval): + wrapper = attr.__class__(retval) + setattr(target, key, wrapper) + elif isinstance(attr, property): + kwargs: dict[str, Any] = dict(doc=attr.__doc__) + for name in ("fset", "fget", "fdel"): + property_func = kwargs[name] = getattr(attr, name) + if is_method_of(property_func, target): + retval = instrument(property_func) + if isfunction(retval): + kwargs[name] = retval + + setattr(target, key, attr.__class__(**kwargs)) + + return target + + # Find either the first Python wrapper or the actual function + wrapper_class: ( + type[classmethod[Any, Any, Any]] | type[staticmethod[Any, Any]] | None + ) = None + if isinstance(target, (classmethod, staticmethod)): + wrapper_class = target.__class__ + target = target.__func__ + + retval = instrument(target) + if isinstance(retval, str): + warn( + f"{retval} -- not typechecking {function_name(target)}", + InstrumentationWarning, + stacklevel=get_stacklevel(), + ) + return target + + if wrapper_class is None: + return retval + else: + return wrapper_class(retval) diff --git a/lib/typeguard/_exceptions.py b/lib/typeguard/_exceptions.py new file mode 100644 index 00000000..625437a6 --- /dev/null +++ b/lib/typeguard/_exceptions.py @@ -0,0 +1,42 @@ +from collections import deque +from typing import Deque + + +class TypeHintWarning(UserWarning): + """ + A warning that is emitted when a type hint in string form could not be resolved to + an actual type. + """ + + +class TypeCheckWarning(UserWarning): + """Emitted by typeguard's type checkers when a type mismatch is detected.""" + + def __init__(self, message: str): + super().__init__(message) + + +class InstrumentationWarning(UserWarning): + """Emitted when there's a problem with instrumenting a function for type checks.""" + + def __init__(self, message: str): + super().__init__(message) + + +class TypeCheckError(Exception): + """ + Raised by typeguard's type checkers when a type mismatch is detected. + """ + + def __init__(self, message: str): + super().__init__(message) + self._path: Deque[str] = deque() + + def append_path_element(self, element: str) -> None: + self._path.append(element) + + def __str__(self) -> str: + if self._path: + return " of ".join(self._path) + " " + str(self.args[0]) + else: + return str(self.args[0]) diff --git a/lib/typeguard/_functions.py b/lib/typeguard/_functions.py new file mode 100644 index 00000000..28497856 --- /dev/null +++ b/lib/typeguard/_functions.py @@ -0,0 +1,308 @@ +from __future__ import annotations + +import sys +import warnings +from typing import Any, Callable, NoReturn, TypeVar, Union, overload + +from . import _suppression +from ._checkers import BINARY_MAGIC_METHODS, check_type_internal +from ._config import ( + CollectionCheckStrategy, + ForwardRefPolicy, + TypeCheckConfiguration, +) +from ._exceptions import TypeCheckError, TypeCheckWarning +from ._memo import TypeCheckMemo +from ._utils import get_stacklevel, qualified_name + +if sys.version_info >= (3, 11): + from typing import Literal, Never, TypeAlias +else: + from typing_extensions import Literal, Never, TypeAlias + +T = TypeVar("T") +TypeCheckFailCallback: TypeAlias = Callable[[TypeCheckError, TypeCheckMemo], Any] + + +@overload +def check_type( + value: object, + expected_type: type[T], + *, + forward_ref_policy: ForwardRefPolicy = ..., + typecheck_fail_callback: TypeCheckFailCallback | None = ..., + collection_check_strategy: CollectionCheckStrategy = ..., +) -> T: ... + + +@overload +def check_type( + value: object, + expected_type: Any, + *, + forward_ref_policy: ForwardRefPolicy = ..., + typecheck_fail_callback: TypeCheckFailCallback | None = ..., + collection_check_strategy: CollectionCheckStrategy = ..., +) -> Any: ... + + +def check_type( + value: object, + expected_type: Any, + *, + forward_ref_policy: ForwardRefPolicy = TypeCheckConfiguration().forward_ref_policy, + typecheck_fail_callback: TypeCheckFailCallback | None = ( + TypeCheckConfiguration().typecheck_fail_callback + ), + collection_check_strategy: CollectionCheckStrategy = ( + TypeCheckConfiguration().collection_check_strategy + ), +) -> Any: + """ + Ensure that ``value`` matches ``expected_type``. + + The types from the :mod:`typing` module do not support :func:`isinstance` or + :func:`issubclass` so a number of type specific checks are required. This function + knows which checker to call for which type. + + This function wraps :func:`~.check_type_internal` in the following ways: + + * Respects type checking suppression (:func:`~.suppress_type_checks`) + * Forms a :class:`~.TypeCheckMemo` from the current stack frame + * Calls the configured type check fail callback if the check fails + + Note that this function is independent of the globally shared configuration in + :data:`typeguard.config`. This means that usage within libraries is safe from being + affected configuration changes made by other libraries or by the integrating + application. Instead, configuration options have the same default values as their + corresponding fields in :class:`TypeCheckConfiguration`. + + :param value: value to be checked against ``expected_type`` + :param expected_type: a class or generic type instance, or a tuple of such things + :param forward_ref_policy: see :attr:`TypeCheckConfiguration.forward_ref_policy` + :param typecheck_fail_callback: + see :attr`TypeCheckConfiguration.typecheck_fail_callback` + :param collection_check_strategy: + see :attr:`TypeCheckConfiguration.collection_check_strategy` + :return: ``value``, unmodified + :raises TypeCheckError: if there is a type mismatch + + """ + if type(expected_type) is tuple: + expected_type = Union[expected_type] + + config = TypeCheckConfiguration( + forward_ref_policy=forward_ref_policy, + typecheck_fail_callback=typecheck_fail_callback, + collection_check_strategy=collection_check_strategy, + ) + + if _suppression.type_checks_suppressed or expected_type is Any: + return value + + frame = sys._getframe(1) + memo = TypeCheckMemo(frame.f_globals, frame.f_locals, config=config) + try: + check_type_internal(value, expected_type, memo) + except TypeCheckError as exc: + exc.append_path_element(qualified_name(value, add_class_prefix=True)) + if config.typecheck_fail_callback: + config.typecheck_fail_callback(exc, memo) + else: + raise + + return value + + +def check_argument_types( + func_name: str, + arguments: dict[str, tuple[Any, Any]], + memo: TypeCheckMemo, +) -> Literal[True]: + if _suppression.type_checks_suppressed: + return True + + for argname, (value, annotation) in arguments.items(): + if annotation is NoReturn or annotation is Never: + exc = TypeCheckError( + f"{func_name}() was declared never to be called but it was" + ) + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise exc + + try: + check_type_internal(value, annotation, memo) + except TypeCheckError as exc: + qualname = qualified_name(value, add_class_prefix=True) + exc.append_path_element(f'argument "{argname}" ({qualname})') + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise + + return True + + +def check_return_type( + func_name: str, + retval: T, + annotation: Any, + memo: TypeCheckMemo, +) -> T: + if _suppression.type_checks_suppressed: + return retval + + if annotation is NoReturn or annotation is Never: + exc = TypeCheckError(f"{func_name}() was declared never to return but it did") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise exc + + try: + check_type_internal(retval, annotation, memo) + except TypeCheckError as exc: + # Allow NotImplemented if this is a binary magic method (__eq__() et al) + if retval is NotImplemented and annotation is bool: + # This does (and cannot) not check if it's actually a method + func_name = func_name.rsplit(".", 1)[-1] + if func_name in BINARY_MAGIC_METHODS: + return retval + + qualname = qualified_name(retval, add_class_prefix=True) + exc.append_path_element(f"the return value ({qualname})") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise + + return retval + + +def check_send_type( + func_name: str, + sendval: T, + annotation: Any, + memo: TypeCheckMemo, +) -> T: + if _suppression.type_checks_suppressed: + return sendval + + if annotation is NoReturn or annotation is Never: + exc = TypeCheckError( + f"{func_name}() was declared never to be sent a value to but it was" + ) + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise exc + + try: + check_type_internal(sendval, annotation, memo) + except TypeCheckError as exc: + qualname = qualified_name(sendval, add_class_prefix=True) + exc.append_path_element(f"the value sent to generator ({qualname})") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise + + return sendval + + +def check_yield_type( + func_name: str, + yieldval: T, + annotation: Any, + memo: TypeCheckMemo, +) -> T: + if _suppression.type_checks_suppressed: + return yieldval + + if annotation is NoReturn or annotation is Never: + exc = TypeCheckError(f"{func_name}() was declared never to yield but it did") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise exc + + try: + check_type_internal(yieldval, annotation, memo) + except TypeCheckError as exc: + qualname = qualified_name(yieldval, add_class_prefix=True) + exc.append_path_element(f"the yielded value ({qualname})") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise + + return yieldval + + +def check_variable_assignment( + value: object, varname: str, annotation: Any, memo: TypeCheckMemo +) -> Any: + if _suppression.type_checks_suppressed: + return value + + try: + check_type_internal(value, annotation, memo) + except TypeCheckError as exc: + qualname = qualified_name(value, add_class_prefix=True) + exc.append_path_element(f"value assigned to {varname} ({qualname})") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise + + return value + + +def check_multi_variable_assignment( + value: Any, targets: list[dict[str, Any]], memo: TypeCheckMemo +) -> Any: + if max(len(target) for target in targets) == 1: + iterated_values = [value] + else: + iterated_values = list(value) + + if not _suppression.type_checks_suppressed: + for expected_types in targets: + value_index = 0 + for ann_index, (varname, expected_type) in enumerate( + expected_types.items() + ): + if varname.startswith("*"): + varname = varname[1:] + keys_left = len(expected_types) - 1 - ann_index + next_value_index = len(iterated_values) - keys_left + obj: object = iterated_values[value_index:next_value_index] + value_index = next_value_index + else: + obj = iterated_values[value_index] + value_index += 1 + + try: + check_type_internal(obj, expected_type, memo) + except TypeCheckError as exc: + qualname = qualified_name(obj, add_class_prefix=True) + exc.append_path_element(f"value assigned to {varname} ({qualname})") + if memo.config.typecheck_fail_callback: + memo.config.typecheck_fail_callback(exc, memo) + else: + raise + + return iterated_values[0] if len(iterated_values) == 1 else iterated_values + + +def warn_on_error(exc: TypeCheckError, memo: TypeCheckMemo) -> None: + """ + Emit a warning on a type mismatch. + + This is intended to be used as an error handler in + :attr:`TypeCheckConfiguration.typecheck_fail_callback`. + + """ + warnings.warn(TypeCheckWarning(str(exc)), stacklevel=get_stacklevel()) diff --git a/lib/typeguard/_importhook.py b/lib/typeguard/_importhook.py new file mode 100644 index 00000000..8590540a --- /dev/null +++ b/lib/typeguard/_importhook.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +import ast +import sys +import types +from collections.abc import Callable, Iterable +from importlib.abc import MetaPathFinder +from importlib.machinery import ModuleSpec, SourceFileLoader +from importlib.util import cache_from_source, decode_source +from inspect import isclass +from os import PathLike +from types import CodeType, ModuleType, TracebackType +from typing import Sequence, TypeVar +from unittest.mock import patch + +from ._config import global_config +from ._transformer import TypeguardTransformer + +if sys.version_info >= (3, 12): + from collections.abc import Buffer +else: + from typing_extensions import Buffer + +if sys.version_info >= (3, 11): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + +if sys.version_info >= (3, 10): + from importlib.metadata import PackageNotFoundError, version +else: + from importlib_metadata import PackageNotFoundError, version + +try: + OPTIMIZATION = "typeguard" + "".join(version("typeguard").split(".")[:3]) +except PackageNotFoundError: + OPTIMIZATION = "typeguard" + +P = ParamSpec("P") +T = TypeVar("T") + + +# The name of this function is magical +def _call_with_frames_removed( + f: Callable[P, T], *args: P.args, **kwargs: P.kwargs +) -> T: + return f(*args, **kwargs) + + +def optimized_cache_from_source(path: str, debug_override: bool | None = None) -> str: + return cache_from_source(path, debug_override, optimization=OPTIMIZATION) + + +class TypeguardLoader(SourceFileLoader): + @staticmethod + def source_to_code( + data: Buffer | str | ast.Module | ast.Expression | ast.Interactive, + path: Buffer | str | PathLike[str] = "", + ) -> CodeType: + if isinstance(data, (ast.Module, ast.Expression, ast.Interactive)): + tree = data + else: + if isinstance(data, str): + source = data + else: + source = decode_source(data) + + tree = _call_with_frames_removed( + ast.parse, + source, + path, + "exec", + ) + + tree = TypeguardTransformer().visit(tree) + ast.fix_missing_locations(tree) + + if global_config.debug_instrumentation and sys.version_info >= (3, 9): + print( + f"Source code of {path!r} after instrumentation:\n" + "----------------------------------------------", + file=sys.stderr, + ) + print(ast.unparse(tree), file=sys.stderr) + print("----------------------------------------------", file=sys.stderr) + + return _call_with_frames_removed( + compile, tree, path, "exec", 0, dont_inherit=True + ) + + def exec_module(self, module: ModuleType) -> None: + # Use a custom optimization marker – the import lock should make this monkey + # patch safe + with patch( + "importlib._bootstrap_external.cache_from_source", + optimized_cache_from_source, + ): + super().exec_module(module) + + +class TypeguardFinder(MetaPathFinder): + """ + Wraps another path finder and instruments the module with + :func:`@typechecked ` if :meth:`should_instrument` returns + ``True``. + + Should not be used directly, but rather via :func:`~.install_import_hook`. + + .. versionadded:: 2.6 + """ + + def __init__(self, packages: list[str] | None, original_pathfinder: MetaPathFinder): + self.packages = packages + self._original_pathfinder = original_pathfinder + + def find_spec( + self, + fullname: str, + path: Sequence[str] | None, + target: types.ModuleType | None = None, + ) -> ModuleSpec | None: + if self.should_instrument(fullname): + spec = self._original_pathfinder.find_spec(fullname, path, target) + if spec is not None and isinstance(spec.loader, SourceFileLoader): + spec.loader = TypeguardLoader(spec.loader.name, spec.loader.path) + return spec + + return None + + def should_instrument(self, module_name: str) -> bool: + """ + Determine whether the module with the given name should be instrumented. + + :param module_name: full name of the module that is about to be imported (e.g. + ``xyz.abc``) + + """ + if self.packages is None: + return True + + for package in self.packages: + if module_name == package or module_name.startswith(package + "."): + return True + + return False + + +class ImportHookManager: + """ + A handle that can be used to uninstall the Typeguard import hook. + """ + + def __init__(self, hook: MetaPathFinder): + self.hook = hook + + def __enter__(self) -> None: + pass + + def __exit__( + self, + exc_type: type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: + self.uninstall() + + def uninstall(self) -> None: + """Uninstall the import hook.""" + try: + sys.meta_path.remove(self.hook) + except ValueError: + pass # already removed + + +def install_import_hook( + packages: Iterable[str] | None = None, + *, + cls: type[TypeguardFinder] = TypeguardFinder, +) -> ImportHookManager: + """ + Install an import hook that instruments functions for automatic type checking. + + This only affects modules loaded **after** this hook has been installed. + + :param packages: an iterable of package names to instrument, or ``None`` to + instrument all packages + :param cls: a custom meta path finder class + :return: a context manager that uninstalls the hook on exit (or when you call + ``.uninstall()``) + + .. versionadded:: 2.6 + + """ + if packages is None: + target_packages: list[str] | None = None + elif isinstance(packages, str): + target_packages = [packages] + else: + target_packages = list(packages) + + for finder in sys.meta_path: + if ( + isclass(finder) + and finder.__name__ == "PathFinder" + and hasattr(finder, "find_spec") + ): + break + else: + raise RuntimeError("Cannot find a PathFinder in sys.meta_path") + + hook = cls(target_packages, finder) + sys.meta_path.insert(0, hook) + return ImportHookManager(hook) diff --git a/lib/typeguard/_memo.py b/lib/typeguard/_memo.py new file mode 100644 index 00000000..1d0d80c6 --- /dev/null +++ b/lib/typeguard/_memo.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import Any + +from typeguard._config import TypeCheckConfiguration, global_config + + +class TypeCheckMemo: + """ + Contains information necessary for type checkers to do their work. + + .. attribute:: globals + :type: dict[str, Any] + + Dictionary of global variables to use for resolving forward references. + + .. attribute:: locals + :type: dict[str, Any] + + Dictionary of local variables to use for resolving forward references. + + .. attribute:: self_type + :type: type | None + + When running type checks within an instance method or class method, this is the + class object that the first argument (usually named ``self`` or ``cls``) refers + to. + + .. attribute:: config + :type: TypeCheckConfiguration + + Contains the configuration for a particular set of type checking operations. + """ + + __slots__ = "globals", "locals", "self_type", "config" + + def __init__( + self, + globals: dict[str, Any], + locals: dict[str, Any], + *, + self_type: type | None = None, + config: TypeCheckConfiguration = global_config, + ): + self.globals = globals + self.locals = locals + self.self_type = self_type + self.config = config diff --git a/lib/typeguard/_pytest_plugin.py b/lib/typeguard/_pytest_plugin.py new file mode 100644 index 00000000..7bca9c26 --- /dev/null +++ b/lib/typeguard/_pytest_plugin.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import sys +import warnings +from typing import Any, Literal + +from pytest import Config, Parser + +from typeguard._config import CollectionCheckStrategy, ForwardRefPolicy, global_config +from typeguard._exceptions import InstrumentationWarning +from typeguard._importhook import install_import_hook +from typeguard._utils import qualified_name, resolve_reference + + +def pytest_addoption(parser: Parser) -> None: + def add_ini_option( + opt_type: ( + Literal["string", "paths", "pathlist", "args", "linelist", "bool"] | None + ) + ) -> None: + parser.addini( + group.options[-1].names()[0][2:], + group.options[-1].attrs()["help"], + opt_type, + ) + + group = parser.getgroup("typeguard") + group.addoption( + "--typeguard-packages", + action="store", + help="comma separated name list of packages and modules to instrument for " + "type checking, or :all: to instrument all modules loaded after typeguard", + ) + add_ini_option("linelist") + + group.addoption( + "--typeguard-debug-instrumentation", + action="store_true", + help="print all instrumented code to stderr", + ) + add_ini_option("bool") + + group.addoption( + "--typeguard-typecheck-fail-callback", + action="store", + help=( + "a module:varname (e.g. typeguard:warn_on_error) reference to a function " + "that is called (with the exception, and memo object as arguments) to " + "handle a TypeCheckError" + ), + ) + add_ini_option("string") + + group.addoption( + "--typeguard-forward-ref-policy", + action="store", + choices=list(ForwardRefPolicy.__members__), + help=( + "determines how to deal with unresolveable forward references in type " + "annotations" + ), + ) + add_ini_option("string") + + group.addoption( + "--typeguard-collection-check-strategy", + action="store", + choices=list(CollectionCheckStrategy.__members__), + help="determines how thoroughly to check collections (list, dict, etc)", + ) + add_ini_option("string") + + +def pytest_configure(config: Config) -> None: + def getoption(name: str) -> Any: + return config.getoption(name.replace("-", "_")) or config.getini(name) + + packages: list[str] | None = [] + if packages_option := config.getoption("typeguard_packages"): + packages = [pkg.strip() for pkg in packages_option.split(",")] + elif packages_ini := config.getini("typeguard-packages"): + packages = packages_ini + + if packages: + if packages == [":all:"]: + packages = None + else: + already_imported_packages = sorted( + package for package in packages if package in sys.modules + ) + if already_imported_packages: + warnings.warn( + f"typeguard cannot check these packages because they are already " + f"imported: {', '.join(already_imported_packages)}", + InstrumentationWarning, + stacklevel=1, + ) + + install_import_hook(packages=packages) + + debug_option = getoption("typeguard-debug-instrumentation") + if debug_option: + global_config.debug_instrumentation = True + + fail_callback_option = getoption("typeguard-typecheck-fail-callback") + if fail_callback_option: + callback = resolve_reference(fail_callback_option) + if not callable(callback): + raise TypeError( + f"{fail_callback_option} ({qualified_name(callback.__class__)}) is not " + f"a callable" + ) + + global_config.typecheck_fail_callback = callback + + forward_ref_policy_option = getoption("typeguard-forward-ref-policy") + if forward_ref_policy_option: + forward_ref_policy = ForwardRefPolicy.__members__[forward_ref_policy_option] + global_config.forward_ref_policy = forward_ref_policy + + collection_check_strategy_option = getoption("typeguard-collection-check-strategy") + if collection_check_strategy_option: + collection_check_strategy = CollectionCheckStrategy.__members__[ + collection_check_strategy_option + ] + global_config.collection_check_strategy = collection_check_strategy diff --git a/lib/typeguard/_suppression.py b/lib/typeguard/_suppression.py new file mode 100644 index 00000000..f6899a9f --- /dev/null +++ b/lib/typeguard/_suppression.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import sys +from collections.abc import Callable, Generator +from contextlib import contextmanager +from functools import update_wrapper +from threading import Lock +from typing import ContextManager, TypeVar, overload + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + +P = ParamSpec("P") +T = TypeVar("T") + +type_checks_suppressed = 0 +type_checks_suppress_lock = Lock() + + +@overload +def suppress_type_checks(func: Callable[P, T]) -> Callable[P, T]: ... + + +@overload +def suppress_type_checks() -> ContextManager[None]: ... + + +def suppress_type_checks( + func: Callable[P, T] | None = None +) -> Callable[P, T] | ContextManager[None]: + """ + Temporarily suppress all type checking. + + This function has two operating modes, based on how it's used: + + #. as a context manager (``with suppress_type_checks(): ...``) + #. as a decorator (``@suppress_type_checks``) + + When used as a context manager, :func:`check_type` and any automatically + instrumented functions skip the actual type checking. These context managers can be + nested. + + When used as a decorator, all type checking is suppressed while the function is + running. + + Type checking will resume once no more context managers are active and no decorated + functions are running. + + Both operating modes are thread-safe. + + """ + + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + global type_checks_suppressed + + with type_checks_suppress_lock: + type_checks_suppressed += 1 + + assert func is not None + try: + return func(*args, **kwargs) + finally: + with type_checks_suppress_lock: + type_checks_suppressed -= 1 + + def cm() -> Generator[None, None, None]: + global type_checks_suppressed + + with type_checks_suppress_lock: + type_checks_suppressed += 1 + + try: + yield + finally: + with type_checks_suppress_lock: + type_checks_suppressed -= 1 + + if func is None: + # Context manager mode + return contextmanager(cm)() + else: + # Decorator mode + update_wrapper(wrapper, func) + return wrapper diff --git a/lib/typeguard/_transformer.py b/lib/typeguard/_transformer.py new file mode 100644 index 00000000..13ac3630 --- /dev/null +++ b/lib/typeguard/_transformer.py @@ -0,0 +1,1229 @@ +from __future__ import annotations + +import ast +import builtins +import sys +import typing +from ast import ( + AST, + Add, + AnnAssign, + Assign, + AsyncFunctionDef, + Attribute, + AugAssign, + BinOp, + BitAnd, + BitOr, + BitXor, + Call, + ClassDef, + Constant, + Dict, + Div, + Expr, + Expression, + FloorDiv, + FunctionDef, + If, + Import, + ImportFrom, + Index, + List, + Load, + LShift, + MatMult, + Mod, + Module, + Mult, + Name, + NamedExpr, + NodeTransformer, + NodeVisitor, + Pass, + Pow, + Return, + RShift, + Starred, + Store, + Sub, + Subscript, + Tuple, + Yield, + YieldFrom, + alias, + copy_location, + expr, + fix_missing_locations, + keyword, + walk, +) +from collections import defaultdict +from collections.abc import Generator, Sequence +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, ClassVar, cast, overload + +generator_names = ( + "typing.Generator", + "collections.abc.Generator", + "typing.Iterator", + "collections.abc.Iterator", + "typing.Iterable", + "collections.abc.Iterable", + "typing.AsyncIterator", + "collections.abc.AsyncIterator", + "typing.AsyncIterable", + "collections.abc.AsyncIterable", + "typing.AsyncGenerator", + "collections.abc.AsyncGenerator", +) +anytype_names = ( + "typing.Any", + "typing_extensions.Any", +) +literal_names = ( + "typing.Literal", + "typing_extensions.Literal", +) +annotated_names = ( + "typing.Annotated", + "typing_extensions.Annotated", +) +ignore_decorators = ( + "typing.no_type_check", + "typeguard.typeguard_ignore", +) +aug_assign_functions = { + Add: "iadd", + Sub: "isub", + Mult: "imul", + MatMult: "imatmul", + Div: "itruediv", + FloorDiv: "ifloordiv", + Mod: "imod", + Pow: "ipow", + LShift: "ilshift", + RShift: "irshift", + BitAnd: "iand", + BitXor: "ixor", + BitOr: "ior", +} + + +@dataclass +class TransformMemo: + node: Module | ClassDef | FunctionDef | AsyncFunctionDef | None + parent: TransformMemo | None + path: tuple[str, ...] + joined_path: Constant = field(init=False) + return_annotation: expr | None = None + yield_annotation: expr | None = None + send_annotation: expr | None = None + is_async: bool = False + local_names: set[str] = field(init=False, default_factory=set) + imported_names: dict[str, str] = field(init=False, default_factory=dict) + ignored_names: set[str] = field(init=False, default_factory=set) + load_names: defaultdict[str, dict[str, Name]] = field( + init=False, default_factory=lambda: defaultdict(dict) + ) + has_yield_expressions: bool = field(init=False, default=False) + has_return_expressions: bool = field(init=False, default=False) + memo_var_name: Name | None = field(init=False, default=None) + should_instrument: bool = field(init=False, default=True) + variable_annotations: dict[str, expr] = field(init=False, default_factory=dict) + configuration_overrides: dict[str, Any] = field(init=False, default_factory=dict) + code_inject_index: int = field(init=False, default=0) + + def __post_init__(self) -> None: + elements: list[str] = [] + memo = self + while isinstance(memo.node, (ClassDef, FunctionDef, AsyncFunctionDef)): + elements.insert(0, memo.node.name) + if not memo.parent: + break + + memo = memo.parent + if isinstance(memo.node, (FunctionDef, AsyncFunctionDef)): + elements.insert(0, "") + + self.joined_path = Constant(".".join(elements)) + + # Figure out where to insert instrumentation code + if self.node: + for index, child in enumerate(self.node.body): + if isinstance(child, ImportFrom) and child.module == "__future__": + # (module only) __future__ imports must come first + continue + elif ( + isinstance(child, Expr) + and isinstance(child.value, Constant) + and isinstance(child.value.value, str) + ): + continue # docstring + + self.code_inject_index = index + break + + def get_unused_name(self, name: str) -> str: + memo: TransformMemo | None = self + while memo is not None: + if name in memo.local_names: + memo = self + name += "_" + else: + memo = memo.parent + + self.local_names.add(name) + return name + + def is_ignored_name(self, expression: expr | Expr | None) -> bool: + top_expression = ( + expression.value if isinstance(expression, Expr) else expression + ) + + if isinstance(top_expression, Attribute) and isinstance( + top_expression.value, Name + ): + name = top_expression.value.id + elif isinstance(top_expression, Name): + name = top_expression.id + else: + return False + + memo: TransformMemo | None = self + while memo is not None: + if name in memo.ignored_names: + return True + + memo = memo.parent + + return False + + def get_memo_name(self) -> Name: + if not self.memo_var_name: + self.memo_var_name = Name(id="memo", ctx=Load()) + + return self.memo_var_name + + def get_import(self, module: str, name: str) -> Name: + if module in self.load_names and name in self.load_names[module]: + return self.load_names[module][name] + + qualified_name = f"{module}.{name}" + if name in self.imported_names and self.imported_names[name] == qualified_name: + return Name(id=name, ctx=Load()) + + alias = self.get_unused_name(name) + node = self.load_names[module][name] = Name(id=alias, ctx=Load()) + self.imported_names[name] = qualified_name + return node + + def insert_imports(self, node: Module | FunctionDef | AsyncFunctionDef) -> None: + """Insert imports needed by injected code.""" + if not self.load_names: + return + + # Insert imports after any "from __future__ ..." imports and any docstring + for modulename, names in self.load_names.items(): + aliases = [ + alias(orig_name, new_name.id if orig_name != new_name.id else None) + for orig_name, new_name in sorted(names.items()) + ] + node.body.insert(self.code_inject_index, ImportFrom(modulename, aliases, 0)) + + def name_matches(self, expression: expr | Expr | None, *names: str) -> bool: + if expression is None: + return False + + path: list[str] = [] + top_expression = ( + expression.value if isinstance(expression, Expr) else expression + ) + + if isinstance(top_expression, Subscript): + top_expression = top_expression.value + elif isinstance(top_expression, Call): + top_expression = top_expression.func + + while isinstance(top_expression, Attribute): + path.insert(0, top_expression.attr) + top_expression = top_expression.value + + if not isinstance(top_expression, Name): + return False + + if top_expression.id in self.imported_names: + translated = self.imported_names[top_expression.id] + elif hasattr(builtins, top_expression.id): + translated = "builtins." + top_expression.id + else: + translated = top_expression.id + + path.insert(0, translated) + joined_path = ".".join(path) + if joined_path in names: + return True + elif self.parent: + return self.parent.name_matches(expression, *names) + else: + return False + + def get_config_keywords(self) -> list[keyword]: + if self.parent and isinstance(self.parent.node, ClassDef): + overrides = self.parent.configuration_overrides.copy() + else: + overrides = {} + + overrides.update(self.configuration_overrides) + return [keyword(key, value) for key, value in overrides.items()] + + +class NameCollector(NodeVisitor): + def __init__(self) -> None: + self.names: set[str] = set() + + def visit_Import(self, node: Import) -> None: + for name in node.names: + self.names.add(name.asname or name.name) + + def visit_ImportFrom(self, node: ImportFrom) -> None: + for name in node.names: + self.names.add(name.asname or name.name) + + def visit_Assign(self, node: Assign) -> None: + for target in node.targets: + if isinstance(target, Name): + self.names.add(target.id) + + def visit_NamedExpr(self, node: NamedExpr) -> Any: + if isinstance(node.target, Name): + self.names.add(node.target.id) + + def visit_FunctionDef(self, node: FunctionDef) -> None: + pass + + def visit_ClassDef(self, node: ClassDef) -> None: + pass + + +class GeneratorDetector(NodeVisitor): + """Detects if a function node is a generator function.""" + + contains_yields: bool = False + in_root_function: bool = False + + def visit_Yield(self, node: Yield) -> Any: + self.contains_yields = True + + def visit_YieldFrom(self, node: YieldFrom) -> Any: + self.contains_yields = True + + def visit_ClassDef(self, node: ClassDef) -> Any: + pass + + def visit_FunctionDef(self, node: FunctionDef | AsyncFunctionDef) -> Any: + if not self.in_root_function: + self.in_root_function = True + self.generic_visit(node) + self.in_root_function = False + + def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> Any: + self.visit_FunctionDef(node) + + +class AnnotationTransformer(NodeTransformer): + type_substitutions: ClassVar[dict[str, tuple[str, str]]] = { + "builtins.dict": ("typing", "Dict"), + "builtins.list": ("typing", "List"), + "builtins.tuple": ("typing", "Tuple"), + "builtins.set": ("typing", "Set"), + "builtins.frozenset": ("typing", "FrozenSet"), + } + + def __init__(self, transformer: TypeguardTransformer): + self.transformer = transformer + self._memo = transformer._memo + self._level = 0 + + def visit(self, node: AST) -> Any: + # Don't process Literals + if isinstance(node, expr) and self._memo.name_matches(node, *literal_names): + return node + + self._level += 1 + new_node = super().visit(node) + self._level -= 1 + + if isinstance(new_node, Expression) and not hasattr(new_node, "body"): + return None + + # Return None if this new node matches a variation of typing.Any + if ( + self._level == 0 + and isinstance(new_node, expr) + and self._memo.name_matches(new_node, *anytype_names) + ): + return None + + return new_node + + def visit_BinOp(self, node: BinOp) -> Any: + self.generic_visit(node) + + if isinstance(node.op, BitOr): + # If either branch of the BinOp has been transformed to `None`, it means + # that a type in the union was ignored, so the entire annotation should e + # ignored + if not hasattr(node, "left") or not hasattr(node, "right"): + return None + + # Return Any if either side is Any + if self._memo.name_matches(node.left, *anytype_names): + return node.left + elif self._memo.name_matches(node.right, *anytype_names): + return node.right + + if sys.version_info < (3, 10): + union_name = self.transformer._get_import("typing", "Union") + return Subscript( + value=union_name, + slice=Index( + Tuple(elts=[node.left, node.right], ctx=Load()), ctx=Load() + ), + ctx=Load(), + ) + + return node + + def visit_Attribute(self, node: Attribute) -> Any: + if self._memo.is_ignored_name(node): + return None + + return node + + def visit_Subscript(self, node: Subscript) -> Any: + if self._memo.is_ignored_name(node.value): + return None + + # The subscript of typing(_extensions).Literal can be any arbitrary string, so + # don't try to evaluate it as code + if node.slice: + if isinstance(node.slice, Index): + # Python 3.8 + slice_value = node.slice.value # type: ignore[attr-defined] + else: + slice_value = node.slice + + if isinstance(slice_value, Tuple): + if self._memo.name_matches(node.value, *annotated_names): + # Only treat the first argument to typing.Annotated as a potential + # forward reference + items = cast( + typing.List[expr], + [self.visit(slice_value.elts[0])] + slice_value.elts[1:], + ) + else: + items = cast( + typing.List[expr], + [self.visit(item) for item in slice_value.elts], + ) + + # If this is a Union and any of the items is Any, erase the entire + # annotation + if self._memo.name_matches(node.value, "typing.Union") and any( + item is None + or ( + isinstance(item, expr) + and self._memo.name_matches(item, *anytype_names) + ) + for item in items + ): + return None + + # If all items in the subscript were Any, erase the subscript entirely + if all(item is None for item in items): + return node.value + + for index, item in enumerate(items): + if item is None: + items[index] = self.transformer._get_import("typing", "Any") + + slice_value.elts = items + else: + self.generic_visit(node) + + # If the transformer erased the slice entirely, just return the node + # value without the subscript (unless it's Optional, in which case erase + # the node entirely + if self._memo.name_matches( + node.value, "typing.Optional" + ) and not hasattr(node, "slice"): + return None + if sys.version_info >= (3, 9) and not hasattr(node, "slice"): + return node.value + elif sys.version_info < (3, 9) and not hasattr(node.slice, "value"): + return node.value + + return node + + def visit_Name(self, node: Name) -> Any: + if self._memo.is_ignored_name(node): + return None + + if sys.version_info < (3, 9): + for typename, substitute in self.type_substitutions.items(): + if self._memo.name_matches(node, typename): + new_node = self.transformer._get_import(*substitute) + return copy_location(new_node, node) + + return node + + def visit_Call(self, node: Call) -> Any: + # Don't recurse into calls + return node + + def visit_Constant(self, node: Constant) -> Any: + if isinstance(node.value, str): + expression = ast.parse(node.value, mode="eval") + new_node = self.visit(expression) + if new_node: + return copy_location(new_node.body, node) + else: + return None + + return node + + +class TypeguardTransformer(NodeTransformer): + def __init__( + self, target_path: Sequence[str] | None = None, target_lineno: int | None = None + ) -> None: + self._target_path = tuple(target_path) if target_path else None + self._memo = self._module_memo = TransformMemo(None, None, ()) + self.names_used_in_annotations: set[str] = set() + self.target_node: FunctionDef | AsyncFunctionDef | None = None + self.target_lineno = target_lineno + + def generic_visit(self, node: AST) -> AST: + has_non_empty_body_initially = bool(getattr(node, "body", None)) + initial_type = type(node) + + node = super().generic_visit(node) + + if ( + type(node) is initial_type + and has_non_empty_body_initially + and hasattr(node, "body") + and not node.body + ): + # If we have still the same node type after transformation + # but we've optimised it's body away, we add a `pass` statement. + node.body = [Pass()] + + return node + + @contextmanager + def _use_memo( + self, node: ClassDef | FunctionDef | AsyncFunctionDef + ) -> Generator[None, Any, None]: + new_memo = TransformMemo(node, self._memo, self._memo.path + (node.name,)) + old_memo = self._memo + self._memo = new_memo + + if isinstance(node, (FunctionDef, AsyncFunctionDef)): + new_memo.should_instrument = ( + self._target_path is None or new_memo.path == self._target_path + ) + if new_memo.should_instrument: + # Check if the function is a generator function + detector = GeneratorDetector() + detector.visit(node) + + # Extract yield, send and return types where possible from a subscripted + # annotation like Generator[int, str, bool] + return_annotation = deepcopy(node.returns) + if detector.contains_yields and new_memo.name_matches( + return_annotation, *generator_names + ): + if isinstance(return_annotation, Subscript): + annotation_slice = return_annotation.slice + + # Python < 3.9 + if isinstance(annotation_slice, Index): + annotation_slice = ( + annotation_slice.value # type: ignore[attr-defined] + ) + + if isinstance(annotation_slice, Tuple): + items = annotation_slice.elts + else: + items = [annotation_slice] + + if len(items) > 0: + new_memo.yield_annotation = self._convert_annotation( + items[0] + ) + + if len(items) > 1: + new_memo.send_annotation = self._convert_annotation( + items[1] + ) + + if len(items) > 2: + new_memo.return_annotation = self._convert_annotation( + items[2] + ) + else: + new_memo.return_annotation = self._convert_annotation( + return_annotation + ) + + if isinstance(node, AsyncFunctionDef): + new_memo.is_async = True + + yield + self._memo = old_memo + + def _get_import(self, module: str, name: str) -> Name: + memo = self._memo if self._target_path else self._module_memo + return memo.get_import(module, name) + + @overload + def _convert_annotation(self, annotation: None) -> None: ... + + @overload + def _convert_annotation(self, annotation: expr) -> expr: ... + + def _convert_annotation(self, annotation: expr | None) -> expr | None: + if annotation is None: + return None + + # Convert PEP 604 unions (x | y) and generic built-in collections where + # necessary, and undo forward references + new_annotation = cast(expr, AnnotationTransformer(self).visit(annotation)) + if isinstance(new_annotation, expr): + new_annotation = ast.copy_location(new_annotation, annotation) + + # Store names used in the annotation + names = {node.id for node in walk(new_annotation) if isinstance(node, Name)} + self.names_used_in_annotations.update(names) + + return new_annotation + + def visit_Name(self, node: Name) -> Name: + self._memo.local_names.add(node.id) + return node + + def visit_Module(self, node: Module) -> Module: + self._module_memo = self._memo = TransformMemo(node, None, ()) + self.generic_visit(node) + self._module_memo.insert_imports(node) + + fix_missing_locations(node) + return node + + def visit_Import(self, node: Import) -> Import: + for name in node.names: + self._memo.local_names.add(name.asname or name.name) + self._memo.imported_names[name.asname or name.name] = name.name + + return node + + def visit_ImportFrom(self, node: ImportFrom) -> ImportFrom: + for name in node.names: + if name.name != "*": + alias = name.asname or name.name + self._memo.local_names.add(alias) + self._memo.imported_names[alias] = f"{node.module}.{name.name}" + + return node + + def visit_ClassDef(self, node: ClassDef) -> ClassDef | None: + self._memo.local_names.add(node.name) + + # Eliminate top level classes not belonging to the target path + if ( + self._target_path is not None + and not self._memo.path + and node.name != self._target_path[0] + ): + return None + + with self._use_memo(node): + for decorator in node.decorator_list.copy(): + if self._memo.name_matches(decorator, "typeguard.typechecked"): + # Remove the decorator to prevent duplicate instrumentation + node.decorator_list.remove(decorator) + + # Store any configuration overrides + if isinstance(decorator, Call) and decorator.keywords: + self._memo.configuration_overrides.update( + {kw.arg: kw.value for kw in decorator.keywords if kw.arg} + ) + + self.generic_visit(node) + return node + + def visit_FunctionDef( + self, node: FunctionDef | AsyncFunctionDef + ) -> FunctionDef | AsyncFunctionDef | None: + """ + Injects type checks for function arguments, and for a return of None if the + function is annotated to return something else than Any or None, and the body + ends without an explicit "return". + + """ + self._memo.local_names.add(node.name) + + # Eliminate top level functions not belonging to the target path + if ( + self._target_path is not None + and not self._memo.path + and node.name != self._target_path[0] + ): + return None + + # Skip instrumentation if we're instrumenting the whole module and the function + # contains either @no_type_check or @typeguard_ignore + if self._target_path is None: + for decorator in node.decorator_list: + if self._memo.name_matches(decorator, *ignore_decorators): + return node + + with self._use_memo(node): + arg_annotations: dict[str, Any] = {} + if self._target_path is None or self._memo.path == self._target_path: + # Find line number we're supposed to match against + if node.decorator_list: + first_lineno = node.decorator_list[0].lineno + else: + first_lineno = node.lineno + + for decorator in node.decorator_list.copy(): + if self._memo.name_matches(decorator, "typing.overload"): + # Remove overloads entirely + return None + elif self._memo.name_matches(decorator, "typeguard.typechecked"): + # Remove the decorator to prevent duplicate instrumentation + node.decorator_list.remove(decorator) + + # Store any configuration overrides + if isinstance(decorator, Call) and decorator.keywords: + self._memo.configuration_overrides = { + kw.arg: kw.value for kw in decorator.keywords if kw.arg + } + + if self.target_lineno == first_lineno: + assert self.target_node is None + self.target_node = node + if node.decorator_list: + self.target_lineno = node.decorator_list[0].lineno + else: + self.target_lineno = node.lineno + + all_args = node.args.args + node.args.kwonlyargs + node.args.posonlyargs + + # Ensure that any type shadowed by the positional or keyword-only + # argument names are ignored in this function + for arg in all_args: + self._memo.ignored_names.add(arg.arg) + + # Ensure that any type shadowed by the variable positional argument name + # (e.g. "args" in *args) is ignored this function + if node.args.vararg: + self._memo.ignored_names.add(node.args.vararg.arg) + + # Ensure that any type shadowed by the variable keywrod argument name + # (e.g. "kwargs" in *kwargs) is ignored this function + if node.args.kwarg: + self._memo.ignored_names.add(node.args.kwarg.arg) + + for arg in all_args: + annotation = self._convert_annotation(deepcopy(arg.annotation)) + if annotation: + arg_annotations[arg.arg] = annotation + + if node.args.vararg: + annotation_ = self._convert_annotation(node.args.vararg.annotation) + if annotation_: + if sys.version_info >= (3, 9): + container = Name("tuple", ctx=Load()) + else: + container = self._get_import("typing", "Tuple") + + subscript_slice: Tuple | Index = Tuple( + [ + annotation_, + Constant(Ellipsis), + ], + ctx=Load(), + ) + if sys.version_info < (3, 9): + subscript_slice = Index(subscript_slice, ctx=Load()) + + arg_annotations[node.args.vararg.arg] = Subscript( + container, subscript_slice, ctx=Load() + ) + + if node.args.kwarg: + annotation_ = self._convert_annotation(node.args.kwarg.annotation) + if annotation_: + if sys.version_info >= (3, 9): + container = Name("dict", ctx=Load()) + else: + container = self._get_import("typing", "Dict") + + subscript_slice = Tuple( + [ + Name("str", ctx=Load()), + annotation_, + ], + ctx=Load(), + ) + if sys.version_info < (3, 9): + subscript_slice = Index(subscript_slice, ctx=Load()) + + arg_annotations[node.args.kwarg.arg] = Subscript( + container, subscript_slice, ctx=Load() + ) + + if arg_annotations: + self._memo.variable_annotations.update(arg_annotations) + + self.generic_visit(node) + + if arg_annotations: + annotations_dict = Dict( + keys=[Constant(key) for key in arg_annotations.keys()], + values=[ + Tuple([Name(key, ctx=Load()), annotation], ctx=Load()) + for key, annotation in arg_annotations.items() + ], + ) + func_name = self._get_import( + "typeguard._functions", "check_argument_types" + ) + args = [ + self._memo.joined_path, + annotations_dict, + self._memo.get_memo_name(), + ] + node.body.insert( + self._memo.code_inject_index, Expr(Call(func_name, args, [])) + ) + + # Add a checked "return None" to the end if there's no explicit return + # Skip if the return annotation is None or Any + if ( + self._memo.return_annotation + and (not self._memo.is_async or not self._memo.has_yield_expressions) + and not isinstance(node.body[-1], Return) + and ( + not isinstance(self._memo.return_annotation, Constant) + or self._memo.return_annotation.value is not None + ) + ): + func_name = self._get_import( + "typeguard._functions", "check_return_type" + ) + return_node = Return( + Call( + func_name, + [ + self._memo.joined_path, + Constant(None), + self._memo.return_annotation, + self._memo.get_memo_name(), + ], + [], + ) + ) + + # Replace a placeholder "pass" at the end + if isinstance(node.body[-1], Pass): + copy_location(return_node, node.body[-1]) + del node.body[-1] + + node.body.append(return_node) + + # Insert code to create the call memo, if it was ever needed for this + # function + if self._memo.memo_var_name: + memo_kwargs: dict[str, Any] = {} + if self._memo.parent and isinstance(self._memo.parent.node, ClassDef): + for decorator in node.decorator_list: + if ( + isinstance(decorator, Name) + and decorator.id == "staticmethod" + ): + break + elif ( + isinstance(decorator, Name) + and decorator.id == "classmethod" + ): + memo_kwargs["self_type"] = Name( + id=node.args.args[0].arg, ctx=Load() + ) + break + else: + if node.args.args: + if node.name == "__new__": + memo_kwargs["self_type"] = Name( + id=node.args.args[0].arg, ctx=Load() + ) + else: + memo_kwargs["self_type"] = Attribute( + Name(id=node.args.args[0].arg, ctx=Load()), + "__class__", + ctx=Load(), + ) + + # Construct the function reference + # Nested functions get special treatment: the function name is added + # to free variables (and the closure of the resulting function) + names: list[str] = [node.name] + memo = self._memo.parent + while memo: + if isinstance(memo.node, (FunctionDef, AsyncFunctionDef)): + # This is a nested function. Use the function name as-is. + del names[:-1] + break + elif not isinstance(memo.node, ClassDef): + break + + names.insert(0, memo.node.name) + memo = memo.parent + + config_keywords = self._memo.get_config_keywords() + if config_keywords: + memo_kwargs["config"] = Call( + self._get_import("dataclasses", "replace"), + [self._get_import("typeguard._config", "global_config")], + config_keywords, + ) + + self._memo.memo_var_name.id = self._memo.get_unused_name("memo") + memo_store_name = Name(id=self._memo.memo_var_name.id, ctx=Store()) + globals_call = Call(Name(id="globals", ctx=Load()), [], []) + locals_call = Call(Name(id="locals", ctx=Load()), [], []) + memo_expr = Call( + self._get_import("typeguard", "TypeCheckMemo"), + [globals_call, locals_call], + [keyword(key, value) for key, value in memo_kwargs.items()], + ) + node.body.insert( + self._memo.code_inject_index, + Assign([memo_store_name], memo_expr), + ) + + self._memo.insert_imports(node) + + # Special case the __new__() method to create a local alias from the + # class name to the first argument (usually "cls") + if ( + isinstance(node, FunctionDef) + and node.args + and self._memo.parent is not None + and isinstance(self._memo.parent.node, ClassDef) + and node.name == "__new__" + ): + first_args_expr = Name(node.args.args[0].arg, ctx=Load()) + cls_name = Name(self._memo.parent.node.name, ctx=Store()) + node.body.insert( + self._memo.code_inject_index, + Assign([cls_name], first_args_expr), + ) + + # Rmove any placeholder "pass" at the end + if isinstance(node.body[-1], Pass): + del node.body[-1] + + return node + + def visit_AsyncFunctionDef( + self, node: AsyncFunctionDef + ) -> FunctionDef | AsyncFunctionDef | None: + return self.visit_FunctionDef(node) + + def visit_Return(self, node: Return) -> Return: + """This injects type checks into "return" statements.""" + self.generic_visit(node) + if ( + self._memo.return_annotation + and self._memo.should_instrument + and not self._memo.is_ignored_name(self._memo.return_annotation) + ): + func_name = self._get_import("typeguard._functions", "check_return_type") + old_node = node + retval = old_node.value or Constant(None) + node = Return( + Call( + func_name, + [ + self._memo.joined_path, + retval, + self._memo.return_annotation, + self._memo.get_memo_name(), + ], + [], + ) + ) + copy_location(node, old_node) + + return node + + def visit_Yield(self, node: Yield) -> Yield | Call: + """ + This injects type checks into "yield" expressions, checking both the yielded + value and the value sent back to the generator, when appropriate. + + """ + self._memo.has_yield_expressions = True + self.generic_visit(node) + + if ( + self._memo.yield_annotation + and self._memo.should_instrument + and not self._memo.is_ignored_name(self._memo.yield_annotation) + ): + func_name = self._get_import("typeguard._functions", "check_yield_type") + yieldval = node.value or Constant(None) + node.value = Call( + func_name, + [ + self._memo.joined_path, + yieldval, + self._memo.yield_annotation, + self._memo.get_memo_name(), + ], + [], + ) + + if ( + self._memo.send_annotation + and self._memo.should_instrument + and not self._memo.is_ignored_name(self._memo.send_annotation) + ): + func_name = self._get_import("typeguard._functions", "check_send_type") + old_node = node + call_node = Call( + func_name, + [ + self._memo.joined_path, + old_node, + self._memo.send_annotation, + self._memo.get_memo_name(), + ], + [], + ) + copy_location(call_node, old_node) + return call_node + + return node + + def visit_AnnAssign(self, node: AnnAssign) -> Any: + """ + This injects a type check into a local variable annotation-assignment within a + function body. + + """ + self.generic_visit(node) + + if ( + isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)) + and node.annotation + and isinstance(node.target, Name) + ): + self._memo.ignored_names.add(node.target.id) + annotation = self._convert_annotation(deepcopy(node.annotation)) + if annotation: + self._memo.variable_annotations[node.target.id] = annotation + if node.value: + func_name = self._get_import( + "typeguard._functions", "check_variable_assignment" + ) + node.value = Call( + func_name, + [ + node.value, + Constant(node.target.id), + annotation, + self._memo.get_memo_name(), + ], + [], + ) + + return node + + def visit_Assign(self, node: Assign) -> Any: + """ + This injects a type check into a local variable assignment within a function + body. The variable must have been annotated earlier in the function body. + + """ + self.generic_visit(node) + + # Only instrument function-local assignments + if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)): + targets: list[dict[Constant, expr | None]] = [] + check_required = False + for target in node.targets: + elts: Sequence[expr] + if isinstance(target, Name): + elts = [target] + elif isinstance(target, Tuple): + elts = target.elts + else: + continue + + annotations_: dict[Constant, expr | None] = {} + for exp in elts: + prefix = "" + if isinstance(exp, Starred): + exp = exp.value + prefix = "*" + + if isinstance(exp, Name): + self._memo.ignored_names.add(exp.id) + name = prefix + exp.id + annotation = self._memo.variable_annotations.get(exp.id) + if annotation: + annotations_[Constant(name)] = annotation + check_required = True + else: + annotations_[Constant(name)] = None + + targets.append(annotations_) + + if check_required: + # Replace missing annotations with typing.Any + for item in targets: + for key, expression in item.items(): + if expression is None: + item[key] = self._get_import("typing", "Any") + + if len(targets) == 1 and len(targets[0]) == 1: + func_name = self._get_import( + "typeguard._functions", "check_variable_assignment" + ) + target_varname = next(iter(targets[0])) + node.value = Call( + func_name, + [ + node.value, + target_varname, + targets[0][target_varname], + self._memo.get_memo_name(), + ], + [], + ) + elif targets: + func_name = self._get_import( + "typeguard._functions", "check_multi_variable_assignment" + ) + targets_arg = List( + [ + Dict(keys=list(target), values=list(target.values())) + for target in targets + ], + ctx=Load(), + ) + node.value = Call( + func_name, + [node.value, targets_arg, self._memo.get_memo_name()], + [], + ) + + return node + + def visit_NamedExpr(self, node: NamedExpr) -> Any: + """This injects a type check into an assignment expression (a := foo()).""" + self.generic_visit(node) + + # Only instrument function-local assignments + if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)) and isinstance( + node.target, Name + ): + self._memo.ignored_names.add(node.target.id) + + # Bail out if no matching annotation is found + annotation = self._memo.variable_annotations.get(node.target.id) + if annotation is None: + return node + + func_name = self._get_import( + "typeguard._functions", "check_variable_assignment" + ) + node.value = Call( + func_name, + [ + node.value, + Constant(node.target.id), + annotation, + self._memo.get_memo_name(), + ], + [], + ) + + return node + + def visit_AugAssign(self, node: AugAssign) -> Any: + """ + This injects a type check into an augmented assignment expression (a += 1). + + """ + self.generic_visit(node) + + # Only instrument function-local assignments + if isinstance(self._memo.node, (FunctionDef, AsyncFunctionDef)) and isinstance( + node.target, Name + ): + # Bail out if no matching annotation is found + annotation = self._memo.variable_annotations.get(node.target.id) + if annotation is None: + return node + + # Bail out if the operator is not found (newer Python version?) + try: + operator_func_name = aug_assign_functions[node.op.__class__] + except KeyError: + return node + + operator_func = self._get_import("operator", operator_func_name) + operator_call = Call( + operator_func, [Name(node.target.id, ctx=Load()), node.value], [] + ) + check_call = Call( + self._get_import("typeguard._functions", "check_variable_assignment"), + [ + operator_call, + Constant(node.target.id), + annotation, + self._memo.get_memo_name(), + ], + [], + ) + return Assign(targets=[node.target], value=check_call) + + return node + + def visit_If(self, node: If) -> Any: + """ + This blocks names from being collected from a module-level + "if typing.TYPE_CHECKING:" block, so that they won't be type checked. + + """ + self.generic_visit(node) + + if ( + self._memo is self._module_memo + and isinstance(node.test, Name) + and self._memo.name_matches(node.test, "typing.TYPE_CHECKING") + ): + collector = NameCollector() + collector.visit(node) + self._memo.ignored_names.update(collector.names) + + return node diff --git a/lib/typeguard/_union_transformer.py b/lib/typeguard/_union_transformer.py new file mode 100644 index 00000000..19617e6a --- /dev/null +++ b/lib/typeguard/_union_transformer.py @@ -0,0 +1,55 @@ +""" +Transforms lazily evaluated PEP 604 unions into typing.Unions, for compatibility with +Python versions older than 3.10. +""" + +from __future__ import annotations + +from ast import ( + BinOp, + BitOr, + Index, + Load, + Name, + NodeTransformer, + Subscript, + fix_missing_locations, + parse, +) +from ast import Tuple as ASTTuple +from types import CodeType +from typing import Any, Dict, FrozenSet, List, Set, Tuple, Union + +type_substitutions = { + "dict": Dict, + "list": List, + "tuple": Tuple, + "set": Set, + "frozenset": FrozenSet, + "Union": Union, +} + + +class UnionTransformer(NodeTransformer): + def __init__(self, union_name: Name | None = None): + self.union_name = union_name or Name(id="Union", ctx=Load()) + + def visit_BinOp(self, node: BinOp) -> Any: + self.generic_visit(node) + if isinstance(node.op, BitOr): + return Subscript( + value=self.union_name, + slice=Index( + ASTTuple(elts=[node.left, node.right], ctx=Load()), ctx=Load() + ), + ctx=Load(), + ) + + return node + + +def compile_type_hint(hint: str) -> CodeType: + parsed = parse(hint, "", "eval") + UnionTransformer().visit(parsed) + fix_missing_locations(parsed) + return compile(parsed, "", "eval", flags=0) diff --git a/lib/typeguard/_utils.py b/lib/typeguard/_utils.py new file mode 100644 index 00000000..96818fd2 --- /dev/null +++ b/lib/typeguard/_utils.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import inspect +import sys +from importlib import import_module +from inspect import currentframe +from types import CodeType, FrameType, FunctionType +from typing import TYPE_CHECKING, Any, Callable, ForwardRef, Union, cast, final +from weakref import WeakValueDictionary + +if TYPE_CHECKING: + from ._memo import TypeCheckMemo + +if sys.version_info >= (3, 10): + from typing import get_args, get_origin + + def evaluate_forwardref(forwardref: ForwardRef, memo: TypeCheckMemo) -> Any: + return forwardref._evaluate(memo.globals, memo.locals, frozenset()) + +else: + from typing_extensions import get_args, get_origin + + evaluate_extra_args: tuple[frozenset[Any], ...] = ( + (frozenset(),) if sys.version_info >= (3, 9) else () + ) + + def evaluate_forwardref(forwardref: ForwardRef, memo: TypeCheckMemo) -> Any: + from ._union_transformer import compile_type_hint, type_substitutions + + if not forwardref.__forward_evaluated__: + forwardref.__forward_code__ = compile_type_hint(forwardref.__forward_arg__) + + try: + return forwardref._evaluate(memo.globals, memo.locals, *evaluate_extra_args) + except NameError: + if sys.version_info < (3, 10): + # Try again, with the type substitutions (list -> List etc.) in place + new_globals = memo.globals.copy() + new_globals.setdefault("Union", Union) + if sys.version_info < (3, 9): + new_globals.update(type_substitutions) + + return forwardref._evaluate( + new_globals, memo.locals or new_globals, *evaluate_extra_args + ) + + raise + + +_functions_map: WeakValueDictionary[CodeType, FunctionType] = WeakValueDictionary() + + +def get_type_name(type_: Any) -> str: + name: str + for attrname in "__name__", "_name", "__forward_arg__": + candidate = getattr(type_, attrname, None) + if isinstance(candidate, str): + name = candidate + break + else: + origin = get_origin(type_) + candidate = getattr(origin, "_name", None) + if candidate is None: + candidate = type_.__class__.__name__.strip("_") + + if isinstance(candidate, str): + name = candidate + else: + return "(unknown)" + + args = get_args(type_) + if args: + if name == "Literal": + formatted_args = ", ".join(repr(arg) for arg in args) + else: + formatted_args = ", ".join(get_type_name(arg) for arg in args) + + name += f"[{formatted_args}]" + + module = getattr(type_, "__module__", None) + if module and module not in (None, "typing", "typing_extensions", "builtins"): + name = module + "." + name + + return name + + +def qualified_name(obj: Any, *, add_class_prefix: bool = False) -> str: + """ + Return the qualified name (e.g. package.module.Type) for the given object. + + Builtins and types from the :mod:`typing` package get special treatment by having + the module name stripped from the generated name. + + """ + if obj is None: + return "None" + elif inspect.isclass(obj): + prefix = "class " if add_class_prefix else "" + type_ = obj + else: + prefix = "" + type_ = type(obj) + + module = type_.__module__ + qualname = type_.__qualname__ + name = qualname if module in ("typing", "builtins") else f"{module}.{qualname}" + return prefix + name + + +def function_name(func: Callable[..., Any]) -> str: + """ + Return the qualified name of the given function. + + Builtins and types from the :mod:`typing` package get special treatment by having + the module name stripped from the generated name. + + """ + # For partial functions and objects with __call__ defined, __qualname__ does not + # exist + module = getattr(func, "__module__", "") + qualname = (module + ".") if module not in ("builtins", "") else "" + return qualname + getattr(func, "__qualname__", repr(func)) + + +def resolve_reference(reference: str) -> Any: + modulename, varname = reference.partition(":")[::2] + if not modulename or not varname: + raise ValueError(f"{reference!r} is not a module:varname reference") + + obj = import_module(modulename) + for attr in varname.split("."): + obj = getattr(obj, attr) + + return obj + + +def is_method_of(obj: object, cls: type) -> bool: + return ( + inspect.isfunction(obj) + and obj.__module__ == cls.__module__ + and obj.__qualname__.startswith(cls.__qualname__ + ".") + ) + + +def get_stacklevel() -> int: + level = 1 + frame = cast(FrameType, currentframe()).f_back + while frame and frame.f_globals.get("__name__", "").startswith("typeguard."): + level += 1 + frame = frame.f_back + + return level + + +@final +class Unset: + __slots__ = () + + def __repr__(self) -> str: + return "" + + +unset = Unset() diff --git a/lib/typeguard/py.typed b/lib/typeguard/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/lib/typing_extensions.py b/lib/typing_extensions.py index f3132ea4..9ccd519c 100644 --- a/lib/typing_extensions.py +++ b/lib/typing_extensions.py @@ -147,27 +147,6 @@ class _Sentinel: _marker = _Sentinel() -def _check_generic(cls, parameters, elen=_marker): - """Check correct count for parameters of a generic cls (internal helper). - This gives a nice error message in case of count mismatch. - """ - if not elen: - raise TypeError(f"{cls} is not a generic class") - if elen is _marker: - if not hasattr(cls, "__parameters__") or not cls.__parameters__: - raise TypeError(f"{cls} is not a generic class") - elen = len(cls.__parameters__) - alen = len(parameters) - if alen != elen: - if hasattr(cls, "__parameters__"): - parameters = [p for p in cls.__parameters__ if not _is_unpack(p)] - num_tv_tuples = sum(isinstance(p, TypeVarTuple) for p in parameters) - if (num_tv_tuples > 0) and (alen >= elen - num_tv_tuples): - return - raise TypeError(f"Too {'many' if alen > elen else 'few'} parameters for {cls};" - f" actual {alen}, expected {elen}") - - if sys.version_info >= (3, 10): def _should_collect_from_parameters(t): return isinstance( @@ -181,27 +160,6 @@ else: return isinstance(t, typing._GenericAlias) and not t._special -def _collect_type_vars(types, typevar_types=None): - """Collect all type variable contained in types in order of - first appearance (lexicographic order). For example:: - - _collect_type_vars((T, List[S, T])) == (T, S) - """ - if typevar_types is None: - typevar_types = typing.TypeVar - tvars = [] - for t in types: - if ( - isinstance(t, typevar_types) and - t not in tvars and - not _is_unpack(t) - ): - tvars.append(t) - if _should_collect_from_parameters(t): - tvars.extend([t for t in t.__parameters__ if t not in tvars]) - return tuple(tvars) - - NoReturn = typing.NoReturn # Some unconstrained type variables. These are used by the container types. @@ -834,7 +792,11 @@ def _ensure_subclassable(mro_entries): return inner -if hasattr(typing, "ReadOnly"): +# Update this to something like >=3.13.0b1 if and when +# PEP 728 is implemented in CPython +_PEP_728_IMPLEMENTED = False + +if _PEP_728_IMPLEMENTED: # The standard library TypedDict in Python 3.8 does not store runtime information # about which (if any) keys are optional. See https://bugs.python.org/issue38834 # The standard library TypedDict in Python 3.9.0/1 does not honour the "total" @@ -845,7 +807,8 @@ if hasattr(typing, "ReadOnly"): # Aaaand on 3.12 we add __orig_bases__ to TypedDict # to enable better runtime introspection. # On 3.13 we deprecate some odd ways of creating TypedDicts. - # PEP 705 proposes adding the ReadOnly[] qualifier. + # Also on 3.13, PEP 705 adds the ReadOnly[] qualifier. + # PEP 728 (still pending) makes more changes. TypedDict = typing.TypedDict _TypedDictMeta = typing._TypedDictMeta is_typeddict = typing.is_typeddict @@ -1122,15 +1085,15 @@ else: return val -if hasattr(typing, "Required"): # 3.11+ +if hasattr(typing, "ReadOnly"): # 3.13+ get_type_hints = typing.get_type_hints -else: # <=3.10 +else: # <=3.13 # replaces _strip_annotations() def _strip_extras(t): """Strips Annotated, Required and NotRequired from a given type.""" if isinstance(t, _AnnotatedAlias): return _strip_extras(t.__origin__) - if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired): + if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired, ReadOnly): return _strip_extras(t.__args__[0]) if isinstance(t, typing._GenericAlias): stripped_args = tuple(_strip_extras(a) for a in t.__args__) @@ -2689,9 +2652,151 @@ else: # counting generic parameters, so that when we subscript a generic, # the runtime doesn't try to substitute the Unpack with the subscripted type. if not hasattr(typing, "TypeVarTuple"): - typing._collect_type_vars = _collect_type_vars - typing._check_generic = _check_generic + def _check_generic(cls, parameters, elen=_marker): + """Check correct count for parameters of a generic cls (internal helper). + This gives a nice error message in case of count mismatch. + """ + if not elen: + raise TypeError(f"{cls} is not a generic class") + if elen is _marker: + if not hasattr(cls, "__parameters__") or not cls.__parameters__: + raise TypeError(f"{cls} is not a generic class") + elen = len(cls.__parameters__) + alen = len(parameters) + if alen != elen: + expect_val = elen + if hasattr(cls, "__parameters__"): + parameters = [p for p in cls.__parameters__ if not _is_unpack(p)] + num_tv_tuples = sum(isinstance(p, TypeVarTuple) for p in parameters) + if (num_tv_tuples > 0) and (alen >= elen - num_tv_tuples): + return + + # deal with TypeVarLike defaults + # required TypeVarLikes cannot appear after a defaulted one. + if alen < elen: + # since we validate TypeVarLike default in _collect_type_vars + # or _collect_parameters we can safely check parameters[alen] + if getattr(parameters[alen], '__default__', None) is not None: + return + + num_default_tv = sum(getattr(p, '__default__', None) + is not None for p in parameters) + + elen -= num_default_tv + + expect_val = f"at least {elen}" + + things = "arguments" if sys.version_info >= (3, 10) else "parameters" + raise TypeError(f"Too {'many' if alen > elen else 'few'} {things}" + f" for {cls}; actual {alen}, expected {expect_val}") +else: + # Python 3.11+ + + def _check_generic(cls, parameters, elen): + """Check correct count for parameters of a generic cls (internal helper). + + This gives a nice error message in case of count mismatch. + """ + if not elen: + raise TypeError(f"{cls} is not a generic class") + alen = len(parameters) + if alen != elen: + expect_val = elen + if hasattr(cls, "__parameters__"): + parameters = [p for p in cls.__parameters__ if not _is_unpack(p)] + + # deal with TypeVarLike defaults + # required TypeVarLikes cannot appear after a defaulted one. + if alen < elen: + # since we validate TypeVarLike default in _collect_type_vars + # or _collect_parameters we can safely check parameters[alen] + if getattr(parameters[alen], '__default__', None) is not None: + return + + num_default_tv = sum(getattr(p, '__default__', None) + is not None for p in parameters) + + elen -= num_default_tv + + expect_val = f"at least {elen}" + + raise TypeError(f"Too {'many' if alen > elen else 'few'} arguments" + f" for {cls}; actual {alen}, expected {expect_val}") + +typing._check_generic = _check_generic + +# Python 3.11+ _collect_type_vars was renamed to _collect_parameters +if hasattr(typing, '_collect_type_vars'): + def _collect_type_vars(types, typevar_types=None): + """Collect all type variable contained in types in order of + first appearance (lexicographic order). For example:: + + _collect_type_vars((T, List[S, T])) == (T, S) + """ + if typevar_types is None: + typevar_types = typing.TypeVar + tvars = [] + # required TypeVarLike cannot appear after TypeVarLike with default + default_encountered = False + for t in types: + if ( + isinstance(t, typevar_types) and + t not in tvars and + not _is_unpack(t) + ): + if getattr(t, '__default__', None) is not None: + default_encountered = True + elif default_encountered: + raise TypeError(f'Type parameter {t!r} without a default' + ' follows type parameter with a default') + + tvars.append(t) + if _should_collect_from_parameters(t): + tvars.extend([t for t in t.__parameters__ if t not in tvars]) + return tuple(tvars) + + typing._collect_type_vars = _collect_type_vars +else: + def _collect_parameters(args): + """Collect all type variables and parameter specifications in args + in order of first appearance (lexicographic order). + + For example:: + + assert _collect_parameters((T, Callable[P, T])) == (T, P) + """ + parameters = [] + # required TypeVarLike cannot appear after TypeVarLike with default + default_encountered = False + for t in args: + if isinstance(t, type): + # We don't want __parameters__ descriptor of a bare Python class. + pass + elif isinstance(t, tuple): + # `t` might be a tuple, when `ParamSpec` is substituted with + # `[T, int]`, or `[int, *Ts]`, etc. + for x in t: + for collected in _collect_parameters([x]): + if collected not in parameters: + parameters.append(collected) + elif hasattr(t, '__typing_subst__'): + if t not in parameters: + if getattr(t, '__default__', None) is not None: + default_encountered = True + elif default_encountered: + raise TypeError(f'Type parameter {t!r} without a default' + ' follows type parameter with a default') + + parameters.append(t) + else: + for x in getattr(t, '__parameters__', ()): + if x not in parameters: + parameters.append(x) + + return tuple(parameters) + + typing._collect_parameters = _collect_parameters # Backport typing.NamedTuple as it exists in Python 3.13. # In 3.11, the ability to define generic `NamedTuple`s was supported. diff --git a/requirements.txt b/requirements.txt index 75a0cdcf..2d8c8448 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,11 +5,12 @@ beautifulsoup4==4.12.3 bleach==6.1.0 certifi==2024.2.2 cheroot==10.0.0 -cloudinary==1.39.1 cherrypy==18.9.0 +cloudinary==1.39.1 distro==1.9.0 dnspython==2.6.1 facebook-sdk==3.1.0 +future==1.0.0 ga4mp==2.0.4 gntp==1.0.3 html5lib==1.1 @@ -36,8 +37,9 @@ pytz==2024.1 requests==2.31.0 requests-oauthlib==2.0.0 rumps==0.4.0; platform_system == "Darwin" -tempora==5.5.1 simplejson==3.19.2 +six==1.16.0 +tempora==5.5.1 tokenize-rt==5.2.0 tzdata==2024.1 tzlocal==5.0.1