Merge branch 'nightly' into dependabot/pip/nightly/paho-mqtt-2.1.0

This commit is contained in:
JonnyWong16 2024-05-09 22:28:43 -07:00 committed by GitHub
commit a05752030b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
139 changed files with 8367 additions and 3353 deletions

View file

@ -23,7 +23,6 @@ import sys
# Ensure lib added to path, before any other imports
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'lib'))
from future.builtins import str
import argparse
import datetime

View file

@ -212,28 +212,6 @@
</div>
</div>
</div>
<% from plexpy.helpers import anon_url %>
<div id="python2-modal" class="modal fade wide" tabindex="-1" role="dialog" aria-labelledby="python2-modal">
<div class="modal-dialog" role="document">
<div class="modal-content">
<div class="modal-header">
<button type="button" class="close" data-dismiss="modal" aria-hidden="true"><i class="fa fa-remove"></i></button>
<h4 class="modal-title">Unable to Update</h4>
</div>
<div class="modal-body" style="text-align: center;">
<p>Tautulli is still running using Python 2 and cannot be updated past v2.6.3.</p>
<p>Python 3 is required to continue receiving updates.</p>
<p>
<strong>Please see the <a href="${anon_url('https://github.com/Tautulli/Tautulli/wiki/Upgrading-to-Python-3-%28Tautulli-v2.5%29')}" target="_blank" rel="noreferrer">wiki</a>
for instructions on how to upgrade to Python 3.</strong>
</p>
</div>
<div class="modal-footer">
<input type="button" class="btn btn-bright" data-dismiss="modal" value="Close">
</div>
</div>
</div>
</div>
% endif
<div class="modal fade" id="ip-info-modal" tabindex="-1" role="dialog" aria-labelledby="ip-info-modal">
@ -1067,16 +1045,4 @@
});
</script>
% endif
% if _session['user_group'] == 'admin':
<script>
const queryString = window.location.search;
const urlParams = new URLSearchParams(queryString);
if (urlParams.get('update') === 'python2') {
$("#python2-modal").modal({
backdrop: 'static',
keyboard: false
});
}
</script>
% endif
</%def>

View file

@ -1,5 +1,5 @@
<%
from six.moves.urllib.parse import urlencode
from urllib.parse import urlencode
%>
<!doctype html>

View file

@ -1 +1 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore
__path__ = __import__("pkgutil").extend_path(__path__, __name__)

View file

@ -1,979 +0,0 @@
# -*- coding: utf-8 -*-
"""A port of Python 3's csv module to Python 2.
The API of the csv module in Python 2 is drastically different from
the csv module in Python 3. This is due, for the most part, to the
difference between str in Python 2 and Python 3.
The semantics of Python 3's version are more useful because they support
unicode natively, while Python 2's csv does not.
"""
from __future__ import unicode_literals, absolute_import
__all__ = [ "QUOTE_MINIMAL", "QUOTE_ALL", "QUOTE_NONNUMERIC", "QUOTE_NONE",
"Error", "Dialect", "__doc__", "excel", "excel_tab",
"field_size_limit", "reader", "writer",
"register_dialect", "get_dialect", "list_dialects", "Sniffer",
"unregister_dialect", "__version__", "DictReader", "DictWriter" ]
import re
import numbers
from io import StringIO
from csv import (
QUOTE_MINIMAL, QUOTE_ALL, QUOTE_NONNUMERIC, QUOTE_NONE,
__version__, __doc__, Error, field_size_limit,
)
# Stuff needed from six
import sys
PY3 = sys.version_info[0] == 3
if PY3:
string_types = str
text_type = str
binary_type = bytes
unichr = chr
else:
string_types = basestring
text_type = unicode
binary_type = str
class QuoteStrategy(object):
quoting = None
def __init__(self, dialect):
if self.quoting is not None:
assert dialect.quoting == self.quoting
self.dialect = dialect
self.setup()
escape_pattern_quoted = r'({quotechar})'.format(
quotechar=re.escape(self.dialect.quotechar or '"'))
escape_pattern_unquoted = r'([{specialchars}])'.format(
specialchars=re.escape(self.specialchars))
self.escape_re_quoted = re.compile(escape_pattern_quoted)
self.escape_re_unquoted = re.compile(escape_pattern_unquoted)
def setup(self):
"""Optional method for strategy-wide optimizations."""
def quoted(self, field=None, raw_field=None, only=None):
"""Determine whether this field should be quoted."""
raise NotImplementedError(
'quoted must be implemented by a subclass')
@property
def specialchars(self):
"""The special characters that need to be escaped."""
raise NotImplementedError(
'specialchars must be implemented by a subclass')
def escape_re(self, quoted=None):
if quoted:
return self.escape_re_quoted
return self.escape_re_unquoted
def escapechar(self, quoted=None):
if quoted and self.dialect.doublequote:
return self.dialect.quotechar
return self.dialect.escapechar
def prepare(self, raw_field, only=None):
field = text_type(raw_field if raw_field is not None else '')
quoted = self.quoted(field=field, raw_field=raw_field, only=only)
escape_re = self.escape_re(quoted=quoted)
escapechar = self.escapechar(quoted=quoted)
if escape_re.search(field):
escapechar = '\\\\' if escapechar == '\\' else escapechar
if not escapechar:
raise Error('No escapechar is set')
escape_replace = r'{escapechar}\1'.format(escapechar=escapechar)
field = escape_re.sub(escape_replace, field)
if quoted:
field = '{quotechar}{field}{quotechar}'.format(
quotechar=self.dialect.quotechar, field=field)
return field
class QuoteMinimalStrategy(QuoteStrategy):
quoting = QUOTE_MINIMAL
def setup(self):
self.quoted_re = re.compile(r'[{specialchars}]'.format(
specialchars=re.escape(self.specialchars)))
@property
def specialchars(self):
return (
self.dialect.lineterminator +
self.dialect.quotechar +
self.dialect.delimiter +
(self.dialect.escapechar or '')
)
def quoted(self, field, only, **kwargs):
if field == self.dialect.quotechar and not self.dialect.doublequote:
# If the only character in the field is the quotechar, and
# doublequote is false, then just escape without outer quotes.
return False
return field == '' and only or bool(self.quoted_re.search(field))
class QuoteAllStrategy(QuoteStrategy):
quoting = QUOTE_ALL
@property
def specialchars(self):
return self.dialect.quotechar
def quoted(self, **kwargs):
return True
class QuoteNonnumericStrategy(QuoteStrategy):
quoting = QUOTE_NONNUMERIC
@property
def specialchars(self):
return (
self.dialect.lineterminator +
self.dialect.quotechar +
self.dialect.delimiter +
(self.dialect.escapechar or '')
)
def quoted(self, raw_field, **kwargs):
return not isinstance(raw_field, numbers.Number)
class QuoteNoneStrategy(QuoteStrategy):
quoting = QUOTE_NONE
@property
def specialchars(self):
return (
self.dialect.lineterminator +
(self.dialect.quotechar or '') +
self.dialect.delimiter +
(self.dialect.escapechar or '')
)
def quoted(self, field, only, **kwargs):
if field == '' and only:
raise Error('single empty field record must be quoted')
return False
class writer(object):
def __init__(self, fileobj, dialect='excel', **fmtparams):
if fileobj is None:
raise TypeError('fileobj must be file-like, not None')
self.fileobj = fileobj
if isinstance(dialect, text_type):
dialect = get_dialect(dialect)
try:
self.dialect = Dialect.combine(dialect, fmtparams)
except Error as e:
raise TypeError(*e.args)
strategies = {
QUOTE_MINIMAL: QuoteMinimalStrategy,
QUOTE_ALL: QuoteAllStrategy,
QUOTE_NONNUMERIC: QuoteNonnumericStrategy,
QUOTE_NONE: QuoteNoneStrategy,
}
self.strategy = strategies[self.dialect.quoting](self.dialect)
def writerow(self, row):
if row is None:
raise Error('row must be an iterable')
row = list(row)
only = len(row) == 1
row = [self.strategy.prepare(field, only=only) for field in row]
line = self.dialect.delimiter.join(row) + self.dialect.lineterminator
return self.fileobj.write(line)
def writerows(self, rows):
for row in rows:
self.writerow(row)
START_RECORD = 0
START_FIELD = 1
ESCAPED_CHAR = 2
IN_FIELD = 3
IN_QUOTED_FIELD = 4
ESCAPE_IN_QUOTED_FIELD = 5
QUOTE_IN_QUOTED_FIELD = 6
EAT_CRNL = 7
AFTER_ESCAPED_CRNL = 8
class reader(object):
def __init__(self, fileobj, dialect='excel', **fmtparams):
self.input_iter = iter(fileobj)
if isinstance(dialect, text_type):
dialect = get_dialect(dialect)
try:
self.dialect = Dialect.combine(dialect, fmtparams)
except Error as e:
raise TypeError(*e.args)
self.fields = None
self.field = None
self.line_num = 0
def parse_reset(self):
self.fields = []
self.field = []
self.state = START_RECORD
self.numeric_field = False
def parse_save_field(self):
field = ''.join(self.field)
self.field = []
if self.numeric_field:
field = float(field)
self.numeric_field = False
self.fields.append(field)
def parse_add_char(self, c):
if len(self.field) >= field_size_limit():
raise Error('field size limit exceeded')
self.field.append(c)
def parse_process_char(self, c):
switch = {
START_RECORD: self._parse_start_record,
START_FIELD: self._parse_start_field,
ESCAPED_CHAR: self._parse_escaped_char,
AFTER_ESCAPED_CRNL: self._parse_after_escaped_crnl,
IN_FIELD: self._parse_in_field,
IN_QUOTED_FIELD: self._parse_in_quoted_field,
ESCAPE_IN_QUOTED_FIELD: self._parse_escape_in_quoted_field,
QUOTE_IN_QUOTED_FIELD: self._parse_quote_in_quoted_field,
EAT_CRNL: self._parse_eat_crnl,
}
return switch[self.state](c)
def _parse_start_record(self, c):
if c == '\0':
return
elif c == '\n' or c == '\r':
self.state = EAT_CRNL
return
self.state = START_FIELD
return self._parse_start_field(c)
def _parse_start_field(self, c):
if c == '\n' or c == '\r' or c == '\0':
self.parse_save_field()
self.state = START_RECORD if c == '\0' else EAT_CRNL
elif (c == self.dialect.quotechar and
self.dialect.quoting != QUOTE_NONE):
self.state = IN_QUOTED_FIELD
elif c == self.dialect.escapechar:
self.state = ESCAPED_CHAR
elif c == ' ' and self.dialect.skipinitialspace:
pass # Ignore space at start of field
elif c == self.dialect.delimiter:
# Save empty field
self.parse_save_field()
else:
# Begin new unquoted field
if self.dialect.quoting == QUOTE_NONNUMERIC:
self.numeric_field = True
self.parse_add_char(c)
self.state = IN_FIELD
def _parse_escaped_char(self, c):
if c == '\n' or c == '\r':
self.parse_add_char(c)
self.state = AFTER_ESCAPED_CRNL
return
if c == '\0':
c = '\n'
self.parse_add_char(c)
self.state = IN_FIELD
def _parse_after_escaped_crnl(self, c):
if c == '\0':
return
return self._parse_in_field(c)
def _parse_in_field(self, c):
# In unquoted field
if c == '\n' or c == '\r' or c == '\0':
# End of line - return [fields]
self.parse_save_field()
self.state = START_RECORD if c == '\0' else EAT_CRNL
elif c == self.dialect.escapechar:
self.state = ESCAPED_CHAR
elif c == self.dialect.delimiter:
self.parse_save_field()
self.state = START_FIELD
else:
# Normal character - save in field
self.parse_add_char(c)
def _parse_in_quoted_field(self, c):
if c == '\0':
pass
elif c == self.dialect.escapechar:
self.state = ESCAPE_IN_QUOTED_FIELD
elif (c == self.dialect.quotechar and
self.dialect.quoting != QUOTE_NONE):
if self.dialect.doublequote:
self.state = QUOTE_IN_QUOTED_FIELD
else:
self.state = IN_FIELD
else:
self.parse_add_char(c)
def _parse_escape_in_quoted_field(self, c):
if c == '\0':
c = '\n'
self.parse_add_char(c)
self.state = IN_QUOTED_FIELD
def _parse_quote_in_quoted_field(self, c):
if (self.dialect.quoting != QUOTE_NONE and
c == self.dialect.quotechar):
# save "" as "
self.parse_add_char(c)
self.state = IN_QUOTED_FIELD
elif c == self.dialect.delimiter:
self.parse_save_field()
self.state = START_FIELD
elif c == '\n' or c == '\r' or c == '\0':
# End of line = return [fields]
self.parse_save_field()
self.state = START_RECORD if c == '\0' else EAT_CRNL
elif not self.dialect.strict:
self.parse_add_char(c)
self.state = IN_FIELD
else:
# illegal
raise Error("{delimiter}' expected after '{quotechar}".format(
delimiter=self.dialect.delimiter,
quotechar=self.dialect.quotechar,
))
def _parse_eat_crnl(self, c):
if c == '\n' or c == '\r':
pass
elif c == '\0':
self.state = START_RECORD
else:
raise Error('new-line character seen in unquoted field - do you '
'need to open the file in universal-newline mode?')
def __iter__(self):
return self
def __next__(self):
self.parse_reset()
while True:
try:
lineobj = next(self.input_iter)
except StopIteration:
if len(self.field) != 0 or self.state == IN_QUOTED_FIELD:
if self.dialect.strict:
raise Error('unexpected end of data')
self.parse_save_field()
if self.fields:
break
raise
if not isinstance(lineobj, text_type):
typ = type(lineobj)
typ_name = 'bytes' if typ == bytes else typ.__name__
err_str = ('iterator should return strings, not {0}'
' (did you open the file in text mode?)')
raise Error(err_str.format(typ_name))
self.line_num += 1
for c in lineobj:
if c == '\0':
raise Error('line contains NULL byte')
self.parse_process_char(c)
self.parse_process_char('\0')
if self.state == START_RECORD:
break
fields = self.fields
self.fields = None
return fields
next = __next__
_dialect_registry = {}
def register_dialect(name, dialect='excel', **fmtparams):
if not isinstance(name, text_type):
raise TypeError('"name" must be a string')
dialect = Dialect.extend(dialect, fmtparams)
try:
Dialect.validate(dialect)
except:
raise TypeError('dialect is invalid')
assert name not in _dialect_registry
_dialect_registry[name] = dialect
def unregister_dialect(name):
try:
_dialect_registry.pop(name)
except KeyError:
raise Error('"{name}" not a registered dialect'.format(name=name))
def get_dialect(name):
try:
return _dialect_registry[name]
except KeyError:
raise Error('Could not find dialect {0}'.format(name))
def list_dialects():
return list(_dialect_registry)
class Dialect(object):
"""Describe a CSV dialect.
This must be subclassed (see csv.excel). Valid attributes are:
delimiter, quotechar, escapechar, doublequote, skipinitialspace,
lineterminator, quoting, strict.
"""
_name = ""
_valid = False
# placeholders
delimiter = None
quotechar = None
escapechar = None
doublequote = None
skipinitialspace = None
lineterminator = None
quoting = None
strict = None
def __init__(self):
self.validate(self)
if self.__class__ != Dialect:
self._valid = True
@classmethod
def validate(cls, dialect):
dialect = cls.extend(dialect)
if not isinstance(dialect.quoting, int):
raise Error('"quoting" must be an integer')
if dialect.delimiter is None:
raise Error('delimiter must be set')
cls.validate_text(dialect, 'delimiter')
if dialect.lineterminator is None:
raise Error('lineterminator must be set')
if not isinstance(dialect.lineterminator, text_type):
raise Error('"lineterminator" must be a string')
if dialect.quoting not in [
QUOTE_NONE, QUOTE_MINIMAL, QUOTE_NONNUMERIC, QUOTE_ALL]:
raise Error('Invalid quoting specified')
if dialect.quoting != QUOTE_NONE:
if dialect.quotechar is None and dialect.escapechar is None:
raise Error('quotechar must be set if quoting enabled')
if dialect.quotechar is not None:
cls.validate_text(dialect, 'quotechar')
@staticmethod
def validate_text(dialect, attr):
val = getattr(dialect, attr)
if not isinstance(val, text_type):
if type(val) == bytes:
raise Error('"{0}" must be string, not bytes'.format(attr))
raise Error('"{0}" must be string, not {1}'.format(
attr, type(val).__name__))
if len(val) != 1:
raise Error('"{0}" must be a 1-character string'.format(attr))
@staticmethod
def defaults():
return {
'delimiter': ',',
'doublequote': True,
'escapechar': None,
'lineterminator': '\r\n',
'quotechar': '"',
'quoting': QUOTE_MINIMAL,
'skipinitialspace': False,
'strict': False,
}
@classmethod
def extend(cls, dialect, fmtparams=None):
if isinstance(dialect, string_types):
dialect = get_dialect(dialect)
if fmtparams is None:
return dialect
defaults = cls.defaults()
if any(param not in defaults for param in fmtparams):
raise TypeError('Invalid fmtparam')
specified = dict(
(attr, getattr(dialect, attr, None))
for attr in cls.defaults()
)
specified.update(fmtparams)
return type(str('ExtendedDialect'), (cls,), specified)
@classmethod
def combine(cls, dialect, fmtparams):
"""Create a new dialect with defaults and added parameters."""
dialect = cls.extend(dialect, fmtparams)
defaults = cls.defaults()
specified = dict(
(attr, getattr(dialect, attr, None))
for attr in defaults
if getattr(dialect, attr, None) is not None or
attr in ['quotechar', 'delimiter', 'lineterminator', 'quoting']
)
defaults.update(specified)
dialect = type(str('CombinedDialect'), (cls,), defaults)
cls.validate(dialect)
return dialect()
def __delattr__(self, attr):
if self._valid:
raise AttributeError('dialect is immutable.')
super(Dialect, self).__delattr__(attr)
def __setattr__(self, attr, value):
if self._valid:
raise AttributeError('dialect is immutable.')
super(Dialect, self).__setattr__(attr, value)
class excel(Dialect):
"""Describe the usual properties of Excel-generated CSV files."""
delimiter = ','
quotechar = '"'
doublequote = True
skipinitialspace = False
lineterminator = '\r\n'
quoting = QUOTE_MINIMAL
register_dialect("excel", excel)
class excel_tab(excel):
"""Describe the usual properties of Excel-generated TAB-delimited files."""
delimiter = '\t'
register_dialect("excel-tab", excel_tab)
class unix_dialect(Dialect):
"""Describe the usual properties of Unix-generated CSV files."""
delimiter = ','
quotechar = '"'
doublequote = True
skipinitialspace = False
lineterminator = '\n'
quoting = QUOTE_ALL
register_dialect("unix", unix_dialect)
class DictReader(object):
def __init__(self, f, fieldnames=None, restkey=None, restval=None,
dialect="excel", *args, **kwds):
self._fieldnames = fieldnames # list of keys for the dict
self.restkey = restkey # key to catch long rows
self.restval = restval # default value for short rows
self.reader = reader(f, dialect, *args, **kwds)
self.dialect = dialect
self.line_num = 0
def __iter__(self):
return self
@property
def fieldnames(self):
if self._fieldnames is None:
try:
self._fieldnames = next(self.reader)
except StopIteration:
pass
self.line_num = self.reader.line_num
return self._fieldnames
@fieldnames.setter
def fieldnames(self, value):
self._fieldnames = value
def __next__(self):
if self.line_num == 0:
# Used only for its side effect.
self.fieldnames
row = next(self.reader)
self.line_num = self.reader.line_num
# unlike the basic reader, we prefer not to return blanks,
# because we will typically wind up with a dict full of None
# values
while row == []:
row = next(self.reader)
d = dict(zip(self.fieldnames, row))
lf = len(self.fieldnames)
lr = len(row)
if lf < lr:
d[self.restkey] = row[lf:]
elif lf > lr:
for key in self.fieldnames[lr:]:
d[key] = self.restval
return d
next = __next__
class DictWriter(object):
def __init__(self, f, fieldnames, restval="", extrasaction="raise",
dialect="excel", *args, **kwds):
self.fieldnames = fieldnames # list of keys for the dict
self.restval = restval # for writing short dicts
if extrasaction.lower() not in ("raise", "ignore"):
raise ValueError("extrasaction (%s) must be 'raise' or 'ignore'"
% extrasaction)
self.extrasaction = extrasaction
self.writer = writer(f, dialect, *args, **kwds)
def writeheader(self):
header = dict(zip(self.fieldnames, self.fieldnames))
self.writerow(header)
def _dict_to_list(self, rowdict):
if self.extrasaction == "raise":
wrong_fields = [k for k in rowdict if k not in self.fieldnames]
if wrong_fields:
raise ValueError("dict contains fields not in fieldnames: "
+ ", ".join([repr(x) for x in wrong_fields]))
return (rowdict.get(key, self.restval) for key in self.fieldnames)
def writerow(self, rowdict):
return self.writer.writerow(self._dict_to_list(rowdict))
def writerows(self, rowdicts):
return self.writer.writerows(map(self._dict_to_list, rowdicts))
# Guard Sniffer's type checking against builds that exclude complex()
try:
complex
except NameError:
complex = float
class Sniffer(object):
'''
"Sniffs" the format of a CSV file (i.e. delimiter, quotechar)
Returns a Dialect object.
'''
def __init__(self):
# in case there is more than one possible delimiter
self.preferred = [',', '\t', ';', ' ', ':']
def sniff(self, sample, delimiters=None):
"""
Returns a dialect (or None) corresponding to the sample
"""
quotechar, doublequote, delimiter, skipinitialspace = \
self._guess_quote_and_delimiter(sample, delimiters)
if not delimiter:
delimiter, skipinitialspace = self._guess_delimiter(sample,
delimiters)
if not delimiter:
raise Error("Could not determine delimiter")
class dialect(Dialect):
_name = "sniffed"
lineterminator = '\r\n'
quoting = QUOTE_MINIMAL
# escapechar = ''
dialect.doublequote = doublequote
dialect.delimiter = delimiter
# _csv.reader won't accept a quotechar of ''
dialect.quotechar = quotechar or '"'
dialect.skipinitialspace = skipinitialspace
return dialect
def _guess_quote_and_delimiter(self, data, delimiters):
"""
Looks for text enclosed between two identical quotes
(the probable quotechar) which are preceded and followed
by the same character (the probable delimiter).
For example:
,'some text',
The quote with the most wins, same with the delimiter.
If there is no quotechar the delimiter can't be determined
this way.
"""
matches = []
for restr in ('(?P<delim>[^\w\n"\'])(?P<space> ?)(?P<quote>["\']).*?(?P=quote)(?P=delim)', # ,".*?",
'(?:^|\n)(?P<quote>["\']).*?(?P=quote)(?P<delim>[^\w\n"\'])(?P<space> ?)', # ".*?",
'(?P<delim>>[^\w\n"\'])(?P<space> ?)(?P<quote>["\']).*?(?P=quote)(?:$|\n)', # ,".*?"
'(?:^|\n)(?P<quote>["\']).*?(?P=quote)(?:$|\n)'): # ".*?" (no delim, no space)
regexp = re.compile(restr, re.DOTALL | re.MULTILINE)
matches = regexp.findall(data)
if matches:
break
if not matches:
# (quotechar, doublequote, delimiter, skipinitialspace)
return ('', False, None, 0)
quotes = {}
delims = {}
spaces = 0
groupindex = regexp.groupindex
for m in matches:
n = groupindex['quote'] - 1
key = m[n]
if key:
quotes[key] = quotes.get(key, 0) + 1
try:
n = groupindex['delim'] - 1
key = m[n]
except KeyError:
continue
if key and (delimiters is None or key in delimiters):
delims[key] = delims.get(key, 0) + 1
try:
n = groupindex['space'] - 1
except KeyError:
continue
if m[n]:
spaces += 1
quotechar = max(quotes, key=quotes.get)
if delims:
delim = max(delims, key=delims.get)
skipinitialspace = delims[delim] == spaces
if delim == '\n': # most likely a file with a single column
delim = ''
else:
# there is *no* delimiter, it's a single column of quoted data
delim = ''
skipinitialspace = 0
# if we see an extra quote between delimiters, we've got a
# double quoted format
dq_regexp = re.compile(
r"((%(delim)s)|^)\W*%(quote)s[^%(delim)s\n]*%(quote)s[^%(delim)s\n]*%(quote)s\W*((%(delim)s)|$)" % \
{'delim':re.escape(delim), 'quote':quotechar}, re.MULTILINE)
if dq_regexp.search(data):
doublequote = True
else:
doublequote = False
return (quotechar, doublequote, delim, skipinitialspace)
def _guess_delimiter(self, data, delimiters):
"""
The delimiter /should/ occur the same number of times on
each row. However, due to malformed data, it may not. We don't want
an all or nothing approach, so we allow for small variations in this
number.
1) build a table of the frequency of each character on every line.
2) build a table of frequencies of this frequency (meta-frequency?),
e.g. 'x occurred 5 times in 10 rows, 6 times in 1000 rows,
7 times in 2 rows'
3) use the mode of the meta-frequency to determine the /expected/
frequency for that character
4) find out how often the character actually meets that goal
5) the character that best meets its goal is the delimiter
For performance reasons, the data is evaluated in chunks, so it can
try and evaluate the smallest portion of the data possible, evaluating
additional chunks as necessary.
"""
data = list(filter(None, data.split('\n')))
ascii = [unichr(c) for c in range(127)] # 7-bit ASCII
# build frequency tables
chunkLength = min(10, len(data))
iteration = 0
charFrequency = {}
modes = {}
delims = {}
start, end = 0, min(chunkLength, len(data))
while start < len(data):
iteration += 1
for line in data[start:end]:
for char in ascii:
metaFrequency = charFrequency.get(char, {})
# must count even if frequency is 0
freq = line.count(char)
# value is the mode
metaFrequency[freq] = metaFrequency.get(freq, 0) + 1
charFrequency[char] = metaFrequency
for char in charFrequency.keys():
items = list(charFrequency[char].items())
if len(items) == 1 and items[0][0] == 0:
continue
# get the mode of the frequencies
if len(items) > 1:
modes[char] = max(items, key=lambda x: x[1])
# adjust the mode - subtract the sum of all
# other frequencies
items.remove(modes[char])
modes[char] = (modes[char][0], modes[char][1]
- sum(item[1] for item in items))
else:
modes[char] = items[0]
# build a list of possible delimiters
modeList = modes.items()
total = float(chunkLength * iteration)
# (rows of consistent data) / (number of rows) = 100%
consistency = 1.0
# minimum consistency threshold
threshold = 0.9
while len(delims) == 0 and consistency >= threshold:
for k, v in modeList:
if v[0] > 0 and v[1] > 0:
if ((v[1]/total) >= consistency and
(delimiters is None or k in delimiters)):
delims[k] = v
consistency -= 0.01
if len(delims) == 1:
delim = list(delims.keys())[0]
skipinitialspace = (data[0].count(delim) ==
data[0].count("%c " % delim))
return (delim, skipinitialspace)
# analyze another chunkLength lines
start = end
end += chunkLength
if not delims:
return ('', 0)
# if there's more than one, fall back to a 'preferred' list
if len(delims) > 1:
for d in self.preferred:
if d in delims.keys():
skipinitialspace = (data[0].count(d) ==
data[0].count("%c " % d))
return (d, skipinitialspace)
# nothing else indicates a preference, pick the character that
# dominates(?)
items = [(v,k) for (k,v) in delims.items()]
items.sort()
delim = items[-1][1]
skipinitialspace = (data[0].count(delim) ==
data[0].count("%c " % delim))
return (delim, skipinitialspace)
def has_header(self, sample):
# Creates a dictionary of types of data in each column. If any
# column is of a single type (say, integers), *except* for the first
# row, then the first row is presumed to be labels. If the type
# can't be determined, it is assumed to be a string in which case
# the length of the string is the determining factor: if all of the
# rows except for the first are the same length, it's a header.
# Finally, a 'vote' is taken at the end for each column, adding or
# subtracting from the likelihood of the first row being a header.
rdr = reader(StringIO(sample), self.sniff(sample))
header = next(rdr) # assume first row is header
columns = len(header)
columnTypes = {}
for i in range(columns): columnTypes[i] = None
checked = 0
for row in rdr:
# arbitrary number of rows to check, to keep it sane
if checked > 20:
break
checked += 1
if len(row) != columns:
continue # skip rows that have irregular number of columns
for col in list(columnTypes.keys()):
for thisType in [int, float, complex]:
try:
thisType(row[col])
break
except (ValueError, OverflowError):
pass
else:
# fallback to length of string
thisType = len(row[col])
if thisType != columnTypes[col]:
if columnTypes[col] is None: # add new column type
columnTypes[col] = thisType
else:
# type is inconsistent, remove column from
# consideration
del columnTypes[col]
# finally, compare results against first row and "vote"
# on whether it's a header
hasHeader = 0
for col, colType in columnTypes.items():
if type(colType) == type(0): # it's a length
if len(header[col]) != colType:
hasHeader += 1
else:
hasHeader -= 1
else: # attempt typecast
try:
colType(header[col])
except (ValueError, TypeError):
hasHeader += 1
else:
hasHeader -= 1
return hasHeader > 0

View file

