Bump cheroot from 8.6.0 to 9.0.0 (#1903)

* Bump cheroot from 8.6.0 to 9.0.0

Bumps [cheroot](https://github.com/cherrypy/cheroot) from 8.6.0 to 9.0.0.
- [Release notes](https://github.com/cherrypy/cheroot/releases)
- [Changelog](https://github.com/cherrypy/cheroot/blob/main/CHANGES.rst)
- [Commits](https://github.com/cherrypy/cheroot/compare/v8.6.0...v9.0.0)

---
updated-dependencies:
- dependency-name: cheroot
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update cheroot==9.0.0

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: JonnyWong16 <9099342+JonnyWong16@users.noreply.github.com>

[skip ci]
This commit is contained in:
dependabot[bot] 2022-12-21 15:58:54 -08:00 committed by GitHub
parent 0a5edebea3
commit 3d378eb583
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
33 changed files with 287 additions and 851 deletions

View file

@ -1,15 +1,12 @@
"""High-performance, pure-Python HTTP server used by CherryPy.""" """High-performance, pure-Python HTTP server used by CherryPy."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
try: try:
import pkg_resources from importlib import metadata
except ImportError: except ImportError:
pass import importlib_metadata as metadata # noqa: WPS440
try: try:
__version__ = pkg_resources.get_distribution('cheroot').version __version__ = metadata.version('cheroot')
except Exception: except Exception:
__version__ = 'unknown' __version__ = 'unknown'

View file

@ -1,19 +1,9 @@
# pylint: disable=unused-import # pylint: disable=unused-import
"""Compatibility code for using Cheroot with various versions of Python.""" """Compatibility code for using Cheroot with various versions of Python."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import os import os
import platform import platform
import re
import six
try:
import selectors # lgtm [py/unused-import]
except ImportError:
import selectors2 as selectors # noqa: F401 # lgtm [py/unused-import]
try: try:
import ssl import ssl
@ -22,20 +12,6 @@ try:
except ImportError: except ImportError:
IS_ABOVE_OPENSSL10 = None IS_ABOVE_OPENSSL10 = None
# contextlib.suppress was added in Python 3.4
try:
from contextlib import suppress
except ImportError:
from contextlib import contextmanager
@contextmanager
def suppress(*exceptions):
"""Return a context manager that suppresses the `exceptions`."""
try:
yield
except exceptions:
pass
IS_CI = bool(os.getenv('CI')) IS_CI = bool(os.getenv('CI'))
IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv('GITHUB_WORKFLOW')) IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv('GITHUB_WORKFLOW'))
@ -53,53 +29,23 @@ PLATFORM_ARCH = platform.machine()
IS_PPC = PLATFORM_ARCH.startswith('ppc') IS_PPC = PLATFORM_ARCH.startswith('ppc')
if not six.PY2: def ntob(n, encoding='ISO-8859-1'):
def ntob(n, encoding='ISO-8859-1'): """Return the native string as bytes in the given encoding."""
"""Return the native string as bytes in the given encoding.""" assert_native(n)
assert_native(n) # In Python 3, the native string type is unicode
# In Python 3, the native string type is unicode return n.encode(encoding)
return n.encode(encoding)
def ntou(n, encoding='ISO-8859-1'):
"""Return the native string as Unicode with the given encoding."""
assert_native(n)
# In Python 3, the native string type is unicode
return n
def bton(b, encoding='ISO-8859-1'): def ntou(n, encoding='ISO-8859-1'):
"""Return the byte string as native string in the given encoding.""" """Return the native string as Unicode with the given encoding."""
return b.decode(encoding) assert_native(n)
else: # In Python 3, the native string type is unicode
# Python 2 return n
def ntob(n, encoding='ISO-8859-1'):
"""Return the native string as bytes in the given encoding."""
assert_native(n)
# In Python 2, the native string type is bytes. Assume it's already
# in the given encoding, which for ISO-8859-1 is almost always what
# was intended.
return n
def ntou(n, encoding='ISO-8859-1'):
"""Return the native string as Unicode with the given encoding."""
assert_native(n)
# In Python 2, the native string type is bytes.
# First, check for the special encoding 'escape'. The test suite uses
# this to signal that it wants to pass a string with embedded \uXXXX
# escapes, but without having to prefix it with u'' for Python 2,
# but no prefix for Python 3.
if encoding == 'escape':
return re.sub(
r'\\u([0-9a-zA-Z]{4})',
lambda m: six.unichr(int(m.group(1), 16)),
n.decode('ISO-8859-1'),
)
# Assume it's already in the given encoding, which for ISO-8859-1
# is almost always what was intended.
return n.decode(encoding)
def bton(b, encoding='ISO-8859-1'): def bton(b, encoding='ISO-8859-1'):
"""Return the byte string as native string in the given encoding.""" """Return the byte string as native string in the given encoding."""
return b return b.decode(encoding)
def assert_native(n): def assert_native(n):
@ -113,17 +59,6 @@ def assert_native(n):
raise TypeError('n must be a native str (got %s)' % type(n).__name__) raise TypeError('n must be a native str (got %s)' % type(n).__name__)
if not six.PY2:
"""Python 3 has :py:class:`memoryview` builtin."""
# Python 2.7 has it backported, but socket.write() does
# str(memoryview(b'0' * 100)) -> <memory at 0x7fb6913a5588>
# instead of accessing it correctly.
memoryview = memoryview
else:
"""Link :py:class:`memoryview` to buffer under Python 2."""
memoryview = buffer # noqa: F821
def extract_bytes(mv): def extract_bytes(mv):
r"""Retrieve bytes out of the given input buffer. r"""Retrieve bytes out of the given input buffer.
@ -138,7 +73,7 @@ def extract_bytes(mv):
or :py:class:`bytes` or :py:class:`bytes`
""" """
if isinstance(mv, memoryview): if isinstance(mv, memoryview):
return bytes(mv) if six.PY2 else mv.tobytes() return mv.tobytes()
if isinstance(mv, bytes): if isinstance(mv, bytes):
return mv return mv

21
lib/cheroot/_compat.pyi Normal file
View file

@ -0,0 +1,21 @@
from typing import Any, ContextManager, Optional, Type, Union
def suppress(*exceptions: Type[BaseException]) -> ContextManager[None]: ...
IS_ABOVE_OPENSSL10: Optional[bool]
IS_CI: bool
IS_GITHUB_ACTIONS_WORKFLOW: bool
IS_PYPY: bool
SYS_PLATFORM: str
IS_WINDOWS: bool
IS_LINUX: bool
IS_MACOS: bool
PLATFORM_ARCH: str
IS_PPC: bool
def ntob(n: str, encoding: str = ...) -> bytes: ...
def ntou(n: str, encoding: str = ...) -> str: ...
def bton(b: bytes, encoding: str = ...) -> str: ...
def assert_native(n: str) -> None: ...
def extract_bytes(mv: Union[memoryview, bytes]) -> bytes: ...

View file

@ -28,18 +28,14 @@ Basic usage:
""" """
import argparse import argparse
from importlib import import_module
import os import os
import sys import sys
import urllib.parse # noqa: WPS301
import six from importlib import import_module
from contextlib import suppress
from . import server from . import server
from . import wsgi from . import wsgi
from ._compat import suppress
__metaclass__ = type
class BindLocation: class BindLocation:
@ -143,7 +139,7 @@ def parse_wsgi_bind_location(bind_addr_string):
return AbstractSocket(bind_addr_string[1:]) return AbstractSocket(bind_addr_string[1:])
# try and match for an IP/hostname and port # try and match for an IP/hostname and port
match = six.moves.urllib.parse.urlparse( match = urllib.parse.urlparse(
'//{addr}'.format(addr=bind_addr_string), '//{addr}'.format(addr=bind_addr_string),
) )
try: try:

View file

@ -1,22 +1,17 @@
"""Utilities to manage open connections.""" """Utilities to manage open connections."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import io import io
import os import os
import socket import socket
import threading import threading
import time import time
import selectors
from contextlib import suppress
from . import errors from . import errors
from ._compat import selectors
from ._compat import suppress
from ._compat import IS_WINDOWS from ._compat import IS_WINDOWS
from .makefile import MakeFile from .makefile import MakeFile
import six
try: try:
import fcntl import fcntl
except ImportError: except ImportError:
@ -310,8 +305,7 @@ class ConnectionManager:
msg, msg,
] ]
sock_to_make = s if not six.PY2 else s._sock wfile = mf(s, 'wb', io.DEFAULT_BUFFER_SIZE)
wfile = mf(sock_to_make, 'wb', io.DEFAULT_BUFFER_SIZE)
try: try:
wfile.write(''.join(buf).encode('ISO-8859-1')) wfile.write(''.join(buf).encode('ISO-8859-1'))
except socket.error as ex: except socket.error as ex:
@ -327,10 +321,7 @@ class ConnectionManager:
conn = self.server.ConnectionClass(self.server, s, mf) conn = self.server.ConnectionClass(self.server, s, mf)
if not isinstance( if not isinstance(self.server.bind_addr, (str, bytes)):
self.server.bind_addr,
(six.text_type, six.binary_type),
):
# optional values # optional values
# Until we do DNS lookups, omit REMOTE_HOST # Until we do DNS lookups, omit REMOTE_HOST
if addr is None: # sometimes this can happen if addr is None: # sometimes this can happen

View file

@ -1,17 +1,14 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""Collection of exceptions raised and/or processed by Cheroot.""" """Collection of exceptions raised and/or processed by Cheroot."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import errno import errno
import sys import sys
class MaxSizeExceeded(Exception): class MaxSizeExceeded(Exception):
"""Exception raised when a client sends more data then acceptable within limit. """Exception raised when a client sends more data then allowed under limit.
Depends on ``request.body.maxbytes`` config option if used within CherryPy Depends on ``request.body.maxbytes`` config option if used within CherryPy.
""" """

View file

@ -1,4 +1,4 @@
from typing import Any, List, Set, Tuple from typing import List, Set, Tuple, Type
class MaxSizeExceeded(Exception): ... class MaxSizeExceeded(Exception): ...
class NoSSLError(Exception): ... class NoSSLError(Exception): ...
@ -10,4 +10,4 @@ socket_error_eintr: List[int]
socket_errors_to_ignore: List[int] socket_errors_to_ignore: List[int]
socket_errors_nonblocking: List[int] socket_errors_nonblocking: List[int]
acceptable_sock_shutdown_error_codes: Set[int] acceptable_sock_shutdown_error_codes: Set[int]
acceptable_sock_shutdown_exceptions: Tuple[Exception] acceptable_sock_shutdown_exceptions: Tuple[Type[Exception], ...]

View file

@ -1,21 +1,9 @@
"""Socket file object.""" """Socket file object."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import socket import socket
try: # prefer slower Python-based io module
# prefer slower Python-based io module import _pyio as io
import _pyio as io
except ImportError:
# Python 2.6
import io
import six
from . import errors
from ._compat import extract_bytes, memoryview
# Write only 16K at a time to sockets # Write only 16K at a time to sockets
@ -48,400 +36,41 @@ class BufferedWriter(io.BufferedWriter):
del self._write_buf[:n] del self._write_buf[:n]
class MakeFile_PY2(getattr(socket, '_fileobject', object)): class StreamReader(io.BufferedReader):
"""Faux file object attached to a socket object.""" """Socket stream reader."""
def __init__(self, *args, **kwargs): def __init__(self, sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE):
"""Initialize faux file object.""" """Initialize socket stream reader."""
super().__init__(socket.SocketIO(sock, mode), bufsize)
self.bytes_read = 0 self.bytes_read = 0
def read(self, *args, **kwargs):
"""Capture bytes read."""
val = super().read(*args, **kwargs)
self.bytes_read += len(val)
return val
def has_data(self):
"""Return true if there is buffered data to read."""
return len(self._read_buf) > self._read_pos
class StreamWriter(BufferedWriter):
"""Socket stream writer."""
def __init__(self, sock, mode='w', bufsize=io.DEFAULT_BUFFER_SIZE):
"""Initialize socket stream writer."""
super().__init__(socket.SocketIO(sock, mode), bufsize)
self.bytes_written = 0 self.bytes_written = 0
socket._fileobject.__init__(self, *args, **kwargs)
self._refcount = 0
def _reuse(self): def write(self, val, *args, **kwargs):
self._refcount += 1 """Capture bytes written."""
res = super().write(val, *args, **kwargs)
def _drop(self): self.bytes_written += len(val)
if self._refcount < 0: return res
self.close()
else:
self._refcount -= 1
def write(self, data):
"""Send entire data contents for non-blocking sockets."""
bytes_sent = 0
data_mv = memoryview(data)
payload_size = len(data_mv)
while bytes_sent < payload_size:
try:
bytes_sent += self.send(
data_mv[bytes_sent:bytes_sent + SOCK_WRITE_BLOCKSIZE],
)
except socket.error as e:
if e.args[0] not in errors.socket_errors_nonblocking:
raise
def send(self, data):
"""Send some part of message to the socket."""
bytes_sent = self._sock.send(extract_bytes(data))
self.bytes_written += bytes_sent
return bytes_sent
def flush(self):
"""Write all data from buffer to socket and reset write buffer."""
if self._wbuf:
buffer = ''.join(self._wbuf)
self._wbuf = []
self.write(buffer)
def recv(self, size):
"""Receive message of a size from the socket."""
while True:
try:
data = self._sock.recv(size)
self.bytes_read += len(data)
return data
except socket.error as e:
what = (
e.args[0] not in errors.socket_errors_nonblocking
and e.args[0] not in errors.socket_error_eintr
)
if what:
raise
class FauxSocket:
"""Faux socket with the minimal interface required by pypy."""
def _reuse(self):
pass
_fileobject_uses_str_type = six.PY2 and isinstance(
socket._fileobject(FauxSocket())._rbuf, six.string_types,
)
# FauxSocket is no longer needed
del FauxSocket
if not _fileobject_uses_str_type: # noqa: C901 # FIXME
def read(self, size=-1):
"""Read data from the socket to buffer."""
# Use max, disallow tiny reads in a loop as they are very
# inefficient.
# We never leave read() with any leftover data from a new recv()
# call in our internal buffer.
rbufsize = max(self._rbufsize, self.default_bufsize)
# Our use of StringIO rather than lists of string objects returned
# by recv() minimizes memory usage and fragmentation that occurs
# when rbufsize is large compared to the typical return value of
# recv().
buf = self._rbuf
buf.seek(0, 2) # seek end
if size < 0:
# Read until EOF
# reset _rbuf. we consume it via buf.
self._rbuf = io.BytesIO()
while True:
data = self.recv(rbufsize)
if not data:
break
buf.write(data)
return buf.getvalue()
else:
# Read until size bytes or EOF seen, whichever comes first
buf_len = buf.tell()
if buf_len >= size:
# Already have size bytes in our buffer? Extract and
# return.
buf.seek(0)
rv = buf.read(size)
self._rbuf = io.BytesIO()
self._rbuf.write(buf.read())
return rv
# reset _rbuf. we consume it via buf.
self._rbuf = io.BytesIO()
while True:
left = size - buf_len
# recv() will malloc the amount of memory given as its
# parameter even though it often returns much less data
# than that. The returned data string is short lived
# as we copy it into a StringIO and free it. This avoids
# fragmentation issues on many platforms.
data = self.recv(left)
if not data:
break
n = len(data)
if n == size and not buf_len:
# Shortcut. Avoid buffer data copies when:
# - We have no data in our buffer.
# AND
# - Our call to recv returned exactly the
# number of bytes we were asked to read.
return data
if n == left:
buf.write(data)
del data # explicit free
break
assert n <= left, 'recv(%d) returned %d bytes' % (left, n)
buf.write(data)
buf_len += n
del data # explicit free
# assert buf_len == buf.tell()
return buf.getvalue()
def readline(self, size=-1):
"""Read line from the socket to buffer."""
buf = self._rbuf
buf.seek(0, 2) # seek end
if buf.tell() > 0:
# check if we already have it in our buffer
buf.seek(0)
bline = buf.readline(size)
if bline.endswith('\n') or len(bline) == size:
self._rbuf = io.BytesIO()
self._rbuf.write(buf.read())
return bline
del bline
if size < 0:
# Read until \n or EOF, whichever comes first
if self._rbufsize <= 1:
# Speed up unbuffered case
buf.seek(0)
buffers = [buf.read()]
# reset _rbuf. we consume it via buf.
self._rbuf = io.BytesIO()
data = None
recv = self.recv
while data != '\n':
data = recv(1)
if not data:
break
buffers.append(data)
return ''.join(buffers)
buf.seek(0, 2) # seek end
# reset _rbuf. we consume it via buf.
self._rbuf = io.BytesIO()
while True:
data = self.recv(self._rbufsize)
if not data:
break
nl = data.find('\n')
if nl >= 0:
nl += 1
buf.write(data[:nl])
self._rbuf.write(data[nl:])
del data
break
buf.write(data)
return buf.getvalue()
else:
# Read until size bytes or \n or EOF seen, whichever comes
# first
buf.seek(0, 2) # seek end
buf_len = buf.tell()
if buf_len >= size:
buf.seek(0)
rv = buf.read(size)
self._rbuf = io.BytesIO()
self._rbuf.write(buf.read())
return rv
# reset _rbuf. we consume it via buf.
self._rbuf = io.BytesIO()
while True:
data = self.recv(self._rbufsize)
if not data:
break
left = size - buf_len
# did we just receive a newline?
nl = data.find('\n', 0, left)
if nl >= 0:
nl += 1
# save the excess data to _rbuf
self._rbuf.write(data[nl:])
if buf_len:
buf.write(data[:nl])
break
else:
# Shortcut. Avoid data copy through buf when
# returning a substring of our first recv().
return data[:nl]
n = len(data)
if n == size and not buf_len:
# Shortcut. Avoid data copy through buf when
# returning exactly all of our first recv().
return data
if n >= left:
buf.write(data[:left])
self._rbuf.write(data[left:])
break
buf.write(data)
buf_len += n
# assert buf_len == buf.tell()
return buf.getvalue()
def has_data(self):
"""Return true if there is buffered data to read."""
return bool(self._rbuf.getvalue())
else:
def read(self, size=-1):
"""Read data from the socket to buffer."""
if size < 0:
# Read until EOF
buffers = [self._rbuf]
self._rbuf = ''
if self._rbufsize <= 1:
recv_size = self.default_bufsize
else:
recv_size = self._rbufsize
while True:
data = self.recv(recv_size)
if not data:
break
buffers.append(data)
return ''.join(buffers)
else:
# Read until size bytes or EOF seen, whichever comes first
data = self._rbuf
buf_len = len(data)
if buf_len >= size:
self._rbuf = data[size:]
return data[:size]
buffers = []
if data:
buffers.append(data)
self._rbuf = ''
while True:
left = size - buf_len
recv_size = max(self._rbufsize, left)
data = self.recv(recv_size)
if not data:
break
buffers.append(data)
n = len(data)
if n >= left:
self._rbuf = data[left:]
buffers[-1] = data[:left]
break
buf_len += n
return ''.join(buffers)
def readline(self, size=-1):
"""Read line from the socket to buffer."""
data = self._rbuf
if size < 0:
# Read until \n or EOF, whichever comes first
if self._rbufsize <= 1:
# Speed up unbuffered case
assert data == ''
buffers = []
while data != '\n':
data = self.recv(1)
if not data:
break
buffers.append(data)
return ''.join(buffers)
nl = data.find('\n')
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
return data[:nl]
buffers = []
if data:
buffers.append(data)
self._rbuf = ''
while True:
data = self.recv(self._rbufsize)
if not data:
break
buffers.append(data)
nl = data.find('\n')
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
buffers[-1] = data[:nl]
break
return ''.join(buffers)
else:
# Read until size bytes or \n or EOF seen, whichever comes
# first
nl = data.find('\n', 0, size)
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
return data[:nl]
buf_len = len(data)
if buf_len >= size:
self._rbuf = data[size:]
return data[:size]
buffers = []
if data:
buffers.append(data)
self._rbuf = ''
while True:
data = self.recv(self._rbufsize)
if not data:
break
buffers.append(data)
left = size - buf_len
nl = data.find('\n', 0, left)
if nl >= 0:
nl += 1
self._rbuf = data[nl:]
buffers[-1] = data[:nl]
break
n = len(data)
if n >= left:
self._rbuf = data[left:]
buffers[-1] = data[:left]
break
buf_len += n
return ''.join(buffers)
def has_data(self):
"""Return true if there is buffered data to read."""
return bool(self._rbuf)
if not six.PY2: def MakeFile(sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE):
class StreamReader(io.BufferedReader): """File object attached to a socket object."""
"""Socket stream reader.""" cls = StreamReader if 'r' in mode else StreamWriter
return cls(sock, mode, bufsize)
def __init__(self, sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE):
"""Initialize socket stream reader."""
super().__init__(socket.SocketIO(sock, mode), bufsize)
self.bytes_read = 0
def read(self, *args, **kwargs):
"""Capture bytes read."""
val = super().read(*args, **kwargs)
self.bytes_read += len(val)
return val
def has_data(self):
"""Return true if there is buffered data to read."""
return len(self._read_buf) > self._read_pos
class StreamWriter(BufferedWriter):
"""Socket stream writer."""
def __init__(self, sock, mode='w', bufsize=io.DEFAULT_BUFFER_SIZE):
"""Initialize socket stream writer."""
super().__init__(socket.SocketIO(sock, mode), bufsize)
self.bytes_written = 0
def write(self, val, *args, **kwargs):
"""Capture bytes written."""
res = super().write(val, *args, **kwargs)
self.bytes_written += len(val)
return res
def MakeFile(sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE):
"""File object attached to a socket object."""
cls = StreamReader if 'r' in mode else StreamWriter
return cls(sock, mode, bufsize)
else:
StreamReader = StreamWriter = MakeFile = MakeFile_PY2

View file

@ -5,19 +5,6 @@ SOCK_WRITE_BLOCKSIZE: int
class BufferedWriter(io.BufferedWriter): class BufferedWriter(io.BufferedWriter):
def write(self, b): ... def write(self, b): ...
class MakeFile_PY2:
bytes_read: int
bytes_written: int
def __init__(self, *args, **kwargs) -> None: ...
def write(self, data) -> None: ...
def send(self, data): ...
def flush(self) -> None: ...
def recv(self, size): ...
class FauxSocket: ...
def read(self, size: int = ...): ...
def readline(self, size: int = ...): ...
def has_data(self): ...
class StreamReader(io.BufferedReader): class StreamReader(io.BufferedReader):
bytes_read: int bytes_read: int
def __init__(self, sock, mode: str = ..., bufsize=...) -> None: ... def __init__(self, sock, mode: str = ..., bufsize=...) -> None: ...

View file

@ -65,9 +65,6 @@ And now for a trivial doctest to exercise the test suite
True True
""" """
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import os import os
import io import io
import re import re
@ -78,20 +75,14 @@ import time
import traceback as traceback_ import traceback as traceback_
import logging import logging
import platform import platform
import queue
import contextlib import contextlib
import threading import threading
import urllib.parse
try: from functools import lru_cache
from functools import lru_cache
except ImportError:
from backports.functools_lru_cache import lru_cache
import six
from six.moves import queue
from six.moves import urllib
from . import connections, errors, __version__ from . import connections, errors, __version__
from ._compat import bton, ntou from ._compat import bton
from ._compat import IS_PPC from ._compat import IS_PPC
from .workers import threadpool from .workers import threadpool
from .makefile import MakeFile, StreamWriter from .makefile import MakeFile, StreamWriter
@ -606,8 +597,8 @@ class ChunkedRFile:
def read_trailer_lines(self): def read_trailer_lines(self):
"""Read HTTP headers and yield them. """Read HTTP headers and yield them.
Returns: :yields: CRLF separated lines
Generator: yields CRLF separated lines. :ytype: bytes
""" """
if not self.closed: if not self.closed:
@ -817,10 +808,6 @@ class HTTPRequest:
return False return False
try: try:
if six.PY2: # FIXME: Figure out better way to do this
# Ref: https://stackoverflow.com/a/196392/595220 (like this?)
"""This is a dummy check for unicode in URI."""
ntou(bton(uri, 'ascii'), 'ascii')
scheme, authority, path, qs, fragment = urllib.parse.urlsplit(uri) scheme, authority, path, qs, fragment = urllib.parse.urlsplit(uri)
except UnicodeError: except UnicodeError:
self.simple_response('400 Bad Request', 'Malformed Request-URI') self.simple_response('400 Bad Request', 'Malformed Request-URI')
@ -1120,7 +1107,7 @@ class HTTPRequest:
buf.append(CRLF) buf.append(CRLF)
if msg: if msg:
if isinstance(msg, six.text_type): if isinstance(msg, str):
msg = msg.encode('ISO-8859-1') msg = msg.encode('ISO-8859-1')
buf.append(msg) buf.append(msg)
@ -1422,10 +1409,7 @@ class HTTPConnection:
https://github.com/daveti/tcpSockHack https://github.com/daveti/tcpSockHack
msdn.microsoft.com/en-us/commandline/wsl/release_notes#build-15025 msdn.microsoft.com/en-us/commandline/wsl/release_notes#build-15025
""" """
six.raise_from( # 3.6+: raise RuntimeError from socket_err raise RuntimeError from socket_err
RuntimeError,
socket_err,
)
else: else:
pid, uid, gid = struct.unpack(PEERCRED_STRUCT_DEF, peer_creds) pid, uid, gid = struct.unpack(PEERCRED_STRUCT_DEF, peer_creds)
return pid, uid, gid return pid, uid, gid
@ -1589,7 +1573,7 @@ class HTTPServer:
""" """
keep_alive_conn_limit = 10 keep_alive_conn_limit = 10
"""The maximum number of waiting keep-alive connections that will be kept open. """Maximum number of waiting keep-alive connections that will be kept open.
Default is 10. Set to None to have unlimited connections.""" Default is 10. Set to None to have unlimited connections."""
@ -1762,13 +1746,13 @@ class HTTPServer:
if os.getenv('LISTEN_PID', None): if os.getenv('LISTEN_PID', None):
# systemd socket activation # systemd socket activation
self.socket = socket.fromfd(3, socket.AF_INET, socket.SOCK_STREAM) self.socket = socket.fromfd(3, socket.AF_INET, socket.SOCK_STREAM)
elif isinstance(self.bind_addr, (six.text_type, six.binary_type)): elif isinstance(self.bind_addr, (str, bytes)):
# AF_UNIX socket # AF_UNIX socket
try: try:
self.bind_unix_socket(self.bind_addr) self.bind_unix_socket(self.bind_addr)
except socket.error as serr: except socket.error as serr:
msg = '%s -- (%s: %s)' % (msg, self.bind_addr, serr) msg = '%s -- (%s: %s)' % (msg, self.bind_addr, serr)
six.raise_from(socket.error(msg), serr) raise socket.error(msg) from serr
else: else:
# AF_INET or AF_INET6 socket # AF_INET or AF_INET6 socket
# Get the correct address family for our host (allows IPv6 # Get the correct address family for our host (allows IPv6
@ -2007,10 +1991,7 @@ class HTTPServer:
* https://gavv.github.io/blog/ephemeral-port-reuse/ * https://gavv.github.io/blog/ephemeral-port-reuse/
""" """
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if nodelay and not isinstance( if nodelay and not isinstance(bind_addr, (str, bytes)):
bind_addr,
(six.text_type, six.binary_type),
):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
if ssl_adapter is not None: if ssl_adapter is not None:
@ -2059,7 +2040,7 @@ class HTTPServer:
""" """
return bind_addr[:2] return bind_addr[:2]
if isinstance(bind_addr, six.binary_type): if isinstance(bind_addr, bytes):
bind_addr = bton(bind_addr) bind_addr = bton(bind_addr)
return bind_addr return bind_addr
@ -2109,10 +2090,7 @@ class HTTPServer:
sock = getattr(self, 'socket', None) sock = getattr(self, 'socket', None)
if sock: if sock:
if not isinstance( if not isinstance(self.bind_addr, (str, bytes)):
self.bind_addr,
(six.text_type, six.binary_type),
):
# Touch our own socket to make accept() return immediately. # Touch our own socket to make accept() return immediately.
try: try:
host, port = sock.getsockname()[:2] host, port = sock.getsockname()[:2]
@ -2179,7 +2157,7 @@ ssl_adapters = {
def get_ssl_adapter_class(name='builtin'): def get_ssl_adapter_class(name='builtin'):
"""Return an SSL adapter class for the given name.""" """Return an SSL adapter class for the given name."""
adapter = ssl_adapters[name.lower()] adapter = ssl_adapters[name.lower()]
if isinstance(adapter, six.string_types): if isinstance(adapter, str):
last_dot = adapter.rfind('.') last_dot = adapter.rfind('.')
attr_name = adapter[last_dot + 1:] attr_name = adapter[last_dot + 1:]
mod_path = adapter[:last_dot] mod_path = adapter[:last_dot]

View file

@ -1,15 +1,9 @@
"""Implementation of the SSL adapter base interface.""" """Implementation of the SSL adapter base interface."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from six import add_metaclass
class Adapter(metaclass=ABCMeta):
@add_metaclass(ABCMeta)
class Adapter:
"""Base class for SSL driver library adapters. """Base class for SSL driver library adapters.
Required methods: Required methods:

View file

@ -7,12 +7,10 @@ To use this module, set ``HTTPServer.ssl_adapter`` to an instance of
``BuiltinSSLAdapter``. ``BuiltinSSLAdapter``.
""" """
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import socket import socket
import sys import sys
import threading import threading
from contextlib import suppress
try: try:
import ssl import ssl
@ -27,18 +25,13 @@ except ImportError:
except ImportError: except ImportError:
DEFAULT_BUFFER_SIZE = -1 DEFAULT_BUFFER_SIZE = -1
import six
from . import Adapter from . import Adapter
from .. import errors from .. import errors
from .._compat import IS_ABOVE_OPENSSL10, suppress from .._compat import IS_ABOVE_OPENSSL10
from ..makefile import StreamReader, StreamWriter from ..makefile import StreamReader, StreamWriter
from ..server import HTTPServer from ..server import HTTPServer
if six.PY2: generic_socket_error = OSError
generic_socket_error = socket.error
else:
generic_socket_error = OSError
def _assert_ssl_exc_contains(exc, *msgs): def _assert_ssl_exc_contains(exc, *msgs):

View file

@ -1,7 +1,6 @@
from typing import Any from typing import Any
from . import Adapter from . import Adapter
generic_socket_error: OSError
DEFAULT_BUFFER_SIZE: int DEFAULT_BUFFER_SIZE: int
class BuiltinSSLAdapter(Adapter): class BuiltinSSLAdapter(Adapter):
@ -14,5 +13,5 @@ class BuiltinSSLAdapter(Adapter):
def context(self, context) -> None: ... def context(self, context) -> None: ...
def bind(self, sock): ... def bind(self, sock): ...
def wrap(self, sock): ... def wrap(self, sock): ...
def get_environ(self): ... def get_environ(self, sock): ...
def makefile(self, sock, mode: str = ..., bufsize: int = ...): ... def makefile(self, sock, mode: str = ..., bufsize: int = ...): ...

View file

@ -50,16 +50,11 @@ will be read, and the context will be automatically created from them.
pyopenssl pyopenssl
""" """
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import socket import socket
import sys import sys
import threading import threading
import time import time
import six
try: try:
import OpenSSL.version import OpenSSL.version
from OpenSSL import SSL from OpenSSL import SSL
@ -229,8 +224,7 @@ class SSLConnectionProxyMeta:
return type(name, bases, nmspc) return type(name, bases, nmspc)
@six.add_metaclass(SSLConnectionProxyMeta) class SSLConnection(metaclass=SSLConnectionProxyMeta):
class SSLConnection:
r"""A thread-safe wrapper for an ``SSL.Connection``. r"""A thread-safe wrapper for an ``SSL.Connection``.
:param tuple args: the arguments to create the wrapped \ :param tuple args: the arguments to create the wrapped \

View file

@ -1,9 +1,9 @@
from . import Adapter from . import Adapter
from ..makefile import StreamReader, StreamWriter from ..makefile import StreamReader, StreamWriter
from OpenSSL import SSL from OpenSSL import SSL
from typing import Any from typing import Any, Type
ssl_conn_type: SSL.Connection ssl_conn_type: Type[SSL.Connection]
class SSLFileobjectMixin: class SSLFileobjectMixin:
ssl_timeout: int ssl_timeout: int
@ -13,13 +13,13 @@ class SSLFileobjectMixin:
def sendall(self, *args, **kwargs): ... def sendall(self, *args, **kwargs): ...
def send(self, *args, **kwargs): ... def send(self, *args, **kwargs): ...
class SSLFileobjectStreamReader(SSLFileobjectMixin, StreamReader): ... # type:ignore class SSLFileobjectStreamReader(SSLFileobjectMixin, StreamReader): ... # type:ignore[misc]
class SSLFileobjectStreamWriter(SSLFileobjectMixin, StreamWriter): ... # type:ignore class SSLFileobjectStreamWriter(SSLFileobjectMixin, StreamWriter): ... # type:ignore[misc]
class SSLConnectionProxyMeta: class SSLConnectionProxyMeta:
def __new__(mcl, name, bases, nmspc): ... def __new__(mcl, name, bases, nmspc): ...
class SSLConnection(): class SSLConnection:
def __init__(self, *args) -> None: ... def __init__(self, *args) -> None: ...
class pyOpenSSLAdapter(Adapter): class pyOpenSSLAdapter(Adapter):
@ -28,3 +28,4 @@ class pyOpenSSLAdapter(Adapter):
def wrap(self, sock): ... def wrap(self, sock): ...
def get_environ(self): ... def get_environ(self): ...
def makefile(self, sock, mode: str = ..., bufsize: int = ...): ... def makefile(self, sock, mode: str = ..., bufsize: int = ...): ...
def get_context(self) -> SSL.Context: ...

View file

@ -8,6 +8,7 @@ from __future__ import absolute_import, division, print_function
__metaclass__ = type __metaclass__ = type
import pytest import pytest
import six
pytest_version = tuple(map(int, pytest.__version__.split('.'))) pytest_version = tuple(map(int, pytest.__version__.split('.')))
@ -43,8 +44,17 @@ def pytest_load_initial_conftests(early_config, parser, args):
'<socket.socket fd=-1, family=AF_INET6, ' '<socket.socket fd=-1, family=AF_INET6, '
'type=SocketKind.SOCK_STREAM, proto=.:' 'type=SocketKind.SOCK_STREAM, proto=.:'
'pytest.PytestUnraisableExceptionWarning:_pytest.unraisableexception', 'pytest.PytestUnraisableExceptionWarning:_pytest.unraisableexception',
'ignore:Exception ignored in. ' ))
'<ssl.SSLSocket fd=-1, family=AddressFamily.AF_UNIX, '
'type=SocketKind.SOCK_STREAM, proto=.:' if six.PY2:
'pytest.PytestUnraisableExceptionWarning:_pytest.unraisableexception', return
# NOTE: `ResourceWarning` does not exist under Python 2 and so using
# NOTE: it in warning filters results in an `_OptionError` exception
# NOTE: being raised.
early_config._inicache['filterwarnings'].extend((
# FIXME: Try to figure out what causes this and ensure that the socket
# FIXME: gets closed.
'ignore:unclosed <socket.socket fd=:ResourceWarning',
'ignore:unclosed <ssl.SSLSocket fd=:ResourceWarning',
)) ))

View file

@ -4,14 +4,12 @@ Contains fixtures, which are tightly bound to the Cheroot framework
itself, useless for end-users' app testing. itself, useless for end-users' app testing.
""" """
from __future__ import absolute_import, division, print_function
__metaclass__ = type # pylint: disable=invalid-name
import threading import threading
import time import time
import pytest import pytest
from .._compat import IS_MACOS, IS_WINDOWS # noqa: WPS436
from ..server import Gateway, HTTPServer from ..server import Gateway, HTTPServer
from ..testing import ( # noqa: F401 # pylint: disable=unused-import from ..testing import ( # noqa: F401 # pylint: disable=unused-import
native_server, wsgi_server, native_server, wsgi_server,
@ -19,6 +17,20 @@ from ..testing import ( # noqa: F401 # pylint: disable=unused-import
from ..testing import get_server_client from ..testing import get_server_client
@pytest.fixture
def http_request_timeout():
"""Return a common HTTP request timeout for tests with queries."""
computed_timeout = 0.1
if IS_MACOS:
computed_timeout *= 2
if IS_WINDOWS:
computed_timeout *= 10
return computed_timeout
@pytest.fixture @pytest.fixture
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
def wsgi_server_client(wsgi_server): # noqa: F811 def wsgi_server_client(wsgi_server): # noqa: F811

View file

@ -1,8 +1,5 @@
"""A library of helper functions for the Cheroot test suite.""" """A library of helper functions for the Cheroot test suite."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import datetime import datetime
import logging import logging
import os import os
@ -10,10 +7,7 @@ import sys
import time import time
import threading import threading
import types import types
import http.client
from six.moves import http_client
import six
import cheroot.server import cheroot.server
import cheroot.wsgi import cheroot.wsgi
@ -60,14 +54,14 @@ class CherootWebCase(webtest.WebCase):
cls.scheme = 'http' cls.scheme = 'http'
else: else:
ssl = ' (ssl)' ssl = ' (ssl)'
cls.HTTP_CONN = http_client.HTTPSConnection cls.HTTP_CONN = http.client.HTTPSConnection
cls.scheme = 'https' cls.scheme = 'https'
v = sys.version.split()[0] v = sys.version.split()[0]
log.info('Python version used to run this test script: %s' % v) log.info('Python version used to run this test script: %s', v)
log.info('Cheroot version: %s' % cheroot.__version__) log.info('Cheroot version: %s', cheroot.__version__)
log.info('HTTP server version: %s%s' % (cls.httpserver.protocol, ssl)) log.info('HTTP server version: %s%s', cls.httpserver.protocol, ssl)
log.info('PID: %s' % os.getpid()) log.info('PID: %s', os.getpid())
if hasattr(cls, 'setup_server'): if hasattr(cls, 'setup_server'):
# Clear the wsgi server so that # Clear the wsgi server so that
@ -135,9 +129,9 @@ class Response:
"""Generate iterable response body object.""" """Generate iterable response body object."""
if self.body is None: if self.body is None:
return [] return []
elif isinstance(self.body, six.text_type): elif isinstance(self.body, str):
return [self.body.encode('iso-8859-1')] return [self.body.encode('iso-8859-1')]
elif isinstance(self.body, six.binary_type): elif isinstance(self.body, bytes):
return [self.body] return [self.body]
else: else:
return [x.encode('iso-8859-1') for x in self.body] return [x.encode('iso-8859-1') for x in self.body]

View file

@ -1,13 +1,8 @@
# -*- coding: utf-8 -*-
"""Test suite for cross-python compatibility helpers.""" """Test suite for cross-python compatibility helpers."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pytest import pytest
import six
from cheroot._compat import extract_bytes, memoryview, ntob, ntou, bton from cheroot._compat import extract_bytes, ntob, ntou, bton
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -32,7 +27,7 @@ def test_compat_functions_positive(func, inp, out):
) )
def test_compat_functions_negative_nonnative(func): def test_compat_functions_negative_nonnative(func):
"""Check that compatibility functions fail loudly for incorrect input.""" """Check that compatibility functions fail loudly for incorrect input."""
non_native_test_str = u'bar' if six.PY2 else b'bar' non_native_test_str = b'bar'
with pytest.raises(TypeError): with pytest.raises(TypeError):
func(non_native_test_str, encoding='utf-8') func(non_native_test_str, encoding='utf-8')

View file

@ -4,11 +4,8 @@
cli cli
""" """
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
import sys import sys
import six
import pytest import pytest
from cheroot.cli import ( from cheroot.cli import (
@ -69,12 +66,6 @@ def wsgi_app(monkeypatch):
app = WSGIAppMock() app = WSGIAppMock()
# patch sys.modules, to include the an instance of WSGIAppMock # patch sys.modules, to include the an instance of WSGIAppMock
# under a specific namespace # under a specific namespace
if six.PY2:
# python2 requires the previous namespaces to be part of sys.modules
# (e.g. for 'a.b.c' we need to insert 'a', 'a.b' and 'a.b.c')
# otherwise it fails, we're setting the same instance on each level,
# we don't really care about those, just the last one.
monkeypatch.setitem(sys.modules, 'mypkg', app)
monkeypatch.setitem(sys.modules, 'mypkg.wsgi', app) monkeypatch.setitem(sys.modules, 'mypkg.wsgi', app)
return app return app

View file

@ -1,18 +1,14 @@
"""Tests for TCP connection handling, including proper and timely close.""" """Tests for TCP connection handling, including proper and timely close."""
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import errno import errno
import socket import socket
import time import time
import logging import logging
import traceback as traceback_ import traceback as traceback_
from collections import namedtuple from collections import namedtuple
import http.client
import urllib.request
from six.moves import range, http_client, urllib
import six
import pytest import pytest
from jaraco.text import trim, unwrap from jaraco.text import trim, unwrap
@ -94,8 +90,6 @@ class Controller(helper.Controller):
WSGI 1.0 is a mess around unicode. Create endpoints WSGI 1.0 is a mess around unicode. Create endpoints
that match the PATH_INFO that it produces. that match the PATH_INFO that it produces.
""" """
if six.PY2:
return string
return string.encode('utf-8').decode('latin-1') return string.encode('utf-8').decode('latin-1')
handlers = { handlers = {
@ -242,7 +236,7 @@ def test_HTTP11_persistent_connections(test_client):
assert header_has_value('Connection', 'close', actual_headers) assert header_has_value('Connection', 'close', actual_headers)
# Make another request on the same connection, which should error. # Make another request on the same connection, which should error.
with pytest.raises(http_client.NotConnected): with pytest.raises(http.client.NotConnected):
test_client.get('/pov', http_conn=http_connection) test_client.get('/pov', http_conn=http_connection)
@ -309,7 +303,7 @@ def test_streaming_11(test_client, set_cl):
# Make another request on the same connection, which should # Make another request on the same connection, which should
# error. # error.
with pytest.raises(http_client.NotConnected): with pytest.raises(http.client.NotConnected):
test_client.get('/pov', http_conn=http_connection) test_client.get('/pov', http_conn=http_connection)
# Try HEAD. # Try HEAD.
@ -324,6 +318,9 @@ def test_streaming_11(test_client, set_cl):
assert actual_resp_body == b'' assert actual_resp_body == b''
assert not header_exists('Transfer-Encoding', actual_headers) assert not header_exists('Transfer-Encoding', actual_headers)
# Prevent the resource warnings:
http_connection.close()
@pytest.mark.parametrize( @pytest.mark.parametrize(
'set_cl', 'set_cl',
@ -389,7 +386,7 @@ def test_streaming_10(test_client, set_cl):
assert not header_exists('Transfer-Encoding', actual_headers) assert not header_exists('Transfer-Encoding', actual_headers)
# Make another request on the same connection, which should error. # Make another request on the same connection, which should error.
with pytest.raises(http_client.NotConnected): with pytest.raises(http.client.NotConnected):
test_client.get( test_client.get(
'/pov', http_conn=http_connection, '/pov', http_conn=http_connection,
protocol='HTTP/1.0', protocol='HTTP/1.0',
@ -397,6 +394,9 @@ def test_streaming_10(test_client, set_cl):
test_client.server_instance.protocol = original_server_protocol test_client.server_instance.protocol = original_server_protocol
# Prevent the resource warnings:
http_connection.close()
@pytest.mark.parametrize( @pytest.mark.parametrize(
'http_server_protocol', 'http_server_protocol',
@ -466,6 +466,9 @@ def test_keepalive(test_client, http_server_protocol):
test_client.server_instance.protocol = original_server_protocol test_client.server_instance.protocol = original_server_protocol
# Prevent the resource warnings:
http_connection.close()
def test_keepalive_conn_management(test_client): def test_keepalive_conn_management(test_client):
"""Test management of Keep-Alive connections.""" """Test management of Keep-Alive connections."""
@ -511,9 +514,9 @@ def test_keepalive_conn_management(test_client):
) )
disconnect_errors = ( disconnect_errors = (
http_client.BadStatusLine, http.client.BadStatusLine,
http_client.CannotSendRequest, http.client.CannotSendRequest,
http_client.NotConnected, http.client.NotConnected,
) )
# Make a new connection. # Make a new connection.
@ -565,6 +568,11 @@ def test_keepalive_conn_management(test_client):
# Restore original timeout. # Restore original timeout.
test_client.server_instance.timeout = timeout test_client.server_instance.timeout = timeout
# Prevent the resource warnings:
c1.close()
c2.close()
c3.close()
@pytest.mark.parametrize( @pytest.mark.parametrize(
('simulated_exception', 'error_number', 'exception_leaks'), ('simulated_exception', 'error_number', 'exception_leaks'),
@ -597,7 +605,6 @@ def test_keepalive_conn_management(test_client):
pytest.param(RuntimeError, 666, True, id='RuntimeError(666)'), pytest.param(RuntimeError, 666, True, id='RuntimeError(666)'),
pytest.param(socket.error, -1, True, id='socket.error(-1)'), pytest.param(socket.error, -1, True, id='socket.error(-1)'),
) + ( ) + (
() if six.PY2 else (
pytest.param( pytest.param(
ConnectionResetError, errno.ECONNRESET, False, ConnectionResetError, errno.ECONNRESET, False,
id='ConnectionResetError(ECONNRESET)', id='ConnectionResetError(ECONNRESET)',
@ -610,7 +617,6 @@ def test_keepalive_conn_management(test_client):
BrokenPipeError, errno.ESHUTDOWN, False, BrokenPipeError, errno.ESHUTDOWN, False,
id='BrokenPipeError(ESHUTDOWN)', id='BrokenPipeError(ESHUTDOWN)',
), ),
)
), ),
) )
def test_broken_connection_during_tcp_fin( def test_broken_connection_during_tcp_fin(
@ -765,7 +771,7 @@ def test_HTTP11_Timeout_after_request(test_client):
response = conn.response_class(conn.sock, method='GET') response = conn.response_class(conn.sock, method='GET')
try: try:
response.begin() response.begin()
except (socket.error, http_client.BadStatusLine): except (socket.error, http.client.BadStatusLine):
pass pass
except Exception as ex: except Exception as ex:
pytest.fail(fail_msg % ex) pytest.fail(fail_msg % ex)
@ -795,7 +801,7 @@ def test_HTTP11_Timeout_after_request(test_client):
response = conn.response_class(conn.sock, method='GET') response = conn.response_class(conn.sock, method='GET')
try: try:
response.begin() response.begin()
except (socket.error, http_client.BadStatusLine): except (socket.error, http.client.BadStatusLine):
pass pass
except Exception as ex: except Exception as ex:
pytest.fail(fail_msg % ex) pytest.fail(fail_msg % ex)
@ -845,8 +851,7 @@ def test_HTTP11_pipelining(test_client):
# ``conn.sock``. Until that bug get's fixed we will # ``conn.sock``. Until that bug get's fixed we will
# monkey patch the ``response`` instance. # monkey patch the ``response`` instance.
# https://bugs.python.org/issue23377 # https://bugs.python.org/issue23377
if not six.PY2: response.fp = conn.sock.makefile('rb', 0)
response.fp = conn.sock.makefile('rb', 0)
response.begin() response.begin()
body = response.read(13) body = response.read(13)
assert response.status == 200 assert response.status == 200
@ -1026,6 +1031,9 @@ def test_No_Message_Body(test_client):
assert actual_resp_body == b'' assert actual_resp_body == b''
assert not header_exists('Connection', actual_headers) assert not header_exists('Connection', actual_headers)
# Prevent the resource warnings:
http_connection.close()
@pytest.mark.xfail( @pytest.mark.xfail(
reason=unwrap( reason=unwrap(

View file

@ -1,16 +1,10 @@
"""Tests for managing HTTP issues (malformed requests, etc).""" """Tests for managing HTTP issues (malformed requests, etc)."""
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import errno import errno
import socket import socket
import urllib.parse # noqa: WPS301
import pytest import pytest
import six
from six.moves import urllib
from cheroot.test import helper from cheroot.test import helper
@ -54,8 +48,6 @@ class HelloController(helper.Controller):
WSGI 1.0 is a mess around unicode. Create endpoints WSGI 1.0 is a mess around unicode. Create endpoints
that match the PATH_INFO that it produces. that match the PATH_INFO that it produces.
""" """
if six.PY2:
return string
return string.encode('utf-8').decode('latin-1') return string.encode('utf-8').decode('latin-1')
handlers = { handlers = {
@ -63,7 +55,13 @@ class HelloController(helper.Controller):
'/no_body': hello, '/no_body': hello,
'/body_required': body_required, '/body_required': body_required,
'/query_string': query_string, '/query_string': query_string,
# FIXME: Unignore the pylint rules in pylint >= 2.15.4.
# Refs:
# * https://github.com/PyCQA/pylint/issues/6592
# * https://github.com/PyCQA/pylint/pull/7395
# pylint: disable-next=too-many-function-args
_munge('/привіт'): hello, _munge('/привіт'): hello,
# pylint: disable-next=too-many-function-args
_munge('/Юххууу'): hello, _munge('/Юххууу'): hello,
'/\xa0Ðblah key 0 900 4 data': hello, '/\xa0Ðblah key 0 900 4 data': hello,
'/*': asterisk, '/*': asterisk,
@ -151,7 +149,6 @@ def test_parse_acceptable_uri(test_client, uri):
assert actual_status == HTTP_OK assert actual_status == HTTP_OK
@pytest.mark.xfail(six.PY2, reason='Fails on Python 2')
def test_parse_uri_unsafe_uri(test_client): def test_parse_uri_unsafe_uri(test_client):
"""Test that malicious URI does not allow HTTP injection. """Test that malicious URI does not allow HTTP injection.
@ -263,6 +260,8 @@ def test_no_content_length(test_client):
assert actual_status == HTTP_OK assert actual_status == HTTP_OK
assert actual_resp_body == b'Hello world!' assert actual_resp_body == b'Hello world!'
c.close() # deal with the resource warning
def test_content_length_required(test_client): def test_content_length_required(test_client):
"""Test POST query with body failing because of missing Content-Length.""" """Test POST query with body failing because of missing Content-Length."""
@ -278,6 +277,8 @@ def test_content_length_required(test_client):
actual_status = response.status actual_status = response.status
assert actual_status == HTTP_LENGTH_REQUIRED assert actual_status == HTTP_LENGTH_REQUIRED
c.close() # deal with the resource warning
@pytest.mark.xfail( @pytest.mark.xfail(
reason='https://github.com/cherrypy/cheroot/issues/106', reason='https://github.com/cherrypy/cheroot/issues/106',
@ -350,6 +351,8 @@ def test_malformed_http_method(test_client):
actual_resp_body = response.read(21) actual_resp_body = response.read(21)
assert actual_resp_body == b'Malformed method name' assert actual_resp_body == b'Malformed method name'
c.close() # deal with the resource warning
def test_malformed_header(test_client): def test_malformed_header(test_client):
"""Check that broken HTTP header results in Bad Request.""" """Check that broken HTTP header results in Bad Request."""
@ -366,6 +369,8 @@ def test_malformed_header(test_client):
actual_resp_body = response.read(20) actual_resp_body = response.read(20)
assert actual_resp_body == b'Illegal header line.' assert actual_resp_body == b'Illegal header line.'
c.close() # deal with the resource warning
def test_request_line_split_issue_1220(test_client): def test_request_line_split_issue_1220(test_client):
"""Check that HTTP request line of exactly 256 chars length is OK.""" """Check that HTTP request line of exactly 256 chars length is OK."""

View file

@ -1,8 +1,4 @@
"""Tests for the HTTP server.""" """Tests for the HTTP server."""
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
from __future__ import absolute_import, division, print_function
from cheroot.wsgi import PathInfoDispatcher from cheroot.wsgi import PathInfoDispatcher

View file

@ -3,9 +3,6 @@
from cheroot import makefile from cheroot import makefile
__metaclass__ = type
class MockSocket: class MockSocket:
"""A mock socket.""" """A mock socket."""

View file

@ -1,23 +1,18 @@
"""Tests for the HTTP server.""" """Tests for the HTTP server."""
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import os import os
import queue
import socket import socket
import tempfile import tempfile
import threading import threading
import uuid import uuid
import urllib.parse # noqa: WPS301
import pytest import pytest
import requests import requests
import requests_unixsocket import requests_unixsocket
import six
from pypytools.gc.custom import DefaultGc from pypytools.gc.custom import DefaultGc
from six.moves import queue, urllib
from .._compat import bton, ntob from .._compat import bton, ntob
from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS, SYS_PLATFORM from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS, SYS_PLATFORM
@ -259,12 +254,12 @@ def peercreds_enabled_server(http_server, unix_sock_file):
@unix_only_sock_test @unix_only_sock_test
@non_macos_sock_test @non_macos_sock_test
def test_peercreds_unix_sock(peercreds_enabled_server): def test_peercreds_unix_sock(http_request_timeout, peercreds_enabled_server):
"""Check that ``PEERCRED`` lookup works when enabled.""" """Check that ``PEERCRED`` lookup works when enabled."""
httpserver = peercreds_enabled_server httpserver = peercreds_enabled_server
bind_addr = httpserver.bind_addr bind_addr = httpserver.bind_addr
if isinstance(bind_addr, six.binary_type): if isinstance(bind_addr, bytes):
bind_addr = bind_addr.decode() bind_addr = bind_addr.decode()
# pylint: disable=possibly-unused-variable # pylint: disable=possibly-unused-variable
@ -275,11 +270,17 @@ def test_peercreds_unix_sock(peercreds_enabled_server):
expected_peercreds = '|'.join(map(str, expected_peercreds)) expected_peercreds = '|'.join(map(str, expected_peercreds))
with requests_unixsocket.monkeypatch(): with requests_unixsocket.monkeypatch():
peercreds_resp = requests.get(unix_base_uri + PEERCRED_IDS_URI) peercreds_resp = requests.get(
unix_base_uri + PEERCRED_IDS_URI,
timeout=http_request_timeout,
)
peercreds_resp.raise_for_status() peercreds_resp.raise_for_status()
assert peercreds_resp.text == expected_peercreds assert peercreds_resp.text == expected_peercreds
peercreds_text_resp = requests.get(unix_base_uri + PEERCRED_TEXTS_URI) peercreds_text_resp = requests.get(
unix_base_uri + PEERCRED_TEXTS_URI,
timeout=http_request_timeout,
)
assert peercreds_text_resp.status_code == 500 assert peercreds_text_resp.status_code == 500
@ -290,14 +291,17 @@ def test_peercreds_unix_sock(peercreds_enabled_server):
) )
@unix_only_sock_test @unix_only_sock_test
@non_macos_sock_test @non_macos_sock_test
def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server): def test_peercreds_unix_sock_with_lookup(
http_request_timeout,
peercreds_enabled_server,
):
"""Check that ``PEERCRED`` resolution works when enabled.""" """Check that ``PEERCRED`` resolution works when enabled."""
httpserver = peercreds_enabled_server httpserver = peercreds_enabled_server
httpserver.peercreds_resolve_enabled = True httpserver.peercreds_resolve_enabled = True
bind_addr = httpserver.bind_addr bind_addr = httpserver.bind_addr
if isinstance(bind_addr, six.binary_type): if isinstance(bind_addr, bytes):
bind_addr = bind_addr.decode() bind_addr = bind_addr.decode()
# pylint: disable=possibly-unused-variable # pylint: disable=possibly-unused-variable
@ -312,7 +316,10 @@ def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server):
) )
expected_textcreds = '!'.join(map(str, expected_textcreds)) expected_textcreds = '!'.join(map(str, expected_textcreds))
with requests_unixsocket.monkeypatch(): with requests_unixsocket.monkeypatch():
peercreds_text_resp = requests.get(unix_base_uri + PEERCRED_TEXTS_URI) peercreds_text_resp = requests.get(
unix_base_uri + PEERCRED_TEXTS_URI,
timeout=http_request_timeout,
)
peercreds_text_resp.raise_for_status() peercreds_text_resp.raise_for_status()
assert peercreds_text_resp.text == expected_textcreds assert peercreds_text_resp.text == expected_textcreds
@ -363,7 +370,10 @@ def test_high_number_of_file_descriptors(native_server_client, resource_limit):
assert any(fn >= resource_limit for fn in native_process_conn.filenos) assert any(fn >= resource_limit for fn in native_process_conn.filenos)
if not IS_WINDOWS: ISSUE511 = IS_MACOS
if not IS_WINDOWS and not ISSUE511:
test_high_number_of_file_descriptors = pytest.mark.forked( test_high_number_of_file_descriptors = pytest.mark.forked(
test_high_number_of_file_descriptors, test_high_number_of_file_descriptors,
) )

View file

@ -1,9 +1,4 @@
"""Tests for TLS support.""" """Tests for TLS support."""
# -*- coding: utf-8 -*-
# vim: set fileencoding=utf-8 :
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import functools import functools
import json import json
@ -14,11 +9,11 @@ import sys
import threading import threading
import time import time
import traceback import traceback
import http.client
import OpenSSL.SSL import OpenSSL.SSL
import pytest import pytest
import requests import requests
import six
import trustme import trustme
from .._compat import bton, ntob, ntou from .._compat import bton, ntob, ntou
@ -49,9 +44,6 @@ IS_PYOPENSSL_SSL_VERSION_1_0 = (
OpenSSL.SSL.SSLeay_version(OpenSSL.SSL.SSLEAY_VERSION). OpenSSL.SSL.SSLeay_version(OpenSSL.SSL.SSLEAY_VERSION).
startswith(b'OpenSSL 1.0.') startswith(b'OpenSSL 1.0.')
) )
PY27 = sys.version_info[:2] == (2, 7)
PY34 = sys.version_info[:2] == (3, 4)
PY3 = not six.PY2
PY310_PLUS = sys.version_info[:2] >= (3, 10) PY310_PLUS = sys.version_info[:2] >= (3, 10)
@ -64,13 +56,12 @@ _stdlib_to_openssl_verify = {
fails_under_py3 = pytest.mark.xfail( fails_under_py3 = pytest.mark.xfail(
not six.PY2,
reason='Fails under Python 3+', reason='Fails under Python 3+',
) )
fails_under_py3_in_pypy = pytest.mark.xfail( fails_under_py3_in_pypy = pytest.mark.xfail(
not six.PY2 and IS_PYPY, IS_PYPY,
reason='Fails under PyPy3', reason='Fails under PyPy3',
) )
@ -213,6 +204,7 @@ def thread_exceptions():
), ),
) )
def test_ssl_adapters( def test_ssl_adapters(
http_request_timeout,
tls_http_server, adapter_type, tls_http_server, adapter_type,
tls_certificate, tls_certificate,
tls_certificate_chain_pem_path, tls_certificate_chain_pem_path,
@ -241,6 +233,7 @@ def test_ssl_adapters(
resp = requests.get( resp = requests.get(
'https://{host!s}:{port!s}/'.format(host=interface, port=port), 'https://{host!s}:{port!s}/'.format(host=interface, port=port),
timeout=http_request_timeout,
verify=tls_ca_certificate_pem_path, verify=tls_ca_certificate_pem_path,
) )
@ -276,8 +269,9 @@ def test_ssl_adapters(
reason='Fails under PyPy in CI for unknown reason', reason='Fails under PyPy in CI for unknown reason',
strict=False, strict=False,
) )
def test_tls_client_auth( # noqa: C901 # FIXME def test_tls_client_auth( # noqa: C901, WPS213 # FIXME
# FIXME: remove twisted logic, separate tests # FIXME: remove twisted logic, separate tests
http_request_timeout,
mocker, mocker,
tls_http_server, adapter_type, tls_http_server, adapter_type,
ca, ca,
@ -331,6 +325,9 @@ def test_tls_client_auth( # noqa: C901 # FIXME
requests.get, requests.get,
'https://{host!s}:{port!s}/'.format(host=interface, port=port), 'https://{host!s}:{port!s}/'.format(host=interface, port=port),
# Don't wait for the first byte forever:
timeout=http_request_timeout,
# Server TLS certificate verification: # Server TLS certificate verification:
verify=tls_ca_certificate_pem_path, verify=tls_ca_certificate_pem_path,
@ -348,12 +345,13 @@ def test_tls_client_auth( # noqa: C901 # FIXME
and tls_verify_mode == ssl.CERT_REQUIRED and tls_verify_mode == ssl.CERT_REQUIRED
and tls_client_identity == 'localhost' and tls_client_identity == 'localhost'
and is_trusted_cert and is_trusted_cert
) or PY34: ):
pytest.xfail( pytest.xfail(
'OpenSSL 1.0 has problems with verifying client certs', 'OpenSSL 1.0 has problems with verifying client certs',
) )
assert is_req_successful assert is_req_successful
assert resp.text == 'Hello world!' assert resp.text == 'Hello world!'
resp.close()
return return
# xfail some flaky tests # xfail some flaky tests
@ -366,29 +364,16 @@ def test_tls_client_auth( # noqa: C901 # FIXME
if issue_237: if issue_237:
pytest.xfail('Test sometimes fails') pytest.xfail('Test sometimes fails')
expected_ssl_errors = ( expected_ssl_errors = requests.exceptions.SSLError,
requests.exceptions.SSLError,
OpenSSL.SSL.Error,
) if PY34 else (
requests.exceptions.SSLError,
)
if IS_WINDOWS or IS_GITHUB_ACTIONS_WORKFLOW: if IS_WINDOWS or IS_GITHUB_ACTIONS_WORKFLOW:
expected_ssl_errors += requests.exceptions.ConnectionError, expected_ssl_errors += requests.exceptions.ConnectionError,
with pytest.raises(expected_ssl_errors) as ssl_err: with pytest.raises(expected_ssl_errors) as ssl_err:
make_https_request() make_https_request().close()
if PY34 and isinstance(ssl_err, OpenSSL.SSL.Error):
pytest.xfail(
'OpenSSL behaves wierdly under Python 3.4 '
'because of an outdated urllib3',
)
try: try:
err_text = ssl_err.value.args[0].reason.args[0].args[0] err_text = ssl_err.value.args[0].reason.args[0].args[0]
except AttributeError: except AttributeError:
if PY34: if IS_WINDOWS or IS_GITHUB_ACTIONS_WORKFLOW:
pytest.xfail('OpenSSL behaves wierdly under Python 3.4')
elif IS_WINDOWS or IS_GITHUB_ACTIONS_WORKFLOW:
err_text = str(ssl_err.value) err_text = str(ssl_err.value)
else: else:
raise raise
@ -400,9 +385,8 @@ def test_tls_client_auth( # noqa: C901 # FIXME
'sslv3 alert bad certificate' if IS_LIBRESSL_BACKEND 'sslv3 alert bad certificate' if IS_LIBRESSL_BACKEND
else 'tlsv1 alert unknown ca', else 'tlsv1 alert unknown ca',
) )
if not six.PY2: if IS_MACOS and IS_PYPY and adapter_type == 'pyopenssl':
if IS_MACOS and IS_PYPY and adapter_type == 'pyopenssl': expected_substrings = ('tlsv1 alert unknown ca',)
expected_substrings = ('tlsv1 alert unknown ca',)
if ( if (
tls_verify_mode in ( tls_verify_mode in (
ssl.CERT_REQUIRED, ssl.CERT_REQUIRED,
@ -469,9 +453,9 @@ def test_tls_client_auth( # noqa: C901 # FIXME
pytest.param( pytest.param(
'builtin', 'builtin',
marks=pytest.mark.xfail( marks=pytest.mark.xfail(
IS_GITHUB_ACTIONS_WORKFLOW and IS_MACOS and PY310_PLUS, IS_MACOS and PY310_PLUS,
reason='Unclosed TLS resource warnings happen on macOS ' reason='Unclosed TLS resource warnings happen on macOS '
'under Python 3.10', 'under Python 3.10 (#508)',
strict=False, strict=False,
), ),
), ),
@ -492,6 +476,7 @@ def test_ssl_env( # noqa: C901 # FIXME
thread_exceptions, thread_exceptions,
recwarn, recwarn,
mocker, mocker,
http_request_timeout,
tls_http_server, adapter_type, tls_http_server, adapter_type,
ca, tls_verify_mode, tls_certificate, ca, tls_verify_mode, tls_certificate,
tls_certificate_chain_pem_path, tls_certificate_chain_pem_path,
@ -532,13 +517,10 @@ def test_ssl_env( # noqa: C901 # FIXME
resp = requests.get( resp = requests.get(
'https://' + interface + ':' + str(port) + '/env', 'https://' + interface + ':' + str(port) + '/env',
timeout=http_request_timeout,
verify=tls_ca_certificate_pem_path, verify=tls_ca_certificate_pem_path,
cert=cl_pem if use_client_cert else None, cert=cl_pem if use_client_cert else None,
) )
if PY34 and resp.status_code != 200:
pytest.xfail(
'Python 3.4 has problems with verifying client certs',
)
env = json.loads(resp.content.decode('utf-8')) env = json.loads(resp.content.decode('utf-8'))
@ -620,7 +602,7 @@ def test_https_over_http_error(http_server, ip_addr):
httpserver = http_server.send((ip_addr, EPHEMERAL_PORT)) httpserver = http_server.send((ip_addr, EPHEMERAL_PORT))
interface, _host, port = _get_conn_data(httpserver.bind_addr) interface, _host, port = _get_conn_data(httpserver.bind_addr)
with pytest.raises(ssl.SSLError) as ssl_err: with pytest.raises(ssl.SSLError) as ssl_err:
six.moves.http_client.HTTPSConnection( http.client.HTTPSConnection(
'{interface}:{port}'.format( '{interface}:{port}'.format(
interface=interface, interface=interface,
port=port, port=port,
@ -633,20 +615,10 @@ def test_https_over_http_error(http_server, ip_addr):
assert expected_substring in ssl_err.value.args[-1] assert expected_substring in ssl_err.value.args[-1]
http_over_https_error_builtin_marks = []
if IS_WINDOWS and six.PY2:
http_over_https_error_builtin_marks.append(
pytest.mark.flaky(reruns=5, reruns_delay=2),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'adapter_type', 'adapter_type',
( (
pytest.param( 'builtin',
'builtin',
marks=http_over_https_error_builtin_marks,
),
'pyopenssl', 'pyopenssl',
), ),
) )
@ -657,7 +629,9 @@ if IS_WINDOWS and six.PY2:
pytest.param(ANY_INTERFACE_IPV6, marks=missing_ipv6), pytest.param(ANY_INTERFACE_IPV6, marks=missing_ipv6),
), ),
) )
@pytest.mark.flaky(reruns=3, reruns_delay=2)
def test_http_over_https_error( def test_http_over_https_error(
http_request_timeout,
tls_http_server, adapter_type, tls_http_server, adapter_type,
ca, ip_addr, ca, ip_addr,
tls_certificate, tls_certificate,
@ -697,36 +671,12 @@ def test_http_over_https_error(
expect_fallback_response_over_plain_http = ( expect_fallback_response_over_plain_http = (
( (
adapter_type == 'pyopenssl' adapter_type == 'pyopenssl'
and (IS_ABOVE_OPENSSL10 or not six.PY2)
) )
or PY27
) or (
IS_GITHUB_ACTIONS_WORKFLOW
and IS_WINDOWS
and six.PY2
and not IS_WIN2016
) )
if (
IS_GITHUB_ACTIONS_WORKFLOW
and IS_WINDOWS
and six.PY2
and IS_WIN2016
and adapter_type == 'builtin'
and ip_addr is ANY_INTERFACE_IPV6
):
expect_fallback_response_over_plain_http = True
if (
IS_GITHUB_ACTIONS_WORKFLOW
and IS_WINDOWS
and six.PY2
and not IS_WIN2016
and adapter_type == 'builtin'
and ip_addr is not ANY_INTERFACE_IPV6
):
expect_fallback_response_over_plain_http = False
if expect_fallback_response_over_plain_http: if expect_fallback_response_over_plain_http:
resp = requests.get( resp = requests.get(
'http://{host!s}:{port!s}/'.format(host=fqdn, port=port), 'http://{host!s}:{port!s}/'.format(host=fqdn, port=port),
timeout=http_request_timeout,
) )
assert resp.status_code == 400 assert resp.status_code == 400
assert resp.text == ( assert resp.text == (
@ -738,6 +688,7 @@ def test_http_over_https_error(
with pytest.raises(requests.exceptions.ConnectionError) as ssl_err: with pytest.raises(requests.exceptions.ConnectionError) as ssl_err:
requests.get( # FIXME: make stdlib ssl behave like PyOpenSSL requests.get( # FIXME: make stdlib ssl behave like PyOpenSSL
'http://{host!s}:{port!s}/'.format(host=fqdn, port=port), 'http://{host!s}:{port!s}/'.format(host=fqdn, port=port),
timeout=http_request_timeout,
) )
if IS_LINUX: if IS_LINUX:

View file

@ -37,6 +37,7 @@ def simple_wsgi_server():
yield locals() yield locals()
@pytest.mark.flaky(reruns=3, reruns_delay=2)
def test_connection_keepalive(simple_wsgi_server): def test_connection_keepalive(simple_wsgi_server):
"""Test the connection keepalive works (duh).""" """Test the connection keepalive works (duh)."""
session = Session(base_url=simple_wsgi_server['url']) session = Session(base_url=simple_wsgi_server['url'])
@ -59,6 +60,7 @@ def test_connection_keepalive(simple_wsgi_server):
] ]
failures = sum(task.result() for task in tasks) failures = sum(task.result() for task in tasks)
session.close()
assert not failures assert not failures

View file

@ -15,9 +15,6 @@ the traceback to stdout, and keep any assertions you have from running
be of further significance to your tests). be of further significance to your tests).
""" """
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import pprint import pprint
import re import re
import socket import socket
@ -29,9 +26,8 @@ import json
import unittest # pylint: disable=deprecated-module,preferred-module import unittest # pylint: disable=deprecated-module,preferred-module
import warnings import warnings
import functools import functools
import http.client
from six.moves import http_client, map, urllib_parse import urllib.parse
import six
from more_itertools.more import always_iterable from more_itertools.more import always_iterable
import jaraco.functools import jaraco.functools
@ -105,7 +101,7 @@ class WebCase(unittest.TestCase):
HOST = '127.0.0.1' HOST = '127.0.0.1'
PORT = 8000 PORT = 8000
HTTP_CONN = http_client.HTTPConnection HTTP_CONN = http.client.HTTPConnection
PROTOCOL = 'HTTP/1.1' PROTOCOL = 'HTTP/1.1'
scheme = 'http' scheme = 'http'
@ -127,7 +123,7 @@ class WebCase(unittest.TestCase):
* from :py:mod:`python:http.client`. * from :py:mod:`python:http.client`.
""" """
cls_name = '{scheme}Connection'.format(scheme=self.scheme.upper()) cls_name = '{scheme}Connection'.format(scheme=self.scheme.upper())
return getattr(http_client, cls_name) return getattr(http.client, cls_name)
def get_conn(self, auto_open=False): def get_conn(self, auto_open=False):
"""Return a connection to our HTTP server.""" """Return a connection to our HTTP server."""
@ -201,9 +197,9 @@ class WebCase(unittest.TestCase):
""" """
ServerError.on = False ServerError.on = False
if isinstance(url, six.text_type): if isinstance(url, str):
url = url.encode('utf-8') url = url.encode('utf-8')
if isinstance(body, six.text_type): if isinstance(body, str):
body = body.encode('utf-8') body = body.encode('utf-8')
# for compatibility, support raise_subcls is None # for compatibility, support raise_subcls is None
@ -386,7 +382,7 @@ class WebCase(unittest.TestCase):
def assertBody(self, value, msg=None): def assertBody(self, value, msg=None):
"""Fail if value != self.body.""" """Fail if value != self.body."""
if isinstance(value, six.text_type): if isinstance(value, str):
value = value.encode(self.encoding) value = value.encode(self.encoding)
if value != self.body: if value != self.body:
if msg is None: if msg is None:
@ -397,7 +393,7 @@ class WebCase(unittest.TestCase):
def assertInBody(self, value, msg=None): def assertInBody(self, value, msg=None):
"""Fail if value not in self.body.""" """Fail if value not in self.body."""
if isinstance(value, six.text_type): if isinstance(value, str):
value = value.encode(self.encoding) value = value.encode(self.encoding)
if value not in self.body: if value not in self.body:
if msg is None: if msg is None:
@ -406,7 +402,7 @@ class WebCase(unittest.TestCase):
def assertNotInBody(self, value, msg=None): def assertNotInBody(self, value, msg=None):
"""Fail if value in self.body.""" """Fail if value in self.body."""
if isinstance(value, six.text_type): if isinstance(value, str):
value = value.encode(self.encoding) value = value.encode(self.encoding)
if value in self.body: if value in self.body:
if msg is None: if msg is None:
@ -415,7 +411,7 @@ class WebCase(unittest.TestCase):
def assertMatchesBody(self, pattern, msg=None, flags=0): def assertMatchesBody(self, pattern, msg=None, flags=0):
"""Fail if value (a regex pattern) is not in self.body.""" """Fail if value (a regex pattern) is not in self.body."""
if isinstance(pattern, six.text_type): if isinstance(pattern, str):
pattern = pattern.encode(self.encoding) pattern = pattern.encode(self.encoding)
if re.search(pattern, self.body, flags) is None: if re.search(pattern, self.body, flags) is None:
if msg is None: if msg is None:
@ -464,25 +460,7 @@ def shb(response):
"""Return status, headers, body the way we like from a response.""" """Return status, headers, body the way we like from a response."""
resp_status_line = '%s %s' % (response.status, response.reason) resp_status_line = '%s %s' % (response.status, response.reason)
if not six.PY2: return resp_status_line, response.getheaders(), response.read()
return resp_status_line, response.getheaders(), response.read()
h = []
key, value = None, None
for line in response.msg.headers:
if line:
if line[0] in ' \t':
value += line.strip()
else:
if key and value:
h.append((key, value))
key, value = line.split(':', 1)
key = key.strip()
value = value.strip()
if key and value:
h.append((key, value))
return resp_status_line, h, response.read()
# def openURL(*args, raise_subcls=(), **kwargs): # def openURL(*args, raise_subcls=(), **kwargs):
@ -514,7 +492,7 @@ def openURL(*args, **kwargs):
def _open_url_once( def _open_url_once(
url, headers=None, method='GET', body=None, url, headers=None, method='GET', body=None,
host='127.0.0.1', port=8000, http_conn=http_client.HTTPConnection, host='127.0.0.1', port=8000, http_conn=http.client.HTTPConnection,
protocol='HTTP/1.1', ssl_context=None, protocol='HTTP/1.1', ssl_context=None,
): ):
"""Open the given HTTP resource and return status, headers, and body.""" """Open the given HTTP resource and return status, headers, and body."""
@ -530,7 +508,7 @@ def _open_url_once(
conn = http_conn(interface(host), port, **kw) conn = http_conn(interface(host), port, **kw)
conn._http_vsn_str = protocol conn._http_vsn_str = protocol
conn._http_vsn = int(''.join([x for x in protocol if x.isdigit()])) conn._http_vsn = int(''.join([x for x in protocol if x.isdigit()]))
if not six.PY2 and isinstance(url, bytes): if isinstance(url, bytes):
url = url.decode() url = url.decode()
conn.putrequest( conn.putrequest(
method.upper(), url, skip_host=True, method.upper(), url, skip_host=True,
@ -572,10 +550,10 @@ def strip_netloc(url):
>>> strip_netloc('/foo/bar?bing#baz') >>> strip_netloc('/foo/bar?bing#baz')
'/foo/bar?bing' '/foo/bar?bing'
""" """
parsed = urllib_parse.urlparse(url) parsed = urllib.parse.urlparse(url)
_scheme, _netloc, path, params, query, _fragment = parsed _scheme, _netloc, path, params, query, _fragment = parsed
stripped = '', '', path, params, query, '' stripped = '', '', path, params, query, ''
return urllib_parse.urlunparse(stripped) return urllib.parse.urlunparse(stripped)
# Add any exceptions which your web framework handles # Add any exceptions which your web framework handles

View file

@ -1,16 +1,13 @@
"""Pytest fixtures and other helpers for doing testing by end-users.""" """Pytest fixtures and other helpers for doing testing by end-users."""
from __future__ import absolute_import, division, print_function from contextlib import closing, contextmanager
__metaclass__ = type
from contextlib import closing
import errno import errno
import socket import socket
import threading import threading
import time import time
import http.client
import pytest import pytest
from six.moves import http_client
import cheroot.server import cheroot.server
from cheroot.test import webtest from cheroot.test import webtest
@ -33,6 +30,7 @@ config = {
} }
@contextmanager
def cheroot_server(server_factory): def cheroot_server(server_factory):
"""Set up and tear down a Cheroot server instance.""" """Set up and tear down a Cheroot server instance."""
conf = config[server_factory].copy() conf = config[server_factory].copy()
@ -64,14 +62,14 @@ def cheroot_server(server_factory):
@pytest.fixture @pytest.fixture
def wsgi_server(): def wsgi_server():
"""Set up and tear down a Cheroot WSGI server instance.""" """Set up and tear down a Cheroot WSGI server instance."""
for srv in cheroot_server(cheroot.wsgi.Server): with cheroot_server(cheroot.wsgi.Server) as srv:
yield srv yield srv
@pytest.fixture @pytest.fixture
def native_server(): def native_server():
"""Set up and tear down a Cheroot HTTP server instance.""" """Set up and tear down a Cheroot HTTP server instance."""
for srv in cheroot_server(cheroot.server.HTTPServer): with cheroot_server(cheroot.server.HTTPServer) as srv:
yield srv yield srv
@ -89,9 +87,9 @@ class _TestClient:
port=self._port, port=self._port,
) )
conn_cls = ( conn_cls = (
http_client.HTTPConnection http.client.HTTPConnection
if self.server_instance.ssl_adapter is None else if self.server_instance.ssl_adapter is None else
http_client.HTTPSConnection http.client.HTTPSConnection
) )
return conn_cls(name) return conn_cls(name)

View file

@ -5,17 +5,12 @@
joinable joinable
""" """
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import collections import collections
import threading import threading
import time import time
import socket import socket
import warnings import warnings
import queue
from six.moves import queue
from jaraco.functools import pass_none from jaraco.functools import pass_none
@ -178,7 +173,7 @@ class ThreadPool:
for worker in self._threads: for worker in self._threads:
worker.name = ( worker.name = (
'CP Server {worker_name!s}'. 'CP Server {worker_name!s}'.
format(worker_name=worker.name), format(worker_name=worker.name)
) )
worker.start() worker.start()
for worker in self._threads: for worker in self._threads:
@ -228,7 +223,7 @@ class ThreadPool:
worker = WorkerThread(self.server) worker = WorkerThread(self.server)
worker.name = ( worker.name = (
'CP Server {worker_name!s}'. 'CP Server {worker_name!s}'.
format(worker_name=worker.name), format(worker_name=worker.name)
) )
worker.start() worker.start()
return worker return worker

View file

@ -25,14 +25,8 @@ as you want in one instance by using a PathInfoDispatcher::
server = wsgi.Server(addr, d) server = wsgi.Server(addr, d)
""" """
from __future__ import absolute_import, division, print_function
__metaclass__ = type
import sys import sys
import six
from six.moves import filter
from . import server from . import server
from .workers import threadpool from .workers import threadpool
from ._compat import ntob, bton from ._compat import ntob, bton
@ -140,7 +134,7 @@ class Gateway(server.Gateway):
response = self.req.server.wsgi_app(self.env, self.start_response) response = self.req.server.wsgi_app(self.env, self.start_response)
try: try:
for chunk in filter(None, response): for chunk in filter(None, response):
if not isinstance(chunk, six.binary_type): if not isinstance(chunk, bytes):
raise ValueError('WSGI Applications must yield bytes') raise ValueError('WSGI Applications must yield bytes')
self.write(chunk) self.write(chunk)
finally: finally:
@ -149,7 +143,7 @@ class Gateway(server.Gateway):
if hasattr(response, 'close'): if hasattr(response, 'close'):
response.close() response.close()
def start_response(self, status, headers, exc_info=None): def start_response(self, status, headers, exc_info=None): # noqa: WPS238
"""WSGI callable to begin the HTTP response.""" """WSGI callable to begin the HTTP response."""
# "The application may call start_response more than once, # "The application may call start_response more than once,
# if and only if the exc_info argument is provided." # if and only if the exc_info argument is provided."
@ -164,10 +158,8 @@ class Gateway(server.Gateway):
# sent, start_response must raise an error, and should raise the # sent, start_response must raise an error, and should raise the
# exc_info tuple." # exc_info tuple."
if self.req.sent_headers: if self.req.sent_headers:
try: value = exc_info[1]
six.reraise(*exc_info) raise value
finally:
exc_info = None
self.req.status = self._encode_status(status) self.req.status = self._encode_status(status)
@ -196,8 +188,6 @@ class Gateway(server.Gateway):
must be of type "str" but are restricted to code points in the must be of type "str" but are restricted to code points in the
"Latin-1" set. "Latin-1" set.
""" """
if six.PY2:
return status
if not isinstance(status, str): if not isinstance(status, str):
raise TypeError('WSGI response status is not of type str.') raise TypeError('WSGI response status is not of type str.')
return status.encode('ISO-8859-1') return status.encode('ISO-8859-1')
@ -273,7 +263,7 @@ class Gateway_10(Gateway):
'wsgi.version': self.version, 'wsgi.version': self.version,
} }
if isinstance(req.server.bind_addr, six.string_types): if isinstance(req.server.bind_addr, str):
# AF_UNIX. This isn't really allowed by WSGI, which doesn't # AF_UNIX. This isn't really allowed by WSGI, which doesn't
# address unix domain sockets. But it's better than nothing. # address unix domain sockets. But it's better than nothing.
env['SERVER_PORT'] = '' env['SERVER_PORT'] = ''
@ -332,10 +322,10 @@ class Gateway_u0(Gateway_10):
"""Return a new environ dict targeting the given wsgi.version.""" """Return a new environ dict targeting the given wsgi.version."""
req = self.req req = self.req
env_10 = super(Gateway_u0, self).get_environ() env_10 = super(Gateway_u0, self).get_environ()
env = dict(map(self._decode_key, env_10.items())) env = dict(env_10.items())
# Request-URI # Request-URI
enc = env.setdefault(six.u('wsgi.url_encoding'), six.u('utf-8')) enc = env.setdefault('wsgi.url_encoding', 'utf-8')
try: try:
env['PATH_INFO'] = req.path.decode(enc) env['PATH_INFO'] = req.path.decode(enc)
env['QUERY_STRING'] = req.qs.decode(enc) env['QUERY_STRING'] = req.qs.decode(enc)
@ -345,25 +335,10 @@ class Gateway_u0(Gateway_10):
env['PATH_INFO'] = env_10['PATH_INFO'] env['PATH_INFO'] = env_10['PATH_INFO']
env['QUERY_STRING'] = env_10['QUERY_STRING'] env['QUERY_STRING'] = env_10['QUERY_STRING']
env.update(map(self._decode_value, env.items())) env.update(env.items())
return env return env
@staticmethod
def _decode_key(item):
k, v = item
if six.PY2:
k = k.decode('ISO-8859-1')
return k, v
@staticmethod
def _decode_value(item):
k, v = item
skip_keys = 'REQUEST_URI', 'wsgi.input'
if not six.PY2 or not isinstance(v, bytes) or k in skip_keys:
return k, v
return k, v.decode('ISO-8859-1')
wsgi_gateways = Gateway.gateway_map() wsgi_gateways = Gateway.gateway_map()

View file

@ -40,3 +40,10 @@ class PathInfoDispatcher:
apps: Any apps: Any
def __init__(self, apps): ... def __init__(self, apps): ...
def __call__(self, environ, start_response): ... def __call__(self, environ, start_response): ...
WSGIServer = Server
WSGIGateway = Gateway
WSGIGateway_u0 = Gateway_u0
WSGIGateway_10 = Gateway_10
WSGIPathInfoDispatcher = PathInfoDispatcher

View file

@ -7,7 +7,7 @@ backports.zoneinfo==0.2.1
beautifulsoup4==4.11.1 beautifulsoup4==4.11.1
bleach==5.0.1 bleach==5.0.1
certifi==2022.9.24 certifi==2022.9.24
cheroot==8.6.0 cheroot==9.0.0
cherrypy==18.8.0 cherrypy==18.8.0
cloudinary==1.30.0 cloudinary==1.30.0
distro==1.8.0 distro==1.8.0