diff --git a/lib/cheroot/_compat.py b/lib/cheroot/_compat.py index 79899b9d..10dcdefa 100644 --- a/lib/cheroot/_compat.py +++ b/lib/cheroot/_compat.py @@ -1,13 +1,20 @@ +# pylint: disable=unused-import """Compatibility code for using Cheroot with various versions of Python.""" from __future__ import absolute_import, division, print_function __metaclass__ = type +import os import platform import re import six +try: + import selectors # lgtm [py/unused-import] +except ImportError: + import selectors2 as selectors # noqa: F401 # lgtm [py/unused-import] + try: import ssl IS_ABOVE_OPENSSL10 = ssl.OPENSSL_VERSION_INFO >= (1, 1) @@ -15,6 +22,24 @@ try: except ImportError: IS_ABOVE_OPENSSL10 = None +# contextlib.suppress was added in Python 3.4 +try: + from contextlib import suppress +except ImportError: + from contextlib import contextmanager + + @contextmanager + def suppress(*exceptions): + """Return a context manager that suppresses the `exceptions`.""" + try: + yield + except exceptions: + pass + + +IS_CI = bool(os.getenv('CI')) +IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv('GITHUB_WORKFLOW')) + IS_PYPY = platform.python_implementation() == 'PyPy' @@ -36,7 +61,7 @@ if not six.PY2: return n.encode(encoding) def ntou(n, encoding='ISO-8859-1'): - """Return the native string as unicode with the given encoding.""" + """Return the native string as Unicode with the given encoding.""" assert_native(n) # In Python 3, the native string type is unicode return n @@ -55,7 +80,7 @@ else: return n def ntou(n, encoding='ISO-8859-1'): - """Return the native string as unicode with the given encoding.""" + """Return the native string as Unicode with the given encoding.""" assert_native(n) # In Python 2, the native string type is bytes. # First, check for the special encoding 'escape'. The test suite uses @@ -78,7 +103,7 @@ else: def assert_native(n): - """Check whether the input is of nativ ``str`` type. + """Check whether the input is of native :py:class:`str` type. Raises: TypeError: in case of failed check @@ -89,22 +114,35 @@ def assert_native(n): if not six.PY2: - """Python 3 has memoryview builtin.""" + """Python 3 has :py:class:`memoryview` builtin.""" # Python 2.7 has it backported, but socket.write() does # str(memoryview(b'0' * 100)) -> # instead of accessing it correctly. memoryview = memoryview else: - """Link memoryview to buffer under Python 2.""" + """Link :py:class:`memoryview` to buffer under Python 2.""" memoryview = buffer # noqa: F821 def extract_bytes(mv): - """Retrieve bytes out of memoryview/buffer or bytes.""" + r"""Retrieve bytes out of the given input buffer. + + :param mv: input :py:func:`buffer` + :type mv: memoryview or bytes + + :return: unwrapped bytes + :rtype: bytes + + :raises ValueError: if the input is not one of \ + :py:class:`memoryview`/:py:func:`buffer` \ + or :py:class:`bytes` + """ if isinstance(mv, memoryview): return bytes(mv) if six.PY2 else mv.tobytes() if isinstance(mv, bytes): return mv - raise ValueError + raise ValueError( + 'extract_bytes() only accepts bytes and memoryview/buffer', + ) diff --git a/lib/cheroot/cli.py b/lib/cheroot/cli.py index f46e7dea..4607e226 100644 --- a/lib/cheroot/cli.py +++ b/lib/cheroot/cli.py @@ -1,36 +1,42 @@ """Command line tool for starting a Cheroot WSGI/HTTP server instance. -Basic usage:: +Basic usage: - # Start a server on 127.0.0.1:8000 with the default settings - # for the WSGI app myapp/wsgi.py:application() - cheroot myapp.wsgi +.. code-block:: shell-session - # Start a server on 0.0.0.0:9000 with 8 threads - # for the WSGI app myapp/wsgi.py:main_app() - cheroot myapp.wsgi:main_app --bind 0.0.0.0:9000 --threads 8 + $ # Start a server on 127.0.0.1:8000 with the default settings + $ # for the WSGI app myapp/wsgi.py:application() + $ cheroot myapp.wsgi - # Start a server for the cheroot.server.Gateway subclass - # myapp/gateway.py:HTTPGateway - cheroot myapp.gateway:HTTPGateway + $ # Start a server on 0.0.0.0:9000 with 8 threads + $ # for the WSGI app myapp/wsgi.py:main_app() + $ cheroot myapp.wsgi:main_app --bind 0.0.0.0:9000 --threads 8 - # Start a server on the UNIX socket /var/spool/myapp.sock - cheroot myapp.wsgi --bind /var/spool/myapp.sock + $ # Start a server for the cheroot.server.Gateway subclass + $ # myapp/gateway.py:HTTPGateway + $ cheroot myapp.gateway:HTTPGateway - # Start a server on the abstract UNIX socket CherootServer - cheroot myapp.wsgi --bind @CherootServer + $ # Start a server on the UNIX socket /var/spool/myapp.sock + $ cheroot myapp.wsgi --bind /var/spool/myapp.sock + + $ # Start a server on the abstract UNIX socket CherootServer + $ cheroot myapp.wsgi --bind @CherootServer + +.. spelling:: + + cli """ import argparse from importlib import import_module import os import sys -import contextlib import six from . import server from . import wsgi +from ._compat import suppress __metaclass__ = type @@ -49,6 +55,7 @@ class TCPSocket(BindLocation): Args: address (str): Host name or IP address port (int): TCP port number + """ self.bind_addr = address, port @@ -64,9 +71,9 @@ class UnixSocket(BindLocation): class AbstractSocket(BindLocation): """AbstractSocket.""" - def __init__(self, addr): + def __init__(self, abstract_socket): """Initialize.""" - self.bind_addr = '\0{}'.format(self.abstract_socket) + self.bind_addr = '\x00{sock_path}'.format(sock_path=abstract_socket) class Application: @@ -77,8 +84,8 @@ class Application: """Read WSGI app/Gateway path string and import application module.""" mod_path, _, app_path = full_path.partition(':') app = getattr(import_module(mod_path), app_path or 'application') - - with contextlib.suppress(TypeError): + # suppress the `TypeError` exception, just in case `app` is not a class + with suppress(TypeError): if issubclass(app, server.Gateway): return GatewayYo(app) @@ -128,8 +135,17 @@ class GatewayYo: def parse_wsgi_bind_location(bind_addr_string): """Convert bind address string to a BindLocation.""" + # if the string begins with an @ symbol, use an abstract socket, + # this is the first condition to verify, otherwise the urlparse + # validation would detect //@ as a valid url with a hostname + # with value: "" and port: None + if bind_addr_string.startswith('@'): + return AbstractSocket(bind_addr_string[1:]) + # try and match for an IP/hostname and port - match = six.moves.urllib.parse.urlparse('//{}'.format(bind_addr_string)) + match = six.moves.urllib.parse.urlparse( + '//{addr}'.format(addr=bind_addr_string), + ) try: addr = match.hostname port = match.port @@ -139,9 +155,6 @@ def parse_wsgi_bind_location(bind_addr_string): pass # else, assume a UNIX socket path - # if the string begins with an @ symbol, use an abstract socket - if bind_addr_string.startswith('@'): - return AbstractSocket(bind_addr_string[1:]) return UnixSocket(path=bind_addr_string) @@ -151,70 +164,70 @@ def parse_wsgi_bind_addr(bind_addr_string): _arg_spec = { - '_wsgi_app': dict( - metavar='APP_MODULE', - type=Application.resolve, - help='WSGI application callable or cheroot.server.Gateway subclass', - ), - '--bind': dict( - metavar='ADDRESS', - dest='bind_addr', - type=parse_wsgi_bind_addr, - default='[::1]:8000', - help='Network interface to listen on (default: [::1]:8000)', - ), - '--chdir': dict( - metavar='PATH', - type=os.chdir, - help='Set the working directory', - ), - '--server-name': dict( - dest='server_name', - type=str, - help='Web server name to be advertised via Server HTTP header', - ), - '--threads': dict( - metavar='INT', - dest='numthreads', - type=int, - help='Minimum number of worker threads', - ), - '--max-threads': dict( - metavar='INT', - dest='max', - type=int, - help='Maximum number of worker threads', - ), - '--timeout': dict( - metavar='INT', - dest='timeout', - type=int, - help='Timeout in seconds for accepted connections', - ), - '--shutdown-timeout': dict( - metavar='INT', - dest='shutdown_timeout', - type=int, - help='Time in seconds to wait for worker threads to cleanly exit', - ), - '--request-queue-size': dict( - metavar='INT', - dest='request_queue_size', - type=int, - help='Maximum number of queued connections', - ), - '--accepted-queue-size': dict( - metavar='INT', - dest='accepted_queue_size', - type=int, - help='Maximum number of active requests in queue', - ), - '--accepted-queue-timeout': dict( - metavar='INT', - dest='accepted_queue_timeout', - type=int, - help='Timeout in seconds for putting requests into queue', - ), + '_wsgi_app': { + 'metavar': 'APP_MODULE', + 'type': Application.resolve, + 'help': 'WSGI application callable or cheroot.server.Gateway subclass', + }, + '--bind': { + 'metavar': 'ADDRESS', + 'dest': 'bind_addr', + 'type': parse_wsgi_bind_addr, + 'default': '[::1]:8000', + 'help': 'Network interface to listen on (default: [::1]:8000)', + }, + '--chdir': { + 'metavar': 'PATH', + 'type': os.chdir, + 'help': 'Set the working directory', + }, + '--server-name': { + 'dest': 'server_name', + 'type': str, + 'help': 'Web server name to be advertised via Server HTTP header', + }, + '--threads': { + 'metavar': 'INT', + 'dest': 'numthreads', + 'type': int, + 'help': 'Minimum number of worker threads', + }, + '--max-threads': { + 'metavar': 'INT', + 'dest': 'max', + 'type': int, + 'help': 'Maximum number of worker threads', + }, + '--timeout': { + 'metavar': 'INT', + 'dest': 'timeout', + 'type': int, + 'help': 'Timeout in seconds for accepted connections', + }, + '--shutdown-timeout': { + 'metavar': 'INT', + 'dest': 'shutdown_timeout', + 'type': int, + 'help': 'Time in seconds to wait for worker threads to cleanly exit', + }, + '--request-queue-size': { + 'metavar': 'INT', + 'dest': 'request_queue_size', + 'type': int, + 'help': 'Maximum number of queued connections', + }, + '--accepted-queue-size': { + 'metavar': 'INT', + 'dest': 'accepted_queue_size', + 'type': int, + 'help': 'Maximum number of active requests in queue', + }, + '--accepted-queue-timeout': { + 'metavar': 'INT', + 'dest': 'accepted_queue_timeout', + 'type': int, + 'help': 'Timeout in seconds for putting requests into queue', + }, } diff --git a/lib/cheroot/connections.py b/lib/cheroot/connections.py index 943ac65a..7debcbfd 100644 --- a/lib/cheroot/connections.py +++ b/lib/cheroot/connections.py @@ -5,11 +5,13 @@ __metaclass__ = type import io import os -import select import socket +import threading import time from . import errors +from ._compat import selectors +from ._compat import suppress from .makefile import MakeFile import six @@ -47,6 +49,69 @@ else: fcntl.fcntl(fd, fcntl.F_SETFD, old_flags | fcntl.FD_CLOEXEC) +class _ThreadsafeSelector: + """Thread-safe wrapper around a DefaultSelector. + + There are 2 thread contexts in which it may be accessed: + * the selector thread + * one of the worker threads in workers/threadpool.py + + The expected read/write patterns are: + * :py:func:`~iter`: selector thread + * :py:meth:`register`: selector thread and threadpool, + via :py:meth:`~cheroot.workers.threadpool.ThreadPool.put` + * :py:meth:`unregister`: selector thread only + + Notably, this means :py:class:`_ThreadsafeSelector` never needs to worry + that connections will be removed behind its back. + + The lock is held when iterating or modifying the selector but is not + required when :py:meth:`select()ing ` on it. + """ + + def __init__(self): + self._selector = selectors.DefaultSelector() + self._lock = threading.Lock() + + def __len__(self): + with self._lock: + return len(self._selector.get_map() or {}) + + @property + def connections(self): + """Retrieve connections registered with the selector.""" + with self._lock: + mapping = self._selector.get_map() or {} + for _, (_, sock_fd, _, conn) in mapping.items(): + yield (sock_fd, conn) + + def register(self, fileobj, events, data=None): + """Register ``fileobj`` with the selector.""" + with self._lock: + return self._selector.register(fileobj, events, data) + + def unregister(self, fileobj): + """Unregister ``fileobj`` from the selector.""" + with self._lock: + return self._selector.unregister(fileobj) + + def select(self, timeout=None): + """Return socket fd and data pairs from selectors.select call. + + Returns entries ready to read in the form: + (socket_file_descriptor, connection) + """ + return ( + (key.fd, key.data) + for key, _ in self._selector.select(timeout=timeout) + ) + + def close(self): + """Close the selector.""" + with self._lock: + self._selector.close() + + class ConnectionManager: """Class which manages HTTPConnection objects. @@ -60,21 +125,34 @@ class ConnectionManager: server (cheroot.server.HTTPServer): web server object that uses this ConnectionManager instance. """ + self._serving = False + self._stop_requested = False + self.server = server - self.connections = [] + self._selector = _ThreadsafeSelector() + + self._selector.register( + server.socket.fileno(), + selectors.EVENT_READ, data=server, + ) def put(self, conn): """Put idle connection into the ConnectionManager to be managed. - Args: - conn (cheroot.server.HTTPConnection): HTTP connection - to be managed. + :param conn: HTTP connection to be managed + :type conn: cheroot.server.HTTPConnection """ conn.last_used = time.time() - conn.ready_with_data = conn.rfile.has_data() - self.connections.append(conn) + # if this conn doesn't have any more data waiting to be read, + # register it with the selector. + if conn.rfile.has_data(): + self.server.process_conn(conn) + else: + self._selector.register( + conn.socket.fileno(), selectors.EVENT_READ, data=conn, + ) - def expire(self): + def _expire(self): """Expire least recently used connections. This happens if there are either too many open connections, or if the @@ -82,107 +160,102 @@ class ConnectionManager: This should be called periodically. """ - if not self.connections: - return + # find any connections still registered with the selector + # that have not been active recently enough. + threshold = time.time() - self.server.timeout + timed_out_connections = [ + (sock_fd, conn) + for (sock_fd, conn) in self._selector.connections + if conn != self.server and conn.last_used < threshold + ] + for sock_fd, conn in timed_out_connections: + self._selector.unregister(sock_fd) + conn.close() - # Look at the first connection - if it can be closed, then do - # that, and wait for get_conn to return it. - conn = self.connections[0] - if conn.closeable: - return + def stop(self): + """Stop the selector loop in run() synchronously. - # Too many connections? - ka_limit = self.server.keep_alive_conn_limit - if ka_limit is not None and len(self.connections) > ka_limit: - conn.closeable = True - return + May take up to half a second. + """ + self._stop_requested = True + while self._serving: + time.sleep(0.01) - # Connection too old? - if (conn.last_used + self.server.timeout) < time.time(): - conn.closeable = True - return - - def get_conn(self, server_socket): - """Return a HTTPConnection object which is ready to be handled. - - A connection returned by this method should be ready for a worker - to handle it. If there are no connections ready, None will be - returned. - - Any connection returned by this method will need to be `put` - back if it should be examined again for another request. + def run(self, expiration_interval): + """Run the connections selector indefinitely. Args: - server_socket (socket.socket): Socket to listen to for new - connections. - Returns: - cheroot.server.HTTPConnection instance, or None. + expiration_interval (float): Interval, in seconds, at which + connections will be checked for expiration. + Connections that are ready to process are submitted via + self.server.process_conn() + + Connections submitted for processing must be `put()` + back if they should be examined again for another request. + + Can be shut down by calling `stop()`. """ - # Grab file descriptors from sockets, but stop if we find a - # connection which is already marked as ready. - socket_dict = {} - for conn in self.connections: - if conn.closeable or conn.ready_with_data: - break - socket_dict[conn.socket.fileno()] = conn - else: - # No ready connection. - conn = None - - # We have a connection ready for use. - if conn: - self.connections.remove(conn) - return conn - - # Will require a select call. - ss_fileno = server_socket.fileno() - socket_dict[ss_fileno] = server_socket + self._serving = True try: - rlist, _, _ = select.select(list(socket_dict), [], [], 0.1) - # No available socket. - if not rlist: - return None - except OSError: - # Mark any connection which no longer appears valid. - for fno, conn in list(socket_dict.items()): - # If the server socket is invalid, we'll just ignore it and - # wait to be shutdown. - if fno == ss_fileno: - continue - try: - os.fstat(fno) - except OSError: - # Socket is invalid, close the connection, insert at - # the front. - self.connections.remove(conn) - self.connections.insert(0, conn) - conn.closeable = True + self._run(expiration_interval) + finally: + self._serving = False - # Wait for the next tick to occur. - return None + def _run(self, expiration_interval): + last_expiration_check = time.time() - try: - # See if we have a new connection coming in. - rlist.remove(ss_fileno) - except ValueError: - # No new connection, but reuse existing socket. - conn = socket_dict[rlist.pop()] - else: - conn = server_socket + while not self._stop_requested: + try: + active_list = self._selector.select(timeout=0.01) + except OSError: + self._remove_invalid_sockets() + continue - # All remaining connections in rlist should be marked as ready. - for fno in rlist: - socket_dict[fno].ready_with_data = True + for (sock_fd, conn) in active_list: + if conn is self.server: + # New connection + new_conn = self._from_server_socket(self.server.socket) + if new_conn is not None: + self.server.process_conn(new_conn) + else: + # unregister connection from the selector until the server + # has read from it and returned it via put() + self._selector.unregister(sock_fd) + self.server.process_conn(conn) - # New connection. - if conn is server_socket: - return self._from_server_socket(server_socket) + now = time.time() + if (now - last_expiration_check) > expiration_interval: + self._expire() + last_expiration_check = now - self.connections.remove(conn) - return conn + def _remove_invalid_sockets(self): + """Clean up the resources of any broken connections. - def _from_server_socket(self, server_socket): + This method attempts to detect any connections in an invalid state, + unregisters them from the selector and closes the file descriptors of + the corresponding network sockets where possible. + """ + invalid_conns = [] + for sock_fd, conn in self._selector.connections: + if conn is self.server: + continue + + try: + os.fstat(sock_fd) + except OSError: + invalid_conns.append((sock_fd, conn)) + + for sock_fd, conn in invalid_conns: + self._selector.unregister(sock_fd) + # One of the reason on why a socket could cause an error + # is that the socket is already closed, ignore the + # socket error if we try to close it at this point. + # This is equivalent to OSError in Py3 + with suppress(socket.error): + conn.close() + + def _from_server_socket(self, server_socket): # noqa: C901 # FIXME try: s, addr = server_socket.accept() if self.server.stats['Enabled']: @@ -274,6 +347,23 @@ class ConnectionManager: def close(self): """Close all monitored connections.""" - for conn in self.connections[:]: - conn.close() - self.connections = [] + for (_, conn) in self._selector.connections: + if conn is not self.server: # server closes its own socket + conn.close() + self._selector.close() + + @property + def _num_connections(self): + """Return the current number of connections. + + Includes all connections registered with the selector, + minus one for the server socket, which is always registered + with the selector. + """ + return len(self._selector) - 1 + + @property + def can_add_keepalive_connection(self): + """Flag whether it is allowed to add a new keep-alive connection.""" + ka_limit = self.server.keep_alive_conn_limit + return ka_limit is None or self._num_connections < ka_limit diff --git a/lib/cheroot/errors.py b/lib/cheroot/errors.py index 80928731..e00629f8 100644 --- a/lib/cheroot/errors.py +++ b/lib/cheroot/errors.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """Collection of exceptions raised and/or processed by Cheroot.""" from __future__ import absolute_import, division, print_function @@ -23,14 +24,14 @@ class FatalSSLAlert(Exception): def plat_specific_errors(*errnames): - """Return error numbers for all errors in errnames on this platform. + """Return error numbers for all errors in ``errnames`` on this platform. - The 'errno' module contains different global constants depending on - the specific platform (OS). This function will return the list of - numeric values for a given list of potential names. + The :py:mod:`errno` module contains different global constants + depending on the specific platform (OS). This function will return + the list of numeric values for a given list of potential names. """ - missing_attr = set([None, ]) - unique_nums = set(getattr(errno, k, None) for k in errnames) + missing_attr = {None} + unique_nums = {getattr(errno, k, None) for k in errnames} return list(unique_nums - missing_attr) @@ -56,3 +57,32 @@ socket_errors_nonblocking = plat_specific_errors( if sys.platform == 'darwin': socket_errors_to_ignore.extend(plat_specific_errors('EPROTOTYPE')) socket_errors_nonblocking.extend(plat_specific_errors('EPROTOTYPE')) + + +acceptable_sock_shutdown_error_codes = { + errno.ENOTCONN, + errno.EPIPE, errno.ESHUTDOWN, # corresponds to BrokenPipeError in Python 3 + errno.ECONNRESET, # corresponds to ConnectionResetError in Python 3 +} +"""Errors that may happen during the connection close sequence. + +* ENOTCONN — client is no longer connected +* EPIPE — write on a pipe while the other end has been closed +* ESHUTDOWN — write on a socket which has been shutdown for writing +* ECONNRESET — connection is reset by the peer, we received a TCP RST packet + +Refs: +* https://github.com/cherrypy/cheroot/issues/341#issuecomment-735884889 +* https://bugs.python.org/issue30319 +* https://bugs.python.org/issue30329 +* https://github.com/python/cpython/commit/83a2c28 +* https://github.com/python/cpython/blob/c39b52f/Lib/poplib.py#L297-L302 +* https://docs.microsoft.com/windows/win32/api/winsock/nf-winsock-shutdown +""" + +try: # py3 + acceptable_sock_shutdown_exceptions = ( + BrokenPipeError, ConnectionResetError, + ) +except NameError: # py2 + acceptable_sock_shutdown_exceptions = () diff --git a/lib/cheroot/makefile.py b/lib/cheroot/makefile.py index 8a86b338..1383c658 100644 --- a/lib/cheroot/makefile.py +++ b/lib/cheroot/makefile.py @@ -68,7 +68,7 @@ class MakeFile_PY2(getattr(socket, '_fileobject', object)): self._refcount -= 1 def write(self, data): - """Sendall for non-blocking sockets.""" + """Send entire data contents for non-blocking sockets.""" bytes_sent = 0 data_mv = memoryview(data) payload_size = len(data_mv) @@ -122,7 +122,7 @@ class MakeFile_PY2(getattr(socket, '_fileobject', object)): # FauxSocket is no longer needed del FauxSocket - if not _fileobject_uses_str_type: + if not _fileobject_uses_str_type: # noqa: C901 # FIXME def read(self, size=-1): """Read data from the socket to buffer.""" # Use max, disallow tiny reads in a loop as they are very diff --git a/lib/cheroot/server.py b/lib/cheroot/server.py index 991160de..8b59a33a 100644 --- a/lib/cheroot/server.py +++ b/lib/cheroot/server.py @@ -7,12 +7,13 @@ sticking incoming connections onto a Queue:: server = HTTPServer(...) server.start() - -> while True: - tick() - # This blocks until a request comes in: - child = socket.accept() - conn = HTTPConnection(child, ...) - server.requests.put(conn) + -> serve() + while ready: + _connections.run() + while not stop_requested: + child = socket.accept() # blocks until a request comes in + conn = HTTPConnection(child, ...) + server.process_conn(conn) # adds conn to threadpool Worker threads are kept in a pool and poll the Queue, popping off and then handling each connection in turn. Each connection can consist of an arbitrary @@ -58,7 +59,6 @@ And now for a trivial doctest to exercise the test suite >>> 'HTTPServer' in globals() True - """ from __future__ import absolute_import, division, print_function @@ -74,6 +74,8 @@ import time import traceback as traceback_ import logging import platform +import contextlib +import threading try: from functools import lru_cache @@ -93,6 +95,7 @@ from .makefile import MakeFile, StreamWriter __all__ = ( 'HTTPRequest', 'HTTPConnection', 'HTTPServer', + 'HeaderReader', 'DropUnderscoreHeaderReader', 'SizeCheckWrapper', 'KnownLengthRFile', 'ChunkedRFile', 'Gateway', 'get_ssl_adapter_class', ) @@ -156,7 +159,10 @@ EMPTY = b'' ASTERISK = b'*' FORWARD_SLASH = b'/' QUOTED_SLASH = b'%2F' -QUOTED_SLASH_REGEX = re.compile(b'(?i)' + QUOTED_SLASH) +QUOTED_SLASH_REGEX = re.compile(b''.join((b'(?i)', QUOTED_SLASH))) + + +_STOPPING_FOR_INTERRUPT = object() # sentinel used during shutdown comma_separated_headers = [ @@ -179,7 +185,7 @@ class HeaderReader: Interface and default implementation. """ - def __call__(self, rfile, hdict=None): + def __call__(self, rfile, hdict=None): # noqa: C901 # FIXME """ Read headers from the given stream into the given header dict. @@ -248,15 +254,14 @@ class DropUnderscoreHeaderReader(HeaderReader): class SizeCheckWrapper: - """Wraps a file-like object, raising MaxSizeExceeded if too large.""" + """Wraps a file-like object, raising MaxSizeExceeded if too large. + + :param rfile: ``file`` of a limited size + :param int maxlen: maximum length of the file being read + """ def __init__(self, rfile, maxlen): - """Initialize SizeCheckWrapper instance. - - Args: - rfile (file): file of a limited size - maxlen (int): maximum length of the file being read - """ + """Initialize SizeCheckWrapper instance.""" self.rfile = rfile self.maxlen = maxlen self.bytes_read = 0 @@ -266,14 +271,12 @@ class SizeCheckWrapper: raise errors.MaxSizeExceeded() def read(self, size=None): - """Read a chunk from rfile buffer and return it. + """Read a chunk from ``rfile`` buffer and return it. - Args: - size (int): amount of data to read - - Returns: - bytes: Chunk from rfile, limited by size if specified. + :param int size: amount of data to read + :returns: chunk from ``rfile``, limited by size if specified + :rtype: bytes """ data = self.rfile.read(size) self.bytes_read += len(data) @@ -281,14 +284,12 @@ class SizeCheckWrapper: return data def readline(self, size=None): - """Read a single line from rfile buffer and return it. + """Read a single line from ``rfile`` buffer and return it. - Args: - size (int): minimum amount of data to read - - Returns: - bytes: One line from rfile. + :param int size: minimum amount of data to read + :returns: one line from ``rfile`` + :rtype: bytes """ if size is not None: data = self.rfile.readline(size) @@ -309,14 +310,12 @@ class SizeCheckWrapper: return EMPTY.join(res) def readlines(self, sizehint=0): - """Read all lines from rfile buffer and return them. + """Read all lines from ``rfile`` buffer and return them. - Args: - sizehint (int): hint of minimum amount of data to read - - Returns: - list[bytes]: Lines of bytes read from rfile. + :param int sizehint: hint of minimum amount of data to read + :returns: lines of bytes read from ``rfile`` + :rtype: list[bytes] """ # Shamelessly stolen from StringIO total = 0 @@ -331,7 +330,7 @@ class SizeCheckWrapper: return lines def close(self): - """Release resources allocated for rfile.""" + """Release resources allocated for ``rfile``.""" self.rfile.close() def __iter__(self): @@ -349,28 +348,24 @@ class SizeCheckWrapper: class KnownLengthRFile: - """Wraps a file-like object, returning an empty string when exhausted.""" + """Wraps a file-like object, returning an empty string when exhausted. + + :param rfile: ``file`` of a known size + :param int content_length: length of the file being read + """ def __init__(self, rfile, content_length): - """Initialize KnownLengthRFile instance. - - Args: - rfile (file): file of a known size - content_length (int): length of the file being read - - """ + """Initialize KnownLengthRFile instance.""" self.rfile = rfile self.remaining = content_length def read(self, size=None): - """Read a chunk from rfile buffer and return it. + """Read a chunk from ``rfile`` buffer and return it. - Args: - size (int): amount of data to read - - Returns: - bytes: Chunk from rfile, limited by size if specified. + :param int size: amount of data to read + :rtype: bytes + :returns: chunk from ``rfile``, limited by size if specified """ if self.remaining == 0: return b'' @@ -384,14 +379,12 @@ class KnownLengthRFile: return data def readline(self, size=None): - """Read a single line from rfile buffer and return it. + """Read a single line from ``rfile`` buffer and return it. - Args: - size (int): minimum amount of data to read - - Returns: - bytes: One line from rfile. + :param int size: minimum amount of data to read + :returns: one line from ``rfile`` + :rtype: bytes """ if self.remaining == 0: return b'' @@ -405,14 +398,12 @@ class KnownLengthRFile: return data def readlines(self, sizehint=0): - """Read all lines from rfile buffer and return them. + """Read all lines from ``rfile`` buffer and return them. - Args: - sizehint (int): hint of minimum amount of data to read - - Returns: - list[bytes]: Lines of bytes read from rfile. + :param int sizehint: hint of minimum amount of data to read + :returns: lines of bytes read from ``rfile`` + :rtype: list[bytes] """ # Shamelessly stolen from StringIO total = 0 @@ -427,7 +418,7 @@ class KnownLengthRFile: return lines def close(self): - """Release resources allocated for rfile.""" + """Release resources allocated for ``rfile``.""" self.rfile.close() def __iter__(self): @@ -449,16 +440,14 @@ class ChunkedRFile: This class is intended to provide a conforming wsgi.input value for request entities that have been encoded with the 'chunked' transfer encoding. + + :param rfile: file encoded with the 'chunked' transfer encoding + :param int maxlen: maximum length of the file being read + :param int bufsize: size of the buffer used to read the file """ def __init__(self, rfile, maxlen, bufsize=8192): - """Initialize ChunkedRFile instance. - - Args: - rfile (file): file encoded with the 'chunked' transfer encoding - maxlen (int): maximum length of the file being read - bufsize (int): size of the buffer used to read the file - """ + """Initialize ChunkedRFile instance.""" self.rfile = rfile self.maxlen = maxlen self.bytes_read = 0 @@ -484,7 +473,10 @@ class ChunkedRFile: chunk_size = line.pop(0) chunk_size = int(chunk_size, 16) except ValueError: - raise ValueError('Bad chunked transfer size: ' + repr(chunk_size)) + raise ValueError( + 'Bad chunked transfer size: {chunk_size!r}'. + format(chunk_size=chunk_size), + ) if chunk_size <= 0: self.closed = True @@ -507,14 +499,12 @@ class ChunkedRFile: ) def read(self, size=None): - """Read a chunk from rfile buffer and return it. + """Read a chunk from ``rfile`` buffer and return it. - Args: - size (int): amount of data to read - - Returns: - bytes: Chunk from rfile, limited by size if specified. + :param int size: amount of data to read + :returns: chunk from ``rfile``, limited by size if specified + :rtype: bytes """ data = EMPTY @@ -540,14 +530,12 @@ class ChunkedRFile: self.buffer = EMPTY def readline(self, size=None): - """Read a single line from rfile buffer and return it. + """Read a single line from ``rfile`` buffer and return it. - Args: - size (int): minimum amount of data to read - - Returns: - bytes: One line from rfile. + :param int size: minimum amount of data to read + :returns: one line from ``rfile`` + :rtype: bytes """ data = EMPTY @@ -583,14 +571,12 @@ class ChunkedRFile: self.buffer = self.buffer[newline_pos:] def readlines(self, sizehint=0): - """Read all lines from rfile buffer and return them. + """Read all lines from ``rfile`` buffer and return them. - Args: - sizehint (int): hint of minimum amount of data to read - - Returns: - list[bytes]: Lines of bytes read from rfile. + :param int sizehint: hint of minimum amount of data to read + :returns: lines of bytes read from ``rfile`` + :rtype: list[bytes] """ # Shamelessly stolen from StringIO total = 0 @@ -635,7 +621,7 @@ class ChunkedRFile: yield line def close(self): - """Release resources allocated for rfile.""" + """Release resources allocated for ``rfile``.""" self.rfile.close() @@ -744,7 +730,7 @@ class HTTPRequest: self.ready = True - def read_request_line(self): + def read_request_line(self): # noqa: C901 # FIXME """Read and parse first line of the HTTP request. Returns: @@ -845,7 +831,7 @@ class HTTPRequest: # `urlsplit()` above parses "example.com:3128" as path part of URI. # this is a workaround, which makes it detect netloc correctly - uri_split = urllib.parse.urlsplit(b'//' + uri) + uri_split = urllib.parse.urlsplit(b''.join((b'//', uri))) _scheme, _authority, _path, _qs, _fragment = uri_split _port = EMPTY try: @@ -975,8 +961,14 @@ class HTTPRequest: return True - def read_request_headers(self): - """Read self.rfile into self.inheaders. Return success.""" + def read_request_headers(self): # noqa: C901 # FIXME + """Read ``self.rfile`` into ``self.inheaders``. + + Ref: :py:attr:`self.inheaders `. + + :returns: success status + :rtype: bool + """ # then all the http headers try: self.header_reader(self.rfile, self.inheaders) @@ -1054,8 +1046,10 @@ class HTTPRequest: # Don't use simple_response here, because it emits headers # we don't want. See # https://github.com/cherrypy/cherrypy/issues/951 - msg = self.server.protocol.encode('ascii') - msg += b' 100 Continue\r\n\r\n' + msg = b''.join(( + self.server.protocol.encode('ascii'), SPACE, b'100 Continue', + CRLF, CRLF, + )) try: self.conn.wfile.write(msg) except socket.error as ex: @@ -1138,10 +1132,11 @@ class HTTPRequest: else: self.conn.wfile.write(chunk) - def send_headers(self): + def send_headers(self): # noqa: C901 # FIXME """Assert, process, and send the HTTP response message-headers. - You must set self.status, and self.outheaders before calling this. + You must set ``self.status``, and :py:attr:`self.outheaders + ` before calling this. """ hkeys = [key.lower() for key, value in self.outheaders] status = int(self.status[:3]) @@ -1168,6 +1163,12 @@ class HTTPRequest: # Closing the conn is the only way to determine len. self.close_connection = True + # Override the decision to not close the connection if the connection + # manager doesn't have space for it. + if not self.close_connection: + can_keep = self.server.can_add_keepalive_connection + self.close_connection = not can_keep + if b'connection' not in hkeys: if self.response_protocol == 'HTTP/1.1': # Both server and client are HTTP/1.1 or better @@ -1178,6 +1179,14 @@ class HTTPRequest: if not self.close_connection: self.outheaders.append((b'Connection', b'Keep-Alive')) + if (b'Connection', b'Keep-Alive') in self.outheaders: + self.outheaders.append(( + b'Keep-Alive', + u'timeout={connection_timeout}'. + format(connection_timeout=self.server.timeout). + encode('ISO-8859-1'), + )) + if (not self.close_connection) and (not self.chunked_read): # Read any remaining request body data on the socket. # "If an origin server receives a request that does not include an @@ -1228,9 +1237,7 @@ class HTTPConnection: peercreds_resolve_enabled = False # Fields set by ConnectionManager. - closeable = False last_used = None - ready_with_data = False def __init__(self, server, sock, makefile=MakeFile): """Initialize HTTPConnection instance. @@ -1259,7 +1266,7 @@ class HTTPConnection: lru_cache(maxsize=1)(self.get_peer_creds) ) - def communicate(self): + def communicate(self): # noqa: C901 # FIXME """Read each request and respond appropriately. Returns true if the connection should be kept open. @@ -1351,6 +1358,9 @@ class HTTPConnection: if not self.linger: self._close_kernel_socket() + # close the socket file descriptor + # (will be closed in the OS if there is no + # other reference to the underlying socket) self.socket.close() else: # On the other hand, sometimes we want to hang around for a bit @@ -1426,12 +1436,12 @@ class HTTPConnection: return gid def resolve_peer_creds(self): # LRU cached on per-instance basis - """Return the username and group tuple of the peercreds if available. + """Look up the username and group tuple of the ``PEERCREDS``. - Raises: - NotImplementedError: in case of unsupported OS - RuntimeError: in case of UID/GID lookup unsupported or disabled + :returns: the username and group tuple of the ``PEERCREDS`` + :raises NotImplementedError: if the OS is unsupported + :raises RuntimeError: if UID/GID lookup is unsupported or disabled """ if not IS_UID_GID_RESOLVABLE: raise NotImplementedError( @@ -1462,18 +1472,20 @@ class HTTPConnection: return group def _close_kernel_socket(self): - """Close kernel socket in outdated Python versions. + """Terminate the connection at the transport level.""" + # Honor ``sock_shutdown`` for PyOpenSSL connections. + shutdown = getattr( + self.socket, 'sock_shutdown', + self.socket.shutdown, + ) - On old Python versions, - Python's socket module does NOT call close on the kernel - socket when you call socket.close(). We do so manually here - because we want this server to send a FIN TCP segment - immediately. Note this must be called *before* calling - socket.close(), because the latter drops its reference to - the kernel socket. - """ - if six.PY2 and hasattr(self.socket, '_sock'): - self.socket._sock.close() + try: + shutdown(socket.SHUT_RDWR) # actually send a TCP FIN + except errors.acceptable_sock_shutdown_exceptions: + pass + except socket.error as e: + if e.errno not in errors.acceptable_sock_shutdown_error_codes: + raise class HTTPServer: @@ -1515,7 +1527,12 @@ class HTTPServer: timeout = 10 """The timeout in seconds for accepted connections (default 10).""" - version = 'Cheroot/' + __version__ + expiration_interval = 0.5 + """The interval, in seconds, at which the server checks for + expired connections (default 0.5). + """ + + version = 'Cheroot/{version!s}'.format(version=__version__) """A version string for the HTTPServer.""" software = None @@ -1540,16 +1557,23 @@ class HTTPServer: """The class to use for handling HTTP connections.""" ssl_adapter = None - """An instance of ssl.Adapter (or a subclass). + """An instance of ``ssl.Adapter`` (or a subclass). - You must have the corresponding SSL driver library installed. + Ref: :py:class:`ssl.Adapter `. + + You must have the corresponding TLS driver library installed. """ peercreds_enabled = False - """If True, peer cred lookup can be performed via UNIX domain socket.""" + """ + If :py:data:`True`, peer creds will be looked up via UNIX domain socket. + """ peercreds_resolve_enabled = False - """If True, username/group will be looked up in the OS from peercreds.""" + """ + If :py:data:`True`, username/group will be looked up in the OS from + ``PEERCREDS``-provided IDs. + """ keep_alive_conn_limit = 10 """The maximum number of waiting keep-alive connections that will be kept open. @@ -1577,7 +1601,6 @@ class HTTPServer: self.requests = threadpool.ThreadPool( self, min=minthreads or 1, max=maxthreads, ) - self.connections = connections.ConnectionManager(self) if not server_name: server_name = self.version @@ -1603,25 +1626,29 @@ class HTTPServer: 'Threads Idle': lambda s: getattr(self.requests, 'idle', None), 'Socket Errors': 0, 'Requests': lambda s: (not s['Enabled']) and -1 or sum( - [w['Requests'](w) for w in s['Worker Threads'].values()], 0, + (w['Requests'](w) for w in s['Worker Threads'].values()), 0, ), 'Bytes Read': lambda s: (not s['Enabled']) and -1 or sum( - [w['Bytes Read'](w) for w in s['Worker Threads'].values()], 0, + (w['Bytes Read'](w) for w in s['Worker Threads'].values()), 0, ), 'Bytes Written': lambda s: (not s['Enabled']) and -1 or sum( - [w['Bytes Written'](w) for w in s['Worker Threads'].values()], + (w['Bytes Written'](w) for w in s['Worker Threads'].values()), 0, ), 'Work Time': lambda s: (not s['Enabled']) and -1 or sum( - [w['Work Time'](w) for w in s['Worker Threads'].values()], 0, + (w['Work Time'](w) for w in s['Worker Threads'].values()), 0, ), 'Read Throughput': lambda s: (not s['Enabled']) and -1 or sum( - [w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6) - for w in s['Worker Threads'].values()], 0, + ( + w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values() + ), 0, ), 'Write Throughput': lambda s: (not s['Enabled']) and -1 or sum( - [w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6) - for w in s['Worker Threads'].values()], 0, + ( + w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6) + for w in s['Worker Threads'].values() + ), 0, ), 'Worker Threads': {}, } @@ -1645,17 +1672,27 @@ class HTTPServer: def bind_addr(self): """Return the interface on which to listen for connections. - For TCP sockets, a (host, port) tuple. Host values may be any IPv4 - or IPv6 address, or any valid hostname. The string 'localhost' is a - synonym for '127.0.0.1' (or '::1', if your hosts file prefers IPv6). - The string '0.0.0.0' is a special IPv4 entry meaning "any active - interface" (INADDR_ANY), and '::' is the similar IN6ADDR_ANY for - IPv6. The empty string or None are not allowed. + For TCP sockets, a (host, port) tuple. Host values may be any + :term:`IPv4` or :term:`IPv6` address, or any valid hostname. + The string 'localhost' is a synonym for '127.0.0.1' (or '::1', + if your hosts file prefers :term:`IPv6`). + The string '0.0.0.0' is a special :term:`IPv4` entry meaning + "any active interface" (INADDR_ANY), and '::' is the similar + IN6ADDR_ANY for :term:`IPv6`. + The empty string or :py:data:`None` are not allowed. - For UNIX sockets, supply the filename as a string. + For UNIX sockets, supply the file name as a string. Systemd socket activation is automatic and doesn't require tempering with this variable. + + .. glossary:: + + :abbr:`IPv4 (Internet Protocol version 4)` + Internet Protocol version 4 + + :abbr:`IPv6 (Internet Protocol version 6)` + Internet Protocol version 6 """ return self._bind_addr @@ -1695,7 +1732,7 @@ class HTTPServer: self.stop() raise - def prepare(self): + def prepare(self): # noqa: C901 # FIXME """Prepare server to serving requests. It binds a socket's port, setups the socket to ``listen()`` and does @@ -1757,6 +1794,9 @@ class HTTPServer: self.socket.settimeout(1) self.socket.listen(self.request_queue_size) + # must not be accessed once stop() has been called + self._connections = connections.ConnectionManager(self) + # Create worker threads self.requests.start() @@ -1765,23 +1805,24 @@ class HTTPServer: def serve(self): """Serve requests, after invoking :func:`prepare()`.""" - while self.ready: + while self.ready and not self.interrupt: try: - self.tick() + self._connections.run(self.expiration_interval) except (KeyboardInterrupt, SystemExit): raise except Exception: self.error_log( - 'Error in HTTPServer.tick', level=logging.ERROR, + 'Error in HTTPServer.serve', level=logging.ERROR, traceback=True, ) + # raise exceptions reported by any worker threads, + # such that the exception is raised from the serve() thread. + if self.interrupt: + while self._stopping_for_interrupt: + time.sleep(0.1) if self.interrupt: - while self.interrupt is True: - # Wait for self.stop() to complete. See _set_interrupt. - time.sleep(0.1) - if self.interrupt: - raise self.interrupt + raise self.interrupt def start(self): """Run the server forever. @@ -1795,6 +1836,31 @@ class HTTPServer: self.prepare() self.serve() + @contextlib.contextmanager + def _run_in_thread(self): + """Context manager for running this server in a thread.""" + self.prepare() + thread = threading.Thread(target=self.serve) + thread.setDaemon(True) + thread.start() + try: + yield thread + finally: + self.stop() + + @property + def can_add_keepalive_connection(self): + """Flag whether it is allowed to add a new keep-alive connection.""" + return self.ready and self._connections.can_add_keepalive_connection + + def put_conn(self, conn): + """Put an idle connection back into the ConnectionManager.""" + if self.ready: + self._connections.put(conn) + else: + # server is shutting down, just close it + conn.close() + def error_log(self, msg='', level=20, traceback=False): """Write error message to log. @@ -1804,7 +1870,7 @@ class HTTPServer: traceback (bool): add traceback to output or not """ # Override this in subclasses as desired - sys.stderr.write(msg + '\n') + sys.stderr.write('{msg!s}\n'.format(msg=msg)) sys.stderr.flush() if traceback: tblines = traceback_.format_exc() @@ -1822,7 +1888,7 @@ class HTTPServer: self.bind_addr = self.resolve_real_bind_addr(sock) return sock - def bind_unix_socket(self, bind_addr): + def bind_unix_socket(self, bind_addr): # noqa: C901 # FIXME """Create (or recreate) a UNIX socket object.""" if IS_WINDOWS: """ @@ -1965,7 +2031,7 @@ class HTTPServer: @staticmethod def resolve_real_bind_addr(socket_): - """Retrieve actual bind addr from bound socket.""" + """Retrieve actual bind address from bound socket.""" # FIXME: keep requested bind_addr separate real bound_addr (port # is different in case of ephemeral port 0) bind_addr = socket_.getsockname() @@ -1985,40 +2051,49 @@ class HTTPServer: return bind_addr - def tick(self): - """Accept a new connection and put it on the Queue.""" - if not self.ready: - return - - conn = self.connections.get_conn(self.socket) - if conn: - try: - self.requests.put(conn) - except queue.Full: - # Just drop the conn. TODO: write 503 back? - conn.close() - - self.connections.expire() + def process_conn(self, conn): + """Process an incoming HTTPConnection.""" + try: + self.requests.put(conn) + except queue.Full: + # Just drop the conn. TODO: write 503 back? + conn.close() @property def interrupt(self): """Flag interrupt of the server.""" return self._interrupt + @property + def _stopping_for_interrupt(self): + """Return whether the server is responding to an interrupt.""" + return self._interrupt is _STOPPING_FOR_INTERRUPT + @interrupt.setter def interrupt(self, interrupt): - """Perform the shutdown of this server and save the exception.""" - self._interrupt = True + """Perform the shutdown of this server and save the exception. + + Typically invoked by a worker thread in + :py:mod:`~cheroot.workers.threadpool`, the exception is raised + from the thread running :py:meth:`serve` once :py:meth:`stop` + has completed. + """ + self._interrupt = _STOPPING_FOR_INTERRUPT self.stop() self._interrupt = interrupt - def stop(self): + def stop(self): # noqa: C901 # FIXME """Gracefully shutdown a server that is serving forever.""" + if not self.ready: + return # already stopped + self.ready = False if self._start_time is not None: self._run_time += (time.time() - self._start_time) self._start_time = None + self._connections.stop() + sock = getattr(self, 'socket', None) if sock: if not isinstance( @@ -2060,7 +2135,7 @@ class HTTPServer: sock.close() self.socket = None - self.connections.close() + self._connections.close() self.requests.stop(self.shutdown_timeout) @@ -2108,7 +2183,9 @@ def get_ssl_adapter_class(name='builtin'): try: adapter = getattr(mod, attr_name) except AttributeError: - raise AttributeError("'%s' object has no attribute '%s'" - % (mod_path, attr_name)) + raise AttributeError( + "'%s' object has no attribute '%s'" + % (mod_path, attr_name), + ) return adapter diff --git a/lib/cheroot/ssl/builtin.py b/lib/cheroot/ssl/builtin.py index d131b2f4..ff987a71 100644 --- a/lib/cheroot/ssl/builtin.py +++ b/lib/cheroot/ssl/builtin.py @@ -1,7 +1,7 @@ """ -A library for integrating Python's builtin ``ssl`` library with Cheroot. +A library for integrating Python's builtin :py:mod:`ssl` library with Cheroot. -The ssl module must be importable for SSL functionality. +The :py:mod:`ssl` module must be importable for SSL functionality. To use this module, set ``HTTPServer.ssl_adapter`` to an instance of ``BuiltinSSLAdapter``. @@ -10,6 +10,10 @@ To use this module, set ``HTTPServer.ssl_adapter`` to an instance of from __future__ import absolute_import, division, print_function __metaclass__ = type +import socket +import sys +import threading + try: import ssl except ImportError: @@ -27,13 +31,12 @@ import six from . import Adapter from .. import errors -from .._compat import IS_ABOVE_OPENSSL10 +from .._compat import IS_ABOVE_OPENSSL10, suppress from ..makefile import StreamReader, StreamWriter +from ..server import HTTPServer if six.PY2: - import socket generic_socket_error = socket.error - del socket else: generic_socket_error = OSError @@ -49,37 +52,159 @@ def _assert_ssl_exc_contains(exc, *msgs): return any(m.lower() in err_msg_lower for m in msgs) +def _loopback_for_cert_thread(context, server): + """Wrap a socket in ssl and perform the server-side handshake.""" + # As we only care about parsing the certificate, the failure of + # which will cause an exception in ``_loopback_for_cert``, + # we can safely ignore connection and ssl related exceptions. Ref: + # https://github.com/cherrypy/cheroot/issues/302#issuecomment-662592030 + with suppress(ssl.SSLError, OSError): + with context.wrap_socket( + server, do_handshake_on_connect=True, server_side=True, + ) as ssl_sock: + # in TLS 1.3 (Python 3.7+, OpenSSL 1.1.1+), the server + # sends the client session tickets that can be used to + # resume the TLS session on a new connection without + # performing the full handshake again. session tickets are + # sent as a post-handshake message at some _unspecified_ + # time and thus a successful connection may be closed + # without the client having received the tickets. + # Unfortunately, on Windows (Python 3.8+), this is treated + # as an incomplete handshake on the server side and a + # ``ConnectionAbortedError`` is raised. + # TLS 1.3 support is still incomplete in Python 3.8; + # there is no way for the client to wait for tickets. + # While not necessary for retrieving the parsed certificate, + # we send a tiny bit of data over the connection in an + # attempt to give the server a chance to send the session + # tickets and close the connection cleanly. + # Note that, as this is essentially a race condition, + # the error may still occur ocasionally. + ssl_sock.send(b'0000') + + +def _loopback_for_cert(certificate, private_key, certificate_chain): + """Create a loopback connection to parse a cert with a private key.""" + context = ssl.create_default_context(cafile=certificate_chain) + context.load_cert_chain(certificate, private_key) + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + # Python 3+ Unix, Python 3.5+ Windows + client, server = socket.socketpair() + try: + # `wrap_socket` will block until the ssl handshake is complete. + # it must be called on both ends at the same time -> thread + # openssl will cache the peer's cert during a successful handshake + # and return it via `getpeercert` even after the socket is closed. + # when `close` is called, the SSL shutdown notice will be sent + # and then python will wait to receive the corollary shutdown. + thread = threading.Thread( + target=_loopback_for_cert_thread, args=(context, server), + ) + try: + thread.start() + with context.wrap_socket( + client, do_handshake_on_connect=True, + server_side=False, + ) as ssl_sock: + ssl_sock.recv(4) + return ssl_sock.getpeercert() + finally: + thread.join() + finally: + client.close() + server.close() + + +def _parse_cert(certificate, private_key, certificate_chain): + """Parse a certificate.""" + # loopback_for_cert uses socket.socketpair which was only + # introduced in Python 3.0 for *nix and 3.5 for Windows + # and requires OS support (AttributeError, OSError) + # it also requires a private key either in its own file + # or combined with the cert (SSLError) + with suppress(AttributeError, ssl.SSLError, OSError): + return _loopback_for_cert(certificate, private_key, certificate_chain) + + # KLUDGE: using an undocumented, private, test method to parse a cert + # unfortunately, it is the only built-in way without a connection + # as a private, undocumented method, it may change at any time + # so be tolerant of *any* possible errors it may raise + with suppress(Exception): + return ssl._ssl._test_decode_cert(certificate) + + return {} + + +def _sni_callback(sock, sni, context): + """Handle the SNI callback to tag the socket with the SNI.""" + sock.sni = sni + # return None to allow the TLS negotiation to continue + + class BuiltinSSLAdapter(Adapter): - """A wrapper for integrating Python's builtin ssl module with Cheroot.""" + """Wrapper for integrating Python's builtin :py:mod:`ssl` with Cheroot.""" certificate = None - """The filename of the server SSL certificate.""" + """The file name of the server SSL certificate.""" private_key = None - """The filename of the server's private key file.""" + """The file name of the server's private key file.""" certificate_chain = None - """The filename of the certificate chain file.""" - - context = None - """The ssl.SSLContext that will be used to wrap sockets.""" + """The file name of the certificate chain file.""" ciphers = None """The ciphers list of SSL.""" + # from mod_ssl/pkg.sslmod/ssl_engine_vars.c ssl_var_lookup_ssl_cert CERT_KEY_TO_ENV = { - 'subject': 'SSL_CLIENT_S_DN', - 'issuer': 'SSL_CLIENT_I_DN', + 'version': 'M_VERSION', + 'serialNumber': 'M_SERIAL', + 'notBefore': 'V_START', + 'notAfter': 'V_END', + 'subject': 'S_DN', + 'issuer': 'I_DN', + 'subjectAltName': 'SAN', + # not parsed by the Python standard library + # - A_SIG + # - A_KEY + # not provided by mod_ssl + # - OCSP + # - caIssuers + # - crlDistributionPoints } + # from mod_ssl/pkg.sslmod/ssl_engine_vars.c ssl_var_lookup_ssl_cert_dn_rec CERT_KEY_TO_LDAP_CODE = { 'countryName': 'C', 'stateOrProvinceName': 'ST', + # NOTE: mod_ssl also provides 'stateOrProvinceName' as 'SP' + # for compatibility with SSLeay 'localityName': 'L', 'organizationName': 'O', 'organizationalUnitName': 'OU', 'commonName': 'CN', + 'title': 'T', + 'initials': 'I', + 'givenName': 'G', + 'surname': 'S', + 'description': 'D', + 'userid': 'UID', 'emailAddress': 'Email', + # not provided by mod_ssl + # - dnQualifier: DNQ + # - domainComponent: DC + # - postalCode: PC + # - streetAddress: STREET + # - serialNumber + # - generationQualifier + # - pseudonym + # - jurisdictionCountryName + # - jurisdictionLocalityName + # - jurisdictionStateOrProvince + # - businessCategory } def __init__( @@ -102,6 +227,45 @@ class BuiltinSSLAdapter(Adapter): if self.ciphers is not None: self.context.set_ciphers(ciphers) + self._server_env = self._make_env_cert_dict( + 'SSL_SERVER', + _parse_cert(certificate, private_key, self.certificate_chain), + ) + if not self._server_env: + return + cert = None + with open(certificate, mode='rt') as f: + cert = f.read() + + # strip off any keys by only taking the first certificate + cert_start = cert.find(ssl.PEM_HEADER) + if cert_start == -1: + return + cert_end = cert.find(ssl.PEM_FOOTER, cert_start) + if cert_end == -1: + return + cert_end += len(ssl.PEM_FOOTER) + self._server_env['SSL_SERVER_CERT'] = cert[cert_start:cert_end] + + @property + def context(self): + """:py:class:`~ssl.SSLContext` that will be used to wrap sockets.""" + return self._context + + @context.setter + def context(self, context): + """Set the ssl ``context`` to use.""" + self._context = context + # Python 3.7+ + # if a context is provided via `cherrypy.config.update` then + # `self.context` will be set after `__init__` + # use a property to intercept it to add an SNI callback + # but don't override the user's callback + # TODO: chain callbacks + with suppress(AttributeError): + if ssl.HAS_SNI and context.sni_callback is None: + context.sni_callback = _sni_callback + def bind(self, sock): """Wrap and return the given socket.""" return super(BuiltinSSLAdapter, self).bind(sock) @@ -135,6 +299,8 @@ class BuiltinSSLAdapter(Adapter): 'no shared cipher', 'certificate unknown', 'ccs received early', 'certificate verify failed', # client cert w/o trusted CA + 'version too low', # caused by SSL3 connections + 'unsupported protocol', # caused by TLS1 connections ) if _assert_ssl_exc_contains(ex, *_block_errors): # Accepted error, let's pass @@ -148,7 +314,8 @@ class BuiltinSSLAdapter(Adapter): except generic_socket_error as exc: """It is unclear why exactly this happens. - It's reproducible only with openssl>1.0 and stdlib ``ssl`` wrapper. + It's reproducible only with openssl>1.0 and stdlib + :py:mod:`ssl` wrapper. In CherryPy it's triggered by Checker plugin, which connects to the app listening to the socket port in TLS mode via plain HTTP during startup (from the same process). @@ -163,7 +330,6 @@ class BuiltinSSLAdapter(Adapter): raise return s, self.get_environ(s) - # TODO: fill this out more with mod ssl env def get_environ(self, sock): """Create WSGI environ entries to be merged into each request.""" cipher = sock.cipher() @@ -172,22 +338,117 @@ class BuiltinSSLAdapter(Adapter): 'HTTPS': 'on', 'SSL_PROTOCOL': cipher[1], 'SSL_CIPHER': cipher[0], - # SSL_VERSION_INTERFACE string The mod_ssl program version - # SSL_VERSION_LIBRARY string The OpenSSL program version + 'SSL_CIPHER_EXPORT': '', + 'SSL_CIPHER_USEKEYSIZE': cipher[2], + 'SSL_VERSION_INTERFACE': '%s Python/%s' % ( + HTTPServer.version, sys.version, + ), + 'SSL_VERSION_LIBRARY': ssl.OPENSSL_VERSION, + 'SSL_CLIENT_VERIFY': 'NONE', + # 'NONE' - client did not provide a cert (overriden below) } + # Python 3.3+ + with suppress(AttributeError): + compression = sock.compression() + if compression is not None: + ssl_environ['SSL_COMPRESS_METHOD'] = compression + + # Python 3.6+ + with suppress(AttributeError): + ssl_environ['SSL_SESSION_ID'] = sock.session.id.hex() + with suppress(AttributeError): + target_cipher = cipher[:2] + for cip in sock.context.get_ciphers(): + if target_cipher == (cip['name'], cip['protocol']): + ssl_environ['SSL_CIPHER_ALGKEYSIZE'] = cip['alg_bits'] + break + + # Python 3.7+ sni_callback + with suppress(AttributeError): + ssl_environ['SSL_TLS_SNI'] = sock.sni + if self.context and self.context.verify_mode != ssl.CERT_NONE: client_cert = sock.getpeercert() if client_cert: - for cert_key, env_var in self.CERT_KEY_TO_ENV.items(): - ssl_environ.update( - self.env_dn_dict(env_var, client_cert.get(cert_key)), - ) + # builtin ssl **ALWAYS** validates client certificates + # and terminates the connection on failure + ssl_environ['SSL_CLIENT_VERIFY'] = 'SUCCESS' + ssl_environ.update( + self._make_env_cert_dict('SSL_CLIENT', client_cert), + ) + ssl_environ['SSL_CLIENT_CERT'] = ssl.DER_cert_to_PEM_cert( + sock.getpeercert(binary_form=True), + ).strip() + + ssl_environ.update(self._server_env) + + # not supplied by the Python standard library (as of 3.8) + # - SSL_SESSION_RESUMED + # - SSL_SECURE_RENEG + # - SSL_CLIENT_CERT_CHAIN_n + # - SRP_USER + # - SRP_USERINFO return ssl_environ - def env_dn_dict(self, env_prefix, cert_value): - """Return a dict of WSGI environment variables for a client cert DN. + def _make_env_cert_dict(self, env_prefix, parsed_cert): + """Return a dict of WSGI environment variables for a certificate. + + E.g. SSL_CLIENT_M_VERSION, SSL_CLIENT_M_SERIAL, etc. + See https://httpd.apache.org/docs/2.4/mod/mod_ssl.html#envvars. + """ + if not parsed_cert: + return {} + + env = {} + for cert_key, env_var in self.CERT_KEY_TO_ENV.items(): + key = '%s_%s' % (env_prefix, env_var) + value = parsed_cert.get(cert_key) + if env_var == 'SAN': + env.update(self._make_env_san_dict(key, value)) + elif env_var.endswith('_DN'): + env.update(self._make_env_dn_dict(key, value)) + else: + env[key] = str(value) + + # mod_ssl 2.1+; Python 3.2+ + # number of days until the certificate expires + if 'notBefore' in parsed_cert: + remain = ssl.cert_time_to_seconds(parsed_cert['notAfter']) + remain -= ssl.cert_time_to_seconds(parsed_cert['notBefore']) + remain /= 60 * 60 * 24 + env['%s_V_REMAIN' % (env_prefix,)] = str(int(remain)) + + return env + + def _make_env_san_dict(self, env_prefix, cert_value): + """Return a dict of WSGI environment variables for a certificate DN. + + E.g. SSL_CLIENT_SAN_Email_0, SSL_CLIENT_SAN_DNS_0, etc. + See SSL_CLIENT_SAN_* at + https://httpd.apache.org/docs/2.4/mod/mod_ssl.html#envvars. + """ + if not cert_value: + return {} + + env = {} + dns_count = 0 + email_count = 0 + for attr_name, val in cert_value: + if attr_name == 'DNS': + env['%s_DNS_%i' % (env_prefix, dns_count)] = val + dns_count += 1 + elif attr_name == 'Email': + env['%s_Email_%i' % (env_prefix, email_count)] = val + email_count += 1 + + # other mod_ssl SAN vars: + # - SAN_OTHER_msUPN_n + return env + + def _make_env_dn_dict(self, env_prefix, cert_value): + """Return a dict of WSGI environment variables for a certificate DN. E.g. SSL_CLIENT_S_DN_CN, SSL_CLIENT_S_DN_C, etc. See SSL_CLIENT_S_DN_x509 at @@ -196,12 +457,26 @@ class BuiltinSSLAdapter(Adapter): if not cert_value: return {} - env = {} + dn = [] + dn_attrs = {} for rdn in cert_value: for attr_name, val in rdn: attr_code = self.CERT_KEY_TO_LDAP_CODE.get(attr_name) - if attr_code: - env['%s_%s' % (env_prefix, attr_code)] = val + dn.append('%s=%s' % (attr_code or attr_name, val)) + if not attr_code: + continue + dn_attrs.setdefault(attr_code, []) + dn_attrs[attr_code].append(val) + + env = { + env_prefix: ','.join(dn), + } + for attr_code, values in dn_attrs.items(): + env['%s_%s' % (env_prefix, attr_code)] = ','.join(values) + if len(values) == 1: + continue + for i, val in enumerate(values): + env['%s_%s_%i' % (env_prefix, attr_code, i)] = val return env def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): diff --git a/lib/cheroot/ssl/pyopenssl.py b/lib/cheroot/ssl/pyopenssl.py index f51be0d8..adc9a1ba 100644 --- a/lib/cheroot/ssl/pyopenssl.py +++ b/lib/cheroot/ssl/pyopenssl.py @@ -1,46 +1,67 @@ """ -A library for integrating pyOpenSSL with Cheroot. +A library for integrating :doc:`pyOpenSSL ` with Cheroot. -The OpenSSL module must be importable for SSL functionality. -You can obtain it from `here `_. +The :py:mod:`OpenSSL ` module must be importable +for SSL/TLS/HTTPS functionality. +You can obtain it from `here `_. -To use this module, set HTTPServer.ssl_adapter to an instance of -ssl.Adapter. There are two ways to use SSL: +To use this module, set :py:attr:`HTTPServer.ssl_adapter +` to an instance of +:py:class:`ssl.Adapter `. +There are two ways to use :abbr:`TLS (Transport-Level Security)`: Method One ---------- - * ``ssl_adapter.context``: an instance of SSL.Context. + * :py:attr:`ssl_adapter.context + `: an instance of + :py:class:`SSL.Context `. -If this is not None, it is assumed to be an SSL.Context instance, -and will be passed to SSL.Connection on bind(). The developer is -responsible for forming a valid Context object. This approach is -to be preferred for more flexibility, e.g. if the cert and key are -streams instead of files, or need decryption, or SSL.SSLv3_METHOD -is desired instead of the default SSL.SSLv23_METHOD, etc. Consult -the pyOpenSSL documentation for complete options. +If this is not None, it is assumed to be an :py:class:`SSL.Context +` instance, and will be passed to +:py:class:`SSL.Connection ` on bind(). +The developer is responsible for forming a valid :py:class:`Context +` object. This +approach is to be preferred for more flexibility, e.g. if the cert and +key are streams instead of files, or need decryption, or +:py:data:`SSL.SSLv3_METHOD ` +is desired instead of the default :py:data:`SSL.SSLv23_METHOD +`, etc. Consult +the :doc:`pyOpenSSL ` documentation for +complete options. Method Two (shortcut) --------------------- - * ``ssl_adapter.certificate``: the filename of the server SSL certificate. - * ``ssl_adapter.private_key``: the filename of the server's private key file. + * :py:attr:`ssl_adapter.certificate + `: the file name + of the server's TLS certificate. + * :py:attr:`ssl_adapter.private_key + `: the file name + of the server's private key file. -Both are None by default. If ssl_adapter.context is None, but .private_key -and .certificate are both given and valid, they will be read, and the -context will be automatically created from them. +Both are :py:data:`None` by default. If :py:attr:`ssl_adapter.context +` is :py:data:`None`, +but ``.private_key`` and ``.certificate`` are both given and valid, they +will be read, and the context will be automatically created from them. + +.. spelling:: + + pyopenssl """ from __future__ import absolute_import, division, print_function __metaclass__ = type import socket +import sys import threading import time import six try: + import OpenSSL.version from OpenSSL import SSL from OpenSSL import crypto @@ -57,13 +78,14 @@ from ..makefile import StreamReader, StreamWriter class SSLFileobjectMixin: - """Base mixin for an SSL socket stream.""" + """Base mixin for a TLS socket stream.""" ssl_timeout = 3 ssl_retry = .01 - def _safe_call(self, is_reader, call, *args, **kwargs): - """Wrap the given call with SSL error-trapping. + # FIXME: + def _safe_call(self, is_reader, call, *args, **kwargs): # noqa: C901 + """Wrap the given call with TLS error-trapping. is_reader: if False EOF errors will be raised. If True, EOF errors will return "" (to emulate normal sockets). @@ -209,9 +231,11 @@ class SSLConnectionProxyMeta: @six.add_metaclass(SSLConnectionProxyMeta) class SSLConnection: - """A thread-safe wrapper for an SSL.Connection. + r"""A thread-safe wrapper for an ``SSL.Connection``. - ``*args``: the arguments to create the wrapped ``SSL.Connection(*args)``. + :param tuple args: the arguments to create the wrapped \ + :py:class:`SSL.Connection(*args) \ + ` """ def __init__(self, *args): @@ -224,22 +248,24 @@ class pyOpenSSLAdapter(Adapter): """A wrapper for integrating pyOpenSSL with Cheroot.""" certificate = None - """The filename of the server SSL certificate.""" + """The file name of the server's TLS certificate.""" private_key = None - """The filename of the server's private key file.""" + """The file name of the server's private key file.""" certificate_chain = None - """Optional. The filename of CA's intermediate certificate bundle. + """Optional. The file name of CA's intermediate certificate bundle. - This is needed for cheaper "chained root" SSL certificates, and should be - left as None if not required.""" + This is needed for cheaper "chained root" TLS certificates, + and should be left as :py:data:`None` if not required.""" context = None - """An instance of SSL.Context.""" + """ + An instance of :py:class:`SSL.Context `. + """ ciphers = None - """The ciphers list of SSL.""" + """The ciphers list of TLS.""" def __init__( self, certificate, private_key, certificate_chain=None, @@ -265,10 +291,16 @@ class pyOpenSSLAdapter(Adapter): def wrap(self, sock): """Wrap and return the given socket, plus WSGI environ entries.""" + # pyOpenSSL doesn't perform the handshake until the first read/write + # forcing the handshake to complete tends to result in the connection + # closing so we can't reliably access protocol/client cert for the env return sock, self._environ.copy() def get_context(self): - """Return an SSL.Context from self attributes.""" + """Return an ``SSL.Context`` from self attributes. + + Ref: :py:class:`SSL.Context ` + """ # See https://code.activestate.com/recipes/442473/ c = SSL.Context(SSL.SSLv23_METHOD) c.use_privatekey_file(self.private_key) @@ -280,18 +312,25 @@ class pyOpenSSLAdapter(Adapter): def get_environ(self): """Return WSGI environ entries to be merged into each request.""" ssl_environ = { + 'wsgi.url_scheme': 'https', 'HTTPS': 'on', - # pyOpenSSL doesn't provide access to any of these AFAICT - # 'SSL_PROTOCOL': 'SSLv2', - # SSL_CIPHER string The cipher specification name - # SSL_VERSION_INTERFACE string The mod_ssl program version - # SSL_VERSION_LIBRARY string The OpenSSL program version + 'SSL_VERSION_INTERFACE': '%s %s/%s Python/%s' % ( + cheroot_server.HTTPServer.version, + OpenSSL.version.__title__, OpenSSL.version.__version__, + sys.version, + ), + 'SSL_VERSION_LIBRARY': SSL.SSLeay_version( + SSL.SSLEAY_VERSION, + ).decode(), } if self.certificate: # Server certificate attributes - cert = open(self.certificate, 'rb').read() - cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert) + with open(self.certificate, 'rb') as cert_file: + cert = crypto.load_certificate( + crypto.FILETYPE_PEM, cert_file.read(), + ) + ssl_environ.update({ 'SSL_SERVER_M_VERSION': cert.get_version(), 'SSL_SERVER_M_SERIAL': cert.get_serial_number(), diff --git a/lib/cheroot/test/_pytest_plugin.py b/lib/cheroot/test/_pytest_plugin.py new file mode 100644 index 00000000..2bba9aa9 --- /dev/null +++ b/lib/cheroot/test/_pytest_plugin.py @@ -0,0 +1,38 @@ +"""Local pytest plugin. + +Contains hooks, which are tightly bound to the Cheroot framework +itself, useless for end-users' app testing. +""" + +from __future__ import absolute_import, division, print_function +__metaclass__ = type + +import pytest + + +pytest_version = tuple(map(int, pytest.__version__.split('.'))) + + +def pytest_load_initial_conftests(early_config, parser, args): + """Drop unfilterable warning ignores.""" + if pytest_version < (6, 2, 0): + return + + # pytest>=6.2.0 under Python 3.8: + # Refs: + # * https://docs.pytest.org/en/stable/usage.html#unraisable + # * https://github.com/pytest-dev/pytest/issues/5299 + early_config._inicache['filterwarnings'].extend(( + 'ignore:Exception in thread CP Server Thread-:' + 'pytest.PytestUnhandledThreadExceptionWarning:_pytest.threadexception', + 'ignore:Exception in thread Thread-:' + 'pytest.PytestUnhandledThreadExceptionWarning:_pytest.threadexception', + 'ignore:Exception ignored in. ' + '` on the server socket when + :py:meth:`socket.shutdown() ` is called. It's + triggered by closing the client socket before the server had a chance + to respond. + + The expectation is that only :py:exc:`RuntimeError` and a + :py:exc:`socket.error` with an unusual error code would leak. + + With the :py:data:`None`-parameter, a real non-simulated + :py:exc:`OSError(107, 'Transport endpoint is not connected') + ` happens. + """ + exc_instance = ( + None if simulated_exception is None + else simulated_exception(error_number, 'Simulated socket error') + ) + old_close_kernel_socket = ( + test_client.server_instance. + ConnectionClass._close_kernel_socket + ) + + def _close_kernel_socket(self): + monkeypatch.setattr( # `socket.shutdown` is read-only otherwise + self, 'socket', + mocker.mock_module.Mock(wraps=self.socket), + ) + if exc_instance is not None: + monkeypatch.setattr( + self.socket, 'shutdown', + mocker.mock_module.Mock(side_effect=exc_instance), + ) + _close_kernel_socket.fin_spy = mocker.spy(self.socket, 'shutdown') + _close_kernel_socket.exception_leaked = True + old_close_kernel_socket(self) + _close_kernel_socket.exception_leaked = False + + monkeypatch.setattr( + test_client.server_instance.ConnectionClass, + '_close_kernel_socket', + _close_kernel_socket, + ) + + conn = test_client.get_connection() + conn.auto_open = False + conn.connect() + conn.send(b'GET /hello HTTP/1.1') + conn.send(('Host: %s' % conn.host).encode('ascii')) + conn.close() + + for _ in range(10): # Let the server attempt TCP shutdown + time.sleep(0.1) + if hasattr(_close_kernel_socket, 'exception_leaked'): + break + + if exc_instance is not None: # simulated by us + assert _close_kernel_socket.fin_spy.spy_exception is exc_instance + else: # real + assert isinstance( + _close_kernel_socket.fin_spy.spy_exception, socket.error, + ) + assert _close_kernel_socket.fin_spy.spy_exception.errno == error_number + + assert _close_kernel_socket.exception_leaked is exception_leaks + + @pytest.mark.parametrize( 'timeout_before_headers', ( @@ -475,7 +694,7 @@ def test_keepalive_conn_management(test_client): def test_HTTP11_Timeout(test_client, timeout_before_headers): """Check timeout without sending any data. - The server will close the conn with a 408. + The server will close the connection with a 408. """ conn = test_client.get_connection() conn.auto_open = False @@ -594,7 +813,7 @@ def test_HTTP11_Timeout_after_request(test_client): def test_HTTP11_pipelining(test_client): """Test HTTP/1.1 pipelining. - httplib doesn't support this directly. + :py:mod:`http.client` doesn't support this directly. """ conn = test_client.get_connection() @@ -639,7 +858,7 @@ def test_100_Continue(test_client): conn = test_client.get_connection() # Try a page without an Expect request header first. - # Note that httplib's response.begin automatically ignores + # Note that http.client's response.begin automatically ignores # 100 Continue responses, so we must manually check for it. conn.putrequest('POST', '/upload', skip_host=True) conn.putheader('Host', conn.host) @@ -800,11 +1019,13 @@ def test_No_Message_Body(test_client): @pytest.mark.xfail( - reason='Server does not correctly read trailers/ending of the previous ' - 'HTTP request, thus the second request fails as the server tries ' - r"to parse b'Content-Type: application/json\r\n' as a " - 'Request-Line. This results in HTTP status code 400, instead of 413' - 'Ref: https://github.com/cherrypy/cheroot/issues/69', + reason=unwrap( + trim(""" + Headers from earlier request leak into the request + line for a subsequent request, resulting in 400 + instead of 413. See cherrypy/cheroot#69 for details. + """), + ), ) def test_Chunked_Encoding(test_client): """Test HTTP uploads with chunked transfer-encoding.""" @@ -837,7 +1058,7 @@ def test_Chunked_Encoding(test_client): # Try a chunked request that exceeds server.max_request_body_size. # Note that the delimiters and trailer are included. - body = b'3e3\r\n' + (b'x' * 995) + b'\r\n0\r\n\r\n' + body = b'\r\n'.join((b'3e3', b'x' * 995, b'0', b'', b'')) conn.putrequest('POST', '/upload', skip_host=True) conn.putheader('Host', conn.host) conn.putheader('Transfer-Encoding', 'chunked') @@ -895,7 +1116,7 @@ def test_Content_Length_not_int(test_client): @pytest.mark.parametrize( - 'uri,expected_resp_status,expected_resp_body', + ('uri', 'expected_resp_status', 'expected_resp_body'), ( ( '/wrong_cl_buffered', 500, @@ -929,6 +1150,16 @@ def test_Content_Length_out( conn.close() + # the server logs the exception that we had verified from the + # client perspective. Tell the error_log verification that + # it can ignore that message. + test_client.server_instance.error_log.ignored_msgs.extend(( + # Python 3.7+: + "ValueError('Response body exceeds the declared Content-Length.')", + # Python 2.7-3.6 (macOS?): + "ValueError('Response body exceeds the declared Content-Length.',)", + )) + @pytest.mark.xfail( reason='Sometimes this test fails due to low timeout. ' @@ -970,11 +1201,94 @@ def test_No_CRLF(test_client, invalid_terminator): # Initialize a persistent HTTP connection conn = test_client.get_connection() - # (b'%s' % b'') is not supported in Python 3.4, so just use + - conn.send(b'GET /hello HTTP/1.1' + invalid_terminator) + # (b'%s' % b'') is not supported in Python 3.4, so just use bytes.join() + conn.send(b''.join((b'GET /hello HTTP/1.1', invalid_terminator))) response = conn.response_class(conn.sock, method='GET') response.begin() actual_resp_body = response.read() expected_resp_body = b'HTTP requires CRLF terminators' assert actual_resp_body == expected_resp_body conn.close() + + +class FaultySelect: + """Mock class to insert errors in the selector.select method.""" + + def __init__(self, original_select): + """Initilize helper class to wrap the selector.select method.""" + self.original_select = original_select + self.request_served = False + self.os_error_triggered = False + + def __call__(self, timeout): + """Intercept the calls to selector.select.""" + if self.request_served: + self.os_error_triggered = True + raise OSError('Error while selecting the client socket.') + + return self.original_select(timeout) + + +class FaultyGetMap: + """Mock class to insert errors in the selector.get_map method.""" + + def __init__(self, original_get_map): + """Initilize helper class to wrap the selector.get_map method.""" + self.original_get_map = original_get_map + self.sabotage_conn = False + self.conn_closed = False + + def __call__(self): + """Intercept the calls to selector.get_map.""" + sabotage_targets = ( + conn for _, (_, _, _, conn) in self.original_get_map().items() + if isinstance(conn, cheroot.server.HTTPConnection) + ) if self.sabotage_conn and not self.conn_closed else () + + for conn in sabotage_targets: + # close the socket to cause OSError + conn.close() + self.conn_closed = True + + return self.original_get_map() + + +def test_invalid_selected_connection(test_client, monkeypatch): + """Test the error handling segment of HTTP connection selection. + + See :py:meth:`cheroot.connections.ConnectionManager.get_conn`. + """ + # patch the select method + faux_select = FaultySelect( + test_client.server_instance._connections._selector.select, + ) + monkeypatch.setattr( + test_client.server_instance._connections._selector, + 'select', + faux_select, + ) + + # patch the get_map method + faux_get_map = FaultyGetMap( + test_client.server_instance._connections._selector._selector.get_map, + ) + + monkeypatch.setattr( + test_client.server_instance._connections._selector._selector, + 'get_map', + faux_get_map, + ) + + # request a page with connection keep-alive to make sure + # we'll have a connection to be modified. + resp_status, resp_headers, resp_body = test_client.request( + '/page1', headers=[('Connection', 'Keep-Alive')], + ) + + assert resp_status == '200 OK' + # trigger the internal errors + faux_get_map.sabotage_conn = faux_select.request_served = True + # give time to make sure the error gets handled + time.sleep(test_client.server_instance.expiration_interval * 2) + assert faux_select.os_error_triggered + assert faux_get_map.conn_closed diff --git a/lib/cheroot/test/test_core.py b/lib/cheroot/test/test_core.py index aad2bb7f..933ff235 100644 --- a/lib/cheroot/test/test_core.py +++ b/lib/cheroot/test/test_core.py @@ -18,6 +18,7 @@ from cheroot.test import helper HTTP_BAD_REQUEST = 400 HTTP_LENGTH_REQUIRED = 411 HTTP_NOT_FOUND = 404 +HTTP_REQUEST_ENTITY_TOO_LARGE = 413 HTTP_OK = 200 HTTP_VERSION_NOT_SUPPORTED = 505 @@ -78,7 +79,7 @@ def _get_http_response(connection, method='GET'): @pytest.fixture def testing_server(wsgi_server_client): - """Attach a WSGI app to the given server and pre-configure it.""" + """Attach a WSGI app to the given server and preconfigure it.""" wsgi_server = wsgi_server_client.server_instance wsgi_server.wsgi_app = HelloController() wsgi_server.max_request_body_size = 30000000 @@ -92,6 +93,21 @@ def test_client(testing_server): return testing_server.server_client +@pytest.fixture +def testing_server_with_defaults(wsgi_server_client): + """Attach a WSGI app to the given server and preconfigure it.""" + wsgi_server = wsgi_server_client.server_instance + wsgi_server.wsgi_app = HelloController() + wsgi_server.server_client = wsgi_server_client + return wsgi_server + + +@pytest.fixture +def test_client_with_defaults(testing_server_with_defaults): + """Get and return a test client out of the given server.""" + return testing_server_with_defaults.server_client + + def test_http_connect_request(test_client): """Check that CONNECT query results in Method Not Allowed status.""" status_line = test_client.connect('/anything')[0] @@ -262,8 +278,30 @@ def test_content_length_required(test_client): assert actual_status == HTTP_LENGTH_REQUIRED +@pytest.mark.xfail( + reason='https://github.com/cherrypy/cheroot/issues/106', + strict=False, # sometimes it passes +) +def test_large_request(test_client_with_defaults): + """Test GET query with maliciously large Content-Length.""" + # If the server's max_request_body_size is not set (i.e. is set to 0) + # then this will result in an `OverflowError: Python int too large to + # convert to C ssize_t` in the server. + # We expect that this should instead return that the request is too + # large. + c = test_client_with_defaults.get_connection() + c.putrequest('GET', '/hello') + c.putheader('Content-Length', str(2**64)) + c.endheaders() + + response = c.getresponse() + actual_status = response.status + + assert actual_status == HTTP_REQUEST_ENTITY_TOO_LARGE + + @pytest.mark.parametrize( - 'request_line,status_code,expected_body', + ('request_line', 'status_code', 'expected_body'), ( ( b'GET /', # missing proto @@ -401,7 +439,7 @@ class CloseResponse: @pytest.fixture def testing_server_close(wsgi_server_client): - """Attach a WSGI app to the given server and pre-configure it.""" + """Attach a WSGI app to the given server and preconfigure it.""" wsgi_server = wsgi_server_client.server_instance wsgi_server.wsgi_app = CloseController() wsgi_server.max_request_body_size = 30000000 diff --git a/lib/cheroot/test/test_dispatch.py b/lib/cheroot/test/test_dispatch.py index bc588749..9974fdab 100644 --- a/lib/cheroot/test/test_dispatch.py +++ b/lib/cheroot/test/test_dispatch.py @@ -8,7 +8,7 @@ from cheroot.wsgi import PathInfoDispatcher def wsgi_invoke(app, environ): - """Serve 1 requeset from a WSGI application.""" + """Serve 1 request from a WSGI application.""" response = {} def start_response(status, headers): @@ -25,7 +25,7 @@ def wsgi_invoke(app, environ): def test_dispatch_no_script_name(): - """Despatch despite lack of SCRIPT_NAME in environ.""" + """Dispatch despite lack of ``SCRIPT_NAME`` in environ.""" # Bare bones WSGI hello world app (from PEP 333). def app(environ, start_response): start_response( diff --git a/lib/cheroot/test/test_errors.py b/lib/cheroot/test/test_errors.py index 34b42d90..469b70a8 100644 --- a/lib/cheroot/test/test_errors.py +++ b/lib/cheroot/test/test_errors.py @@ -8,7 +8,7 @@ from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS @pytest.mark.parametrize( - 'err_names,err_nums', + ('err_names', 'err_nums'), ( (('', 'some-nonsense-name'), []), ( @@ -24,7 +24,7 @@ from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS ), ) def test_plat_specific_errors(err_names, err_nums): - """Test that plat_specific_errors retrieves correct err num list.""" + """Test that ``plat_specific_errors`` gets correct error numbers list.""" actual_err_nums = errors.plat_specific_errors(*err_names) assert len(actual_err_nums) == len(err_nums) assert sorted(actual_err_nums) == sorted(err_nums) diff --git a/lib/cheroot/test/test_makefile.py b/lib/cheroot/test/test_makefile.py index 55db5038..cdded07e 100644 --- a/lib/cheroot/test/test_makefile.py +++ b/lib/cheroot/test/test_makefile.py @@ -1,4 +1,4 @@ -"""self-explanatory.""" +"""Tests for :py:mod:`cheroot.makefile`.""" from cheroot import makefile @@ -7,14 +7,14 @@ __metaclass__ = type class MockSocket: - """Mocks a socket.""" + """A mock socket.""" def __init__(self): - """Initialize.""" + """Initialize :py:class:`MockSocket`.""" self.messages = [] def recv_into(self, buf): - """Simulate recv_into for Python 3.""" + """Simulate ``recv_into`` for Python 3.""" if not self.messages: return 0 msg = self.messages.pop(0) @@ -23,7 +23,7 @@ class MockSocket: return len(msg) def recv(self, size): - """Simulate recv for Python 2.""" + """Simulate ``recv`` for Python 2.""" try: return self.messages.pop(0) except IndexError: @@ -44,7 +44,7 @@ def test_bytes_read(): def test_bytes_written(): - """Writer should capture bytes writtten.""" + """Writer should capture bytes written.""" sock = MockSocket() sock.messages.append(b'foo') wfile = makefile.MakeFile(sock, 'w') diff --git a/lib/cheroot/test/test_server.py b/lib/cheroot/test/test_server.py index 30112354..c851f039 100644 --- a/lib/cheroot/test/test_server.py +++ b/lib/cheroot/test/test_server.py @@ -5,10 +5,12 @@ from __future__ import absolute_import, division, print_function __metaclass__ = type +from contextlib import closing import os import socket import tempfile import threading +import time import uuid import pytest @@ -16,14 +18,15 @@ import requests import requests_unixsocket import six +from six.moves import queue, urllib + from .._compat import bton, ntob -from .._compat import IS_LINUX, IS_MACOS, SYS_PLATFORM +from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS, SYS_PLATFORM from ..server import IS_UID_GID_RESOLVABLE, Gateway, HTTPServer from ..testing import ( ANY_INTERFACE_IPV4, ANY_INTERFACE_IPV6, EPHEMERAL_PORT, - get_server_client, ) @@ -42,15 +45,8 @@ non_macos_sock_test = pytest.mark.skipif( @pytest.fixture(params=('abstract', 'file')) def unix_sock_file(request): """Check that bound UNIX socket address is stored in server.""" - if request.param == 'abstract': - yield request.getfixturevalue('unix_abstract_sock') - return - tmp_sock_fh, tmp_sock_fname = tempfile.mkstemp() - - yield tmp_sock_fname - - os.close(tmp_sock_fh) - os.unlink(tmp_sock_fname) + name = 'unix_{request.param}_sock'.format(**locals()) + return request.getfixturevalue(name) @pytest.fixture @@ -67,6 +63,17 @@ def unix_abstract_sock(): )).decode() +@pytest.fixture +def unix_file_sock(): + """Yield a unix file socket.""" + tmp_sock_fh, tmp_sock_fname = tempfile.mkstemp() + + yield tmp_sock_fname + + os.close(tmp_sock_fh) + os.unlink(tmp_sock_fname) + + def test_prepare_makes_server_ready(): """Check that prepare() makes the server ready, and stop() clears it.""" httpserver = HTTPServer( @@ -110,6 +117,77 @@ def test_stop_interrupts_serve(): assert not serve_thread.is_alive() +@pytest.mark.parametrize( + 'exc_cls', + ( + IOError, + KeyboardInterrupt, + OSError, + RuntimeError, + ), +) +def test_server_interrupt(exc_cls): + """Check that assigning interrupt stops the server.""" + interrupt_msg = 'should catch {uuid!s}'.format(uuid=uuid.uuid4()) + raise_marker_sentinel = object() + + httpserver = HTTPServer( + bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT), + gateway=Gateway, + ) + + result_q = queue.Queue() + + def serve_thread(): + # ensure we catch the exception on the serve() thread + try: + httpserver.serve() + except exc_cls as e: + if str(e) == interrupt_msg: + result_q.put(raise_marker_sentinel) + + httpserver.prepare() + serve_thread = threading.Thread(target=serve_thread) + serve_thread.start() + + serve_thread.join(0.5) + assert serve_thread.is_alive() + + # this exception is raised on the serve() thread, + # not in the calling context. + httpserver.interrupt = exc_cls(interrupt_msg) + + serve_thread.join(0.5) + assert not serve_thread.is_alive() + assert result_q.get_nowait() is raise_marker_sentinel + + +def test_serving_is_false_and_stop_returns_after_ctrlc(): + """Check that stop() interrupts running of serve().""" + httpserver = HTTPServer( + bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT), + gateway=Gateway, + ) + + httpserver.prepare() + + # Simulate a Ctrl-C on the first call to `run`. + def raise_keyboard_interrupt(*args, **kwargs): + raise KeyboardInterrupt() + + httpserver._connections._selector.select = raise_keyboard_interrupt + + serve_thread = threading.Thread(target=httpserver.serve) + serve_thread.start() + + # The thread should exit right away due to the interrupt. + serve_thread.join(httpserver.expiration_interval * 2) + assert not serve_thread.is_alive() + + assert not httpserver._connections._serving + httpserver.stop() + + @pytest.mark.parametrize( 'ip_addr', ( @@ -135,7 +213,7 @@ def test_bind_addr_unix(http_server, unix_sock_file): @unix_only_sock_test def test_bind_addr_unix_abstract(http_server, unix_abstract_sock): - """Check that bound UNIX abstract sockaddr is stored in server.""" + """Check that bound UNIX abstract socket address is stored in server.""" httpserver = http_server.send(unix_abstract_sock) assert httpserver.bind_addr == unix_abstract_sock @@ -167,27 +245,26 @@ class _TestGateway(Gateway): @pytest.fixture -def peercreds_enabled_server_and_client(http_server, unix_sock_file): - """Construct a test server with `peercreds_enabled`.""" +def peercreds_enabled_server(http_server, unix_sock_file): + """Construct a test server with ``peercreds_enabled``.""" httpserver = http_server.send(unix_sock_file) httpserver.gateway = _TestGateway httpserver.peercreds_enabled = True - return httpserver, get_server_client(httpserver) + return httpserver @unix_only_sock_test @non_macos_sock_test -def test_peercreds_unix_sock(peercreds_enabled_server_and_client): - """Check that peercred lookup works when enabled.""" - httpserver, testclient = peercreds_enabled_server_and_client +def test_peercreds_unix_sock(peercreds_enabled_server): + """Check that ``PEERCRED`` lookup works when enabled.""" + httpserver = peercreds_enabled_server bind_addr = httpserver.bind_addr if isinstance(bind_addr, six.binary_type): bind_addr = bind_addr.decode() - unix_base_uri = 'http+unix://{}'.format( - bind_addr.replace('\0', '%00').replace('/', '%2F'), - ) + quoted = urllib.parse.quote(bind_addr, safe='') + unix_base_uri = 'http+unix://{quoted}'.format(**locals()) expected_peercreds = os.getpid(), os.getuid(), os.getgid() expected_peercreds = '|'.join(map(str, expected_peercreds)) @@ -208,9 +285,9 @@ def test_peercreds_unix_sock(peercreds_enabled_server_and_client): ) @unix_only_sock_test @non_macos_sock_test -def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server_and_client): - """Check that peercred resolution works when enabled.""" - httpserver, testclient = peercreds_enabled_server_and_client +def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server): + """Check that ``PEERCRED`` resolution works when enabled.""" + httpserver = peercreds_enabled_server httpserver.peercreds_resolve_enabled = True bind_addr = httpserver.bind_addr @@ -218,9 +295,8 @@ def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server_and_client): if isinstance(bind_addr, six.binary_type): bind_addr = bind_addr.decode() - unix_base_uri = 'http+unix://{}'.format( - bind_addr.replace('\0', '%00').replace('/', '%2F'), - ) + quoted = urllib.parse.quote(bind_addr, safe='') + unix_base_uri = 'http+unix://{quoted}'.format(**locals()) import grp import pwd @@ -233,3 +309,111 @@ def test_peercreds_unix_sock_with_lookup(peercreds_enabled_server_and_client): peercreds_text_resp = requests.get(unix_base_uri + PEERCRED_TEXTS_URI) peercreds_text_resp.raise_for_status() assert peercreds_text_resp.text == expected_textcreds + + +@pytest.mark.skipif( + IS_WINDOWS, + reason='This regression test is for a Linux bug, ' + 'and the resource module is not available on Windows', +) +@pytest.mark.parametrize( + 'resource_limit', + ( + 1024, + 2048, + ), + indirect=('resource_limit',), +) +@pytest.mark.usefixtures('many_open_sockets') +def test_high_number_of_file_descriptors(resource_limit): + """Test the server does not crash with a high file-descriptor value. + + This test shouldn't cause a server crash when trying to access + file-descriptor higher than 1024. + + The earlier implementation used to rely on ``select()`` syscall that + doesn't support file descriptors with numbers higher than 1024. + """ + # We want to force the server to use a file-descriptor with + # a number above resource_limit + + # Create our server + httpserver = HTTPServer( + bind_addr=(ANY_INTERFACE_IPV4, EPHEMERAL_PORT), gateway=Gateway, + ) + + try: + # This will trigger a crash if select() is used in the implementation + with httpserver._run_in_thread(): + # allow server to run long enough to invoke select() + time.sleep(1.0) + except: # noqa: E722 + raise # only needed for `else` to work + else: + # We use closing here for py2-compat + with closing(socket.socket()) as sock: + # Check new sockets created are still above our target number + assert sock.fileno() >= resource_limit + finally: + # Stop our server + httpserver.stop() + + +if not IS_WINDOWS: + test_high_number_of_file_descriptors = pytest.mark.forked( + test_high_number_of_file_descriptors, + ) + + +@pytest.fixture +def resource_limit(request): + """Set the resource limit two times bigger then requested.""" + resource = pytest.importorskip( + 'resource', + reason='The "resource" module is Unix-specific', + ) + + # Get current resource limits to restore them later + soft_limit, hard_limit = resource.getrlimit(resource.RLIMIT_NOFILE) + + # We have to increase the nofile limit above 1024 + # Otherwise we see a 'Too many files open' error, instead of + # an error due to the file descriptor number being too high + resource.setrlimit( + resource.RLIMIT_NOFILE, + (request.param * 2, hard_limit), + ) + + try: # noqa: WPS501 + yield request.param + finally: + # Reset the resource limit back to the original soft limit + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) + + +@pytest.fixture +def many_open_sockets(resource_limit): + """Allocate a lot of file descriptors by opening dummy sockets.""" + # Hoard a lot of file descriptors by opening and storing a lot of sockets + test_sockets = [] + # Open a lot of file descriptors, so the next one the server + # opens is a high number + try: + for i in range(resource_limit): + sock = socket.socket() + test_sockets.append(sock) + # NOTE: We used to interrupt the loop early but this doesn't seem + # NOTE: to work well in envs with indeterministic runtimes like + # NOTE: PyPy. It looks like sometimes it frees some file + # NOTE: descriptors in between running this fixture and the actual + # NOTE: test code so the early break has been removed to try + # NOTE: address that. The approach may need to be rethought if the + # NOTE: issue reoccurs. Another approach may be disabling the GC. + # Check we opened enough descriptors to reach a high number + the_highest_fileno = max(sock.fileno() for sock in test_sockets) + assert the_highest_fileno >= resource_limit + yield the_highest_fileno + finally: + # Close our open resources + for test_socket in test_sockets: + test_socket.close() diff --git a/lib/cheroot/test/test_ssl.py b/lib/cheroot/test/test_ssl.py index caa1ae0a..8aa258f4 100644 --- a/lib/cheroot/test/test_ssl.py +++ b/lib/cheroot/test/test_ssl.py @@ -1,4 +1,4 @@ -"""Tests for TLS/SSL support.""" +"""Tests for TLS support.""" # -*- coding: utf-8 -*- # vim: set fileencoding=utf-8 : @@ -6,11 +6,14 @@ from __future__ import absolute_import, division, print_function __metaclass__ = type import functools +import json import os import ssl +import subprocess import sys import threading import time +import traceback import OpenSSL.SSL import pytest @@ -21,7 +24,7 @@ import trustme from .._compat import bton, ntob, ntou from .._compat import IS_ABOVE_OPENSSL10, IS_PYPY from .._compat import IS_LINUX, IS_MACOS, IS_WINDOWS -from ..server import Gateway, HTTPServer, get_ssl_adapter_class +from ..server import HTTPServer, get_ssl_adapter_class from ..testing import ( ANY_INTERFACE_IPV4, ANY_INTERFACE_IPV6, @@ -30,9 +33,17 @@ from ..testing import ( _get_conn_data, _probe_ipv6_sock, ) +from ..wsgi import Gateway_10 IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv('GITHUB_WORKFLOW')) +IS_WIN2016 = ( + IS_WINDOWS + # pylint: disable=unsupported-membership-test + and b'Microsoft Windows Server 2016 Datacenter' in subprocess.check_output( + ('systeminfo',), + ) +) IS_LIBRESSL_BACKEND = ssl.OPENSSL_VERSION.startswith('LibreSSL') IS_PYOPENSSL_SSL_VERSION_1_0 = ( OpenSSL.SSL.SSLeay_version(OpenSSL.SSL.SSLEAY_VERSION). @@ -40,6 +51,7 @@ IS_PYOPENSSL_SSL_VERSION_1_0 = ( ) PY27 = sys.version_info[:2] == (2, 7) PY34 = sys.version_info[:2] == (3, 4) +PY3 = not six.PY2 _stdlib_to_openssl_verify = { @@ -71,7 +83,7 @@ missing_ipv6 = pytest.mark.skipif( ) -class HelloWorldGateway(Gateway): +class HelloWorldGateway(Gateway_10): """Gateway responding with Hello World to root URI.""" def respond(self): @@ -83,11 +95,21 @@ class HelloWorldGateway(Gateway): req.ensure_headers_sent() req.write(b'Hello world!') return + if req_uri == '/env': + req.status = b'200 OK' + req.ensure_headers_sent() + env = self.get_environ() + # drop files so that it can be json dumped + env.pop('wsgi.errors') + env.pop('wsgi.input') + print(env) + req.write(json.dumps(env).encode('utf-8')) + return return super(HelloWorldGateway, self).respond() def make_tls_http_server(bind_addr, ssl_adapter, request): - """Create and start an HTTP server bound to bind_addr.""" + """Create and start an HTTP server bound to ``bind_addr``.""" httpserver = HTTPServer( bind_addr=bind_addr, gateway=HelloWorldGateway, @@ -128,7 +150,7 @@ def tls_ca_certificate_pem_path(ca): def tls_certificate(ca): """Provide a leaf certificate via fixture.""" interface, host, port = _get_conn_data(ANY_INTERFACE_IPV4) - return ca.issue_server_cert(ntou(interface), ) + return ca.issue_server_cert(ntou(interface)) @pytest.fixture @@ -145,6 +167,43 @@ def tls_certificate_private_key_pem_path(tls_certificate): yield cert_key_pem +def _thread_except_hook(exceptions, args): + """Append uncaught exception ``args`` in threads to ``exceptions``.""" + if issubclass(args.exc_type, SystemExit): + return + # cannot store the exception, it references the thread's stack + exceptions.append(( + args.exc_type, + str(args.exc_value), + ''.join( + traceback.format_exception( + args.exc_type, args.exc_value, args.exc_traceback, + ), + ), + )) + + +@pytest.fixture +def thread_exceptions(): + """Provide a list of uncaught exceptions from threads via a fixture. + + Only catches exceptions on Python 3.8+. + The list contains: ``(type, str(value), str(traceback))`` + """ + exceptions = [] + # Python 3.8+ + orig_hook = getattr(threading, 'excepthook', None) + if orig_hook is not None: + threading.excepthook = functools.partial( + _thread_except_hook, exceptions, + ) + try: + yield exceptions + finally: + if orig_hook is not None: + threading.excepthook = orig_hook + + @pytest.mark.parametrize( 'adapter_type', ( @@ -180,7 +239,7 @@ def test_ssl_adapters( ) resp = requests.get( - 'https://' + interface + ':' + str(port) + '/', + 'https://{host!s}:{port!s}/'.format(host=interface, port=port), verify=tls_ca_certificate_pem_path, ) @@ -188,7 +247,7 @@ def test_ssl_adapters( assert resp.text == 'Hello world!' -@pytest.mark.parametrize( +@pytest.mark.parametrize( # noqa: C901 # FIXME 'adapter_type', ( 'builtin', @@ -196,7 +255,7 @@ def test_ssl_adapters( ), ) @pytest.mark.parametrize( - 'is_trusted_cert,tls_client_identity', + ('is_trusted_cert', 'tls_client_identity'), ( (True, 'localhost'), (True, '127.0.0.1'), (True, '*.localhost'), (True, 'not_localhost'), @@ -211,7 +270,7 @@ def test_ssl_adapters( ssl.CERT_REQUIRED, # server should validate if client cert CA is OK ), ) -def test_tls_client_auth( +def test_tls_client_auth( # noqa: C901 # FIXME # FIXME: remove twisted logic, separate tests mocker, tls_http_server, adapter_type, @@ -265,7 +324,7 @@ def test_tls_client_auth( make_https_request = functools.partial( requests.get, - 'https://' + interface + ':' + str(port) + '/', + 'https://{host!s}:{port!s}/'.format(host=interface, port=port), # Server TLS certificate verification: verify=tls_ca_certificate_pem_path, @@ -324,36 +383,200 @@ def test_tls_client_auth( except AttributeError: if PY34: pytest.xfail('OpenSSL behaves wierdly under Python 3.4') - elif not six.PY2 and IS_WINDOWS: + elif IS_WINDOWS or IS_GITHUB_ACTIONS_WORKFLOW: err_text = str(ssl_err.value) else: raise + if isinstance(err_text, int): + err_text = str(ssl_err.value) + expected_substrings = ( 'sslv3 alert bad certificate' if IS_LIBRESSL_BACKEND else 'tlsv1 alert unknown ca', ) if not six.PY2: if IS_MACOS and IS_PYPY and adapter_type == 'pyopenssl': - expected_substrings = ('tlsv1 alert unknown ca', ) - if ( - IS_WINDOWS - and tls_verify_mode in ( - ssl.CERT_REQUIRED, - ssl.CERT_OPTIONAL, - ) - and not is_trusted_cert - and tls_client_identity == 'localhost' - ): - expected_substrings += ( - 'bad handshake: ' - "SysCallError(10054, 'WSAECONNRESET')", - "('Connection aborted.', " - 'OSError("(10054, \'WSAECONNRESET\')"))', + expected_substrings = ('tlsv1 alert unknown ca',) + if ( + tls_verify_mode in ( + ssl.CERT_REQUIRED, + ssl.CERT_OPTIONAL, ) + and not is_trusted_cert + and tls_client_identity == 'localhost' + ): + expected_substrings += ( + 'bad handshake: ' + "SysCallError(10054, 'WSAECONNRESET')", + "('Connection aborted.', " + 'OSError("(10054, \'WSAECONNRESET\')"))', + "('Connection aborted.', " + 'OSError("(10054, \'WSAECONNRESET\')",))', + "('Connection aborted.', " + 'error("(10054, \'WSAECONNRESET\')",))', + "('Connection aborted.', " + 'ConnectionResetError(10054, ' + "'An existing connection was forcibly closed " + "by the remote host', None, 10054, None))", + ) if IS_WINDOWS else ( + "('Connection aborted.', " + 'OSError("(104, \'ECONNRESET\')"))', + "('Connection aborted.', " + 'OSError("(104, \'ECONNRESET\')",))', + "('Connection aborted.', " + 'error("(104, \'ECONNRESET\')",))', + "('Connection aborted.', " + "ConnectionResetError(104, 'Connection reset by peer'))", + "('Connection aborted.', " + "error(104, 'Connection reset by peer'))", + ) if ( + IS_GITHUB_ACTIONS_WORKFLOW + and IS_LINUX + ) else ( + "('Connection aborted.', " + "BrokenPipeError(32, 'Broken pipe'))", + ) assert any(e in err_text for e in expected_substrings) +@pytest.mark.parametrize( # noqa: C901 # FIXME + 'adapter_type', + ( + 'builtin', + 'pyopenssl', + ), +) +@pytest.mark.parametrize( + ('tls_verify_mode', 'use_client_cert'), + ( + (ssl.CERT_NONE, False), + (ssl.CERT_NONE, True), + (ssl.CERT_OPTIONAL, False), + (ssl.CERT_OPTIONAL, True), + (ssl.CERT_REQUIRED, True), + ), +) +def test_ssl_env( # noqa: C901 # FIXME + thread_exceptions, + recwarn, + mocker, + tls_http_server, adapter_type, + ca, tls_verify_mode, tls_certificate, + tls_certificate_chain_pem_path, + tls_certificate_private_key_pem_path, + tls_ca_certificate_pem_path, + use_client_cert, +): + """Test the SSL environment generated by the SSL adapters.""" + interface, _host, port = _get_conn_data(ANY_INTERFACE_IPV4) + + with mocker.mock_module.patch( + 'idna.core.ulabel', + return_value=ntob('127.0.0.1'), + ): + client_cert = ca.issue_cert(ntou('127.0.0.1')) + + with client_cert.private_key_and_cert_chain_pem.tempfile() as cl_pem: + tls_adapter_cls = get_ssl_adapter_class(name=adapter_type) + tls_adapter = tls_adapter_cls( + tls_certificate_chain_pem_path, + tls_certificate_private_key_pem_path, + ) + if adapter_type == 'pyopenssl': + tls_adapter.context = tls_adapter.get_context() + tls_adapter.context.set_verify( + _stdlib_to_openssl_verify[tls_verify_mode], + lambda conn, cert, errno, depth, preverify_ok: preverify_ok, + ) + else: + tls_adapter.context.verify_mode = tls_verify_mode + + ca.configure_trust(tls_adapter.context) + tls_certificate.configure_cert(tls_adapter.context) + + tlswsgiserver = tls_http_server((interface, port), tls_adapter) + + interface, _host, port = _get_conn_data(tlswsgiserver.bind_addr) + + resp = requests.get( + 'https://' + interface + ':' + str(port) + '/env', + verify=tls_ca_certificate_pem_path, + cert=cl_pem if use_client_cert else None, + ) + if PY34 and resp.status_code != 200: + pytest.xfail( + 'Python 3.4 has problems with verifying client certs', + ) + + env = json.loads(resp.content.decode('utf-8')) + + # hard coded env + assert env['wsgi.url_scheme'] == 'https' + assert env['HTTPS'] == 'on' + + # ensure these are present + for key in {'SSL_VERSION_INTERFACE', 'SSL_VERSION_LIBRARY'}: + assert key in env + + # pyOpenSSL generates the env before the handshake completes + if adapter_type == 'pyopenssl': + return + + for key in {'SSL_PROTOCOL', 'SSL_CIPHER'}: + assert key in env + + # client certificate env + if tls_verify_mode == ssl.CERT_NONE or not use_client_cert: + assert env['SSL_CLIENT_VERIFY'] == 'NONE' + else: + assert env['SSL_CLIENT_VERIFY'] == 'SUCCESS' + + with open(cl_pem, 'rt') as f: + assert env['SSL_CLIENT_CERT'] in f.read() + + for key in { + 'SSL_CLIENT_M_VERSION', 'SSL_CLIENT_M_SERIAL', + 'SSL_CLIENT_I_DN', 'SSL_CLIENT_S_DN', + }: + assert key in env + + # builtin ssl environment generation may use a loopback socket + # ensure no ResourceWarning was raised during the test + # NOTE: python 2.7 does not emit ResourceWarning for ssl sockets + if IS_PYPY: + # NOTE: PyPy doesn't have ResourceWarning + # Ref: https://doc.pypy.org/en/latest/cpython_differences.html + return + for warn in recwarn: + if not issubclass(warn.category, ResourceWarning): + continue + + # the tests can sporadically generate resource warnings + # due to timing issues + # all of these sporadic warnings appear to be about socket.socket + # and have been observed to come from requests connection pool + msg = str(warn.message) + if 'socket.socket' in msg: + pytest.xfail( + '\n'.join(( + 'Sometimes this test fails due to ' + 'a socket.socket ResourceWarning:', + msg, + )), + ) + pytest.fail(msg) + + # to perform the ssl handshake over that loopback socket, + # the builtin ssl environment generation uses a thread + for _, _, trace in thread_exceptions: + print(trace, file=sys.stderr) + assert not thread_exceptions, ': '.join(( + thread_exceptions[0][0].__name__, + thread_exceptions[0][1], + )) + + @pytest.mark.parametrize( 'ip_addr', ( @@ -382,7 +605,16 @@ def test_https_over_http_error(http_server, ip_addr): @pytest.mark.parametrize( 'adapter_type', ( - 'builtin', + pytest.param( + 'builtin', + marks=pytest.mark.xfail( + IS_WINDOWS and six.PY2, + raises=requests.exceptions.ConnectionError, + reason='Stdlib `ssl` module behaves weirdly ' + 'on Windows under Python 2', + strict=False, + ), + ), 'pyopenssl', ), ) @@ -428,16 +660,41 @@ def test_http_over_https_error( fqdn = interface if ip_addr is ANY_INTERFACE_IPV6: - fqdn = '[{}]'.format(fqdn) + fqdn = '[{fqdn}]'.format(**locals()) expect_fallback_response_over_plain_http = ( - (adapter_type == 'pyopenssl' - and (IS_ABOVE_OPENSSL10 or not six.PY2)) + ( + adapter_type == 'pyopenssl' + and (IS_ABOVE_OPENSSL10 or not six.PY2) + ) or PY27 + ) or ( + IS_GITHUB_ACTIONS_WORKFLOW + and IS_WINDOWS + and six.PY2 + and not IS_WIN2016 ) + if ( + IS_GITHUB_ACTIONS_WORKFLOW + and IS_WINDOWS + and six.PY2 + and IS_WIN2016 + and adapter_type == 'builtin' + and ip_addr is ANY_INTERFACE_IPV6 + ): + expect_fallback_response_over_plain_http = True + if ( + IS_GITHUB_ACTIONS_WORKFLOW + and IS_WINDOWS + and six.PY2 + and not IS_WIN2016 + and adapter_type == 'builtin' + and ip_addr is not ANY_INTERFACE_IPV6 + ): + expect_fallback_response_over_plain_http = False if expect_fallback_response_over_plain_http: resp = requests.get( - 'http://' + fqdn + ':' + str(port) + '/', + 'http://{host!s}:{port!s}/'.format(host=fqdn, port=port), ) assert resp.status_code == 400 assert resp.text == ( @@ -448,7 +705,7 @@ def test_http_over_https_error( with pytest.raises(requests.exceptions.ConnectionError) as ssl_err: requests.get( # FIXME: make stdlib ssl behave like PyOpenSSL - 'http://' + fqdn + ':' + str(port) + '/', + 'http://{host!s}:{port!s}/'.format(host=fqdn, port=port), ) if IS_LINUX: @@ -468,7 +725,7 @@ def test_http_over_https_error( underlying_error = ssl_err.value.args[0].args[-1] err_text = str(underlying_error) assert underlying_error.errno == expected_error_code, ( - 'The underlying error is {!r}'. - format(underlying_error) + 'The underlying error is {underlying_error!r}'. + format(**locals()) ) assert expected_error_text in err_text diff --git a/lib/cheroot/test/test_wsgi.py b/lib/cheroot/test/test_wsgi.py new file mode 100644 index 00000000..d3c47ece --- /dev/null +++ b/lib/cheroot/test/test_wsgi.py @@ -0,0 +1,58 @@ +"""Test wsgi.""" + +from concurrent.futures.thread import ThreadPoolExecutor + +import pytest +import portend +import requests +from requests_toolbelt.sessions import BaseUrlSession as Session +from jaraco.context import ExceptionTrap + +from cheroot import wsgi +from cheroot._compat import IS_MACOS, IS_WINDOWS + + +IS_SLOW_ENV = IS_MACOS or IS_WINDOWS + + +@pytest.fixture +def simple_wsgi_server(): + """Fucking simple wsgi server fixture (duh).""" + port = portend.find_available_local_port() + + def app(environ, start_response): + status = '200 OK' + response_headers = [('Content-type', 'text/plain')] + start_response(status, response_headers) + return [b'Hello world!'] + + host = '::' + addr = host, port + server = wsgi.Server(addr, app, timeout=600 if IS_SLOW_ENV else 20) + url = 'http://localhost:{port}/'.format(**locals()) + with server._run_in_thread() as thread: + yield locals() + + +def test_connection_keepalive(simple_wsgi_server): + """Test the connection keepalive works (duh).""" + session = Session(base_url=simple_wsgi_server['url']) + pooled = requests.adapters.HTTPAdapter( + pool_connections=1, pool_maxsize=1000, + ) + session.mount('http://', pooled) + + def do_request(): + with ExceptionTrap(requests.exceptions.ConnectionError) as trap: + resp = session.get('info') + resp.raise_for_status() + return bool(trap) + + with ThreadPoolExecutor(max_workers=10 if IS_SLOW_ENV else 50) as pool: + tasks = [ + pool.submit(do_request) + for n in range(250 if IS_SLOW_ENV else 1000) + ] + failures = sum(task.result() for task in tasks) + + assert not failures diff --git a/lib/cheroot/test/webtest.py b/lib/cheroot/test/webtest.py index 934b2004..cdd340e8 100644 --- a/lib/cheroot/test/webtest.py +++ b/lib/cheroot/test/webtest.py @@ -1,12 +1,14 @@ """Extensions to unittest for web frameworks. -Use the WebCase.getPage method to request a page from your HTTP server. +Use the :py:meth:`WebCase.getPage` method to request a page +from your HTTP server. + Framework Integration ===================== If you have control over your server process, you can handle errors in the server-side of the HTTP conversation a bit better. You must run -both the client (your WebCase tests) and the server in the same process -(but in separate threads, obviously). +both the client (your :py:class:`WebCase` tests) and the server in the +same process (but in separate threads, obviously). When an error occurs in the framework, call server_error. It will print the traceback to stdout, and keep any assertions you have from running (the assumption is that, if the server errors, the page output will not @@ -122,7 +124,7 @@ class WebCase(unittest.TestCase): def _Conn(self): """Return HTTPConnection or HTTPSConnection based on self.scheme. - * from http.client. + * from :py:mod:`python:http.client`. """ cls_name = '{scheme}Connection'.format(scheme=self.scheme.upper()) return getattr(http_client, cls_name) @@ -157,7 +159,7 @@ class WebCase(unittest.TestCase): @property def persistent(self): - """Presense of the persistent HTTP connection.""" + """Presence of the persistent HTTP connection.""" return hasattr(self.HTTP_CONN, '__class__') @persistent.setter @@ -176,7 +178,9 @@ class WebCase(unittest.TestCase): self, url, headers=None, method='GET', body=None, protocol=None, raise_subcls=(), ): - """Open the url with debugging support. Return status, headers, body. + """Open the url with debugging support. + + Return status, headers, body. url should be the identifier passed to the server, typically a server-absolute path and query string (sent between method and @@ -184,16 +188,16 @@ class WebCase(unittest.TestCase): enabled in the server. If the application under test generates absolute URIs, be sure - to wrap them first with strip_netloc:: + to wrap them first with :py:func:`strip_netloc`:: - class MyAppWebCase(WebCase): - def getPage(url, *args, **kwargs): - super(MyAppWebCase, self).getPage( - cheroot.test.webtest.strip_netloc(url), - *args, **kwargs - ) + >>> class MyAppWebCase(WebCase): + ... def getPage(url, *args, **kwargs): + ... super(MyAppWebCase, self).getPage( + ... cheroot.test.webtest.strip_netloc(url), + ... *args, **kwargs + ... ) - `raise_subcls` is passed through to openURL. + ``raise_subcls`` is passed through to :py:func:`openURL`. """ ServerError.on = False @@ -247,7 +251,7 @@ class WebCase(unittest.TestCase): console_height = 30 - def _handlewebError(self, msg): + def _handlewebError(self, msg): # noqa: C901 # FIXME print('') print(' ERROR: %s' % msg) @@ -487,7 +491,7 @@ def openURL(*args, **kwargs): """ Open a URL, retrying when it fails. - Specify `raise_subcls` (class or tuple of classes) to exclude + Specify ``raise_subcls`` (class or tuple of classes) to exclude those socket.error subclasses from being suppressed and retried. """ raise_subcls = kwargs.pop('raise_subcls', ()) @@ -553,7 +557,7 @@ def strip_netloc(url): server-absolute portion. Useful for wrapping an absolute-URI for which only the - path is expected (such as in calls to getPage). + path is expected (such as in calls to :py:meth:`WebCase.getPage`). >>> strip_netloc('https://google.com/foo/bar?bing#baz') '/foo/bar?bing' diff --git a/lib/cheroot/testing.py b/lib/cheroot/testing.py index 94bb7734..c9a6ac99 100644 --- a/lib/cheroot/testing.py +++ b/lib/cheroot/testing.py @@ -61,14 +61,14 @@ def cheroot_server(server_factory): httpserver.stop() # destroy it -@pytest.fixture(scope='module') +@pytest.fixture def wsgi_server(): """Set up and tear down a Cheroot WSGI server instance.""" for srv in cheroot_server(cheroot.wsgi.Server): yield srv -@pytest.fixture(scope='module') +@pytest.fixture def native_server(): """Set up and tear down a Cheroot HTTP server instance.""" for srv in cheroot_server(cheroot.server.HTTPServer): diff --git a/lib/cheroot/workers/threadpool.py b/lib/cheroot/workers/threadpool.py index 8c1d29f7..915934cc 100644 --- a/lib/cheroot/workers/threadpool.py +++ b/lib/cheroot/workers/threadpool.py @@ -1,4 +1,9 @@ -"""A thread-based worker pool.""" +"""A thread-based worker pool. + +.. spelling:: + + joinable +""" from __future__ import absolute_import, division, print_function __metaclass__ = type @@ -111,11 +116,6 @@ class WorkerThread(threading.Thread): if conn is _SHUTDOWNREQUEST: return - # Just close the connection and move on. - if conn.closeable: - conn.close() - continue - self.conn = conn is_stats_enabled = self.server.stats['Enabled'] if is_stats_enabled: @@ -125,7 +125,7 @@ class WorkerThread(threading.Thread): keep_conn_open = conn.communicate() finally: if keep_conn_open: - self.server.connections.put(conn) + self.server.put_conn(conn) else: conn.close() if is_stats_enabled: @@ -176,7 +176,10 @@ class ThreadPool: for i in range(self.min): self._threads.append(WorkerThread(self.server)) for worker in self._threads: - worker.setName('CP Server ' + worker.getName()) + worker.setName( + 'CP Server {worker_name!s}'. + format(worker_name=worker.getName()), + ) worker.start() for worker in self._threads: while not worker.ready: @@ -192,7 +195,7 @@ class ThreadPool: """Put request into queue. Args: - obj (cheroot.server.HTTPConnection): HTTP connection + obj (:py:class:`~cheroot.server.HTTPConnection`): HTTP connection waiting to be processed """ self._queue.put(obj, block=True, timeout=self._queue_put_timeout) @@ -223,7 +226,10 @@ class ThreadPool: def _spawn_worker(self): worker = WorkerThread(self.server) - worker.setName('CP Server ' + worker.getName()) + worker.setName( + 'CP Server {worker_name!s}'. + format(worker_name=worker.getName()), + ) worker.start() return worker diff --git a/lib/cheroot/wsgi.py b/lib/cheroot/wsgi.py index 30599b35..6635f528 100644 --- a/lib/cheroot/wsgi.py +++ b/lib/cheroot/wsgi.py @@ -119,10 +119,7 @@ class Gateway(server.Gateway): corresponding class """ - return dict( - (gw.version, gw) - for gw in cls.__subclasses__() - ) + return {gw.version: gw for gw in cls.__subclasses__()} def get_environ(self): """Return a new environ dict targeting the given wsgi.version.""" @@ -195,9 +192,9 @@ class Gateway(server.Gateway): """Cast status to bytes representation of current Python version. According to :pep:`3333`, when using Python 3, the response status - and headers must be bytes masquerading as unicode; that is, they + and headers must be bytes masquerading as Unicode; that is, they must be of type "str" but are restricted to code points in the - "latin-1" set. + "Latin-1" set. """ if six.PY2: return status @@ -300,7 +297,11 @@ class Gateway_10(Gateway): # Request headers env.update( - ('HTTP_' + bton(k).upper().replace('-', '_'), bton(v)) + ( + 'HTTP_{header_name!s}'. + format(header_name=bton(k).upper().replace('-', '_')), + bton(v), + ) for k, v in req.inheaders.items() ) @@ -321,7 +322,7 @@ class Gateway_10(Gateway): class Gateway_u0(Gateway_10): """A Gateway class to interface HTTPServer with WSGI u.0. - WSGI u.0 is an experimental protocol, which uses unicode for keys + WSGI u.0 is an experimental protocol, which uses Unicode for keys and values in both Python 2 and Python 3. """ @@ -409,7 +410,7 @@ class PathInfoDispatcher: path = environ['PATH_INFO'] or '/' for p, app in self.apps: # The apps list should be sorted by length, descending. - if path.startswith(p + '/') or path == p: + if path.startswith('{path!s}/'.format(path=p)) or path == p: environ = environ.copy() environ['SCRIPT_NAME'] = environ.get('SCRIPT_NAME', '') + p environ['PATH_INFO'] = path[len(p):]