@ -1,243 +0,0 @@
from __future__ import absolute_import
import functools
from collections import namedtuple
from threading import RLock
_CacheInfo = namedtuple("_CacheInfo", ["hits", "misses", "maxsize", "currsize"])
@functools.wraps(functools.update_wrapper)
def update_wrapper(
wrapper,
wrapped,
assigned=functools.WRAPPER_ASSIGNMENTS,
updated=functools.WRAPPER_UPDATES,
):
"""
Patch two bugs in functools.update_wrapper.
"""
# workaround for http://bugs.python.org/issue3445
assigned = tuple(attr for attr in assigned if hasattr(wrapped, attr))
wrapper = functools.update_wrapper(wrapper, wrapped, assigned, updated)
# workaround for https://bugs.python.org/issue17482
wrapper.__wrapped__ = wrapped
return wrapper
class _HashedSeq(list):
"""This class guarantees that hash() will be called no more than once
per element. This is important because the lru_cache() will hash
the key multiple times on a cache miss.
"""
__slots__ = 'hashvalue'
def __init__(self, tup, hash=hash):
self[:] = tup
self.hashvalue = hash(tup)
def __hash__(self):
return self.hashvalue
def _make_key(
args,
kwds,
typed,
kwd_mark=(object(),),
fasttypes={int, str},
tuple=tuple,
type=type,
len=len,
):
"""Make a cache key from optionally typed positional and keyword arguments
The key is constructed in a way that is flat as possible rather than
as a nested structure that would take more memory.
If there is only a single argument and its data type is known to cache
its hash value, then that argument is returned without a wrapper. This
saves space and improves lookup speed.
"""
# All of code below relies on kwds preserving the order input by the user.
# Formerly, we sorted() the kwds before looping. The new way is *much*
# faster; however, it means that f(x=1, y=2) will now be treated as a
# distinct call from f(y=2, x=1) which will be cached separately.
key = args
if kwds:
key += kwd_mark
for item in kwds.items():
key += item
if typed:
key += tuple(type(v) for v in args)
if kwds:
key += tuple(type(v) for v in kwds.values())
elif len(key) == 1 and type(key[0]) in fasttypes:
return key[0]
return _HashedSeq(key)
def lru_cache(maxsize=128, typed=False):
"""Least-recently-used cache decorator.
If *maxsize* is set to None, the LRU features are disabled and the cache
can grow without bound.
If *typed* is True, arguments of different types will be cached separately.
For example, f(decimal.Decimal("3.0")) and f(3.0) will be treated as
distinct calls with distinct results. Some types such as str and int may
be cached separately even when typed is false.
Arguments to the cached function must be hashable.
View the cache statistics named tuple (hits, misses, maxsize, currsize)
with f.cache_info(). Clear the cache and statistics with f.cache_clear().
Access the underlying function with f.__wrapped__.
See: https://en.wikipedia.org/wiki/Cache_replacement_policies#Least_recently_used_(LRU)
"""
# Users should only access the lru_cache through its public API:
# cache_info, cache_clear, and f.__wrapped__
# The internals of the lru_cache are encapsulated for thread safety and
# to allow the implementation to change (including a possible C version).
if isinstance(maxsize, int):
# Negative maxsize is treated as 0
if maxsize < 0:
maxsize = 0
elif callable(maxsize) and isinstance(typed, bool):
# The user_function was passed in directly via the maxsize argument
user_function, maxsize = maxsize, 128
wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
wrapper.cache_parameters = lambda: {'maxsize': maxsize, 'typed': typed}
return update_wrapper(wrapper, user_function)
elif maxsize is not None:
raise TypeError('Expected first argument to be an integer, a callable, or None')
def decorating_function(user_function):
wrapper = _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo)
wrapper.cache_parameters = lambda: {'maxsize': maxsize, 'typed': typed}
return update_wrapper(wrapper, user_function)
return decorating_function
def _lru_cache_wrapper(user_function, maxsize, typed, _CacheInfo):
# Constants shared by all lru cache instances:
sentinel = object() # unique object used to signal cache misses
make_key = _make_key # build a key from the function arguments
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
cache = {}
hits = misses = 0
full = False
cache_get = cache.get # bound method to lookup a key or return None
cache_len = cache.__len__ # get cache size without calling len()
lock = RLock() # because linkedlist updates aren't threadsafe
root = [] # root of the circular doubly linked list
root[:] = [root, root, None, None] # initialize by pointing to self
if maxsize == 0:
def wrapper(*args, **kwds):
# No caching -- just a statistics update
nonlocal misses
misses += 1
result = user_function(*args, **kwds)
return result
elif maxsize is None:
def wrapper(*args, **kwds):
# Simple caching without ordering or size limit
nonlocal hits, misses
key = make_key(args, kwds, typed)
result = cache_get(key, sentinel)
if result is not sentinel:
hits += 1
return result
misses += 1
result = user_function(*args, **kwds)
cache[key] = result
return result
else:
def wrapper(*args, **kwds):
# Size limited caching that tracks accesses by recency
nonlocal root, hits, misses, full
key = make_key(args, kwds, typed)
with lock:
link = cache_get(key)
if link is not None:
# Move the link to the front of the circular queue
link_prev, link_next, _key, result = link
link_prev[NEXT] = link_next
link_next[PREV] = link_prev
last = root[PREV]
last[NEXT] = root[PREV] = link
link[PREV] = last
link[NEXT] = root
hits += 1
return result
misses += 1
result = user_function(*args, **kwds)
with lock:
if key in cache:
# Getting here means that this same key was added to the
# cache while the lock was released. Since the link
# update is already done, we need only return the
# computed result and update the count of misses.
pass
elif full:
# Use the old root to store the new key and result.
oldroot = root
oldroot[KEY] = key
oldroot[RESULT] = result
# Empty the oldest link and make it the new root.
# Keep a reference to the old key and old result to
# prevent their ref counts from going to zero during the
# update. That will prevent potentially arbitrary object
# clean-up code (i.e. __del__) from running while we're
# still adjusting the links.
root = oldroot[NEXT]
oldkey = root[KEY]
root[KEY] = root[RESULT] = None
# Now update the cache dictionary.
del cache[oldkey]
# Save the potentially reentrant cache[key] assignment
# for last, after the root and links have been put in
# a consistent state.
cache[key] = oldroot
else:
# Put result in a new link at the front of the queue.
last = root[PREV]
link = [last, root, key, result]
last[NEXT] = root[PREV] = cache[key] = link
# Use the cache_len bound method instead of the len() function
# which could potentially be wrapped in an lru_cache itself.
full = cache_len() >= maxsize
return result
def cache_info():
"""Report cache statistics"""
with lock:
return _CacheInfo(hits, misses, maxsize, cache_len())
def cache_clear():
"""Clear the cache and cache statistics"""
nonlocal hits, misses, full
with lock:
cache.clear()
root[:] = [root, root, None, None]
hits = misses = 0
full = False
wrapper.cache_info = cache_info
wrapper.cache_clear = cache_clear
return wrapper

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,5 @@
from . import main
if __name__ == '__main__':
main()

View file

@ -0,0 +1,24 @@
import sys
if sys.version_info < (3, 9):
def removesuffix(self, suffix):
# suffix='' should not call self[:-0].
if suffix and self.endswith(suffix):
return self[: -len(suffix)]
else:
return self[:]
def removeprefix(self, prefix):
if self.startswith(prefix):
return self[len(prefix) :]
else:
return self[:]
else:
def removesuffix(self, suffix):
return self.removesuffix(suffix)
def removeprefix(self, prefix):
return self.removeprefix(prefix)

View file

