Merge branch 'nightly' into dependabot/pip/nightly/paho-mqtt-2.1.0

This commit is contained in:
JonnyWong16 2024-05-09 22:28:43 -07:00 committed by GitHub
commit a05752030b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
139 changed files with 8367 additions and 3353 deletions

View file

@ -23,7 +23,6 @@ import sys
# Ensure lib added to path, before any other imports # Ensure lib added to path, before any other imports
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lib')) sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lib'))
from future.builtins import str
import argparse import argparse
import datetime import datetime

View file

@ -212,28 +212,6 @@
</div> </div>
</div> </div>
</div> </div>
<% from plexpy.helpers import anon_url %>
<div id="python2-modal" class="modal fade wide" tabindex="-1" role="dialog" aria-labelledby="python2-modal">
<div class="modal-dialog" role="document">
<div class="modal-content">
<div class="modal-header">
<button type="button" class="close" data-dismiss="modal" aria-hidden="true"><i class="fa fa-remove"></i></button>
<h4 class="modal-title">Unable to Update</h4>
</div>
<div class="modal-body" style="text-align: center;">
<p>Tautulli is still running using Python 2 and cannot be updated past v2.6.3.</p>
<p>Python 3 is required to continue receiving updates.</p>
<p>
<strong>Please see the <a href="${anon_url('https://github.com/Tautulli/Tautulli/wiki/Upgrading-to-Python-3-%28Tautulli-v2.5%29')}" target="_blank" rel="noreferrer">wiki</a>
for instructions on how to upgrade to Python 3.</strong>
</p>
</div>
<div class="modal-footer">
<input type="button" class="btn btn-bright" data-dismiss="modal" value="Close">
</div>
</div>
</div>
</div>
% endif % endif
<div class="modal fade" id="ip-info-modal" tabindex="-1" role="dialog" aria-labelledby="ip-info-modal"> <div class="modal fade" id="ip-info-modal" tabindex="-1" role="dialog" aria-labelledby="ip-info-modal">
@ -1067,16 +1045,4 @@
}); });
</script> </script>
% endif % endif
% if _session['user_group'] == 'admin':
<script>
const queryString = window.location.search;
const urlParams = new URLSearchParams(queryString);
if (urlParams.get('update') === 'python2') {
$("#python2-modal").modal({
backdrop: 'static',
keyboard: false
});
}
</script>
% endif
</%def> </%def>

View file

@ -1,5 +1,5 @@
<% <%
from six.moves.urllib.parse import urlencode from urllib.parse import urlencode
%> %>
<!doctype html> <!doctype html>

View file

@ -1 +1 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore __path__ = __import__("pkgutil").extend_path(__path__, __name__)

View file

@ -1,979 +0,0 @@
# -*- coding: utf-8 -*-
"""A port of Python 3's csv module to Python 2.
The API of the csv module in Python 2 is drastically different from
the csv module in Python 3. This is due, for the most part, to the
difference between str in Python 2 and Python 3.
The semantics of Python 3's version are more useful because they support
unicode natively, while Python 2's csv does not.
"""
from __future__ import unicode_literals, absolute_import
__all__ = [ "QUOTE_MINIMAL", "QUOTE_ALL", "QUOTE_NONNUMERIC", "QUOTE_NONE",
"Error", "Dialect", "__doc__", "excel", "excel_tab",
"field_size_limit", "reader", "writer",
"register_dialect", "get_dialect", "list_dialects", "Sniffer",
"unregister_dialect", "__version__", "DictReader", "DictWriter" ]
import re
import numbers
from io import StringIO
from csv import (
QUOTE_MINIMAL, QUOTE_ALL, QUOTE_NONNUMERIC, QUOTE_NONE,
__version__, __doc__, Error, field_size_limit,
)
# Stuff needed from six
import sys
PY3 = sys.version_info[0] == 3
if PY3:
string_types = str
text_type = str
binary_type = bytes
unichr = chr
else:
string_types = basestring
text_type = unicode
binary_type = str
class QuoteStrategy(object):
quoting = None
def __init__(self, dialect):
if self.quoting is not None:
assert dialect.quoting == self.quoting
self.dialect = dialect
self.setup()
escape_pattern_quoted = r'({quotechar})'.format(
quotechar=re.escape(self.dialect.quotechar or '"'))
escape_pattern_unquoted = r'([{specialchars}])'.format(
specialchars=re.escape(self.specialchars))
self.escape_re_quoted = re.compile(escape_pattern_quoted)
self.escape_re_unquoted = re.compile(escape_pattern_unquoted)
def setup(self):
"""Optional method for strategy-wide optimizations."""
def quoted(self, field=None, raw_field=None, only=None):
"""Determine whether this field should be quoted."""
raise NotImplementedError(
'quoted must be implemented by a subclass')
@property
def specialchars(self):
"""The special characters that need to be escaped."""
raise NotImplementedError(
'specialchars must be implemented by a subclass')
def escape_re(self, quoted=None):
if quoted:
return self.escape_re_quoted
return self.escape_re_unquoted
def escapechar(self, quoted=None):
if quoted and self.dialect.doublequote:
return self.dialect.quotechar
return self.dialect.escapechar
def prepare(self, raw_field, only=None):
field = text_type(raw_field if raw_field is not None else '')
quoted = self.quoted(field=field, raw_field=raw_field, only=only)
escape_re = self.escape_re(quoted=quoted)
escapechar = self.escapechar(quoted=quoted)
if escape_re.search(field):
escapechar = '\\\\' if escapechar == '\\' else escapechar
if not escapechar:
raise Error('No escapechar is set')
escape_replace = r'{escapechar}\1'.format(escapechar=escapechar)
field = escape_re.sub(escape_replace, field)
if quoted:
field = '{quotechar}{field}{quotechar}'.format(
quotechar=self.dialect.quotechar, field=field)
return field
class QuoteMinimalStrategy(QuoteStrategy):
quoting = QUOTE_MINIMAL
def setup(self):
self.quoted_re = re.compile(r'[{specialchars}]'.format(
specialchars=re.escape(self.specialchars)))
@property
def specialchars(self):
return (
self.dialect.lineterminator +
self.dialect.quotechar +
self.dialect.delimiter +
(self.dialect.escapechar or '')
)
def quoted(self, field, only, **kwargs):
if field == self.dialect.quotechar and not self.dialect.doublequote:
# If the only character in the field is the quotechar, and
# doublequote is false, then just escape without outer quotes.
return False
return field == '' and only or bool(self.quoted_re.search(field))
class QuoteAllStrategy(QuoteStrategy):
quoting = QUOTE_ALL
@property
def specialchars(self):
return self.dialect.quotechar
def quoted(self, **kwargs):
return True
class QuoteNonnumericStrategy(QuoteStrategy):
quoting = QUOTE_NONNUMERIC
@property
def specialchars(self):
return (
self.dialect.lineterminator +
self.dialect.quotechar +
self.dialect.delimiter +
(self.dialect.escapechar or '')
)
def quoted(self, raw_field, **kwargs):
return not isinstance(raw_field, numbers.Number)
class QuoteNoneStrategy(QuoteStrategy):
quoting = QUOTE_NONE
@property
def specialchars(self):
return (
self.dialect.lineterminator +
(self.dialect.quotechar or '') +
self.dialect.delimiter +
(self.dialect.escapechar or '')
)
def quoted(self, field, only, **kwargs):
if field == '' and only:
raise Error('single empty field record must be quoted')
return False
class writer(object):
def __init__(self, fileobj, dialect='excel', **fmtparams):
if fileobj is None:
raise TypeError('fileobj must be file-like, not None')
self.fileobj = fileobj
if isinstance(dialect, text_type):
dialect = get_dialect(dialect)
try:
self.dialect = Dialect.combine(dialect, fmtparams)
except Error as e:
raise TypeError(*e.args)
strategies = {
QUOTE_MINIMAL: QuoteMinimalStrategy,
QUOTE_ALL: QuoteAllStrategy,
QUOTE_NONNUMERIC: QuoteNonnumericStrategy,
QUOTE_NONE: QuoteNoneStrategy,
}
self.strategy = strategies[self.dialect.quoting](self.dialect)
def writerow(self, row):
if row is None:
raise Error('row must be an iterable')
row = list(row)
only = len(row) == 1
row = [self.strategy.prepare(field, only=only) for field in row]
line = self.dialect.delimiter.join(row) + self.dialect.lineterminator
return self.fileobj.write(line)
def writerows(self, rows):
for row in rows:
self.writerow(row)
START_RECORD = 0
START_FIELD = 1
ESCAPED_CHAR = 2
IN_FIELD = 3
IN_QUOTED_FIELD = 4
ESCAPE_IN_QUOTED_FIELD = 5
QUOTE_IN_QUOTED_FIELD = 6
EAT_CRNL = 7
AFTER_ESCAPED_CRNL = 8
class reader(object):
def __init__(self, fileobj, dialect='excel', **fmtparams):
self.input_iter = iter(fileobj)
if isinstance(dialect, text_type):
dialect = get_dialect(dialect)
try:
self.dialect = Dialect.combine(dialect, fmtparams)
except Error as e:
raise TypeError(*e.args)
self.fields = None
self.field = None
self.line_num = 0
def parse_reset(self):
self.fields = []
self.field = []
self.state = START_RECORD
self.numeric_field = False
def parse_save_field(self):
field = ''.join(self.field)
self.field = []
if self.numeric_field:
field = float(field)
self.numeric_field = False
self.fields.append(field)
def parse_add_char(self, c):
if len(self.field) >= field_size_limit():
raise Error('field size limit exceeded')
self.field.append(c)
def parse_process_char(self, c):
switch = {
START_RECORD: self._parse_start_record,
START_FIELD: self._parse_start_field,
ESCAPED_CHAR: self._parse_escaped_char,
AFTER_ESCAPED_CRNL: self._parse_after_escaped_crnl,
IN_FIELD: self._parse_in_field,
IN_QUOTED_FIELD: self._parse_in_quoted_field,
ESCAPE_IN_QUOTED_FIELD: self._parse_escape_in_quoted_field,
QUOTE_IN_QUOTED_FIELD: self._parse_quote_in_quoted_field,
EAT_CRNL: self._parse_eat_crnl,
}
return switch[self.state](c)
def _parse_start_record(self, c):
if c == '\0':
return
elif c == '\n' or c == '\r':
self.state = EAT_CRNL
return
self.state = START_FIELD
return self._parse_start_field(c)
def _parse_start_field(self, c):
if c == '\n' or c == '\r' or c == '\0':
self.parse_save_field()
self.state = START_RECORD if c == '\0' else EAT_CRNL
elif (c == self.dialect.quotechar and
self.dialect.quoting != QUOTE_NONE):
self.state = IN_QUOTED_FIELD
elif c == self.dialect.escapechar:
self.state = ESCAPED_CHAR
elif c == ' ' and self.dialect.skipinitialspace:
pass # Ignore space at start of field
elif c == self.dialect.delimiter:
# Save empty field
self.parse_save_field()
else:
# Begin new unquoted field
if self.dialect.quoting == QUOTE_NONNUMERIC:
self.numeric_field = True
self.parse_add_char(c)
self.state = IN_FIELD
def _parse_escaped_char(self, c):
if c == '\n' or c == '\r':
self.parse_add_char(c)
self.state = AFTER_ESCAPED_CRNL
return
if c == '\0':
c = '\n'
self.parse_add_char(c)
self.state = IN_FIELD
def _parse_after_escaped_crnl(self, c):
if c == '\0':
return
return self._parse_in_field(c)
def _parse_in_field(self, c):
# In unquoted field
if c == '\n' or c == '\r' or c == '\0':
# End of line - return [fields]
self.parse_save_field()
self.state = START_RECORD if c == '\0' else EAT_CRNL
elif c == self.dialect.escapechar:
self.state = ESCAPED_CHAR
elif c == self.dialect.delimiter:
self.parse_save_field()
self.state = START_FIELD
else:
# Normal character - save in field
self.parse_add_char(c)
def _parse_in_quoted_field(self, c):
if c == '\0':
pass
elif c == self.dialect.escapechar:
self.state = ESCAPE_IN_QUOTED_FIELD
elif (c == self.dialect.quotechar and
self.dialect.quoting != QUOTE_NONE):
if self.dialect.doublequote:
self.state = QUOTE_IN_QUOTED_FIELD
else:
self.state = IN_FIELD
else:
self.parse_add_char(c)
def _parse_escape_in_quoted_field(self, c):
if c == '\0':
c = '\n'
self.parse_add_char(c)
self.state = IN_QUOTED_FIELD
def _parse_quote_in_quoted_field(self, c):
if (self.dialect.quoting != QUOTE_NONE and
c == self.dialect.quotechar):
# save "" as "
self.parse_add_char(c)
self.state = IN_QUOTED_FIELD
elif c == self.dialect.delimiter:
self.parse_save_field()
self.state = START_FIELD
elif c == '\n' or c == '\r' or c == '\0':
# End of line = return [fields]
self.parse_save_field()
self.state = START_RECORD if c == '\0' else EAT_CRNL
elif not self.dialect.strict:
self.parse_add_char(c)
self.state = IN_FIELD
else:
# illegal
raise Error("{delimiter}' expected after '{quotechar}".format(
delimiter=self.dialect.delimiter,
quotechar=self.dialect.quotechar,
))
def _parse_eat_crnl(self, c):
if c == '\n' or c == '\r':
pass
elif c == '\0':
self.state = START_RECORD
else:
raise Error('new-line character seen in unquoted field - do you '
'need to open the file in universal-newline mode?')
def __iter__(self):
return self
def __next__(self):
self.parse_reset()
while True:
try:
lineobj = next(self.input_iter)
except StopIteration:
if len(self.field) != 0 or self.state == IN_QUOTED_FIELD:
if self.dialect.strict:
raise Error('unexpected end of data')
self.parse_save_field()
if self.fields:
break
raise
if not isinstance(lineobj, text_type):
typ = type(lineobj)
typ_name = 'bytes' if typ == bytes else typ.__name__
err_str = ('iterator should return strings, not {0}'
' (did you open the file in text mode?)')
raise Error(err_str.format(typ_name))
self.line_num += 1
for c in lineobj:
if c == '\0':
raise Error('line contains NULL byte')
self.parse_process_char(c)
self.parse_process_char('\0')
if self.state == START_RECORD:
break
fields = self.fields
self.fields = None
return fields
next = __next__
_dialect_registry = {}
def register_dialect(name, dialect='excel', **fmtparams):
if not isinstance(name, text_type):
raise TypeError('"name" must be a string')
dialect = Dialect.extend(dialect, fmtparams)
try:
Dialect.validate(dialect)
except:
raise TypeError('dialect is invalid')
assert name not in _dialect_registry
_dialect_registry[name] = dialect
def unregister_dialect(name):
try:
_dialect_registry.pop(name)
except KeyError:
raise Error('"{name}" not a registered dialect'.format(name=name))
def get_dialect(name):
try:
return _dialect_registry[name]
except KeyError:
raise Error('Could not find dialect {0}'.format(name))
def list_dialects():
return list(_dialect_registry)
class Dialect(object):
"""Describe a CSV dialect.
This must be subclassed (see csv.excel). Valid attributes are:
delimiter, quotechar, escapechar, doublequote, skipinitialspace,
lineterminator, quoting, strict.
"""
_name = ""
_valid = False
# placeholders
delimiter = None
quotechar = None
escapechar = None
doublequote = None
skipinitialspace = None
lineterminator = None
quoting = None
strict = None
def __init__(self):
self.validate(self)
if self.__class__ != Dialect:
self._valid = True
@classmethod
def validate(cls, dialect):
dialect = cls.extend(dialect)
if not isinstance(dialect.quoting, int):
raise Error('"quoting" must be an integer')
if dialect.delimiter is None:
raise Error('delimiter must be set')
cls.validate_text(dialect, 'delimiter')
if dialect.lineterminator is None:
raise Error('lineterminator must be set')
if not isinstance(dialect.lineterminator, text_type):
raise Error('"lineterminator" must be a string')
if dialect.quoting not in [
QUOTE_NONE, QUOTE_MINIMAL, QUOTE_NONNUMERIC, QUOTE_ALL]:
raise Error('Invalid quoting specified')
if dialect.quoting != QUOTE_NONE:
if dialect.quotechar is None and dialect.escapechar is None:
raise Error('quotechar must be set if quoting enabled')
if dialect.quotechar is not None:
cls.validate_text(dialect, 'quotechar')
@staticmethod
def validate_text(dialect, attr):
val = getattr(dialect, attr)
if not isinstance(val, text_type):
if type(val) == bytes:
raise Error('"{0}" must be string, not bytes'.format(attr))
raise Error('"{0}" must be string, not {1}'.format(
attr, type(val).__name__))
if len(val) != 1:
raise Error('"{0}" must be a 1-character string'.format(attr))
@staticmethod
def defaults():
return {
'delimiter': ',',
'doublequote': True,
'escapechar': None,
'lineterminator': '\r\n',
'quotechar': '"',
'quoting': QUOTE_MINIMAL,
'skipinitialspace': False,
'strict': False,
}
@classmethod
def extend(cls, dialect, fmtparams=None):
if isinstance(dialect, string_types):
dialect = get_dialect(dialect)
if fmtparams is None:
return dialect
defaults = cls.defaults()
if any(param not in defaults for param in fmtparams):
raise TypeError('Invalid fmtparam')
specified = dict(
(attr, getattr(dialect, attr, None))
for attr in cls.defaults()
)
specified.update(fmtparams)
return type(str('ExtendedDialect'), (cls,), specified)
@classmethod
def combine(cls, dialect, fmtparams):
"""Create a new dialect with defaults and added parameters."""
dialect = cls.extend(dialect, fmtparams)
defaults = cls.defaults()
specified = dict(
(attr, getattr(dialect, attr, None))
for attr in defaults
if getattr(dialect, attr, None) is not None or
attr in ['quotechar', 'delimiter', 'lineterminator', 'quoting']
)
defaults.update(specified)
dialect = type(str('CombinedDialect'), (cls,), defaults)
cls.validate(dialect)
return dialect()
def __delattr__(self, attr):
if self._valid:
raise AttributeError('dialect is immutable.')
super(Dialect, self).__delattr__(attr)
def __setattr__(self, attr, value):
if self._valid:
raise AttributeError('dialect is immutable.')
super(Dialect, self).__setattr__(attr, value)
class excel(Dialect):
"""Describe the usual properties of Excel-generated CSV files."""
delimiter = ','
quotechar = '"'
doublequote = True
skipinitialspace = False
lineterminator = '\r\n'
quoting = QUOTE_MINIMAL
register_dialect("excel", excel)
class excel_tab(excel):
"""Describe the usual properties of Excel-generated TAB-delimited files."""
delimiter = '\t'
register_dialect("excel-tab", excel_tab)
class unix_dialect(Dialect):
"""Describe the usual properties of Unix-generated CSV files."""
delimiter = ','
quotechar = '"'
doublequote = True
skipinitialspace = False
lineterminator = '\n'
quoting = QUOTE_ALL
register_dialect("unix", unix_dialect)
class DictReader(object):
def __init__(self, f, fieldnames=None, restkey=None, restval=None,
dialect="excel", *args, **kwds):
self._fieldnames = fieldnames # list of keys for the dict
self.restkey = restkey # key to catch long rows
self.restval = restval # default value for short rows
self.reader = reader(f, dialect, *args, **kwds)
self.dialect = dialect
self.line_num = 0
def __iter__(self):
return self
@property
def fieldnames(self):
if self._fieldnames is None:
try:
self._fieldnames = next(self.reader)
except StopIteration:
pass
self.line_num = self.reader.line_num
return self._fieldnames
@fieldnames.setter
def fieldnames(self, value):
self._fieldnames = value
def __next__(self):
if self.line_num == 0:
# Used only for its side effect.
self.fieldnames
row = next(self.reader)
self.line_num = self.reader.line_num
# unlike the basic reader, we prefer not to return blanks,
# because we will typically wind up with a dict full of None
# values
while row == []:
row = next(self.reader)
d = dict(zip(self.fieldnames, row))
lf = len(self.fieldnames)
lr = len(row)
if lf < lr:
d[self.restkey] = row[lf:]
elif lf > lr:
for key in self.fieldnames[lr:]:
d[key] = self.restval
return d
next = __next__
class DictWriter(object):
def __init__(self, f, fieldnames, restval="", extrasaction="raise",
dialect="excel", *args, **kwds):
self.fieldnames = fieldnames # list of keys for the dict
self.restval = restval # for writing short dicts
if extrasaction.lower() not in ("raise", "ignore"):
raise ValueError("extrasaction (%s) must be 'raise' or 'ignore'"
% extrasaction)
self.extrasaction = extrasaction
self.writer = writer(f, dialect, *args, **kwds)
def writeheader(self):
header = dict(zip(self.fieldnames, self.fieldnames))
self.writerow(header)
def _dict_to_list(self, rowdict):
if self.extrasaction == "raise":
wrong_fields = [k for k in rowdict if k not in self.fieldnames]
if wrong_fields:
raise ValueError("dict contains fields not in fieldnames: "
+ ", ".join([repr(x) for x in wrong_fields]))
return (rowdict.get(key, self.restval) for key in self.fieldnames)
def writerow(self, rowdict):
return self.writer.writerow(self._dict_to_list(rowdict))
def writerows(self, rowdicts):
return self.writer.writerows(map(self._dict_to_list, rowdicts))
# Guard Sniffer's type checking against builds that exclude complex()
try:
complex
except NameError:
complex = float
class Sniffer(object):
'''
"Sniffs" the format of a CSV file (i.e. delimiter, quotechar)
Returns a Dialect object.
'''
def __init__(self):
# in case there is more than one possible delimiter
self.preferred = [',', '\t', ';', ' ', ':']
def sniff(self, sample, delimiters=None):
"""
Returns a dialect (or None) corresponding to the sample
"""
quotechar, doublequote, delimiter, skipinitialspace = \
self._guess_quote_and_delimiter(sample, delimiters)
if not delimiter:
delimiter, skipinitialspace = self._guess_delimiter(sample,
delimiters)
if not delimiter:
raise Error("Could not determine delimiter")
class dialect(Dialect):
_name = "sniffed"
lineterminator = '\r\n'
quoting = QUOTE_MINIMAL
# escapechar = ''
dialect.doublequote = doublequote
dialect.delimiter = delimiter
# _csv.reader won't accept a quotechar of ''
dialect.quotechar = quotechar or '"'
dialect.skipinitialspace = skipinitialspace
return dialect
def _guess_quote_and_delimiter(self, data, delimiters):
"""
Looks for text enclosed between two identical quotes
(the probable quotechar) which are preceded and followed
by the same character (the probable delimiter).
For example:
,'some text',
The quote with the most wins, same with the delimiter.
If there is no quotechar the delimiter can't be determined
this way.
"""
matches = []
for restr in ('(?P<delim>[^\w\n"\'])(?P<space> ?)(?P<quote>["\']).*?(?P=quote)(?P=delim)', # ,".*?",
'(?:^|\n)(?P<quote>["\']).*?(?P=quote)(?P<delim>[^\w\n"\'])(?P<space> ?)', # ".*?",
'(?P<delim>>[^\w\n"\'])(?P<space> ?)(?P<quote>["\']).*?(?P=quote)(?:$|\n)', # ,".*?"
'(?:^|\n)(?P<quote>["\']).*?(?P=quote)(?:$|\n)'): # ".*?" (no delim, no space)
regexp = re.compile(restr, re.DOTALL | re.MULTILINE)
matches = regexp.findall(data)
if matches:
break
if not matches:
# (quotechar, doublequote, delimiter, skipinitialspace)
return ('', False, None, 0)
quotes = {}
delims = {}
spaces = 0
groupindex = regexp.groupindex
for m in matches:
n = groupindex['quote'] - 1
key = m[n]
if key:
quotes[key] = quotes.get(key, 0) + 1
try:
n = groupindex['delim'] - 1
key = m[n]
except KeyError:
continue
if key and (delimiters is None or key in delimiters):
delims[key] = delims.get(key, 0) + 1
try:
n = groupindex['space'] - 1
except KeyError:
continue
if m[n]:
spaces += 1
quotechar = max(quotes, key=quotes.get)
if delims:
delim = max(delims, key=delims.get)
skipinitialspace = delims[delim] == spaces
if delim == '\n': # most likely a file with a single column
delim = ''
else:
# there is *no* delimiter, it's a single column of quoted data
delim = ''
skipinitialspace = 0
# if we see an extra quote between delimiters, we've got a
# double quoted format
dq_regexp = re.compile(
r"((%(delim)s)|^)\W*%(quote)s[^%(delim)s\n]*%(quote)s[^%(delim)s\n]*%(quote)s\W*((%(delim)s)|$)" % \
{'delim':re.escape(delim), 'quote':quotechar}, re.MULTILINE)
if dq_regexp.search(data):
doublequote = True
else:
doublequote = False
return (quotechar, doublequote, delim, skipinitialspace)
def _guess_delimiter(self, data, delimiters):
"""
The delimiter /should/ occur the same number of times on
each row. However, due to malformed data, it may not. We don't want
an all or nothing approach, so we allow for small variations in this
number.
1) build a table of the frequency of each character on every line.
2) build a table of frequencies of this frequency (meta-frequency?),
e.g. 'x occurred 5 times in 10 rows, 6 times in 1000 rows,
7 times in 2 rows'
3) use the mode of the meta-frequency to determine the /expected/
frequency for that character
4) find out how often the character actually meets that goal
5) the character that best meets its goal is the delimiter
For performance reasons, the data is evaluated in chunks, so it can
try and evaluate the smallest portion of the data possible, evaluating
additional chunks as necessary.
"""
data = list(filter(None, data.split('\n')))
ascii = [unichr(c) for c in range(127)] # 7-bit ASCII
# build frequency tables
chunkLength = min(10, len(data))
iteration = 0
charFrequency = {}
modes = {}
delims = {}
start, end = 0, min(chunkLength, len(data))
while start < len(data):
iteration += 1
for line in data[start:end]:
for char in ascii:
metaFrequency = charFrequency.get(char, {})
# must count even if frequency is 0
freq = line.count(char)
# value is the mode
metaFrequency[freq] = metaFrequency.get(freq, 0) + 1
charFrequency[char] = metaFrequency
for char in charFrequency.keys():
items = list(charFrequency[char].items())
if len(items) == 1 and items[0][0] == 0:
continue
# get the mode of the frequencies
if len(items) > 1:
modes[char] = max(items, key=lambda x: x[1])
# adjust the mode - subtract the sum of all
# other frequencies
items.remove(modes[char])
modes[char] = (modes[char][0], modes[char][1]
- sum(item[1] for item in items))
else:
modes[char] = items[0]
# build a list of possible delimiters
modeList = modes.items()
total = float(chunkLength * iteration)
# (rows of consistent data) / (number of rows) = 100%
consistency = 1.0
# minimum consistency threshold
threshold = 0.9
while len(delims) == 0 and consistency >= threshold:
for k, v in modeList:
if v[0] > 0 and v[1] > 0:
if ((v[1]/total) >= consistency and
(delimiters is None or k in delimiters)):
delims[k] = v
consistency -= 0.01
if len(delims) == 1:
delim = list(delims.keys())[0]
skipinitialspace = (data[0].count(delim) ==
data[0].count("%c " % delim))
return (delim, skipinitialspace)
# analyze another chunkLength lines
start = end
end += chunkLength
if not delims:
return ('', 0)
# if there's more than one, fall back to a 'preferred' list
if len(delims) > 1:
for d in self.preferred:
if d in delims.keys():
skipinitialspace = (data[0].count(d) ==
data[0].count("%c " % d))
return (d, skipinitialspace)
# nothing else indicates a preference, pick the character that
# dominates(?)
items = [(v,k) for (k,v) in delims.items()]
items.sort()
delim = items[-1][1]
skipinitialspace = (data[0].count(delim) ==
data[0].count("%c " % delim))
return (delim, skipinitialspace)
def has_header(self, sample):
# Creates a dictionary of types of data in each column. If any
# column is of a single type (say, integers), *except* for the first
# row, then the first row is presumed to be labels. If the type
# can't be determined, it is assumed to be a string in which case
# the length of the string is the determining factor: if all of the
# rows except for the first are the same length, it's a header.
# Finally, a 'vote' is taken at the end for each column, adding or
# subtracting from the likelihood of the first row being a header.
rdr = reader(StringIO(sample), self.sniff(sample))
header = next(rdr) # assume first row is header
columns = len(header)
columnTypes = {}
for i in range(columns): columnTypes[i] = None
checked = 0
for row in rdr:
# arbitrary number of rows to check, to keep it sane
if checked > 20:
break
checked += 1
if len(row) != columns:
continue # skip rows that have irregular number of columns
for col in list(columnTypes.keys()):
for thisType in [int, float, complex]:
try:
thisType(row[col])
break
except (ValueError, OverflowError):
pass
else:
# fallback to length of string
thisType = len(row[col])
if thisType != columnTypes[col]:
if columnTypes[col] is None: # add new column type
columnTypes[col] = thisType
else:
# type is inconsistent, remove column from
# consideration
del columnTypes[col]
# finally, compare results against first row and "vote"
# on whether it's a header
hasHeader = 0
for col, colType in columnTypes.items():
if type(colType) == type(0): # it's a length
if len(header[col]) != colType:
hasHeader += 1
else:
hasHeader -= 1
else: # attempt typecast
try:
colType(header[col])
except (ValueError, TypeError):
hasHeader += 1
else:
hasHeader -= 1
return hasHeader > 0

View file

@ -1,243 +0,0 @@
from __future__ import absolute_import
import functools
from collections import namedtuple
from threading import RLock
_CacheInfo = namedtuple("_CacheInfo", ["hits", "misses", "maxsize", "currsize"])
@functools.wraps(functools.update_wrapper)
def update_wrapper(
wrapper,
wrapped,
assigned=functools.WRAPPER_ASSIGNMENTS,
updated=functools.WRAPPER_UPDATES,
):
"""
Patch two bugs in functools.update_wrapper.
"""
# workaround for http://bugs.python.org/issue3445
assigned = tuple(attr for attr in assigned if hasattr(wrapped, attr))
wrapper = functools.update_wrapper(wrapper, wrapped, assigned, updated)
# workaround for https://bugs.python.org/issue17482
wrapper.__wrapped__ = wrapped
return wrapper
class _HashedSeq(list):
"""This class guarantees that hash() will be called no more than once
per element. This is important because the lru_cache() will hash
the key multiple times on a cache miss.
"""
__slots__ = 'hashvalue'
def __init__(self, tup, hash=hash):
self[:] = tup
self.hashvalue = hash(tup)
def __hash__(self):
return self.hashvalue
def _make_key(
args,
kwds,
typed,
kwd_mark=(object(),),
fasttypes={int, str},
tuple=tuple,
type=type,
len=len,
):
"""Make a cache key from optionally typed positional and keyword arguments
The key is constructed in a way that is flat as possible rather than
as a nested structure that would take more memory.
If there is only a single argument and its data type is known to cache
its hash value, then that argument is returned without a wrapper. This
saves space and improves lookup speed.
"""
# All of code below relies on kwds preserving the order input by the user.
# Formerly, we sorted() the kwds before looping. The new way is *much*
# faster; however, it means that f(x=1, y=2) will now be treated as a
# distinct call from f(y=2, x=1) which will be cached separately.
key = args
if kwds:
key += kwd_mark
for item in kwds.items():
key += item
if typed:
key += tuple(type(v) for v in args)
if kwds:
key += tuple(type(v) for v in kwds.values())
elif len(key) == 1 and type(key[0]) in fasttypes:
return key[0]
return _HashedSeq(key)
def lru_cache(maxsize=128, typed=False):
"""Least-recently-used cache decorator.
If *maxsize* is set to None, the LRU features are disabled and the cache
can grow without bound.
If *typed* is True, arguments of different types will be cached separately.
For example, f(decimal.Decimal("3.0")) and f(3.0) will be treated as
distinct calls with distinct results. Some types such as str and int may
be cached separately even when typed is false.
Arguments to the cached function must be hashable.
View the cache statistics named tuple (hits, misses, maxsize, currsize)
with f.cache_info(). Clear the cache and statistics with f.cache_clear().
Access the underlying function with f.__wrapped__.
See: https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)
"""
# Users should only access the lru_cache through its public API:
# cache_info, cache_clear, and f.__wrapped__
# The internals of the lru_cache are encapsulated for thread safety and
# to allow the implementation to change (including a possible C version).
if isinstance(maxsize, int):
# Negative maxsize is treated as 0
if maxsize < 0:
maxsize = 0
elif callable(maxsize) and isinstance(typed, bool):
# The user_function was passed in directly via the maxsize argument
user_function, maxsize = maxsize, 128
wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
wrapper.cache_parameters = lambda: {'maxsize': maxsize, 'typed': typed}
return update_wrapper(wrapper, user_function)
elif maxsize is not None:
raise TypeError('Expected first argument to be an integer, a callable, or None')
def decorating_function(user_function):
wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
wrapper.cache_parameters = lambda: {'maxsize': maxsize, 'typed': typed}
return update_wrapper(wrapper, user_function)
return decorating_function
def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo):
# Constants shared by all lru cache instances:
sentinel = object() # unique object used to signal cache misses
make_key = _make_key # build a key from the function arguments
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
cache = {}
hits = misses = 0
full = False
cache_get = cache.get # bound method to lookup a key or return None
cache_len = cache.__len__ # get cache size without calling len()
lock = RLock() # because linkedlist updates aren't threadsafe
root = [] # root of the circular doubly linked list
root[:] = [root, root, None, None] # initialize by pointing to self
if maxsize == 0:
def wrapper(*args, **kwds):
# No caching -- just a statistics update
nonlocal misses
misses += 1
result = user_function(*args, **kwds)
return result
elif maxsize is None:
def wrapper(*args, **kwds):
# Simple caching without ordering or size limit
nonlocal hits, misses
key = make_key(args, kwds, typed)
result = cache_get(key, sentinel)
if result is not sentinel:
hits += 1
return result
misses += 1
result = user_function(*args, **kwds)
cache[key] = result
return result
else:
def wrapper(*args, **kwds):
# Size limited caching that tracks accesses by recency
nonlocal root, hits, misses, full
key = make_key(args, kwds, typed)
with lock:
link = cache_get(key)
if link is not None:
# Move the link to the front of the circular queue
link_prev, link_next, _key, result = link
link_prev[NEXT] = link_next
link_next[PREV] = link_prev
last = root[PREV]
last[NEXT] = root[PREV] = link
link[PREV] = last
link[NEXT] = root
hits += 1
return result
misses += 1
result = user_function(*args, **kwds)
with lock:
if key in cache:
# Getting here means that this same key was added to the
# cache while the lock was released. Since the link
# update is already done, we need only return the
# computed result and update the count of misses.
pass
elif full:
# Use the old root to store the new key and result.
oldroot = root
oldroot[KEY] = key
oldroot[RESULT] = result
# Empty the oldest link and make it the new root.
# Keep a reference to the old key and old result to
# prevent their ref counts from going to zero during the
# update. That will prevent potentially arbitrary object
# clean-up code (i.e. __del__) from running while we're
# still adjusting the links.
root = oldroot[NEXT]
oldkey = root[KEY]
root[KEY] = root[RESULT] = None
# Now update the cache dictionary.
del cache[oldkey]
# Save the potentially reentrant cache[key] assignment
# for last, after the root and links have been put in
# a consistent state.
cache[key] = oldroot
else:
# Put result in a new link at the front of the queue.
last = root[PREV]
link = [last, root, key, result]
last[NEXT] = root[PREV] = cache[key] = link
# Use the cache_len bound method instead of the len() function
# which could potentially be wrapped in an lru_cache itself.
full = cache_len() >= maxsize
return result
def cache_info():
"""Report cache statistics"""
with lock:
return _CacheInfo(hits, misses, maxsize, cache_len())
def cache_clear():
"""Clear the cache and cache statistics"""
nonlocal hits, misses, full
with lock:
cache.clear()
root[:] = [root, root, None, None]
hits = misses = 0
full = False
wrapper.cache_info = cache_info
wrapper.cache_clear = cache_clear
return wrapper

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,5 @@
from . import main
if __name__ == '__main__':
main()

View file

@ -0,0 +1,24 @@
import sys
if sys.version_info < (3, 9):
def removesuffix(self, suffix):
# suffix='' should not call self[:-0].
if suffix and self.endswith(suffix):
return self[: -len(suffix)]
else:
return self[:]
def removeprefix(self, prefix):
if self.startswith(prefix):
return self[len(prefix) :]
else:
return self[:]
else:
def removesuffix(self, suffix):
return self.removesuffix(suffix)
def removeprefix(self, prefix):
return self.removeprefix(prefix)

View file