@ -292,7 +292,20 @@ class ConnectionManager:
if self.server.ssl_adapter is not None:
try:
s, ssl_env = self.server.ssl_adapter.wrap(s)
except errors.NoSSLError:
except errors.FatalSSLAlert as tls_connection_drop_error:
self.server.error_log(
f'Client {addr !s} lost — peer dropped the TLS '
'connection suddenly, during handshake: '
f'{tls_connection_drop_error !s}',
)
return
except errors.NoSSLError as http_over_https_err:
self.server.error_log(
f'Client {addr !s} attempted to speak plain HTTP into '
'a TCP connection configured for TLS-only traffic — '
'trying to send back a plain HTTP error response: '
f'{http_over_https_err !s}',
)
msg = (
'The client sent a plain HTTP request, but '
'this server only speaks HTTPS on this port.'
@ -311,8 +324,6 @@ class ConnectionManager:
if ex.args[0] not in errors.socket_errors_to_ignore:
raise
return
if not s:
return
mf = self.server.ssl_adapter.makefile
# Re-apply our timeout since we may have a new socket object
if hasattr(s, 'settimeout'):

View file

@ -157,7 +157,7 @@ QUOTED_SLASH = b'%2F'
QUOTED_SLASH_REGEX = re.compile(b''.join((b'(?i)', QUOTED_SLASH)))
_STOPPING_FOR_INTERRUPT = object() # sentinel used during shutdown
_STOPPING_FOR_INTERRUPT = Exception() # sentinel used during shutdown
comma_separated_headers = [
@ -209,7 +209,11 @@ class HeaderReader:
if not line.endswith(CRLF):
raise ValueError('HTTP requires CRLF terminators')
if line[0] in (SPACE, TAB):
if line[:1] in (SPACE, TAB):
# NOTE: `type(line[0]) is int` and `type(line[:1]) is bytes`.
# NOTE: The former causes a the following warning:
# NOTE: `BytesWarning('Comparison between bytes and int')`
# NOTE: The latter is equivalent and does not.
# It's a continuation line.
v = line.strip()
else:
@ -1725,16 +1729,16 @@ class HTTPServer:
"""Run the server forever, and stop it cleanly on exit."""
try:
self.start()
except (KeyboardInterrupt, IOError):
# The time.sleep call might raise
# "IOError: [Errno 4] Interrupted function call" on KBInt.
self.error_log('Keyboard Interrupt: shutting down')
self.stop()
raise
except SystemExit:
self.error_log('SystemExit raised: shutting down')
self.stop()
raise
except KeyboardInterrupt as kb_intr_exc:
underlying_interrupt = self.interrupt
if not underlying_interrupt:
self.interrupt = kb_intr_exc
raise kb_intr_exc from underlying_interrupt
except SystemExit as sys_exit_exc:
underlying_interrupt = self.interrupt
if not underlying_interrupt:
self.interrupt = sys_exit_exc
raise sys_exit_exc from underlying_interrupt
def prepare(self): # noqa: C901 # FIXME
"""Prepare server to serving requests.
@ -2111,6 +2115,13 @@ class HTTPServer:
has completed.
"""
self._interrupt = _STOPPING_FOR_INTERRUPT
if isinstance(interrupt, KeyboardInterrupt):
self.error_log('Keyboard Interrupt: shutting down')
if isinstance(interrupt, SystemExit):
self.error_log('SystemExit raised: shutting down')
self.stop()
self._interrupt = interrupt

View file

@ -27,12 +27,9 @@ except ImportError:
from . import Adapter
from .. import errors
from .._compat import IS_ABOVE_OPENSSL10
from ..makefile import StreamReader, StreamWriter
from ..server import HTTPServer
generic_socket_error = OSError
def _assert_ssl_exc_contains(exc, *msgs):
"""Check whether SSL exception contains either of messages provided."""
@ -265,62 +262,35 @@ class BuiltinSSLAdapter(Adapter):
def wrap(self, sock):
"""Wrap and return the given socket, plus WSGI environ entries."""
EMPTY_RESULT = None, {}
try:
s = self.context.wrap_socket(
sock, do_handshake_on_connect=True, server_side=True,
)
except ssl.SSLError as ex:
if ex.errno == ssl.SSL_ERROR_EOF:
# This is almost certainly due to the cherrypy engine
# 'pinging' the socket to assert it's connectable;
# the 'ping' isn't SSL.
return EMPTY_RESULT
elif ex.errno == ssl.SSL_ERROR_SSL:
if _assert_ssl_exc_contains(ex, 'http request'):
# The client is speaking HTTP to an HTTPS server.
raise errors.NoSSLError
except (
ssl.SSLEOFError,
ssl.SSLZeroReturnError,
) as tls_connection_drop_error:
raise errors.FatalSSLAlert(
*tls_connection_drop_error.args,
) from tls_connection_drop_error
except ssl.SSLError as generic_tls_error:
peer_speaks_plain_http_over_https = (
generic_tls_error.errno == ssl.SSL_ERROR_SSL and
_assert_ssl_exc_contains(generic_tls_error, 'http request')
)
if peer_speaks_plain_http_over_https:
reraised_connection_drop_exc_cls = errors.NoSSLError
else:
reraised_connection_drop_exc_cls = errors.FatalSSLAlert
# Check if it's one of the known errors
# Errors that are caught by PyOpenSSL, but thrown by
# built-in ssl
_block_errors = (
'unknown protocol', 'unknown ca', 'unknown_ca',
'unknown error',
'https proxy request', 'inappropriate fallback',
'wrong version number',
'no shared cipher', 'certificate unknown',
'ccs received early',
'certificate verify failed', # client cert w/o trusted CA
'version too low', # caused by SSL3 connections
'unsupported protocol', # caused by TLS1 connections
)
if _assert_ssl_exc_contains(ex, *_block_errors):
# Accepted error, let's pass
return EMPTY_RESULT
elif _assert_ssl_exc_contains(ex, 'handshake operation timed out'):
# This error is thrown by builtin SSL after a timeout
# when client is speaking HTTP to an HTTPS server.
# The connection can safely be dropped.
return EMPTY_RESULT
raise
except generic_socket_error as exc:
"""It is unclear why exactly this happens.
raise reraised_connection_drop_exc_cls(
*generic_tls_error.args,
) from generic_tls_error
except OSError as tcp_connection_drop_error:
raise errors.FatalSSLAlert(
*tcp_connection_drop_error.args,
) from tcp_connection_drop_error
It's reproducible only with openssl>1.0 and stdlib
:py:mod:`ssl` wrapper.
In CherryPy it's triggered by Checker plugin, which connects
to the app listening to the socket port in TLS mode via plain
HTTP during startup (from the same process).
Ref: https://github.com/cherrypy/cherrypy/issues/1618
"""
is_error0 = exc.args == (0, 'Error')
if is_error0 and IS_ABOVE_OPENSSL10:
return EMPTY_RESULT
raise
return s, self.get_environ(s)
def get_environ(self, sock):

View file

@ -150,7 +150,7 @@ class SSLFileobjectMixin:
return self._safe_call(
False,
super(SSLFileobjectMixin, self).sendall,
*args, **kwargs
*args, **kwargs,
)
def send(self, *args, **kwargs):
@ -158,7 +158,7 @@ class SSLFileobjectMixin:
return self._safe_call(
False,
super(SSLFileobjectMixin, self).send,
*args, **kwargs
*args, **kwargs,
)
@ -196,6 +196,7 @@ class SSLConnectionProxyMeta:
def lock_decorator(method):
"""Create a proxy method for a new class."""
def proxy_wrapper(self, *args):
self._lock.acquire()
try:
@ -212,6 +213,7 @@ class SSLConnectionProxyMeta:
def make_property(property_):
"""Create a proxy method for a new class."""
def proxy_prop_wrapper(self):
return getattr(self._ssl_conn, property_)
proxy_prop_wrapper.__name__ = property_

View file

@ -12,7 +12,10 @@ import pytest
from .._compat import IS_MACOS, IS_WINDOWS # noqa: WPS436
from ..server import Gateway, HTTPServer
from ..testing import ( # noqa: F401 # pylint: disable=unused-import
native_server, wsgi_server,
native_server,
thread_and_wsgi_server,
thread_and_native_server,
wsgi_server,
)
from ..testing import get_server_client
@ -31,6 +34,28 @@ def http_request_timeout():
return computed_timeout
@pytest.fixture
# pylint: disable=redefined-outer-name
def wsgi_server_thread(thread_and_wsgi_server): # noqa: F811
"""Set up and tear down a Cheroot WSGI server instance.
This exposes the server thread.
"""
server_thread, _srv = thread_and_wsgi_server
return server_thread
@pytest.fixture
# pylint: disable=redefined-outer-name
def native_server_thread(thread_and_native_server): # noqa: F811
"""Set up and tear down a Cheroot HTTP server instance.
This exposes the server thread.
"""
server_thread, _srv = thread_and_native_server
return server_thread
@pytest.fixture
# pylint: disable=redefined-outer-name
def wsgi_server_client(wsgi_server): # noqa: F811

View file

@ -1,7 +1,9 @@
"""Tests for TCP connection handling, including proper and timely close."""
import errno
from re import match as _matches_pattern
import socket
import sys
import time
import logging
import traceback as traceback_
@ -17,6 +19,7 @@ from cheroot._compat import IS_CI, IS_MACOS, IS_PYPY, IS_WINDOWS
import cheroot.server
IS_PY36 = sys.version_info[:2] == (3, 6)
IS_SLOW_ENV = IS_MACOS or IS_WINDOWS
@ -53,7 +56,8 @@ class Controller(helper.Controller):
"'POST' != request.method %r" %
req.environ['REQUEST_METHOD'],
)
return "thanks for '%s'" % req.environ['wsgi.input'].read()
input_contents = req.environ['wsgi.input'].read().decode('utf-8')
return f"thanks for '{input_contents !s}'"
def custom_204(req, resp):
"""Render response with status 204."""
@ -605,18 +609,18 @@ def test_keepalive_conn_management(test_client):
pytest.param(RuntimeError, 666, True, id='RuntimeError(666)'),
pytest.param(socket.error, -1, True, id='socket.error(-1)'),
) + (
pytest.param(
ConnectionResetError, errno.ECONNRESET, False,
id='ConnectionResetError(ECONNRESET)',
),
pytest.param(
BrokenPipeError, errno.EPIPE, False,
id='BrokenPipeError(EPIPE)',
),
pytest.param(
BrokenPipeError, errno.ESHUTDOWN, False,
id='BrokenPipeError(ESHUTDOWN)',
),
pytest.param(
ConnectionResetError, errno.ECONNRESET, False,
id='ConnectionResetError(ECONNRESET)',
),
pytest.param(
BrokenPipeError, errno.EPIPE, False,
id='BrokenPipeError(EPIPE)',
),
pytest.param(
BrokenPipeError, errno.ESHUTDOWN, False,
id='BrokenPipeError(ESHUTDOWN)',
),
),
)
def test_broken_connection_during_tcp_fin(
@ -699,6 +703,275 @@ def test_broken_connection_during_tcp_fin(
assert _close_kernel_socket.exception_leaked is exception_leaks
def test_broken_connection_during_http_communication_fallback( # noqa: WPS118
monkeypatch,
test_client,
testing_server,
wsgi_server_thread,
):
"""Test that unhandled internal error cascades into shutdown."""
def _raise_connection_reset(*_args, **_kwargs):
raise ConnectionResetError(666)
def _read_request_line(self):
monkeypatch.setattr(self.conn.rfile, 'close', _raise_connection_reset)
monkeypatch.setattr(self.conn.wfile, 'write', _raise_connection_reset)
_raise_connection_reset()
monkeypatch.setattr(
test_client.server_instance.ConnectionClass.RequestHandlerClass,
'read_request_line',
_read_request_line,
)
test_client.get_connection().send(b'GET / HTTP/1.1')
wsgi_server_thread.join() # no extra logs upon server termination
actual_log_entries = testing_server.error_log.calls[:]
testing_server.error_log.calls.clear() # prevent post-test assertions
expected_log_entries = (
(logging.WARNING, r'^socket\.error 666$'),
(
logging.INFO,
'^Got a connection error while handling a connection '
r'from .*:\d{1,5} \(666\)',
),
(
logging.CRITICAL,
r'A fatal exception happened\. Setting the server interrupt flag '
r'to ConnectionResetError\(666,?\) and giving up\.\n\nPlease, '
'report this on the Cheroot tracker at '
r'<https://github\.com/cherrypy/cheroot/issues/new/choose>, '
'providing a full reproducer with as much context and details '
r'as possible\.$',
),
)
assert len(actual_log_entries) == len(expected_log_entries)
for ( # noqa: WPS352
(expected_log_level, expected_msg_regex),
(actual_msg, actual_log_level, _tb),
) in zip(expected_log_entries, actual_log_entries):
assert expected_log_level == actual_log_level
assert _matches_pattern(expected_msg_regex, actual_msg) is not None, (
f'{actual_msg !r} does not match {expected_msg_regex !r}'
)
def test_kb_int_from_http_handler(
test_client,
testing_server,
wsgi_server_thread,
):
"""Test that a keyboard interrupt from HTTP handler causes shutdown."""
def _trigger_kb_intr(_req, _resp):
raise KeyboardInterrupt('simulated test handler keyboard interrupt')
testing_server.wsgi_app.handlers['/kb_intr'] = _trigger_kb_intr
http_conn = test_client.get_connection()
http_conn.putrequest('GET', '/kb_intr', skip_host=True)
http_conn.putheader('Host', http_conn.host)
http_conn.endheaders()
wsgi_server_thread.join() # no extra logs upon server termination
actual_log_entries = testing_server.error_log.calls[:]
testing_server.error_log.calls.clear() # prevent post-test assertions
expected_log_entries = (
(
logging.DEBUG,
'^Got a server shutdown request while handling a connection '
r'from .*:\d{1,5} \(simulated test handler keyboard interrupt\)$',
),
(
logging.DEBUG,
'^Setting the server interrupt flag to KeyboardInterrupt'
r"\('simulated test handler keyboard interrupt',?\)$",
),
(
logging.INFO,
'^Keyboard Interrupt: shutting down$',
),
)
assert len(actual_log_entries) == len(expected_log_entries)
for ( # noqa: WPS352
(expected_log_level, expected_msg_regex),
(actual_msg, actual_log_level, _tb),
) in zip(expected_log_entries, actual_log_entries):
assert expected_log_level == actual_log_level
assert _matches_pattern(expected_msg_regex, actual_msg) is not None, (
f'{actual_msg !r} does not match {expected_msg_regex !r}'
)
@pytest.mark.xfail(
IS_CI and IS_PYPY and IS_PY36 and not IS_SLOW_ENV,
reason='Fails under PyPy 3.6 under Ubuntu 20.04 in CI for unknown reason',
# NOTE: Actually covers any Linux
strict=False,
)
def test_unhandled_exception_in_request_handler(
mocker,
monkeypatch,
test_client,
testing_server,
wsgi_server_thread,
):
"""Ensure worker threads are resilient to in-handler exceptions."""
class SillyMistake(BaseException): # noqa: WPS418, WPS431
"""A simulated crash within an HTTP handler."""
def _trigger_scary_exc(_req, _resp):
raise SillyMistake('simulated unhandled exception 💣 in test handler')
testing_server.wsgi_app.handlers['/scary_exc'] = _trigger_scary_exc
server_connection_close_spy = mocker.spy(
test_client.server_instance.ConnectionClass,
'close',
)
http_conn = test_client.get_connection()
http_conn.putrequest('GET', '/scary_exc', skip_host=True)
http_conn.putheader('Host', http_conn.host)
http_conn.endheaders()
# NOTE: This spy ensure the log entry gets recorded before we're testing
# NOTE: them and before server shutdown, preserving their order and making
# NOTE: the log entry presence non-flaky.
while not server_connection_close_spy.called: # noqa: WPS328
pass
assert len(testing_server.requests._threads) == 10
while testing_server.requests.idle < 10: # noqa: WPS328
pass
assert len(testing_server.requests._threads) == 10
testing_server.interrupt = SystemExit('test requesting shutdown')
assert not testing_server.requests._threads
wsgi_server_thread.join() # no extra logs upon server termination
actual_log_entries = testing_server.error_log.calls[:]
testing_server.error_log.calls.clear() # prevent post-test assertions
expected_log_entries = (
(
logging.ERROR,
'^Unhandled error while processing an incoming connection '
'SillyMistake'
r"\('simulated unhandled exception 💣 in test handler',?\)$",
),
(
logging.INFO,
'^SystemExit raised: shutting down$',
),
)
assert len(actual_log_entries) == len(expected_log_entries)
for ( # noqa: WPS352
(expected_log_level, expected_msg_regex),
(actual_msg, actual_log_level, _tb),
) in zip(expected_log_entries, actual_log_entries):
assert expected_log_level == actual_log_level
assert _matches_pattern(expected_msg_regex, actual_msg) is not None, (
f'{actual_msg !r} does not match {expected_msg_regex !r}'
)
@pytest.mark.xfail(
IS_CI and IS_PYPY and IS_PY36 and not IS_SLOW_ENV,
reason='Fails under PyPy 3.6 under Ubuntu 20.04 in CI for unknown reason',
# NOTE: Actually covers any Linux
strict=False,
)
def test_remains_alive_post_unhandled_exception(
mocker,
monkeypatch,
test_client,
testing_server,
wsgi_server_thread,
):
"""Ensure worker threads are resilient to unhandled exceptions."""
class ScaryCrash(BaseException): # noqa: WPS418, WPS431
"""A simulated crash during HTTP parsing."""
_orig_read_request_line = (
test_client.server_instance.
ConnectionClass.RequestHandlerClass.
read_request_line
)
def _read_request_line(self):
_orig_read_request_line(self)
raise ScaryCrash(666)
monkeypatch.setattr(
test_client.server_instance.ConnectionClass.RequestHandlerClass,
'read_request_line',
_read_request_line,
)
server_connection_close_spy = mocker.spy(
test_client.server_instance.ConnectionClass,
'close',
)
# NOTE: The initial worker thread count is 10.
assert len(testing_server.requests._threads) == 10
test_client.get_connection().send(b'GET / HTTP/1.1')
# NOTE: This spy ensure the log entry gets recorded before we're testing
# NOTE: them and before server shutdown, preserving their order and making
# NOTE: the log entry presence non-flaky.
while not server_connection_close_spy.called: # noqa: WPS328
pass
# NOTE: This checks for whether there's any crashed threads
while testing_server.requests.idle < 10: # noqa: WPS328
pass
assert len(testing_server.requests._threads) == 10
assert all(
worker_thread.is_alive()
for worker_thread in testing_server.requests._threads
)
testing_server.interrupt = SystemExit('test requesting shutdown')
assert not testing_server.requests._threads
wsgi_server_thread.join() # no extra logs upon server termination
actual_log_entries = testing_server.error_log.calls[:]
testing_server.error_log.calls.clear() # prevent post-test assertions
expected_log_entries = (
(
logging.ERROR,
'^Unhandled error while processing an incoming connection '
r'ScaryCrash\(666,?\)$',
),
(
logging.INFO,
'^SystemExit raised: shutting down$',
),
)
assert len(actual_log_entries) == len(expected_log_entries)
for ( # noqa: WPS352
(expected_log_level, expected_msg_regex),
(actual_msg, actual_log_level, _tb),
) in zip(expected_log_entries, actual_log_entries):
assert expected_log_level == actual_log_level
assert _matches_pattern(expected_msg_regex, actual_msg) is not None, (
f'{actual_msg !r} does not match {expected_msg_regex !r}'
)
@pytest.mark.parametrize(
'timeout_before_headers',
(
@ -917,7 +1190,7 @@ def test_100_Continue(test_client):
status_line, _actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == 200
expected_resp_body = ("thanks for '%s'" % body).encode()
expected_resp_body = f"thanks for '{body.decode() !s}'".encode()
assert actual_resp_body == expected_resp_body
conn.close()
@ -987,7 +1260,7 @@ def test_readall_or_close(test_client, max_request_body_size):
status_line, actual_headers, actual_resp_body = webtest.shb(response)
actual_status = int(status_line[:3])
assert actual_status == 200
expected_resp_body = ("thanks for '%s'" % body).encode()
expected_resp_body = f"thanks for '{body.decode() !s}'".encode()
assert actual_resp_body == expected_resp_body
conn.close()

View file

@ -134,7 +134,7 @@ def test_query_string_request(test_client):
'/hello', # plain
'/query_string?test=True', # query
'/{0}?{1}={2}'.format( # quoted unicode
*map(urllib.parse.quote, ('Юххууу', 'ї', 'йо'))
*map(urllib.parse.quote, ('Юххууу', 'ї', 'йо')),
),
),
)

View file

@ -31,7 +31,7 @@ config = {
@contextmanager
def cheroot_server(server_factory):
def cheroot_server(server_factory): # noqa: WPS210
"""Set up and tear down a Cheroot server instance."""
conf = config[server_factory].copy()
bind_port = conf.pop('bind_addr')[-1]
@ -41,7 +41,7 @@ def cheroot_server(server_factory):
actual_bind_addr = (interface, bind_port)
httpserver = server_factory( # create it
bind_addr=actual_bind_addr,
**conf
**conf,
)
except OSError:
pass
@ -50,27 +50,52 @@ def cheroot_server(server_factory):
httpserver.shutdown_timeout = 0 # Speed-up tests teardown
threading.Thread(target=httpserver.safe_start).start() # spawn it
# FIXME: Expose this thread through a fixture so that it
# FIXME: could be awaited in tests.
server_thread = threading.Thread(target=httpserver.safe_start)
server_thread.start() # spawn it
while not httpserver.ready: # wait until fully initialized and bound
time.sleep(0.1)
yield httpserver
httpserver.stop() # destroy it
try:
yield server_thread, httpserver
finally:
httpserver.stop() # destroy it
server_thread.join() # wait for the thread to be turn down
@pytest.fixture
def wsgi_server():
def thread_and_wsgi_server():
"""Set up and tear down a Cheroot WSGI server instance.
This emits a tuple of a thread and a server instance.
"""
with cheroot_server(cheroot.wsgi.Server) as (server_thread, srv):
yield server_thread, srv
@pytest.fixture
def thread_and_native_server():
"""Set up and tear down a Cheroot HTTP server instance.
This emits a tuple of a thread and a server instance.
"""
with cheroot_server(cheroot.server.HTTPServer) as (server_thread, srv):
yield server_thread, srv
@pytest.fixture
def wsgi_server(thread_and_wsgi_server): # noqa: WPS442
"""Set up and tear down a Cheroot WSGI server instance."""
with cheroot_server(cheroot.wsgi.Server) as srv:
yield srv
_server_thread, srv = thread_and_wsgi_server
return srv
@pytest.fixture
def native_server():
def native_server(thread_and_native_server): # noqa: WPS442
"""Set up and tear down a Cheroot HTTP server instance."""
with cheroot_server(cheroot.server.HTTPServer) as srv:
yield srv
_server_thread, srv = thread_and_native_server
return srv
class _TestClient:

View file

@ -6,6 +6,7 @@
"""
import collections
import logging
import threading
import time
import socket
@ -30,7 +31,7 @@ class TrueyZero:
trueyzero = TrueyZero()
_SHUTDOWNREQUEST = None
_SHUTDOWNREQUEST = object()
class WorkerThread(threading.Thread):
@ -99,39 +100,127 @@ class WorkerThread(threading.Thread):
threading.Thread.__init__(self)
def run(self):
"""Process incoming HTTP connections.
"""Set up incoming HTTP connection processing loop.
Retrieves incoming connections from thread pool.
This is the thread's entry-point. It performs lop-layer
exception handling and interrupt processing.
:exc:`KeyboardInterrupt` and :exc:`SystemExit` bubbling up
from the inner-layer code constitute a global server interrupt
request. When they happen, the worker thread exits.
:raises BaseException: when an unexpected non-interrupt
exception leaks from the inner layers
# noqa: DAR401 KeyboardInterrupt SystemExit
"""
self.server.stats['Worker Threads'][self.name] = self.stats
self.ready = True
try:
self.ready = True
while True:
conn = self.server.requests.get()
if conn is _SHUTDOWNREQUEST:
return
self._process_connections_until_interrupted()
except (KeyboardInterrupt, SystemExit) as interrupt_exc:
interrupt_cause = interrupt_exc.__cause__ or interrupt_exc
self.server.error_log(
f'Setting the server interrupt flag to {interrupt_cause !r}',
level=logging.DEBUG,
)
self.server.interrupt = interrupt_cause
except BaseException as underlying_exc: # noqa: WPS424
# NOTE: This is the last resort logging with the last dying breath
# NOTE: of the worker. It is only reachable when exceptions happen
# NOTE: in the `finally` branch of the internal try/except block.
self.server.error_log(
'A fatal exception happened. Setting the server interrupt flag'
f' to {underlying_exc !r} and giving up.'
'\N{NEW LINE}\N{NEW LINE}'
'Please, report this on the Cheroot tracker at '
'<https://github.com/cherrypy/cheroot/issues/new/choose>, '
'providing a full reproducer with as much context and details as possible.',
level=logging.CRITICAL,
traceback=True,
)
self.server.interrupt = underlying_exc
raise
finally:
self.ready = False
self.conn = conn
is_stats_enabled = self.server.stats['Enabled']
def _process_connections_until_interrupted(self):
"""Process incoming HTTP connections in an infinite loop.
Retrieves incoming connections from thread pool, processing
them one by one.
:raises SystemExit: on the internal requests to stop the
server instance
"""
while True:
conn = self.server.requests.get()
if conn is _SHUTDOWNREQUEST:
return
self.conn = conn
is_stats_enabled = self.server.stats['Enabled']
if is_stats_enabled:
self.start_time = time.time()
keep_conn_open = False
try:
keep_conn_open = conn.communicate()
except ConnectionError as connection_error:
keep_conn_open = False # Drop the connection cleanly
self.server.error_log(
'Got a connection error while handling a '
f'connection from {conn.remote_addr !s}:'
f'{conn.remote_port !s} ({connection_error !s})',
level=logging.INFO,
)
continue
except (KeyboardInterrupt, SystemExit) as shutdown_request:
# Shutdown request
keep_conn_open = False # Drop the connection cleanly
self.server.error_log(
'Got a server shutdown request while handling a '
f'connection from {conn.remote_addr !s}:'
f'{conn.remote_port !s} ({shutdown_request !s})',
level=logging.DEBUG,
)
raise SystemExit(
str(shutdown_request),
) from shutdown_request
except BaseException as unhandled_error: # noqa: WPS424
# NOTE: Only a shutdown request should bubble up to the
# NOTE: external cleanup code. Otherwise, this thread dies.
# NOTE: If this were to happen, the threadpool would still
# NOTE: list a dead thread without knowing its state. And
# NOTE: the calling code would fail to schedule processing
# NOTE: of new requests.
self.server.error_log(
'Unhandled error while processing an incoming '
f'connection {unhandled_error !r}',
level=logging.ERROR,
traceback=True,
)
continue # Prevent the thread from dying
finally:
# NOTE: Any exceptions coming from within `finally` may
# NOTE: kill the thread, causing the threadpool to only
# NOTE: contain references to dead threads rendering the
# NOTE: server defunct, effectively meaning a DoS.
# NOTE: Ideally, things called here should process
# NOTE: everything recoverable internally. Any unhandled
# NOTE: errors will bubble up into the outer try/except
# NOTE: block. They will be treated as fatal and turned
# NOTE: into server shutdown requests and then reraised
# NOTE: unconditionally.
if keep_conn_open:
self.server.put_conn(conn)
else:
conn.close()
if is_stats_enabled:
self.start_time = time.time()
keep_conn_open = False
try:
keep_conn_open = conn.communicate()
finally:
if keep_conn_open:
self.server.put_conn(conn)
else:
conn.close()
if is_stats_enabled:
self.requests_seen += self.conn.requests_seen
self.bytes_read += self.conn.rfile.bytes_read
self.bytes_written += self.conn.wfile.bytes_written
self.work_time += time.time() - self.start_time
self.start_time = None
self.conn = None
except (KeyboardInterrupt, SystemExit) as ex:
self.server.interrupt = ex
self.requests_seen += conn.requests_seen
self.bytes_read += conn.rfile.bytes_read
self.bytes_written += conn.wfile.bytes_written
self.work_time += time.time() - self.start_time
self.start_time = None
self.conn = None
class ThreadPool:

View file

@ -52,7 +52,7 @@ Automatic conversion
--------------------
An included script called `futurize
<http://python-future.org/automatic_conversion.html>`_ aids in converting
<https://python-future.org/automatic_conversion.html>`_ aids in converting
code (from either Python 2 or Python 3) to code compatible with both
platforms. It is similar to ``python-modernize`` but goes further in
providing Python 3 compatibility through the use of the backported types
@ -62,21 +62,20 @@ and builtin functions in ``future``.
Documentation
-------------
See: http://python-future.org
See: https://python-future.org
Credits
-------
:Author: Ed Schofield, Jordan M. Adler, et al
:Sponsor: Python Charmers Pty Ltd, Australia, and Python Charmers Pte
Ltd, Singapore. http://pythoncharmers.com
:Others: See docs/credits.rst or http://python-future.org/credits.html
:Sponsor: Python Charmers: https://pythoncharmers.com
:Others: See docs/credits.rst or https://python-future.org/credits.html
Licensing
---------
Copyright 2013-2019 Python Charmers Pty Ltd, Australia.
Copyright 2013-2024 Python Charmers, Australia.
The software is distributed under an MIT licence. See LICENSE.txt.
"""
@ -84,10 +83,10 @@ The software is distributed under an MIT licence. See LICENSE.txt.
__title__ = 'future'
__author__ = 'Ed Schofield'
__license__ = 'MIT'
__copyright__ = 'Copyright 2013-2019 Python Charmers Pty Ltd'
__ver_major__ = 0
__ver_minor__ = 18
__ver_patch__ = 3
__copyright__ = 'Copyright 2013-2024 Python Charmers (https://pythoncharmers.com)'
__ver_major__ = 1
__ver_minor__ = 0
__ver_patch__ = 0
__ver_sub__ = ''
__version__ = "%d.%d.%d%s" % (__ver_major__, __ver_minor__,
__ver_patch__, __ver_sub__)

View file

@ -689,7 +689,7 @@ class date(object):
@classmethod
def fromordinal(cls, n):
"""Contruct a date from a proleptic Gregorian ordinal.
"""Construct a date from a proleptic Gregorian ordinal.
January 1 of year 1 is day 1. Only the year, month and day are
non-zero in the result.

View file

@ -2867,7 +2867,7 @@ def parse_content_type_header(value):
_find_mime_parameters(ctype, value)
return ctype
ctype.append(token)
# XXX: If we really want to follow the formal grammer we should make
# XXX: If we really want to follow the formal grammar we should make
# mantype and subtype specialized TokenLists here. Probably not worth it.
if not value or value[0] != '/':
ctype.defects.append(errors.InvalidHeaderDefect(

View file

@ -26,7 +26,7 @@ class Parser(object):
textual representation of the message.
The string must be formatted as a block of RFC 2822 headers and header
continuation lines, optionally preceeded by a `Unix-from' header. The
continuation lines, optionally preceded by a `Unix-from' header. The
header block is terminated either by the end of the string or by a
blank line.
@ -92,7 +92,7 @@ class BytesParser(object):
textual representation of the message.
The input must be formatted as a block of RFC 2822 headers and header
continuation lines, optionally preceeded by a `Unix-from' header. The
continuation lines, optionally preceded by a `Unix-from' header. The
header block is terminated either by the end of the input or by a
blank line.

View file

@ -1851,7 +1851,7 @@ def lwp_cookie_str(cookie):
class LWPCookieJar(FileCookieJar):
"""
The LWPCookieJar saves a sequence of "Set-Cookie3" lines.
"Set-Cookie3" is the format used by the libwww-perl libary, not known
"Set-Cookie3" is the format used by the libwww-perl library, not known
to be compatible with any browser, but which is easy to read and
doesn't lose information about RFC 2965 cookies.

View file

@ -28,7 +28,6 @@ import importlib
# import collections.abc # not present on Py2.7
import re
import subprocess
import imp
import time
try:
import sysconfig
@ -341,37 +340,6 @@ def rmtree(path):
if error.errno != errno.ENOENT:
raise
def make_legacy_pyc(source):
"""Move a PEP 3147 pyc/pyo file to its legacy pyc/pyo location.
The choice of .pyc or .pyo extension is done based on the __debug__ flag
value.
:param source: The file system path to the source file. The source file
does not need to exist, however the PEP 3147 pyc file must exist.
:return: The file system path to the legacy pyc file.
"""
pyc_file = imp.cache_from_source(source)
up_one = os.path.dirname(os.path.abspath(source))
legacy_pyc = os.path.join(up_one, source + ('c' if __debug__ else 'o'))
os.rename(pyc_file, legacy_pyc)
return legacy_pyc
def forget(modname):
"""'Forget' a module was ever imported.
This removes the module from sys.modules and deletes any PEP 3147 or
legacy .pyc and .pyo files.
"""
unload(modname)
for dirname in sys.path:
source = os.path.join(dirname, modname + '.py')
# It doesn't matter if they exist or not, unlink all possible
# combinations of PEP 3147 and legacy pyc and pyo files.
unlink(source + 'c')
unlink(source + 'o')
unlink(imp.cache_from_source(source, debug_override=True))
unlink(imp.cache_from_source(source, debug_override=False))
# On some platforms, should not run gui test even if it is allowed
# in `use_resources'.

View file

@ -134,10 +134,11 @@ from __future__ import (absolute_import, division, print_function,
from future.builtins import bytes, dict, int, range, str
import base64
# Py2.7 compatibility hack
base64.encodebytes = base64.encodestring
base64.decodebytes = base64.decodestring
import sys
if sys.version_info < (3, 9):
# Py2.7 compatibility hack
base64.encodebytes = base64.encodestring
base64.decodebytes = base64.decodestring
import time
from datetime import datetime
from future.backports.http import client as http_client
@ -1251,7 +1252,7 @@ class Transport(object):
# Send HTTP request.
#
# @param host Host descriptor (URL or (URL, x509 info) tuple).
# @param handler Targer RPC handler (a path relative to host)
# @param handler Target RPC handler (a path relative to host)
# @param request_body The XML-RPC request body
# @param debug Enable debugging if debug is true.
# @return An HTTPConnection.

View file

@ -2,7 +2,7 @@
A module that brings in equivalents of the new and modified Python 3
builtins into Py2. Has no effect on Py3.
See the docs `here <http://python-future.org/what-else.html>`_
See the docs `here <https://python-future.org/what-else.html>`_
(``docs/what-else.rst``) for more information.
"""

View file

@ -1,8 +1,13 @@
from __future__ import absolute_import
from future.utils import PY3
from future.utils import PY3, PY39_PLUS
if PY3:
from _dummy_thread import *
if PY39_PLUS:
# _dummy_thread and dummy_threading modules were both deprecated in
# Python 3.7 and removed in Python 3.9
from _thread import *
elif PY3:
from _dummy_thread import *
else:
__future_module__ = True
from dummy_thread import *

View file

@ -0,0 +1,7 @@
from __future__ import absolute_import
from future.utils import PY3
from multiprocessing import *
if not PY3:
__future_module__ = True
from multiprocessing.queues import SimpleQueue

View file

@ -1,9 +1,18 @@
from __future__ import absolute_import
import sys
from future.standard_library import suspend_hooks
from future.utils import PY3
if PY3:
from test.support import *
if sys.version_info[:2] >= (3, 10):
from test.support.os_helper import (
EnvironmentVarGuard,
TESTFN,
)
from test.support.warnings_helper import check_warnings
else:
__future_module__ = True
with suspend_hooks():

View file

@ -17,7 +17,7 @@ And then these normal Py3 imports work on both Py3 and Py2::
import socketserver
import winreg # on Windows only
import test.support
import html, html.parser, html.entites
import html, html.parser, html.entities
import http, http.client, http.server
import http.cookies, http.cookiejar
import urllib.parse, urllib.request, urllib.response, urllib.error, urllib.robotparser
@ -33,6 +33,7 @@ And then these normal Py3 imports work on both Py3 and Py2::
from collections import OrderedDict, Counter, ChainMap # even on Py2.6
from subprocess import getoutput, getstatusoutput
from subprocess import check_output # even on Py2.6
from multiprocessing import SimpleQueue
(The renamed modules and functions are still available under their old
names on Python 2.)
@ -62,9 +63,12 @@ from __future__ import absolute_import, division, print_function
import sys
import logging
import imp
# imp was deprecated in python 3.6
if sys.version_info >= (3, 6):
import importlib as imp
else:
import imp
import contextlib
import types
import copy
import os
@ -108,6 +112,7 @@ RENAMES = {
'future.moves.socketserver': 'socketserver',
'ConfigParser': 'configparser',
'repr': 'reprlib',
'multiprocessing.queues': 'multiprocessing',
# 'FileDialog': 'tkinter.filedialog',
# 'tkFileDialog': 'tkinter.filedialog',
# 'SimpleDialog': 'tkinter.simpledialog',
@ -125,7 +130,7 @@ RENAMES = {
# 'Tkinter': 'tkinter',
'_winreg': 'winreg',
'thread': '_thread',
'dummy_thread': '_dummy_thread',
'dummy_thread': '_dummy_thread' if sys.version_info < (3, 9) else '_thread',
# 'anydbm': 'dbm', # causes infinite import loop
# 'whichdb': 'dbm', # causes infinite import loop
# anydbm and whichdb are handled by fix_imports2
@ -184,6 +189,7 @@ MOVES = [('collections', 'UserList', 'UserList', 'UserList'),
('itertools', 'filterfalse','itertools', 'ifilterfalse'),
('itertools', 'zip_longest','itertools', 'izip_longest'),
('sys', 'intern','__builtin__', 'intern'),
('multiprocessing', 'SimpleQueue', 'multiprocessing.queues', 'SimpleQueue'),
# The re module has no ASCII flag in Py2, but this is the default.
# Set re.ASCII to a zero constant. stat.ST_MODE just happens to be one
# (and it exists on Py2.6+).

View file

@ -223,9 +223,11 @@ class newint(with_metaclass(BaseNewInt, long)):
def __rpow__(self, other):
value = super(newint, self).__rpow__(other)
if value is NotImplemented:
if isint(value):
return newint(value)
elif value is NotImplemented:
return other ** long(self)
return newint(value)
return value
def __lshift__(self, other):
if not isint(other):
@ -318,7 +320,7 @@ class newint(with_metaclass(BaseNewInt, long)):
bits = length * 8
num = (2**bits) + self
if num <= 0:
raise OverflowError("int too smal to convert")
raise OverflowError("int too small to convert")
else:
if self < 0:
raise OverflowError("can't convert negative int to unsigned")

View file

@ -105,7 +105,7 @@ class newrange(Sequence):
raise ValueError('%r is not in range' % value)
def count(self, value):
"""Return the number of ocurrences of integer `value`
"""Return the number of occurrences of integer `value`
in the sequence this range represents."""
# a value can occur exactly zero or one times
return int(value in self)

View file

@ -3,6 +3,8 @@ inflect: english language inflection
- correctly generate plurals, ordinals, indefinite articles
- convert numbers to words
Copyright (C) 2010 Paul Dyson
Based upon the Perl module
`Lingua::EN::Inflect <https://metacpan.org/pod/Lingua::EN::Inflect>`_.
@ -50,34 +52,33 @@ Exceptions:
"""
from __future__ import annotations
import ast
import re
import functools
import collections
import contextlib
import functools
import itertools
import re
from numbers import Number
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Union,
Optional,
Iterable,
List,
Match,
Tuple,
Callable,
Optional,
Sequence,
Tuple,
Union,
cast,
Any,
)
from typing_extensions import Literal
from numbers import Number
from pydantic import Field
from typing_extensions import Annotated
from .compat.pydantic1 import validate_call
from .compat.pydantic import same_method
from more_itertools import windowed_complete
from typeguard import typechecked
from typing_extensions import Annotated, Literal
class UnknownClassicalModeError(Exception):
@ -258,9 +259,9 @@ si_sb_irregular_compound = {v: k for (k, v) in pl_sb_irregular_compound.items()}
for k in list(si_sb_irregular_compound):
if "|" in k:
k1, k2 = k.split("|")
si_sb_irregular_compound[k1] = si_sb_irregular_compound[
k2
] = si_sb_irregular_compound[k]
si_sb_irregular_compound[k1] = si_sb_irregular_compound[k2] = (
si_sb_irregular_compound[k]
)
del si_sb_irregular_compound[k]
# si_sb_irregular_keys = enclose('|'.join(si_sb_irregular.keys()))
@ -1597,7 +1598,7 @@ pl_prep_bysize = bysize(pl_prep_list_da)
pl_prep = enclose("|".join(pl_prep_list_da))
pl_sb_prep_dual_compound = fr"(.*?)((?:-|\s+)(?:{pl_prep})(?:-|\s+))a(?:-|\s+)(.*)"
pl_sb_prep_dual_compound = rf"(.*?)((?:-|\s+)(?:{pl_prep})(?:-|\s+))a(?:-|\s+)(.*)"
singular_pronoun_genders = {
@ -1764,7 +1765,7 @@ plverb_ambiguous_pres = {
}
plverb_ambiguous_pres_keys = re.compile(
fr"^({enclose('|'.join(plverb_ambiguous_pres))})((\s.*)?)$", re.IGNORECASE
rf"^({enclose('|'.join(plverb_ambiguous_pres))})((\s.*)?)$", re.IGNORECASE
)
@ -1804,7 +1805,7 @@ pl_count_one = ("1", "a", "an", "one", "each", "every", "this", "that")
pl_adj_special = {"a": "some", "an": "some", "this": "these", "that": "those"}
pl_adj_special_keys = re.compile(
fr"^({enclose('|'.join(pl_adj_special))})$", re.IGNORECASE
rf"^({enclose('|'.join(pl_adj_special))})$", re.IGNORECASE
)
pl_adj_poss = {
@ -1816,7 +1817,7 @@ pl_adj_poss = {
"their": "their",
}
pl_adj_poss_keys = re.compile(fr"^({enclose('|'.join(pl_adj_poss))})$", re.IGNORECASE)
pl_adj_poss_keys = re.compile(rf"^({enclose('|'.join(pl_adj_poss))})$", re.IGNORECASE)
# 2. INDEFINITE ARTICLES
@ -1883,7 +1884,7 @@ ordinal = dict(
twelve="twelfth",
)
ordinal_suff = re.compile(fr"({'|'.join(ordinal)})\Z")
ordinal_suff = re.compile(rf"({'|'.join(ordinal)})\Z")
# NUMBERS
@ -1948,13 +1949,13 @@ DOLLAR_DIGITS = re.compile(r"\$(\d+)")
FUNCTION_CALL = re.compile(r"((\w+)\([^)]*\)*)", re.IGNORECASE)
PARTITION_WORD = re.compile(r"\A(\s*)(.+?)(\s*)\Z")
PL_SB_POSTFIX_ADJ_STEMS_RE = re.compile(
fr"^(?:{pl_sb_postfix_adj_stems})$", re.IGNORECASE
rf"^(?:{pl_sb_postfix_adj_stems})$", re.IGNORECASE
)
PL_SB_PREP_DUAL_COMPOUND_RE = re.compile(
fr"^(?:{pl_sb_prep_dual_compound})$", re.IGNORECASE
rf"^(?:{pl_sb_prep_dual_compound})$", re.IGNORECASE
)
DENOMINATOR = re.compile(r"(?P<denominator>.+)( (per|a) .+)")
PLVERB_SPECIAL_S_RE = re.compile(fr"^({plverb_special_s})$")
PLVERB_SPECIAL_S_RE = re.compile(rf"^({plverb_special_s})$")
WHITESPACE = re.compile(r"\s")
ENDS_WITH_S = re.compile(r"^(.*[^s])s$", re.IGNORECASE)
ENDS_WITH_APOSTROPHE_S = re.compile(r"^(.*)'s?$")
@ -2020,10 +2021,25 @@ class Words(str):
self.last = self.split_[-1]
Word = Annotated[str, Field(min_length=1)]
Falsish = Any # ideally, falsish would only validate on bool(value) is False
_STATIC_TYPE_CHECKING = TYPE_CHECKING
# ^-- Workaround for typeguard AST manipulation:
# https://github.com/agronholm/typeguard/issues/353#issuecomment-1556306554
if _STATIC_TYPE_CHECKING: # pragma: no cover
Word = Annotated[str, "String with at least 1 character"]
else:
class _WordMeta(type): # Too dynamic to be supported by mypy...
def __instancecheck__(self, instance: Any) -> bool:
return isinstance(instance, str) and len(instance) >= 1
class Word(metaclass=_WordMeta): # type: ignore[no-redef]
"""String with at least 1 character"""
class engine:
def __init__(self) -> None:
self.classical_dict = def_classical.copy()
@ -2045,7 +2061,7 @@ class engine:
def _number_args(self, val):
self.__number_args = val
@validate_call
@typechecked
def defnoun(self, singular: Optional[Word], plural: Optional[Word]) -> int:
"""
Set the noun plural of singular to plural.
@ -2057,7 +2073,7 @@ class engine:
self.si_sb_user_defined.extend((plural, singular))
return 1
@validate_call
@typechecked
def defverb(
self,
s1: Optional[Word],
@ -2082,7 +2098,7 @@ class engine:
self.pl_v_user_defined.extend((s1, p1, s2, p2, s3, p3))
return 1
@validate_call
@typechecked
def defadj(self, singular: Optional[Word], plural: Optional[Word]) -> int:
"""
Set the adjective plural of singular to plural.
@ -2093,7 +2109,7 @@ class engine:
self.pl_adj_user_defined.extend((singular, plural))
return 1
@validate_call
@typechecked
def defa(self, pattern: Optional[Word]) -> int:
"""
Define the indefinite article as 'a' for words matching pattern.
@ -2103,7 +2119,7 @@ class engine:
self.A_a_user_defined.extend((pattern, "a"))
return 1
@validate_call
@typechecked
def defan(self, pattern: Optional[Word]) -> int:
"""
Define the indefinite article as 'an' for words matching pattern.
@ -2121,8 +2137,8 @@ class engine:
return
try:
re.match(pattern, "")
except re.error:
raise BadUserDefinedPatternError(pattern)
except re.error as err:
raise BadUserDefinedPatternError(pattern) from err
def checkpatplural(self, pattern: Optional[Word]) -> None:
"""
@ -2130,10 +2146,10 @@ class engine:
"""
return
@validate_call
@typechecked
def ud_match(self, word: Word, wordlist: Sequence[Optional[Word]]) -> Optional[str]:
for i in range(len(wordlist) - 2, -2, -2): # backwards through even elements
mo = re.search(fr"^{wordlist[i]}$", word, re.IGNORECASE)
mo = re.search(rf"^{wordlist[i]}$", word, re.IGNORECASE)
if mo:
if wordlist[i + 1] is None:
return None
@ -2191,8 +2207,8 @@ class engine:
if count is not None:
try:
self.persistent_count = int(count)
except ValueError:
raise BadNumValueError
except ValueError as err:
raise BadNumValueError from err
if (show is None) or show:
return str(count)
else:
@ -2270,7 +2286,7 @@ class engine:
# 0. PERFORM GENERAL INFLECTIONS IN A STRING
@validate_call
@typechecked
def inflect(self, text: Word) -> str:
"""
Perform inflections in a string.
@ -2347,7 +2363,7 @@ class engine:
else:
return "", "", ""
@validate_call
@typechecked
def plural(self, text: Word, count: Optional[Union[str, int, Any]] = None) -> str:
"""
Return the plural of text.
@ -2371,7 +2387,7 @@ class engine:
)
return f"{pre}{plural}{post}"
@validate_call
@typechecked
def plural_noun(
self, text: Word, count: Optional[Union[str, int, Any]] = None
) -> str:
@ -2392,7 +2408,7 @@ class engine:
plural = self.postprocess(word, self._plnoun(word, count))
return f"{pre}{plural}{post}"
@validate_call
@typechecked
def plural_verb(
self, text: Word, count: Optional[Union[str, int, Any]] = None
) -> str:
@ -2416,7 +2432,7 @@ class engine:
)
return f"{pre}{plural}{post}"
@validate_call
@typechecked
def plural_adj(
self, text: Word, count: Optional[Union[str, int, Any]] = None
) -> str:
@ -2437,7 +2453,7 @@ class engine:
plural = self.postprocess(word, self._pl_special_adjective(word, count) or word)
return f"{pre}{plural}{post}"
@validate_call
@typechecked
def compare(self, word1: Word, word2: Word) -> Union[str, bool]:
"""
compare word1 and word2 for equality regardless of plurality
@ -2460,15 +2476,13 @@ class engine:
>>> compare('egg', '')
Traceback (most recent call last):
...
pydantic...ValidationError: ...
...
...at least 1 characters...
typeguard.TypeCheckError:...is not an instance of inflect.Word
"""
norms = self.plural_noun, self.plural_verb, self.plural_adj
results = (self._plequal(word1, word2, norm) for norm in norms)
return next(filter(None, results), False)
@validate_call
@typechecked
def compare_nouns(self, word1: Word, word2: Word) -> Union[str, bool]:
"""
compare word1 and word2 for equality regardless of plurality
@ -2484,7 +2498,7 @@ class engine:
"""
return self._plequal(word1, word2, self.plural_noun)
@validate_call
@typechecked
def compare_verbs(self, word1: Word, word2: Word) -> Union[str, bool]:
"""
compare word1 and word2 for equality regardless of plurality
@ -2500,7 +2514,7 @@ class engine:
"""
return self._plequal(word1, word2, self.plural_verb)
@validate_call
@typechecked
def compare_adjs(self, word1: Word, word2: Word) -> Union[str, bool]:
"""
compare word1 and word2 for equality regardless of plurality
@ -2516,7 +2530,7 @@ class engine:
"""
return self._plequal(word1, word2, self.plural_adj)
@validate_call
@typechecked
def singular_noun(
self,
text: Word,
@ -2574,18 +2588,18 @@ class engine:
return "s:p"
self.classical_dict = classval.copy()
if same_method(pl, self.plural) or same_method(pl, self.plural_noun):
if pl == self.plural or pl == self.plural_noun:
if self._pl_check_plurals_N(word1, word2):
return "p:p"
if self._pl_check_plurals_N(word2, word1):
return "p:p"
if same_method(pl, self.plural) or same_method(pl, self.plural_adj):
if pl == self.plural or pl == self.plural_adj:
if self._pl_check_plurals_adj(word1, word2):
return "p:p"
return False
def _pl_reg_plurals(self, pair: str, stems: str, end1: str, end2: str) -> bool:
pattern = fr"({stems})({end1}\|\1{end2}|{end2}\|\1{end1})"
pattern = rf"({stems})({end1}\|\1{end2}|{end2}\|\1{end1})"
return bool(re.search(pattern, pair))
def _pl_check_plurals_N(self, word1: str, word2: str) -> bool:
@ -2679,6 +2693,8 @@ class engine:
word = Words(word)
if word.last.lower() in pl_sb_uninflected_complete:
if len(word.split_) >= 3:
return self._handle_long_compounds(word, count=2) or word
return word
if word in pl_sb_uninflected_caps:
@ -2707,13 +2723,9 @@ class engine:
)
if len(word.split_) >= 3:
for numword in range(1, len(word.split_) - 1):
if word.split_[numword] in pl_prep_list_da:
return " ".join(
word.split_[: numword - 1]
+ [self._plnoun(word.split_[numword - 1], 2)]
+ word.split_[numword:]
)
handled_words = self._handle_long_compounds(word, count=2)
if handled_words is not None:
return handled_words
# only pluralize denominators in units
mo = DENOMINATOR.search(word.lowered)
@ -2972,6 +2984,30 @@ class engine:
parts[: pivot - 1] + [sep.join([transformed, parts[pivot], ''])]
) + " ".join(parts[(pivot + 1) :])
def _handle_long_compounds(self, word: Words, count: int) -> Union[str, None]:
"""
Handles the plural and singular for compound `Words` that
have three or more words, based on the given count.
>>> engine()._handle_long_compounds(Words("pair of scissors"), 2)
'pairs of scissors'
>>> engine()._handle_long_compounds(Words("men beyond hills"), 1)
'man beyond hills'
"""
inflection = self._sinoun if count == 1 else self._plnoun
solutions = ( # type: ignore
" ".join(
itertools.chain(
leader,
[inflection(cand, count), prep], # type: ignore
trailer,
)
)
for leader, (cand, prep), trailer in windowed_complete(word.split_, 2)
if prep in pl_prep_list_da # type: ignore
)
return next(solutions, None)
@staticmethod
def _find_pivot(words, candidates):
pivots = (
@ -2980,7 +3016,7 @@ class engine:
try:
return next(pivots)
except StopIteration:
raise ValueError("No pivot found")
raise ValueError("No pivot found") from None
def _pl_special_verb( # noqa: C901
self, word: str, count: Optional[Union[str, int]] = None
@ -3145,8 +3181,8 @@ class engine:
gender = self.thegender
elif gender not in singular_pronoun_genders:
raise BadGenderError
except (TypeError, IndexError):
raise BadGenderError
except (TypeError, IndexError) as err:
raise BadGenderError from err
# HANDLE USER-DEFINED NOUNS
@ -3165,6 +3201,8 @@ class engine:
words = Words(word)
if words.last.lower() in pl_sb_uninflected_complete:
if len(words.split_) >= 3:
return self._handle_long_compounds(words, count=1) or word
return word
if word in pl_sb_uninflected_caps:
@ -3450,7 +3488,7 @@ class engine:
# ADJECTIVES
@validate_call
@typechecked
def a(self, text: Word, count: Optional[Union[int, str, Any]] = 1) -> str:
"""
Return the appropriate indefinite article followed by text.
@ -3531,7 +3569,7 @@ class engine:
# 2. TRANSLATE ZERO-QUANTIFIED $word TO "no plural($word)"
@validate_call
@typechecked
def no(self, text: Word, count: Optional[Union[int, str]] = None) -> str:
"""
If count is 0, no, zero or nil, return 'no' followed by the plural
@ -3569,7 +3607,7 @@ class engine:
# PARTICIPLES
@validate_call
@typechecked
def present_participle(self, word: Word) -> str:
"""
Return the present participle for word.
@ -3588,7 +3626,7 @@ class engine:
# NUMERICAL INFLECTIONS
@validate_call(config=dict(arbitrary_types_allowed=True))
@typechecked
def ordinal(self, num: Union[Number, Word]) -> str:
"""
Return the ordinal of num.
@ -3619,16 +3657,7 @@ class engine:
post = nth[n % 10]
return f"{num}{post}"
else:
# Mad props to Damian Conway (?) whose ordinal()
# algorithm is type-bendy enough to foil MyPy
str_num: str = num # type: ignore[assignment]
mo = ordinal_suff.search(str_num)
if mo:
post = ordinal[mo.group(1)]
rval = ordinal_suff.sub(post, str_num)
else:
rval = f"{str_num}th"
return rval
return self._sub_ord(num)
def millfn(self, ind: int = 0) -> str:
if ind > len(mill) - 1:
@ -3747,7 +3776,36 @@ class engine:
num = ONE_DIGIT_WORD.sub(self.unitsub, num, 1)
return num
@validate_call(config=dict(arbitrary_types_allowed=True)) # noqa: C901
@staticmethod
def _sub_ord(val):
new = ordinal_suff.sub(lambda match: ordinal[match.group(1)], val)
return new + "th" * (new == val)
@classmethod
def _chunk_num(cls, num, decimal, group):
if decimal:
max_split = -1 if group != 0 else 1
chunks = num.split(".", max_split)
else:
chunks = [num]
return cls._remove_last_blank(chunks)
@staticmethod
def _remove_last_blank(chunks):
"""
Remove the last item from chunks if it's a blank string.
Return the resultant chunks and whether the last item was removed.
"""
removed = chunks[-1] == ""
result = chunks[:-1] if removed else chunks
return result, removed
@staticmethod
def _get_sign(num):
return {'+': 'plus', '-': 'minus'}.get(num.lstrip()[0], '')
@typechecked
def number_to_words( # noqa: C901
self,
num: Union[Number, Word],
@ -3794,13 +3852,8 @@ class engine:
if group < 0 or group > 3:
raise BadChunkingOptionError
nowhite = num.lstrip()
if nowhite[0] == "+":
sign = "plus"
elif nowhite[0] == "-":
sign = "minus"
else:
sign = ""
sign = self._get_sign(num)
if num in nth_suff:
num = zero
@ -3808,34 +3861,21 @@ class engine:
myord = num[-2:] in nth_suff
if myord:
num = num[:-2]
finalpoint = False
if decimal:
if group != 0:
chunks = num.split(".")
else:
chunks = num.split(".", 1)
if chunks[-1] == "": # remove blank string if nothing after decimal
chunks = chunks[:-1]
finalpoint = True # add 'point' to end of output
else:
chunks = [num]
first: Union[int, str, bool] = 1
loopstart = 0
chunks, finalpoint = self._chunk_num(num, decimal, group)
if chunks[0] == "":
first = 0
if len(chunks) > 1:
loopstart = 1
loopstart = chunks[0] == ""
first: bool | None = not loopstart
def _handle_chunk(chunk):
nonlocal first
for i in range(loopstart, len(chunks)):
chunk = chunks[i]
# remove all non numeric \D
chunk = NON_DIGIT.sub("", chunk)
if chunk == "":
chunk = "0"
if group == 0 and (first == 0 or first == ""):
if group == 0 and not first:
chunk = self.enword(chunk, 1)
else:
chunk = self.enword(chunk, group)
@ -3850,20 +3890,17 @@ class engine:
# chunk = re.sub(r"(\A\s|\s\Z)", self.blankfn, chunk)
chunk = chunk.strip()
if first:
first = ""
chunks[i] = chunk
first = None
return chunk
chunks[loopstart:] = map(_handle_chunk, chunks[loopstart:])
numchunks = []
if first != 0:
numchunks = chunks[0].split(f"{comma} ")
if myord and numchunks:
# TODO: can this be just one re as it is in perl?
mo = ordinal_suff.search(numchunks[-1])
if mo:
numchunks[-1] = ordinal_suff.sub(ordinal[mo.group(1)], numchunks[-1])
else:
numchunks[-1] += "th"
numchunks[-1] = self._sub_ord(numchunks[-1])
for chunk in chunks[1:]:
numchunks.append(decimal)
@ -3872,34 +3909,30 @@ class engine:
if finalpoint:
numchunks.append(decimal)
# wantlist: Perl list context. can explicitly specify in Python
if wantlist:
if sign:
numchunks = [sign] + numchunks
return numchunks
elif group:
signout = f"{sign} " if sign else ""
return f"{signout}{', '.join(numchunks)}"
else:
signout = f"{sign} " if sign else ""
num = f"{signout}{numchunks.pop(0)}"
if decimal is None:
first = True
else:
first = not num.endswith(decimal)
for nc in numchunks:
if nc == decimal:
num += f" {nc}"
first = 0
elif first:
num += f"{comma} {nc}"
else:
num += f" {nc}"
return num
return [sign] * bool(sign) + numchunks
# Join words with commas and a trailing 'and' (when appropriate)...
signout = f"{sign} " if sign else ""
valout = (
', '.join(numchunks)
if group
else ''.join(self._render(numchunks, decimal, comma))
)
return signout + valout
@validate_call
@staticmethod
def _render(chunks, decimal, comma):
first_item = chunks.pop(0)
yield first_item
first = decimal is None or not first_item.endswith(decimal)
for nc in chunks:
if nc == decimal:
first = False
elif first:
yield comma
yield f" {nc}"
@typechecked
def join(
self,
words: Optional[Sequence[Word]],

View file

@ -1,19 +0,0 @@
class ValidateCallWrapperWrapper:
def __init__(self, wrapped):
self.orig = wrapped
def __eq__(self, other):
return self.raw_function == other.raw_function
@property
def raw_function(self):
return getattr(self.orig, 'raw_function') or self.orig
def same_method(m1, m2) -> bool:
"""
Return whether m1 and m2 are the same method.
Workaround for pydantic/pydantic#6390.
"""
return ValidateCallWrapperWrapper(m1) == ValidateCallWrapperWrapper(m2)

View file

@ -1,8 +0,0 @@
try:
from pydantic import validate_call # type: ignore
except ImportError:
# Pydantic 1
from pydantic import validate_arguments as validate_call # type: ignore
__all__ = ['validate_call']

View file

@ -1,68 +0,0 @@
"""
Routines for obtaining the class names
of an object and its parent classes.
"""
from more_itertools import unique_everseen
def all_bases(c):
"""
return a tuple of all base classes the class c has as a parent.
>>> object in all_bases(list)
True
"""
return c.mro()[1:]
def all_classes(c):
"""
return a tuple of all classes to which c belongs
>>> list in all_classes(list)
True
"""
return c.mro()
# borrowed from
# http://code.activestate.com/recipes/576949-find-all-subclasses-of-a-given-class/
def iter_subclasses(cls):
"""
Generator over all subclasses of a given class, in depth-first order.
>>> bool in list(iter_subclasses(int))
True
>>> class A(object): pass
>>> class B(A): pass
>>> class C(A): pass
>>> class D(B,C): pass
>>> class E(D): pass
>>>
>>> for cls in iter_subclasses(A):
... print(cls.__name__)
B
D
E
C
>>> # get ALL classes currently defined
>>> res = [cls.__name__ for cls in iter_subclasses(object)]
>>> 'type' in res
True
>>> 'tuple' in res
True
>>> len(res) > 100
True
"""
return unique_everseen(_iter_all_subclasses(cls))
def _iter_all_subclasses(cls):
try:
subs = cls.__subclasses__()
except TypeError: # fails only when cls is type
subs = cls.__subclasses__(cls)
for sub in subs:
yield sub
yield from iter_subclasses(sub)

View file

@ -1,66 +0,0 @@
"""
meta.py
Some useful metaclasses.
"""
class LeafClassesMeta(type):
"""
A metaclass for classes that keeps track of all of them that
aren't base classes.
>>> Parent = LeafClassesMeta('MyParentClass', (), {})
>>> Parent in Parent._leaf_classes
True
>>> Child = LeafClassesMeta('MyChildClass', (Parent,), {})
>>> Child in Parent._leaf_classes
True
>>> Parent in Parent._leaf_classes
False
>>> Other = LeafClassesMeta('OtherClass', (), {})
>>> Parent in Other._leaf_classes
False
>>> len(Other._leaf_classes)
1
"""
def __init__(cls, name, bases, attrs):
if not hasattr(cls, '_leaf_classes'):
cls._leaf_classes = set()
leaf_classes = getattr(cls, '_leaf_classes')
leaf_classes.add(cls)
# remove any base classes
leaf_classes -= set(bases)
class TagRegistered(type):
"""
As classes of this metaclass are created, they keep a registry in the
base class of all classes by a class attribute, indicated by attr_name.
>>> FooObject = TagRegistered('FooObject', (), dict(tag='foo'))
>>> FooObject._registry['foo'] is FooObject
True
>>> BarObject = TagRegistered('Barobject', (FooObject,), dict(tag='bar'))
>>> FooObject._registry is BarObject._registry
True
>>> len(FooObject._registry)
2
'...' below should be 'jaraco.classes' but for pytest-dev/pytest#3396
>>> FooObject._registry['bar']
<class '....meta.Barobject'>
"""
attr_name = 'tag'
def __init__(cls, name, bases, namespace):
super(TagRegistered, cls).__init__(name, bases, namespace)
if not hasattr(cls, '_registry'):
cls._registry = {}
meta = cls.__class__
attr = getattr(cls, meta.attr_name, None)
if attr:
cls._registry[attr] = cls

View file

@ -1,170 +0,0 @@
class NonDataProperty:
"""Much like the property builtin, but only implements __get__,
making it a non-data property, and can be subsequently reset.
See http://users.rcn.com/python/download/Descriptor.htm for more
information.
>>> class X(object):
... @NonDataProperty
... def foo(self):
... return 3
>>> x = X()
>>> x.foo
3
>>> x.foo = 4
>>> x.foo
4
'...' below should be 'jaraco.classes' but for pytest-dev/pytest#3396
>>> X.foo
<....properties.NonDataProperty object at ...>
"""
def __init__(self, fget):
assert fget is not None, "fget cannot be none"
assert callable(fget), "fget must be callable"
self.fget = fget
def __get__(self, obj, objtype=None):
if obj is None:
return self
return self.fget(obj)
class classproperty:
"""
Like @property but applies at the class level.
>>> class X(metaclass=classproperty.Meta):
... val = None
... @classproperty
... def foo(cls):
... return cls.val
... @foo.setter
... def foo(cls, val):
... cls.val = val
>>> X.foo
>>> X.foo = 3
>>> X.foo
3
>>> x = X()
>>> x.foo
3
>>> X.foo = 4
>>> x.foo
4
Setting the property on an instance affects the class.
>>> x.foo = 5
>>> x.foo
5
>>> X.foo
5
>>> vars(x)
{}
>>> X().foo
5
Attempting to set an attribute where no setter was defined
results in an AttributeError:
>>> class GetOnly(metaclass=classproperty.Meta):
... @classproperty
... def foo(cls):
... return 'bar'
>>> GetOnly.foo = 3
Traceback (most recent call last):
...
AttributeError: can't set attribute
It is also possible to wrap a classmethod or staticmethod in
a classproperty.
>>> class Static(metaclass=classproperty.Meta):
... @classproperty
... @classmethod
... def foo(cls):
... return 'foo'
... @classproperty
... @staticmethod
... def bar():
... return 'bar'
>>> Static.foo
'foo'
>>> Static.bar
'bar'
*Legacy*
For compatibility, if the metaclass isn't specified, the
legacy behavior will be invoked.
>>> class X:
... val = None
... @classproperty
... def foo(cls):
... return cls.val
... @foo.setter
... def foo(cls, val):
... cls.val = val
>>> X.foo
>>> X.foo = 3
>>> X.foo
3
>>> x = X()
>>> x.foo
3
>>> X.foo = 4
>>> x.foo
4
Note, because the metaclass was not specified, setting
a value on an instance does not have the intended effect.
>>> x.foo = 5
>>> x.foo
5
>>> X.foo # should be 5
4
>>> vars(x) # should be empty
{'foo': 5}
>>> X().foo # should be 5
4
"""
class Meta(type):
def __setattr__(self, key, value):
obj = self.__dict__.get(key, None)
if type(obj) is classproperty:
return obj.__set__(self, value)
return super().__setattr__(key, value)
def __init__(self, fget, fset=None):
self.fget = self._ensure_method(fget)
self.fset = fset
fset and self.setter(fset)
def __get__(self, instance, owner=None):
return self.fget.__get__(None, owner)()
def __set__(self, owner, value):
if not self.fset:
raise AttributeError("can't set attribute")
if type(owner) is not classproperty.Meta:
owner = type(owner)
return self.fset.__get__(None, owner)(value)
def setter(self, fset):
self.fset = self._ensure_method(fset)
return self
@classmethod
def _ensure_method(cls, fn):
"""
Ensure fn is a classmethod or staticmethod.
"""
needs_method = not isinstance(fn, (classmethod, staticmethod))
return classmethod(fn) if needs_method else fn

View file

@ -1,16 +1,17 @@
import re
import operator
from __future__ import annotations
import collections.abc
import itertools
import copy
import functools
import itertools
import operator
import random
import re
from collections.abc import Container, Iterable, Mapping
from typing import Callable, Union
from typing import Any, Callable, Union
import jaraco.text
_Matchable = Union[Callable, Container, Iterable, re.Pattern]
@ -199,7 +200,12 @@ class RangeMap(dict):
"""
def __init__(self, source, sort_params={}, key_match_comparator=operator.le):
def __init__(
self,
source,
sort_params: Mapping[str, Any] = {},
key_match_comparator=operator.le,
):
dict.__init__(self, source)
self.sort_params = sort_params
self.match = key_match_comparator
@ -291,7 +297,7 @@ class KeyTransformingDict(dict):
return key
def __init__(self, *args, **kargs):
super(KeyTransformingDict, self).__init__()
super().__init__()
# build a dictionary using the default constructs
d = dict(*args, **kargs)
# build this dictionary using transformed keys.
@ -300,31 +306,31 @@ class KeyTransformingDict(dict):
def __setitem__(self, key, val):
key = self.transform_key(key)
super(KeyTransformingDict, self).__setitem__(key, val)
super().__setitem__(key, val)
def __getitem__(self, key):
key = self.transform_key(key)
return super(KeyTransformingDict, self).__getitem__(key)
return super().__getitem__(key)
def __contains__(self, key):
key = self.transform_key(key)
return super(KeyTransformingDict, self).__contains__(key)
return super().__contains__(key)
def __delitem__(self, key):
key = self.transform_key(key)
return super(KeyTransformingDict, self).__delitem__(key)
return super().__delitem__(key)
def get(self, key, *args, **kwargs):
key = self.transform_key(key)
return super(KeyTransformingDict, self).get(key, *args, **kwargs)
return super().get(key, *args, **kwargs)
def setdefault(self, key, *args, **kwargs):
key = self.transform_key(key)
return super(KeyTransformingDict, self).setdefault(key, *args, **kwargs)
return super().setdefault(key, *args, **kwargs)
def pop(self, key, *args, **kwargs):
key = self.transform_key(key)
return super(KeyTransformingDict, self).pop(key, *args, **kwargs)
return super().pop(key, *args, **kwargs)
def matching_key_for(self, key):
"""
@ -333,8 +339,8 @@ class KeyTransformingDict(dict):
"""
try:
return next(e_key for e_key in self.keys() if e_key == key)
except StopIteration:
raise KeyError(key)
except StopIteration as err:
raise KeyError(key) from err
class FoldedCaseKeyedDict(KeyTransformingDict):
@ -483,7 +489,7 @@ class ItemsAsAttributes:
def __getattr__(self, key):
try:
return getattr(super(ItemsAsAttributes, self), key)
return getattr(super(), key)
except AttributeError as e:
# attempt to get the value from the mapping (return self[key])
# but be careful not to lose the original exception context.
@ -677,7 +683,7 @@ class BijectiveMap(dict):
"""
def __init__(self, *args, **kwargs):
super(BijectiveMap, self).__init__()
super().__init__()
self.update(*args, **kwargs)
def __setitem__(self, item, value):
@ -691,19 +697,19 @@ class BijectiveMap(dict):
)
if overlap:
raise ValueError("Key/Value pairs may not overlap")
super(BijectiveMap, self).__setitem__(item, value)
super(BijectiveMap, self).__setitem__(value, item)
super().__setitem__(item, value)
super().__setitem__(value, item)
def __delitem__(self, item):
self.pop(item)
def __len__(self):
return super(BijectiveMap, self).__len__() // 2
return super().__len__() // 2
def pop(self, key, *args, **kwargs):
mirror = self[key]
super(BijectiveMap, self).__delitem__(mirror)
return super(BijectiveMap, self).pop(key, *args, **kwargs)
super().__delitem__(mirror)
return super().pop(key, *args, **kwargs)
def update(self, *args, **kwargs):
# build a dictionary using the default constructs
@ -769,7 +775,7 @@ class FrozenDict(collections.abc.Mapping, collections.abc.Hashable):
__slots__ = ['__data']
def __new__(cls, *args, **kwargs):
self = super(FrozenDict, cls).__new__(cls)
self = super().__new__(cls)
self.__data = dict(*args, **kwargs)
return self
@ -844,7 +850,7 @@ class Enumeration(ItemsAsAttributes, BijectiveMap):
names = names.split()
if codes is None:
codes = itertools.count()
super(Enumeration, self).__init__(zip(names, codes))
super().__init__(zip(names, codes))
@property
def names(self):

View file

@ -1,15 +1,26 @@
import os
import subprocess
from __future__ import annotations
import contextlib
import functools
import tempfile
import shutil
import operator
import os
import shutil
import subprocess
import sys
import tempfile
import urllib.request
import warnings
from typing import Iterator
if sys.version_info < (3, 12):
from backports import tarfile
else:
import tarfile
@contextlib.contextmanager
def pushd(dir):
def pushd(dir: str | os.PathLike) -> Iterator[str | os.PathLike]:
"""
>>> tmp_path = getfixture('tmp_path')
>>> with pushd(tmp_path):
@ -26,33 +37,88 @@ def pushd(dir):
@contextlib.contextmanager
def tarball_context(url, target_dir=None, runner=None, pushd=pushd):
def tarball(
url, target_dir: str | os.PathLike | None = None
) -> Iterator[str | os.PathLike]:
"""
Get a tarball, extract it, change to that directory, yield, then
clean up.
`runner` is the function to invoke commands.
`pushd` is a context manager for changing the directory.
Get a tarball, extract it, yield, then clean up.
>>> import urllib.request
>>> url = getfixture('tarfile_served')
>>> target = getfixture('tmp_path') / 'out'
>>> tb = tarball(url, target_dir=target)
>>> import pathlib
>>> with tb as extracted:
... contents = pathlib.Path(extracted, 'contents.txt').read_text(encoding='utf-8')
>>> assert not os.path.exists(extracted)
"""
if target_dir is None:
target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '')
if runner is None:
runner = functools.partial(subprocess.check_call, shell=True)
else:
warnings.warn("runner parameter is deprecated", DeprecationWarning)
# In the tar command, use --strip-components=1 to strip the first path and
# then
# use -C to cause the files to be extracted to {target_dir}. This ensures
# that we always know where the files were extracted.
runner('mkdir {target_dir}'.format(**vars()))
os.mkdir(target_dir)
try:
getter = 'wget {url} -O -'
extract = 'tar x{compression} --strip-components=1 -C {target_dir}'
cmd = ' | '.join((getter, extract))
runner(cmd.format(compression=infer_compression(url), **vars()))
with pushd(target_dir):
yield target_dir
req = urllib.request.urlopen(url)
with tarfile.open(fileobj=req, mode='r|*') as tf:
tf.extractall(path=target_dir, filter=strip_first_component)
yield target_dir
finally:
runner('rm -Rf {target_dir}'.format(**vars()))
shutil.rmtree(target_dir)
def strip_first_component(
member: tarfile.TarInfo,
path,
) -> tarfile.TarInfo:
_, member.name = member.name.split('/', 1)
return member
def _compose(*cmgrs):
"""
Compose any number of dependent context managers into a single one.
The last, innermost context manager may take arbitrary arguments, but
each successive context manager should accept the result from the
previous as a single parameter.
Like :func:`jaraco.functools.compose`, behavior works from right to
left, so the context manager should be indicated from outermost to
innermost.
Example, to create a context manager to change to a temporary
directory:
>>> temp_dir_as_cwd = _compose(pushd, temp_dir)
>>> with temp_dir_as_cwd() as dir:
... assert os.path.samefile(os.getcwd(), dir)
"""
def compose_two(inner, outer):
def composed(*args, **kwargs):
with inner(*args, **kwargs) as saved, outer(saved) as res:
yield res
return contextlib.contextmanager(composed)
return functools.reduce(compose_two, reversed(cmgrs))
tarball_cwd = _compose(pushd, tarball)
@contextlib.contextmanager
def tarball_context(*args, **kwargs):
warnings.warn(
"tarball_context is deprecated. Use tarball or tarball_cwd instead.",
DeprecationWarning,
stacklevel=2,
)
pushd_ctx = kwargs.pop('pushd', pushd)
with tarball(*args, **kwargs) as tball, pushd_ctx(tball) as dir:
yield dir
def infer_compression(url):
@ -68,6 +134,11 @@ def infer_compression(url):
>>> infer_compression('file.xz')
'J'
"""
warnings.warn(
"infer_compression is deprecated with no replacement",
DeprecationWarning,
stacklevel=2,
)
# cheat and just assume it's the last two characters
compression_indicator = url[-2:]
mapping = dict(gz='z', bz='j', xz='J')
@ -84,7 +155,7 @@ def temp_dir(remover=shutil.rmtree):
>>> import pathlib
>>> with temp_dir() as the_dir:
... assert os.path.isdir(the_dir)
... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents')
... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents', encoding='utf-8')
>>> assert not os.path.exists(the_dir)
"""
temp_dir = tempfile.mkdtemp()
@ -113,15 +184,23 @@ def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir):
yield repo_dir
@contextlib.contextmanager
def null():
"""
A null context suitable to stand in for a meaningful context.
>>> with null() as value:
... assert value is None
This context is most useful when dealing with two or more code
branches but only some need a context. Wrap the others in a null
context to provide symmetry across all options.
"""
yield
warnings.warn(
"null is deprecated. Use contextlib.nullcontext",
DeprecationWarning,
stacklevel=2,
)
return contextlib.nullcontext()
class ExceptionTrap:
@ -267,13 +346,7 @@ class on_interrupt(contextlib.ContextDecorator):
... on_interrupt('ignore')(do_interrupt)()
"""
def __init__(
self,
action='error',
# py3.7 compat
# /,
code=1,
):
def __init__(self, action='error', /, code=1):
self.action = action
self.code = code

View file

@ -74,9 +74,6 @@ def result_invoke(
def invoke(
f: Callable[_P, _R], /, *args: _P.args, **kwargs: _P.kwargs
) -> Callable[_P, _R]: ...
def call_aside(
f: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs
) -> Callable[_P, _R]: ...
class Throttler(Generic[_R]):
last_called: float

View file

@ -9,11 +9,11 @@ python-modernize licence: BSD (from python-modernize/LICENSE)
"""
from lib2to3.fixer_util import (FromImport, Newline, is_import,
find_root, does_tree_import, Comma)
find_root, does_tree_import,
Call, Name, Comma)
from lib2to3.pytree import Leaf, Node
from lib2to3.pygram import python_symbols as syms, python_grammar
from lib2to3.pygram import python_symbols as syms
from lib2to3.pygram import token
from lib2to3.fixer_util import (Node, Call, Name, syms, Comma, Number)
import re
@ -116,7 +116,7 @@ def suitify(parent):
"""
for node in parent.children:
if node.type == syms.suite:
# already in the prefered format, do nothing
# already in the preferred format, do nothing
return
# One-liners have no suite node, we have to fake one up
@ -390,6 +390,7 @@ def touch_import_top(package, name_to_import, node):
break
insert_pos = idx
children_hooks = []
if package is None:
import_ = Node(syms.import_name, [
Leaf(token.NAME, u"import"),
@ -413,8 +414,6 @@ def touch_import_top(package, name_to_import, node):
]
)
children_hooks = [install_hooks, Newline()]
else:
children_hooks = []
# FromImport(package, [Leaf(token.NAME, name_to_import, prefix=u" ")])
@ -448,7 +447,6 @@ def check_future_import(node):
else:
node = node.children[3]
# now node is the import_as_name[s]
# print(python_grammar.number2symbol[node.type]) # breaks sometimes
if node.type == syms.import_as_names:
result = set()
for n in node.children:

View file

@ -37,7 +37,7 @@ from lib2to3.fixer_util import Name, syms, Node, Leaf, touch_import, Call, \
def has_metaclass(parent):
""" we have to check the cls_node without changing it.
There are two possiblities:
There are two possibilities:
1) clsdef => suite => simple_stmt => expr_stmt => Leaf('__meta')
2) clsdef => simple_stmt => expr_stmt => Leaf('__meta')
"""
@ -63,7 +63,7 @@ def fixup_parse_tree(cls_node):
# already in the preferred format, do nothing
return
# !%@#! oneliners have no suite node, we have to fake one up
# !%@#! one-liners have no suite node, we have to fake one up
for i, node in enumerate(cls_node.children):
if node.type == token.COLON:
break

View file

@ -16,6 +16,7 @@ MAPPING = {u"reprlib": u"repr",
u"winreg": u"_winreg",
u"configparser": u"ConfigParser",
u"copyreg": u"copy_reg",
u"multiprocessing.SimpleQueue": u"multiprocessing.queues.SimpleQueue",
u"queue": u"Queue",
u"socketserver": u"SocketServer",
u"_markupbase": u"markupbase",

View file

@ -18,8 +18,12 @@ def assignment_source(num_pre, num_post, LISTNAME, ITERNAME):
Returns a source fit for Assign() from fixer_util
"""
children = []
pre = unicode(num_pre)
post = unicode(num_post)
try:
pre = unicode(num_pre)
post = unicode(num_post)
except NameError:
pre = str(num_pre)
post = str(num_post)
# This code builds the assignment source from lib2to3 tree primitives.
# It's not very readable, but it seems like the most correct way to do it.
if num_pre > 0:

View file

@ -75,12 +75,12 @@ Credits
-------
:Author: Ed Schofield, Jordan M. Adler, et al
:Sponsor: Python Charmers Pty Ltd, Australia: http://pythoncharmers.com
:Sponsor: Python Charmers: https://pythoncharmers.com
Licensing
---------
Copyright 2013-2019 Python Charmers Pty Ltd, Australia.
Copyright 2013-2024 Python Charmers, Australia.
The software is distributed under an MIT licence. See LICENSE.txt.
"""

View file

@ -1,11 +1,13 @@
from __future__ import unicode_literals
import inspect
import sys
import math
import numbers
from future.utils import PY2, PY3, exec_
if PY2:
from collections import Mapping
else:
@ -103,13 +105,12 @@ if PY3:
return '0' + builtins.oct(number)[2:]
raw_input = input
try:
# imp was deprecated in python 3.6
if sys.version_info >= (3, 6):
from importlib import reload
except ImportError:
else:
# for python2, python3 <= 3.4
from imp import reload
unicode = str
unichr = chr
xrange = range

View file

@ -32,17 +32,31 @@ Author: Ed Schofield.
Inspired by and based on ``uprefix`` by Vinay M. Sajip.
"""
import imp
import logging
import marshal
import os
import sys
# imp was deprecated in python 3.6
if sys.version_info >= (3, 6):
import importlib as imp
else:
import imp
import logging
import os
import copy
from lib2to3.pgen2.parse import ParseError
from lib2to3.refactor import RefactoringTool
from libfuturize import fixes
try:
from importlib.machinery import (
PathFinder,
SourceFileLoader,
)
except ImportError:
PathFinder = None
SourceFileLoader = object
if sys.version_info[:2] < (3, 4):
import imp
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@ -225,6 +239,81 @@ def detect_python2(source, pathname):
return False
def transform(source, pathname):
# This implementation uses lib2to3,
# you can override and use something else
# if that's better for you
# lib2to3 likes a newline at the end
RTs.setup()
source += '\n'
try:
tree = RTs._rt.refactor_string(source, pathname)
except ParseError as e:
if e.msg != 'bad input' or e.value != '=':
raise
tree = RTs._rtp.refactor_string(source, pathname)
# could optimise a bit for only doing str(tree) if
# getattr(tree, 'was_changed', False) returns True
return str(tree)[:-1] # remove added newline
class PastSourceFileLoader(SourceFileLoader):
exclude_paths = []
include_paths = []
def _convert_needed(self):
fullname = self.name
if any(fullname.startswith(path) for path in self.exclude_paths):
convert = False
elif any(fullname.startswith(path) for path in self.include_paths):
convert = True
else:
convert = False
return convert
def _exec_transformed_module(self, module):
source = self.get_source(self.name)
pathname = self.path
if detect_python2(source, pathname):
source = transform(source, pathname)
code = compile(source, pathname, "exec")
exec(code, module.__dict__)
# For Python 3.3
def load_module(self, fullname):
logger.debug("Running load_module for %s", fullname)
if fullname in sys.modules:
mod = sys.modules[fullname]
else:
if self._convert_needed():
logger.debug("Autoconverting %s", fullname)
mod = imp.new_module(fullname)
sys.modules[fullname] = mod
# required by PEP 302
mod.__file__ = self.path
mod.__loader__ = self
if self.is_package(fullname):
mod.__path__ = []
mod.__package__ = fullname
else:
mod.__package__ = fullname.rpartition('.')[0]
self._exec_transformed_module(mod)
else:
mod = super().load_module(fullname)
return mod
# For Python >=3.4
def exec_module(self, module):
logger.debug("Running exec_module for %s", module)
if self._convert_needed():
logger.debug("Autoconverting %s", self.name)
self._exec_transformed_module(module)
else:
super().exec_module(module)
class Py2Fixer(object):
"""
An import hook class that uses lib2to3 for source-to-source translation of
@ -258,151 +347,30 @@ class Py2Fixer(object):
"""
self.exclude_paths += paths
# For Python 3.3
def find_module(self, fullname, path=None):
logger.debug('Running find_module: {0}...'.format(fullname))
if '.' in fullname:
parent, child = fullname.rsplit('.', 1)
if path is None:
loader = self.find_module(parent, path)
mod = loader.load_module(parent)
path = mod.__path__
fullname = child
# Perhaps we should try using the new importlib functionality in Python
# 3.3: something like this?
# thing = importlib.machinery.PathFinder.find_module(fullname, path)
try:
self.found = imp.find_module(fullname, path)
except Exception as e:
logger.debug('Py2Fixer could not find {0}')
logger.debug('Exception was: {0})'.format(fullname, e))
logger.debug("Running find_module: (%s, %s)", fullname, path)
loader = PathFinder.find_module(fullname, path)
if not loader:
logger.debug("Py2Fixer could not find %s", fullname)
return None
self.kind = self.found[-1][-1]
if self.kind == imp.PKG_DIRECTORY:
self.pathname = os.path.join(self.found[1], '__init__.py')
elif self.kind == imp.PY_SOURCE:
self.pathname = self.found[1]
return self
loader.__class__ = PastSourceFileLoader
loader.exclude_paths = self.exclude_paths
loader.include_paths = self.include_paths
return loader
def transform(self, source):
# This implementation uses lib2to3,
# you can override and use something else
# if that's better for you
# For Python >=3.4
def find_spec(self, fullname, path=None, target=None):
logger.debug("Running find_spec: (%s, %s, %s)", fullname, path, target)
spec = PathFinder.find_spec(fullname, path, target)
if not spec:
logger.debug("Py2Fixer could not find %s", fullname)
return None
spec.loader.__class__ = PastSourceFileLoader
spec.loader.exclude_paths = self.exclude_paths
spec.loader.include_paths = self.include_paths
return spec
# lib2to3 likes a newline at the end
RTs.setup()
source += '\n'
try:
tree = RTs._rt.refactor_string(source, self.pathname)
except ParseError as e:
if e.msg != 'bad input' or e.value != '=':
raise
tree = RTs._rtp.refactor_string(source, self.pathname)
# could optimise a bit for only doing str(tree) if
# getattr(tree, 'was_changed', False) returns True
return str(tree)[:-1] # remove added newline
def load_module(self, fullname):
logger.debug('Running load_module for {0}...'.format(fullname))
if fullname in sys.modules:
mod = sys.modules[fullname]
else:
if self.kind in (imp.PY_COMPILED, imp.C_EXTENSION, imp.C_BUILTIN,
imp.PY_FROZEN):
convert = False
# elif (self.pathname.startswith(_stdlibprefix)
# and 'site-packages' not in self.pathname):
# # We assume it's a stdlib package in this case. Is this too brittle?
# # Please file a bug report at https://github.com/PythonCharmers/python-future
# # if so.
# convert = False
# in theory, other paths could be configured to be excluded here too
elif any([fullname.startswith(path) for path in self.exclude_paths]):
convert = False
elif any([fullname.startswith(path) for path in self.include_paths]):
convert = True
else:
convert = False
if not convert:
logger.debug('Excluded {0} from translation'.format(fullname))
mod = imp.load_module(fullname, *self.found)
else:
logger.debug('Autoconverting {0} ...'.format(fullname))
mod = imp.new_module(fullname)
sys.modules[fullname] = mod
# required by PEP 302
mod.__file__ = self.pathname
mod.__name__ = fullname
mod.__loader__ = self
# This:
# mod.__package__ = '.'.join(fullname.split('.')[:-1])
# seems to result in "SystemError: Parent module '' not loaded,
# cannot perform relative import" for a package's __init__.py
# file. We use the approach below. Another option to try is the
# minimal load_module pattern from the PEP 302 text instead.
# Is the test in the next line more or less robust than the
# following one? Presumably less ...
# ispkg = self.pathname.endswith('__init__.py')
if self.kind == imp.PKG_DIRECTORY:
mod.__path__ = [ os.path.dirname(self.pathname) ]
mod.__package__ = fullname
else:
#else, regular module
mod.__path__ = []
mod.__package__ = fullname.rpartition('.')[0]
try:
cachename = imp.cache_from_source(self.pathname)
if not os.path.exists(cachename):
update_cache = True
else:
sourcetime = os.stat(self.pathname).st_mtime
cachetime = os.stat(cachename).st_mtime
update_cache = cachetime < sourcetime
# # Force update_cache to work around a problem with it being treated as Py3 code???
# update_cache = True
if not update_cache:
with open(cachename, 'rb') as f:
data = f.read()
try:
code = marshal.loads(data)
except Exception:
# pyc could be corrupt. Regenerate it
update_cache = True
if update_cache:
if self.found[0]:
source = self.found[0].read()
elif self.kind == imp.PKG_DIRECTORY:
with open(self.pathname) as f:
source = f.read()
if detect_python2(source, self.pathname):
source = self.transform(source)
code = compile(source, self.pathname, 'exec')
dirname = os.path.dirname(cachename)
try:
if not os.path.exists(dirname):
os.makedirs(dirname)
with open(cachename, 'wb') as f:
data = marshal.dumps(code)
f.write(data)
except Exception: # could be write-protected
pass
exec(code, mod.__dict__)
except Exception as e:
# must remove module from sys.modules
del sys.modules[fullname]
raise # keep it simple
if self.found[0]:
self.found[0].close()
return mod
_hook = Py2Fixer()

View file