@ -292,7 +292,20 @@ class ConnectionManager:
if self.server.ssl_adapter is not None: if self.server.ssl_adapter is not None:
try: try:
s, ssl_env = self.server.ssl_adapter.wrap(s) s, ssl_env = self.server.ssl_adapter.wrap(s)
except errors.NoSSLError: except errors.FatalSSLAlert as tls_connection_drop_error:
self.server.error_log(
f'Client {addr !s} lost — peer dropped the TLS '
'connection suddenly, during handshake: '
f'{tls_connection_drop_error !s}',
)
return
except errors.NoSSLError as http_over_https_err:
self.server.error_log(
f'Client {addr !s} attempted to speak plain HTTP into '
'a TCP connection configured for TLS-only traffic — '
'trying to send back a plain HTTP error response: '
f'{http_over_https_err !s}',
)
msg = ( msg = (
'The client sent a plain HTTP request, but ' 'The client sent a plain HTTP request, but '
'this server only speaks HTTPS on this port.' 'this server only speaks HTTPS on this port.'
@ -311,8 +324,6 @@ class ConnectionManager:
if ex.args[0] not in errors.socket_errors_to_ignore: if ex.args[0] not in errors.socket_errors_to_ignore:
raise raise
return return
if not s:
return
mf = self.server.ssl_adapter.makefile mf = self.server.ssl_adapter.makefile
# Re-apply our timeout since we may have a new socket object # Re-apply our timeout since we may have a new socket object
if hasattr(s, 'settimeout'): if hasattr(s, 'settimeout'):

View file

@ -157,7 +157,7 @@ QUOTED_SLASH = b'%2F'
QUOTED_SLASH_REGEX = re.compile(b''.join((b'(?i)', QUOTED_SLASH))) QUOTED_SLASH_REGEX = re.compile(b''.join((b'(?i)', QUOTED_SLASH)))
_STOPPING_FOR_INTERRUPT = object() # sentinel used during shutdown _STOPPING_FOR_INTERRUPT = Exception() # sentinel used during shutdown
comma_separated_headers = [ comma_separated_headers = [
@ -209,7 +209,11 @@ class HeaderReader:
if not line.endswith(CRLF): if not line.endswith(CRLF):
raise ValueError('HTTP requires CRLF terminators') raise ValueError('HTTP requires CRLF terminators')
if line[0] in (SPACE, TAB): if line[:1] in (SPACE, TAB):
# NOTE: `type(line[0]) is int` and `type(line[:1]) is bytes`.
# NOTE: The former causes a the following warning:
# NOTE: `BytesWarning('Comparison between bytes and int')`
# NOTE: The latter is equivalent and does not.
# It's a continuation line. # It's a continuation line.
v = line.strip() v = line.strip()
else: else:
@ -1725,16 +1729,16 @@ class HTTPServer:
"""Run the server forever, and stop it cleanly on exit.""" """Run the server forever, and stop it cleanly on exit."""
try: try:
self.start() self.start()
except (KeyboardInterrupt, IOError): except KeyboardInterrupt as kb_intr_exc:
# The time.sleep call might raise underlying_interrupt = self.interrupt
# "IOError: [Errno 4] Interrupted function call" on KBInt. if not underlying_interrupt:
self.error_log('Keyboard Interrupt: shutting down') self.interrupt = kb_intr_exc
self.stop() raise kb_intr_exc from underlying_interrupt
raise except SystemExit as sys_exit_exc:
except SystemExit: underlying_interrupt = self.interrupt
self.error_log('SystemExit raised: shutting down') if not underlying_interrupt:
self.stop() self.interrupt = sys_exit_exc
raise raise sys_exit_exc from underlying_interrupt
def prepare(self): # noqa: C901 # FIXME def prepare(self): # noqa: C901 # FIXME
"""Prepare server to serving requests. """Prepare server to serving requests.
@ -2111,6 +2115,13 @@ class HTTPServer:
has completed. has completed.
""" """
self._interrupt = _STOPPING_FOR_INTERRUPT self._interrupt = _STOPPING_FOR_INTERRUPT
if isinstance(interrupt, KeyboardInterrupt):
self.error_log('Keyboard Interrupt: shutting down')
if isinstance(interrupt, SystemExit):
self.error_log('SystemExit raised: shutting down')
self.stop() self.stop()
self._interrupt = interrupt self._interrupt = interrupt

View file

@ -27,12 +27,9 @@ except ImportError:
from . import Adapter from . import Adapter
from .. import errors from .. import errors
from .._compat import IS_ABOVE_OPENSSL10
from ..makefile import StreamReader, StreamWriter from ..makefile import StreamReader, StreamWriter
from ..server import HTTPServer from ..server import HTTPServer
generic_socket_error = OSError
def _assert_ssl_exc_contains(exc, *msgs): def _assert_ssl_exc_contains(exc, *msgs):
"""Check whether SSL exception contains either of messages provided.""" """Check whether SSL exception contains either of messages provided."""
@ -265,62 +262,35 @@ class BuiltinSSLAdapter(Adapter):
def wrap(self, sock): def wrap(self, sock):
"""Wrap and return the given socket, plus WSGI environ entries.""" """Wrap and return the given socket, plus WSGI environ entries."""
EMPTY_RESULT = None, {}
try: try:
s = self.context.wrap_socket( s = self.context.wrap_socket(
sock, do_handshake_on_connect=True, server_side=True, sock, do_handshake_on_connect=True, server_side=True,
) )
except ssl.SSLError as ex: except (
if ex.errno == ssl.SSL_ERROR_EOF: ssl.SSLEOFError,
# This is almost certainly due to the cherrypy engine ssl.SSLZeroReturnError,
# 'pinging' the socket to assert it's connectable; ) as tls_connection_drop_error:
# the 'ping' isn't SSL. raise errors.FatalSSLAlert(
return EMPTY_RESULT *tls_connection_drop_error.args,
elif ex.errno == ssl.SSL_ERROR_SSL: ) from tls_connection_drop_error
if _assert_ssl_exc_contains(ex, 'http request'): except ssl.SSLError as generic_tls_error:
# The client is speaking HTTP to an HTTPS server. peer_speaks_plain_http_over_https = (
raise errors.NoSSLError generic_tls_error.errno == ssl.SSL_ERROR_SSL and
_assert_ssl_exc_contains(generic_tls_error, 'http request')
# Check if it's one of the known errors
# Errors that are caught by PyOpenSSL, but thrown by
# built-in ssl
_block_errors = (
'unknown protocol', 'unknown ca', 'unknown_ca',
'unknown error',
'https proxy request', 'inappropriate fallback',
'wrong version number',
'no shared cipher', 'certificate unknown',
'ccs received early',
'certificate verify failed', # client cert w/o trusted CA
'version too low', # caused by SSL3 connections
'unsupported protocol', # caused by TLS1 connections
) )
if _assert_ssl_exc_contains(ex, *_block_errors): if peer_speaks_plain_http_over_https:
# Accepted error, let's pass reraised_connection_drop_exc_cls = errors.NoSSLError
return EMPTY_RESULT else:
elif _assert_ssl_exc_contains(ex, 'handshake operation timed out'): reraised_connection_drop_exc_cls = errors.FatalSSLAlert
# This error is thrown by builtin SSL after a timeout
# when client is speaking HTTP to an HTTPS server.
# The connection can safely be dropped.
return EMPTY_RESULT
raise
except generic_socket_error as exc:
"""It is unclear why exactly this happens.
It's reproducible only with openssl>1.0 and stdlib raise reraised_connection_drop_exc_cls(
:py:mod:`ssl` wrapper. *generic_tls_error.args,
In CherryPy it's triggered by Checker plugin, which connects ) from generic_tls_error
to the app listening to the socket port in TLS mode via plain except OSError as tcp_connection_drop_error:
HTTP during startup (from the same process). raise errors.FatalSSLAlert(
*tcp_connection_drop_error.args,
) from tcp_connection_drop_error
Ref: https://github.com/cherrypy/cherrypy/issues/1618
"""
is_error0 = exc.args == (0, 'Error')
if is_error0 and IS_ABOVE_OPENSSL10:
return EMPTY_RESULT
raise
return s, self.get_environ(s) return s, self.get_environ(s)
def get_environ(self, sock): def get_environ(self, sock):

View file

@ -150,7 +150,7 @@ class SSLFileobjectMixin:
return self._safe_call( return self._safe_call(
False, False,
super(SSLFileobjectMixin, self).sendall, super(SSLFileobjectMixin, self).sendall,
*args, **kwargs *args, **kwargs,
) )
def send(self, *args, **kwargs): def send(self, *args, **kwargs):
@ -158,7 +158,7 @@ class SSLFileobjectMixin:
return self._safe_call( return self._safe_call(
False, False,
super(SSLFileobjectMixin, self).send, super(SSLFileobjectMixin, self).send,
*args, **kwargs *args, **kwargs,
) )
@ -196,6 +196,7 @@ class SSLConnectionProxyMeta:
def lock_decorator(method): def lock_decorator(method):
"""Create a proxy method for a new class.""" """Create a proxy method for a new class."""
def proxy_wrapper(self, *args): def proxy_wrapper(self, *args):
self._lock.acquire() self._lock.acquire()
try: try:
@ -212,6 +213,7 @@ class SSLConnectionProxyMeta:
def make_property(property_): def make_property(property_):
"""Create a proxy method for a new class.""" """Create a proxy method for a new class."""
def proxy_prop_wrapper(self): def proxy_prop_wrapper(self):
return getattr(self._ssl_conn, property_) return getattr(self._ssl_conn, property_)
proxy_prop_wrapper.__name__ = property_ proxy_prop_wrapper.__name__ = property_

View file

@ -12,7 +12,10 @@ import pytest
from .._compat import IS_MACOS, IS_WINDOWS # noqa: WPS436 from .._compat import IS_MACOS, IS_WINDOWS # noqa: WPS436
from ..server import Gateway, HTTPServer from ..server import Gateway, HTTPServer
from ..testing import ( # noqa: F401 # pylint: disable=unused-import from ..testing import ( # noqa: F401 # pylint: disable=unused-import
native_server, wsgi_server, native_server,
thread_and_wsgi_server,
thread_and_native_server,
wsgi_server,
) )
from ..testing import get_server_client from ..testing import get_server_client
@ -31,6 +34,28 @@ def http_request_timeout():
return computed_timeout return computed_timeout
@pytest.fixture
# pylint: disable=redefined-outer-name
def wsgi_server_thread(thread_and_wsgi_server): # noqa: F811
"""Set up and tear down a Cheroot WSGI server instance.
This exposes the server thread.
"""
server_thread, _srv = thread_and_wsgi_server
return server_thread
@pytest.fixture
# pylint: disable=redefined-outer-name
def native_server_thread(thread_and_native_server): # noqa: F811
"""Set up and tear down a Cheroot HTTP server instance.
This exposes the server thread.
"""
server_thread, _srv = thread_and_native_server
return server_thread
@pytest.fixture @pytest.fixture
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
def wsgi_server_client(wsgi_server): # noqa: F811 def wsgi_server_client(wsgi_server): # noqa: F811

View file

@ -1,7 +1,9 @@
"""Tests for TCP connection handling, including proper and timely close.""" """Tests for TCP connection handling, including proper and timely close."""
import errno import errno
from re import match as _matches_pattern
import socket import socket
import sys
import time import time
import logging import logging
import traceback as traceback_ import traceback as traceback_
@ -17,6 +19,7 @@ from cheroot._compat import IS_CI, IS_MACOS, IS_PYPY, IS_WINDOWS
import cheroot.server import cheroot.server
IS_PY36 = sys.version_info[:2] == (3, 6)
IS_SLOW_ENV = IS_MACOS or IS_WINDOWS IS_SLOW_ENV = IS_MACOS or IS_WINDOWS
@ -53,7 +56,8 @@ class Controller(helper.Controller):
"'POST' != request.method %r" % "'POST' != request.method %r" %
req.environ['REQUEST_METHOD'], req.environ['REQUEST_METHOD'],
) )
return "thanks for '%s'" % req.environ['wsgi.input'].read() input_contents = req.environ['wsgi.input'].read().decode('utf-8')
return f"thanks for '{input_contents !s}'"
def custom_204(req, resp): def custom_204(req, resp):
"""Render response with status 204.""" """Render response with status 204."""
@ -699,6 +703,275 @@ def test_broken_connection_during_tcp_fin(
assert _close_kernel_socket.exception_leaked is exception_leaks assert _close_kernel_socket.exception_leaked is exception_leaks
def test_broken_connection_during_http_communication_fallback( # noqa: WPS118
monkeypatch,
test_client,
testing_server,
wsgi_server_thread,
):
"""Test that unhandled internal error cascades into shutdown."""
def _raise_connection_reset(*_args, **_kwargs):
raise ConnectionResetError(666)
def _read_request_line(self):
monkeypatch.setattr(self.conn.rfile, 'close', _raise_connection_reset)
monkeypatch.setattr(self.conn.wfile, 'write', _raise_connection_reset)
_raise_connection_reset()
monkeypatch.setattr(
test_client.server_instance.ConnectionClass.RequestHandlerClass,
'read_request_line',
_read_request_line,
)
test_client.get_connection().send(b'GET / HTTP/1.1')
wsgi_server_thread.join() # no extra logs upon server termination
actual_log_entries = testing_server.error_log.calls[:]
testing_server.error_log.calls.clear() # prevent post-test assertions
expected_log_entries = (
(logging.WARNING, r'^socket\.error 666$'),
(
logging.INFO,
'^Got a connection error while handling a connection '
r'from .*:\d{1,5} \(666\)',
),
(
logging.CRITICAL,
r'A fatal exception happened\. Setting the server interrupt flag '
r'to ConnectionResetError\(666,?\) and giving up\.\n\nPlease, '
'report this on the Cheroot tracker at '
r'<https://github\.com/cherrypy/cheroot/issues/new/choose>, '
'providing a full reproducer with as much context and details '
r'as possible\.$',
),
)
assert len(actual_log_entries) == len(expected_log_entries)
for ( # noqa: WPS352
(expected_log_level, expected_msg_regex),
(actual_msg, actual_log_level, _tb),
) in zip(expected_log_entries, actual_log_entries):
assert expected_log_level == actual_log_level
assert _matches_pattern(expected_msg_regex, actual_msg) is not None, (
f'{actual_msg !r} does not match {expected_msg_regex !r}'
)
def test_kb_int_from_http_handler(
test_client,
testing_server,
wsgi_server_thread,
):
"""Test that a keyboard interrupt from HTTP handler causes shutdown."""
def _trigger_kb_intr(_req, _resp):
raise KeyboardInterrupt('simulated test handler keyboard interrupt')
testing_server.wsgi_app.handlers['/kb_intr'] = _trigger_kb_intr
http_conn = test_client.get_connection()
http_conn.putrequest('GET', '/kb_intr', skip_host=True)
http_conn.putheader('Host', http_conn.host)
http_conn.endheaders()
wsgi_server_thread.join() # no extra logs upon server termination
actual_log_entries = testing_server.error_log.calls[:]
testing_server.error_log.calls.clear() # prevent post-test assertions
expected_log_entries = (
(
logging.DEBUG,
'^Got a server shutdown request while handling a connection '
r'from .*:\d{1,5} \(simulated test handler keyboard interrupt\)$',
),
(
logging.DEBUG,
'^Setting the server interrupt flag to KeyboardInterrupt'
r"\('simulated test handler keyboard interrupt',?\)$",
),
(
logging.INFO,
'^Keyboard Interrupt: shutting down$',
),
)
assert len(actual_log_entries) == len(expected_log_entries)
for ( # noqa: WPS352
(expected_log_level, expected_msg_regex),
(actual_msg, actual_log_level, _tb),
) in zip(expected_log_entries, actual_log_entries):
assert expected_log_level == actual_log_level
assert _matches_pattern(expected_msg_regex, actual_msg) is not None, (
f'{actual_msg !r} does not match {expected_msg_regex !r}'
)
@pytest.mark.xfail(
IS_CI and IS_PYPY and IS_PY36 and not IS_SLOW_ENV,
reason='Fails under PyPy 3.6 under Ubuntu 20.04 in CI for unknown reason',
# NOTE: Actually covers any Linux
strict=False,
)
def test_unhandled_exception_in_request_handler(
mocker,
monkeypatch,
test_client,
testing_server,
wsgi_server_thread,
):
"""Ensure worker threads are resilient to in-handler exceptions."""
class SillyMistake(BaseException): # noqa: WPS418, WPS431
"""A simulated crash within an HTTP handler."""
def _trigger_scary_exc(_req, _resp):
raise SillyMistake('simulated unhandled exception 💣 in test handler')
testing_server.wsgi_app.handlers['/scary_exc'] = _trigger_scary_exc
server_connection_close_spy = mocker.spy(
test_client.server_instance.ConnectionClass,
'close',
)
http_conn = test_client.get_connection()
http_conn.putrequest('GET', '/scary_exc', skip_host=True)
http_conn.putheader('Host', http_conn.host)
http_conn.endheaders()
# NOTE: This spy ensure the log entry gets recorded before we're testing
# NOTE: them and before server shutdown, preserving their order and making
# NOTE: the log entry presence non-flaky.
while not server_connection_close_spy.called: # noqa: WPS328
pass
assert len(testing_server.requests._threads) == 10
while testing_server.requests.idle < 10: # noqa: WPS328
pass
assert len(testing_server.requests._threads) == 10
testing_server.interrupt = SystemExit('test requesting shutdown')
assert not testing_server.requests._threads
wsgi_server_thread.join() # no extra logs upon server termination
actual_log_entries = testing_server.error_log.calls[:]
testing_server.error_log.calls.clear() # prevent post-test assertions
expected_log_entries = (
(
logging.ERROR,
'^Unhandled error while processing an incoming connection '
'SillyMistake'
r"\('simulated unhandled exception 💣 in test handler',?\)$",
),
(
logging.INFO,
'^SystemExit raised: shutting down$',
),
)
assert len(actual_log_entries) == len(expected_log_entries)
for ( # noqa: WPS352
(expected_log_level, expected_msg_regex),
(actual_msg, actual_log_level, _tb),
) in zip(expected_log_entries, actual_log_entries):
assert expected_log_level == actual_log_level
assert _matches_pattern(expected_msg_regex, actual_msg) is not None, (
f'{actual_msg !r} does not match {expected_msg_regex !r}'
)
@pytest.mark.xfail(
IS_CI and IS_PYPY and IS_PY36 and not IS_SLOW_ENV,
reason='Fails under PyPy 3.6 under Ubuntu 20.04 in CI for unknown reason',
# NOTE: Actually covers any Linux
strict=False,
)
def test_remains_alive_post_unhandled_exception(
mocker,
monkeypatch,
test_client,
testing_server,
wsgi_server_thread,
):
"""Ensure worker threads are resilient to unhandled exceptions."""
class ScaryCrash(BaseException): # noqa: WPS418, WPS431
"""A simulated crash during HTTP parsing."""
_orig_read_request_line = (
test_client.server_instance.
ConnectionClass.RequestHandlerClass.
read_request_line
)
def _read_request_line(self):
_orig_read_request_line(self)
raise ScaryCrash(666)
monkeypatch.setattr(
test_client.server_instance.ConnectionClass.RequestHandlerClass,
'read_request_line',
_read_request_line,
)
server_connection_close_spy = mocker.spy(
test_client.server_instance.ConnectionClass,
'close',
)
# NOTE: The initial worker thread count is 10.
assert len(testing_server.requests._threads) == 10
test_client.get_connection().send(b'GET / HTTP/1.1')
# NOTE: This spy ensure the log entry gets recorded before we're testing
# NOTE: them and before server shutdown, preserving their order and making
# NOTE: the log entry presence non-flaky.
while not server_connection_close_spy.called: # noqa: WPS328
pass
# NOTE: This checks for whether there's any crashed threads
while testing_server.requests.idle < 10: # noqa: WPS328
pass
assert len(testing_server.requests._threads) == 10
assert all(
worker_thread.is_alive()
for worker_thread in testing_server.requests._threads
)
testing_server.interrupt = SystemExit('test requesting shutdown')
assert not testing_server.requests._threads
wsgi_server_thread.join() # no extra logs upon server termination
actual_log_entries = testing_server.error_log.calls[:]
testing_server.error_log.calls.clear() # prevent post-test assertions
expected_log_entries = (
(
logging.ERROR,
'^Unhandled error while processing an incoming connection '
r'ScaryCrash\(666,?\)$',
),
(
logging.INFO,
'^SystemExit raised: shutting down$',
),
)
assert len(actual_log_entries) == len(expected_log_entries)
for ( # noqa: WPS352
(expected_log_level, expected_msg_regex),
(actual_msg, actual_log_level, _tb),
) in zip(expected_log_entries, actual_log_entries):
assert expected_log_level == actual_log_level
assert _matches_pattern(expected_msg_regex, actual_msg) is not None, (
f'{actual_msg !r} does not match {expected_msg_regex !r}'
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'timeout_before_headers', 'timeout_before_headers',
( (
@ -917,7 +1190,7 @@ def test_100_Continue(test_client):
status_line, _actual_headers, actual_resp_body = webtest.shb(response) status_line, _actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3]) actual_status = int(status_line[:3])
assert actual_status == 200 assert actual_status == 200
expected_resp_body = ("thanks for '%s'" % body).encode() expected_resp_body = f"thanks for '{body.decode() !s}'".encode()
assert actual_resp_body == expected_resp_body assert actual_resp_body == expected_resp_body
conn.close() conn.close()
@ -987,7 +1260,7 @@ def test_readall_or_close(test_client, max_request_body_size):
status_line, actual_headers, actual_resp_body = webtest.shb(response) status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3]) actual_status = int(status_line[:3])
assert actual_status == 200 assert actual_status == 200
expected_resp_body = ("thanks for '%s'" % body).encode() expected_resp_body = f"thanks for '{body.decode() !s}'".encode()
assert actual_resp_body == expected_resp_body assert actual_resp_body == expected_resp_body
conn.close() conn.close()

View file

@ -134,7 +134,7 @@ def test_query_string_request(test_client):
'/hello', # plain '/hello', # plain
'/query_string?test=True', # query '/query_string?test=True', # query
'/{0}?{1}={2}'.format( # quoted unicode '/{0}?{1}={2}'.format( # quoted unicode
*map(urllib.parse.quote, ('Юххууу', 'ї', 'йо')) *map(urllib.parse.quote, ('Юххууу', 'ї', 'йо')),
), ),
), ),
) )

View file

@ -31,7 +31,7 @@ config = {
@contextmanager @contextmanager
def cheroot_server(server_factory): def cheroot_server(server_factory): # noqa: WPS210
"""Set up and tear down a Cheroot server instance.""" """Set up and tear down a Cheroot server instance."""
conf = config[server_factory].copy() conf = config[server_factory].copy()
bind_port = conf.pop('bind_addr')[-1] bind_port = conf.pop('bind_addr')[-1]
@ -41,7 +41,7 @@ def cheroot_server(server_factory):
actual_bind_addr = (interface, bind_port) actual_bind_addr = (interface, bind_port)
httpserver = server_factory( # create it httpserver = server_factory( # create it
bind_addr=actual_bind_addr, bind_addr=actual_bind_addr,
**conf **conf,
) )
except OSError: except OSError:
pass pass
@ -50,27 +50,52 @@ def cheroot_server(server_factory):
httpserver.shutdown_timeout = 0 # Speed-up tests teardown httpserver.shutdown_timeout = 0 # Speed-up tests teardown
threading.Thread(target=httpserver.safe_start).start() # spawn it # FIXME: Expose this thread through a fixture so that it
# FIXME: could be awaited in tests.
server_thread = threading.Thread(target=httpserver.safe_start)
server_thread.start() # spawn it
while not httpserver.ready: # wait until fully initialized and bound while not httpserver.ready: # wait until fully initialized and bound
time.sleep(0.1) time.sleep(0.1)
yield httpserver try:
yield server_thread, httpserver
finally:
httpserver.stop() # destroy it httpserver.stop() # destroy it
server_thread.join() # wait for the thread to be turn down
@pytest.fixture @pytest.fixture
def wsgi_server(): def thread_and_wsgi_server():
"""Set up and tear down a Cheroot WSGI server instance.
This emits a tuple of a thread and a server instance.
"""
with cheroot_server(cheroot.wsgi.Server) as (server_thread, srv):
yield server_thread, srv
@pytest.fixture
def thread_and_native_server():
"""Set up and tear down a Cheroot HTTP server instance.
This emits a tuple of a thread and a server instance.
"""
with cheroot_server(cheroot.server.HTTPServer) as (server_thread, srv):
yield server_thread, srv
@pytest.fixture
def wsgi_server(thread_and_wsgi_server): # noqa: WPS442
"""Set up and tear down a Cheroot WSGI server instance.""" """Set up and tear down a Cheroot WSGI server instance."""
with cheroot_server(cheroot.wsgi.Server) as srv: _server_thread, srv = thread_and_wsgi_server
yield srv return srv
@pytest.fixture @pytest.fixture
def native_server(): def native_server(thread_and_native_server): # noqa: WPS442
"""Set up and tear down a Cheroot HTTP server instance.""" """Set up and tear down a Cheroot HTTP server instance."""
with cheroot_server(cheroot.server.HTTPServer) as srv: _server_thread, srv = thread_and_native_server
yield srv return srv
class _TestClient: class _TestClient:

View file

@ -6,6 +6,7 @@
""" """
import collections import collections
import logging
import threading import threading
import time import time
import socket import socket
@ -30,7 +31,7 @@ class TrueyZero:
trueyzero = TrueyZero() trueyzero = TrueyZero()
_SHUTDOWNREQUEST = None _SHUTDOWNREQUEST = object()
class WorkerThread(threading.Thread): class WorkerThread(threading.Thread):
@ -99,13 +100,58 @@ class WorkerThread(threading.Thread):
threading.Thread.__init__(self) threading.Thread.__init__(self)
def run(self): def run(self):
"""Process incoming HTTP connections. """Set up incoming HTTP connection processing loop.
Retrieves incoming connections from thread pool. This is the thread's entry-point. It performs lop-layer
exception handling and interrupt processing.
:exc:`KeyboardInterrupt` and :exc:`SystemExit` bubbling up
from the inner-layer code constitute a global server interrupt
request. When they happen, the worker thread exits.
:raises BaseException: when an unexpected non-interrupt
exception leaks from the inner layers
# noqa: DAR401 KeyboardInterrupt SystemExit
""" """
self.server.stats['Worker Threads'][self.name] = self.stats self.server.stats['Worker Threads'][self.name] = self.stats
try:
self.ready = True self.ready = True
try:
self._process_connections_until_interrupted()
except (KeyboardInterrupt, SystemExit) as interrupt_exc:
interrupt_cause = interrupt_exc.__cause__ or interrupt_exc
self.server.error_log(
f'Setting the server interrupt flag to {interrupt_cause !r}',
level=logging.DEBUG,
)
self.server.interrupt = interrupt_cause
except BaseException as underlying_exc: # noqa: WPS424
# NOTE: This is the last resort logging with the last dying breath
# NOTE: of the worker. It is only reachable when exceptions happen
# NOTE: in the `finally` branch of the internal try/except block.
self.server.error_log(
'A fatal exception happened. Setting the server interrupt flag'
f' to {underlying_exc !r} and giving up.'
'\N{NEW LINE}\N{NEW LINE}'
'Please, report this on the Cheroot tracker at '
'<https://github.com/cherrypy/cheroot/issues/new/choose>, '
'providing a full reproducer with as much context and details as possible.',
level=logging.CRITICAL,
traceback=True,
)
self.server.interrupt = underlying_exc
raise
finally:
self.ready = False
def _process_connections_until_interrupted(self):
"""Process incoming HTTP connections in an infinite loop.
Retrieves incoming connections from thread pool, processing
them one by one.
:raises SystemExit: on the internal requests to stop the
server instance
"""
while True: while True:
conn = self.server.requests.get() conn = self.server.requests.get()
if conn is _SHUTDOWNREQUEST: if conn is _SHUTDOWNREQUEST:
@ -118,20 +164,63 @@ class WorkerThread(threading.Thread):
keep_conn_open = False keep_conn_open = False
try: try:
keep_conn_open = conn.communicate() keep_conn_open = conn.communicate()
except ConnectionError as connection_error:
keep_conn_open = False # Drop the connection cleanly
self.server.error_log(
'Got a connection error while handling a '
f'connection from {conn.remote_addr !s}:'
f'{conn.remote_port !s} ({connection_error !s})',
level=logging.INFO,
)
continue
except (KeyboardInterrupt, SystemExit) as shutdown_request:
# Shutdown request
keep_conn_open = False # Drop the connection cleanly
self.server.error_log(
'Got a server shutdown request while handling a '
f'connection from {conn.remote_addr !s}:'
f'{conn.remote_port !s} ({shutdown_request !s})',
level=logging.DEBUG,
)
raise SystemExit(
str(shutdown_request),
) from shutdown_request
except BaseException as unhandled_error: # noqa: WPS424
# NOTE: Only a shutdown request should bubble up to the
# NOTE: external cleanup code. Otherwise, this thread dies.
# NOTE: If this were to happen, the threadpool would still
# NOTE: list a dead thread without knowing its state. And
# NOTE: the calling code would fail to schedule processing
# NOTE: of new requests.
self.server.error_log(
'Unhandled error while processing an incoming '
f'connection {unhandled_error !r}',
level=logging.ERROR,
traceback=True,
)
continue # Prevent the thread from dying
finally: finally:
# NOTE: Any exceptions coming from within `finally` may
# NOTE: kill the thread, causing the threadpool to only
# NOTE: contain references to dead threads rendering the
# NOTE: server defunct, effectively meaning a DoS.
# NOTE: Ideally, things called here should process
# NOTE: everything recoverable internally. Any unhandled
# NOTE: errors will bubble up into the outer try/except
# NOTE: block. They will be treated as fatal and turned
# NOTE: into server shutdown requests and then reraised
# NOTE: unconditionally.
if keep_conn_open: if keep_conn_open:
self.server.put_conn(conn) self.server.put_conn(conn)
else: else:
conn.close() conn.close()
if is_stats_enabled: if is_stats_enabled:
self.requests_seen += self.conn.requests_seen self.requests_seen += conn.requests_seen
self.bytes_read += self.conn.rfile.bytes_read self.bytes_read += conn.rfile.bytes_read
self.bytes_written += self.conn.wfile.bytes_written self.bytes_written += conn.wfile.bytes_written
self.work_time += time.time() - self.start_time self.work_time += time.time() - self.start_time
self.start_time = None self.start_time = None
self.conn = None self.conn = None
except (KeyboardInterrupt, SystemExit) as ex:
self.server.interrupt = ex
class ThreadPool: class ThreadPool:

View file

@ -52,7 +52,7 @@ Automatic conversion
-------------------- --------------------
An included script called `futurize An included script called `futurize
<http://python-future.org/automatic_conversion.html>`_ aids in converting <https://python-future.org/automatic_conversion.html>`_ aids in converting
code (from either Python 2 or Python 3) to code compatible with both code (from either Python 2 or Python 3) to code compatible with both
platforms. It is similar to ``python-modernize`` but goes further in platforms. It is similar to ``python-modernize`` but goes further in
providing Python 3 compatibility through the use of the backported types providing Python 3 compatibility through the use of the backported types
@ -62,21 +62,20 @@ and builtin functions in ``future``.
Documentation Documentation
------------- -------------
See: http://python-future.org See: https://python-future.org
Credits Credits
------- -------
:Author: Ed Schofield, Jordan M. Adler, et al :Author: Ed Schofield, Jordan M. Adler, et al
:Sponsor: Python Charmers Pty Ltd, Australia, and Python Charmers Pte :Sponsor: Python Charmers: https://pythoncharmers.com
Ltd, Singapore. http://pythoncharmers.com :Others: See docs/credits.rst or https://python-future.org/credits.html
:Others: See docs/credits.rst or http://python-future.org/credits.html
Licensing Licensing
--------- ---------
Copyright 2013-2019 Python Charmers Pty Ltd, Australia. Copyright 2013-2024 Python Charmers, Australia.
The software is distributed under an MIT licence. See LICENSE.txt. The software is distributed under an MIT licence. See LICENSE.txt.
""" """
@ -84,10 +83,10 @@ The software is distributed under an MIT licence. See LICENSE.txt.
__title__ = 'future' __title__ = 'future'
__author__ = 'Ed Schofield' __author__ = 'Ed Schofield'
__license__ = 'MIT' __license__ = 'MIT'
__copyright__ = 'Copyright 2013-2019 Python Charmers Pty Ltd' __copyright__ = 'Copyright 2013-2024 Python Charmers (https://pythoncharmers.com)'
__ver_major__ = 0 __ver_major__ = 1
__ver_minor__ = 18 __ver_minor__ = 0
__ver_patch__ = 3 __ver_patch__ = 0
__ver_sub__ = '' __ver_sub__ = ''
__version__ = "%d.%d.%d%s" % (__ver_major__, __ver_minor__, __version__ = "%d.%d.%d%s" % (__ver_major__, __ver_minor__,
__ver_patch__, __ver_sub__) __ver_patch__, __ver_sub__)

View file

@ -689,7 +689,7 @@ class date(object):
@classmethod @classmethod
def fromordinal(cls, n): def fromordinal(cls, n):
"""Contruct a date from a proleptic Gregorian ordinal. """Construct a date from a proleptic Gregorian ordinal.
January 1 of year 1 is day 1. Only the year, month and day are January 1 of year 1 is day 1. Only the year, month and day are
non-zero in the result. non-zero in the result.

View file

@ -2867,7 +2867,7 @@ def parse_content_type_header(value):
_find_mime_parameters(ctype, value) _find_mime_parameters(ctype, value)
return ctype return ctype
ctype.append(token) ctype.append(token)
# XXX: If we really want to follow the formal grammer we should make # XXX: If we really want to follow the formal grammar we should make
# mantype and subtype specialized TokenLists here. Probably not worth it. # mantype and subtype specialized TokenLists here. Probably not worth it.
if not value or value[0] != '/': if not value or value[0] != '/':
ctype.defects.append(errors.InvalidHeaderDefect( ctype.defects.append(errors.InvalidHeaderDefect(

View file

@ -26,7 +26,7 @@ class Parser(object):
textual representation of the message. textual representation of the message.
The string must be formatted as a block of RFC 2822 headers and header The string must be formatted as a block of RFC 2822 headers and header
continuation lines, optionally preceeded by a `Unix-from' header. The continuation lines, optionally preceded by a `Unix-from' header. The
header block is terminated either by the end of the string or by a header block is terminated either by the end of the string or by a
blank line. blank line.
@ -92,7 +92,7 @@ class BytesParser(object):
textual representation of the message. textual representation of the message.
The input must be formatted as a block of RFC 2822 headers and header The input must be formatted as a block of RFC 2822 headers and header
continuation lines, optionally preceeded by a `Unix-from' header. The continuation lines, optionally preceded by a `Unix-from' header. The
header block is terminated either by the end of the input or by a header block is terminated either by the end of the input or by a
blank line. blank line.

View file

@ -1851,7 +1851,7 @@ def lwp_cookie_str(cookie):
class LWPCookieJar(FileCookieJar): class LWPCookieJar(FileCookieJar):
""" """
The LWPCookieJar saves a sequence of "Set-Cookie3" lines. The LWPCookieJar saves a sequence of "Set-Cookie3" lines.
"Set-Cookie3" is the format used by the libwww-perl libary, not known "Set-Cookie3" is the format used by the libwww-perl library, not known
to be compatible with any browser, but which is easy to read and to be compatible with any browser, but which is easy to read and
doesn't lose information about RFC 2965 cookies. doesn't lose information about RFC 2965 cookies.

View file

@ -28,7 +28,6 @@ import importlib
# import collections.abc # not present on Py2.7 # import collections.abc # not present on Py2.7
import re import re
import subprocess import subprocess
import imp
import time import time
try: try:
import sysconfig import sysconfig
@ -341,37 +340,6 @@ def rmtree(path):
if error.errno != errno.ENOENT: if error.errno != errno.ENOENT:
raise raise
def make_legacy_pyc(source):
"""Move a PEP 3147 pyc/pyo file to its legacy pyc/pyo location.
The choice of .pyc or .pyo extension is done based on the __debug__ flag
value.
:param source: The file system path to the source file. The source file
does not need to exist, however the PEP 3147 pyc file must exist.
:return: The file system path to the legacy pyc file.
"""
pyc_file = imp.cache_from_source(source)
up_one = os.path.dirname(os.path.abspath(source))
legacy_pyc = os.path.join(up_one, source + ('c' if __debug__ else 'o'))
os.rename(pyc_file, legacy_pyc)
return legacy_pyc
def forget(modname):
"""'Forget' a module was ever imported.
This removes the module from sys.modules and deletes any PEP 3147 or
legacy .pyc and .pyo files.
"""
unload(modname)
for dirname in sys.path:
source = os.path.join(dirname, modname + '.py')
# It doesn't matter if they exist or not, unlink all possible
# combinations of PEP 3147 and legacy pyc and pyo files.
unlink(source + 'c')
unlink(source + 'o')
unlink(imp.cache_from_source(source, debug_override=True))
unlink(imp.cache_from_source(source, debug_override=False))
# On some platforms, should not run gui test even if it is allowed # On some platforms, should not run gui test even if it is allowed
# in `use_resources'. # in `use_resources'.

View file

@ -134,10 +134,11 @@ from __future__ import (absolute_import, division, print_function,
from future.builtins import bytes, dict, int, range, str from future.builtins import bytes, dict, int, range, str
import base64 import base64
import sys
if sys.version_info < (3, 9):
# Py2.7 compatibility hack # Py2.7 compatibility hack
base64.encodebytes = base64.encodestring base64.encodebytes = base64.encodestring
base64.decodebytes = base64.decodestring base64.decodebytes = base64.decodestring
import sys
import time import time
from datetime import datetime from datetime import datetime
from future.backports.http import client as http_client from future.backports.http import client as http_client
@ -1251,7 +1252,7 @@ class Transport(object):
# Send HTTP request. # Send HTTP request.
# #
# @param host Host descriptor (URL or (URL, x509 info) tuple). # @param host Host descriptor (URL or (URL, x509 info) tuple).
# @param handler Targer RPC handler (a path relative to host) # @param handler Target RPC handler (a path relative to host)
# @param request_body The XML-RPC request body # @param request_body The XML-RPC request body
# @param debug Enable debugging if debug is true. # @param debug Enable debugging if debug is true.
# @return An HTTPConnection. # @return An HTTPConnection.

View file

@ -2,7 +2,7 @@
A module that brings in equivalents of the new and modified Python 3 A module that brings in equivalents of the new and modified Python 3
builtins into Py2. Has no effect on Py3. builtins into Py2. Has no effect on Py3.
See the docs `here <http://python-future.org/what-else.html>`_ See the docs `here <https://python-future.org/what-else.html>`_
(``docs/what-else.rst``) for more information. (``docs/what-else.rst``) for more information.
""" """

View file

@ -1,7 +1,12 @@
from __future__ import absolute_import from __future__ import absolute_import
from future.utils import PY3 from future.utils import PY3, PY39_PLUS
if PY3:
if PY39_PLUS:
# _dummy_thread and dummy_threading modules were both deprecated in
# Python 3.7 and removed in Python 3.9
from _thread import *
elif PY3:
from _dummy_thread import * from _dummy_thread import *
else: else:
__future_module__ = True __future_module__ = True

View file

@ -0,0 +1,7 @@
from __future__ import absolute_import
from future.utils import PY3
from multiprocessing import *
if not PY3:
__future_module__ = True
from multiprocessing.queues import SimpleQueue

View file

@ -1,9 +1,18 @@
from __future__ import absolute_import from __future__ import absolute_import
import sys
from future.standard_library import suspend_hooks from future.standard_library import suspend_hooks
from future.utils import PY3 from future.utils import PY3
if PY3: if PY3:
from test.support import * from test.support import *
if sys.version_info[:2] >= (3, 10):
from test.support.os_helper import (
EnvironmentVarGuard,
TESTFN,
)
from test.support.warnings_helper import check_warnings
else: else:
__future_module__ = True __future_module__ = True
with suspend_hooks(): with suspend_hooks():

View file

@ -17,7 +17,7 @@ And then these normal Py3 imports work on both Py3 and Py2::
import socketserver import socketserver
import winreg # on Windows only import winreg # on Windows only
import test.support import test.support
import html, html.parser, html.entites import html, html.parser, html.entities
import http, http.client, http.server import http, http.client, http.server
import http.cookies, http.cookiejar import http.cookies, http.cookiejar
import urllib.parse, urllib.request, urllib.response, urllib.error, urllib.robotparser import urllib.parse, urllib.request, urllib.response, urllib.error, urllib.robotparser
@ -33,6 +33,7 @@ And then these normal Py3 imports work on both Py3 and Py2::
from collections import OrderedDict, Counter, ChainMap # even on Py2.6 from collections import OrderedDict, Counter, ChainMap # even on Py2.6
from subprocess import getoutput, getstatusoutput from subprocess import getoutput, getstatusoutput
from subprocess import check_output # even on Py2.6 from subprocess import check_output # even on Py2.6
from multiprocessing import SimpleQueue
(The renamed modules and functions are still available under their old (The renamed modules and functions are still available under their old
names on Python 2.) names on Python 2.)
@ -62,9 +63,12 @@ from __future__ import absolute_import, division, print_function
import sys import sys
import logging import logging
# imp was deprecated in python 3.6
if sys.version_info >= (3, 6):
import importlib as imp
else:
import imp import imp
import contextlib import contextlib
import types
import copy import copy
import os import os
@ -108,6 +112,7 @@ RENAMES = {
'future.moves.socketserver': 'socketserver', 'future.moves.socketserver': 'socketserver',
'ConfigParser': 'configparser', 'ConfigParser': 'configparser',
'repr': 'reprlib', 'repr': 'reprlib',
'multiprocessing.queues': 'multiprocessing',
# 'FileDialog': 'tkinter.filedialog', # 'FileDialog': 'tkinter.filedialog',
# 'tkFileDialog': 'tkinter.filedialog', # 'tkFileDialog': 'tkinter.filedialog',
# 'SimpleDialog': 'tkinter.simpledialog', # 'SimpleDialog': 'tkinter.simpledialog',
@ -125,7 +130,7 @@ RENAMES = {
# 'Tkinter': 'tkinter', # 'Tkinter': 'tkinter',
'_winreg': 'winreg', '_winreg': 'winreg',
'thread': '_thread', 'thread': '_thread',
'dummy_thread': '_dummy_thread', 'dummy_thread': '_dummy_thread' if sys.version_info < (3, 9) else '_thread',
# 'anydbm': 'dbm', # causes infinite import loop # 'anydbm': 'dbm', # causes infinite import loop
# 'whichdb': 'dbm', # causes infinite import loop # 'whichdb': 'dbm', # causes infinite import loop
# anydbm and whichdb are handled by fix_imports2 # anydbm and whichdb are handled by fix_imports2
@ -184,6 +189,7 @@ MOVES = [('collections', 'UserList', 'UserList', 'UserList'),
('itertools', 'filterfalse','itertools', 'ifilterfalse'), ('itertools', 'filterfalse','itertools', 'ifilterfalse'),
('itertools', 'zip_longest','itertools', 'izip_longest'), ('itertools', 'zip_longest','itertools', 'izip_longest'),
('sys', 'intern','__builtin__', 'intern'), ('sys', 'intern','__builtin__', 'intern'),
('multiprocessing', 'SimpleQueue', 'multiprocessing.queues', 'SimpleQueue'),
# The re module has no ASCII flag in Py2, but this is the default. # The re module has no ASCII flag in Py2, but this is the default.
# Set re.ASCII to a zero constant. stat.ST_MODE just happens to be one # Set re.ASCII to a zero constant. stat.ST_MODE just happens to be one
# (and it exists on Py2.6+). # (and it exists on Py2.6+).

View file

@ -223,9 +223,11 @@ class newint(with_metaclass(BaseNewInt, long)):
def __rpow__(self, other): def __rpow__(self, other):
value = super(newint, self).__rpow__(other) value = super(newint, self).__rpow__(other)
if value is NotImplemented: if isint(value):
return other ** long(self)
return newint(value) return newint(value)
elif value is NotImplemented:
return other ** long(self)
return value
def __lshift__(self, other): def __lshift__(self, other):
if not isint(other): if not isint(other):
@ -318,7 +320,7 @@ class newint(with_metaclass(BaseNewInt, long)):
bits = length * 8 bits = length * 8
num = (2**bits) + self num = (2**bits) + self
if num <= 0: if num <= 0:
raise OverflowError("int too smal to convert") raise OverflowError("int too small to convert")
else: else:
if self < 0: if self < 0:
raise OverflowError("can't convert negative int to unsigned") raise OverflowError("can't convert negative int to unsigned")

View file

@ -105,7 +105,7 @@ class newrange(Sequence):
raise ValueError('%r is not in range' % value) raise ValueError('%r is not in range' % value)
def count(self, value): def count(self, value):
"""Return the number of ocurrences of integer `value` """Return the number of occurrences of integer `value`
in the sequence this range represents.""" in the sequence this range represents."""
# a value can occur exactly zero or one times # a value can occur exactly zero or one times
return int(value in self) return int(value in self)

View file

@ -3,6 +3,8 @@ inflect: english language inflection
- correctly generate plurals, ordinals, indefinite articles - correctly generate plurals, ordinals, indefinite articles
- convert numbers to words - convert numbers to words
Copyright (C) 2010 Paul Dyson
Based upon the Perl module Based upon the Perl module
`Lingua::EN::Inflect <https://metacpan.org/pod/Lingua::EN::Inflect>`_. `Lingua::EN::Inflect <https://metacpan.org/pod/Lingua::EN::Inflect>`_.
@ -50,34 +52,33 @@ Exceptions:
""" """
from __future__ import annotations
import ast import ast
import re
import functools
import collections import collections
import contextlib import contextlib
import functools
import itertools
import re
from numbers import Number
from typing import ( from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict, Dict,
Union,
Optional,
Iterable, Iterable,
List, List,
Match, Match,
Tuple, Optional,
Callable,
Sequence, Sequence,
Tuple,
Union,
cast, cast,
Any,
) )
from typing_extensions import Literal
from numbers import Number
from more_itertools import windowed_complete
from pydantic import Field from typeguard import typechecked
from typing_extensions import Annotated from typing_extensions import Annotated, Literal
from .compat.pydantic1 import validate_call
from .compat.pydantic import same_method
class UnknownClassicalModeError(Exception): class UnknownClassicalModeError(Exception):
@ -258,9 +259,9 @@ si_sb_irregular_compound = {v: k for (k, v) in pl_sb_irregular_compound.items()}
for k in list(si_sb_irregular_compound): for k in list(si_sb_irregular_compound):
if "|" in k: if "|" in k:
k1, k2 = k.split("|") k1, k2 = k.split("|")
si_sb_irregular_compound[k1] = si_sb_irregular_compound[ si_sb_irregular_compound[k1] = si_sb_irregular_compound[k2] = (
k2 si_sb_irregular_compound[k]
] = si_sb_irregular_compound[k] )
del si_sb_irregular_compound[k] del si_sb_irregular_compound[k]
# si_sb_irregular_keys = enclose('|'.join(si_sb_irregular.keys())) # si_sb_irregular_keys = enclose('|'.join(si_sb_irregular.keys()))
@ -1597,7 +1598,7 @@ pl_prep_bysize = bysize(pl_prep_list_da)
pl_prep = enclose("|".join(pl_prep_list_da)) pl_prep = enclose("|".join(pl_prep_list_da))
pl_sb_prep_dual_compound = fr"(.*?)((?:-|\s+)(?:{pl_prep})(?:-|\s+))a(?:-|\s+)(.*)" pl_sb_prep_dual_compound = rf"(.*?)((?:-|\s+)(?:{pl_prep})(?:-|\s+))a(?:-|\s+)(.*)"
singular_pronoun_genders = { singular_pronoun_genders = {
@ -1764,7 +1765,7 @@ plverb_ambiguous_pres = {
} }
plverb_ambiguous_pres_keys = re.compile( plverb_ambiguous_pres_keys = re.compile(
fr"^({enclose('|'.join(plverb_ambiguous_pres))})((\s.*)?)$", re.IGNORECASE rf"^({enclose('|'.join(plverb_ambiguous_pres))})((\s.*)?)$", re.IGNORECASE
) )
@ -1804,7 +1805,7 @@ pl_count_one = ("1", "a", "an", "one", "each", "every", "this", "that")
pl_adj_special = {"a": "some", "an": "some", "this": "these", "that": "those"} pl_adj_special = {"a": "some", "an": "some", "this": "these", "that": "those"}
pl_adj_special_keys = re.compile( pl_adj_special_keys = re.compile(
fr"^({enclose('|'.join(pl_adj_special))})$", re.IGNORECASE rf"^({enclose('|'.join(pl_adj_special))})$", re.IGNORECASE
) )
pl_adj_poss = { pl_adj_poss = {
@ -1816,7 +1817,7 @@ pl_adj_poss = {
"their": "their", "their": "their",
} }
pl_adj_poss_keys = re.compile(fr"^({enclose('|'.join(pl_adj_poss))})$", re.IGNORECASE) pl_adj_poss_keys = re.compile(rf"^({enclose('|'.join(pl_adj_poss))})$", re.IGNORECASE)
# 2. INDEFINITE ARTICLES # 2. INDEFINITE ARTICLES
@ -1883,7 +1884,7 @@ ordinal = dict(
twelve="twelfth", twelve="twelfth",
) )
ordinal_suff = re.compile(fr"({'|'.join(ordinal)})\Z") ordinal_suff = re.compile(rf"({'|'.join(ordinal)})\Z")
# NUMBERS # NUMBERS
@ -1948,13 +1949,13 @@ DOLLAR_DIGITS = re.compile(r"\$(\d+)")
FUNCTION_CALL = re.compile(r"((\w+)\([^)]*\)*)", re.IGNORECASE) FUNCTION_CALL = re.compile(r"((\w+)\([^)]*\)*)", re.IGNORECASE)
PARTITION_WORD = re.compile(r"\A(\s*)(.+?)(\s*)\Z") PARTITION_WORD = re.compile(r"\A(\s*)(.+?)(\s*)\Z")
PL_SB_POSTFIX_ADJ_STEMS_RE = re.compile( PL_SB_POSTFIX_ADJ_STEMS_RE = re.compile(
fr"^(?:{pl_sb_postfix_adj_stems})$", re.IGNORECASE rf"^(?:{pl_sb_postfix_adj_stems})$", re.IGNORECASE
) )
PL_SB_PREP_DUAL_COMPOUND_RE = re.compile( PL_SB_PREP_DUAL_COMPOUND_RE = re.compile(
fr"^(?:{pl_sb_prep_dual_compound})$", re.IGNORECASE rf"^(?:{pl_sb_prep_dual_compound})$", re.IGNORECASE
) )
DENOMINATOR = re.compile(r"(?P<denominator>.+)( (per|a) .+)") DENOMINATOR = re.compile(r"(?P<denominator>.+)( (per|a) .+)")
PLVERB_SPECIAL_S_RE = re.compile(fr"^({plverb_special_s})$") PLVERB_SPECIAL_S_RE = re.compile(rf"^({plverb_special_s})$")
WHITESPACE = re.compile(r"\s") WHITESPACE = re.compile(r"\s")
ENDS_WITH_S = re.compile(r"^(.*[^s])s$", re.IGNORECASE) ENDS_WITH_S = re.compile(r"^(.*[^s])s$", re.IGNORECASE)
ENDS_WITH_APOSTROPHE_S = re.compile(r"^(.*)'s?$") ENDS_WITH_APOSTROPHE_S = re.compile(r"^(.*)'s?$")
@ -2020,10 +2021,25 @@ class Words(str):
self.last = self.split_[-1] self.last = self.split_[-1]
Word = Annotated[str, Field(min_length=1)]
Falsish = Any # ideally, falsish would only validate on bool(value) is False Falsish = Any # ideally, falsish would only validate on bool(value) is False
_STATIC_TYPE_CHECKING = TYPE_CHECKING
# ^-- Workaround for typeguard AST manipulation:
# https://github.com/agronholm/typeguard/issues/353#issuecomment-1556306554
if _STATIC_TYPE_CHECKING: # pragma: no cover
Word = Annotated[str, "String with at least 1 character"]
else:
class _WordMeta(type): # Too dynamic to be supported by mypy...
def __instancecheck__(self, instance: Any) -> bool:
return isinstance(instance, str) and len(instance) >= 1
class Word(metaclass=_WordMeta): # type: ignore[no-redef]
"""String with at least 1 character"""
class engine: class engine:
def __init__(self) -> None: def __init__(self) -> None:
self.classical_dict = def_classical.copy() self.classical_dict = def_classical.copy()
@ -2045,7 +2061,7 @@ class engine:
def _number_args(self, val): def _number_args(self, val):
self.__number_args = val self.__number_args = val
@validate_call @typechecked
def defnoun(self, singular: Optional[Word], plural: Optional[Word]) -> int: def defnoun(self, singular: Optional[Word], plural: Optional[Word]) -> int:
""" """
Set the noun plural of singular to plural. Set the noun plural of singular to plural.
@ -2057,7 +2073,7 @@ class engine:
self.si_sb_user_defined.extend((plural, singular)) self.si_sb_user_defined.extend((plural, singular))
return 1 return 1
@validate_call @typechecked
def defverb( def defverb(
self, self,
s1: Optional[Word], s1: Optional[Word],
@ -2082,7 +2098,7 @@ class engine:
self.pl_v_user_defined.extend((s1, p1, s2, p2, s3, p3)) self.pl_v_user_defined.extend((s1, p1, s2, p2, s3, p3))
return 1 return 1
@validate_call @typechecked
def defadj(self, singular: Optional[Word], plural: Optional[Word]) -> int: def defadj(self, singular: Optional[Word], plural: Optional[Word]) -> int:
""" """
Set the adjective plural of singular to plural. Set the adjective plural of singular to plural.
@ -2093,7 +2109,7 @@ class engine:
self.pl_adj_user_defined.extend((singular, plural)) self.pl_adj_user_defined.extend((singular, plural))
return 1 return 1
@validate_call @typechecked
def defa(self, pattern: Optional[Word]) -> int: def defa(self, pattern: Optional[Word]) -> int:
""" """
Define the indefinite article as 'a' for words matching pattern. Define the indefinite article as 'a' for words matching pattern.
@ -2103,7 +2119,7 @@ class engine:
self.A_a_user_defined.extend((pattern, "a")) self.A_a_user_defined.extend((pattern, "a"))
return 1 return 1
@validate_call @typechecked
def defan(self, pattern: Optional[Word]) -> int: def defan(self, pattern: Optional[Word]) -> int:
""" """
Define the indefinite article as 'an' for words matching pattern. Define the indefinite article as 'an' for words matching pattern.
@ -2121,8 +2137,8 @@ class engine:
return return
try: try:
re.match(pattern, "") re.match(pattern, "")
except re.error: except re.error as err:
raise BadUserDefinedPatternError(pattern) raise BadUserDefinedPatternError(pattern) from err
def checkpatplural(self, pattern: Optional[Word]) -> None: def checkpatplural(self, pattern: Optional[Word]) -> None:
""" """
@ -2130,10 +2146,10 @@ class engine:
""" """
return return
@validate_call @typechecked
def ud_match(self, word: Word, wordlist: Sequence[Optional[Word]]) -> Optional[str]: def ud_match(self, word: Word, wordlist: Sequence[Optional[Word]]) -> Optional[str]:
for i in range(len(wordlist) - 2, -2, -2): # backwards through even elements for i in range(len(wordlist) - 2, -2, -2): # backwards through even elements
mo = re.search(fr"^{wordlist[i]}$", word, re.IGNORECASE) mo = re.search(rf"^{wordlist[i]}$", word, re.IGNORECASE)
if mo: if mo:
if wordlist[i + 1] is None: if wordlist[i + 1] is None:
return None return None
@ -2191,8 +2207,8 @@ class engine:
if count is not None: if count is not None:
try: try:
self.persistent_count = int(count) self.persistent_count = int(count)
except ValueError: except ValueError as err:
raise BadNumValueError raise BadNumValueError from err
if (show is None) or show: if (show is None) or show:
return str(count) return str(count)
else: else:
@ -2270,7 +2286,7 @@ class engine:
# 0. PERFORM GENERAL INFLECTIONS IN A STRING # 0. PERFORM GENERAL INFLECTIONS IN A STRING
@validate_call @typechecked
def inflect(self, text: Word) -> str: def inflect(self, text: Word) -> str:
""" """
Perform inflections in a string. Perform inflections in a string.
@ -2347,7 +2363,7 @@ class engine:
else: else:
return "", "", "" return "", "", ""
@validate_call @typechecked
def plural(self, text: Word, count: Optional[Union[str, int, Any]] = None) -> str: def plural(self, text: Word, count: Optional[Union[str, int, Any]] = None) -> str:
""" """
Return the plural of text. Return the plural of text.
@ -2371,7 +2387,7 @@ class engine:
) )
return f"{pre}{plural}{post}" return f"{pre}{plural}{post}"
@validate_call @typechecked
def plural_noun( def plural_noun(
self, text: Word, count: Optional[Union[str, int, Any]] = None self, text: Word, count: Optional[Union[str, int, Any]] = None
) -> str: ) -> str:
@ -2392,7 +2408,7 @@ class engine:
plural = self.postprocess(word, self._plnoun(word, count)) plural = self.postprocess(word, self._plnoun(word, count))
return f"{pre}{plural}{post}" return f"{pre}{plural}{post}"
@validate_call @typechecked
def plural_verb( def plural_verb(
self, text: Word, count: Optional[Union[str, int, Any]] = None self, text: Word, count: Optional[Union[str, int, Any]] = None
) -> str: ) -> str:
@ -2416,7 +2432,7 @@ class engine:
) )
return f"{pre}{plural}{post}" return f"{pre}{plural}{post}"
@validate_call @typechecked
def plural_adj( def plural_adj(
self, text: Word, count: Optional[Union[str, int, Any]] = None self, text: Word, count: Optional[Union[str, int, Any]] = None
) -> str: ) -> str:
@ -2437,7 +2453,7 @@ class engine:
plural = self.postprocess(word, self._pl_special_adjective(word, count) or word) plural = self.postprocess(word, self._pl_special_adjective(word, count) or word)
return f"{pre}{plural}{post}" return f"{pre}{plural}{post}"
@validate_call @typechecked
def compare(self, word1: Word, word2: Word) -> Union[str, bool]: def compare(self, word1: Word, word2: Word) -> Union[str, bool]:
""" """
compare word1 and word2 for equality regardless of plurality compare word1 and word2 for equality regardless of plurality
@ -2460,15 +2476,13 @@ class engine:
>>> compare('egg', '') >>> compare('egg', '')
Traceback (most recent call last): Traceback (most recent call last):
... ...
pydantic...ValidationError: ... typeguard.TypeCheckError:...is not an instance of inflect.Word
...
...at least 1 characters...
""" """
norms = self.plural_noun, self.plural_verb, self.plural_adj norms = self.plural_noun, self.plural_verb, self.plural_adj
results = (self._plequal(word1, word2, norm) for norm in norms) results = (self._plequal(word1, word2, norm) for norm in norms)
return next(filter(None, results), False) return next(filter(None, results), False)
@validate_call @typechecked
def compare_nouns(self, word1: Word, word2: Word) -> Union[str, bool]: def compare_nouns(self, word1: Word, word2: Word) -> Union[str, bool]:
""" """
compare word1 and word2 for equality regardless of plurality compare word1 and word2 for equality regardless of plurality
@ -2484,7 +2498,7 @@ class engine:
""" """
return self._plequal(word1, word2, self.plural_noun) return self._plequal(word1, word2, self.plural_noun)
@validate_call @typechecked
def compare_verbs(self, word1: Word, word2: Word) -> Union[str, bool]: def compare_verbs(self, word1: Word, word2: Word) -> Union[str, bool]:
""" """
compare word1 and word2 for equality regardless of plurality compare word1 and word2 for equality regardless of plurality
@ -2500,7 +2514,7 @@ class engine:
""" """
return self._plequal(word1, word2, self.plural_verb) return self._plequal(word1, word2, self.plural_verb)
@validate_call @typechecked
def compare_adjs(self, word1: Word, word2: Word) -> Union[str, bool]: def compare_adjs(self, word1: Word, word2: Word) -> Union[str, bool]:
""" """
compare word1 and word2 for equality regardless of plurality compare word1 and word2 for equality regardless of plurality
@ -2516,7 +2530,7 @@ class engine:
""" """
return self._plequal(word1, word2, self.plural_adj) return self._plequal(word1, word2, self.plural_adj)
@validate_call @typechecked
def singular_noun( def singular_noun(
self, self,
text: Word, text: Word,
@ -2574,18 +2588,18 @@ class engine:
return "s:p" return "s:p"
self.classical_dict = classval.copy() self.classical_dict = classval.copy()
if same_method(pl, self.plural) or same_method(pl, self.plural_noun): if pl == self.plural or pl == self.plural_noun:
if self._pl_check_plurals_N(word1, word2): if self._pl_check_plurals_N(word1, word2):
return "p:p" return "p:p"
if self._pl_check_plurals_N(word2, word1): if self._pl_check_plurals_N(word2, word1):
return "p:p" return "p:p"
if same_method(pl, self.plural) or same_method(pl, self.plural_adj): if pl == self.plural or pl == self.plural_adj:
if self._pl_check_plurals_adj(word1, word2): if self._pl_check_plurals_adj(word1, word2):
return "p:p" return "p:p"
return False return False
def _pl_reg_plurals(self, pair: str, stems: str, end1: str, end2: str) -> bool: def _pl_reg_plurals(self, pair: str, stems: str, end1: str, end2: str) -> bool:
pattern = fr"({stems})({end1}\|\1{end2}|{end2}\|\1{end1})" pattern = rf"({stems})({end1}\|\1{end2}|{end2}\|\1{end1})"
return bool(re.search(pattern, pair)) return bool(re.search(pattern, pair))
def _pl_check_plurals_N(self, word1: str, word2: str) -> bool: def _pl_check_plurals_N(self, word1: str, word2: str) -> bool:
@ -2679,6 +2693,8 @@ class engine:
word = Words(word) word = Words(word)
if word.last.lower() in pl_sb_uninflected_complete: if word.last.lower() in pl_sb_uninflected_complete:
if len(word.split_) >= 3:
return self._handle_long_compounds(word, count=2) or word
return word return word
if word in pl_sb_uninflected_caps: if word in pl_sb_uninflected_caps:
@ -2707,13 +2723,9 @@ class engine:
) )
if len(word.split_) >= 3: if len(word.split_) >= 3:
for numword in range(1, len(word.split_) - 1): handled_words = self._handle_long_compounds(word, count=2)
if word.split_[numword] in pl_prep_list_da: if handled_words is not None:
return " ".join( return handled_words
word.split_[: numword - 1]
+ [self._plnoun(word.split_[numword - 1], 2)]
+ word.split_[numword:]
)
# only pluralize denominators in units # only pluralize denominators in units
mo = DENOMINATOR.search(word.lowered) mo = DENOMINATOR.search(word.lowered)
@ -2972,6 +2984,30 @@ class engine:
parts[: pivot - 1] + [sep.join([transformed, parts[pivot], ''])] parts[: pivot - 1] + [sep.join([transformed, parts[pivot], ''])]
) + " ".join(parts[(pivot + 1) :]) ) + " ".join(parts[(pivot + 1) :])
def _handle_long_compounds(self, word: Words, count: int) -> Union[str, None]:
"""
Handles the plural and singular for compound `Words` that
have three or more words, based on the given count.
>>> engine()._handle_long_compounds(Words("pair of scissors"), 2)
'pairs of scissors'
>>> engine()._handle_long_compounds(Words("men beyond hills"), 1)
'man beyond hills'
"""
inflection = self._sinoun if count == 1 else self._plnoun
solutions = ( # type: ignore
" ".join(
itertools.chain(
leader,
[inflection(cand, count), prep], # type: ignore
trailer,
)
)
for leader, (cand, prep), trailer in windowed_complete(word.split_, 2)
if prep in pl_prep_list_da # type: ignore
)
return next(solutions, None)
@staticmethod @staticmethod
def _find_pivot(words, candidates): def _find_pivot(words, candidates):
pivots = ( pivots = (
@ -2980,7 +3016,7 @@ class engine:
try: try:
return next(pivots) return next(pivots)
except StopIteration: except StopIteration:
raise ValueError("No pivot found") raise ValueError("No pivot found") from None
def _pl_special_verb( # noqa: C901 def _pl_special_verb( # noqa: C901
self, word: str, count: Optional[Union[str, int]] = None self, word: str, count: Optional[Union[str, int]] = None
@ -3145,8 +3181,8 @@ class engine:
gender = self.thegender gender = self.thegender
elif gender not in singular_pronoun_genders: elif gender not in singular_pronoun_genders:
raise BadGenderError raise BadGenderError
except (TypeError, IndexError): except (TypeError, IndexError) as err:
raise BadGenderError raise BadGenderError from err
# HANDLE USER-DEFINED NOUNS # HANDLE USER-DEFINED NOUNS
@ -3165,6 +3201,8 @@ class engine:
words = Words(word) words = Words(word)
if words.last.lower() in pl_sb_uninflected_complete: if words.last.lower() in pl_sb_uninflected_complete:
if len(words.split_) >= 3:
return self._handle_long_compounds(words, count=1) or word
return word return word
if word in pl_sb_uninflected_caps: if word in pl_sb_uninflected_caps:
@ -3450,7 +3488,7 @@ class engine:
# ADJECTIVES # ADJECTIVES
@validate_call @typechecked
def a(self, text: Word, count: Optional[Union[int, str, Any]] = 1) -> str: def a(self, text: Word, count: Optional[Union[int, str, Any]] = 1) -> str:
""" """
Return the appropriate indefinite article followed by text. Return the appropriate indefinite article followed by text.
@ -3531,7 +3569,7 @@ class engine:
# 2. TRANSLATE ZERO-QUANTIFIED $word TO "no plural($word)" # 2. TRANSLATE ZERO-QUANTIFIED $word TO "no plural($word)"
@validate_call @typechecked
def no(self, text: Word, count: Optional[Union[int, str]] = None) -> str: def no(self, text: Word, count: Optional[Union[int, str]] = None) -> str:
""" """
If count is 0, no, zero or nil, return 'no' followed by the plural If count is 0, no, zero or nil, return 'no' followed by the plural
@ -3569,7 +3607,7 @@ class engine:
# PARTICIPLES # PARTICIPLES
@validate_call @typechecked
def present_participle(self, word: Word) -> str: def present_participle(self, word: Word) -> str:
""" """
Return the present participle for word. Return the present participle for word.
@ -3588,7 +3626,7 @@ class engine:
# NUMERICAL INFLECTIONS # NUMERICAL INFLECTIONS
@validate_call(config=dict(arbitrary_types_allowed=True)) @typechecked
def ordinal(self, num: Union[Number, Word]) -> str: def ordinal(self, num: Union[Number, Word]) -> str:
""" """
Return the ordinal of num. Return the ordinal of num.
@ -3619,16 +3657,7 @@ class engine:
post = nth[n % 10] post = nth[n % 10]
return f"{num}{post}" return f"{num}{post}"
else: else:
# Mad props to Damian Conway (?) whose ordinal() return self._sub_ord(num)
# algorithm is type-bendy enough to foil MyPy
str_num: str = num # type: ignore[assignment]
mo = ordinal_suff.search(str_num)
if mo:
post = ordinal[mo.group(1)]
rval = ordinal_suff.sub(post, str_num)
else:
rval = f"{str_num}th"
return rval
def millfn(self, ind: int = 0) -> str: def millfn(self, ind: int = 0) -> str:
if ind > len(mill) - 1: if ind > len(mill) - 1:
@ -3747,7 +3776,36 @@ class engine:
num = ONE_DIGIT_WORD.sub(self.unitsub, num, 1) num = ONE_DIGIT_WORD.sub(self.unitsub, num, 1)
return num return num
@validate_call(config=dict(arbitrary_types_allowed=True)) # noqa: C901 @staticmethod
def _sub_ord(val):
new = ordinal_suff.sub(lambda match: ordinal[match.group(1)], val)
return new + "th" * (new == val)
@classmethod
def _chunk_num(cls, num, decimal, group):
if decimal:
max_split = -1 if group != 0 else 1
chunks = num.split(".", max_split)
else:
chunks = [num]
return cls._remove_last_blank(chunks)
@staticmethod
def _remove_last_blank(chunks):
"""
Remove the last item from chunks if it's a blank string.
Return the resultant chunks and whether the last item was removed.
"""
removed = chunks[-1] == ""
result = chunks[:-1] if removed else chunks
return result, removed
@staticmethod
def _get_sign(num):
return {'+': 'plus', '-': 'minus'}.get(num.lstrip()[0], '')
@typechecked
def number_to_words( # noqa: C901 def number_to_words( # noqa: C901
self, self,
num: Union[Number, Word], num: Union[Number, Word],
@ -3794,13 +3852,8 @@ class engine:
if group < 0 or group > 3: if group < 0 or group > 3:
raise BadChunkingOptionError raise BadChunkingOptionError
nowhite = num.lstrip()
if nowhite[0] == "+": sign = self._get_sign(num)
sign = "plus"
elif nowhite[0] == "-":
sign = "minus"
else:
sign = ""
if num in nth_suff: if num in nth_suff:
num = zero num = zero
@ -3808,34 +3861,21 @@ class engine:
myord = num[-2:] in nth_suff myord = num[-2:] in nth_suff
if myord: if myord:
num = num[:-2] num = num[:-2]
finalpoint = False
if decimal:
if group != 0:
chunks = num.split(".")
else:
chunks = num.split(".", 1)
if chunks[-1] == "": # remove blank string if nothing after decimal
chunks = chunks[:-1]
finalpoint = True # add 'point' to end of output
else:
chunks = [num]
first: Union[int, str, bool] = 1 chunks, finalpoint = self._chunk_num(num, decimal, group)
loopstart = 0
if chunks[0] == "": loopstart = chunks[0] == ""
first = 0 first: bool | None = not loopstart
if len(chunks) > 1:
loopstart = 1 def _handle_chunk(chunk):
nonlocal first
for i in range(loopstart, len(chunks)):
chunk = chunks[i]
# remove all non numeric \D # remove all non numeric \D
chunk = NON_DIGIT.sub("", chunk) chunk = NON_DIGIT.sub("", chunk)
if chunk == "": if chunk == "":
chunk = "0" chunk = "0"
if group == 0 and (first == 0 or first == ""): if group == 0 and not first:
chunk = self.enword(chunk, 1) chunk = self.enword(chunk, 1)
else: else:
chunk = self.enword(chunk, group) chunk = self.enword(chunk, group)
@ -3850,20 +3890,17 @@ class engine:
# chunk = re.sub(r"(\A\s|\s\Z)", self.blankfn, chunk) # chunk = re.sub(r"(\A\s|\s\Z)", self.blankfn, chunk)
chunk = chunk.strip() chunk = chunk.strip()
if first: if first:
first = "" first = None
chunks[i] = chunk return chunk
chunks[loopstart:] = map(_handle_chunk, chunks[loopstart:])
numchunks = [] numchunks = []
if first != 0: if first != 0:
numchunks = chunks[0].split(f"{comma} ") numchunks = chunks[0].split(f"{comma} ")
if myord and numchunks: if myord and numchunks:
# TODO: can this be just one re as it is in perl? numchunks[-1] = self._sub_ord(numchunks[-1])
mo = ordinal_suff.search(numchunks[-1])
if mo:
numchunks[-1] = ordinal_suff.sub(ordinal[mo.group(1)], numchunks[-1])
else:
numchunks[-1] += "th"
for chunk in chunks[1:]: for chunk in chunks[1:]:
numchunks.append(decimal) numchunks.append(decimal)
@ -3872,34 +3909,30 @@ class engine:
if finalpoint: if finalpoint:
numchunks.append(decimal) numchunks.append(decimal)
# wantlist: Perl list context. can explicitly specify in Python
if wantlist: if wantlist:
if sign: return [sign] * bool(sign) + numchunks
numchunks = [sign] + numchunks
return numchunks
elif group:
signout = f"{sign} " if sign else "" signout = f"{sign} " if sign else ""
return f"{signout}{', '.join(numchunks)}" valout = (
else: ', '.join(numchunks)
signout = f"{sign} " if sign else "" if group
num = f"{signout}{numchunks.pop(0)}" else ''.join(self._render(numchunks, decimal, comma))
if decimal is None: )
first = True return signout + valout
else:
first = not num.endswith(decimal) @staticmethod
for nc in numchunks: def _render(chunks, decimal, comma):
first_item = chunks.pop(0)
yield first_item
first = decimal is None or not first_item.endswith(decimal)
for nc in chunks:
if nc == decimal: if nc == decimal:
num += f" {nc}" first = False
first = 0
elif first: elif first:
num += f"{comma} {nc}" yield comma
else: yield f" {nc}"
num += f" {nc}"
return num
# Join words with commas and a trailing 'and' (when appropriate)... @typechecked
@validate_call
def join( def join(
self, self,
words: Optional[Sequence[Word]], words: Optional[Sequence[Word]],

View file

@ -1,19 +0,0 @@
class ValidateCallWrapperWrapper:
def __init__(self, wrapped):
self.orig = wrapped
def __eq__(self, other):
return self.raw_function == other.raw_function
@property
def raw_function(self):
return getattr(self.orig, 'raw_function') or self.orig
def same_method(m1, m2) -> bool:
"""
Return whether m1 and m2 are the same method.
Workaround for pydantic/pydantic#6390.
"""
return ValidateCallWrapperWrapper(m1) == ValidateCallWrapperWrapper(m2)

View file

@ -1,8 +0,0 @@
try:
from pydantic import validate_call # type: ignore
except ImportError:
# Pydantic 1
from pydantic import validate_arguments as validate_call # type: ignore
__all__ = ['validate_call']

View file

@ -1,68 +0,0 @@
"""
Routines for obtaining the class names
of an object and its parent classes.
"""
from more_itertools import unique_everseen
def all_bases(c):
"""
return a tuple of all base classes the class c has as a parent.
>>> object in all_bases(list)
True
"""
return c.mro()[1:]
def all_classes(c):
"""
return a tuple of all classes to which c belongs
>>> list in all_classes(list)
True
"""
return c.mro()
# borrowed from
# http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/
def iter_subclasses(cls):
"""
Generator over all subclasses of a given class, in depth-first order.
>>> bool in list(iter_subclasses(int))
True
>>> class A(object): pass
>>> class B(A): pass
>>> class C(A): pass
>>> class D(B,C): pass
>>> class E(D): pass
>>>
>>> for cls in iter_subclasses(A):
... print(cls.__name__)
B
D
E
C
>>> # get ALL classes currently defined
>>> res = [cls.__name__ for cls in iter_subclasses(object)]
>>> 'type' in res
True
>>> 'tuple' in res
True
>>> len(res) > 100
True
"""
return unique_everseen(_iter_all_subclasses(cls))
def _iter_all_subclasses(cls):
try:
subs = cls.__subclasses__()
except TypeError: # fails only when cls is type
subs = cls.__subclasses__(cls)
for sub in subs:
yield sub
yield from iter_subclasses(sub)

View file

@ -1,66 +0,0 @@
"""
meta.py
Some useful metaclasses.
"""
class LeafClassesMeta(type):
"""
A metaclass for classes that keeps track of all of them that
aren't base classes.
>>> Parent = LeafClassesMeta('MyParentClass', (), {})
>>> Parent in Parent._leaf_classes
True
>>> Child = LeafClassesMeta('MyChildClass', (Parent,), {})
>>> Child in Parent._leaf_classes
True
>>> Parent in Parent._leaf_classes
False
>>> Other = LeafClassesMeta('OtherClass', (), {})
>>> Parent in Other._leaf_classes
False
>>> len(Other._leaf_classes)
1
"""
def __init__(cls, name, bases, attrs):
if not hasattr(cls, '_leaf_classes'):
cls._leaf_classes = set()
leaf_classes = getattr(cls, '_leaf_classes')
leaf_classes.add(cls)
# remove any base classes
leaf_classes -= set(bases)
class TagRegistered(type):
"""
As classes of this metaclass are created, they keep a registry in the
base class of all classes by a class attribute, indicated by attr_name.
>>> FooObject = TagRegistered('FooObject', (), dict(tag='foo'))
>>> FooObject._registry['foo'] is FooObject
True
>>> BarObject = TagRegistered('Barobject', (FooObject,), dict(tag='bar'))
>>> FooObject._registry is BarObject._registry
True
>>> len(FooObject._registry)
2
'...' below should be 'jaraco.classes' but for pytest-dev/pytest#3396
>>> FooObject._registry['bar']
<class '....meta.Barobject'>
"""
attr_name = 'tag'
def __init__(cls, name, bases, namespace):
super(TagRegistered, cls).__init__(name, bases, namespace)
if not hasattr(cls, '_registry'):
cls._registry = {}
meta = cls.__class__
attr = getattr(cls, meta.attr_name, None)
if attr:
cls._registry[attr] = cls

View file

@ -1,170 +0,0 @@
class NonDataProperty:
"""Much like the property builtin, but only implements __get__,
making it a non-data property, and can be subsequently reset.
See http://users.rcn.com/python/download/Descriptor.htm for more
information.
>>> class X(object):
... @NonDataProperty
... def foo(self):
... return 3
>>> x = X()
>>> x.foo
3
>>> x.foo = 4
>>> x.foo
4
'...' below should be 'jaraco.classes' but for pytest-dev/pytest#3396
>>> X.foo
<....properties.NonDataProperty object at ...>
"""
def __init__(self, fget):
assert fget is not None, "fget cannot be none"
assert callable(fget), "fget must be callable"
self.fget = fget
def __get__(self, obj, objtype=None):
if obj is None:
return self
return self.fget(obj)
class classproperty:
"""
Like @property but applies at the class level.
>>> class X(metaclass=classproperty.Meta):
... val = None
... @classproperty
... def foo(cls):
... return cls.val
... @foo.setter
... def foo(cls, val):
... cls.val = val
>>> X.foo
>>> X.foo = 3
>>> X.foo
3
>>> x = X()
>>> x.foo
3
>>> X.foo = 4
>>> x.foo
4
Setting the property on an instance affects the class.
>>> x.foo = 5
>>> x.foo
5
>>> X.foo
5
>>> vars(x)
{}
>>> X().foo
5
Attempting to set an attribute where no setter was defined
results in an AttributeError:
>>> class GetOnly(metaclass=classproperty.Meta):
... @classproperty
... def foo(cls):
... return 'bar'
>>> GetOnly.foo = 3
Traceback (most recent call last):
...
AttributeError: can't set attribute
It is also possible to wrap a classmethod or staticmethod in
a classproperty.
>>> class Static(metaclass=classproperty.Meta):
... @classproperty
... @classmethod
... def foo(cls):
... return 'foo'
... @classproperty
... @staticmethod
... def bar():
... return 'bar'
>>> Static.foo
'foo'
>>> Static.bar
'bar'
*Legacy*
For compatibility, if the metaclass isn't specified, the
legacy behavior will be invoked.
>>> class X:
... val = None
... @classproperty
... def foo(cls):
... return cls.val
... @foo.setter
... def foo(cls, val):
... cls.val = val
>>> X.foo
>>> X.foo = 3
>>> X.foo
3
>>> x = X()
>>> x.foo
3
>>> X.foo = 4
>>> x.foo
4
Note, because the metaclass was not specified, setting
a value on an instance does not have the intended effect.
>>> x.foo = 5
>>> x.foo
5
>>> X.foo # should be 5
4
>>> vars(x) # should be empty
{'foo': 5}
>>> X().foo # should be 5
4
"""
class Meta(type):
def __setattr__(self, key, value):
obj = self.__dict__.get(key, None)
if type(obj) is classproperty:
return obj.__set__(self, value)
return super().__setattr__(key, value)
def __init__(self, fget, fset=None):
self.fget = self._ensure_method(fget)
self.fset = fset
fset and self.setter(fset)
def __get__(self, instance, owner=None):
return self.fget.__get__(None, owner)()
def __set__(self, owner, value):
if not self.fset:
raise AttributeError("can't set attribute")
if type(owner) is not classproperty.Meta:
owner = type(owner)
return self.fset.__get__(None, owner)(value)
def setter(self, fset):
self.fset = self._ensure_method(fset)
return self
@classmethod
def _ensure_method(cls, fn):
"""
Ensure fn is a classmethod or staticmethod.
"""
needs_method = not isinstance(fn, (classmethod, staticmethod))
return classmethod(fn) if needs_method else fn

View file

@ -1,16 +1,17 @@
import re from __future__ import annotations
import operator
import collections.abc import collections.abc
import itertools
import copy import copy
import functools import functools
import itertools
import operator
import random import random
import re
from collections.abc import Container, Iterable, Mapping from collections.abc import Container, Iterable, Mapping
from typing import Callable, Union from typing import Any, Callable, Union
import jaraco.text import jaraco.text
_Matchable = Union[Callable, Container, Iterable, re.Pattern] _Matchable = Union[Callable, Container, Iterable, re.Pattern]
@ -199,7 +200,12 @@ class RangeMap(dict):
""" """
def __init__(self, source, sort_params={}, key_match_comparator=operator.le): def __init__(
self,
source,
sort_params: Mapping[str, Any] = {},
key_match_comparator=operator.le,
):
dict.__init__(self, source) dict.__init__(self, source)
self.sort_params = sort_params self.sort_params = sort_params
self.match = key_match_comparator self.match = key_match_comparator
@ -291,7 +297,7 @@ class KeyTransformingDict(dict):
return key return key
def __init__(self, *args, **kargs): def __init__(self, *args, **kargs):
super(KeyTransformingDict, self).__init__() super().__init__()
# build a dictionary using the default constructs # build a dictionary using the default constructs
d = dict(*args, **kargs) d = dict(*args, **kargs)
# build this dictionary using transformed keys. # build this dictionary using transformed keys.
@ -300,31 +306,31 @@ class KeyTransformingDict(dict):
def __setitem__(self, key, val): def __setitem__(self, key, val):
key = self.transform_key(key) key = self.transform_key(key)
super(KeyTransformingDict, self).__setitem__(key, val) super().__setitem__(key, val)
def __getitem__(self, key): def __getitem__(self, key):
key = self.transform_key(key) key = self.transform_key(key)
return super(KeyTransformingDict, self).__getitem__(key) return super().__getitem__(key)
def __contains__(self, key): def __contains__(self, key):
key = self.transform_key(key) key = self.transform_key(key)
return super(KeyTransformingDict, self).__contains__(key) return super().__contains__(key)
def __delitem__(self, key): def __delitem__(self, key):
key = self.transform_key(key) key = self.transform_key(key)
return super(KeyTransformingDict, self).__delitem__(key) return super().__delitem__(key)
def get(self, key, *args, **kwargs): def get(self, key, *args, **kwargs):
key = self.transform_key(key) key = self.transform_key(key)
return super(KeyTransformingDict, self).get(key, *args, **kwargs) return super().get(key, *args, **kwargs)
def setdefault(self, key, *args, **kwargs): def setdefault(self, key, *args, **kwargs):
key = self.transform_key(key) key = self.transform_key(key)
return super(KeyTransformingDict, self).setdefault(key, *args, **kwargs) return super().setdefault(key, *args, **kwargs)
def pop(self, key, *args, **kwargs): def pop(self, key, *args, **kwargs):
key = self.transform_key(key) key = self.transform_key(key)
return super(KeyTransformingDict, self).pop(key, *args, **kwargs) return super().pop(key, *args, **kwargs)
def matching_key_for(self, key): def matching_key_for(self, key):
""" """
@ -333,8 +339,8 @@ class KeyTransformingDict(dict):
""" """
try: try:
return next(e_key for e_key in self.keys() if e_key == key) return next(e_key for e_key in self.keys() if e_key == key)
except StopIteration: except StopIteration as err:
raise KeyError(key) raise KeyError(key) from err
class FoldedCaseKeyedDict(KeyTransformingDict): class FoldedCaseKeyedDict(KeyTransformingDict):
@ -483,7 +489,7 @@ class ItemsAsAttributes:
def __getattr__(self, key): def __getattr__(self, key):
try: try:
return getattr(super(ItemsAsAttributes, self), key) return getattr(super(), key)
except AttributeError as e: except AttributeError as e:
# attempt to get the value from the mapping (return self[key]) # attempt to get the value from the mapping (return self[key])
# but be careful not to lose the original exception context. # but be careful not to lose the original exception context.
@ -677,7 +683,7 @@ class BijectiveMap(dict):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(BijectiveMap, self).__init__() super().__init__()
self.update(*args, **kwargs) self.update(*args, **kwargs)
def __setitem__(self, item, value): def __setitem__(self, item, value):
@ -691,19 +697,19 @@ class BijectiveMap(dict):
) )
if overlap: if overlap:
raise ValueError("Key/Value pairs may not overlap") raise ValueError("Key/Value pairs may not overlap")
super(BijectiveMap, self).__setitem__(item, value) super().__setitem__(item, value)
super(BijectiveMap, self).__setitem__(value, item) super().__setitem__(value, item)
def __delitem__(self, item): def __delitem__(self, item):
self.pop(item) self.pop(item)
def __len__(self): def __len__(self):
return super(BijectiveMap, self).__len__() // 2 return super().__len__() // 2
def pop(self, key, *args, **kwargs): def pop(self, key, *args, **kwargs):
mirror = self[key] mirror = self[key]
super(BijectiveMap, self).__delitem__(mirror) super().__delitem__(mirror)
return super(BijectiveMap, self).pop(key, *args, **kwargs) return super().pop(key, *args, **kwargs)
def update(self, *args, **kwargs): def update(self, *args, **kwargs):
# build a dictionary using the default constructs # build a dictionary using the default constructs
@ -769,7 +775,7 @@ class FrozenDict(collections.abc.Mapping, collections.abc.Hashable):
__slots__ = ['__data'] __slots__ = ['__data']
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
self = super(FrozenDict, cls).__new__(cls) self = super().__new__(cls)
self.__data = dict(*args, **kwargs) self.__data = dict(*args, **kwargs)
return self return self
@ -844,7 +850,7 @@ class Enumeration(ItemsAsAttributes, BijectiveMap):
names = names.split() names = names.split()
if codes is None: if codes is None:
codes = itertools.count() codes = itertools.count()
super(Enumeration, self).__init__(zip(names, codes)) super().__init__(zip(names, codes))
@property @property
def names(self): def names(self):

View file

@ -1,15 +1,26 @@
import os from __future__ import annotations
import subprocess
import contextlib import contextlib
import functools import functools
import tempfile
import shutil
import operator import operator
import os
import shutil
import subprocess
import sys
import tempfile
import urllib.request
import warnings import warnings
from typing import Iterator
if sys.version_info < (3, 12):
from backports import tarfile
else:
import tarfile
@contextlib.contextmanager @contextlib.contextmanager
def pushd(dir): def pushd(dir: str | os.PathLike) -> Iterator[str | os.PathLike]:
""" """
>>> tmp_path = getfixture('tmp_path') >>> tmp_path = getfixture('tmp_path')
>>> with pushd(tmp_path): >>> with pushd(tmp_path):
@ -26,33 +37,88 @@ def pushd(dir):
@contextlib.contextmanager @contextlib.contextmanager
def tarball_context(url, target_dir=None, runner=None, pushd=pushd): def tarball(
url, target_dir: str | os.PathLike | None = None
) -> Iterator[str | os.PathLike]:
""" """
Get a tarball, extract it, change to that directory, yield, then Get a tarball, extract it, yield, then clean up.
clean up.
`runner` is the function to invoke commands. >>> import urllib.request
`pushd` is a context manager for changing the directory. >>> url = getfixture('tarfile_served')
>>> target = getfixture('tmp_path') / 'out'
>>> tb = tarball(url, target_dir=target)
>>> import pathlib
>>> with tb as extracted:
... contents = pathlib.Path(extracted, 'contents.txt').read_text(encoding='utf-8')
>>> assert not os.path.exists(extracted)
""" """
if target_dir is None: if target_dir is None:
target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '') target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '')
if runner is None:
runner = functools.partial(subprocess.check_call, shell=True)
else:
warnings.warn("runner parameter is deprecated", DeprecationWarning)
# In the tar command, use --strip-components=1 to strip the first path and # In the tar command, use --strip-components=1 to strip the first path and
# then # then
# use -C to cause the files to be extracted to {target_dir}. This ensures # use -C to cause the files to be extracted to {target_dir}. This ensures
# that we always know where the files were extracted. # that we always know where the files were extracted.
runner('mkdir {target_dir}'.format(**vars())) os.mkdir(target_dir)
try: try:
getter = 'wget {url} -O -' req = urllib.request.urlopen(url)
extract = 'tar x{compression} --strip-components=1 -C {target_dir}' with tarfile.open(fileobj=req, mode='r|*') as tf:
cmd = ' | '.join((getter, extract)) tf.extractall(path=target_dir, filter=strip_first_component)
runner(cmd.format(compression=infer_compression(url), **vars()))
with pushd(target_dir):
yield target_dir yield target_dir
finally: finally:
runner('rm -Rf {target_dir}'.format(**vars())) shutil.rmtree(target_dir)
def strip_first_component(
member: tarfile.TarInfo,
path,
) -> tarfile.TarInfo:
_, member.name = member.name.split('/', 1)
return member
def _compose(*cmgrs):
"""
Compose any number of dependent context managers into a single one.
The last, innermost context manager may take arbitrary arguments, but
each successive context manager should accept the result from the
previous as a single parameter.
Like :func:`jaraco.functools.compose`, behavior works from right to
left, so the context manager should be indicated from outermost to
innermost.
Example, to create a context manager to change to a temporary
directory:
>>> temp_dir_as_cwd = _compose(pushd, temp_dir)
>>> with temp_dir_as_cwd() as dir:
... assert os.path.samefile(os.getcwd(), dir)
"""
def compose_two(inner, outer):
def composed(*args, **kwargs):
with inner(*args, **kwargs) as saved, outer(saved) as res:
yield res
return contextlib.contextmanager(composed)
return functools.reduce(compose_two, reversed(cmgrs))
tarball_cwd = _compose(pushd, tarball)
@contextlib.contextmanager
def tarball_context(*args, **kwargs):
warnings.warn(
"tarball_context is deprecated. Use tarball or tarball_cwd instead.",
DeprecationWarning,
stacklevel=2,
)
pushd_ctx = kwargs.pop('pushd', pushd)
with tarball(*args, **kwargs) as tball, pushd_ctx(tball) as dir:
yield dir
def infer_compression(url): def infer_compression(url):
@ -68,6 +134,11 @@ def infer_compression(url):
>>> infer_compression('file.xz') >>> infer_compression('file.xz')
'J' 'J'
""" """
warnings.warn(
"infer_compression is deprecated with no replacement",
DeprecationWarning,
stacklevel=2,
)
# cheat and just assume it's the last two characters # cheat and just assume it's the last two characters
compression_indicator = url[-2:] compression_indicator = url[-2:]
mapping = dict(gz='z', bz='j', xz='J') mapping = dict(gz='z', bz='j', xz='J')
@ -84,7 +155,7 @@ def temp_dir(remover=shutil.rmtree):
>>> import pathlib >>> import pathlib
>>> with temp_dir() as the_dir: >>> with temp_dir() as the_dir:
... assert os.path.isdir(the_dir) ... assert os.path.isdir(the_dir)
... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents') ... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents', encoding='utf-8')
>>> assert not os.path.exists(the_dir) >>> assert not os.path.exists(the_dir)
""" """
temp_dir = tempfile.mkdtemp() temp_dir = tempfile.mkdtemp()
@ -113,15 +184,23 @@ def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir):
yield repo_dir yield repo_dir
@contextlib.contextmanager
def null(): def null():
""" """
A null context suitable to stand in for a meaningful context. A null context suitable to stand in for a meaningful context.
>>> with null() as value: >>> with null() as value:
... assert value is None ... assert value is None
This context is most useful when dealing with two or more code
branches but only some need a context. Wrap the others in a null
context to provide symmetry across all options.
""" """
yield warnings.warn(
"null is deprecated. Use contextlib.nullcontext",
DeprecationWarning,
stacklevel=2,
)
return contextlib.nullcontext()
class ExceptionTrap: class ExceptionTrap:
@ -267,13 +346,7 @@ class on_interrupt(contextlib.ContextDecorator):
... on_interrupt('ignore')(do_interrupt)() ... on_interrupt('ignore')(do_interrupt)()
""" """
def __init__( def __init__(self, action='error', /, code=1):
self,
action='error',
# py3.7 compat
# /,
code=1,
):
self.action = action self.action = action
self.code = code self.code = code

View file

@ -74,9 +74,6 @@ def result_invoke(
def invoke( def invoke(
f: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs f: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs
) -> Callable[_P, _R]: ... ) -> Callable[_P, _R]: ...
def call_aside(
f: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
) -> Callable[_P, _R]: ...
class Throttler(Generic[_R]): class Throttler(Generic[_R]):
last_called: float last_called: float

View file

@ -9,11 +9,11 @@ python-modernize licence: BSD (from python-modernize/LICENSE)
""" """
from lib2to3.fixer_util import (FromImport, Newline, is_import, from lib2to3.fixer_util import (FromImport, Newline, is_import,
find_root, does_tree_import, Comma) find_root, does_tree_import,
Call, Name, Comma)
from lib2to3.pytree import Leaf, Node from lib2to3.pytree import Leaf, Node
from lib2to3.pygram import python_symbols as syms, python_grammar from lib2to3.pygram import python_symbols as syms
from lib2to3.pygram import token from lib2to3.pygram import token
from lib2to3.fixer_util import (Node, Call, Name, syms, Comma, Number)
import re import re
@ -116,7 +116,7 @@ def suitify(parent):
""" """
for node in parent.children: for node in parent.children:
if node.type == syms.suite: if node.type == syms.suite:
# already in the prefered format, do nothing # already in the preferred format, do nothing
return return
# One-liners have no suite node, we have to fake one up # One-liners have no suite node, we have to fake one up
@ -390,6 +390,7 @@ def touch_import_top(package, name_to_import, node):
break break
insert_pos = idx insert_pos = idx
children_hooks = []
if package is None: if package is None:
import_ = Node(syms.import_name, [ import_ = Node(syms.import_name, [
Leaf(token.NAME, u"import"), Leaf(token.NAME, u"import"),
@ -413,8 +414,6 @@ def touch_import_top(package, name_to_import, node):
] ]
) )
children_hooks = [install_hooks, Newline()] children_hooks = [install_hooks, Newline()]
else:
children_hooks = []
# FromImport(package, [Leaf(token.NAME, name_to_import, prefix=u" ")]) # FromImport(package, [Leaf(token.NAME, name_to_import, prefix=u" ")])
@ -448,7 +447,6 @@ def check_future_import(node):
else: else:
node = node.children[3] node = node.children[3]
# now node is the import_as_name[s] # now node is the import_as_name[s]
# print(python_grammar.number2symbol[node.type]) # breaks sometimes
if node.type == syms.import_as_names: if node.type == syms.import_as_names:
result = set() result = set()
for n in node.children: for n in node.children:

View file

@ -37,7 +37,7 @@ from lib2to3.fixer_util import Name, syms, Node, Leaf, touch_import, Call, \
def has_metaclass(parent): def has_metaclass(parent):
""" we have to check the cls_node without changing it. """ we have to check the cls_node without changing it.
There are two possiblities: There are two possibilities:
1) clsdef => suite => simple_stmt => expr_stmt => Leaf('__meta') 1) clsdef => suite => simple_stmt => expr_stmt => Leaf('__meta')
2) clsdef => simple_stmt => expr_stmt => Leaf('__meta') 2) clsdef => simple_stmt => expr_stmt => Leaf('__meta')
""" """
@ -63,7 +63,7 @@ def fixup_parse_tree(cls_node):
# already in the preferred format, do nothing # already in the preferred format, do nothing
return return
# !%@#! oneliners have no suite node, we have to fake one up # !%@#! one-liners have no suite node, we have to fake one up
for i, node in enumerate(cls_node.children): for i, node in enumerate(cls_node.children):
if node.type == token.COLON: if node.type == token.COLON:
break break

View file

@ -16,6 +16,7 @@ MAPPING = {u"reprlib": u"repr",
u"winreg": u"_winreg", u"winreg": u"_winreg",
u"configparser": u"ConfigParser", u"configparser": u"ConfigParser",
u"copyreg": u"copy_reg", u"copyreg": u"copy_reg",
u"multiprocessing.SimpleQueue": u"multiprocessing.queues.SimpleQueue",
u"queue": u"Queue", u"queue": u"Queue",
u"socketserver": u"SocketServer", u"socketserver": u"SocketServer",
u"_markupbase": u"markupbase", u"_markupbase": u"markupbase",

View file

@ -18,8 +18,12 @@ def assignment_source(num_pre, num_post, LISTNAME, ITERNAME):
Returns a source fit for Assign() from fixer_util Returns a source fit for Assign() from fixer_util
""" """
children = [] children = []
try:
pre = unicode(num_pre) pre = unicode(num_pre)
post = unicode(num_post) post = unicode(num_post)
except NameError:
pre = str(num_pre)
post = str(num_post)
# This code builds the assignment source from lib2to3 tree primitives. # This code builds the assignment source from lib2to3 tree primitives.
# It's not very readable, but it seems like the most correct way to do it. # It's not very readable, but it seems like the most correct way to do it.
if num_pre > 0: if num_pre > 0:

View file

@ -75,12 +75,12 @@ Credits
------- -------
:Author: Ed Schofield, Jordan M. Adler, et al :Author: Ed Schofield, Jordan M. Adler, et al
:Sponsor: Python Charmers Pty Ltd, Australia: http://pythoncharmers.com :Sponsor: Python Charmers: https://pythoncharmers.com
Licensing Licensing
--------- ---------
Copyright 2013-2019 Python Charmers Pty Ltd, Australia. Copyright 2013-2024 Python Charmers, Australia.
The software is distributed under an MIT licence. See LICENSE.txt. The software is distributed under an MIT licence. See LICENSE.txt.
""" """

View file

@ -1,11 +1,13 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import inspect import inspect
import sys
import math import math
import numbers import numbers
from future.utils import PY2, PY3, exec_ from future.utils import PY2, PY3, exec_
if PY2: if PY2:
from collections import Mapping from collections import Mapping
else: else:
@ -103,13 +105,12 @@ if PY3:
return '0' + builtins.oct(number)[2:] return '0' + builtins.oct(number)[2:]
raw_input = input raw_input = input
# imp was deprecated in python 3.6
try: if sys.version_info >= (3, 6):
from importlib import reload from importlib import reload
except ImportError: else:
# for python2, python3 <= 3.4 # for python2, python3 <= 3.4
from imp import reload from imp import reload
unicode = str unicode = str
unichr = chr unichr = chr
xrange = range xrange = range

View file

@ -32,17 +32,31 @@ Author: Ed Schofield.
Inspired by and based on ``uprefix`` by Vinay M. Sajip. Inspired by and based on ``uprefix`` by Vinay M. Sajip.
""" """
import sys
# imp was deprecated in python 3.6
if sys.version_info >= (3, 6):
import importlib as imp
else:
import imp import imp
import logging import logging
import marshal
import os import os
import sys
import copy import copy
from lib2to3.pgen2.parse import ParseError from lib2to3.pgen2.parse import ParseError
from lib2to3.refactor import RefactoringTool from lib2to3.refactor import RefactoringTool
from libfuturize import fixes from libfuturize import fixes
try:
from importlib.machinery import (
PathFinder,
SourceFileLoader,
)
except ImportError:
PathFinder = None
SourceFileLoader = object
if sys.version_info[:2] < (3, 4):
import imp
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -225,6 +239,81 @@ def detect_python2(source, pathname):
return False return False
def transform(source, pathname):
# This implementation uses lib2to3,
# you can override and use something else
# if that's better for you
# lib2to3 likes a newline at the end
RTs.setup()
source += '\n'
try:
tree = RTs._rt.refactor_string(source, pathname)
except ParseError as e:
if e.msg != 'bad input' or e.value != '=':
raise
tree = RTs._rtp.refactor_string(source, pathname)
# could optimise a bit for only doing str(tree) if
# getattr(tree, 'was_changed', False) returns True
return str(tree)[:-1] # remove added newline
class PastSourceFileLoader(SourceFileLoader):
exclude_paths = []
include_paths = []
def _convert_needed(self):
fullname = self.name
if any(fullname.startswith(path) for path in self.exclude_paths):
convert = False
elif any(fullname.startswith(path) for path in self.include_paths):
convert = True
else:
convert = False
return convert
def _exec_transformed_module(self, module):
source = self.get_source(self.name)
pathname = self.path
if detect_python2(source, pathname):
source = transform(source, pathname)
code = compile(source, pathname, "exec")
exec(code, module.__dict__)
# For Python 3.3
def load_module(self, fullname):
logger.debug("Running load_module for %s", fullname)
if fullname in sys.modules:
mod = sys.modules[fullname]
else:
if self._convert_needed():
logger.debug("Autoconverting %s", fullname)
mod = imp.new_module(fullname)
sys.modules[fullname] = mod
# required by PEP 302
mod.__file__ = self.path
mod.__loader__ = self
if self.is_package(fullname):
mod.__path__ = []
mod.__package__ = fullname
else:
mod.__package__ = fullname.rpartition('.')[0]
self._exec_transformed_module(mod)
else:
mod = super().load_module(fullname)
return mod
# For Python >=3.4
def exec_module(self, module):
logger.debug("Running exec_module for %s", module)
if self._convert_needed():
logger.debug("Autoconverting %s", self.name)
self._exec_transformed_module(module)
else:
super().exec_module(module)
class Py2Fixer(object): class Py2Fixer(object):
""" """
An import hook class that uses lib2to3 for source-to-source translation of An import hook class that uses lib2to3 for source-to-source translation of
@ -258,151 +347,30 @@ class Py2Fixer(object):
""" """
self.exclude_paths += paths self.exclude_paths += paths
# For Python 3.3
def find_module(self, fullname, path=None): def find_module(self, fullname, path=None):
logger.debug('Running find_module: {0}...'.format(fullname)) logger.debug("Running find_module: (%s, %s)", fullname, path)
if '.' in fullname: loader = PathFinder.find_module(fullname, path)
parent, child = fullname.rsplit('.', 1) if not loader:
if path is None: logger.debug("Py2Fixer could not find %s", fullname)
loader = self.find_module(parent, path)
mod = loader.load_module(parent)
path = mod.__path__
fullname = child
# Perhaps we should try using the new importlib functionality in Python
# 3.3: something like this?
# thing = importlib.machinery.PathFinder.find_module(fullname, path)
try:
self.found = imp.find_module(fullname, path)
except Exception as e:
logger.debug('Py2Fixer could not find {0}')
logger.debug('Exception was: {0})'.format(fullname, e))
return None return None
self.kind = self.found[-1][-1] loader.__class__ = PastSourceFileLoader
if self.kind == imp.PKG_DIRECTORY: loader.exclude_paths = self.exclude_paths
self.pathname = os.path.join(self.found[1], '__init__.py') loader.include_paths = self.include_paths
elif self.kind == imp.PY_SOURCE: return loader
self.pathname = self.found[1]
return self
def transform(self, source): # For Python >=3.4
# This implementation uses lib2to3, def find_spec(self, fullname, path=None, target=None):
# you can override and use something else logger.debug("Running find_spec: (%s, %s, %s)", fullname, path, target)
# if that's better for you spec = PathFinder.find_spec(fullname, path, target)
if not spec:
logger.debug("Py2Fixer could not find %s", fullname)
return None
spec.loader.__class__ = PastSourceFileLoader
spec.loader.exclude_paths = self.exclude_paths
spec.loader.include_paths = self.include_paths
return spec
# lib2to3 likes a newline at the end
RTs.setup()
source += '\n'
try:
tree = RTs._rt.refactor_string(source, self.pathname)
except ParseError as e:
if e.msg != 'bad input' or e.value != '=':
raise
tree = RTs._rtp.refactor_string(source, self.pathname)
# could optimise a bit for only doing str(tree) if
# getattr(tree, 'was_changed', False) returns True
return str(tree)[:-1] # remove added newline
def load_module(self, fullname):
logger.debug('Running load_module for {0}...'.format(fullname))
if fullname in sys.modules:
mod = sys.modules[fullname]
else:
if self.kind in (imp.PY_COMPILED, imp.C_EXTENSION, imp.C_BUILTIN,
imp.PY_FROZEN):
convert = False
# elif (self.pathname.startswith(_stdlibprefix)
# and 'site-packages' not in self.pathname):
# # We assume it's a stdlib package in this case. Is this too brittle?
# # Please file a bug report at https://github.com/PythonCharmers/python-future
# # if so.
# convert = False
# in theory, other paths could be configured to be excluded here too
elif any([fullname.startswith(path) for path in self.exclude_paths]):
convert = False
elif any([fullname.startswith(path) for path in self.include_paths]):
convert = True
else:
convert = False
if not convert:
logger.debug('Excluded {0} from translation'.format(fullname))
mod = imp.load_module(fullname, *self.found)
else:
logger.debug('Autoconverting {0} ...'.format(fullname))
mod = imp.new_module(fullname)
sys.modules[fullname] = mod
# required by PEP 302
mod.__file__ = self.pathname
mod.__name__ = fullname
mod.__loader__ = self
# This:
# mod.__package__ = '.'.join(fullname.split('.')[:-1])
# seems to result in "SystemError: Parent module '' not loaded,
# cannot perform relative import" for a package's __init__.py
# file. We use the approach below. Another option to try is the
# minimal load_module pattern from the PEP 302 text instead.
# Is the test in the next line more or less robust than the
# following one? Presumably less ...
# ispkg = self.pathname.endswith('__init__.py')
if self.kind == imp.PKG_DIRECTORY:
mod.__path__ = [ os.path.dirname(self.pathname) ]
mod.__package__ = fullname
else:
#else, regular module
mod.__path__ = []
mod.__package__ = fullname.rpartition('.')[0]
try:
cachename = imp.cache_from_source(self.pathname)
if not os.path.exists(cachename):
update_cache = True
else:
sourcetime = os.stat(self.pathname).st_mtime
cachetime = os.stat(cachename).st_mtime
update_cache = cachetime < sourcetime
# # Force update_cache to work around a problem with it being treated as Py3 code???
# update_cache = True
if not update_cache:
with open(cachename, 'rb') as f:
data = f.read()
try:
code = marshal.loads(data)
except Exception:
# pyc could be corrupt. Regenerate it
update_cache = True
if update_cache:
if self.found[0]:
source = self.found[0].read()
elif self.kind == imp.PKG_DIRECTORY:
with open(self.pathname) as f:
source = f.read()
if detect_python2(source, self.pathname):
source = self.transform(source)
code = compile(source, self.pathname, 'exec')
dirname = os.path.dirname(cachename)
try:
if not os.path.exists(dirname):
os.makedirs(dirname)
with open(cachename, 'wb') as f:
data = marshal.dumps(code)
f.write(data)
except Exception: # could be write-protected
pass
exec(code, mod.__dict__)
except Exception as e:
# must remove module from sys.modules
del sys.modules[fullname]
raise # keep it simple
if self.found[0]:
self.found[0].close()
return mod
_hook = Py2Fixer() _hook = Py2Fixer()

View file

@ -1,6 +1,8 @@
""" """
Utilities for determining application-specific dirs. See <https://github.com/platformdirs/platformdirs> for details and Utilities for determining application-specific dirs.
usage.
See <https://github.com/platformdirs/platformdirs> for details and usage.
""" """
from __future__ import annotations from __future__ import annotations
@ -20,22 +22,22 @@ if TYPE_CHECKING:
def _set_platform_dir_class() -> type[PlatformDirsABC]: def _set_platform_dir_class() -> type[PlatformDirsABC]:
if sys.platform == "win32": if sys.platform == "win32":
from platformdirs.windows import Windows as Result from platformdirs.windows import Windows as Result # noqa: PLC0415
elif sys.platform == "darwin": elif sys.platform == "darwin":
from platformdirs.macos import MacOS as Result from platformdirs.macos import MacOS as Result # noqa: PLC0415
else: else:
from platformdirs.unix import Unix as Result from platformdirs.unix import Unix as Result # noqa: PLC0415
if os.getenv("ANDROID_DATA") == "/data" and os.getenv("ANDROID_ROOT") == "/system": if os.getenv("ANDROID_DATA") == "/data" and os.getenv("ANDROID_ROOT") == "/system":
if os.getenv("SHELL") or os.getenv("PREFIX"): if os.getenv("SHELL") or os.getenv("PREFIX"):
return Result return Result
from platformdirs.android import _android_folder from platformdirs.android import _android_folder # noqa: PLC0415
if _android_folder() is not None: if _android_folder() is not None:
from platformdirs.android import Android from platformdirs.android import Android # noqa: PLC0415
return Android # return to avoid redefinition of result return Android # return to avoid redefinition of a result
return Result return Result
@ -507,7 +509,7 @@ def user_log_path(
def user_documents_path() -> Path: def user_documents_path() -> Path:
""":returns: documents path tied to the user""" """:returns: documents a path tied to the user"""
return PlatformDirs().user_documents_path return PlatformDirs().user_documents_path
@ -585,41 +587,41 @@ def site_runtime_path(
__all__ = [ __all__ = [
"AppDirs",
"PlatformDirs",
"PlatformDirsABC",
"__version__", "__version__",
"__version_info__", "__version_info__",
"PlatformDirs",
"AppDirs",
"PlatformDirsABC",
"user_data_dir",
"user_config_dir",
"user_cache_dir",
"user_state_dir",
"user_log_dir",
"user_documents_dir",
"user_downloads_dir",
"user_pictures_dir",
"user_videos_dir",
"user_music_dir",
"user_desktop_dir",
"user_runtime_dir",
"site_data_dir",
"site_config_dir",
"site_cache_dir", "site_cache_dir",
"site_runtime_dir",
"user_data_path",
"user_config_path",
"user_cache_path",
"user_state_path",
"user_log_path",
"user_documents_path",
"user_downloads_path",
"user_pictures_path",
"user_videos_path",
"user_music_path",
"user_desktop_path",
"user_runtime_path",
"site_data_path",
"site_config_path",
"site_cache_path", "site_cache_path",
"site_config_dir",
"site_config_path",
"site_data_dir",
"site_data_path",
"site_runtime_dir",
"site_runtime_path", "site_runtime_path",
"user_cache_dir",
"user_cache_path",
"user_config_dir",
"user_config_path",
"user_data_dir",
"user_data_path",
"user_desktop_dir",
"user_desktop_path",
"user_documents_dir",
"user_documents_path",
"user_downloads_dir",
"user_downloads_path",
"user_log_dir",
"user_log_path",
"user_music_dir",
"user_music_path",
"user_pictures_dir",
"user_pictures_path",
"user_runtime_dir",
"user_runtime_path",
"user_state_dir",
"user_state_path",
"user_videos_dir",
"user_videos_path",
] ]

View file

@ -24,7 +24,7 @@ PROPS = (
def main() -> None: def main() -> None:
"""Run main entry point.""" """Run the main entry point."""
app_name = "MyApp" app_name = "MyApp"
app_author = "MyCompany" app_author = "MyCompany"

View file

@ -13,10 +13,11 @@ from .api import PlatformDirsABC
class Android(PlatformDirsABC): class Android(PlatformDirsABC):
""" """
Follows the guidance `from here <https://android.stackexchange.com/a/216132>`_. Makes use of the Follows the guidance `from here <https://android.stackexchange.com/a/216132>`_.
`appname <platformdirs.api.PlatformDirsABC.appname>`,
`version <platformdirs.api.PlatformDirsABC.version>`, Makes use of the `appname <platformdirs.api.PlatformDirsABC.appname>`, `version
`ensure_exists <platformdirs.api.PlatformDirsABC.ensure_exists>`. <platformdirs.api.PlatformDirsABC.version>`, `ensure_exists <platformdirs.api.PlatformDirsABC.ensure_exists>`.
""" """
@property @property
@ -44,7 +45,7 @@ class Android(PlatformDirsABC):
@property @property
def user_cache_dir(self) -> str: def user_cache_dir(self) -> str:
""":return: cache directory tied to the user, e.g. e.g. ``/data/user/<userid>/<packagename>/cache/<AppName>``""" """:return: cache directory tied to the user, e.g.,``/data/user/<userid>/<packagename>/cache/<AppName>``"""
return self._append_app_name_and_version(cast(str, _android_folder()), "cache") return self._append_app_name_and_version(cast(str, _android_folder()), "cache")
@property @property
@ -119,13 +120,13 @@ class Android(PlatformDirsABC):
def _android_folder() -> str | None: def _android_folder() -> str | None:
""":return: base folder for the Android OS or None if it cannot be found""" """:return: base folder for the Android OS or None if it cannot be found"""
try: try:
# First try to get path to android app via pyjnius # First try to get a path to android app via pyjnius
from jnius import autoclass from jnius import autoclass # noqa: PLC0415
context = autoclass("android.content.Context") context = autoclass("android.content.Context")
result: str | None = context.getFilesDir().getParentFile().getAbsolutePath() result: str | None = context.getFilesDir().getParentFile().getAbsolutePath()
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
# if fails find an android folder looking path on the sys.path # if fails find an android folder looking a path on the sys.path
pattern = re.compile(r"/data/(data|user/\d+)/(.+)/files") pattern = re.compile(r"/data/(data|user/\d+)/(.+)/files")
for path in sys.path: for path in sys.path:
if pattern.match(path): if pattern.match(path):
@ -141,7 +142,7 @@ def _android_documents_folder() -> str:
""":return: documents folder for the Android OS""" """:return: documents folder for the Android OS"""
# Get directories with pyjnius # Get directories with pyjnius
try: try:
from jnius import autoclass from jnius import autoclass # noqa: PLC0415
context = autoclass("android.content.Context") context = autoclass("android.content.Context")
environment = autoclass("android.os.Environment") environment = autoclass("android.os.Environment")
@ -157,7 +158,7 @@ def _android_downloads_folder() -> str:
""":return: downloads folder for the Android OS""" """:return: downloads folder for the Android OS"""
# Get directories with pyjnius # Get directories with pyjnius
try: try:
from jnius import autoclass from jnius import autoclass # noqa: PLC0415
context = autoclass("android.content.Context") context = autoclass("android.content.Context")
environment = autoclass("android.os.Environment") environment = autoclass("android.os.Environment")
@ -173,7 +174,7 @@ def _android_pictures_folder() -> str:
""":return: pictures folder for the Android OS""" """:return: pictures folder for the Android OS"""
# Get directories with pyjnius # Get directories with pyjnius
try: try:
from jnius import autoclass from jnius import autoclass # noqa: PLC0415
context = autoclass("android.content.Context") context = autoclass("android.content.Context")
environment = autoclass("android.os.Environment") environment = autoclass("android.os.Environment")
@ -189,7 +190,7 @@ def _android_videos_folder() -> str:
""":return: videos folder for the Android OS""" """:return: videos folder for the Android OS"""
# Get directories with pyjnius # Get directories with pyjnius
try: try:
from jnius import autoclass from jnius import autoclass # noqa: PLC0415
context = autoclass("android.content.Context") context = autoclass("android.content.Context")
environment = autoclass("android.os.Environment") environment = autoclass("android.os.Environment")
@ -205,7 +206,7 @@ def _android_music_folder() -> str:
""":return: music folder for the Android OS""" """:return: music folder for the Android OS"""
# Get directories with pyjnius # Get directories with pyjnius
try: try:
from jnius import autoclass from jnius import autoclass # noqa: PLC0415
context = autoclass("android.content.Context") context = autoclass("android.content.Context")
environment = autoclass("android.os.Environment") environment = autoclass("android.os.Environment")

View file

@ -11,10 +11,10 @@ if TYPE_CHECKING:
from typing import Iterator, Literal from typing import Iterator, Literal
class PlatformDirsABC(ABC): class PlatformDirsABC(ABC): # noqa: PLR0904
"""Abstract base class for platform directories.""" """Abstract base class for platform directories."""
def __init__( # noqa: PLR0913 def __init__( # noqa: PLR0913, PLR0917
self, self,
appname: str | None = None, appname: str | None = None,
appauthor: str | None | Literal[False] = None, appauthor: str | None | Literal[False] = None,
@ -34,34 +34,47 @@ class PlatformDirsABC(ABC):
:param multipath: See `multipath`. :param multipath: See `multipath`.
:param opinion: See `opinion`. :param opinion: See `opinion`.
:param ensure_exists: See `ensure_exists`. :param ensure_exists: See `ensure_exists`.
""" """
self.appname = appname #: The name of application. self.appname = appname #: The name of application.
self.appauthor = appauthor self.appauthor = appauthor
""" """
The name of the app author or distributing body for this application. Typically, it is the owning company name. The name of the app author or distributing body for this application.
Defaults to `appname`. You may pass ``False`` to disable it.
Typically, it is the owning company name. Defaults to `appname`. You may pass ``False`` to disable it.
""" """
self.version = version self.version = version
""" """
An optional version path element to append to the path. You might want to use this if you want multiple versions An optional version path element to append to the path.
of your app to be able to run independently. If used, this would typically be ``<major>.<minor>``.
You might want to use this if you want multiple versions of your app to be able to run independently. If used,
this would typically be ``<major>.<minor>``.
""" """
self.roaming = roaming self.roaming = roaming
""" """
Whether to use the roaming appdata directory on Windows. That means that for users on a Windows network setup Whether to use the roaming appdata directory on Windows.
for roaming profiles, this user data will be synced on login (see
`here <http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>`_). That means that for users on a Windows network setup for roaming profiles, this user data will be synced on
login (see
`here <https://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>`_).
""" """
self.multipath = multipath self.multipath = multipath
""" """
An optional parameter which indicates that the entire list of data dirs should be returned. An optional parameter which indicates that the entire list of data dirs should be returned.
By default, the first item would only be returned. By default, the first item would only be returned.
""" """
self.opinion = opinion #: A flag to indicating to use opinionated values. self.opinion = opinion #: A flag to indicating to use opinionated values.
self.ensure_exists = ensure_exists self.ensure_exists = ensure_exists
""" """
Optionally create the directory (and any missing parents) upon access if it does not exist. Optionally create the directory (and any missing parents) upon access if it does not exist.
By default, no directories are created. By default, no directories are created.
""" """
def _append_app_name_and_version(self, *base: str) -> str: def _append_app_name_and_version(self, *base: str) -> str:
@ -200,7 +213,7 @@ class PlatformDirsABC(ABC):
@property @property
def user_documents_path(self) -> Path: def user_documents_path(self) -> Path:
""":return: documents path tied to the user""" """:return: documents a path tied to the user"""
return Path(self.user_documents_dir) return Path(self.user_documents_dir)
@property @property

View file

@ -10,11 +10,14 @@ from .api import PlatformDirsABC
class MacOS(PlatformDirsABC): class MacOS(PlatformDirsABC):
""" """
Platform directories for the macOS operating system. Follows the guidance from `Apple documentation Platform directories for the macOS operating system.
<https://developer.apple.com/library/archive/documentation/FileManagement/Conceptual/FileSystemProgrammingGuide/MacOSXDirectories/MacOSXDirectories.html>`_.
Follows the guidance from
`Apple documentation <https://developer.apple.com/library/archive/documentation/FileManagement/Conceptual/FileSystemProgrammingGuide/MacOSXDirectories/MacOSXDirectories.html>`_.
Makes use of the `appname <platformdirs.api.PlatformDirsABC.appname>`, Makes use of the `appname <platformdirs.api.PlatformDirsABC.appname>`,
`version <platformdirs.api.PlatformDirsABC.version>`, `version <platformdirs.api.PlatformDirsABC.version>`,
`ensure_exists <platformdirs.api.PlatformDirsABC.ensure_exists>`. `ensure_exists <platformdirs.api.PlatformDirsABC.ensure_exists>`.
""" """
@property @property
@ -28,7 +31,7 @@ class MacOS(PlatformDirsABC):
:return: data directory shared by users, e.g. ``/Library/Application Support/$appname/$version``. :return: data directory shared by users, e.g. ``/Library/Application Support/$appname/$version``.
If we're using a Python binary managed by `Homebrew <https://brew.sh>`_, the directory If we're using a Python binary managed by `Homebrew <https://brew.sh>`_, the directory
will be under the Homebrew prefix, e.g. ``/opt/homebrew/share/$appname/$version``. will be under the Homebrew prefix, e.g. ``/opt/homebrew/share/$appname/$version``.
If `multipath <platformdirs.api.PlatformDirsABC.multipath>` is enabled and we're in Homebrew, If `multipath <platformdirs.api.PlatformDirsABC.multipath>` is enabled, and we're in Homebrew,
the response is a multi-path string separated by ":", e.g. the response is a multi-path string separated by ":", e.g.
``/opt/homebrew/share/$appname/$version:/Library/Application Support/$appname/$version`` ``/opt/homebrew/share/$appname/$version:/Library/Application Support/$appname/$version``
""" """
@ -60,7 +63,7 @@ class MacOS(PlatformDirsABC):
:return: cache directory shared by users, e.g. ``/Library/Caches/$appname/$version``. :return: cache directory shared by users, e.g. ``/Library/Caches/$appname/$version``.
If we're using a Python binary managed by `Homebrew <https://brew.sh>`_, the directory If we're using a Python binary managed by `Homebrew <https://brew.sh>`_, the directory
will be under the Homebrew prefix, e.g. ``/opt/homebrew/var/cache/$appname/$version``. will be under the Homebrew prefix, e.g. ``/opt/homebrew/var/cache/$appname/$version``.
If `multipath <platformdirs.api.PlatformDirsABC.multipath>` is enabled and we're in Homebrew, If `multipath <platformdirs.api.PlatformDirsABC.multipath>` is enabled, and we're in Homebrew,
the response is a multi-path string separated by ":", e.g. the response is a multi-path string separated by ":", e.g.
``/opt/homebrew/var/cache/$appname/$version:/Library/Caches/$appname/$version`` ``/opt/homebrew/var/cache/$appname/$version:/Library/Caches/$appname/$version``
""" """

View file

@ -6,13 +6,13 @@ import os
import sys import sys
from configparser import ConfigParser from configparser import ConfigParser
from pathlib import Path from pathlib import Path
from typing import Iterator from typing import Iterator, NoReturn
from .api import PlatformDirsABC from .api import PlatformDirsABC
if sys.platform == "win32": if sys.platform == "win32":
def getuid() -> int: def getuid() -> NoReturn:
msg = "should only be used on Unix" msg = "should only be used on Unix"
raise RuntimeError(msg) raise RuntimeError(msg)
@ -20,17 +20,17 @@ else:
from os import getuid from os import getuid
class Unix(PlatformDirsABC): class Unix(PlatformDirsABC): # noqa: PLR0904
""" """
On Unix/Linux, we follow the On Unix/Linux, we follow the `XDG Basedir Spec <https://specifications.freedesktop.org/basedir-spec/basedir-spec-
`XDG Basedir Spec <https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html>`_. The spec allows latest.html>`_.
overriding directories with environment variables. The examples show are the default values, alongside the name of
the environment variable that overrides them. Makes use of the The spec allows overriding directories with environment variables. The examples shown are the default values,
`appname <platformdirs.api.PlatformDirsABC.appname>`, alongside the name of the environment variable that overrides them. Makes use of the `appname
`version <platformdirs.api.PlatformDirsABC.version>`, <platformdirs.api.PlatformDirsABC.appname>`, `version <platformdirs.api.PlatformDirsABC.version>`, `multipath
`multipath <platformdirs.api.PlatformDirsABC.multipath>`, <platformdirs.api.PlatformDirsABC.multipath>`, `opinion <platformdirs.api.PlatformDirsABC.opinion>`, `ensure_exists
`opinion <platformdirs.api.PlatformDirsABC.opinion>`, <platformdirs.api.PlatformDirsABC.ensure_exists>`.
`ensure_exists <platformdirs.api.PlatformDirsABC.ensure_exists>`.
""" """
@property @property
@ -205,17 +205,17 @@ class Unix(PlatformDirsABC):
@property @property
def site_data_path(self) -> Path: def site_data_path(self) -> Path:
""":return: data path shared by users. Only return first item, even if ``multipath`` is set to ``True``""" """:return: data path shared by users. Only return the first item, even if ``multipath`` is set to ``True``"""
return self._first_item_as_path_if_multipath(self.site_data_dir) return self._first_item_as_path_if_multipath(self.site_data_dir)
@property @property
def site_config_path(self) -> Path: def site_config_path(self) -> Path:
""":return: config path shared by the users. Only return first item, even if ``multipath`` is set to ``True``""" """:return: config path shared by the users, returns the first item, even if ``multipath`` is set to ``True``"""
return self._first_item_as_path_if_multipath(self.site_config_dir) return self._first_item_as_path_if_multipath(self.site_config_dir)
@property @property
def site_cache_path(self) -> Path: def site_cache_path(self) -> Path:
""":return: cache path shared by users. Only return first item, even if ``multipath`` is set to ``True``""" """:return: cache path shared by users. Only return the first item, even if ``multipath`` is set to ``True``"""
return self._first_item_as_path_if_multipath(self.site_cache_dir) return self._first_item_as_path_if_multipath(self.site_cache_dir)
def _first_item_as_path_if_multipath(self, directory: str) -> Path: def _first_item_as_path_if_multipath(self, directory: str) -> Path:
@ -246,7 +246,12 @@ def _get_user_media_dir(env_var: str, fallback_tilde_path: str) -> str:
def _get_user_dirs_folder(key: str) -> str | None: def _get_user_dirs_folder(key: str) -> str | None:
"""Return directory from user-dirs.dirs config file. See https://freedesktop.org/wiki/Software/xdg-user-dirs/.""" """
Return directory from user-dirs.dirs config file.
See https://freedesktop.org/wiki/Software/xdg-user-dirs/.
"""
user_dirs_config_path = Path(Unix().user_config_dir) / "user-dirs.dirs" user_dirs_config_path = Path(Unix().user_config_dir) / "user-dirs.dirs"
if user_dirs_config_path.exists(): if user_dirs_config_path.exists():
parser = ConfigParser() parser = ConfigParser()

View file

@ -12,5 +12,5 @@ __version__: str
__version_tuple__: VERSION_TUPLE __version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE version_tuple: VERSION_TUPLE
__version__ = version = '4.2.0' __version__ = version = '4.2.1'
__version_tuple__ = version_tuple = (4, 2, 0) __version_tuple__ = version_tuple = (4, 2, 1)

View file

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import ctypes
import os import os
import sys import sys
from functools import lru_cache from functools import lru_cache
@ -16,15 +15,13 @@ if TYPE_CHECKING:
class Windows(PlatformDirsABC): class Windows(PlatformDirsABC):
""" """
`MSDN on where to store app data files `MSDN on where to store app data files <https://learn.microsoft.com/en-us/windows/win32/shell/knownfolderid>`_.
<http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120>`_.
Makes use of the Makes use of the `appname <platformdirs.api.PlatformDirsABC.appname>`, `appauthor
`appname <platformdirs.api.PlatformDirsABC.appname>`, <platformdirs.api.PlatformDirsABC.appauthor>`, `version <platformdirs.api.PlatformDirsABC.version>`, `roaming
`appauthor <platformdirs.api.PlatformDirsABC.appauthor>`, <platformdirs.api.PlatformDirsABC.roaming>`, `opinion <platformdirs.api.PlatformDirsABC.opinion>`, `ensure_exists
`version <platformdirs.api.PlatformDirsABC.version>`, <platformdirs.api.PlatformDirsABC.ensure_exists>`.
`roaming <platformdirs.api.PlatformDirsABC.roaming>`,
`opinion <platformdirs.api.PlatformDirsABC.opinion>`,
`ensure_exists <platformdirs.api.PlatformDirsABC.ensure_exists>`.
""" """
@property @property
@ -165,7 +162,7 @@ def get_win_folder_from_env_vars(csidl_name: str) -> str:
def get_win_folder_if_csidl_name_not_env_var(csidl_name: str) -> str | None: def get_win_folder_if_csidl_name_not_env_var(csidl_name: str) -> str | None:
"""Get folder for a CSIDL name that does not exist as an environment variable.""" """Get a folder for a CSIDL name that does not exist as an environment variable."""
if csidl_name == "CSIDL_PERSONAL": if csidl_name == "CSIDL_PERSONAL":
return os.path.join(os.path.normpath(os.environ["USERPROFILE"]), "Documents") # noqa: PTH118 return os.path.join(os.path.normpath(os.environ["USERPROFILE"]), "Documents") # noqa: PTH118
@ -189,6 +186,7 @@ def get_win_folder_from_registry(csidl_name: str) -> str:
This is a fallback technique at best. I'm not sure if using the registry for these guarantees us the correct answer This is a fallback technique at best. I'm not sure if using the registry for these guarantees us the correct answer
for all CSIDL_* names. for all CSIDL_* names.
""" """
shell_folder_name = { shell_folder_name = {
"CSIDL_APPDATA": "AppData", "CSIDL_APPDATA": "AppData",
@ -205,7 +203,7 @@ def get_win_folder_from_registry(csidl_name: str) -> str:
raise ValueError(msg) raise ValueError(msg)
if sys.platform != "win32": # only needed for mypy type checker to know that this code runs only on Windows if sys.platform != "win32": # only needed for mypy type checker to know that this code runs only on Windows
raise NotImplementedError raise NotImplementedError
import winreg import winreg # noqa: PLC0415
key = winreg.OpenKey(winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders") key = winreg.OpenKey(winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders")
directory, _ = winreg.QueryValueEx(key, shell_folder_name) directory, _ = winreg.QueryValueEx(key, shell_folder_name)
@ -218,6 +216,8 @@ def get_win_folder_via_ctypes(csidl_name: str) -> str:
# Use 'CSIDL_PROFILE' (40) and append the default folder 'Downloads' instead. # Use 'CSIDL_PROFILE' (40) and append the default folder 'Downloads' instead.
# https://learn.microsoft.com/en-us/windows/win32/shell/knownfolderid # https://learn.microsoft.com/en-us/windows/win32/shell/knownfolderid
import ctypes # noqa: PLC0415
csidl_const = { csidl_const = {
"CSIDL_APPDATA": 26, "CSIDL_APPDATA": 26,
"CSIDL_COMMON_APPDATA": 35, "CSIDL_COMMON_APPDATA": 35,
@ -250,10 +250,15 @@ def get_win_folder_via_ctypes(csidl_name: str) -> str:
def _pick_get_win_folder() -> Callable[[str], str]: def _pick_get_win_folder() -> Callable[[str], str]:
try:
import ctypes # noqa: PLC0415
except ImportError:
pass
else:
if hasattr(ctypes, "windll"): if hasattr(ctypes, "windll"):
return get_win_folder_via_ctypes return get_win_folder_via_ctypes
try: try:
import winreg # noqa: F401 import winreg # noqa: PLC0415, F401
except ImportError: except ImportError:
return get_win_folder_from_env_vars return get_win_folder_from_env_vars
else: else:

View file

@ -170,7 +170,16 @@ class PlexObject:
elem = ElementTree.fromstring(xml) elem = ElementTree.fromstring(xml)
return self._buildItemOrNone(elem, cls) return self._buildItemOrNone(elem, cls)
def fetchItems(self, ekey, cls=None, container_start=None, container_size=None, maxresults=None, **kwargs): def fetchItems(
self,
ekey,
cls=None,
container_start=None,
container_size=None,
maxresults=None,
params=None,
**kwargs,
):
""" Load the specified key to find and build all items with the specified tag """ Load the specified key to find and build all items with the specified tag
and attrs. and attrs.
@ -186,6 +195,7 @@ class PlexObject:
container_start (None, int): offset to get a subset of the data container_start (None, int): offset to get a subset of the data
container_size (None, int): How many items in data container_size (None, int): How many items in data
maxresults (int, optional): Only return the specified number of results. maxresults (int, optional): Only return the specified number of results.
params (dict, optional): Any additional params to add to the request.
**kwargs (dict): Optionally add XML attribute to filter the items. **kwargs (dict): Optionally add XML attribute to filter the items.
See the details below for more info. See the details below for more info.
@ -268,7 +278,7 @@ class PlexObject:
headers['X-Plex-Container-Start'] = str(container_start) headers['X-Plex-Container-Start'] = str(container_start)
headers['X-Plex-Container-Size'] = str(container_size) headers['X-Plex-Container-Size'] = str(container_size)
data = self._server.query(ekey, headers=headers) data = self._server.query(ekey, headers=headers, params=params)
subresults = self.findItems(data, cls, ekey, **kwargs) subresults = self.findItems(data, cls, ekey, **kwargs)
total_size = utils.cast(int, data.attrib.get('totalSize') or data.attrib.get('size')) or len(subresults) total_size = utils.cast(int, data.attrib.get('totalSize') or data.attrib.get('size')) or len(subresults)
@ -283,6 +293,11 @@ class PlexObject:
results.extend(subresults) results.extend(subresults)
container_start += container_size
if container_start > total_size:
break
wanted_number_of_items = total_size - offset wanted_number_of_items = total_size - offset
if maxresults is not None: if maxresults is not None:
wanted_number_of_items = min(maxresults, wanted_number_of_items) wanted_number_of_items = min(maxresults, wanted_number_of_items)
@ -291,11 +306,6 @@ class PlexObject:
if wanted_number_of_items <= len(results): if wanted_number_of_items <= len(results):
break break
container_start += container_size
if container_start > total_size:
break
return results return results
def fetchItem(self, ekey, cls=None, **kwargs): def fetchItem(self, ekey, cls=None, **kwargs):
@ -337,7 +347,7 @@ class PlexObject:
kwargs['type'] = cls.TYPE kwargs['type'] = cls.TYPE
# rtag to iter on a specific root tag using breadth-first search # rtag to iter on a specific root tag using breadth-first search
if rtag: if rtag:
data = next(utils.iterXMLBFS(data, rtag), []) data = next(utils.iterXMLBFS(data, rtag), Element('Empty'))
# loop through all data elements to find matches # loop through all data elements to find matches
items = MediaContainer[cls](self._server, data, initpath=initpath) if data.tag == 'MediaContainer' else [] items = MediaContainer[cls](self._server, data, initpath=initpath) if data.tag == 'MediaContainer' else []
for elem in data: for elem in data:

View file

@ -4,6 +4,6 @@
# Library version # Library version
MAJOR_VERSION = 4 MAJOR_VERSION = 4
MINOR_VERSION = 15 MINOR_VERSION = 15
PATCH_VERSION = 11 PATCH_VERSION = 12
__short_version__ = f"{MAJOR_VERSION}.{MINOR_VERSION}" __short_version__ = f"{MAJOR_VERSION}.{MINOR_VERSION}"
__version__ = f"{__short_version__}.{PATCH_VERSION}" __version__ = f"{__short_version__}.{PATCH_VERSION}"

View file

@ -746,7 +746,7 @@ class PlexServer(PlexObject):
""" Returns list of all :class:`~plexapi.media.TranscodeJob` objects running or paused on server. """ """ Returns list of all :class:`~plexapi.media.TranscodeJob` objects running or paused on server. """
return self.fetchItems('/status/sessions/background') return self.fetchItems('/status/sessions/background')
def query(self, key, method=None, headers=None, timeout=None, **kwargs): def query(self, key, method=None, headers=None, params=None, timeout=None, **kwargs):
""" Main method used to handle HTTPS requests to the Plex server. This method helps """ Main method used to handle HTTPS requests to the Plex server. This method helps
by encoding the response to utf-8 and parsing the returned XML into and by encoding the response to utf-8 and parsing the returned XML into and
ElementTree object. Returns None if no data exists in the response. ElementTree object. Returns None if no data exists in the response.
@ -756,7 +756,7 @@ class PlexServer(PlexObject):
timeout = timeout or self._timeout timeout = timeout or self._timeout
log.debug('%s %s', method.__name__.upper(), url) log.debug('%s %s', method.__name__.upper(), url)
headers = self._headers(**headers or {}) headers = self._headers(**headers or {})
response = method(url, headers=headers, timeout=timeout, **kwargs) response = method(url, headers=headers, params=params, timeout=timeout, **kwargs)
if response.status_code not in (200, 201, 204): if response.status_code not in (200, 201, 204):
codename = codes.get(response.status_code)[0] codename = codes.get(response.status_code)[0]
errtext = response.text.replace('\n', ' ') errtext = response.text.replace('\n', ' ')

48
lib/typeguard/__init__.py Normal file
View file

@ -0,0 +1,48 @@
import os
from typing import Any
from ._checkers import TypeCheckerCallable as TypeCheckerCallable
from ._checkers import TypeCheckLookupCallback as TypeCheckLookupCallback
from ._checkers import check_type_internal as check_type_internal
from ._checkers import checker_lookup_functions as checker_lookup_functions
from ._checkers import load_plugins as load_plugins
from ._config import CollectionCheckStrategy as CollectionCheckStrategy
from ._config import ForwardRefPolicy as ForwardRefPolicy
from ._config import TypeCheckConfiguration as TypeCheckConfiguration
from ._decorators import typechecked as typechecked
from ._decorators import typeguard_ignore as typeguard_ignore
from ._exceptions import InstrumentationWarning as InstrumentationWarning
from ._exceptions import TypeCheckError as TypeCheckError
from ._exceptions import TypeCheckWarning as TypeCheckWarning
from ._exceptions import TypeHintWarning as TypeHintWarning
from ._functions import TypeCheckFailCallback as TypeCheckFailCallback
from ._functions import check_type as check_type
from ._functions import warn_on_error as warn_on_error
from ._importhook import ImportHookManager as ImportHookManager
from ._importhook import TypeguardFinder as TypeguardFinder
from ._importhook import install_import_hook as install_import_hook
from ._memo import TypeCheckMemo as TypeCheckMemo
from ._suppression import suppress_type_checks as suppress_type_checks
from ._utils import Unset as Unset
# Re-export imports so they look like they live directly in this package
for value in list(locals().values()):
if getattr(value, "__module__", "").startswith(f"{__name__}."):
value.__module__ = __name__
config: TypeCheckConfiguration
def __getattr__(name: str) -> Any:
if name == "config":
from ._config import global_config
return global_config
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
# Automatically load checker lookup functions unless explicitly disabled
if "TYPEGUARD_DISABLE_PLUGIN_AUTOLOAD" not in os.environ:
load_plugins()

910
lib/typeguard/_checkers.py Normal file
View file

@ -0,0 +1,910 @@
from __future__ import annotations
import collections.abc
import inspect
import sys
import types
import typing
import warnings
from enum import Enum
from inspect import Parameter, isclass, isfunction
from io import BufferedIOBase, IOBase, RawIOBase, TextIOBase
from textwrap import indent
from typing import (
IO,
AbstractSet,
Any,
BinaryIO,
Callable,
Dict,
ForwardRef,
List,
Mapping,
MutableMapping,
NewType,
Optional,
Sequence,
Set,
TextIO,
Tuple,
Type,
TypeVar,
Union,
)
from unittest.mock import Mock
try:
import typing_extensions
except ImportError:
typing_extensions = None # type: ignore[assignment]
from ._config import ForwardRefPolicy
from ._exceptions import TypeCheckError, TypeHintWarning
from ._memo import TypeCheckMemo
from ._utils import evaluate_forwardref, get_stacklevel, get_type_name, qualified_name
if sys.version_info >= (3, 13):
from typing import is_typeddict
else:
from typing_extensions import is_typeddict
if sys.version_info >= (3, 11):
from typing import (
Annotated,
NotRequired,
TypeAlias,
get_args,
get_origin,
)
SubclassableAny = Any
else:
from typing_extensions import (
Annotated,
NotRequired,
TypeAlias,
get_args,
get_origin,
)
from typing_extensions import Any as SubclassableAny
if sys.version_info >= (3, 10):
from importlib.metadata import entry_points
from typing import ParamSpec
else:
from importlib_metadata import entry_points
from typing_extensions import ParamSpec
TypeCheckerCallable: TypeAlias = Callable[
[Any, Any, Tuple[Any, ...], TypeCheckMemo], Any
]
TypeCheckLookupCallback: TypeAlias = Callable[
[Any, Tuple[Any, ...], Tuple[Any, ...]], Optional[TypeCheckerCallable]
]
checker_lookup_functions: list[TypeCheckLookupCallback] = []
generic_alias_types: tuple[type, ...] = (type(List), type(List[Any]))
if sys.version_info >= (3, 9):
generic_alias_types += (types.GenericAlias,)
# Sentinel
_missing = object()
# Lifted from mypy.sharedparse
BINARY_MAGIC_METHODS = {
"__add__",
"__and__",
"__cmp__",
"__divmod__",
"__div__",
"__eq__",
"__floordiv__",
"__ge__",
"__gt__",
"__iadd__",
"__iand__",
"__idiv__",
"__ifloordiv__",
"__ilshift__",
"__imatmul__",
"__imod__",
"__imul__",
"__ior__",
"__ipow__",
"__irshift__",
"__isub__",
"__itruediv__",
"__ixor__",
"__le__",
"__lshift__",
"__lt__",
"__matmul__",
"__mod__",
"__mul__",
"__ne__",
"__or__",
"__pow__",
"__radd__",
"__rand__",
"__rdiv__",
"__rfloordiv__",
"__rlshift__",
"__rmatmul__",
"__rmod__",
"__rmul__",
"__ror__",
"__rpow__",
"__rrshift__",
"__rshift__",
"__rsub__",
"__rtruediv__",
"__rxor__",
"__sub__",
"__truediv__",
"__xor__",
}
def check_callable(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not callable(value):
raise TypeCheckError("is not callable")
if args:
try:
signature = inspect.signature(value)
except (TypeError, ValueError):
return
argument_types = args[0]
if isinstance(argument_types, list) and not any(
type(item) is ParamSpec for item in argument_types
):
# The callable must not have keyword-only arguments without defaults
unfulfilled_kwonlyargs = [
param.name
for param in signature.parameters.values()
if param.kind == Parameter.KEYWORD_ONLY
and param.default == Parameter.empty
]
if unfulfilled_kwonlyargs:
raise TypeCheckError(
f"has mandatory keyword-only arguments in its declaration: "
f'{", ".join(unfulfilled_kwonlyargs)}'
)
num_positional_args = num_mandatory_pos_args = 0
has_varargs = False
for param in signature.parameters.values():
if param.kind in (
Parameter.POSITIONAL_ONLY,
Parameter.POSITIONAL_OR_KEYWORD,
):
num_positional_args += 1
if param.default is Parameter.empty:
num_mandatory_pos_args += 1
elif param.kind == Parameter.VAR_POSITIONAL:
has_varargs = True
if num_mandatory_pos_args > len(argument_types):
raise TypeCheckError(
f"has too many mandatory positional arguments in its declaration; "
f"expected {len(argument_types)} but {num_mandatory_pos_args} "
f"mandatory positional argument(s) declared"
)
elif not has_varargs and num_positional_args < len(argument_types):
raise TypeCheckError(
f"has too few arguments in its declaration; expected "
f"{len(argument_types)} but {num_positional_args} argument(s) "
f"declared"
)
def check_mapping(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if origin_type is Dict or origin_type is dict:
if not isinstance(value, dict):
raise TypeCheckError("is not a dict")
if origin_type is MutableMapping or origin_type is collections.abc.MutableMapping:
if not isinstance(value, collections.abc.MutableMapping):
raise TypeCheckError("is not a mutable mapping")
elif not isinstance(value, collections.abc.Mapping):
raise TypeCheckError("is not a mapping")
if args:
key_type, value_type = args
if key_type is not Any or value_type is not Any:
samples = memo.config.collection_check_strategy.iterate_samples(
value.items()
)
for k, v in samples:
try:
check_type_internal(k, key_type, memo)
except TypeCheckError as exc:
exc.append_path_element(f"key {k!r}")
raise
try:
check_type_internal(v, value_type, memo)
except TypeCheckError as exc:
exc.append_path_element(f"value of key {k!r}")
raise
def check_typed_dict(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not isinstance(value, dict):
raise TypeCheckError("is not a dict")
declared_keys = frozenset(origin_type.__annotations__)
if hasattr(origin_type, "__required_keys__"):
required_keys = set(origin_type.__required_keys__)
else: # py3.8 and lower
required_keys = set(declared_keys) if origin_type.__total__ else set()
existing_keys = set(value)
extra_keys = existing_keys - declared_keys
if extra_keys:
keys_formatted = ", ".join(f'"{key}"' for key in sorted(extra_keys, key=repr))
raise TypeCheckError(f"has unexpected extra key(s): {keys_formatted}")
# Detect NotRequired fields which are hidden by get_type_hints()
type_hints: dict[str, type] = {}
for key, annotation in origin_type.__annotations__.items():
if isinstance(annotation, ForwardRef):
annotation = evaluate_forwardref(annotation, memo)
if get_origin(annotation) is NotRequired:
required_keys.discard(key)
annotation = get_args(annotation)[0]
type_hints[key] = annotation
missing_keys = required_keys - existing_keys
if missing_keys:
keys_formatted = ", ".join(f'"{key}"' for key in sorted(missing_keys, key=repr))
raise TypeCheckError(f"is missing required key(s): {keys_formatted}")
for key, argtype in type_hints.items():
argvalue = value.get(key, _missing)
if argvalue is not _missing:
try:
check_type_internal(argvalue, argtype, memo)
except TypeCheckError as exc:
exc.append_path_element(f"value of key {key!r}")
raise
def check_list(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not isinstance(value, list):
raise TypeCheckError("is not a list")
if args and args != (Any,):
samples = memo.config.collection_check_strategy.iterate_samples(value)
for i, v in enumerate(samples):
try:
check_type_internal(v, args[0], memo)
except TypeCheckError as exc:
exc.append_path_element(f"item {i}")
raise
def check_sequence(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not isinstance(value, collections.abc.Sequence):
raise TypeCheckError("is not a sequence")
if args and args != (Any,):
samples = memo.config.collection_check_strategy.iterate_samples(value)
for i, v in enumerate(samples):
try:
check_type_internal(v, args[0], memo)
except TypeCheckError as exc:
exc.append_path_element(f"item {i}")
raise
def check_set(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if origin_type is frozenset:
if not isinstance(value, frozenset):
raise TypeCheckError("is not a frozenset")
elif not isinstance(value, AbstractSet):
raise TypeCheckError("is not a set")
if args and args != (Any,):
samples = memo.config.collection_check_strategy.iterate_samples(value)
for v in samples:
try:
check_type_internal(v, args[0], memo)
except TypeCheckError as exc:
exc.append_path_element(f"[{v}]")
raise
def check_tuple(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
# Specialized check for NamedTuples
if field_types := getattr(origin_type, "__annotations__", None):
if not isinstance(value, origin_type):
raise TypeCheckError(
f"is not a named tuple of type {qualified_name(origin_type)}"
)
for name, field_type in field_types.items():
try:
check_type_internal(getattr(value, name), field_type, memo)
except TypeCheckError as exc:
exc.append_path_element(f"attribute {name!r}")
raise
return
elif not isinstance(value, tuple):
raise TypeCheckError("is not a tuple")
if args:
use_ellipsis = args[-1] is Ellipsis
tuple_params = args[: -1 if use_ellipsis else None]
else:
# Unparametrized Tuple or plain tuple
return
if use_ellipsis:
element_type = tuple_params[0]
samples = memo.config.collection_check_strategy.iterate_samples(value)
for i, element in enumerate(samples):
try:
check_type_internal(element, element_type, memo)
except TypeCheckError as exc:
exc.append_path_element(f"item {i}")
raise
elif tuple_params == ((),):
if value != ():
raise TypeCheckError("is not an empty tuple")
else:
if len(value) != len(tuple_params):
raise TypeCheckError(
f"has wrong number of elements (expected {len(tuple_params)}, got "
f"{len(value)} instead)"
)
for i, (element, element_type) in enumerate(zip(value, tuple_params)):
try:
check_type_internal(element, element_type, memo)
except TypeCheckError as exc:
exc.append_path_element(f"item {i}")
raise
def check_union(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
errors: dict[str, TypeCheckError] = {}
try:
for type_ in args:
try:
check_type_internal(value, type_, memo)
return
except TypeCheckError as exc:
errors[get_type_name(type_)] = exc
formatted_errors = indent(
"\n".join(f"{key}: {error}" for key, error in errors.items()), " "
)
finally:
del errors # avoid creating ref cycle
raise TypeCheckError(f"did not match any element in the union:\n{formatted_errors}")
def check_uniontype(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
errors: dict[str, TypeCheckError] = {}
for type_ in args:
try:
check_type_internal(value, type_, memo)
return
except TypeCheckError as exc:
errors[get_type_name(type_)] = exc
formatted_errors = indent(
"\n".join(f"{key}: {error}" for key, error in errors.items()), " "
)
raise TypeCheckError(f"did not match any element in the union:\n{formatted_errors}")
def check_class(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not isclass(value) and not isinstance(value, generic_alias_types):
raise TypeCheckError("is not a class")
if not args:
return
if isinstance(args[0], ForwardRef):
expected_class = evaluate_forwardref(args[0], memo)
else:
expected_class = args[0]
if expected_class is Any:
return
elif getattr(expected_class, "_is_protocol", False):
check_protocol(value, expected_class, (), memo)
elif isinstance(expected_class, TypeVar):
check_typevar(value, expected_class, (), memo, subclass_check=True)
elif get_origin(expected_class) is Union:
errors: dict[str, TypeCheckError] = {}
for arg in get_args(expected_class):
if arg is Any:
return
try:
check_class(value, type, (arg,), memo)
return
except TypeCheckError as exc:
errors[get_type_name(arg)] = exc
else:
formatted_errors = indent(
"\n".join(f"{key}: {error}" for key, error in errors.items()), " "
)
raise TypeCheckError(
f"did not match any element in the union:\n{formatted_errors}"
)
elif not issubclass(value, expected_class): # type: ignore[arg-type]
raise TypeCheckError(f"is not a subclass of {qualified_name(expected_class)}")
def check_newtype(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
check_type_internal(value, origin_type.__supertype__, memo)
def check_instance(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not isinstance(value, origin_type):
raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}")
def check_typevar(
value: Any,
origin_type: TypeVar,
args: tuple[Any, ...],
memo: TypeCheckMemo,
*,
subclass_check: bool = False,
) -> None:
if origin_type.__bound__ is not None:
annotation = (
Type[origin_type.__bound__] if subclass_check else origin_type.__bound__
)
check_type_internal(value, annotation, memo)
elif origin_type.__constraints__:
for constraint in origin_type.__constraints__:
annotation = Type[constraint] if subclass_check else constraint
try:
check_type_internal(value, annotation, memo)
except TypeCheckError:
pass
else:
break
else:
formatted_constraints = ", ".join(
get_type_name(constraint) for constraint in origin_type.__constraints__
)
raise TypeCheckError(
f"does not match any of the constraints " f"({formatted_constraints})"
)
if typing_extensions is None:
def _is_literal_type(typ: object) -> bool:
return typ is typing.Literal
else:
def _is_literal_type(typ: object) -> bool:
return typ is typing.Literal or typ is typing_extensions.Literal
def check_literal(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
def get_literal_args(literal_args: tuple[Any, ...]) -> tuple[Any, ...]:
retval: list[Any] = []
for arg in literal_args:
if _is_literal_type(get_origin(arg)):
retval.extend(get_literal_args(arg.__args__))
elif arg is None or isinstance(arg, (int, str, bytes, bool, Enum)):
retval.append(arg)
else:
raise TypeError(
f"Illegal literal value: {arg}"
) # TypeError here is deliberate
return tuple(retval)
final_args = tuple(get_literal_args(args))
try:
index = final_args.index(value)
except ValueError:
pass
else:
if type(final_args[index]) is type(value):
return
formatted_args = ", ".join(repr(arg) for arg in final_args)
raise TypeCheckError(f"is not any of ({formatted_args})") from None
def check_literal_string(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
check_type_internal(value, str, memo)
def check_typeguard(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
check_type_internal(value, bool, memo)
def check_none(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if value is not None:
raise TypeCheckError("is not None")
def check_number(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if origin_type is complex and not isinstance(value, (complex, float, int)):
raise TypeCheckError("is neither complex, float or int")
elif origin_type is float and not isinstance(value, (float, int)):
raise TypeCheckError("is neither float or int")
def check_io(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if origin_type is TextIO or (origin_type is IO and args == (str,)):
if not isinstance(value, TextIOBase):
raise TypeCheckError("is not a text based I/O object")
elif origin_type is BinaryIO or (origin_type is IO and args == (bytes,)):
if not isinstance(value, (RawIOBase, BufferedIOBase)):
raise TypeCheckError("is not a binary I/O object")
elif not isinstance(value, IOBase):
raise TypeCheckError("is not an I/O object")
def check_protocol(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
# TODO: implement proper compatibility checking and support non-runtime protocols
if getattr(origin_type, "_is_runtime_protocol", False):
if not isinstance(value, origin_type):
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} protocol"
)
else:
warnings.warn(
f"Typeguard cannot check the {origin_type.__qualname__} protocol because "
f"it is a non-runtime protocol. If you would like to type check this "
f"protocol, please use @typing.runtime_checkable",
stacklevel=get_stacklevel(),
)
def check_byteslike(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not isinstance(value, (bytearray, bytes, memoryview)):
raise TypeCheckError("is not bytes-like")
def check_self(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if memo.self_type is None:
raise TypeCheckError("cannot be checked against Self outside of a method call")
if isclass(value):
if not issubclass(value, memo.self_type):
raise TypeCheckError(
f"is not an instance of the self type "
f"({qualified_name(memo.self_type)})"
)
elif not isinstance(value, memo.self_type):
raise TypeCheckError(
f"is not an instance of the self type ({qualified_name(memo.self_type)})"
)
def check_paramspec(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
pass # No-op for now
def check_instanceof(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not isinstance(value, origin_type):
raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}")
def check_type_internal(
value: Any,
annotation: Any,
memo: TypeCheckMemo,
) -> None:
"""
Check that the given object is compatible with the given type annotation.
This function should only be used by type checker callables. Applications should use
:func:`~.check_type` instead.
:param value: the value to check
:param annotation: the type annotation to check against
:param memo: a memo object containing configuration and information necessary for
looking up forward references
"""
if isinstance(annotation, ForwardRef):
try:
annotation = evaluate_forwardref(annotation, memo)
except NameError:
if memo.config.forward_ref_policy is ForwardRefPolicy.ERROR:
raise
elif memo.config.forward_ref_policy is ForwardRefPolicy.WARN:
warnings.warn(
f"Cannot resolve forward reference {annotation.__forward_arg__!r}",
TypeHintWarning,
stacklevel=get_stacklevel(),
)
return
if annotation is Any or annotation is SubclassableAny or isinstance(value, Mock):
return
# Skip type checks if value is an instance of a class that inherits from Any
if not isclass(value) and SubclassableAny in type(value).__bases__:
return
extras: tuple[Any, ...]
origin_type = get_origin(annotation)
if origin_type is Annotated:
annotation, *extras_ = get_args(annotation)
extras = tuple(extras_)
origin_type = get_origin(annotation)
else:
extras = ()
if origin_type is not None:
args = get_args(annotation)
# Compatibility hack to distinguish between unparametrized and empty tuple
# (tuple[()]), necessary due to https://github.com/python/cpython/issues/91137
if origin_type in (tuple, Tuple) and annotation is not Tuple and not args:
args = ((),)
else:
origin_type = annotation
args = ()
for lookup_func in checker_lookup_functions:
checker = lookup_func(origin_type, args, extras)
if checker:
checker(value, origin_type, args, memo)
return
if isclass(origin_type):
if not isinstance(value, origin_type):
raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}")
elif type(origin_type) is str: # noqa: E721
warnings.warn(
f"Skipping type check against {origin_type!r}; this looks like a "
f"string-form forward reference imported from another module",
TypeHintWarning,
stacklevel=get_stacklevel(),
)
# Equality checks are applied to these
origin_type_checkers = {
bytes: check_byteslike,
AbstractSet: check_set,
BinaryIO: check_io,
Callable: check_callable,
collections.abc.Callable: check_callable,
complex: check_number,
dict: check_mapping,
Dict: check_mapping,
float: check_number,
frozenset: check_set,
IO: check_io,
list: check_list,
List: check_list,
typing.Literal: check_literal,
Mapping: check_mapping,
MutableMapping: check_mapping,
None: check_none,
collections.abc.Mapping: check_mapping,
collections.abc.MutableMapping: check_mapping,
Sequence: check_sequence,
collections.abc.Sequence: check_sequence,
collections.abc.Set: check_set,
set: check_set,
Set: check_set,
TextIO: check_io,
tuple: check_tuple,
Tuple: check_tuple,
type: check_class,
Type: check_class,
Union: check_union,
}
if sys.version_info >= (3, 10):
origin_type_checkers[types.UnionType] = check_uniontype
origin_type_checkers[typing.TypeGuard] = check_typeguard
if sys.version_info >= (3, 11):
origin_type_checkers.update(
{typing.LiteralString: check_literal_string, typing.Self: check_self}
)
if typing_extensions is not None:
# On some Python versions, these may simply be re-exports from typing,
# but exactly which Python versions is subject to change,
# so it's best to err on the safe side
# and update the dictionary on all Python versions
# if typing_extensions is installed
origin_type_checkers[typing_extensions.Literal] = check_literal
origin_type_checkers[typing_extensions.LiteralString] = check_literal_string
origin_type_checkers[typing_extensions.Self] = check_self
origin_type_checkers[typing_extensions.TypeGuard] = check_typeguard
def builtin_checker_lookup(
origin_type: Any, args: tuple[Any, ...], extras: tuple[Any, ...]
) -> TypeCheckerCallable | None:
checker = origin_type_checkers.get(origin_type)
if checker is not None:
return checker
elif is_typeddict(origin_type):
return check_typed_dict
elif isclass(origin_type) and issubclass(
origin_type, Tuple # type: ignore[arg-type]
):
# NamedTuple
return check_tuple
elif getattr(origin_type, "_is_protocol", False):
return check_protocol
elif isinstance(origin_type, ParamSpec):
return check_paramspec
elif isinstance(origin_type, TypeVar):
return check_typevar
elif origin_type.__class__ is NewType:
# typing.NewType on Python 3.10+
return check_newtype
elif (
isfunction(origin_type)
and getattr(origin_type, "__module__", None) == "typing"
and getattr(origin_type, "__qualname__", "").startswith("NewType.")
and hasattr(origin_type, "__supertype__")
):
# typing.NewType on Python 3.9 and below
return check_newtype
return None
checker_lookup_functions.append(builtin_checker_lookup)
def load_plugins() -> None:
"""
Load all type checker lookup functions from entry points.
All entry points from the ``typeguard.checker_lookup`` group are loaded, and the
returned lookup functions are added to :data:`typeguard.checker_lookup_functions`.
.. note:: This function is called implicitly on import, unless the
``TYPEGUARD_DISABLE_PLUGIN_AUTOLOAD`` environment variable is present.
"""
for ep in entry_points(group="typeguard.checker_lookup"):
try:
plugin = ep.load()
except Exception as exc:
warnings.warn(
f"Failed to load plugin {ep.name!r}: " f"{qualified_name(exc)}: {exc}",
stacklevel=2,
)
continue
if not callable(plugin):
warnings.warn(
f"Plugin {ep} returned a non-callable object: {plugin!r}", stacklevel=2
)
continue
checker_lookup_functions.insert(0, plugin)

108
lib/typeguard/_config.py Normal file
View file

@ -0,0 +1,108 @@
from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
from ._functions import TypeCheckFailCallback
T = TypeVar("T")
class ForwardRefPolicy(Enum):
"""
Defines how unresolved forward references are handled.
Members:
* ``ERROR``: propagate the :exc:`NameError` when the forward reference lookup fails
* ``WARN``: emit a :class:`~.TypeHintWarning` if the forward reference lookup fails
* ``IGNORE``: silently skip checks for unresolveable forward references
"""
ERROR = auto()
WARN = auto()
IGNORE = auto()
class CollectionCheckStrategy(Enum):
"""
Specifies how thoroughly the contents of collections are type checked.
This has an effect on the following built-in checkers:
* ``AbstractSet``
* ``Dict``
* ``List``
* ``Mapping``
* ``Set``
* ``Tuple[<type>, ...]`` (arbitrarily sized tuples)
Members:
* ``FIRST_ITEM``: check only the first item
* ``ALL_ITEMS``: check all items
"""
FIRST_ITEM = auto()
ALL_ITEMS = auto()
def iterate_samples(self, collection: Iterable[T]) -> Iterable[T]:
if self is CollectionCheckStrategy.FIRST_ITEM:
try:
return [next(iter(collection))]
except StopIteration:
return ()
else:
return collection
@dataclass
class TypeCheckConfiguration:
"""
You can change Typeguard's behavior with these settings.
.. attribute:: typecheck_fail_callback
:type: Callable[[TypeCheckError, TypeCheckMemo], Any]
Callable that is called when type checking fails.
Default: ``None`` (the :exc:`~.TypeCheckError` is raised directly)
.. attribute:: forward_ref_policy
:type: ForwardRefPolicy
Specifies what to do when a forward reference fails to resolve.
Default: ``WARN``
.. attribute:: collection_check_strategy
:type: CollectionCheckStrategy
Specifies how thoroughly the contents of collections (list, dict, etc.) are
type checked.
Default: ``FIRST_ITEM``
.. attribute:: debug_instrumentation
:type: bool
If set to ``True``, the code of modules or functions instrumented by typeguard
is printed to ``sys.stderr`` after the instrumentation is done
Requires Python 3.9 or newer.
Default: ``False``
"""
forward_ref_policy: ForwardRefPolicy = ForwardRefPolicy.WARN
typecheck_fail_callback: TypeCheckFailCallback | None = None
collection_check_strategy: CollectionCheckStrategy = (
CollectionCheckStrategy.FIRST_ITEM
)
debug_instrumentation: bool = False
global_config = TypeCheckConfiguration()

View file

@ -0,0 +1,235 @@
from __future__ import annotations
import ast
import inspect
import sys
from collections.abc import Sequence
from functools import partial
from inspect import isclass, isfunction
from types import CodeType, FrameType, FunctionType
from typing import TYPE_CHECKING, Any, Callable, ForwardRef, TypeVar, cast, overload
from warnings import warn
from ._config import CollectionCheckStrategy, ForwardRefPolicy, global_config
from ._exceptions import InstrumentationWarning
from ._functions import TypeCheckFailCallback
from ._transformer import TypeguardTransformer
from ._utils import Unset, function_name, get_stacklevel, is_method_of, unset
if TYPE_CHECKING:
from typeshed.stdlib.types import _Cell
_F = TypeVar("_F")
def typeguard_ignore(f: _F) -> _F:
"""This decorator is a noop during static type-checking."""
return f
else:
from typing import no_type_check as typeguard_ignore # noqa: F401
T_CallableOrType = TypeVar("T_CallableOrType", bound=Callable[..., Any])
def make_cell(value: object) -> _Cell:
return (lambda: value).__closure__[0] # type: ignore[index]
def find_target_function(
new_code: CodeType, target_path: Sequence[str], firstlineno: int
) -> CodeType | None:
target_name = target_path[0]
for const in new_code.co_consts:
if isinstance(const, CodeType):
if const.co_name == target_name:
if const.co_firstlineno == firstlineno:
return const
elif len(target_path) > 1:
target_code = find_target_function(
const, target_path[1:], firstlineno
)
if target_code:
return target_code
return None
def instrument(f: T_CallableOrType) -> FunctionType | str:
if not getattr(f, "__code__", None):
return "no code associated"
elif not getattr(f, "__module__", None):
return "__module__ attribute is not set"
elif f.__code__.co_filename == "<stdin>":
return "cannot instrument functions defined in a REPL"
elif hasattr(f, "__wrapped__"):
return (
"@typechecked only supports instrumenting functions wrapped with "
"@classmethod, @staticmethod or @property"
)
target_path = [item for item in f.__qualname__.split(".") if item != "<locals>"]
module_source = inspect.getsource(sys.modules[f.__module__])
module_ast = ast.parse(module_source)
instrumentor = TypeguardTransformer(target_path, f.__code__.co_firstlineno)
instrumentor.visit(module_ast)
if not instrumentor.target_node or instrumentor.target_lineno is None:
return "instrumentor did not find the target function"
module_code = compile(module_ast, f.__code__.co_filename, "exec", dont_inherit=True)
new_code = find_target_function(
module_code, target_path, instrumentor.target_lineno
)
if not new_code:
return "cannot find the target function in the AST"
if global_config.debug_instrumentation and sys.version_info >= (3, 9):
# Find the matching AST node, then unparse it to source and print to stdout
print(
f"Source code of {f.__qualname__}() after instrumentation:"
"\n----------------------------------------------",
file=sys.stderr,
)
print(ast.unparse(instrumentor.target_node), file=sys.stderr)
print(
"----------------------------------------------",
file=sys.stderr,
)
closure = f.__closure__
if new_code.co_freevars != f.__code__.co_freevars:
# Create a new closure and find values for the new free variables
frame = cast(FrameType, inspect.currentframe())
frame = cast(FrameType, frame.f_back)
frame_locals = cast(FrameType, frame.f_back).f_locals
cells: list[_Cell] = []
for key in new_code.co_freevars:
if key in instrumentor.names_used_in_annotations:
# Find the value and make a new cell from it
value = frame_locals.get(key) or ForwardRef(key)
cells.append(make_cell(value))
else:
# Reuse the cell from the existing closure
assert f.__closure__
cells.append(f.__closure__[f.__code__.co_freevars.index(key)])
closure = tuple(cells)
new_function = FunctionType(new_code, f.__globals__, f.__name__, closure=closure)
new_function.__module__ = f.__module__
new_function.__name__ = f.__name__
new_function.__qualname__ = f.__qualname__
new_function.__annotations__ = f.__annotations__
new_function.__doc__ = f.__doc__
new_function.__defaults__ = f.__defaults__
new_function.__kwdefaults__ = f.__kwdefaults__
return new_function
@overload
def typechecked(
*,
forward_ref_policy: ForwardRefPolicy | Unset = unset,
typecheck_fail_callback: TypeCheckFailCallback | Unset = unset,
collection_check_strategy: CollectionCheckStrategy | Unset = unset,
debug_instrumentation: bool | Unset = unset,
) -> Callable[[T_CallableOrType], T_CallableOrType]: ...
@overload
def typechecked(target: T_CallableOrType) -> T_CallableOrType: ...
def typechecked(
target: T_CallableOrType | None = None,
*,
forward_ref_policy: ForwardRefPolicy | Unset = unset,
typecheck_fail_callback: TypeCheckFailCallback | Unset = unset,
collection_check_strategy: CollectionCheckStrategy | Unset = unset,
debug_instrumentation: bool | Unset = unset,
) -> Any:
"""
Instrument the target function to perform run-time type checking.
This decorator recompiles the target function, injecting code to type check
arguments, return values, yield values (excluding ``yield from``) and assignments to
annotated local variables.
This can also be used as a class decorator. This will instrument all type annotated
methods, including :func:`@classmethod <classmethod>`,
:func:`@staticmethod <staticmethod>`, and :class:`@property <property>` decorated
methods in the class.
.. note:: When Python is run in optimized mode (``-O`` or ``-OO``, this decorator
is a no-op). This is a feature meant for selectively introducing type checking
into a code base where the checks aren't meant to be run in production.
:param target: the function or class to enable type checking for
:param forward_ref_policy: override for
:attr:`.TypeCheckConfiguration.forward_ref_policy`
:param typecheck_fail_callback: override for
:attr:`.TypeCheckConfiguration.typecheck_fail_callback`
:param collection_check_strategy: override for
:attr:`.TypeCheckConfiguration.collection_check_strategy`
:param debug_instrumentation: override for
:attr:`.TypeCheckConfiguration.debug_instrumentation`
"""
if target is None:
return partial(
typechecked,
forward_ref_policy=forward_ref_policy,
typecheck_fail_callback=typecheck_fail_callback,
collection_check_strategy=collection_check_strategy,
debug_instrumentation=debug_instrumentation,
)
if not __debug__:
return target
if isclass(target):
for key, attr in target.__dict__.items():
if is_method_of(attr, target):
retval = instrument(attr)
if isfunction(retval):
setattr(target, key, retval)
elif isinstance(attr, (classmethod, staticmethod)):
if is_method_of(attr.__func__, target):
retval = instrument(attr.__func__)
if isfunction(retval):
wrapper = attr.__class__(retval)
setattr(target, key, wrapper)
elif isinstance(attr, property):
kwargs: dict[str, Any] = dict(doc=attr.__doc__)
for name in ("fset", "fget", "fdel"):
property_func = kwargs[name] = getattr(attr, name)
if is_method_of(property_func, target):
retval = instrument(property_func)
if isfunction(retval):
kwargs[name] = retval
setattr(target, key, attr.__class__(**kwargs))
return target
# Find either the first Python wrapper or the actual function
wrapper_class: (
type[classmethod[Any, Any, Any]] | type[staticmethod[Any, Any]] | None
) = None
if isinstance(target, (classmethod, staticmethod)):
wrapper_class = target.__class__
target = target.__func__
retval = instrument(target)
if isinstance(retval, str):
warn(
f"{retval} -- not typechecking {function_name(target)}",
InstrumentationWarning,
stacklevel=get_stacklevel(),
)
return target
if wrapper_class is None:
return retval
else:
return wrapper_class(retval)

View file

@ -0,0 +1,42 @@
from collections import deque
from typing import Deque
class TypeHintWarning(UserWarning):
"""
A warning that is emitted when a type hint in string form could not be resolved to
an actual type.
"""
class TypeCheckWarning(UserWarning):
"""Emitted by typeguard's type checkers when a type mismatch is detected."""
def __init__(self, message: str):
super().__init__(message)
class InstrumentationWarning(UserWarning):
"""Emitted when there's a problem with instrumenting a function for type checks."""
def __init__(self, message: str):
super().__init__(message)
class TypeCheckError(Exception):
"""
Raised by typeguard's type checkers when a type mismatch is detected.
"""
def __init__(self, message: str):
super().__init__(message)
self._path: Deque[str] = deque()
def append_path_element(self, element: str) -> None:
self._path.append(element)
def __str__(self) -> str:
if self._path:
return " of ".join(self._path) + " " + str(self.args[0])
else:
return str(self.args[0])

308
lib/typeguard/_functions.py Normal file
View file

@ -0,0 +1,308 @@
from __future__ import annotations
import sys
import warnings
from typing import Any, Callable, NoReturn, TypeVar, Union, overload
from . import _suppression
from ._checkers import BINARY_MAGIC_METHODS, check_type_internal
from ._config import (
CollectionCheckStrategy,
ForwardRefPolicy,
TypeCheckConfiguration,
)
from ._exceptions import TypeCheckError, TypeCheckWarning
from ._memo import TypeCheckMemo
from ._utils import get_stacklevel, qualified_name
if sys.version_info >= (3, 11):
from typing import Literal, Never, TypeAlias
else:
from typing_extensions import Literal, Never, TypeAlias
T = TypeVar("T")
TypeCheckFailCallback: TypeAlias = Callable[[TypeCheckError, TypeCheckMemo], Any]
@overload
def check_type(
value: object,
expected_type: type[T],
*,
forward_ref_policy: ForwardRefPolicy = ...,
typecheck_fail_callback: TypeCheckFailCallback | None = ...,
collection_check_strategy: CollectionCheckStrategy = ...,
) -> T: ...
@overload
def check_type(
value: object,
expected_type: Any,
*,
forward_ref_policy: ForwardRefPolicy = ...,
typecheck_fail_callback: TypeCheckFailCallback | None = ...,
collection_check_strategy: CollectionCheckStrategy = ...,
) -> Any: ...
def check_type(
value: object,
expected_type: Any,
*,
forward_ref_policy: ForwardRefPolicy = TypeCheckConfiguration().forward_ref_policy,
typecheck_fail_callback: TypeCheckFailCallback | None = (
TypeCheckConfiguration().typecheck_fail_callback
),
collection_check_strategy: CollectionCheckStrategy = (
TypeCheckConfiguration().collection_check_strategy
),
) -> Any:
"""
Ensure that ``value`` matches ``expected_type``.
The types from the :mod:`typing` module do not support :func:`isinstance` or
:func:`issubclass` so a number of type specific checks are required. This function
knows which checker to call for which type.
This function wraps :func:`~.check_type_internal` in the following ways:
* Respects type checking suppression (:func:`~.suppress_type_checks`)
* Forms a :class:`~.TypeCheckMemo` from the current stack frame
* Calls the configured type check fail callback if the check fails
Note that this function is independent of the globally shared configuration in
:data:`typeguard.config`. This means that usage within libraries is safe from being
affected configuration changes made by other libraries or by the integrating
application. Instead, configuration options have the same default values as their
corresponding fields in :class:`TypeCheckConfiguration`.
:param value: value to be checked against ``expected_type``
:param expected_type: a class or generic type instance, or a tuple of such things
:param forward_ref_policy: see :attr:`TypeCheckConfiguration.forward_ref_policy`
:param typecheck_fail_callback:
see :attr`TypeCheckConfiguration.typecheck_fail_callback`
:param collection_check_strategy:
see :attr:`TypeCheckConfiguration.collection_check_strategy`
:return: ``value``, unmodified
:raises TypeCheckError: if there is a type mismatch
"""
if type(expected_type) is tuple:
expected_type = Union[expected_type]
config = TypeCheckConfiguration(
forward_ref_policy=forward_ref_policy,
typecheck_fail_callback=typecheck_fail_callback,
collection_check_strategy=collection_check_strategy,
)
if _suppression.type_checks_suppressed or expected_type is Any:
return value
frame = sys._getframe(1)
memo = TypeCheckMemo(frame.f_globals, frame.f_locals, config=config)
try:
check_type_internal(value, expected_type, memo)
except TypeCheckError as exc:
exc.append_path_element(qualified_name(value, add_class_prefix=True))
if config.typecheck_fail_callback:
config.typecheck_fail_callback(exc, memo)
else:
raise
return value
def check_argument_types(
func_name: str,
arguments: dict[str, tuple[Any, Any]],
memo: TypeCheckMemo,
) -> Literal[True]:
if _suppression.type_checks_suppressed:
return True
for argname, (value, annotation) in arguments.items():
if annotation is NoReturn or annotation is Never:
exc = TypeCheckError(
f"{func_name}() was declared never to be called but it was"
)
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise exc
try:
check_type_internal(value, annotation, memo)
except TypeCheckError as exc:
qualname = qualified_name(value, add_class_prefix=True)
exc.append_path_element(f'argument "{argname}" ({qualname})')
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise
return True
def check_return_type(
func_name: str,
retval: T,
annotation: Any,
memo: TypeCheckMemo,
) -> T:
if _suppression.type_checks_suppressed:
return retval
if annotation is NoReturn or annotation is Never:
exc = TypeCheckError(f"{func_name}() was declared never to return but it did")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise exc
try:
check_type_internal(retval, annotation, memo)
except TypeCheckError as exc:
# Allow NotImplemented if this is a binary magic method (__eq__() et al)
if retval is NotImplemented and annotation is bool:
# This does (and cannot) not check if it's actually a method
func_name = func_name.rsplit(".", 1)[-1]
if func_name in BINARY_MAGIC_METHODS:
return retval
qualname = qualified_name(retval, add_class_prefix=True)
exc.append_path_element(f"the return value ({qualname})")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise
return retval
def check_send_type(
func_name: str,
sendval: T,
annotation: Any,
memo: TypeCheckMemo,
) -> T:
if _suppression.type_checks_suppressed:
return sendval
if annotation is NoReturn or annotation is Never:
exc = TypeCheckError(
f"{func_name}() was declared never to be sent a value to but it was"
)
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise exc
try:
check_type_internal(sendval, annotation, memo)
except TypeCheckError as exc:
qualname = qualified_name(sendval, add_class_prefix=True)
exc.append_path_element(f"the value sent to generator ({qualname})")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise
return sendval
def check_yield_type(
func_name: str,
yieldval: T,
annotation: Any,
memo: TypeCheckMemo,
) -> T:
if _suppression.type_checks_suppressed:
return yieldval
if annotation is NoReturn or annotation is Never:
exc = TypeCheckError(f"{func_name}() was declared never to yield but it did")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise exc
try:
check_type_internal(yieldval, annotation, memo)
except TypeCheckError as exc:
qualname = qualified_name(yieldval, add_class_prefix=True)
exc.append_path_element(f"the yielded value ({qualname})")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise
return yieldval
def check_variable_assignment(
value: object, varname: str, annotation: Any, memo: TypeCheckMemo
) -> Any:
if _suppression.type_checks_suppressed:
return value
try:
check_type_internal(value, annotation, memo)
except TypeCheckError as exc:
qualname = qualified_name(value, add_class_prefix=True)
exc.append_path_element(f"value assigned to {varname} ({qualname})")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise
return value
def check_multi_variable_assignment(
value: Any, targets: list[dict[str, Any]], memo: TypeCheckMemo
) -> Any:
if max(len(target) for target in targets) == 1:
iterated_values = [value]
else:
iterated_values = list(value)
if not _suppression.type_checks_suppressed:
for expected_types in targets:
value_index = 0
for ann_index, (varname, expected_type) in enumerate(
expected_types.items()
):
if varname.startswith("*"):
varname = varname[1:]
keys_left = len(expected_types) - 1 - ann_index
next_value_index = len(iterated_values) - keys_left
obj: object = iterated_values[value_index:next_value_index]
value_index = next_value_index
else:
obj = iterated_values[value_index]
value_index += 1
try:
check_type_internal(obj, expected_type, memo)
except TypeCheckError as exc:
qualname = qualified_name(obj, add_class_prefix=True)
exc.append_path_element(f"value assigned to {varname} ({qualname})")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise
return iterated_values[0] if len(iterated_values) == 1 else iterated_values
def warn_on_error(exc: TypeCheckError, memo: TypeCheckMemo) -> None:
"""
Emit a warning on a type mismatch.
This is intended to be used as an error handler in
:attr:`TypeCheckConfiguration.typecheck_fail_callback`.
"""
warnings.warn(TypeCheckWarning(str(exc)), stacklevel=get_stacklevel())

View file

@ -0,0 +1,213 @@
from __future__ import annotations
import ast
import sys
import types
from collections.abc import Callable, Iterable
from importlib.abc import MetaPathFinder
from importlib.machinery import ModuleSpec, SourceFileLoader
from importlib.util import cache_from_source, decode_source
from inspect import isclass
from os import PathLike
from types import CodeType, ModuleType, TracebackType
from typing import Sequence, TypeVar
from unittest.mock import patch
from ._config import global_config
from ._transformer import TypeguardTransformer
if sys.version_info >= (3, 12):
from collections.abc import Buffer
else:
from typing_extensions import Buffer
if sys.version_info >= (3, 11):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
if sys.version_info >= (3, 10):
from importlib.metadata import PackageNotFoundError, version
else:
from importlib_metadata import PackageNotFoundError, version
try:
OPTIMIZATION = "typeguard" + "".join(version("typeguard").split(".")[:3])
except PackageNotFoundError:
OPTIMIZATION = "typeguard"
P = ParamSpec("P")
T = TypeVar("T")
# The name of this function is magical
def _call_with_frames_removed(
f: Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> T:
return f(*args, **kwargs)
def optimized_cache_from_source(path: str, debug_override: bool | None = None) -> str:
return cache_from_source(path, debug_override, optimization=OPTIMIZATION)
class TypeguardLoader(SourceFileLoader):
@staticmethod
def source_to_code(
data: Buffer | str | ast.Module | ast.Expression | ast.Interactive,
path: Buffer | str | PathLike[str] = "<string>",
) -> CodeType:
if isinstance(data, (ast.Module, ast.Expression, ast.Interactive)):
tree = data
else:
if isinstance(data, str):
source = data
else:
source = decode_source(data)
tree = _call_with_frames_removed(
ast.parse,
source,
path,
"exec",
)
tree = TypeguardTransformer().visit(tree)
ast.fix_missing_locations(tree)
if global_config.debug_instrumentation and sys.version_info >= (3, 9):
print(
f"Source code of {path!r} after instrumentation:\n"
"----------------------------------------------",
file=sys.stderr,
)
print(ast.unparse(tree), file=sys.stderr)
print("----------------------------------------------", file=sys.stderr)
return _call_with_frames_removed(
compile, tree, path, "exec", 0, dont_inherit=True
)
def exec_module(self, module: ModuleType) -> None:
# Use a custom optimization marker the import lock should make this monkey
# patch safe
with patch(
"importlib._bootstrap_external.cache_from_source",
optimized_cache_from_source,
):
super().exec_module(module)
class TypeguardFinder(MetaPathFinder):
"""
Wraps another path finder and instruments the module with
:func:`@typechecked <typeguard.typechecked>` if :meth:`should_instrument` returns
``True``.
Should not be used directly, but rather via :func:`~.install_import_hook`.
.. versionadded:: 2.6
"""
def __init__(self, packages: list[str] | None, original_pathfinder: MetaPathFinder):
self.packages = packages
self._original_pathfinder = original_pathfinder
def find_spec(
self,
fullname: str,
path: Sequence[str] | None,
target: types.ModuleType | None = None,
) -> ModuleSpec | None:
if self.should_instrument(fullname):
spec = self._original_pathfinder.find_spec(fullname, path, target)
if spec is not None and isinstance(spec.loader, SourceFileLoader):
spec.loader = TypeguardLoader(spec.loader.name, spec.loader.path)
return spec
return None
def should_instrument(self, module_name: str) -> bool:
"""
Determine whether the module with the given name should be instrumented.
:param module_name: full name of the module that is about to be imported (e.g.
``xyz.abc``)
"""
if self.packages is None:
return True
for package in self.packages:
if module_name == package or module_name.startswith(package + "."):
return True
return False
class ImportHookManager:
"""
A handle that can be used to uninstall the Typeguard import hook.
"""
def __init__(self, hook: MetaPathFinder):
self.hook = hook
def __enter__(self) -> None:
pass
def __exit__(
self,
exc_type: type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
) -> None:
self.uninstall()
def uninstall(self) -> None:
"""Uninstall the import hook."""
try:
sys.meta_path.remove(self.hook)
except ValueError:
pass # already removed
def install_import_hook(
packages: Iterable[str] | None = None,
*,
cls: type[TypeguardFinder] = TypeguardFinder,
) -> ImportHookManager:
"""
Install an import hook that instruments functions for automatic type checking.
This only affects modules loaded **after** this hook has been installed.
:param packages: an iterable of package names to instrument, or ``None`` to
instrument all packages
:param cls: a custom meta path finder class
:return: a context manager that uninstalls the hook on exit (or when you call
``.uninstall()``)
.. versionadded:: 2.6
"""
if packages is None:
target_packages: list[str] | None = None
elif isinstance(packages, str):
target_packages = [packages]
else:
target_packages = list(packages)
for finder in sys.meta_path:
if (
isclass(finder)
and finder.__name__ == "PathFinder"
and hasattr(finder, "find_spec")
):
break
else:
raise RuntimeError("Cannot find a PathFinder in sys.meta_path")
hook = cls(target_packages, finder)
sys.meta_path.insert(0, hook)
return ImportHookManager(hook)

48
lib/typeguard/_memo.py Normal file
View file

@ -0,0 +1,48 @@
from __future__ import annotations
from typing import Any
from typeguard._config import TypeCheckConfiguration, global_config
class TypeCheckMemo:
"""
Contains information necessary for type checkers to do their work.
.. attribute:: globals
:type: dict[str, Any]
Dictionary of global variables to use for resolving forward references.
.. attribute:: locals
:type: dict[str, Any]
Dictionary of local variables to use for resolving forward references.
.. attribute:: self_type
:type: type | None
When running type checks within an instance method or class method, this is the
class object that the first argument (usually named ``self`` or ``cls``) refers
to.
.. attribute:: config
:type: TypeCheckConfiguration
Contains the configuration for a particular set of type checking operations.
"""
__slots__ = "globals", "locals", "self_type", "config"
def __init__(
self,
globals: dict[str, Any],
locals: dict[str, Any],
*,
self_type: type | None = None,
config: TypeCheckConfiguration = global_config,
):
self.globals = globals
self.locals = locals
self.self_type = self_type
self.config = config

View file

@ -0,0 +1,126 @@
from __future__ import annotations
import sys
import warnings
from typing import Any, Literal
from pytest import Config, Parser
from typeguard._config import CollectionCheckStrategy, ForwardRefPolicy, global_config
from typeguard._exceptions import InstrumentationWarning
from typeguard._importhook import install_import_hook
from typeguard._utils import qualified_name, resolve_reference
def pytest_addoption(parser: Parser) -> None:
def add_ini_option(
opt_type: (
Literal["string", "paths", "pathlist", "args", "linelist", "bool"] | None
)
) -> None:
parser.addini(
group.options[-1].names()[0][2:],
group.options[-1].attrs()["help"],
opt_type,
)
group = parser.getgroup("typeguard")
group.addoption(
"--typeguard-packages",
action="store",
help="comma separated name list of packages and modules to instrument for "
"type checking, or :all: to instrument all modules loaded after typeguard",
)
add_ini_option("linelist")
group.addoption(
"--typeguard-debug-instrumentation",
action="store_true",
help="print all instrumented code to stderr",
)
add_ini_option("bool")
group.addoption(
"--typeguard-typecheck-fail-callback",
action="store",
help=(
"a module:varname (e.g. typeguard:warn_on_error) reference to a function "
"that is called (with the exception, and memo object as arguments) to "
"handle a TypeCheckError"
),
)
add_ini_option("string")
group.addoption(
"--typeguard-forward-ref-policy",
action="store",
choices=list(ForwardRefPolicy.__members__),
help=(
"determines how to deal with unresolveable forward references in type "
"annotations"
),
)
add_ini_option("string")
group.addoption(
"--typeguard-collection-check-strategy",
action="store",
choices=list(CollectionCheckStrategy.__members__),
help="determines how thoroughly to check collections (list, dict, etc)",
)
add_ini_option("string")
def pytest_configure(config: Config) -> None:
def getoption(name: str) -> Any:
return config.getoption(name.replace("-", "_")) or config.getini(name)
packages: list[str] | None = []
if packages_option := config.getoption("typeguard_packages"):
packages = [pkg.strip() for pkg in packages_option.split(",")]
elif packages_ini := config.getini("typeguard-packages"):
packages = packages_ini
if packages:
if packages == [":all:"]:
packages = None
else:
already_imported_packages = sorted(
package for package in packages if package in sys.modules
)
if already_imported_packages:
warnings.warn(
f"typeguard cannot check these packages because they are already "
f"imported: {', '.join(already_imported_packages)}",
InstrumentationWarning,
stacklevel=1,
)
install_import_hook(packages=packages)
debug_option = getoption("typeguard-debug-instrumentation")
if debug_option:
global_config.debug_instrumentation = True
fail_callback_option = getoption("typeguard-typecheck-fail-callback")
if fail_callback_option:
callback = resolve_reference(fail_callback_option)
if not callable(callback):
raise TypeError(
f"{fail_callback_option} ({qualified_name(callback.__class__)}) is not "
f"a callable"
)
global_config.typecheck_fail_callback = callback
forward_ref_policy_option = getoption("typeguard-forward-ref-policy")
if forward_ref_policy_option:
forward_ref_policy = ForwardRefPolicy.__members__[forward_ref_policy_option]
global_config.forward_ref_policy = forward_ref_policy
collection_check_strategy_option = getoption("typeguard-collection-check-strategy")
if collection_check_strategy_option:
collection_check_strategy = CollectionCheckStrategy.__members__[
collection_check_strategy_option
]
global_config.collection_check_strategy = collection_check_strategy

View file

@ -0,0 +1,86 @@
from __future__ import annotations
import sys
from collections.abc import Callable, Generator
from contextlib import contextmanager
from functools import update_wrapper
from threading import Lock
from typing import ContextManager, TypeVar, overload
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
P = ParamSpec("P")
T = TypeVar("T")
type_checks_suppressed = 0
type_checks_suppress_lock = Lock()
@overload
def suppress_type_checks(func: Callable[P, T]) -> Callable[P, T]: ...
@overload
def suppress_type_checks() -> ContextManager[None]: ...
def suppress_type_checks(
func: Callable[P, T] | None = None
) -> Callable[P, T] | ContextManager[None]:
"""
Temporarily suppress all type checking.
This function has two operating modes, based on how it's used:
#. as a context manager (``with suppress_type_checks(): ...``)
#. as a decorator (``@suppress_type_checks``)
When used as a context manager, :func:`check_type` and any automatically
instrumented functions skip the actual type checking. These context managers can be
nested.
When used as a decorator, all type checking is suppressed while the function is
running.
Type checking will resume once no more context managers are active and no decorated
functions are running.
Both operating modes are thread-safe.
"""
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
global type_checks_suppressed
with type_checks_suppress_lock:
type_checks_suppressed += 1
assert func is not None
try:
return func(*args, **kwargs)
finally:
with type_checks_suppress_lock:
type_checks_suppressed -= 1
def cm() -> Generator[None, None, None]:
global type_checks_suppressed
with type_checks_suppress_lock:
type_checks_suppressed += 1
try:
yield
finally:
with type_checks_suppress_lock:
type_checks_suppressed -= 1
if func is None:
# Context manager mode
return contextmanager(cm)()
else:
# Decorator mode
update_wrapper(wrapper, func)
return wrapper

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,55 @@
"""
Transforms lazily evaluated PEP 604 unions into typing.Unions, for compatibility with
Python versions older than 3.10.
"""
from __future__ import annotations
from ast import (
BinOp,
BitOr,
Index,
Load,
Name,
NodeTransformer,
Subscript,
fix_missing_locations,
parse,
)
from ast import Tuple as ASTTuple
from types import CodeType
from typing import Any, Dict, FrozenSet, List, Set, Tuple, Union
type_substitutions = {
"dict": Dict,
"list": List,
"tuple": Tuple,
"set": Set,
"frozenset": FrozenSet,
"Union": Union,
}
class UnionTransformer(NodeTransformer):
def __init__(self, union_name: Name | None = None):
self.union_name = union_name or Name(id="Union", ctx=Load())
def visit_BinOp(self, node: BinOp) -> Any:
self.generic_visit(node)
if isinstance(node.op, BitOr):
return Subscript(
value=self.union_name,
slice=Index(
ASTTuple(elts=[node.left, node.right], ctx=Load()), ctx=Load()
),
ctx=Load(),
)
return node
def compile_type_hint(hint: str) -> CodeType:
parsed = parse(hint, "<string>", "eval")
UnionTransformer().visit(parsed)
fix_missing_locations(parsed)
return compile(parsed, "<string>", "eval", flags=0)

163
lib/typeguard/_utils.py Normal file
View file

@ -0,0 +1,163 @@
from __future__ import annotations
import inspect
import sys
from importlib import import_module
from inspect import currentframe
from types import CodeType, FrameType, FunctionType
from typing import TYPE_CHECKING, Any, Callable, ForwardRef, Union, cast, final
from weakref import WeakValueDictionary
if TYPE_CHECKING:
from ._memo import TypeCheckMemo
if sys.version_info >= (3, 10):
from typing import get_args, get_origin
def evaluate_forwardref(forwardref: ForwardRef, memo: TypeCheckMemo) -> Any:
return forwardref._evaluate(memo.globals, memo.locals, frozenset())
else:
from typing_extensions import get_args, get_origin
evaluate_extra_args: tuple[frozenset[Any], ...] = (
(frozenset(),) if sys.version_info >= (3, 9) else ()
)
def evaluate_forwardref(forwardref: ForwardRef, memo: TypeCheckMemo) -> Any:
from ._union_transformer import compile_type_hint, type_substitutions
if not forwardref.__forward_evaluated__:
forwardref.__forward_code__ = compile_type_hint(forwardref.__forward_arg__)
try:
return forwardref._evaluate(memo.globals, memo.locals, *evaluate_extra_args)
except NameError:
if sys.version_info < (3, 10):
# Try again, with the type substitutions (list -> List etc.) in place
new_globals = memo.globals.copy()
new_globals.setdefault("Union", Union)
if sys.version_info < (3, 9):
new_globals.update(type_substitutions)
return forwardref._evaluate(
new_globals, memo.locals or new_globals, *evaluate_extra_args
)
raise
_functions_map: WeakValueDictionary[CodeType, FunctionType] = WeakValueDictionary()
def get_type_name(type_: Any) -> str:
name: str
for attrname in "__name__", "_name", "__forward_arg__":
candidate = getattr(type_, attrname, None)
if isinstance(candidate, str):
name = candidate
break
else:
origin = get_origin(type_)
candidate = getattr(origin, "_name", None)
if candidate is None:
candidate = type_.__class__.__name__.strip("_")
if isinstance(candidate, str):
name = candidate
else:
return "(unknown)"
args = get_args(type_)
if args:
if name == "Literal":
formatted_args = ", ".join(repr(arg) for arg in args)
else:
formatted_args = ", ".join(get_type_name(arg) for arg in args)
name += f"[{formatted_args}]"
module = getattr(type_, "__module__", None)
if module and module not in (None, "typing", "typing_extensions", "builtins"):
name = module + "." + name
return name
def qualified_name(obj: Any, *, add_class_prefix: bool = False) -> str:
"""
Return the qualified name (e.g. package.module.Type) for the given object.
Builtins and types from the :mod:`typing` package get special treatment by having
the module name stripped from the generated name.
"""
if obj is None:
return "None"
elif inspect.isclass(obj):
prefix = "class " if add_class_prefix else ""
type_ = obj
else:
prefix = ""
type_ = type(obj)
module = type_.__module__
qualname = type_.__qualname__
name = qualname if module in ("typing", "builtins") else f"{module}.{qualname}"
return prefix + name
def function_name(func: Callable[..., Any]) -> str:
"""
Return the qualified name of the given function.
Builtins and types from the :mod:`typing` package get special treatment by having
the module name stripped from the generated name.
"""
# For partial functions and objects with __call__ defined, __qualname__ does not
# exist
module = getattr(func, "__module__", "")
qualname = (module + ".") if module not in ("builtins", "") else ""
return qualname + getattr(func, "__qualname__", repr(func))
def resolve_reference(reference: str) -> Any:
modulename, varname = reference.partition(":")[::2]
if not modulename or not varname:
raise ValueError(f"{reference!r} is not a module:varname reference")
obj = import_module(modulename)
for attr in varname.split("."):
obj = getattr(obj, attr)
return obj
def is_method_of(obj: object, cls: type) -> bool:
return (
inspect.isfunction(obj)
and obj.__module__ == cls.__module__
and obj.__qualname__.startswith(cls.__qualname__ + ".")
)
def get_stacklevel() -> int:
level = 1
frame = cast(FrameType, currentframe()).f_back
while frame and frame.f_globals.get("__name__", "").startswith("typeguard."):
level += 1
frame = frame.f_back
return level
@final
class Unset:
__slots__ = ()
def __repr__(self) -> str:
return "<unset>"
unset = Unset()

0
lib/typeguard/py.typed Normal file
View file

View file

@ -147,27 +147,6 @@ class _Sentinel:
_marker = _Sentinel() _marker = _Sentinel()
def _check_generic(cls, parameters, elen=_marker):
"""Check correct count for parameters of a generic cls (internal helper).
This gives a nice error message in case of count mismatch.
"""
if not elen:
raise TypeError(f"{cls} is not a generic class")
if elen is _marker:
if not hasattr(cls, "__parameters__") or not cls.__parameters__:
raise TypeError(f"{cls} is not a generic class")
elen = len(cls.__parameters__)
alen = len(parameters)
if alen != elen:
if hasattr(cls, "__parameters__"):
parameters = [p for p in cls.__parameters__ if not _is_unpack(p)]
num_tv_tuples = sum(isinstance(p, TypeVarTuple) for p in parameters)
if (num_tv_tuples > 0) and (alen >= elen - num_tv_tuples):
return
raise TypeError(f"Too {'many' if alen > elen else 'few'} parameters for {cls};"
f" actual {alen}, expected {elen}")
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
def _should_collect_from_parameters(t): def _should_collect_from_parameters(t):
return isinstance( return isinstance(
@ -181,27 +160,6 @@ else:
return isinstance(t, typing._GenericAlias) and not t._special return isinstance(t, typing._GenericAlias) and not t._special
def _collect_type_vars(types, typevar_types=None):
"""Collect all type variable contained in types in order of
first appearance (lexicographic order). For example::
_collect_type_vars((T, List[S, T])) == (T, S)
"""
if typevar_types is None:
typevar_types = typing.TypeVar
tvars = []
for t in types:
if (
isinstance(t, typevar_types) and
t not in tvars and
not _is_unpack(t)
):
tvars.append(t)
if _should_collect_from_parameters(t):
tvars.extend([t for t in t.__parameters__ if t not in tvars])
return tuple(tvars)
NoReturn = typing.NoReturn NoReturn = typing.NoReturn
# Some unconstrained type variables. These are used by the container types. # Some unconstrained type variables. These are used by the container types.
@ -834,7 +792,11 @@ def _ensure_subclassable(mro_entries):
return inner return inner
if hasattr(typing, "ReadOnly"): # Update this to something like >=3.13.0b1 if and when
# PEP 728 is implemented in CPython
_PEP_728_IMPLEMENTED = False
if _PEP_728_IMPLEMENTED:
# The standard library TypedDict in Python 3.8 does not store runtime information # The standard library TypedDict in Python 3.8 does not store runtime information
# about which (if any) keys are optional. See https://bugs.python.org/issue38834 # about which (if any) keys are optional. See https://bugs.python.org/issue38834
# The standard library TypedDict in Python 3.9.0/1 does not honour the "total" # The standard library TypedDict in Python 3.9.0/1 does not honour the "total"
@ -845,7 +807,8 @@ if hasattr(typing, "ReadOnly"):
# Aaaand on 3.12 we add __orig_bases__ to TypedDict # Aaaand on 3.12 we add __orig_bases__ to TypedDict
# to enable better runtime introspection. # to enable better runtime introspection.
# On 3.13 we deprecate some odd ways of creating TypedDicts. # On 3.13 we deprecate some odd ways of creating TypedDicts.
# PEP 705 proposes adding the ReadOnly[] qualifier. # Also on 3.13, PEP 705 adds the ReadOnly[] qualifier.
# PEP 728 (still pending) makes more changes.
TypedDict = typing.TypedDict TypedDict = typing.TypedDict
_TypedDictMeta = typing._TypedDictMeta _TypedDictMeta = typing._TypedDictMeta
is_typeddict = typing.is_typeddict is_typeddict = typing.is_typeddict
@ -1122,15 +1085,15 @@ else:
return val return val
if hasattr(typing, "Required"): # 3.11+ if hasattr(typing, "ReadOnly"): # 3.13+
get_type_hints = typing.get_type_hints get_type_hints = typing.get_type_hints
else: # <=3.10 else: # <=3.13
# replaces _strip_annotations() # replaces _strip_annotations()
def _strip_extras(t): def _strip_extras(t):
"""Strips Annotated, Required and NotRequired from a given type.""" """Strips Annotated, Required and NotRequired from a given type."""
if isinstance(t, _AnnotatedAlias): if isinstance(t, _AnnotatedAlias):
return _strip_extras(t.__origin__) return _strip_extras(t.__origin__)
if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired): if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired, ReadOnly):
return _strip_extras(t.__args__[0]) return _strip_extras(t.__args__[0])
if isinstance(t, typing._GenericAlias): if isinstance(t, typing._GenericAlias):
stripped_args = tuple(_strip_extras(a) for a in t.__args__) stripped_args = tuple(_strip_extras(a) for a in t.__args__)
@ -2689,9 +2652,151 @@ else:
# counting generic parameters, so that when we subscript a generic, # counting generic parameters, so that when we subscript a generic,
# the runtime doesn't try to substitute the Unpack with the subscripted type. # the runtime doesn't try to substitute the Unpack with the subscripted type.
if not hasattr(typing, "TypeVarTuple"): if not hasattr(typing, "TypeVarTuple"):
typing._collect_type_vars = _collect_type_vars def _check_generic(cls, parameters, elen=_marker):
"""Check correct count for parameters of a generic cls (internal helper).
This gives a nice error message in case of count mismatch.
"""
if not elen:
raise TypeError(f"{cls} is not a generic class")
if elen is _marker:
if not hasattr(cls, "__parameters__") or not cls.__parameters__:
raise TypeError(f"{cls} is not a generic class")
elen = len(cls.__parameters__)
alen = len(parameters)
if alen != elen:
expect_val = elen
if hasattr(cls, "__parameters__"):
parameters = [p for p in cls.__parameters__ if not _is_unpack(p)]
num_tv_tuples = sum(isinstance(p, TypeVarTuple) for p in parameters)
if (num_tv_tuples > 0) and (alen >= elen - num_tv_tuples):
return
# deal with TypeVarLike defaults
# required TypeVarLikes cannot appear after a defaulted one.
if alen < elen:
# since we validate TypeVarLike default in _collect_type_vars
# or _collect_parameters we can safely check parameters[alen]
if getattr(parameters[alen], '__default__', None) is not None:
return
num_default_tv = sum(getattr(p, '__default__', None)
is not None for p in parameters)
elen -= num_default_tv
expect_val = f"at least {elen}"
things = "arguments" if sys.version_info >= (3, 10) else "parameters"
raise TypeError(f"Too {'many' if alen > elen else 'few'} {things}"
f" for {cls}; actual {alen}, expected {expect_val}")
else:
# Python 3.11+
def _check_generic(cls, parameters, elen):
"""Check correct count for parameters of a generic cls (internal helper).
This gives a nice error message in case of count mismatch.
"""
if not elen:
raise TypeError(f"{cls} is not a generic class")
alen = len(parameters)
if alen != elen:
expect_val = elen
if hasattr(cls, "__parameters__"):
parameters = [p for p in cls.__parameters__ if not _is_unpack(p)]
# deal with TypeVarLike defaults
# required TypeVarLikes cannot appear after a defaulted one.
if alen < elen:
# since we validate TypeVarLike default in _collect_type_vars
# or _collect_parameters we can safely check parameters[alen]
if getattr(parameters[alen], '__default__', None) is not None:
return
num_default_tv = sum(getattr(p, '__default__', None)
is not None for p in parameters)
elen -= num_default_tv
expect_val = f"at least {elen}"
raise TypeError(f"Too {'many' if alen > elen else 'few'} arguments"
f" for {cls}; actual {alen}, expected {expect_val}")
typing._check_generic = _check_generic typing._check_generic = _check_generic
# Python 3.11+ _collect_type_vars was renamed to _collect_parameters
if hasattr(typing, '_collect_type_vars'):
def _collect_type_vars(types, typevar_types=None):
"""Collect all type variable contained in types in order of
first appearance (lexicographic order). For example::
_collect_type_vars((T, List[S, T])) == (T, S)
"""
if typevar_types is None:
typevar_types = typing.TypeVar
tvars = []
# required TypeVarLike cannot appear after TypeVarLike with default
default_encountered = False
for t in types:
if (
isinstance(t, typevar_types) and
t not in tvars and
not _is_unpack(t)
):
if getattr(t, '__default__', None) is not None:
default_encountered = True
elif default_encountered:
raise TypeError(f'Type parameter {t!r} without a default'
' follows type parameter with a default')
tvars.append(t)
if _should_collect_from_parameters(t):
tvars.extend([t for t in t.__parameters__ if t not in tvars])
return tuple(tvars)
typing._collect_type_vars = _collect_type_vars
else:
def _collect_parameters(args):
"""Collect all type variables and parameter specifications in args
in order of first appearance (lexicographic order).
For example::
assert _collect_parameters((T, Callable[P, T])) == (T, P)
"""
parameters = []
# required TypeVarLike cannot appear after TypeVarLike with default
default_encountered = False
for t in args:
if isinstance(t, type):
# We don't want __parameters__ descriptor of a bare Python class.
pass
elif isinstance(t, tuple):
# `t` might be a tuple, when `ParamSpec` is substituted with
# `[T, int]`, or `[int, *Ts]`, etc.
for x in t:
for collected in _collect_parameters([x]):
if collected not in parameters:
parameters.append(collected)
elif hasattr(t, '__typing_subst__'):
if t not in parameters:
if getattr(t, '__default__', None) is not None:
default_encountered = True
elif default_encountered:
raise TypeError(f'Type parameter {t!r} without a default'
' follows type parameter with a default')
parameters.append(t)
else:
for x in getattr(t, '__parameters__', ()):
if x not in parameters:
parameters.append(x)
return tuple(parameters)
typing._collect_parameters = _collect_parameters
# Backport typing.NamedTuple as it exists in Python 3.13. # Backport typing.NamedTuple as it exists in Python 3.13.
# In 3.11, the ability to define generic `NamedTuple`s was supported. # In 3.11, the ability to define generic `NamedTuple`s was supported.

View file

@ -2,7 +2,7 @@
__init__.py __init__.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -17,10 +17,10 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from ._abnf import * from ._abnf import *
from ._app import WebSocketApp, setReconnect from ._app import WebSocketApp as WebSocketApp, setReconnect as setReconnect
from ._core import * from ._core import *
from ._exceptions import * from ._exceptions import *
from ._logging import * from ._logging import *
from ._socket import * from ._socket import *
__version__ = "1.7.0" __version__ = "1.8.0"

View file

@ -5,14 +5,14 @@ import sys
from threading import Lock from threading import Lock
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
from ._exceptions import * from ._exceptions import WebSocketPayloadException, WebSocketProtocolException
from ._utils import validate_utf8 from ._utils import validate_utf8
""" """
_abnf.py _abnf.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.

View file

@ -13,13 +13,14 @@ from ._exceptions import (
WebSocketException, WebSocketException,
WebSocketTimeoutException, WebSocketTimeoutException,
) )
from ._ssl_compat import SSLEOFError
from ._url import parse_url from ._url import parse_url
""" """
_app.py _app.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -165,6 +166,7 @@ class WebSocketApp:
url: str, url: str,
header: Union[list, dict, Callable, None] = None, header: Union[list, dict, Callable, None] = None,
on_open: Optional[Callable[[WebSocket], None]] = None, on_open: Optional[Callable[[WebSocket], None]] = None,
on_reconnect: Optional[Callable[[WebSocket], None]] = None,
on_message: Optional[Callable[[WebSocket, Any], None]] = None, on_message: Optional[Callable[[WebSocket, Any], None]] = None,
on_error: Optional[Callable[[WebSocket, Any], None]] = None, on_error: Optional[Callable[[WebSocket, Any], None]] = None,
on_close: Optional[Callable[[WebSocket, Any, Any], None]] = None, on_close: Optional[Callable[[WebSocket, Any, Any], None]] = None,
@ -194,6 +196,10 @@ class WebSocketApp:
Callback object which is called at opening websocket. Callback object which is called at opening websocket.
on_open has one argument. on_open has one argument.
The 1st argument is this class object. The 1st argument is this class object.
on_reconnect: function
Callback object which is called at reconnecting websocket.
on_reconnect has one argument.
The 1st argument is this class object.
on_message: function on_message: function
Callback object which is called when received data. Callback object which is called when received data.
on_message has 2 arguments. on_message has 2 arguments.
@ -244,6 +250,7 @@ class WebSocketApp:
self.cookie = cookie self.cookie = cookie
self.on_open = on_open self.on_open = on_open
self.on_reconnect = on_reconnect
self.on_message = on_message self.on_message = on_message
self.on_data = on_data self.on_data = on_data
self.on_error = on_error self.on_error = on_error
@ -424,6 +431,7 @@ class WebSocketApp:
self.ping_interval = ping_interval self.ping_interval = ping_interval
self.ping_timeout = ping_timeout self.ping_timeout = ping_timeout
self.ping_payload = ping_payload self.ping_payload = ping_payload
self.has_done_teardown = False
self.keep_running = True self.keep_running = True
def teardown(close_frame: ABNF = None): def teardown(close_frame: ABNF = None):
@ -495,6 +503,9 @@ class WebSocketApp:
if self.ping_interval: if self.ping_interval:
self._start_ping_thread() self._start_ping_thread()
if reconnecting and self.on_reconnect:
self._callback(self.on_reconnect)
else:
self._callback(self.on_open) self._callback(self.on_open)
dispatcher.read(self.sock.sock, read, check) dispatcher.read(self.sock.sock, read, check)
@ -516,9 +527,10 @@ class WebSocketApp:
except ( except (
WebSocketConnectionClosedException, WebSocketConnectionClosedException,
KeyboardInterrupt, KeyboardInterrupt,
SSLEOFError,
) as e: ) as e:
if custom_dispatcher: if custom_dispatcher:
return handleDisconnect(e) return handleDisconnect(e, bool(reconnect))
else: else:
raise e raise e

View file

@ -5,7 +5,7 @@ from typing import Optional
_cookiejar.py _cookiejar.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -23,13 +23,13 @@ limitations under the License.
class SimpleCookieJar: class SimpleCookieJar:
def __init__(self) -> None: def __init__(self) -> None:
self.jar: dict = dict() self.jar: dict = {}
def add(self, set_cookie: Optional[str]) -> None: def add(self, set_cookie: Optional[str]) -> None:
if set_cookie: if set_cookie:
simpleCookie = http.cookies.SimpleCookie(set_cookie) simple_cookie = http.cookies.SimpleCookie(set_cookie)
for k, v in simpleCookie.items(): for v in simple_cookie.values():
if domain := v.get("domain"): if domain := v.get("domain"):
if not domain.startswith("."): if not domain.startswith("."):
domain = f".{domain}" domain = f".{domain}"
@ -38,25 +38,25 @@ class SimpleCookieJar:
if self.jar.get(domain) if self.jar.get(domain)
else http.cookies.SimpleCookie() else http.cookies.SimpleCookie()
) )
cookie.update(simpleCookie) cookie.update(simple_cookie)
self.jar[domain.lower()] = cookie self.jar[domain.lower()] = cookie
def set(self, set_cookie: str) -> None: def set(self, set_cookie: str) -> None:
if set_cookie: if set_cookie:
simpleCookie = http.cookies.SimpleCookie(set_cookie) simple_cookie = http.cookies.SimpleCookie(set_cookie)
for k, v in simpleCookie.items(): for v in simple_cookie.values():
if domain := v.get("domain"): if domain := v.get("domain"):
if not domain.startswith("."): if not domain.startswith("."):
domain = f".{domain}" domain = f".{domain}"
self.jar[domain.lower()] = simpleCookie self.jar[domain.lower()] = simple_cookie
def get(self, host: str) -> str: def get(self, host: str) -> str:
if not host: if not host:
return "" return ""
cookies = [] cookies = []
for domain, simpleCookie in self.jar.items(): for domain, _ in self.jar.items():
host = host.lower() host = host.lower()
if host.endswith(domain) or host == domain[1:]: if host.endswith(domain) or host == domain[1:]:
cookies.append(self.jar.get(domain)) cookies.append(self.jar.get(domain))
@ -66,7 +66,7 @@ class SimpleCookieJar:
None, None,
sorted( sorted(
[ [
"%s=%s" % (k, v.value) f"{k}={v.value}"
for cookie in filter(None, cookies) for cookie in filter(None, cookies)
for k, v in cookie.items() for k, v in cookie.items()
] ]

View file

@ -5,20 +5,20 @@ import time
from typing import Optional, Union from typing import Optional, Union
# websocket modules # websocket modules
from ._abnf import * from ._abnf import ABNF, STATUS_NORMAL, continuous_frame, frame_buffer
from ._exceptions import * from ._exceptions import WebSocketProtocolException, WebSocketConnectionClosedException
from ._handshake import * from ._handshake import SUPPORTED_REDIRECT_STATUSES, handshake
from ._http import * from ._http import connect, proxy_info
from ._logging import * from ._logging import debug, error, trace, isEnabledForError, isEnabledForTrace
from ._socket import * from ._socket import getdefaulttimeout, recv, send, sock_opt
from ._ssl_compat import * from ._ssl_compat import ssl
from ._utils import * from ._utils import NoLock
""" """
_core.py _core.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -259,7 +259,7 @@ class WebSocket:
try: try:
self.handshake_response = handshake(self.sock, url, *addrs, **options) self.handshake_response = handshake(self.sock, url, *addrs, **options)
for attempt in range(options.pop("redirect_limit", 3)): for _ in range(options.pop("redirect_limit", 3)):
if self.handshake_response.status in SUPPORTED_REDIRECT_STATUSES: if self.handshake_response.status in SUPPORTED_REDIRECT_STATUSES:
url = self.handshake_response.headers["location"] url = self.handshake_response.headers["location"]
self.sock.close() self.sock.close()

View file

@ -2,7 +2,7 @@
_exceptions.py _exceptions.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.

View file

@ -2,7 +2,7 @@
_handshake.py _handshake.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -23,10 +23,10 @@ from base64 import encodebytes as base64encode
from http import HTTPStatus from http import HTTPStatus
from ._cookiejar import SimpleCookieJar from ._cookiejar import SimpleCookieJar
from ._exceptions import * from ._exceptions import WebSocketException, WebSocketBadStatusException
from ._http import * from ._http import read_headers
from ._logging import * from ._logging import dump, error
from ._socket import * from ._socket import send
__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"] __all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]

View file

@ -2,7 +2,7 @@
_http.py _http.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -21,11 +21,15 @@ import os
import socket import socket
from base64 import encodebytes as base64encode from base64 import encodebytes as base64encode
from ._exceptions import * from ._exceptions import (
from ._logging import * WebSocketAddressException,
from ._socket import * WebSocketException,
from ._ssl_compat import * WebSocketProxyException,
from ._url import * )
from ._logging import debug, dump, trace
from ._socket import DEFAULT_SOCKET_OPTION, recv_line, send
from ._ssl_compat import HAVE_SSL, ssl
from ._url import get_proxy_info, parse_url
__all__ = ["proxy_info", "connect", "read_headers"] __all__ = ["proxy_info", "connect", "read_headers"]
@ -283,22 +287,22 @@ def _wrap_sni_socket(sock: socket.socket, sslopt: dict, hostname, check_hostname
def _ssl_socket(sock: socket.socket, user_sslopt: dict, hostname): def _ssl_socket(sock: socket.socket, user_sslopt: dict, hostname):
sslopt: dict = dict(cert_reqs=ssl.CERT_REQUIRED) sslopt: dict = {"cert_reqs": ssl.CERT_REQUIRED}
sslopt.update(user_sslopt) sslopt.update(user_sslopt)
certPath = os.environ.get("WEBSOCKET_CLIENT_CA_BUNDLE") cert_path = os.environ.get("WEBSOCKET_CLIENT_CA_BUNDLE")
if ( if (
certPath cert_path
and os.path.isfile(certPath) and os.path.isfile(cert_path)
and user_sslopt.get("ca_certs", None) is None and user_sslopt.get("ca_certs", None) is None
): ):
sslopt["ca_certs"] = certPath sslopt["ca_certs"] = cert_path
elif ( elif (
certPath cert_path
and os.path.isdir(certPath) and os.path.isdir(cert_path)
and user_sslopt.get("ca_cert_path", None) is None and user_sslopt.get("ca_cert_path", None) is None
): ):
sslopt["ca_cert_path"] = certPath sslopt["ca_cert_path"] = cert_path
if sslopt.get("server_hostname", None): if sslopt.get("server_hostname", None):
hostname = sslopt["server_hostname"] hostname = sslopt["server_hostname"]
@ -327,7 +331,7 @@ def _tunnel(sock: socket.socket, host, port: int, auth) -> socket.socket:
send(sock, connect_header) send(sock, connect_header)
try: try:
status, resp_headers, status_message = read_headers(sock) status, _, _ = read_headers(sock)
except Exception as e: except Exception as e:
raise WebSocketProxyException(str(e)) raise WebSocketProxyException(str(e))

View file

@ -4,7 +4,7 @@ import logging
_logging.py _logging.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.

View file

@ -3,15 +3,18 @@ import selectors
import socket import socket
from typing import Union from typing import Union
from ._exceptions import * from ._exceptions import (
from ._ssl_compat import * WebSocketConnectionClosedException,
from ._utils import * WebSocketTimeoutException,
)
from ._ssl_compat import SSLError, SSLWantReadError, SSLWantWriteError
from ._utils import extract_error_code, extract_err_message
""" """
_socket.py _socket.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.

View file

@ -2,7 +2,7 @@
_ssl_compat.py _ssl_compat.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -16,11 +16,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
__all__ = ["HAVE_SSL", "ssl", "SSLError", "SSLWantReadError", "SSLWantWriteError"] __all__ = [
"HAVE_SSL",
"ssl",
"SSLError",
"SSLEOFError",
"SSLWantReadError",
"SSLWantWriteError",
]
try: try:
import ssl import ssl
from ssl import SSLError, SSLWantReadError, SSLWantWriteError from ssl import SSLError, SSLEOFError, SSLWantReadError, SSLWantWriteError
HAVE_SSL = True HAVE_SSL = True
except ImportError: except ImportError:
@ -28,6 +35,9 @@ except ImportError:
class SSLError(Exception): class SSLError(Exception):
pass pass
class SSLEOFError(Exception):
pass
class SSLWantReadError(Exception): class SSLWantReadError(Exception):
pass pass

View file

@ -3,12 +3,13 @@ import socket
import struct import struct
from typing import Optional from typing import Optional
from urllib.parse import unquote, urlparse from urllib.parse import unquote, urlparse
from ._exceptions import WebSocketProxyException
""" """
_url.py _url.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -167,6 +168,8 @@ def get_proxy_info(
return None, 0, None return None, 0, None
if proxy_host: if proxy_host:
if not proxy_port:
raise WebSocketProxyException("Cannot use port 0 when proxy_host specified")
port = proxy_port port = proxy_port
auth = proxy_auth auth = proxy_auth
return proxy_host, port, auth return proxy_host, port, auth

View file

@ -4,7 +4,7 @@ from typing import Union
_url.py _url.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.

View file

@ -4,7 +4,7 @@
wsdump.py wsdump.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.

View file

@ -10,7 +10,7 @@ import websockets
LOCAL_WS_SERVER_PORT = int(os.environ.get("LOCAL_WS_SERVER_PORT", "8765")) LOCAL_WS_SERVER_PORT = int(os.environ.get("LOCAL_WS_SERVER_PORT", "8765"))
async def echo(websocket, path): async def echo(websocket):
async for message in websocket: async for message in websocket:
await websocket.send(message) await websocket.send(message)

View file

@ -2,14 +2,14 @@
# #
import unittest import unittest
import websocket as ws from websocket._abnf import ABNF, frame_buffer
from websocket._abnf import * from websocket._exceptions import WebSocketProtocolException
""" """
test_abnf.py test_abnf.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -26,7 +26,7 @@ limitations under the License.
class ABNFTest(unittest.TestCase): class ABNFTest(unittest.TestCase):
def testInit(self): def test_init(self):
a = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_PING) a = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_PING)
self.assertEqual(a.fin, 0) self.assertEqual(a.fin, 0)
self.assertEqual(a.rsv1, 0) self.assertEqual(a.rsv1, 0)
@ -38,28 +38,28 @@ class ABNFTest(unittest.TestCase):
self.assertEqual(a_bad.rsv1, 1) self.assertEqual(a_bad.rsv1, 1)
self.assertEqual(a_bad.opcode, 77) self.assertEqual(a_bad.opcode, 77)
def testValidate(self): def test_validate(self):
a_invalid_ping = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_PING) a_invalid_ping = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_PING)
self.assertRaises( self.assertRaises(
ws._exceptions.WebSocketProtocolException, WebSocketProtocolException,
a_invalid_ping.validate, a_invalid_ping.validate,
skip_utf8_validation=False, skip_utf8_validation=False,
) )
a_bad_rsv_value = ABNF(0, 1, 0, 0, opcode=ABNF.OPCODE_TEXT) a_bad_rsv_value = ABNF(0, 1, 0, 0, opcode=ABNF.OPCODE_TEXT)
self.assertRaises( self.assertRaises(
ws._exceptions.WebSocketProtocolException, WebSocketProtocolException,
a_bad_rsv_value.validate, a_bad_rsv_value.validate,
skip_utf8_validation=False, skip_utf8_validation=False,
) )
a_bad_opcode = ABNF(0, 0, 0, 0, opcode=77) a_bad_opcode = ABNF(0, 0, 0, 0, opcode=77)
self.assertRaises( self.assertRaises(
ws._exceptions.WebSocketProtocolException, WebSocketProtocolException,
a_bad_opcode.validate, a_bad_opcode.validate,
skip_utf8_validation=False, skip_utf8_validation=False,
) )
a_bad_close_frame = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x01") a_bad_close_frame = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x01")
self.assertRaises( self.assertRaises(
ws._exceptions.WebSocketProtocolException, WebSocketProtocolException,
a_bad_close_frame.validate, a_bad_close_frame.validate,
skip_utf8_validation=False, skip_utf8_validation=False,
) )
@ -67,7 +67,7 @@ class ABNFTest(unittest.TestCase):
0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x01\x8a\xaa\xff\xdd" 0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x01\x8a\xaa\xff\xdd"
) )
self.assertRaises( self.assertRaises(
ws._exceptions.WebSocketProtocolException, WebSocketProtocolException,
a_bad_close_frame_2.validate, a_bad_close_frame_2.validate,
skip_utf8_validation=False, skip_utf8_validation=False,
) )
@ -75,12 +75,12 @@ class ABNFTest(unittest.TestCase):
0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x03\xe7" 0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x03\xe7"
) )
self.assertRaises( self.assertRaises(
ws._exceptions.WebSocketProtocolException, WebSocketProtocolException,
a_bad_close_frame_3.validate, a_bad_close_frame_3.validate,
skip_utf8_validation=True, skip_utf8_validation=True,
) )
def testMask(self): def test_mask(self):
abnf_none_data = ABNF( abnf_none_data = ABNF(
0, 0, 0, 0, opcode=ABNF.OPCODE_PING, mask_value=1, data=None 0, 0, 0, 0, opcode=ABNF.OPCODE_PING, mask_value=1, data=None
) )
@ -91,7 +91,7 @@ class ABNFTest(unittest.TestCase):
) )
self.assertEqual(abnf_str_data._get_masked(bytes_val), b"aaaa\x00") self.assertEqual(abnf_str_data._get_masked(bytes_val), b"aaaa\x00")
def testFormat(self): def test_format(self):
abnf_bad_rsv_bits = ABNF(2, 0, 0, 0, opcode=ABNF.OPCODE_TEXT) abnf_bad_rsv_bits = ABNF(2, 0, 0, 0, opcode=ABNF.OPCODE_TEXT)
self.assertRaises(ValueError, abnf_bad_rsv_bits.format) self.assertRaises(ValueError, abnf_bad_rsv_bits.format)
abnf_bad_opcode = ABNF(0, 0, 0, 0, opcode=5) abnf_bad_opcode = ABNF(0, 0, 0, 0, opcode=5)
@ -110,7 +110,7 @@ class ABNFTest(unittest.TestCase):
) )
self.assertEqual(b"\x01\x03\x01\x8a\xcc", abnf_no_mask.format()) self.assertEqual(b"\x01\x03\x01\x8a\xcc", abnf_no_mask.format())
def testFrameBuffer(self): def test_frame_buffer(self):
fb = frame_buffer(0, True) fb = frame_buffer(0, True)
self.assertEqual(fb.recv, 0) self.assertEqual(fb.recv, 0)
self.assertEqual(fb.skip_utf8_validation, True) self.assertEqual(fb.skip_utf8_validation, True)

View file

@ -12,7 +12,7 @@ import websocket as ws
test_app.py test_app.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -53,10 +53,13 @@ class WebSocketAppTest(unittest.TestCase):
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet() WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
WebSocketAppTest.on_error_data = WebSocketAppTest.NotSetYet() WebSocketAppTest.on_error_data = WebSocketAppTest.NotSetYet()
def close(self):
pass
@unittest.skipUnless( @unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
) )
def testKeepRunning(self): def test_keep_running(self):
"""A WebSocketApp should keep running as long as its self.keep_running """A WebSocketApp should keep running as long as its self.keep_running
is not False (in the boolean context). is not False (in the boolean context).
""" """
@ -69,7 +72,7 @@ class WebSocketAppTest(unittest.TestCase):
WebSocketAppTest.keep_running_open = self.keep_running WebSocketAppTest.keep_running_open = self.keep_running
self.keep_running = False self.keep_running = False
def on_message(wsapp, message): def on_message(_, message):
print(message) print(message)
self.close() self.close()
@ -87,7 +90,7 @@ class WebSocketAppTest(unittest.TestCase):
# @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled") # @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled")
@unittest.skipUnless(False, "Test disabled for now (requires rel)") @unittest.skipUnless(False, "Test disabled for now (requires rel)")
def testRunForeverDispatcher(self): def test_run_forever_dispatcher(self):
"""A WebSocketApp should keep running as long as its self.keep_running """A WebSocketApp should keep running as long as its self.keep_running
is not False (in the boolean context). is not False (in the boolean context).
""" """
@ -98,7 +101,7 @@ class WebSocketAppTest(unittest.TestCase):
self.recv() self.recv()
self.send("goodbye!") self.send("goodbye!")
def on_message(wsapp, message): def on_message(_, message):
print(message) print(message)
self.close() self.close()
@ -115,7 +118,7 @@ class WebSocketAppTest(unittest.TestCase):
@unittest.skipUnless( @unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
) )
def testRunForeverTeardownCleanExit(self): def test_run_forever_teardown_clean_exit(self):
"""The WebSocketApp.run_forever() method should return `False` when the application ends gracefully.""" """The WebSocketApp.run_forever() method should return `False` when the application ends gracefully."""
app = ws.WebSocketApp(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}") app = ws.WebSocketApp(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
threading.Timer(interval=0.2, function=app.close).start() threading.Timer(interval=0.2, function=app.close).start()
@ -123,7 +126,7 @@ class WebSocketAppTest(unittest.TestCase):
self.assertEqual(teardown, False) self.assertEqual(teardown, False)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSockMaskKey(self): def test_sock_mask_key(self):
"""A WebSocketApp should forward the received mask_key function down """A WebSocketApp should forward the received mask_key function down
to the actual socket. to the actual socket.
""" """
@ -140,14 +143,14 @@ class WebSocketAppTest(unittest.TestCase):
self.assertEqual(id(app.get_mask_key), id(my_mask_key_func)) self.assertEqual(id(app.get_mask_key), id(my_mask_key_func))
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testInvalidPingIntervalPingTimeout(self): def test_invalid_ping_interval_ping_timeout(self):
"""Test exception handling if ping_interval < ping_timeout""" """Test exception handling if ping_interval < ping_timeout"""
def on_ping(app, msg): def on_ping(app, _):
print("Got a ping!") print("Got a ping!")
app.close() app.close()
def on_pong(app, msg): def on_pong(app, _):
print("Got a pong! No need to respond") print("Got a pong! No need to respond")
app.close() app.close()
@ -163,14 +166,14 @@ class WebSocketAppTest(unittest.TestCase):
) )
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testPingInterval(self): def test_ping_interval(self):
"""Test WebSocketApp proper ping functionality""" """Test WebSocketApp proper ping functionality"""
def on_ping(app, msg): def on_ping(app, _):
print("Got a ping!") print("Got a ping!")
app.close() app.close()
def on_pong(app, msg): def on_pong(app, _):
print("Got a pong! No need to respond") print("Got a pong! No need to respond")
app.close() app.close()
@ -182,7 +185,7 @@ class WebSocketAppTest(unittest.TestCase):
) )
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testOpcodeClose(self): def test_opcode_close(self):
"""Test WebSocketApp close opcode""" """Test WebSocketApp close opcode"""
app = ws.WebSocketApp("wss://tsock.us1.twilio.com/v3/wsconnect") app = ws.WebSocketApp("wss://tsock.us1.twilio.com/v3/wsconnect")
@ -197,7 +200,7 @@ class WebSocketAppTest(unittest.TestCase):
# app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload") # app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload")
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testBadPingInterval(self): def test_bad_ping_interval(self):
"""A WebSocketApp handling of negative ping_interval""" """A WebSocketApp handling of negative ping_interval"""
app = ws.WebSocketApp("wss://api-pub.bitfinex.com/ws/1") app = ws.WebSocketApp("wss://api-pub.bitfinex.com/ws/1")
self.assertRaises( self.assertRaises(
@ -208,7 +211,7 @@ class WebSocketAppTest(unittest.TestCase):
) )
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testBadPingTimeout(self): def test_bad_ping_timeout(self):
"""A WebSocketApp handling of negative ping_timeout""" """A WebSocketApp handling of negative ping_timeout"""
app = ws.WebSocketApp("wss://api-pub.bitfinex.com/ws/1") app = ws.WebSocketApp("wss://api-pub.bitfinex.com/ws/1")
self.assertRaises( self.assertRaises(
@ -219,7 +222,7 @@ class WebSocketAppTest(unittest.TestCase):
) )
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testCloseStatusCode(self): def test_close_status_code(self):
"""Test extraction of close frame status code and close reason in WebSocketApp""" """Test extraction of close frame status code and close reason in WebSocketApp"""
def on_close(wsapp, close_status_code, close_msg): def on_close(wsapp, close_status_code, close_msg):
@ -249,7 +252,7 @@ class WebSocketAppTest(unittest.TestCase):
@unittest.skipUnless( @unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
) )
def testCallbackFunctionException(self): def test_callback_function_exception(self):
"""Test callback function exception handling""" """Test callback function exception handling"""
exc = None exc = None
@ -264,7 +267,7 @@ class WebSocketAppTest(unittest.TestCase):
nonlocal exc nonlocal exc
exc = err exc = err
def on_pong(app, msg): def on_pong(app, _):
app.close() app.close()
app = ws.WebSocketApp( app = ws.WebSocketApp(
@ -282,7 +285,7 @@ class WebSocketAppTest(unittest.TestCase):
@unittest.skipUnless( @unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
) )
def testCallbackMethodException(self): def test_callback_method_exception(self):
"""Test callback method exception handling""" """Test callback method exception handling"""
class Callbacks: class Callbacks:
@ -297,14 +300,14 @@ class WebSocketAppTest(unittest.TestCase):
) )
self.app.run_forever(ping_interval=2, ping_timeout=1) self.app.run_forever(ping_interval=2, ping_timeout=1)
def on_open(self, app): def on_open(self, _):
raise RuntimeError("Callback failed") raise RuntimeError("Callback failed")
def on_error(self, app, err): def on_error(self, app, err):
self.passed_app = app self.passed_app = app
self.exc = err self.exc = err
def on_pong(self, app, msg): def on_pong(self, app, _):
app.close() app.close()
callbacks = Callbacks() callbacks = Callbacks()
@ -316,16 +319,16 @@ class WebSocketAppTest(unittest.TestCase):
@unittest.skipUnless( @unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
) )
def testReconnect(self): def test_reconnect(self):
"""Test reconnect""" """Test reconnect"""
pong_count = 0 pong_count = 0
exc = None exc = None
def on_error(app, err): def on_error(_, err):
nonlocal exc nonlocal exc
exc = err exc = err
def on_pong(app, msg): def on_pong(app, _):
nonlocal pong_count nonlocal pong_count
pong_count += 1 pong_count += 1
if pong_count == 1: if pong_count == 1:

View file

@ -6,7 +6,7 @@ from websocket._cookiejar import SimpleCookieJar
test_cookiejar.py test_cookiejar.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -23,7 +23,7 @@ limitations under the License.
class CookieJarTest(unittest.TestCase): class CookieJarTest(unittest.TestCase):
def testAdd(self): def test_add(self):
cookie_jar = SimpleCookieJar() cookie_jar = SimpleCookieJar()
cookie_jar.add("") cookie_jar.add("")
self.assertFalse( self.assertFalse(
@ -67,7 +67,7 @@ class CookieJarTest(unittest.TestCase):
self.assertEqual(cookie_jar.get("xyz"), "e=f") self.assertEqual(cookie_jar.get("xyz"), "e=f")
self.assertEqual(cookie_jar.get("something"), "") self.assertEqual(cookie_jar.get("something"), "")
def testSet(self): def test_set(self):
cookie_jar = SimpleCookieJar() cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b") cookie_jar.set("a=b")
self.assertFalse( self.assertFalse(
@ -104,7 +104,7 @@ class CookieJarTest(unittest.TestCase):
self.assertEqual(cookie_jar.get("xyz"), "e=f") self.assertEqual(cookie_jar.get("xyz"), "e=f")
self.assertEqual(cookie_jar.get("something"), "") self.assertEqual(cookie_jar.get("something"), "")
def testGet(self): def test_get(self):
cookie_jar = SimpleCookieJar() cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc.com") cookie_jar.set("a=b; c=d; domain=abc.com")
self.assertEqual(cookie_jar.get("abc.com"), "a=b; c=d") self.assertEqual(cookie_jar.get("abc.com"), "a=b; c=d")

View file

@ -7,7 +7,7 @@ import ssl
import unittest import unittest
import websocket import websocket
import websocket as ws from websocket._exceptions import WebSocketProxyException, WebSocketException
from websocket._http import ( from websocket._http import (
_get_addrinfo_list, _get_addrinfo_list,
_start_proxied_socket, _start_proxied_socket,
@ -15,13 +15,14 @@ from websocket._http import (
connect, connect,
proxy_info, proxy_info,
read_headers, read_headers,
HAVE_PYTHON_SOCKS,
) )
""" """
test_http.py test_http.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -93,20 +94,18 @@ class OptsList:
class HttpTest(unittest.TestCase): class HttpTest(unittest.TestCase):
def testReadHeader(self): def test_read_header(self):
status, header, status_message = read_headers( status, header, _ = read_headers(HeaderSockMock("data/header01.txt"))
HeaderSockMock("data/header01.txt")
)
self.assertEqual(status, 101) self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade") self.assertEqual(header["connection"], "Upgrade")
# header02.txt is intentionally malformed # header02.txt is intentionally malformed
self.assertRaises( self.assertRaises(
ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt") WebSocketException, read_headers, HeaderSockMock("data/header02.txt")
) )
def testTunnel(self): def test_tunnel(self):
self.assertRaises( self.assertRaises(
ws.WebSocketProxyException, WebSocketProxyException,
_tunnel, _tunnel,
HeaderSockMock("data/header01.txt"), HeaderSockMock("data/header01.txt"),
"example.com", "example.com",
@ -114,7 +113,7 @@ class HttpTest(unittest.TestCase):
("username", "password"), ("username", "password"),
) )
self.assertRaises( self.assertRaises(
ws.WebSocketProxyException, WebSocketProxyException,
_tunnel, _tunnel,
HeaderSockMock("data/header02.txt"), HeaderSockMock("data/header02.txt"),
"example.com", "example.com",
@ -123,9 +122,9 @@ class HttpTest(unittest.TestCase):
) )
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testConnect(self): def test_connect(self):
# Not currently testing an actual proxy connection, so just check whether proxy errors are raised. This requires internet for a DNS lookup # Not currently testing an actual proxy connection, so just check whether proxy errors are raised. This requires internet for a DNS lookup
if ws._http.HAVE_PYTHON_SOCKS: if HAVE_PYTHON_SOCKS:
# Need this check, otherwise case where python_socks is not installed triggers # Need this check, otherwise case where python_socks is not installed triggers
# websocket._exceptions.WebSocketException: Python Socks is needed for SOCKS proxying but is not available # websocket._exceptions.WebSocketException: Python Socks is needed for SOCKS proxying but is not available
self.assertRaises( self.assertRaises(
@ -244,7 +243,7 @@ class HttpTest(unittest.TestCase):
@unittest.skipUnless( @unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
) )
def testProxyConnect(self): def test_proxy_connect(self):
ws = websocket.WebSocket() ws = websocket.WebSocket()
ws.connect( ws.connect(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
@ -289,7 +288,7 @@ class HttpTest(unittest.TestCase):
# TODO: Test SOCKS4 and SOCK5 proxies with unit tests # TODO: Test SOCKS4 and SOCK5 proxies with unit tests
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSSLopt(self): def test_sslopt(self):
ssloptions = { ssloptions = {
"check_hostname": False, "check_hostname": False,
"server_hostname": "ServerName", "server_hostname": "ServerName",
@ -315,7 +314,7 @@ class HttpTest(unittest.TestCase):
ws_ssl2.connect("wss://api.bitfinex.com/ws/2") ws_ssl2.connect("wss://api.bitfinex.com/ws/2")
ws_ssl2.close ws_ssl2.close
def testProxyInfo(self): def test_proxy_info(self):
self.assertEqual( self.assertEqual(
proxy_info( proxy_info(
http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http" http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http"

View file

@ -9,12 +9,13 @@ from websocket._url import (
get_proxy_info, get_proxy_info,
parse_url, parse_url,
) )
from websocket._exceptions import WebSocketProxyException
""" """
test_url.py test_url.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -36,7 +37,7 @@ class UrlTest(unittest.TestCase):
self.assertTrue(_is_address_in_network("127.1.0.1", "127.0.0.0/8")) self.assertTrue(_is_address_in_network("127.1.0.1", "127.0.0.0/8"))
self.assertFalse(_is_address_in_network("127.1.0.1", "127.0.0.0/24")) self.assertFalse(_is_address_in_network("127.1.0.1", "127.0.0.0/24"))
def testParseUrl(self): def test_parse_url(self):
p = parse_url("ws://www.example.com/r") p = parse_url("ws://www.example.com/r")
self.assertEqual(p[0], "www.example.com") self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80) self.assertEqual(p[1], 80)
@ -130,9 +131,13 @@ class IsNoProxyHostTest(unittest.TestCase):
elif "no_proxy" in os.environ: elif "no_proxy" in os.environ:
del os.environ["no_proxy"] del os.environ["no_proxy"]
def testMatchAll(self): def test_match_all(self):
self.assertTrue(_is_no_proxy_host("any.websocket.org", ["*"])) self.assertTrue(_is_no_proxy_host("any.websocket.org", ["*"]))
self.assertTrue(_is_no_proxy_host("192.168.0.1", ["*"])) self.assertTrue(_is_no_proxy_host("192.168.0.1", ["*"]))
self.assertFalse(_is_no_proxy_host("192.168.0.1", ["192.168.1.1"]))
self.assertFalse(
_is_no_proxy_host("any.websocket.org", ["other.websocket.org"])
)
self.assertTrue( self.assertTrue(
_is_no_proxy_host("any.websocket.org", ["other.websocket.org", "*"]) _is_no_proxy_host("any.websocket.org", ["other.websocket.org", "*"])
) )
@ -142,7 +147,7 @@ class IsNoProxyHostTest(unittest.TestCase):
os.environ["no_proxy"] = "other.websocket.org, *" os.environ["no_proxy"] = "other.websocket.org, *"
self.assertTrue(_is_no_proxy_host("any.websocket.org", None)) self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
def testIpAddress(self): def test_ip_address(self):
self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.1"])) self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.1"]))
self.assertFalse(_is_no_proxy_host("127.0.0.2", ["127.0.0.1"])) self.assertFalse(_is_no_proxy_host("127.0.0.2", ["127.0.0.1"]))
self.assertTrue( self.assertTrue(
@ -158,7 +163,7 @@ class IsNoProxyHostTest(unittest.TestCase):
self.assertTrue(_is_no_proxy_host("127.0.0.1", None)) self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
self.assertFalse(_is_no_proxy_host("127.0.0.2", None)) self.assertFalse(_is_no_proxy_host("127.0.0.2", None))
def testIpAddressInRange(self): def test_ip_address_in_range(self):
self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.0/8"])) self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.0/8"]))
self.assertTrue(_is_no_proxy_host("127.0.0.2", ["127.0.0.0/8"])) self.assertTrue(_is_no_proxy_host("127.0.0.2", ["127.0.0.0/8"]))
self.assertFalse(_is_no_proxy_host("127.1.0.1", ["127.0.0.0/24"])) self.assertFalse(_is_no_proxy_host("127.1.0.1", ["127.0.0.0/24"]))
@ -168,7 +173,7 @@ class IsNoProxyHostTest(unittest.TestCase):
os.environ["no_proxy"] = "127.0.0.0/24" os.environ["no_proxy"] = "127.0.0.0/24"
self.assertFalse(_is_no_proxy_host("127.1.0.1", None)) self.assertFalse(_is_no_proxy_host("127.1.0.1", None))
def testHostnameMatch(self): def test_hostname_match(self):
self.assertTrue(_is_no_proxy_host("my.websocket.org", ["my.websocket.org"])) self.assertTrue(_is_no_proxy_host("my.websocket.org", ["my.websocket.org"]))
self.assertTrue( self.assertTrue(
_is_no_proxy_host( _is_no_proxy_host(
@ -182,7 +187,7 @@ class IsNoProxyHostTest(unittest.TestCase):
os.environ["no_proxy"] = "other.websocket.org, my.websocket.org" os.environ["no_proxy"] = "other.websocket.org, my.websocket.org"
self.assertTrue(_is_no_proxy_host("my.websocket.org", None)) self.assertTrue(_is_no_proxy_host("my.websocket.org", None))
def testHostnameMatchDomain(self): def test_hostname_match_domain(self):
self.assertTrue(_is_no_proxy_host("any.websocket.org", [".websocket.org"])) self.assertTrue(_is_no_proxy_host("any.websocket.org", [".websocket.org"]))
self.assertTrue(_is_no_proxy_host("my.other.websocket.org", [".websocket.org"])) self.assertTrue(_is_no_proxy_host("my.other.websocket.org", [".websocket.org"]))
self.assertTrue( self.assertTrue(
@ -227,10 +232,13 @@ class ProxyInfoTest(unittest.TestCase):
elif "no_proxy" in os.environ: elif "no_proxy" in os.environ:
del os.environ["no_proxy"] del os.environ["no_proxy"]
def testProxyFromArgs(self): def test_proxy_from_args(self):
self.assertEqual( self.assertRaises(
get_proxy_info("echo.websocket.events", False, proxy_host="localhost"), WebSocketProxyException,
("localhost", 0, None), get_proxy_info,
"echo.websocket.events",
False,
proxy_host="localhost",
) )
self.assertEqual( self.assertEqual(
get_proxy_info( get_proxy_info(
@ -238,10 +246,6 @@ class ProxyInfoTest(unittest.TestCase):
), ),
("localhost", 3128, None), ("localhost", 3128, None),
) )
self.assertEqual(
get_proxy_info("echo.websocket.events", True, proxy_host="localhost"),
("localhost", 0, None),
)
self.assertEqual( self.assertEqual(
get_proxy_info( get_proxy_info(
"echo.websocket.events", True, proxy_host="localhost", proxy_port=3128 "echo.websocket.events", True, proxy_host="localhost", proxy_port=3128
@ -254,9 +258,10 @@ class ProxyInfoTest(unittest.TestCase):
"echo.websocket.events", "echo.websocket.events",
False, False,
proxy_host="localhost", proxy_host="localhost",
proxy_port=9001,
proxy_auth=("a", "b"), proxy_auth=("a", "b"),
), ),
("localhost", 0, ("a", "b")), ("localhost", 9001, ("a", "b")),
) )
self.assertEqual( self.assertEqual(
get_proxy_info( get_proxy_info(
@ -273,9 +278,10 @@ class ProxyInfoTest(unittest.TestCase):
"echo.websocket.events", "echo.websocket.events",
True, True,
proxy_host="localhost", proxy_host="localhost",
proxy_port=8765,
proxy_auth=("a", "b"), proxy_auth=("a", "b"),
), ),
("localhost", 0, ("a", "b")), ("localhost", 8765, ("a", "b")),
) )
self.assertEqual( self.assertEqual(
get_proxy_info( get_proxy_info(
@ -311,7 +317,18 @@ class ProxyInfoTest(unittest.TestCase):
(None, 0, None), (None, 0, None),
) )
def testProxyFromEnv(self): self.assertEqual(
get_proxy_info(
"echo.websocket.events",
True,
proxy_host="localhost",
proxy_port=3128,
no_proxy=[".websocket.events"],
),
(None, 0, None),
)
def test_proxy_from_env(self):
os.environ["http_proxy"] = "http://localhost/" os.environ["http_proxy"] = "http://localhost/"
self.assertEqual( self.assertEqual(
get_proxy_info("echo.websocket.events", False), ("localhost", None, None) get_proxy_info("echo.websocket.events", False), ("localhost", None, None)

View file

@ -7,6 +7,7 @@ import unittest
from base64 import decodebytes as base64decode from base64 import decodebytes as base64decode
import websocket as ws import websocket as ws
from websocket._exceptions import WebSocketBadStatusException, WebSocketAddressException
from websocket._handshake import _create_sec_websocket_key from websocket._handshake import _create_sec_websocket_key
from websocket._handshake import _validate as _validate_header from websocket._handshake import _validate as _validate_header
from websocket._http import read_headers from websocket._http import read_headers
@ -16,7 +17,7 @@ from websocket._utils import validate_utf8
test_websocket.py test_websocket.py
websocket - WebSocket client library for Python websocket - WebSocket client library for Python
Copyright 2023 engn33r Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
@ -33,7 +34,6 @@ limitations under the License.
try: try:
import ssl import ssl
from ssl import SSLError
except ImportError: except ImportError:
# dummy class of SSLError for ssl none-support environment. # dummy class of SSLError for ssl none-support environment.
class SSLError(Exception): class SSLError(Exception):
@ -95,24 +95,24 @@ class WebSocketTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def testDefaultTimeout(self): def test_default_timeout(self):
self.assertEqual(ws.getdefaulttimeout(), None) self.assertEqual(ws.getdefaulttimeout(), None)
ws.setdefaulttimeout(10) ws.setdefaulttimeout(10)
self.assertEqual(ws.getdefaulttimeout(), 10) self.assertEqual(ws.getdefaulttimeout(), 10)
ws.setdefaulttimeout(None) ws.setdefaulttimeout(None)
def testWSKey(self): def test_ws_key(self):
key = _create_sec_websocket_key() key = _create_sec_websocket_key()
self.assertTrue(key != 24) self.assertTrue(key != 24)
self.assertTrue("¥n" not in key) self.assertTrue("¥n" not in key)
def testNonce(self): def test_nonce(self):
"""WebSocket key should be a random 16-byte nonce.""" """WebSocket key should be a random 16-byte nonce."""
key = _create_sec_websocket_key() key = _create_sec_websocket_key()
nonce = base64decode(key.encode("utf-8")) nonce = base64decode(key.encode("utf-8"))
self.assertEqual(16, len(nonce)) self.assertEqual(16, len(nonce))
def testWsUtils(self): def test_ws_utils(self):
key = "c6b8hTg4EeGb2gQMztV1/g==" key = "c6b8hTg4EeGb2gQMztV1/g=="
required_header = { required_header = {
"upgrade": "websocket", "upgrade": "websocket",
@ -157,16 +157,12 @@ class WebSocketTest(unittest.TestCase):
# This case will print out a logging error using the error() function, but that is expected # This case will print out a logging error using the error() function, but that is expected
self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (False, None)) self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (False, None))
def testReadHeader(self): def test_read_header(self):
status, header, status_message = read_headers( status, header, _ = read_headers(HeaderSockMock("data/header01.txt"))
HeaderSockMock("data/header01.txt")
)
self.assertEqual(status, 101) self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade") self.assertEqual(header["connection"], "Upgrade")
status, header, status_message = read_headers( status, header, _ = read_headers(HeaderSockMock("data/header03.txt"))
HeaderSockMock("data/header03.txt")
)
self.assertEqual(status, 101) self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade, Keep-Alive") self.assertEqual(header["connection"], "Upgrade, Keep-Alive")
@ -175,7 +171,7 @@ class WebSocketTest(unittest.TestCase):
ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt") ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt")
) )
def testSend(self): def test_send(self):
# TODO: add longer frame data # TODO: add longer frame data
sock = ws.WebSocket() sock = ws.WebSocket()
sock.set_mask_key(create_mask_key) sock.set_mask_key(create_mask_key)
@ -194,7 +190,7 @@ class WebSocketTest(unittest.TestCase):
self.assertEqual(sock.send_binary(b"1111111111101"), 19) self.assertEqual(sock.send_binary(b"1111111111101"), 19)
def testRecv(self): def test_recv(self):
# TODO: add longer frame data # TODO: add longer frame data
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
@ -210,7 +206,7 @@ class WebSocketTest(unittest.TestCase):
self.assertEqual(data, "Hello") self.assertEqual(data, "Hello")
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testIter(self): def test_iter(self):
count = 2 count = 2
s = ws.create_connection("wss://api.bitfinex.com/ws/2") s = ws.create_connection("wss://api.bitfinex.com/ws/2")
s.send('{"event": "subscribe", "channel": "ticker"}') s.send('{"event": "subscribe", "channel": "ticker"}')
@ -220,11 +216,11 @@ class WebSocketTest(unittest.TestCase):
break break
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testNext(self): def test_next(self):
sock = ws.create_connection("wss://api.bitfinex.com/ws/2") sock = ws.create_connection("wss://api.bitfinex.com/ws/2")
self.assertEqual(str, type(next(sock))) self.assertEqual(str, type(next(sock)))
def testInternalRecvStrict(self): def test_internal_recv_strict(self):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
s.add_packet(b"foo") s.add_packet(b"foo")
@ -241,7 +237,7 @@ class WebSocketTest(unittest.TestCase):
with self.assertRaises(ws.WebSocketConnectionClosedException): with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.frame_buffer.recv_strict(1) sock.frame_buffer.recv_strict(1)
def testRecvTimeout(self): def test_recv_timeout(self):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
s.add_packet(b"\x81") s.add_packet(b"\x81")
@ -258,7 +254,7 @@ class WebSocketTest(unittest.TestCase):
with self.assertRaises(ws.WebSocketConnectionClosedException): with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv() sock.recv()
def testRecvWithSimpleFragmentation(self): def test_recv_with_simple_fragmentation(self):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Brevity is " # OPCODE=TEXT, FIN=0, MSG="Brevity is "
@ -270,7 +266,7 @@ class WebSocketTest(unittest.TestCase):
with self.assertRaises(ws.WebSocketConnectionClosedException): with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv() sock.recv()
def testRecvWithFireEventOfFragmentation(self): def test_recv_with_fire_event_of_fragmentation(self):
sock = ws.WebSocket(fire_cont_frame=True) sock = ws.WebSocket(fire_cont_frame=True)
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Brevity is " # OPCODE=TEXT, FIN=0, MSG="Brevity is "
@ -296,7 +292,7 @@ class WebSocketTest(unittest.TestCase):
with self.assertRaises(ws.WebSocketConnectionClosedException): with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv() sock.recv()
def testClose(self): def test_close(self):
sock = ws.WebSocket() sock = ws.WebSocket()
sock.connected = True sock.connected = True
sock.close sock.close
@ -308,14 +304,14 @@ class WebSocketTest(unittest.TestCase):
sock.recv() sock.recv()
self.assertEqual(sock.connected, False) self.assertEqual(sock.connected, False)
def testRecvContFragmentation(self): def test_recv_cont_fragmentation(self):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=CONT, FIN=1, MSG="the soul of wit" # OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17") s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
self.assertRaises(ws.WebSocketException, sock.recv) self.assertRaises(ws.WebSocketException, sock.recv)
def testRecvWithProlongedFragmentation(self): def test_recv_with_prolonged_fragmentation(self):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, " # OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, "
@ -331,7 +327,7 @@ class WebSocketTest(unittest.TestCase):
with self.assertRaises(ws.WebSocketConnectionClosedException): with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv() sock.recv()
def testRecvWithFragmentationAndControlFrame(self): def test_recv_with_fragmentation_and_control_frame(self):
sock = ws.WebSocket() sock = ws.WebSocket()
sock.set_mask_key(create_mask_key) sock.set_mask_key(create_mask_key)
s = sock.sock = SockMock() s = sock.sock = SockMock()
@ -352,7 +348,7 @@ class WebSocketTest(unittest.TestCase):
@unittest.skipUnless( @unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
) )
def testWebSocket(self): def test_websocket(self):
s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}") s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
self.assertNotEqual(s, None) self.assertNotEqual(s, None)
s.send("Hello, World") s.send("Hello, World")
@ -369,7 +365,7 @@ class WebSocketTest(unittest.TestCase):
@unittest.skipUnless( @unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
) )
def testPingPong(self): def test_ping_pong(self):
s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}") s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
self.assertNotEqual(s, None) self.assertNotEqual(s, None)
s.ping("Hello") s.ping("Hello")
@ -377,17 +373,13 @@ class WebSocketTest(unittest.TestCase):
s.close() s.close()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSupportRedirect(self): def test_support_redirect(self):
s = ws.WebSocket() s = ws.WebSocket()
self.assertRaises( self.assertRaises(WebSocketBadStatusException, s.connect, "ws://google.com/")
ws._exceptions.WebSocketBadStatusException, s.connect, "ws://google.com/"
)
# Need to find a URL that has a redirect code leading to a websocket # Need to find a URL that has a redirect code leading to a websocket
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSecureWebSocket(self): def test_secure_websocket(self):
import ssl
s = ws.create_connection("wss://api.bitfinex.com/ws/2") s = ws.create_connection("wss://api.bitfinex.com/ws/2")
self.assertNotEqual(s, None) self.assertNotEqual(s, None)
self.assertTrue(isinstance(s.sock, ssl.SSLSocket)) self.assertTrue(isinstance(s.sock, ssl.SSLSocket))
@ -401,7 +393,7 @@ class WebSocketTest(unittest.TestCase):
@unittest.skipUnless( @unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
) )
def testWebSocketWithCustomHeader(self): def test_websocket_with_custom_header(self):
s = ws.create_connection( s = ws.create_connection(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
headers={"User-Agent": "PythonWebsocketClient"}, headers={"User-Agent": "PythonWebsocketClient"},
@ -417,7 +409,7 @@ class WebSocketTest(unittest.TestCase):
@unittest.skipUnless( @unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
) )
def testAfterClose(self): def test_after_close(self):
s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}") s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
self.assertNotEqual(s, None) self.assertNotEqual(s, None)
s.close() s.close()
@ -429,7 +421,7 @@ class SockOptTest(unittest.TestCase):
@unittest.skipUnless( @unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled" TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
) )
def testSockOpt(self): def test_sockopt(self):
sockopt = ((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),) sockopt = ((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),)
s = ws.create_connection( s = ws.create_connection(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", sockopt=sockopt f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", sockopt=sockopt
@ -441,7 +433,7 @@ class SockOptTest(unittest.TestCase):
class UtilsTest(unittest.TestCase): class UtilsTest(unittest.TestCase):
def testUtf8Validator(self): def test_utf8_validator(self):
state = validate_utf8(b"\xf0\x90\x80\x80") state = validate_utf8(b"\xf0\x90\x80\x80")
self.assertEqual(state, True) self.assertEqual(state, True)
state = validate_utf8( state = validate_utf8(
@ -454,7 +446,7 @@ class UtilsTest(unittest.TestCase):
class HandshakeTest(unittest.TestCase): class HandshakeTest(unittest.TestCase):
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_http_SSL(self): def test_http_ssl(self):
websock1 = ws.WebSocket( websock1 = ws.WebSocket(
sslopt={"cert_chain": ssl.get_default_verify_paths().capath}, sslopt={"cert_chain": ssl.get_default_verify_paths().capath},
enable_multithread=False, enable_multithread=False,
@ -466,7 +458,7 @@ class HandshakeTest(unittest.TestCase):
) )
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testManualHeaders(self): def test_manual_headers(self):
websock3 = ws.WebSocket( websock3 = ws.WebSocket(
sslopt={ sslopt={
"ca_certs": ssl.get_default_verify_paths().cafile, "ca_certs": ssl.get_default_verify_paths().cafile,
@ -474,7 +466,7 @@ class HandshakeTest(unittest.TestCase):
} }
) )
self.assertRaises( self.assertRaises(
ws._exceptions.WebSocketBadStatusException, WebSocketBadStatusException,
websock3.connect, websock3.connect,
"wss://api.bitfinex.com/ws/2", "wss://api.bitfinex.com/ws/2",
cookie="chocolate", cookie="chocolate",
@ -490,16 +482,14 @@ class HandshakeTest(unittest.TestCase):
}, },
) )
def testIPv6(self): def test_ipv6(self):
websock2 = ws.WebSocket() websock2 = ws.WebSocket()
self.assertRaises(ValueError, websock2.connect, "2001:4860:4860::8888") self.assertRaises(ValueError, websock2.connect, "2001:4860:4860::8888")
def testBadURLs(self): def test_bad_urls(self):
websock3 = ws.WebSocket() websock3 = ws.WebSocket()
self.assertRaises(ValueError, websock3.connect, "ws//example.com") self.assertRaises(ValueError, websock3.connect, "ws//example.com")
self.assertRaises( self.assertRaises(WebSocketAddressException, websock3.connect, "ws://example")
ws.WebSocketAddressException, websock3.connect, "ws://example"
)
self.assertRaises(ValueError, websock3.connect, "example.com") self.assertRaises(ValueError, websock3.connect, "example.com")

View file

@ -13,14 +13,10 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Tautulli. If not, see <http://www.gnu.org/licenses/>. # along with Tautulli. If not, see <http://www.gnu.org/licenses/>.
from __future__ import unicode_literals
from future.builtins import range
from future.builtins import str
import ctypes import ctypes
import datetime import datetime
import os import os
import future.moves.queue as queue import queue
import sqlite3 import sqlite3
import sys import sys
import subprocess import subprocess
@ -39,31 +35,6 @@ from apscheduler.triggers.interval import IntervalTrigger
from ga4mp import GtagMP from ga4mp import GtagMP
import pytz import pytz
PYTHON2 = sys.version_info[0] == 2
if PYTHON2:
import activity_handler
import activity_pinger
import common
import database
import datafactory
import exporter
import helpers
import libraries
import logger
import mobile_app
import newsletters
import newsletter_handler
import notification_handler
import notifiers
import plex
import plextv
import users
import versioncheck
import web_socket
import webstart
import config
else:
from plexpy import activity_handler from plexpy import activity_handler
from plexpy import activity_pinger from plexpy import activity_pinger
from plexpy import common from plexpy import common
@ -214,7 +185,6 @@ def initialize(config_file):
logger.initLogger(console=not QUIET, log_dir=CONFIG.LOG_DIR if log_writable else None, logger.initLogger(console=not QUIET, log_dir=CONFIG.LOG_DIR if log_writable else None,
verbose=VERBOSE) verbose=VERBOSE)
if not PYTHON2:
os.environ['PLEXAPI_CONFIG_PATH'] = os.path.join(DATA_DIR, 'plexapi.config.ini') os.environ['PLEXAPI_CONFIG_PATH'] = os.path.join(DATA_DIR, 'plexapi.config.ini')
os.environ['PLEXAPI_LOG_PATH'] = os.path.join(CONFIG.LOG_DIR, 'plexapi.log') os.environ['PLEXAPI_LOG_PATH'] = os.path.join(CONFIG.LOG_DIR, 'plexapi.log')
os.environ['PLEXAPI_LOG_LEVEL'] = 'DEBUG' os.environ['PLEXAPI_LOG_LEVEL'] = 'DEBUG'

View file

@ -13,10 +13,6 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Tautulli. If not, see <http://www.gnu.org/licenses/>. # along with Tautulli. If not, see <http://www.gnu.org/licenses/>.
from __future__ import unicode_literals
from future.builtins import str
from future.builtins import object
import datetime import datetime
import os import os
import time import time
@ -25,15 +21,6 @@ from apscheduler.triggers.date import DateTrigger
import pytz import pytz
import plexpy import plexpy
if plexpy.PYTHON2:
import activity_processor
import common
import datafactory
import helpers
import logger
import notification_handler
import pmsconnect
else:
from plexpy import activity_processor from plexpy import activity_processor
from plexpy import common from plexpy import common
from plexpy import datafactory from plexpy import datafactory

View file

@ -13,29 +13,13 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Tautulli. If not, see <http://www.gnu.org/licenses/>. # along with Tautulli. If not, see <http://www.gnu.org/licenses/>.
from __future__ import unicode_literals
from future.builtins import str
import threading import threading
import plexpy import plexpy
if plexpy.PYTHON2:
import activity_handler
import activity_processor
import database
import helpers
import libraries
import logger
import notification_handler
import plextv
import pmsconnect
import web_socket
else:
from plexpy import activity_handler from plexpy import activity_handler
from plexpy import activity_processor from plexpy import activity_processor
from plexpy import database from plexpy import database
from plexpy import helpers from plexpy import helpers
from plexpy import libraries
from plexpy import logger from plexpy import logger
from plexpy import notification_handler from plexpy import notification_handler
from plexpy import plextv from plexpy import plextv

View file

@ -13,22 +13,10 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with Tautulli. If not, see <http://www.gnu.org/licenses/>. # along with Tautulli. If not, see <http://www.gnu.org/licenses/>.
from __future__ import unicode_literals
from future.builtins import str
from future.builtins import object
from collections import defaultdict from collections import defaultdict
import json import json
import plexpy import plexpy
if plexpy.PYTHON2:
import database
import helpers
import libraries
import logger
import pmsconnect
import users
else:
from plexpy import database from plexpy import database
from plexpy import helpers from plexpy import helpers
from plexpy import libraries from plexpy import libraries

Some files were not shown because too many files have changed in this diff Show more