@ -1,6 +1,8 @@
"""
Utilities for determining application-specific dirs. See <https://github.com/platformdirs/platformdirs> for details and
usage.
Utilities for determining application-specific dirs.
See <https://github.com/platformdirs/platformdirs> for details and usage.
"""
from __future__ import annotations
@ -20,22 +22,22 @@ if TYPE_CHECKING:
def _set_platform_dir_class() -> type[PlatformDirsABC]:
if sys.platform == "win32":
from platformdirs.windows import Windows as Result
from platformdirs.windows import Windows as Result # noqa: PLC0415
elif sys.platform == "darwin":
from platformdirs.macos import MacOS as Result
from platformdirs.macos import MacOS as Result # noqa: PLC0415
else:
from platformdirs.unix import Unix as Result
from platformdirs.unix import Unix as Result # noqa: PLC0415
if os.getenv("ANDROID_DATA") == "/data" and os.getenv("ANDROID_ROOT") == "/system":
if os.getenv("SHELL") or os.getenv("PREFIX"):
return Result
from platformdirs.android import _android_folder
from platformdirs.android import _android_folder # noqa: PLC0415
if _android_folder() is not None:
from platformdirs.android import Android
from platformdirs.android import Android # noqa: PLC0415
return Android # return to avoid redefinition of result
return Android # return to avoid redefinition of a result
return Result
@ -507,7 +509,7 @@ def user_log_path(
def user_documents_path() -> Path:
""":returns: documents path tied to the user"""
""":returns: documents a path tied to the user"""
return PlatformDirs().user_documents_path
@ -585,41 +587,41 @@ def site_runtime_path(
__all__ = [
"AppDirs",
"PlatformDirs",
"PlatformDirsABC",
"__version__",
"__version_info__",
"PlatformDirs",
"AppDirs",
"PlatformDirsABC",
"user_data_dir",
"user_config_dir",
"user_cache_dir",
"user_state_dir",
"user_log_dir",
"user_documents_dir",
"user_downloads_dir",
"user_pictures_dir",
"user_videos_dir",
"user_music_dir",
"user_desktop_dir",
"user_runtime_dir",
"site_data_dir",
"site_config_dir",
"site_cache_dir",
"site_runtime_dir",
"user_data_path",
"user_config_path",
"user_cache_path",
"user_state_path",
"user_log_path",
"user_documents_path",
"user_downloads_path",
"user_pictures_path",
"user_videos_path",
"user_music_path",
"user_desktop_path",
"user_runtime_path",
"site_data_path",
"site_config_path",
"site_cache_path",
"site_config_dir",
"site_config_path",
"site_data_dir",
"site_data_path",
"site_runtime_dir",
"site_runtime_path",
"user_cache_dir",
"user_cache_path",
"user_config_dir",
"user_config_path",
"user_data_dir",
"user_data_path",
"user_desktop_dir",
"user_desktop_path",
"user_documents_dir",
"user_documents_path",
"user_downloads_dir",
"user_downloads_path",
"user_log_dir",
"user_log_path",
"user_music_dir",
"user_music_path",
"user_pictures_dir",
"user_pictures_path",
"user_runtime_dir",
"user_runtime_path",
"user_state_dir",
"user_state_path",
"user_videos_dir",
"user_videos_path",
]

View file

@ -24,7 +24,7 @@ PROPS = (
def main() -> None:
"""Run main entry point."""
"""Run the main entry point."""
app_name = "MyApp"
app_author = "MyCompany"

View file

@ -13,10 +13,11 @@ from .api import PlatformDirsABC
class Android(PlatformDirsABC):
"""
Follows the guidance `from here <https://android.stackexchange.com/a/216132>`_. Makes use of the
`appname <platformdirs.api.PlatformDirsABC.appname>`,
`version <platformdirs.api.PlatformDirsABC.version>`,
`ensure_exists <platformdirs.api.PlatformDirsABC.ensure_exists>`.
Follows the guidance `from here <https://android.stackexchange.com/a/216132>`_.
Makes use of the `appname <platformdirs.api.PlatformDirsABC.appname>`, `version
<platformdirs.api.PlatformDirsABC.version>`, `ensure_exists <platformdirs.api.PlatformDirsABC.ensure_exists>`.
"""
@property
@ -44,7 +45,7 @@ class Android(PlatformDirsABC):
@property
def user_cache_dir(self) -> str:
""":return: cache directory tied to the user, e.g. e.g. ``/data/user/<userid>/<packagename>/cache/<AppName>``"""
""":return: cache directory tied to the user, e.g.,``/data/user/<userid>/<packagename>/cache/<AppName>``"""
return self._append_app_name_and_version(cast(str, _android_folder()), "cache")
@property
@ -119,13 +120,13 @@ class Android(PlatformDirsABC):
def _android_folder() -> str | None:
""":return: base folder for the Android OS or None if it cannot be found"""
try:
# First try to get path to android app via pyjnius
from jnius import autoclass
# First try to get a path to android app via pyjnius
from jnius import autoclass # noqa: PLC0415
context = autoclass("android.content.Context")
result: str | None = context.getFilesDir().getParentFile().getAbsolutePath()
except Exception: # noqa: BLE001
# if fails find an android folder looking path on the sys.path
# if fails find an android folder looking a path on the sys.path
pattern = re.compile(r"/data/(data|user/\d+)/(.+)/files")
for path in sys.path:
if pattern.match(path):
@ -141,7 +142,7 @@ def _android_documents_folder() -> str:
""":return: documents folder for the Android OS"""
# Get directories with pyjnius
try:
from jnius import autoclass
from jnius import autoclass # noqa: PLC0415
context = autoclass("android.content.Context")
environment = autoclass("android.os.Environment")
@ -157,7 +158,7 @@ def _android_downloads_folder() -> str:
""":return: downloads folder for the Android OS"""
# Get directories with pyjnius
try:
from jnius import autoclass
from jnius import autoclass # noqa: PLC0415
context = autoclass("android.content.Context")
environment = autoclass("android.os.Environment")
@ -173,7 +174,7 @@ def _android_pictures_folder() -> str:
""":return: pictures folder for the Android OS"""
# Get directories with pyjnius
try:
from jnius import autoclass
from jnius import autoclass # noqa: PLC0415
context = autoclass("android.content.Context")
environment = autoclass("android.os.Environment")
@ -189,7 +190,7 @@ def _android_videos_folder() -> str:
""":return: videos folder for the Android OS"""
# Get directories with pyjnius
try:
from jnius import autoclass
from jnius import autoclass # noqa: PLC0415
context = autoclass("android.content.Context")
environment = autoclass("android.os.Environment")
@ -205,7 +206,7 @@ def _android_music_folder() -> str:
""":return: music folder for the Android OS"""
# Get directories with pyjnius
try:
from jnius import autoclass
from jnius import autoclass # noqa: PLC0415
context = autoclass("android.content.Context")
environment = autoclass("android.os.Environment")

View file

@ -11,10 +11,10 @@ if TYPE_CHECKING:
from typing import Iterator, Literal
class PlatformDirsABC(ABC):
class PlatformDirsABC(ABC): # noqa: PLR0904
"""Abstract base class for platform directories."""
def __init__( # noqa: PLR0913
def __init__( # noqa: PLR0913, PLR0917
self,
appname: str | None = None,
appauthor: str | None | Literal[False] = None,
@ -34,34 +34,47 @@ class PlatformDirsABC(ABC):
:param multipath: See `multipath`.
:param opinion: See `opinion`.
:param ensure_exists: See `ensure_exists`.
"""
self.appname = appname #: The name of application.
self.appauthor = appauthor
"""
The name of the app author or distributing body for this application. Typically, it is the owning company name.
Defaults to `appname`. You may pass ``False`` to disable it.
The name of the app author or distributing body for this application.
Typically, it is the owning company name. Defaults to `appname`. You may pass ``False`` to disable it.
"""
self.version = version
"""
An optional version path element to append to the path. You might want to use this if you want multiple versions
of your app to be able to run independently. If used, this would typically be ``<major>.<minor>``.
An optional version path element to append to the path.
You might want to use this if you want multiple versions of your app to be able to run independently. If used,
this would typically be ``<major>.<minor>``.
"""
self.roaming = roaming
"""
Whether to use the roaming appdata directory on Windows. That means that for users on a Windows network setup
for roaming profiles, this user data will be synced on login (see
`here <http://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>`_).
Whether to use the roaming appdata directory on Windows.
That means that for users on a Windows network setup for roaming profiles, this user data will be synced on
login (see
`here <https://technet.microsoft.com/en-us/library/cc766489(WS.10).aspx>`_).
"""
self.multipath = multipath
"""
An optional parameter which indicates that the entire list of data dirs should be returned.
By default, the first item would only be returned.
"""
self.opinion = opinion #: A flag to indicating to use opinionated values.
self.ensure_exists = ensure_exists
"""
Optionally create the directory (and any missing parents) upon access if it does not exist.
By default, no directories are created.
"""
def _append_app_name_and_version(self, *base: str) -> str:
@ -200,7 +213,7 @@ class PlatformDirsABC(ABC):
@property
def user_documents_path(self) -> Path:
""":return: documents path tied to the user"""
""":return: documents a path tied to the user"""
return Path(self.user_documents_dir)
@property

View file

@ -10,11 +10,14 @@ from .api import PlatformDirsABC
class MacOS(PlatformDirsABC):
"""
Platform directories for the macOS operating system. Follows the guidance from `Apple documentation
<https://developer.apple.com/library/archive/documentation/FileManagement/Conceptual/FileSystemProgrammingGuide/MacOSXDirectories/MacOSXDirectories.html>`_.
Platform directories for the macOS operating system.
Follows the guidance from
`Apple documentation <https://developer.apple.com/library/archive/documentation/FileManagement/Conceptual/FileSystemProgrammingGuide/MacOSXDirectories/MacOSXDirectories.html>`_.
Makes use of the `appname <platformdirs.api.PlatformDirsABC.appname>`,
`version <platformdirs.api.PlatformDirsABC.version>`,
`ensure_exists <platformdirs.api.PlatformDirsABC.ensure_exists>`.
"""
@property
@ -28,7 +31,7 @@ class MacOS(PlatformDirsABC):
:return: data directory shared by users, e.g. ``/Library/Application Support/$appname/$version``.
If we're using a Python binary managed by `Homebrew <https://brew.sh>`_, the directory
will be under the Homebrew prefix, e.g. ``/opt/homebrew/share/$appname/$version``.
If `multipath <platformdirs.api.PlatformDirsABC.multipath>` is enabled and we're in Homebrew,
If `multipath <platformdirs.api.PlatformDirsABC.multipath>` is enabled, and we're in Homebrew,
the response is a multi-path string separated by ":", e.g.
``/opt/homebrew/share/$appname/$version:/Library/Application Support/$appname/$version``
"""
@ -60,7 +63,7 @@ class MacOS(PlatformDirsABC):
:return: cache directory shared by users, e.g. ``/Library/Caches/$appname/$version``.
If we're using a Python binary managed by `Homebrew <https://brew.sh>`_, the directory
will be under the Homebrew prefix, e.g. ``/opt/homebrew/var/cache/$appname/$version``.
If `multipath <platformdirs.api.PlatformDirsABC.multipath>` is enabled and we're in Homebrew,
If `multipath <platformdirs.api.PlatformDirsABC.multipath>` is enabled, and we're in Homebrew,
the response is a multi-path string separated by ":", e.g.
``/opt/homebrew/var/cache/$appname/$version:/Library/Caches/$appname/$version``
"""

View file

@ -6,13 +6,13 @@ import os
import sys
from configparser import ConfigParser
from pathlib import Path
from typing import Iterator
from typing import Iterator, NoReturn
from .api import PlatformDirsABC
if sys.platform == "win32":
def getuid() -> int:
def getuid() -> NoReturn:
msg = "should only be used on Unix"
raise RuntimeError(msg)
@ -20,17 +20,17 @@ else:
from os import getuid
class Unix(PlatformDirsABC):
class Unix(PlatformDirsABC): # noqa: PLR0904
"""
On Unix/Linux, we follow the
`XDG Basedir Spec <https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html>`_. The spec allows
overriding directories with environment variables. The examples show are the default values, alongside the name of
the environment variable that overrides them. Makes use of the
`appname <platformdirs.api.PlatformDirsABC.appname>`,
`version <platformdirs.api.PlatformDirsABC.version>`,
`multipath <platformdirs.api.PlatformDirsABC.multipath>`,
`opinion <platformdirs.api.PlatformDirsABC.opinion>`,
`ensure_exists <platformdirs.api.PlatformDirsABC.ensure_exists>`.
On Unix/Linux, we follow the `XDG Basedir Spec <https://specifications.freedesktop.org/basedir-spec/basedir-spec-
latest.html>`_.
The spec allows overriding directories with environment variables. The examples shown are the default values,
alongside the name of the environment variable that overrides them. Makes use of the `appname
<platformdirs.api.PlatformDirsABC.appname>`, `version <platformdirs.api.PlatformDirsABC.version>`, `multipath
<platformdirs.api.PlatformDirsABC.multipath>`, `opinion <platformdirs.api.PlatformDirsABC.opinion>`, `ensure_exists
<platformdirs.api.PlatformDirsABC.ensure_exists>`.
"""
@property
@ -205,17 +205,17 @@ class Unix(PlatformDirsABC):
@property
def site_data_path(self) -> Path:
""":return: data path shared by users. Only return first item, even if ``multipath`` is set to ``True``"""
""":return: data path shared by users. Only return the first item, even if ``multipath`` is set to ``True``"""
return self._first_item_as_path_if_multipath(self.site_data_dir)
@property
def site_config_path(self) -> Path:
""":return: config path shared by the users. Only return first item, even if ``multipath`` is set to ``True``"""
""":return: config path shared by the users, returns the first item, even if ``multipath`` is set to ``True``"""
return self._first_item_as_path_if_multipath(self.site_config_dir)
@property
def site_cache_path(self) -> Path:
""":return: cache path shared by users. Only return first item, even if ``multipath`` is set to ``True``"""
""":return: cache path shared by users. Only return the first item, even if ``multipath`` is set to ``True``"""
return self._first_item_as_path_if_multipath(self.site_cache_dir)
def _first_item_as_path_if_multipath(self, directory: str) -> Path:
@ -246,7 +246,12 @@ def _get_user_media_dir(env_var: str, fallback_tilde_path: str) -> str:
def _get_user_dirs_folder(key: str) -> str | None:
"""Return directory from user-dirs.dirs config file. See https://freedesktop.org/wiki/Software/xdg-user-dirs/."""
"""
Return directory from user-dirs.dirs config file.
See https://freedesktop.org/wiki/Software/xdg-user-dirs/.
"""
user_dirs_config_path = Path(Unix().user_config_dir) / "user-dirs.dirs"
if user_dirs_config_path.exists():
parser = ConfigParser()

View file

@ -12,5 +12,5 @@ __version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
__version__ = version = '4.2.0'
__version_tuple__ = version_tuple = (4, 2, 0)
__version__ = version = '4.2.1'
__version_tuple__ = version_tuple = (4, 2, 1)

View file

@ -2,7 +2,6 @@
from __future__ import annotations
import ctypes
import os
import sys
from functools import lru_cache
@ -16,15 +15,13 @@ if TYPE_CHECKING:
class Windows(PlatformDirsABC):
"""
`MSDN on where to store app data files
<http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120>`_.
Makes use of the
`appname <platformdirs.api.PlatformDirsABC.appname>`,
`appauthor <platformdirs.api.PlatformDirsABC.appauthor>`,
`version <platformdirs.api.PlatformDirsABC.version>`,
`roaming <platformdirs.api.PlatformDirsABC.roaming>`,
`opinion <platformdirs.api.PlatformDirsABC.opinion>`,
`ensure_exists <platformdirs.api.PlatformDirsABC.ensure_exists>`.
`MSDN on where to store app data files <https://learn.microsoft.com/en-us/windows/win32/shell/knownfolderid>`_.
Makes use of the `appname <platformdirs.api.PlatformDirsABC.appname>`, `appauthor
<platformdirs.api.PlatformDirsABC.appauthor>`, `version <platformdirs.api.PlatformDirsABC.version>`, `roaming
<platformdirs.api.PlatformDirsABC.roaming>`, `opinion <platformdirs.api.PlatformDirsABC.opinion>`, `ensure_exists
<platformdirs.api.PlatformDirsABC.ensure_exists>`.
"""
@property
@ -165,7 +162,7 @@ def get_win_folder_from_env_vars(csidl_name: str) -> str:
def get_win_folder_if_csidl_name_not_env_var(csidl_name: str) -> str | None:
"""Get folder for a CSIDL name that does not exist as an environment variable."""
"""Get a folder for a CSIDL name that does not exist as an environment variable."""
if csidl_name == "CSIDL_PERSONAL":
return os.path.join(os.path.normpath(os.environ["USERPROFILE"]), "Documents") # noqa: PTH118
@ -189,6 +186,7 @@ def get_win_folder_from_registry(csidl_name: str) -> str:
This is a fallback technique at best. I'm not sure if using the registry for these guarantees us the correct answer
for all CSIDL_* names.
"""
shell_folder_name = {
"CSIDL_APPDATA": "AppData",
@ -205,7 +203,7 @@ def get_win_folder_from_registry(csidl_name: str) -> str:
raise ValueError(msg)
if sys.platform != "win32": # only needed for mypy type checker to know that this code runs only on Windows
raise NotImplementedError
import winreg
import winreg # noqa: PLC0415
key = winreg.OpenKey(winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders")
directory, _ = winreg.QueryValueEx(key, shell_folder_name)
@ -218,6 +216,8 @@ def get_win_folder_via_ctypes(csidl_name: str) -> str:
# Use 'CSIDL_PROFILE' (40) and append the default folder 'Downloads' instead.
# https://learn.microsoft.com/en-us/windows/win32/shell/knownfolderid
import ctypes # noqa: PLC0415
csidl_const = {
"CSIDL_APPDATA": 26,
"CSIDL_COMMON_APPDATA": 35,
@ -250,10 +250,15 @@ def get_win_folder_via_ctypes(csidl_name: str) -> str:
def _pick_get_win_folder() -> Callable[[str], str]:
if hasattr(ctypes, "windll"):
return get_win_folder_via_ctypes
try:
import winreg # noqa: F401
import ctypes # noqa: PLC0415
except ImportError:
pass
else:
if hasattr(ctypes, "windll"):
return get_win_folder_via_ctypes
try:
import winreg # noqa: PLC0415, F401
except ImportError:
return get_win_folder_from_env_vars
else:

View file

@ -170,7 +170,16 @@ class PlexObject:
elem = ElementTree.fromstring(xml)
return self._buildItemOrNone(elem, cls)
def fetchItems(self, ekey, cls=None, container_start=None, container_size=None, maxresults=None, **kwargs):
def fetchItems(
self,
ekey,
cls=None,
container_start=None,
container_size=None,
maxresults=None,
params=None,
**kwargs,
):
""" Load the specified key to find and build all items with the specified tag
and attrs.
@ -186,6 +195,7 @@ class PlexObject:
container_start (None, int): offset to get a subset of the data
container_size (None, int): How many items in data
maxresults (int, optional): Only return the specified number of results.
params (dict, optional): Any additional params to add to the request.
**kwargs (dict): Optionally add XML attribute to filter the items.
See the details below for more info.
@ -268,7 +278,7 @@ class PlexObject:
headers['X-Plex-Container-Start'] = str(container_start)
headers['X-Plex-Container-Size'] = str(container_size)
data = self._server.query(ekey, headers=headers)
data = self._server.query(ekey, headers=headers, params=params)
subresults = self.findItems(data, cls, ekey, **kwargs)
total_size = utils.cast(int, data.attrib.get('totalSize') or data.attrib.get('size')) or len(subresults)
@ -283,6 +293,11 @@ class PlexObject:
results.extend(subresults)
container_start += container_size
if container_start > total_size:
break
wanted_number_of_items = total_size - offset
if maxresults is not None:
wanted_number_of_items = min(maxresults, wanted_number_of_items)
@ -291,11 +306,6 @@ class PlexObject:
if wanted_number_of_items <= len(results):
break
container_start += container_size
if container_start > total_size:
break
return results
def fetchItem(self, ekey, cls=None, **kwargs):
@ -337,7 +347,7 @@ class PlexObject:
kwargs['type'] = cls.TYPE
# rtag to iter on a specific root tag using breadth-first search
if rtag:
data = next(utils.iterXMLBFS(data, rtag), [])
data = next(utils.iterXMLBFS(data, rtag), Element('Empty'))
# loop through all data elements to find matches
items = MediaContainer[cls](self._server, data, initpath=initpath) if data.tag == 'MediaContainer' else []
for elem in data:

View file

@ -4,6 +4,6 @@
# Library version
MAJOR_VERSION = 4
MINOR_VERSION = 15
PATCH_VERSION = 11
PATCH_VERSION = 12
__short_version__ = f"{MAJOR_VERSION}.{MINOR_VERSION}"
__version__ = f"{__short_version__}.{PATCH_VERSION}"

View file

@ -746,7 +746,7 @@ class PlexServer(PlexObject):
""" Returns list of all :class:`~plexapi.media.TranscodeJob` objects running or paused on server. """
return self.fetchItems('/status/sessions/background')
def query(self, key, method=None, headers=None, timeout=None, **kwargs):
def query(self, key, method=None, headers=None, params=None, timeout=None, **kwargs):
""" Main method used to handle HTTPS requests to the Plex server. This method helps
by encoding the response to utf-8 and parsing the returned XML into and
ElementTree object. Returns None if no data exists in the response.
@ -756,7 +756,7 @@ class PlexServer(PlexObject):
timeout = timeout or self._timeout
log.debug('%s %s', method.__name__.upper(), url)
headers = self._headers(**headers or {})
response = method(url, headers=headers, timeout=timeout, **kwargs)
response = method(url, headers=headers, params=params, timeout=timeout, **kwargs)
if response.status_code not in (200, 201, 204):
codename = codes.get(response.status_code)[0]
errtext = response.text.replace('\n', ' ')

48
lib/typeguard/__init__.py Normal file
View file

@ -0,0 +1,48 @@
import os
from typing import Any
from ._checkers import TypeCheckerCallable as TypeCheckerCallable
from ._checkers import TypeCheckLookupCallback as TypeCheckLookupCallback
from ._checkers import check_type_internal as check_type_internal
from ._checkers import checker_lookup_functions as checker_lookup_functions
from ._checkers import load_plugins as load_plugins
from ._config import CollectionCheckStrategy as CollectionCheckStrategy
from ._config import ForwardRefPolicy as ForwardRefPolicy
from ._config import TypeCheckConfiguration as TypeCheckConfiguration
from ._decorators import typechecked as typechecked
from ._decorators import typeguard_ignore as typeguard_ignore
from ._exceptions import InstrumentationWarning as InstrumentationWarning
from ._exceptions import TypeCheckError as TypeCheckError
from ._exceptions import TypeCheckWarning as TypeCheckWarning
from ._exceptions import TypeHintWarning as TypeHintWarning
from ._functions import TypeCheckFailCallback as TypeCheckFailCallback
from ._functions import check_type as check_type
from ._functions import warn_on_error as warn_on_error
from ._importhook import ImportHookManager as ImportHookManager
from ._importhook import TypeguardFinder as TypeguardFinder
from ._importhook import install_import_hook as install_import_hook
from ._memo import TypeCheckMemo as TypeCheckMemo
from ._suppression import suppress_type_checks as suppress_type_checks
from ._utils import Unset as Unset
# Re-export imports so they look like they live directly in this package
for value in list(locals().values()):
if getattr(value, "__module__", "").startswith(f"{__name__}."):
value.__module__ = __name__
config: TypeCheckConfiguration
def __getattr__(name: str) -> Any:
if name == "config":
from ._config import global_config
return global_config
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
# Automatically load checker lookup functions unless explicitly disabled
if "TYPEGUARD_DISABLE_PLUGIN_AUTOLOAD" not in os.environ:
load_plugins()

910
lib/typeguard/_checkers.py Normal file
View file

@ -0,0 +1,910 @@
from __future__ import annotations
import collections.abc
import inspect
import sys
import types
import typing
import warnings
from enum import Enum
from inspect import Parameter, isclass, isfunction
from io import BufferedIOBase, IOBase, RawIOBase, TextIOBase
from textwrap import indent
from typing import (
IO,
AbstractSet,
Any,
BinaryIO,
Callable,
Dict,
ForwardRef,
List,
Mapping,
MutableMapping,
NewType,
Optional,
Sequence,
Set,
TextIO,
Tuple,
Type,
TypeVar,
Union,
)
from unittest.mock import Mock
try:
import typing_extensions
except ImportError:
typing_extensions = None # type: ignore[assignment]
from ._config import ForwardRefPolicy
from ._exceptions import TypeCheckError, TypeHintWarning
from ._memo import TypeCheckMemo
from ._utils import evaluate_forwardref, get_stacklevel, get_type_name, qualified_name
if sys.version_info >= (3, 13):
from typing import is_typeddict
else:
from typing_extensions import is_typeddict
if sys.version_info >= (3, 11):
from typing import (
Annotated,
NotRequired,
TypeAlias,
get_args,
get_origin,
)
SubclassableAny = Any
else:
from typing_extensions import (
Annotated,
NotRequired,
TypeAlias,
get_args,
get_origin,
)
from typing_extensions import Any as SubclassableAny
if sys.version_info >= (3, 10):
from importlib.metadata import entry_points
from typing import ParamSpec
else:
from importlib_metadata import entry_points
from typing_extensions import ParamSpec
TypeCheckerCallable: TypeAlias = Callable[
[Any, Any, Tuple[Any, ...], TypeCheckMemo], Any
]
TypeCheckLookupCallback: TypeAlias = Callable[
[Any, Tuple[Any, ...], Tuple[Any, ...]], Optional[TypeCheckerCallable]
]
checker_lookup_functions: list[TypeCheckLookupCallback] = []
generic_alias_types: tuple[type, ...] = (type(List), type(List[Any]))
if sys.version_info >= (3, 9):
generic_alias_types += (types.GenericAlias,)
# Sentinel
_missing = object()
# Lifted from mypy.sharedparse
BINARY_MAGIC_METHODS = {
"__add__",
"__and__",
"__cmp__",
"__divmod__",
"__div__",
"__eq__",
"__floordiv__",
"__ge__",
"__gt__",
"__iadd__",
"__iand__",
"__idiv__",
"__ifloordiv__",
"__ilshift__",
"__imatmul__",
"__imod__",
"__imul__",
"__ior__",
"__ipow__",
"__irshift__",
"__isub__",
"__itruediv__",
"__ixor__",
"__le__",
"__lshift__",
"__lt__",
"__matmul__",
"__mod__",
"__mul__",
"__ne__",
"__or__",
"__pow__",
"__radd__",
"__rand__",
"__rdiv__",
"__rfloordiv__",
"__rlshift__",
"__rmatmul__",
"__rmod__",
"__rmul__",
"__ror__",
"__rpow__",
"__rrshift__",
"__rshift__",
"__rsub__",
"__rtruediv__",
"__rxor__",
"__sub__",
"__truediv__",
"__xor__",
}
def check_callable(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not callable(value):
raise TypeCheckError("is not callable")
if args:
try:
signature = inspect.signature(value)
except (TypeError, ValueError):
return
argument_types = args[0]
if isinstance(argument_types, list) and not any(
type(item) is ParamSpec for item in argument_types
):
# The callable must not have keyword-only arguments without defaults
unfulfilled_kwonlyargs = [
param.name
for param in signature.parameters.values()
if param.kind == Parameter.KEYWORD_ONLY
and param.default == Parameter.empty
]
if unfulfilled_kwonlyargs:
raise TypeCheckError(
f"has mandatory keyword-only arguments in its declaration: "
f'{", ".join(unfulfilled_kwonlyargs)}'
)
num_positional_args = num_mandatory_pos_args = 0
has_varargs = False
for param in signature.parameters.values():
if param.kind in (
Parameter.POSITIONAL_ONLY,
Parameter.POSITIONAL_OR_KEYWORD,
):
num_positional_args += 1
if param.default is Parameter.empty:
num_mandatory_pos_args += 1
elif param.kind == Parameter.VAR_POSITIONAL:
has_varargs = True
if num_mandatory_pos_args > len(argument_types):
raise TypeCheckError(
f"has too many mandatory positional arguments in its declaration; "
f"expected {len(argument_types)} but {num_mandatory_pos_args} "
f"mandatory positional argument(s) declared"
)
elif not has_varargs and num_positional_args < len(argument_types):
raise TypeCheckError(
f"has too few arguments in its declaration; expected "
f"{len(argument_types)} but {num_positional_args} argument(s) "
f"declared"
)
def check_mapping(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if origin_type is Dict or origin_type is dict:
if not isinstance(value, dict):
raise TypeCheckError("is not a dict")
if origin_type is MutableMapping or origin_type is collections.abc.MutableMapping:
if not isinstance(value, collections.abc.MutableMapping):
raise TypeCheckError("is not a mutable mapping")
elif not isinstance(value, collections.abc.Mapping):
raise TypeCheckError("is not a mapping")
if args:
key_type, value_type = args
if key_type is not Any or value_type is not Any:
samples = memo.config.collection_check_strategy.iterate_samples(
value.items()
)
for k, v in samples:
try:
check_type_internal(k, key_type, memo)
except TypeCheckError as exc:
exc.append_path_element(f"key {k!r}")
raise
try:
check_type_internal(v, value_type, memo)
except TypeCheckError as exc:
exc.append_path_element(f"value of key {k!r}")
raise
def check_typed_dict(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not isinstance(value, dict):
raise TypeCheckError("is not a dict")
declared_keys = frozenset(origin_type.__annotations__)
if hasattr(origin_type, "__required_keys__"):
required_keys = set(origin_type.__required_keys__)
else: # py3.8 and lower
required_keys = set(declared_keys) if origin_type.__total__ else set()
existing_keys = set(value)
extra_keys = existing_keys - declared_keys
if extra_keys:
keys_formatted = ", ".join(f'"{key}"' for key in sorted(extra_keys, key=repr))
raise TypeCheckError(f"has unexpected extra key(s): {keys_formatted}")
# Detect NotRequired fields which are hidden by get_type_hints()
type_hints: dict[str, type] = {}
for key, annotation in origin_type.__annotations__.items():
if isinstance(annotation, ForwardRef):
annotation = evaluate_forwardref(annotation, memo)
if get_origin(annotation) is NotRequired:
required_keys.discard(key)
annotation = get_args(annotation)[0]
type_hints[key] = annotation
missing_keys = required_keys - existing_keys
if missing_keys:
keys_formatted = ", ".join(f'"{key}"' for key in sorted(missing_keys, key=repr))
raise TypeCheckError(f"is missing required key(s): {keys_formatted}")
for key, argtype in type_hints.items():
argvalue = value.get(key, _missing)
if argvalue is not _missing:
try:
check_type_internal(argvalue, argtype, memo)
except TypeCheckError as exc:
exc.append_path_element(f"value of key {key!r}")
raise
def check_list(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not isinstance(value, list):
raise TypeCheckError("is not a list")
if args and args != (Any,):
samples = memo.config.collection_check_strategy.iterate_samples(value)
for i, v in enumerate(samples):
try:
check_type_internal(v, args[0], memo)
except TypeCheckError as exc:
exc.append_path_element(f"item {i}")
raise
def check_sequence(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not isinstance(value, collections.abc.Sequence):
raise TypeCheckError("is not a sequence")
if args and args != (Any,):
samples = memo.config.collection_check_strategy.iterate_samples(value)
for i, v in enumerate(samples):
try:
check_type_internal(v, args[0], memo)
except TypeCheckError as exc:
exc.append_path_element(f"item {i}")
raise
def check_set(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if origin_type is frozenset:
if not isinstance(value, frozenset):
raise TypeCheckError("is not a frozenset")
elif not isinstance(value, AbstractSet):
raise TypeCheckError("is not a set")
if args and args != (Any,):
samples = memo.config.collection_check_strategy.iterate_samples(value)
for v in samples:
try:
check_type_internal(v, args[0], memo)
except TypeCheckError as exc:
exc.append_path_element(f"[{v}]")
raise
def check_tuple(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
# Specialized check for NamedTuples
if field_types := getattr(origin_type, "__annotations__", None):
if not isinstance(value, origin_type):
raise TypeCheckError(
f"is not a named tuple of type {qualified_name(origin_type)}"
)
for name, field_type in field_types.items():
try:
check_type_internal(getattr(value, name), field_type, memo)
except TypeCheckError as exc:
exc.append_path_element(f"attribute {name!r}")
raise
return
elif not isinstance(value, tuple):
raise TypeCheckError("is not a tuple")
if args:
use_ellipsis = args[-1] is Ellipsis
tuple_params = args[: -1 if use_ellipsis else None]
else:
# Unparametrized Tuple or plain tuple
return
if use_ellipsis:
element_type = tuple_params[0]
samples = memo.config.collection_check_strategy.iterate_samples(value)
for i, element in enumerate(samples):
try:
check_type_internal(element, element_type, memo)
except TypeCheckError as exc:
exc.append_path_element(f"item {i}")
raise
elif tuple_params == ((),):
if value != ():
raise TypeCheckError("is not an empty tuple")
else:
if len(value) != len(tuple_params):
raise TypeCheckError(
f"has wrong number of elements (expected {len(tuple_params)}, got "
f"{len(value)} instead)"
)
for i, (element, element_type) in enumerate(zip(value, tuple_params)):
try:
check_type_internal(element, element_type, memo)
except TypeCheckError as exc:
exc.append_path_element(f"item {i}")
raise
def check_union(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
errors: dict[str, TypeCheckError] = {}
try:
for type_ in args:
try:
check_type_internal(value, type_, memo)
return
except TypeCheckError as exc:
errors[get_type_name(type_)] = exc
formatted_errors = indent(
"\n".join(f"{key}: {error}" for key, error in errors.items()), " "
)
finally:
del errors # avoid creating ref cycle
raise TypeCheckError(f"did not match any element in the union:\n{formatted_errors}")
def check_uniontype(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
errors: dict[str, TypeCheckError] = {}
for type_ in args:
try:
check_type_internal(value, type_, memo)
return
except TypeCheckError as exc:
errors[get_type_name(type_)] = exc
formatted_errors = indent(
"\n".join(f"{key}: {error}" for key, error in errors.items()), " "
)
raise TypeCheckError(f"did not match any element in the union:\n{formatted_errors}")
def check_class(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not isclass(value) and not isinstance(value, generic_alias_types):
raise TypeCheckError("is not a class")
if not args:
return
if isinstance(args[0], ForwardRef):
expected_class = evaluate_forwardref(args[0], memo)
else:
expected_class = args[0]
if expected_class is Any:
return
elif getattr(expected_class, "_is_protocol", False):
check_protocol(value, expected_class, (), memo)
elif isinstance(expected_class, TypeVar):
check_typevar(value, expected_class, (), memo, subclass_check=True)
elif get_origin(expected_class) is Union:
errors: dict[str, TypeCheckError] = {}
for arg in get_args(expected_class):
if arg is Any:
return
try:
check_class(value, type, (arg,), memo)
return
except TypeCheckError as exc:
errors[get_type_name(arg)] = exc
else:
formatted_errors = indent(
"\n".join(f"{key}: {error}" for key, error in errors.items()), " "
)
raise TypeCheckError(
f"did not match any element in the union:\n{formatted_errors}"
)
elif not issubclass(value, expected_class): # type: ignore[arg-type]
raise TypeCheckError(f"is not a subclass of {qualified_name(expected_class)}")
def check_newtype(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
check_type_internal(value, origin_type.__supertype__, memo)
def check_instance(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not isinstance(value, origin_type):
raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}")
def check_typevar(
value: Any,
origin_type: TypeVar,
args: tuple[Any, ...],
memo: TypeCheckMemo,
*,
subclass_check: bool = False,
) -> None:
if origin_type.__bound__ is not None:
annotation = (
Type[origin_type.__bound__] if subclass_check else origin_type.__bound__
)
check_type_internal(value, annotation, memo)
elif origin_type.__constraints__:
for constraint in origin_type.__constraints__:
annotation = Type[constraint] if subclass_check else constraint
try:
check_type_internal(value, annotation, memo)
except TypeCheckError:
pass
else:
break
else:
formatted_constraints = ", ".join(
get_type_name(constraint) for constraint in origin_type.__constraints__
)
raise TypeCheckError(
f"does not match any of the constraints " f"({formatted_constraints})"
)
if typing_extensions is None:
def _is_literal_type(typ: object) -> bool:
return typ is typing.Literal
else:
def _is_literal_type(typ: object) -> bool:
return typ is typing.Literal or typ is typing_extensions.Literal
def check_literal(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
def get_literal_args(literal_args: tuple[Any, ...]) -> tuple[Any, ...]:
retval: list[Any] = []
for arg in literal_args:
if _is_literal_type(get_origin(arg)):
retval.extend(get_literal_args(arg.__args__))
elif arg is None or isinstance(arg, (int, str, bytes, bool, Enum)):
retval.append(arg)
else:
raise TypeError(
f"Illegal literal value: {arg}"
) # TypeError here is deliberate
return tuple(retval)
final_args = tuple(get_literal_args(args))
try:
index = final_args.index(value)
except ValueError:
pass
else:
if type(final_args[index]) is type(value):
return
formatted_args = ", ".join(repr(arg) for arg in final_args)
raise TypeCheckError(f"is not any of ({formatted_args})") from None
def check_literal_string(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
check_type_internal(value, str, memo)
def check_typeguard(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
check_type_internal(value, bool, memo)
def check_none(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if value is not None:
raise TypeCheckError("is not None")
def check_number(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if origin_type is complex and not isinstance(value, (complex, float, int)):
raise TypeCheckError("is neither complex, float or int")
elif origin_type is float and not isinstance(value, (float, int)):
raise TypeCheckError("is neither float or int")
def check_io(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if origin_type is TextIO or (origin_type is IO and args == (str,)):
if not isinstance(value, TextIOBase):
raise TypeCheckError("is not a text based I/O object")
elif origin_type is BinaryIO or (origin_type is IO and args == (bytes,)):
if not isinstance(value, (RawIOBase, BufferedIOBase)):
raise TypeCheckError("is not a binary I/O object")
elif not isinstance(value, IOBase):
raise TypeCheckError("is not an I/O object")
def check_protocol(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
# TODO: implement proper compatibility checking and support non-runtime protocols
if getattr(origin_type, "_is_runtime_protocol", False):
if not isinstance(value, origin_type):
raise TypeCheckError(
f"is not compatible with the {origin_type.__qualname__} protocol"
)
else:
warnings.warn(
f"Typeguard cannot check the {origin_type.__qualname__} protocol because "
f"it is a non-runtime protocol. If you would like to type check this "
f"protocol, please use @typing.runtime_checkable",
stacklevel=get_stacklevel(),
)
def check_byteslike(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not isinstance(value, (bytearray, bytes, memoryview)):
raise TypeCheckError("is not bytes-like")
def check_self(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if memo.self_type is None:
raise TypeCheckError("cannot be checked against Self outside of a method call")
if isclass(value):
if not issubclass(value, memo.self_type):
raise TypeCheckError(
f"is not an instance of the self type "
f"({qualified_name(memo.self_type)})"
)
elif not isinstance(value, memo.self_type):
raise TypeCheckError(
f"is not an instance of the self type ({qualified_name(memo.self_type)})"
)
def check_paramspec(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
pass # No-op for now
def check_instanceof(
value: Any,
origin_type: Any,
args: tuple[Any, ...],
memo: TypeCheckMemo,
) -> None:
if not isinstance(value, origin_type):
raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}")
def check_type_internal(
value: Any,
annotation: Any,
memo: TypeCheckMemo,
) -> None:
"""
Check that the given object is compatible with the given type annotation.
This function should only be used by type checker callables. Applications should use
:func:`~.check_type` instead.
:param value: the value to check
:param annotation: the type annotation to check against
:param memo: a memo object containing configuration and information necessary for
looking up forward references
"""
if isinstance(annotation, ForwardRef):
try:
annotation = evaluate_forwardref(annotation, memo)
except NameError:
if memo.config.forward_ref_policy is ForwardRefPolicy.ERROR:
raise
elif memo.config.forward_ref_policy is ForwardRefPolicy.WARN:
warnings.warn(
f"Cannot resolve forward reference {annotation.__forward_arg__!r}",
TypeHintWarning,
stacklevel=get_stacklevel(),
)
return
if annotation is Any or annotation is SubclassableAny or isinstance(value, Mock):
return
# Skip type checks if value is an instance of a class that inherits from Any
if not isclass(value) and SubclassableAny in type(value).__bases__:
return
extras: tuple[Any, ...]
origin_type = get_origin(annotation)
if origin_type is Annotated:
annotation, *extras_ = get_args(annotation)
extras = tuple(extras_)
origin_type = get_origin(annotation)
else:
extras = ()
if origin_type is not None:
args = get_args(annotation)
# Compatibility hack to distinguish between unparametrized and empty tuple
# (tuple[()]), necessary due to https://github.com/python/cpython/issues/91137
if origin_type in (tuple, Tuple) and annotation is not Tuple and not args:
args = ((),)
else:
origin_type = annotation
args = ()
for lookup_func in checker_lookup_functions:
checker = lookup_func(origin_type, args, extras)
if checker:
checker(value, origin_type, args, memo)
return
if isclass(origin_type):
if not isinstance(value, origin_type):
raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}")
elif type(origin_type) is str: # noqa: E721
warnings.warn(
f"Skipping type check against {origin_type!r}; this looks like a "
f"string-form forward reference imported from another module",
TypeHintWarning,
stacklevel=get_stacklevel(),
)
# Equality checks are applied to these
origin_type_checkers = {
bytes: check_byteslike,
AbstractSet: check_set,
BinaryIO: check_io,
Callable: check_callable,
collections.abc.Callable: check_callable,
complex: check_number,
dict: check_mapping,
Dict: check_mapping,
float: check_number,
frozenset: check_set,
IO: check_io,
list: check_list,
List: check_list,
typing.Literal: check_literal,
Mapping: check_mapping,
MutableMapping: check_mapping,
None: check_none,
collections.abc.Mapping: check_mapping,
collections.abc.MutableMapping: check_mapping,
Sequence: check_sequence,
collections.abc.Sequence: check_sequence,
collections.abc.Set: check_set,
set: check_set,
Set: check_set,
TextIO: check_io,
tuple: check_tuple,
Tuple: check_tuple,
type: check_class,
Type: check_class,
Union: check_union,
}
if sys.version_info >= (3, 10):
origin_type_checkers[types.UnionType] = check_uniontype
origin_type_checkers[typing.TypeGuard] = check_typeguard
if sys.version_info >= (3, 11):
origin_type_checkers.update(
{typing.LiteralString: check_literal_string, typing.Self: check_self}
)
if typing_extensions is not None:
# On some Python versions, these may simply be re-exports from typing,
# but exactly which Python versions is subject to change,
# so it's best to err on the safe side
# and update the dictionary on all Python versions
# if typing_extensions is installed
origin_type_checkers[typing_extensions.Literal] = check_literal
origin_type_checkers[typing_extensions.LiteralString] = check_literal_string
origin_type_checkers[typing_extensions.Self] = check_self
origin_type_checkers[typing_extensions.TypeGuard] = check_typeguard
def builtin_checker_lookup(
origin_type: Any, args: tuple[Any, ...], extras: tuple[Any, ...]
) -> TypeCheckerCallable | None:
checker = origin_type_checkers.get(origin_type)
if checker is not None:
return checker
elif is_typeddict(origin_type):
return check_typed_dict
elif isclass(origin_type) and issubclass(
origin_type, Tuple # type: ignore[arg-type]
):
# NamedTuple
return check_tuple
elif getattr(origin_type, "_is_protocol", False):
return check_protocol
elif isinstance(origin_type, ParamSpec):
return check_paramspec
elif isinstance(origin_type, TypeVar):
return check_typevar
elif origin_type.__class__ is NewType:
# typing.NewType on Python 3.10+
return check_newtype
elif (
isfunction(origin_type)
and getattr(origin_type, "__module__", None) == "typing"
and getattr(origin_type, "__qualname__", "").startswith("NewType.")
and hasattr(origin_type, "__supertype__")
):
# typing.NewType on Python 3.9 and below
return check_newtype
return None
checker_lookup_functions.append(builtin_checker_lookup)
def load_plugins() -> None:
"""
Load all type checker lookup functions from entry points.
All entry points from the ``typeguard.checker_lookup`` group are loaded, and the
returned lookup functions are added to :data:`typeguard.checker_lookup_functions`.
.. note:: This function is called implicitly on import, unless the
``TYPEGUARD_DISABLE_PLUGIN_AUTOLOAD`` environment variable is present.
"""
for ep in entry_points(group="typeguard.checker_lookup"):
try:
plugin = ep.load()
except Exception as exc:
warnings.warn(
f"Failed to load plugin {ep.name!r}: " f"{qualified_name(exc)}: {exc}",
stacklevel=2,
)
continue
if not callable(plugin):
warnings.warn(
f"Plugin {ep} returned a non-callable object: {plugin!r}", stacklevel=2
)
continue
checker_lookup_functions.insert(0, plugin)

108
lib/typeguard/_config.py Normal file
View file

@ -0,0 +1,108 @@
from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
from ._functions import TypeCheckFailCallback
T = TypeVar("T")
class ForwardRefPolicy(Enum):
"""
Defines how unresolved forward references are handled.
Members:
* ``ERROR``: propagate the :exc:`NameError` when the forward reference lookup fails
* ``WARN``: emit a :class:`~.TypeHintWarning` if the forward reference lookup fails
* ``IGNORE``: silently skip checks for unresolveable forward references
"""
ERROR = auto()
WARN = auto()
IGNORE = auto()
class CollectionCheckStrategy(Enum):
"""
Specifies how thoroughly the contents of collections are type checked.
This has an effect on the following built-in checkers:
* ``AbstractSet``
* ``Dict``
* ``List``
* ``Mapping``
* ``Set``
* ``Tuple[<type>, ...]`` (arbitrarily sized tuples)
Members:
* ``FIRST_ITEM``: check only the first item
* ``ALL_ITEMS``: check all items
"""
FIRST_ITEM = auto()
ALL_ITEMS = auto()
def iterate_samples(self, collection: Iterable[T]) -> Iterable[T]:
if self is CollectionCheckStrategy.FIRST_ITEM:
try:
return [next(iter(collection))]
except StopIteration:
return ()
else:
return collection
@dataclass
class TypeCheckConfiguration:
"""
You can change Typeguard's behavior with these settings.
.. attribute:: typecheck_fail_callback
:type: Callable[[TypeCheckError, TypeCheckMemo], Any]
Callable that is called when type checking fails.
Default: ``None`` (the :exc:`~.TypeCheckError` is raised directly)
.. attribute:: forward_ref_policy
:type: ForwardRefPolicy
Specifies what to do when a forward reference fails to resolve.
Default: ``WARN``
.. attribute:: collection_check_strategy
:type: CollectionCheckStrategy
Specifies how thoroughly the contents of collections (list, dict, etc.) are
type checked.
Default: ``FIRST_ITEM``
.. attribute:: debug_instrumentation
:type: bool
If set to ``True``, the code of modules or functions instrumented by typeguard
is printed to ``sys.stderr`` after the instrumentation is done
Requires Python 3.9 or newer.
Default: ``False``
"""
forward_ref_policy: ForwardRefPolicy = ForwardRefPolicy.WARN
typecheck_fail_callback: TypeCheckFailCallback | None = None
collection_check_strategy: CollectionCheckStrategy = (
CollectionCheckStrategy.FIRST_ITEM
)
debug_instrumentation: bool = False
global_config = TypeCheckConfiguration()

View file

@ -0,0 +1,235 @@
from __future__ import annotations
import ast
import inspect
import sys
from collections.abc import Sequence
from functools import partial
from inspect import isclass, isfunction
from types import CodeType, FrameType, FunctionType
from typing import TYPE_CHECKING, Any, Callable, ForwardRef, TypeVar, cast, overload
from warnings import warn
from ._config import CollectionCheckStrategy, ForwardRefPolicy, global_config
from ._exceptions import InstrumentationWarning
from ._functions import TypeCheckFailCallback
from ._transformer import TypeguardTransformer
from ._utils import Unset, function_name, get_stacklevel, is_method_of, unset
if TYPE_CHECKING:
from typeshed.stdlib.types import _Cell
_F = TypeVar("_F")
def typeguard_ignore(f: _F) -> _F:
"""This decorator is a noop during static type-checking."""
return f
else:
from typing import no_type_check as typeguard_ignore # noqa: F401
T_CallableOrType = TypeVar("T_CallableOrType", bound=Callable[..., Any])
def make_cell(value: object) -> _Cell:
return (lambda: value).__closure__[0] # type: ignore[index]
def find_target_function(
new_code: CodeType, target_path: Sequence[str], firstlineno: int
) -> CodeType | None:
target_name = target_path[0]
for const in new_code.co_consts:
if isinstance(const, CodeType):
if const.co_name == target_name:
if const.co_firstlineno == firstlineno:
return const
elif len(target_path) > 1:
target_code = find_target_function(
const, target_path[1:], firstlineno
)
if target_code:
return target_code
return None
def instrument(f: T_CallableOrType) -> FunctionType | str:
if not getattr(f, "__code__", None):
return "no code associated"
elif not getattr(f, "__module__", None):
return "__module__ attribute is not set"
elif f.__code__.co_filename == "<stdin>":
return "cannot instrument functions defined in a REPL"
elif hasattr(f, "__wrapped__"):
return (
"@typechecked only supports instrumenting functions wrapped with "
"@classmethod, @staticmethod or @property"
)
target_path = [item for item in f.__qualname__.split(".") if item != "<locals>"]
module_source = inspect.getsource(sys.modules[f.__module__])
module_ast = ast.parse(module_source)
instrumentor = TypeguardTransformer(target_path, f.__code__.co_firstlineno)
instrumentor.visit(module_ast)
if not instrumentor.target_node or instrumentor.target_lineno is None:
return "instrumentor did not find the target function"
module_code = compile(module_ast, f.__code__.co_filename, "exec", dont_inherit=True)
new_code = find_target_function(
module_code, target_path, instrumentor.target_lineno
)
if not new_code:
return "cannot find the target function in the AST"
if global_config.debug_instrumentation and sys.version_info >= (3, 9):
# Find the matching AST node, then unparse it to source and print to stdout
print(
f"Source code of {f.__qualname__}() after instrumentation:"
"\n----------------------------------------------",
file=sys.stderr,
)
print(ast.unparse(instrumentor.target_node), file=sys.stderr)
print(
"----------------------------------------------",
file=sys.stderr,
)
closure = f.__closure__
if new_code.co_freevars != f.__code__.co_freevars:
# Create a new closure and find values for the new free variables
frame = cast(FrameType, inspect.currentframe())
frame = cast(FrameType, frame.f_back)
frame_locals = cast(FrameType, frame.f_back).f_locals
cells: list[_Cell] = []
for key in new_code.co_freevars:
if key in instrumentor.names_used_in_annotations:
# Find the value and make a new cell from it
value = frame_locals.get(key) or ForwardRef(key)
cells.append(make_cell(value))
else:
# Reuse the cell from the existing closure
assert f.__closure__
cells.append(f.__closure__[f.__code__.co_freevars.index(key)])
closure = tuple(cells)
new_function = FunctionType(new_code, f.__globals__, f.__name__, closure=closure)
new_function.__module__ = f.__module__
new_function.__name__ = f.__name__
new_function.__qualname__ = f.__qualname__
new_function.__annotations__ = f.__annotations__
new_function.__doc__ = f.__doc__
new_function.__defaults__ = f.__defaults__
new_function.__kwdefaults__ = f.__kwdefaults__
return new_function
@overload
def typechecked(
*,
forward_ref_policy: ForwardRefPolicy | Unset = unset,
typecheck_fail_callback: TypeCheckFailCallback | Unset = unset,
collection_check_strategy: CollectionCheckStrategy | Unset = unset,
debug_instrumentation: bool | Unset = unset,
) -> Callable[[T_CallableOrType], T_CallableOrType]: ...
@overload
def typechecked(target: T_CallableOrType) -> T_CallableOrType: ...
def typechecked(
target: T_CallableOrType | None = None,
*,
forward_ref_policy: ForwardRefPolicy | Unset = unset,
typecheck_fail_callback: TypeCheckFailCallback | Unset = unset,
collection_check_strategy: CollectionCheckStrategy | Unset = unset,
debug_instrumentation: bool | Unset = unset,
) -> Any:
"""
Instrument the target function to perform run-time type checking.
This decorator recompiles the target function, injecting code to type check
arguments, return values, yield values (excluding ``yield from``) and assignments to
annotated local variables.
This can also be used as a class decorator. This will instrument all type annotated
methods, including :func:`@classmethod <classmethod>`,
:func:`@staticmethod <staticmethod>`, and :class:`@property <property>` decorated
methods in the class.
.. note:: When Python is run in optimized mode (``-O`` or ``-OO``, this decorator
is a no-op). This is a feature meant for selectively introducing type checking
into a code base where the checks aren't meant to be run in production.
:param target: the function or class to enable type checking for
:param forward_ref_policy: override for
:attr:`.TypeCheckConfiguration.forward_ref_policy`
:param typecheck_fail_callback: override for
:attr:`.TypeCheckConfiguration.typecheck_fail_callback`
:param collection_check_strategy: override for
:attr:`.TypeCheckConfiguration.collection_check_strategy`
:param debug_instrumentation: override for
:attr:`.TypeCheckConfiguration.debug_instrumentation`
"""
if target is None:
return partial(
typechecked,
forward_ref_policy=forward_ref_policy,
typecheck_fail_callback=typecheck_fail_callback,
collection_check_strategy=collection_check_strategy,
debug_instrumentation=debug_instrumentation,
)
if not __debug__:
return target
if isclass(target):
for key, attr in target.__dict__.items():
if is_method_of(attr, target):
retval = instrument(attr)
if isfunction(retval):
setattr(target, key, retval)
elif isinstance(attr, (classmethod, staticmethod)):
if is_method_of(attr.__func__, target):
retval = instrument(attr.__func__)
if isfunction(retval):
wrapper = attr.__class__(retval)
setattr(target, key, wrapper)
elif isinstance(attr, property):
kwargs: dict[str, Any] = dict(doc=attr.__doc__)
for name in ("fset", "fget", "fdel"):
property_func = kwargs[name] = getattr(attr, name)
if is_method_of(property_func, target):
retval = instrument(property_func)
if isfunction(retval):
kwargs[name] = retval
setattr(target, key, attr.__class__(**kwargs))
return target
# Find either the first Python wrapper or the actual function
wrapper_class: (
type[classmethod[Any, Any, Any]] | type[staticmethod[Any, Any]] | None
) = None
if isinstance(target, (classmethod, staticmethod)):
wrapper_class = target.__class__
target = target.__func__
retval = instrument(target)
if isinstance(retval, str):
warn(
f"{retval} -- not typechecking {function_name(target)}",
InstrumentationWarning,
stacklevel=get_stacklevel(),
)
return target
if wrapper_class is None:
return retval
else:
return wrapper_class(retval)

View file

@ -0,0 +1,42 @@
from collections import deque
from typing import Deque
class TypeHintWarning(UserWarning):
"""
A warning that is emitted when a type hint in string form could not be resolved to
an actual type.
"""
class TypeCheckWarning(UserWarning):
"""Emitted by typeguard's type checkers when a type mismatch is detected."""
def __init__(self, message: str):
super().__init__(message)
class InstrumentationWarning(UserWarning):
"""Emitted when there's a problem with instrumenting a function for type checks."""
def __init__(self, message: str):
super().__init__(message)
class TypeCheckError(Exception):
"""
Raised by typeguard's type checkers when a type mismatch is detected.
"""
def __init__(self, message: str):
super().__init__(message)
self._path: Deque[str] = deque()
def append_path_element(self, element: str) -> None:
self._path.append(element)
def __str__(self) -> str:
if self._path:
return " of ".join(self._path) + " " + str(self.args[0])
else:
return str(self.args[0])

308
lib/typeguard/_functions.py Normal file
View file

@ -0,0 +1,308 @@
from __future__ import annotations
import sys
import warnings
from typing import Any, Callable, NoReturn, TypeVar, Union, overload
from . import _suppression
from ._checkers import BINARY_MAGIC_METHODS, check_type_internal
from ._config import (
CollectionCheckStrategy,
ForwardRefPolicy,
TypeCheckConfiguration,
)
from ._exceptions import TypeCheckError, TypeCheckWarning
from ._memo import TypeCheckMemo
from ._utils import get_stacklevel, qualified_name
if sys.version_info >= (3, 11):
from typing import Literal, Never, TypeAlias
else:
from typing_extensions import Literal, Never, TypeAlias
T = TypeVar("T")
TypeCheckFailCallback: TypeAlias = Callable[[TypeCheckError, TypeCheckMemo], Any]
@overload
def check_type(
value: object,
expected_type: type[T],
*,
forward_ref_policy: ForwardRefPolicy = ...,
typecheck_fail_callback: TypeCheckFailCallback | None = ...,
collection_check_strategy: CollectionCheckStrategy = ...,
) -> T: ...
@overload
def check_type(
value: object,
expected_type: Any,
*,
forward_ref_policy: ForwardRefPolicy = ...,
typecheck_fail_callback: TypeCheckFailCallback | None = ...,
collection_check_strategy: CollectionCheckStrategy = ...,
) -> Any: ...
def check_type(
value: object,
expected_type: Any,
*,
forward_ref_policy: ForwardRefPolicy = TypeCheckConfiguration().forward_ref_policy,
typecheck_fail_callback: TypeCheckFailCallback | None = (
TypeCheckConfiguration().typecheck_fail_callback
),
collection_check_strategy: CollectionCheckStrategy = (
TypeCheckConfiguration().collection_check_strategy
),
) -> Any:
"""
Ensure that ``value`` matches ``expected_type``.
The types from the :mod:`typing` module do not support :func:`isinstance` or
:func:`issubclass` so a number of type specific checks are required. This function
knows which checker to call for which type.
This function wraps :func:`~.check_type_internal` in the following ways:
* Respects type checking suppression (:func:`~.suppress_type_checks`)
* Forms a :class:`~.TypeCheckMemo` from the current stack frame
* Calls the configured type check fail callback if the check fails
Note that this function is independent of the globally shared configuration in
:data:`typeguard.config`. This means that usage within libraries is safe from being
affected configuration changes made by other libraries or by the integrating
application. Instead, configuration options have the same default values as their
corresponding fields in :class:`TypeCheckConfiguration`.
:param value: value to be checked against ``expected_type``
:param expected_type: a class or generic type instance, or a tuple of such things
:param forward_ref_policy: see :attr:`TypeCheckConfiguration.forward_ref_policy`
:param typecheck_fail_callback:
see :attr`TypeCheckConfiguration.typecheck_fail_callback`
:param collection_check_strategy:
see :attr:`TypeCheckConfiguration.collection_check_strategy`
:return: ``value``, unmodified
:raises TypeCheckError: if there is a type mismatch
"""
if type(expected_type) is tuple:
expected_type = Union[expected_type]
config = TypeCheckConfiguration(
forward_ref_policy=forward_ref_policy,
typecheck_fail_callback=typecheck_fail_callback,
collection_check_strategy=collection_check_strategy,
)
if _suppression.type_checks_suppressed or expected_type is Any:
return value
frame = sys._getframe(1)
memo = TypeCheckMemo(frame.f_globals, frame.f_locals, config=config)
try:
check_type_internal(value, expected_type, memo)
except TypeCheckError as exc:
exc.append_path_element(qualified_name(value, add_class_prefix=True))
if config.typecheck_fail_callback:
config.typecheck_fail_callback(exc, memo)
else:
raise
return value
def check_argument_types(
func_name: str,
arguments: dict[str, tuple[Any, Any]],
memo: TypeCheckMemo,
) -> Literal[True]:
if _suppression.type_checks_suppressed:
return True
for argname, (value, annotation) in arguments.items():
if annotation is NoReturn or annotation is Never:
exc = TypeCheckError(
f"{func_name}() was declared never to be called but it was"
)
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise exc
try:
check_type_internal(value, annotation, memo)
except TypeCheckError as exc:
qualname = qualified_name(value, add_class_prefix=True)
exc.append_path_element(f'argument "{argname}" ({qualname})')
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise
return True
def check_return_type(
func_name: str,
retval: T,
annotation: Any,
memo: TypeCheckMemo,
) -> T:
if _suppression.type_checks_suppressed:
return retval
if annotation is NoReturn or annotation is Never:
exc = TypeCheckError(f"{func_name}() was declared never to return but it did")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise exc
try:
check_type_internal(retval, annotation, memo)
except TypeCheckError as exc:
# Allow NotImplemented if this is a binary magic method (__eq__() et al)
if retval is NotImplemented and annotation is bool:
# This does (and cannot) not check if it's actually a method
func_name = func_name.rsplit(".", 1)[-1]
if func_name in BINARY_MAGIC_METHODS:
return retval
qualname = qualified_name(retval, add_class_prefix=True)
exc.append_path_element(f"the return value ({qualname})")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise
return retval
def check_send_type(
func_name: str,
sendval: T,
annotation: Any,
memo: TypeCheckMemo,
) -> T:
if _suppression.type_checks_suppressed:
return sendval
if annotation is NoReturn or annotation is Never:
exc = TypeCheckError(
f"{func_name}() was declared never to be sent a value to but it was"
)
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise exc
try:
check_type_internal(sendval, annotation, memo)
except TypeCheckError as exc:
qualname = qualified_name(sendval, add_class_prefix=True)
exc.append_path_element(f"the value sent to generator ({qualname})")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise
return sendval
def check_yield_type(
func_name: str,
yieldval: T,
annotation: Any,
memo: TypeCheckMemo,
) -> T:
if _suppression.type_checks_suppressed:
return yieldval
if annotation is NoReturn or annotation is Never:
exc = TypeCheckError(f"{func_name}() was declared never to yield but it did")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise exc
try:
check_type_internal(yieldval, annotation, memo)
except TypeCheckError as exc:
qualname = qualified_name(yieldval, add_class_prefix=True)
exc.append_path_element(f"the yielded value ({qualname})")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise
return yieldval
def check_variable_assignment(
value: object, varname: str, annotation: Any, memo: TypeCheckMemo
) -> Any:
if _suppression.type_checks_suppressed:
return value
try:
check_type_internal(value, annotation, memo)
except TypeCheckError as exc:
qualname = qualified_name(value, add_class_prefix=True)
exc.append_path_element(f"value assigned to {varname} ({qualname})")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise
return value
def check_multi_variable_assignment(
value: Any, targets: list[dict[str, Any]], memo: TypeCheckMemo
) -> Any:
if max(len(target) for target in targets) == 1:
iterated_values = [value]
else:
iterated_values = list(value)
if not _suppression.type_checks_suppressed:
for expected_types in targets:
value_index = 0
for ann_index, (varname, expected_type) in enumerate(
expected_types.items()
):
if varname.startswith("*"):
varname = varname[1:]
keys_left = len(expected_types) - 1 - ann_index
next_value_index = len(iterated_values) - keys_left
obj: object = iterated_values[value_index:next_value_index]
value_index = next_value_index
else:
obj = iterated_values[value_index]
value_index += 1
try:
check_type_internal(obj, expected_type, memo)
except TypeCheckError as exc:
qualname = qualified_name(obj, add_class_prefix=True)
exc.append_path_element(f"value assigned to {varname} ({qualname})")
if memo.config.typecheck_fail_callback:
memo.config.typecheck_fail_callback(exc, memo)
else:
raise
return iterated_values[0] if len(iterated_values) == 1 else iterated_values
def warn_on_error(exc: TypeCheckError, memo: TypeCheckMemo) -> None:
"""
Emit a warning on a type mismatch.
This is intended to be used as an error handler in
:attr:`TypeCheckConfiguration.typecheck_fail_callback`.
"""
warnings.warn(TypeCheckWarning(str(exc)), stacklevel=get_stacklevel())

View file

@ -0,0 +1,213 @@
from __future__ import annotations
import ast
import sys
import types
from collections.abc import Callable, Iterable
from importlib.abc import MetaPathFinder
from importlib.machinery import ModuleSpec, SourceFileLoader
from importlib.util import cache_from_source, decode_source
from inspect import isclass
from os import PathLike
from types import CodeType, ModuleType, TracebackType
from typing import Sequence, TypeVar
from unittest.mock import patch
from ._config import global_config
from ._transformer import TypeguardTransformer
if sys.version_info >= (3, 12):
from collections.abc import Buffer
else:
from typing_extensions import Buffer
if sys.version_info >= (3, 11):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
if sys.version_info >= (3, 10):
from importlib.metadata import PackageNotFoundError, version
else:
from importlib_metadata import PackageNotFoundError, version
try:
OPTIMIZATION = "typeguard" + "".join(version("typeguard").split(".")[:3])
except PackageNotFoundError:
OPTIMIZATION = "typeguard"
P = ParamSpec("P")
T = TypeVar("T")
# The name of this function is magical
def _call_with_frames_removed(
f: Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> T:
return f(*args, **kwargs)
def optimized_cache_from_source(path: str, debug_override: bool | None = None) -> str:
return cache_from_source(path, debug_override, optimization=OPTIMIZATION)
class TypeguardLoader(SourceFileLoader):
@staticmethod
def source_to_code(
data: Buffer | str | ast.Module | ast.Expression | ast.Interactive,
path: Buffer | str | PathLike[str] = "<string>",
) -> CodeType:
if isinstance(data, (ast.Module, ast.Expression, ast.Interactive)):
tree = data
else:
if isinstance(data, str):
source = data
else:
source = decode_source(data)
tree = _call_with_frames_removed(
ast.parse,
source,
path,
"exec",
)
tree = TypeguardTransformer().visit(tree)
ast.fix_missing_locations(tree)
if global_config.debug_instrumentation and sys.version_info >= (3, 9):
print(
f"Source code of {path!r} after instrumentation:\n"
"----------------------------------------------",
file=sys.stderr,
)
print(ast.unparse(tree), file=sys.stderr)
print("----------------------------------------------", file=sys.stderr)
return _call_with_frames_removed(
compile, tree, path, "exec", 0, dont_inherit=True
)
def exec_module(self, module: ModuleType) -> None:
# Use a custom optimization marker the import lock should make this monkey
# patch safe
with patch(
"importlib._bootstrap_external.cache_from_source",
optimized_cache_from_source,
):
super().exec_module(module)
class TypeguardFinder(MetaPathFinder):
"""
Wraps another path finder and instruments the module with
:func:`@typechecked <typeguard.typechecked>` if :meth:`should_instrument` returns
``True``.
Should not be used directly, but rather via :func:`~.install_import_hook`.
.. versionadded:: 2.6
"""
def __init__(self, packages: list[str] | None, original_pathfinder: MetaPathFinder):
self.packages = packages
self._original_pathfinder = original_pathfinder
def find_spec(
self,
fullname: str,
path: Sequence[str] | None,
target: types.ModuleType | None = None,
) -> ModuleSpec | None:
if self.should_instrument(fullname):
spec = self._original_pathfinder.find_spec(fullname, path, target)
if spec is not None and isinstance(spec.loader, SourceFileLoader):
spec.loader = TypeguardLoader(spec.loader.name, spec.loader.path)
return spec
return None
def should_instrument(self, module_name: str) -> bool:
"""
Determine whether the module with the given name should be instrumented.
:param module_name: full name of the module that is about to be imported (e.g.
``xyz.abc``)
"""
if self.packages is None:
return True
for package in self.packages:
if module_name == package or module_name.startswith(package + "."):
return True
return False
class ImportHookManager:
"""
A handle that can be used to uninstall the Typeguard import hook.
"""
def __init__(self, hook: MetaPathFinder):
self.hook = hook
def __enter__(self) -> None:
pass
def __exit__(
self,
exc_type: type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
) -> None:
self.uninstall()
def uninstall(self) -> None:
"""Uninstall the import hook."""
try:
sys.meta_path.remove(self.hook)
except ValueError:
pass # already removed
def install_import_hook(
packages: Iterable[str] | None = None,
*,
cls: type[TypeguardFinder] = TypeguardFinder,
) -> ImportHookManager:
"""
Install an import hook that instruments functions for automatic type checking.
This only affects modules loaded **after** this hook has been installed.
:param packages: an iterable of package names to instrument, or ``None`` to
instrument all packages
:param cls: a custom meta path finder class
:return: a context manager that uninstalls the hook on exit (or when you call
``.uninstall()``)
.. versionadded:: 2.6
"""
if packages is None:
target_packages: list[str] | None = None
elif isinstance(packages, str):
target_packages = [packages]
else:
target_packages = list(packages)
for finder in sys.meta_path:
if (
isclass(finder)
and finder.__name__ == "PathFinder"
and hasattr(finder, "find_spec")
):
break
else:
raise RuntimeError("Cannot find a PathFinder in sys.meta_path")
hook = cls(target_packages, finder)
sys.meta_path.insert(0, hook)
return ImportHookManager(hook)

48
lib/typeguard/_memo.py Normal file
View file

@ -0,0 +1,48 @@
from __future__ import annotations
from typing import Any
from typeguard._config import TypeCheckConfiguration, global_config
class TypeCheckMemo:
"""
Contains information necessary for type checkers to do their work.
.. attribute:: globals
:type: dict[str, Any]
Dictionary of global variables to use for resolving forward references.
.. attribute:: locals
:type: dict[str, Any]
Dictionary of local variables to use for resolving forward references.
.. attribute:: self_type
:type: type | None
When running type checks within an instance method or class method, this is the
class object that the first argument (usually named ``self`` or ``cls``) refers
to.
.. attribute:: config
:type: TypeCheckConfiguration
Contains the configuration for a particular set of type checking operations.
"""
__slots__ = "globals", "locals", "self_type", "config"
def __init__(
self,
globals: dict[str, Any],
locals: dict[str, Any],
*,
self_type: type | None = None,
config: TypeCheckConfiguration = global_config,
):
self.globals = globals
self.locals = locals
self.self_type = self_type
self.config = config

View file

@ -0,0 +1,126 @@
from __future__ import annotations
import sys
import warnings
from typing import Any, Literal
from pytest import Config, Parser
from typeguard._config import CollectionCheckStrategy, ForwardRefPolicy, global_config
from typeguard._exceptions import InstrumentationWarning
from typeguard._importhook import install_import_hook
from typeguard._utils import qualified_name, resolve_reference
def pytest_addoption(parser: Parser) -> None:
def add_ini_option(
opt_type: (
Literal["string", "paths", "pathlist", "args", "linelist", "bool"] | None
)
) -> None:
parser.addini(
group.options[-1].names()[0][2:],
group.options[-1].attrs()["help"],
opt_type,
)
group = parser.getgroup("typeguard")
group.addoption(
"--typeguard-packages",
action="store",
help="comma separated name list of packages and modules to instrument for "
"type checking, or :all: to instrument all modules loaded after typeguard",
)
add_ini_option("linelist")
group.addoption(
"--typeguard-debug-instrumentation",
action="store_true",
help="print all instrumented code to stderr",
)
add_ini_option("bool")
group.addoption(
"--typeguard-typecheck-fail-callback",
action="store",
help=(
"a module:varname (e.g. typeguard:warn_on_error) reference to a function "
"that is called (with the exception, and memo object as arguments) to "
"handle a TypeCheckError"
),
)
add_ini_option("string")
group.addoption(
"--typeguard-forward-ref-policy",
action="store",
choices=list(ForwardRefPolicy.__members__),
help=(
"determines how to deal with unresolveable forward references in type "
"annotations"
),
)
add_ini_option("string")
group.addoption(
"--typeguard-collection-check-strategy",
action="store",
choices=list(CollectionCheckStrategy.__members__),
help="determines how thoroughly to check collections (list, dict, etc)",
)
add_ini_option("string")
def pytest_configure(config: Config) -> None:
def getoption(name: str) -> Any:
return config.getoption(name.replace("-", "_")) or config.getini(name)
packages: list[str] | None = []
if packages_option := config.getoption("typeguard_packages"):
packages = [pkg.strip() for pkg in packages_option.split(",")]
elif packages_ini := config.getini("typeguard-packages"):
packages = packages_ini
if packages:
if packages == [":all:"]:
packages = None
else:
already_imported_packages = sorted(
package for package in packages if package in sys.modules
)
if already_imported_packages:
warnings.warn(
f"typeguard cannot check these packages because they are already "
f"imported: {', '.join(already_imported_packages)}",
InstrumentationWarning,
stacklevel=1,
)
install_import_hook(packages=packages)
debug_option = getoption("typeguard-debug-instrumentation")
if debug_option:
global_config.debug_instrumentation = True
fail_callback_option = getoption("typeguard-typecheck-fail-callback")
if fail_callback_option:
callback = resolve_reference(fail_callback_option)
if not callable(callback):
raise TypeError(
f"{fail_callback_option} ({qualified_name(callback.__class__)}) is not "
f"a callable"
)
global_config.typecheck_fail_callback = callback
forward_ref_policy_option = getoption("typeguard-forward-ref-policy")
if forward_ref_policy_option:
forward_ref_policy = ForwardRefPolicy.__members__[forward_ref_policy_option]
global_config.forward_ref_policy = forward_ref_policy
collection_check_strategy_option = getoption("typeguard-collection-check-strategy")
if collection_check_strategy_option:
collection_check_strategy = CollectionCheckStrategy.__members__[
collection_check_strategy_option
]
global_config.collection_check_strategy = collection_check_strategy

View file

@ -0,0 +1,86 @@
from __future__ import annotations
import sys
from collections.abc import Callable, Generator
from contextlib import contextmanager
from functools import update_wrapper
from threading import Lock
from typing import ContextManager, TypeVar, overload
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
P = ParamSpec("P")
T = TypeVar("T")
type_checks_suppressed = 0
type_checks_suppress_lock = Lock()
@overload
def suppress_type_checks(func: Callable[P, T]) -> Callable[P, T]: ...
@overload
def suppress_type_checks() -> ContextManager[None]: ...
def suppress_type_checks(
func: Callable[P, T] | None = None
) -> Callable[P, T] | ContextManager[None]:
"""
Temporarily suppress all type checking.
This function has two operating modes, based on how it's used:
#. as a context manager (``with suppress_type_checks(): ...``)
#. as a decorator (``@suppress_type_checks``)
When used as a context manager, :func:`check_type` and any automatically
instrumented functions skip the actual type checking. These context managers can be
nested.
When used as a decorator, all type checking is suppressed while the function is
running.
Type checking will resume once no more context managers are active and no decorated
functions are running.
Both operating modes are thread-safe.
"""
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
global type_checks_suppressed
with type_checks_suppress_lock:
type_checks_suppressed += 1
assert func is not None
try:
return func(*args, **kwargs)
finally:
with type_checks_suppress_lock:
type_checks_suppressed -= 1
def cm() -> Generator[None, None, None]:
global type_checks_suppressed
with type_checks_suppress_lock:
type_checks_suppressed += 1
try:
yield
finally:
with type_checks_suppress_lock:
type_checks_suppressed -= 1
if func is None:
# Context manager mode
return contextmanager(cm)()
else:
# Decorator mode
update_wrapper(wrapper, func)
return wrapper

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,55 @@
"""
Transforms lazily evaluated PEP 604 unions into typing.Unions, for compatibility with
Python versions older than 3.10.
"""
from __future__ import annotations
from ast import (
BinOp,
BitOr,
Index,
Load,
Name,
NodeTransformer,
Subscript,
fix_missing_locations,
parse,
)
from ast import Tuple as ASTTuple
from types import CodeType
from typing import Any, Dict, FrozenSet, List, Set, Tuple, Union
type_substitutions = {
"dict": Dict,
"list": List,
"tuple": Tuple,
"set": Set,
"frozenset": FrozenSet,
"Union": Union,
}
class UnionTransformer(NodeTransformer):
def __init__(self, union_name: Name | None = None):
self.union_name = union_name or Name(id="Union", ctx=Load())
def visit_BinOp(self, node: BinOp) -> Any:
self.generic_visit(node)
if isinstance(node.op, BitOr):
return Subscript(
value=self.union_name,
slice=Index(
ASTTuple(elts=[node.left, node.right], ctx=Load()), ctx=Load()
),
ctx=Load(),
)
return node
def compile_type_hint(hint: str) -> CodeType:
parsed = parse(hint, "<string>", "eval")
UnionTransformer().visit(parsed)
fix_missing_locations(parsed)
return compile(parsed, "<string>", "eval", flags=0)

163
lib/typeguard/_utils.py Normal file
View file

@ -0,0 +1,163 @@
from __future__ import annotations
import inspect
import sys
from importlib import import_module
from inspect import currentframe
from types import CodeType, FrameType, FunctionType
from typing import TYPE_CHECKING, Any, Callable, ForwardRef, Union, cast, final
from weakref import WeakValueDictionary
if TYPE_CHECKING:
from ._memo import TypeCheckMemo
if sys.version_info >= (3, 10):
from typing import get_args, get_origin
def evaluate_forwardref(forwardref: ForwardRef, memo: TypeCheckMemo) -> Any:
return forwardref._evaluate(memo.globals, memo.locals, frozenset())
else:
from typing_extensions import get_args, get_origin
evaluate_extra_args: tuple[frozenset[Any], ...] = (
(frozenset(),) if sys.version_info >= (3, 9) else ()
)
def evaluate_forwardref(forwardref: ForwardRef, memo: TypeCheckMemo) -> Any:
from ._union_transformer import compile_type_hint, type_substitutions
if not forwardref.__forward_evaluated__:
forwardref.__forward_code__ = compile_type_hint(forwardref.__forward_arg__)
try:
return forwardref._evaluate(memo.globals, memo.locals, *evaluate_extra_args)
except NameError:
if sys.version_info < (3, 10):
# Try again, with the type substitutions (list -> List etc.) in place
new_globals = memo.globals.copy()
new_globals.setdefault("Union", Union)
if sys.version_info < (3, 9):
new_globals.update(type_substitutions)
return forwardref._evaluate(
new_globals, memo.locals or new_globals, *evaluate_extra_args
)
raise
_functions_map: WeakValueDictionary[CodeType, FunctionType] = WeakValueDictionary()
def get_type_name(type_: Any) -> str:
name: str
for attrname in "__name__", "_name", "__forward_arg__":
candidate = getattr(type_, attrname, None)
if isinstance(candidate, str):
name = candidate
break
else:
origin = get_origin(type_)
candidate = getattr(origin, "_name", None)
if candidate is None:
candidate = type_.__class__.__name__.strip("_")
if isinstance(candidate, str):
name = candidate
else:
return "(unknown)"
args = get_args(type_)
if args:
if name == "Literal":
formatted_args = ", ".join(repr(arg) for arg in args)
else:
formatted_args = ", ".join(get_type_name(arg) for arg in args)
name += f"[{formatted_args}]"
module = getattr(type_, "__module__", None)
if module and module not in (None, "typing", "typing_extensions", "builtins"):
name = module + "." + name
return name
def qualified_name(obj: Any, *, add_class_prefix: bool = False) -> str:
"""
Return the qualified name (e.g. package.module.Type) for the given object.
Builtins and types from the :mod:`typing` package get special treatment by having
the module name stripped from the generated name.
"""
if obj is None:
return "None"
elif inspect.isclass(obj):
prefix = "class " if add_class_prefix else ""
type_ = obj
else:
prefix = ""
type_ = type(obj)
module = type_.__module__
qualname = type_.__qualname__
name = qualname if module in ("typing", "builtins") else f"{module}.{qualname}"
return prefix + name
def function_name(func: Callable[..., Any]) -> str:
"""
Return the qualified name of the given function.
Builtins and types from the :mod:`typing` package get special treatment by having
the module name stripped from the generated name.
"""
# For partial functions and objects with __call__ defined, __qualname__ does not
# exist
module = getattr(func, "__module__", "")
qualname = (module + ".") if module not in ("builtins", "") else ""
return qualname + getattr(func, "__qualname__", repr(func))
def resolve_reference(reference: str) -> Any:
modulename, varname = reference.partition(":")[::2]
if not modulename or not varname:
raise ValueError(f"{reference!r} is not a module:varname reference")
obj = import_module(modulename)
for attr in varname.split("."):
obj = getattr(obj, attr)
return obj
def is_method_of(obj: object, cls: type) -> bool:
return (
inspect.isfunction(obj)
and obj.__module__ == cls.__module__
and obj.__qualname__.startswith(cls.__qualname__ + ".")
)
def get_stacklevel() -> int:
level = 1
frame = cast(FrameType, currentframe()).f_back
while frame and frame.f_globals.get("__name__", "").startswith("typeguard."):
level += 1
frame = frame.f_back
return level
@final
class Unset:
__slots__ = ()
def __repr__(self) -> str:
return "<unset>"
unset = Unset()

0
lib/typeguard/py.typed Normal file
View file

View file

@ -147,27 +147,6 @@ class _Sentinel:
_marker = _Sentinel()
def _check_generic(cls, parameters, elen=_marker):
"""Check correct count for parameters of a generic cls (internal helper).
This gives a nice error message in case of count mismatch.
"""
if not elen:
raise TypeError(f"{cls} is not a generic class")
if elen is _marker:
if not hasattr(cls, "__parameters__") or not cls.__parameters__:
raise TypeError(f"{cls} is not a generic class")
elen = len(cls.__parameters__)
alen = len(parameters)
if alen != elen:
if hasattr(cls, "__parameters__"):
parameters = [p for p in cls.__parameters__ if not _is_unpack(p)]
num_tv_tuples = sum(isinstance(p, TypeVarTuple) for p in parameters)
if (num_tv_tuples > 0) and (alen >= elen - num_tv_tuples):
return
raise TypeError(f"Too {'many' if alen > elen else 'few'} parameters for {cls};"
f" actual {alen}, expected {elen}")
if sys.version_info >= (3, 10):
def _should_collect_from_parameters(t):
return isinstance(
@ -181,27 +160,6 @@ else:
return isinstance(t, typing._GenericAlias) and not t._special
def _collect_type_vars(types, typevar_types=None):
"""Collect all type variable contained in types in order of
first appearance (lexicographic order). For example::
_collect_type_vars((T, List[S, T])) == (T, S)
"""
if typevar_types is None:
typevar_types = typing.TypeVar
tvars = []
for t in types:
if (
isinstance(t, typevar_types) and
t not in tvars and
not _is_unpack(t)
):
tvars.append(t)
if _should_collect_from_parameters(t):
tvars.extend([t for t in t.__parameters__ if t not in tvars])
return tuple(tvars)
NoReturn = typing.NoReturn
# Some unconstrained type variables. These are used by the container types.
@ -834,7 +792,11 @@ def _ensure_subclassable(mro_entries):
return inner
if hasattr(typing, "ReadOnly"):
# Update this to something like >=3.13.0b1 if and when
# PEP 728 is implemented in CPython
_PEP_728_IMPLEMENTED = False
if _PEP_728_IMPLEMENTED:
# The standard library TypedDict in Python 3.8 does not store runtime information
# about which (if any) keys are optional. See https://bugs.python.org/issue38834
# The standard library TypedDict in Python 3.9.0/1 does not honour the "total"
@ -845,7 +807,8 @@ if hasattr(typing, "ReadOnly"):
# Aaaand on 3.12 we add __orig_bases__ to TypedDict
# to enable better runtime introspection.
# On 3.13 we deprecate some odd ways of creating TypedDicts.
# PEP 705 proposes adding the ReadOnly[] qualifier.
# Also on 3.13, PEP 705 adds the ReadOnly[] qualifier.
# PEP 728 (still pending) makes more changes.
TypedDict = typing.TypedDict
_TypedDictMeta = typing._TypedDictMeta
is_typeddict = typing.is_typeddict
@ -1122,15 +1085,15 @@ else:
return val
if hasattr(typing, "Required"): # 3.11+
if hasattr(typing, "ReadOnly"): # 3.13+
get_type_hints = typing.get_type_hints
else: # <=3.10
else: # <=3.13
# replaces _strip_annotations()
def _strip_extras(t):
"""Strips Annotated, Required and NotRequired from a given type."""
if isinstance(t, _AnnotatedAlias):
return _strip_extras(t.__origin__)
if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired):
if hasattr(t, "__origin__") and t.__origin__ in (Required, NotRequired, ReadOnly):
return _strip_extras(t.__args__[0])
if isinstance(t, typing._GenericAlias):
stripped_args = tuple(_strip_extras(a) for a in t.__args__)
@ -2689,9 +2652,151 @@ else:
# counting generic parameters, so that when we subscript a generic,
# the runtime doesn't try to substitute the Unpack with the subscripted type.
if not hasattr(typing, "TypeVarTuple"):
typing._collect_type_vars = _collect_type_vars
typing._check_generic = _check_generic
def _check_generic(cls, parameters, elen=_marker):
"""Check correct count for parameters of a generic cls (internal helper).
This gives a nice error message in case of count mismatch.
"""
if not elen:
raise TypeError(f"{cls} is not a generic class")
if elen is _marker:
if not hasattr(cls, "__parameters__") or not cls.__parameters__:
raise TypeError(f"{cls} is not a generic class")
elen = len(cls.__parameters__)
alen = len(parameters)
if alen != elen:
expect_val = elen
if hasattr(cls, "__parameters__"):
parameters = [p for p in cls.__parameters__ if not _is_unpack(p)]
num_tv_tuples = sum(isinstance(p, TypeVarTuple) for p in parameters)
if (num_tv_tuples > 0) and (alen >= elen - num_tv_tuples):
return
# deal with TypeVarLike defaults
# required TypeVarLikes cannot appear after a defaulted one.
if alen < elen:
# since we validate TypeVarLike default in _collect_type_vars
# or _collect_parameters we can safely check parameters[alen]
if getattr(parameters[alen], '__default__', None) is not None:
return
num_default_tv = sum(getattr(p, '__default__', None)
is not None for p in parameters)
elen -= num_default_tv
expect_val = f"at least {elen}"
things = "arguments" if sys.version_info >= (3, 10) else "parameters"
raise TypeError(f"Too {'many' if alen > elen else 'few'} {things}"
f" for {cls}; actual {alen}, expected {expect_val}")
else:
# Python 3.11+
def _check_generic(cls, parameters, elen):
"""Check correct count for parameters of a generic cls (internal helper).
This gives a nice error message in case of count mismatch.
"""
if not elen:
raise TypeError(f"{cls} is not a generic class")
alen = len(parameters)
if alen != elen:
expect_val = elen
if hasattr(cls, "__parameters__"):
parameters = [p for p in cls.__parameters__ if not _is_unpack(p)]
# deal with TypeVarLike defaults
# required TypeVarLikes cannot appear after a defaulted one.
if alen < elen:
# since we validate TypeVarLike default in _collect_type_vars
# or _collect_parameters we can safely check parameters[alen]
if getattr(parameters[alen], '__default__', None) is not None:
return
num_default_tv = sum(getattr(p, '__default__', None)
is not None for p in parameters)
elen -= num_default_tv
expect_val = f"at least {elen}"
raise TypeError(f"Too {'many' if alen > elen else 'few'} arguments"
f" for {cls}; actual {alen}, expected {expect_val}")
typing._check_generic = _check_generic
# Python 3.11+ _collect_type_vars was renamed to _collect_parameters
if hasattr(typing, '_collect_type_vars'):
def _collect_type_vars(types, typevar_types=None):
"""Collect all type variable contained in types in order of
first appearance (lexicographic order). For example::
_collect_type_vars((T, List[S, T])) == (T, S)
"""
if typevar_types is None:
typevar_types = typing.TypeVar
tvars = []
# required TypeVarLike cannot appear after TypeVarLike with default
default_encountered = False
for t in types:
if (
isinstance(t, typevar_types) and
t not in tvars and
not _is_unpack(t)
):
if getattr(t, '__default__', None) is not None:
default_encountered = True
elif default_encountered:
raise TypeError(f'Type parameter {t!r} without a default'
' follows type parameter with a default')
tvars.append(t)
if _should_collect_from_parameters(t):
tvars.extend([t for t in t.__parameters__ if t not in tvars])
return tuple(tvars)
typing._collect_type_vars = _collect_type_vars
else:
def _collect_parameters(args):
"""Collect all type variables and parameter specifications in args
in order of first appearance (lexicographic order).
For example::
assert _collect_parameters((T, Callable[P, T])) == (T, P)
"""
parameters = []
# required TypeVarLike cannot appear after TypeVarLike with default
default_encountered = False
for t in args:
if isinstance(t, type):
# We don't want __parameters__ descriptor of a bare Python class.
pass
elif isinstance(t, tuple):
# `t` might be a tuple, when `ParamSpec` is substituted with
# `[T, int]`, or `[int, *Ts]`, etc.
for x in t:
for collected in _collect_parameters([x]):
if collected not in parameters:
parameters.append(collected)
elif hasattr(t, '__typing_subst__'):
if t not in parameters:
if getattr(t, '__default__', None) is not None:
default_encountered = True
elif default_encountered:
raise TypeError(f'Type parameter {t!r} without a default'
' follows type parameter with a default')
parameters.append(t)
else:
for x in getattr(t, '__parameters__', ()):
if x not in parameters:
parameters.append(x)
return tuple(parameters)
typing._collect_parameters = _collect_parameters
# Backport typing.NamedTuple as it exists in Python 3.13.
# In 3.11, the ability to define generic `NamedTuple`s was supported.

View file

@ -2,7 +2,7 @@
__init__.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -17,10 +17,10 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from ._abnf import *
from ._app import WebSocketApp, setReconnect
from ._app import WebSocketApp as WebSocketApp, setReconnect as setReconnect
from ._core import *
from ._exceptions import *
from ._logging import *
from ._socket import *
__version__ = "1.7.0"
__version__ = "1.8.0"

View file

@ -5,14 +5,14 @@ import sys
from threading import Lock
from typing import Callable, Optional, Union
from ._exceptions import *
from ._exceptions import WebSocketPayloadException, WebSocketProtocolException
from ._utils import validate_utf8
"""
_abnf.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View file

@ -13,13 +13,14 @@ from ._exceptions import (
WebSocketException,
WebSocketTimeoutException,
)
from ._ssl_compat import SSLEOFError
from ._url import parse_url
"""
_app.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -165,6 +166,7 @@ class WebSocketApp:
url: str,
header: Union[list, dict, Callable, None] = None,
on_open: Optional[Callable[[WebSocket], None]] = None,
on_reconnect: Optional[Callable[[WebSocket], None]] = None,
on_message: Optional[Callable[[WebSocket, Any], None]] = None,
on_error: Optional[Callable[[WebSocket, Any], None]] = None,
on_close: Optional[Callable[[WebSocket, Any, Any], None]] = None,
@ -194,6 +196,10 @@ class WebSocketApp:
Callback object which is called at opening websocket.
on_open has one argument.
The 1st argument is this class object.
on_reconnect: function
Callback object which is called at reconnecting websocket.
on_reconnect has one argument.
The 1st argument is this class object.
on_message: function
Callback object which is called when received data.
on_message has 2 arguments.
@ -244,6 +250,7 @@ class WebSocketApp:
self.cookie = cookie
self.on_open = on_open
self.on_reconnect = on_reconnect
self.on_message = on_message
self.on_data = on_data
self.on_error = on_error
@ -424,6 +431,7 @@ class WebSocketApp:
self.ping_interval = ping_interval
self.ping_timeout = ping_timeout
self.ping_payload = ping_payload
self.has_done_teardown = False
self.keep_running = True
def teardown(close_frame: ABNF = None):
@ -495,7 +503,10 @@ class WebSocketApp:
if self.ping_interval:
self._start_ping_thread()
self._callback(self.on_open)
if reconnecting and self.on_reconnect:
self._callback(self.on_reconnect)
else:
self._callback(self.on_open)
dispatcher.read(self.sock.sock, read, check)
except (
@ -516,9 +527,10 @@ class WebSocketApp:
except (
WebSocketConnectionClosedException,
KeyboardInterrupt,
SSLEOFError,
) as e:
if custom_dispatcher:
return handleDisconnect(e)
return handleDisconnect(e, bool(reconnect))
else:
raise e

View file

@ -5,7 +5,7 @@ from typing import Optional
_cookiejar.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -23,13 +23,13 @@ limitations under the License.
class SimpleCookieJar:
def __init__(self) -> None:
self.jar: dict = dict()
self.jar: dict = {}
def add(self, set_cookie: Optional[str]) -> None:
if set_cookie:
simpleCookie = http.cookies.SimpleCookie(set_cookie)
simple_cookie = http.cookies.SimpleCookie(set_cookie)
for k, v in simpleCookie.items():
for v in simple_cookie.values():
if domain := v.get("domain"):
if not domain.startswith("."):
domain = f".{domain}"
@ -38,25 +38,25 @@ class SimpleCookieJar:
if self.jar.get(domain)
else http.cookies.SimpleCookie()
)
cookie.update(simpleCookie)
cookie.update(simple_cookie)
self.jar[domain.lower()] = cookie
def set(self, set_cookie: str) -> None:
if set_cookie:
simpleCookie = http.cookies.SimpleCookie(set_cookie)
simple_cookie = http.cookies.SimpleCookie(set_cookie)
for k, v in simpleCookie.items():
for v in simple_cookie.values():
if domain := v.get("domain"):
if not domain.startswith("."):
domain = f".{domain}"
self.jar[domain.lower()] = simpleCookie
self.jar[domain.lower()] = simple_cookie
def get(self, host: str) -> str:
if not host:
return ""
cookies = []
for domain, simpleCookie in self.jar.items():
for domain, _ in self.jar.items():
host = host.lower()
if host.endswith(domain) or host == domain[1:]:
cookies.append(self.jar.get(domain))
@ -66,7 +66,7 @@ class SimpleCookieJar:
None,
sorted(
[
"%s=%s" % (k, v.value)
f"{k}={v.value}"
for cookie in filter(None, cookies)
for k, v in cookie.items()
]

View file

@ -5,20 +5,20 @@ import time
from typing import Optional, Union
# websocket modules
from ._abnf import *
from ._exceptions import *
from ._handshake import *
from ._http import *
from ._logging import *
from ._socket import *
from ._ssl_compat import *
from ._utils import *
from ._abnf import ABNF, STATUS_NORMAL, continuous_frame, frame_buffer
from ._exceptions import WebSocketProtocolException, WebSocketConnectionClosedException
from ._handshake import SUPPORTED_REDIRECT_STATUSES, handshake
from ._http import connect, proxy_info
from ._logging import debug, error, trace, isEnabledForError, isEnabledForTrace
from ._socket import getdefaulttimeout, recv, send, sock_opt
from ._ssl_compat import ssl
from ._utils import NoLock
"""
_core.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -259,7 +259,7 @@ class WebSocket:
try:
self.handshake_response = handshake(self.sock, url, *addrs, **options)
for attempt in range(options.pop("redirect_limit", 3)):
for _ in range(options.pop("redirect_limit", 3)):
if self.handshake_response.status in SUPPORTED_REDIRECT_STATUSES:
url = self.handshake_response.headers["location"]
self.sock.close()

View file

@ -2,7 +2,7 @@
_exceptions.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View file

@ -2,7 +2,7 @@
_handshake.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -23,10 +23,10 @@ from base64 import encodebytes as base64encode
from http import HTTPStatus
from ._cookiejar import SimpleCookieJar
from ._exceptions import *
from ._http import *
from ._logging import *
from ._socket import *
from ._exceptions import WebSocketException, WebSocketBadStatusException
from ._http import read_headers
from ._logging import dump, error
from ._socket import send
__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]

View file

@ -2,7 +2,7 @@
_http.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -21,11 +21,15 @@ import os
import socket
from base64 import encodebytes as base64encode
from ._exceptions import *
from ._logging import *
from ._socket import *
from ._ssl_compat import *
from ._url import *
from ._exceptions import (
WebSocketAddressException,
WebSocketException,
WebSocketProxyException,
)
from ._logging import debug, dump, trace
from ._socket import DEFAULT_SOCKET_OPTION, recv_line, send
from ._ssl_compat import HAVE_SSL, ssl
from ._url import get_proxy_info, parse_url
__all__ = ["proxy_info", "connect", "read_headers"]
@ -283,22 +287,22 @@ def _wrap_sni_socket(sock: socket.socket, sslopt: dict, hostname, check_hostname
def _ssl_socket(sock: socket.socket, user_sslopt: dict, hostname):
sslopt: dict = dict(cert_reqs=ssl.CERT_REQUIRED)
sslopt: dict = {"cert_reqs": ssl.CERT_REQUIRED}
sslopt.update(user_sslopt)
certPath = os.environ.get("WEBSOCKET_CLIENT_CA_BUNDLE")
cert_path = os.environ.get("WEBSOCKET_CLIENT_CA_BUNDLE")
if (
certPath
and os.path.isfile(certPath)
cert_path
and os.path.isfile(cert_path)
and user_sslopt.get("ca_certs", None) is None
):
sslopt["ca_certs"] = certPath
sslopt["ca_certs"] = cert_path
elif (
certPath
and os.path.isdir(certPath)
cert_path
and os.path.isdir(cert_path)
and user_sslopt.get("ca_cert_path", None) is None
):
sslopt["ca_cert_path"] = certPath
sslopt["ca_cert_path"] = cert_path
if sslopt.get("server_hostname", None):
hostname = sslopt["server_hostname"]
@ -327,7 +331,7 @@ def _tunnel(sock: socket.socket, host, port: int, auth) -> socket.socket:
send(sock, connect_header)
try:
status, resp_headers, status_message = read_headers(sock)
status, _, _ = read_headers(sock)
except Exception as e:
raise WebSocketProxyException(str(e))

View file

@ -4,7 +4,7 @@ import logging
_logging.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View file

@ -3,15 +3,18 @@ import selectors
import socket
from typing import Union
from ._exceptions import *
from ._ssl_compat import *
from ._utils import *
from ._exceptions import (
WebSocketConnectionClosedException,
WebSocketTimeoutException,
)
from ._ssl_compat import SSLError, SSLWantReadError, SSLWantWriteError
from ._utils import extract_error_code, extract_err_message
"""
_socket.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View file

@ -2,7 +2,7 @@
_ssl_compat.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -16,11 +16,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
__all__ = ["HAVE_SSL", "ssl", "SSLError", "SSLWantReadError", "SSLWantWriteError"]
__all__ = [
"HAVE_SSL",
"ssl",
"SSLError",
"SSLEOFError",
"SSLWantReadError",
"SSLWantWriteError",
]
try:
import ssl
from ssl import SSLError, SSLWantReadError, SSLWantWriteError
from ssl import SSLError, SSLEOFError, SSLWantReadError, SSLWantWriteError
HAVE_SSL = True
except ImportError:
@ -28,6 +35,9 @@ except ImportError:
class SSLError(Exception):
pass
class SSLEOFError(Exception):
pass
class SSLWantReadError(Exception):
pass

View file

@ -3,12 +3,13 @@ import socket
import struct
from typing import Optional
from urllib.parse import unquote, urlparse
from ._exceptions import WebSocketProxyException
"""
_url.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -167,6 +168,8 @@ def get_proxy_info(
return None, 0, None
if proxy_host:
if not proxy_port:
raise WebSocketProxyException("Cannot use port 0 when proxy_host specified")
port = proxy_port
auth = proxy_auth
return proxy_host, port, auth

View file

@ -4,7 +4,7 @@ from typing import Union
_url.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View file

@ -4,7 +4,7 @@
wsdump.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View file

@ -10,7 +10,7 @@ import websockets
LOCAL_WS_SERVER_PORT = int(os.environ.get("LOCAL_WS_SERVER_PORT", "8765"))
async def echo(websocket, path):
async def echo(websocket):
async for message in websocket:
await websocket.send(message)

View file

@ -2,14 +2,14 @@
#
import unittest
import websocket as ws
from websocket._abnf import *
from websocket._abnf import ABNF, frame_buffer
from websocket._exceptions import WebSocketProtocolException
"""
test_abnf.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -26,7 +26,7 @@ limitations under the License.
class ABNFTest(unittest.TestCase):
def testInit(self):
def test_init(self):
a = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_PING)
self.assertEqual(a.fin, 0)
self.assertEqual(a.rsv1, 0)
@ -38,28 +38,28 @@ class ABNFTest(unittest.TestCase):
self.assertEqual(a_bad.rsv1, 1)
self.assertEqual(a_bad.opcode, 77)
def testValidate(self):
def test_validate(self):
a_invalid_ping = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_PING)
self.assertRaises(
ws._exceptions.WebSocketProtocolException,
WebSocketProtocolException,
a_invalid_ping.validate,
skip_utf8_validation=False,
)
a_bad_rsv_value = ABNF(0, 1, 0, 0, opcode=ABNF.OPCODE_TEXT)
self.assertRaises(
ws._exceptions.WebSocketProtocolException,
WebSocketProtocolException,
a_bad_rsv_value.validate,
skip_utf8_validation=False,
)
a_bad_opcode = ABNF(0, 0, 0, 0, opcode=77)
self.assertRaises(
ws._exceptions.WebSocketProtocolException,
WebSocketProtocolException,
a_bad_opcode.validate,
skip_utf8_validation=False,
)
a_bad_close_frame = ABNF(0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x01")
self.assertRaises(
ws._exceptions.WebSocketProtocolException,
WebSocketProtocolException,
a_bad_close_frame.validate,
skip_utf8_validation=False,
)
@ -67,7 +67,7 @@ class ABNFTest(unittest.TestCase):
0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x01\x8a\xaa\xff\xdd"
)
self.assertRaises(
ws._exceptions.WebSocketProtocolException,
WebSocketProtocolException,
a_bad_close_frame_2.validate,
skip_utf8_validation=False,
)
@ -75,12 +75,12 @@ class ABNFTest(unittest.TestCase):
0, 0, 0, 0, opcode=ABNF.OPCODE_CLOSE, data=b"\x03\xe7"
)
self.assertRaises(
ws._exceptions.WebSocketProtocolException,
WebSocketProtocolException,
a_bad_close_frame_3.validate,
skip_utf8_validation=True,
)
def testMask(self):
def test_mask(self):
abnf_none_data = ABNF(
0, 0, 0, 0, opcode=ABNF.OPCODE_PING, mask_value=1, data=None
)
@ -91,7 +91,7 @@ class ABNFTest(unittest.TestCase):
)
self.assertEqual(abnf_str_data._get_masked(bytes_val), b"aaaa\x00")
def testFormat(self):
def test_format(self):
abnf_bad_rsv_bits = ABNF(2, 0, 0, 0, opcode=ABNF.OPCODE_TEXT)
self.assertRaises(ValueError, abnf_bad_rsv_bits.format)
abnf_bad_opcode = ABNF(0, 0, 0, 0, opcode=5)
@ -110,7 +110,7 @@ class ABNFTest(unittest.TestCase):
)
self.assertEqual(b"\x01\x03\x01\x8a\xcc", abnf_no_mask.format())
def testFrameBuffer(self):
def test_frame_buffer(self):
fb = frame_buffer(0, True)
self.assertEqual(fb.recv, 0)
self.assertEqual(fb.skip_utf8_validation, True)

View file

@ -12,7 +12,7 @@ import websocket as ws
test_app.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -53,10 +53,13 @@ class WebSocketAppTest(unittest.TestCase):
WebSocketAppTest.get_mask_key_id = WebSocketAppTest.NotSetYet()
WebSocketAppTest.on_error_data = WebSocketAppTest.NotSetYet()
def close(self):
pass
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def testKeepRunning(self):
def test_keep_running(self):
"""A WebSocketApp should keep running as long as its self.keep_running
is not False (in the boolean context).
"""
@ -69,7 +72,7 @@ class WebSocketAppTest(unittest.TestCase):
WebSocketAppTest.keep_running_open = self.keep_running
self.keep_running = False
def on_message(wsapp, message):
def on_message(_, message):
print(message)
self.close()
@ -87,7 +90,7 @@ class WebSocketAppTest(unittest.TestCase):
# @unittest.skipUnless(TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled")
@unittest.skipUnless(False, "Test disabled for now (requires rel)")
def testRunForeverDispatcher(self):
def test_run_forever_dispatcher(self):
"""A WebSocketApp should keep running as long as its self.keep_running
is not False (in the boolean context).
"""
@ -98,7 +101,7 @@ class WebSocketAppTest(unittest.TestCase):
self.recv()
self.send("goodbye!")
def on_message(wsapp, message):
def on_message(_, message):
print(message)
self.close()
@ -115,7 +118,7 @@ class WebSocketAppTest(unittest.TestCase):
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def testRunForeverTeardownCleanExit(self):
def test_run_forever_teardown_clean_exit(self):
"""The WebSocketApp.run_forever() method should return `False` when the application ends gracefully."""
app = ws.WebSocketApp(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
threading.Timer(interval=0.2, function=app.close).start()
@ -123,7 +126,7 @@ class WebSocketAppTest(unittest.TestCase):
self.assertEqual(teardown, False)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSockMaskKey(self):
def test_sock_mask_key(self):
"""A WebSocketApp should forward the received mask_key function down
to the actual socket.
"""
@ -140,14 +143,14 @@ class WebSocketAppTest(unittest.TestCase):
self.assertEqual(id(app.get_mask_key), id(my_mask_key_func))
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testInvalidPingIntervalPingTimeout(self):
def test_invalid_ping_interval_ping_timeout(self):
"""Test exception handling if ping_interval < ping_timeout"""
def on_ping(app, msg):
def on_ping(app, _):
print("Got a ping!")
app.close()
def on_pong(app, msg):
def on_pong(app, _):
print("Got a pong! No need to respond")
app.close()
@ -163,14 +166,14 @@ class WebSocketAppTest(unittest.TestCase):
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testPingInterval(self):
def test_ping_interval(self):
"""Test WebSocketApp proper ping functionality"""
def on_ping(app, msg):
def on_ping(app, _):
print("Got a ping!")
app.close()
def on_pong(app, msg):
def on_pong(app, _):
print("Got a pong! No need to respond")
app.close()
@ -182,7 +185,7 @@ class WebSocketAppTest(unittest.TestCase):
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testOpcodeClose(self):
def test_opcode_close(self):
"""Test WebSocketApp close opcode"""
app = ws.WebSocketApp("wss://tsock.us1.twilio.com/v3/wsconnect")
@ -197,7 +200,7 @@ class WebSocketAppTest(unittest.TestCase):
# app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload")
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testBadPingInterval(self):
def test_bad_ping_interval(self):
"""A WebSocketApp handling of negative ping_interval"""
app = ws.WebSocketApp("wss://api-pub.bitfinex.com/ws/1")
self.assertRaises(
@ -208,7 +211,7 @@ class WebSocketAppTest(unittest.TestCase):
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testBadPingTimeout(self):
def test_bad_ping_timeout(self):
"""A WebSocketApp handling of negative ping_timeout"""
app = ws.WebSocketApp("wss://api-pub.bitfinex.com/ws/1")
self.assertRaises(
@ -219,7 +222,7 @@ class WebSocketAppTest(unittest.TestCase):
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testCloseStatusCode(self):
def test_close_status_code(self):
"""Test extraction of close frame status code and close reason in WebSocketApp"""
def on_close(wsapp, close_status_code, close_msg):
@ -249,7 +252,7 @@ class WebSocketAppTest(unittest.TestCase):
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def testCallbackFunctionException(self):
def test_callback_function_exception(self):
"""Test callback function exception handling"""
exc = None
@ -264,7 +267,7 @@ class WebSocketAppTest(unittest.TestCase):
nonlocal exc
exc = err
def on_pong(app, msg):
def on_pong(app, _):
app.close()
app = ws.WebSocketApp(
@ -282,7 +285,7 @@ class WebSocketAppTest(unittest.TestCase):
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def testCallbackMethodException(self):
def test_callback_method_exception(self):
"""Test callback method exception handling"""
class Callbacks:
@ -297,14 +300,14 @@ class WebSocketAppTest(unittest.TestCase):
)
self.app.run_forever(ping_interval=2, ping_timeout=1)
def on_open(self, app):
def on_open(self, _):
raise RuntimeError("Callback failed")
def on_error(self, app, err):
self.passed_app = app
self.exc = err
def on_pong(self, app, msg):
def on_pong(self, app, _):
app.close()
callbacks = Callbacks()
@ -316,16 +319,16 @@ class WebSocketAppTest(unittest.TestCase):
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def testReconnect(self):
def test_reconnect(self):
"""Test reconnect"""
pong_count = 0
exc = None
def on_error(app, err):
def on_error(_, err):
nonlocal exc
exc = err
def on_pong(app, msg):
def on_pong(app, _):
nonlocal pong_count
pong_count += 1
if pong_count == 1:

View file

@ -6,7 +6,7 @@ from websocket._cookiejar import SimpleCookieJar
test_cookiejar.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -23,7 +23,7 @@ limitations under the License.
class CookieJarTest(unittest.TestCase):
def testAdd(self):
def test_add(self):
cookie_jar = SimpleCookieJar()
cookie_jar.add("")
self.assertFalse(
@ -67,7 +67,7 @@ class CookieJarTest(unittest.TestCase):
self.assertEqual(cookie_jar.get("xyz"), "e=f")
self.assertEqual(cookie_jar.get("something"), "")
def testSet(self):
def test_set(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b")
self.assertFalse(
@ -104,7 +104,7 @@ class CookieJarTest(unittest.TestCase):
self.assertEqual(cookie_jar.get("xyz"), "e=f")
self.assertEqual(cookie_jar.get("something"), "")
def testGet(self):
def test_get(self):
cookie_jar = SimpleCookieJar()
cookie_jar.set("a=b; c=d; domain=abc.com")
self.assertEqual(cookie_jar.get("abc.com"), "a=b; c=d")

View file

@ -7,7 +7,7 @@ import ssl
import unittest
import websocket
import websocket as ws
from websocket._exceptions import WebSocketProxyException, WebSocketException
from websocket._http import (
_get_addrinfo_list,
_start_proxied_socket,
@ -15,13 +15,14 @@ from websocket._http import (
connect,
proxy_info,
read_headers,
HAVE_PYTHON_SOCKS,
)
"""
test_http.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -93,20 +94,18 @@ class OptsList:
class HttpTest(unittest.TestCase):
def testReadHeader(self):
status, header, status_message = read_headers(
HeaderSockMock("data/header01.txt")
)
def test_read_header(self):
status, header, _ = read_headers(HeaderSockMock("data/header01.txt"))
self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade")
# header02.txt is intentionally malformed
self.assertRaises(
ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt")
WebSocketException, read_headers, HeaderSockMock("data/header02.txt")
)
def testTunnel(self):
def test_tunnel(self):
self.assertRaises(
ws.WebSocketProxyException,
WebSocketProxyException,
_tunnel,
HeaderSockMock("data/header01.txt"),
"example.com",
@ -114,7 +113,7 @@ class HttpTest(unittest.TestCase):
("username", "password"),
)
self.assertRaises(
ws.WebSocketProxyException,
WebSocketProxyException,
_tunnel,
HeaderSockMock("data/header02.txt"),
"example.com",
@ -123,9 +122,9 @@ class HttpTest(unittest.TestCase):
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testConnect(self):
def test_connect(self):
# Not currently testing an actual proxy connection, so just check whether proxy errors are raised. This requires internet for a DNS lookup
if ws._http.HAVE_PYTHON_SOCKS:
if HAVE_PYTHON_SOCKS:
# Need this check, otherwise case where python_socks is not installed triggers
# websocket._exceptions.WebSocketException: Python Socks is needed for SOCKS proxying but is not available
self.assertRaises(
@ -244,7 +243,7 @@ class HttpTest(unittest.TestCase):
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def testProxyConnect(self):
def test_proxy_connect(self):
ws = websocket.WebSocket()
ws.connect(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
@ -289,7 +288,7 @@ class HttpTest(unittest.TestCase):
# TODO: Test SOCKS4 and SOCK5 proxies with unit tests
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSSLopt(self):
def test_sslopt(self):
ssloptions = {
"check_hostname": False,
"server_hostname": "ServerName",
@ -315,7 +314,7 @@ class HttpTest(unittest.TestCase):
ws_ssl2.connect("wss://api.bitfinex.com/ws/2")
ws_ssl2.close
def testProxyInfo(self):
def test_proxy_info(self):
self.assertEqual(
proxy_info(
http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http"

View file

@ -9,12 +9,13 @@ from websocket._url import (
get_proxy_info,
parse_url,
)
from websocket._exceptions import WebSocketProxyException
"""
test_url.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -36,7 +37,7 @@ class UrlTest(unittest.TestCase):
self.assertTrue(_is_address_in_network("127.1.0.1", "127.0.0.0/8"))
self.assertFalse(_is_address_in_network("127.1.0.1", "127.0.0.0/24"))
def testParseUrl(self):
def test_parse_url(self):
p = parse_url("ws://www.example.com/r")
self.assertEqual(p[0], "www.example.com")
self.assertEqual(p[1], 80)
@ -130,9 +131,13 @@ class IsNoProxyHostTest(unittest.TestCase):
elif "no_proxy" in os.environ:
del os.environ["no_proxy"]
def testMatchAll(self):
def test_match_all(self):
self.assertTrue(_is_no_proxy_host("any.websocket.org", ["*"]))
self.assertTrue(_is_no_proxy_host("192.168.0.1", ["*"]))
self.assertFalse(_is_no_proxy_host("192.168.0.1", ["192.168.1.1"]))
self.assertFalse(
_is_no_proxy_host("any.websocket.org", ["other.websocket.org"])
)
self.assertTrue(
_is_no_proxy_host("any.websocket.org", ["other.websocket.org", "*"])
)
@ -142,7 +147,7 @@ class IsNoProxyHostTest(unittest.TestCase):
os.environ["no_proxy"] = "other.websocket.org, *"
self.assertTrue(_is_no_proxy_host("any.websocket.org", None))
def testIpAddress(self):
def test_ip_address(self):
self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.1"]))
self.assertFalse(_is_no_proxy_host("127.0.0.2", ["127.0.0.1"]))
self.assertTrue(
@ -158,7 +163,7 @@ class IsNoProxyHostTest(unittest.TestCase):
self.assertTrue(_is_no_proxy_host("127.0.0.1", None))
self.assertFalse(_is_no_proxy_host("127.0.0.2", None))
def testIpAddressInRange(self):
def test_ip_address_in_range(self):
self.assertTrue(_is_no_proxy_host("127.0.0.1", ["127.0.0.0/8"]))
self.assertTrue(_is_no_proxy_host("127.0.0.2", ["127.0.0.0/8"]))
self.assertFalse(_is_no_proxy_host("127.1.0.1", ["127.0.0.0/24"]))
@ -168,7 +173,7 @@ class IsNoProxyHostTest(unittest.TestCase):
os.environ["no_proxy"] = "127.0.0.0/24"
self.assertFalse(_is_no_proxy_host("127.1.0.1", None))
def testHostnameMatch(self):
def test_hostname_match(self):
self.assertTrue(_is_no_proxy_host("my.websocket.org", ["my.websocket.org"]))
self.assertTrue(
_is_no_proxy_host(
@ -182,7 +187,7 @@ class IsNoProxyHostTest(unittest.TestCase):
os.environ["no_proxy"] = "other.websocket.org, my.websocket.org"
self.assertTrue(_is_no_proxy_host("my.websocket.org", None))
def testHostnameMatchDomain(self):
def test_hostname_match_domain(self):
self.assertTrue(_is_no_proxy_host("any.websocket.org", [".websocket.org"]))
self.assertTrue(_is_no_proxy_host("my.other.websocket.org", [".websocket.org"]))
self.assertTrue(
@ -227,10 +232,13 @@ class ProxyInfoTest(unittest.TestCase):
elif "no_proxy" in os.environ:
del os.environ["no_proxy"]
def testProxyFromArgs(self):
self.assertEqual(
get_proxy_info("echo.websocket.events", False, proxy_host="localhost"),
("localhost", 0, None),
def test_proxy_from_args(self):
self.assertRaises(
WebSocketProxyException,
get_proxy_info,
"echo.websocket.events",
False,
proxy_host="localhost",
)
self.assertEqual(
get_proxy_info(
@ -238,10 +246,6 @@ class ProxyInfoTest(unittest.TestCase):
),
("localhost", 3128, None),
)
self.assertEqual(
get_proxy_info("echo.websocket.events", True, proxy_host="localhost"),
("localhost", 0, None),
)
self.assertEqual(
get_proxy_info(
"echo.websocket.events", True, proxy_host="localhost", proxy_port=3128
@ -254,9 +258,10 @@ class ProxyInfoTest(unittest.TestCase):
"echo.websocket.events",
False,
proxy_host="localhost",
proxy_port=9001,
proxy_auth=("a", "b"),
),
("localhost", 0, ("a", "b")),
("localhost", 9001, ("a", "b")),
)
self.assertEqual(
get_proxy_info(
@ -273,9 +278,10 @@ class ProxyInfoTest(unittest.TestCase):
"echo.websocket.events",
True,
proxy_host="localhost",
proxy_port=8765,
proxy_auth=("a", "b"),
),
("localhost", 0, ("a", "b")),
("localhost", 8765, ("a", "b")),
)
self.assertEqual(
get_proxy_info(
@ -311,7 +317,18 @@ class ProxyInfoTest(unittest.TestCase):
(None, 0, None),
)
def testProxyFromEnv(self):
self.assertEqual(
get_proxy_info(
"echo.websocket.events",
True,
proxy_host="localhost",
proxy_port=3128,
no_proxy=[".websocket.events"],
),
(None, 0, None),
)
def test_proxy_from_env(self):
os.environ["http_proxy"] = "http://localhost/"
self.assertEqual(
get_proxy_info("echo.websocket.events", False), ("localhost", None, None)

View file

@ -7,6 +7,7 @@ import unittest
from base64 import decodebytes as base64decode
import websocket as ws
from websocket._exceptions import WebSocketBadStatusException, WebSocketAddressException
from websocket._handshake import _create_sec_websocket_key
from websocket._handshake import _validate as _validate_header
from websocket._http import read_headers
@ -16,7 +17,7 @@ from websocket._utils import validate_utf8
test_websocket.py
websocket - WebSocket client library for Python
Copyright 2023 engn33r
Copyright 2024 engn33r
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -33,7 +34,6 @@ limitations under the License.
try:
import ssl
from ssl import SSLError
except ImportError:
# dummy class of SSLError for ssl none-support environment.
class SSLError(Exception):
@ -95,24 +95,24 @@ class WebSocketTest(unittest.TestCase):
def tearDown(self):
pass
def testDefaultTimeout(self):
def test_default_timeout(self):
self.assertEqual(ws.getdefaulttimeout(), None)
ws.setdefaulttimeout(10)
self.assertEqual(ws.getdefaulttimeout(), 10)
ws.setdefaulttimeout(None)
def testWSKey(self):
def test_ws_key(self):
key = _create_sec_websocket_key()
self.assertTrue(key != 24)
self.assertTrue("¥n" not in key)
def testNonce(self):
def test_nonce(self):
"""WebSocket key should be a random 16-byte nonce."""
key = _create_sec_websocket_key()
nonce = base64decode(key.encode("utf-8"))
self.assertEqual(16, len(nonce))
def testWsUtils(self):
def test_ws_utils(self):
key = "c6b8hTg4EeGb2gQMztV1/g=="
required_header = {
"upgrade": "websocket",
@ -157,16 +157,12 @@ class WebSocketTest(unittest.TestCase):
# This case will print out a logging error using the error() function, but that is expected
self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (False, None))
def testReadHeader(self):
status, header, status_message = read_headers(
HeaderSockMock("data/header01.txt")
)
def test_read_header(self):
status, header, _ = read_headers(HeaderSockMock("data/header01.txt"))
self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade")
status, header, status_message = read_headers(
HeaderSockMock("data/header03.txt")
)
status, header, _ = read_headers(HeaderSockMock("data/header03.txt"))
self.assertEqual(status, 101)
self.assertEqual(header["connection"], "Upgrade, Keep-Alive")
@ -175,7 +171,7 @@ class WebSocketTest(unittest.TestCase):
ws.WebSocketException, read_headers, HeaderSockMock("data/header02.txt")
)
def testSend(self):
def test_send(self):
# TODO: add longer frame data
sock = ws.WebSocket()
sock.set_mask_key(create_mask_key)
@ -194,7 +190,7 @@ class WebSocketTest(unittest.TestCase):
self.assertEqual(sock.send_binary(b"1111111111101"), 19)
def testRecv(self):
def test_recv(self):
# TODO: add longer frame data
sock = ws.WebSocket()
s = sock.sock = SockMock()
@ -210,7 +206,7 @@ class WebSocketTest(unittest.TestCase):
self.assertEqual(data, "Hello")
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testIter(self):
def test_iter(self):
count = 2
s = ws.create_connection("wss://api.bitfinex.com/ws/2")
s.send('{"event": "subscribe", "channel": "ticker"}')
@ -220,11 +216,11 @@ class WebSocketTest(unittest.TestCase):
break
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testNext(self):
def test_next(self):
sock = ws.create_connection("wss://api.bitfinex.com/ws/2")
self.assertEqual(str, type(next(sock)))
def testInternalRecvStrict(self):
def test_internal_recv_strict(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
s.add_packet(b"foo")
@ -241,7 +237,7 @@ class WebSocketTest(unittest.TestCase):
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.frame_buffer.recv_strict(1)
def testRecvTimeout(self):
def test_recv_timeout(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
s.add_packet(b"\x81")
@ -258,7 +254,7 @@ class WebSocketTest(unittest.TestCase):
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def testRecvWithSimpleFragmentation(self):
def test_recv_with_simple_fragmentation(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Brevity is "
@ -270,7 +266,7 @@ class WebSocketTest(unittest.TestCase):
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def testRecvWithFireEventOfFragmentation(self):
def test_recv_with_fire_event_of_fragmentation(self):
sock = ws.WebSocket(fire_cont_frame=True)
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Brevity is "
@ -296,7 +292,7 @@ class WebSocketTest(unittest.TestCase):
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def testClose(self):
def test_close(self):
sock = ws.WebSocket()
sock.connected = True
sock.close
@ -308,14 +304,14 @@ class WebSocketTest(unittest.TestCase):
sock.recv()
self.assertEqual(sock.connected, False)
def testRecvContFragmentation(self):
def test_recv_cont_fragmentation(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(b"\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")
self.assertRaises(ws.WebSocketException, sock.recv)
def testRecvWithProlongedFragmentation(self):
def test_recv_with_prolonged_fragmentation(self):
sock = ws.WebSocket()
s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, "
@ -331,7 +327,7 @@ class WebSocketTest(unittest.TestCase):
with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv()
def testRecvWithFragmentationAndControlFrame(self):
def test_recv_with_fragmentation_and_control_frame(self):
sock = ws.WebSocket()
sock.set_mask_key(create_mask_key)
s = sock.sock = SockMock()
@ -352,7 +348,7 @@ class WebSocketTest(unittest.TestCase):
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def testWebSocket(self):
def test_websocket(self):
s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
self.assertNotEqual(s, None)
s.send("Hello, World")
@ -369,7 +365,7 @@ class WebSocketTest(unittest.TestCase):
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def testPingPong(self):
def test_ping_pong(self):
s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
self.assertNotEqual(s, None)
s.ping("Hello")
@ -377,17 +373,13 @@ class WebSocketTest(unittest.TestCase):
s.close()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSupportRedirect(self):
def test_support_redirect(self):
s = ws.WebSocket()
self.assertRaises(
ws._exceptions.WebSocketBadStatusException, s.connect, "ws://google.com/"
)
self.assertRaises(WebSocketBadStatusException, s.connect, "ws://google.com/")
# Need to find a URL that has a redirect code leading to a websocket
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSecureWebSocket(self):
import ssl
def test_secure_websocket(self):
s = ws.create_connection("wss://api.bitfinex.com/ws/2")
self.assertNotEqual(s, None)
self.assertTrue(isinstance(s.sock, ssl.SSLSocket))
@ -401,7 +393,7 @@ class WebSocketTest(unittest.TestCase):
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def testWebSocketWithCustomHeader(self):
def test_websocket_with_custom_header(self):
s = ws.create_connection(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}",
headers={"User-Agent": "PythonWebsocketClient"},
@ -417,7 +409,7 @@ class WebSocketTest(unittest.TestCase):
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def testAfterClose(self):
def test_after_close(self):
s = ws.create_connection(f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}")
self.assertNotEqual(s, None)
s.close()
@ -429,7 +421,7 @@ class SockOptTest(unittest.TestCase):
@unittest.skipUnless(
TEST_WITH_LOCAL_SERVER, "Tests using local websocket server are disabled"
)
def testSockOpt(self):
def test_sockopt(self):
sockopt = ((socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),)
s = ws.create_connection(
f"ws://127.0.0.1:{LOCAL_WS_SERVER_PORT}", sockopt=sockopt
@ -441,7 +433,7 @@ class SockOptTest(unittest.TestCase):
class UtilsTest(unittest.TestCase):
def testUtf8Validator(self):
def test_utf8_validator(self):
state = validate_utf8(b"\xf0\x90\x80\x80")
self.assertEqual(state, True)
state = validate_utf8(
@ -454,7 +446,7 @@ class UtilsTest(unittest.TestCase):
class HandshakeTest(unittest.TestCase):
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_http_SSL(self):
def test_http_ssl(self):
websock1 = ws.WebSocket(
sslopt={"cert_chain": ssl.get_default_verify_paths().capath},
enable_multithread=False,
@ -466,7 +458,7 @@ class HandshakeTest(unittest.TestCase):
)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testManualHeaders(self):
def test_manual_headers(self):
websock3 = ws.WebSocket(
sslopt={
"ca_certs": ssl.get_default_verify_paths().cafile,
@ -474,7 +466,7 @@ class HandshakeTest(unittest.TestCase):
}
)
self.assertRaises(
ws._exceptions.WebSocketBadStatusException,
WebSocketBadStatusException,
websock3.connect,
"wss://api.bitfinex.com/ws/2",
cookie="chocolate",
@ -490,16 +482,14 @@ class HandshakeTest(unittest.TestCase):
},
)
def testIPv6(self):
def test_ipv6(self):
websock2 = ws.WebSocket()
self.assertRaises(ValueError, websock2.connect, "2001:4860:4860::8888")
def testBadURLs(self):
def test_bad_urls(self):
websock3 = ws.WebSocket()
self.assertRaises(ValueError, websock3.connect, "ws//example.com")
self.assertRaises(
ws.WebSocketAddressException, websock3.connect, "ws://example"
)
self.assertRaises(WebSocketAddressException, websock3.connect, "ws://example")
self.assertRaises(ValueError, websock3.connect, "example.com")

View file

@ -13,14 +13,10 @@
# You should have received a copy of the GNU General Public License
# along with Tautulli. If not, see <http://www.gnu.org/licenses/>.
from __future__ import unicode_literals
from future.builtins import range
from future.builtins import str
import ctypes
import datetime
import os
import future.moves.queue as queue
import queue
import sqlite3
import sys
import subprocess
@ -39,52 +35,27 @@ from apscheduler.triggers.interval import IntervalTrigger
from ga4mp import GtagMP
import pytz
PYTHON2 = sys.version_info[0] == 2
if PYTHON2:
import activity_handler
import activity_pinger
import common
import database
import datafactory
import exporter
import helpers
import libraries
import logger
import mobile_app
import newsletters
import newsletter_handler
import notification_handler
import notifiers
import plex
import plextv
import users
import versioncheck
import web_socket
import webstart
import config
else:
from plexpy import activity_handler
from plexpy import activity_pinger
from plexpy import common
from plexpy import database
from plexpy import datafactory
from plexpy import exporter
from plexpy import helpers
from plexpy import libraries
from plexpy import logger
from plexpy import mobile_app
from plexpy import newsletters
from plexpy import newsletter_handler
from plexpy import notification_handler
from plexpy import notifiers
from plexpy import plex
from plexpy import plextv
from plexpy import users
from plexpy import versioncheck
from plexpy import web_socket
from plexpy import webstart
from plexpy import config
from plexpy import activity_handler
from plexpy import activity_pinger
from plexpy import common
from plexpy import database
from plexpy import datafactory
from plexpy import exporter
from plexpy import helpers
from plexpy import libraries
from plexpy import logger
from plexpy import mobile_app
from plexpy import newsletters
from plexpy import newsletter_handler
from plexpy import notification_handler
from plexpy import notifiers
from plexpy import plex
from plexpy import plextv
from plexpy import users
from plexpy import versioncheck
from plexpy import web_socket
from plexpy import webstart
from plexpy import config
PROG_DIR = None
@ -214,11 +185,10 @@ def initialize(config_file):
logger.initLogger(console=not QUIET, log_dir=CONFIG.LOG_DIR if log_writable else None,
verbose=VERBOSE)
if not PYTHON2:
os.environ['PLEXAPI_CONFIG_PATH'] = os.path.join(DATA_DIR, 'plexapi.config.ini')
os.environ['PLEXAPI_LOG_PATH'] = os.path.join(CONFIG.LOG_DIR, 'plexapi.log')
os.environ['PLEXAPI_LOG_LEVEL'] = 'DEBUG'
plex.initialize_plexapi()
os.environ['PLEXAPI_CONFIG_PATH'] = os.path.join(DATA_DIR, 'plexapi.config.ini')
os.environ['PLEXAPI_LOG_PATH'] = os.path.join(CONFIG.LOG_DIR, 'plexapi.log')
os.environ['PLEXAPI_LOG_LEVEL'] = 'DEBUG'
plex.initialize_plexapi()
if DOCKER:
build = '[Docker] '

View file

@ -13,10 +13,6 @@
# You should have received a copy of the GNU General Public License
# along with Tautulli. If not, see <http://www.gnu.org/licenses/>.
from __future__ import unicode_literals
from future.builtins import str
from future.builtins import object
import datetime
import os
import time
@ -25,22 +21,13 @@ from apscheduler.triggers.date import DateTrigger
import pytz
import plexpy
if plexpy.PYTHON2:
import activity_processor
import common
import datafactory
import helpers
import logger
import notification_handler
import pmsconnect
else:
from plexpy import activity_processor
from plexpy import common
from plexpy import datafactory
from plexpy import helpers
from plexpy import logger
from plexpy import notification_handler
from plexpy import pmsconnect
from plexpy import activity_processor
from plexpy import common
from plexpy import datafactory
from plexpy import helpers
from plexpy import logger
from plexpy import notification_handler
from plexpy import pmsconnect
ACTIVITY_SCHED = None

View file

@ -13,34 +13,18 @@
# You should have received a copy of the GNU General Public License
# along with Tautulli. If not, see <http://www.gnu.org/licenses/>.
from __future__ import unicode_literals
from future.builtins import str
import threading
import plexpy
if plexpy.PYTHON2:
import activity_handler
import activity_processor
import database
import helpers
import libraries
import logger
import notification_handler
import plextv
import pmsconnect
import web_socket
else:
from plexpy import activity_handler
from plexpy import activity_processor
from plexpy import database
from plexpy import helpers
from plexpy import libraries
from plexpy import logger
from plexpy import notification_handler
from plexpy import plextv
from plexpy import pmsconnect
from plexpy import web_socket
from plexpy import activity_handler
from plexpy import activity_processor
from plexpy import database
from plexpy import helpers
from plexpy import logger
from plexpy import notification_handler
from plexpy import plextv
from plexpy import pmsconnect
from plexpy import web_socket
monitor_lock = threading.Lock()

View file

@ -13,28 +13,16 @@
# You should have received a copy of the GNU General Public License
# along with Tautulli. If not, see <http://www.gnu.org/licenses/>.
from __future__ import unicode_literals
from future.builtins import str
from future.builtins import object
from collections import defaultdict
import json
import plexpy
if plexpy.PYTHON2:
import database
import helpers
import libraries
import logger
import pmsconnect
import users
else:
from plexpy import database
from plexpy import helpers
from plexpy import libraries
from plexpy import logger
from plexpy import pmsconnect
from plexpy import users
from plexpy import database
from plexpy import helpers
from plexpy import libraries
from plexpy import logger
from plexpy import pmsconnect
from plexpy import users
class ActivityProcessor(object):

Some files were not shown because too many files have changed in this diff Show more